By 苏剑林 | September 01, 2025
In previous articles "How Should the Learning Rate Change as the Batch Size Increases?" and "How Does Adam's epsilon Affect the Learning Rate Scaling Law?", we theoretically discussed the law of how the learning rate changes with the batch size. The most classic part of this analysis is the second-order expansion approach proposed by OpenAI. However, when dealing with non-SGD optimizers, this analytical method's calculation process often becomes quite complex, leaving one feeling at a loss as to where to begin.
In the next few articles, I will reorganize and rethink the relevant details in the aforementioned posts, attempting to simplify some of the derivation steps. I aim to provide a more general and lightweight derivation path and explore the possibility of extending it to the Muon optimizer.
General Idea
First, let's review the previous analysis method. In "How Should the Learning Rate Change as the Batch Size Increases?", we introduced various perspectives for analyzing the relationship between learning rate and batch size. The second-order approximate analysis proposed by OpenAI in "An Empirical Model of Large-Batch Training" occupied the main portion of that discussion, and this article continues to follow that same line of reasoning.
Next, we need to introduce some notation. Let the loss function be $\mathcal{L}(\boldsymbol{w})$, where $\boldsymbol{w}\in\mathbb{R}^N$ is the parameter vector and $\boldsymbol{g}$ is its gradient. Note that the ideal loss function is calculated as the expectation over the entire training set, but in practice, we can only sample a batch to compute it. This introduces randomness into the gradient. We denote the gradient of a single sample as $\tilde{\boldsymbol{g}}$, its mean is $\boldsymbol{g}$, and its covariance matrix is denoted as $\boldsymbol{\Sigma}$. When the batch size is $B$, the gradient is denoted as $\tilde{\boldsymbol{g}}_B$; its mean is still $\boldsymbol{g}$, but its covariance matrix becomes $\boldsymbol{\Sigma}/B$.
Furthermore, let the current learning rate be $\eta$ and the update vector be $\tilde{\boldsymbol{\varphi}}_B$. Then the loss function after the update will be:
\begin{equation}\begin{aligned}
\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B) \approx&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{\varphi}}_B \\
=&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\newcommand{tr}{\mathop{\text{tr}}}\tr(\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H})
\end{aligned}\end{equation}
On the right side, we have performed a Taylor expansion to the second order, where $\boldsymbol{H}$ is the Hessian matrix and $\tr$ is the trace of the matrix. The second equality uses the identity $\tr(\boldsymbol{A}\boldsymbol{B})=\tr(\boldsymbol{B}\boldsymbol{A})$. To obtain a deterministic result, we take the expectation of both sides:
\begin{equation}\mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B)] \approx \mathcal{L}(\boldsymbol{w}) - \eta\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \tr(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H})\end{equation}
We treat the right side as a quadratic function of $\eta$. Assuming the coefficient of the second-order term is positive (a stronger assumption is that the $\boldsymbol{H}$ matrix is positive definite), we can find the minimum point:
\begin{equation}\eta^* \approx \frac{\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g}}{\tr(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H})}\end{equation}
This represents the learning rate that, on average, results in the fastest decrease in the loss function—it is the theoretical optimal learning rate. Our task is to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$ for a specific $\tilde{\boldsymbol{\varphi}}_B$, and then extract its relationship with the batch size ($B$) from the above formula.
Warm-up Exercise
As the first example, we naturally consider the simplest case, SGD. In this case, $\tilde{\boldsymbol{\varphi}}_B=\tilde{\boldsymbol{g}}_B$. We can easily obtain $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]=\boldsymbol{g}$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]=\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B$. Thus, we have:
\begin{equation}\eta^* \approx \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\tr((\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B)\boldsymbol{H})} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \tr(\boldsymbol{\Sigma}\boldsymbol{H})/B} = \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B}\label{eq:eta-sgd}\end{equation}
where
\begin{equation}\eta_{\max} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}},\qquad\mathcal{B}_{\text{noise}} = \frac{\tr(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}\end{equation}
The result in $\eqref{eq:eta-sgd}$ can be interpreted in several ways. First, it is a monotonically increasing function with an upper bound $\eta_{\max}$. This indicates that the learning rate cannot increase indefinitely. Compared to a simple linear or square root scale, it aligns better with our intuitive understanding. When $B \ll \mathcal{B}_{\text{noise}}$, we have:
\begin{equation}\eta^* \approx \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \approx \frac{\eta_{\max}}{\mathcal{B}_{\text{noise}}/B} = \eta_{\max} B / \mathcal{B}_{\text{noise}}\end{equation}
This shows that when the batch size is small, the SGD learning rate is indeed linearly related to the batch size, and it also suggests that $\mathcal{B}_{\text{noise}}$ is a critical statistic. However, the definition of $\mathcal{B}_{\text{noise}}$ depends on the Hessian matrix $\boldsymbol{H}$, which is almost impossible to calculate precisely in LLMs. Therefore, in practice, we usually assume it is (a multiple of) the identity matrix, yielding a simplified form:
\begin{equation}\mathcal{B}_{\text{simple}} = \frac{\tr(\boldsymbol{\Sigma})}{\boldsymbol{g}^{\top}\boldsymbol{g}}\end{equation}
This result takes the form of noise intensity ($\tr(\boldsymbol{\Sigma})$) divided by signal intensity ($\boldsymbol{g}^{\top}\boldsymbol{g}$), which is actually the reciprocal of the signal-to-noise ratio (SNR). it indicates that the lower the SNR, the larger the batch size required to use the same $\eta_{\max}$, which also matches our intuition. Since $\tr(\boldsymbol{\Sigma})$ only depends on the diagonal elements of $\boldsymbol{\Sigma}$, it means we only need to independently estimate the mean and variance for each parameter, which is practically feasible.
Data Efficiency
Besides the direct relationship between learning rate and batch size, I believe the derived asymptotic relationship regarding training data volume and training steps is also an essential and fascinating part to study. Notably, this conclusion seems more general than the learning rate formula $\eqref{eq:eta-sgd}$. As we will see later, SignSGD yields a result of the same form, even though its learning rate law is not given by Eq $\eqref{eq:eta-sgd}$.
The original paper's discussion on this part is quite complex; the following derivation is simplified. Specifically, substituting $\eta^*$ back into $\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{g}}_B)$, we get:
\begin{equation}\overline{\Delta\mathcal{L}} = \mathcal{L}(\boldsymbol{w}) - \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta^*\tilde{\boldsymbol{g}}_B)] \approx \frac{\Delta\mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise}}/B}\end{equation}
where $\Delta\mathcal{L}_{\max} = \frac{(\boldsymbol{g}^{\top}\boldsymbol{g})^2}{2\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$. How should we understand this result? First, it is a monotonically increasing function of $B$. As $B\to\infty$, it equals $\Delta\mathcal{L}_{\max}$. In other words, if we could use an infinitely large batch size, the loss decrease per step would be $\Delta\mathcal{L}_{\max}$, and the required number of training steps would be minimized, denoted as $S_{\min}$.
If the batch size is finite, the average loss decrease per step is only $\overline{\Delta\mathcal{L}}$. This means that, on average, we need to take $1 + \mathcal{B}_{\text{noise}}/B$ steps to achieve the same loss decrease as one step with an infinite batch size. Thus, to reach the same loss, we need to train for $S = (1 + \mathcal{B}_{\text{noise}}/B)S_{\min}$ steps.
Since the batch size is $B$, it is easy to see that the total volume of training data consumed is $E = BS = (B + \mathcal{B}_{\text{noise}})S_{\min}$. This result shows that if we increase the batch size, to achieve the same effect, we also need to appropriately increase the data volume $E$. As $B\to 0$, the required data volume is minimized at $E_{\min} = \mathcal{B}_{\text{noise}}S_{\min}$. Using these notations, we can write:
\begin{equation}\left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1\end{equation}
This is the classic relationship between training data volume and training steps. It has two parameters, $S_{\min}$ and $E_{\min}$. We can also fit this equation by searching for multiple experimental points $(S, E)$, thereby estimating $S_{\min}$ and $E_{\min}$, and subsequently estimating $\mathcal{B}_{\text{noise}} = E_{\min} / S_{\min}$. For more analysis details, please refer back to the previous article "How Should the Learning Rate Change as the Batch Size Increases?" or OpenAI's original paper "An Empirical Model of Large-Batch Training".
Difficulty Analysis
So far, everything has been confined to SGD. From a computational perspective, SGD is trivial. Complexity truly arises when $\tilde{\boldsymbol{\varphi}}_B$ depends non-linearly on $\tilde{\boldsymbol{g}}_B$. For instance, SignSGD corresponds to $\newcommand{\sign}{\mathop{\text{sign}}}\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$. In theoretical analysis, it is often used as an approximation for Adam. A more accurate approximation is the SoftSignSGD, which considers $\epsilon$, and we attempted to analyze it in "How Does Adam's epsilon Affect the Learning Rate Scaling Law?".
Under these non-linear scenarios, calculating $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$ is often quite difficult, even if we assume that $\tilde{\boldsymbol{g}}_B$ follows a simple normal distribution (note that in the analysis of SGD, we did not need to make any assumptions about the form of its distribution). For example, in previous articles, to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$ for SignSGD where $\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$, we went through the following steps:
1. Assume the components of $\tilde{\boldsymbol{g}}_B$ are independent, simplifying the problem to the expectation of a single component's $\tilde{\varphi}_B=\sign(\tilde{g}_B)$ (not bold);
2. Assume $\tilde{g}_B$ (now a scalar) follows a normal distribution, allowing us to calculate $\mathbb{E}[\tilde{\varphi}_B]$, with the answer expressed using the $\newcommand{\erf}{\mathop{\text{erf}}}\erf$ function;
3. Approximate the $\erf$ function with a function of the form $x/\sqrt{x^2+c}$ to simplify the result.
In other words, we had to go through a series of roundabout steps just to barely compute an approximate result for analysis (this process first appeared in Tencent's paper "Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling"). And this was considered the simple case because SoftSignSGD is even more complicated:
1. Assume the components of $\tilde{\boldsymbol{g}}_B$ are independent, simplifying the problem to the expectation of a single component's $\tilde{\varphi}_B=\newcommand{\softsign}{\mathop{\text{softsign}}}\softsign(\tilde{g}_B, \epsilon)$;
2. Approximate the $\softsign$ function with a piecewise linear function to calculate the following integral;
3. Assume $\tilde{g}_B$ follows a normal distribution, and combine with the approximation from step 2 to calculate $\mathbb{E}[\tilde{\varphi}_B]$, which results in a complex function containing $\erf$;
4. Approximate the complex function with a function of the form $x/\sqrt{x^2+c}$ to simplify the result.
The struggle doesn't end there. After all that effort and numerous assumptions, we only just managed to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$. Next, we have to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$, which is usually much more complex (SignSGD is an exception because $\sign(x)^2$ is always 1, making it simpler). However, the complexity of calculation is secondary; the main issue is that these steps don't seem to reveal any generalizable patterns—it feels like every specific problem requires its own unique analysis, which is mentally exhausting.
To Be Continued
To avoid making this article too long, we will stop here for now. We have mainly reviewed existing analytical results and computational difficulties. In the next article, I will introduce some of my attempts to reduce the mental burden during the derivation process.