A Brief Discussion on Initialization, Parameterization, and Normalization of Transformer

By 苏剑林 | August 17, 2021

A few days ago, while training a new Transformer model, I found that it wouldn't converge no matter what. After some debugging, I realized that I had forgotten to divide $\boldsymbol{Q}\boldsymbol{K}^{\top}$ by $\sqrt{d}$ during Self-Attention. This prompted me to revisit why dividing by $\sqrt{d}$ is so important. Of course, Google's T5 does not divide by $\sqrt{d}$, yet it still converges normally; that is because it adjusted its initialization strategy. Thus, this matter is closely related to initialization.

Taking this opportunity, this article will summarize model initialization, parameterization, and normalization, with the discussion primarily centered around the Transformer.

Sampling Distributions

Initialization is naturally a process of random sampling, so we start by introducing commonly used sampling distributions. Generally, we sample from a random distribution with a specified mean and variance. Three random distributions are commonly used: Normal, Uniform, and Truncated Normal.

Clearly, Normal and Uniform are very common. The Normal distribution is usually denoted as $\mathcal{N}(\mu, \sigma^2)$, with mean $\mu$ and variance $\sigma^2$. The Uniform distribution on the interval $[a, b]$ is denoted as $U[a, b]$, with a mean of $\frac{a+b}{2}$ and a variance of $\frac{(b-a)^2}{12}$. Therefore, if you specify a mean $\mu$ and variance $\sigma^2$, the corresponding uniform distribution is $U[\mu - \sqrt{3}\sigma, \mu + \sqrt{3}\sigma]$.

Generally, sampling results from a Normal distribution are more diverse, but it is theoretically unbounded. If results with excessively large absolute values are sampled, it might hinder optimization. Conversely, the Uniform distribution is bounded, but its sampling results are usually more uniform. Consequently, the "Truncated Normal Distribution" emerged, combining the advantages of both. A truncated normal distribution specifies both mean $\mu$ and variance $\sigma^2$, as well as an interval $[a, b]$. It samples from $\mathcal{N}(\mu, \sigma^2)$; if the result is within $[a, b]$, it is kept; otherwise, it resamples until the result falls within $[a, b]$.

In TensorFlow's built-in tf.random.truncated_normal, the interval is hard-coded as $a = \mu - 2\sigma, b = \mu + 2\sigma$. Based on the formula, the actual mean of the sampling results from this function remains $\mu$, but the actual variance is $\gamma\sigma^2$, where:

\begin{equation}\gamma = \frac{\int_{-2}^2 e^{-x^2/2}x^2 dx}{\int_{-2}^2 e^{-x^2/2} dx} = 0.7737413\dots\end{equation}

If you want to obtain sampling results with a variance of $\sigma^2$, the standard deviation passed to the function should be $\frac{\sigma}{\sqrt{\gamma}} = 1.1368472\dots\sigma$.

Stable Second Moment

In a previous article, "Understanding Model Parameter Initialization Strategies from a Geometric Perspective", I analyzed existing initialization methods from a geometric angle. The general idea is that specific random matrices approximate orthogonal matrices, thereby ensuring model stability during the initial stage. However, while the geometric perspective is intuitive, it is often difficult to generalize. Therefore, we will now understand initialization methods from an algebraic perspective.

In typical tutorials, the derivation of initialization methods aims to keep the input and output at the same mean and variance. Usually, it is assumed the input is a random vector with mean 0 and variance 1, and the method tries to keep the output mean at 0 and variance at 1. However, I believe this is not strictly necessary. For certain non-negative activation functions, it is impossible to achieve a mean of 0. In fact, we only need an indicator to measure whether a signal "vanishes" or "explodes". Zero mean and unit variance are not essential. Here, we use the second (raw) moment instead. It can be viewed as a variant of the L2 norm, serving a purpose similar to variance in gauging signal "vanishing" or "explosion," but it is generally more versatile and simpler.

Now, consider a fully connected layer without an activation function (let the number of input nodes be $m$ and output nodes be $n$):

\begin{equation} y_j = b_j + \sum_i x_i w_{i,j}\end{equation}

For simplicity, we usually initialize the bias term $b_j$ with zeros and the mean of $w_{i,j}$ as $\mathbb{E}[w_{i,j}] = 0$. This simplifies results, though it is not mandatory; it is just a clear choice. We calculate the second moment:

\begin{equation}\begin{aligned} \mathbb{E}[y_j^2] =&\, \mathbb{E}\left[\left(\sum_i x_i w_{i,j}\right)^2\right] = \mathbb{E}\left[\left(\sum_{i_1} x_{i_1} w_{i_1,j}\right)\left(\sum_{i_2} x_{i_2} w_{i_2,j}\right)\right] \\ =&\, \mathbb{E}\left[\sum_{i_1, i_2} (x_{i_1}x_{i_2}) (w_{i_1,j} w_{i_2,j})\right] = \sum_{i_1, i_2} \mathbb{E}[x_{i_1}x_{i_2}] \mathbb{E}[w_{i_1,j} w_{i_2,j}] \end{aligned}\end{equation}

Note that $w_{i_1,j}, w_{i_2,j}$ are independent and identically distributed, so when $i_1 \neq i_2$, $\mathbb{E}[w_{i_1,j}w_{i_2,j}] = \mathbb{E}[w_{i_1,j}]\mathbb{E}[w_{i_2,j}] = 0$. Thus, we only need to consider the case $i_1 = i_2 = i$. Assuming the second moment of the input is 1, then:

\begin{equation} \mathbb{E}[y_j^2] = \sum_{i} \mathbb{E}[x_i^2] \mathbb{E}[w_{i,j}^2] = m\mathbb{E}[w_{i,j}^2] \label{eq:m} \end{equation}

So to make $\mathbb{E}[y_j^2]$ equal to 1, we need $\mathbb{E}[w_{i,j}^2] = 1/m$. Combining this with the assumption of mean 0, we obtain the initialization strategy for $w_{i,j}$ as "sampling independently and repeatedly from a random distribution with mean 0 and variance 1/m." This is Lecun initialization. Note that throughout this process, we made no assumptions about the mean of the input; thus, it doesn't matter even if the input is entirely non-negative.

Activation Functions

Of course, this is only for scenarios without activation functions. If activation functions are considered, specific analysis for specific cases is needed. For example, if the activation function is $\text{relu}$, we can assume that roughly half of the $y_j$ are set to zero, so the estimated second moment is half of Eq. $\eqref{eq:m}$:

\begin{equation} \mathbb{E}[y_j^2] = \frac{m}{2}\mathbb{E}[w_{i,j}^2] \end{equation}

Thus, the initialization variance to keep the second moment stable is $2/m$, which is He initialization, specifically designed for $\text{relu}$ networks.

However, if the activation function is $\text{elu}$, $\text{gelu}$, etc., the analysis isn't so simple. And if the activation function is $\tanh$ or $\text{sigmoid}$, no initialization exists that can keep the second moment at 1. In such cases, if we still want to keep the second moment stable, a possible solution is "micro-tuning the definition of the activation function."

Using $\text{sigmoid}$ as an example, assume the input has mean 0 and variance 1, and we use the "mean 0, variance 1/m" initialization. The output before activation also has mean 0 and variance 1. We can then use the standard normal distribution to estimate the second moment after $\text{sigmoid}$:

\begin{equation}\int_{-\infty}^{\infty} \frac{e^{-x^2/2}}{\sqrt{2\pi}}\text{sigmoid}(x)^2 dx = 0.2933790\dots\end{equation}

In other words, under this assumption, the second moment after activation is roughly 0.293379. Thus, if we want to keep the output second moment roughly constant, we can divide the output result by $\sqrt{0.293379}$. In other words, the activation function is changed from $\text{sigmoid}(x)$ to $\frac{\text{sigmoid}(x)}{\sqrt{0.293379}}$. This is the "micro-tuned" activation function. If you feel it's necessary, you can also subtract a constant to make the output mean zero.

Recall that in 2017, a "sensational" paper "Self-Normalizing Neural Networks" proposed the $\text{selu}$ activation function. It is actually a "micro-tuned" $\text{elu}$ function based on the same logic. Its form is as follows:

\begin{equation}\text{selu}(x)=\lambda\left\{\begin{aligned} &x,& (x > 0) \\ &\alpha e^{x}-\alpha, &(x\leq 0) \end{aligned}\right.\end{equation}

where $\lambda=1.0507\dots, \alpha=1.6732\dots$. It was "sensational" first because it claimed to achieve automatic network normalization without using Batch Normalization, and second because the dozens of pages of mathematical derivation it included were quite "intimidating." However, from the perspective above, it's just introducing two parameters to micro-tune the $\text{elu}$ function so that when a standard normal distribution is the input, the output activation values have a mean of 0 and a variance of 1. Thus, at best, it can be considered a good initialization, which explains why it was only sensational for "a while." We can similarly solve for its two parameters numerically using Mathematica:


Direct Normalization

Instead of such simple "micro-tuning," more direct processing methods include various Normalization techniques, such as Batch Normalization, Instance Normalization, and Layer Normalization. These methods directly calculate the mean and variance of the current data to normalize the output, without needing prior integral estimates. These are also referred to as "normalization." These three methods are mostly similar; apart from Batch Normalization involving a moving average for the mean and variance used in prediction, they only differ in the dimensions of normalization. For instance, Layer Normalization, commonly used in NLP, particularly in Transformer models, is:

\begin{equation}y_{i,j,k} = \frac{x_{i,j,k} - \mu_{i,j}}{\sqrt{\sigma_{i,j}^2 + \epsilon}} \times \gamma_k + \beta_k, \quad \mu_{i,j} = \frac{1}{d}\sum_{k=1}^d x_{i,j,k}, \quad \sigma_{i,j}^2 = \frac{1}{d}\sum_{k=1}^d (x_{i,j,k}-\mu_{i,j})^2\end{equation}

I won't repeat the descriptions for the others. For the principle behind such methods, interested readers can refer to my previous post "What Does BN Actually Do? A 'Closed-Door' Analysis."

Here, I have observed an interesting phenomenon: Normalization generally includes two parts: subtracting the mean (center) and dividing by the standard deviation (scale). However, some recent works have gradually attempted to remove the center step, and some results even show that performance is slightly improved after removing center.

For example, the 2019 paper "Root Mean Square Layer Normalization" compared Layer Normalization after removing center, which the article calls RMS Norm, in the following form:

\begin{equation}y_{i,j,k} = \frac{x_{i,j,k}}{\sqrt{\sigma_{i,j}^2 + \epsilon}} \times \gamma_k, \quad \sigma_{i,j}^2 = \frac{1}{d}\sum_{k=1}^d x_{i,j,k}^2\end{equation}

It can be seen that RMS Norm is just a simple variant of L2 Normalization. However, the overall results of this paper show that RMS Norm is faster than Layer Normalization, and the effectiveness is basically consistent.

Beyond this paper, RMS Norm was used by Google in T5. Furthermore, another article "Do Transformer Modifications Transfer Across Implementations and Applications?" conducted comprehensive comparison experiments, demonstrating the superiority of RMS Norm. In this light, it's very likely that RMS Norm will replace Layer Normalization as the standard for Transformers in the future.

Coincidentally, another 2019 paper "Analyzing and Improving the Image Quality of StyleGAN" proposed StyleGAN2, where they discovered that the Instance Normalization used caused "water droplets" in some generated images. They eventually removed Instance Normalization and replaced it with something called "Weight demodulation," but they found that retaining Instance Normalization while only removing the center operation also improved this phenomenon. This provides evidence that the center operation in Normalization may have negative effects.

An intuitive guess is that the center operation, similar to the bias term in fully connected layers, stores prior distribution information about the pre-training task. Storing this prior distribution information directly in the model might instead cause a decrease in the model's transferability. This is why T5 not only removed the center operation from Layer Normalization but also removed the bias terms from every layer.

NTK Parameterization

Returning to the Xavier initialization for fully connected layers, it suggests using a "random distribution with mean 0 and variance 1/m." However, beyond using this initialization directly, we can use another parameterization method: initialize with a "random distribution with mean 0 and variance 1," but divide the output by $\sqrt{m}$. The model then becomes:

\begin{equation} y_j = b_j + \frac{1}{\sqrt{m}}\sum_i x_i w_{i,j}\end{equation}

In the context of Gaussian processes, this is called "NTK parameterization." Reference papers include "Neural Tangent Kernel: Convergence and Generalization in Neural Networks" and "On the infinite width limit of neural networks with a standard parameterization." For me, the first time I saw this operation was in the PGGAN paper "Progressive Growing of GANs for Improved Quality, Stability, and Variation."

Clearly, using NTK parameterization, we can initialize all parameters using a standard variance distribution while still maintaining a constant second moment. Even the "micro-tuning activation functions" mentioned earlier can be seen as a form of NTK parameterization. A natural question is: what are the benefits of NTK parameterization compared to direct Xavier initialization?

Theoretically, there are small benefits. With NTK parameterization, all parameters can be initialized with a distribution of variance 1, meaning each parameter's magnitude is roughly at the same $\mathcal{O}(1)$ level. This allows us to set a larger learning rate, such as $10^{-2}$. If using an adaptive optimizer, where the update is roughly $\frac{\text{gradient}}{\sqrt{\text{gradient} \otimes \text{gradient}}} \times \text{learning rate}$, then we know a learning rate of $10^{-2}$ adjusts parameters by roughly 1% per step. Overall, NTK parameterization allows us to treat every parameter more equally and gives us an intuitive understanding of the update magnitude during training, helping us better adjust hyperparameters.

At this point, we can discuss the question at the beginning of the article: why is dividing by $\sqrt{d}$ in Attention so important? For two $d$-dimensional vectors $\boldsymbol{q}, \boldsymbol{k}$, assuming they are sampled from a distribution with "mean 0, variance 1," the second moment of their dot product is:

\begin{equation}\begin{aligned} \mathbb{E}[(\boldsymbol{q}\cdot \boldsymbol{k})^2] =&\, \mathbb{E}\left[\left(\sum_{i=1}^d q_i k_i\right)^2\right] = \mathbb{E}\left[\left(\sum_i q_i k_i\right)\left(\sum_j q_j k_j\right)\right] \\ =&\, \mathbb{E}\left[\sum_{i,j} (q_i q_j) (k_i k_j)\right] = \sum_{i,j} \mathbb{E}[q_i q_j] \mathbb{E}[k_i k_j] \\ =&\, \sum_i \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] = d \end{aligned}\end{equation}

Which means the second moment of the dot product is $d$. Since the mean is also 0, this implies the variance is also $d$. Attention is a softmax after a dot product, primarily involving $e^{\boldsymbol{q}\cdot \boldsymbol{k}}$. We can roughly assume that the values after the dot product and before the softmax are in the range of $-3\sqrt{d}$ to $3\sqrt{d}$. Since $d$ is usually at least 64, $e^{3\sqrt{d}}$ is quite large and $e^{-3\sqrt{d}}$ is quite small. Consequently, after the softmax, the Attention distribution becomes very close to a one-hot distribution. This causes severe gradient vanishing problems, leading to poor training results.

Accordingly, there are two solutions. One is like NTK parameterization: divide by $\sqrt{d}$ after the dot product to make the variance of $\boldsymbol{q}\cdot \boldsymbol{k}$ become 1. This corresponds to $e^3, e^{-3}$, which are not excessively large or small, preventing the softmax from becoming one-hot and causing gradient vanishing. This is the common approach in Transformers like BERT. The alternative is not to divide by $\sqrt{d}$ but to divide the initialization variance of the fully connected layers for $\boldsymbol{q}$ and $\boldsymbol{k}$ by $\sqrt{d}$. This also makes the initial variance of $\boldsymbol{q}\cdot \boldsymbol{k}$ become 1, an approach adopted by T5.

Residual Connections

Finally, we must discuss designs related to the residual $x + F(x)$. It is easy to prove that if the variance (or second moment) of $x$ is $\sigma_1^2$ and the variance of $F(x)$ is $\sigma_2^2$, and assuming they are independent, the variance of $x + F(x)$ is $\sigma_1^2 + \sigma_2^2$. In other words, residuals further amplify the variance, so we must adopt strategies to reduce it.

A naive solution is to add a Normalization operation directly after the residual:

\begin{equation}x_{t+1} = \text{Norm}(x_t + F_t(x_t))\end{equation}

This can be called the Post Norm structure, which is the design used in the original Transformer and BERT. However, although this stabilizes the variance of forward propagation, it actually significantly weakens the identity branch of the residual, causing it to lose the "ease of training" advantage. Usually, it requires a warmup and a sufficiently small learning rate to converge.

How do we understand this? Assume that at the initial state, the variances of $x$ and $F(x)$ are both 1. Then the variance of $x + F(x)$ is 2, and the Normalization operation is responsible for bringing the variance back down to 1. This implies that at the initial stage, Post Norm is equivalent to:

\begin{equation}x_{t+1} = \frac{x_t + F_t(x_t)}{\sqrt{2}}\end{equation}

Recursively expanding this, we get:

\begin{equation}\begin{aligned} x_l =&\, \frac{x_{l-1}}{\sqrt{2}} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\ =&\, \frac{x_{l-2}}{2} + \frac{F_{l-2}(x_{l-2})}{2} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\ =&\, \dots \\ =&\, \frac{x_0}{2^{l/2}} + \frac{F_0(x_0)}{2^{l/2}} + \frac{F_1(x_1)}{2^{(l-1)/2}} + \frac{F_2(x_2)}{2^{(l-2)/2}} + \dots + \frac{F_{l-1}(x_{l-1})}{2^{1/2}} \end{aligned}\end{equation}

Do you see the problem? The original intention of the residual was to create a "green channel" for the preceding layers, allowing gradients to propagate back more directly. In Post Norm, however, this "green channel" is severely weakened. The closer the channel is to the beginning, the smaller its weight. The residual exists "in name only," making training difficult. Related analysis can also be found in the paper "On Layer Normalization in the Transformer Architecture."

A targeted improvement is called Pre Norm. Its idea is to "normalize only when needed," taking the form:

\begin{equation}x_{t+1} = x_t + F_t(\text{Norm}(x_t))\end{equation}

Similarly, after iterative expansion, we can assume that at the initial stage:

\begin{equation} x_l = x_0 + F_0(x_0) + F_1(x_1/\sqrt{2}) + F_2(x_2/\sqrt{3}) + \dots + F_{l-1}(x_{l-1}/\sqrt{l}) \end{equation}

In this way, at least every residual channel is equally weighted, and the effect of the residual is more pronounced than in Post Norm, making it easier to optimize. Of course, the variance of the final $x_l$ will be very large, so we still need to add a Normalization before the prediction layer.

In my view, neither Post Norm nor Pre Norm is perfect because neither can maintain an identity function at the initial stage. The most elegant approach, in my opinion, should involve introducing a scalar parameter $\alpha_t$ initialized to 0, such that:

\begin{equation}x_{t+1} = x_t + \alpha_t F_t(x_t)\end{equation}

Then, $\alpha_t$ is gradually updated. In this way, at the initial stage, we can ensure the model is an identity function, thus avoiding variance problems. This trick later appeared in two papers: in "Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks," it was called SkipInit, while in "ReZero is All You Need: Fast Convergence at Large Depth," it was called ReZero. The two papers were released less than a month apart, and their results both showed that after this treatment, the Normalization operation in the residual could essentially be replaced directly. Furthermore, "Fixup Initialization: Residual Learning Without Normalization" proposed a method called Fixup, which initializes the last layer of each residual branch with zeros, sharing similarities with SkipInit and ReZero.

Regarding the update of $\alpha_t$, both SkipInit and ReZero treat it as a model parameter to be updated along with others. I originally thought the same. Later, I realized that the status of $\alpha_t$ is not equal to that of other parameters and should not be treated the same. For instance, according to the NTK parameterization introduced earlier, we can use very large learning rates for other parameters, but clearly $\alpha_t$ should not use a large one. Moreover, we know that if training is successful, the results for Post Norm or Pre Norm are both very good (corresponding to $\alpha_t = 1$). Therefore, the choice of this residual mode is purely an initialization problem rather than a fitting capability problem. Considering these points, I later simply allowed $\alpha_t$ to increase gradually at a fixed, very small step size until it reached $\alpha_t = 1$, where it remained fixed. In my experiments, this update mode achieved the best results.

The Long Road to Training

This article discussed issues related to model initialization, parameterization, and normalization, hoping to provide some reference value for everyone's training and tuning. The road to training is vast and endless. Beyond these topics, there are many other things to tune, such as learning rates, optimizers, and data augmentation. May all readers have smooth sailing on their training journey~