By 苏剑林 | November 08, 2021
A few days ago in a group discussion, the question "How does the Transformer solve gradient vanishing?" came up. Answers mentioned residuals, and others mentioned LN (Layer Norm). Are these the correct answers? In fact, this is a very interesting and comprehensive question that relates to many model details, such as "Why does BERT need warmup?", "Why is BERT's initial standard deviation 0.02?", "Why add an extra Dense layer before MLM prediction?", and so on. This article aims to focus on discussing these issues.
What Does "Gradient Vanishing" Mean?
In the article "Also Talking About the RNN Gradient Vanishing/Explosion Problem", we discussed the gradient vanishing problem in RNNs. In fact, the phenomenon of gradient vanishing in general models is similar: it refers to the fact that (mainly in the initial stage of the model) the closer a layer is to the input, the smaller its gradient becomes, tending toward zero or even equalling zero. Since we mainly use gradient-based optimizers, gradient vanishing means we lack a good signal to adjust and optimize the earlier layers.
In other words, the earlier layers might receive almost no updates and remain in a state of random initialization; only the layers closer to the output are updated well. However, the input to these layers is the output of the earlier, poorly updated layers, so the input quality might be terrible (having passed through a nearly random transformation). Therefore, even if the later layers are optimized, the overall effect is poor. Ultimately, we observe a counter-intuitive phenomenon: the deeper the model, the worse the performance, even on the training set.
A standard method to solve gradient vanishing is residual connections, formally proposed in ResNet. The idea of residuals is simple and direct: are you worried that the input gradient will vanish? Then I'll simply add a term with a constant gradient. Simply, the model becomes:
\begin{equation}y = x + F(x)\end{equation}
In this way, because there is a "direct path" $x$, even if the gradient of $x$ through $F(x)$ vanishes, the gradient through $x$ is essentially preserved, allowing deep models to be trained effectively.
Does LN Really Alleviate Gradient Vanishing?
However, in BERT and the original Transformer, a Post-Norm design is used, where the Norm operation is added after the residual:
\begin{equation}x_{t+1} = \text{Norm}(x_t + F_t(x_t))\end{equation}
In fact, the specific Norm method is not very important; whether it's Batch Norm or Layer Norm, the conclusion is similar. In the article "Brief Talk on Transformer Initialization, Parameterization, and Standardization", we analyzed this Norm structure; let's repeat it here.
In the initialization phase, since all parameters are randomly initialized, we can consider $x$ and $F(x)$ as two independent random vectors. If we assume each has a variance of 1, then the variance of $x+F(x)$ is 2. The $\text{Norm}$ operation is responsible for resetting the variance to 1. Thus, in the initialization phase, the $\text{Norm}$ operation is equivalent to "dividing by $\sqrt{2}$":
\begin{equation}x_{t+1} = \frac{x_t + F_t(x_t)}{\sqrt{2}}\end{equation}
Recursively, this becomes:
\begin{equation}\begin{aligned}
x_l =&\, \frac{x_{l-1}}{\sqrt{2}} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\
=&\, \frac{x_{l-2}}{2} + \frac{F_{l-2}(x_{l-2})}{2} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\
=&\, \cdots \\
=&\, \frac{x_0}{2^{l/2}} + \frac{F_0(x_0)}{2^{l/2}} + \frac{F_1(x_1)}{2^{(l-1)/2}} + \frac{F_2(x_2)}{2^{(l-2)/2}} + \cdots + \frac{F_{l-1}(x_{l-1})}{2^{1/2}}
\end{aligned}\end{equation}
We know that residuals help solve gradient vanishing, but in Post-Norm, the residual path is severely weakened. The closer to the input, the more severe the weakening—the residual exists "in name only." Therefore, in Post-Norm BERT models, LN not only fails to alleviate gradient vanishing, it is actually one of the "culprits" behind it.
So Why Do We Still Add LN?
Then the question naturally arises: since LN exacerbates gradient vanishing, why not just remove it?
It can be removed, but as mentioned before, the variance of $x+F(x)$ is 2, and the more residuals there are, the larger the variance becomes. So, a Norm operation is still needed. We can add it to the input of each module, i.e., $x+F(\text{Norm}(x))$, with a final $\text{Norm}$ at the very end; this is the Pre-Norm structure. In this case, each residual branch is weighted equally, rather than having an exponential decay as in Post-Norm. Of course, there are also methods that omit Norm entirely but require special initialization for $F(x)$ to keep its initial output close to 0, such as ReZero, Skip Init, Fixup, etc. These were also introduced in "Brief Talk on Transformer Initialization, Parameterization, and Standardization".
But setting these improvements aside, does Post-Norm have no merits? Could it be that Transformer and BERT started with a completely failed design?
That is highly unlikely. Although Post-Norm brings some gradient vanishing issues, it also has advantages in other areas. Most obviously, it stabilizes the numerical values of forward propagation and maintains consistency across each module. For example, in BERT-base, we can attach a Dense layer to the last layer for classification, or take the 6th layer and attach a Dense layer; but if you use Pre-Norm, after taking out the middle layer, you need to add an LN yourself before the Dense layer, otherwise the variance increases in later layers, which is unfavorable for optimization.
Secondly, gradient vanishing is not entirely "bad"; for the Fine-tuning stage, it is actually a benefit. During Fine-tuning, we usually hope to prioritize adjusting the parameters closer to the output layer and avoid over-adjusting the parameters closer to the input layer to prevent severely damaging the pre-training effects. Gradient vanishing means that layers closer to the input have a weaker impact on the final output, which is exactly what is desired during Fine-tuning. Therefore, pre-trained Post-Norm models often have better Fine-tuning effects than Pre-Norm models, as we also mentioned in "RealFormer: Moving Residuals to the Attention Matrix".
Are We Really Worried About Gradient Vanishing?
Actually, the most critical reason is that under current adaptive optimization techniques, we are no longer very worried about gradient vanishing.
This is because the mainstream optimizers in current NLP are Adam and its variants. For Adam, because it includes momentum and second-moment correction, its update amount is approximately:
\begin{equation}\Delta \theta = -\eta\frac{\mathbb{E}_t[g_t]}{\sqrt{\mathbb{E}_t[g_t^2]}}\end{equation}
As can be seen, the numerator and denominator are of the same dimension, so the result of the fraction is actually of the order $\mathcal{O}(1)$, and the update amount is of the order $\mathcal{O}(\eta)$. This means that, theoretically, as long as the absolute value of the gradient is greater than the random error, the corresponding parameters will have a constant-order update amount. This is different from SGD; SGD's update amount is proportional to the gradient. If the gradient is small, the update amount will also be small, and if the gradient is too small, the parameters will hardly be updated at conclusion.
Therefore, although the residuals in Post-Norm are severely weakened, in base and large-scale models, they are not yet weakened to the point of being smaller than the random error. Thus, combined with optimizers like Adam, they can still be updated effectively and can potentially be trained successfully. Of course, this is only a possibility; in fact, deeper Post-Norm models are indeed harder to train, requiring careful adjustment of learning rates and Warmup.
How Does Warmup Work?
You may have heard that Warmup is a key step in training Transformers; without it, the model may not converge or may converge to a poor position. Why is this? Didn't we say that with Adam, we are no longer afraid of gradient vanishing?
It is important to note that Adam solves the problem of small parameter update amounts caused by gradient vanishing—meaning the update amount won't be too small regardless of whether gradients vanish. However, for Post-Norm models, gradient vanishing still exists, but its meaning has changed. According to the Taylor expansion:
\begin{equation}f(x+\Delta x) \approx f(x) + \langle\nabla_x f(x), \Delta x\rangle\end{equation}
The increment $f(x+\Delta x) - f(x)$ is proportional to the gradient. In other words, the gradient measures the dependence of the output on the input. If the gradient vanishes, it means the model's output dependence on the input has weakened.
Warmup involves slowly increasing the learning rate from 0 to a specified size at the start of training, rather than starting training at the specified size immediately. Without Warmup, the model learns quickly from the beginning. Due to gradient vanishing, the model is more sensitive to the later layers, meaning they learn faster. However, the later layers take the output of the earlier layers as input. Since the earlier layers haven't learned well yet, the later layers, though learning fast, are built on a poor input foundation.
Quickly, the later layers reach a poor local optimum based on poor input. At this point, their learning begins to slow down (because they have reached what they consider an optimum), while the gradient signal back-propagated to the earlier layers further weakens. This leads to inaccurate gradients for the earlier layers. But as we said, Adam's update amount is of a constant order; if the gradient is inaccurate but the update is still of that order, it might effectively be a constant-order random noise. Thus, the learning direction becomes unreasonable, and the output of the earlier layers begins to collapse, leading the later layers to collapse as well.
Therefore, if a Post-Norm model is trained without Warmup, the observed phenomenon is often that the loss quickly converges to near a constant and then, after some training, starts to diverge until it becomes NAN. With Warmup, the model is given enough time for "preheating." During this process, the learning speed of the later layers is primarily suppressed, giving the earlier layers more time to optimize and promoting synchronized optimization across all layers.
The discussion here assumes gradient vanishing exists; in cases like Pre-Norm where there is no obvious gradient vanishing, successful training is often possible without Warmup.
Why is the Initial Standard Deviation 0.02?
Students who like to focus on details will notice that BERT's default initialization method is a truncated normal distribution with a standard deviation of 0.02. In "Brief Talk on Transformer Initialization, Parameterization, and Standardization", we also mentioned that because it is a truncated normal distribution, the actual standard deviation is smaller, approximately $0.02/1.1368472 \approx 0.0176$. Is this standard deviation large or small? For Xavier initialization, an $n \times n$ matrix should be initialized with a variance of $1/n$. For BERT-base, $n$ is 768, so the calculated standard deviation is $1/\sqrt{768} \approx 0.0361$. This means the default initialization standard deviation is significantly smaller, roughly half of the common initialization standard deviation.
Why does BERT use a smaller standard deviation for initialization? In fact, this is still related to the Post-Norm design. A smaller standard deviation leads the function's output to be smaller overall, making the Post-Norm design closer to an identity function in the initialization phase, which is more conducive to optimization. Specifically, following the previous assumptions, if the variance of $x$ is 1 and the variance of $F(x)$ is $\sigma^2$, then in the initialization stage, the $\text{Norm}$ operation is equivalent to dividing by $\sqrt{1+\sigma^2}$. If $\sigma$ is small, the weight of the "direct path" in the residual is closer to 1, making the model closer to an identity function in the initial stage, thus less prone to gradient vanishing.
As the saying goes, "We aren't afraid of gradient vanishing, but we don't want it either." Simply setting the initialization standard deviation smaller can make $\sigma$ smaller, thereby alleviating gradient vanishing while maintaining Post-Norm. Why not do it? Can it be set even smaller or even to zero? Generally, too small an initialization leads to a loss of diversity and narrows the model's trial-and-error space, which can bring negative effects. On balance, shrinking it to 1/2 of the standard is a relatively reliable choice.
Of course, some people do like to challenge the limits. Recently, I saw an article attempting to use nearly all-zero initialization for the entire model and achieved decent results. If you're interested, you can read "ZerO Initialization: Initializing Residual Networks with only Zeros and Ones".
Why Add an Extra Dense Layer for MLM?
Finally, regarding a detail of BERT's MLM model: why does BERT add an extra Dense layer and LN layer before the MLM probability prediction? Is it okay not to add them?
The answers I saw previously generally suggested that layers closer to the output are more Task-Specific. By adding an extra Dense layer, we hope this Dense layer is MLM-Specific, and then discard it during downstream fine-tuning because it's no longer MLM-Specific. This explanation seems somewhat reasonable, but it feels a bit metaphysical, as "Task-Specific" is not easy to analyze quantitatively.
Here I provide a more concrete explanation. In fact, it is still directly related to BERT's use of 0.02 standard deviation for initialization. As we just said, this initialization is quite small. If we predict the probability distribution by multiplying with Embeddings without adding an extra Dense layer, then the resulting distribution would be too uniform (before Softmax, every logit is close to 0), so the model wants to scale the values up. Now the model has two choices: first, scale up the values of the Embedding layer, but updating the Embedding layer is sparse and scaling them one by one is troublesome; second, scale up the input. We know the last layer of the BERT encoder is LN, and LN has a gamma parameter initialized to 1; just scaling up that parameter works.
Model optimization uses gradient descent, and we know it will choose the fastest path. Clearly, the second choice is faster, so the model will prioritize the second path. This leads to a phenomenon: the gamma value of the last LN layer will be relatively large. If a Dense+LN is not added before predicting the MLM probability distribution, then the gamma value of the last LN layer of the BERT encoder will be large, causing the variance of the last layer to be significantly larger than that of other layers, which is clearly not elegant. By adding an extra Dense+LN, the larger gamma is transferred to the new LN, while each layer of the encoder maintains consistency.
In fact, readers can observe the gamma values of each LN layer in BERT themselves, and they will find that the gamma value of the last LN layer is indeed significantly larger. This verifies our hypothesis!
Conclusion
This article has attempted to answer several questions related to the model optimization of Transformer and BERT. Some are results I found in my own pre-training work, and some are intuitive imaginations combined with my experience. In any case, consider this a reference answer. If there are any inaccuracies, I hope you will be patient and offer corrections.