Why Are Residuals Needed? A Perspective from DeepNet

By 苏剑林 | March 19, 2022

In "What Are the Difficulties in Training a 1000-Layer Transformer?", we introduced the DeepNet technology proposed by Microsoft, which enables training a 1000-layer Transformer. Readers generally have two reactions to DeepNet: one is to be amazed and give it a thumbs up, while the other is to feel it’s just "old wine in a new bottle" and uninteresting. The latter reaction often stems from the fact that the two improvements proposed by DeepNet—increasing the weight of the identity path and reducing the residual branch initialization—are quite common, and similar conclusions have appeared in other works. Therefore, it is hard to find anything particularly novel in them.

Admittedly, from the perspective of conclusions alone, DeepNet is indeed not that interesting. However, I believe that the process of DeepNet is far more important than the conclusions. Its interesting aspect lies in providing a concise and effective gradient magnitude analysis approach that can be used to analyze many related issues. For example, the question we will discuss in this article—"Why are residuals needed?"—can be given a more fundamental answer through this lens.

Incremental Explosion

Why are residuals needed? The answer is that with residuals, deep models (which might be hundreds, thousands, or even tens of thousands of layers) become easier to train. So the question becomes: why is it difficult to train deep models without residuals?

Many readers might answer "vanishing or exploding gradients." These are indeed two very important problems. However, by using specific initialization methods and Normalization techniques, we have already managed to make the gradients of ordinary feed-forward neural networks quite stable. Yet even then, training deep feed-forward neural networks remains difficult. This suggests that the reason is not just vanishing/exploding gradients but another problem, which we discussed in "What Are the Difficulties in Training a 1000-Layer Transformer?": "Incremental Explosion."

Understanding incremental explosion is not difficult. Suppose the loss function is $\mathcal{L}(\boldsymbol{\theta})$, where $\boldsymbol{\theta}$ represents its parameters. When the parameters change from $\boldsymbol{\theta}$ to $\boldsymbol{\theta}+\Delta\boldsymbol{\theta}$:

\begin{equation}\Delta\mathcal{L} = \mathcal{L}(\boldsymbol{\theta}+\Delta\boldsymbol{\theta}) - \mathcal{L}(\boldsymbol{\theta}) \approx \langle\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}),\Delta\boldsymbol{\theta}\rangle\end{equation}

For SGD, we have $\Delta\boldsymbol{\theta}=-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$, so $\Delta\mathcal{L} \approx -\eta\|\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\|^2$. Let the model have $N$ layers, and the average number of parameters per layer be $K$. If the gradient vanishing/explosion problem is solved, we can assume that the gradient of each parameter is of magnitude $\mathcal{O}(1)$. Thus, we have $\Delta\mathcal{L}=\mathcal{O}(\eta NK)$. Therefore, the update amount of the model at each step is proportional to the model depth $N$ (width is not discussed here). The deeper the model, the larger the update amount. This means that in the initial stage, the deeper the model is, the more likely it is to fall into a poor local optimum, leading to training stagnation or even collapse. This is the "incremental explosion" problem.

Treating the Symptoms

Simply put, "incremental explosion" means that as the number of layers increases, tiny changes in parameters lead to large changes in the loss function. This is particularly unfavorable for model training, especially in the initial stage. A direct coping technique for this is Warmup, where an extremely small learning rate is used in the initial stage and then gradually increased to avoid learning too fast at the beginning. Once the model has safely passed this initial "danger period," it can be trained normally.

However, although Warmup can play a certain role, it actually "treats the symptoms but not the root cause." The fact that "tiny changes in parameters lead to large changes in the loss function" means the model itself has high jitter. In more professional terms, the model's landscape is extremely non-smooth, which is not a property a good model should possess. Therefore, we should solve this problem by modifying the model rather than through the "superficial" method of lowering the learning rate.

By "modifying the model," we mean adjusting the model structure or initialization method to naturally counteract the effect of the layer count $N$ on the update amount. Based on the previous results $\Delta\mathcal{L} \approx -\eta\|\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\|^2$ and $\Delta\mathcal{L}=\mathcal{O}(\eta NK)$, to counteract the effect of the layer count, the gradient $\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$ must be brought to a magnitude of $\mathcal{O}(1/\sqrt{N})$. In other words, the gradient of each parameter must decrease as the number of layers increases.

Stable Propagation

If we just purely wanted to shrink the gradient, it would be simple—just minimize the initialization variance. But in reality, while shrinking the gradient, we must also maintain forward propagation stability, because forward propagation stability represents our prior knowledge of the task and means a better starting point for the model. In "A Brief Talk on Transformer Initialization, Parameterization, and Normalization", we discussed that the stability of forward propagation can be measured by the second moment. For a simple linear layer:

\begin{equation}\boldsymbol{y} = \boldsymbol{x}\boldsymbol{W}, \quad \boldsymbol{x}\in\mathbb{R}^n, \boldsymbol{W}\in\mathbb{R}^{n\times m}\end{equation}

We already know that to keep the second moment of $\boldsymbol{y}$ equal to that of $\boldsymbol{x}$, we need an initialization method with a mean of zero and a variance of $1/n$. If activation functions are considered, a constant scale is added; for example, for the $\text{relu}$ activation function, the variance is changed to $2/n$. For backward propagation, we have:

\begin{equation}\frac{\partial\mathcal{L}}{\partial \boldsymbol{x}} = \frac{\partial\mathcal{L}}{\partial \boldsymbol{y}}\frac{\partial\boldsymbol{y}}{\partial \boldsymbol{x}} = \frac{\partial\mathcal{L}}{\partial \boldsymbol{y}}\boldsymbol{W}^{\top}\end{equation}

As can be seen, backward propagation is just the opposite. To stabilize the second moment of backward propagation, an initialization method with a mean of zero and a variance of $1/m$ is needed. Xavier initialization takes the average of the two, $2/(n+m)$. More details can be found in "Thinking on Dimension Averaging Strategies for Non-Square Matrices in Initialization Methods".

In other words, if we want to stabilize forward propagation, the initialization variance is $1/n$, and the second moment of backward propagation is $m/n$ times the original. $m$ and $n$ are pre-selected hyperparameters and have no necessary connection with the number of layers. We cannot use them to achieve the requirement of reducing the gradient to $1/\sqrt{N}$ times. This means that for a deep feed-forward neural network without residuals:

\begin{equation}\phi_l(\phi_{l-1}(\phi_{l-2}(\cdots\phi_1(\boldsymbol{x}\boldsymbol{W}_1 + \boldsymbol{b}_1)\cdots)\boldsymbol{W}_{l-1} + \boldsymbol{b}_{l-1})\boldsymbol{W}_l + \boldsymbol{b}_l)\end{equation}

Once its forward propagation is stable, backward propagation is also fixed, making it impossible to scale the gradient relative to the layer count. Therefore, at most, we can solve the gradient vanishing and explosion problems of deep feed-forward neural networks, but we cannot solve the "incremental explosion" problem mentioned at the beginning of this article. Thus, deep feed-forward neural networks are inherently difficult to train.

The Emergence of Residuals

This is where residuals come onto the stage! Without loss of generality, assuming input and output dimensions are equal, we consider:

\begin{equation}\boldsymbol{y} = \boldsymbol{x} + \varepsilon \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\theta})\end{equation}

Clearly, as long as $\varepsilon$ is small enough, forward propagation is inevitably stable; and

\begin{equation}\frac{\partial \boldsymbol{y}}{\partial \boldsymbol{x}} = \boldsymbol{I} + \varepsilon\frac{\partial \boldsymbol{f(\boldsymbol{x};\boldsymbol{\theta})}}{\partial \boldsymbol{x}}\label{eq:bp}\end{equation}

So it can also be seen that as long as $\varepsilon$ is small enough, backward propagation is also stable. As for the parameter gradient:

\begin{equation}\frac{\partial \mathcal{L}}{\partial \boldsymbol{\theta}} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{y}}\frac{\partial \boldsymbol{y}}{\partial \boldsymbol{\theta}} = \varepsilon\frac{\partial \mathcal{L}}{\partial \boldsymbol{y}}\frac{\partial \boldsymbol{f(\boldsymbol{x};\boldsymbol{\theta})}}{\partial \boldsymbol{\theta}}\end{equation}

This shows that we can control $\varepsilon$ to achieve layer-count-related gradient scaling! For example, if we want to scale the gradient to $1/\sqrt{N}$, we simply set $\varepsilon=1/\sqrt{N}$.

With this result, we can answer why residuals are used:

Because the residual structure is a design that can simultaneously stabilize forward and backward propagation and scale parameter gradients to solve incremental explosion, helping us train deeper models.

Small Enough

We just said "$\varepsilon$ is small enough" twice, but how small counts as small enough? Is $\varepsilon=1/\sqrt{N}$ enough?

Suppose it is a 1D model, then $\frac{\partial y}{\partial x} = 1 + \varepsilon\frac{\partial f}{\partial x}$. Generally, we assume $\frac{\partial f}{\partial x}$ is $\mathcal{O}(1)$, so we can approximately use $\frac{\partial y}{\partial x}=1+\varepsilon$ for magnitude estimation. After propagating through $N$ layers, the "expansion coefficient" is approximately $(1+\varepsilon)^N$. And we know:

\begin{equation}\left(1 + \frac{1}{N}\right)^N < \lim_{N\to\infty} \left(1 + \frac{1}{N}\right)^N = e\end{equation}

That is to say, for a 1D model, to prevent backward propagation from exploding with the increase in layers, $\varepsilon$ must be at least $\mathcal{O}(1/N)$. Indeed, $\varepsilon=1/\sqrt{N}$ is not quite enough.

However, for high-dimensional models, the situation changes. Let's multiply both sides of Equation $\eqref{eq:bp}$ by an arbitrary vector $\boldsymbol{v}$:

\begin{equation}\boldsymbol{v}\frac{\partial \boldsymbol{y}}{\partial \boldsymbol{x}} = \boldsymbol{v} + \varepsilon\boldsymbol{v}\frac{\partial \boldsymbol{f(\boldsymbol{x};\boldsymbol{\theta})}}{\partial \boldsymbol{x}}\end{equation}

Note that in the initial stage, $\frac{\partial \boldsymbol{f(\boldsymbol{x};\boldsymbol{\theta})}}{\partial \boldsymbol{x}}$ also acts like a zero-mean randomly initialized matrix. In "Understanding Parameter Initialization Strategies from a Geometric Perspective", we discussed that such a matrix is close to (a multiple of) an orthogonal matrix. Therefore, in the initial stage, $\boldsymbol{v}$ and $\varepsilon\boldsymbol{v}\frac{\partial \boldsymbol{f(\boldsymbol{x};\boldsymbol{\theta})}}{\partial \boldsymbol{x}}$ are nearly orthogonal. Thus:

\begin{equation}\left\Vert\boldsymbol{v}\frac{\partial \boldsymbol{y}}{\partial \boldsymbol{x}}\right\Vert^2 = \mathcal{O}\big((1 + \varepsilon^2)\Vert\boldsymbol{v}\Vert^2\big)\end{equation}

Simply put, in high-dimensional cases, the expansion coefficient of each layer is closer to $1+\varepsilon^2$ rather than $1+\varepsilon$. According to the discussion of the 1D case, we only need $\varepsilon^2=\mathcal{O}(1/N)$, so $\varepsilon=1/\sqrt{N}$ is basically sufficient.

Summary

This article discussed the question "Why are residuals needed?". Inspired by DeepNet, the conclusion is that residuals can simultaneously stabilize forward and backward propagation and solve incremental explosion, making deep models easier to train. In contrast, ordinary feed-forward neural networks without residuals cannot solve these three problems simultaneously, making them difficult to train when deepened.