By 苏剑林 | November 29, 2023
A few days ago, in "VQ the Key, and Transformer's Complexity Becomes Linear," we introduced "Transformer-VQ," a scheme that achieves linear attention complexity by applying a Vector Quantization (VQ) transformation to the Key sequence. Admittedly, Transformer-VQ provides a beautiful transition from standard Attention to linear Attention, embodying a sense of "simplicity as the ultimate sophistication." However, readers familiar with VQ might feel that as the codebook size or model parameters increase, VQ could become a performance bottleneck. This is because the gradients estimated via the Straight-Through Estimator (STE) are likely suboptimal (experimental results from FSQ also provide support for this). Furthermore, the gradient truncation used by Transformer-VQ to make training efficiency linear might also become a future performance bottleneck.
To this end, I spent some time thinking about linearization ideas that could replace VQ. From the $\exp(QC^\top)$ form in Transformer-VQ, I was reminded of Performer, and by "following the trail," I discovered that Performer can be regarded as a "soft" version of Transformer-VQ. Building on this, I attempted to use Performer's derivation method to re-derive Transformer-VQ, providing some reference results for subsequent optimizations.
Flashback
First, let's take a moment to review Transformer-VQ. Let $Q, K \in \mathbb{R}^{n \times d_k}$ and $V \in \mathbb{R}^{n \times d_v}$. The key to Transformer-VQ is the following VQ approximation of $K$:
$$K \approx \hat{K} \triangleq \Delta C$$
Here $\Delta \in \{0,1\}^{n \times c}$ and $C \in \mathbb{R}^{c \times d_k}$ are matrices, where $C$ is a trainable parameter and $\Delta$ is defined as:
$$\Delta_{i,j} = \begin{cases} 1, & j = \text{argmin}_{k=1,2,\dots,c} \|K_i - C_k\| \\ 0, & \text{otherwise} \end{cases}$$
Simply put, VQ approximates $K_i$ using the $C_j$ that is most similar to it. Under this approximation, we have (using the Encoder as a simple example):
$$\exp(Q\hat{K}^\top)V = \exp(QC^\top\Delta^\top)V = \exp(QC^\top)(\Delta^\top V)$$
Readers familiar with linear Attention will easily recognize that the operation in the last expression has linear complexity; this is one of the main components of Transformer-VQ (specifically the numerator; the same applies to the denominator).
Without complex derivation, linear Attention emerges. This gives us the feeling that we reduced Attention's complexity to linear "unintentionally" while approximating the Key, which is quite aesthetic. Thus, we return to the evaluation mentioned multiple times—Transformer-VQ provides a beautiful transition from standard Attention to linear Attention.
Déjà Vu
The $\exp(QC^\top)$ term in Transformer-VQ reminded me of a previous article, "The Road to Transformer Upgrading: 3. From Performer to Linear Attention." In that post, I simplified Performer's results and asserted that the optimal activation function for $Q$ and $K$ in linear Attention is $\exp$. Since Transformer-VQ also utilizes $\exp$, there might be some correlation between them.
To excavate this connection, let's look at Performer, which is based on an elegant approximation:
\begin{equation}
e^{\boldsymbol{q}\cdot \boldsymbol{k}}=\mathbb{E}_{\boldsymbol{\omega}\sim \mathcal{N}(\boldsymbol{\omega};0,\boldsymbol{1}_d)}\left[e^{\boldsymbol{\omega}\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \,e^{\boldsymbol{\omega}\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\right]\approx\underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}}
\label{eq:performer}\end{equation}
Since the final attention must be normalized over all $\boldsymbol{k}$, removing $1/\sqrt{m}$ and $-\Vert \boldsymbol{q}\Vert^2/2$ from the above equation will not affect the final result. At the same time, if we assume the norms of $\boldsymbol{\omega}_1, \boldsymbol{\omega}_2, \dots, \boldsymbol{\omega}_m$ are equal (refer to the JL Lemma), then subtracting $\Vert \boldsymbol{\omega}_i \Vert^2/2$ from the exponent of $\boldsymbol{k}$ will also not affect the outcome. Thus, Performer is equivalent to constructing $\tilde{\boldsymbol{q}}, \tilde{\boldsymbol{k}}$ in the following format:
\begin{equation}\underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_1\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_2\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} = \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{\begin{pmatrix}e^{-\Vert \boldsymbol{k}-\boldsymbol{\omega}_1\Vert^2 / 2} \\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_2\Vert^2 / 2}\\
\vdots\\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \propto \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{softmax\begin{pmatrix}e^{-\Vert \boldsymbol{k}-\boldsymbol{\omega}_1\Vert^2 / 2} \\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_2\Vert^2 / 2}\\
\vdots\\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \end{equation}
Comparing the final expression with the definition of Transformer-VQ, one finds many similarities: don't $\boldsymbol{\omega}_1, \boldsymbol{\omega}_2, \dots, \boldsymbol{\omega}_m$ correspond to the codebook $C$? Doesn't $\tilde{\boldsymbol{q}}$ correspond to $\exp(QC^\top)$? As for the final $\tilde{\boldsymbol{k}}$, it performs a softmax using $-\Vert \boldsymbol{k} - \boldsymbol{\omega}_i\Vert^2 / 2$ as logits; what it highlights is precisely the $\boldsymbol{\omega}_i$ most similar to $\boldsymbol{k}$. Since the limit of softmax is one-hot encoding, doesn't this correspond exactly to Transformer-VQ's $\Delta$ matrix? Therefore, while they aren't identical, they are at least "sixty to seventy percent" similar.
Imitating the Pattern
Of course, the above result is more of a metaphorical analogy than an equivalence, because Performer is essentially based on a completely different approximation logic—for instance, the $\boldsymbol{\omega}_1, \boldsymbol{\omega}_2, \dots, \boldsymbol{\omega}_m$ in Performer are randomly sampled and fixed, meaning their approximation as centroid vectors is actually quite poor. However, this similarity triggered a thought: can we mimic Performer's logic to re-derive Transformer-VQ? That is, like in Equation \eqref{eq:performer}, first construct an exact identity and then transform it into a sampling approximation to obtain the linear version.
After a few days of thinking, I discovered a scheme that can generate the desired derivation. First, we use the Dirac delta function to write:
\begin{equation}e^{\boldsymbol{q}\cdot \boldsymbol{k}} = \int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}}\delta(\boldsymbol{\omega} - \boldsymbol{k})d\boldsymbol{\omega}\end{equation}
This is a pure identity given by the definition of the Dirac function, involving no sophisticated operations or approximations. However, when we substitute this into Attention (the numerator), some interesting results emerge:
\begin{equation}\sum_j e^{\boldsymbol{q}\cdot \boldsymbol{k}_j} \boldsymbol{v}_j = \sum_j \boldsymbol{v}_j\int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}}\delta(\boldsymbol{\omega} - \boldsymbol{k}_j)d\boldsymbol{\omega} = \int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}} \left[\sum_j \delta(\boldsymbol{\omega} - \boldsymbol{k}_j) \boldsymbol{v}_j\right]d\boldsymbol{\omega}\label{eq:inf-vq}\end{equation}
The last equal sign is exactly the form of linear Attention! Of course, because it requires integration over $\boldsymbol{\omega}$, this result, like "The Road to Transformer Upgrading: 5. Linear Attention as Infinite Dimensional," is an "infinite-dimensional" linear Attention, currently possessing only formal value.
Generally speaking, we understand $\delta(\boldsymbol{\omega} - \boldsymbol{k}_j)$ as the limit of a normal distribution $\mathcal{N}(\boldsymbol{\omega}; \boldsymbol{k}_j, \sigma^2 \boldsymbol{I})$ as $\sigma \to 0$, which also means $\delta(\boldsymbol{\omega} - \boldsymbol{k}_j)$ carries the meaning of a conditional distribution $p(\boldsymbol{\omega}|\boldsymbol{k}_j)$. However, from the perspective of generative models, the Dirac function is a single-point distribution—essentially "memorizing" the training set—so it lacks abstraction and generalization capabilities. To alleviate this, we approximate $p(\boldsymbol{\omega}|\boldsymbol{k}_j)$ using a GMM (Gaussian Mixture Model):
\begin{equation}p(\boldsymbol{\omega}|\boldsymbol{k}_j) \approx \sum_{y=1}^m \mathcal{N}(\boldsymbol{\omega}; \boldsymbol{c}_y, \sigma^2 \boldsymbol{I}) \, p(y|\boldsymbol{k}_j) \end{equation}
Substituting this into Equation \eqref{eq:inf-vq} and taking the limit $\sigma \to 0$, we get:
\begin{equation}\sum_j e^{\boldsymbol{q}\cdot \boldsymbol{k}_j} \boldsymbol{v}_j \approx \sum_{y=1}^m e^{\boldsymbol{q}\cdot \boldsymbol{c}_y} \left[\sum_j p(y|\boldsymbol{k}_j) \boldsymbol{v}_j\right]\end{equation}
This yields a finite-dimensional linear Attention. If we align $p(y|\boldsymbol{k}_j)$ with the definition of the one-hot distribution $\Delta$ from Transformer-VQ, then the resulting formula is exactly Transformer-VQ.
Summary
In this article, I shared a discovery: the early linear Attention work "Performer" can be seen as a "soft" version of Transformer-VQ. Building on this observation, I derived a new derivation for Transformer-VQ: using Dirac delta functions to transform standard Attention into an infinite-dimensional linear Attention, and then applying a GMM approximation to arrive at Transformer-VQ.