Asymptotic Estimation of Weight RMS for AdamW

By 苏剑林 | October 01, 2025

In "Why Adam's Update RMS is 0.2?", we used the mean-field approximation to estimate the Update RMS of Adam. Shortly after, reader @EIFY pointed out that the same result already appears in the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks". After reading it, I found that it contains not only the estimation of Update RMS but also the estimation of Weight RMS.

That is to say, for a model trained with AdamW, the RMS of its weights can be estimated asymptotically in advance. Do you find this conclusion a bit surprising? I was quite surprised when I first saw it; intuitively, the weight magnitude should be something the model learns based on the training set, but it turns out to be hidden within the optimizer's hyperparameters, which is rather counter-intuitive.

In this article, we will again use the mean-field approximation method to reproduce the asymptotic estimation of the Weight RMS.

Sliding View

First, let's review the update rules of AdamW:

\begin{equation} \text{Adam}\color{skyblue}{\text{W}}:=\left\{\begin{aligned} &\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\ &\boldsymbol{v}_t = \beta_2 \boldsymbol{v}_{t-1} + \left(1 - \beta_2\right) \boldsymbol{g}_t^2\\ &\hat{\boldsymbol{m}}_t = \boldsymbol{m}_t\left/\left(1 - \beta_1^t\right)\right.\\ &\hat{\boldsymbol{v}}_t = \boldsymbol{v}_t\left/\left(1 - \beta_2^t\right)\right.\\ &\boldsymbol{u}_t =\hat{\boldsymbol{m}}_t\left/\left(\sqrt{\hat{\boldsymbol{v}}_t} + \epsilon\right)\right.\\ &\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{skyblue}{ + \lambda_t \boldsymbol{\theta}_{t-1}}) \end{aligned}\right. \end{equation} Once again, bold symbols represent vectors in $\mathbb{R}^d$, and vector multiplication and division (including squares and square roots) are element-wise Hadamard products/quotients.

Similar to "Why Adam's Update RMS is 0.2?", we consider the case where $t \to \infty$ (relative to $\beta_1, \beta_2$) and $\epsilon \to 0$, so $\boldsymbol{u}_t = \boldsymbol{m}_t / \sqrt{\boldsymbol{v}_t}$. For now, let's consider the case where $\eta_t, \lambda_t$ are constants, so their subscripts can be omitted. Letting $\beta_3 = 1 - \eta \lambda$, we have:

\begin{equation}\boldsymbol{\theta}_t = \beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda)\label{eq:ema-wd}\end{equation}

This equation indicates that we can understand Weight Decay from the perspective of an Exponential Moving Average (EMA) of the updates. This transition in perspective is very meaningful and serves as the foundation for works such as "How to set AdamW’s weight decay as you scale model and dataset size" and "Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training".

Weighted Average

According to equation $\eqref{eq:ema-wd}$, we can expand $\boldsymbol{\theta}_t$ into a weighted average form:

\begin{equation}\boldsymbol{\theta}_t = \beta_3^t\boldsymbol{\theta}_0 + (1-\beta_3)\sum_{i=1}^t \beta_3^{t-i} (-\boldsymbol{u}_i/\lambda)\label{eq:theta-t}\end{equation}

Similarly, $\boldsymbol{m}_t$ and $\boldsymbol{v}_t$ can also be expanded as:

\begin{equation}\boldsymbol{m}_t = (1 - \beta_1)\sum_{i=1}^t \beta_1^{t-i}\boldsymbol{g}_i,\qquad \boldsymbol{v}_t = (1 - \beta_2)\sum_{i=1}^t \beta_2^{t-i}\boldsymbol{g}_i^2\label{eq:mv-roll}\end{equation}

There is a small detail: in the expression for $\boldsymbol{\theta}_t$, we retained $\boldsymbol{\theta}_0$, but in the expressions for $\boldsymbol{m}_t$ and $\boldsymbol{v}_t$, we did not retain $\boldsymbol{m}_0$ and $\boldsymbol{v}_0$. There are two reasons: 1. The initialization of $\boldsymbol{m}$ and $\boldsymbol{v}$ is usually zero; 2. Even if their initialization is not zero, the corresponding $\beta_1^t$ and $\beta_2^t$ will become sufficiently close to zero, so the influence of initialization can be ignored.

However, $\boldsymbol{\theta}$ is the model weight, and its initialization is usually not zero. Furthermore, $\beta_3$ is often very close to 1, and for the entire training cycle, $\beta_3^t$ may not necessarily become sufficiently close to zero. Therefore, we explicitly retain $\beta_3^t$ and $\boldsymbol{\theta}_0$ and choose whether to keep them as needed.

Fast Estimation

Our task is to estimate the Weight RMS, denoted as $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$. As the name suggests, it is the Root Mean Square of the individual components:

\begin{equation}\Vert\boldsymbol{\theta}\Vert_{RMS} = \sqrt{\frac{1}{d}\sum_{i=1}^d \theta_i^2},\qquad\qquad \text{where } \boldsymbol{\theta} = (\theta_1,\theta_2,\cdots,\theta_d)\end{equation}

The difference between it and the norm is just the additional division by $\sqrt{d}$, so most properties of the norm also apply to the RMS. For $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$, we have a fast but not entirely accurate derivation method: by directly taking $\Vert\cdot\Vert_{RMS}^2$ on both sides of equation $\eqref{eq:ema-wd}$, we get:

\begin{equation}\begin{aligned} \Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 =&\, \Vert\beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda)\Vert_{RMS}^2 \\[5pt] =&\, \beta_3^2\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2 + (1-\beta_3)^2\Vert\boldsymbol{u}_t\Vert_{RMS}^2/\lambda^2 - 2\beta_3(1-\beta_3)\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t/(\lambda d) \end{aligned}\end{equation}

Assuming $\boldsymbol{\theta}_{t-1}$ and $\boldsymbol{u}_t$ are nearly orthogonal, then $\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t\approx 0$. This is usually a good approximation in high-dimensional spaces (see "The distribution of the angle between two random vectors in n-dimensional space"). Since $\Vert\boldsymbol{u}_t\Vert_{RMS}$ has already been calculated as approximately $\sqrt{\frac{1-\beta_1}{1+\beta_1}}$, and considering the steady-state result where $\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2=\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2$, we have:

\begin{equation}(1-\beta_3^2)\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx (1-\beta_3)^2 \frac{1-\beta_1}{1+\beta_1} /\lambda^2\qquad\Rightarrow\qquad \Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{1-\beta_1}{1+\beta_1}\frac{\eta}{2\lambda}}\end{equation}

From the left to the right, we also used the approximation $\beta_3 \approx 1$. The final result will have some error because $\boldsymbol{\theta}_t \cdot \boldsymbol{u}_t \approx 0$ does not strictly hold, but the conclusion that $\Vert\boldsymbol{\theta}_t\Vert_{RMS} \propto \sqrt{\eta/\lambda}$ is correct. A similar derivation appears in "Why Gradients Rapidly Increase Near the End of Training".

Better Approximation

In many cases, we only need to know that $\Vert\boldsymbol{\theta}_t\Vert_{RMS} \propto \sqrt{\eta/\lambda}$ is sufficient; this is a relatively universal conclusion. However, for readers seeking a more accurate conclusion, we can use the mean-field method to obtain a better approximation. The cost is a more complex calculation process, but the benefit is that we can gain more and clearer insights.

Step One

Starting from equation $\eqref{eq:theta-t}$, the summation term itself takes the form of a weighted average, so we first apply the first mean-field approximation:

\begin{equation}\underbrace{\frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \boldsymbol{u}_i}_{\text{denoted as }\bar{\boldsymbol{u}}_t} = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \frac{\hat{\boldsymbol{m}}_i}{\sqrt{\hat{\boldsymbol{v}}_i}}\approx \frac{\bar{\boldsymbol{m}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{m}}_i}{\sqrt{\bar{\boldsymbol{v}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{v}}_i}}\label{eq:u-bar}\end{equation}

Now returning to equation $\eqref{eq:theta-t}$, since $\boldsymbol{\theta}_0$ is a random initialization vector, we can assume $\boldsymbol{\theta}_0$ is orthogonal to $\bar{\boldsymbol{u}}_t$, hence we have:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \lambda^{-2}\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2\end{equation}

Now we need to find $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$. Based on previous experience, we assume that $\boldsymbol{g}_j$ are independently and identically distributed following $\mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2)$, and then calculate:

\begin{equation}\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \mathbb{E}\left[\frac{\bar{\boldsymbol{m}}_t^2}{\bar{\boldsymbol{v}}_t}\right] \approx \frac{\mathbb{E}[\bar{\boldsymbol{m}}_t^2]}{\mathbb{E}[\bar{\boldsymbol{v}}_t]}\end{equation}

Finally, by averaging $\mathbb{E}[\bar{\boldsymbol{u}}_t^2]$ across components, we can use it as an approximation for $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$.

Step Two

Combining with equation $\eqref{eq:mv-roll}$, we get:

\begin{gather} \sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{m}}_i = (1 - \beta_1)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_1^{i-j}\boldsymbol{g}_j = (1 - \beta_1)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\boldsymbol{g}_j\\ \sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{v}}_i = (1 - \beta_2)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_2^{i-j}\boldsymbol{g}_j^2 = (1 - \beta_2)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\boldsymbol{g}_j^2\\ \end{gather}

If you lack ideas for simplifying the final double summation, you can refer to this link. From the equations above, we see that $\bar{\boldsymbol{m}}_t$ and $\bar{\boldsymbol{v}}_t$ are the weighted averages of the gradient and the gradient squared, respectively. Therefore, calculating $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$ is essentially the same as calculating $\Vert \boldsymbol{u}_t\Vert_{RMS}^2$ in "Why Adam's Update RMS is 0.2?", only with different weighting coefficients.

Step Three

We first calculate the denominator:

\begin{equation}\begin{aligned} \mathbb{E}[\bar{\boldsymbol{v}}_t] =&\, \frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\mathbb{E}[\boldsymbol{g}_j^2] \\ =&\, \frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\ =&\, \frac{(1 - \beta_3)(1 - \beta_2)}{(1 - \beta_3^t)(\beta_3 - \beta_2)}\left(\frac{\beta_3 - \beta_3^{t+1}}{1 - \beta_3} - \frac{\beta_2 - \beta_2^{t+1}}{1 - \beta_2}\right)(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\[5pt] \approx &\, \boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2 \end{aligned}\end{equation}

The final approximation is because, in actual training, $\beta_3$ will be sufficiently close to 1 and $\beta_2^{t+1}$ will be sufficiently close to 0, but $\beta_3^{t+1}$ might not be. Therefore, we substituted $\beta_2^{t+1}$ with zero, replaced independent $\beta_3$ with 1 after simplification, and finally applied the approximation $\beta_3^{t+1} \approx \beta_3^t$.

Step Four

Next is $\mathbb{E}[\bar{\boldsymbol{m}}_t^2] = \mathbb{E}[\bar{\boldsymbol{m}}_t]^2 + \mathbb{V}ar[\bar{\boldsymbol{m}}_t]$. The calculation of $\mathbb{E}[\bar{\boldsymbol{m}}_t]$ is similar to $\mathbb{E}[\bar{\boldsymbol{v}}_t]$, resulting in $\boldsymbol{\mu}$. For $\mathbb{V}ar[\bar{\boldsymbol{m}}_t]$, we utilize the additivity of variance:

\begin{equation}\begin{aligned} \mathbb{V}ar[\bar{\boldsymbol{m}}_t] =&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2\mathbb{V}ar[\boldsymbol{g}_j] \\ =&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2 \boldsymbol{\sigma}^2 \\ =&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2(\beta_3 - \beta_1)^2}\left(\frac{\beta_3^2 - \beta_3^{2(t+1)}}{1 - \beta_3^2} + \frac{\beta_1^2 - \beta_1^{2(t+1)}}{1 - \beta_1^2} - 2\frac{\beta_1\beta_3 - \beta_1^{t+1}\beta_3^{t+1}}{1 - \beta_1\beta_3}\right) \boldsymbol{\sigma}^2 \\[5pt] \approx &\, (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t) \end{aligned}\end{equation}

The reasoning for the approximation is the same as above.

Step Five

Substituting the results of the previous two steps, we have:

\begin{equation}\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \frac{\boldsymbol{\mu}^2 + (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t)}{\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2}\end{equation}

Then:

\begin{equation}\Vert\bar{\boldsymbol{u}}_t\Vert_{RMS}^2 \approx \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1} \end{equation}

Ultimately we get:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1)}\label{eq:theta-rms}\end{equation}

Analysis of Results

Formula $\eqref{eq:theta-rms}$ looks relatively complex; let's observe a few special cases. First, consider the case where $\boldsymbol{\mu}=\boldsymbol{0}$. Here:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^{2t}) (1 - \beta_3)/2\lambda^2 = \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^{2t}) \eta/2\lambda\label{eq:theta-rms-mu0}\end{equation}

In particular, if we consider $t\to\infty$, or if $\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2$ is initialized to $\eta/2\lambda$, then we have:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{\eta}{2\lambda}}\label{eq:theta-rms-simple}\end{equation}

This is the result given in the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks". Consistent with the original paper's assumptions, it is the steady-state result of a random walk under zero mean. If we do not consider $t\to\infty$ but instead consider the limit $\lambda\to 0$, equation $\eqref{eq:theta-rms-mu0}$ gives:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + \eta^2 t\end{equation}

This indicates that in the absence of weight decay, $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$ grows roughly at a rate of $\eta\sqrt{t}$. This also suggests that without weight decay, we could achieve stability in Weight RMS by setting a specific learning rate schedule. On the other hand, if the batch size is large enough such that the signal-to-noise ratio term $\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2$ dominates, then equation $\eqref{eq:theta-rms}$ gives:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2}{\lambda^2(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1)}\end{equation}

This might apply to special situations where the model needs to actively increase Weight RMS. However, from experience, this situation is generally unlikely to occur.

Simulation Experiment

We can use the following simulation script to simply verify the accuracy mentioned above:

import numpy as np

N, T = 10000, 100000
beta1, beta2 = 0.9, 0.95
m, v = 0, 0
w = np.random.randn(N) * 0.1
for i in range(T):
 g = np.random.randn(N)
 m = beta1 * m + (1 - beta1) * g
 v = beta2 * v + (1 - beta2) * g**2
 w = w - 0.001 * (m / v**0.5 + 0.1 * w)

weight_rms = (w**2).mean()**0.5
print(weight_rms)

You can try changing the weight initialization or the mean and variance of the gradients to see how the final results match equation $\eqref{eq:theta-rms}$. I tried a few cases myself, and overall, it is quite reliable.

Sign Version

By slightly adjusting the previous proof, it can be applied to the combination of "SignSGDM + Weight Decay":

\begin{equation}\text{SignSGDM}\color{skyblue}{\text{W}}:=\left\{\begin{aligned} &\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\ &\boldsymbol{u}_t = \newcommand{sign}{\mathop{\text{sign}}}\sign(\boldsymbol{m}_t)\\ &\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{skyblue}{ + \lambda_t \boldsymbol{\theta}_{t-1}}) \end{aligned}\right.\end{equation}

The modification is due to $\sign(\boldsymbol{m}_t)=\boldsymbol{m}_t/\sqrt{\boldsymbol{m}_t^2}$, so the definition of $\bar{\boldsymbol{v}}_t$ should be changed to:

\begin{equation}\bar{\boldsymbol{v}}_t \triangleq \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{m}_i^2\end{equation}

Then:

\begin{equation}\mathbb{E}[\bar{\boldsymbol{v}}_t] = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}[\boldsymbol{m}_i^2] \approx \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}\left(\boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2\right) = \boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2\end{equation}

Where the calculation of $\mathbb{E}[\boldsymbol{m}_i^2]$ can refer to "Why Adam's Update RMS is 0.2?" or "Rethinking Learning Rate and Batch Size (Part IV): EMA". Using these results, we get:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2\left(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + \frac{1-\beta_1}{1 + \beta_1}\right)}\end{equation}

Particularly, considering the limit $\boldsymbol{\mu}=0, t\to\infty$, we have:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \sqrt{\frac{\eta}{2\lambda}\frac{1+\beta_1}{1 - \beta_1}}\end{equation}

This result is also reasonable because the Update RMS of SignSGDMW is $\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$ times that of AdamW. Therefore, for the same $\eta, \lambda$, its Weight RMS is also $\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$ times larger.

Related Analysis

As mentioned earlier, result $\eqref{eq:theta-rms-simple}$ is consistent with the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks", but our derivation method is completely different and leads to the more general $\eqref{eq:theta-rms}$. However, the original paper also has some interesting points, such as the concept of **Total Update Contribution (TUC)**, which is worth some appreciation.

The idea of TUC is as follows: due to the existence of momentum, the current gradient $\boldsymbol{g}_t$ does not only remain at the current step; it also affects future steps (with a "decay"). Thus, assuming the number of training steps tends to infinity, we can consider the **total contribution** of the current gradient $\boldsymbol{g}_t$ to the entire training process. Specifically, for Adam we have $\boldsymbol{u}_t = \boldsymbol{m}_t / \sqrt{\boldsymbol{v}_t}$. The contribution of the current $\boldsymbol{g}_t$ to $\boldsymbol{u}_t$ is $(1-\beta_1)\boldsymbol{g}_t / \sqrt{\boldsymbol{v}_t}$. In the next step, $\boldsymbol{g}_t$ will be decayed (multiplied by $\beta_1$), and the denominator changed to $\boldsymbol{v}_{t+1}$, and so on. Thus, the total contribution can be defined as:

\begin{equation}\tilde{\boldsymbol{u}}_t = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t}\frac{\boldsymbol{g}_t}{\sqrt{\boldsymbol{v}_k}}\end{equation}

In this way, we decompose the updates $\boldsymbol{u}_1, \boldsymbol{u}_2, \boldsymbol{u}_3, \cdots$ into contributions from $\tilde{\boldsymbol{u}}_1, \tilde{\boldsymbol{u}}_2, \tilde{\boldsymbol{u}}_3, \cdots$. The advantage is that each $\tilde{\boldsymbol{u}}_t$ only contains the gradient from a single step, so we can repeat the derivation from the Fast Estimation section:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 = \Vert\beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\tilde{\boldsymbol{u}}_t/\lambda)\Vert_{RMS}^2 \approx \beta_3^2\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2 + (1-\beta_3)^2\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS}^2/\lambda^2 \label{eq:tilde-u-rms}\end{equation}

The final approximation relies on $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$. we assert that $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t$ is closer to zero than $\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t$ because $\tilde{\boldsymbol{u}}_t$ only depends on the current gradient $\boldsymbol{g}_t$, while $\boldsymbol{\theta}_{t-1}$ has not yet encountered $\boldsymbol{g}_t$. Therefore, they are independent variables, and when $\boldsymbol{g}_t$ has a zero mean, $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$ is likely to hold. To estimate $\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS}^2$, the original paper directly assumes that $\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}$ has the same direction and unit RMS, so:

\begin{equation}\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS} = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t}\left\Vert\frac{\boldsymbol{g}_t}{\sqrt{\boldsymbol{v}_k}}\right\Vert_{RMS} = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t} = 1\end{equation}

Substituting this into equation $\eqref{eq:tilde-u-rms}$ and applying the same approximations as in the Fast Estimation section, we solve for:

\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{\eta}{2\lambda}}\end{equation}

However, if limited to the lens of the original paper, we find many approximations rather mysterious—for example, $\boldsymbol{v}_t$ also contains $\boldsymbol{g}_t$, so saying $\tilde{\boldsymbol{u}}_t$ only contains the influence of the current $\boldsymbol{g}_t$ is not entirely accurate. Furthermore, the assertion $\Vert\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}\Vert_{RMS}=1$ seems somewhat forced. But placed in the context of this article, we find that under the mean-field approximation, the methods in the original paper appear very reasonable. Thus, the original paper was actually implicitly using the mean-field method.

Summary

In this article, we used the mean-field approximation to derive an interesting and perhaps surprising conclusion: for a model trained with AdamW, the RMS of its weights can be estimated asymptotically. In general, it primarily depends on the learning rate and Weight Decay.