Study: Artificial Intelligence(AI)/AI: 2D Vision(Det, Seg, Trac)

[๋…ผ๋ฌธ๋ฆฌ๋ทฐ] VAE: Auto Encoder - Auto-Encoding Variational Bayes

DrawingProcess 2024. 7. 5. 15:38
๋ฐ˜์‘ํ˜•
๐Ÿ’ก ๋ณธ ๋ฌธ์„œ๋Š” 'VAE: Auto Encoder - Auto-Encoding Variational Bayes' ๋…ผ๋ฌธ์„ ์ •๋ฆฌํ•ด๋†“์€ ๊ธ€์ž…๋‹ˆ๋‹ค.
ํ•ด๋‹น ๋…ผ๋ฌธ์€ CLIP ๊ฐ™์€ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๋ชจ๋ธ์˜ language embedding์„ NeRF ์•ˆ์— ์ง‘์–ด๋„ฃ์–ด NeRF๋ฅผ Multi Modal๋กœ ํ™•์žฅ ๊ฐ€๋Šฅ์„ฑ์„ ๋ณด์—ฌ์ค€ ๋…ผ๋ฌธ์ด๋‹ˆ ์ฐธ๊ณ ํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.
 - Paper: https://arxiv.org/pdf/1312.6114

์‚ฌ์ „ ์ง€์‹: Domain Gap

VAE๋Š” ๋…ผ๋ฌธ์„ ์ดํ•ดํ•˜๋ ค๋ฉด ๊ฝค ๋งŽ์€(์ ์–ด๋„ ๋‚˜์—๊ฒŒ๋Š”) ์‚ฌ์ „์ง€์‹์ด ํ•„์š”ํ–ˆ์œผ๋ฉฐ, ๊ฐ„๋‹จํ•˜๊ฒŒ ์ •๋ฆฌํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. (์ž์„ธํ•œ ์„ค๋ช…์€ ์ฐธ๊ณ ๋งํฌ๋ฅผ ํ™•์ธํ•˜๊ธฐ ๋ฐ”๋ž€๋‹ค.)

[1] VAE๋Š” Generative Model์ด๋‹ค.

[2] ํ™•๋ฅ  ํ†ต๊ณ„ ์ด๋ก  (Bayseain, conditional prob, pdf etc)

  • ๋ฒ ์ด์ง€์•ˆ ํ™•๋ฅ (Bayesian probability): ์„ธ์ƒ์— ๋ฐ˜๋ณตํ•  ์ˆ˜ ์—†๋Š” ํ˜น์€ ์•Œ ์ˆ˜ ์—†๋Š” ํ™•๋ฅ ๋“ค, ์ฆ‰ ์ผ์–ด๋‚˜์ง€ ์•Š์€ ์ผ์— ๋Œ€ํ•œ ํ™•๋ฅ ์„ ์‚ฌ๊ฑด๊ณผ ๊ด€๋ จ์ด ์žˆ๋Š” ์—ฌ๋Ÿฌ ํ™•๋ฅ ๋“ค์„ ์ด์šฉํ•ด ์šฐ๋ฆฌ๊ฐ€ ์•Œ๊ณ ์‹ถ์€ ์‚ฌ๊ฑด์„ ์ถ”์ •ํ•˜๋Š” ๊ฒƒ์ด ๋ฒ ์ด์ง€์•ˆ ํ™•๋ฅ ์ด๋‹ค.
  • ๋ฒ ์ด์ง€์•ˆ ์ด๋ก  ๊ด€๋ จ ์„ค๋ช…๊ธ€

[3] ๊ด€๋ จ ์šฉ์–ด๋“ค

  • latent : ‘์ž ์žฌํ•˜๋Š”’, ‘์ˆจ์–ด์žˆ๋Š”’, ‘hidden’์˜ ๋œป์„ ๊ฐ€์ง„ ๋‹จ์–ด. ์—ฌ๊ธฐ์„œ ๋งํ•˜๋Š” latent variable z๋Š” ํŠน์ง•(feature)๋ฅผ ๊ฐ€์ง„ vector๋กœ ์ดํ•ดํ•˜๋ฉด ์ข‹๋‹ค.
  • intractable : ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•œ ์‹œ๊ฐ„์ด ๋ฌธ์ œ์˜ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ง€์ˆ˜์ ์œผ๋กœ (exponential) ์ฆ๊ฐ€ํ•œ๋‹ค๋ฉด ๊ทธ ๋ฌธ์ œ๋Š” ๋‚œํ•ด (intractable) ํ•˜๋‹ค๊ณ  ํ•œ๋‹ค.
  • explicit density model : ์ƒ˜ํ”Œ๋ง ๋ชจ๋ธ์˜ ๊ตฌ์กฐ(๋ถ„ํฌ)๋ฅผ ๋ช…ํ™•ํžˆ ์ •์˜
  • implicit density model : ์ƒ˜ํ”Œ๋ง ๋ชจ๋ธ์˜ ๊ตฌ์กฐ(๋ถ„ํฌ)๋ฅผ explicitํ•˜๊ฒŒ ์ •์˜ํ•˜์ง€ ์•Š์Œ
  • density estimation : x๋ผ๋Š” ๋ฐ์ดํ„ฐ๋งŒ ๊ด€์ฐฐํ•  ์ˆ˜ ์žˆ์„ ๋•Œ, ๊ด€์ฐฐํ•  ์ˆ˜ ์—†๋Š” x๊ฐ€ ์ƒ˜ํ”Œ๋œ ํ™•๋ฅ ๋ฐ€๋„ํ•จ์ˆ˜(probability density function)์„ estimateํ•˜๋Š” ๊ฒƒ
  • Gaussian distribution : ์ •๊ทœ๋ถ„ํฌ
  • Bernoulli distribution : ๋ฒ ๋ฅด๋ˆ„์ด๋ถ„ํฌ
  • Marginal Probability : ์ฃผ๋ณ€ ํ™•๋ฅ  ๋ถ„ํฌ
  • D_kl : ์ฟจ๋ฐฑ-๋ผ์ด๋ธ”๋Ÿฌ ๋ฐœ์‚ฐ(Kullback–Leibler divergence, KLD), ๋‘ ํ™•๋ฅ ๋ถ„ํฌ์˜ ์ฐจ์ด
  • Encode / Decode: ์•”ํ˜ธํ™”,๋ถ€ํ˜ธํ™” / ์•”ํ˜ธํ™”ํ•ด์ œ,๋ถ€ํ˜ธํ™”ํ•ด์ œ
  • likelihood : ๊ฐ€๋Šฅ๋„. ์ด์— ๋Œ€ํ•œ ์„ค๋ช…์€ ๊ผญ ์•„๋ž˜ ๋งํฌ์— ๋“ค์–ด๊ฐ€ ์ฝ์–ด๋ณด๊ธธ ๋ฐ”๋ž€๋‹ค.
  • likellihood์— ๋Œ€ํ•œ ์„ค๋ช…๊ธ€

๋“ฑ์˜ ๊ฐœ๋…๋“ค์„ ์ˆ™์ง€ํ•˜๊ณ  ๋„˜์–ด๊ฐ€์•ผ ๋…ผ๋ฌธ์— ๋Œ€ํ•œ ๋‚ด์šฉ์„ ์ดํ•ด๋ฅผ ํ•  ์ˆ˜ ์žˆ๋‹ค.

[4] Auto-Encoder

  • VAE์™€ ์˜คํ† ์ธ์ฝ”๋”(AE)๋Š” ๋ชฉ์ ์ด ์ „ํ˜€ ๋‹ค๋ฅด๋‹ค.
  • ์˜คํ† ์ธ์ฝ”๋”์˜ ๋ชฉ์ ์€ ์–ด๋–ค ๋ฐ์ดํ„ฐ๋ฅผ ์ž˜ ์••์ถ•ํ•˜๋Š”๊ฒƒ, ์–ด๋–ค ๋ฐ์ดํ„ฐ์˜ ํŠน์ง•์„ ์ž˜ ๋ฝ‘๋Š” ๊ฒƒ, ์–ด๋–ค ๋ฐ์ดํ„ฐ์˜ ์ฐจ์›์„ ์ž˜ ์ค„์ด๋Š” ๊ฒƒ์ด๋‹ค.
  • ๋ฐ˜๋ฉด VAE์˜ ๋ชฉ์ ์€ Generative model์œผ๋กœ ์–ด๋–ค ์ƒˆ๋กœ์šด X๋ฅผ ๋งŒ๋“ค์–ด๋‚ด๋Š” ๊ฒƒ์ด๋‹ค.

Abstract

ํฐ ๋ฐ์ดํ„ฐ์…‹๊ณผ ๊ณ„์‚ฐ์ด ๋ถˆ๊ฐ€๋Šฅํ•œ posterior ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง€๋Š” ์—ฐ์†ํ˜• ์ž ์žฌ ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์„ ๋•Œ, ์–ด๋–ป๊ฒŒ directed probabilistic model์„ ํšจ์œจ์ ์œผ๋กœ ํ•™์Šตํ•˜๊ณ  ์ถ”๋ก ํ•  ์ˆ˜ ์žˆ์„๊นŒ?

์šฐ๋ฆฌ๋Š” ํฐ ๋ฐ์ดํ„ฐ์…‹์—๋„ ํ™•์žฅํ•  ์ˆ˜ ์žˆ๊ณ  ๊ฐ€๋ฒผ์šด ๋ฏธ๋ถ„๊ฐ€๋Šฅ์„ฑ ์กฐ๊ฑด์ด ์žˆ๋‹ค๋ฉด ๊ณ„์‚ฐ์ด ๋ถˆ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ์—๋„ ์ž‘๋™ํ•˜๋Š” stochastic variational inference and learning ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ œ์•ˆํ•œ๋‹ค.

์šฐ๋ฆฌ์˜ ๊ธฐ์—ฌ๋Š” ๋‘ ๊ฐ€์ง€์ด๋‹ค.

์ฒซ ๋ฒˆ์งธ, variational lower bound์˜ reparameterization์ด ํ‘œ์ค€์ ์ธ stochastic gradient ๋ฐฉ๋ฒ•๋ก ๋“ค์„ ์‚ฌ์šฉํ•˜์—ฌ ์ง์ ‘์ ์œผ๋กœ ์ตœ์ ํ™”๋  ์ˆ˜ ์žˆ๋Š” lower bound estimator๋ฅผ ๋งŒ๋“ค์–ด๋‚ธ๋‹ค๋Š” ๊ฒƒ์„ ๋ณด์˜€๋‹ค.

๋‘ ๋ฒˆ์งธ, ๊ฐ datapoint๊ฐ€ ์—ฐ์†ํ˜• ์ž ์žฌ ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ง€๋Š” i.i.d. ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•ด์„œ, ์ œ์•ˆ๋œ lower bound estimator๋ฅผ ์‚ฌ์šฉํ•ด approximate inference model(๋˜๋Š” recognition model์ด๋ผ๊ณ  ๋ถˆ๋ฆผ)์„ ๊ณ„์‚ฐ์ด ๋ถˆ๊ฐ€๋Šฅํ•œ posterior์— fitting ์‹œํ‚ด์œผ๋กœ์จ posterior inference๊ฐ€ ํŠนํžˆ ํšจ์œจ์ ์œผ๋กœ ๋งŒ๋“ค์–ด์งˆ ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์„ ๋ณด์ธ๋‹ค. ์‹คํ—˜ ๊ฒฐ๊ณผ์— ์ด๋ก ์  ์ด์ ์ด ๋ฐ˜์˜๋˜์—ˆ๋‹ค.

VAE

์ด์ œ๋ถ€ํ„ฐ ๋ณธ๊ฒฉ์ ์œผ๋กœ VAE ๊ด€๋ จ๋œ ๋‚ด์šฉ๋“ค์„ ์ฝ”๋“œ์™€ ํ•จ๊ป˜ ์‚ดํŽด๋ณด์ž. ๊ธฐ์กด์˜ ๋…ผ๋ฌธ์˜ ํ๋ฆ„์€ Generative Model์ด ๊ฐ€์ง€๋Š” ๋ฌธ์ œ์ ๋“ค์„ ํ•ด์†Œํ•˜๊ธฐ ์œ„ํ•ด ์–ด๋–ค ๋ฐฉ์‹์„ ๋„์ž…ํ–ˆ๋Š”์ง€ ์ฐจ๋ก€์ฐจ๋ก€ ์„ค๋ช…ํ•˜๊ณ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ ๊ด€๋ จ๋œ ์ˆ˜์‹๋„ ๋งŽ๊ณ  ์ค‘๊ฐ„์— ์ƒ๋žต๋œ ์‹๋„ ๋งŽ์•„ ๋…ผ๋ฌธ๋Œ€๋กœ ๋”ฐ๋ผ๊ฐ€๋‹ค๋ณด๋ฉด ์ „์ฒด์ ์ธ ๊ตฌ์กฐ๋ฅผ ์ดํ•ดํ•˜๊ธฐ ํž˜๋“ค๊ธฐ๋•Œ๋ฌธ์— ๋จผ์ € ๊ตฌ์กฐ๋ฅผ ์‚ดํŽด๋ณธ ๋’ค ๊ฐ ๊ตฌ์กฐ๊ฐ€ ๊ฐ€์ง€๋Š” ์˜๋ฏธ๊ฐ€ ๋ฌด์—‡์ธ์ง€(์™œ ์ด ๊ตฌ์กฐ๊ฐ€ ๋‚˜์™”๋Š”์ง€) ์‚ดํŽด๋ณด๊ณ  ์ตœ์ข…์ ์œผ๋กœ ์ •๋ฆฌํ•˜๋„๋ก ํ•  ์˜ˆ์ •์ด๋‹ค. (Top-down approach(?))

VAE GOAL

๋…ผ๋ฌธ Abstract์— ๋‚˜์™€์žˆ๋Š” ์ฒซ ๋ฌธ์žฅ์ด๋‹ค. ์ด ๋ชฉ์ ์„ ์ดํ•ดํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ˆ ์ฒœ์ฒœํžˆ ๋ณด๋ฉด์„œ ์ดํ•ดํ•˜๊ธฐ ๋ฐ”๋ž€๋‹ค.

How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets?

VAE์˜ ๋ชฉํ‘œ๋Š” Generative Model์˜ ๋ชฉํ‘œ์™€ ๊ฐ™๋‹ค. (1) data์™€ ๊ฐ™์€ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง€๋Š” sample ๋ถ„ํฌ์—์„œ sample์„ ๋ฝ‘๊ณ (2) ์–ด๋–ค ์ƒˆ๋กœ์šด ๊ฒƒ์„ ์ƒ์„ฑํ•ด๋‚ด๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ๋‹ค. ์ฆ‰,

  • (1) ์ฃผ์–ด์ง„ training data๊ฐ€ p_data(x)(ํ™•๋ฅ ๋ฐ€๋„ํ•จ์ˆ˜)๊ฐ€ ์–ด๋–ค ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค๋ฉด, sample ๋ชจ๋ธ p_model(x) ์—ญ์‹œ ๊ฐ™์€ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง€๋ฉด์„œ, (sampling ๋ถ€๋ถ„)
  • (2) ๊ทธ ๋ชจ๋ธ์„ ํ†ตํ•ด ๋‚˜์˜จ inference ๊ฐ’์ด ์ƒˆ๋กœ์šด x๋ผ๋Š” ๋ฐ์ดํ„ฐ์ด๊ธธ ๋ฐ”๋ž€๋‹ค. (Generation ๋ถ€๋ถ„)

์˜ˆ๋ฅผ ๋“ค์–ด, ๋ช‡ ๊ฐœ์˜ ๋‹ค์ด์•„๋ชฌ๋“œ(training data)๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ด๋ณด์ž. ๊ทธ๋Ÿฌ๋ฉด training ๋‹ค์ด์•„๋ชฌ๋“œ ๋ฟ๋งŒ์•„๋‹ˆ๋ผ ๋ชจ๋“  ๋‹ค์ด์•„๋ชฌ๋“œ์˜ ํ™•๋ฅ ๋ถ„ํฌ์™€ ๋˜‘๊ฐ™์€ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ์—์„œ ๊ฐ’์„ ๋ฝ‘์•„(1. sampling) training ์‹œ์ผฐ๋˜ ๋‹ค์ด์•„๋ชฌ๋“œ์™€๋Š” ๋‹ค๋ฅธ ๋˜ ๋‹ค๋ฅธ ๋‹ค์ด์•„๋ชฌ๋“œ(new)๋ฅผ ๋งŒ๋“œ๋Š”(generate) ๊ฒƒ์ด๋‹ค.

VAE ๊ตฌ์กฐ

๋ฐฑ๋ฌธ์ด ๋ถˆ์–ด์ผ๊ฒฌ. VAE์˜ ์ „์ฒด ๊ตฌ์กฐ๋ฅผ ํ•œ ๋„์‹์œผ๋กœ ์‚ดํŽด๋ณด์ž.

์ผ€๋ผ์Šค ๊ต์žฌ์— ๊ตฌํ˜„๋œ ์ฝ”๋“œ์™€ ๋…ผ๋ฌธ์˜ ๊ตฌ์กฐ๋Š” ์•ฝ๊ฐ„์˜ ์ฐจ์ด๊ฐ€ ์žˆ๋‹ค. ์ „์ฒด์ ์ธ ๊ตฌ์กฐ๋Š” ๋˜‘๊ฐ™์œผ๋‹ˆ ํฌ๊ฒŒ ํ—ท๊ฐˆ๋ฆด ๊ฒƒ์€ ์—†์ง€๋งŒ, ๊ทธ๋ž˜๋„ ์ฝ”๋“œ์˜ ์•ฝ๊ฐ„์˜ ๋ณ€ํ˜•๋œ ๋ถ€๋ถ„์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

๋…ผ๋ฌธ๊ณผ ๋‹ค๋ฅธ์  : Input shape, Encoder์˜ NN ๋ชจ๋ธ, Decoder์˜ NN๋ชจ๋ธ (์ฝ”๋“œ์—์„œ๋Š” ์™ผ์ชฝ์˜ ๊ฐ ๋ถ€๋ถ„๋“ค์„ DNN์„ CNN๊ตฌ์กฐ๋กœ ๋ฐ”๊ฟˆ)

์œ„์˜ ๋„์‹์€ VAE ๊ตฌ์กฐ๋ฅผ ์™„๋ฒฝํžˆ ์ •๋ฆฌํ•œ ๊ทธ๋ฆผ์ด๋‹ค. ์ด์ œ ์ด ๊ทธ๋ฆผ์„ ๋ณด๋ฉด์„œ, input ๊ทธ๋ฆผ์ด ์žˆ์„ ๋•Œ ์–ด๋–ค ์˜๋ฏธ๋ฅผ ๊ฐ€์ง„ ๊ตฌ์กฐ๋ฅผ ๊ฑฐ์ณ output์ด ๋‚˜์˜ค๊ฒŒ ๋˜๋Š”์ง€ 3 ๋‹จ๊ณ„๋กœ ๋‚˜๋ˆ„์–ด ์‚ดํŽด๋ณด์ž.

  1. input: x –> ๐‘ž_∅ (๐‘ฅ)–> ๐œ‡_๐‘–,๐œŽ_๐‘–
  2. ๐œ‡_๐‘–, ๐œŽ_๐‘–, ๐œ–_๐‘– –> ๐‘ง_๐‘–
  3. ๐‘ง_๐‘– –> ๐‘”_๐œƒ (๐‘ง_๐‘–) –> ๐‘_๐‘– : output

1. Encoder

input: x –> ๐‘ž_∅ (๐‘ฅ)–> ๐œ‡_๐‘–,๐œŽ_๐‘–

img_shape = (28,28,1)
batch_size = 16
latent_dim = 2

input_img = keras.Input(shape = img_shape)
x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img)
x = layers.Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)

shape_before_flattening = K.int_shape(x) # return tuple of integers of shape of x

x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)

z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
  • Input shape(x) : (28,28,1)
  • ๐‘ž_∅ (๐‘ฅ) ๋Š” encoder ํ•จ์ˆ˜์ธ๋ฐ, x๊ฐ€ ์ฃผ์–ด์กŒ์„๋•Œ(given) z๊ฐ’์˜ ๋ถ„ํฌ์˜ ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์„ ์•„์›ƒํ’‹์œผ๋กœ ๋‚ด๋Š” ํ•จ์ˆ˜์ด๋‹ค.
  • ๋‹ค์‹œ๋งํ•ด q ํ•จ์ˆ˜(=Encoder)์˜ output์€ ๐œ‡_๐‘–,๐œŽ_๐‘– ์ด๋‹ค.

์–ด๋–ค X๋ผ๋Š” ์ž…๋ ฅ์„ ๋„ฃ์–ด ์ธ์ฝ”๋”์˜ ์•„์›ƒํ’‹์€ ๐œ‡_๐‘–,๐œŽ_๐‘– ์ด๋‹ค. ์–ด๋–ค ๋ฐ์ดํ„ฐ์˜ ํŠน์ง•์„(latent variable) X๋ฅผ ํ†ตํ•ด ์ถ”์ธกํ•œ๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ ์—ฌ๊ธฐ์„œ ๋‚˜์˜จ ํŠน์ง•๋“ค์˜ ๋ถ„ํฌ๋Š” ์ •๊ทœ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅธ๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ์ด๋Ÿฐ ํŠน์ง•๋“ค์ด ๊ฐ€์ง€๋Š” ํ™•๋ฅ  ๋ถ„ํฌ ๐‘ž_∅ (๐‘ฅ) (์ •ํ™•ํžˆ ๋งํ•˜๋ฉด $์˜ true ๋ถ„ํฌ (= $)๋ฅผ ์ •๊ทœ๋ถ„ํฌ(=Gaussian)๋ผ ๊ฐ€์ •ํ•œ๋‹ค๋Š” ๋ง์ด๋‹ค. ๋”ฐ๋ผ์„œ latent space์˜ latent variable ๊ฐ’๋“ค์€ ๐‘ž_∅ (๐‘ฅ)์˜ true ๋ถ„ํฌ๋ฅผ approximateํ•˜๋Š” ๐œ‡_๐‘–,๐œŽ_๐‘–๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค.

Encoder ํ•จ์ˆ˜์˜ output์€ latent variable์˜ ๋ถ„ํฌ์˜ ๐œ‡ ์™€ ๐œŽ ๋ฅผ ๋‚ด๊ณ , ์ด output๊ฐ’์„ ํ‘œํ˜„ํ•˜๋Š” ํ™•๋ฅ ๋ฐ€๋„ํ•จ์ˆ˜๋ฅผ ์ƒ๊ฐํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค.

2. Reparameterization Trick (Sampling)

๐œ‡_๐‘–, ๐œŽ_๐‘–, ๐œ–_๐‘– –> ๐‘ง_๐‘–

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_var])

๋งŒ์•ฝ Encoder ๊ฒฐ๊ณผ์—์„œ ๋‚˜์˜จ ๊ฐ’์„ ํ™œ์šฉํ•ด decoding ํ•˜๋Š”๋ฐ sampling ํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด ์–ด๋–ค ์ผ์ด ๋ฒŒ์–ด์งˆ๊นŒ? ๋‹น์—ฐํžˆ ๋Š” ํ•œ ๊ฐ’์„ ๊ฐ€์ง€๋ฏ€๋กœ ๊ทธ์— ๋Œ€ํ•œ decoder(NN)์—ญ์‹œ ํ•œ ๊ฐ’๋งŒ ๋ฑ‰๋Š”๋‹ค. ๊ทธ๋ ‡๊ฒŒ ๋œ๋‹ค๋ฉด ์–ด๋–ค ํ•œ variable์€ ๋ฌด์กฐ๊ฑด ๋˜‘๊ฐ™์€ ํ•œ ๊ฐ’์˜ output์„ ๊ฐ€์ง€๊ฒŒ ๋œ๋‹ค.

ํ•˜์ง€๋งŒ Generative Model, VAE๊ฐ€ ํ•˜๊ณ  ์‹ถ์€ ๊ฒƒ์€, ์–ด๋–ค data์˜ true ๋ถ„ํฌ๊ฐ€ ์žˆ์œผ๋ฉด ๊ทธ ๋ถ„ํฌ์—์„œ ํ•˜๋‚˜๋ฅผ ๋ฝ‘์•„ ๊ธฐ์กด DB์— ์žˆ์ง€ ์•Š์€ ์ƒˆ๋กœ์šด data๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์‹ถ๋‹ค. ๋”ฐ๋ผ์„œ ์šฐ๋ฆฌ๋Š” ํ•„์—ฐ์ ์œผ๋กœ ๊ทธ ๋ฐ์ดํ„ฐ์˜ ํ™•๋ฅ ๋ถ„ํฌ์™€ ๊ฐ™์€ ๋ถ„ํฌ์—์„œ ํ•˜๋‚˜๋ฅผ ๋ฝ‘๋Š” sampling์„ ํ•ด์•ผํ•œ๋‹ค. ํ•˜์ง€๋งŒ ๊ทธ๋ƒฅ sampling ํ•œ๋‹ค๋ฉด sampling ํ•œ ๊ฐ’๋“ค์„ backpropagation ํ•  ์ˆ˜ ์—†๋‹ค.(์•„๋ž˜์˜ ๊ทธ๋ฆผ์„ ๋ณด๋ฉด ์ง๊ด€์ ์œผ๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค) ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด reparmeterization trick์„ ์‚ฌ์šฉํ•œ๋‹ค.

์ •๊ทœ๋ถ„ํฌ์—์„œ z1๋ฅผ ์ƒ˜ํ”Œ๋งํ•˜๋Š” ๊ฒƒ์ด๋‚˜, ์ž…์‹ค๋ก ์„ ์ •๊ทœ๋ถ„ํฌ(์ž์„ธํžˆ๋Š” N(0,1))์—์„œ ์ƒ˜ํ”Œ๋งํ•˜๊ณ  ๊ทธ ๊ฐ’์„ ๋ถ„์‚ฐ๊ณผ ๊ณฑํ•˜๊ณ  ํ‰๊ท ์„ ๋”ํ•ด z2๋ฅผ ๋งŒ๋“ค๊ฑฐ๋‚˜ ๋‘ z1,z2 ๋Š” ๊ฐ™์€ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง€๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๊ทธ๋ž˜์„œ ์ฝ”๋“œ์—์„œ epsilon์„ ๋จผ์ € ์ •๊ทœ๋ถ„ํฌ์—์„œ randomํ•˜๊ฒŒ ๋ฝ‘๊ณ , ๊ทธ epsilon์„ exp(z_log_var)๊ณผ ๊ณฑํ•˜๊ณ  z_mean์„ ๋”ํ•œ๋‹ค. ๊ทธ๋ ‡๊ฒŒ ํ˜•์„ฑ๋œ ๊ฐ’์ด z๊ฐ€ ๋œ๋‹ค.

latent variable์—์„œ sample๋œ z๋ผ๋Š” value (= decoder input)์ด ๋งŒ๋“ค์–ด์ง„๋‹ค.

3. Decoder

๐‘ง_๐‘– –> ๐‘”_๐œƒ (๐‘ง_๐‘–) –> ๐‘_๐‘– : output

## 8.25 VAE decoder network, mapping latent space points to imgaes

decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)

decoder = Model(decoder_input, x)
z_decoded = decoder(z)

z ๊ฐ’์„ g ํ•จ์ˆ˜(decoder)์— ๋„ฃ๊ณ  deconv(์ฝ”๋“œ์—์„œ๋Š” Conv2DTranspose)๋ฅผ ํ•ด ์›๋ž˜ ์ด๋ฏธ์ง€ ์‚ฌ์ด์ฆˆ์˜ ์•„์›ƒํ’‹ z_decoded๊ฐ€ ๋‚˜์˜ค๊ฒŒ ๋œ๋‹ค. ์ด๋•Œ p_data(x)์˜ ๋ถ„ํฌ๋ฅผ Bernoulli ๋กœ ๊ฐ€์ •ํ–ˆ์œผ๋ฏ€๋กœ(์ด๋ฏธ์ง€ recognition ์—์„œ Gaussian ์œผ๋กœ ๊ฐ€์ •ํ• ๋•Œ๋ณด๋‹ค Bernoulli๋กœ ๊ฐ€์ •ํ•ด์•ผ ์˜๋ฏธ์ƒ ๊ทธ๋ฆฌ๊ณ  ๊ฒฐ๊ณผ์ƒ ๋” ์ ์ ˆํ–ˆ๊ธฐ ๋•Œ๋ฌธ) output ๊ฐ’์€ 0~1 ์‚ฌ์ด ๊ฐ’์„ ๊ฐ€์ ธ์•ผํ•˜๊ณ , ์ด๋ฅผ ์œ„ํ•ด activatino function์„ sigmoid๋กœ ์„ค์ •ํ•ด์ฃผ์—ˆ๋‹ค. (Gaussian ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅธ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๊ณ  ํ‘ผ๋‹ค๋ฉด ์•„๋ž˜ loss๋ฅผ ๋‹ค๋ฅด๊ฒŒ ์„ค์ •ํ•ด์•ผํ•œ๋‹ค.)

VAE ํ•™์Šต

Loss Fucntion ์ดํ•ด

Loss ๋Š” ํฌ๊ฒŒ ์ด 2๊ฐ€์ง€ ๋ถ€๋ถ„์ด ์žˆ๋‹ค.

 def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
        kl_loss   = -5e-4*K.mean(1+z_log_var-K.square(z_mean)-K.exp(z_log_var),axis=-1)
        return K.mean(xent_loss + kl_loss)
  • Reconstruction Loss(code์—์„œ๋Š” xent_loss)
  • Regularization Loss(code์—์„œ๋Š” kl_loss)

์ผ๋‹จ ์ง๊ด€์ ์œผ๋กœ ์ดํ•ด๋ฅผ ํ•˜์ž๋ฉด,

  1. Generative ๋ชจ๋ธ๋‹ต๊ฒŒ ์ƒˆ๋กœ์šด X๋ฅผ ๋งŒ๋“ค์–ด์•ผํ•˜๋ฏ€๋กœ X์™€ ๋งŒ๋“ค์–ด์ง„ output, New X์™€์˜ ๊ด€๊ณ„๋ฅผ ์‚ดํŽด๋ด์•ผํ•˜๊ณ , ์ด๋ฅผ Reconstruction Loss ๋ถ€๋ถ„์ด๋ผ๊ณ  ํ•œ๋‹ค. ์ด๋•Œ ๋””์ฝ”๋” ๋ถ€๋ถ„์˜ pdf๋Š” Bernoulli ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅธ๋‹ค๊ณ  ๊ฐ€์ •ํ–ˆ์œผ๋ฏ€๋กœ ๊ทธ ๋‘˜๊ฐ„์˜ cross entropy๋ฅผ ๊ตฌํ•œ๋‹ค( ์ด ๋ถ€๋ถ„์— ๋Œ€ํ•ด์„œ ์™œ ๊ฐ™์€์ง€๋Š” ์ˆ˜์‹์„ ํฌํ•จํ•œ ํฌ์Šคํ„ฐ์—์„œ ๋” ์ƒ์„ธํžˆ ๋‹ค๋ฃฐ ๊ฒƒ์ด๋‹ค)
  2. X๊ฐ€ ์›๋ž˜ ๊ฐ€์ง€๋Š” ๋ถ„ํฌ์™€ ๋™์ผํ•œ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง€๊ฒŒ ํ•™์Šตํ•˜๊ฒŒ ํ•˜๊ธฐ์œ„ํ•ด true ๋ถ„ํฌ๋ฅผ approximate ํ•œ ํ•จ์ˆ˜์˜ ๋ถ„ํฌ์— ๋Œ€ํ•œ loss term์ด Regularization Loss๋‹ค. ์ด๋•Œ loss๋Š” true pdf ์™€ approximated pdf๊ฐ„์˜ D_kl(๋‘ ํ™•๋ฅ ๋ถ„ํฌ์˜ ์ฐจ์ด(๊ฑฐ๋ฆฌ))์„ ๊ณ„์‚ฐํ•œ๋‹ค. (์ด ๋ถ€๋ถ„๋„ ์—ญ์‹œ ์™œ ์ด๋Ÿฐ ์‹์ด ๋‚˜์™”๋Š”์ง€๋Š” ์ˆ˜์‹์„ ํฌํ•จํ•œ ํฌ์Šคํ……์„œ ๋” ์ƒ์„ธํžˆ ๋‹ค๋ฃฐ ๊ฒƒ์ด๋‹ค)

ํ•™์Šต

encoder ๋ถ€๋ถ„๊ณผ decoder ๋ถ€๋ถ„์„ ํ•ฉ์ณ ํ•œ ๋ชจ๋ธ์„ ๋งŒ๋“ค๊ณ  train ํ•˜๋ฉด ๋! ์ž์„ธํ•œ ์ฝ”๋“œ๋Š” Github์— ์˜ฌ๋ ค๋‘์—ˆ์œผ๋‹ˆ ์ฐธ๊ณ ํ•˜๊ธฐ ๋ฐ”๋ž€๋‹ค.

VAE ๊ฒฐ๊ณผ

import matplotlib.pyplot as plt
from scipy.stats import norm

n=20
digit_size = 28
figure = np.zeros((digit_size*n,digit_size*n))
grid_x = norm.ppf(np.linspace(0.05,0.95,n))
grid_y = norm.ppf(np.linspace(0.05,0.95,n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi,yi]])
        z_sample = np.tile(z_sample,batch_size).reshape(batch_size,2)
        x_decoded = decoder.predict(z_sample, batch_size = batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i+1)*digit_size, j*digit_size:(j+1)*digit_size] = digit
        
plt.figure(figsize=(10,10))
plt.imshow(figure, cmap ='Greys_r')
plt.show()

์œ„์˜ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰์‹œํ‚ค๋ฉด ์œ„ ๊ทธ๋ฆผ์—์„œ ์˜ค๋ฅธ์ชฝ๊ณผ ๊ฐ™์€ ๋„์‹์ด ๋‚˜์˜ค๋Š”๋ฐ ํ•™์Šต์ด ์ž˜ ๋˜์—ˆ๋‹ค๋ฉด ์ฐจ์›์˜ manifold๋ฅผ ์ž˜ ํ•™์Šตํ–ˆ๋‹ค๋Š” ๋ง์ด๋‹ค. ๊ทธ manifold๋ฅผ 2์ฐจ์›์œผ๋กœ ์ถ•์†Œ์‹œํ‚จ ๊ฒƒ(z1,z2)์—์„œ z1 20๊ฐœ(0.05~0.95), z2 20๊ฐœ, ์ด 400๊ฐœ์˜ ์ˆœ์„œ์Œ์˜ xi,yi์—์„œ sample์„ ๋ฝ‘์•„ ์‹œ๊ฐํ™”ํ•œ๊ฒƒ์ด ์˜ค๋ฅธ์ชฝ ๊ทธ๋ฆผ์ธ๋ฐ 2D์ƒ์—์„œ ๊ฑฐ๋ฆฌ์˜ ์œ ์˜๋ฏธํ•œ ์ฐจ์ด์— ๋”ฐ๋ผ ์ˆซ์ž๋“ค์ด ๋‹ฌ๋ผ์ง€๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ฐ ์ˆซ์ž ์ƒ์—์„œ๋„ ์„œ๋กœ ๋‹ค๋ฅธ rotation๋“ค์„ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค๋Š” ๊ฒƒ์ด ๋ณด์ธ๋‹ค.

Insight

๋งˆ์ง€๋ง‰์œผ๋กœ VAE๋Š” ์™œํ•˜๋ƒ๊ณ  ๋ฌผ์–ด๋ณธ๋‹ค๋ฉด ํฌ๊ฒŒ 2๊ฐ€์ง€๋กœ ๋‹ตํ•  ์ˆ˜ ์žˆ๋‹ค.

  1. Generative Model ๋ชฉ์  ๋‹ฌ์„ฑ
  2. Latent variable control ๊ฐ€๋Šฅ
  • Generative Model์„ ํ†ตํ•ด ์ ์€ data๋ฅผ ๊ฐ€์ง€๊ณ  ์›๋ž˜ data๊ฐ€ ๊ฐ€์ง€๋Š” ๋ถ„ํฌ๋ฅผ ๊ฝค ๊ฐ€๊น๊ฒŒ ๊ทผ์‚ฌํ•˜๊ณ  ์ด๋ฅผ ํ†ตํ•ด ์ƒˆ๋กœ์šด data๋ฅผ ์ƒ์„ฑํ•ด๋‚ผ ์ˆ˜ ์žˆ๋‹ค๋Š” ์ .
  • Latent variable์˜ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ •ํ•ด ์šฐ๋ฆฌ๊ฐ€ sampling ํ•  ๊ฐ’๋“ค์˜ ๋ถ„ํฌ๋ฅผ control ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜๊ณ , manifold๋„ ์ž˜ ํ•™์Šต์ด ๋œ๋‹ค๋Š”์ .
  • ์ด๋Š” data์˜ ํŠน์ง•๋“ค๋„ ์ž˜ ์•Œ ์ˆ˜ ์žˆ๊ณ , ๊ทธ ํŠน์ง•๋“ค์˜ ๋ถ„ํฌ๋“ค์€ ํฌ๊ฒŒ ๋ฒ—์–ด๋‚˜์ง€ ์•Š๊ฒŒ control ํ•˜๋ฉด์„œ ๊ทธ ์†์—์„œ ์ƒˆ๋กœ์šด ๊ฐ’์„ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค๋Š” ์ .

์ •๋„๊ฐ€ ๋  ๊ฒƒ ๊ฐ™๋‹ค.

VAE์˜ ํ•œ๊ณ„์ ์„ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด, CVAE, AAE๊ฐ€ ๋‚˜์™”์œผ๋‹ˆ ๊ด€์‹ฌ์žˆ๋Š” ์‚ฌ๋žŒ์€ ๊ด€๋ จ๋œ ์ž๋ฃŒ๋ฅผ ์ฐพ์•„๋ณด๊ธฐ ๋ฐ”๋ž€๋‹ค

์ฐธ๊ณ 

๋ฐ˜์‘ํ˜•