Gradient Flow: Exploring the Path to the Minimum

By 苏剑林 | June 16, 2023

In this article, we will explore a concept known as "Gradient Flow." Put simply, gradient flow connects the various points we encounter during the process of finding a minimum using gradient descent, forming a trajectory that changes over (virtual) time. This trajectory is what we call "gradient flow." In the second half of the article, we will focus on how to extend the concept of gradient flow to probability spaces, resulting in "Wasserstein Gradient Flow," which provides a new perspective for understanding the continuity equation, the Fokker-Planck equation, and other related topics.

Gradient Descent

Suppose we want to search for the minimum of a smooth function $f(\boldsymbol{x})$. A common approach is Gradient Descent, which iterates according to the following format: \begin{equation}\boldsymbol{x}_{t+1} = \boldsymbol{x}_t -\alpha \nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\label{eq:gd-d}\end{equation} If $f(\boldsymbol{x})$ is convex with respect to $\boldsymbol{x}$, gradient descent can usually find the minimum point; conversely, it typically only converges to a "stationary point"—a point where the gradient is 0. In an ideal case, it may converge to a local minimum. Here, we do not strictly distinguish between local and global minima because, in deep learning, even reaching a local minimum is a significant achievement.

If we denote $\alpha$ as $\Delta t$ and $\boldsymbol{x}_{t+1}$ as $\boldsymbol{x}_{t+\Delta t}$, and consider the limit as $\Delta t \to 0$, then Equation $\eqref{eq:gd-d}$ becomes an ODE: \begin{equation}\frac{d\boldsymbol{x}_t}{dt} = -\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\label{eq:gd-c}\end{equation} The trajectory $\boldsymbol{x}_t$ obtained by solving this ODE is what we call "Gradient Flow." That is to say, gradient flow is the trajectory of gradient descent in the process of searching for the minimum. Under the premise that Equation $\eqref{eq:gd-c}$ holds, we also have: \begin{equation}\frac{df(\boldsymbol{x}_t)}{dt} = \left\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\frac{d\boldsymbol{x}_t}{dt}\right\rangle = -\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert^2 \leq 0\end{equation} This means that as long as $\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t) \neq \boldsymbol{0}$, gradient descent will always move in a direction that decreases $f(\boldsymbol{x})$, provided the learning rate is small enough.

For more related discussions, you can refer to previous series on optimization algorithms, such as "Optimization Algorithms from a Dynamical Perspective (I): From SGD to Momentum Acceleration" and "Optimization Algorithms from a Dynamical Perspective (III): A More Holistic View."

Steepest Direction

Why use gradient descent? A mainstream explanation is that "the negative gradient direction is the direction of steepest local descent." You can find a lot of content by searching for this phrase directly. This statement isn't wrong, but it's slightly imprecise because it doesn't specify the preconditions—the "steepest" in "steepest direction" inherently involves quantitative comparison. Only after determining the comparison metric can we determine the result of the "steepest" direction.

If we only care about the direction of steepest descent, the goal of gradient descent should be: \begin{equation}\boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x},\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert = \epsilon} f(\boldsymbol{x})\label{eq:gd-min-co}\end{equation} Assuming a first-order approximation is sufficient, then we have: \begin{equation}\begin{aligned} f(\boldsymbol{x})&\,=f(\boldsymbol{x}_t) + \langle \nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x} - \boldsymbol{x}_t\rangle\\ &\,\geq f(\boldsymbol{x}_t) - \Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert \Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert\\ &\,= f(\boldsymbol{x}_t) - \Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert \epsilon\\ \end{aligned}\end{equation} The condition for the equality to hold is: \begin{equation}\boldsymbol{x} - \boldsymbol{x}_t = -\epsilon\frac{\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)}{\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert}\quad\Rightarrow\quad\boldsymbol{x}_{t+1} = \boldsymbol{x}_t - \epsilon\frac{\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)}{\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert}\label{eq:gd-d-norm} \end{equation} As we can see, the update direction is exactly the negative direction of the gradient, so it is indeed the direction of steepest local descent. However, do not forget that this is obtained under the constraint condition $\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert = \epsilon$, where $\Vert\cdot\Vert$ is the Euclidean norm. If we change the definition of the norm, or simply change the constraint condition, the result will be different. Therefore, more strictly speaking, it should be "In Euclidean space, the negative gradient direction is the direction of steepest local descent."

Optimization Perspective

Equation $\eqref{eq:gd-min-co}$ is a constrained optimization problem, which is troublesome to generalize and solve. Furthermore, the result of solving Equation $\eqref{eq:gd-min-co}$ is Equation $\eqref{eq:gd-d-norm}$, which is not the original gradient descent $\eqref{eq:gd-d}$. In fact, it can be proven that the optimization objective corresponding to Equation $\eqref{eq:gd-d}$ is: \begin{equation}\boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha} + f(\boldsymbol{x})\label{eq:gd-min}\end{equation} In other words, by including the constraint as a penalty term in the optimization objective, we avoid having to solve a constrained problem, and it becomes easier to generalize. Moreover, even with the addition of the extra $\frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha}$, it is guaranteed that the optimization will not move in a worse direction, because substituting $\boldsymbol{x} = \boldsymbol{x}_t$ clearly makes the objective function equal to $f(\boldsymbol{x}_t)$, so the result of $\min_{\boldsymbol{x}}$ will at least not be greater than $f(\boldsymbol{x}_t)$.

When $\alpha$ is small enough, the first term dominates, thus $\Vert \boldsymbol{x} - \boldsymbol{x}_t \Vert$ needs to be small enough for the first term to be small. That is, the optimal point should be very close to $\boldsymbol{x}_t$. Thus we can expand $f(\boldsymbol{x})$ at $\boldsymbol{x}_t$ to get: \begin{equation}\boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha} + f(\boldsymbol{x}_t)+\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x}-\boldsymbol{x}_t\rangle\end{equation} This is currently just a quadratic minimization problem, and the solution is exactly Equation $\eqref{eq:gd-d}$.

Obviously, besides the squared norm, we can consider other regularization terms, thus forming different gradient descent schemes. For example, Natural Gradient Descent uses KL divergence as the regularization term: \begin{equation}\boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{KL(p(\boldsymbol{y}\|\boldsymbol{x})\Vert p(\boldsymbol{y}\|\boldsymbol{x}_t))}{\alpha} + f(\boldsymbol{x})\end{equation} where $p(\boldsymbol{y}\|\boldsymbol{x})$ is some probability distribution related to $f(\boldsymbol{x})$. To solve the above, we similarly expand at $f(\boldsymbol{x})$. While $f(\boldsymbol{x})$ is expanded to the first order, KL divergence is special; its first-order expansion is zero (refer to here), so it must be expanded to at least the second order. The total result is: \begin{equation}\boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{(\boldsymbol{x}-\boldsymbol{x}_t)^{\top}\boldsymbol{F}(\boldsymbol{x}-\boldsymbol{x}_t)}{2\alpha} + f(\boldsymbol{x}_t)+\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x}-\boldsymbol{x}_t\rangle\end{equation} Here, $\boldsymbol{F}$ is the Fisher Information Matrix. We won't expand on the calculation details, but the process can also be referenced here. Now the equation above is essentially a quadratic minimization problem, and the result is: \begin{equation}\boldsymbol{x}_{t+1} = \boldsymbol{x}_t -\alpha \boldsymbol{F}^{-1}\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\end{equation} This is what is called "Natural Gradient Descent."

Introduction to Functionals

Equation $\eqref{eq:gd-min}$ can generalize not only the regularization term but also the target of optimization, for example, extending it to functionals.

"Functionals" (泛函) might sound intimidating, but in fact, regular readers of this site should have encountered them many times. Simply put, a multivariate function takes a vector as input and outputs a scalar, whereas a functional takes a function as input and outputs a scalar, like the definite integral operation: \begin{equation}\mathcal{I}[f] = \int_a^b f(x)dx\end{equation} For any function $f$, the result of calculating $\mathcal{I}[f]$ is a scalar, so $\mathcal{I}[f]$ is a functional. Another example is the KL divergence mentioned earlier, which is defined as: \begin{equation}KL(p\Vert q) = \int p(\boldsymbol{x})\log \frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}d\boldsymbol{x}\end{equation} Here the integral is implicitly over the entire space. If $p(\boldsymbol{x})$ is fixed, then it is a functional of $q(\boldsymbol{x})$ because $q(\boldsymbol{x})$ is a function; inputting a function that satisfies the conditions, $KL(p\Vert q)$ will output a scalar. More generally, the $f$-divergence introduced in "Introduction to f-GAN: The Production Workshop of GAN Models" is also a type of functional. These are relatively simple functionals; more complex functionals might include derivatives of the input function, such as the principle of least action in theoretical physics.

Below, we will focus primarily on functionals whose domain is the set of all probability density functions—that is, studying functionals that take a probability density as input and output a scalar.

Flow of Probability

Suppose we have a functional $\mathcal{F}[q]$ and we want to calculate its minimum. Following the logic of gradient descent, as long as we can find some kind of gradient for it, we can iterate in its negative direction.

To determine the iteration format, we follow our previous thoughts and consider a generalization of Equation $\eqref{eq:gd-min}$, where $f(\boldsymbol{x})$ is naturally replaced by $\mathcal{F}[q]$. What should the first regularization term be replaced with? In Equation $\eqref{eq:gd-min}$, it is the square of the Euclidean distance, so it is natural to think it should be replaced by the square of some distance here. For probability distributions, a distance with good properties is the Wasserstein distance (specifically, the "2-Wasserstein distance"): \begin{equation}\mathcal{W}_2[p,q]=\sqrt{\inf_{\gamma\in \Pi[p,q]} \iint \gamma(\boldsymbol{x},\boldsymbol{y}) \Vert\boldsymbol{x}-\boldsymbol{y}\Vert^2 d\boldsymbol{x}d\boldsymbol{y}}\end{equation} An introduction to it will not be detailed here; interested readers can refer to "From Wasserstein Distance and Duality Theory to WGAN." If we further replace the Euclidean distance in Equation $\eqref{eq:gd-min}$ with the Wasserstein distance, the final objective is: \begin{equation}q_{t+1} = \mathop{\text{argmin}}_{q} \frac{\mathcal{W}_2^2[q,q_t]}{2\alpha} + \mathcal{F}[q]\end{equation} I apologize that I cannot concisely provide the solving process for the objective above; even I myself do not fully understand the solving process. I can only provide the result directly based on literature such as "Introduction to Gradient Flows in the 2-Wasserstein Space" and "{ Euclidean, Metric, and Wasserstein } Gradient Flows: an overview" as: \begin{equation}q_{t+1}(\boldsymbol{x}) = q_t(\boldsymbol{x}) + \alpha \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\delta \mathcal{F}[q_t(\boldsymbol{x})]}{\delta q_t(\boldsymbol{x})}\right)\end{equation} Or after taking the limit: \begin{equation}\frac{\partial q_t(\boldsymbol{x})}{\partial t} = \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\delta \mathcal{F}[q_t(\boldsymbol{x})]}{\delta q_t(\boldsymbol{x})}\right)\end{equation} This is the "Wasserstein Gradient Flow," where $\frac{\delta \mathcal{F}[q]}{\delta q}$ is the variational derivative of $\mathcal{F}[q]$. For definite integral functionals, the variational derivative is simply the derivative of the integrand: \begin{equation}\mathcal{F}[q] = \int F(q(\boldsymbol{x}))d\boldsymbol{x} \quad\Rightarrow\quad \frac{\delta \mathcal{F}[q(\boldsymbol{x})]}{\delta q(\boldsymbol{x})} = \frac{\partial F(q(\boldsymbol{x}))}{\partial q(\boldsymbol{x})}\end{equation}

Some Examples

According to "Introduction to f-GAN: The Production Workshop of GAN Models," the $f$-divergence is defined as: \begin{equation}\mathcal{D}_f(p\Vert q) = \int q(\boldsymbol{x}) f\left(\frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}\right)d\boldsymbol{x}\end{equation} Fixing $p$ and setting $\mathcal{F}[q]=\mathcal{D}_f(p\Vert q)$, we obtain: \begin{equation}\frac{\partial q_t(\boldsymbol{x})}{\partial t} = \nabla_{\boldsymbol{x}}\cdot\Big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\big(f(r_t(\boldsymbol{x})) - r_t(\boldsymbol{x}) f'(r_t(\boldsymbol{x}))\big)\Big)\label{eq:wgd}\end{equation} where $r_t(\boldsymbol{x}) = \frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}$. Based on the content of "Deriving the Continuity Equation and Fokker-Planck Equation Using the Test Function Method," the above equation takes the form of a continuity equation. Therefore, through the ODE: \begin{equation}\frac{d\boldsymbol{x}}{dt} = - \nabla_{\boldsymbol{x}}\big(f(r_t(\boldsymbol{x})) - r_t(\boldsymbol{x}) f'(r_t(\boldsymbol{x}))\big)\end{equation} one can sample from the distribution $q_t$. According to our previous discussion, Equation $\eqref{eq:wgd}$ is the Wasserstein gradient flow that minimizes the $f$-divergence between $p$ and $q$. When $t\to\infty$, the $f$-divergence is zero, meaning $q_t=p$. Thus, as $t\to\infty$, the ODE above achieves sampling from distribution $p$. However, this result currently only has formal significance and no practical use, because it implies that we need to know the expression for distribution $p$ and solve Equation $\eqref{eq:wgd}$ for the expression of $q_t$, and then calculate the right side of the ODE to complete the sampling. This calculation is extremely difficult and usually cannot be completed.

A relatively simple example is the (reverse) KL divergence, where $f=-\log$. Substituting into Equation $\eqref{eq:wgd}$, we get: \begin{equation}\begin{aligned}\frac{\partial q_t(\boldsymbol{x})}{\partial t} =&\, - \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log \frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}\right)\\ =&\, - \nabla_{\boldsymbol{x}}\cdot\Big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\big(\log p(\boldsymbol{x}) - \log q_t(\boldsymbol{x})\big)\Big)\\ =&\, - \nabla_{\boldsymbol{x}}\cdot\big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x})\big) + \nabla_{\boldsymbol{x}}\cdot\nabla_{\boldsymbol{x}} q_t(\boldsymbol{x}) \end{aligned}\end{equation} Comparing again with the results from "Deriving the Continuity Equation and Fokker-Planck Equation Using the Test Function Method," this is exactly a Fokker-Planck equation, which corresponds to the SDE: \begin{equation}d\boldsymbol{x} = \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) dt + \sqrt{2}dw\end{equation} In other words, if we know $\log p(\boldsymbol{x})$, we can use the equation above to sample from $p(\boldsymbol{x})$. Compared to the previous ODE, this avoids the process of solving for $q_t(\boldsymbol{x})$, making it a relatively usable scheme.

Summary

This article introduced the concept of "gradient flow" in the process of seeking the minimum via gradient descent, including the extension from gradient flow in vector spaces to Wasserstein gradient flow in probability spaces, and their connections with the continuity equation, Fokker-Planck equation, and ODE/SDE sampling.