By 苏剑林 | November 09, 2023
Efficient Transformer refers generally to all work dedicated to reducing the quadratic complexity of the Transformer. Initially, this specifically targeted improvements to Attention, but later, more general ideas such as Fourier transforms and linear RNNs were also included in this category. It must be said that in order to reduce the quadratic complexity of the Transformer, various experts have truly "shown their special prowess like the Eight Immortals crossing the sea," and various magical ideas have "bloomed in abundance," from which I have learned a lot of theoretical knowledge. However, although Efficient Transformers are brilliant in theory, in practice, the field has remained somewhat lukewarm, with no exceptionally performing models. In today's LLM-dominated era, they have even gradually faded from public view and from my own personal interests.
However, a recent paper titled "Transformer-VQ: Linear-Time Transformers via Vector Quantization" has truly impressed me. The authors insightfully observed that by simply applying VQ (Vector Quantization) to the Keys of standard Attention, the complexity automatically drops to linear! This linearization approach preserves the form of standard Attention, serving as a perfect transition from standard Attention to Linear Attention, while retaining the capabilities of standard Attention to the greatest extent possible.
Speaking of which, this site was among the early followers of Efficient Transformer work, dating back to a 2019 blog post interpreting Sparse Transformer: "Born for Savings: From Standard Attention to Sparse Attention". Subsequently, other blog posts written about Efficient Transformers include:
"Exploring Linear Attention: Does Attention Must Have a Softmax?"
"Performer: Linearizing Attention Complexity with Random Projections"
"Nyströmformer: A Linearized Attention Scheme Based on Matrix Decomposition"
"The Road to Transformer Upgrade: 3. From Performer to Linear Attention"
"Linear Transformer Might Not Be the Model You Are Waiting For"
"FLASH: Possibly the Most Interesting Efficient Transformer Design Lately"
"Google's New Work Attempts to 'Resurrect' RNNs: Can RNNs Shine Again?"
However, as mentioned at the beginning of this article, although there has been a lot of work on Efficient Transformers and they were once highly anticipated, the field hasn't produced many "breakout" works. The reasons for this might be:
- Many Efficient Transformers sacrifice performance for speed;
- The complexity reduction in many Efficient Transformers is purely theoretical, with little noticeable improvement in actual use;
- Some Efficient Transformers are difficult to use for training Causal LMs, making them less useful in today's era of popular LLMs;
- The emergence of Flash Attention shows that even standard Transformers still have significant room for speedup.
So, why does Transformer-VQ have "breakout" potential?
In simple terms, Transformer-VQ "clusters" the Key vector sequence of Attention and approximates the original vectors with their assigned cluster centers, which then makes the Attention complexity linear. In other words, Transformer-VQ only changes the form of the Key, while the rest remains (theoretically) completely unchanged. Thus, this is a linearization scheme with minimal changes to Attention, and it clearly demonstrates exactly where the precision is lost after linearization (i.e., the difference between the original vector and its cluster center approximation).
Having laid the groundwork, let's formally introduce Transformer-VQ. First, let's assume $Q, K \in \mathbb{R}^{n \times d_k}$ and $V \in \mathbb{R}^{n \times d_v}$. Standard Attention is:
\begin{equation}softmax\left(QK^{\top}\right)V\end{equation}For simplicity, the scale factor is omitted here. Transformer-VQ changes this to:
\begin{equation}softmax\left(Q\hat{K}^{\top}\right)V,\quad \hat{K} = \color{skyblue}{\mathcal{VQ}}(K, C)\label{eq:vq-att}\end{equation}Where $C \in \mathbb{R}^{c \times d_k}$ is the training parameter, which is the VQ Codebook. By the way, "VQ" here refers to the VQ in VQ-VAE. Readers unfamiliar with this can refer to "A Concise Introduction to VQ-VAE: Quantized Autoencoders" and "Embarrassingly Simple FSQ: 'Rounding' Surpasses VQ-VAE"; I won't repeat the details here. In short, after $\color{skyblue}{\mathcal{VQ}}$, the most direct result is that every vector in $K$ becomes the one in $C$ that is most similar to it. This means every vector in $\hat{K}$ is one of the vectors in $C$; in mathematical terms, $K \in \mathbb{R}^{n \times d_k}$ becomes $\hat{K} \in C^n$.
Of course, if we were to implement Transformer-VQ directly according to equation \eqref{eq:vq-att}, the complexity would still be quadratic. however, since every vector in $\hat{K}$ is one of the vectors in $C$, we can first compute $\exp(QC^\top)$ and then "pick out" the results corresponding to $\exp(Q\hat{K}^\top)$. Since the size of $C$ is fixed, the complexity of the key operation $QC^\top$ is linear. This is the principle behind how Transformer-VQ achieves linearization (which we might call the "picking" trick).
As preparation, let's first consider the Encoder case with bidirectional attention. Since:
\begin{equation}softmax\left(QK^{\top}\right)V = \frac{\exp\left(QK^{\top}\right)V}{\exp\left(QK^{\top}\right)1_{n\times 1}}\label{eq:softmax-qkv}\end{equation}Here $1_{n\times 1}$ refers to an $n \times 1$ matrix of all ones. The denominator can be viewed as a special form of the numerator, so we only need to consider the numerator $\exp(QK^\top)V$. Since every vector in $\hat{K}$ is one of those in $C$, we can construct a one-hot matrix $\Delta \in \{0,1\}^{n \times c}$, where $\Delta_i \in \{0,1\}^c$ is a one-hot vector; if the dimension where the 1 is located is $j$, then $\hat{K}_i = C_j$, hence $\hat{K} = \Delta C$.
Thus, for Transformer-VQ, we have:
\begin{equation}\exp\left(Q\hat{K}{}^{\top}\right)V = \exp\left(QC^{\top}\Delta^{\top}\right)V = \exp\left(QC^{\top}\right)\Delta^{\top}V = \exp\left(QC^{\top}\right)(\Delta^{\top}V)\end{equation}Clearly, the most crucial part here is the second equals sign! For the one-hot matrix $\Delta$, right-multiplying by its transpose can be separated from the $\exp$ operation. This is the mathematical expression of the "picking" trick. Once separated, due to the associative property of matrix multiplication, $\Delta^\top$ can be multiplied by $V$ first, yielding a $c \times d_v$ matrix. Since $\exp(QC^\top)$ is an $n \times c$ matrix, multiplying it by $\Delta^\top V$ results in an $n \times d_v$ matrix. The total theoretical complexity is $\mathcal{O}(ncd_k + ncd_v + ncd_v) = \mathcal{O}(n)$.
Finally, according to equation \eqref{eq:softmax-qkv}, by substituting the result of $\exp(Q\hat{K}^\top)V$, the complete Attention result can be calculated (possibly needing some details to avoid overflow), and the entire process can be completed within linear complexity.
Now let's consider the Decoder with unidirectional attention, which is key for training generative models and is the foundation of current LLMs. With the Encoder preparation, understanding the Decoder is not that difficult. Suppose $Q_i, \hat{K}_j \in \mathbb{R}^{1 \times d_k}, V_j \in \mathbb{R}^{1 \times d_v}$ are one of the row vectors in the sequences $Q, \hat{K}, V$. Then, for the Decoder numerator, we have:
\begin{equation}\begin{aligned} O_i =&\, \sum_{j\leq i}\exp\left(Q_i\hat{K}{}_j^{\top}\right)V_j = \sum_{j\leq i}\exp\left(Q_i C^{\top}\Delta_j^{\top}\right)V_j \\ =&\, \sum_{j\leq i}\exp\left(Q_i C^{\top}\right)\Delta_j^{\top}V_j = \exp\left(Q_i C^{\top}\right)\sum_{j\leq i}\Delta_j^{\top}V_j \end{aligned}\end{equation}If $c \times d_v$ is not large, the final expression can be computed directly using the $\text{cumsum}$ operator. However, in general cases, especially with Multi-Head Attention, to save memory, it is usually converted into an RNN for recursive calculation, as in the "Autoregressive Generation" section of "Exploring Linear Attention: Does Attention Must Have a Softmax?". Specifically, let $U_i = \sum_{j \leq i} \Delta_j^\top V_j \in \mathbb{R}^{c \times d_v}$, then:
\begin{equation}O_i = \exp\left(Q_i C^{\top}\right)U_i,\quad U_i = U_{i-1} + \Delta_i^{\top}V_i \end{equation}During the inference phase, this step-by-step recursive calculation works fine. However, training step-by-step might be slow. We can change it to block-by-block for acceleration: without loss of generality, let $n=lm$, where $l$ represents the $block\_size$ and $m$ represents the number of blocks. The block slice $[il:(i+1)l]$ is abbreviated as $[i]$, then:
\begin{equation}\begin{aligned} O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j < i}\exp\left(Q_{[i]}\hat{K}{}_{[j]}^{\top}\right)V_{[j]} \\ =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j < i}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\ =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j < i}\Delta_{[j]}^{\top}V_{[j]} \\ \end{aligned}\end{equation}Where $M \in \{-\infty, 0\}^{l \times l}$ is the lower-triangular Attention Mask, i.e., $M_{i,j}=0$ when $i \geq j$, otherwise $M_{i,j}=-\infty$. Thus, by denoting $U_i = \sum_{j < i} \Delta_{[j]}^\top V_{[j]}$, we have:
\begin{equation}O_{[i]} = \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-1},\quad U_i = U_{i-1} + \Delta_{[i]}^{\top}V_{[i]} \end{equation}Thus we have reduced the recursion steps to $m$, allowing us to better utilize hardware parallelism while maintaining linear efficiency. The same method can be used to calculate the denominator, and finally, division yields the full Attention result.
Is that all? Not quite. If it were just this, Transformer-VQ might not be much different from previous kernelized attention models based on matrix decomposition, such as the Performer. When the sequence length $n$ is much larger than the codebook size $c$, we know from the pigeonhole principle that some code vectors will inevitably recur. We can even reasonably guess that all code vectors should be uniformly distributed across the entire sequence. Consequently, Attention for nearby tokens would be the same as Attention for certain distant tokens. In other words, the model cannot distinguish between near and far, which is essentially the low-rank problem inherent in all Kernelized Attention models.
Existing experience tells us that for language models, nearby tokens are often more important than distant tokens. Therefore, a good language model architecture should have the ability to distinguish distance. To this end, Transformer-VQ chooses to add a Sliding Window shaped Attention Bias (denoted as $B$) after $Q\hat{K}$ to weight nearby tokens, as shown below:

From the last diagram, it can be seen that if the window size is set directly to the block size $l$, i.e., $B_{i,j}=0$ when $i < j$ or $i - j \leq l$, then in block-wise computation, matrix $B$ at most affects the two most adjacent blocks. Further blocks can still be linearized using the "picking" trick. For convenience in the derivation below, let $B_{[i,j]} = B_{[il:(i+1)l, jl:(j+1)l]}$, then:
\begin{equation}\begin{aligned} O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j < i-1}\exp\left(Q_{[i]}\hat{K}{}_{[j]}^{\top}\right)V_{[j]} \\ =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j < i-1}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\ =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j < i-1}\Delta_{[j]}^{\top}V_{[j]} \\ \end{aligned}\end{equation}So obviously, we have (assuming $V_{[-1]}, U_{[-1]}, U_{[-2]}$ are all matrices of zeros):
\begin{equation}\begin{aligned} O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-2}\\[5pt] U_i =&\, U_{i-1} + \Delta_{[i]}^{\top}V_{[i]} \end{aligned}\label{eq:tvq}\end{equation}In my opinion, the introduction of $B$ is the key reason Transformer-VQ pulls away from other Kernelized Attention models. To reduce the number of parameters and support variable-length generation, we constrain the non-zero part of $B$ to be a "Toeplitz matrix," i.e., $B_{i,j}$ is a function of $i-j$. In this case, $B$ is equivalent to additive relative position encoding. Instead of this approach, one could also consider replacing it with ReRoPE, which I previously proposed; it is a windowed version of RoPE and shares the same relative position encoding shape as $B$.
Wait, did we forget something? Readers familiar with VQ-VAE know that "every vector in $\hat{K}$ is one of the vectors in $C$" is only the behavior during forward propagation. Reverse propagation uses the original $K$, which means even if different positions $\hat{K}_j$ equal the same $C_k$, their gradients are not necessarily equal. This is called STE (Straight-Through Estimator). Due to STE, the "picking" trick theoretically only applies to the inference phase; the training phase cannot be linearized.
Is there no other way? Indeed, if we insist on obtaining precise gradient results, there is no scheme for linearization efficiency. However, considering that the gradient of VQ itself is an approximation, the necessity of obtaining exact gradients for Attention doesn't seem that critical. So the authors devised a compromise: still perform recursive calculation following equation \eqref{eq:tvq}, but only use STE in the first two terms (allowing the Key sequence to receive gradients), while the gradient for $U_{i-1}$ is simply stopped (using the $\text{stop\_gradient}$ operator). In this way, we maintain the linearity of the model while also preserving the most important gradients (the most recent two blocks), which is a reasonable approximation scheme. From this point of view, Transformer-VQ is very similar to Transformer-XL, which also stops gradients of history windows during recursion, allowing the historical window to participate in recursive calculation without passing gradients back.
After resolving the gradient backpropagation issue, an auxiliary loss brought by VQ, used for updating the codebook, is added to the autoregressive cross-entropy loss to get the complete training objective. Of course, for codebook updates, Transformer-VQ uses a direct exponential moving average (EMA) scheme, so it only supplements the auxiliary loss for the Key. These details were easy to understand from looking at the original paper after being familiar with VQ-VAE.
In this section, let's look at the experimental results from the original paper. The author has open-sourced the code at:
It is worth noting that the base architecture the author used for VQ is not the conventional MHA (Multi-Head Attention), but the GAU (Gated Attention Unit) + Softmax, which I have always strongly advocated for. Transformer-VQ would be more accurately named "GAU-VQ." Readers unfamiliar with GAU can refer to "FLASH: Possibly the Most Interesting Efficient Transformer Design Lately" and "It Seems Attention and Softmax Pair Better Together." In brief, GAU itself is more efficient than MHA, and with the VQ technique, it becomes even "more powerful."
In terms of experiments, the author performed language modeling (enwik8, PG-19) and image generation (ImageNet64). In all experiments, the codebook size was $c=512$. The maximum parameter count for the model was 1.3B, which, while not as large as mainstream big models, is not small for research purposes. The experimental results are generally excellent:


Finally, it is surprising that there is only one author for Transformer-VQ, and their identity is "Independent Researcher."
I have found that starting from Transformer-VQ, one can connect to many research topics, which is one of the reasons I appreciate it so much.
First, once again, kudos to the author's remarkable insight. The discovery that "simply VQing the Key makes Transformer complexity linear" is truly wonderful. it achieves a natural transition from standard Attention to Linear Attention and can be made more effective than many Kernelized Attentions by adding Attention Bias. Furthermore, the "clustering" method through VQ is also more sophisticated than Linformer, Nyströmformer, etc., because it prevents future information leakage and can naturally be used for Causal language models.
We know that VQ is essentially an operation that converts a sequence into discrete IDs, which is very similar to the function of a Tokenizer. From this perspective, Transformer-VQ, like models such as MegaByte, builds the Tokenizer directly into the model. Compared to MegaByte, the VQ operation is much more similar and intuitive to our traditional idea of a Tokenizer. Thus, Transformer-VQ is actually very suitable for training "No Tokenizer" models that take direct Byte input. In fact, the enwik8 experiment mentioned above used Byte input, and Transformer-VQ's performance was significantly better than MegaByte's.
Compared to the recently released RetNet, Transformer-VQ has no explicit long-range decay, so its Long Context capability might be better. At the same time, because the Keys are VQ'd and belong to a finite set, there are no "unseen" Keys, so its length extrapolation ability is likely better. Although the base architecture GAU of Transformer-VQ is Single-Head, its model memory state size during recursion is $\Delta_i^\top V_i \in \mathbb{R}^{c \times d_v}$. In the default settings, this is larger than that of the Multi-Head RetNet (RetNet's memory state size is $d_k^2$, and in default settings $d_v = 2d_k$), so the memory capacity is theoretically sufficient.
Since the previous article was exactly "Embarrassingly Simple FSQ: 'Rounding' Surpasses VQ-VAE," some readers might wonder if the simpler FSQ can replace VQ? I think it would be difficult, for reasons actually given in the previous article: first, $c=512$ still falls within the range where VQ is superior to FSQ in terms of code quantity, so switching to FSQ would likely hurt performance; second, since Northern Key in every Layer Attention must be VQ'd, on average, the VQ Encoder and Decoder are not very strong, a situation where VQ's approximation precision is higher; third, Transformer-VQ needs to use the center vector of the VQ-approximated Key rather than an ID, and FSQ yields an ID directly, making it harder to recover an approximated center vector.
Besides this, using VQ instead of FSQ gives Transformer-VQ the hope of being fine-tuned from existing pre-trained models like Llama 2, rather than just being trained from scratch. Because VQ has clear geometric meaning and many similarities with K-Means, we can start with an existing pre-trained model, calculate the Keys for some samples, perform K-Means on the Keys to get center vectors as the initialization for the codebook, and then fine-tune the original model with the added VQ. However, Transformer-VQ is not well-suited for RoPE, so as mentioned before, it would be better to switch RoPE models to ReRoPE before VQ, in which case the Bias may not be necessary.
In short, in my eyes, Transformer-VQ is among the most unique, outstanding, and deeply potential works among numerous Efficient Transformer efforts.
This article introduced an Efficient Transformer scheme named Transformer-VQ. It expands on the observation that "simply VQing the Key makes Transformer complexity linear." I personally find it a very unique and striking linearization strategy, with excellent experimental results. It can be understood both as a more sophisticated linear Attention/RNN model and as an Attention model with a "trainable Tokenizer."