By 苏剑林 | March 09, 2022
As is well known, current Transformers are growing larger, but this "largeness" is usually in "width" rather than "depth." For example, although GPT-3 has hundreds of billions of parameters, it is only a 96-layer Transformer, which is far from the depth we can imagine. What limits the development of Transformers toward greater "depth"? Some readers might think it's computational power, but a "wide and shallow" model doesn't require significantly less computational power than a "narrow and deep" model. Therefore, computational power is not the main constraint; ultimately, it boils down to the inherent training difficulties of Transformers. The general view is that the training difficulty of deep models stems from gradient vanishing or exploding. However, practice shows that even when gradients are improved through various means, deep models are still not easy to train.
Recent work (such as Admin) points out that the fundamental difficulty in training deep models lies in "increment explosion"—the deeper the model, the greater the perturbation to the output. Last week's paper, "DeepNet: Scaling Transformers to 1,000 Layers," follows this line of thought with a magnitude analysis and adjusts the model's normalization and initialization schemes accordingly. Ultimately, they successfully trained a 1,000-layer Transformer model. The entire analysis process is of significant reference value, so let's learn about it.
The full analysis in the original paper is quite long, and some assumptions or descriptions are not entirely reasonable upon closer inspection. In this post, I will try to correct these issues and attempt to derive similar results in a more logical manner.
Suppose the loss function is $\mathcal{L}(\boldsymbol{\theta})$, where $\boldsymbol{\theta}$ represents its parameters. Consider the increment of the loss function when the parameters change from $\boldsymbol{\theta}$ to $\boldsymbol{\theta}+\Delta\boldsymbol{\theta}$:
\begin{equation}\Delta\mathcal{L} = \mathcal{L}(\boldsymbol{\theta}+\Delta\boldsymbol{\theta}) - \mathcal{L}(\boldsymbol{\theta}) \approx \langle\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}),\Delta\boldsymbol{\theta}\rangle\end{equation}For SGD, we have $\Delta\boldsymbol{\theta}=-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$, so $\Delta\mathcal{L} \approx -\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert^2$. Suppose the model has $N$ layers, and each layer has $K$ parameter matrices (where $K$ is roughly constant). Combined with Xavier initialization and various normalization techniques, we can ensure that the gradient norm of each parameter matrix is of the order $\mathcal{O}(1)$. Thus, we have $\Delta\mathcal{L}=\mathcal{O}(\eta NK)$. Therefore, the update amount per step of the model is proportional to the model depth $N$. If the model is deeper, the update amount is larger, which means that in the initial stage, the model is more likely to enter a poor local optimum, leading to training stagnation or even collapse. This is the "increment explosion" problem.
There are two solutions at this point. One is to use a smaller learning rate in the initial stage (no more than the order of $\eta/N$) and then gradually increase it; this is the Warmup technique. The second is to adjust the initialization scheme so that the gradient of the parameters is of the order $\mathcal{O}(1/\sqrt{N})$, which automatically offsets the impact of the model depth.
How do we achieve the second solution? We can try to analyze the gradient of the Transformer. However, calculating the exact gradient is tedious, and in fact, we don't need the exact gradient—we just need to perform a magnitude analysis. Therefore, we can use the following "magnitude decomposition" trick to convert the problem into scalar derivatives.
For a matrix $\boldsymbol{W}$, we decompose it into the form $\boldsymbol{W}=\lambda \boldsymbol{U}$, where
\begin{equation}\lambda = \mathop{\text{argmin}}_{\kappa > 0} \Vert \boldsymbol{W}\boldsymbol{W}^{\top}/\kappa^2 - \boldsymbol{I}\Vert,\quad \end{equation}In simple terms, we want to decompose a matrix into the product of a scalar $\lambda$ and a matrix $\boldsymbol{U}$ that is as orthogonal as possible. Since $\boldsymbol{U}$ is close to an orthogonal matrix, it serves as a standard reference frame, while the corresponding $\lambda$ represents the magnitude of the matrix $\boldsymbol{W}$. If $\boldsymbol{W}$ uses Xavier initialization, then $\lambda$ corresponds to the "gain" parameter; that is, on top of Xavier initialization, one must multiply by $\lambda$. This is because the result of Xavier initialization is close to an orthogonal matrix, which can be referenced in "Understanding Model Parameter Initialization Strategy from a Geometric Perspective".
Under this decomposition, we have
\begin{equation}\frac{\partial \mathcal{L}(\lambda \boldsymbol{U})}{\partial \lambda} = \left\langle\frac{\partial \mathcal{L}(\lambda \boldsymbol{U})}{\partial (\lambda \boldsymbol{U})}, \boldsymbol{U}\right\rangle = \left\langle\frac{\partial \mathcal{L}(\boldsymbol{W})}{\partial \boldsymbol{W}}, \boldsymbol{U}\right\rangle\end{equation}This means that $\frac{\partial \mathcal{L}}{\partial \lambda}$ is proportional to $\frac{\partial \mathcal{L}}{\partial \boldsymbol{W}}$ in terms of magnitude. Therefore, performing a magnitude analysis on $\frac{\partial \mathcal{L}}{\partial \lambda}$ is equivalent to performing a magnitude analysis on $\frac{\partial \mathcal{L}}{\partial \boldsymbol{W}}$. In this way, $\frac{\partial \mathcal{L}}{\partial \lambda}$ acts as a simple "probe" for the magnitude of $\frac{\partial \mathcal{L}}{\partial \boldsymbol{W}}$, converting matrix differentiation into scalar differentiation and reducing the difficulty of the analysis.
Many experimental results show that although Pre Norm is easier to train than Post Norm, the final performance of Post Norm is often better. Therefore, the original paper retains the Post Norm structure and considers a more general form (DeepNorm):
\begin{equation}\boldsymbol{x}_{l+1} = \text{LN}(\alpha\boldsymbol{x}_l + F(\boldsymbol{x}_l)) = \text{LN}(\boldsymbol{x}_l + F(\boldsymbol{x}_l)/\alpha)\end{equation}where $\alpha > 0$ is a constant. For simplicity, let's first consider the FFN layer, in which case:
\begin{equation}\boldsymbol{x}_{l+1} = \text{LN}(\boldsymbol{x}_l + \phi(\boldsymbol{x}_l \boldsymbol{W}_1)\boldsymbol{W}_2/\alpha)\end{equation}Here $\phi$ is the activation function, usually ReLU or its variants (Swish, GeLU, etc.), which (approximately) satisfy $\phi(\lambda x) = \lambda \phi(x), \forall \lambda > 0$. Using the magnitude decomposition probe from the previous section, we get:
\begin{equation}\boldsymbol{x}_{l+1} = \text{LN}(\underbrace{\boldsymbol{x}_l + \lambda_1 \lambda_2 \phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2/\alpha}_{\text{denoted as } \boldsymbol{z}_{l+1}})\label{eq:ffn}\end{equation}Calculating the gradients of $\lambda$:
\begin{equation}\begin{aligned} \frac{\partial \mathcal{L}}{\partial \lambda_1} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_1} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_2 \phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2}{\alpha} \\ \frac{\partial \mathcal{L}}{\partial \lambda_2} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_2} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_1 \phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2}{\alpha} \end{aligned}\end{equation}We assert that $\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are both $\mathcal{O}(1)$, and since $\boldsymbol{U}_1$ and $\boldsymbol{U}_2$ are both close to orthogonal matrices, $\phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2$ is also $\mathcal{O}(1)$. Therefore, we ultimately have:
\begin{equation}\frac{\partial \mathcal{L}}{\partial \lambda_1} = \mathcal{O}\left(\frac{\lambda_2}{\alpha}\right),\quad \frac{\partial \mathcal{L}}{\partial \lambda_2} = \mathcal{O}\left(\frac{\lambda_1}{\alpha}\right)\end{equation}Now consider Self-Attention. For magnitude analysis, we consider single-head attention, which takes the form:
\begin{equation}\boldsymbol{x}_{l+1} = \text{LN}(\boldsymbol{x}_l + \sigma(\boldsymbol{x}_l \boldsymbol{W}_q\boldsymbol{W}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{W}_v\boldsymbol{W}_o/\alpha)\end{equation}where $\sigma(\cdot)$ is shorthand for the softmax operation; the Attention scale operation is omitted here. The magnitude decomposition form of the above equation is:
\begin{equation}\boldsymbol{x}_{l+1} = \text{LN}(\underbrace{\boldsymbol{x}_l + \lambda_v\lambda_o \sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o/\alpha}_{\text{denoted as } \boldsymbol{z}_{l+1}})\label{eq:sa}\end{equation}Now we can find the gradients for each $\lambda$. Due to the existence of softmax, the gradients of $\lambda_q, \lambda_k$ will themselves be very small and will not significantly affect the final update volume. Therefore, considering the updates of $\lambda_v, \lambda_o$ is sufficient:
\begin{equation}\begin{aligned} \frac{\partial \mathcal{L}}{\partial \lambda_v} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_v} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_o \sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o}{\alpha} \\ \frac{\partial \mathcal{L}}{\partial \lambda_o} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_o} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_v \sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o}{\alpha} \end{aligned}\end{equation}Similarly, we assert that $\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are both $\mathcal{O}(1)$. Note that the softmax output is a probability distribution that performs a weighted average of the tokens in $\boldsymbol{x}_l$. Generally speaking, the vector before and after averaging will be in the same order of magnitude, so we assume $\sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o$ is also $\mathcal{O}(1)$. Therefore, the results are similar to those of the FFN layer:
\begin{equation}\frac{\partial \mathcal{L}}{\partial \lambda_v} = \mathcal{O}\left(\frac{\lambda_o}{\alpha}\right),\quad \frac{\partial \mathcal{L}}{\partial \lambda_o} = \mathcal{O}\left(\frac{\lambda_v}{\alpha}\right)\end{equation}Whether it is FFN or Self-Attention, we have obtained similar conclusions. For simplicity, assume the magnitude of each parameter (at least during the initialization stage) is consistent—that is, all $\lambda$ take the same value. The overall conclusion is then:
\begin{equation}\frac{\partial \mathcal{L}}{\partial \lambda} = \mathcal{O}\left(\frac{\lambda}{\alpha}\right)\end{equation}Thus, the magnitude of the gradient is $\mathcal{O}(\lambda/\alpha)$. On the other hand, for an $N$-layer Transformer model, there are generally $N$ Self-Attention layers plus $N$ FFN layers, so strictly speaking, the number of layers is $2N$. Therefore, according to the analysis in the "Increment Explosion" section, we need to adjust the gradient to $\mathcal{O}(1/\sqrt{2N})$. The above equation tells us we can achieve this by setting $\lambda/\alpha=1/\sqrt{2N}$. The original paper's derivation is slightly looser, yielding the result $\lambda/\alpha = 1/\sqrt{4N}$, which is equivalent in magnitude.
Now we have a proportional relationship between $\lambda$ and $\alpha$, but we cannot directly obtain specific values for $\lambda$ and $\alpha$. According to the paper, starting from a symmetry perspective, setting $\lambda=1/\alpha$ leads to the solution:
\begin{equation}\alpha = (2N)^{1/4},\quad \lambda = (2N)^{-1/4}\label{eq:result}\end{equation}However, a purely symmetrical explanation is obviously not convincing enough. We need to figure out what different results different choices would produce. For this, we can compare two other sets of solutions:
Alternative 1: $\alpha=1, \lambda=(2N)^{-1/2}$. In this case, the initialization of the parameters is reduced to $(2N)^{-1/2}$ times the original value, and the gradient is also reduced to $(2N)^{-1/2}$ times. According to SGD's $\Delta\boldsymbol{\theta}=-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$, the update volume per step is also $(2N)^{-1/2}$ times the original. This means that the relative magnitude of learning before and after the adjustment has not changed. Therefore, it is possible that while at the start $\lambda=\mathcal{O}((2N)^{-1/2})$, after a few steps on the dataset, it deviates from this magnitude.
Alternative 2: $\alpha=(2N)^{1/2}, \lambda=1$. In this case, parameter initialization is not scaled, but the gradient is reduced to $(2N)^{-1/2}$ times. According to SGD's update rule, the update volume per step is $(2N)^{-1/2}$ times the original. The relative magnitude of learning after the adjustment is significantly reduced, so it is possible that learning will be very slow.
Both cases seem to have their drawbacks. Therefore, Equation $\eqref{eq:result}$, which lies between them, seems justifiable. It maintains the gradient scaling to $(2N)^{-1/2}$ while keeping the initial learning pace slightly slower but not too slow, effectively acting as an implicit Warmup.
The previous derivation was based on SGD, but in fact, we rarely use SGD directly to train NLP models. We mostly use adaptive learning rate optimizers, which fall into two main categories: one uses the second moment to calibrate the learning rate (Adam, AdamW, etc.), and the other further calibrates the learning rate based on parameter norms, such as LAMB and AdaFactor. The original paper says "we derived on SGD and verified on Adam and found it also works," but theoretically speaking, they are not completely universal. In this section, we will perform a targeted analysis.
For Adam-type optimizers, the update per step is approximately $\Delta\boldsymbol{\theta}=-\eta\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$, so $\Delta\mathcal{L} \approx -\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$. This is proportional to the 1st power of the gradient rather than the 2nd power. Therefore, for the update amount to be independent of the number of layers, the gradient should be scaled to $1/(2N)$ times the original. This implies $\lambda/\alpha=1/(2N)$. If we also let $\lambda=1/\alpha$, then we have:
\begin{equation}\alpha = (2N)^{1/2},\quad \lambda = (2N)^{-1/2}\end{equation}For LAMB-type optimizers, the update per step is approximately $\Delta\boldsymbol{\theta}=-\eta\Vert\theta\Vert\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$, so $\Delta\mathcal{L} \approx -\eta\Vert\theta\Vert\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$. Note that the scaling factor for parameters is $\lambda$ and for gradients is $\lambda/\alpha$, so $\Delta\mathcal{L}=\mathcal{O}(2N\lambda^2/\alpha)$. Thus, $\lambda^2/\alpha=1/(2N)$. Since for this type of optimizer, the relative update size per step is the same (equal to the learning rate $\eta$) regardless of how $\alpha, \lambda$ are adjusted, we can directly set $\alpha=1, \lambda=(2N)^{-1/2}$.
The summary of the results is as follows:
| Optimizer | $\Delta\boldsymbol{\theta}$ | $\Delta\mathcal{L}$ | $\alpha$ | $\lambda$ |
|---|---|---|---|---|
| SGD | $-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$ | $-\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert^2$ | $(2N)^{1/4}$ | $(2N)^{-1/4}$ |
| Adam | $-\eta\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$ | $-\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$ | $(2N)^{1/2}$ | $(2N)^{-1/2}$ |
| LAMB | $-\eta\Vert\theta\Vert\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$ | $-\eta\Vert\theta\Vert\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$ | $1$ | $(2N)^{-1/2}$ |
The previous two sections of derivation used the assertion that "$\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are both $\mathcal{O}(1)$." Does it hold? Let's perform a hindsight analysis here.
It's actually quite simple. After the aforementioned adjustments, whether in the FFN layer $\eqref{eq:ffn}$ or the Self-Attention layer $\eqref{eq:sa}$, the weight of each residual branch in the initial stage is scaled to $\lambda^2/\alpha$ times the original. regardless of the results for any optimizer, $\lambda^2/\alpha$ is a relatively small number. This means that in the initial stage, the entire model is actually close to an identity function. Thus, $\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are naturally both $\mathcal{O}(1)$, so the conclusion and the assertion are self-consistent.
Additionally, some readers might wonder if the same analysis can be applied to the Pre Norm structure. The answer is yes, and the conclusion is basically consistent. Only because the Norm is placed before the residual branch, there is no need to set the $\alpha$ parameter. Therefore, the conclusion is that for all the Post Norm results mentioned above, $\alpha$ is set to 1, and the corresponding $\lambda$ is recalculated.
Finally, readers might question whether spending so much effort discussing making models deeper is actually important. Yes, the original paper provides a beautiful experimental result: a 200-layer "deep and narrow" model (3.2 billion parameters) beats a previous 48-layer "shallow and wide" SOTA model (12 billion parameters):
The "Deep and Narrow" model outperforms the "Shallow and Wide" model
This article analyzes the bottlenecks in making Transformers "deep" and provides corresponding solutions. The main ideas of the article originate from Microsoft's new DeepNet and simplify and improve the original paper's analysis process.