Also Discussing RNN's Gradient Vanishing/Exploding Problem

By 苏剑林 | November 13, 2020

Although Transformer-based models have occupied most areas of NLP, RNN models such as LSTM and GRU still have their unique value in certain scenarios. Therefore, RNNs are still models worth studying thoroughly. Analyzing RNN gradients is an excellent example of thinking about and analyzing models from an optimization perspective, which is worth careful consideration. You may have noticed that questions like "Why can LSTM solve gradient vanishing/exploding?" remain among the most popular interview questions today...

Classic LSTM
Classic LSTM

Many netizens have provided answers to such questions. However, after searching through several articles (including some answers on Zhihu, columns, and classic English blogs), the author found that it is difficult to find a truly satisfactory answer: some derivation notations are inherently confusing, some narration processes fail to highlight the key points, and overall, they lack clarity and self-consistency. Therefore, I will attempt to provide my own understanding for your reference.

RNN and Its Gradient

The unified definition of an RNN is:

where $h_t$ is the output at each step, which is determined by the current input $x_t$ and the previous output $h_{t-1}$, and $\theta$ is the trainable parameter. When conducting a basic analysis, we can assume that $h_t, x_t, \theta$ are all one-dimensional. This allows us to gain the most intuitive understanding, and the results still have reference value for high-dimensional situations. The reason we consider gradients is that our current mainstream optimizers are still gradient descent and its variants; therefore, we require the models we define to have reasonable gradients. we can obtain:

As we can see, the gradient of an RNN is also an RNN. The gradient at the current moment $\frac{d h_t}{d\theta}$ is a function of the gradient at the previous moment $\frac{d h_{t-1}}{d\theta}$ and the current operation gradient $\frac{\partial h_t}{\partial \theta}$. At the same time, we can see from the above formula that the phenomenon of gradient vanishing or explosion is almost inevitable: when $\left\|\frac{\partial h_t}{\partial h_{t-1}}\right\| < 1$, it means the historical gradient information decays, so the gradient will inevitably vanish if there are too many steps (similar to $\lim\limits_{n\to\infty} 0.9^n \to 0$); when $\left\|\frac{\partial h_t}{\partial h_{t-1}}\right\| > 1$, it means the historical gradient information gradually increases, so the gradient will inevitably explode if there are too many steps (similar to $\lim\limits_{n\to\infty} 1.1^n \to \infty$). It is impossible for $\left\|\frac{\partial h_t}{\partial h_{t-1}}\right\| = 1$ to hold all the time, right? Of course, it is possible that it is greater than 1 at some moments and less than 1 at others, eventually stabilizing around 1, but the probability of this is very small and requires very sophisticated model design.

Therefore, as the number of steps increases, gradient vanishing or explosion is almost unavoidable; we can only mitigate this problem for a finite number of steps.

Vanishing or Exploding?

Having said that, we haven't clearly explained one question: what exactly is RNN gradient vanishing/explosion? Gradient explosion is easy to understand—it's when the gradient values divergence and eventually become NaN. But does gradient vanishing mean the gradient becomes zero? Not exactly. As we just said, gradient vanishing occurs when $\left\|\frac{\partial h_t}{\partial h_{t-1}}\right\|$ is consistently less than 1, causing the historical gradient to decay continuously. This doesn't mean the total gradient becomes 0. Specifically, if we continue to iterate, we have:

Obviously, as long as $\frac{\partial h_t}{\partial \theta}$ is not 0, the probability of the total gradient being 0 is actually very small. However, if the iteration continues, the coefficient in front of the term $\frac{\partial h_1}{\partial \theta}$ is the product of $t-1$ terms $\frac{\partial h_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial h_{t-2}}\cdots\frac{\partial h_2}{\partial h_1}$. If their absolute values are all less than 1, the result will tend toward 0. In this way, $\frac{d h_t}{d\theta}$ contains almost no information from the initial gradient $\frac{\partial h_1}{\partial \theta}$. This is the true meaning of gradient vanishing in RNNs: the further back the time step is from the current one, the less significant its feedback gradient signal becomes, eventually potentially having no effect at all. This means that the RNN's ability to capture long-distance semantics has failed.

Put simply, if your optimization process has nothing to do with long-distance feedback, how can you guarantee the learned model can effectively capture long-distance dependencies?

A Few Mathematical Formulas

The above text provided general analysis. Next, we will analyze specific RNNs. Before that, however, we need to review several mathematical formulas that we will apply multiple times in subsequent derivations:

$$ \begin{aligned} &\tanh x = 2\sigma(2x) - 1\\ &\sigma(x) = \frac{1}{2}\left(\tanh \frac{x}{2} + 1\right)\\ &(\tanh x)' = 1 - \tanh^2 x\\ &\sigma'(x) = \sigma(x)\left(1 - \sigma(x)\right) \end{aligned} $$

where $\sigma(x) = 1/(1+e^{-x})$ is the sigmoid function. These formulas essentially state one thing: $\tanh x$ and $\sigma(x)$ are basically equivalent, and their derivatives can be expressed in terms of themselves.

Analysis of Simple RNN

First to appear is the relatively primitive Simple RNN (sometimes we do call it SimpleRNN), its formula is:

where $W, U, b$ are the parameters to be optimized. Seeing this, a natural question arises: why use $\tanh$ as the activation function instead of the more popular $\text{ReLU}$? This is a good question, and we will answer it shortly.

From the previous discussion, we know that gradient vanishing or explosion mainly depends on $\left\|\frac{\partial h_t}{\partial h_{t-1}}\right\|$, so we calculate:

Since we cannot determine the range of $U$, $\left\|\frac{\partial h_t}{\partial h_{t-1}}\right\|$ might be less than 1 or greater than 1; the risk of gradient vanishing/explosion exists. Interestingly, if $\|U\|$ is very large, then $h_t$ will accordingly be very close to 1 or -1, causing $(1-h_t^2)U$ to actually become small. In fact, it can be strictly proven that if we fix $h_{t-1} \neq 0$, then $(1-h_t^2)U$ is bounded as a function of $U$. That is, no matter what $U$ equals, it will not exceed a fixed constant.

In this way, we can answer why $\tanh$ is used as the activation function. Because after using $\tanh$ as the activation function, the corresponding gradient $\frac{\partial h_t}{\partial h_{t-1}}$ is bounded. Although this bound might not be 1, the probability of a bounded quantity being less than or equal to 1 is always higher than that of an unbounded quantity, so the risk of gradient explosion is lower. In contrast, if $\text{ReLU}$ activation is used, its derivative on the positive axis is always 1, and at this time $\frac{\partial h_t}{\partial h_{t-1}}=U$ is unbounded, leading to a higher risk of gradient explosion.

Therefore, the primary purpose of RNNs using $\tanh$ instead of $\text{ReLU}$ is to mitigate the risk of gradient explosion. Of course, this mitigation is relative; there is still a possibility of explosion even with $\tanh$. In fact, the most fundamental way to deal with gradient explosion is weight clipping or gradient clipping. In other words, if I artificially clip $U$ into the range $[-1, 1]$, can't I guarantee that the gradient won't explode? Of course, some readers will ask, since clipping can solve the problem, can we use $\text{ReLU}$? This is indeed the case. With a good initialization method and parameter/gradient clipping schemes, $\text{ReLU}$ versions of RNNs can also be trained well. However, we still prefer to use $\tanh$ because its corresponding $\frac{\partial h_t}{\partial h_{t-1}}$ is bounded, requiring less aggressive clipping, and the model's fitting ability may be better.

LSTM Results

Of course, while clipping can work, it is ultimately a last resort. Moreover, clipping only solves the gradient explosion problem, not gradient vanishing. It is naturally best if we can solve this problem through model design. The legendary LSTM is one such design; is this true? Let's analyze it immediately.

The update formulas for LSTM are quite complex:

We could calculate $\frac{\partial h_t}{\partial h_{t-1}}$ as we did before, but as $h_{t} = o_{t} \circ \tanh \left( c_{t} \right)$, analyzing $c_{t}$ is equivalent to analyzing $h_{t}$, and calculating $\frac{\partial c_t}{\partial c_{t-1}}$ is somewhat simpler, so we proceed in that direction.

Similarly, focusing only on the 1-dimensional case, we use the chain rule to obtain:

The first term on the right, $f_t$, is what we call the "forget gate." From the following discussion, we can see that the other three terms are generally minor, so $f_t$ is the "main term." Since $f_t$ is between 0 and 1, this means the risk of gradient explosion will be very small. As for whether the gradient will vanish, it depends on whether $f_t$ is close to 1. Very coincidentally, there is a quite self-consistent conclusion here: if our task relies heavily on historical information, then $f_t$ will be close to 1, and at this time the historical gradient information also happens not to vanish easily; if $f_t$ is very close to 0, it means our task does not depend on historical information, and it doesn't matter if the gradient vanishes at that point.

So, the key now is to see whether the conclusion "the other three terms are minor" holds. The subsequent three terms are all in the form of "one term multiplied by the partial derivative of another," and the terms being differentiated are all $\sigma$ or $\tanh$ activations. As stated in the review of mathematical formulas, $\sigma$ and $\tanh$ are basically equivalent, so the following three terms are similar, and analyzing one is equivalent to analyzing the others. Taking the second term as an example, substituting $h_{t-1} = o_{t-1} \tanh \left( c_{t-1} \right)$, we can calculate:

Note that $f_t, 1 - f_t, o_{t-1}$ are all between 0 and 1, and it can also be proven that $\|\left(1-\tanh^2 c_{t-1}\right)c_{t-1}\| < 0.45$, so it is also between -1 and 1. Therefore, $c_{t-1}\frac{\partial f_t}{\partial c_{t-1}}$ is equivalent to one $U_f$ multiplied by four "gates," and the result will be compressed even smaller. So, as long as the initialization is not terrible, it will be compressed quite significantly and thus will not play a dominant role. Compared with the simple RNN gradient \eqref{eq:rnn-g}, it has three more gates. To put it bluntly, this change means: if one gate can't suppress you, what about a few more?

The conclusions for the remaining two terms are similar:

Therefore, the gradients of the latter three terms contain more "gates," and generally speaking, they will be compressed more severely after multiplication. Thus, the dominant term is still $f_t$. The characteristic that $f_t$ is between 0 and 1 determines that the risk of gradient explosion is small, while $f_t$ indicates the model's dependence on historical information, which also happens to be the retention degree of the historical gradient. The two are self-consistent, so LSTM can also effectively mitigate the gradient vanishing problem. Therefore, LSTM mitigates both the gradient vanishing and explosion problems fairly well. Now when we train LSTMs, in most cases, we only need to call adaptive learning rate optimizers like Adam and do not need to manually adjust the gradients anymore.

Of course, these results are "probabilistic"; if you insist on constructing an LSTM that would vanish or explode its gradient, you can certainly do so. Furthermore, even if LSTM can mitigate these two problems, it is only within a certain number of steps. If your sequence is very long, such as thousands or tens of thousands of steps, whatever should vanish will still vanish. After all, a single vector cannot cache that much information~

A Quick Look at GRU

Before ending the article, let's also analyze GRU, the strong competitor to LSTM. The calculation process for GRU is:

There is an even more extreme version where $r_t$ and $z_t$ are merged into one:

Regardless of which one it is, we find that when calculating $\hat{h}_t$, $h_{t-1}$ is first multiplied by $r_t$ to become $r_t \circ h_{t - 1}$. Has any reader ever been confused by this? Wouldn't using $h_{t-1}$ directly be more concise and intuitive?

First, we observe that $h_0$ is generally initialized to all zeros, and $\hat{h}_t$ must result in a value between -1 and 1 due to the $\tanh$ activation. Therefore, $h_t$, as a weighted average of $h_{t-1}$ and $\hat{h}_t$, also remains between -1 and 1. Thus, $h_t$ itself has a function similar to a gate. This is different from the $c_t$ in LSTM, which can theoretically diverge. After understanding this point, we find the derivative:

Actually, the result is similar to LSTM. The dominant term should be $1-z_t$, but the remaining terms have one fewer gate than the corresponding terms in LSTM, so their magnitudes could be larger. Compared to the LSTM gradient, it is inherently more unstable, especially the $r_t \circ h_{t-1}$ operation. Although it introduces one more gate $r_t$ to the last term, it also introduces an additional term $1 + (1 - r_t)h_{t-1}U_r$. It's hard to say whether this is beneficial or not. On the whole, it feels that GRU should be more unstable and more dependent on good initialization than LSTM.

Based on the above analysis, I personally believe that if we follow the GRU philosophy but want to simplify LSTM while maintaining its gradient-friendliness, a better approach would be to put $r_t \circ h_{t - 1}$ at the very end:

Of course, this would require caching an extra variable, leading to additional memory consumption.

Article Summary

This article discussed the RNN gradient vanishing/explosion problem, primarily through the boundedness of gradient functions and the number of gating components, to clarify the gradient flow situation in models like RNN, LSTM, and GRU, and to determine the magnitude of their respective risks. This article is a work of personal deduction; if there are any errors or omissions, I welcome your kind corrections.