By 苏剑林 | July 04, 2020
As is well known, although Transformer-based models using the Attention mechanism have excellent parallel performance, their spatial and temporal complexity are both $\mathcal{O}(n^2)$, where $n$ is the sequence length. Therefore, when $n$ is large, the computational cost of Transformer models becomes unbearable. Recently, many works have been dedicated to reducing the computational volume of Transformer models, such as model pruning, quantization, distillation, and other compression techniques, or modifying the Attention structure to reduce its complexity to $\mathcal{O}(n \log n)$ or even $\mathcal{O}(n)$.
A few days ago, I read the paper "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" and learned about the exploration of Linear Attention. Subsequently, I read some related literature and gained some good insights. Finally, I have summarized my understanding of Linear Attention in this article.
The most popular Attention mechanism today is Scaled-Dot Attention, in the form: \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\label{eq:std-att}\end{equation} Here $\boldsymbol{Q}\in\mathbb{R}^{n\times d_k}, \boldsymbol{K}\in\mathbb{R}^{m\times d_k}, \boldsymbol{V}\in\mathbb{R}^{m\times d_v}$. For simplicity, we haven't explicitly written the denominator scaling factor of Attention. In this article, we are mainly concerned with the Self-Attention scenario, so for convenience, we define $\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}\in\mathbb{R}^{n\times d}$. In general scenarios, $n > d$ or even $n \gg d$ (in BERT-base, $d=64$). For related interpretations, you can refer to the author's "Attention is All You Need" Brief Reading (Introduction + Code), as well as some improvement works such as "Breaking the Bottleneck: Building More Powerful Transformers" and "Google's Synthesizer: We Don't Know Enough About Self-Attention Yet". We won't delve too deeply here.
Readers may not realize that the key factor restricting Attention performance is actually the Softmax in the definition! In fact, a simple derivation leads to this conclusion. In the step $\boldsymbol{Q}\boldsymbol{K}^{\top}$, we obtain an $n\times n$ matrix, which is the step that determines the complexity of Attention is $\mathcal{O}(n^2)$. If there were no Softmax, there would be a continuous multiplication of three matrices $\boldsymbol{Q}\boldsymbol{K}^{\top}\boldsymbol{V}$. Since matrix multiplication satisfies the associative property, we can first calculate $\boldsymbol{K}^{\top}\boldsymbol{V}$ to get a $d\times d$ matrix, and then left-multiply it by $\boldsymbol{Q}$. Since $d \ll n$, the computational complexity of this method is roughly $\mathcal{O}(n)$ (dominated by the left-multiplication step with $\boldsymbol{Q}$).
In other words, by removing Softmax, the complexity of Attention can be reduced to the ideal linear level $\mathcal{O}(n)$! This is clearly our ultimate pursuit: Linear Attention, an Attention with linear complexity. Therefore, the theme of this article is to explore Linear Attention after removing Softmax.
The question is, can it still be called Attention if Softmax is directly removed? Can it still have the effect of standard Attention? To answer this, let's rewrite the definition of Scaled-Dot Attention \eqref{eq:std-att} into an equivalent form (vectors in this article are column vectors): \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\label{eq:std-att-2}\end{equation} Thus, Scaled-Dot Attention is essentially a weighted average of $\boldsymbol{v}_j$ using $e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}$ as weights. Therefore, we can propose a generalized definition of Attention: \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\label{eq:gen-att}\end{equation} This involves replacing $e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}$ with a general function $\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)$ of $\boldsymbol{q}_i, \boldsymbol{k}_j$. To preserve the distributional characteristics similar to Attention, we require $\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0$ to always hold. That is, if we want to define a new type of Attention, we must retain the form of Equation \eqref{eq:gen-att} and satisfy $\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0$.
This general form of Attention is also known as a Non-Local network in Computer Vision (CV), originating from the paper "Non-local Neural Networks".
If Softmax is directly removed, $\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \boldsymbol{q}_i^{\top}\boldsymbol{k}_j$. The problem is that the inner product cannot guarantee non-negativity, so this is not a reasonable choice. Below we briefly introduce several viable schemes.
It is worth noting that the first two Linear Attentions introduced below come from the CV field, and the third one is my own conception. Therefore, none of them have been extensively tested on NLP tasks yet. For NLPers working on model improvements, these represent potential experimental directions (^_^). By the way, there are many works in the CV field for improving Attention (besides the ones introduced below, there is also EMANet, etc.), and many contents are worth referencing for those of us in NLP.
A natural idea is: if every element of $\boldsymbol{q}_i, \boldsymbol{k}_j$ is non-negative, then the inner product will naturally be non-negative. To achieve this, we can apply an activation function $\phi, \varphi$ to each of $\boldsymbol{q}_i, \boldsymbol{k}_j$, i.e., \begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\label{eq:gen-att-2}\end{equation} where $\phi(\cdot), \varphi(\cdot)$ are activation functions with non-negative ranges. The paper mentioned at the beginning, "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention", chooses $\phi(x)=\varphi(x)=\text{elu}(x)+1$.
Metaphorically speaking, Equation \eqref{eq:gen-att-2} can be linked to "kernel methods." Especially when $\phi=\varphi$, $\phi$ is equivalent to a kernel mapping, and $\langle \phi(\boldsymbol{q}_i), \phi(\boldsymbol{k}_j)\rangle$ is an inner product defined via a kernel function. Thinking in this area can refer to the paper "Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel". We won't expand further here.
Another earlier article, "Efficient Attention: Attention with Linear Complexities", provides a more interesting choice. It notes that in $\boldsymbol{Q}\boldsymbol{K}^{\top}$, where $\boldsymbol{Q}, \boldsymbol{K} \in\mathbb{R}^{n\times d}$, if "$\boldsymbol{Q}$ is normalized along the $d$ dimension and $\boldsymbol{K}$ is normalized along the $n$ dimension," then $\boldsymbol{Q}\boldsymbol{K}^{\top}$ is automatically normalized. Thus, the choice it provides is: \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\end{equation} where $softmax_1$ and $softmax_2$ refer to performing the Softmax operation along the first ($n$) and second ($d$) dimensions, respectively. In other words, at this point, we apply Softmax to $\boldsymbol{Q}$ and $\boldsymbol{K}$ individually, instead of adding Softmax after $\boldsymbol{Q}\boldsymbol{K}^{\top}$ is calculated.
If we directly take $\phi(\boldsymbol{q}_i)=softmax(\boldsymbol{q}_i)$ and $\varphi(\boldsymbol{k}_j)=softmax(\boldsymbol{k}_j)$, then it is obvious that this form is also a special case of Equation \eqref{eq:gen-att-2}. Additionally, this design has appeared more than once in CV, such as in A2-Nets, which includes the same approach.
Here, I present a conception of my own. The starting point for this idea is not Equation \eqref{eq:gen-att-2}, but rather an approximation of the original definition \eqref{eq:std-att-2}. From Taylor expansion, we have: \begin{equation}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\end{equation} If $\boldsymbol{q}_i^{\top}\boldsymbol{k}_j \geq -1$, then the non-negativity of the right side can be guaranteed, allowing for $\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)=1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j$. At this point, readers may have realized that to ensure $\boldsymbol{q}_i^{\top}\boldsymbol{k}_j \geq -1$, one only needs to apply $l_2$ normalization to $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$ respectively. Therefore, the scheme I eventually proposed is: \begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\end{equation} This differs from the form of Equation \eqref{eq:gen-att-2}, but theoretically, it is closer to the original Scaled-Dot Attention.
There are many works related to modifying the Attention form to reduce its computational complexity. Some are briefly listed here.
We previously introduced OpenAI's Sparse Attention, which reduces the amount of Attention calculation by "only retaining values within a small area and forcing most attention to zero." After special design, most elements of the Attention matrix are 0, so in theory, it can save memory footprint and computation. Subsequent similar works include "Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection" and "Longformer: The Long-Document Transformer".
However, it is clear that this approach has two shortcomings:
1. How to choose the attention area to be retained is determined subjectively by humans, involving a high degree of non-intelligence;
2. It requires specific design optimization in programming to obtain an efficient implementation, so it is not easily generalized.
Reformer is also a representative improvement work that reduces the complexity of Attention to $\mathcal{O}(n\log n)$. In a sense, Reformer is also a type of Sparse Attention, but its sparsity pattern is not predetermined. Instead, it uses LSH (Locality Sensitive Hashing) technology to (approximately) quickly find the largest several Attention values and then only calculates those values. Furthermore, Reformer reduces memory consumption by constructing an FFN (Feedforward Network) in reversible form to replace the original FFN, and redesigning the backpropagation process.
So, compared to the aforementioned Sparse Attention, Reformer solves the first shortcoming but still has the second: high implementation complexity. Implementing LSH-form Attention is much more complicated than standard Attention, and rewriting the backpropagation process for a reversible network is even more out of reach for average readers.
Work very similar to the Linear Attention introduced in this article is Linformer, recently released by Facebook. It still retains the original Scaled-Dot Attention form, but before performing Attention, it uses two $m\times n$ matrices $\boldsymbol{E}, \boldsymbol{F}$ to project $\boldsymbol{K}$ and $\boldsymbol{V}$ respectively, i.e., it becomes: \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}(\boldsymbol{E}\boldsymbol{K})^{\\top}\right)\boldsymbol{F}\boldsymbol{V}\end{equation} In this way, $\boldsymbol{Q}(\boldsymbol{E}\boldsymbol{K})^{\top}$ is only an $n\times m$ matrix, and the authors claim that even for very large sequence lengths $n$, $m$ can remain an appropriate constant, thereby making this Attention linear. Similar ideas to Linformer also appeared in the earlier CV paper "Asymmetric Non-local Neural Networks for Semantic Segmentation".
However, I believe the conclusion that "$m$ can remain constant for super-long sequences" is questionable. For long sequences, the original paper only performed MLM tasks, and it is clear that MLM does not heavily require long-range dependencies, so this experiment is not very persuasive. Therefore, whether Linformer is truly Linear is still up for debate.
Another disadvantage of Linformer is that the operations $\boldsymbol{E}\boldsymbol{K}$ and $\boldsymbol{F}\boldsymbol{V}$ directly "mix" the information of the entire sequence, so it cannot simply mask out future information (Causal Masking). Thus, it cannot perform autoregressive generation tasks like language models or Seq2Seq, which is also why the original authors only did MLM tasks. In contrast, several Linear Attentions introduced in this article can achieve this. Taking Equation \eqref{eq:gen-att} and Equation \eqref{eq:gen-att-2} as examples, to mask future information, one only needs to change the sum $\sum\limits_{j=1}^n$ to $\sum\limits_{j=1}^i$: \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^i \left(\phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\right)\boldsymbol{v}_j}{\sum\limits_{j=1}^i \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)}=\frac{ \phi(\boldsymbol{q}_i)^{\top} \sum\limits_{j=1}^i\varphi(\boldsymbol{k}_j)\boldsymbol{v}_j^{\top}}{ \phi(\boldsymbol{q}_i)^{\top} \sum\limits_{j=1}^i\varphi(\boldsymbol{k}_j)}\end{equation} There are two ways to implement the above equation. The first way is to set $\boldsymbol{S}_i = \sum\limits_{j=1}^i \varphi(\boldsymbol{k}_j)\boldsymbol{v}_j^{\top}$ and $\boldsymbol{z}_i = \sum\limits_{j=1}^i \varphi(\boldsymbol{k}_j)$. We have: \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i =\frac{ \phi(\boldsymbol{q}_i)^{\top} \boldsymbol{S}_i}{ \phi(\boldsymbol{q}_i)^{\top} \boldsymbol{z}_i},\quad \begin{aligned}&\boldsymbol{S}_i=\boldsymbol{S}_{i-1}+\varphi(\boldsymbol{k}_i)\boldsymbol{v}_i^{\top}\\ &\boldsymbol{z}_i=\boldsymbol{z}_{i-1}+\varphi(\boldsymbol{k}_i) \end{aligned}\end{equation} This indicates that this type of Attention can be implemented as an RNN model in a recursive manner. Its spatial complexity is the lowest, but it requires sequential calculation, making it suitable for decoding during prediction. The second way is to directly perform the outer product of $\varphi(\boldsymbol{K}), \boldsymbol{V} \in\mathbb{R}^{n\times d}$ to get an $n\times d\times d$ matrix, and then perform a $cumsum$ operation along the $n$ dimension. This allows obtaining $\boldsymbol{S}_1, \boldsymbol{S}_2, \dots, \boldsymbol{S}_n$ all at once. Its speed is the fastest, but its spatial occupancy is largest, making it suitable for training. However, since $d^2 \gg n$ in many cases, it is usually difficult to bear this spatial complexity during training, so the RNN form is still mostly used.
From the results, Linformer's $\boldsymbol{E}\boldsymbol{K}, \boldsymbol{F}\boldsymbol{V}$ simply shortens the sequence (downsampling). The most basic method for shortening a sequence is Pooling, so I previously tried introducing Pooling technology into Transformers. Recently, similar works have also been released, such as IBM's "PoWER-BERT: Accelerating BERT Inference via Progressive Word-vector Elimination" and Google's "Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing". Besides Pooling, there are other downsampling techniques, such as realization via 1D convolution with stride > 1. Based on this idea, perhaps we can replace the Position-Wise full connection in FFN with 1D convolution with stride > 1? In short, many variations can be played in this area, but like Linformer, it is difficult to perform autoregressive generation after such mixing.
This article introduced some works that modify the Attention structure to reduce its computational complexity. The main idea is that removing Softmax from standard Attention can degrade the complexity of Attention to the ideal $\mathcal{O}(n)$ level (Linear Attention). Compared to other similar improved structures, this modification can reduce complexity to $\mathcal{O}(n)$ while retaining all "token-token" attention and preserving the possibility for autoregressive generation.