By 苏剑林 | March 29, 2022
The comparison between Pre Norm and Post Norm is a "long-standing" topic. This blog has discussed this issue multiple times, such as in the articles "A Brief Discussion on Initialization, Parameterization, and Standardization of Transformers" and "Model Optimization Talk: Why is the Initial Standard Deviation of BERT 0.02?". Currently, the most established conclusion is that under the same settings, the Pre Norm structure is often easier to train, but its final performance is usually not as good as Post Norm. It is easy to understand why Pre Norm is easier to train because its identity path is more prominent, but why is its performance not as good?
I didn't have a good answer for this until recently when I saw a reply from @Tang Xianghao on Zhihu, which gave me a "sudden realization." It turns out there is a very intuitive understanding of this problem! Let's learn about it together in this article.
Basic Conclusion
The formulas for Pre Norm and Post Norm are as follows:
\begin{align}
\text{Pre Norm: } \quad \boldsymbol{x}_{t+1} = \boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t))\\
\text{Post Norm: }\quad \boldsymbol{x}_{t+1} = \text{Norm}(\boldsymbol{x}_t + F_t(\boldsymbol{x}_t))
\end{align}
In a Transformer, here $\text{Norm}$ mainly refers to Layer Normalization, but in general models, it can also be Batch Normalization, Instance Normalization, etc. The relevant conclusions are essentially universal.
In the materials I have found, there are two papers showing that Post Norm is superior to Pre Norm: "Understanding the Difficulty of Training Transformers" and "RealFormer: Transformer Likes Residual Attention". Additionally, I have conducted my own comparative experiments, which showed that the transfer performance of the Post Norm structure is better. That is to say, during Pre-training, both Pre Norm and Post Norm can achieve roughly the same results, but the Fine-tuning effect of Post Norm is significantly better.
Readers might ask, doesn't "On Layer Normalization in the Transformer Architecture" show that Pre Norm is better than Post Norm? Isn't this a contradiction? In fact, that article compared the performance of Pre Norm and Post Norm under identical training settings, which only shows that Pre Norm is easier to train. This is because Post Norm requires different training configurations (e.g., Pre Norm can omit Warmup, but Post Norm usually requires it) to reach its optimal performance. Therefore, the conclusions are not contradictory.
Intuitive Understanding
Why is the performance of Pre Norm not as good as Post Norm? The answer given by @Tang Xianghao on Zhihu is: Pre Norm's depth is "diluted"! In other words, the actual equivalent depth of an $L$-layer Pre Norm model is not as deep as an $L$-layer Post Norm model, and having fewer layers leads to worse performance.
How do we understand this specifically? Quite simply, for a Pre Norm model, we iterate to obtain:
\begin{equation}\begin{aligned}
\boldsymbol{x}_{t+1} =&\, \boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \boldsymbol{x}_{t-1} + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \cdots \\
=&\, \boldsymbol{x}_0 + F_0 (\text{Norm}(\boldsymbol{x}_0)) + \cdots + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t))
\end{aligned}\end{equation}
If each term is of the same magnitude, then we have $\boldsymbol{x}_{t+1}=\mathcal{O}(t+1)$. That is to say, the difference between layer $t+1$ and layer $t$ is equivalent to the difference between $t+1$ and $t$. When $t$ is large, the relative difference between the two is very small, therefore:
\begin{equation}\begin{aligned}
&\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1})) \\
\approx&\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \begin{pmatrix} 1 & 1\end{pmatrix}\begin{pmatrix} F_t \\ F_{t+1}\end{pmatrix}(\text{Norm}(\boldsymbol{x}_t))
\end{aligned}\end{equation}
This means that when $t$ is large, $\boldsymbol{x}_t$ and $\boldsymbol{x}_{t+1}$ are very close, so $F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1}))$ and $F_{t+1}(\text{Norm}(\boldsymbol{x}_t))$ are or also very close. Thus, the sum of a $t$-layer model and $t+1$-layer model is approximately equivalent to an even wider $t$-layer model. Therefore, in Pre Norm, the result of stacking many layers is more about increasing width rather than depth. The more layers there are, the more "insubstantial" each layer becomes.
To put it simply, the Pre Norm structure invisibly increases the model's width while decreasing its depth. Since we know that depth is usually more important than width, it is this invisible reduction in depth that leads to the final performance drop. Post Norm is just the opposite; as we analyzed in "A Brief Discussion on Initialization, Parameterization, and Standardization of Transformers", every time it performs Norm, it weakens the weight of the identity branch. Thus, Post Norm emphasizes the residual branch more. Consequently, the layers in Post Norm are "full measure," and once trained well, the performance is superior.
Related Work
Many readers have likely heard of DeepNet, which claims to be able to train 1000-layer Transformers. In its paper "DeepNet: Scaling Transformers to 1,000 Layers", the description of Pre Norm is:
However, the gradients of Pre-LN at bottom layers tend to be larger than at top layers, leading to a degradation in performance compared with Post-LN.
Many readers might not have understood the logical relationship in this sentence at the time, but after reading the explanation in the previous section, you will likely have a new understanding.
In short, the statement "the gradients of Pre-LN at bottom layers tend to be larger than at top layers" means that the Pre Norm structure leans too heavily toward the identity branch (bottom layers), causing Pre Norm to degrade into a "shallow and wide" model, which eventually performs worse than a Post Norm model of the same depth. This is essentially consistent with the previous intuitive understanding.
Summary
This article mainly shared an intuitive understanding of "why the performance of Pre Norm is not as good as Post Norm."