Entropy-Invariant Attention from the Perspective of the JL Lemma

By 苏剑林 | April 10, 2023

In the articles "Entropy Invariance: Looking at the Scaling Operation of Attention" and "A Fast Derivation of Entropy-Invariant Softmax," the author proposed Entropy-Invariant Softmax. Simply put, this involves multiplying the Attention matrix before Softmax by an additional factor of $\log n$, which theoretically helps improve length extrapolation, where $n$ is the sequence length. This $\log n$ factor reminded me of the JL Lemma (Johnson-Lindenstrauss Lemma). Since the JL Lemma tells us that encoding $n$ vectors only requires a dimension of $\mathcal{O}(\log n)$, and both involve $\log n$, is there a connection between the two?

Entropy Invariance

We know that entropy is a measure of uncertainty. In the attention mechanism, we use it as the "degree of focused attention." So-called entropy invariance means that regardless of the sequence length $n$, we want the attention to remain concentrated on a few key tokens rather than becoming too dispersed. To this end, the proposed form of Entropy-Invariant Attention is:

\begin{equation}Attention(Q,K,V) = softmax\left(\frac{\log_{512} n}{\sqrt{d}}QK^{\top}\right)V\label{eq:core}\end{equation}

Here $Q,K \in \mathbb{R}^{n \times d}$. Compared to conventional Attention, the scaling factor includes an extra $\log_{512} n$, where the base is taken as 512, assuming that all our hyperparameters (such as $d$) are tuned for a training length of 512. Of course, even if your planned pre-training length is not 512, you can take the base as 512 without much consequence; the results will remain largely unaffected.

The principle behind this form is quite intuitive. When $n$ increases, it means there are more tokens to "dilute" the attention, resulting in less focused attention. At this point, we multiply by a factor that is monotonically increasing with respect to $n$. After Softmax, this is equivalent to a power operation on the original probabilities. Since probabilities are less than 1, smaller probabilities become even smaller after the power operation, thus making the attention concentrated again. As for why this factor takes a logarithmic form, one would need to refer to the derivation process in the articles mentioned at the beginning.

JL Lemma

The JL Lemma, full name "Johnson-Lindenstrauss Lemma," is an important conclusion regarding vector embedding. Simply put, it tells us that "to fit $n$ vectors, only $\mathcal{O}(\log n)$ dimensions are needed" (here $\log$ denotes the natural logarithm with base $e$ by default). For a detailed introduction, please refer to "The Amazing Johnson-Lindenstrauss Lemma: Theory Edition."

Interestingly, even before I knew about the JL Lemma, I had derived a similar, and perhaps even more specific result in "The Principle of Minimum Entropy (VI): How to Choose the Dimension of Word Vectors?"—namely, that embedding $n$ word vectors requires roughly $8 \log n$ dimensions. This estimation is very close to the dimensions used in practice. For example, when $n$ equals 100,000, $8 \log n$ calculates to approximately 92, and the word vector dimensions we frequently use are on the order of one or two hundred.

Additionally, the JL Lemma can be used to explain the multi-head nature of the attention mechanism. If we substitute $n=512$, then $8 \log n \approx 50$, which is very close to the projection dimensions commonly used for $Q$ and $K$ in Attention (i.e., the key_size, which is 64 in BERT; see here). This tells us that if the sequence length is 512, then the dimension for calculating Attention $Q$ and $K$ at the level of 50 is sufficient; there is no need to use the full hidden_size (BERT base is 768). The saved dimensions can instead be used for multi-head attention.

For more related discussions, please refer to "Availability Analysis of the Dimension Formula 'n > 8.33 log N'" and "The Amazing Johnson-Lindenstrauss Lemma: Application Edition."

Connecting Them

Now, we can attempt to connect the JL Lemma with Entropy-Invariant Attention.

Let the key_size of $Q$ and $K$ be $d$. The JL Lemma tells us that the optimal choice for $d$ should be $d_n = \lambda \log n$, where $\lambda$ is a proportionality constant whose specific value is not important. That is to say, ideally, $d$ should change as $n$ changes. However, it is obvious that such a design is not easy to implement and is not conducive to computational parallelization. Therefore, in practice, we can only use a fixed $d$.

Assuming we have selected a fixed $d$ and that this $d$ is designed for a training length of 512, we can conclude that $d = \lambda \log 512$, which means $\lambda = \frac{d}{\log 512}$, and:

\begin{equation}d_n = \frac{d}{\log 512}\log n=d\log_{512} n\end{equation}

For $n \neq 512$, ideally, a projection dimension of $d_n$ should be used. But in practice, $d$ dimensions are used. According to the definition of the inner product $\langle q,k\rangle = \sum_{i=1}^d q_i k_i$, the number of terms in the sum is exactly equal to the dimension $d$. That is, in the ideal case, there should be a sum of $d_n$ terms, but in reality, it becomes a sum of $d$ terms. Intuitively, then, if the contribution of each term is similar, multiplying the result by $\frac{d_n}{d}$ will make the result closer to the ideal case of a $d_n$-term sum. Therefore, we conclude that we should multiply $\langle q,k\rangle$ by the factor:

\begin{equation}\frac{d_n}{d} = \log_{512} n\end{equation}

to compensate for the gap between the actual and ideal situations. When conventional Scaled-Dot Attention is multiplied by $\log_{512} n$, it yields exactly the Entropy-Invariant Attention shown in equation $\eqref{eq:core}$.

In this way, we have linked the JL Lemma to Entropy-Invariant Attention. Note that this is only an intuitive, qualitative understanding process; it is difficult to further formalize it strictly from a quantitative perspective. In fact, there is no need for further quantification, as the JL Lemma itself is more of a qualitative conclusion.

Summary

This article constructs a simple connection between the JL Lemma and Entropy-Invariant Attention.