Looking at the Scale Operation of Attention from the Perspective of Gradient Maximization

By 苏剑林 | October 22, 2023

We know that the Scale factor of Scaled Dot-Product Attention is $\frac{1}{\sqrt{d}}$, where $d$ is the dimension of $\boldsymbol{q}, \boldsymbol{k}$. The general explanation for this Scale factor is: if not divided by $\sqrt{d}$, the initial Attention would be very close to a one-hot distribution, which causes gradient vanishing and prevents the model from being trained. However, it can be proven that when the Scale equals 0, there is also a gradient vanishing problem, which means the Scale being either too large or too small is problematic.

So, how large should the Scale be? Is $\frac{1}{\sqrt{d}}$ the optimal Scale? This article attempts to answer this question from the perspective of gradients.

Existing Results

In "Briefly Discussion on Transformer Initialization, Parameterization, and Standardization", we derived the standard Scale factor $\frac{1}{\sqrt{d}}$. The derivation logic is simple: assuming that in the initial stage, $\boldsymbol{q}, \boldsymbol{k} \in \mathbb{R}^d$ are sampled from a distribution with "mean 0 and variance 1", we can calculate \begin{equation}\mathbb{V}ar[\boldsymbol{q}\cdot\boldsymbol{k}] = d\end{equation} Then, we divide $\boldsymbol{q}\cdot\boldsymbol{k}$ by $\sqrt{d}$ to make the variance of the Attention Score equal to 1. In other words, previous derivations were purely based on the faith that "mean 0 and variance 1" would be better, but they did not explain why making the variance of the Attention Score 1 is necessary, nor did they evaluate whether $\frac{1}{\sqrt{d}}$ truly solves the gradient vanishing problem.

Of course, from existing experiments, $\frac{1}{\sqrt{d}}$ has at least to some extent alleviated this problem, but this is an experimental result after all. We still hope to understand theoretically how much "to some extent" actually is.

Calculating Gradient

Since this involves gradients, the best way is to calculate the gradient and define an optimization objective. Let $p_i = e^{\alpha s_i}/Z$, $i \in \{1,2,...,n\}$, where $Z=\sum_i e^{\alpha s_i}$ is the normalization factor. We can directly calculate: \begin{equation}\frac{\partial p_i}{\partial s_j} = \left\{\begin{aligned} \alpha(p_i - p_i^2),&\quad i=j\\ -\alpha p_i p_j,&\quad i\neq j \end{aligned}\right.\end{equation} Or it can be written as $\partial p_i/\partial s_j = \alpha(p_i\delta_{i,j} - p_i p_j)$. Obviously, when $\alpha\to 0$, the gradient is 0; when $\alpha\to\infty$, there is only one 1 among $p_i$ and the rest are 0 (assuming $s_i$ has a unique maximum value), and the gradient is also 0.

To be more conducive to optimization, we should select $\alpha$ so that the gradient is maximized. For this purpose, we use the L1 norm as a measure of the gradient size: \begin{equation}\frac{1}{2}\left\Vert\frac{\partial p}{\partial s}\right\Vert_1=\frac{1}{2}\sum_{i,j}\left|\frac{\partial p_i}{\partial s_j}\right|=\frac{1}{2}\sum_i \alpha(p_i - p_i^2) + \frac{1}{2}\sum_{i\neq j} \alpha p_i p_j = \alpha\left(1 - \sum_i p_i^2\right)\label{eq:target}\end{equation} It is not difficult to guess from the final result that the fundamental reason for choosing L1 over others is that the calculation result of the L1 norm is sufficiently simple. It is worth pointing out that $\sum_i p_i^2$ appears here, which is essentially the "Rényi entropy" introduced in "How to Measure Data Sparsity?". Similar to information entropy, it is also a measure of uncertainty.

With an optimization target, we can proceed with maximization. Note that the definition of $p_i$ also contains $\alpha$, so this is a complex nonlinear target with respect to $\alpha$. Although obtaining an analytical solution seems impossible, we can find approximate solutions for some specific cases.

Normal Distribution

First, we can continue from previous results. After we make the mean of the Attention Score 0 and the variance 1 by dividing by $\sqrt{d}$, we can approximately assume $s_i\sim\mathcal{N}(0,1)$, and then find the optimal solution for $\alpha$. If $\alpha=1$, it means the original $\frac{1}{\sqrt{d}}$ is the optimal Scale ratio; otherwise, $\frac{\alpha}{\sqrt{d}}$ is the optimal Scale ratio.

We use expectation to estimate the sum: \begin{equation}\sum_i p_i^2 = \frac{\sum_i e^{2\alpha s_i}}{\left(\sum_i e^{\alpha s_i}\right)^2} = \frac{\frac{1}{n}\sum_i e^{2\alpha s_i}}{n\left(\frac{1}{n}\sum_i e^{\alpha s_i}\right)^2} \approx \frac{\mathbb{E}_s[e^{2\alpha s}]}{n\left(\mathbb{E}_s[e^{\alpha s}]\right)^2}\label{eq:approx}\end{equation} For $s$ following the standard normal distribution, we have \begin{equation}\mathbb{E}_s[e^{\alpha s}] = \int \frac{1}{\sqrt{2\pi}}e^{-s^2/2}e^{\alpha s} ds = e^{\alpha^2 / 2}\label{eq:normal}\end{equation} Substituting this into the above formula, and then into equation $\eqref{eq:target}$, we get \begin{equation}\alpha\left(1 - \sum_i p_i^2\right)\approx\alpha\left(1 - \frac{e^{\alpha^2}}{n}\right)\end{equation} The final approximation, though simple enough, is actually not easy to find the maximum for. However, that's fine; we can traverse some $n$ and solve for $\alpha^*$ that takes the maximum value numerically. This way, we can roughly see the relationship between $\alpha^*$ and $n$. The reference Mathematica code is as follows:

(* Define function *)
f[a_, n_] := a*(1 - Exp[a^2]/n)
(* Find the point a corresponding to the maximum of the function *)
FindArg[n_] :=
 Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(* Given the range of n *)
nRange = 40*Range[1, 500];
(* Solve for a corresponding to each n *)
args = FindArg /@ nRange;
(* Plot the relationship between a and n *)
ListLinePlot[{args, 0.84*Log[nRange]^0.5},
 DataRange -> {40, 20000}, AxesLabel -> {"n", "a"},
 PlotLegends -> {Row[{"a", Superscript["", "*"]}],
 TraditionalForm[HoldForm[0.84*Sqrt[Log[n]]]]}]

Through fitting, the author found that within a certain range, the optimal point $\alpha^*$ and $n$ roughly satisfy the relationship $\alpha\approx 0.84\sqrt{\log n}$. Therefore, the corresponding approximation function is also plotted together:

Relationship between optimal alpha and n for standard normal distribution
Relationship between optimal alpha and n for standard normal distribution

It can be seen that within a fairly large range, the optimal values for $\alpha^*$ all lie between $2\sim 3$. Therefore, as a compromise, blindly taking $\frac{2.5}{\sqrt{d}}$ as the Attention Scale factor is theoretically more conducive to optimization.

Cosine Distribution

Now we consider another less common example: when we apply $l_2$ normalization to both $\boldsymbol{q}, \boldsymbol{k}$ to make them unit vectors, their inner product becomes the cosine of the angle between them. That is, $s_i$ approximately follows the distribution of the cosine of the angle between two random vectors in $d$-dimensional space. Some readers may not be familiar with this distribution, but we previously explored it in "Distribution of the angle between two random vectors in n-dimensional space". Its probability density has the form: \begin{equation}p(s)\propto (1-s^2)^{(d-3)/2}\end{equation}

It doesn't look complicated, but in fact, this form is much harder to handle than the normal distribution, mainly because $\mathbb{E}_s[e^{\alpha s}]$ can no longer be expressed with elementary functions like equation $\eqref{eq:normal}$. However, this is not a big problem for Mathematica numerical solutions. Following the same logic as the previous section, the approximation in equation $\eqref{eq:approx}$ also applies. We first solve for the maximum value numerically and then fit it. The results are as follows (in the figure $d=128$, and $\alpha^*$ is correlated with $d$):

Relationship between optimal alpha and n for cosine distribution
Relationship between optimal alpha and n for cosine distribution

It can be seen that $\alpha^*$ also fits well with $3.5\log n$ (if $d$ is changed, the coefficient $3.5$ will change). Within a considerable range, $\alpha^*$ values are between $25\sim 35$. Therefore, if the $\cos$ value is used as the Attention Score, it needs to be multiplied by a Scale between $25\sim 35$ to make the model easier to train. This also explains why when using $\cos$ values to construct Softmax distributions (such as in AM-Softmax, SimCSE, etc.), we need to multiply the $\cos$ by a Scale of about 30, because without it, the model is very hard to train.

For different $d$ and $n$, readers can modify the following code to calculate the optimal $\alpha$:

(* Define function *)
h[a_] :=
 Integrate[Exp[a*s]*(1 - s^2)^((d - 3)/2), {s, -1, 1},
 Assumptions -> {d > 10}]
g[a_] = h[a]/h[0] // FullSimplify;
f[a_, n_] := a (1 - g[2*a]/g[a]^2/n) /. {d -> 128}
(* Find the point a corresponding to the maximum of the function *)
FindArg[n_] :=
 Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(* Given range of n *)
nRange = 40*Range[1, 500];
(* Solve for a corresponding to each n *)
args = FindArg /@ nRange;
(* Plot relation between a and n *)
ListLinePlot[{args, 3.5*Log[nRange]},
 DataRange -> {40, 20000}, AxesLabel -> {"n", "a"},
 PlotLegends -> {Row[{"a", Superscript["", "*"]}],
 TraditionalForm[HoldForm[3.5*Log[n]]]}]

Related Thoughts

The title and results of this article, especially the result of $\alpha$ being approximately proportional to $\log n$ in the cosine distribution, easily reminds us of another article discussing Attention Scale: "Looking at Attention's Scale Operation from the Perspective of Entropy Invariance". In fact, there is indeed a connection between the two articles. The "Rényi entropy" appeared in our optimization target $\eqref{eq:target}$, while the entropy in "entropy invariance" refers to Shannon information entropy. The properties of the two are largely consistent. Maximizing equation $\eqref{eq:target}$ puts it in a "slowly changing" region, which means the "Rényi entropy" changes very slowly with respect to $n$, which also means the information entropy changes very slowly with respect to $n$, roughly approximating entropy invariance.

In addition, for bidirectional Attention (Encoder), assuming the training sample lengths are the same, $n$ is a constant. We can calculate the corresponding optimal $\alpha$ based on $n$ and then fix it in the model. However, for unidirectional Attention (Decoder), the $n$ for each token is actually different (position id plus 1). Therefore, theoretically, it is impossible to maximize equation $\eqref{eq:target}$ for all tokens. But since $\alpha^*$ changes slowly with respect to $n$, taking an approximately constant value is fine. For example, one could take $n=L_{\max} / 2$, which would be quite friendly to the gradients of most tokens.

Summary

This article explored the selection of the Attention Scale factor from the perspective of gradients. As is well known, the "standard answer" regarding this Scale factor is $\frac{1}{\sqrt{d}}$, but its derivation did not discuss its optimality. Therefore, the author defined an optimization target for the Softmax gradient and explored the optimal value of the Scale factor from the perspective of maximizing this target. The relevant results can be used both to improve the Scale factor of Attention and to explain the temperature parameters in contrastive learning using $\cos$ similarity.