Variational Autoencoders (Part 2): From a Bayesian Perspective

By 苏剑林 | April 03, 2018

I feel like my recent articles have been becoming quite long, and they seem to come in clusters! After writing three consecutive posts introducing Capsules, it's now VAE's turn. This is the third exploration of VAE, and who knows, there might even be a fourth. Regardless, quantity isn't what matters; what's important is thinking through the problems clearly. Especially for a novel modeling approach like VAE, it's worth scrutinizing the details. The question we want to address this time is: Why does VAE actually work?

I imagine readers of VAE go through several stages. Phase one: having just read an introduction and feeling lost in the fog—it seems like an autoencoder but also doesn't; after several readings and looking at the source code, you get a general idea. Phase two: building on the first, reading the principles deeply—latent variable models, KL divergence, variational inference, etc. As you dig deeper, you find that despite all the tossing and turning, you finally understand it.

At this point, readers might enter the third stage. Here, various questions arise, particularly regarding feasibility: "Why is it that after all this maneuvering, the resulting model actually works? I have many ideas too—why don't mine work?"

Review of Previous Content


Let's tirelessly review some principles of VAE. VAE aims to describe the distribution of data $X$ via latent variable decomposition:

$$p(x)=\int p(x|z)p(z)dz,\quad p(x,z) = p(x|z)p(z)\tag{1}$$

Then, we fit $p(x|z)$ with the model $q(x|z)$ and $p(z)$ with the model $q(z)$. To give the model generative capabilities, $q(z)$ is defined as a standard normal distribution. Theoretically, we can solve the model using maximum likelihood of the marginal probability:

$$\begin{aligned}q(x|z)=&\mathop{\text{argmax}}_{q(x|z)} \int \tilde{p}(x)\ln\left(\int q(x|z)q(z)dz\right)dx\\ =&\mathop{\text{argmax}}_{q(x|z)} \mathbb{E}_{x\sim\tilde{p}(x)}\left[\ln\left(\int q(x|z)q(z)dz\right)\right] \end{aligned}\tag{2}$$

However, because the integral inside the parentheses cannot be solved explicitly, we introduce the KL divergence to observe the gap between joint distributions. The final objective function becomes:

$$\begin{aligned}\mathcal{L} =&\mathbb{E}_{x\sim \tilde{p}(x)} \left[-\int p(z|x)\ln q(x|z)dz+\int p(z|x)\ln \frac{p(z|x)}{q(z)}dz\right]\\ = & \mathbb{E}_{x\sim \tilde{p}(x)} \left[\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]+\mathbb{E}_{z\sim p(z|x)}\Big[\ln \frac{p(z|x)}{q(z)}\Big]\right]\end{aligned}\tag{3}$$

By minimizing $\mathcal{L}$, we find $p(z|x)$ and $q(x|z)$ respectively. The previous post, "Variational Autoencoders (Part 2): Starting from the Bayesian Perspective," also showed that $\mathcal{L}$ has a lower bound $-\mathbb{E}_{x\sim \tilde{p}(x)}\big[\ln \tilde{p}(x)\big]$, so comparing the proximity of $\mathcal{L}$ to $-\mathbb{E}_{x\sim \tilde{p}(x)}\big[\ln \tilde{p}(x)\big]$ allows us to compare the relative quality of the generator.

Confusions of Sampling


In this section, we attempt to ask detailed questions about VAE principles to answer why it's done this way and why it works.

One Point is Enough

Regarding equation $(3)$, we later handle it like this:

  1. Notice that $\mathbb{E}_{z\sim p(z|x)}\Big[\ln \frac{p(z|x)}{q(z)}\Big]$ is exactly the KL divergence between $p(z|x)$ and $q(z)$, denoted as $KL\Big(p(z|x)\Big\Vert q(z)\Big)$. Since both are assumed to be normal distributions, this term can be calculated analytically;
  2. For the term $\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]$, we believe that sampling just one point is representative enough, so this term becomes $-\ln q(x|z),\, z\sim p(z|x)$.

With this treatment, the entire loss can be explicitly written as:

$$\mathcal{L}=\mathbb{E}_{x\sim \tilde{p}(x)} \left[-\ln q(x|z) + KL\Big(p(z|x)\Big\Vert q(z)\Big)\right],\quad z\sim p(z|x)\tag{4}$$

Wait, some readers might find this questionable: calculating $KL\Big(p(z|x)\Big\Vert q(z)\Big)$ analytically is equivalent to sampling infinitely many points to estimate it, yet for $\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]$, we only sample one point. Given that both are parts of the loss, is this "unfair treatment" really okay?

In fact, $\mathbb{E}_{z\sim p(z|x)}\Big[\ln \frac{p(z|x)}{q(z)}\Big]$ can also be calculated by sampling just one point. That is, by sampling one point for the entire expression, we can turn equation $(3)$ into:

$$\begin{aligned}\mathcal{L} =&\mathbb{E}_{x\sim \tilde{p}(x)} \left[-\ln q(x|z)+\ln \frac{p(z|x)}{q(z)}\right]\\ =&\mathbb{E}_{x\sim \tilde{p}(x)} \Big[-\ln q(x|z)+\ln p(z|x) - \ln q(z)\Big]\,,\quad z\sim p(z|x) \end{aligned}\tag{5}$$

Although this loss is different from the standard VAE, it actually converges to similar results.

Why is One Point Enough?

So why is sampling one point enough? Under what circumstances is one point sufficient?

First, let me give an example where "sampling one point is not enough." Let's look back at equation $(2)$, which can be rewritten as:

$$q(x|z)=\mathop{\text{argmax}}_{q(x|z)} \mathbb{E}_{x\sim\tilde{p}(x)}\Bigg[\ln\Big(\mathbb{E}_{z\sim q(z)}\big[q(x|z)\big]\Big)\Bigg]\tag{6}$$

If one point were enough—actually, let's be cautious and sample $k$ points—we could write:

$$q(x|z)=\mathop{\text{argmax}}_{q(x|z)} \mathbb{E}_{x\sim\tilde{p}(x)}\Bigg[\ln\left(\frac{1}{k}\sum_{i=1}^k q(x|z_i)\right)\Bigg],\quad z_1,\dots,z_k \sim q(z)\tag{7}$$

Then we could train with gradient descent. However, this strategy is unsuccessful. In practice, the number of samples $k$ is usually smaller than the batch size. In this case, maximizing $\ln\left(\frac{1}{k}\sum\limits_{i=1}^k q(x|z_i)\right)$ leads to a "resource scramble": in each iteration, various $x_i$ in a batch compete for $z_1, z_2, \dots, z_k$. Whoever successfully "claims" a $z_j$ gets a high $q(x|z)$ (specifically, if $x_i$ finds a $z_j$ that belongs only to it, meaning $z_j$ only generates $x_i$ and not others, then $q(x_i|z_j)$ is high). But all samples are equal, and sampling is random; we cannot predict the outcome of this "resource scramble." It's complete chaos!

If the dataset is just MNIST, it might be okay because MNIST samples have clear clustering tendencies. If the number of samples $k$ exceeds 10, there's enough for each $x_i$ to share. But for datasets like faces or ImageNet, which lack obvious clustering and have high intra-class variance, the $z$'s are simply not enough to go around. One moment $x_i$ grabs $z_j$, and the next $x_{i+1}$ grabs $z_j$; training simply fails.

Therefore, it is this "too many monks, too little porridge" situation that causes model $(7)$ to fail. But why does VAE succeed with just one point?

One Point Really is Enough

We need to further analyze our thinking regarding $q(x|z)$. We call $q(x|z)$ the generative model part. Generally, we assume it's a Bernoulli or Gaussian distribution. Given that Bernoulli distributions have limited use cases, we'll assume it's Gaussian:

$$q(x|z)=\frac{1}{\prod\limits_{k=1}^D \sqrt{2\pi \sigma_{(k)}^2(z)}}\exp\left(-\frac{1}{2}\left\Vert\frac{x-\mu(z)}{\sigma(z)}\right\Vert^2\right)\tag{8}$$

Where $\mu(z)$ is the network used to calculate the mean and $\sigma^2(z)$ is the network for the variance. Often, we fix the variance, leaving only the mean network.

Note that $q(x|z)$ is a probability distribution. After sampling $z$ from $q(z)$ and plugging it into $q(x|z)$, we get the specific form of the distribution. Theoretically, we should sample from $q(x|z)$ again to get $x$. However, we don't do that. We directly treat the result of the mean network $\mu(z)$ as $x$. Being able to do this indicates that $q(x|z)$ is a Gaussian distribution with very low variance (if the variance is fixed, it must be set low before training; if it's a Bernoulli distribution, this isn't an issue). Each sampling result is nearly identical (always the mean $\mu(z)$). At this point, $x$ and $z$ have an "almost" one-to-one relationship, approaching the deterministic function $x=\mu(z)$.

Standard Normal Distribution (Blue) vs. Low Variance Normal Distribution (Orange)

For the posterior distribution $p(z|x)$, we also assume it's Gaussian. Since we said $z$ and $x$ are almost one-to-one, this property also applies to the posterior distribution $p(z|x)$. This implies that the posterior distribution will also be a Gaussian with very small variance (readers can verify this by checking the encoder results on MNIST). This means that each sample taken from $p(z|x)$ is almost the same.

As such, there is little difference between sampling once and sampling multiple times, because the samples are essentially the same anyway. This explains why we can start from equation $(3)$ and calculate using only one sample, resulting in equations $(4)$ or $(5)$.

The Magic of the Posterior


We previously explained why sampling from the prior distribution $q(z)$ is bad, while one point from the posterior $p(z|x)$ is enough. In fact, using KL divergence to introduce a posterior distribution into latent variable models is a magical trick. In this section, we'll organize these thoughts and provide a new example applying this idea.

Prior Knowledge of the Posterior

Readers might feel a bit logically confused: "You say $q(x|z)$ and $p(z|x)$ both end up as low-variance Gaussians, but that's just the final training result. Theoretically, we shouldn't know how large the variances are before modeling, so how can we sample just one point beforehand?"

I believe this is our prior knowledge of the problem. When we decide to use a dataset $X$ for VAE, the dataset itself carries strong constraints. For example, MNIST has 784 pixels, but its independent dimensions are much fewer than 784. Most obviously, some edge pixels are always 0. Compared to all $28\times 28$ images, MNIST is a very small subset. Another example: the "Tang Poetry" corpus is a tiny subset of all possible sentences. Even for ImageNet with thousands of categories, it's a small subset of the infinite space of images.

Consequently, we think this dataset $X$ can be projected onto a low-dimensional space (latent variable space), where the latent variables correspond one-to-one with the original set $X$. Readers might notice: isn't this just a standard autoencoder? Yes. In the case of a standard autoencoder, we can achieve a one-to-one mapping between latent variables and the original dataset (meaning the variances of $p(z|x)$ and $q(x|z)$ are zero). After introducing the Gaussian prior distribution $q(z)$, roughly speaking, this only applies translation and scaling to the latent space, so the variance doesn't have to be large.

Thus, we are essentially guessing beforehand that the variances of $q(x|z)$ and $p(z|x)$ are small and making the model realize this estimate. Simply put, the operation of "sampling one point" is our prior knowledge of the data and the model—a prior about the posterior. We use this prior knowledge to hope the model will move towards it.

The entire logic is:

  1. Have the original corpus;
  2. Observe the corpus and hypothesize that it can be mapped one-to-one to a latent variable space;
  3. Use "sampling one point" to let the model learn this mapping.

This part is a bit messy—it feels almost redundant. I hope readers aren't confused by me. If it feels confusing, feel free to ignore this part!

The Straightforward IWAE

The following example, called "Importance Weighted Autoencoders" (IWAE), demonstrates the utility of the posterior even more directly. In some ways, it can be seen as an upgrade to VAE.

The starting point of IWAE is equation $(2)$. It introduces the posterior distribution to rewrite equation $(2)$:

$$\int q(x|z)q(z)dz = \int p(z|x)\frac{q(x|z)q(z)}{p(z|x)}dz=\mathbb{E}_{z\sim p(z|x)}\left[\frac{q(x|z)q(z)}{p(z|x)}\right]\tag{8}$$

In this way, equation $(2)$ changes from sampling from $q(z)$ to sampling from $p(z|x)$. Since we've argued that $p(z|x)$ has low variance, sampling a few points is sufficient:

$$\int q(x|z)q(z)dz = \frac{1}{k}\sum_{i=1}^k \frac{q(x|z_i)q(z_i)}{p(z_i|x)},\quad z_1,\dots,z_k\sim p(z|x)\tag{9}$$

Substituting into equation $(2)$ gives:

$$q(x|z)=\mathop{\text{argmax}}_{q(x|z)} \mathbb{E}_{x\sim\tilde{p}(x)}\Bigg[\ln\left(\frac{1}{k}\sum_{i=1}^k \frac{q(x|z_i)q(z_i)}{p(z_i|x)}\right)\Bigg],\quad z_1,\dots,z_k \sim p(z|x)\tag{10}$$

This is IWAE. To align with equations $(4)$ and $(5)$, it can be equivalently written as:

$$\begin{aligned}&q(x|z) = \mathop{\text{argmin}}_{q(x|z),p(z|x)} \mathcal{L}_k,\\ \mathcal{L}_k = \mathbb{E}_{x\sim\tilde{p}(x)}\Bigg[&-\ln\left(\frac{1}{k}\sum_{i=1}^k \frac{q(x|z_i)q(z_i)}{p(z_i|x)}\right)\Bigg],\quad z_1,\dots,z_k \sim p(z|x)\end{aligned}\tag{11}$$

When $k=1$, the above equation is exactly the same as equation $(5)$. So from this perspective, IWAE is an upgrade of VAE.

From the construction process, replacing $p(z|x)$ in equation $(8)$ with any distribution of $z$ would be valid; $p(z|x)$ is chosen simply because it is focused and easy to sample from. When $k$ is large enough, the specific form of $p(z|x)$ actually becomes unimportant. This suggests that in IWAE, the role of the encoder model $p(z|x)$ is weakened in exchange for improving the generative model $q(x|z)$. In VAE, we assume $p(z|x)$ is a Gaussian, which is just an approximation for easy computation; the validity of this approximation simultaneously affects the quality of the generative model $q(x|z)$. It can be proven that $\mathcal{L}_k$ is closer to the lower bound $-\mathbb{E}_{x\sim \tilde{p}(x)} \left[\ln \tilde{p}(x)\right]$ than $\mathcal{L}$, so the generator's quality is superior.

Intuitively, in IWAE, the degree of approximation of $p(z|x)$ isn't as critical, resulting in a better generative model. However, the cost is that the encoder's quality decreases because the importance of $p(z|x)$ is reduced, so the model doesn't focus as much effort on training $p(z|x)$. Therefore, if we want a good encoder, IWAE is not the way to go.

There is also a paper, "Tighter Variational Bounds are Not Necessarily Better," which supposedly improves both the encoder and decoder quality, though I haven't fully understood it yet.

The Godlike Reparameterization


If the introduction of the posterior distribution successfully outlines the VAE blueprint, the reparameterization trick is the "finishing touch" that brings it to life.

Previously, we said VAE introduces the posterior distribution to shift sampling from the loose standard normal $q(z)$ to the compact normal $p(z|x)$. However, though both are normal distributions, their implications are very different. Let's write out:

$$p(z|x)=\frac{1}{\prod\limits_{k=1}^d \sqrt{2\pi \sigma_{(k)}^2(x)}}\exp\left(-\frac{1}{2}\left\Vert\frac{z-\mu(x)}{\sigma(x)}\right\Vert^2\right)\tag{12}$$

That is, the mean and variance of $p(z|x)$ are both models to be trained. Let's imagine: when the model reaches this step, it calculates $\mu(x)$ and $\sigma(x)$, and then constructs the normal distribution and samples. But what is sampled? It's a vector, and we can't see the relationship between this vector and $\mu(x)$ or $\sigma(x)$. It effectively treats the sample as a constant vector; differentiating it results in zero. Thus, during gradient descent, we get no feedback to update $\mu(x)$ and $\sigma(x)$.

The Reparameterization Trick

This is where the reparameterization trick makes its grand entrance. It tells us directly:

$$z = \mu(x) + \varepsilon \times \sigma(x),\quad \varepsilon\sim \mathcal{N}(0,I).$$

Nothing is simpler than this. It looks like a minor transformation, but it explicitly tells us the relationship between $z$ and $\mu(x), \sigma(x)$! Consequently, the derivative of $z$ is no longer zero, and $\mu(x), \sigma(x)$ finally receive the feedback they deserve. With this, the model is ready—all that's left is writing the code.

Clearly, the "reparameterization trick" is a masterstroke.

Conclusion


Rambling on, I've filled another post. This post was meant to clear up some minor details of VAE, especially how VAE solves the sampling hurdle (and thus the training hurdle) by cleverly introducing the posterior, and to introduce IWAE along the way.

Seeking an intuitive understanding inevitably sacrifices some rigor; you can't have both. So, for any flaws in the article, I hope experienced readers will be lenient and I welcome criticisms and suggestions.

Original link: https://kexue.fm/archives/5383