Chapter of Space and Time: Viewing Attention as an RNN with Squared Complexity

By 苏剑林 | March 18, 2024

In recent years, Recurrent Neural Networks (RNNs) have regained significant interest among researchers and users due to their linear training and inference efficiency, suggesting a sort of "renaissance." Representative works include RWKV, RetNet, and Mamba. When RNNs are used for language modeling, their typical characteristic is that each generation step has constant space and time complexity; for the entire sequence, this results in constant space complexity and linear time complexity. Of course, everything has two sides. Compared to the dynamically growing KV Cache of Attention, the constant space complexity of RNNs often leads people to suspect a finite memory capacity, making it difficult for them to match Attention's performance on Long Context tasks.

In this article, we show that Causal Attention can be rewritten in the form of an RNN, and its generation at each step can theoretically be performed with $\mathcal{O}(1)$ space complexity (at the cost of extremely high time complexity, far exceeding squared complexity). This indicates that the advantage of Attention (if any) is achieved through computational "stacking" rather than an intuitive "stacking" of memory; like RNNs, it essentially possesses a constant-magnitude memory capacity (memory bottleneck).

RNNs Beyond Linearity

RNN supporters often pose a seemingly undeniable point: think about your brain—is it an RNN or Attention?

Intuitively, the space complexity of RNN inference is constant, while Attention's KV cache grows dynamically. Given that human brain capacity is finite, RNNs indeed seem closer to the human brain from this perspective. However, even if we reasonably believe that brain capacity limits the space complexity of each inference step to a constant, it does not limit the time complexity of each step to be constant. Or to put it another way, even if each time step for a human is constant, humans do not necessarily scan a sequence of length $L$ only once (like "flipping through a book"). Thus, the total number of inference steps might significantly exceed $L$, leading to non-linear time complexity.

Considering this, the author had a "sudden epiphany": can we generalize RNN models to consider constant space complexity but non-linear time complexity to compensate for the abilities mainstream RNNs lack (such as the aforementioned "page-flipping")? For a language modeling task, assuming the sample is "a b c d e," the training task is to input "a b c d" and predict "b c d e." A common RNN is shown below:

Figure 1: Common RNN

The problem with this kind of RNN is the lack of "page-flipping" ability; each input is discarded after being read. The characteristic of Attention is that for every token read, it completely "flips through" the entire history. While this approach may have efficiency issues, it is undoubtedly the simplest and most brutal way to introduce page-flipping capability. To give RNNs this ability, we can perfectly imitate Attention's approach to using RNNs:

Figure 2: Continuously "page-turning" RNN

Like Attention, for every new token read, it flips through the complete history once. Of course, one could argue this isn't a new RNN design but rather a new way of using RNNs, simply modifying the input—whether it's RWKV or Mamba, they can all be adapted to this. Under this usage, decoding can still be completed within constant space complexity, but the time complexity of each inference step grows linearly, resulting in a total time cost of $\mathcal{O}(L^2)$.

Attention is Also an RNN

In fact, the model represented by Figure 2 is very broad; even Attention is merely a special case of it, as shown below:

Figure 3: RNN corresponding to Causal Attention

Compared to Figure 2, several arrows in Figure 3 are faded, representing that these positions are actually disconnected, which illustrates how Attention is a special case of Figure 2. Specifically, the calculation formula for Attention is:

\begin{equation}o_i = \sum_{j=1}^i a_{i,j}v_j = \frac{\sum_{j=1}^i e^{q_i\cdot k_j} v_j}{\sum_{j=1}^i e^{q_i\cdot k_j}}\end{equation}

Evidently, the sums in the numerator and denominator can be written in recursive form:

\begin{equation} \begin{pmatrix} y_i^{(t)} \\ z_i^{(t)} \end{pmatrix} = \begin{pmatrix} y_i^{(t-1)} \\ z_i^{(t-1)} \end{pmatrix} + e^{q_i\cdot k_{i-t+1}}\begin{pmatrix} v_{i-t+1} \\ 1 \end{pmatrix}\quad,\quad o_i = \frac{y_i^{(i)}}{z_i^{(i)}} \end{equation}

According to the literature the author has read, the first paper to propose the above equation and use it to optimize Attention calculation is "Self-attention Does Not Need O(n^2) Memory". The block matrix version of the above equation is the theoretical foundation of mainstream acceleration technologies like Flash Attention. Since in Self Attention, $Q, K, V$ are all obtained from the same input through token-wise operations, the above recursive form can be represented exactly as Figure 3.

Of course, Figure 3 only illustrates a single layer of Attention. Multiple layers can naturally be drawn, though the connections become somewhat complex. For instance, the case for two layers is shown below:

Figure 4: RNN corresponding to two-layer Attention

Constant Space Complexity

As mentioned at the beginning of this article, a common advantage of RNNs is the ability to perform inference with constant space complexity and linear time complexity. Since Attention can also be written as an RNN, the natural question is: does it also possess these two advantages in this form?

Clearly, since the RNN corresponding to Attention has a sequence length increased to $\mathcal{O}(L^2)$, linear time complexity is out of the question. The only thing worth considering is whether it can achieve constant space complexity. One's first reaction might be "no," because it is well known that Attention decoding requires a dynamically growing KV cache. However, this is only the case for typical efficient implementations. If we trade time for space regardless of cost, how far can we reduce the space complexity?

The answer might be surprising: If the space-time tradeoff is pushed to the limit, the space complexity can indeed be reduced to $\mathcal{O}(1)$!

This conclusion is not hard to imagine. First, for the single-layer Attention shown in Figure 3, the form is no different from an ordinary single-layer RNN; thus, it can obviously be inferred using a fixed amount of storage space. Next, looking at the multi-layer Attention in Figure 4, the connections between layers are complex, so history K and V are usually cached for efficient computation. But if we resolutely refuse to store KV cache, the $K, V$ inputs for every layer and every inference step can be recomputed entirely from the original input (recomputation). This leads to a massive amount of redundant calculation, causing the total time complexity to far exceed squared complexity—very "un-eco-friendly"—but the space complexity can indeed be maintained at $\mathcal{O}(1)$.

Taking two-layer Attention as an example: the second layer of Attention uses the outputs of the first layer as its input. Every output of the first layer can be computed in $\mathcal{O}(1)$ space. Therefore, as long as we are willing to sacrifice efficiency for recomputation, the second layer of Attention can also be completed in $\mathcal{O}(1)$ space. By extension, the third layer uses the second layer's output, and the $N$-th layer uses the $(N-1)$-th layer's output. Since each previous layer can be completed in $\mathcal{O}(1)$ space through recomputation, every layer and consequently the entire model can be completed in $\mathcal{O}(1)$ space.

This returns to the point made at the start: if Attention has any advantage over RNNs, it is only achieved through more computation. The intuitive feeling of "expanding memory" is just a superficial manifestation of trading space for time; like RNNs, it essentially has a memory bottleneck of constant capacity.

Of course, some readers might think: isn't trading time for space a very common practice? Is this a valuable conclusion? Indeed, trading time for space is common, but it is not always possible. In other words, not all problems can have their space complexity reduced to $\mathcal{O}(1)$ through time-space tradeoffs. This is a common but non-trivial characteristic.

Reflections on Model Capabilities

The reason for pointing out this characteristic of Attention is not to actually use it for inference, but to help us further reflect on the Bottlenecks of Attention's capabilities.

First, if we really get into the details, $\mathcal{O}(1)$ is not strictly correct; it should more accurately be $\mathcal{O}(L)$. This is because an RNN with squared complexity needs to repeatedly scan the history sequence, which at minimum requires storing the original input and the outputs generated during the process—meaning at least $L$ integer token IDs must be stored, which takes $\mathcal{O}(L)$ space. If $L$ is large enough, $\mathcal{O}(L)$ will be larger than $\mathcal{O}(1)$. However, $\mathcal{O}(1)$ here mainly refers to the minimum space required for the LLM's internal calculation layers, equivalent to the `hidden_state` when viewed as an RNN, which has at least (hidden_size * num_layers * 2) components, while $\mathcal{O}(L)$ space is reflected in input and output. A visual analogy is to treat Attention as a computer with an infinite hard drive but fixed RAM; it continuously reads data from the hard drive, computes in RAM, and writes results back to the hard drive.

We know that if the RAM itself is very large while the processed data is small, we tend to be more "extravagant" during programming, perhaps loading all data into RAM and having the calculation process completely independent of hard drive I/O. Similarly, LLMs trained under the "Large Model, Short Sequence" background tend to use the $\mathcal{O}(1)$ fixed "RAM" brought by model scale rather than the dynamic "Hard Drive" brought by sequence length. In current LLM scales, the former is large enough that Stochastic Gradient Descent (SGD) "lazily" treats the model as a machine with infinite static memory (because for short sequences, RAM is always sufficient). But in reality, the static memory of the model is finite. Therefore, for tasks that cannot be completed in $\mathcal{O}(1)$ space, Attention-based models cannot generalize to inputs of arbitrary length.

For example, if we want to calculate the decimal representation $y$ of $2^x$ using Attention for conditional modeling $p(y|x)$, the training corpus would be $\{x, \color{red}{[sep]}, y\}$ concatenated, calculating the loss only for $y$. Note that $y$ is uniquely determined by $x$, so in theory, 100% accuracy should be learnable. However, if there is no Chain of Thought (CoT) to dynamically increase sequence length, the model can only place all calculations implicitly into "RAM," which is effective for short inputs. But in fact, RAM is finite, while the space required to calculate $2^x$ increases as $x$ increases. Thus, there must exist a sufficiently large $x$ where the accuracy of $p(y|x)$ cannot reach 100% (even training accuracy). This is different from the length extrapolation issues discussed in "Transformer Upgrade Road: 16. 'Reviewing' Length Extrapolation Techniques"; it is not caused by the OOD of position encoding, but rather a capacity defect brought about by "Large Model, Short Sequence" training without enough CoT guidance.

So why is the mainstream scale-up direction still increasing LLM RAM (i.e., increasing `hidden_size` and `num_layers`) instead of researching schemes to increase `seq_len` like CoT? The latter is certainly a mainstream research area, but the core problem is that if RAM becomes a bottleneck, it reduces the model's learning efficiency and universality. It is like when RAM is small but data volume is large: we need to frequently save results to the hard drive and clear the RAM, which means the algorithm must be more ingenious, harder to write, and possibly tailored to specific tasks. Under what circumstances does a RAM bottleneck occur? Taking Llama2-70B as an example, its `num_layers` is 80 and `hidden_size` is 8192; multiplying them gives 640K, and multiplying by 2 is roughly 1M. In other words, when the input length reaches the level of 1M tokens, Llama2-70B's "RAM" might become a bottleneck. Although training 1M token LLMs is still not easy, it is no longer out of reach—for instance, Kimi has already launched a 1M-level model internal test.

Therefore, continuously increasing the model's context length (Hard Drive) to accommodate more input and CoT, while simultaneously increasing the model's own scale so that "RAM" is not a bottleneck, has become the main theme for current LLMs.

At the same time, this also negates a previous thought of the author: could we achieve the same effect as large models by shrinking model size and increasing `seq_len`? The answer is likely no, because small models have RAM bottlenecks. To compensate for this using the "Hard Drive" of `seq_len`, one would need to provide sufficiently long CoT for every sample, which is harder than training the large model directly. If `seq_len` is increased through simple schemes like repetition, it brings no substantial benefit because no additional information is introduced. However, if the increase in `seq_len` is achieved through prefix tuning, it might bridge the gap in space complexity because prefix parameters are not computed from the input sequence but are trained separately. This is equivalent to inserting extra "RAM sticks," thereby increasing the model's memory.

Final Summary

In this article, we examined Attention from the perspective of squared-complexity RNNs and discovered it has a constant space complexity bottleneck. This indicates that Attention does not essentially increase "memory" compared to RNNs; it only increases the amount of computation. The existence of this bottleneck suggests that Attention may face theoretical difficulties in length generalization for certain tasks (insufficient RAM). Guiding models to better utilize the dynamic "Hard Drive" provided by the `seq_len` dimension may be the key to solving this difficulty.