By 苏剑林 | August 09, 2021
In this blog, we have discussed the content related to Linear Attention several times. The logic for introducing Linear Attention is generally: standard Attention has a quadratic complexity of $\mathcal{O}(n^2)$, which is one of its major "pain points." Therefore, we introduce improved models with $\mathcal{O}(n)$ linear complexity, known as Linear Attention. After seeing introductions to Linear Attention, some readers have been eagerly awaiting our release of pre-trained models based on Linear Attention, hoping to alleviate the "life-and-death" struggle caused by BERT's computational consumption.
However, what this article aims to say is: readers holding onto this idea might be disappointed. The conversion from standard Attention to Linear Attention is likely to fall far short of your expectations, and the reason BERT is so slow is not actually because of the quadratic complexity of standard Attention.
According to intuitive understanding, shouldn't replacing quadratic complexity with linear complexity lead to a "massive leap"? Why would it "fall far short of expectations"? The primary reason for this doubt is that we have long failed to carefully evaluate the overall computational volume of conventional Transformer models (such as BERT).
Many readers already know that the Transformer structure generally consists of an Embedding layer plus several Transformer layers. The computational load of the Embedding layer is minimal; we mainly care about the Transformer layers. Ignoring layers with relatively small computational loads like residuals and Layer Normalization, each Transformer layer is mainly composed of two sub-layers: Self Attention (SA) and the FeedForward Network (FFN). Although the seminal Transformer paper claimed that "Attention is all you need," many subsequent works have demonstrated the necessity of modules like residuals and FFN, such as "Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth".
Now, a question for everyone:
Which do you think has a larger computational load, SA or FFN?
Doubtless, the complexity of SA is $\mathcal{O}(n^2)$, while the complexity of FFN is $\mathcal{O}(n)$. If you simply assume SA's computational load is larger than FFN's based on this, you would be wrong!
We know that addition is much faster than multiplication, so when estimating computational volume, we primarily count how many multiplications are performed. In neural networks, the main operations are matrix multiplications. It is easy to estimate that, by definition, multiplying an $a \times b$ matrix by a $b \times c$ matrix requires $abc$ multiplications. Thus, $abc$ is the complexity of multiplying two matrices, which is our basis for estimating Transformer complexity.
Let $n$ be the sequence length, $d$ be the head_size (64 in the base version), and $h$ be the number of heads (12 in the base version). Then $hd$ is what we usually call the "hidden_size" (768 in the base version). For SA, first are the $Q, K, V$ projection transformations: an $n \times hd$ matrix multiplied by an $hd \times hd$ matrix, performed 3 times. Thus, the calculation volume is $3n(hd)^2$. Next is the computation of $h$ Attention heads. Each head involves multiplying an $n \times d$ $Q$ matrix with a $d \times n$ $K^{\top}$ matrix to get an $n \times n$ Attention matrix (ignoring softmax and normalization for now), and then multiplying the $n \times n$ matrix with an $n \times d$ $V$ matrix to get an $n \times d$ matrix. Both steps have a computational volume of $n^2 d$. So the total calculation for the heads is $h(n^2 d + n^2 d)$. Finally, the output has another projection transformation, which is an $n \times hd$ matrix multiplied by an $hd \times hd$ matrix, with a computational volume of $n(hd)^2$. Thus, the total computational volume for SA is:
\begin{equation}3n(hd)^2 + h(n^2 d + n^2 d) + n(hd)^2 = 4nh^2 d^2 + 2n^2 hd\end{equation}As for FFN, it is simpler. It consists of two fully connected layers, meaning two matrix transformations (ignoring the calculation of activation functions). The general parameter setting is: the first layer is an $n \times hd$ matrix multiplied by an $hd \times 4hd$ matrix, and the second layer is an $n \times 4hd$ matrix multiplied by a $4hd \times hd$ matrix. Therefore, the total computational volume is:
\begin{equation}n\times hd\times 4hd + n\times 4hd\times hd = 8nh^2 d^2\end{equation}In this way, if the computational load of SA were larger than that of FFN, it would mean:
\begin{equation}4nh^2 d^2 + 2n^2 hd > 8nh^2 d^2\quad\Leftrightarrow\quad n > 2hd\end{equation}For the base version, this means $n > 1536$! In other words, only when the sequence length exceeds 1536 does the computational load of SA become larger than that of FFN. Before that, the linear-complexity FFN dominates!
There's more. From the results above, we can obtain the total computational volume of a Transformer layer as:
\begin{equation}4nh^2 d^2 + 2n^2 hd + 8nh^2 d^2 = 12nh^2 d^2 + 2n^2 hd\end{equation}This is the sum of a linear term and a quadratic term regarding $n$. When $n$ is large enough, the complexity is naturally $\mathcal{O}(n^2)$. However, the condition for the quadratic term to dominate is:
\begin{equation}2n^2 hd > 12nh^2 d^2\quad\Leftrightarrow\quad n > 6hd\end{equation}For the base version, this means $n > 4608$! That is to say, only when the sequence length approaches 5000 does the complexity of the Transformer truly begin to manifest its quadratic nature!
Synthesizing the above results, we can conclude: for the base version, when the sequence length does not exceed 1536, the Transformer's complexity is nearly linear. When the sequence length exceeds 1536, the computational volume of the Transformer gradually becomes dominated by Attention, and the complexity slowly trends toward quadratic, until the length exceeds 4608, where the quadratic term truly dominates. Of course, these boundaries are estimates; actual conditions may deviate slightly. Everyone should just perceive the range and order of magnitude.
I have previously suggested to many readers that for "long text" tasks with lengths not exceeding 2000, they should just try models like NEZHA or RoFormer that do not have fixed length limits, without overthinking techniques. The reason is the same. No matter how many techniques you use, at most you reduce it to linear complexity, and within this length range, the model itself is already nearly linear. Various techniques won't save much.
For readers dutifully using BERT-base, the maxlen generally does not exceed 512, which is far below the aforementioned boundaries. Therefore, please stop complaining about Attention's quadratic complexity taxing the hardware, because the truth is:
BERT is slow primarily because it is truly large, not because of the quadratic complexity of Attention.
As for why there is "far less expectation" regarding Linear Attention, another reason is the failure to analyze the computational volume of Linear Attention from a practical perspective, resulting in overly high expectations.
For an introduction to Linear Attention, you can refer to "Exploration of Linear Attention: Must Attention have a Softmax?"; I will not repeat it here. Simply put, Linear Attention calculates attention in the order of $Q(K^{\top} V)$. Thus, following the previous estimation method, the computational volume for each head in Linear Attention is $2nd^2$, whereas standard Attention is $2n^2 d$. Therefore, if $n > d$, Linear Attention saves computational volume compared to standard Attention. (Note: There is more than one way to achieve linear efficiency in Attention, but generally, the complexity is similar, so the following conclusion is representative.)
For the base version, that means $n > 64$, a boundary that is easily surpassed. So some readers might think, "Every little bit counts," or "Might as well use it." However, this is based on the assumption that standard Attention and Linear Attention both use the same $d$. Readers who have carefully pondered "Performer: Linearizing Attention Complexity with Random Projections" and "Transformer Upgrade Road: 3. From Performer to Linear Attention" know that Linear Attention suffers from a more severe "low-rank bottleneck" than standard Attention. Thus, if you switch to Linear Attention while keeping the same $d$, the performance will drop significantly. To maintain roughly the same effect, Linear Attention needs a larger $d$ (generally about 4 times the original).
In this case, the computational volume of Linear Attention would be $2n(4d)^2$. If Linear Attention is to be faster than standard Attention, then $n > 16d$. For the base version, this means $n > 1024$, which is also beyond the range most readers encounter. Furthermore, after switching to Linear Attention, the previous conclusions regarding SA and FFN computational volume still hold—namely, that for most sequence lengths, the dominant computational load is still FFN and other linear operations. Switching to Linear Attention will not yield a noticeable speed boost. Thus, in summary:
Unless you are dealing with sequence lengths in the tens of thousands, don't worry about switching to Linear Attention.
In fact, even without the above analysis, any reader who has seriously read work related to Attention efficiency improvements would reach similar conclusions from certain figures in the papers: so-called "efficient" Attention mechanisms generally only apply to sequence lengths in the thousands or tens of thousands. Only in such scenarios is there a significant performance improvement.
For example, in the earlier work Sparse Transformers, there is a chart showing that the sequence lengths processed are all 3000+:
Sparse Transformer processes lengths that are all 3000+
Take the famous Reformer; the sequence lengths used to demonstrate performance are all in units of K (thousands):
Reformer demonstrates performance with sequence lengths in units of K
The highly praised Longformer is the same:
Longformer demonstrates performance with sequence lengths of several thousands or even tens of thousands
And then there is Google's classic work on Linear Attention, Performer, which shows that even if the sequence length is $2^{12}=4096$, the gap between Performer and Transformer isn't particularly significant:
Performer performance curve
Finally, the relatively newer work Luna provides a fairly comprehensive comparison table, which also supports our conclusion:
Comparison of performance for various improved Attention mechanisms in Luna
From the existing various efficient Attention works, we can conclude: the sequence lengths these improvement works care about are primarily in the thousands. A sequence length with a clear boost in computational efficiency must generally be several thousand. Of course, our discussion here mainly targets time complexity. For spatial complexity—that is, memory occupancy—the degree of reduction is typically larger than the degree of speedup in time complexity, but overall, it's only valuable for long sequences.
Therefore, if your sequence length is only one or two hundred, then stop expecting anything from improvements to Attention itself—just switch to a smaller model. You can hope that in the future, smaller models will achieve the same good results, but do not expect similarly large models to improve efficiency by modifying Attention. To put it bluntly, even if you removed Attention entirely, it wouldn't improve performance by much.