By 苏剑林 | March 13, 2025
As is well known, the cost of fully training a large LLM is expensive, which dictates that we cannot directly test hyperparameters on large LLMs repeatedly. A natural idea is to hope that we can carefully search for hyperparameters on small models with the same structure and then transfer the optimal combination directly to the large model. Although this idea is simple, realizing it is non-trivial; it requires us to understand the scaling laws between common hyperparameters and model scales. muP is precisely a practical implementation of this idea.
muP, sometimes written as $\mu P$, stands for Maximal Update Parametrization. it originated from the paper "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer". With the popularization of LLM training, it has gradually become one of the standard benchmarks for scientific model training.
Before diving into the main topic, I must vent a bit: the original muP paper is written in an extremely obscure manner, and the presentation of the conclusions is cluttered, which adds a lot of difficulty to understanding. Therefore, I will try to reproduce the conclusions of muP in a (self-perceived) concise and clear way.
To state the conclusion first: muP primarily studies the transfer laws of hyperparameters across model scales. There are a few keywords here:
1. Hyperparameters: Currently, this mainly refers to the learning rate;
2. Model Scale: Currently, this mainly refers to model width;
3. The core is "transfer".
Please note that muP does not study what the optimal hyperparameters are; it only studies how the optimal hyperparameters vary with the model scale. Therefore, we need to search for the optimal hyperparameter combination on some small model and then transfer it to the large model. This is the use case and application method of muP.
The principle of deriving muP is to ensure that the model's forward propagation, backward propagation, loss increment, and feature changes do not change significantly with model scale:
1. The specific method is to analyze the order of magnitude at initialization, assuming these conclusions can represent the laws of subsequent optimization;
2. Simply put, it assumes that if initialization is done well, the rest will automatically follow the correct trajectory (a good start is half the battle?);
3. Of course, one can tell stories about the Law of Large Numbers or the Central Limit Theorem to support this assumption, but I believe it is not strictly necessary.
We start with forward propagation because it is a relatively simple and mature part. First, consider a linear layer $\boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}$, where $\boldsymbol{X}\in\mathbb{R}^{b\times d_{in}}, \boldsymbol{W}\in\mathbb{R}^{d_{in}\times d_{out}}$. We use RMS (Root Mean Square) as an indicator of matrix scale, for example:
\begin{equation}\text{RMS}(\boldsymbol{W}) = \sqrt{\frac{1}{d_{in} d_{out}}\sum_{i=1}^{d_{in}} \sum_{j=1}^{d_{out}} W_{i,j}^2}\end{equation}We know that to keep the RMS of $\boldsymbol{X}$ roughly equal to the RMS of $\boldsymbol{Y}$ during the initialization phase (referred to as "stable"), $\boldsymbol{W}$ should use:
LeCun Initialization: Random initialization with "mean 0, variance $1/d_{in}$".
This is already considered one of the fundamental conclusions of deep learning, so we won't expand on the derivation. Readers who are not familiar with it can refer to previous blog posts such as "Understanding Model Parameter Initialization Strategies from a Geometric Perspective" and "A Brief Discussion on Initialization, Parameterization, and Standardization of Transformers".
Next, we consider a non-linear layer $\boldsymbol{Y}=\phi(\boldsymbol{X}\boldsymbol{W})$, where $\phi$ is an element-wise activation function. If we still want to maintain the RMS of $\boldsymbol{X}$ approximately equal to the RMS of $\boldsymbol{Y}$, the result will be slightly different. For example, with $\text{relu}$ activation, we get:
Kaiming Initialization: Random initialization with "mean 0, variance $2/d_{in}$".
It is easy to see that compared to LeCun initialization, Kaiming initialization only differs by a constant factor of 2 in variance (which is independent of model scale). It can be proven that other activation functions yield similar results. Thus, we can conclude:
fan_in Initialization: To ensure the stability of forward propagation, a random initialization with "mean 0, and variance proportional to $1/d_{in}$" should be used.
This conclusion can also be understood as "the influence of the activation function is independent of the model scale." Therefore, if we only want to analyze the effect of model scale, we can ignore the existence of (element-wise) activation functions and directly obtain the scaling law $\propto 1/d_{in}$ from LeCun initialization.
Now let's continue analyzing the backward propagation (gradients). Note that we assume variables and their gradients have the same shape. We calculate:
\begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}} =&\, \boldsymbol{X}^{\top}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right) \label{eq:grad-w} \\[5pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{X}} =&\, \left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right)\boldsymbol{W}^{\top} \label{eq:grad-x} \end{align}The first formula is the gradient of parameters within the current layer, and the second formula is the gradient propagated back to the previous layer. $\otimes$ is the Hadamard product, and $\phi'$ is the derivative of $\phi$.
Note a fact: the derivatives of the activation functions we commonly use are bounded by a (scale-independent) constant. So, at least in terms of magnitude, we can write:
\begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}} \sim&\, \boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} \\[5pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{X}} \sim&\, \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\boldsymbol{W}^{\top} \end{align}Looking at the second formula, compared to $\boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}$, the matrix multiplied on the right has changed to $\boldsymbol{W}^{\top}$. Following the conclusion of the previous section, if we want to maintain the RMS stability of backward propagation, the initialization of $\boldsymbol{W}$ should be:
fan_out Initialization: Random initialization with "mean 0, variance $1/d_{out}$".
When $d_{in} \neq d_{out}$, the requirements for forward and backward propagation conflict. At this point, some proposed a compromise strategy:
Xavier Initialization: Random initialization with "mean 0, variance $2/(d_{in} + d_{out})$".
This is also called "fan_avg initialization" because it simply averages $d_{in}$ and $d_{out}$ algebraically. Other averaging methods can also be considered; refer to "Thinking about Dimension Averaging Strategies in Initialization Methods for Non-Square Matrices". Xavier initialization seems to take both forward and backward into account, but one could also say it accounts for neither sufficiently. A better approach is to design the model such that most parameters are square matrices, as discussed later in the model cluster $\eqref{eq:model}$.
With the foundations of forward and backward propagation, we can try to analyze the increment of the loss function. Consider the change in the loss function when $\boldsymbol{W} \to \boldsymbol{W} + \Delta\boldsymbol{W}$:
\begin{equation}\Delta \mathcal{L} = \mathcal{L}(\boldsymbol{W} + \Delta\boldsymbol{W}) - \mathcal{L}(\boldsymbol{W})\approx \left\langle\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}, \Delta\boldsymbol{W}\right\rangle_F\end{equation}Here $\langle\cdot,\cdot\rangle_F$ is the Frobenius inner product, which is essentially the vector inner product after flattening the matrices. Considering gradient descent $\Delta\boldsymbol{W} = -\eta \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}$, where $\eta$ is the learning rate, and combining this with Eq. $\eqref{eq:grad-w}$, we have:
\begin{equation}\Delta \mathcal{L}\approx -\eta\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_F^2\sim -eta \left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2\end{equation}In fact, this formula already tells us why the same learning rate $\eta$ cannot be used across model scales:
1. $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is a $d_{in}\times d_{out}$ matrix;
2. $\left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2$ is the sum of squares of $d_{in}\times d_{out}$ numbers;
3. $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is exactly the product of the forward and backward passes;
4. If both the forward and backward passes are stable, each element of $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is $\mathcal{O}(1)$;
5. Therefore, $\left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2$ is $\mathcal{O}(d_{in} d_{out})$.
The 4th point might require more commentary. $\boldsymbol{X}^{\top}$ is a $d_{in}\times b$ matrix, and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is a $b\times d_{out}$ matrix. Their product is the inner product of two $b$-dimensional vectors at each of the $d_{in} d_{out}$ positions. The inner product is a sum of $b$ terms, and the loss $\mathcal{L}$ is usually an average over samples (containing a $1/b$ operation). So if both $\boldsymbol{X}^{\top}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ are scale-independent, their product is basically scale-independent (i.e., the RMS is $\mathcal{O}(1)$).
The final conclusion shows that if we directly use the learning rate of a small model for a larger model, the loss increment per step for a sufficiently large model will **explode** as the parameter scale (i.e., $d_{in} d_{out}$) increases. This means the convergence process of the small model cannot be replicated, and it might even fail to converge because the steps are too large.
At this point, one might think of making $\eta \propto 1/(d_{in} d_{out})$ to scale $\Delta\mathcal{L}$. This thought actually aligns with muP's reasoning. However, in practice, because of the incompatibility between forward and backward passes mentioned earlier, the 4th point "if forward and backward are both stable, each element of $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is $\mathcal{O}(1)$" does not always hold. Thus, the actual situation is more complex.
Now let's consider a scenario closer to practice. Our task is to train a model $\mathbb{R}^{d_{in}}\mapsto \mathbb{R}^{d_{out}}$, where $d_{in}, d_{out}$ are determined by the data and are fixed. As stated at the beginning, muP aims to study the scaling law of hyperparameters with model scale. Therefore, all fixed quantities are treated as constants or $\mathcal{O}(1)$. For example, an initialization variance of $1/d_{in}$ is equivalent to saying the initialization variance is $\mathcal{O}(1)$.
What we can change is the model's architecture, number of parameters, etc., but muP primarily considers the laws of width. So we define the model architecture. The model cluster considered here is:
\begin{equation}\begin{gathered} \boldsymbol{Y}_{in} = \boldsymbol{X} \boldsymbol{W}_{in} \\[5pt] \boldsymbol{Y}_{out} = \text{NN}(\boldsymbol{Y}_{in},\boldsymbol{\Theta}) \\[5pt] \boldsymbol{Z} = \boldsymbol{Y}_{out} \boldsymbol{W}_{out} \end{gathered}\label{eq:model}\end{equation}Where:
1. $\boldsymbol{X}\in\mathbb{R}^{b\times d_{in}}$ (includes batch size);
2. $\boldsymbol{W}_{in} \in \mathbb{R}^{d_{in}\times d}, \boldsymbol{W}_{out} \in \mathbb{R}^{d\times d_{out}}$;
3. $\text{NN}$ is any neural network mapping $\mathbb{R}^d\mapsto \mathbb{R}^d$;
4. $d$ is what we usually call the hidden size;
5. We can freely increase $d$ to improve the model's parameter count and potential;
6. muP aims to study the variation laws of hyperparameters with respect to $d$.
More specifically, the $\text{NN}$ we consider is a $K$-layer MLP:
\begin{equation}\begin{aligned} \boldsymbol{Y}_0 =&\, \boldsymbol{Y}_{in} \\[5pt] \boldsymbol{Y}_{k+1} =&\, \phi(\boldsymbol{Y}_k \boldsymbol{W}_{k+1}) \\[5pt] \boldsymbol{Y}_{out} =&\, \boldsymbol{Y}_K \end{aligned}\label{eq:mlp}\end{equation}Here $\boldsymbol{\Theta}=\{\boldsymbol{W}_1, \boldsymbol{W}_2, \cdots, \boldsymbol{W}_K\}$, and $\boldsymbol{W}_k\in\mathbb{R}^{d\times d}$. Since they are all square matrices, they all use fan_in initialization (equivalently, fan_out initialization).
To supplement, assuming all parameter matrices are $d\times d$ square matrices is purely to simplify the analysis, not a mandatory requirement. The true purpose here is to assume that the parameters in $\text{NN}$ do not have scale-independent shapes. For instance, a shape like $d\times 64$ is not allowed because 64 is a constant, but a shape like $d\times 4d$ is allowed because regardless of fan_in, fan_out, or fan_avg initialization, the variance is proportional to $1/d$.
With a specific model established, we can assemble the previous conclusions. The parameters to be updated are divided into three parts: $\boldsymbol{W}_{in}, \boldsymbol{\Theta}, \boldsymbol{W}_{out}$. Calculate the gradients respectively:
\begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}} =&\, \boldsymbol{Y}_{out}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}} \\[6pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} =&\, \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} \cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}} = \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} \cdot\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}\right) \\[6pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}} =&\, \boldsymbol{X}^{\top} \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{in}} = \boldsymbol{X}^{\top} \left(\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}}\right) = \boldsymbol{X}^{\top} \left(\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}\right)\right) \end{align}The $\cdot$ operation needs a brief explanation: $\boldsymbol{Y}_{in}, \boldsymbol{Y}_{out}$ are matrices, so $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}$ is, in principle, a fourth-order tensor. The chain rule $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}}$ is actually a higher-order tensor multiplication. However, I don't intend to expand on this here, so I simply use $\cdot$ as a placeholder; readers only need to know it's a generalization of matrix multiplication.
Now let's observe the laws:
1. All three equations contain $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$;
2. The last two equations include $\boldsymbol{W}_{out}^{\top}$;
3. Since $\boldsymbol{W}_k$ consists of square matrices, $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}$ and $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ are stable (RMS is $\mathcal{O}(1)$);
4. If $\boldsymbol{W}_{in}$ also uses fan_in initialization, then $\boldsymbol{Y}_{out}$ is stable;
5. For $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}$ to be stable, the initialization variance would be $1/d_{out}$, but $d_{out}$ is scale-independent, which corresponds to a constant.
As a result:
1. The RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$ is $\mathcal{O}(1)$. $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\Vert_F^2$ is the sum of squares of $d\times d_{out}$ numbers, so its magnitude is $\mathcal{O}(d\times d_{out})$. Since $d_{out}$ is constant, it is effectively $\mathcal{O}(d)$. Thus, to obtain $\mathcal{O}(1)$ for $\Delta\mathcal{L}$, its learning rate must satisfy $\eta_{out}\propto 1/d$;
2. $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\Vert_F^2$ is a sum of $d^2$ numbers. Since the RMS of $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$ are both $\mathcal{O}(1)$, if we directly set the initialization variance of $\boldsymbol{W}_{out}$ to $\propto 1/d^2$, then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ will be $\mathcal{O}(1/d)$. After squaring and summing, it is exactly $\mathcal{O}(1)$, so the learning rate does not need to change;
3. At this point, the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$ is also $\mathcal{O}(1/d)$, but $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\Vert_F^2$ is only the sum of squares of $d_{in}\times d$ numbers, so the result is $\mathcal{O}(1/d)$. To get $\mathcal{O}(1)$ for $\Delta\mathcal{L}$, the learning rate instead needs to be increased by $d$ times to cancel this effect, i.e., $\eta_{in}\propto d$.
The above results are correct, but if we think carefully, there is a problem in the derivation: the 2nd and 3rd points are built on the setting "directly set the initialization variance of $\boldsymbol{W}_{out}$ to $\propto 1/d^2$". However, there is currently no direct basis for this setting. Without further explanation, the derivation is incomplete.
In fact, the requirement $\Delta \mathcal{L}=\mathcal{O}(1)$ alone cannot rule out other possibilities. For example, if the initialization variance of $\boldsymbol{W}_{out}$ is set to $\propto 1/d$, then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ is $\mathcal{O}(1/\sqrt{d})$. After squaring and summing, it becomes $\mathcal{O}(d)$, so as long as the learning rate $\eta\propto 1/d$, we can still achieve $\Delta \mathcal{L}=\mathcal{O}(1)$. Therefore, to explain the necessity of "setting the initialization variance of $\boldsymbol{W}_{out}$ to $\propto 1/d^2$", one needs to introduce a new condition.
The loss function $\mathcal{L}$ is a macroscopic indicator of the model, or an external indicator. Looking at its change alone is insufficient to explain all results, so we need to look into the model's interior. Specifically, we hope that the variation of each layer's output (usually called features, sometimes activations) also possesses scale invariance. For instance, in a linear layer $\boldsymbol{Y}_k = \boldsymbol{Y}_{k-1} \boldsymbol{W}_k$, the change in output caused by parameters $\boldsymbol{W}_k\to \boldsymbol{W}_k + \Delta \boldsymbol{W}_k$ is:
\begin{equation}\Delta\boldsymbol{Y}_k = \boldsymbol{Y}_{k-1} (\boldsymbol{W}_k + \Delta \boldsymbol{W}_k) - \boldsymbol{Y}_{k-1} \boldsymbol{W}_k = \boldsymbol{Y}_{k-1} \Delta\boldsymbol{W}_k\end{equation}Note that $\boldsymbol{Y}_{k-1}\in\mathbb{R}^{b\times d}, \Delta\boldsymbol{W}_k\in\mathbb{R}^{d\times d}$, so $\boldsymbol{Y}_{k-1} \Delta\boldsymbol{W}_k$ is the inner product of $b\times d$ pairs of $d$-dimensional vectors. Note that $\Delta\boldsymbol{W}_k$ is a carefully designed update; it is unlikely to be independent of $\boldsymbol{Y}_{k-1}$ like initialization. Therefore, the "inner product of $d$-dimensional vector pairs" is more likely to be $\mathcal{O}(d)$ (since a $d$-dimensional inner product is a sum of $d$ terms). Thus, if the RMS of $\Delta\boldsymbol{Y}_{k-1}$ is $\mathcal{O}(1)$, we can consider that the RMS of $\Delta\boldsymbol{Y}_k$ will be $\mathcal{O}(d\times \text{RMS}(\Delta \boldsymbol{W}_k))$.
Consequently, to have the RMS of $\Delta\boldsymbol{Y}_k$ be $\mathcal{O}(1)$, we get an additional requirement for $\Delta \boldsymbol{W}_k$:
\begin{equation}\text{RMS}(\Delta \boldsymbol{W}_k) = \mathcal{O}(1 / d)\label{eq:dw-rms}\end{equation}Combining $\Delta \boldsymbol{W}_k = -\eta\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ and $\Delta\mathcal{L}=\mathcal{O}(1)$, we can derive the result of "setting the initialization variance of $\boldsymbol{W}_{out}$ to $\propto 1/d^2$".
(Note: This section relies on pointing from Chenyu Zheng. Thank you very much!)
The above is muP for SGD. For Adam, we usually use SignSGD as an approximation for magnitude analysis:
1. $\Delta \boldsymbol{W} = -\eta \mathop{\text{sign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right)$ ;
2. $\Delta \mathcal{L} \approx -\eta \left\|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\|_1$ ;
3. Here $\|\cdot\|_1$ refers to summing the absolute values of each element.
Regarding the SignSGD approximation itself, readers can refer to "How Should Learning Rate Change as Batch Size Increases?" and "How Does Adam's Epsilon Affect the Learning Rate's Scaling Law?". In short, SignSGD is a commonly used approximation method when analyzing scaling laws related to Adam.
Now we can mimic the analysis process for SGD:
1. The RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$ is $\mathcal{O}(1)$. $\left\|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\|_1$ is a sum of $d\times d_{out}$ numbers, so the magnitude is $\mathcal{O}(d\times d_{out}) = \mathcal{O}(d)$. Thus, its learning rate must satisfy $\eta_{out}\propto 1/d$ to cancel the scale effect;
2. $\left\|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\|_1$ is a sum of $d^2$ numbers. Since the RMS of $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$ are both $\mathcal{O}(1)$, we set the initialization variance of $\boldsymbol{W}_{out}$ to $\propto 1/d^2$. Then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ is $\mathcal{O}(1/d)$. The sum of $d^2$ numbers is then $\mathcal{O}(d)$, so the learning rate changes according to $\eta_k\propto 1/d$ to cancel the scale effect;
3. At this point, the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$ is also $\mathcal{O}(1/d)$, but $\left\|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\|_1$ is only a sum of $d_{in}\times d$ numbers, so it is already $\mathcal{O}(1)$, and thus correctly, the learning rate does not need to change with scale.
(Note: The reader can verify for themselves that Eq. $\eqref{eq:dw-rms}$ is satisfied.)
Next is naturally the analysis for Muon. Regarding Muon itself, we have provided detailed introductions in "Muon Optimizer Appreciation: The Essential Leap from Vectors to Matrices" and "Muon Sequel: Why Did We Choose to Try Muon?". Similar to using SignSGD for Adam, we use MSignSGD to approximate Muon:
1. $\Delta \boldsymbol{W} = -\eta \mathop{\text{msign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right)$ ;
2. $\Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_*$ (proof see here);
3. Here $\Vert\cdot\Vert_*$ refers to the Nuclear Norm, which is the sum of all singular values of the matrix;
4. The nuclear norm is not easy to calculate, but the $F$ norm is easy. It is the square root of the sum of squares of all singular values of the matrix;
5. We use the $F$ norm as an approximation of the nuclear norm, so $\Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_* \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_F$;
6. The $F$ norm is also equal to the square root of the sum of squares of all elements of the matrix.
We can then start the analysis:
1. The RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$ is $\mathcal{O}(1)$, so the magnitude of $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\Vert_*$ is $\mathcal{O}(\sqrt{d\times d_{out}}) = \mathcal{O}(\sqrt{d})$. To eliminate the scale effect, its learning rate must satisfy $\eta_{out}\propto 1/\sqrt{d}$;
2. $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\Vert_F$ is the square root of the sum of squares of $d^2$ numbers. Since the RMS of $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$ are both $\mathcal{O}(1)$, we set the initialization variance of $\boldsymbol{W}_{out}$ to $\propto 1/d^2$. Then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ is $\mathcal{O}(1/d)$. After squaring and summing, then taking the square root, the result is $\mathcal{O}(1)$, so the learning rate does not change;
3. At this point, the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$ is also $\mathcal{O}(1/d)$, but $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\Vert_F$ is only the square root of the sum of squares of $d_{in}\times d$ numbers, so it is $\mathcal{O}(1/\sqrt{d})$. The learning rate actually needs to be multiplied by $\sqrt{d}$ to cancel this influence, i.e., $\eta_{in}\propto \sqrt{d}$.
(Note: Here the conclusion for Muon is correct, but it does not satisfy condition $\eqref{eq:dw-rms}$, because $\eqref{eq:dw-rms}$ actually relies on the update being element-wise, which Muon does not satisfy. We won't expand on this here, but directly adopted the conclusion that "the initialization variance of $\boldsymbol{W}_{out}$ is set to $\propto 1/d^2$", bypassing Eq. $\eqref{eq:dw-rms}$.)
Summing up the above conclusions:
| $\boldsymbol{W}_{in}$ Variance | $\boldsymbol{W}_{in}$ LR | $\boldsymbol{W}_k$ Variance | $\boldsymbol{W}_k$ LR | $\boldsymbol{W}_{out}$ Variance | $\boldsymbol{W}_{out}$ LR | |
|---|---|---|---|---|---|---|
| SGD | $1/d_{in}$ | $d$ | $1 / d$ | 1 | $1/d^2$ | $1 / d$ |
| Adam | $1/d_{in}$ | 1 | $1 / d$ | $1 / d$ | $1/d^2$ | $1 / d$ |
| Muon | $1/d_{in}$ | $\sqrt{d}$ | $1 / d$ | 1 | $1/d^2$ | $1 / \sqrt{d}$ |
The $\boldsymbol{W}_k$ here refers to all parameters except $\boldsymbol{W}_{in}$ and $\boldsymbol{W}_{out}$. It must be emphasized that the relationships here are "proportional to" rather than "equal to". Additionally, slight adjustments can be made according to specific needs in practice. For instance, in practice, when using Muon, $\boldsymbol{W}_{in}$ and $\boldsymbol{W}_{out}$ are usually optimized with Adam instead of Muon, which results in two changes:
1. $\eta_{out}\propto 1/d$;
2. $\eta_{in}$ remains unchanged.
If combining with the "Adjust LR" mentioned in our "Muon is Scalable for LLM Training", the learning rate needs to be multiplied by an additional $\sqrt{\max(n, m)}$, where $n\times m$ is the shape of the parameter matrix. We have already assumed the parameters in the $\text{NN}$ part scale proportionally, so $\sqrt{\max(n, m)}\propto \sqrt{d}$. Therefore, to cancel the scale effect brought by Adjust LR, we would need:
3. $\eta_k\propto 1/\sqrt{d}$.
This article introduces muP (Maximal Update Parametrization) in a way that is as clear and simple as possible. muP is a body of work aimed at studying the transfer laws of hyperparameters across model scales. Based on muP, we can search for hyperparameters (primarily learning rate and initialization) on small models at a relatively low cost and then transfer them to large models, reducing the cost of "alchemy" for large models.
Objectively speaking, the introduction and analysis here are still preliminary. For instance, Bias terms were not considered, nor was the universality of conclusions for architectures beyond MLP, and the role of normalization and residuals was not carefully examined. Excluding the Bias term was purely out of laziness, so consider it an exercise for the reader. As for muP in different architectures, analysis is generally complex, but due to the similarity of neural networks, the conclusions are roughly the same, and we can use them without proof. I believe more critical points for improvement are the influences of normalization and residuals—especially normalization, which allows for stable forward propagation without relying on specific initializations, bringing more freedom and possibilities.
Of course, all of these are left for subsequent analysis.