By 苏剑林 | October 10, 2020
The theme of this article is "why we need a finite learning rate." By "finite," we mean the rate should be neither too large nor too small, but just right. It is easy to understand that a learning rate that is too large can lead to the divergence of the algorithm. But why is a learning rate that is too small also undesirable? An easy-to-understand answer is that a learning rate that is too small requires an excessive number of iterations, which is a waste of time and computational power. Therefore, from the perspective of "efficiency" and "acceleration," we avoid using excessively small learning rates. However, if we set aside the factors of computational power and time, would an extremely small learning rate be desirable? A recent paper published by Google on arXiv, "Implicit Gradient Regularization," attempts to answer this question. It points out that a finite learning rate implicitly introduces a gradient penalty term into the optimization process, which is beneficial for improving generalization performance. Therefore, even without considering computational power or time, one should not use a learning rate that is too small.
Regarding gradient penalty, this blog has discussed it several times. In the articles "A Brief Talk on Adversarial Training: Significance, Methods, and Reflections" and "Random Thoughts on Generalization: From Random Noise and Gradient Penalty to Virtual Adversarial Training," we analyzed how adversarial training is somewhat equivalent to a gradient penalty on the input. The article "Do We Really Need to Reduce Training Set Loss to Zero?" introduced the Flooding technique, which corresponds to a gradient penalty on the parameters. In general, whether it's a gradient penalty on inputs or parameters, it helps improve generalization capabilities.
Following this series of articles, we view the optimization process as solving a differential equation. Recalling the previous post "Optimization Algorithms from a Dynamical Perspective (III): A More Holistic View," let the loss function be $L(\boldsymbol{\theta})$. we regard $\boldsymbol{\theta}$ as a trajectory $\boldsymbol{\theta}(t)$ evolving along some time parameter $t$. Now, let's consider its rate of change:
\begin{equation}\frac{d}{dt}L(\boldsymbol{\theta}(t))=\left\langle\nabla_{\boldsymbol{\theta}}L(\boldsymbol{\theta}(t)),\, \dot{\boldsymbol{\theta}}(t)\right\rangle\end{equation}We want $L(\boldsymbol{\theta}(t))$ to decrease over time (lower loss is better), so we want the above expression to be less than 0. When the magnitude $\Vert\dot{\boldsymbol{\theta}}(t)\Vert$ is fixed, the minimum value of the right side is achieved in the direction opposite to the gradient $-\nabla_{\boldsymbol{\theta}}L(\boldsymbol{\theta}(t))$. Thus, we say the negative gradient direction is the direction of fastest descent. For simplicity, we can let
\begin{equation}\dot{\boldsymbol{\theta}}(t) = -\nabla_{\boldsymbol{\theta}}L(\boldsymbol{\theta}(t))\triangleq - \boldsymbol{g}(\boldsymbol{\theta}(t))\label{eq:odes}\end{equation}Then, solving for the parameters $\boldsymbol{\theta}$ transforms into solving the above system of ordinary differential equations (ODEs). This is the basic starting point for this series.
However, the practical problem is that we cannot truly solve the system of differential equations $\eqref{eq:odes}$; we can only use numerical iteration. Using the simplest Euler method, we get:
\begin{equation}\boldsymbol{\theta}_{t+\gamma} = \boldsymbol{\theta}_{t} - \gamma \boldsymbol{g}(\boldsymbol{\theta}_t)\label{eq:gd}\end{equation}This is actually the most basic Gradient Descent (GD) method, where $\gamma$ is what we usually call the learning rate. Effectively, this is a difference equation.
One can imagine starting from $t=0$, the resulting points $\boldsymbol{\theta}_{\gamma},\boldsymbol{\theta}_{2\gamma},\boldsymbol{\theta}_{3\gamma},\cdots$ will differ to some extent from the exact solutions $\boldsymbol{\theta}(\gamma),\boldsymbol{\theta}(2\gamma),\boldsymbol{\theta}(3\gamma),\cdots$ of the equation system $\eqref{eq:odes}$. How do we measure the extent of this discrepancy? Imagine that $\boldsymbol{\theta}_{\gamma},\boldsymbol{\theta}_{2\gamma},\boldsymbol{\theta}_{3\gamma},\cdots$ are actually exact solutions to a system of differential equations similar to $\eqref{eq:odes}$, but with $\boldsymbol{g}(\boldsymbol{\theta}(t))$ replaced by some new $\tilde{\boldsymbol{g}}(\boldsymbol{\theta}_t)$. We can then compare the difference between $\tilde{\boldsymbol{g}}(\boldsymbol{\theta}_t)$ and $\boldsymbol{g}(\boldsymbol{\theta}(t))$.
After derivation, if we only keep terms up to the first order of $\gamma$, we have:
\begin{equation}\tilde{\boldsymbol{g}}(\boldsymbol{\theta}_t) = \boldsymbol{g}(\boldsymbol{\theta}_t) + \frac{\gamma}{4}\nabla_{\boldsymbol{\theta}}\Vert \boldsymbol{g}(\boldsymbol{\theta}_t)\Vert^2 = \nabla_{\boldsymbol{\theta}}\left(L(\boldsymbol{\theta}_t) + \frac{1}{4}\gamma\Vert \nabla_{\boldsymbol{\theta}} L(\boldsymbol{\theta}_t)\Vert^2\right)\end{equation}The derivation process is provided in the next section. As we can see, this is equivalent to adding a gradient penalty regularization term $\frac{1}{4}\gamma\Vert \nabla_{\boldsymbol{\theta}} L(\boldsymbol{\theta})\Vert^2$ to the loss function. The gradient penalty term helps the model reach smoother regions, which is conducive to improving generalization performance. This means that the discretized iteration process implicitly brings a gradient penalty term, which is actually helpful for the model's generalization. If $\gamma \to 0$, this implicit penalty will weaken or even disappear.
Therefore, the conclusion is that the learning rate should not be too small. A larger learning rate not only has the benefit of accelerating convergence but also the benefit of improving the model's generalization ability. Of course, some readers might think: if I directly add the gradient penalty to the loss, could I use an arbitrarily small learning rate? Theoretically, yes. The original paper refers to the practice of adding gradient penalty to the loss as "explicit gradient penalty."
For the conversion of difference equations to differential equations, we can use the standard "perturbation method," which has been briefly introduced in this blog before (you can check the tag "perturbation"). However, a more elegant solution is to use operator series expansion directly, referring to the previous article "The Art of Operators: Difference, Differentiation, and Bernoulli Numbers."
We expand $\boldsymbol{\theta}_{t+\gamma}$ using a Taylor series:
\begin{equation}\boldsymbol{\theta}_{t+\gamma}=\boldsymbol{\theta}_{t}+\gamma \dot{\boldsymbol{\theta}}_{t} + \frac{1}{2}\gamma^2\ddot{\boldsymbol{\theta}}_{t} + \frac{1}{6}\gamma^3\dddot{\boldsymbol{\theta}}_{t} + \cdots\end{equation}If we denote the derivative operator with respect to $t$ as $D$, then the equation is actually:
\begin{equation}\boldsymbol{\theta}_{t+\gamma} = \left(1+\gamma D + \frac{1}{2}\gamma^2 D^2 + \frac{1}{6}\gamma^3 D^3 + \cdots\right)\boldsymbol{\theta}_{t} = e^{\gamma D}\boldsymbol{\theta}_{t}\end{equation}So the difference equation $\eqref{eq:gd}$ can be written as:
\begin{equation}\left(e^{\gamma D} - 1\right)\boldsymbol{\theta}_{t} = - \gamma \boldsymbol{g}(\boldsymbol{\theta}_t)\end{equation}Just like standard algebraic operations, we have:
\begin{equation}\begin{aligned} D\boldsymbol{\theta}_{t} =& - \gamma \left(\frac{D}{e^{\gamma D} - 1}\right)\boldsymbol{g}(\boldsymbol{\theta}_t)\\ =& - \left(1 - \frac{1}{2}\gamma D + \frac{1}{12}\gamma^2 D^2 - \frac{1}{720}\gamma^4 D^4 + \cdots\right)\boldsymbol{g}(\boldsymbol{\theta}_t) \end{aligned}\end{equation}The left side is $\dot{\boldsymbol{\theta}}_{t}$, so the right side is the expression for $-\tilde{\boldsymbol{g}}(\boldsymbol{\theta}_t)$. Keeping terms up to the first order:
\begin{equation} - \left(1 - \frac{1}{2}\gamma D\right)\boldsymbol{g}(\boldsymbol{\theta}_t) = - \boldsymbol{g}(\boldsymbol{\theta}_t) + \frac{1}{2}\gamma \frac{d}{dt}\boldsymbol{g}(\boldsymbol{\theta}_t) = - \boldsymbol{g}(\boldsymbol{\theta}_t) + \frac{1}{2}\gamma \nabla_{\boldsymbol{\theta}}\boldsymbol{g}(\boldsymbol{\theta}_t)\dot{\boldsymbol{\theta}}_t \end{equation}That is:
\begin{equation}\begin{aligned} \dot{\boldsymbol{\theta}}_{t} =& - \boldsymbol{g}(\boldsymbol{\theta}_t) + \frac{1}{2}\gamma \nabla_{\boldsymbol{\theta}}\boldsymbol{g}(\boldsymbol{\theta}_t)\dot{\boldsymbol{\theta}}_t\\ =& - \boldsymbol{g}(\boldsymbol{\theta}_t) + \frac{1}{2}\gamma \nabla_{\boldsymbol{\theta}}\boldsymbol{g}(\boldsymbol{\theta}_t)\left[- \boldsymbol{g}(\boldsymbol{\theta}_t) + \frac{1}{2}\gamma \nabla_{\boldsymbol{\theta}}\boldsymbol{g}(\boldsymbol{\theta}_t)\dot{\boldsymbol{\theta}}_t\right]\\ =& - \boldsymbol{g}(\boldsymbol{\theta}_t) - \frac{1}{2}\gamma \nabla_{\boldsymbol{\theta}}\boldsymbol{g}(\boldsymbol{\theta}_t)\boldsymbol{g}(\boldsymbol{\theta}_t)\quad\text{(ignoring second-order terms)}\\ =& - \boldsymbol{g}(\boldsymbol{\theta}_t) - \frac{1}{4}\gamma \nabla_{\boldsymbol{\theta}}\Vert\boldsymbol{g}(\boldsymbol{\theta}_t)\Vert^2 \end{aligned}\end{equation}So the first-order $\tilde{\boldsymbol{g}}(\boldsymbol{\theta}_t)=\boldsymbol{g}(\boldsymbol{\theta}_t) + \frac{1}{4}\gamma \nabla_{\boldsymbol{\theta}}\Vert\boldsymbol{g}(\boldsymbol{\theta}_t)\Vert^2$. The derivation is complete.
The development and popularization of deep learning cannot be separated from the successful application of optimizers based on gradient descent. However, why gradient descent is so successful has yet to be profoundly explained. In the process of "alchemy," many researchers summarize some "tips and tricks" that are effective for unknown reasons, such as how large the batch size should be or how to adjust the learning rate; everyone likely has their own experience.
The phenomenon that "the learning rate shouldn't be too small" is something everyone has likely experienced. In many cases, it is processed as a piece of "common sense" without bothering to think about the underlying principles. This paper from Google provides a possible explanation for understanding this phenomenon: an appropriate, rather than excessively small, learning rate brings an implicit gradient penalty term to the optimization process, helping it converge to smoother regions. I believe this analysis process is worth referencing and learning from.