The Road to Transformer Upgrade: 3. From Performer to Linear Attention

By 苏剑林 | April 22, 2021

Readers of my previous articles "Exploration of Linear Attention: Does Attention Need a Softmax?" and "Performer: Linearizing Attention Complexity with Random Projections" might find the title of this post a bit unnatural. Since Linear Attention came before the Performer, the relationship is typically framed as "the Performer is an implementation of Linear Attention that approximates standard Attention while maintaining linear complexity." Thus, normally, it should be "From Linear Attention to Performer." However, this post does not intend to trace the historical development of Linear Attention, but rather to reflect on the insights Performer brings to Linear Attention. Hence, "From Performer to Linear Attention."

Activation Functions

The common form of Linear Attention is:

\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)} = \frac{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)}\end{equation}

where $\phi(\cdot)$ and $\varphi(\cdot)$ are non-negative activation functions. How should we choose this activation function? Performer tells us that we should choose the exponential function:

\begin{equation}\phi(x)=\varphi(x)=e^x\end{equation}

First, let's look at how this differs from existing results. In "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention", the given choice is:

\begin{equation}\phi(x)=\varphi(x)=1 + \text{elu}(x) = \left\{\begin{aligned}1 + x,\, x \geq 0\\ e^x,\, x < 0\end{aligned}\right.\end{equation}

We know that $1+x$ is exactly the first-order Taylor expansion of $e^x$ at $x=0$, so the choice of $1+\text{elu}(x)$ is actually quite close to $e^x$. Furthermore, the scheme $\phi(x)=\varphi(x)=e^x$ is very similar to the dual-softmax design for building Linear Attention introduced in "Efficient Attention: Attention with Linear Complexities". In that design, $\phi(\boldsymbol{q})=softmax(\boldsymbol{q})$ and $\varphi(\boldsymbol{k})=e^{\boldsymbol{k}}$. Compared to directly using $\phi(x)=\varphi(x)=e^x$, the difference is merely the position of normalization.

Simple Derivation

Why do we say Performer tells us the best activation function is $e^x$? Let's look at the mapping found by Performer to linearize standard Attention:

\begin{equation}\begin{aligned} e^{\boldsymbol{q}\cdot \boldsymbol{k}}&=\mathbb{E}_{\boldsymbol{\omega}\sim \mathcal{N}(\boldsymbol{\omega};0,\boldsymbol{1}_d)}\left[e^{\boldsymbol{\omega}\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \times e^{\boldsymbol{\omega}\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\right]\\[6pt] &\approx\underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{q}}} \cdot \underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \end{aligned}\end{equation}

Simply put, Performer found a mapping that transforms $d$-dimensional vectors $\boldsymbol{q}, \boldsymbol{k}$ into $m$-dimensional vectors $\tilde{\boldsymbol{q}}, \tilde{\boldsymbol{k}}$, satisfying the approximation $e^{\boldsymbol{q}\cdot \boldsymbol{k}}\approx \tilde{\boldsymbol{q}}\cdot\tilde{\boldsymbol{k}}$. In this case:

\begin{equation}a_{i,j} = \frac{e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}{\sum\limits_j e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}\approx \frac{\tilde{\boldsymbol{q}}_i\cdot\tilde{\boldsymbol{k}}_j}{\sum\limits_j \tilde{\boldsymbol{q}}_i\cdot\tilde{\boldsymbol{k}}_j} = \frac{(\lambda(\tilde{\boldsymbol{q}}_i)\tilde{\boldsymbol{q}}_i)\cdot\tilde{\boldsymbol{k}}_j}{\sum\limits_j (\lambda(\tilde{\boldsymbol{q}}_i)\tilde{\boldsymbol{q}}_i)\cdot\tilde{\boldsymbol{k}}_j}\end{equation}

The last equality shows that multiplying $\tilde{\boldsymbol{q}}$ by a constant (even if this constant depends on $\tilde{\boldsymbol{q}}$) does not change the Performer's result at all. This means if we change the mapping to:

\begin{equation} \tilde{\boldsymbol{q}} = \begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix},\qquad \tilde{\boldsymbol{k}}=\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \end{pmatrix} \end{equation}

The Performer's result will not change. Of course, the $\Vert \boldsymbol{k}\Vert^2$ term cannot be removed yet. However, if we assume that $\Vert \boldsymbol{k}\Vert^2$ does not fluctuate too much and is not the primary factor of Attention, then this term also acts as a constant. Thus, the final mapping is (approximately) equivalent to:

\begin{equation} \tilde{\boldsymbol{q}} = \begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix},\qquad \tilde{\boldsymbol{k}}=\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}} \end{pmatrix} \end{equation}

How do we understand this mapping, which appears much simplified? In fact, the $m$ random vectors $\boldsymbol{\omega}_1,\boldsymbol{\omega}_2,\cdots,\boldsymbol{\omega}_m$ form a $d\times m$ matrix that maps the $d$-dimensional $\boldsymbol{q}, \boldsymbol{k}$ to $m$-dimensional vectors, which are then passed through the activation function $e^x$ to obtain $\tilde{\boldsymbol{q}}, \tilde{\boldsymbol{k}}$. We know that $\boldsymbol{q}, \boldsymbol{k}$ in Attention undergo a fully connected layer transformation. If we integrate this $d\times m$ mapping matrix into that fully connected layer, what remains is the activation function $e^x$!

So this is the source of the optimal activation function $e^x$: as long as we change the output dimension of $\boldsymbol{q}, \boldsymbol{k}$ from $d$ to $m$ and pair it with the activation function $e^x$, then theoretically, it possesses the fitting capability of Performer, or even stronger. This is because Performer's $d\times m$ matrix is a fixed random matrix, whereas here, we treat that matrix as trainable and remove the low-rank constraint, allowing for a larger space than Performer.

The Low-Rank Problem

Whether it is the Performer discussed here or the Nyströmformer introduced previously, their approach is to "seek a Linear Attention that can approximate standard Attention." A natural question arises: what is so good about standard Attention? Why is it worth aligning with?

From the perspective of information loss, the "rank" of the standard Attention matrix may be larger, meaning it is closer to an invertible matrix, which signifies it can retain more valid information. Specifically, the Attention matrix is an $n\times n$ matrix derived from $\boldsymbol{Q},\boldsymbol{K}\in\mathbb{R}^{n\times d}$ via $softmax(\boldsymbol{Q}\boldsymbol{K}^{\top})$. It should be noted that the $d$ here is the key_size of the Attention; for instance, in BERT-base, it is only 64, while $n$ is often quite large. This implies the rank of $\boldsymbol{Q}\boldsymbol{K}^{\top}$ does not exceed $d$, and since $d\ll n$, it is far from full rank. However, the key operation of $softmax$ is $e^{\boldsymbol{Q}\boldsymbol{K}^{\top}}$. If every element in a matrix is exponentiated, the rank of the new matrix can potentially increase! Thus, the standard Attention matrix has the potential for rank expansion, meaning it inherently possesses a more effective capacity for processing information.

In contrast, the Linear Attention matrix takes the form $\tilde{\boldsymbol{Q}}\tilde{\boldsymbol{K}}^{\top}$, so its rank cannot exceed $m$. To compensate for the loss of rank, $m > d$ is generally set. In Performer experiments, $m = 4d$ was chosen, meaning the key_size was quadrupled—the importance of rank is evident. Of course, increasing the key_size has the direct consequence that Linear Attention may be slower than standard Attention when processing short sequences, which is an inherent bottleneck of Linear Attention.

Regarding theoretical analysis of the rank of the Attention matrix, there are papers for reference. For example, "Low-Rank Bottleneck in Multi-head Attention Models" points out that even in standard Attention, low-rankness is a serious bottleneck, and increasing key_size can improve performance. A paper from last month, "Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth", points out that without residuals and FFNs, standard Attention runs a high risk of decaying into a simple transformation with a rank of 1. If even standard Attention, which has "rank-expanding potential," faces low-rank issues, then Linear Attention—which has a hard cap on its rank—is even more vulnerable.

In short: Linear Attention requires a larger key_size to maintain its rank.

Concentrating Attention

We can also understand the benefits of standard Attention from the perspective of sparsity. Intuitively, since it is an "Attention Mechanism," it must "concentrate attention." If it is too dispersed, it might be equivalent to average pooling. "Concentrating attention" means each token should only significantly correlate with a few other tokens. Mathematically, this means the Attention matrix is sparse, or at least has the potential to become sparse.

For standard Attention, it normalizes through softmax:

\begin{equation}a_{i,j} = \frac{e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}{\sum\limits_j e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}\end{equation}

The exponential function $e^x$ plays an amplifying role. As long as the individual $\boldsymbol{q}_i\cdot \boldsymbol{k}_j$ values can establish some gap, $e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}$ will further widen that gap. The result is that after normalization, probabilities at positions other than the maximum values become very close to 0. This shows standard Attention has the potential to "concentrate attention." Linear Attention, however, is the result of a direct inner product without the amplification of $e^x$. Consequently, its attention is denser, and when sequence lengths are large, it often approaches average pooling. To mitigate this, one still needs to increase the key_size to amplify the differences; intuitively, $n$ vectors are too "crowded" in a low-dimensional space, and moving to a higher-dimensional space makes it "looser."

How can we verify the importance of sparsity? I once tried calculating the Linear Attention matrix and then manually truncating it (i.e., each token only attends to a few tokens before and after it, turning it into a local Attention). The results showed that this truncated Linear Attention performed significantly better than the full-matrix version. This confirms the importance of sparsity. Of course, explicitly calculating the Attention matrix before truncating it means the complexity is no longer linear, so it lacks practical value and was used only for theoretical validation.

Another experimental phenomenon supports the importance of sparsity: when Linear Attention is used for language models or decoders, its performance is nearly identical to standard Attention. In these cases, Linear Attention becomes a unidirectional RNN (refer to "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"), which is equivalent to the Attention matrix becoming lower triangular and thus sparser. In contrast, if non-sparse bidirectional Linear Attention is used directly for an MLM (Masked Language Model), the performance drop is quite significant.

More importantly, sparsity and the rank mentioned in the previous section are closely linked—two sides of the same coin: appropriate sparsification methods can increase the rank of a matrix! For instance, in the lower triangular Attention matrix of a language model, as long as the diagonal elements are non-zero (which they usually are), the matrix becomes full-rank and invertible! Similarly, the local Attention truncation I experimented with also increases the rank; in an extreme case, if each token only attends to itself, the Attention matrix becomes the full-rank Identity matrix!

Article Summary

Starting from the Performer, this article reflects on several issues regarding Linear Attention, including the choice of activation function and the bottlenecks (low-rankness, sparsity). The overall conclusion is that the optimal activation function for Linear Attention should be the exponential function, and an effective Attention mechanism should possess higher rank and greater sparsity.