Another VQ Trick: Adding a Linear Transformation to the Codebook

By 苏剑林 | November 06, 2024

In "The Rotation Trick for VQ: A General Extension of Straight-Through Estimation", we introduced the Rotation Trick for VQ (Vector Quantization). Its idea is to design better gradients for VQ by generalizing the STE (Straight-Through Estimator), thereby alleviating issues such as codebook collapse and low codebook utilization.

Coincidentally, a paper released yesterday on arXiv titled "Addressing Representation Collapse in Vector Quantized Models with One Linear Layer" proposed another trick to improve VQ: adding a linear transformation to the codebook. This trick simply changes the parameterization of the codebook without altering the underlying theoretical framework of VQ, but the experimental results are remarkably excellent, making it a classic example of something "simple and effective."

Foundation

Since we have already introduced VQ and VQ-VAE multiple times in articles like "A Simple Introduction to VQ-VAE: Quantized Autoencoders" and "Embarrassingly Simple FSQ: 'Rounding' Surpasses VQ-VAE", we won't start from scratch here. Let's directly provide the mathematical forms of a standard AE and VQ-VAE:

\begin{align} \text{AE:}&\qquad z = \text{encoder}(x),\quad \hat{x}=\text{decoder}(z),\quad \mathcal{L}=\Vert x - \hat{x}\Vert^2 \\[12pt] \text{VQ-VAE:}&\qquad\left\{\begin{aligned} z =&\, \text{encoder}(x)\\[5pt] z_q =&\, z + \text{sg}[q - z],\quad q = \mathop{\text{argmin}}_{e\in\{e_1,e_2,\cdots,e_K\}} \Vert z - e\Vert\\ \hat{x} =&\, \text{decoder}(z_q)\\[5pt] \mathcal{L} =&\, \Vert x - \hat{x}\Vert^2 + \beta\Vert q - \text{sg}[z]\Vert^2 + \gamma\Vert z - \text{sg}[q]\Vert^2 \end{aligned}\right.\label{eq:vqvae} \end{align}

To reiterate a common point: VQ-VAE is not a VAE; it is simply an AE with VQ added, and it lacks the generative capabilities of a VAE. VQ is the operation of mapping an arbitrary vector to the nearest vector in the codebook. This operation itself is non-differentiable, so the STE is used to design gradients for the encoder, and the two loss terms $\beta$ and $\gamma$ are added to provide gradients for the codebook and to regularize the encoder's representation.

Modification

The paper refers to its proposed method as SimVQ. While they don't explicitly define "Sim," I suspect it stands for "Simple," because the modification in SimVQ is indeed very simple:

\begin{equation} \text{SimVQ-VAE:}\qquad\left\{\begin{aligned} z =&\, \text{encoder}(x)\\[5pt] z_q =&\, z + \text{sg}[q\color{red}{W} - z],\quad q = \mathop{\text{argmin}}_{e\in\{e_1,e_2,\cdots,e_K\}} \Vert z - e\color{red}{W}\Vert\\ \hat{x} =&\, \text{decoder}(z_q)\\[5pt] \mathcal{L} =&\, \Vert x - \hat{x}\Vert^2 + \beta\Vert q\color{red}{W} - \text{sg}[z]\Vert^2 + \gamma\Vert z - \text{sg}[q\color{red}{W}]\Vert^2\end{aligned}\right. \end{equation}

That's it—just multiplying the codebook by a matrix $W$, leaving everything else unchanged.

If you were already training VQ using the form in Eq. $\eqref{eq:vqvae}$, SimVQ can be implemented directly. If you were using EMA to update the codebook (i.e., $\beta=0$ with a separate exponential moving average process to update the codebook, which is the approach in VQ-VAE-2 and subsequent models, mathematically equivalent to using SGD to optimize the codebook loss while other losses use non-SGD optimizers like Adam), then you need to disable EMA and reintroduce the $\beta$ term for end-to-end optimization.

Some readers might immediately question: isn't this just changing the codebook parameterization from $E$ to $EW$? Since $EW$ can be merged into a single matrix, making it equivalent to a new $E$, shouldn't the theoretical capacity of the model remain unchanged? Yes, SimVQ does not change the model's theoretical capacity, but it *does* change it for optimizers like SGD or Adam. It alters the learning dynamics of the optimizer, thereby influencing the quality of the learning outcome.

Experiments

Before diving deeper into the analysis, let's look at SimVQ's experimental results. SimVQ was tested on vision and audio tasks. Table 1 is particularly representative:

SimVQ Experimental Results

SimVQ Experimental Results

According to the paper, the SimVQ code was modified from the VQGAN code shown in the first row. The only change was inserting a linear transformation into the VQ layer, yet the improvement was significant. Not only did it achieve optimal reconstruction quality at the same codebook size, but it also further improved reconstruction quality as the codebook size increased. This demonstrates the charm of SimVQ—simple yet effective.

I also tried it on my own previous VQ-VAE code. Actual tests showed that adding this linear transformation significantly accelerated the convergence speed of VQ-VAE, and the final reconstruction loss was also lower. I also experimented with a variant where $W$ is a diagonal matrix, which is equivalent to each code vector being element-wise multiplied by a parameter vector (initialized to all ones). The results showed that this variant also provided similar benefits, falling somewhere between standard VQ and SimVQ.

Analysis

Intuitively, the update of the codebook in standard VQ is quite "isolated." For example, if a sample $z$ is quantized to $q$, the gradient from this sample only affects $q$ and does not touch other vectors in the codebook. SimVQ is different: it updates not just $q$, but also $W$. From a geometric perspective, $W$ acts as the "basis" for the codebook. Once $W$ is updated, the entire codebook is updated. Therefore, SimVQ makes the "linkage" across the entire codebook more intimate, providing a better chance to find a superior solution rather than falling into the "separate and isolated" local optima.

But why does SimVQ improve codebook utilization? This is also relatively easy to understand. Again, using the interpretation of $W$ as the codebook's basis: if the codebook utilization is too low, $W$ will exhibit "anisotropy," where the basis leans towards the codes that are being utilized. However, once the basis changes in this way, its linear combinations should also lean towards those utilized codes, which prevents utilization from dropping too low. Simply put, the learnable basis automatically adjusts itself to increase its own utilization, thereby pulling up the utilization of the entire codebook.

We can also describe this process through mathematical formulas. Assuming the optimizer is SGD, the update for code $e_i$ in standard VQ is:

\begin{equation}e_i^{(t+1)} = e_i^{(t)} - \eta\frac{\partial \mathcal{L}}{\partial e_i^{(t)}}\end{equation}

In this case, if $e_i$ is not selected in the current batch, $\frac{\partial \mathcal{L}}{\partial e_i^{(t)}}$ is zero, and that entry in the codebook remains unchanged. But if $e_i$ is parameterized as $q_i W$, then:

\begin{equation}\begin{aligned} q_i^{(t+1)} =&\, q_i^{(t)} - \eta\frac{\partial \mathcal{L}}{\partial q_i^{(t)}} = q_i^{(t)} - \eta \frac{\partial \mathcal{L}}{\partial e_i^{(t)}} W^{(t)}{}^{\top}\\ W^{(t+1)} =&\, W^{(t)} - \eta\frac{\partial \mathcal{L}}{\partial W^{(t)}} = W^{(t)} - \eta \sum_j q_j^{(t)}{}^{\top}\frac{\partial \mathcal{L}}{\partial e_j^{(t)}} \\ e_i^{(t+1)}=&\,q_i^{(t+1)}W^{(t+1)}\approx e_i^{(t)} - \eta\left(\frac{\partial \mathcal{L}}{\partial e_i^{(t)}} W^{(t)}{}^{\top}W^{(t)} + q_i^{(t)}\sum_j q_j^{(t)}{}^{\top}\frac{\partial \mathcal{L}}{\partial e_j^{(t)}}\right) \end{aligned}\end{equation}

We can see that:

1. $W$ is updated based on the sum of gradients from all selected codes, so it naturally leans towards high-utilization directions;

2. Due to the presence of $q_i^{(t)}\sum_j q_j^{(t)}{}^{\top}\frac{\partial \mathcal{L}}{\partial e_j^{(t)}}$, the update for code $i$ is almost never zero, regardless of whether it was selected or not;

3. The term $q_i^{(t)}\sum_j q_j^{(t)}{}^{\top}\frac{\partial \mathcal{L}}{\partial e_j^{(t)}}$ acts as a projection onto the high-utilization direction, making every code move towards that direction.

However, too much of a good thing can be bad. If all codes move aggressively toward the high-utilization direction, it might lead to "codebook collapse." Therefore, SimVQ defaults to a very conservative strategy: it only updates $W$, and all $q$ vectors are frozen after random initialization. This almost completely eliminates the possibility of codebook collapse. The good news is that, with an appropriate code dimension, experiments show that updating both $q$ and $W$ performs similarly to updating only $W$, so readers can choose the specific form according to their preference.

Extensions

Setting VQ aside, the practice of introducing extra parameters that are mathematically equivalent—meaning they don't change the model's theoretical fitting capacity but only alter the optimization dynamics—is known as "Overparameterization."

Overparameterization is common in neural networks. For example, the mainstream architecture is now Pre-Norm, i.e., $x + f(\text{RMSNorm}(x))$. The $\gamma$ vector multiplied at the end of RMSNorm is usually overparameterized because the first layer of $f$ is typically a linear transformation (e.g., Attention projecting to Q, K, V, or FFN projecting to higher dimensions). During inference, the $\gamma$ vector can be fully merged into the linear transformation of $f$, yet we rarely see anyone removing $\gamma$ during training.

This is because many believe overparameterization plays an indispensable role in why deep learning models are "easy to train." Therefore, removing proven overparameterization carries significant risk. In this context, "easy to train" mainly refers to the fact that gradient descent—a method that theoretically should get stuck in local optima—frequently finds solutions with excellent practical performance. This in itself is quite remarkable. Works like "On the Optimization of Deep Networks: Implicit Acceleration by Overparameterization" suggest that overparameterization implicitly accelerates training, acting similarly to momentum in SGD.

Finally, since VQ can essentially be understood as a sparse training scheme, the insights and modifications from SimVQ might also be applicable to other sparse training models, such as MoE (Mixture of Experts). In current MoE schemes, updates between experts are also quite independent; only the experts selected by the Router have their parameters updated. Is it possible that, like SimVQ, all experts could be followed by a shared linear transformation to improve expert utilization? Of course, MoE has many differences from VQ, so this remains just a hypothesis.

Summary

This article introduced another training trick for VQ (Vector Quantization)—SimVQ. By simply adding a linear transformation to the VQ codebook with no other changes, one can accelerate convergence, improve codebook utilization, and reduce reconstruction loss. It is remarkably simple and effective.