Analysis of the AdaX Optimizer (with Open-Source Implementation)

By 苏剑林 | May 11, 2020

This article briefly introduces an optimizer called AdaX, from the paper "AdaX: Adaptive Gradient Descent with Exponential Long Term Memory". The reason for introducing this optimizer is that it once again confirms a conclusion mentioned in my previous post "Analysis of the AdaFactor Optimizer (with Open-Source Implementation)"; the two articles can be read in comparison.

Adam & AdaX

The update format of AdaX is:

\begin{equation}\left\{\begin{aligned}&g_t = \nabla_{\theta} L(\theta_t)\\ &m_t = \beta_1 m_{t-1} + \left(1 - \beta_1\right) g_t\\ &v_t = (1 + \beta_2) v_{t-1} + \beta_2 g_t^2\\ &\hat{v}_t = v_t\left/\left(\left(1 + \beta_2\right)^t - 1\right)\right.\\ &\theta_t = \theta_{t-1} - \alpha_t m_t\left/\sqrt{\hat{v}_t + \epsilon}\right. \end{aligned}\right.\end{equation}

Where the default value of $\beta_2$ is $0.0001$. By the way, here is my Keras implementation: https://github.com/bojone/adax

For comparison, the update format of Adam is:

\begin{equation}\left\{\begin{aligned}&g_t = \nabla_{\theta} L(\theta_t)\\ &m_t = \beta_1 m_{t-1} + \left(1 - \beta_1\right) g_t\\ &v_t = \beta_2 v_{t-1} + \left(1 - \beta_2\right) g_t^2\\ &\hat{m}_t = m_t\left/\left(1 - \beta_1^t\right)\right.\\ &\hat{v}_t = v_t\left/\left(1 - \beta_2^t\right)\right.\\ &\theta_t = \theta_{t-1} - \alpha_t \hat{m}_t\left/\sqrt{\hat{v}_t + \epsilon}\right. \end{aligned}\right.\end{equation}

Where the default value of $\beta_2$ is $0.999$.

Equivalent Form Transformation

As can be seen, the first difference is that AdaX removes the bias correction for momentum (the step $\hat{m}_t = m_t\left/\left(1 - \beta_1^t\right)\right.$), but this actually has little impact. The biggest modification in AdaX is at $v_t$. Originally, $v_t = \beta_2 v_{t-1} + \left(1 - \beta_2\right) g_t^2$ is a moving average format, whereas $v_t = (1 + \beta_2) v_{t-1} + \beta_2 g_t^2$ does not look like a moving average, and since $1 + \beta_2 > 1$, there seems to be a risk of exponential explosion. The original paper calls this "with Exponential Long Term Memory," referring to the fact that $1 + \beta_2 > 1$ causes the weight of historical accumulated gradients not to decrease, but rather to increase; this is its long-term memory property.

In fact, the learning rate correction uses $\hat{v}_t$, so to see whether it explodes, we should observe $\hat{v}_t$. For Adam, we have:

\begin{equation}\begin{aligned} \hat{v}_t =& v_t\left/\left(1 - \beta_2^t\right)\right.\\ =&\frac{\beta_2 v_{t-1} + (1-\beta_2) g_t^2}{1 - \beta_2^t}\\ =&\frac{\beta_2 \hat{v}_{t-1}\left(1 - \beta_2^{t-1}\right) + (1-\beta_2) g_t^2}{1 - \beta_2^t}\\ =&\beta_2\frac{1 - \beta_2^{t-1}}{1 - \beta_2^t}\hat{v}_{t-1} + \left(1 - \beta_2\frac{1 - \beta_2^{t-1}}{1 - \beta_2^t}\right)g_t^2 \end{aligned}\end{equation}

So if we set $\hat{\beta}_{2,t}=\beta_2\frac{1 - \beta_2^{t-1}}{1 - \beta_2^t}$, then the update formula is:

\begin{equation}\hat{v}_t =\hat{\beta}_{2,t}\hat{v}_{t-1} + \left(1 - \hat{\beta}_{2,t}\right)g_t^2\end{equation}

Based on the same logic, if we set $\hat{\beta}_{2,t}=1 - \frac{\beta_2}{(1 + \beta_2)^t - 1}$, then the update formula for AdaX's $\hat{v}_t$ can also be written in the above form.

Comparison of Decay Strategies

Therefore, looking at $\hat{v}_t$, which is actually used to correct the gradients, whether it is Adam or AdaX, the update formulas are both in the moving average format, only the corresponding decay coefficients $\hat{\beta}_{2,t}$ are different.

For Adam, when $t=1$, $\hat{\beta}_{2,t}=0$. At this time, $\hat{v}_t$ is $g_t^2$, meaning the real-time gradient is used to correct the learning rate, providing the strongest correction. When $t\to\infty$, $\hat{\beta}_{2,t}\to \beta_2$. At this point, $v_t$ is a weighted average of the accumulated squared gradients and the current squared gradient. Since $\beta_2 < 1$, the weight of the current gradient $1 - \beta_2$ is not zero. This might lead to training instability because, in the later stages of training, gradients become smaller and the training itself tends toward stability, so the significance of learning rate correction decreases. Consequently, the intensity of learning rate correction should become smaller, and as $t\to\infty$, the learning rate should ideally become constant (effectively degrading to SGD), which requires that when $t\to\infty$, $\hat{\beta}_{2,t}\to 1$.

For AdaX, when $t=1$, $\hat{\beta}_{2,t}=0$, and when $t\to\infty$, $\hat{\beta}_{2,t}\to 1$, satisfying the ideal property mentioned above. Therefore, from this perspective, AdaX is indeed an improvement over Adam. AdaFactor uses $\hat{\beta}_{2,t} = 1 - \frac{1}{t^c}$, which was also designed from this perspective. As for whether the strategy of AdaX or AdaFactor is superior, I believe it is hard to explain clearly from a theoretical standpoint and can probably only be determined through experiments.

Conclusion

Well, the article ends here. As stated at the beginning, this post is just a simple introduction to AdaX, as it once again confirms a previous conclusion—$\hat{\beta}_{2,t}$ should satisfy the condition "$\hat{\beta}_{2,1}=0, \hat{\beta}_{2,\infty}=1$". This may become one of the basic conditions for the improvement of optimizers in the future.

Original Address: https://kexue.fm/archives/7387