Google's New Work Attempts to "Resurrect" RNN: Can RNN Shine Again?

By 苏剑林 | March 28, 2023

Currently, Large Language Models (LLMs) like ChatGPT are "sweeping the globe." Some readers have noticed that almost all LLMs still use the original Multi-Head Scaled-Dot Attention. In recent years, a large amount of work on Efficient Transformers, such as Linear Attention and FLASH, has not been adopted. Is it because their performance is too poor, or is there simply no need to consider efficiency? Actually, I analyzed the answer in "Linear Transformer is Probably Not the Model You Are Waiting For". Standard Attention only exhibits quadratic complexity when the sequence length significantly exceeds the hidden size; before that, it remains close to linear. Its speed is faster than many efficient improvements. Since GPT-3 uses hidden sizes in the tens of thousands, it means that as long as your LLM is not targeting text generation with lengths in the tens of thousands, efficient improvements are unnecessary. Often, speed is not improved, while performance decreases.

So, when there is a real demand for processing sequences with lengths of tens or even hundreds of thousands, what model should we use? Recently, a paper from Google, "Resurrecting Recurrent Neural Networks for Long Sequences", has re-optimized the RNN model, specifically pointing out the advantages of RNNs in scenarios involving ultra-long sequences. So, can RNNs shine again?

Linearization

The RNN proposed in the article is called LRU (Linear Recurrent Unit). it is a minimalist linear RNN that can be both parallelized and serialized, offering efficiency in both training and inference. LRU shares many similarities with works like SSM (Structured State Model) and RWKV. In fact, the starting point for LRU was the discovery that SSM performs very well on LRA (Long Range Arena), leading to the search for ways to make native RNNs perform well on LRA as well; the result is LRU. Regrettably, the original paper only conducted experiments on LRA (a benchmark for testing long-range dependency capabilities). At the end of this article, I will supplement some results from my own experiments on language modeling.

The original paper begins with SSM and spends considerable space describing the connection between LRU and SSM. In this article, we will skip those descriptions and directly introduce LRU as an independent RNN model. We know that the simplest RNN can be written as: \begin{equation}x_t = f(Ax_{t-1} + u_t)\end{equation} where $x_t, u_t \in \mathbb{R}^d, A \in \mathbb{R}^{d \times d}$, and $f$ is an activation function. In general, there are projection matrices before $u_t$ and after $x_t$, but here we focus on the recurrence itself, so we won't write them explicitly.

Traditional wisdom holds that activation functions are non-linear, with common choices being $\text{sigmoid}, \tanh, \text{relu}$, etc. In particular, some work has shown that a single-layer RNN with $\text{sigmoid}$ or $\tanh$ activation is Turing complete, which makes people firm believers in the necessity of non-linear activation functions. However, in deep learning, experiment is the only criterion for testing truth. The authors found that if the Self-Attention in a Transformer is replaced with an RNN, the linear RNN actually performs best:

On various tasks in LRA, linear RNN is actually the best

This is surprisingly good news. "Surprising" because it might overturn some readers' perceptions about the model's need for non-linearity; of course, some readers might not be surprised, as work like MetaFormer has shown that thanks to the power of the FFN layer, the non-linearity of the layers responsible for mixing tokens (like Self-Attention) can be very weak, or even replaced with a Pooling layer. As for "good news," it is because linear RNNs have parallel implementation algorithms, making training speeds much faster than non-linear RNNs.

Therefore, the authors conducted a series of discussions around linear RNNs.

Diagonalization

Removing the activation function, the RNN simplifies back to: \begin{equation}x_t = Ax_{t-1} + u_t\label{eq:lr}\end{equation} Repeated iteration yields: \begin{equation} \begin{aligned} x_0 &= u_0\\ x_1 &= Au_0 + u_1\\ x_2 &= A^2 u_0 + Au_1 + u_2\\ &\vdots \\ x_t &= \sum_{k=0}^t A^{t-k}u_k \end{aligned} \label{eq:lr-e}\end{equation} As we can see, the main computational burden is concentrated on the power operations of the matrix $A$. At this point, it is natural to think of matrix diagonalization, which is an efficient method for calculating matrix powers. However, a general matrix may not be diagonalizable in the real field. What should we do? Open up your perspective: if it can't be done in the real field, we go to the complex field! Almost all matrices can be diagonalized in the complex field, which means $A$ can always be written as: \begin{equation}A = P\Lambda P^{-1} \quad \Rightarrow \quad A^n = P\Lambda^n P^{-1}\end{equation} where $P, \Lambda \in \mathbb{C}^{d \times d}$, and $\Lambda$ is a diagonal matrix composed of eigenvalues. Substituting into equation $\eqref{eq:lr-e}$, we get: \begin{equation}x_t = \sum_{k=0}^t P\Lambda^{t-k}P^{-1}u_k = P\left(\sum_{k=0}^t \Lambda^{t-k}(P^{-1}u_k)\right)\end{equation} As mentioned, there are general projection matrices before $u_t$ and after $x_t$. As long as we agree that these two projection matrices are complex matrices, then theoretically $P$ and $P^{-1}$ can be merged into their projection operations. This means that if all operations are considered in the complex field, replacing the general matrix $A$ in a linear RNN with a diagonal matrix $\Lambda$ results in no loss of model capacity! Thus, we only need to consider the following minimalist RNN: \begin{equation}x_t = \Lambda x_{t-1} + u_t \quad \Rightarrow \quad x_t = \sum_{k=0}^t \Lambda^{t-k}u_k\label{eq:lr-x}\end{equation}

Parameterization

The benefit of a diagonal matrix is that all operations are element-wise, so the computation of each dimension can be fully parallelized. This also means that analyzing one dimension is equivalent to analyzing all dimensions; the model analysis only needs to be performed in a one-dimensional space. Let $\Lambda = \text{diag}(\lambda_1, \lambda_2, \cdots, \lambda_d)$, and let $\lambda$ represent one of $\lambda_1, \lambda_2, \cdots, \lambda_d$. At the same time, where no confusion arises, $x_t$ and $u_t$ are also used to represent the components corresponding to $\lambda$. Thus, equation $\eqref{eq:lr-x}$ simplifies to a scalar operation: \begin{equation}x_t = \lambda x_{t-1} + u_t \quad \Rightarrow \quad x_t = \sum_{k=0}^t \lambda^{t-k}u_k\label{eq:lr-xx}\end{equation} Note that $\lambda$ is complex, so we can set $\lambda = re^{i\theta}$, where $r \geq 0$ and $\theta \in [0, 2\pi)$ are real numbers: \begin{equation}x_t = \sum_{k=0}^t r^{t-k}e^{i(t-k)\theta}u_k\label{eq:lr-e-r-theta}\end{equation} During the summation, $t-k$ is always non-negative, so $r \leq 1$ must hold. Otherwise, the weight of historical terms would gradually tend toward infinity, which contradicts intuition (intuitively, reliance on historical information should gradually weaken) and risks gradient explosion. On the other hand, if $r \ll 1$, there is a risk of gradient vanishing. This poses two requirements for $r$: 1. ensure $r \in [0, 1]$; 2. $r$ should be close to 1 during the initialization phase.

To this end, we first set $r = e^{-\nu}$, then $r \in [0, 1]$ requires $\nu \geq 0$. We then set $\nu = e^{\nu^{\log}}$, so that $\nu^{\log} \in \mathbb{R}$, converting it into an unconstrained optimization. Here $\nu^{\log}$ is just a notation for a variable, not representing any special operation. Since $\nu$ is parameterized as $e^{\nu^{\log}}$, to maintain consistency, we also parameterize $\theta$ as $e^{\theta^{\log}}$.

Readers might ask, there are many ways to constrain $r \in [0, 1]$, why make it so complicated? Wouldn't adding a sigmoid directly be fine? First, after parameterizing $r$ as $e^{-\nu}$, the power operation can be combined with $\theta$, i.e., $r^k e^{ik\theta} = e^{k(-\nu + i\theta)}$, which is better from both implementation and calculation perspectives. Second, since $\nu \geq 0$, the simplest smooth function that can map any real number to a non-negative number is arguably the exponential function, leading to $\nu = e^{\nu^{\log}}$. SSM uses $\text{relu}$ activation, i.e., $r = e^{-\max(\nu, 0)}$, but this has a saturation zone, which might be unfavorable for optimization.

Initialization

Next, we consider the initialization problem. Returning to the original form $\eqref{eq:lr}$, for a $d \times d$ real matrix, standard Glorot initialization is a normal or uniform distribution with mean 0 and variance $1/d$ (refer to "Understanding Model Parameter Initialization Strategies from a Geometric Perspective"). Theoretical or experimental evidence shows that for such an initialized matrix, its eigenvalues are roughly uniformly distributed within the unit circle on the complex plane:

Eigenvalues of a Glorot-initialized matrix are uniformly distributed in the unit disk

From this, we can think of the standard initialization method for $\Lambda$ as uniformly sampling points within the unit circle on the complex plane. Converting from Cartesian coordinates to polar coordinates, we have $dxdy = rdrd\theta = \frac{1}{2}d(r^2)d\theta$. This tells us that to achieve uniform sampling within the unit circle, we only need $\theta \sim U[0, 2\pi]$ and $r^2 \sim U[0, 1]$.

Switching to ring initialization, performance improves on most tasks

However, we just said that to prevent gradient vanishing as much as possible, we should make $r$ close to 1 during the initialization phase. So, the improvement is to sample uniformly within a ring where $r \in [r_{\min}, r_{\max}]$. The sampling method becomes $\theta \sim U[0, 2\pi]$ and $r^2 \sim U[r_{\min}^2, r_{\max}^2]$. Experimental results from the original paper show that $r_{\min} = 0.9, r_{\max} = 0.999$ works well for most experiments.

There is a problem here: if $r$ is initialized close to 1, and $u_t$ is close to i.i.d. in the initial stage, then equation $\eqref{eq:lr-e-r-theta}$ approximates a sum of several items with constant magnitude (rather than an average), which might lead to explosion risk. To analyze this, we first write: \begin{equation}\|x_t\|^2 = x_t x_t^* = \sum_{k=0}^t \sum_{l=0}^t r^{(t-k)+(t-l)}e^{i[(t-k)-(t-l)]\theta}u_k u_l^*\end{equation} where $*$ is the complex conjugate and $\|\cdot\|$ is the modulus of a complex number. Taking the expectation of both sides, assuming $u_k, u_l$ are independently drawn from the same distribution with mean 0, for $k \neq l$, $\mathbb{E}[u_k u_l^*] = \mathbb{E}[u_k]\mathbb{E}[u_l^*] = 0$. Only terms where $k = l$ remain: \begin{equation}\mathbb{E}[\|x_t\|^2] = \sum_{k=0}^t r^{2(t-k)}\mathbb{E}[u_k u_k^*] = \mathbb{E}[\|u_k\|^2]\sum_{k=0}^t r^{2(t-k)} = \frac{(1 - r^{2(t+1)})\mathbb{E}[\|u_k\|^2]}{1-r^2}\end{equation} Since $r \in (0, 1)$, as $t$ becomes large enough $r^{2(t+1)} \to 0$. This means that when $t$ is relatively large, the ratio of the expected modulus of $x_t$ to $u_k$ is $1/\sqrt{1-r^2}$. When $r$ is very close to 1, this ratio is very large, meaning the sequence will expand significantly after passing through the RNN, which is detrimental to training stability. Thus, the authors proposed a simple trick: introduce an element-wise parameter $\gamma$, initialized to $\sqrt{1-r^2}$, and then change equation $\eqref{eq:lr-xx}$ to: \begin{equation}x_t = \lambda x_{t-1} + \gamma u_t \quad \Rightarrow \quad x_t = \gamma \sum_{k=0}^t \lambda^{t-k} u_k\label{eq:lr-xxx}\end{equation} In this way, the model's output is stabilized at least in the initial stage, and the model is then left to learn on its own. Combining the above results gives the LRU (Linear Recurrent Unit) model proposed in the paper, as shown below:

Schematic of the LRU model

Related Work

Here we introduce two related variants of LRU.

SLRU

The starting point for LRU was to simplify general linear RNN models $\eqref{eq:lr}$, and to achieve the effect of a general matrix in theory, it had to introduce complex projection matrices and a complex eigenvalue diagonal matrix $\Lambda$. If we do not consider reaching the theoretical effect of a general matrix and only care about the decay effect brought by $r$, we can further simplify the LRU model—assuming both the projection matrices and the eigenvalue diagonal matrix are real—this simplified version is called SLRU (Simpler Linear Recurrent Unit).

The original paper did not study SLRU, but I feel it is more intuitively appealing (primarily because the change in phase $\theta$ is not easily understood intuitively), so I have also supplemented SLRU experiments later.

RWKV

Speaking of RNNs, some readers may have heard of the recently popular RWKV, which can be viewed as a combination of SLRU/ Hydra Attention and GLU (Gated Linear Unit). The RNN portion of RWKV is: \begin{equation}x_t = \sigma(r_t) \times \frac{y_t + (\gamma \lambda - 1)e^{k_t}v_t}{z_t + (\gamma \lambda - 1)e^{k_t}}, \quad \begin{aligned}y_t &= \lambda y_{t-1} + e^{k_t}v_t \\ z_t &= \lambda z_{t-1} + e^{k_t}\end{aligned}\end{equation} As we can see, the recursive part consists of two SLRUs. The characteristic of RWKV is dividing the results of two SLRUs to achieve normalization, so the gamma trick from LRU is not needed. Additionally, perhaps to align the number of parameters with Self-Attention or to further improve performance, a gate $\sigma(r_t)$ is added to multiply after normalization. Although the author has verified the effectiveness of RWKV on LM tasks, comparative experiments with common models seem to be missing; this article will supplement that part.

Note: RWKV here specifically refers to the RNN module responsible for token mixing, not the complete model (i.e., it doesn't use the author's Channel-Mix layer, Time Shift, etc.).

Implementation

In this section, we discuss the implementation of LRU. The original paper provides Jax-based LRU reference code in the appendix. Here, I also provide a Keras version:

Github: https://github.com/bojone/rnn

There are two technical difficulties in implementing LRU: Complex-ization and Parallelization.

Complex-ization

The projection matrices and eigenvalues of LRU are complex. The Jax code provided by the authors uses complex matrices directly. In Keras, this would mean we cannot reuse existing Dense layers, which would be a bit of a pity. In fact, from $(B+iC)u = Bu + iCu$, we can see that a complex projection matrix just doubles the projection dimension. Therefore, for the projection part, we don't use complex matrices; instead, we just use Dense layers with twice the units.

Next is the $e^{i(t-k)}u_k$ part, which can either be expanded into pure real operations or implemented directly with complex arithmetic according to the formula. If expanded into real operations, its form is the same as RoPE. I was quite excited when I first saw LRU, thinking "RoPE is all you need." However, after comparing speeds, I found that the version directly implemented with complex arithmetic is slightly faster, so it is recommended to use the complex version.

Finally, for projecting the complex output back to a real matrix, since $\Re[(B+iC)(x+iy)] = Bx - Cy = [B, -C][x, y]^\top$, we only need to concatenate the real and imaginary parts and then apply a Dense layer.

Parallelization

If the serial version of the RNN is implemented directly according to the recurrence formula, training will be very slow (inference is auto-regressive and serial, so that's fine). As mentioned, an important property of linear RNNs is that they have parallel algorithms, which can significantly speed up training.

In fact, we can rewrite equation $\eqref{eq:lr-xx}$ as: \begin{equation}x_t = \lambda^t \sum_{k=0}^t \lambda^{-k} u_k\end{equation} This actually suggests a fast algorithm: multiply each $u_k$ by $\lambda^{-k}$, which is element-wise and parallelizable; then the summation $\sum_{k=0}^t$ is a cumsum operation, which is fast in most frameworks; finally, multiply each cumsum result by its respective $\lambda^t$, which is also element-wise and parallelizable. However, because $\|\lambda\| < 1$, $\lambda^{-k}$ is almost certain to explode when $k$ is large, let alone in fp16 precision; even in FP32 or FP64, it may not hold for long sequences. Therefore, this seemingly straightforward plan has theoretical merit but little practical value.

The key to parallel acceleration is noting the decomposition ($T > t$): \begin{equation} \begin{aligned} x_T &= \sum_{k=0}^T \lambda^{T-k} u_k \\ &= \sum_{k=0}^t \lambda^{T-k} u_k + \sum_{k=t+1}^T \lambda^{T-k} u_k \\ &= \lambda^{T-t} \sum_{k=0}^t \lambda^{t-k} u_k + \sum_{k=t+1}^T \lambda^{T-k} u_k \\ \end{aligned} \end{equation} This decomposition tells us that the result of applying $\eqref{eq:lr-xx}$ to the entire sequence is equivalent to splitting the sequence into two halves, applying $\eqref{eq:lr-xx}$ to each half, and then weighting the last result of the first half and adding it to each position of the second half, as shown in the left figure below:

Parallel recursive decomposition of linear RNN
Complete expansion of linear RNN recursive decomposition

The key point is that the two halves can be processed in parallel! By recursing, we reduce what was originally $\mathcal{O}(L)$ steps to $\mathcal{O}(\log L)$, significantly accelerating training, as shown in the right figure above.

In fact, this is the "Upper/Lower" parallel algorithm for the Prefix Sum problem. Code details can be found in the code I provided above. Since Tensorflow 1.x does not support writing recursion directly, I implemented it using tf.while_loop or for from the bottom up; training speed barely approached that of Self-Attention. If the loop part were rewritten as a CUDA kernel, it should be able to surpass Self-Attention (unfortunately, I don't know how to do that). The author of RWKV simply wrote the RWKV RNN format as a CUDA kernel without considering parallelization, and even that could already match the speed of Self-Attention.

Additionally, there is the "Odd/Even" parallel algorithm for Prefix Sum, which theoretically has higher computational efficiency, but its structure is more complex. If implemented in Tensorflow, it involves more steps and more reshape and concat operations. Its actual efficiency might not surpass the "Upper/Lower" algorithm, so I did not implement it (mostly because Tensorflow 1.x doesn't support recursion, making it tedious to write).

Evaluation

In this section, we demonstrate the experimental results on LRA from the original paper, as well as my own results on language modeling (LM) tasks.

In the original paper, the authors primarily used a combination of theory and experiment to show how to step-by-step optimize a common RNN until it achieved performance close to SOTA on LRA. This process of analysis and improvement is fascinating and worth savoring. However, because the experiments in the paper were repeatedly conducted on LRA, the experiments themselves don't offer too many surprises. Here I only show Table 8 from the paper:

Summary of experimental results from the LRU paper

Readers of this article might care more about its performance in NLP, especially the recently popular LM. Unfortunately, the original paper does not include this part. I conducted some comparative experiments for your reference. The models compared include GAU (same as GAU-α), SA (same as RoFormerV2), LRU, SLRU, and RWKV. For LRU, SLRU, and RWKV, I simply replaced the Self-Attention in RoFormerV2 with LRU, SLRU, and RWKV of similar parameter and computation counts. All models are "base" versions with about 100 million parameters, which is considered small nowadays. Initialization used DeepNorm, the optimizer was Tiger, and all other hyperparameters were kept consistent, basically achieving strict variable control.

LOSS curve at training length 128
ACC curve at training length 128

LOSS curve at training length 512
ACC curve at training length 512

As can be seen, the ranking in terms of performance is: $$\text{GAU} > \text{SA} > \text{RWKV} > \text{LRU} > \text{SLRU}$$

From the experimental results, we can conclude:

1. LRU is better than SLRU, indicating that introducing complex projection matrices and complex eigenvalues is indeed helpful, though there is some loss of computational efficiency (even if the parameter count is held constant);

2. As sequence length increases, the performance of the Attention series (GAU, SA) improves, while the performance of the RNN series (LRU, SLRU, RWKV) decreases. This is a fundamental difference, likely because the long-range memory capacity of RNNs is limited by the hidden_size;

3. RWKV may indeed be the best RNN model currently, but there is still a clear gap compared to Attention-type models (GAU, SA);

4. According to Point 2, for the RNN series to catch up with the Attention series, the hidden_size probably needs to be further increased. Thus, for LM tasks, the RNN series may only have an advantage at larger scales;

5. Combining Points 1 and 3, is the next improved version of RNN a complex version of RWKV?

Additionally, there are some experiences from the experimental process. Since GAU is single-headed, it is significantly more computationally efficient than SA in long-sequence and large-scale scenarios, and its performance is also better than SA. Therefore, GAU should be the best choice for language models across a considerable range. Off the top of my head, for models under 10 billion parameters and sequence lengths under 5,000, I recommend GAU. However, it cannot be denied that RNN series models of the same scale are superior in inference efficiency (the calculation for each step is the same, as is the cache size), and their training efficiency is not inferior to Attention series models. Therefore, after scaling up, they should still have a chance to compete with the Attention series.

It is worth noting that while RWKV's overall performance is good, there is still a gap compared with GAU and SA. Thus, in a fair comparison, RWKV is not as perfect as legends say. In fact, the RWKV author's own implementation includes a series of quite obscure tricks said to help LM performance (according to the author, these tricks are the "essence"). These tricks can only be discovered by reading the source code provided by the author and were not included in my experiment. It is not ruled out that these tricks help in better training an LM, but I wanted to do a fair controlled experiment rather than actually training an LM model. Once these tricks are introduced, there are too many variables, and with limited computing power, I cannot control for them all.

Of course, the above conclusions were only drawn from "small models" at the 100-million scale. I am still trying larger scales and cannot give a conclusion for now.

Conclusion

This article introduced Google's attempt to "save" RNNs, building a high-efficiency RNN model from the top down that performs close to SOTA on LRA. In addition to the LRA experiments from the original paper, I also provided my own results on language modeling tasks, including comparisons with related models like RWKV. Overall, optimized RNN models are not inferior to Attention-type models in training efficiency and offer better inference performance, though there is still a certain gap in language modeling performance. Perhaps models need to be larger to further demonstrate the advantages of RNNs.