By 苏剑林 | February 25, 2022
Efficient Transformers, which generally refer to works that improve the efficiency of Transformer models, are something I have been following for a long time. My earliest blog post on this topic dates back to 2019, "Born for Economy: From Standard Attention to Sparse Attention", at a time when there were very few works in this area. Later, such works gradually increased, and I followed several of them, such as Linear Attention, Performer, and Nyströmformer, and even performed some explorations myself, such as the "Transformer Upgrade Path". Subsequently, as related works became more numerous and often quite repetitive, I stopped paying close attention.
Feeling like "welcome rain after a long drought," a very interesting work on efficient Transformers has recently appeared—Google's "Transformer Quality in Linear Time". After reading it carefully, I believe it is truly full of surprises.
What kind of results deserve the description "surprising"? Is it an exaggeration? Let's first look at what the paper achieves:
1. It proposes a new Transformer variant that still has quadratic complexity but, compared to standard Transformers, offers faster speeds, lower memory usage, and better performance.
2. It proposes a new linearization scheme for Transformers that not only improves the performance of existing Linear Attention but also maintains the possibility of acting as a Decoder, and even maintains high training parallelism when functioning as a Decoder.
To be honest, I feel that achieving either of these points would be rare and commendable, but this paper achieves both simultaneously. More importantly, the improvements in the paper are generally natural and elegant, unlike many similar works that feel forced. Furthermore, I have conducted simple replication experiments, and the results show that the paper's reproducibility is quite good. I really feel a sense of "Transformer is in danger."
Without further ado, let's get to the main topic. We know that standard Transformers are built by alternating Attention layers and FFN layers. The core of this paper is the proposal of a new design called GAU (Gated Attention Unit), which merges the two. This is the key to the new model being faster, more economical, and better. Additionally, it results in a model with only one type of layer, which is more elegant.
How do we merge Attention and FFN? First, a standard FFN is a two-layer MLP model:
\begin{equation}\boldsymbol{O}=\phi(\boldsymbol{X}\boldsymbol{W}_u)\boldsymbol{W}_o\end{equation}Here $\boldsymbol{X}\in\mathbb{R}^{n\times d}, \boldsymbol{W}_u\in\mathbb{R}^{d\times e}, \boldsymbol{W}_o\in\mathbb{R}^{e\times d}$ and $\phi$ is an activation function. Later, "GLU Variants Improve Transformer" found that FFNs using GLU (Gated Linear Unit) perform better, which was adopted by mT5. Its form is:
\begin{equation}\boldsymbol{O}=(\boldsymbol{U}\odot\boldsymbol{V})\boldsymbol{W}_o,\quad \boldsymbol{U}=\phi_u(\boldsymbol{X}\boldsymbol{W}_u),\quad\boldsymbol{V}=\phi_v(\boldsymbol{X}\boldsymbol{W}_v)\end{equation}Where $\boldsymbol{W}_u, \boldsymbol{W}_v\in\mathbb{R}^{d\times e}$ and $\odot$ is the element-wise (Hadamard) product. It is not surprising that GLU is more effective; it played a key role in Facebook's 2017 paper "Convolutional Sequence to Sequence Learning", and my own research on DGCNN also confirmed the effectiveness of GLU.
In general GLUs, $\boldsymbol{U}$ is used without an activation function while $\boldsymbol{V}$ has a Sigmoid. However, this paper uses the Swish activation function (also known as SiLU, Sigmoid Linear Unit) for both $\boldsymbol{U}$ and $\boldsymbol{V}$. This can be found in the source code in the appendix; this usage is slightly different from the mainstream GLU usage, and I'm pointing it out specifically.
Since GLU-style FFNs are more effective, let's modify it. Note that FFNs cannot replace Attention because there is no interaction between tokens—that is, each row of $\boldsymbol{U}, \boldsymbol{V}$ is calculated independently. To compensate for this, a natural idea is to add the relationship between tokens to $\boldsymbol{U}$ and $\boldsymbol{V}$. To reflect the combination with Attention, a natural design is:
\begin{equation}\boldsymbol{O}=(\boldsymbol{U}\odot\boldsymbol{A}\boldsymbol{V})\boldsymbol{W}_o\label{eq:mix}\end{equation}Where $\boldsymbol{A}\in\mathbb{R}^{n\times n}$ is the Attention matrix, responsible for merging information between tokens. The resulting $\boldsymbol{O}$ contains token interactions, and in principle, it can replace Attention. As for how $\boldsymbol{A}$ is calculated, we will discuss that in a moment.
In formula $\eqref{eq:mix}$, if $\boldsymbol{A}$ equals the identity matrix $\boldsymbol{I}$, it becomes a GLU-style FFN; if $\boldsymbol{U}$ is an all-ones matrix, it becomes the standard attention mechanism. Therefore, $\eqref{eq:mix}$ is a simple and natural fusion of Attention and FFN. We expect it to replace both Attention and FFN simultaneously, perhaps with better performance.
As mentioned, GLU is already very strong—otherwise, Facebook wouldn't have achieved SOTA in Seq2Seq using CNN+GLU. Since GLU is so strong, one hypothesis is that it weakens the dependence on Attention. That is, although $\boldsymbol{A}$ is indispensable in formula $\eqref{eq:mix}$, we might be able to simplify its form. Indeed, the original paper uses the following simplified Attention matrix:
\begin{equation}\boldsymbol{A}=\frac{1}{n}\text{relu}^2\left(\frac{\mathcal{Q}(\boldsymbol{Z})\mathcal{K}(\boldsymbol{Z})^{\top}}{\sqrt{s}}\right)=\frac{1}{ns}\text{relu}^2\left(\mathcal{Q}(\boldsymbol{Z})\mathcal{K}(\boldsymbol{Z})^{\top}\right),\quad \boldsymbol{Z}=\phi_z(\boldsymbol{X}\boldsymbol{W}_z)\label{eq:relu-att}\end{equation}Here $\boldsymbol{W}_z\in\mathbb{R}^{d\times s}$, and $s$ is the attention's head_size ($s=128$ is used). $\mathcal{Q}, \mathcal{K}$ are simple affine transformations (like multiplying by $\gamma$ and adding $\beta$ in Layer Norm), and $\text{relu}^2$ is $\text{relu}$ followed by a square.
Similar to standard Scaled-Dot Self Attention, the attention matrix here is derived from the inner product of $\boldsymbol{Q}$ and $\boldsymbol{K}$ divided by the square root of the dimension, and the complexity is still $\mathcal{O}(n^2)$. The difference is that the transformation source for $\boldsymbol{Q}, \boldsymbol{K}$ is simplified, and the activation function is changed to $\text{relu}^2$. This activation function might be unfamiliar; it was actually discovered via NAS by the authors in their previous paper "Primer: Searching for Efficient Transformers for Language Modeling". The final $1/n$ is a simple normalization factor to eliminate the effect of length.
Note that according to the reference code in the appendix, the simplified scaling factor is actually $\frac{1}{n^2}$ rather than $\frac{1}{ns}$ as shown above. I believe $\frac{1}{ns}$ is more reasonable; otherwise, as $n$ grows large, each attention term becomes too small. Comparing it to the softmax used in standard attention, its denominator is only of order $\mathcal{O}(n)$, so $n^2$ feels unscientific. I've done a simple comparison and found that at sequence length 512, the $\frac{1}{ns}$ version performs slightly better, so I will stick with my intuition here.
Next, ladies and gentlemen, don't blink—the real "heavyweight" is arriving! Perhaps GLU is truly so powerful that its dependence on Attention is incredibly weak, to the point where the authors discovered: Only one head is enough!
We know standard Transformers use multi-head attention mechanisms, which require generating matrices of size $bhn^2$ during computation ($b$ is batch_size, $h$ is the number of heads). Imagine when $n=1000, n=2000$, or even larger, $n^2$ is already painful, and multiplying it by $h$ makes time and space complexity even worse. Now, a single-headed GAU can achieve the same or even better results, increasing calculation speed and reducing memory usage—it's essentially a "free lunch."
When GAU has only one head, the number of parameters for $\boldsymbol{W}_z$ is very small. The main parameter counts are in $\boldsymbol{W}_u, \boldsymbol{W}_v, \boldsymbol{W}_o$, so the total GAU parameters are approximately $3de$. In a standard Transformer, Attention has $4d^2$ parameters and FFN has $8d^2$ (assuming $e=4d$), for a total of $12d^2$. Thus, from a parameter standpoint, when $e=2d$, two GAU layers are roughly equivalent to one Attention+FFN layer.
In the GAU experiments, the authors fixed $e=2d$. Thus, a standard Transformer with "$n$ Attention layers + $n$ FFN layers" corresponds to a "$2n$ GAU layer" model, which we call FLASH-Quad. "Quad" is short for "Quadratic," indicating the complexity remains quadratic. As for the meaning of FLASH, we'll talk about that next.
In fact, FLASH-Quad is already an excellent replacement for the standard Transformer. However, the authors weren't satisfied with quadratic complexity and proposed FLASH (Fast Linear Attention with a Single Head) with linear complexity. To do this, they proposed a "Mixed Chunk Attention" scheme, which can be used both in GAU and standard Attention as a general linearization trick.
Mainstream efficient Transformer efforts generally fall into two categories: "Sparsification" and "Linearization."
The post "Born for Economy" mentioned earlier is a "Sparsification" work, followed by things like Reformer. Works combined with pooling like Linformer also fall into this category. The característica of these works is introducing certain inductive biases to force most attention to zero, theoretically reducing computation. However, the downside is that they often require specialized programming optimization to achieve speedup, or they are difficult to use as Decoders (pooling-based works), and their effectiveness depends heavily on the introduced inductive bias.
As for "Linearization," introduced in "Exploring Linear Attention", many researchers have looked into this, including cosFormer and Flowformer. Simply put, these works change standard Attention $\phi(\boldsymbol{Q}\boldsymbol{K}^{\top})\boldsymbol{V}$ to $(\phi_q(\boldsymbol{Q})\phi_k(\boldsymbol{K})^{\top})\boldsymbol{V}=\phi_q(\boldsymbol{Q})(\phi_k(\boldsymbol{K})^{\top}\boldsymbol{V})$, achieving linear complexity. The benefit is ease of implementation, but there are two major problems: one is that low-rankness significantly degrades performance, and the other is that when used as a Decoder (Causal), it sacrifices training parallelism because it must be calculated as an RNN. Alternatively, to keep parallelism, it requires $bhns^2$ space complexity, which only has an advantage over standard Attention's $bhn^2$ when $n \gg s^2$ (even if $s=64$, then $n \gg 4096$), which is unrealistic in most cases.
FLASH takes a "local-global" mixed chunking approach, combining the advantages of "Sparsification" and "Linearization." First, for an input sequence of length $n$, we divide it into $n/c$ non-overlapping chunks of length $c$ (assuming $c$ divides $n$; the paper uses $c=256$). Let $\boldsymbol{U}_g, \boldsymbol{V}_g \in \mathbb{R}^{c \times e}, \boldsymbol{Z}_g \in \mathbb{R}^{c \times s}$ be the $g$-th chunk. As in formula $\eqref{eq:relu-att}$, we pass $\boldsymbol{Z}_g$ through four simple affine transformations to obtain $\boldsymbol{Q}_g^{\text{quad}}, \boldsymbol{K}_g^{\text{quad}}, \boldsymbol{Q}_g^{\text{lin}}, \boldsymbol{K}_g^{\text{lin}}$.
We use $\boldsymbol{Q}_g^{\text{quad}}, \boldsymbol{K}_g^{\text{quad}}$ to calculate intra-chunk self-attention:
\begin{equation}\hat{\boldsymbol{V}}_g^{\text{quad}}=\frac{1}{cs}\text{relu}^2\left(\boldsymbol{Q}_g^{\text{quad}}{\boldsymbol{K}_g^{\text{quad}}}^{\top}\right)\boldsymbol{V}_g\end{equation}This represents internal interaction within each chunk, which is essentially a form of "Sparsification." Its complexity is approximately $\mathcal{O}(n/c \times c^2) = \mathcal{O}(nc)$, which is linear with respect to $n$. Implementation is equivalent to multi-head attention with $n/c$ heads and sequence length $c$, enabling full parallelism. To make it a Decoder, simply mask the upper triangle of the attention matrix.
The remaining $\boldsymbol{Q}_g^{\text{lin}}, \boldsymbol{K}_g^{\text{lin}}$ are used for global Attention, which we compute using Linear Attention:
\begin{equation}\hat{\boldsymbol{V}}_g^{\text{lin}}=\frac{1}{n}\boldsymbol{Q}_g^{\text{lin}}\sum_{h=1}^{n/c} {\boldsymbol{K}_h^{\text{lin}}}^{\top}\boldsymbol{V}_h\end{equation}Note that this operation is perfectly equivalent to using full matrices $\boldsymbol{Q}^{\text{lin}}, \boldsymbol{K}^{\text{lin}} \in \mathbb{R}^{n \times s}$ with $\boldsymbol{V}$ for linear attention; writing it this way just highlights the connection to chunking. For a Decoder, to prevent future information leakage, change it to a cumsum form:
\begin{equation}\hat{\boldsymbol{V}}_g^{\text{lin}}=\frac{1}{(g-1)n/c}\boldsymbol{Q}_g^{\text{lin}}\sum_{h=1}^{g-1} {\boldsymbol{K}_h^{\text{lin}}}^{\top}\boldsymbol{V}_h\end{equation}In this case, to maintain parallelism, we only need $b(n/c)se$ space complexity. Without chunking, linear attention requires $bns^2$ (or $bhns^2$ with multi-head). Under current params where $e/c \ll s$, this is much more memory efficient.
Finally, combine the two attention results and integrate them into GAU to obtain the linear version of GAU:
\begin{equation}\boldsymbol{O}_g=\left[\boldsymbol{U}_g\odot\left(\hat{\boldsymbol{V}}_g^{\text{quad}} + \hat{\boldsymbol{V}}_g^{\text{lin}}\right)\right]\boldsymbol{W}_o\end{equation}The Transformer model built based on this linear version of GAU is the FLASH model.
I believe the reason for this "local-global" mixed chunking, besides reducing computational cost, is that it produces an attention distribution that better fits reality. According to our empirical understanding of NLP, natural language correlations are primarily local. While global, extremely long-distance correlations exist, they are not dominant. Therefore, this mixed attention design helps the model highlight local correlations without discarding long-range ones. The original paper also performed ablation experiments showing that local attention is generally more important than global attention, but the hybrid version performs best.
Additionally, some readers might worry that non-overlapping chunking might disadvantage boundary words. The paper addresses this, stating that introducing more complex overlapping local attention does improve results, but it also increases calculation cost. For the same increase in cost, adding more layers of current non-overlapping GAU provides more gain than shifting to overlapping attention. Thus, the current non-overlapping design strikes a good balance between speed and effect.
Finally, this "Mixed Chunking" linearization is essentially universal. It can be used not only in GAU but also in standard Transformers (keeping standard Attention+FFN but linearizing Attention with mixed chunking). The paper calls this "MC-TFM" and performed comparisons showing that GAU is more advantageous in linearization.
Regarding GAU and FLASH, two experimental results are most noteworthy.
The first is the comparison between the new GAU and standard Multi-Head Attention (MHSA), which is effectively comparing FLASH-Quad with standard Transformer. The results show that regardless of model size, GAU outperforms standard multi-head attention models in both speed and effect (being closer to the top-right of the plot).
The second is the experiment table for the FLASH model, which directly shows:
1. Although both FLASH-Quad and Transformer have quadratic complexity, FLASH-Quad has better results and faster speeds;
2. When the sequence is long enough, the linear-complexity FLASH is faster than FLASH-Quad with similar performance.
The speed increase of FLASH-Quad (still quadratic) is something many so-called linear complexity works fail to achieve, demonstrating the power of GAU. The paper also points out that RoPE (Rotary Positional Embedding), which I proposed previously, significantly improves the performance of both Transformer and FLASH, so Transformer+, Transformer++, FLASH-Quad, and FLASH in the experiments all use RoPE. I'm quite proud of this.
Furthermore, though the table doesn't explicitly compare memory usage, my tests found that at the base-scale with sequence length 1024, the maximum usable batch_size for FLASH-Quad is nearly double that of a standard Transformer. I also attempted building a small version of FLASH-Quad for Chinese pre-training and found it performed slightly better than RoFormer (RoPE+Transformer), confirming the paper's reported results.
The introduction to GAU and FLASH is basically complete. As of this post, the authors have not yet released the full source code on GitHub, but the appendix contains almost usable pseudo-code (TensorFlow version). Implementation should not be difficult. Those with interest and compute can experiment with it.
Now for the "nitpicking"—areas where I feel the paper is not perfect.
Firstly, I feel that FLASH-Quad and FLASH are not decoupled well enough. As stated at the start, both are blockbuster results. To me, FLASH-Quad is even more valuable because the quadratic complexity of self-attention provides enough degrees of freedom for various tricks like UniLM. FLASH-Quad is a very independent and worthy model, but in the paper, it feels like a transition product for FLASH. Fortunately, the authors isolated the GAU concept, which alleviates this.
Secondly, GAU can replace both Attention and FFN. From its design, it seems intended to replace Self-Attention. The authors don't seem concerned with its ability to replace Cross-Attention, and the paper lacks experiments on this. Is it possible for GAU to replace Cross-Attention? Theoretically, formula $\eqref{eq:mix}$ suggests it is, but it's unclear if it can remain single-headed, which is GAU's greatest advantage. Also, the paper only conducted LM and MLM language modeling experiments; it didn't do "pre-training + fine-tuning" experiments, so the transfer performance of GAU is uncertain.
Lastly, something I don't fully understand is why GAU/FLASH-Quad/FLASH use additive absolute, additive relative, and RoPE positioning all at once. Theoretically, one should suffice. In my GAU experiments, using only RoPE worked well. Regarding the appendix code, the authors didn't handle padding very carefully, and the recursive normalization factor for the Decoder (where sum to $t$ should be divided by $t$, not $n$) needs improvement. These are minor details, and the authors' actual code might be correct, with the appendix being simplified for readability.
This article introduced Google's new efficient Transformer work, which merges Attention and FFN into a new GAU layer to create FLASH-Quad and then further proposes a "Mixed Chunking" linearization to create FLASH. Experimental results show that compared to standard Transformers, both are faster, more memory-efficient, and better. Perhaps in the near future, what "You Need" will no longer be Attention, but GAU.