Can we losslessly enlarge a Transformer model? (Part 1)

By 苏剑林 | June 02, 2021

Looking at the title, readers might find it strange: shouldn't everyone be thinking about how to shrink large models? Why are you thinking about enlarging a small one? The background is this: Generally speaking, larger models with more data do indeed yield better results. However, when computing power is limited, the time cost of pre-training a large model from scratch is too high. If you need to tune hyper-parameters a few times, months might pass by.

This is where "poor man's logic" comes in (those with infinite resources can ignore this): Can we first train a small model with the same number of layers, then enlarge it and continue training? In this way, the weights of the pre-trained small model, after being enlarged, serve as a very high starting point for the large model's initialization. Consequently, the number of training steps for the large model stage can be reduced, thereby shortening the overall training time.

So, can a small model be losslessly enlarged into a large model? This article will analyze this problem from a theoretical perspective.

Meaning

Some readers might think: "Of course this is possible, the fitting capacity of a large model is definitely greater than that of a small model." Indeed, from the perspective of fitting capacity, this is certainly achievable, but that's not the full meaning of the "lossless enlargement" we are concerned with here.

Taking BERT as an example, the pre-training stage is mainly an MLM (Masked Language Model). The meaning of "lossless enlargement" is:

Is it possible to directly transform a small model into a large model through some kind of transformation such that the output remains completely unchanged?

Here, the transformation refers to deterministic transformations performed on the weights without needing to continue training via gradient descent. "Output remains completely unchanged" means that for the same input, the small model and the large model give identical prediction results. In other words, although they look different on the surface, mathematically they represent exactly the same function. Since it is a lossless enlargement, we can at least guarantee that the large model is no worse than the small one, so continuing pre-training theoretically offers a positive gain. As for whether this "small-then-large" pre-training strategy can compete with training from a large model from the start in terms of final performance, that requires experimental validation and is not the focus of this article.

Intuitively, this kind of enlargement doesn't seem difficult. For example, operations like "repetition" or "zero-padding" can achieve a natural enlargement of model weights. In fact, these are the directions one would try, but the difficulty lies in carefully analyzing the consequences of enlarging each module of the model to ensure the final result is truly lossless.

Attempt

Below, we analyze and attempt to "enlarge a BERT by a factor of 2" as an example to determine the final form of the transformation. Here, "enlargement" refers only to expanding the dimension of the hidden vectors, without changing the number of model layers or the number of heads in the multi-head attention mechanism.

Embedding

First, the input layer is the Embedding layer, so we must first solve the enlargement problem there. This is one of the simplest parts: just enlarge the vector dimension of each token to 2 times. The main operations are "repetition" and "zero-padding":

\[ \text{Repetition: } [x_1, x_2, x_3, x_4] \to [x_1, x_1, x_2, x_2, x_3, x_3, x_4, x_4] \] \[ \text{Zero-padding: } [x_1, x_2, x_3, x_4] \to [x_1, x_2, x_3, x_4, 0, 0, 0, 0] \]

Both schemes could be candidates, but intuitively, zero-padding introduces too many zeros, leading to excessive sparsity and too many repetitions of the same value, which is unfavorable for weight diversity. Therefore, we choose the repetition scheme. However, even with repetition, there is more than one way. For example, $[x_1, x_2, x_3, x_4, x_1, x_2, x_3, x_4]$ is also a scheme, but subsequent analysis of the Attention layer shows that this latter scheme is not advisable.

In addition, we usually hope the transformation is orthogonal, which generally ensures model stability to the greatest extent. Specifically, the most basic property of an orthogonal transformation is that it does not change the norm of the vector. Therefore, we adjust the final repetition transformation to:

\begin{equation} \begin{pmatrix} x_1 \\ x_2 \\ \vdots \\ x_d \end{pmatrix} \to \begin{pmatrix} \tilde{x}_1 \\ \tilde{x}_2 \\ \tilde{x}_3 \\ \tilde{x}_4 \\ \vdots \\ \tilde{x}_{2d-1} \\ \tilde{x}_{2d} \end{pmatrix} = \frac{1}{\sqrt{2}} \begin{pmatrix} x_1 \\ x_1 \\ x_2 \\ x_2 \\ \vdots \\ x_d \\ x_d \end{pmatrix} \label{eq:repeat-sqrt2} \end{equation}

Or simplified as $\tilde{x}_i = x_{\lceil i/2 \rceil} / \sqrt{2}$, where $\lceil \cdot \rceil$ is the ceiling operation. We call this "repeat and divide by $\sqrt{2}$".

LayerNorm

The layer following the Embedding is LayerNorm. Before the transformation, the computation for LayerNorm is:

\begin{align} y_i &= \frac{x_i - \mu}{\sigma} \times \gamma_i + \beta_i \\ \mu &= \frac{1}{d} \sum_{i=1}^d x_i \\ \sigma &= \sqrt{\frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2} \end{align}

After the transformation, we have:

\begin{align} \tilde{\mu} &= \frac{1}{2d} \sum_{i=1}^{2d} \tilde{x}_i = \frac{1}{d} \sum_{i=1}^d \frac{x_i}{\sqrt{2}} = \frac{\mu}{\sqrt{2}} \\ \tilde{\sigma} &= \sqrt{\frac{1}{2d} \sum_{i=1}^{2d} (\tilde{x}_i - \tilde{\mu})^2} = \sqrt{\frac{1}{d} \sum_{i=1}^d \left(\frac{x_i}{\sqrt{2}} - \frac{\mu}{\sqrt{2}}\right)^2} = \frac{\sigma}{\sqrt{2}} \\ \frac{\tilde{x}_i - \tilde{\mu}}{\tilde{\sigma}} &= \frac{x_{\lceil i/2 \rceil}/\sqrt{2} - \mu/\sqrt{2}}{\sigma/\sqrt{2}} = \frac{x_{\lceil i/2 \rceil} - \mu}{\sigma} \end{align}

This means that the "subtract mean and divide by standard deviation" step automatically cancels out the factor $1/\sqrt{2}$, and the result is a direct repetition of the result before enlargement. If we also transform the parameter vectors $\beta, \gamma$ according to Formula \eqref{eq:repeat-sqrt2}, the result will be $\tilde{y}_i = y_{\lceil i/2 \rceil} / \sqrt{2}$, which is consistent with the transformation result of the Embedding layer. Our goal is to ensure that the "net transformation" of each layer is the same simple transformation: "repeat and divide by $\sqrt{2}$".

FeedForward

According to the order, we should analyze the Attention layer next, but the FeedForward layer is relatively simpler, and its analysis helps in understanding the transformation of the Attention layer. So, let's consider the FeedForward layer first.

The FeedForward layer is just a composition of two fully connected layers, so we only need to analyze a single fully connected layer:

\begin{equation} y_j = A\left(\sum_{i=1}^d x_i w_{i,j} + b_j\right) \end{equation}

where $A(\cdot)$ is the activation function. Based on previous experience, we try the following transformations:

\begin{equation} \tilde{w}_{i,j} = \frac{1}{2} w_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}_j = \frac{1}{\sqrt{2}} b_{\lceil j/2 \rceil} \label{eq:linear-trans} \end{equation}

That is, $b_j$ is transformed according to Formula \eqref{eq:repeat-sqrt2}, while for $w_{i,j}$, we attempt the following transformation:

\begin{equation} \begin{pmatrix} w_{1,1} & w_{1,2} & \cdots & w_{1,D} \\ w_{2,1} & w_{2,2} & \cdots & w_{2,D} \\ \vdots & \vdots & \ddots & \vdots \\ w_{d,1} & w_{d,2} & \cdots & w_{d,D} \end{pmatrix} \to \frac{1}{2} \begin{pmatrix} w_{1,1} & w_{1,1} & w_{1,2} & w_{1,2} & \cdots & w_{1,D} & w_{1,D} \\ w_{1,1} & w_{1,1} & w_{1,2} & w_{1,2} & \cdots & w_{1,D} & w_{1,D} \\ w_{2,1} & w_{2,1} & w_{2,2} & w_{2,2} & \cdots & w_{2,D} & w_{2,D} \\ w_{2,1} & w_{2,1} & w_{2,2} & w_{2,2} & \cdots & w_{2,D} & w_{2,D} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ w_{d,1} & w_{d,1} & w_{d,2} & w_{d,2} & \cdots & w_{d,D} & w_{d,D} \\ w_{d,1} & w_{d,1} & w_{d,2} & w_{d,2} & \cdots & w_{d,D} & w_{d,D} \end{pmatrix} \label{eq:matrix-trans} \end{equation}

Here $D$ is the size of the output dimension; we assume $D$ also doubles after the model is enlarged by 2 times. It is easy to see that this transformation is actually performing transformation \eqref{eq:repeat-sqrt2} on both the rows and columns of the weight matrix $w_{i,j}$. At this time,

\begin{equation} \sum_{i=1}^{2d} \tilde{x}_i \tilde{w}_{i,j} + \tilde{b}_j = \sum_{i=1}^{2d} \frac{x_{\lceil i/2 \rceil}}{\sqrt{2}} \frac{w_{\lceil i/2 \rceil, \lceil j/2 \rceil}}{2} + \frac{b_{\lceil j/2 \rceil}}{\sqrt{2}} = \frac{1}{\sqrt{2}} \left(\sum_{i=1}^d x_i w_{i, \lceil j/2 \rceil} + b_{\lceil j/2 \rceil}\right) \end{equation}

This shows that transformation \eqref{eq:matrix-trans} satisfies our ideal goal for the linear transformation layer—the enlarged result is "repeat and divide by $\sqrt{2}$". However, this is not enough because the fully connected layer also has an activation function $A(\cdot)$. The problem now is that $A(x/\sqrt{2})$ does not necessarily equal $A(x)/\sqrt{2}$. If they are not equal, we cannot make the overall transformation equivalent to "repeat and divide by $\sqrt{2}$".

In fact, the GeLU activation function used by BERT does not satisfy this identity. Linear activation functions (no activation) obviously satisfy it, and a common non-linear activation function that satisfies this is ReLU (including LeakyReLU). Therefore, a direct solution is to switch the FeedForward layer to the ReLU activation function. In fact, this is already a common choice for pre-trained models; Baidu's Ernie and Google's T5 models both use ReLU for their FeedForward activation functions.

So, is there no way for FeedForward layers with non-ReLU activations like BERT? Not necessarily. Since the FeedForward layer is a composition of two fully connected layers, we only need to divide by one less $\sqrt{2}$ when transforming the first fully connected layer, and divide by one more $\sqrt{2}$ when transforming the second. Specifically, the weights for the first fully connected layer become:

\begin{equation} \tilde{w}_{i,j} = \frac{1}{\sqrt{2}} w_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}_j = b_{\lceil j/2 \rceil} \label{eq:ffn1-trans} \end{equation}

In this case,

\begin{equation} A\left(\sum_{i=1}^{2d} \tilde{x}_i \tilde{w}_{i,j} + \tilde{b}_j\right) = A\left(\sum_{i=1}^d x_i w_{i, \lceil j/2 \rceil} + b_{\lceil j/2 \rceil} \right) \end{equation}

The result is a direct repetition of the original result without dividing by $\sqrt{2}$. Since this is the case, the subsequent fully connected layer should be divided by an additional factor of $\sqrt{2}$. That is, the transformation for the second fully connected layer's weights is:

\begin{equation} \tilde{w}_{i,j} = \frac{1}{2\sqrt{2}} w_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}_j = \frac{1}{2} b_{\lceil j/2 \rceil} \label{eq:ffn2-trans} \end{equation}

Thus, the effect of the entire FeedForward layer is equivalent to "repeat and divide by $\sqrt{2}$".

Attention

Now we come to the hardest "tough nut to crack"—the transformation of the Attention layer. The Attention layer first transforms each input vector into $q, k, v$ through three linear layers:

\begin{equation} q_j = \sum_{i=1}^d x_i w^{(q)}_{i,j} + b^{(q)}_j, \quad k_j = \sum_{i=1}^d x_i w^{(k)}_{i,j} + b^{(k)}_j, \quad v_j = \sum_{i=1}^d x_i w^{(v)}_{i,j} + b^{(v)}_j \end{equation}

According to the previous analysis of the FeedForward layer, if we want $q, k, v$ to all achieve the effect of "repeat and divide by $\sqrt{2}$", we only need to follow transformation \eqref{eq:matrix-trans}. But the Attention layer is not a simple fully connected layer. After the transformation, we need to check if the Attention matrix remains unchanged. Let's calculate the inner product:

\begin{equation} \sum_{i=1}^{2d'} \tilde{q}_i \tilde{k}_i = \sum_{i=1}^{2d'} \frac{q_{\lceil i/2 \rceil}}{\sqrt{2}} \frac{k_{\lceil i/2 \rceil}}{\sqrt{2}} = \sum_{i=1}^{d'} q_i k_i \end{equation}

where $d'$ is the corresponding head_size. This result tells us that the above transformation keeps the inner product unchanged, so it should also keep the Attention matrix unchanged. But there is a trap! For a model like T5, there is no scaling after the inner product, so this would indeed be fine. However, for a model like BERT, the inner product is divided by $\sqrt{d'}$ before applying Softmax. Once the model is enlarged, the divisor $\sqrt{d'}$ becomes $\sqrt{2d'}$. Keeping the inner product constant no longer preserves the Attention matrix. Therefore, we also need to multiply the weights of $q$ and $k$ by $\sqrt[4]{2}$. So the final transformation should be:

\begin{align} \tilde{w}^{(q)}_{i,j} = \frac{\sqrt[4]{2}}{2} w^{(q)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(q)}_j = \frac{\sqrt[4]{2}}{\sqrt{2}} b^{(q)}_{\lceil j/2 \rceil} \nonumber\\ \tilde{w}^{(k)}_{i,j} = \frac{\sqrt[4]{2}}{2} w^{(k)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(k)}_j = \frac{\sqrt[4]{2}}{\sqrt{2}} b^{(k)}_{\lceil j/2 \rceil} \label{eq:qkv-trans}\\ \tilde{w}^{(v)}_{i,j} = \frac{1}{2} w^{(v)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(v)}_j = \frac{1}{\sqrt{2}} b^{(v)}_{\lceil j/2 \rceil} \nonumber \end{align}

After this transformation, the Attention matrix remains unchanged, and $\tilde{v}_i = v_{\lceil i/2 \rceil} / \sqrt{2}$, so the final output result also follows $\tilde{o}_i = o_{\lceil i/2 \rceil} / \sqrt{2}$.

The content above is only an analysis for a single Attention head. In reality, Attention has multiple heads, and the outputs of multiple heads are concatenated and then passed through another fully connected layer. Of course, since each head is equal and independent, the above conclusion remains basicallly the same. The final fully connected layer only needs to be transformed according to Formula \eqref{eq:matrix-trans} to achieve the intended transformation effect for Attention. However, one effect brought by multi-head is that when repeating, we must repeat locally.

Specifically, when implementing multi-head, we don't actually perform multiple fully connected operations. Instead, we perform one large fully connected operation and then reshape. Consequently, we can compare the results of two different repetition methods after reshaping:

\begin{align*} [x_1, x_2, x_3, x_4, x_5, x_6] &\xrightarrow{\text{Method 1}} [x_1, x_1, x_2, x_2, x_3, x_3, x_4, x_4, x_5, x_5, x_6, x_6] \xrightarrow{\text{Reshape}} \begin{pmatrix} x_1, x_1, x_2, x_2 \\ x_3, x_3, x_4, x_4 \\ x_5, x_5, x_6, x_6 \end{pmatrix} \\ [x_1, x_2, x_3, x_4, x_5, x_6] &\xrightarrow{\text{Method 2}} [x_1, x_2, x_3, x_4, x_5, x_6, x_1, x_2, x_3, x_4, x_5, x_6] \xrightarrow{\text{Reshape}} \begin{pmatrix} x_1, x_2, x_3, x_4 \\ x_5, x_6, x_1, x_2 \\ x_3, x_4, x_5, x_6 \end{pmatrix} \end{align*}

Note that the result before enlargement and reshaping was $(x_1, x_2; x_3, x_4; x_5, x_6)$. Comparing the reshape results of the two different repetition methods, we find that the result of the second method is completely scrambled after reshaping and is not equivalent to repeating each head individually. Therefore, we must choose the first repetition method.

Output Probability Distribution

Through the above analysis, we can make the entire Encoder achieve the "repeat and divide by $\sqrt{2}$" effect after being enlarged by a factor of 2. Finally, what remains is the output part, which transforms the Encoder's output vector into a probability distribution over tokens. There are several cases here.

Models like GPT and T5 directly multiply the Encoder output by the transpose of the Embedding matrix as the logits for the probability distribution (possibly with an added bias). Since the Embedding matrix itself already contains the "repeat and divide by $\sqrt{2}$" operation and the Encoder's output is also "repeat and divide by $\sqrt{2}$", the two factors combined cancel out exactly. Therefore, from the perspective of the probability distribution, the output is completely unchanged.

However, BERT adds another fully connected layer. That is, it first connects to a GeLU-activated fully connected layer, then multiplies by the transpose of the Embedding matrix and adds a bias to get the logits. As discussed in the "FeedForward" section, non-ReLU activated fully connected layers cannot achieve the "repeat and divide by $\sqrt{2}$" effect, but can only achieve a pure "repetition" effect through transformation \eqref{eq:ffn1-trans}. So, to achieve the "divide by $\sqrt{2}$" effect again, the LayerNorm that follows it must be transformed with an additional division by $\sqrt{2}$.

Of course, if the activation is ReLU, then transforming according to Formula \eqref{eq:matrix-trans} allows for a completely unchanged result. Additionally, if, like mT5, the transformation matrix for the logits is not shared with the Embedding layer, then the output can also be kept completely unchanged by adjusting the final transformation matrix.

RoPE Position Encoding

The previous analysis applies only to cases where each neuron is independent, meaning there is no inherent correlation between any two components $x_i, x_j$ of the vector. However, if we use "Rotary Positional Embedding (RoPE)" in the model, this assumption no longer holds because RoPE transforms in groups of two components, i.e., $[x_1, x_2]$ is a group, $[x_3, x_4]$ is a group, and so on.

If we still follow the repetition transformation of \eqref{eq:repeat-sqrt2}, it becomes $[x_1, x_1]$ as a group, $[x_2, x_2]$ as a group, etc., which is inconsistent with the original grouping and will bring significant bias. In this case, the repetition should also be done in groups of two:

\begin{equation} [x_1, x_2, x_3, x_4, \dots, x_{d-1}, x_d] \xrightarrow{\frac{1}{\sqrt{2}}} [x_1, x_2, x_1, x_2, x_3, x_4, x_3, x_4, \dots, x_{d-1}, x_d, x_{d-1}, x_d] \label{eq:rope-repeat} \end{equation}

Of course, since the default RoPE has no trainable weights and changes in a fixed gradual manner, even if repeated in this way, consistency cannot be fully guaranteed. In other words, if RoPE is used, it is basically impossible to achieve a truly lossless enlargement. However, actual test results show that after repeating and enlarging in this way, the performance loss of the corresponding RoFormer is small and can be quickly recovered through continued training.

Conclusion

We can now confirm that for BERT, if the non-linear activation function is ReLU, then BERT can be directly losslessly enlarged. If the non-linear activation function is not ReLU, a transformation can achieve an enlargement with lossless MLM accuracy (in fact, through more delicate adjustments, completely lossless enlargement can also be achieved, but the transformation for each layer would be inconsistent and less elegant). For models like GPT and T5, regardless of the activation function used (including the GLU activation used by mT5, which can be appropriately customized), lossless enlargement is achievable.

Among these, the transformations to enlarge BERT weights by a factor of 2 are summarized as follows:

Embedding $\tilde{x}_i = \frac{1}{\sqrt{2}} x_{\lceil i/2 \rceil}$
LayerNorm $\tilde{\beta}_i = \frac{1}{\sqrt{2}} \beta_{\lceil i/2 \rceil}, \quad \tilde{\gamma}_i = \frac{1}{\sqrt{2}} \gamma_{\lceil i/2 \rceil}$
Attention $\tilde{w}^{(q)}_{i,j} = \frac{\sqrt[4]{2}}{2} w^{(q)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(q)}_j = \frac{\sqrt[4]{2}}{\sqrt{2}} b^{(q)}_{\lceil j/2 \rceil}$
$\tilde{w}^{(k)}_{i,j} = \frac{\sqrt[4]{2}}{2} w^{(k)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(k)}_j = \frac{\sqrt[4]{2}}{\sqrt{2}} b^{(k)}_{\lceil j/2 \rceil}$
$\tilde{w}^{(v)}_{i,j} = \frac{1}{2} w^{(v)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(v)}_j = \frac{1}{\sqrt{2}} b^{(v)}_{\lceil j/2 \rceil}$
$\tilde{w}^{(o)}_{i,j} = \frac{1}{2} w^{(o)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(o)}_j = \frac{1}{\sqrt{2}} b^{(o)}_{\lceil j/2 \rceil}$
FeedForward $\tilde{w}^{(1)}_{i,j} = \frac{1}{\sqrt{2}} w^{(1)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(1)}_j = b^{(1)}_{\lceil j/2 \rceil}$
$\tilde{w}^{(2)}_{i,j} = \frac{1}{2\sqrt{2}} w^{(2)}_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}^{(2)}_j = \frac{1}{2} b^{(2)}_{\lceil j/2 \rceil}$
Output Logits $\tilde{w}_{i,j} = \frac{1}{\sqrt{2}} w_{\lceil i/2 \rceil, \lceil j/2 \rceil}, \quad \tilde{b}_j = b_{\lceil j/2 \rceil}$

If it is another slightly different model, just follow the same logic as the previous analysis. If it is RoPE, change the repetition scheme to Formula \eqref{eq:rope-repeat}. If enlarging by $k$ times, replace the number 2 in most parts of the table with $k$. Simply put, if the Attention has no scaling (division by $\sqrt{d'}$) and the FeedForward activation is ReLU (or LeakyReLU), then the transformation for enlarging by $k$ times is the simplest: apply "repeat $k$ times and divide by $\sqrt{k}$" to every dimension of the weights.

Summary

This article analyzed the possibility of directly enlarging a Transformer model mathematically. Ultimately, several usable transformations were obtained, confirming the feasibility of lossless enlargement for Transformer models, providing a reference for weight-progressive training of large models.