When Batch Size Increases, How Should the Learning Rate Change Accordingly?

By 苏剑林 | November 14, 2024

With the rapid advancement of computing power, more and more scenarios hope to achieve "trading compute for time"—that is, shortening model training time by piling on more computing resources. Ideally, we hope that if we invest $n$ times the compute, the time to reach the same level of performance will be reduced to $1/n$, while the total compute cost remains identical. This "hope" seems reasonable and natural, but in reality, it is not trivial. Even if we ignore communication bottlenecks, when compute exceeds a certain scale or models are smaller than a certain scale, increasing compute often only results in increasing the Batch Size. However, does increasing the Batch Size necessarily shorten training time while maintaining performance?

This is the topic we are about to discuss: when the Batch Size increases, how should various hyperparameters, especially the learning rate, be adjusted to maintain the original training effect and maximize training efficiency? We can also call this the Scaling Law between Batch Size and Learning Rate.

Variance Perspective

Intuitively, when the Batch Size increases, the gradient for each batch will be more accurate, so the step can be larger—that is, the learning rate can be increased—in order to reach the destination faster and shorten training time. Most people can generally think of this point. The question is, how much of an increase is most appropriate?

Square Root Scaling

The earliest answer to this question might be square root scaling, where if the Batch Size is expanded by a factor of $n$, the learning rate is expanded by a factor of $\sqrt{n}$, originating from the 2014 paper "One weird trick for parallelizing convolutional neural networks". The derivation principle is to keep the variance of the SGD increments constant. Specifically, we denote the gradient of a randomly sampled example as $\tilde{\boldsymbol{g}}$, with its mean and covariance denoted as $\boldsymbol{g}$ and $\boldsymbol{\Sigma}$, respectively, where $\boldsymbol{g}$ is the gradient of all examples. When we increase the sample size to $B$, we have

\begin{equation}\tilde{\boldsymbol{g}}_B \triangleq \frac{1}{B}\sum_{i=1}^B \tilde{\boldsymbol{g}}^{(i)},\quad \mathbb{E}[\tilde{\boldsymbol{g}}_B] = \boldsymbol{g},\quad \mathbb{E}[(\tilde{\boldsymbol{g}}_B-\boldsymbol{g})(\tilde{\boldsymbol{g}}_B-\boldsymbol{g})^{\top}]=\frac{\boldsymbol{\Sigma}}{B}\end{equation}

This means that increasing the sample size does not change the mean, while the covariance is reduced to $1/B$. For the SGD optimizer, the increment is $-\eta \tilde{\boldsymbol{g}}_B$, and its covariance is proportional to $\eta^2/B$. We believe that a moderate amount of noise (neither too much nor too little) is necessary during the optimization process. Therefore, when the Batch Size $B$ changes, we adjust the learning rate $\eta$ to keep the noise intensity of the increment—that is, the covariance matrix—constant, leading to

\begin{equation}\frac{\eta^2}{B} = \text{constant}\quad\Rightarrow\quad \eta\propto \sqrt{B}\end{equation}

This gives the square root scaling law for the learning rate and Batch Size. Later, "Train longer, generalize better: closing the generalization gap in large batch training of neural networks" also agreed with this choice.

Linear Scaling

Interestingly, linear scaling, i.e., $\eta\propto B$, often performs better in practice. Even the author of the aforementioned "One weird trick for parallelizing convolutional neural networks", which first proposed square root scaling, pointed this out in the paper and stated that he could not provide a reasonable explanation for it.

To some extent, linear scaling is more in line with our intuitive understanding, especially as described in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": assuming the gradient directions of $n$ consecutive batches do not change much, then linear scaling is almost self-evident. However, this assumption is obviously too strong. Relaxing this assumption requires linking SGD with SDEs (Stochastic Differential Equations), which was accomplished by "Stochastic Modified Equations and Dynamics of Stochastic Gradient Algorithms I: Mathematical Foundations". However, the first paper to point out the scaling relationship between the learning rate and Batch Size using this link should be "On the Generalization Benefit of Noise in Stochastic Gradient Descent".

In hindsight, the establishment of this connection is not difficult to understand. Let the model parameters be $\boldsymbol{\theta}$, then the SGD update rule can be rewritten as

\begin{equation}\boldsymbol{\theta}_{t+1} =\boldsymbol{\theta}_t - \eta \tilde{\boldsymbol{g}}_{B,t} =\boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \eta (\tilde{\boldsymbol{g}}_{B,t} - \boldsymbol{g}_t)\end{equation}

where $\tilde{\boldsymbol{g}}_{B,t} - \boldsymbol{g}_t$ is the gradient noise. So far, we have not made any assumptions about the distribution of this noise, only that its mean is $\boldsymbol{0}$ and its covariance is $\boldsymbol{\Sigma}_t/B$. Next, we assume that the distribution of this noise is a normal distribution $\mathcal{N}(\boldsymbol{0},\boldsymbol{\Sigma}_t/B)$. Then the above iteration can be further rewritten as

\begin{equation}\begin{aligned} \boldsymbol{\theta}_{t+1} =&\, \boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \eta (\tilde{\boldsymbol{g}}_{B,t} - \boldsymbol{g}_t) \\[5pt] =&\, \boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \eta \sqrt{\frac{\boldsymbol{\Sigma}_t}{B}}\boldsymbol{z},\quad \boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I}) \\[5pt] =&\, \boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \sqrt{\eta} \sqrt{\frac{\eta\boldsymbol{\Sigma}_t}{B}}\boldsymbol{z},\quad \boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I}) \end{aligned}\end{equation}

This means that the SGD iteration format $\boldsymbol{\theta}_{t+1} =\boldsymbol{\theta}_t - \eta \tilde{\boldsymbol{g}}_{B,t}$ is actually approximately solving the SDE:

\begin{equation}d\boldsymbol{\theta} = - \boldsymbol{g}_t dt - \sqrt{\frac{\eta\boldsymbol{\Sigma}_t}{B}}d\boldsymbol{w}\end{equation}

Therefore, for the running results to remain largely unchanged when $B$ changes, the form of the SDE above should be invariant, which leads to linear scaling $\eta\propto B$. The most critical step in this process is that the step size of the noise term in the SDE is the square root of the non-noise term, which separates out a factor of $\sqrt{\eta}$. We also commented on this point in "Generating Diffusion Models (5): General Framework—SDE Section"; simply put, zero-mean Gaussian noise has a certain cancellation effect over the long term, so the step size must be increased to manifest the noise effect.

The above conclusions are all derived based on the SGD optimizer. The paper "On the SDEs and Scaling Rules for Adaptive Gradient Algorithms" generalized it to optimizers like RMSProp and Adam, resulting in square root scaling. Coincidentally, the slightly earlier "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" also applied square root scaling when testing Adam and its variant LAMB. For more details, refer to the blog "How to Scale Hyperparameters as Batch Size Increases".

Facing the Loss Directly

What is certain is that whether it is square root scaling or linear scaling, they can only be approximately true within a local range, because they both imply the conclusion that "as long as the Batch Size is large enough, the learning rate can be arbitrarily large," which is obviously impossible. Furthermore, the work in the previous two sections focused on variance, but our fundamental task is to reduce the loss function. Therefore, taking a loss-function-oriented approach may be more essential.

Monotonic and Bounded

The classic work in this perspective is OpenAI's "An Empirical Model of Large-Batch Training", which analyzes the optimal learning rate of SGD using a second-order approximation of the loss function, concluding that "the learning rate increases monotonically with the Batch Size but has an upper bound." The same idea appeared in the slightly earlier "Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients", although that paper was not used to analyze the effect of Batch Size.

The most important concept in the entire derivation process is to treat the learning rate as an optimization parameter: let the loss function be $\mathcal{L}(\boldsymbol{\theta})$, and the gradient of the current batch be $\tilde{\boldsymbol{g}}_B$, then the loss function after SGD is $\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B)$. We regard solving for the optimal learning rate as an optimization problem:

\begin{equation}\eta^* = \mathop{\text{argmin}}_{\eta} \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B)]\end{equation}

This goal is clearly intuitive—choosing the learning rate that makes the training efficiency highest (the loss function decreases fastest) on average. To solve this problem, we approximate the loss function by expanding it to the second order:

\begin{equation}\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B) \approx \mathcal{L}(\boldsymbol{\theta}) - \eta\tilde{\boldsymbol{g}}_B^{\top}\underbrace{\frac{\partial \mathcal{L}(\boldsymbol{\theta})}{\partial\boldsymbol{\theta}}}_{\text{which is }\boldsymbol{g}} + \frac{1}{2}\eta^2 \tilde{\boldsymbol{g}}_B^{\top}\underbrace{\frac{\partial^2 \mathcal{L}(\boldsymbol{\theta})}{\partial\boldsymbol{\theta}^2}}_{\text{denoted as }\boldsymbol{H}}\tilde{\boldsymbol{g}}_B = \mathcal{L}(\boldsymbol{\theta}) - \eta\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B\end{equation}

where $\boldsymbol{H}$ is the Hessian matrix, and $\frac{\partial \mathcal{L}(\boldsymbol{\theta})}{\partial\boldsymbol{\theta}}$ is the gradient of the loss function. The ideal objective function is calculated based on all samples, which is why its gradient is the mean $\boldsymbol{g}$ of $\tilde{\boldsymbol{g}}_B$. Next, taking the expectation, we get

\begin{equation}\mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B)] \approx \mathbb{E}[\mathcal{L}(\boldsymbol{\theta}) - \eta\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B] = \mathcal{L}(\boldsymbol{\theta}) - \eta\boldsymbol{g}^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \mathbb{E}[\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B]\end{equation}

The last term involves a small trick:

\begin{equation}\begin{aligned} \mathbb{E}[\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B] =&\, \mathbb{E}[\text{Tr}(\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B)]= \mathbb{E}[\text{Tr}(\tilde{\boldsymbol{g}}_B\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H})] = \text{Tr}(\mathbb{E}[\tilde{\boldsymbol{g}}_B\tilde{\boldsymbol{g}}_B^{\top}]\boldsymbol{H}) \\[5pt] =&\, \text{Tr}((\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B)\boldsymbol{H}) = \boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})/B \end{aligned}\end{equation}

The transformation process mainly utilizes $\text{Tr}(\boldsymbol{A}\boldsymbol{B}) = \text{Tr}(\boldsymbol{B}\boldsymbol{A})$. Now, assuming the positive definiteness of $\boldsymbol{H}$, the problem becomes finding the minimum of a quadratic function, which yields:

\begin{equation}\eta^* \approx \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})/B} = \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B}\label{eq:eta-opt}\end{equation}

This results in the conclusion "monotonically increasing as $B$ increases with an upper bound," where

\begin{equation}\eta_{\max} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}},\qquad\mathcal{B}_{\text{noise}} = \frac{\text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}\end{equation}

Practical Analysis

When $B \ll \mathcal{B}_{\text{noise}}$, $1 + \mathcal{B}_{\text{noise}}/B\approx \mathcal{B}_{\text{noise}}/B$, so $\eta^* \approx \eta_{\max}B/\mathcal{B}_{\text{noise}}\propto B$, i.e., linear scaling. This again demonstrates that linear scaling is only a local approximation for small Batch Sizes. When $B > \mathcal{B}_{\text{noise}}$, $\eta^*$ gradually approaches a saturation value $\eta_{\max}$, meaning that the increase in training cost far outweighs the gain in training efficiency. Therefore, $\mathcal{B}_{\text{noise}}$ acts as a watershed; when the Batch Size exceeds this value, there is no need to continue investing compute power to increase the Batch Size.

For practice, the most critical question is undoubtedly how to estimate $\eta_{\max}$ and $\mathcal{B}_{\text{noise}}$, especially since $\mathcal{B}_{\text{noise}}$ is directly related to the scaling law of the learning rate and the saturation of training efficiency. Direct calculation involves the Hessian matrix $\boldsymbol{H}$, whose computational complexity is proportional to the square of the parameter count. In an era where models with hundreds of millions of parameters are considered small, calculating the Hessian matrix is clearly unrealistic, so a more effective calculation method must be found.

Let's look at $\mathcal{B}_{\text{noise}}$ first. Its formula is $\frac{\text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$. Since both the numerator and the denominator have $\boldsymbol{H}$, there is an urge to "cancel" them out. In fact, the simplification follows exactly this logic. Assuming $\boldsymbol{H}$ is approximately a multiple of the identity matrix, we get

\begin{equation}\mathcal{B}_{\text{noise}} = \frac{\text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}\approx \frac{\text{Tr}(\boldsymbol{\Sigma})}{\boldsymbol{g}^{\top}\boldsymbol{g}}\triangleq \mathcal{B}_{\text{simple}}\end{equation}

$\mathcal{B}_{\text{simple}}$ is more computationally feasible, and experiments have found it is usually a good approximation of $\mathcal{B}_{\text{noise}}$, so we choose to estimate $\mathcal{B}_{\text{simple}}$ instead of $\mathcal{B}_{\text{noise}}$. Note that $\text{Tr}(\boldsymbol{\Sigma})$ only requires the diagonal elements, so there is no need to calculate the full covariance matrix; one only needs to calculate the variance of each gradient component individually and then sum them. In data-parallel scenarios, the gradients calculated on each device can be used directly to estimate the gradient variance.

It should be noted that the results like Eq $\eqref{eq:eta-opt}$ are actually dynamic—theoretically, $\eta_{\max}$, $\mathcal{B}_{\text{noise}}$, and $\mathcal{B}_{\text{simple}}$ are different for every training step. Therefore, if we hope to obtain a static law, we need to train for a period of time until the model's training is back on the "right track" before calculating a reliable $\mathcal{B}_{\text{simple}}$, or we can continuously monitor $\mathcal{B}_{\text{simple}}$ during the training process to judge the gap between the current setting and the optimum.

As for $\eta_{\max}$, there is no need to estimate it using the formula; one can simply perform a grid search for the learning rate at a certain small Batch Size to find an approximate $\eta^*$, and then combine it with the estimated $\mathcal{B}_{\text{simple}}$ to back-calculate $\eta_{\max}$.

Data Efficiency

Starting from these results, we can also derive an asymptotic relationship between the training data volume and the number of training steps. The derivation process is simple: substituting $\eqref{eq:eta-opt}$ into the loss function shows that under the optimal learning rate, the reduction in the loss function per iteration is:

\begin{equation}\Delta\mathcal{L} = \mathcal{L}(\boldsymbol{\theta}) - \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta^*\tilde{\boldsymbol{g}}_B)] \approx \frac{\Delta\mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise}}/B}\label{eq:Delta-L-sgd}\end{equation}

where $\Delta\mathcal{L}_{\max} = \frac{(\boldsymbol{g}^{\top}\boldsymbol{g})^2}{2\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$. The focus now is on the interpretation of this result.

When $B\to\infty$, which is full-batch SGD, the reduction in the loss function per step reaches the maximum $\Delta\mathcal{L}_{\max}$. At this point, the goal can be achieved with the minimum number of training steps (denoted as $S_{\min}$). When $B$ is finite, the average loss reduction per step is only $\Delta\mathcal{L}$, which means we need $1 + \mathcal{B}_{\text{noise}}/B$ steps to achieve the reduction of a single step of full-batch SGD. Thus, the total number of training steps is roughly $S = (1 + \mathcal{B}_{\text{noise}}/B)S_{\min}$.

Since the Batch Size is $B$, the total number of samples consumed during the training process is $E = BS = (B + \mathcal{B}_{\text{noise}})S_{\min}$, which is an increasing function of $B$. When $B\to 0$, $E_{\min} = \mathcal{B}_{\text{noise}}S_{\min}$. This indicates that as long as we use a small enough Batch Size to train the model, the total training samples $E$ required will also decrease accordingly, at the cost of very many training steps $S$. Further, using these notations, we can write the relationship between them as:

\begin{equation}\left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1\label{eq:E-S}\end{equation}

This is the scaling law between the amount of training data and the number of training steps. It indicates that the smaller the data volume, the more the Batch Size should be reduced—allowing for more training steps—to increase the chance of reaching a more optimal solution. The derivation here has been simplified by the author, assuming the invariance of $\mathcal{B}_{\text{noise}}$ and $\Delta\mathcal{L}_{\max}$ throughout the training process. If necessary, one could follow the appendix of the original paper using integration to more precisely handle the dynamics (but this requires introducing the assumption $B = \sqrt{r\mathcal{B}_{\text{noise}}}$), which we won't expand on here.

Additionally, since $\mathcal{B}_{\text{noise}} = E_{\min}/S_{\min}$, the above equation also provides another scheme for estimating $\mathcal{B}_{\text{noise}}$: get multiple $(S,E)$ pairs through multiple experiments and grid searches, and then fit the above equation to estimate $E_{\min}, S_{\min}$, and subsequently calculate $\mathcal{B}_{\text{noise}}$.

Adaptive Version

It must be said that OpenAI is indeed one of the pioneers of various Scaling Laws. The previous analysis is quite brilliant and provides rich results. More importantly, the derivation process is not complex, giving a sense of "finding the essence in simplicity" (大道至简). However, current conclusions are based on SGD. Their applicability to adaptive learning rate optimizers like Adam is unclear. This part of the content was completed by "Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling".

Symbolic Approximation

The logic for analyzing Adam is the same as for SGD, based on second-order expansion. The difference is that the direction vector $\tilde{\boldsymbol{g}}_B$ is replaced by a general vector $\tilde{\boldsymbol{u}}_B$. Here, we have

\begin{equation}\mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{u}}_B)] \approx \mathcal{L}(\boldsymbol{\theta}) - \eta\mathbb{E}[\tilde{\boldsymbol{u}}_B]^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \text{Tr}(\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]\boldsymbol{H})\end{equation}

Now we need to determine $\tilde{\boldsymbol{u}}_B$ and calculate the corresponding $\mathbb{E}[\tilde{\boldsymbol{u}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]$. Since we only need an asymptotic relationship, just as in "Configuring different learning rates, can LoRA gain a bit more?", we choose SignSGD, i.e., $\tilde{\boldsymbol{u}}_B = \text{sign}(\tilde{\boldsymbol{g}}_B)$, as an approximation for Adam. This approach likely first appeared in "Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients". The rationality of this approximation is reflected in two points:

1. Regardless of the values of $\beta_1, \beta_2$, the update vector of the first step of Adam is $\text{sign}(\tilde{\boldsymbol{g}}_B)$;
2. When $\beta_1=\beta_2=0$, the update vector of Adam is always $\text{sign}(\tilde{\boldsymbol{g}}_B)$.

To calculate $\mathbb{E}[\tilde{\boldsymbol{u}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]$, we also need to assume—just as in the "Linear Scaling" section—that $\tilde{\boldsymbol{g}}_B$ follows the distribution $\mathcal{N}(\boldsymbol{g},\boldsymbol{\Sigma}/B)$, and to simplify calculation, we further assume $\boldsymbol{\Sigma}$ is a diagonal matrix $\text{diag}(\sigma_1^2,\sigma_2^2,\sigma_3^2,\dots)$, assuming that the components are independent. In this way, we can process each component independently. From reparameterization, $\tilde{g}_B\sim \mathcal{N}(g, \sigma^2/B)$ is equivalent to $\tilde{g}_B=g + \sigma z/\sqrt{B}, z\sim\mathcal{N}(0,1)$, thus

\begin{equation}\begin{aligned} \mathbb{E}[\tilde{u}_B] =&\, \mathbb{E}[\text{sign}(g + \sigma z/\sqrt{B})] = \mathbb{E}[\text{sign}(g\sqrt{B}/\sigma + z)] \\[5pt] =&\, \frac{1}{\sqrt{2\pi}}\int_{-\infty}^{\infty} \text{sign}(g\sqrt{B}/\sigma + z) e^{-z^2/2}dz \\[5pt] =&\, \frac{1}{\sqrt{2\pi}}\int_{-\infty}^{-g\sqrt{B}/\sigma} (-1)\times e^{-z^2/2}dz + \frac{1}{\sqrt{2\pi}}\int_{-g\sqrt{B}/\sigma}^{\infty} 1\times e^{-z^2/2}dz \\[5pt] =&\, \text{erf}\left(\frac{g}{\sigma}\sqrt{\frac{B}{2}}\right) \end{aligned}\end{equation}

Here $\text{erf}$ is the error function, which is an S-shaped function with a value range of $(-1,1)$, similar to $\tanh$. It can serve as a smooth approximation of $\text{sign}$. However, since $\text{erf}$ itself does not have an elementary function expression, we'd better find an elementary function approximation to observe the variation law more intuitively. We discussed this topic in "Where did the two elementary function approximations of GELU come from?", but the approximations there are still too complex (incorporating exponential operations). Let's use something simpler here:

\begin{equation}\text{erf}(x)\approx \text{sign}(x) = \frac{x}{|x|} = \frac{x}{\sqrt{x^2}}\approx \frac{x}{\sqrt{x^2+c}}\end{equation}

We choose $c=\pi/4$ so that the first-order approximation of this function at $x=0$ is equal to that of $\text{erf}$. Of course, having made so many heavy approximations, the value of $c$ is no longer critical; we only need to know that such a $c > 0$ exists. Based on this approximation, we get

\begin{equation}\mathbb{E}[\tilde{u}_B] \approx \frac{g/\sigma}{\sqrt{\pi/2B+(g/\sigma)^2}}\quad\Rightarrow\quad\mathbb{E}[\tilde{\boldsymbol{u}}_B]_i \approx \frac{g_i/\sigma_i}{\sqrt{\pi/2B+(g_i/\sigma_i)^2}}\triangleq \mu_i\end{equation}

We can find that a clear difference between Adam and SGD is that $\mathbb{E}[\tilde{\boldsymbol{u}}_B]$ is already related to $B$ at this step. Fortunately, the second moment is simpler now because the square of $\text{sign}(x)$ must be 1, therefore

\begin{equation}\mathbb{E}[\tilde{u}_B^2] = 1\quad\Rightarrow\quad\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]_{i,j} \to\left\{\begin{aligned}&=1, & i = j \\ &\approx\mu_i \mu_j,&\,i\neq j\end{aligned}\right.\end{equation}

Using these results, we can calculate:

\begin{gather}\eta^* \approx \frac{\mathbb{E}[\tilde{\boldsymbol{u}}_B]^{\top}\boldsymbol{g}}{\text{Tr}(\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]\boldsymbol{H})} \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i} + \sum_{i\neq j} \mu_i \mu_j H_{i,j}}\label{eq:eta-opt-sign} \\[5pt] \Delta \mathcal{L} = \mathcal{L}(\boldsymbol{\theta}) - \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta^*\tilde{\boldsymbol{u}}_B)] \approx \frac{1}{2}\frac{(\sum_i \mu_i g_i)^2}{\sum_i H_{i,i} + \sum_{i\neq j} \mu_i \mu_j H_{i,j}}\label{eq:Delta-L-sign}\end{gather}

Two Special Cases

Compared to Eq $\eqref{eq:eta-opt}$ for SGD, Eq $\eqref{eq:eta-opt-sign}$ for Adam is more complex, making it impossible to intuitively see its dependency law on $B$. Thus we start with several special examples.

First, consider $B\to\infty$. At this point $\mu_i = \text{sign}(g_i)$, so

\begin{equation}\eta^* \approx \frac{\sum_i |g_i|}{\sum_i H_{i,i} + \sum_{i\neq j} \text{sign}(g_i g_j) H_{i,j}}\end{equation}

The difference between this and SGD's $\eta_{\max}$ is that it is not homogeneous with respect to the gradient, but rather proportional to the scale of the gradient.

Next, we consider the case where $\boldsymbol{H}$ is a diagonal matrix, i.e., $H_{i,j}=0$ when $i\neq j$. Then

\begin{equation}\eta^* \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i}}=\frac{1}{\sum_i H_{i,i}}\sum_i \frac{g_i^2/\sigma_i}{\sqrt{\pi/2B+(g_i/\sigma_i)^2}}\end{equation}

Every term in the sum here is monotonically increasing with $B$ and has an upper bound, so the total result is the same. To capture the most essential law, we can consider further simplifying $\mu_i$ (starting differently from the original paper):

\begin{equation}\mu_i = \frac{g_i/\sigma_i}{\sqrt{\pi/2B+(g_i/\sigma_i)^2}} = \frac{\text{sign}(g_i)}{\sqrt{1 + \pi(\sigma_i/g_i)^2/2B}} \approx \frac{\text{sign}(g_i)}{\sqrt{1 + \pi\kappa^2/2B}}\label{eq:mu-approx}\end{equation}

The assumption here is that there exists a constant $\kappa^2$ independent of $i$ [for example, one could consider taking some kind of average of all $(\sigma_i/g_i)^2$; in fact, here $\kappa^2$ is similar to the previous $\mathcal{B}_{\text{simple}}$, and can be estimated according to its definition], such that replacing $(\sigma_i/g_i)^2$ with $\kappa^2$ for any $i$ is a good approximation. Thus

\begin{equation}\eta^* \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i}}\approx \frac{\sum_i |g_i|}{\sum_i H_{i,i}}\frac{1}{\sqrt{1 + \pi\kappa^2/2B}}\label{eq:eta-opt-sign-diag}\end{equation}

When $\pi\kappa^2\gg 2B$, i.e., $B \ll \pi\kappa^2/2$, we can further write the approximation:

\begin{equation}\eta^* \approx \frac{\sum_i \sigma_i}{\kappa\sum_i H_{i,i}}\sqrt{\frac{2B}{\pi}} \propto \sqrt{B}\end{equation}

This shows that when the Batch Size itself is small, Adam indeed follows the square root scaling law.

Emergent Behavior

If we apply approximation $\eqref{eq:mu-approx}$ to the original equation $\eqref{eq:eta-opt-sign}$, we will find that it possesses some entirely new characteristics. Specifically, we have

\begin{equation}\eta^* \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i} + \sum_{i\neq j} \mu_i \mu_j H_{i,j}} \approx \frac{\eta_{\max}}{\frac{1}{2}\left(\frac{\beta_{\text{noise}}}{\beta} + \frac{\beta}{\beta_{\text{noise}}}\right)}\label{eq:eta-opt-beta}\end{equation}

where $\beta = (1 + \pi\kappa^2/2B)^{-1/2}$, and

\begin{equation}\beta_{\text{noise}} = \sqrt{\frac{\sum_i H_{i,i}}{\sum_{i\neq j}\text{sign}(g_i g_j) H_{i,j}}},\quad \eta_{\max} = \frac{\sum_i |g_i|}{2\sqrt{\left(\sum_i H_{i,i}\right)\left(\sum_{i\neq j} \text{sign}(g_i g_j) H_{i,j}\right)}}\end{equation}

Note that $\beta$ is a monotonically increasing function of $B$, but the last approximation in Eq $\eqref{eq:eta-opt-beta}$ is not a monotonically increasing function of $\beta$. It increases first and then decreases, with its maximum achieved at $\beta=\beta_{\text{noise}}$. This means there is a corresponding $\mathcal{B}_{\text{noise}}$. When the Batch Size exceeds this $\mathcal{B}_{\text{noise}}$, the optimal learning rate should not increase but rather decrease! This is the "Surge phenomenon" mentioned in the title of the original paper. (Of course, there is an additional constraint: $\beta$ is always less than $1$. If $\beta_{\text{noise}} \geq 1$, the relationship between the optimal learning rate and Batch Size remains monotonically increasing.)

How do we intuitively understand the Surge phenomenon? The author believes this is essentially a manifestation of the sub-optimality of adaptive learning rate strategies. Still taking the approximation $\tilde{\boldsymbol{u}}_B = \text{sign}(\tilde{\boldsymbol{g}}_B)$ as an example: the larger $B$ is, the more accurate $\tilde{\boldsymbol{g}}_B$ is. $B\to \infty$ yields $\text{sign}(\boldsymbol{g})$. However, is $\text{sign}(\boldsymbol{g})$ the most scientific update direction? Not necessarily, especially in the later stages of training where such adaptive strategies may also have negative effects. Therefore, when $B$ takes an appropriate value, the noise of $\text{sign}(\tilde{\boldsymbol{g}}_B)$ may actually correct this sub-optimality. As $B$ continues to increase, noise decreases, thereby reducing the opportunity for correction, which necessitates more cautious lowering of the learning rate.

Efficiency Relationship

As with the analysis of SGD, finally we can consider $\Delta\mathcal{L}$. By substituting Eq $\eqref{eq:eta-opt-beta}$ into Eq $\eqref{eq:Delta-L-sign}$, restoring the notation $B$ and then simplifying (the simplification process does not require any approximation), we obtain

\begin{equation}\Delta \mathcal{L} \approx \frac{\Delta \mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise-2}}/B}\label{eq:Delta-L-sign-2}\end{equation}

where

\begin{equation}\Delta \mathcal{L}_{\max} = \frac{\beta_{\text{noise}}\eta_{\max}\sum_i|g_i|}{1 + \beta_{\text{noise}}^2},\quad \mathcal{B}_{\text{noise-2}} = \frac{\pi\kappa^2\beta_{\text{noise}}^2}{2(1 + \beta_{\text{noise}}^2)}\label{eq:beta-B-noise}\end{equation}

Note that here $\mathcal{B}_{\text{noise-2}}$ is a new notation, it is not $\mathcal{B}_{\text{noise}}$, the latter being the theoretically optimal Batch Size solved from $\beta=\beta_{\text{noise}}$, which results in

\begin{equation}\mathcal{B}_{\text{noise}} = \frac{\pi\kappa^2\beta_{\text{noise}}^2}{2(1 - \beta_{\text{noise}}^2)}\end{equation}

The relationship between them is

\begin{equation}\frac{1}{\mathcal{B}_{\text{noise-2}}} - \frac{1}{\mathcal{B}_{\text{noise}}} = \frac{4}{\pi\kappa^2}\quad\Rightarrow\quad \mathcal{B}_{\text{noise}} = \left(\frac{1}{\mathcal{B}_{\text{noise-2}}} - \frac{4}{\pi\kappa^2}\right)^{-1}\label{eq:B-1-2}\end{equation}

Since the form of Eq $\eqref{eq:Delta-L-sign-2}$ is the same as SGD's Eq $\eqref{eq:Delta-L-sgd}$, the analysis of that section applies equally, thus we can also derive Eq $\eqref{eq:E-S}$:

\begin{equation}\left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1\end{equation}

except now $E_{\min}/S_{\min} = \mathcal{B}_{\text{noise-2}}$. This gives us a scheme to estimate $\beta_{\text{noise}}$ and $\mathcal{B}_{\text{noise}}$: get multiple $(S,E)$ pairs through multiple experiments. During the experiment, $\kappa^2$ can also be estimated. Then fit the equation to get $E_{\min}, S_{\min}$, and subsequently estimate $\mathcal{B}_{\text{noise-2}}$, finally solving $\beta_{\text{noise}}$ from Eq $\eqref{eq:beta-B-noise}$.

If $\beta_{\text{noise}} \geq 1$, then there is no optimal $\mathcal{B}_{\text{noise}}$. If $\beta_{\text{noise}} \gg 1$, it indicates the diagonal elements of the Hessian matrix dominate, so the scaling law $\eqref{eq:eta-opt-sign-diag}$ applies—increasing Batch Size can always moderately increase the learning rate. When $\beta_{\text{noise}} < 1$, the optimal $\mathcal{B}_{\text{noise}}$ can be solved from $\eqref{eq:B-1-2}$. If Batch Size exceeds this value, the learning rate should decrease.

Supplementary Remarks

It should be pointed out that the starting point and final conclusion of the analysis in the previous sections are largely consistent with the original paper "Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling", but the handling of approximations in the intermediate process is different.

Most of the conclusions in the original paper are approximate results under the assumption $B \ll \pi(\sigma_i/g_i)^2/2$, which leads to the conclusion that the Surge phenomenon will almost always appear. This is not entirely scientific. The most obvious problem is the form of the assumption $B \ll \pi(\sigma_i/g_i)^2/2$ itself: its right side depends on $i$. We cannot assign a separate Batch Size to each component. Therefore, to get a global result, it would have to be $B \ll \min_i \pi(\sigma_i/g_i)^2/2$, which is somewhat harsh.

The approach in this paper is to introduce the approximation $\eqref{eq:mu-approx}$, which can be seen as a mean-field approximation. Intuitively, it is more reasonable than the point-by-point assumption $B \ll \pi(\sigma_i/g_i)^2/2$, so the conclusion should theoretically be more accurate. For example, it can yield the conclusion that "even if the non-diagonal elements of the Hessian matrix are not negligible, the Surge phenomenon does not necessarily occur" (depending on $\beta_{\text{noise}}$). In particular, this accuracy does not sacrifice simplicity; for instance, Eq $\eqref{eq:eta-opt-beta}$ is also very simple and clear, and the form of Eq $\eqref{eq:Delta-L-sign-2}$ is identical to the original paper without requiring additional approximation assumptions, and so on.

Finally, a small reflection: OpenAI's analysis of SGD was actually 2018 work, بينما the paper on the Surge phenomenon was only released in the middle of this year. It's quite surprising that it took 6 years to go from SGD to Adam. This is probably due to OpenAI's "prestige" and the guess $\eqref{eq:openai-adam}$, making everyone think there was nothing left to do with Adam. No one expected that Adam might have some new characteristics. Of course, the question of how reasonable $\tilde{\boldsymbol{u}}_B = \text{sign}(\tilde{\boldsymbol{g}}_B)$ is as an approximation for Adam and to what extent it represents practical situations still deserves further thought, in the author's opinion.

Conclusion

This article has discussed the classic "AI alchemy" problem of the "Scaling Law between Batch Size and Learning Rate" from multiple perspectives. It has focused on the derivation and conclusion of OpenAI's analysis based on a second-order approximation of the loss function, as well as subsequent work using the same idea to analyze the Adam optimizer.