[๋ ผ๋ฌธ๋ฆฌ๋ทฐ] VAE: Auto Encoder - Auto-Encoding Variational Bayes
๐ก ๋ณธ ๋ฌธ์๋ 'VAE: Auto Encoder - Auto-Encoding Variational Bayes' ๋ ผ๋ฌธ์ ์ ๋ฆฌํด๋์ ๊ธ์ ๋๋ค.
ํด๋น ๋ ผ๋ฌธ์ CLIP ๊ฐ์ ๋ฉํฐ๋ชจ๋ฌ ๋ชจ๋ธ์ language embedding์ NeRF ์์ ์ง์ด๋ฃ์ด NeRF๋ฅผ Multi Modal๋ก ํ์ฅ ๊ฐ๋ฅ์ฑ์ ๋ณด์ฌ์ค ๋ ผ๋ฌธ์ด๋ ์ฐธ๊ณ ํ์๊ธฐ ๋ฐ๋๋๋ค.
- Paper: https://arxiv.org/pdf/1312.6114
์ฌ์ ์ง์: Domain Gap
VAE๋ ๋ ผ๋ฌธ์ ์ดํดํ๋ ค๋ฉด ๊ฝค ๋ง์(์ ์ด๋ ๋์๊ฒ๋) ์ฌ์ ์ง์์ด ํ์ํ์ผ๋ฉฐ, ๊ฐ๋จํ๊ฒ ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค. (์์ธํ ์ค๋ช ์ ์ฐธ๊ณ ๋งํฌ๋ฅผ ํ์ธํ๊ธฐ ๋ฐ๋๋ค.)
[1] VAE๋ Generative Model์ด๋ค.
- Generative Model์ด๋ training data๊ฐ ์ฃผ์ด์ก์ ๋ ์ด training data๊ฐ ๊ฐ์ง๋ real ๋ถํฌ์ ๊ฐ์ ๋ถํฌ์์ sampling๋ ๊ฐ์ผ๋ก new data๋ฅผ ์์ฑํ๋ model์ ๋งํ๋ค.
- ์ด์ ์๋์ Generative model์ ๊ดํ ์ค๋ช ๊ธ
[2] ํ๋ฅ ํต๊ณ ์ด๋ก (Bayseain, conditional prob, pdf etc)
- ๋ฒ ์ด์ง์ ํ๋ฅ (Bayesian probability): ์ธ์์ ๋ฐ๋ณตํ ์ ์๋ ํน์ ์ ์ ์๋ ํ๋ฅ ๋ค, ์ฆ ์ผ์ด๋์ง ์์ ์ผ์ ๋ํ ํ๋ฅ ์ ์ฌ๊ฑด๊ณผ ๊ด๋ จ์ด ์๋ ์ฌ๋ฌ ํ๋ฅ ๋ค์ ์ด์ฉํด ์ฐ๋ฆฌ๊ฐ ์๊ณ ์ถ์ ์ฌ๊ฑด์ ์ถ์ ํ๋ ๊ฒ์ด ๋ฒ ์ด์ง์ ํ๋ฅ ์ด๋ค.
- ๋ฒ ์ด์ง์ ์ด๋ก ๊ด๋ จ ์ค๋ช ๊ธ
[3] ๊ด๋ จ ์ฉ์ด๋ค
- latent : ‘์ ์ฌํ๋’, ‘์จ์ด์๋’, ‘hidden’์ ๋ป์ ๊ฐ์ง ๋จ์ด. ์ฌ๊ธฐ์ ๋งํ๋ latent variable z๋ ํน์ง(feature)๋ฅผ ๊ฐ์ง vector๋ก ์ดํดํ๋ฉด ์ข๋ค.
- intractable : ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ํ์ํ ์๊ฐ์ด ๋ฌธ์ ์ ํฌ๊ธฐ์ ๋ฐ๋ผ ์ง์์ ์ผ๋ก (exponential) ์ฆ๊ฐํ๋ค๋ฉด ๊ทธ ๋ฌธ์ ๋ ๋ํด (intractable) ํ๋ค๊ณ ํ๋ค.
- explicit density model : ์ํ๋ง ๋ชจ๋ธ์ ๊ตฌ์กฐ(๋ถํฌ)๋ฅผ ๋ช ํํ ์ ์
- implicit density model : ์ํ๋ง ๋ชจ๋ธ์ ๊ตฌ์กฐ(๋ถํฌ)๋ฅผ explicitํ๊ฒ ์ ์ํ์ง ์์
- density estimation : x๋ผ๋ ๋ฐ์ดํฐ๋ง ๊ด์ฐฐํ ์ ์์ ๋, ๊ด์ฐฐํ ์ ์๋ x๊ฐ ์ํ๋ ํ๋ฅ ๋ฐ๋ํจ์(probability density function)์ estimateํ๋ ๊ฒ
- Gaussian distribution : ์ ๊ท๋ถํฌ
- Bernoulli distribution : ๋ฒ ๋ฅด๋์ด๋ถํฌ
- Marginal Probability : ์ฃผ๋ณ ํ๋ฅ ๋ถํฌ
- D_kl : ์ฟจ๋ฐฑ-๋ผ์ด๋ธ๋ฌ ๋ฐ์ฐ(Kullback–Leibler divergence, KLD), ๋ ํ๋ฅ ๋ถํฌ์ ์ฐจ์ด
- Encode / Decode: ์ํธํ,๋ถํธํ / ์ํธํํด์ ,๋ถํธํํด์
- likelihood : ๊ฐ๋ฅ๋. ์ด์ ๋ํ ์ค๋ช ์ ๊ผญ ์๋ ๋งํฌ์ ๋ค์ด๊ฐ ์ฝ์ด๋ณด๊ธธ ๋ฐ๋๋ค.
- likellihood์ ๋ํ ์ค๋ช ๊ธ
๋ฑ์ ๊ฐ๋ ๋ค์ ์์งํ๊ณ ๋์ด๊ฐ์ผ ๋ ผ๋ฌธ์ ๋ํ ๋ด์ฉ์ ์ดํด๋ฅผ ํ ์ ์๋ค.
[4] Auto-Encoder
- VAE์ ์คํ ์ธ์ฝ๋(AE)๋ ๋ชฉ์ ์ด ์ ํ ๋ค๋ฅด๋ค.
- ์คํ ์ธ์ฝ๋์ ๋ชฉ์ ์ ์ด๋ค ๋ฐ์ดํฐ๋ฅผ ์ ์์ถํ๋๊ฒ, ์ด๋ค ๋ฐ์ดํฐ์ ํน์ง์ ์ ๋ฝ๋ ๊ฒ, ์ด๋ค ๋ฐ์ดํฐ์ ์ฐจ์์ ์ ์ค์ด๋ ๊ฒ์ด๋ค.
- ๋ฐ๋ฉด VAE์ ๋ชฉ์ ์ Generative model์ผ๋ก ์ด๋ค ์๋ก์ด X๋ฅผ ๋ง๋ค์ด๋ด๋ ๊ฒ์ด๋ค.
Abstract
ํฐ ๋ฐ์ดํฐ์ ๊ณผ ๊ณ์ฐ์ด ๋ถ๊ฐ๋ฅํ posterior ๋ถํฌ๋ฅผ ๊ฐ์ง๋ ์ฐ์ํ ์ ์ฌ ๋ณ์๋ฅผ ๊ฐ์ง๊ณ ์์ ๋, ์ด๋ป๊ฒ directed probabilistic model์ ํจ์จ์ ์ผ๋ก ํ์ตํ๊ณ ์ถ๋ก ํ ์ ์์๊น?
์ฐ๋ฆฌ๋ ํฐ ๋ฐ์ดํฐ์ ์๋ ํ์ฅํ ์ ์๊ณ ๊ฐ๋ฒผ์ด ๋ฏธ๋ถ๊ฐ๋ฅ์ฑ ์กฐ๊ฑด์ด ์๋ค๋ฉด ๊ณ์ฐ์ด ๋ถ๊ฐ๋ฅํ ๊ฒฝ์ฐ์๋ ์๋ํ๋ stochastic variational inference and learning ์๊ณ ๋ฆฌ์ฆ์ ์ ์ํ๋ค.
์ฐ๋ฆฌ์ ๊ธฐ์ฌ๋ ๋ ๊ฐ์ง์ด๋ค.
์ฒซ ๋ฒ์งธ, variational lower bound์ reparameterization์ด ํ์ค์ ์ธ stochastic gradient ๋ฐฉ๋ฒ๋ก ๋ค์ ์ฌ์ฉํ์ฌ ์ง์ ์ ์ผ๋ก ์ต์ ํ๋ ์ ์๋ lower bound estimator๋ฅผ ๋ง๋ค์ด๋ธ๋ค๋ ๊ฒ์ ๋ณด์๋ค.
๋ ๋ฒ์งธ, ๊ฐ datapoint๊ฐ ์ฐ์ํ ์ ์ฌ ๋ณ์๋ฅผ ๊ฐ์ง๋ i.i.d. ๋ฐ์ดํฐ์ ์ ๋ํด์, ์ ์๋ lower bound estimator๋ฅผ ์ฌ์ฉํด approximate inference model(๋๋ recognition model์ด๋ผ๊ณ ๋ถ๋ฆผ)์ ๊ณ์ฐ์ด ๋ถ๊ฐ๋ฅํ posterior์ fitting ์ํด์ผ๋ก์จ posterior inference๊ฐ ํนํ ํจ์จ์ ์ผ๋ก ๋ง๋ค์ด์ง ์ ์๋ค๋ ์ ์ ๋ณด์ธ๋ค. ์คํ ๊ฒฐ๊ณผ์ ์ด๋ก ์ ์ด์ ์ด ๋ฐ์๋์๋ค.
VAE
์ด์ ๋ถํฐ ๋ณธ๊ฒฉ์ ์ผ๋ก VAE ๊ด๋ จ๋ ๋ด์ฉ๋ค์ ์ฝ๋์ ํจ๊ป ์ดํด๋ณด์. ๊ธฐ์กด์ ๋ ผ๋ฌธ์ ํ๋ฆ์ Generative Model์ด ๊ฐ์ง๋ ๋ฌธ์ ์ ๋ค์ ํด์ํ๊ธฐ ์ํด ์ด๋ค ๋ฐฉ์์ ๋์ ํ๋์ง ์ฐจ๋ก์ฐจ๋ก ์ค๋ช ํ๊ณ ์๋ค. ํ์ง๋ง ๊ด๋ จ๋ ์์๋ ๋ง๊ณ ์ค๊ฐ์ ์๋ต๋ ์๋ ๋ง์ ๋ ผ๋ฌธ๋๋ก ๋ฐ๋ผ๊ฐ๋ค๋ณด๋ฉด ์ ์ฒด์ ์ธ ๊ตฌ์กฐ๋ฅผ ์ดํดํ๊ธฐ ํ๋ค๊ธฐ๋๋ฌธ์ ๋จผ์ ๊ตฌ์กฐ๋ฅผ ์ดํด๋ณธ ๋ค ๊ฐ ๊ตฌ์กฐ๊ฐ ๊ฐ์ง๋ ์๋ฏธ๊ฐ ๋ฌด์์ธ์ง(์ ์ด ๊ตฌ์กฐ๊ฐ ๋์๋์ง) ์ดํด๋ณด๊ณ ์ต์ข ์ ์ผ๋ก ์ ๋ฆฌํ๋๋ก ํ ์์ ์ด๋ค. (Top-down approach(?))
VAE GOAL
๋ ผ๋ฌธ Abstract์ ๋์์๋ ์ฒซ ๋ฌธ์ฅ์ด๋ค. ์ด ๋ชฉ์ ์ ์ดํดํ๋ ๊ฒ์ด ๊ฐ์ฅ ์ค์ํ๋ ์ฒ์ฒํ ๋ณด๋ฉด์ ์ดํดํ๊ธฐ ๋ฐ๋๋ค.
How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets?
VAE์ ๋ชฉํ๋ Generative Model์ ๋ชฉํ์ ๊ฐ๋ค. (1) data์ ๊ฐ์ ๋ถํฌ๋ฅผ ๊ฐ์ง๋ sample ๋ถํฌ์์ sample์ ๋ฝ๊ณ (2) ์ด๋ค ์๋ก์ด ๊ฒ์ ์์ฑํด๋ด๋ ๊ฒ์ด ๋ชฉํ๋ค. ์ฆ,
- (1) ์ฃผ์ด์ง training data๊ฐ p_data(x)(ํ๋ฅ ๋ฐ๋ํจ์)๊ฐ ์ด๋ค ๋ถํฌ๋ฅผ ๊ฐ์ง๊ณ ์๋ค๋ฉด, sample ๋ชจ๋ธ p_model(x) ์ญ์ ๊ฐ์ ๋ถํฌ๋ฅผ ๊ฐ์ง๋ฉด์, (sampling ๋ถ๋ถ)
- (2) ๊ทธ ๋ชจ๋ธ์ ํตํด ๋์จ inference ๊ฐ์ด ์๋ก์ด x๋ผ๋ ๋ฐ์ดํฐ์ด๊ธธ ๋ฐ๋๋ค. (Generation ๋ถ๋ถ)
์๋ฅผ ๋ค์ด, ๋ช ๊ฐ์ ๋ค์ด์๋ชฌ๋(training data)๋ฅผ ๊ฐ์ง๊ณ ์๋ค๊ณ ์๊ฐํด๋ณด์. ๊ทธ๋ฌ๋ฉด training ๋ค์ด์๋ชฌ๋ ๋ฟ๋ง์๋๋ผ ๋ชจ๋ ๋ค์ด์๋ชฌ๋์ ํ๋ฅ ๋ถํฌ์ ๋๊ฐ์ ๋ถํฌ๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์์ ๊ฐ์ ๋ฝ์(1. sampling) training ์์ผฐ๋ ๋ค์ด์๋ชฌ๋์๋ ๋ค๋ฅธ ๋ ๋ค๋ฅธ ๋ค์ด์๋ชฌ๋(new)๋ฅผ ๋ง๋๋(generate) ๊ฒ์ด๋ค.
VAE ๊ตฌ์กฐ
๋ฐฑ๋ฌธ์ด ๋ถ์ด์ผ๊ฒฌ. VAE์ ์ ์ฒด ๊ตฌ์กฐ๋ฅผ ํ ๋์์ผ๋ก ์ดํด๋ณด์.
์ผ๋ผ์ค ๊ต์ฌ์ ๊ตฌํ๋ ์ฝ๋์ ๋ ผ๋ฌธ์ ๊ตฌ์กฐ๋ ์ฝ๊ฐ์ ์ฐจ์ด๊ฐ ์๋ค. ์ ์ฒด์ ์ธ ๊ตฌ์กฐ๋ ๋๊ฐ์ผ๋ ํฌ๊ฒ ํท๊ฐ๋ฆด ๊ฒ์ ์์ง๋ง, ๊ทธ๋๋ ์ฝ๋์ ์ฝ๊ฐ์ ๋ณํ๋ ๋ถ๋ถ์ ๋ค์๊ณผ ๊ฐ๋ค.
๋ ผ๋ฌธ๊ณผ ๋ค๋ฅธ์ : Input shape, Encoder์ NN ๋ชจ๋ธ, Decoder์ NN๋ชจ๋ธ (์ฝ๋์์๋ ์ผ์ชฝ์ ๊ฐ ๋ถ๋ถ๋ค์ DNN์ CNN๊ตฌ์กฐ๋ก ๋ฐ๊ฟ)
์์ ๋์์ VAE ๊ตฌ์กฐ๋ฅผ ์๋ฒฝํ ์ ๋ฆฌํ ๊ทธ๋ฆผ์ด๋ค. ์ด์ ์ด ๊ทธ๋ฆผ์ ๋ณด๋ฉด์, input ๊ทธ๋ฆผ์ด ์์ ๋ ์ด๋ค ์๋ฏธ๋ฅผ ๊ฐ์ง ๊ตฌ์กฐ๋ฅผ ๊ฑฐ์ณ output์ด ๋์ค๊ฒ ๋๋์ง 3 ๋จ๊ณ๋ก ๋๋์ด ์ดํด๋ณด์.
- input: x –> ๐_∅ (๐ฅ)–> ๐_๐,๐_๐
- ๐_๐, ๐_๐, ๐_๐ –> ๐ง_๐
- ๐ง_๐ –> ๐_๐ (๐ง_๐) –> ๐_๐ : output
1. Encoder
input: x –> ๐_∅ (๐ฅ)–> ๐_๐,๐_๐
img_shape = (28,28,1)
batch_size = 16
latent_dim = 2
input_img = keras.Input(shape = img_shape)
x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img)
x = layers.Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
shape_before_flattening = K.int_shape(x) # return tuple of integers of shape of x
x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
- Input shape(x) : (28,28,1)
- ๐_∅ (๐ฅ) ๋ encoder ํจ์์ธ๋ฐ, x๊ฐ ์ฃผ์ด์ก์๋(given) z๊ฐ์ ๋ถํฌ์ ํ๊ท ๊ณผ ๋ถ์ฐ์ ์์ํ์ผ๋ก ๋ด๋ ํจ์์ด๋ค.
- ๋ค์๋งํด q ํจ์(=Encoder)์ output์ ๐_๐,๐_๐ ์ด๋ค.
์ด๋ค X๋ผ๋ ์ ๋ ฅ์ ๋ฃ์ด ์ธ์ฝ๋์ ์์ํ์ ๐_๐,๐_๐ ์ด๋ค. ์ด๋ค ๋ฐ์ดํฐ์ ํน์ง์(latent variable) X๋ฅผ ํตํด ์ถ์ธกํ๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก ์ฌ๊ธฐ์ ๋์จ ํน์ง๋ค์ ๋ถํฌ๋ ์ ๊ท๋ถํฌ๋ฅผ ๋ฐ๋ฅธ๋ค๊ณ ๊ฐ์ ํ๋ค. ์ด๋ฐ ํน์ง๋ค์ด ๊ฐ์ง๋ ํ๋ฅ ๋ถํฌ ๐_∅ (๐ฅ) (์ ํํ ๋งํ๋ฉด $์ true ๋ถํฌ (= $)๋ฅผ ์ ๊ท๋ถํฌ(=Gaussian)๋ผ ๊ฐ์ ํ๋ค๋ ๋ง์ด๋ค. ๋ฐ๋ผ์ latent space์ latent variable ๊ฐ๋ค์ ๐_∅ (๐ฅ)์ true ๋ถํฌ๋ฅผ approximateํ๋ ๐_๐,๐_๐๋ฅผ ๋ํ๋ธ๋ค.
Encoder ํจ์์ output์ latent variable์ ๋ถํฌ์ ๐ ์ ๐ ๋ฅผ ๋ด๊ณ , ์ด output๊ฐ์ ํํํ๋ ํ๋ฅ ๋ฐ๋ํจ์๋ฅผ ์๊ฐํด๋ณผ ์ ์๋ค.
2. Reparameterization Trick (Sampling)
๐_๐, ๐_๐, ๐_๐ –> ๐ง_๐
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0., stddev=1.)
return z_mean + K.exp(z_log_var) * epsilon
z = layers.Lambda(sampling)([z_mean, z_log_var])
๋ง์ฝ Encoder ๊ฒฐ๊ณผ์์ ๋์จ ๊ฐ์ ํ์ฉํด decoding ํ๋๋ฐ sampling ํ์ง ์๋๋ค๋ฉด ์ด๋ค ์ผ์ด ๋ฒ์ด์ง๊น? ๋น์ฐํ ๋ ํ ๊ฐ์ ๊ฐ์ง๋ฏ๋ก ๊ทธ์ ๋ํ decoder(NN)์ญ์ ํ ๊ฐ๋ง ๋ฑ๋๋ค. ๊ทธ๋ ๊ฒ ๋๋ค๋ฉด ์ด๋ค ํ variable์ ๋ฌด์กฐ๊ฑด ๋๊ฐ์ ํ ๊ฐ์ output์ ๊ฐ์ง๊ฒ ๋๋ค.
ํ์ง๋ง Generative Model, VAE๊ฐ ํ๊ณ ์ถ์ ๊ฒ์, ์ด๋ค data์ true ๋ถํฌ๊ฐ ์์ผ๋ฉด ๊ทธ ๋ถํฌ์์ ํ๋๋ฅผ ๋ฝ์ ๊ธฐ์กด DB์ ์์ง ์์ ์๋ก์ด data๋ฅผ ์์ฑํ๊ณ ์ถ๋ค. ๋ฐ๋ผ์ ์ฐ๋ฆฌ๋ ํ์ฐ์ ์ผ๋ก ๊ทธ ๋ฐ์ดํฐ์ ํ๋ฅ ๋ถํฌ์ ๊ฐ์ ๋ถํฌ์์ ํ๋๋ฅผ ๋ฝ๋ sampling์ ํด์ผํ๋ค. ํ์ง๋ง ๊ทธ๋ฅ sampling ํ๋ค๋ฉด sampling ํ ๊ฐ๋ค์ backpropagation ํ ์ ์๋ค.(์๋์ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ์ง๊ด์ ์ผ๋ก ์ดํดํ ์ ์๋ค) ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด reparmeterization trick์ ์ฌ์ฉํ๋ค.
์ ๊ท๋ถํฌ์์ z1๋ฅผ ์ํ๋งํ๋ ๊ฒ์ด๋, ์ ์ค๋ก ์ ์ ๊ท๋ถํฌ(์์ธํ๋ N(0,1))์์ ์ํ๋งํ๊ณ ๊ทธ ๊ฐ์ ๋ถ์ฐ๊ณผ ๊ณฑํ๊ณ ํ๊ท ์ ๋ํด z2๋ฅผ ๋ง๋ค๊ฑฐ๋ ๋ z1,z2 ๋ ๊ฐ์ ๋ถํฌ๋ฅผ ๊ฐ์ง๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ๋์ ์ฝ๋์์ epsilon์ ๋จผ์ ์ ๊ท๋ถํฌ์์ randomํ๊ฒ ๋ฝ๊ณ , ๊ทธ epsilon์ exp(z_log_var)๊ณผ ๊ณฑํ๊ณ z_mean์ ๋ํ๋ค. ๊ทธ๋ ๊ฒ ํ์ฑ๋ ๊ฐ์ด z๊ฐ ๋๋ค.
latent variable์์ sample๋ z๋ผ๋ value (= decoder input)์ด ๋ง๋ค์ด์ง๋ค.
3. Decoder
๐ง_๐ –> ๐_๐ (๐ง_๐) –> ๐_๐ : output
## 8.25 VAE decoder network, mapping latent space points to imgaes
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)
decoder = Model(decoder_input, x)
z_decoded = decoder(z)
z ๊ฐ์ g ํจ์(decoder)์ ๋ฃ๊ณ deconv(์ฝ๋์์๋ Conv2DTranspose)๋ฅผ ํด ์๋ ์ด๋ฏธ์ง ์ฌ์ด์ฆ์ ์์ํ z_decoded๊ฐ ๋์ค๊ฒ ๋๋ค. ์ด๋ p_data(x)์ ๋ถํฌ๋ฅผ Bernoulli ๋ก ๊ฐ์ ํ์ผ๋ฏ๋ก(์ด๋ฏธ์ง recognition ์์ Gaussian ์ผ๋ก ๊ฐ์ ํ ๋๋ณด๋ค Bernoulli๋ก ๊ฐ์ ํด์ผ ์๋ฏธ์ ๊ทธ๋ฆฌ๊ณ ๊ฒฐ๊ณผ์ ๋ ์ ์ ํ๊ธฐ ๋๋ฌธ) output ๊ฐ์ 0~1 ์ฌ์ด ๊ฐ์ ๊ฐ์ ธ์ผํ๊ณ , ์ด๋ฅผ ์ํด activatino function์ sigmoid๋ก ์ค์ ํด์ฃผ์๋ค. (Gaussian ๋ถํฌ๋ฅผ ๋ฐ๋ฅธ๋ค๊ณ ๊ฐ์ ํ๊ณ ํผ๋ค๋ฉด ์๋ loss๋ฅผ ๋ค๋ฅด๊ฒ ์ค์ ํด์ผํ๋ค.)
VAE ํ์ต
Loss Fucntion ์ดํด
Loss ๋ ํฌ๊ฒ ์ด 2๊ฐ์ง ๋ถ๋ถ์ด ์๋ค.
def vae_loss(self, x, z_decoded):
x = K.flatten(x)
z_decoded = K.flatten(z_decoded)
xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
kl_loss = -5e-4*K.mean(1+z_log_var-K.square(z_mean)-K.exp(z_log_var),axis=-1)
return K.mean(xent_loss + kl_loss)
- Reconstruction Loss(code์์๋ xent_loss)
- Regularization Loss(code์์๋ kl_loss)
์ผ๋จ ์ง๊ด์ ์ผ๋ก ์ดํด๋ฅผ ํ์๋ฉด,
- Generative ๋ชจ๋ธ๋ต๊ฒ ์๋ก์ด X๋ฅผ ๋ง๋ค์ด์ผํ๋ฏ๋ก X์ ๋ง๋ค์ด์ง output, New X์์ ๊ด๊ณ๋ฅผ ์ดํด๋ด์ผํ๊ณ , ์ด๋ฅผ Reconstruction Loss ๋ถ๋ถ์ด๋ผ๊ณ ํ๋ค. ์ด๋ ๋์ฝ๋ ๋ถ๋ถ์ pdf๋ Bernoulli ๋ถํฌ๋ฅผ ๋ฐ๋ฅธ๋ค๊ณ ๊ฐ์ ํ์ผ๋ฏ๋ก ๊ทธ ๋๊ฐ์ cross entropy๋ฅผ ๊ตฌํ๋ค( ์ด ๋ถ๋ถ์ ๋ํด์ ์ ๊ฐ์์ง๋ ์์์ ํฌํจํ ํฌ์คํฐ์์ ๋ ์์ธํ ๋ค๋ฃฐ ๊ฒ์ด๋ค)
- X๊ฐ ์๋ ๊ฐ์ง๋ ๋ถํฌ์ ๋์ผํ ๋ถํฌ๋ฅผ ๊ฐ์ง๊ฒ ํ์ตํ๊ฒ ํ๊ธฐ์ํด true ๋ถํฌ๋ฅผ approximate ํ ํจ์์ ๋ถํฌ์ ๋ํ loss term์ด Regularization Loss๋ค. ์ด๋ loss๋ true pdf ์ approximated pdf๊ฐ์ D_kl(๋ ํ๋ฅ ๋ถํฌ์ ์ฐจ์ด(๊ฑฐ๋ฆฌ))์ ๊ณ์ฐํ๋ค. (์ด ๋ถ๋ถ๋ ์ญ์ ์ ์ด๋ฐ ์์ด ๋์๋์ง๋ ์์์ ํฌํจํ ํฌ์คํ ์ ๋ ์์ธํ ๋ค๋ฃฐ ๊ฒ์ด๋ค)
ํ์ต
encoder ๋ถ๋ถ๊ณผ decoder ๋ถ๋ถ์ ํฉ์ณ ํ ๋ชจ๋ธ์ ๋ง๋ค๊ณ train ํ๋ฉด ๋! ์์ธํ ์ฝ๋๋ Github์ ์ฌ๋ ค๋์์ผ๋ ์ฐธ๊ณ ํ๊ธฐ ๋ฐ๋๋ค.
VAE ๊ฒฐ๊ณผ
import matplotlib.pyplot as plt
from scipy.stats import norm
n=20
digit_size = 28
figure = np.zeros((digit_size*n,digit_size*n))
grid_x = norm.ppf(np.linspace(0.05,0.95,n))
grid_y = norm.ppf(np.linspace(0.05,0.95,n))
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi,yi]])
z_sample = np.tile(z_sample,batch_size).reshape(batch_size,2)
x_decoded = decoder.predict(z_sample, batch_size = batch_size)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i+1)*digit_size, j*digit_size:(j+1)*digit_size] = digit
plt.figure(figsize=(10,10))
plt.imshow(figure, cmap ='Greys_r')
plt.show()
์์ ์ฝ๋๋ฅผ ์คํ์ํค๋ฉด ์ ๊ทธ๋ฆผ์์ ์ค๋ฅธ์ชฝ๊ณผ ๊ฐ์ ๋์์ด ๋์ค๋๋ฐ ํ์ต์ด ์ ๋์๋ค๋ฉด ์ฐจ์์ manifold๋ฅผ ์ ํ์ตํ๋ค๋ ๋ง์ด๋ค. ๊ทธ manifold๋ฅผ 2์ฐจ์์ผ๋ก ์ถ์์ํจ ๊ฒ(z1,z2)์์ z1 20๊ฐ(0.05~0.95), z2 20๊ฐ, ์ด 400๊ฐ์ ์์์์ xi,yi์์ sample์ ๋ฝ์ ์๊ฐํํ๊ฒ์ด ์ค๋ฅธ์ชฝ ๊ทธ๋ฆผ์ธ๋ฐ 2D์์์ ๊ฑฐ๋ฆฌ์ ์ ์๋ฏธํ ์ฐจ์ด์ ๋ฐ๋ผ ์ซ์๋ค์ด ๋ฌ๋ผ์ง๋ ๊ฒ์ ํ์ธํ ์ ์์ผ๋ฉฐ, ๊ฐ ์ซ์ ์์์๋ ์๋ก ๋ค๋ฅธ rotation๋ค์ ๊ฐ์ง๊ณ ์๋ค๋ ๊ฒ์ด ๋ณด์ธ๋ค.
Insight
๋ง์ง๋ง์ผ๋ก VAE๋ ์ํ๋๊ณ ๋ฌผ์ด๋ณธ๋ค๋ฉด ํฌ๊ฒ 2๊ฐ์ง๋ก ๋ตํ ์ ์๋ค.
- Generative Model ๋ชฉ์ ๋ฌ์ฑ
- Latent variable control ๊ฐ๋ฅ
- Generative Model์ ํตํด ์ ์ data๋ฅผ ๊ฐ์ง๊ณ ์๋ data๊ฐ ๊ฐ์ง๋ ๋ถํฌ๋ฅผ ๊ฝค ๊ฐ๊น๊ฒ ๊ทผ์ฌํ๊ณ ์ด๋ฅผ ํตํด ์๋ก์ด data๋ฅผ ์์ฑํด๋ผ ์ ์๋ค๋ ์ .
- Latent variable์ ๋ถํฌ๋ฅผ ๊ฐ์ ํด ์ฐ๋ฆฌ๊ฐ sampling ํ ๊ฐ๋ค์ ๋ถํฌ๋ฅผ control ํ ์ ์๊ฒ ๋๊ณ , manifold๋ ์ ํ์ต์ด ๋๋ค๋์ .
- ์ด๋ data์ ํน์ง๋ค๋ ์ ์ ์ ์๊ณ , ๊ทธ ํน์ง๋ค์ ๋ถํฌ๋ค์ ํฌ๊ฒ ๋ฒ์ด๋์ง ์๊ฒ control ํ๋ฉด์ ๊ทธ ์์์ ์๋ก์ด ๊ฐ์ ๋ง๋ค ์ ์๋ค๋ ์ .
์ ๋๊ฐ ๋ ๊ฒ ๊ฐ๋ค.
VAE์ ํ๊ณ์ ์ ๊ทน๋ณตํ๊ธฐ ์ํด, CVAE, AAE๊ฐ ๋์์ผ๋ ๊ด์ฌ์๋ ์ฌ๋์ ๊ด๋ จ๋ ์๋ฃ๋ฅผ ์ฐพ์๋ณด๊ธฐ ๋ฐ๋๋ค
์ฐธ๊ณ
- [Paper] VAE: https://arxiv.org/pdf/1312.6114.pdf
- [PPT] ๋ค์ด๋ฒ ์ดํ์๋์ ์ฌ๋ผ์ด๋๋ ธํธ: https://www.slideshare.net/NaverEngineering/ss-96581209
- [Youtube] ๋ค์ด๋ฒ ์ดํ์๋์ ์คํ ์ธ์ฝ๋์ ๋ชจ๋ ๊ฒ ๊ฐ์ in Youtube: https://www.youtube.com/watch?v=rNh2CrTFpm4&t=2206s
- [Blog] ์ด์ ์๋์ VAE ๊ด๋ จ tutorial ๊ธ: https://dnddnjs.github.io/paper/2018/06/19/vae/
- [Github] ์ฝ๋ in Github : Deeplearning-with-Python-์ผ๋ผ์ค ๋ฅ๋ฌ๋ ๊ต์ฌ ์์: https://github.com/Taeu/FirstYear_RealWorld/blob/master/GoogleStudy/Keras_week8_2/8.4%20VAE.ipynb
- [Blog] VAE(Auto-Encoding Variational Bayes) ์ง๊ด์ ์ดํด: https://taeu.github.io/paper/deeplearning-paper-vae/