By 苏剑林 | May 13, 2024
A few days ago, the release of DeepSeek-V2 by High-Flyer Quant sparked heated discussions. First, the most shocking aspect was the price—1 RMB per million tokens—which is generally two orders of magnitude cheaper than existing competitors' APIs. This led some to joke, "At this price, even if it outputs gibberish, I would consider that gibberish a form of art." Secondly, according to the technical report, one of the key technologies behind this low price is the newly proposed MLA (Multi-head Latent Attention). It is an improvement over GQA, claimed to be both more efficient and effective, which has drawn extensive attention from readers.
In this article, we will trace the evolution from MHA, MQA, and GQA to MLA, with a particular focus on the design philosophy behind MLA.
MHA
MHA (Multi-Head Attention), or Multi-Head Attention, is the attention mechanism proposed in the pioneering work "Attention is all you need". It can be said that it is the fundamental building block of current mainstream LLMs. Mathematically, MHA is equivalent to the concatenation of multiple independent single-head attention mechanisms. Assuming the input sequence of (row) vectors is $\boldsymbol{x}_1, \boldsymbol{x}_2, \cdots, \boldsymbol{x}_l$, where $\boldsymbol{x}_i \in \mathbb{R}^d$, MHA can be formally expressed as:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
For simplicity, the scaling factor of the attention matrix is omitted here. In practice, common settings are $d_k = d_v = d / h$. For Llama2-7b, $d=4096, h=32, d_k = d_v = 128$; for Llama2-70b, $d=8192, h=64, d_k = d_v = 128$.
Since we only consider Causal Attention used in mainstream autoregressive LLMs, when generating tokens one by one, the newly predicted $(t+1)$-th token does not affect the previously computed $\boldsymbol{k}_{\leq t}^{(s)}, \boldsymbol{v}_{\leq t}^{(s)}$. Therefore, these results can be cached for subsequent generation steps to avoid unnecessary redundant calculations. This is known as the KV Cache.
The subsequent MQA, GQA, and MLA are all products evolved around the theme of "how to reduce KV Cache while ensuring performance as much as possible."
Bottlenecks
A natural question is: why is reducing the size of the KV Cache so important?
As is well known, LLM inference generally takes place on GPUs. The video memory (VRAM) of a single GPU is limited. One part is used to store model parameters and activation values during forward computation; this part depends on the size of the model and is a constant once the model is chosen. Another part is used to store the KV Cache, which depends not only on the model size but also on the input sequence length. That is, it grows dynamically during inference. When the context length is long enough, the KV Cache size dominated the memory usage, potentially exceeding the total memory of a single card or even a single machine (8 cards).
The principle of model deployment on GPUs is: if it can be deployed on one card, do not span multiple cards; if it can be deployed on one machine, do not span multiple machines. This is because "intra-card bandwidth > inter-card bandwidth > inter-machine bandwidth." Due to the "bottleneck effect," the more devices a model spans during deployment, the more it is "dragged down" by inter-device communication bandwidth. In fact, even though the bandwidth between SRAM and HBM within a single H100 card has reached 3TB/s, this speed is still a bottleneck for short-context inference, let alone the much slower inter-card and inter-machine communication.
Therefore, the purpose of reducing KV Cache is to enable the inference of longer contexts on fewer devices, or to allow a larger batch size for the same context length, thereby achieving faster inference speed or higher total throughput. Ultimately, the goal is to achieve lower inference costs.
To understand this issue in more detail, readers can further read "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness", "A guide to LLM inference and performance", and "LLM inference speed of light". I won't expand further here (mainly because my own level is limited, and I feat saying too much might lead to errors).
MQA
MQA, or "Multi-Query Attention," is a very naive attempt to reduce KV Cache, first proposed in "Fast Transformer Decoding: One Write-Head is All You Need". This paper is already from 2019, which means that long before the LLM craze, reducing KV Cache was already a topic of great concern to researchers.
The idea of MQA is simple: directly let all attention heads share the same K and V. In formula terms, this means removing the superscript ${}^{(s)}$ from all $\boldsymbol{k}$ and $\boldsymbol{v}$ in MHA:
\begin{equation}\require{cancel}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{#ccc}{\smash{\bcancel{(s)}}}} ,\boldsymbol{v}_{\leq t}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{v}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}} = \boldsymbol{x}_i\boldsymbol{W}_k^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}} = \boldsymbol{x}_i\boldsymbol{W}_v^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
Models using MQA include PaLM, StarCoder, Gemini, and others. Obviously, MQA directly reduces the KV Cache to $1/h$ of the original size, which is very significant and effectively the ceiling for memory saving.
Regarding performance, currently, the loss on most tasks appears limited, and supporters of MQA believe this loss can be compensated through further training. Additionally, it is worth noting that since MQA shares K and V, the parameter count for the Attention mechanism is nearly halved. To keep the total model parameter count unchanged, the scale of FFN/GLU is usually increased accordingly, which also helps mitigate performance loss.
GQA
However, there are concerns that the compression of KV Cache in MQA is too severe, potentially affecting the model's learning efficiency and final performance. To address this, GQA (Grouped-Query Attention), which acts as a transition version between MHA and MQA, was introduced. It comes from the paper "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints", a piece of work from last year.
In hindsight, the idea of GQA is also very simple. It divides all heads into $g$ groups ($g$ must divide $h$), and each group shares the same pair of K and V. Represented mathematically:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{red}{(\lceil sg/h\rceil)}} ,\boldsymbol{v}_{\leq t}^{\color{red}{(\lceil sg/h\rceil)}}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{(\lceil sg/h\rceil)}}{}^{\top}\right)\boldsymbol{v}_i^{\color{red}{(\lceil sg/h\rceil)}}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{(\lceil sg/h\rceil)}}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{\color{red}{(\lceil sg/h\rceil)}} = \boldsymbol{x}_i\boldsymbol{W}_k^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{\color{red}{(\lceil sg/h\rceil)}} = \boldsymbol{x}_i\boldsymbol{W}_v^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
Here $\lceil\cdot\rceil$ is the ceiling symbol. GQA provides a natural transition from MHA to MQA: when $g=h$, it is MHA; when $g=1$, it is MQA. When $1 < g < h$, it only compresses the KV Cache to $g/h$. While the compression rate is not as high as MQA, it provides greater degree of freedom, and performance is better guaranteed. The most well-known users of GQA are Meta's open-sourced Llama2-70B and the entire Llama3 series. Other models using GQA include TigerBot, DeepSeek-V1, StarCoder2, Yi, ChatGLM2, ChatGLM3, etc. More models use GQA than MQA (although ChatGLM describes itself as MQA, it is actually GQA with $g=2$).
In Llama2/3-70B, $g=8$, and other GQA models of similar size basically maintain this setting. This is not accidental but is also for inference efficiency. We know that a model of the 70B size, unless extremely quantized, cannot be deployed on a single card (A100/H100 80G). If a single card won't work, then a single machine must suffice; generally, a single machine can be equipped with 8 cards. As we just said, each head of the attention mechanism is actually computed independently and then concatenated. When $g=8$, each card can exactly take charge of the Attention Head calculations corresponding to one pair of K and V. This way, the diversity of K and V is preserved as much as possible while minimizing inter-card communication.
MLA
With the foundations of MHA, MQA, and GQA, understanding MLA (Multi-head Latent Attention) becomes relatively easier. DeepSeek-V2's technical report introduces MLA from the perspective of low-rank projection, which led some readers to ask, "Why did it take until MLA to use low-rank decomposition for KV Cache if LoRA has been around for so long?"
However, I believe the perspective of low-rank projection is not the closest to the essence. If we speak of low-rank projection, in fact, as long as we stack all Ks and Vs of GQA together, we will find that GQA is also equivalent to performing a low-rank projection:
\begin{equation}\underbrace{\left[\boldsymbol{k}_i^{(1)},\cdots,\boldsymbol{k}_i^{(g)},\boldsymbol{v}_i^{(1)},\cdots,\boldsymbol{v}_i^{(g)}\right]}_{\boldsymbol{c}_i\in\mathbb{R}^{g(d_k+d_v)}} = \boldsymbol{x}_i \underbrace{\left[\boldsymbol{W}_k^{(1)},\cdots,\boldsymbol{W}_k^{(g)},\boldsymbol{W}_v^{(1)},\cdots,\boldsymbol{W}_v^{(g)}\right]}_{\boldsymbol{W}_c\in\mathbb{R}^{d\times g(d_k+d_v)}}\end{equation}
Here we concatenate all $\boldsymbol{k}_i^{(s)}, \boldsymbol{v}_i^{(s)}$ together and denote them as $\boldsymbol{c}_i$. The corresponding projection matrices are also concatenated and denoted as $\boldsymbol{W}_c$. Note that generally $d_c = g(d_k+d_v) < d$, so the transformation from $\boldsymbol{x}_i$ to $\boldsymbol{c}_i$ is indeed a low-rank projection. Thus, the essential improvement of MLA is not the low-rank projection itself, but the work done after the low-rank projection.
Part 1
What did GQA do after the projection? First, it split the vector in half to serve as K and V respectively. Then, each half was further divided into $g$ parts, and each part was duplicated $h/g$ times to "make up" enough K and V for $h$ Attention Heads. We know that splitting and duplicating are simple linear transformations. MLA's first idea is to replace these simple linear transformations with generic linear transformations to enhance the model's capacity:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} , \boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_c\times d_k} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt]
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c}
\end{gathered}
\end{equation}
Theoretically, this can increase model capacity. However, remember that the main purpose of GQA is to reduce the KV Cache. For considerations of saving calculation and communication costs, we generally cache the projected $\boldsymbol{k}_i, \boldsymbol{v}_i$ rather than the pre-projection $\boldsymbol{c}_i$ or $\boldsymbol{x}_i$. In this approach of MLA, because different projection matrices make all K and V heads different again, the KV Cache size returns to being as large as MHA, which violates the original intent of GQA.
To this, MLA discovered that by combining the specific form of Dot-Attention, we can bypass this problem through a simple but ingenious identity transformation. First, training proceeds normally, as there is little room for optimization there. Then, during the inference stage, we utilize:
\begin{equation}\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\right){}^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}\right)\boldsymbol{c}_i^{\top} \end{equation}
This means that in the inference stage, we can merge $\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$ as the projection matrix for Q. Then $\boldsymbol{c}_i$ takes the place of the original $\boldsymbol{k}_i$. Similarly, after $\boldsymbol{o}_t$ there is another projection matrix, so the $\boldsymbol{W}_v^{(s)}$ in $\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}$ can also be absorbed into the subsequent projection matrix. Thus, $\boldsymbol{v}_i$ can equivalently be replaced by $\boldsymbol{c}_i$. In other words, at this point, the KV Cache only needs to store all $\boldsymbol{c}_i$ instead of all $\boldsymbol{k}_i^{(s)}, \boldsymbol{v}_i^{(s)}$. Note that $\boldsymbol{c}_i$ is independent of ${}^{(s)}$, meaning it is shared across all heads. Thus, MLA can be identically transformed into an MQA during inference.
Re-emphasizing: the theme of this article has always been reducing KV Cache. What has MLA achieved so far? The answer is that it has enhanced GQA's capacity through different projection matrices while maintaining the same KV Cache size during inference. Conversely, if we only need similar capability to GQA, could we then further reduce the KV Cache? In other words, $d_c$ doesn't need to be $g(d_k+d_v)$ but can take a smaller value (DeepSeek-V2 set it to 512). This further compression of KV Cache is the core idea of MLA.
Supplementary Notes:
1. The identity transformation that merges $\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$ into one matrix theoretically only holds under infinite precision. In practice, if we use single precision, especially BF16, the precision loss after transformation is often quite obvious and can accumulate through multiple layers to a noticeable extent;
2. In practice, we generally do not calculate Q according to $\boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}\right)$, but rather calculate it as $\left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\right)\boldsymbol{W}_k^{(s)}{}^{\top}$. Although this is serial, under the low-rank assumption, the computation is smaller and the theoretical precision loss is also smaller. However, for the purpose of this article, we still describe it as merging $\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$ into one matrix.
Part 2
Everything seems perfect, and it looks like an ideal design—both efficient and effective—is about to emerge. But wait; upon further reflection, we find that the MLA described so far has an unavoidable flaw—it is incompatible with RoPE (Rotary Positional Embedding).
As we just said, the key step for MLA to maintain the same KV Cache size as GQA is "merging $\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$ into one (position-independent) matrix to serve as the projection matrix for Q." But if RoPE is added, this step cannot be achieved. This is because RoPE is a position-dependent $d_k\times d_k$ block-diagonal matrix $\boldsymbol{\mathcal{R}}_m$, satisfying $\boldsymbol{\mathcal{R}}_m\boldsymbol{\mathcal{R}}_n^{\top}=\boldsymbol{\mathcal{R}}_{m-n}$. Adding RoPE to MLA inserts an extra term $\boldsymbol{\mathcal{R}}_{t-i}$ between $\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$:
\begin{equation}
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\quad,\quad\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i} \\
\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_t}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right){}^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_{t-i}}\boldsymbol{W}_k^{(s)}{}^{\top}\right)\boldsymbol{c}_i^{\top} \end{equation}
Here $\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_{t-i}}\boldsymbol{W}_k^{(s)}{}^{\top}$ cannot be merged into a fixed projection matrix (it is related to the position difference $t-i$). Thus, the idea of MLA cannot be combined with RoPE.
I had the honor of discussing this issue with the DeepSeek team some time ago. However, this problem is very fundamental, and at the time I couldn't really offer any effective suggestions. The simplest way would be to abandon RoPE and switch to other positional encodings based on Attention Bias, such as ALIBI. But DeepSeek's experiments showed it was significantly inferior to RoPE (note that it's not that MLA cannot use RoPE, but that adding RoPE prevents using the identity transformation trick to reduce KV Cache). I also suggested using Sandwich, which does not monotonically decay to negative infinity, and I estimated its effect might be better, but it felt like treating the symptoms rather than the root cause. Another compromise is to change the input of $\boldsymbol{q}_i$ to $\boldsymbol{c}_i$, and then add RoPE after $\boldsymbol{c}_i$, i.e.,
\begin{equation}\boldsymbol{q}_i^{(s)} = \boldsymbol{c}_i\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\boldsymbol{W}_q^{(s)},\quad\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\boldsymbol{W}_k^{(s)}\end{equation}
This way $\boldsymbol{\mathcal{R}}_i$ could be absorbed into $\boldsymbol{c}_i$. However, then there would be no $\boldsymbol{\mathcal{R}}_m\boldsymbol{\mathcal{R}}_n^{\top}=\boldsymbol{\mathcal{R}}_{m-n}$ calculation. RoPE at this point would no longer achieve relative position through absolute position, but would simply add absolute position to Q and K, letting the model figure out how to extract relative position information.
The finally released MLA adopts a hybrid approach—each Attention Head's Q and K have $d_r$ additional dimensions added to apply RoPE, and the additional dimensions of K are shared among all heads:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} , \boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \left[\boldsymbol{x}_i\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{qr}^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d\times d_r}\\
\boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt]
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c}
\end{gathered}
\end{equation}
In this way, the dimensions without RoPE can repeat the operations of "Part 1". During inference, the KV Cache only needs to store $\boldsymbol{c}_i$. The new dimensions with RoPE are used to provide positional information, and since they are shared among all heads, they only add $d_r$ dimensions to the K Cache. The original paper took $d_r = d_k / 2 = 64$, which is not a large increase compared to the original $d_c = 512$.
Part 3
Finally, there is a detail: the final version of MLA also changes the input of Q to a low-rank projection form. This is unrelated to reducing KV Cache; it is primarily used to reduce the memory occupied by parameters and their gradients (the original paper said activation values, which I personally don't quite understand) during training:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} , \boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r}\\
\boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt]
\boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\
\end{gathered}
\label{eq:mla-mha}\end{equation}
Note that in the second term of $\boldsymbol{k}_i^{(s)}$, the part with RoPE uses $\boldsymbol{x}_i$ as input rather than $\boldsymbol{c}_i$. This follows the setting in the original paper and is not a typo. The original paper took $d_c' = 1536$, which is different from $d_c = 512$. At the same time, let's place the MHA with RoPE below for comparison:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} , \boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\in\mathbb{R}^{d_k},\quad\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\in\mathbb{R}^{d_k},\quad\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
We can see that in the training stage, aside from an extra low-rank projection and RoPE being applied only to certain dimensions, MLA is basically identical to an MHA where the Q and K Head Size is changed from $d_k$ to $d_k + d_r$.
In the decoding stage, MLA is converted into an MQA form:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}\boldsymbol{W}_v^{(1)}, \boldsymbol{o}_t^{(2)}\boldsymbol{W}_v^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\boldsymbol{W}_v^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{#ccc}{\smash{\bcancel{(s)}}}} , \boldsymbol{c}_{\leq t}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{c}_i}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}\boldsymbol{W}_{kc}^{(s)}{}^{\top}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c + d_r}\\
\boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}} = \left[\boldsymbol{c}_i, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c+d_r}\\
\boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k}, \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r}, \boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\[10pt]
\boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\
\end{gathered}
\label{eq:mla-mqa}\end{equation}
At this time, the Head Size of Q and K becomes $d_c + d_r$, and the Head Size of V becomes $d_c$. According to the settings in the original paper, this is 4 times larger than $d_k, d_v$. Thus, although the conversion made by MLA in the decoding stage can effectively reduce KV Cache, the calculation workload of decoding is increased.
Why can it still improve inference efficiency? This goes back to the discussion in the "Bottleneck" section. We can divide LLM inference into two parts: the generation of the first token (Prefill) and the generation of each subsequent token (Generation). The Prefill stage involves parallel calculation for all tokens of the input and then saving the corresponding KV Cache; this stage is a bottleneck for computation, bandwidth, and memory. We can use MLA's MHA form $\eqref{eq:mla-mha}$ for this. However, in the Generation stage, since only one token is calculated per step, it is more of a bandwidth and memory bottleneck. At this time, we can use MLA's MQA form $\eqref{eq:mla-mqa}$, thereby significantly increasing the speed of Generation.
There is another detail that fully reflects this characteristic. General LLM architecture parameters satisfy $h \times d_k = d$, i.e., num_heads * head_size = hidden_size. But DeepSeek-V2 is different: $d_k=128, d=5120$, but $h=128$, which is three times the usual setting! This is because the KV Cache size of MLA is independent of $h$. Increasing $h$ only increases computation and improves model capability, but does not increase KV Cache, and thus does not bring about speed bottlenecks.
Summary
This article briefly summarized the evolution of Multi-Head Attention, especially the shift in concepts from MHA to MQA, GQA, and finally to MLA. In this article, MLA is regarded as a generalization of GQA, replacing GQA's splitting and repeating with projection matrices, and introducing an identity transformation trick that can further compress the KV Cache while using a hybrid method to remain compatible with RoPE. Overall, MLA is a very practical variant of Attention.