Beyond MuP: 3. Special Cases, Special Treatment

By 苏剑林 | March 02, 2026

After several blog posts on the subject, many readers are likely familiar with the Muon optimizer—even if the theoretical details aren't perfectly clear, you probably have the impression that it is an "optimizer specifically customized for matrix parameters." However, this statement isn't entirely accurate. For example, for the input-side Embedding layer and the output-side LM Head, although their parameters are matrices, they are not suitable for Muon (see "Muon Optimizer Guide: Quick Start and Key Details").

Why should they be "treated differently"? This article will follow the three stability indicators proposed in the first post to explore the initialization patterns of different types of layers and their corresponding steepest descent directions, thereby answering this question.

Previous Review

In the first article "Beyond MuP: 1. Three Characteristics of a Good Model", we proposed three stability indicators:

The three indicators have a unified format: take the RMS of the output, then take the $\max$ over the input. Here $\boldsymbol{x}$ represents the input, $\boldsymbol{\omega}$ represents the parameters, and $\boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega})$ can represent a layer, a block, or even the entire model, depending on our ability to solve for the $\max$.

Since the range of $\boldsymbol{x}$ is not restricted, the maximum value does not always exist. Therefore, sometimes we need to add certain operations to the model, which in turn guides the model's design. For instance, in the previous article "Beyond MuP: 2. Linear Layers and Steepest Descent", to calculate the stability indicators for a linear layer, we added Input Normalization (In Norm). Additionally, combined with the idea of steepest descent, we reproduced the derivation of the Muon optimizer.

Steepest descent is not a new concept; it answers the question of "what optimizer to use given stability indicators." The core contribution of the "Beyond MuP" series is answering the question of "what stability indicators should be used," providing a calculation formula for stability indicators applicable to any layer.

The Embedding Layer

Now we consider the Embedding layer, which is arguably the simplest layer. The input is an index $i$, and the output is the corresponding vector, i.e., $\boldsymbol{f}(i; \boldsymbol{E}) = \boldsymbol{E}_i$, where $\boldsymbol{E}$ is a $|V| \times d$ matrix and $\boldsymbol{E}_i \triangleq \boldsymbol{E}_{i,:}$ denotes the $i$-th row of $\boldsymbol{E}$. It is easy to calculate:

\begin{align} &\text{Forward Stability:}\quad\max_i \Vert\boldsymbol{E}_i\Vert_{RMS} = \Theta(1)\\[5pt] &\text{Dependency Stability:}\quad\max_{i,j} \Vert\boldsymbol{E}_i - \boldsymbol{E}_j\Vert_{RMS} = \Theta(1) \\[5pt] &\text{Update Stability:}\quad\max_i \Vert \Delta \boldsymbol{E}_i\Vert_{RMS} = \Theta(1) \label{eq:ec3} \end{align}

Note that $\max_{i,j} \Vert\boldsymbol{E}_i - \boldsymbol{E}_j\Vert_{RMS} \leq 2 \max_i \Vert\boldsymbol{E}_i\Vert_{RMS}$, so they are essentially both the maximum row norm of $\boldsymbol{E}$ or $\Delta\boldsymbol{E}$ (multiplied by $1/\sqrt{d}$). Forward stability and dependency stability are only used to guide initialization; they tell us to initialize $\boldsymbol{E}$ with zero mean and $\Theta(1)$ variance.

As for update stability, equation \eqref{eq:ec3} tells us that although it is a matrix, the metric for "stability" for the Embedding layer should not be the spectral norm, but rather the maximum row norm. This results in its steepest descent direction not being Muon. To find the steepest descent for the Embedding layer, we need to solve the optimization problem:

\begin{equation}\min_{\Delta \boldsymbol{E}} \langle\boldsymbol{G},\Delta\boldsymbol{E}\rangle \qquad \text{s.t.}\qquad \max_i \underbrace{\Vert\Delta\boldsymbol{E}_i\Vert_{RMS}}_{\Vert\Delta\boldsymbol{E}_i\Vert_2/\sqrt{d}}\leq\eta\end{equation}

This problem is not difficult to solve; we just need to use the Cauchy-Schwarz inequality:

\begin{equation}\langle\boldsymbol{G},\Delta\boldsymbol{E}\rangle = \sum_{i=1}^{|V|}\langle\boldsymbol{G}_i,\Delta\boldsymbol{E}_i\rangle \geq -\sum_{i=1}^{|V|}\Vert\boldsymbol{G}_i\Vert_2 \times \Vert\Delta\boldsymbol{E}_i\Vert_2 \geq -\eta\sqrt{d}\sum_{i=1}^{|V|}\Vert\boldsymbol{G}_i\Vert_2\end{equation}

Equality holds when $\Delta\boldsymbol{E}_i = - \eta\boldsymbol{G}_i / \Vert\boldsymbol{G}_i\Vert_{RMS}$. That is to say, the steepest descent applicable to the Embedding layer is to perform row-wise RMS Norm on the gradient (Normalized SGD).

The Output Head

Next, let's look at the LM Head. On the surface, this is also a linear layer: the input is $\boldsymbol{x}\in\mathbb{R}^d$, the weights are $\boldsymbol{W}\in\mathbb{R}^{d\times |V|}$, and the output is $\boldsymbol{x}\boldsymbol{W}\in\mathbb{R}^{|V|}$. $\boldsymbol{x}$ usually also carries an RMS Norm. In every respect, it looks like a linear layer, so why is it not suitable for Muon?

Responsible for the Loss

The answer is that the LM Head needs to be "responsible" for the Loss.

It is important to note that steepest descent serves training. From an inference perspective, the model takes several tokens to predict the next token; however, from a training perspective, the true "model" is: input several tokens and the next token to calculate the Loss. That is to say, the data and labels are both inputs, and the true output is actually the Loss. For previous layers, we can ignore the labels and Loss, but the LM Head, being the final layer "adjacent" to the Loss, must consider the influence of labels and Loss.

So, the input of the LM Head becomes $\boldsymbol{x}$ and the index of the next token $t$, and the output becomes the cross-entropy loss, namely:

\begin{equation}\ell(\boldsymbol{x},t;\boldsymbol{W}) = \log\sum_{i=1}^{|V|} e^{\langle \boldsymbol{x},\boldsymbol{w}_i\rangle} - \langle \boldsymbol{x},\boldsymbol{w}_t\rangle = \log\sum_{i=1}^{|V|} e^{\langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle}\end{equation}

Where $\boldsymbol{w}_i\triangleq \boldsymbol{W}_{:, i}$ is the $i$-th column of $\boldsymbol{W}$. Since $\ell$ is a complex non-linear function of $\boldsymbol{x}, t, \boldsymbol{W}$, its three indicators cannot be calculated exactly; our goal is to find a reasonably tight upper bound.

Forward Stability

First is the relatively simple forward stability. By a simple relaxation:

\begin{equation}\begin{aligned} \ell(\boldsymbol{x},t;\boldsymbol{W}) = \log\sum_{i=1}^{|V|} e^{\langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle} \leq&\, \log \left(|V| \max_i e^{\langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle}\right) \\ =&\, \log |V| + \max_i \langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle \\ \leq &\, \log |V| + \max_i \Vert\boldsymbol{x}\Vert_2 \Vert\boldsymbol{w}_i - \boldsymbol{w}_t\Vert_2 \end{aligned}\end{equation}

Therefore:

\begin{equation}\begin{aligned} \text{Forward Stability:}\quad\max_{t, \Vert\boldsymbol{x}\Vert_{RMS}=1} \ell(\boldsymbol{x},t;\boldsymbol{W}) \leq&\, \log |V| + d\max_{i,t} \Vert\boldsymbol{w}_i - \boldsymbol{w}_t\Vert_{RMS} \\ \leq&\, \log |V| + 2d\max_i \Vert\boldsymbol{w}_i\Vert_{RMS} \end{aligned}\end{equation}

If the constant $\log|V|$ is removed, it becomes a lower bound, so this bound is quite tight in the asymptotic sense. To make it $\Theta(1)$, the initialization variance of the LM Head should be chosen as $\Theta(1/d^2)$.

An Important Inequality

As for the remaining two indicators, since they involve differences, the calculation becomes more complex. We first prove an inequality we will need:

\begin{equation}\left\|\log\sum_{i=1}^n e^{a_i} - \log\sum_{i=1}^n e^{b_i}\right\| \leq \max_i |a_i - b_i|\label{leq:lse-ab}\end{equation}

The proof is not difficult but requires a trick: denote the right side as $M$. Then, by the monotonicity of $\log, \sum, \exp$:

\begin{equation}\log\sum_{i=1}^n e^{a_i} = \log\sum_{i=1}^n e^{(a_i - b_i)+b_i} \leq \log\sum_{i=1}^n e^{M + b_i} = M + \log\sum_{i=1}^n e^{b_i}\end{equation}

This proves that:

\begin{equation}\log\sum_{i=1}^n e^{a_i} - \log\sum_{i=1}^n e^{b_i} \leq M\end{equation}

By symmetry, swapping $a_i$ and $b_i$ also holds, thus proving the original inequality.

Dependency Stability

Using inequality \eqref{leq:lse-ab} and Cauchy-Schwarz:

\begin{equation}\begin{aligned} |\ell(\boldsymbol{x}_1,t_1;\boldsymbol{W}) - \ell(\boldsymbol{x}_2,t_2;\boldsymbol{W})| \leq&\, \max_i |\langle \boldsymbol{x}_1,\boldsymbol{w}_i - \boldsymbol{w}_{t_1}\rangle - \langle \boldsymbol{x}_2,\boldsymbol{w}_i - \boldsymbol{w}_{t_2}\rangle| \\ \leq&\, \max_i (\Vert\boldsymbol{x}_1\Vert_2 \Vert\boldsymbol{w}_i - \boldsymbol{w}_{t_1}\Vert_2 + \Vert\boldsymbol{x}_2\Vert_2 \Vert\boldsymbol{w}_i - \boldsymbol{w}_{t_2}\Vert_2) \\ =&\, d\max_i (\Vert\boldsymbol{x}_1\Vert_{RMS} \Vert\boldsymbol{w}_i - \boldsymbol{w}_{t_1}\Vert_{RMS} + \Vert\boldsymbol{x}_2\Vert_{RMS} \Vert\boldsymbol{w}_i - \boldsymbol{w}_{t_2}\Vert_{RMS}) \\ \end{aligned}\end{equation}

Therefore:

\begin{equation}\begin{aligned} \text{Dependency Stability:}\quad\max_{\begin{gathered}t_1, t_2, \\ \Vert\boldsymbol{x}_1\Vert_{RMS}=1 \\ \Vert\boldsymbol{x}_2\Vert_{RMS}=1\end{gathered}} |\ell(\boldsymbol{x}_1,t_1;\boldsymbol{W}) - \ell(\boldsymbol{x}_2,t_2;\boldsymbol{W})| \leq&\, d\max_{i,t_1,t_2} (\Vert\boldsymbol{w}_i - \boldsymbol{w}_{t_1}\Vert_{RMS} + \Vert\boldsymbol{w}_i - \boldsymbol{w}_{t_2}\Vert_{RMS}) \\ \leq&\, 4d\max_i \Vert\boldsymbol{w}_i\Vert_{RMS} \end{aligned}\end{equation}

This result mirrors the forward stability calculation.

Update Stability

Finally, for update stability, again using inequality \eqref{leq:lse-ab} and Cauchy-Schwarz:

\begin{equation}\begin{aligned} |\ell(\boldsymbol{x},t;\boldsymbol{W} + \Delta\boldsymbol{W}) - \ell(\boldsymbol{x},t;\boldsymbol{W})| \leq&\, \max_i |\langle \boldsymbol{x},\boldsymbol{w}_i + \Delta\boldsymbol{w}_i - \boldsymbol{w}_t - \Delta\boldsymbol{w}_t\rangle - \langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle| \\ =&\, \max_i |\langle \boldsymbol{x},\Delta\boldsymbol{w}_i - \Delta\boldsymbol{w}_t\rangle| \\ \leq &\, d \max_i \Vert\boldsymbol{x}\Vert_{RMS} \Vert\Delta\boldsymbol{w}_i - \Delta\boldsymbol{w}_t\Vert_{RMS} \end{aligned}\end{equation}

Therefore:

\begin{equation}\begin{aligned} \text{Update Stability:}\quad\max_{t,\Vert\boldsymbol{x}\Vert_{RMS}=1} |\ell(\boldsymbol{x},t;\boldsymbol{W} + \Delta\boldsymbol{W}) - \ell(\boldsymbol{x},t;\boldsymbol{W})| \leq&\, d \max_{i, t} \Vert\Delta\boldsymbol{w}_i - \Delta\boldsymbol{w}_t\Vert_{RMS} \\ \leq&\, 2d\max_i \Vert\Delta\boldsymbol{w}_i\Vert_{RMS} \end{aligned}\end{equation}

It is not hard to find that the three stability indicators of the LM Head are essentially the same as those of the Embedding layer: they are the maximum row/column norm of the parameter matrix or its increment. This means the steepest descent for the LM Head is also Normalized SGD, with the difference that the LM Head is Normalized column-wise. Furthermore, the three indicators for the LM Head all include a factor of $d$, so its initialization standard deviation and learning rate scale by $\Theta(1/d)$, while the Embedding layer scales by $\Theta(1)$. This means they will differ slightly during cross-width migration.

Other Modules

Besides linear layers, Embeddings, and LM Heads, common Transformer models usually contain other parameters or layers that require individual analysis. Let's go through them one by one.

Hadamard Product

We know that after RMS Norm, there is usually a multiplication by a $\boldsymbol{\gamma}$ vector (Hadamard product), i.e., $(\boldsymbol{x} / \Vert\boldsymbol{x}\Vert_{RMS}) \odot \boldsymbol{\gamma}$, to adjust the output scale. This parameter is not a matrix, so it is not suitable for the aforementioned Muon or Normalized SGD.

We could follow the original definitions to calculate the three stability indicators for $\boldsymbol{\gamma}$ and then analyze its initialization and steepest descent, but there is a cleverer way: note that $\newcommand{diag}{\mathop{\text{diag}}}\boldsymbol{x} \odot \boldsymbol{\gamma} = \boldsymbol{x} \diag(\boldsymbol{\gamma})$. That is, the Hadamard product of $\boldsymbol{x}$ and $\boldsymbol{\gamma}$ is equal to the matrix multiplication of $\boldsymbol{x}$ and the diagonal matrix $\diag(\boldsymbol{\gamma})$. This turns it into a special linear layer where $\boldsymbol{W} = \diag(\boldsymbol{\gamma})$, and we can reuse the conclusions for linear layers.

According to the previous article, the initial spectral norm of $\boldsymbol{W}$ should be $\Theta(\sqrt{d_{out}/d_{in}})$. Since $\boldsymbol{W}$ is square here, it is exactly $\Theta(1)$. Since $\boldsymbol{W}$ is a diagonal matrix, we can simply initialize $\boldsymbol{W}$ as the identity matrix to satisfy this requirement, which corresponds to initializing $\boldsymbol{\gamma}$ as all-ones.

As for the optimizer, if the gradient of $\boldsymbol{\gamma}$ is $\boldsymbol{g}$, then the gradient of $\boldsymbol{W}$ is $\boldsymbol{G} = \diag(\boldsymbol{g})$. We know the steepest descent direction for a linear layer is Muon, namely $\newcommand{msign}{\mathop{\text{msign}}}\Delta\boldsymbol{W} = -\eta\msign(\boldsymbol{G})$. For a diagonal matrix, $\newcommand{sign}{\mathop{\text{sign}}}\msign(\boldsymbol{G}) = \sign(\boldsymbol{G}) = \diag(\sign(\boldsymbol{g}))$. Thus, the steepest descent for the $\boldsymbol{\gamma}$ parameter is SignSGD.

Linear Biases

Traditional linear layers usually have a bias vector $\boldsymbol{b}$, i.e., the complete linear operation is $\boldsymbol{f}(\boldsymbol{x};\boldsymbol{W},\boldsymbol{b}) = \boldsymbol{x}\boldsymbol{W}+\boldsymbol{b}$. However, in recent years, open-source models have mostly removed the bias term, so it has little presence. For completeness, we include its discussion here.

With the bias vector included, the three stability indicators are:

\begin{align} &\text{Forward Stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x}\boldsymbol{W} + \boldsymbol{b}\Vert_{RMS} \\[5pt] &\text{Dependency Stability:}\quad\max_{\Vert\boldsymbol{x}_1\Vert_{RMS}=\Vert\boldsymbol{x}_2\Vert_{RMS}=1} \Vert \boldsymbol{x}_1\boldsymbol{W} - \boldsymbol{x}_2\boldsymbol{W}\Vert_{RMS}\\[5pt] &\text{Update Stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x} \Delta\boldsymbol{W} + \Delta\boldsymbol{b}\Vert_{RMS} \end{align}

The dependency stability is the same as without a bias, so we only need to look at forward and update stability. For simplicity, using the inequality $\Vert \boldsymbol{x}\boldsymbol{W} + \boldsymbol{b}\Vert_{RMS} \leq \Vert \boldsymbol{x}\boldsymbol{W}\Vert_{RMS} + \Vert\boldsymbol{b}\Vert_{RMS}$. Assuming $\boldsymbol{W}$ follows the original initialization, the $\Vert \boldsymbol{x}\boldsymbol{W}\Vert_{RMS}$ part already achieves $\Theta(1)$, so we only need $\Vert\boldsymbol{b}\Vert_{RMS}=\mathcal{O}(1)$. In practice, $\boldsymbol{b}$ is usually initialized to zero.

Similarly, $\Vert \boldsymbol{x}\Delta\boldsymbol{W} + \Delta\boldsymbol{b}\Vert_{RMS} \leq \Vert \boldsymbol{x}\Delta\boldsymbol{W}\Vert_{RMS} + \Vert\Delta\boldsymbol{b}\Vert_{RMS}$. If we let $\Vert\Delta\boldsymbol{b}\Vert_{RMS} = \mathcal{O}(1)$, the bias parameter $\boldsymbol{b}$ will follow steepest descent based on $\Vert\Delta\boldsymbol{b}\Vert_{RMS}$ as the stability indicator, which again results in Normalized SGD.

Attention Scaling

Using the forward stability indicator, we can also re-derive the scaling factor of the Attention mechanism. Let $\boldsymbol{q}=\boldsymbol{x}\boldsymbol{W}_q, \boldsymbol{k}=\boldsymbol{x}\boldsymbol{W}_k$. If $\boldsymbol{W}_q, \boldsymbol{W}_k$ are treated as linear layers, we can assume $\Vert\boldsymbol{q}\Vert_{RMS}=\Theta(1)$ and $\Vert\boldsymbol{k}\Vert_{RMS}=\Theta(1)$ have been achieved. Then, according to Cauchy-Schwarz:

\begin{equation}|\langle\boldsymbol{q},\boldsymbol{k}\rangle| \leq \Vert\boldsymbol{q}\Vert_2 \Vert\boldsymbol{k}\Vert_2 = d\Vert\boldsymbol{q}\Vert_{RMS} \Vert\boldsymbol{k}\Vert_{RMS} \end{equation}

Here $d$ is the dimension of $\boldsymbol{q}$ and $\boldsymbol{k}$, i.e., Head Dim. Clearly, the above expression is $\Theta(d)$. To make it $\Theta(1)$, one must multiply $\boldsymbol{q}\cdot\boldsymbol{k}$ by a scaling factor on the order of $\Theta(1/d)$, which differs from the previous $1/\sqrt{d}$ (refer to "On Initialization, Parameterization, and Standardization of Transformers").

Which one is correct? Actually, both are. $1/\sqrt{d}$ is the average result under random initialization, whereas $\Theta(1/d)$ is the limit value applicable throughout the training process. It doesn't mean we must directly change the scaling factor to $1/d$; rather, scaling inversely with $d$ might bring better transferability. The two can be compatible. For example, if using $1/\sqrt{128}$ as the scaling factor works well when $d=128$, then when migrating to $d=256$, one might consider changing the scaling factor to $1/2\sqrt{128}$ instead of $1/\sqrt{256}$.

In fact, limited by Flash Attention, the choice of Head Dim is not very flexible—usually 128, and generally only up to 256. Therefore, in practice, there is almost no issue of parameter transfer across Head Dims, making this result more theoretical.

Article Summary

Finally, the main results of these two articles are summarized below:

\begin{array}{|c|c|c|c|c|c|} \hline & \text{Input} & \text{Parameter} & \text{Output} & \text{Initial Variance} & \text{Steepest Descent} \\ \hline \text{Linear} & \boldsymbol{x} & \begin{aligned}\boldsymbol{W}\in&\,\mathbb{R}^{d_{in}\times d_{out}} \\ \boldsymbol{b}\in&\,\mathbb{R}^{d_{out}} \end{aligned} & \boldsymbol{x}\boldsymbol{W} + \boldsymbol{b} & \begin{aligned} \boldsymbol{W}:&\, \small{\sqrt{\frac{d_{out}}{d_{in}}}\frac{1}{\sqrt{d_{in}} + \sqrt{d_{out}}}} \\ \boldsymbol{b}:&\, 0\end{aligned}& \begin{aligned} \Delta\boldsymbol{W} =&\, \small{-\eta\sqrt{\frac{d_{out}}{d_{in}}}\msign(\boldsymbol{G})} \\ \Delta\boldsymbol{b} =&\, \small{-\eta \frac{\boldsymbol{g}}{\Vert\boldsymbol{g}\Vert_{RMS}}}\end{aligned} \\ \hline \text{Embedding} & i & \boldsymbol{E} \in\mathbb{R}^{|V|\times d} & \boldsymbol{E}_{i,:} & 1 & \Delta\boldsymbol{E}_{i,:} = -\eta \frac{\boldsymbol{G}_{i,:}}{\Vert\boldsymbol{G}_{i,:} \Vert_{RMS}} \\ \hline \text{LM Head} & \boldsymbol{x}, t & \boldsymbol{W} \in\mathbb{R}^{d \times |V|} & \log\sum\limits_{i=1}^{|V|} e^{\langle \boldsymbol{x},\boldsymbol{W}_{:,i} - \boldsymbol{W}_{:,t}\rangle} & \frac{1}{d^2} & \Delta\boldsymbol{W}_{:,i} = -\frac{\eta}{d} \frac{\boldsymbol{G}_{:,i}}{\Vert\boldsymbol{G}_{:,i}\Vert_{RMS}} \\ \hline \text{RMS Norm} & \boldsymbol{x} & \boldsymbol{\gamma} \in\mathbb{R}^d & \frac{\boldsymbol{x}}{\Vert\boldsymbol{x}\Vert_{RMS}}\odot\boldsymbol{\gamma} & 1 & \Delta\boldsymbol{\gamma} = -\eta \sign(\boldsymbol{g})\\ \hline \end{array}

Among these, the steepest descent directions for the Embedding and LM Head are row/column-wise Normalized SGD, respectively, which is consistent with works like Scion. As for the transfer laws of variance and learning rate, they are consistent with the conclusions of MuP. In these two articles, they are derived based on our proposed "three stability indicators," which shows that we have indeed found a unified form of stability measurement for arbitrary layers.