Variational Autoencoders (Part 2): From a Bayesian Perspective

By 苏剑林 | March 28, 2018

Origins

A few days ago, I wrote the blog post "Variational Autoencoders (Part 1): So That's How It Is", which understood Variational Autoencoders (VAE) from a relatively intuitive perspective. From the viewpoint of that article, a VAE is not much different from a standard autoencoder, except for the addition of noise and constraints on that noise. However, my original intention for wanting to understand VAEs was to see exactly how the probabilistic graphical models of the Bayesian school work in conjunction with Deep Learning. Gaining only an intuitive understanding is clearly not enough. Therefore, I continued to think about VAEs for a few days, attempting to explain them clearly using more general, probabilistic language. In fact, this line of thinking can answer questions that the intuitive understanding cannot, such as whether it is better to use MSE or Cross-Entropy for reconstruction loss, how to balance the reconstruction loss and the KL loss, and so on.

It is recommended to read "Variational Autoencoders (Part 1): So That's How It Is" before reading this article. This post will try to avoid overlapping with the previous content as much as possible.

Preparation

Before entering the description of VAE, I think it is necessary to explain some conceptual content.

Numerical Computation vs. Sampling Computation

For readers who are not very familiar with probability and statistics, two easily confused concepts are numerical computation and sampling computation. Some readers have expressed similar confusion in "The Three Flavors of Capsule: Matrix Capsule and EM Routing". For example, if the probability density function $p(x)$ is known, then the expectation of $x$ is defined as:

\begin{equation} \mathbb{E}[x] = \int x p(x)dx\tag{1} \end{equation}

If you want to perform numerical computation (numerical integration) on it, you can select several representative points $x_0 < x_1 < x_2 < \dots < x_n$, and then obtain:

\begin{equation} \mathbb{E}[x] \approx \sum_{i=1}^n x_i p(x_i) \left(x_i - x_{i-1}\right)\tag{2} \end{equation}

I won't discuss what "representative" means here, nor the methods for improving numerical precision. This is written to contrast it with sampling computation. If we sample several points $x_1, x_2, \dots, x_n$ from $p(x)$, then we have:

\begin{equation} \mathbb{E}[x] \approx \frac{1}{n}\sum_{i=1}^n x_i,\quad x_i \sim p(x)\tag{3} \end{equation}

We can compare $(2)$ with $(3)$. Their main difference is that $(2)$ includes the calculation of probability, while $(3)$ only involves the calculation of $x$. This is because in $(3)$, $x_i$ is sampled according to the probability from $p(x)$; regions with higher probability result in $x_i$ appearing more frequently. Therefore, it can be said that the sampling results already contain $p(x)$ within them, so there is no need to multiply by $p(x_i)$ anymore.

More generally, we can write:

\begin{equation} \mathbb{E}_{x\sim p(x)}[f(x)] = \int f(x)p(x)dx \approx \frac{1}{n}\sum_{i=1}^n f(x_i),\quad x_i\sim p(x)\tag{4} \end{equation}

This is the foundation of Monte Carlo simulation.

KL Divergence and Variational Inference

We usually use KL divergence to measure the difference between two probability distributions $p(x)$ and $q(x)$, defined as:

\begin{equation} KL\Big(p(x)\Big\Vert q(x)\Big) = \int p(x)\ln \frac{p(x)}{q(x)} dx=\mathbb{E}_{x\sim p(x)}\left[\ln \frac{p(x)}{q(x)}\right]\tag{5} \end{equation}

The main property of KL divergence is non-negativity. If $p(x)$ is fixed, then $KL\Big(p(x)\Big\Vert q(x)\Big)=0 \Leftrightarrow p(x)=q(x)$; if $q(x)$ is fixed, it also holds that $KL\Big(p(x)\Big\Vert q(x)\Big)=0 \Leftrightarrow p(x)=q(x)$. That is to say, regardless of which one is fixed, minimizing the KL divergence results in the two being as equal as possible. The strict proof of this point requires variational methods, and in fact, the "V" in VAE (Variational) exists because the derivation of VAE involves KL divergence (which in turn involves variational calculus).

Of course, KL divergence has a rather obvious problem: when $q(x)$ is equal to 0 in a certain region while $p(x)$ is not zero in that region, the KL divergence goes to infinity. This is an inherent problem of KL divergence, and we can only find ways to circumvent it. For example, for the prior distribution of latent variables, we use a Gaussian distribution instead of a uniform distribution for this reason, as mentioned in the previous article "Variational Autoencoders (Part 1): So That's How It Is".

As a side note, is KL divergence the only way to measure the difference between two probability distributions? Of course not. We can look at the Statistical Distance entry on Wikipedia, which introduces many distribution distances. For example, there is a very beautiful measure called the Bhattacharyya distance, defined as:

\begin{equation} D_B\Big(p(x), q(x)\Big)=-\ln\int \sqrt{p(x)q(x)} dx\tag{6} \end{equation}

This distance is not only symmetric but also avoids the infinity problem of KL divergence. However, we still choose KL divergence because we not only want theoretical beauty but also practical feasibility. KL divergence can be written in the form of an expectation, which allows us to perform sampling calculations. Conversely, it is not as easy for the Bhattacharyya distance. If the reader tries to replace KL divergence with Bhattacharyya distance in the following calculation process, they will find it nearly impossible to proceed.

Table of Notation

Explaining VAE inevitably involves a large number of formulas and symbols. The meanings of some expressions are listed in advance as follows:

$$ \begin{array}{c|c} \hline x_k, z_k & \text{Represents the } k\text{-th sample of random variables } x, z\\ \hline x_{(k)}, z_{(k)} & \text{Represents the } k\text{-th component of multivariate variables } x, z\\ \hline \mathbb{E}_{x\sim p(x)}[f(x)] & \text{Represents the expectation of } f(x)\text{, where the distribution of } x \text{ is } p(x)\\ \hline KL\Big(p(x)\Big\Vert q(x)\Big)& \text{KL divergence between two distributions}\\ \hline \Vert x\Vert^2& \text{The } l^2 \text{ norm of vector } x\text{, i.e., the square of what we usually call the magnitude}\\ \hline \mathcal{L}& \text{The symbol for the loss function in this article}\\ \hline D, d & D \text{ is the dimension of the input } x, d \text{ is the dimension of the latent variable } z\\ \hline \end{array} $$

Framework

Here, a concise and quick theoretical framework for VAE is provided by directly approximating the joint distribution.

Facing the Joint Distribution Directly

The starting point remains unchanged, and I will restate it here. First, we have a set of data samples $\{x_1, \dots, x_n\}$, which is described as a whole by $x$. We hope to describe the distribution $\tilde{p}(x)$ of $x$ with the help of latent variables $z$:

\begin{equation} q(x)=\int q(x|z)q(z)dz,\quad q(x,z) = q(x|z)q(z)\tag{7} \end{equation}

Here $q(z)$ is the prior distribution (standard normal distribution), and the goal is to make $q(x)$ approach $\tilde{p}(x)$. In this way (theoretically), we have both described $\tilde{p}(x)$ and obtained the generative model $q(x|z)$, killing two birds with one stone.

The next step is to use KL divergence for approximation. But what I have never understood is why, starting from the original paper "Auto-Encoding Variational Bayes", VAE tutorials focus on describing the posterior distribution $p(z|x)$? Perhaps they are influenced by the EM algorithm. The reason EM cannot be applied here is that the posterior distribution $p(z|x)$ is difficult to calculate, so the authors of VAE focused on the derivation of $p(z|x)$.

However, in fact, directly approximating $p(x,z)$ is the most straightforward way. Specifically, define $p(x,z)=\tilde{p}(x)p(z|x)$. Suppose we use a joint probability distribution $q(x,z)$ to approximate $p(x,z)$. Let's look at the distance between them using KL divergence:

\begin{equation} KL\Big(p(x,z)\Big\Vert q(x,z)\Big) = \iint p(x,z)\ln \frac{p(x,z)}{q(x,z)} dzdx\tag{8} \end{equation}

KL divergence is our ultimate goal because we want the two distributions to be as close as possible, meaning the KL divergence should be as small as possible. Of course, since $p(x,z)$ also contains parameters now, it's not just $q(x,z)$ approximating $p(x,z)$; $p(x,z)$ will also actively approximate $q(x,z)$. The two approach each other.

Thus we have:

\begin{align} KL\Big(p(x,z)\Big\Vert q(x,z)\Big) =& \int \tilde{p}(x) \left[\int p(z|x)\ln \frac{\tilde{p}(x)p(z|x)}{q(x,z)} dz\right]dx\\ =& \mathbb{E}_{x\sim \tilde{p}(x)} \left[\int p(z|x)\ln \frac{\tilde{p}(x)p(z|x)}{q(x,z)} dz\right] \tag{9} \end{align}

In this way, using formula $(4)$, we can perform calculations by substituting various $x_i$. This expression can be further simplified because $\ln \frac{\tilde{p}(x)p(z|x)}{q(x,z)}=\ln \tilde{p}(x) + \ln \frac{p(z|x)}{q(x,z)}$, and:

\begin{align} \mathbb{E}_{x\sim \tilde{p}(x)} \left[\int p(z|x)\ln \tilde{p}(x)dz\right] =& \mathbb{E}_{x\sim \tilde{p}(x)} \left[\ln \tilde{p}(x)\int p(z|x)dz\right]\\ =&\mathbb{E}_{x\sim \tilde{p}(x)} \big[\ln \tilde{p}(x)\big] \tag{10} \end{align}

Note that $\tilde{p}(x)$ here is the prior distribution of $x$ determined by samples $x_1, x_2, \dots, x_n$. Although we may not necessarily be able to write down its exact form, it is fixed and exists. Therefore, this term is just a constant, so we can write:

\begin{equation} \mathcal{L}=KL\Big(p(x,z)\Big\Vert q(x,z)\Big) - \text{Constant}= \mathbb{E}_{x\sim \tilde{p}(x)} \left[\int p(z|x)\ln \frac{p(z|x)}{q(x,z)} dz\right]\tag{11} \end{equation}

Currently, minimizing $KL\Big(p(x,z)\Big\Vert q(x,z)\Big)$ is equivalent to minimizing $\mathcal{L}$. Note that the subtracted constant is $\mathbb{E}_{x\sim \tilde{p}(x)} \big[\ln \tilde{p}(x)\big]$, so $\mathcal{L}$ has a lower bound of $-\mathbb{E}_{x\sim \tilde{p}(x)} \big[\ln \tilde{p}(x)\big]$~ Note that $\tilde{p}(x)$ is not necessarily a probability; in continuous cases, $\tilde{p}(x)$ is a probability density, which can be greater or less than 1. Therefore, $-\mathbb{E}_{x\sim \tilde{p}(x)} \big[\ln \tilde{p}(x)\big]$ is not necessarily non-negative, meaning the loss can be negative.

Your VAE has arrived

At this point, we recall our original intention—to obtain a generative model, so we write $q(x,z)$ as $q(x|z)q(z)$, which leads to:

\begin{align} \mathcal{L} =& \mathbb{E}_{x\sim \tilde{p}(x)} \left[\int p(z|x)\ln \frac{p(z|x)}{q(x|z)q(z)} dz\right]\\ =&\mathbb{E}_{x\sim \tilde{p}(x)} \left[-\int p(z|x)\ln q(x|z)dz+\int p(z|x)\ln \frac{p(z|x)}{q(z)}dz\right]\tag{12} \end{align}

To make it even clearer:

\begin{align} \mathcal{L} = & \mathbb{E}_{x\sim \tilde{p}(x)} \left[\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]+\mathbb{E}_{z\sim p(z|x)}\Big[\ln \frac{p(z|x)}{q(z)}\Big]\right]\\ = & \mathbb{E}_{x\sim \tilde{p}(x)} \Bigg[\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]+KL\Big(p(z|x)\Big\Vert q(z)\Big)\Bigg] \tag{13} \end{align}

Look, isn't what's inside the brackets exactly the loss function of VAE? We've just changed the notation. We just need to find appropriate $q(x|z)$ and $q(z)$ to minimize $\mathcal{L}$.

Looking back at the whole process, we hardly made any "difficult to think of" formal transformations, yet VAE emerged. Therefore, there is no need to analyze the posterior distribution; by facing the joint distribution directly, we can reach the end objective more quickly.

Do Not Split Them Up~

Given the characteristics of formula $(13)$, we might split $\mathcal{L}$ into two parts: the expectation of $\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]$ and the expectation of $KL\Big(p(z|x)\Big\Vert q(z)\Big)$, and think the problem has turned into minimizing these two losses separately.

However, this view is inappropriate. $KL\Big(p(z|x)\Big\Vert q(z)\Big)=0$ means that $z$ has no discriminative power, so $-\ln q(x|z)$ cannot be small (poor prediction). Conversely, if $-\ln q(x|z)$ is small, then $q(x|z)$ is large and predictions are accurate. In this case, $p(z|x)$ will not be too random, meaning $KL\Big(p(z|x)\Big\Vert q(z)\Big)$ will not be small. Therefore, these two parts of the loss are actually antagonistic. Thus, $\mathcal{L}$ cannot be viewed in isolation; it must be viewed as a whole. The smaller the overall $\mathcal{L}$, the closer the model is to convergence, and we cannot simply observe one part of the loss alone.

In fact, this is exactly what is dreamed of in GAN models—a total indicator that can indicate the progress of generating model training. VAE naturally possesses this capability, while in GANs, it wasn't until WGAN that such an indicator was found~

Experiment

As of the content above, we have finished the overall theoretical construction of VAE. But in order to put it into practice, some work still needs to be done. In fact, the original paper "Auto-Encoding Variational Bayes" also expanded quite fully on this part, but unfortunately, many online VAE tutorials stop at formula $(13)$ without further detail.

Posterior Distribution Approximation

Currently $q(z), q(x|z), p(z|x)$ are all unknown, and even their forms haven't been determined. To conduct experiments, every term in formula $(13)$ must be explicitly written out.

First, for easier sampling, we assume $z\sim N(0,I)$, i.e., a standard multivariate normal distribution. This solves $q(z)$. What about $q(x|z), p(z|x)$?

Let's just use neural networks to fit them.

Note: Originally, if $q(x|z)$ and $q(z)$ were known, the most reasonable estimate for $p(z|x)$ would be:

\begin{equation} \hat{p}(z|x) = q(z|x) = \frac{q(x|z)q(z)}{q(x)} = \frac{q(x|z)q(z)}{\int q(x|z)q(z)dz}\tag{14} \end{equation}

This is actually the step of posterior probability estimation in the EM algorithm. For details, refer to "From Maximum Likelihood to EM Algorithm: A Consistent Way of Understanding". However, in reality, the integral in the denominator is almost impossible to complete, so this is not feasible. Therefore, we simply use a general network to approximate it. This may not reach the optimum, but it is ultimately a usable approximation.

Specifically, we assume that $p(z|x)$ is also a normal distribution (with independent components), whose mean and variance are determined by $x$. This "determination" is a neural network:

\begin{equation} p(z|x)=\frac{1}{\prod\limits_{k=1}^d \sqrt{2\pi \sigma_{(k)}^2(x)}}\exp\left(-\frac{1}{2}\left\Vert\frac{z-\mu(x)}{\sigma(x)}\right\Vert^2\right)\tag{15} \end{equation}

Here $\mu(x), \sigma^2(x)$ are neural networks with input $x$ and outputs as mean and variance, respectively. This $\mu(x)$ plays the role of the encoder. Since we have assumed a Gaussian distribution, the KL divergence term in formula $(13)$ can be calculated first:

\begin{equation} KL\Big(p(z|x)\Big\Vert q(z)\Big)=\frac{1}{2} \sum_{k=1}^d \Big(\mu_{(k)}^2(x) + \sigma_{(k)}^2(x) - \ln \sigma_{(k)}^2(x) - 1\Big)\tag{16} \end{equation}

This is what we call the KL loss, which was already given in the previous article.

Generative Model Approximation

Now only the generative model part $q(x|z)$ remains. What distribution should we choose? The paper "Auto-Encoding Variational Bayes" provides two candidate schemes: Bernoulli distribution or Normal distribution.

What? Normal distribution again? Isn't that too simplified?

However, there's no choice because we need to construct a distribution, not just any function. Since it's a distribution, it must satisfy the normalization requirement. For it to satisfy normalization and be easy to calculate, we really don't have many options.

Bernoulli Distribution Model

First, let's look at the Bernoulli distribution. As everyone knows, it is actually a binary distribution:

\begin{equation} p(\xi)=\left\{\begin{aligned}&\rho,\, \xi = 1;\\ &1-\rho,\, \xi = 0\end{aligned}\right.\tag{17} \end{equation}

So the Bernoulli distribution only applies to cases where $x$ is a multivariate binary vector, such as when $x$ is a binary image (MNIST can be seen as such a case). In this case, we use a neural network $\rho(z)$ to calculate the parameter $\rho$, thus obtaining:

\begin{equation} q(x|z)=\prod_{k=1}^D \Big(\rho_{(k)}(z)\Big)^{x_{(k)}} \Big(1 - \rho_{(k)}(z)\Big)^{1 - x_{(k)}}\tag{18} \end{equation}

At this point, we can calculate:

\begin{equation} -\ln q(x|z) = \sum_{k=1}^D \Big[- x_{(k)} \ln \rho_{(k)}(z) - (1-x_{(k)}) \ln \Big(1 -\rho_{(k)}(z)\Big)\Big]\tag{19} \end{equation}

This implies that $\rho(z)$ needs to be squashed between 0 and 1 (using sigmoid activation, for example), and then binary cross-entropy is used as the loss function. Here $\rho(z)$ plays the role of the decoder.

Normal Distribution Model

Next is the normal distribution, which is just like $p(z|x)$, but with $x$ and $z$ swapped:

\begin{equation} q(x|z)=\frac{1}{\prod\limits_{k=1}^D \sqrt{2\pi \tilde{\sigma}_{(k)}^2(z)}}\exp\left(-\frac{1}{2}\left\Vert\frac{x-\tilde{\mu}(z)}{\tilde{\sigma}(z)}\right\Vert^2\right)\tag{20} \end{equation}

Here $\tilde{\mu}(z), \tilde{\sigma}^2(z)$ are neural networks with input $z$ and outputs as mean and variance, respectively. $\tilde{\mu}(z)$ plays the role of the decoder. Thus:

\begin{equation} -\ln q(x|z) = \frac{1}{2}\left\Vert\frac{x-\tilde{\mu}(z)}{\tilde{\sigma}(z)}\right\Vert^2 + \frac{D}{2}\ln 2\pi + \frac{1}{2}\sum_{k=1}^D \ln \tilde{\sigma}_{(k)}^2(z)\tag{21} \end{equation}

Many times we fix the variance to a constant $\tilde{\sigma}^2$, in which case:

\begin{equation} -\ln q(x|z) \sim \frac{1}{2\tilde{\sigma}^2}\Big\Vert x-\tilde{\mu}(z)\Big\Vert^2\tag{22} \end{equation}

This leads to the MSE loss function.

So it's clear now: for binary data, we can use sigmoid activation in the decoder and then use cross-entropy as the loss function, which corresponds to $q(x|z)$ being a Bernoulli distribution; for general data, we use MSE as the loss function, which corresponds to $q(x|z)$ being a Normal distribution with fixed variance.

Sampling Calculation Tricks

The previous section achieved many things, mainly to write down formula $(13)$ explicitly. When we assume that $p(z|x)$ and $q(z)$ are both normal distributions, the KL divergence part of formula $(13)$ is already calculated, as shown in formula $(16)$. When we assume $q(x|z)$ is a Bernoulli or Gaussian distribution, $-\ln q(x|z)$ can also be calculated. What's missing now?

Sampling!

The role of $p(z|x)$ is split into two parts: one is for calculating $KL\Big(p(z|x)\Big\Vert q(z)\Big)$, and the other is for calculating $\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]$. And $\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]$ means:

\begin{equation} -\frac{1}{n}\sum_{i=1}^n \ln q(x|z_i),\quad z_i \sim p(z|x)\tag{23} \end{equation}

We have assumed $p(z|x)$ is a normal distribution with mean and variance calculated by the model. In this way, sampling can be completed using the "reparameterization trick".

But how many samples are appropriate? VAE is very straightforward: one! So formula $(13)$ becomes very simple:

\begin{equation} \mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Bigg[-\ln q(x|z) + KL\Big(p(z|x)\Big\Vert q(z)\Big)\Bigg],\quad z\sim p(z|x)\tag{24} \end{equation}

Each term in this formula can be found from formulas $(16), (19), (21), (22)$. Note that for each $x$ in a batch, a $z$ "exclusive" to $x$ needs to be sampled from $p(z|x)$ before calculating $-\ln q(x|z)$. Because VAE only samples one sample here from $p(z|x)$, it looks very similar to a standard AE.

The final question is: is sampling just one sample enough? In fact, we run multiple epochs, and the latent variables are randomly generated each time. Therefore, when the number of epochs is sufficient, the adequacy of sampling can be guaranteed. I have also experimented with sampling multiple points, and I felt that the generated samples did not change significantly.

Tribute

This article provides an overview of the entire VAE process from the perspective of Bayesian theory. When considering things from this perspective, we need to hold onto two points: "distribution" and "sampling"—write down the distribution form and simplify the process through sampling.

Simply put, because directly describing a complex distribution is difficult, we introduce latent variables to transform it into the superposition of conditional distributions. At this point, we can appropriately simplify both the distribution of latent variables and the conditional distribution (e.g., assuming they are both normal distributions), and the parameters of the conditional distribution can be combined with deep learning models (using deep learning to calculate the parameters of the latent variables). At this point, the "Deep Probabilistic Graphical Model" is revealed.

Let us pay tribute to the great Bayes, as well as the many masters who study probabilistic graphical models. They are the true heroes.