A Concise Introduction to VQ-VAE: Quantized Autoencoder

By 苏剑林 | June 24, 2019

I recall seeing VQ-VAE quite some time ago, but I didn't have much interest in it then. Recently, two things have reignited my interest. First, VQ-VAE-2 achieved generation results comparable to BigGAN (according to reports by Machine Heart); second, while reading the NLP paper "Unsupervised Paraphrasing without Translation", I noticed it also utilized VQ-VAE. These two events suggest that VQ-VAE is a fairly versatile and interesting model, so I decided to study it thoroughly.

My personal reproduction of VQ-VAE reconstruction on CelebA. Note that details are preserved well, though some blurriness appears upon magnification.
My personal reproduction of VQ-VAE reconstruction on CelebA. Note that details are preserved well, though some blurriness appears upon magnification.

Model Overview

VQ-VAE (Vector Quantised - Variational AutoEncoder) first appeared in the paper "Neural Discrete Representation Learning", which, like VQ-VAE-2, is a major work from the Google team.

Interesting yet Abstruse

As an autoencoder, a distinct feature of VQ-VAE is that the vectors it encodes are discrete. In other words, every element of the final encoding vector is an integer. This is the meaning of "Quantised" (similar to "Quantum" in quantum mechanics, referring to discretization).

Despite the entire model being continuous and differentiable, the resulting encoding vectors are discrete, and the reconstruction quality appears very sharp (as shown in the image at the beginning). This implies VQ-VAE contains some interesting and valuable techniques worth learning. However, after reading the original paper, it feels somewhat difficult to grasp—not because of the same dense technicality found in the ON-LSTM paper, but because of a sense of being "intentionally abstruse."

First, once you finish the paper, you realize VQ-VAE is actually an AE (Autoencoder) rather than a VAE (Variational Autoencoder). I am not sure why the authors felt compelled to use probabilistic language to tether it to VAE, as this significantly increases the difficulty of understanding. Second, a core step in VQ-VAE is the Straight-Through Estimator, an optimization trick for discrete latent variables. The paper lacks a detailed explanation, making it necessary to look at the source code to understand what is being said. Finally, the core idea of the model isn't explained well; it feels as if the paper focuses purely on presenting the model itself without explaining the underlying ideology.

PixelCNN

To trace the ideology of VQ-VAE, one must discuss autoregressive models. The strategy VQ-VAE uses for generative modeling stems from autoregressive models like PixelRNN and PixelCNN. These models recognize that the images we want to generate are actually discrete rather than continuous. Take a CIFAR-10 image as an example: it is a $32 \times 32$ image with 3 channels. In other words, it is a $32 \times 32 \times 3$ matrix where each element is an integer from 0 to 255. Thus, we can view it as a sentence with a length of $32 \times 32 \times 3 = 3072$ and a vocabulary size of 256. We can then use language modeling methods to generate the image pixel-by-pixel, recursively (predicting the next pixel given all previous pixels). This is the so-called autoregressive approach:

\begin{equation}p(x)=p(x_1)p(x_2|x_1)\dots p(x_{3n^2}|x_1,x_2,\dots,x_{3n^2-1})\end{equation}

where each $p(x_1), p(x_2|x_1), \dots, p(x_{3n^2}|x_1,x_2,\dots,x_{3n^2-1})$ is a 256-way classification problem, albeit with different dependencies.

Introductory materials for PixelRNN and PixelCNN are readily available online, so I won't repeat them here. I feel one could even ride the Bert wave and create a "PixelAtt" (Attention) to do this. Research on autoregressive models mainly focuses on two aspects: first, how to design the recursion order to better generate samples, as image sequences aren't simple 1D sequences (they are at least 2D or 3D). Whether you go "left-to-right then top-to-bottom," "top-to-bottom then left-to-right," or "center-outwards" significantly impacts the generation. Second, research focuses on accelerating the sampling process. Among the literature I've read, a relatively recent achievement in autoregressive models is the ICLR 2019 work "Generating High Fidelity Images with Subscale Pixel Networks and Multidimensional Upscaling".

The autoregressive method is robust and allows for effective probability estimation, but it has one fatal flaw: it is slow. Because it generates pixel-by-pixel, random sampling must be performed for every single pixel. The CIFAR-10 example is considered a small image; convincing image generation today requires at least $128 \times 128 \times 3$, totaling nearly 50,000 pixels. Generating a 50,000-length "sentence" pixel-by-pixel is extremely time-consuming. Furthermore, such long sequences are difficult for both RNN and CNN models to capture long-range dependencies.

Original autoregressive methods also have an issue: they decouple the relationship between categories. While pixels are discrete (making 256-way classification viable), continuous pixel values are actually very similar. Pure classification fails to capture this connection. Mathematically, our cross-entropy objective is $-\log p_t$. If the target pixel is 100 but I predict 99, since they are different classes, $p_t$ might be near 0, making $-\log p_t$ very large. Visually, however, there is little difference between a pixel value of 100 or 99; such a large penalty shouldn't exist.

VQ-VAE

To address the inherent flaws of autoregressive models, VQ-VAE proposes a solution: first perform dimension reduction, and then use PixelCNN to model the encoding vectors.

Dimension Reduction and Discretization

This solution seems natural, yet it is actually anything but.

Because PixelCNN generates discrete sequences, modeling encoding vectors with PixelCNN implies those vectors must also be discrete. However, standard dimension reduction methods, like autoencoders, produce continuous latent variables that cannot directly yield discrete variables. Furthermore, generating discrete variables often entails gradient vanishing problems. Additionally, how can we ensure the image doesn't lose significant detail during the reduction and reconstruction process? If the distortion is severe, or if it performs worse than a standard VAE, then VQ-VAE would have no value.

Fortunately, VQ-VAE provides an effective training strategy to solve these two problems.

Nearest Neighbor Reconstruction

In VQ-VAE, an $n \times n \times 3$ image $x$ is first passed into an $encoder$ to obtain a continuous encoding vector $z$:

\begin{equation}z = encoder(x)\end{equation}

Here, $z$ is a vector of size $d$. Additionally, VQ-VAE maintains an Embedding layer, which we can call a codebook, denoted as:

\begin{equation}E = [e_1, e_2, \dots, e_K]\end{equation}

where each $e_i$ is a vector of size $d$. Then, VQ-VAE uses a nearest neighbor search to map $z$ to one of these $K$ vectors:

\begin{equation}z\to e_k,\quad k = \mathop{\text{argmin}}_j \Vert z - e_j\Vert_2\end{equation}

We denote the codebook vector corresponding to $z$ as $z_q$, and we consider $z_q$ to be the final encoding result. Finally, $z_q$ is passed through a $decoder$ to reconstruct the original image $\hat{x}=decoder(z_q)$.

The entire workflow is:

\begin{equation}x\xrightarrow{encoder} z \xrightarrow{\text{Nearest Neighbor}} z_q \xrightarrow{decoder}\hat{x}\end{equation}

In this way, because $z_q$ is one of the vectors in codebook $E$, it is effectively equivalent to one of the integers $1, 2, \dots, K$. Thus, this entire process essentially encodes the image into an integer.

Of course, the process above is simplified. If we only encode into a single vector, reconstruction would inevitably lose detail and generalization would be hard to guarantee. Therefore, in practice, multiple convolutional layers are used to encode $x$ into an $m \times m$ grid of vectors of size $d$:

\begin{equation}z = \begin{pmatrix}z_{11} & z_{12} & \dots & z_{1m}\\ z_{21} & z_{22} & \dots & z_{2m}\\ \vdots & \vdots & \ddots & \vdots\\ z_{m1} & z_{m2} & \dots & z_{mm}\\ \end{pmatrix}\end{equation}

That is, the total size of $z$ is $m \times m \times d$, preserving spatial structure. Each vector is then mapped to one in the codebook using the aforementioned method to obtain $z_q$ of the same size, which is then used for reconstruction. Consequently, $z_q$ is equivalent to an $m \times m$ integer matrix, achieving discrete encoding.

Designing Your Own Gradients

We know that for a standard autoencoder, we can train directly using the following loss:

\begin{equation}\Vert x - decoder(z)\Vert_2^2\end{equation}

However, in VQ-VAE, we use $z_q$ rather than $z$ for reconstruction, so it seems we should use this loss instead:

\begin{equation}\Vert x - decoder(z_q)\Vert_2^2\end{equation}

The problem is that the construction of $z_q$ involves an $\text{argmin}$ operation, which is non-differentiable. Therefore, if we used the second loss, we couldn't update the $encoder$.

In other words, our actual goal is to minimize $\Vert x - decoder(z_q)\Vert_2^2$, but it's hard to optimize; meanwhile, $\Vert x - decoder(z)\Vert_2^2$ is easy to optimize but isn't our real target. What do we do? A crude method would be to use both:

\begin{equation}\Vert x - decoder(z)\Vert_2^2 + \Vert x - decoder(z_q)\Vert_2^2\end{equation}

But this isn't ideal, as minimizing $\Vert x - decoder(z)\Vert_2^2$ is not our objective and adds extra constraints.

VQ-VAE uses a very clever and direct method called the Straight-Through Estimator (STE). It originated from Bengio's paper "Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation". The VQ-VAE paper simply cites this without much explanation. In fact, reading the original Bengio paper is quite unfriendly; looking at source code is better.

The core idea of Straight-Through is simple: during the forward pass, you can use the desired variable (even if it's non-differentiable), but during the backward pass, you use a gradient that you designed for it. Based on this idea, the loss function we design is:

\begin{equation}\Vert x - decoder(z + sg[z_q - z])\Vert_2^2\end{equation}

where $sg$ stands for "stop gradient," meaning its gradient is not calculated. In the forward pass for calculating the loss, this is identical to $decoder(z + z_q - z) = decoder(z_q)$. During the backward pass for gradients, since $z_q - z$ contributes no gradient, it is equivalent to $decoder(z)$, which allows us to optimize the $encoder$.

As a side note, based on this idea, we can customize gradients for many functions. For example, $x + sg[\text{relu}(x) - x]$ defines the gradient of $\text{relu}(x)$ as consistently 1, while remaining identical to $\text{relu}(x)$ for the error calculation itself. We can use the same method to arbitrarily assign a gradient to a function; whether it has practical value depends on the specific task.

Maintaining the Codebook

Note that according to the design of the nearest neighbor search in VQ-VAE, we expect $z_q$ and $z$ to be very close (in fact, each vector in the codebook $E$ ends up acting like a cluster center for various $z$). However, this isn't guaranteed. Even if both $\Vert x - decoder(z)\Vert_2^2$ and $\Vert x - decoder(z_q)\Vert_2^2$ are small, it doesn't necessarily mean $z_q$ and $z$ are similar (just as $f(z_1)=f(z_2)$ doesn't imply $z_1 = z_2$).

So, to bring $z_q$ and $z$ closer, we can directly add $\Vert z - z_q\Vert_2^2$ to the loss:

\begin{equation}\Vert x - decoder(z + sg[z_q - z])\Vert_2^2 + \beta \Vert z - z_q\Vert_2^2\end{equation}

Beyond this, we can be even more precise. Since the codebook ($z_q$) is relatively free while $z$ must ensure reconstruction quality, we should primarily "move $z_q$ toward $z$" rather than "move $z$ toward $z_q$." Since the gradient of $\Vert z_q - z\Vert_2^2$ is the sum of the gradient with respect to $z_q$ and the gradient with respect to $z$, we can decompose it equivalently as:

\begin{equation}\Vert sg[z] - z_q\Vert_2^2 + \Vert z - sg[z_q]\Vert_2^2\end{equation}

The first term fixes $z$ and moves $z_q$ toward $z$; the second term fixes $z_q$ and moves $z$ toward $z_q$. Note this "equivalence" applies to the backward pass (gradients); for the forward pass (loss), it is twice the original. Based on our discussion, we want $z_q$ to move toward $z$ more than vice-versa, so we can adjust the ratio in the final loss:

\begin{equation}\Vert x - decoder(z + sg[z_q - z])\Vert_2^2 + \beta \Vert sg[z] - z_q\Vert_2^2 + \gamma \Vert z - sg[z_q]\Vert_2^2\end{equation}

where $\gamma < \beta$. In the original paper, $\gamma = 0.25 \beta$ is used.

(Note: One can also update the codebook using an exponential moving average; please see the original paper for details.)

Fitting the Code Distribution

After this extensive design, we have finally encoded the image into an $m \times m$ integer matrix. Because this $m \times m$ matrix preserves the spatial information of the original input image to some extent, we can use autoregressive models like PixelCNN to fit the encoding matrix (i.e., modeling the prior distribution). Once the distribution is obtained via PixelCNN, we can randomly generate a new encoding matrix, map it to a 3D real-valued matrix $z_q$ (rows $\times$ columns $\times$ encoding dimension) using codebook $E$, and finally pass it through the $decoder$ to get an image.

Generally, the current $m \times m$ is much smaller than the original $n \times n \times 3$. For instance, in my experiments with CelebA data, an original $128 \times 128 \times 3$ image can be encoded into a $32 \times 32$ matrix with almost no distortion. Thus, modeling the encoding matrix with an autoregressive model is much easier than modeling the original image directly.

Personal Implementation

This is my own VQ-VAE implementation using Keras (Python 2.7 + Tensorflow 1.8 + Keras 2.2.4, with the model part referencing this):

https://github.com/bojone/vae/blob/master/vq_vae_keras.py

The main body of this script only contains the VQ-VAE encoding and reconstruction (the image at the beginning of the post was reconstructed with this script; the results are quite good). It does not include the PixelCNN modeling of the prior distribution. However, the comments at the end include an example of using Attention to model the prior. After modeling the prior distribution with Attention, random sampling results look like this:

Randomly sampled results after modeling the prior with PixelAtt (randomly picked, not screened)
Randomly sampled results after modeling the prior with PixelAtt (randomly picked, not screened)

The results indicate that random sampling is feasible, but the generation quality isn't perfect. I used PixelAtt instead of PixelCNN because, in my reproduction, PixelCNN performed significantly worse than PixelAtt. PixelAtt has specific advantages, though it is memory-hungry and prone to OOM. My personal reproduction not being perfect doesn't mean the method itself is flawed; it could be due to my lack of tuning or the network not being deep enough. Personally, I am optimistic about research into discrete encoding.

Final Summary

By now, I've finally explained VQ-VAE in a way I consider clear. Looking back at the text, there's actually no hint of VAE; as I said, it is simply an AE with discrete vector encoding. The reason it can reconstruct sharp images is that it preserves a sufficiently large feature map during encoding.

Once you understand VQ-VAE, the newly released Version 2.0 isn't hard to grasp. VQ-VAE-2 has almost no fundamental technical updates compared to VQ-VAE; it just performs encoding and decoding in two layers (one global, one local), further reducing blurriness (or at least making it less noticeable, though if you look closely at VQ-VAE-2's large images, slight blurriness remains).

Nevertheless, it's worth acknowledging that the VQ-VAE model is quite interesting. Concepts like discrete encoding and using Straight-Through to assign gradients are novel and worth serious study. They deepen our understanding of deep learning models and optimization (if you can design the gradients, are you still worried about designing the model?).