Variational Autoencoders (I): That's What It's All About

By 苏剑林 | March 18, 2018

Although I hadn't looked into it closely in the past, I always had the impression that the Variational Autoencoder (VAE) was a fascinating tool. Taking advantage of a recent short-lived enthusiasm for Probabilistic Graphical Models, I decided to strive to understand VAE as well. As usual, I flipped through a lot of materials online and found that, without exception, they were quite vague. The general feeling was that while there were pages of formulas, everything remained blurry. When I finally thought I understood it and checked the implementation code, I felt the code and the theory were two completely different things.

Finally, by piecing things together and combining them with my recent accumulation of knowledge on probabilistic models, and repeatedly comparing everything with the original paper "Auto-Encoding Variational Bayes", I think I've finally figured it out. Actually, the real VAE is somewhat different from what many tutorials describe. Many tutorials write a lot without highlighting the key points of the model. Thus, I wrote this piece, hoping to explain VAE clearly through the following text.

Distribution Transformation

We usually compare VAE with GAN. Indeed, their goals are essentially the same—hoping to build a model that generates target data $X$ from hidden variables $Z$. However, the implementation is different. More accurately, they assume that $Z$ follows some common distribution (like a normal or uniform distribution), and then hope to train a model $X=g(Z)$ that maps the original probability distribution to the distribution of the training set. That is to say, the purpose of both is to perform transformations between distributions.

Difficulty of generative models
The difficulty of generative models is judging the similarity between the generated distribution and the true distribution because we only know the sampling results of both, not their analytical distribution expressions.

Now, assume $Z$ follows a standard normal distribution. I can sample several $Z_1, Z_2, \dots, Z_n$ from it and then apply a transformation to get $\hat{X}_1 = g(Z_1), \hat{X}_2 = g(Z_2), \dots, \hat{X}_n = g(Z_n)$. How do we judge whether the distribution of this dataset constructed through $g$ is the same as our target dataset's distribution? Some readers might say, "Isn't there KL divergence?" Of course not, because KL divergence is calculated based on the expressions of two probability distributions to calculate their similarity. However, we currently do not know the expressions of their probability distributions. We only have a batch of data $\{\hat{X}_1, \hat{X}_2, \dots, \hat{X}_n\}$ sampled from the constructed distribution and a batch of data $\{X_1, X_2, \dots, X_n\}$ sampled from the real distribution (which is the training set we want to generate). We only have the samples themselves, not the distribution expressions, so we have no way to calculate KL divergence.

Despite these difficulties, a solution must be found. GAN's approach is direct and blunt: since there is no suitable metric, I'll simply train that metric using a neural network as well. Thus, WGAN was born (for details, please refer to "The Art of Mutual Conflict: Direct to WGAN-GP from Zero"). VAE, on the other hand, uses an exquisite and roundabout trick.

A Slow Talk on VAE

In this part, we first review how general tutorials introduce VAE, then explore the problems with those explanations, and then naturally discover the true face of VAE.

Classic Review

First, we have a batch of data samples $\{X_1, \dots, X_n\}$, represented as $X$. We originally wanted to obtain the distribution $p(X)$ of $X$ based on these samples. If we could get it, we could sample directly according to $p(X)$ to obtain all possible $X$ (including those beyond the original samples). This would be the ultimate ideal generative model. Of course, this ideal is difficult to achieve, so we modify the distribution:

\[p(X)=\sum_Z p(X|Z)p(Z) \tag{1}\]

Here we don't distinguish between summation or integration; as long as the meaning is clear. At this point, $p(X|Z)$ describes a model that generates $X$ from $Z$, and we assume $Z$ follows a standard normal distribution, i.e., $p(Z)=\mathcal{N}(0,I)$. If this ideal could be realized, we could first sample a $Z$ from the standard normal distribution and then calculate an $X$ based on $Z$, which would also be an excellent generative model. Next, this is combined with an autoencoder to implement reconstruction, ensuring that no effective information is lost, coupled with a series of derivations, and finally implementing the model. The schematic diagram of the framework is as follows:

Traditional understanding of VAE
Traditional understanding of VAE

Do you see the problem? If things followed this diagram, it would be completely unclear whether the resampled $Z_k$ still corresponds to the original $X_k$. Therefore, if we directly minimize $\mathcal{D}(\hat{X}_k, X_k)^2$ (where $\mathcal{D}$ represents some distance function), it would be very unscientific. In fact, if you look at the code, you will find it is not implemented this way at all. In other words, many tutorials say a lot of plausible-sounding things, but when writing the code, they don't follow the text they wrote, yet they don't seem to see the contradiction.

The Appearance of VAE

In fact, in the entire VAE model, we do not use the assumption that $p(Z)$ (the distribution of the latent space) is a normal distribution; instead, we assume that $p(Z|X)$ (the posterior distribution) is a normal distribution!!

Specifically, given a real sample $X_k$, we assume there exists a distribution $p(Z|X_k)$ exclusive to $X_k$ (academically called the posterior distribution), and we further assume this distribution is a (independent, multivariate) normal distribution. Why emphasize "exclusive"? Because later we need to train a generator $X=g(Z)$, hoping to restore $X_k$ from a $Z_k$ sampled from the distribution $p(Z|X_k)$. If we assumed $p(Z)$ was a normal distribution and sampled a $Z$ from $p(Z)$, how would we know which real $X$ this $Z$ corresponds to? Now that $p(Z|X_k)$ is exclusive to $X_k$, we have reason to say that a $Z$ sampled from this distribution should be restored to $X_k$.

In fact, in the application section of the paper "Auto-Encoding Variational Bayes," this point is specifically emphasized:

In this case, we can let the variational approximate posterior be a multivariate Gaussian with a diagonal covariance structure: \[\log q_{\phi}(\boldsymbol{z}|\boldsymbol{x}^{(i)}) = \log \mathcal{N}(\boldsymbol{z} ;\boldsymbol{\mu}^{(i)},\boldsymbol{\sigma}^{2(i)}\boldsymbol{I}) \tag{9}\] (Note: This is directly quoted from the original paper; the symbols used here are not completely consistent with this article to avoid confusion.)

Equation (9) in the paper is the key to implementing the entire model. I don't know why many tutorials skip over this when introducing VAE. Although the paper also mentions $p(Z)$ being a standard normal distribution, that is not the most essential part.

Returning to this article, each $X_k$ is now paired with an exclusive normal distribution, facilitating the generator's restoration work. But this means there are as many normal distributions as there are $X$ samples. We know that a normal distribution has two sets of parameters: mean $\mu$ and variance $\sigma^2$ (for multivariate, they are vectors). How do I find the mean and variance of the normal distribution $p(Z|X_k)$ exclusive to $X_k$? There doesn't seem to be a direct approach. Well, then I'll use a neural network to fit them! This is the philosophy of the neural network era: use neural networks to fit anything hard to calculate. We've experienced this once with WGAN, and now we experience it again.

So we construct two neural networks $\mu_k = f_1(X_k), \log \sigma_k^2 = f_2(X_k)$ to calculate them. We choose to fit $\log \sigma_k^2$ instead of $\sigma_k^2$ directly because $\sigma_k^2$ is always non-negative and would require an activation function, while fitting $\log \sigma_k^2$ does not since it can be positive or negative. At this point, I can know the mean and variance exclusive to $X_k$, and thus know what its normal distribution looks like. Then, I sample a $Z_k$ from this exclusive distribution and pass it through a generator to get $\hat{X}_k=g(Z_k)$. Now we can safely minimize $\mathcal{D}(\hat{X}_k, X_k)^2$ because $Z_k$ was sampled from the distribution exclusive to $X_k$; this generator should restore the initial $X_k$. Thus, we can draw the schematic diagram of VAE:

In reality VAE constructs exclusive distributions
In fact, VAE constructs an exclusive normal distribution for each sample and then samples to reconstruct.

Distribution Standardization

Let's think about what result we will eventually get according to the training process described above.

First, we hope to reconstruct $X$, which means minimizing $\mathcal{D}(\hat{X}_k, X_k)^2$. However, this reconstruction process is affected by noise because $Z_k$ is resampled and not directly calculated by the encoder. Obviously, noise increases the difficulty of reconstruction. Fortunately, the intensity of this noise (the variance) is calculated by a neural network. Therefore, to achieve better reconstruction, the model will surely try every means to make the variance zero. If the variance is zero, there is no randomness, so no matter how you sample, you only get a deterministic result (the mean). Fitting one result is obviously easier than fitting multiple results. The mean is calculated by another neural network.

To put it bluntly, the model will slowly degenerate into a standard AutoEncoder; the noise will no longer play a role.

Isn't that a waste of effort? What about the promised "generative model"?

Don't worry. Actually, VAE also makes all $p(Z|X)$ look like a standard normal distribution, which prevents the noise from becoming zero and ensures the model has generative capabilities. How do we understand "ensuring generative capabilities"? If all $p(Z|X)$ are very close to the standard normal distribution $\mathcal{N}(0,I)$, then by definition:

\[p(Z)=\sum_X p(Z|X)p(X)=\sum_X \mathcal{N}(0,I)p(X)=\mathcal{N}(0,I) \sum_X p(X) = \mathcal{N}(0,I) \tag{2}\]

In this way, we can achieve our prior assumption: $p(Z)$ is a standard normal distribution. Then we can safely sample from $\mathcal{N}(0,I)$ to generate images.

VAE forces p(Z|X) toward normal distribution
To give the model generative power, VAE requires each $p(Z|X)$ to look like a standard normal distribution.

So how do we make all $p(Z|X)$ look like $\mathcal{N}(0,I)$? Without external knowledge, the most direct method would be to add an extra loss on top of the reconstruction error:

\[\mathcal{L}_{\mu}=\Vert f_1(X_k)\Vert^2 \quad \text{and} \quad \mathcal{L}_{\sigma^2}=\Vert f_2(X_k)\Vert^2 \tag{3}\]

Because they respectively represent the mean $\mu_k$ and the logarithm of the variance $\log\sigma_k^2$, reaching $\mathcal{N}(0,I)$ means hoping both are as close to 0 as possible. However, this again faces the problem of how to choose the ratio between these two losses; if chosen poorly, the generated images will be blurry. Therefore, the original paper directly calculated the KL divergence between a general (with independent components) normal distribution and the standard normal distribution $KL\Big(N(\mu,\sigma^2)\Big\Vert N(0,I)\Big)$ as this extra loss. The calculation result is:

\[\mathcal{L}_{\mu,\sigma^2}=\frac{1}{2} \sum_{i=1}^d \Big(\mu_{(i)}^2 + \sigma_{(i)}^2 - \log \sigma_{(i)}^2 - 1\Big) \tag{4}\]

Here $d$ is the dimension of the latent variable $Z$, and $\mu_{(i)}$ and $\sigma_{(i)}^2$ represent the $i$-th components of the mean vector and variance vector of the general normal distribution, respectively. Using this formula as a supplementary loss avoids the problem of the relative ratio between mean loss and variance loss. Obviously, this loss can also be understood in two parts:

\begin{align} &\mathcal{L}_{\mu,\sigma^2}=\mathcal{L}_{\mu} + \mathcal{L}_{\sigma^2} \\ &\mathcal{L}_{\mu}=\frac{1}{2} \sum_{i=1}^d \mu_{(i)}^2=\frac{1}{2}\Vert f_1(X)\Vert^2 \\ &\mathcal{L}_{\sigma^2}=\frac{1}{2} \sum_{i=1}^d\Big(\sigma_{(i)}^2 - \log \sigma_{(i)}^2 - 1\Big) \end{align} \tag{5}
Derivation
Since we are considering a multivariate normal distribution with independent components, we only need to derive the univariate case. According to the definition, we can write: \begin{align} &KL\Big(N(\mu,\sigma^2)\Big\Vert N(0,1)\Big) \\ =&\int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \left(\log \frac{e^{-(x-\mu)^2/2\sigma^2}/\sqrt{2\pi\sigma^2}}{e^{-x^2/2}/\sqrt{2\pi}}\right)dx \\ =&\int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \log \left\{\frac{1}{\sqrt{\sigma^2}}\exp\left\{\frac{1}{2}\big[x^2-(x-\mu)^2/\sigma^2\big]\right\} \right\}dx \\ =&\frac{1}{2}\int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \Big[-\log \sigma^2+x^2-(x-\mu)^2/\sigma^2 \Big] dx \end{align} The result is divided into three integral terms. The first term is essentially $-\log \sigma^2$ multiplied by the integral of the probability density (which is 1), so the result is $-\log \sigma^2$. The second term is actually the second moment of the normal distribution; readers familiar with normal distributions should know the second moment is $\mu^2+\sigma^2$. By definition, the third term is effectively "-variance divided by variance = -1". So the total result is: \[KL\Big(N(\mu,\sigma^2)\Big\Vert N(0,1)\Big)=\frac{1}{2}\Big(-\log \sigma^2+\mu^2+\sigma^2-1\Big)\]

Reparameterization Trick

Reparameterization trick
Reparameterization trick

Finally, there is a trick for implementing the model, known as the reparameterization trick. It's actually very simple: we need to sample a $Z_k$ from $p(Z|X_k)$. Although we know $p(Z|X_k)$ is a normal distribution, the mean and variance are calculated by the model. We need to use this process to inversely optimize the mean and variance models, but the "sampling" operation is not differentiable, while the result of sampling is. We utilize:

\[\frac{1}{\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(z-\mu)^2}{2\sigma^2}\right)dz = \frac{1}{\sqrt{2\pi}}\exp\left[-\frac{1}{2}\left(\frac{z-\mu}{\sigma}\right)^2\right]d\left(\frac{z-\mu}{\sigma}\right) \tag{6}\]

This shows that $(z-\mu)/\sigma=\varepsilon$ follows a standard normal distribution with mean 0 and variance 1. We must include $dz$ because multiplying by $dz$ counts as a probability; removing $dz$ refers to the probability density, not the probability. Here we get:

Sampling a $Z$ from $\mathcal{N}(\mu,\sigma^2)$ is equivalent to sampling an $\varepsilon$ from $\mathcal{N}(0,I)$ and then letting $Z=\mu + \varepsilon \times \sigma$.

Consequently, we've turned sampling from $\mathcal{N}(\mu,\sigma^2)$ into sampling from $\mathcal{N}(0,I)$ and then obtaining the result of sampling from $\mathcal{N}(\mu,\sigma^2)$ through parameter transformation. In this way, the "sampling" operation itself does not need to participate in gradient descent; instead, the result of the sampling does, making the entire model trainable.

Regarding the specific implementation, if you compare the above text with the code, you'll understand it instantly.

Follow-up Analysis

Even after clarifying all the content above, we may still have many questions regarding VAE.

What is the Essence?

What is the essence of VAE? Although VAE is called a type of AE (Autoencoder), its approach (or its interpretation of the network) is unique. In VAE, it has two encoders: one for calculating the mean and one for calculating the variance. This is already surprising: the encoder isn't used for "encoding" but for calculating the mean and variance. This is big news. Also, aren't mean and variance statistics? How are they calculated using neural networks?

In fact, I think VAE starts from variational and Bayesian theories that are daunting to ordinary people but eventually lands on a specific model that is actually very down-to-earth: It is essentially our conventional autoencoder with "Gaussian noise" added to the output of the encoder (which corresponds to the network that calculates the mean in VAE), so that the decoder can be robust to noise; and that extra KL loss (aiming to make the mean 0 and variance 1) is actually equivalent to a regularization term for the encoder, hoping the output of the encoder has a zero mean.

And the role of the other encoder (the one that calculates the variance)? It's used to dynamically adjust the intensity of the noise. Intuitively, when the decoder hasn't been trained well yet (reconstruction error is much larger than KL loss), it appropriately reduces noise (KL loss increases), making fitting easier (reconstruction error starts to drop); conversely, if the decoder is trained well (reconstruction error is smaller than KL loss), the noise will increase (KL loss decreases), making the fitting more difficult (reconstruction error starts to increase again). At this point, the decoder must figure out how to improve its generative capacity.

Essential structure of VAE
Essential structure of VAE

To put it bluntly, the reconstruction process wants no noise, while the KL loss wants Gaussian noise; the two are in opposition. Therefore, VAE, like GAN, contains an internal adversarial process, except in VAE, the two are mixed together and evolve jointly. From this perspective, VAE's idea seems slightly more sophisticated because in GAN, when the counterfeiter evolves, the discriminator stays still, and vice versa. Of course, this is just one aspect and doesn't prove VAE is better than GAN. GAN's real brilliance is that it even trains the metric directly, and this metric is often better than what we think of manually (though GAN itself has various problems, which won't be expanded on here).

From this discussion, we can also see that each $p(Z|X)$ cannot possibly equal the standard normal distribution exactly; otherwise, $p(Z|X)$ would be independent of $X$, and the reconstruction effect would be extremely poor. The final result is that $p(Z|X)$ retains some information about $X$, the reconstruction effect is acceptable, and Equation (2) approximately holds, so generative capacity is maintained.

Normal Distribution?

Regarding the distribution of $p(Z|X)$, readers might wonder: is it necessary to choose a normal distribution? Can a uniform distribution be chosen?

It's probably not feasible, mainly because of the formula for calculating KL divergence:

\[KL\Big(p(x)\Big\Vert q(x)\Big) = \int p(x) \ln \frac{p(x)}{q(x)}dx \tag{7}\]

If $p(x) \neq 0$ and $q(x) = 0$ in some region, the KL divergence becomes infinite. For a normal distribution, the probability density at all points is non-negative, so this problem doesn't exist. But for a uniform distribution, as long as the two distributions are inconsistent, there will inevitably be intervals where $p(x) \neq 0$ and $q(x) = 0$. Therefore, the KL divergence will be infinite. Of course, we would prevent division-by-zero errors when writing code, but it still wouldn't prevent the KL loss from accounting for a huge proportion. Thus, the model would rapidly reduce the KL loss, meaning the posterior distribution $p(Z|X)$ would quickly converge to the prior distribution $p(Z)$, and the noise and reconstruction would not be able to act adversarially. This brings us back to what we said at the start: it wouldn't be possible to distinguish which $z$ corresponds to which $x$.

Of course, it's not impossible to use a uniform distribution; you'd have to calculate the KL divergence between two uniform distributions, handle division-by-zero errors well, increase the weight of the reconstruction loss, etc. But that would look quite ugly.

Where is the "Variational"?

Another interesting (but not very important) question is: VAE is called a "Variational Autoencoder." What is its connection to the Calculus of Variations? In the VAE paper and relevant interpretations, there doesn't seem to be an obvious presence of variational calculus.

Well, actually, if the reader already accepts the KL divergence, then VAE really doesn't have much to do with variations anymore. Theoretically, for the KL divergence (7), we need to prove:

Given a fixed probability distribution $p(x)$ (or $q(x)$), for any probability distribution $q(x)$ (or $p(x)$), we have $KL\Big(p(x)\Big\Vert q(x)\Big) \geq 0$, and it equals zero only when $p(x)=q(x)$.

Since $KL\Big(p(x)\Big\Vert q(x)\Big)$ is actually a functional, finding its extreme value requires using the calculus of variations. Of course, the variational method here is just a parallel extension of ordinary calculus and hasn't involved truly complex variational calculus yet. Since the variational lower bound of VAE is derived directly based on KL divergence, once you accept KL divergence, there is no more "variational" business.

In short, the name "variational" in VAE is because its derivation process uses KL divergence and its properties.

Conditional VAE

Finally, since the current VAE is trained unsupervised, it's natural to think: if there is labeled data, can label information be added to assist in generating samples? The purpose of this question is usually the hope of being able to control a certain variable to generate an image of a certain category. Of course, this is certainly possible. We call this a Conditional VAE, or CVAE. (Similarly, we have CGAN in GAN.)

However, CVAE is not a specific model but a class of models. In short, there are many ways to integrate label information into VAE for different purposes. Based on the previous discussion, a very simple version of CVAE is presented here.

A simple CVAE structure
A simple CVAE structure

In the previous discussion, we hoped that after $X$ is encoded, the distribution of $Z$ would have zero mean and unit variance. This "hope" was achieved by adding a KL loss. If we now have category information $Y$, we can hope that samples of the same class all have an exclusive mean $\mu^Y$ (with the variance being the same, still unit variance), and let the model train this $\mu^Y$ itself. In this case, there are as many normal distributions as there are classes. During generation, we can control the category of the generated image by controlling the mean. In fact, this might be the solution with the least code added on top of VAE to implement CVAE, as this "new hope" only requires modifying the KL loss:

\[\mathcal{L}_{\mu,\sigma^2}=\frac{1}{2} \sum_{i=1}^d\Big[\big(\mu_{(i)}-\mu^Y_{(i)}\big)^2 + \sigma_{(i)}^2 - \log \sigma_{(i)}^2 - 1\Big] \tag{8}\]

The figure below shows that this simple CVAE has some effect, but because the encoder and decoder are quite simple (pure MLP), the control of generation is not perfectly ideal. For a more complete CVAE, readers should learn on their own. Recently, there has also been work combining CVAE with GAN, called CVAE-GAN. Model tricks are ever-changing.

CVAE generating digits
Using this CVAE to control the generation of the digit 9, you can see that various styles of 9 were generated, and they slowly transit towards 7, so preliminary observation suggests this CVAE is effective.

Code

I copied a version of the official Keras VAE code, made slight adjustments, added Chinese comments based on the content of this article, and implemented the simple CVAE mentioned at the end for readers' reference.

Code: https://github.com/bojone/vae

Final Stop

Stumbling along, we've reached the end of the article. I don't know if I've explained it clearly; I hope everyone provides more feedback.

Generally speaking, the idea of VAE is very beautiful. It's not necessarily that it provides a superior generative model (because in reality, the images it generates are not that great and tend to be blurry), but rather that it provides a fantastic case study for combining probabilistic graphs with deep learning. There are many aspects of this case study worth contemplating.