Mind-Bending: Nonlinear RNNs Can Actually Be Computed in Parallel?

By 苏剑林 | September 26, 2023

In recent years, linear RNNs have attracted significant attention from researchers (such as in my previous post "Google's New Work Attempts to 'Resurrect' RNN: Can RNNs Shine Again?") due to their ability to be trained in parallel and their constant inference cost. This gives RNNs "a foothold" even in the current trend where Transformers are blooming everywhere. However, at present, it seems this "foothold" belongs only to linear RNNs, as nonlinear RNNs cannot be trained efficiently in parallel, making them "unable to keep up" in the architecture wars.

However, a paper titled "Parallelizing Non-Linear Sequential Models over the Sequence Length" takes a different view. It proposes an iterative algorithm that claims to achieve parallel training for nonlinear RNNs! Is it really that magical? Let's take a closer look.

Finding Fixed Points

The original paper presents its method in a very general way, focusing primarily on PDEs and ODEs. Here, we will start directly with RNNs. Consider a common simple nonlinear RNN:

\begin{equation}x_t = \tanh(Ax_{t-1} + u_t)\label{eq:rnn}\end{equation}

Due to the presence of $\tanh$, it can only be computed serially. Now, let's subtract $Ax_{t-1}$ from both sides:

\begin{equation}x_t - Ax_{t-1} = \tanh(Ax_{t-1} + u_t) - Ax_{t-1}\end{equation}

Of course, this does not change the essence of it being a nonlinear RNN. However, we can find that if the $x_{t-1}$ on the right side were replaced by a given vector like $u_t$, it would become a linear RNN. According to the results in "Google's New Work Attempts to 'Resurrect' RNN: Can RNNs Shine Again?", it could then be computed in parallel. At this point, quick-witted readers might have already guessed the next step—iterative solving!

First, modify the above RNN to:

\begin{equation}x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)}\label{eq:rnn-iter}\end{equation}

Starting from a given $x_t^{(0)}$, repeatedly iterate the above equation. Ideally, it will converge to a fixed point $x_t^*$, which is the original calculation result of the nonlinear RNN. Theoretically, the total amount of computation for iteration through formula $\eqref{eq:rnn-iter}$ is larger than directly calculating recursively through formula $\eqref{eq:rnn}$. However, since each step of the iteration is a parallelizable linear RNN, and if the convergence speed is relatively fast such that the number of iteration steps is not too many, the total time consumption is usually faster than direct nonlinear RNN recursion (especially when the sequence length is very large).

Simplified Form

In fact, the slowness of nonlinear RNNs is only secondary to their inability to be computed in parallel. The most crucial factor is that they contain a large number of non-element-wise operations, such as the matrix multiplication $Ax_{t-1}$ inside the $\tanh$ in formula $\eqref{eq:rnn}$. Linear RNNs are fast not only because they allow parallel training but, more importantly, because they can be diagonalized to transform matrix multiplication into element-wise multiplication—for element-wise multiplication, even serial calculation is not too slow.

When we transform the nonlinear RNN into an iteration of linear RNNs through formula $\eqref{eq:rnn-iter}$, we also enjoy the "treatment" of being able to diagonalize the linear RNN, thereby increasing calculation speed. Specifically, diagonalize $A$ into $P\Lambda P^{-1}$ in the complex field, and formula $\eqref{eq:rnn-iter}$ becomes:

\begin{equation}x_t^{(n)} - P\Lambda P^{-1} x_{t-1}^{(n)} = \tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - P\Lambda P^{-1} x_{t-1}^{(n-1)}\end{equation}

Multiply both sides by $P^{-1}$ from the left:

\begin{equation}P^{-1} x_t^{(n)} - \Lambda P^{-1} x_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - \Lambda P^{-1} x_{t-1}^{(n-1)}\end{equation}

Let $y_t = P^{-1} x_t$, then the above equation can be simplified to:

\begin{equation}y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)}\end{equation}

Since an RNN is generally followed by a projection layer, the $P$ in $x_t = P y_t$ can theoretically be merged into the external projection layer. That is to say, the above equation theoretically possesses the same expressive power as the original $\eqref{eq:rnn}$. However, because $\Lambda$ is a diagonal matrix, the recursive calculation load is significantly reduced. The above equation also features an inverse matrix $P^{-1}$, which is not only computationally heavy but also unfavorable for optimization. Therefore, we might as well replace $P^{-1}$ and $P\Lambda$ with two unrelated parameter matrices:

\begin{equation}y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P\tanh(Q y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)}\end{equation}

As long as the initialization satisfies $PQ=\Lambda$.

The Idea of Perturbation

Assuming $x_t^{(0)}=0$, formula $\eqref{eq:rnn-iter}$ actually decomposes the original nonlinear RNN into a series of linear RNNs:

\begin{equation}\begin{array}{c} x_t^{(1)} - Ax_{t-1}^{(1)} = \tanh(u_t)\\ x_t^{(2)} - Ax_{t-1}^{(2)} = \tanh(Ax_{t-1}^{(1)} + u_t) - Ax_{t-1}^{(1)} \\ \vdots \\ x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)} \\ \vdots \\ \end{array}\label{eq:rnns}\end{equation}

Assuming $x_{t-1}, u_t$ are small quantities, applying $\tanh x \approx x$ to the right side of formula $\eqref{eq:rnn}$ gives:

\begin{equation}x_t = \tanh(Ax_{t-1} + u_t) \approx Ax_{t-1} + u_t \approx Ax_{t-1} + \tanh(u_t)\label{eq:rnn-approx}\end{equation}

This is exactly the first equation in $\eqref{eq:rnns}$. Therefore, if the assumption holds, $x_t^{(1)}$ might already be sufficiently close to the ideal $x_t^*$, and each subsequent iteration quickly approaches it. From this, we can see that "subtracting $Ax_{t-1}$ from both sides" is the key. It makes the first step of iteration in $\eqref{eq:rnn-iter}$ close to the first-order linear approximation of the original nonlinear RNN, which increases the convergence speed. This is a classic operation in mathematical physics called "perturbation".

Speeding Up Convergence

According to the principles of perturbation theory, the key to increasing convergence speed is to improve the accuracy of the approximate expansion. For example, a simpler improvement is to assume only $x_{t-1}$ is small. Then, according to the first-order Taylor expansion (with $u_t$ as a column vector and $\circ$ representing the Hadamard product):

\begin{equation}x_t = \tanh(Ax_{t-1} + u_t) \approx \tanh(u_t) + (\text{sech}^2 u_t\circ A)x_{t-1}\end{equation}

The resulting improvement to formula $\eqref{eq:rnn-iter}$ is:

\begin{equation}x_t^{(n)} - A_t x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t x_{t-1}^{(n-1)}\label{eq:iter-plus1}\end{equation}

where $A_t = \text{sech}^2 u_t\circ A$. A more refined improvement is to expand at each iteration step based on the results of the previous iteration:

\begin{equation}\begin{aligned} x_t =&\, \tanh(Ax_{t-1} + u_t) \\ \approx&\, \tanh(Ax_{t-1}^{(n-1)} + u_t) + (\text{sech}^2 (Ax_{t-1}^{(n-1)} + u_t)\circ A)(x_{t-1} - x_{t-1}^{(n-1)}) \end{aligned}\end{equation}

So formula $\eqref{eq:rnn-iter}$ becomes:

\begin{equation}x_t^{(n)} - A_t^{(n)} x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t^{(n)} x_{t-1}^{(n-1)}\label{eq:iter-plus2}\end{equation}

where $A_t^{(n)}=\text{sech}^2 (Ax_{t-1}^{(n-1)} + u_t)\circ A$. This final iterative format is essentially "Newton's method" for solving numerical equations, and it possesses quadratic convergence speed.

Why Converge?

Theoretically, the two improvements in $\eqref{eq:iter-plus1}$ and $\eqref{eq:iter-plus2}$ can indeed increase the convergence speed. However, they make the matrix $A$ in each step of the linear recursion dependent on $t$ or even $n$. This significantly increases the complexity of parallelization and prevents the use of the diagonalization trick from the "Simplified Form" section for acceleration. On the other hand, if we maintain an iterative format like $\eqref{eq:rnn-iter}$, although there are many efficiency benefits, convergence cannot be guaranteed very well.

Is the contradiction between these two really irreconcilable? In fact, from my point of view, the most direct approach is to "stop worrying about it." After deriving $\eqref{eq:rnn-iter}$ with the help of nonlinear RNNs, just forget the original nonlinear RNN and treat formula $\eqref{eq:rnn-iter}$ as the basic model. That is to say, why worry whether formula $\eqref{eq:rnn-iter}$ will converge to the original nonlinear RNN? Wouldn't it be better to just treat it as a new starting point? Whatever result gradient descent learns is the result. If gradient descent doesn't learn a result that converges to the original nonlinear RNN, it simply means that not converging to the original RNN is more suitable.

Once you cast off this layer of thinking, many problems become clear. First, even if formula $\eqref{eq:iter-plus2}$ has a very good convergence speed in theory, it is conditional, and in the context of deep learning, ensuring these conditions would be very luxurious. In other words, even the convergence of formula $\eqref{eq:iter-plus2}$ is not absolutely guaranteed, so why "be the pot calling the kettle black" by criticizing formula $\eqref{eq:rnn-iter}$? Secondly, after treating formula $\eqref{eq:rnn-iter}$ as a new starting point, we can simply understand it as a new way to use linear RNNs, or a way to solve the drawbacks of linear RNNs (such as linear RNNs not being Turing complete), which makes it more operational.

Overall, ignoring the convergence seems more likely to break the mental stalemate and explore even more general results.

General Cases

The preceding "long discourse" centered only on the simple nonlinear RNN, formula $\eqref{eq:rnn}$. What about the more commonly used LSTM and GRU?

Using GRU as an example, its original form is:

\begin{equation}\begin{aligned} z_{t} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1} + b_{z} \right) \\ r_{t} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1} + b_{r} \right) \\ \hat{h}_t & = \tanh \left( W_{h} x_{t} + U_{h} (r_t \circ h_{t - 1}) + b_{c} \right)\\ h_{t} & = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \end{aligned}\end{equation}

In the initial stage, all gates can be roughly considered as $\frac{1}{2}$. Then, imitating $\eqref{eq:rnn-approx}$, we have:

\begin{equation}\begin{aligned} h_{t} &\, = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \\ &\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \hat{h}_t \\ &\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \left(\tanh ( W_{h} x_{t} + b_{c} ) + \frac{1}{2}U_{h} h_{t - 1}\right) \\ &\, = \frac{1}{2} \left(I + \frac{1}{2}U_{h}\right)h_{t - 1} + \frac{1}{2} \tanh ( W_{h} x_{t} + b_{c} ) \\ \end{aligned}\end{equation}

So we can choose $A=\frac{1}{2} \left(I + \frac{1}{2}U_{h}\right)$ and rewrite GRU as an iteration:

\begin{equation}\begin{aligned} z_{t}^{(n)} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1}^{(n-1)} + b_{z} \right) \\ r_{t}^{(n)} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1}^{(n-1)} + b_{r} \right) \\ \hat{h}_t^{(n)} & = \tanh \left( W_{h} x_{t} + U_{h} (r_t^{(n)} \circ h_{t - 1}^{(n-1)}) + b_{c} \right)\\ h_{t}^{(n)} & = Ah_{t-1}^{(n)} - Ah_{t-1}^{(n - 1)} + \left(1 - z_{t}^{(n)}\right) \circ h_{t - 1}^{(n-1)} + z_{t}^{(n)} \circ \hat{h}_t^{(n)} \end{aligned}\end{equation}

In general, this conversion of a nonlinear RNN into a linear RNN iteration, from a practical perspective, uses the nonlinear RNN as a guide to derive a method for parameter sharing and combination of a multi-layer linear RNN. If it iterates $n$ times, it has the computational load of $n$ layers of linear RNNs. This naturally leads to a thought: Unless it can be proven that nonlinear RNNs like GRU and LSTM have an absolute advantage, wouldn't it be better to just stack several layers of "Linear RNN + MLP"?

Summary

This article briefly explores the parallel computation of nonlinear RNNs. Through the "perturbation" idea from mathematical physics, we can transform a nonlinear RNN into an iteration of linear RNNs, thereby utilizing the parallelizability of linear RNNs to achieve parallelization for nonlinear RNNs.