By 苏剑林 | March 11, 2022
After the article "What are the difficulties in training a 1000-layer Transformer?" was published, a reader soon asked: what would the results be if that logic were applied to the "Gated Attention Unit (GAU)" from "FLASH: Probably the Most Interesting Efficient Transformer Design Recently"? How does it differ from the standard Transformer results? This article discusses that question.
Conclusion First
In fact, GAU is a very easy-to-train model. Even if we directly use "Post Norm + Xavier initialization" without any adjustments, we can easily train a GAU with dozens of layers without needing Warmup. Therefore, many training techniques for standard Transformers may find no use here in GAU.
Why can GAU achieve this? Quite simply, because under default settings, theoretically $\text{GAU}(\boldsymbol{x}_l)$ is nearly two orders of magnitude smaller than $\boldsymbol{x}_l$, so:
\begin{equation}\boldsymbol{x}_{l+1} = \text{LN}(\boldsymbol{x}_l + \text{GAU}(\boldsymbol{x}_l))\approx \boldsymbol{x}_l\end{equation}
Consequently, when paired with residual connections, GAU is already very close to an identity function under standard initialization. Models with this property are very easy to train and typically do not require Warmup. To map this to the conclusions of "What are the difficulties in training a 1000-layer Transformer?", these two orders of magnitude correspond to $\lambda=1, \alpha=100$. This means it automatically incorporates an effect equivalent to the DeepNorm operation for a hundred-layer model. Thus, theoretically, we can directly train hundreds of layers of a GAU model without special adjustment techniques.
Model Assumptions
We only need to perform a magnitude analysis on the inputs and outputs of the GAU. The standard GAU operations are as follows:
\begin{equation}\begin{aligned}
&\boldsymbol{O}=(\boldsymbol{U}\odot\boldsymbol{A}\boldsymbol{V})\boldsymbol{W}_o,\quad \boldsymbol{A}=\frac{1}{ns}\text{relu}^2\left(\mathcal{Q}(\boldsymbol{Z})\mathcal{K}(\boldsymbol{Z})^{\top}\right)\\
&\boldsymbol{U}=\phi(\boldsymbol{X}\boldsymbol{W}_u),\quad\boldsymbol{V}=\phi(\boldsymbol{X}\boldsymbol{W}_v),\quad\boldsymbol{Z}=\phi(\boldsymbol{X}\boldsymbol{W}_z)
\end{aligned}\end{equation}
Where $\boldsymbol{X}\in\mathbb{R}^{n\times d}$, $\boldsymbol{W}_u,\boldsymbol{W}_v\in\mathbb{R}^{d\times e}$, $\boldsymbol{W}_z\in\mathbb{R}^{d\times s}$, $\boldsymbol{W}_o\in\mathbb{R}^{e\times d}$, $\mathcal{Q},\mathcal{K}$ are simple affine transformations, and $\phi$ is an activation function, default to Swish. If any parts are unclear, you can refer to "FLASH: Probably the Most Interesting Efficient Transformer Design Recently".
We assume that the components of $\boldsymbol{X}$ are independently sampled from a standard normal distribution $\mathcal{N}(0,1)$. We also assume the initialization distribution for $\boldsymbol{W}_u,\boldsymbol{W}_v, \boldsymbol{W}_z$ is $\mathcal{N}(0,1/d)$, while for $\boldsymbol{W}_o$ it is independently sampled from $\mathcal{N}(0,1/e)$. This distribution is known as LeCun initialization. Its characteristics include keeping the output mean at 0 and maintaining consistency between the second moments of the input and output. Related content can be found in my previous article "A Brief Discussion on Initialization, Parameterization, and Normalization of Transformers".
Basic Integrals
Under these assumptions, let's estimate the distribution after each operation one by one. Combined with the assumptions, since LeCun initialization maintains the second moment, $\boldsymbol{X}\boldsymbol{W}$ can be approximated as being standard normally distributed. Thus, we can use the following equations to estimate the mean and second moment after applying the activation function $\phi$:
\begin{equation}\begin{aligned}
\mu\triangleq\mathbb{E}[\phi(\varepsilon)] =&\, \int_{-\infty}^{\infty} \frac{1}{\sqrt{2\pi}}\exp\left(-\frac{1}{2}\varepsilon^2\right)\phi(\varepsilon)d\varepsilon = 0.2066\cdots \\
\nu^2\triangleq\mathbb{E}[\phi(\varepsilon)^2] =&\, \int_{-\infty}^{\infty} \frac{1}{\sqrt{2\pi}}\exp\left(-\frac{1}{2}\varepsilon^2\right)\phi(\varepsilon)^2d\varepsilon = 0.3557\cdots
\end{aligned}\end{equation}
In other words, the mean and second moment of the components of $\boldsymbol{U}, \boldsymbol{V}, \boldsymbol{Z}$ are $\mu$ and $\nu^2$ respectively. In fact, only the second moment $\nu^2$ is used later. For a simple estimation, one can take $\nu=0.6$.
Self-Attention
In the initial stage, we have $\mathcal{Q}(\boldsymbol{Z})=\mathcal{K}(\boldsymbol{Z})=\boldsymbol{Z}$. Therefore, the initial stage has $\boldsymbol{A}=\frac{1}{ns}\text{relu}^2\left(\boldsymbol{Z}\boldsymbol{Z}^{\top}\right)$, which means (below $i\neq j$):
\begin{equation}\begin{aligned}
&\boldsymbol{A}_{i,i} = \frac{1}{ns}\text{relu}^2\big(\left\langle\boldsymbol{Z}_i, \boldsymbol{Z}_i\right\rangle\big) \approx \frac{1}{ns}\text{relu}^2\big(s\mathbb{E}[\phi(\varepsilon)^2]\big) = \frac{s\nu^4}{n} \\
&\boldsymbol{A}_{i,j} = \frac{1}{ns}\text{relu}^2\big(\left\langle\boldsymbol{Z}_i, \boldsymbol{Z}_j\right\rangle\big) \approx \frac{1}{ns}\text{relu}^2\big(s\mathbb{E}[\phi(\varepsilon)]^2\big) = \frac{s\mu^4}{n}
\end{aligned}\end{equation}
Notice that $\boldsymbol{A}_{i,i} / \boldsymbol{A}_{i,j} \approx \nu^4 / \mu^4 \approx 69 \gg 1$. That is, the diagonal elements are much larger than the non-diagonal elements. Therefore, in the initial stage, $\boldsymbol{A}$ is actually very close to $\frac{s\nu^4}{n}$ times the identity matrix, i.e., $\boldsymbol{A}\approx \frac{s\nu^4}{n}\boldsymbol{I}$. Consequently:
\begin{equation}\boldsymbol{O}=(\boldsymbol{U}\odot\boldsymbol{A}\boldsymbol{V})\boldsymbol{W}_o\approx \frac{s\nu^4}{n}(\boldsymbol{U}\odot\boldsymbol{V})\boldsymbol{W}_o\end{equation}
Remaining Parts
For $\boldsymbol{U}\odot\boldsymbol{V}$, it is approximately the calculation of $\phi(\varepsilon_i)\phi(\varepsilon_j)$ from two independent and identically distributed variables $\varepsilon_i, \varepsilon_j$. Thus:
\begin{equation}\mathbb{E}[(\boldsymbol{U}\odot\boldsymbol{V})^2] \approx \mathbb{E}[\phi(\varepsilon_i)^2\phi(\varepsilon_j)^2] = \mathbb{E}[\phi(\varepsilon_i)^2]\mathbb{E}[\phi(\varepsilon_j)^2] = \nu^4\end{equation}
Thus we have ($\boldsymbol{W}_o$ does not change the second moment):
\begin{equation}\mathbb{E}[\boldsymbol{O}^2] \approx \mathbb{E}\left[\left(\frac{s\nu^4}{n}\boldsymbol{U}\odot\boldsymbol{V}\right)^2\right] = \mathbb{E}[\phi(\varepsilon_i)^2\phi(\varepsilon_j)^2] = \frac{s^2\nu^{12}}{n^2}\end{equation}
Therefore, the magnitude of $\boldsymbol{O}$ is:
\begin{equation}\boldsymbol{O} = \mathcal{O}\left(\sqrt{\frac{s^2\nu^{12}}{n^2}}\right) = \mathcal{O}\left(\frac{s\nu^{6}}{n}\right) \end{equation}
Taking the conventional pre-training settings $s=128, n=512$ as an example, $s\nu^6/n\approx 0.01$. Thus, in the initial stage, the result coming out of $\text{GAU}(\boldsymbol{x}_l)$ is roughly at the level of $0.01\boldsymbol{x}_l$, which is two orders of magnitude smaller. Of course, this is a theoretical result; the actual result might be larger or smaller due to random error. However, even if it were larger, there is no need to worry, because GAU also possesses the following "Crazy Scale" property.
Crazy Scale
In the reference code in the appendix of the GAU paper, the initialization method used by the authors is not LeCun initialization, but a normal distribution with a standard deviation of 0.02. For BERT-base, $d=768$, and the standard deviation given by LeCun initialization is $1/\sqrt{d}\approx 0.036$, which means the standard deviation used in the appendix is only about half that of LeCun initialization.
When we replace all weight matrices $\boldsymbol{W}$ in the GAU with $\lambda \boldsymbol{W}$, we have:
\begin{equation}\begin{aligned}
&\tilde{\boldsymbol{U}}=\phi(\boldsymbol{X}\lambda\boldsymbol{W}_u) \approx \lambda\phi(\boldsymbol{X}\boldsymbol{W}_u)=\lambda \boldsymbol{U}\\
&\tilde{\boldsymbol{V}}=\phi(\boldsymbol{X}\lambda\boldsymbol{W}_v) \approx \lambda\phi(\boldsymbol{X}\boldsymbol{W}_v)=\lambda \boldsymbol{V}\\
&\tilde{\boldsymbol{Z}}=\phi(\boldsymbol{X}\lambda\boldsymbol{W}_z) \approx \lambda\phi(\boldsymbol{X}\boldsymbol{W}_z)=\lambda \boldsymbol{Z}\\
&\tilde{\boldsymbol{A}}=\frac{1}{ns}\text{relu}^2\left(\lambda^2\mathcal{Q}(\boldsymbol{Z})\mathcal{K}(\boldsymbol{Z})^{\top}\right) = \lambda^4 \boldsymbol{A}\\
&\tilde{\boldsymbol{O}}=(\tilde{\boldsymbol{U}}\odot\tilde{\boldsymbol{A}}\tilde{\boldsymbol{V}})\lambda\boldsymbol{W}_o \approx \lambda^7 \boldsymbol{O}
\end{aligned}\end{equation}
That is to say, if all initializations are scaled down to $\lambda$ times their original value, the output of the GAU will scale down to $\lambda^7$ times its original value! This is a quite "crazy" scale regarding GAU. Calculated at $\lambda=1/2$, $\lambda^7$ is also at the level of 0.01, shrinking it by another two orders of magnitude! Therefore, if we follow the initialization choices of the original paper, we could theoretically directly train a GAU model with tens of thousands of layers!
Article Summary
This article briefly analyzes the magnitude of GAU in the initial stage and concludes that GAU under standard initialization is already close to an identity function. Therefore, it possesses characteristics that make it quite easy to train, and basically, there is no need for additional adjustment even when training a hundred-layer GAU model.