By 苏剑林 | November 20, 2023
Broadly speaking, current Transformer length extrapolation techniques can be classified into two categories: one is post-hoc modification, such as NTK-RoPE, YaRN, and ReRoPE. These methods are characterized by directly modifying the inference model, achieving a certain degree of length extrapolation without fine-tuning. However, the downside is that they cannot maintain identity in the model's performance within the original training length. The other category is pre-training modification, such as ALIBI, KERPLE, XPOS, and HWFA. These can achieve length extrapolation without further modification, but the corresponding changes must be introduced before training. Consequently, they cannot be used on existing models without fine-tuning, and whether such methods can effectively scale up has not yet been widely recognized.
In this article, I will introduce an unexpectedly discovered length extrapolation scheme—"KeyNorm"—which involves applying $L_2$ Normalization to the Key sequence in Attention. It clearly belongs to the pre-training modification category, but the change to the Attention mechanism is very small, making it look very promising for scaling up.
Initial Motivation
The reason I called it an "unexpected discovery" is that the original motivation for this change was not length extrapolation, but an attempt to replace the scaling method in Scaled Dot-Product Attention. As we know, the standard definition of Attention (primarily considering the Causal scenario in this paper) is:
\begin{equation}\boldsymbol{o}_i = \frac{\sum_{j = 1}^i\exp\left(\frac{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}{\sqrt{d}}\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\frac{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}{\sqrt{d}}\right)},\quad \boldsymbol{q}_i,\boldsymbol{k}_j\in\mathbb{R}^d\label{eq:sdpa}\end{equation}
The scale factor $\frac{1}{\sqrt{d}}$ has been explained and even generalized multiple times, such as in "On the Initialization, Parameterization, and Standardization of Transformer", "Attention Scaling from the Invariance of Entropy", and "Attention Scaling from the Perspective of Gradient Maximization". The standard derivation is performed under the assumption that "$\boldsymbol{q}_i$ and $\boldsymbol{k}_j$ are independently sampled from a distribution with mean 0 and variance 1." Under this assumption, we also have:
\begin{equation}\Vert\boldsymbol{q}_i\Vert\approx \sqrt{d},\quad \Vert\boldsymbol{k}_j\Vert\approx \sqrt{d}\end{equation}
This is because:
\begin{equation}\Vert\boldsymbol{x}\Vert^2 = \sum_{i=1}^d x_i^2 = d\times\frac{1}{d}\sum_{i=1}^d x_i^2\approx d\,\mathbb{E}_{x\sim\mathcal{N}(0,1)}[x^2] = d\end{equation}
Related generalizations can also be found in "The Amazing Johnson-Lindenstrauss Lemma: Theoretical Part". This approximation implies that in the initial stage of Attention, equation $\eqref{eq:sdpa}$ has the same effect as the following two variants:
\begin{align}\color{red}{\text{Q}}\text{uery}\color{red}{\text{N}}\text{orm:}\quad\boldsymbol{o}_i =&\, \frac{\sum_{j = 1}^i\exp\left(\tilde{\boldsymbol{q}}_i\cdot \boldsymbol{k}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\tilde{\boldsymbol{q}}_i\cdot \boldsymbol{k}_j\right)},\qquad \tilde{\boldsymbol{q}}_i = \frac{\boldsymbol{q}_i}{\Vert\boldsymbol{q}_i\Vert} \label{eq:qna}\\[5pt]
\color{red}{\text{K}}\text{ey}\color{red}{\text{N}}\text{orm:}\quad\boldsymbol{o}_i =&\, \frac{\sum_{j = 1}^i\exp\left(\boldsymbol{q}_i\cdot \tilde{\boldsymbol{k}}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\boldsymbol{q}_i\cdot \tilde{\boldsymbol{k}}_j\right)},\qquad \tilde{\boldsymbol{k}}_j = \frac{\boldsymbol{k}_j}{\Vert\boldsymbol{k}_j\Vert} \label{eq:kna}
\end{align}
Therefore, I had the idea to verify which of these two variants is superior compared to the standard equation $\eqref{eq:sdpa}$. For convenience, we can refer to them as "Query/Key-Normalized Dot-Product Attention," abbreviated as "QNA" and "KNA" respectively.
Furthermore, since we can have QueryNorm and KeyNorm, it is natural to consider normalizing both. Thus, we also experimented with the following "Scaled Cosine Attention (CosA)":
\begin{equation}\boldsymbol{o}_i = \frac{\sum_{j = 1}^i\exp\left(\lambda\,\tilde{\boldsymbol{q}}_i\cdot \tilde{\boldsymbol{k}}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\lambda\,\tilde{\boldsymbol{q}}_i\cdot \tilde{\boldsymbol{k}}_j\right)} = \frac{\sum_{j = 1}^i\exp\left(\lambda\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\lambda\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)\right)} \label{eq:cosa}
\end{equation}
Where $\lambda$ adopts the results from "Attention Scaling from the Perspective of Gradient Maximization", specifically $\lambda = 4\log n$ (the original paper suggests 3.5, but since the training length here is relatively small, 4 is more precise), where $n$ is fixed at half the training length or dynamically set to the position ID plus 1.
Results First
Following the same experimental setup for length extrapolation used previously: small models with 100 million parameters, GAU architecture, trained for the same number of steps (given time constraints, the models are not yet fully trained at this step count), a training length of 512, and evaluating extrapolation to a length of 4096. The experimental results are shown in the table below. "Baseline" refers to equation $\eqref{eq:sdpa}$, and "$-\log n$" refers to the addition of the length-dependent scaling factor introduced in "Attention Scaling from the Invariance of Entropy". The evaluation metric is the per-token accuracy of the language model (higher is better).
\[\begin{array}{c|cc}
\hline
\text{Test Length} & 512(\text{Train}) & 4096(\text{Repeated}) & 4096(\text{Non-repeated}) \\
\hline
\text{Baseline} & 49.41\% & 24.17\% & 23.16\% \\
\text{Baseline-}\log n & 49.40\% & 24.60\% & 24.02\% \\
\hline
\text{QNA} & 49.55\% & 22.45\% & 22.18\% \\
\text{QNA-}\log n & 49.42\% & 19.55\% & 18.74\% \\
\text{KNA} & 49.60\% & 61.08\% & 47.69\% \\
\text{KNA-}\log n & 49.58\% & 63.17\% & 46.40\%\\
\text{CosA} & 49.73\% & 58.90\% & 46.98\% \\
\text{CosA-}\log n & 49.67\% & 64.74\% & 48.95\% \\
\hline
\end{array}\]
From the table, we can observe: 1. Both QueryNorm and KeyNorm achieved better results at the training length. Although this advantage is very slight and will likely become negligible as training progresses, the advantage is very stable, suggesting a possibility for more stable training. 2. KeyNorm provides a very significant boost to length extrapolation—this is the "unexpected gift" of the experiment!
Note that unlike NTK-RoPE and YaRN, which require modifying the model during the inference stage, the length extrapolation for KNA and CosA here is achieved without any changes during inference. Therefore, some readers might wonder: since KNA and CosA already have such good extrapolation performance without modifications, would the effect be even better if combined with extrapolation techniques like NTK-RoPE or YaRN? To investigate this, I also conducted tests, and the results are shown in the table below:
\[\begin{array}{c|cc}
\hline
\text{Test Length} & 512(\text{Train}) & 4096(\text{Repeated}) & 4096(\text{Non-repeated}) \\
\hline
\text{Baseline} & 49.41\% & 24.17\% & 23.16\% \\
\text{Baseline-NTK} & 49.41\% & 60.57\% & 42.20\% \\
\text{Baseline-YaRN} & 49.41\% & 80.10\% & 47.45\% \\
\text{Baseline-ReRoPE} & 49.41\% & 76.11\% & 47.82\% \\
\hline
\text{Baseline-}\log n & 49.40\% & 24.60\% & 24.02\% \\
\text{Baseline-}\log n\text{-NTK} & 49.40\% & 75.86\% & 47.06\% \\
\text{Baseline-}\log n\text{-YaRN} & 49.40\% & 82.57\% & 46.52\% \\
\text{Baseline-}\log n\text{-ReRoPE} & 49.40\% & 85.47\% & 48.87\% \\
\hline
\text{QNA} & 49.55\% & 22.45\% & 22.18\% \\
\text{QNA-NTK} & 49.55\% & 52.28\% & 39.88\% \\
\text{QNA-YaRN} & 49.55\% & 82.53\% & 47.50\% \\
\text{QNA-ReRoPE} & 49.55\% & 78.22\% & 47.72\% \\
\hline
\text{QNA-}\log n & 49.42\% & 19.55\% & 18.74\% \\
\text{QNA-}\log n\text{-NTK} & 49.42\% & 57.44\% & 41.56\% \\
\text{QNA-}\log n\text{-YaRN} & 49.42\% & 80.08\% & 45.16\% \\
\text{QNA-}\log n\text{-ReRoPE} & 49.42\% & 84.71\% & 48.31\% \\
\hline
\text{KNA} & 49.60\% & 61.08\% & 47.69\% \\
\text{KNA-NTK} & 49.60\% & 64.44\% & 43.02\% \\
\text{KNA-YaRN} & 49.60\% & 84.19\% & 47.44\% \\
\text{KNA-ReRoPE} & 49.60\% & 77.76\% & 47.73\% \\
\hline
\text{KNA-}\log n & 49.58\% & 63.17\% & 46.40\%\\
\text{KNA-}\log n\text{-NTK} & 49.58\% & 79.05\% & 47.43\%\\
\text{KNA-}\log n\text{-YaRN} & 49.58\% & 83.95\% & 47.16\%\\
\text{KNA-}\log n\text{-ReRoPE} & 49.58\% & 85.48\% & 48.78\%\\
\hline
\text{CosA} & 49.73\% & 58.90\% & 46.98\% \\
\text{CosA-NTK} & 49.73\% & 62.50\% & 42.77\% \\
\text{CosA-YaRN} & 49.73\% & 83.40\% & 47.80\% \\
\text{CosA-ReRoPE} & 49.73\% & 77.82\% & 47.80\% \\
\hline
\text{CosA-}\log n & 49.67\% & 64.74\% & 48.39\% \\
\text{CosA-}\log n\text{-NTK} & 49.67\% & 78.97\% & 47.46\% \\
\text{CosA-}\log n\text{-YaRN} & 49.67\% & 82.28\% & 45.72\% \\
\text{CosA-}\log n\text{-ReRoPE} & 49.67\% & 85.67\% & 48.39\% \\
\hline
\end{array}\]
This table is a bit verbose, primarily to give everyone a comprehensive sense of the difference in effectiveness of mainstream length extrapolation techniques. You can compare the dimensions you are interested in, but note that if you are looking at length extrapolation performance, you should focus on the "Non-repeated" column, with the "Repeated" column as secondary. Looking at the table above, the results are quite surprising: KeyNorm seems to be "immune" to existing RoPE extrapolation techniques; stacking NTK or YaRN did not result in a significant improvement and might even lead to a decline. However, overall, the "Repeated" column still shows a significant improvement, while the improvement in the "Non-repeated" column is not prominent. These results indicate that while KeyNorm still struggles with the problem of effectively identifying positions exceeding the training length (hence the "Repeated" results are not high), it effectively avoids the PPL explosion problem (hence the "Non-repeated" results are decent).
This might be good news for those working on Long Context: on the one hand, unlike ALIBI or KERPLE, KeyNorm's length extrapolation does not require adding local constraints and requires no modification after training—it is purely a "free lunch." It even looks like training effectiveness improves after adding KeyNorm. On the other hand, because it is non-local, it can be used for continued training on longer texts, and you no longer have to struggle between choosing PI or ABF; for KeyNorm, you don't need to change anything.
Principle Analysis
Despite this being an unexpected discovery, we still need to attempt to explain it, otherwise it remains just a fluke. So in this section, let's think about why KeyNorm helps with length extrapolation.
Let's return to equation $\eqref{eq:sdpa}$. The correlation score between the $i$-th token and the $j$-th token is calculated by the dot product:
\begin{equation}s(j|i) = \boldsymbol{q}_i\cdot \boldsymbol{k}_j = \Vert\boldsymbol{q}_i\Vert \Vert\boldsymbol{k}_j\Vert \cos(\boldsymbol{q}_i,\boldsymbol{k}_j),\quad p(j|i) = \frac{\exp\left(\frac{s(j|i)}{\sqrt{d}}\right)}{\sum_{j=1}^i \exp\left(\frac{s(j|i)}{\sqrt{d}}\right)}\end{equation}
For the second equality, based on geometric meaning, we decompose it into the product of their respective norms and the cosine of the included angle. Attention $p(j|i)$ is a conditional probability. $\Vert\boldsymbol{q}_i\Vert$ is only related to the current position $i$; it does not change the relative size of attention but only its sparsity. $\Vert\boldsymbol{k}_j\Vert$ has the ability to change the relative size of $p(j|i)$, but it does not involve the interaction between $i$ and $j$; it can be used to express some absolute signals. For example, Scissorhands shows that the attention of tokens at certain absolute positions remains very high, which could be expressed using $\Vert\boldsymbol{k}_j\Vert$. The remaining $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ is used to express the interaction between $i$ and $j$, and it is the term with the greatest degree of freedom.
Obviously, to increase the relative importance of a certain position $j$, the model has two choices: 1. increase the norm $\Vert\boldsymbol{k}_j\Vert$; 2. increase $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$, i.e., reduce the angle between $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$. However, due to the existence of the "curse of dimensionality," significantly changing the angle in a high-dimensional space is relatively difficult. Therefore, if the task can be completed by increasing the norm $\Vert\boldsymbol{k}_j\Vert$, the model will prioritize doing so. The direct consequence is that the training of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ might be insufficient.
Here, I make an assertion (conjecture):
The insufficient training of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ is the main reason why Attention cannot perform length extrapolation.
Insufficient training of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ means that the angles between the trained $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$ form only a finite set. During length extrapolation, the model faces a larger set, rendering it unable to make correct predictions. A careful consideration of the derivation in the YaRN paper reveals that the reason NTK and YaRN are effective is that they modify the implementation of RoPE during inference, causing the angles of $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$ to fall within the original finite set from the training phase, avoiding the larger, unseen set and turning extrapolation into interpolation. ReRoPE is even more straightforward, directly truncating relative positions outside the window, which ensures that the position encodings encountered during inference are never "unfamiliar." These techniques, to some extent, indirectly validate this assertion.
Starting from this assertion, the cause of KeyNorm's length extrapolation becomes simple. Whether it is KNA, which only performs KeyNorm, or CosA, which performs both QueryNorm and KeyNorm, they both exclude $\Vert\boldsymbol{k}_j\Vert$ from the definition of Attention. Consequently, to change the relative importance of $j$, the model has only one choice: "adjust $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$." This forces the model to train and utilize $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ more thoroughly, thereby indirectly promoting length extrapolation. Furthermore, I also experimented with the combination of "KeyNorm + NoPE," but found no length extrapolation capability. This suggests that RoPE also plays an important role in the length extrapolation of KeyNorm. In fact, this is not hard to understand: RoPE rotates $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$, which helps expand the range of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ during training, making its training more thorough.
Has any work already attempted QueryNorm and KeyNorm? Yes. The 2020 paper "Query-Key Normalization for Transformers" experimented with CosA and proposed a similar scale factor based on the logarithm of length, but did not discuss length extrapolation. Additionally, Google's paper "Scaling Vision Transformers to 22 Billion Parameters" earlier this year also added Norm to Query and Key, but it used LayerNorm. LayerNorm or RMSNorm both include learnable gamma parameters, which means the vector norms after normalization are not necessarily constant. Therefore, it is difficult to say whether they can achieve the same length extrapolation effect as in this article.
Summary
This article introduced an unexpectedly discovered length extrapolation scheme, "KeyNorm"—applying $L_2$ normalization to the Key sequence in Attention. It achieves better results at the training length and shows a significant improvement in length extrapolation. It belongs to the "pre-training modification" category. Compared to other pre-training modification schemes like ALIBI and KERPLE, it has no local constraints and thus is more promising for scaling up. Compared to "post-hoc modification" schemes like NTK-RoPE and YaRN, it does not sacrifice performance within the training length during extrapolation.