Re-exploring Shared Embeddings at the Output of Language Models

By 苏剑林 | July 20, 2023

In the early days of pre-training, it was common practice to reuse Embedding weights at the output of a language model. For instance, BERT, the first version of T5, and early versions of GPT all employed this operation. This was because when the model backbone is small and the vocabulary is large, the number of parameters in the Embedding layer is quite significant. Adding an independent weight matrix of the same size at the output would cause a sharp increase in VRAM consumption. However, as model parameter scales have grown, the Proportion of the Embedding layer has become relatively smaller. Furthermore, studies like "Rethinking embedding coupling in pre-trained language models" have suggested that sharing Embeddings might have some negative impacts. Consequently, the practice of sharing Embeddings has become increasingly rare.

The purpose of this article is to analyze the problems that may arise when sharing Embedding weights and to explore how to perform initialization and parameterization more effectively. Although shared Embeddings may seem "outdated," it remains an interesting research topic.

Shared Weights

The practice of reusing Embedding weights at the output of a language model is referred to in English as "Tied Embeddings" or "Coupled Embeddings." The core idea is that the Embedding matrix and the projection matrix used to transform hidden states into logits at the output are the same size (differing only by a transpose). Since this parameter matrix is quite large, to avoid unnecessary waste, the same weights are shared directly, as shown in the diagram below:

Schematic of a Transformer with shared Embedding weights Schematic of a Transformer with shared Embedding weights

The most direct consequence of shared Embeddings is that it can lead to a very large initial loss during pre-training. This is because we typically use techniques like DeepNorm to reduce training difficulty, which involves initializing the model's residual branches to be close to zero. In other words, the model acts approximately as an identity function in the initial stage, making the initial model equivalent to a 2-gram model with shared Embeddings. Next, we will derive the reasons why such a 2-gram model has a high loss and analyze some solutions.

Preparations

Before formally starting the derivation, we need to prepare some basic conclusions.

First, it must be clarified that we primarily analyze the results of the initial stage. At this point, the weights are independently and identically distributed (i.i.d.) sampled from a distribution with "mean 0 and variance $\sigma^2$." This allows us to estimate certain sums through expected values. For example, for a vector $\boldsymbol{w}=(w_1,w_2,\cdots,w_d)$, we have:

\begin{equation}\mathbb{E}\left[\Vert \boldsymbol{w}\Vert^2\right] = \mathbb{E}\left[\sum_i w_i^2\right] = \sum_i \mathbb{E}\left[w_i^2\right] = d\sigma^2\label{eq:norm}\end{equation}

Thus, we can take $\Vert \boldsymbol{w}\Vert\approx \sqrt{d}\sigma$. How large is the error? We can gauge it through its variance. To do this, we first find its second moment:

\[\begin{aligned}\mathbb{E}\left[\Vert \boldsymbol{w}\Vert^4\right] =&\, \mathbb{E}\left[\left(\sum_i w_i^2\right)^2\right] = \mathbb{E}\left[\sum_i w_i^4 + \sum_{i,j|i\neq j} w_i^2 w_j^2\right] \\ =&\, \sum_i \mathbb{E}\left[w_i^4\right] + \sum_{i,j|i\neq j} \mathbb{E}\left[w_i^2\right] \mathbb{E}\left[w_j^2\right] \\ =&\, d\,\mathbb{E}\left[w^4\right] + d(d-1) \sigma^4 \\ \end{aligned}\]

If the sampling distribution is a normal distribution, we can directly calculate $\mathbb{E}\left[w^4\right]=3\sigma^4$, so:

\begin{equation}\mathbb{V}ar\left[\Vert \boldsymbol{w}\Vert^2\right] = \mathbb{E}\left[\Vert \boldsymbol{w}\Vert^4\right] - \mathbb{E}\left[\Vert \boldsymbol{w}\Vert^2\right]^2 = 2d\sigma^4\end{equation}

The size of this variance represents the degree of approximation for $\Vert \boldsymbol{w}\Vert\approx \sqrt{d}\sigma$; that is, the smaller the original sampling variance $\sigma^2$, the higher the degree of approximation. Specifically, a common sampling variance is $1/d$ (corresponding to $\Vert \boldsymbol{w}\Vert\approx 1$, a unit vector), which yields a variance of $2/d$ when substituted into the above formula, meaning that the higher the dimension, the higher the approximation degree. Furthermore, if the sampling distribution is not normal, one can recalculate $\mathbb{E}\left[w^4\right]$ separately or simply use the normal distribution result as a reference; after all, it is just an estimation.

If $\boldsymbol{v}=(v_1,v_2,\cdots,v_d)$ is another i.i.d. vector, then we can estimate the inner product using the same method. The result is:

\begin{equation}\mathbb{E}\left[\boldsymbol{w}\cdot\boldsymbol{v}\right] = \mathbb{E}\left[\sum_i w_i v_i\right] = \sum_i \mathbb{E}\left[w_i\right] \mathbb{E}\left[v_i\right] = 0\label{eq:dot}\end{equation}

And:

\[\begin{aligned}\mathbb{E}\left[(\boldsymbol{w}\cdot\boldsymbol{v})^2\right] =&\, \mathbb{E}\left[\left(\sum_i w_i v_i\right)^2\right] = \mathbb{E}\left[\sum_i w_i^2 v_i^2 + \sum_{i,j|i\neq j} w_i v_i w_j v_j\right] \\ =&\, \sum_i \mathbb{E}\left[w_i^2\right]\mathbb{E}\left[w_j^2\right] + \sum_{i,j|i\neq j} \mathbb{E}\left[w_i\right]\mathbb{E}\left[v_i\right]\mathbb{E}\left[w_j\right]\mathbb{E}\left[v_j\right] \\ =&\, d \sigma^4 \\ \end{aligned}\]

Similarly, taking $\sigma^2=1/d$, the variance is $1/d^3$, and the degree of approximation increases with dimensionality. These two results can be considered statistical versions of the conclusions in "The distribution of the angle between two random vectors in n-dimensional space" and "The Amazing Johnson-Lindenstrauss Lemma: Theoretical Part".

Loss Analysis

For a language model, the ultimate goal is to output an $n$-way distribution per token, where $n$ is the vocabulary size. Suppose we directly output a uniform distribution, where the probability of each token is $1/n$. It is not difficult to calculate that the cross-entropy loss would be $\log n$. This means that a reasonable initialization should not result in an initial loss significantly exceeding $\log n$, because $\log n$ represents the simplest uniform distribution. Exceeding $\log n$ significantly is equivalent to saying the model is far worse than a uniform distribution, as if it were making mistakes intentionally, which is not reasonable.

So, why does this occur with shared Embeddings? Suppose the initial Embeddings are $\{\boldsymbol{w}_1,\boldsymbol{w}_2,\cdots,\boldsymbol{w}_n\}$. As previously mentioned, in the initial stage, the residual branches are close to zero. Therefore, if the input token is $i$, the model output is the Embedding $\boldsymbol{w}_i$ after Normalization. Common Normalization methods are Layer Norm or RMS Norm. Since the initialization distribution has a zero mean, Layer Norm and RMS Norm are roughly equivalent, so the output is:

\begin{equation}\frac{\boldsymbol{w}_i}{\Vert\boldsymbol{w}_i\Vert \big/\sqrt{d}} = \frac{\boldsymbol{w}_i}{\sigma}\end{equation}

Next, the Embedding is reused by taking the inner product followed by Softmax. The distribution established is essentially:

\begin{equation}p(j|i) = \frac{e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_j / \sigma}}{\sum\limits_k e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_k / \sigma}}\end{equation}

The corresponding loss function is:

\begin{equation}-\log p(j|i) = \log \sum\limits_k e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_k / \sigma} - \boldsymbol{w}_i\cdot \boldsymbol{w}_j \big/ \sigma\end{equation}

The language modeling task is to predict the next token, and we know that the proportion of repeated tokens in natural sentences is very small. Thus, we can basically assume $j\neq i$. According to the result of \eqref{eq:dot}, we have $\boldsymbol{w}_i\cdot \boldsymbol{w}_j\approx 0$. Therefore, the initial loss function is:

\begin{equation}-\log p(j|i) \approx \log \sum_k e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_k / \sigma}=\log \left(e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_i / \sigma} + \sum\limits_{k|k\neq i} e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_k / \sigma}\right)\approx\log \left(e^{d \sigma} + (n-1)\right)\label{eq:loss}\end{equation}

The second $\approx$ again uses formulas \eqref{eq:norm} and \eqref{eq:dot}. For common initialization variances $\sigma^2$, which are either a constant or $1/d$ (in which case $e^{d \sigma}=e^{\sqrt{d}}$), when $d$ is large, $e^{d \sigma}$ dominates. Consequently, the loss will be on the order of $\log e^{d\sigma}=d\sigma$, which can easily exceed the $\log n$ of a uniform distribution.

Some Countermeasures

Based on the derivation above, we can design target-oriented countermeasures. A relatively direct solution is to adjust the initialization. According to formula \eqref{eq:loss}, we only need to make $e^{d\sigma}=n$ so that the initial loss becomes on the order of $\log n$, which means the standard deviation of the initialization should be changed to $\sigma=(\log n)/d$.

Generally, we want the initialization variance of parameters to be as large as possible so that gradients are less likely to underflow. However, $\sigma=(\log n)/d$ can sometimes be too small. For this reason, we can consider another approach: obviously, the reason \eqref{eq:loss} is too large is the appearance of $e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_i / \sigma}$. Since the two vectors $\boldsymbol{w}_i$ are the same, their inner product becomes the squared norm, making it very large. If they can be made different, this dominant term will not appear.

Therefore, the simplest method is to simply not share the Embedding. In this case, we have $e^{\boldsymbol{w}_i\cdot \boldsymbol{v}_i / \sigma}$ instead of $e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_i / \sigma}$. Using \eqref{eq:dot} instead of \eqref{eq:norm} as an approximation, formula \eqref{eq:loss} asymptotically approaches $\log n$. If we still want to keep shared Embeddings, we can follow the final Normalization with an orthogonally initialized projection layer. Thus $e^{\boldsymbol{w}_i\cdot \boldsymbol{w}_i / \sigma}$ becomes $e^{(\boldsymbol{w}_i\boldsymbol{P})\cdot \boldsymbol{w}_i / \sigma}$. According to the Johnson-Lindenstrauss Lemma, a vector after random projection is approximately independent of the original vector, so it approximates the non-shared case. This is actually BERT's solution. Specifically, this projection layer can also be generalized by adding a bias and an activation function.

If one does not want to introduce any extra parameters at all, one can consider "shuffling" the various dimensions of $\boldsymbol{w}_i$ after Normalization, for example:

\begin{equation}\mathcal{S}[\boldsymbol{w}] = \boldsymbol{w}[d/2:]\circ\boldsymbol{w}[:d/2]\end{equation}

where $\circ$ is the concatenation operation. Then $\mathcal{S}[\boldsymbol{w}_i]$ and $\boldsymbol{w}_i$ are nearly orthogonal, and their inner product is naturally close to 0. This is equivalent to (in the initial stage) splitting the original $n\times d$ Embedding matrix into two $n\times (d/2)$ matrices and constructing a 2-gram model without shared Embeddings. Additionally, other shuffling operations can be considered, such as the Reshape-Transpose-Reshape method in ShuffleNet.

In the author's experiments, directly changing the initialization standard deviation to $\sigma=(\log n)/d$ resulted in the slowest convergence speed. The convergence speeds of the other methods were similar. As for the final outcome, all methods seemed to perform roughly the same.

Summary

This article revisited the operation of sharing Embedding weights at the output of a language model, derived the possibility that direct reuse of Embeddings for output projection can lead to excessive loss, and explored some solutions.