Generative Diffusion Model Talk (21): Accelerated ODE Sampling via the Mean Value Theorem

By 苏剑林 | December 07, 2023

In the history of generative diffusion models, DDIM and Song Yang's concurrent work on Diffusion SDEs are both considered milestones. This is because they established a close connection between diffusion models and the mathematical fields of Stochastic Differential Equations (SDE) and Ordinary Differential Equations (ODE). This allows us to utilize various existing mathematical tools from SDEs and ODEs to analyze, solve, and extend diffusion models. For instance, a vast amount of subsequent work on accelerated sampling is based on this foundation, effectively opening a completely new perspective on generative diffusion models.

In this article, we focus on ODEs. In our previous blogs—Part (6), (12), (14), (15), and (17)—we already derived the relationship between ODEs and diffusion models. This article provides a brief introduction to sampling acceleration for diffusion ODEs and highlights a clever new acceleration scheme called "AMED," which utilizes the spirit of the "Mean Value Theorem."

Euler's Method

As mentioned, since we have already derived the connection between diffusion models and ODEs in several articles, we will not repeat the derivation here. Instead, we directly define the sampling of a diffusion ODE as solving the following ODE:

\begin{equation}\frac{d\boldsymbol{x}_t}{dt} = \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\label{eq:dm-ode}\end{equation}

where $t \in [0, T]$, the initial condition is $\boldsymbol{x}_T$, and the result to be returned is $\boldsymbol{x}_0$. In principle, we do not care about the intermediate values $\boldsymbol{x}_t$ for $t \in (0, T)$; we only need the final $\boldsymbol{x}_0$. For numerical solving, we need to select nodes $0 = t_0 < t_1 < t_2 < \cdots < t_N = T$. A common choice is:

\begin{equation}t_n=\left(t_1^{1 / \rho}+\frac{n-1}{N-1}\left(t_N^{1 / \rho}-t_1^{1 / \rho}\right)\right)^\rho\end{equation}

where $\rho > 0$. This form comes from "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM). AMED also adopts this scheme. Personally, I believe the choice of nodes is not an essential element; therefore, this article will not delve deeply into it.

The simplest solver is "Euler's method": using the finite difference approximation:

\begin{equation}\left.\frac{d\boldsymbol{x}_t}{dt}\right|_{t=t_{n+1}}\approx \frac{\boldsymbol{x}_{t_{n+1}} - \boldsymbol{x}_{t_n}}{t_{n+1} - t_n}\end{equation}

We can obtain:

\begin{equation}\boldsymbol{x}_{t_n}\approx \boldsymbol{x}_{t_{n+1}} - \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_{n+1}}, t_{n+1})(t_{n+1} - t_n)\end{equation}

This is often directly referred to as the DDIM method, as DDIM was the first to notice that its sampling process corresponds to the Euler method of an ODE, subsequently deriving the corresponding ODE.

Higher-Order Methods

From the perspective of numerical solving, Euler's method is a first-order approximation. It is simple and fast, but its downside is poor precision unless the step size is very small. This means that relying solely on Euler's method is unlikely to significantly reduce the number of sampling steps while maintaining sampling quality. Therefore, subsequent sampling acceleration works have applied higher-order methods.

For example, intuitively, the difference $\frac{\boldsymbol{x}_{t_{n+1}} - \boldsymbol{x}_{t_n}}{t_{n+1} - t_n}$ should be closer to the derivative at the midpoint rather than at the boundary. Therefore, replacing the right side with the average of $t_n$ and $t_{n+1}$ should yield higher precision:

\begin{equation}\frac{\boldsymbol{x}_{t_{n+1}} - \boldsymbol{x}_{t_n}}{t_{n+1} - t_n}\approx \frac{1}{2}\left[\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_n}, t_n) + \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_{n+1}}, t_{n+1})\right]\label{eq:heun-0}\end{equation}

From this, we get:

\begin{equation}\boldsymbol{x}_{t_n}\approx \boldsymbol{x}_{t_{n+1}} - \frac{1}{2}\left[\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_n}, t_n) + \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_{n+1}}, t_{n+1})\right](t_{n+1} - t_n) \end{equation}

However, $\boldsymbol{x}_{t_n}$ appears on the right side, but $\boldsymbol{x}_{t_n}$ is exactly what we are trying to calculate. Thus, this equation cannot be used directly for iteration. For this reason, we use Euler's method to "predict" $\boldsymbol{x}_{t_n}$, and then substitute it into the equation above:

\begin{equation}\begin{aligned} \tilde{\boldsymbol{x}}_{t_n}=&\, \boldsymbol{x}_{t_{n+1}} - \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_{n+1}}, t_{n+1})(t_{n+1} - t_n) \\ \boldsymbol{x}_{t_n}\approx&\, \boldsymbol{x}_{t_{n+1}} - \frac{1}{2}\left[\boldsymbol{v}_{\boldsymbol{\theta}}(\tilde{\boldsymbol{x}}_{t_n}, t_n) + \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_{n+1}}, t_{n+1})\right](t_{n+1} - t_n) \end{aligned}\label{eq:heun}\end{equation}

This is the "Heun's method" used by EDM, which is a second-order method. In this way, each iteration step requires two evaluations of $\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$, but the precision is significantly improved. Thus, the number of iterative steps can be reduced, and the total computational cost is lowered.

There are many variants of second-order methods. For example, we can directly replace the right side of Eq. $\eqref{eq:heun-0}$ with the function value at the midpoint $t=(t_n+t_{n+1})/2$, yielding:

\begin{equation}\boldsymbol{x}_{t_n}\approx \boldsymbol{x}_{t_{n+1}} - \boldsymbol{v}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{(t_n+t_{n+1})/2}, \frac{t_n+t_{n+1}}{2}\right)(t_{n+1} - t_n) \end{equation}

There are also different ways to find the midpoint. In addition to the arithmetic mean $(t_n+t_{n+1})/2$, one can also consider the geometric mean:

\begin{equation}\boldsymbol{x}_{t_n}\approx \boldsymbol{x}_{t_{n+1}} - \boldsymbol{v}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{\sqrt{t_n t_{n+1}}}, \sqrt{t_n t_{n+1}}\right)(t_{n+1} - t_n) \label{eq:dpm-solver-2}\end{equation}

In fact, Eq. $\eqref{eq:dpm-solver-2}$ is a special case of DPM-Solver-2.

Beyond second-order methods, there are many higher-order methods for solving ODEs, such as the "Runge-Kutta methods" and the "Linear Multistep methods." However, whether using second-order or higher-order methods, although they can accelerate diffusion ODE sampling to some extent, they are "general methods." Since they are not customized for the background and forms of diffusion models, it is difficult to drive the computational steps of the sampling process down to the extreme (single digits).

Mean Value Theorem

Now, the protagonist of this article, AMED, makes its appearance. Its paper, "Fast ODE-based Sampling for Diffusion Models in Around 5 Steps," was just posted on Arxiv two days ago—it is fresh out of the oven. Instead of simply increasing theoretical precision like traditional ODE solvers, AMED cleverly analogies the "Mean Value Theorem" and adds a very small distillation cost to create a customized high-speed solver for diffusion ODEs.

Schematic of several diffusion ODE-Solvers

Schematic of several diffusion ODE-Solvers

First, we integrate both sides of Eq. $\eqref{eq:dm-ode}$, allowing us to write an exact equality:

\begin{equation} \boldsymbol{x}_{t_{n+1}} - \boldsymbol{x}_{t_n} = \int_{t_n}^{t_{n+1}}\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)dt\end{equation}

If $\boldsymbol{v}$ were a one-dimensional scalar function, then by the "Mean Value Theorem for Integrals," we know there exists a point $s_n \in (t_n, t_{n+1})$ such that:

\begin{equation}\frac{1}{t_{n+1} - t_n}\int_{t_n}^{t_{n+1}}\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)dt = \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{s_n}, s_n) \end{equation}

Unfortunately, the mean value theorem does not generally hold for vector functions. However, provided that $t_{n+1}-t_n$ is not too large and under certain assumptions, we can still analogously write the approximation:

\begin{equation}\frac{1}{t_{n+1} - t_n}\int_{t_n}^{t_{n+1}}\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)dt \approx \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{s_n}, s_n) \end{equation}

Thus we obtain:

\begin{equation} \boldsymbol{x}_{t_n}\approx \boldsymbol{x}_{t_{n+1}} - \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{s_n}, s_n)(t_{n+1}-t_n)\end{equation}

Of course, this is currently just a formal solution; how to obtain $s_n$ and $\boldsymbol{x}_{s_n}$ remains unresolved. For $\boldsymbol{x}_{s_n}$, we still use Euler's method to predict it, i.e., $\tilde{\boldsymbol{x}}_{s_n}= \boldsymbol{x}_{t_{n+1}} - \boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_{n+1}}, t_{n+1})(t_{n+1} - s_n)$. For $s_n$, we use a small neural network to estimate it:

\begin{equation}s_n = g_{\boldsymbol{\phi}}(\boldsymbol{h}_{t_{n+1}}, t_{n+1})\end{equation}

where $\boldsymbol{\phi}$ are trainable parameters, and $\boldsymbol{h}_{t_{n+1}}$ is an intermediate feature of the U-Net model $\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_{n+1}}, t_{n+1})$. Finally, to solve for the parameters $\boldsymbol{\phi}$, we adopt a distillation approach: we pre-calculate more accurate trajectory point pairs $(\boldsymbol{x}_{t_n}, \boldsymbol{x}_{t_{n+1}})$ using a solver with more steps, and then minimize the estimation error. This is the AMED-Solver (Approximate MEan-Direction Solver) in the paper. It possesses the form of a conventional ODE-Solver but requires extra distillation cost. However, this distillation cost is almost negligible compared to other distillation acceleration methods, so the author understands it as a "customized" solver.

The word "customized" is crucial. Research on sampling acceleration for diffusion ODEs has existed for a long time. With the combined contributions of many researchers, non-training solvers have likely gone very far, but they still fail to bring the number of sampling steps to the extreme. Unless we have further breakthroughs in the theoretical understanding of diffusion models in the future, I do not believe non-training solvers have significant room for improvement. Therefore, AMED's approach of carrying a small training cost for acceleration is both "unconventional" and "natural as things evolve."

Experimental Results

Before looking at the experimental results, let us first understand a concept called "NFE," short for "Number of Function Evaluations." Put simply, it is the number of times the model $\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$ is executed, which is directly linked to the computational load. For example, a first-order method has 1 NFE per iteration step because it only needs to execute $\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$ once. For a second-order method, it's 2 NFE per step. The computational cost of AMED-Solver's $g_{\boldsymbol{\phi}}$ is very small and negligible, so the NFE for each step of AMED-Solver is also counted as 2. To achieve a fair comparison, one must keep the total NFE constant throughout the sampling process to compare the effects of different solvers.

The basic experimental results are shown in Table 2 of the original paper:

Experimental results of AMED (Table 2)

Experimental results of AMED (Table 2)

There are several points worth noting in this table. First, when NFE does not exceed 5, the second-order DPM-Solver and EDM are even worse than the first-order DDIM. This is because a solver's error is linked not only to the order but also to the step size $t_{n+1}-t_n$, roughly as $\mathcal{O}((t_{n+1}-t_n)^m)$, where $m$ is the "order." When total NFE is small, higher-order methods can only take larger step sizes, making the actual precision worse and the results poorer. Second, the AMED-Solver (also a second-order method) achieves comprehensive SOTA results at low NFEs. This fully demonstrates the importance of "customization." Third, "AMED-Plugin" here refers to the paper's proposal of using the AMED idea as a "plugin" for other ODE Solvers, which is a slightly more complex approach but yields even better results.

Some readers might wonder: since each iteration of a second-order method requires 2 NFE, how can odd NFEs appear in the table? In fact, this is because the authors used a technique called "AFS (Analytical First Step)" to reduce one NFE. This technique comes from "Genie: Higher-order denoising diffusion solvers." Specifically, in the context of diffusion models, it is found that $\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_N}, t_N)$ is very close to $\boldsymbol{x}_{t_N}$ (different diffusion models may behave differently, but the core idea is that the first step can be solved analytically). Thus, in the first step of sampling, $\boldsymbol{v}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t_N}, t_N)$ is replaced directly by $\boldsymbol{x}_{t_N}$, saving one NFE. Tables 8, 9, and 10 in the appendix of the paper evaluate the impact of AFS on results in more detail; interested readers can analyze them.

Finally, since AMED uses a distillation method to train $g_{\boldsymbol{\phi}}$, some readers may want to know the difference in effect between it and other distillation acceleration schemes. Unfortunately, the paper does not provide a related comparison. For this reason, I consulted the author via email. The author stated that the distillation cost of AMED is extremely low: CIFAR10 requires less than 20 minutes of training on a single A100, and 256-sized images require only a few hours on four A100s. In contrast, other distillation acceleration ideas require days or even dozens of days. Therefore, the author regards AMED as solver work rather than distillation work. However, the author also expressed that they would try to supplement a comparison with distillation work if given the opportunity later.

Hypothesis Analysis

Previously, when discussing the extension of the mean value theorem to vector functions, we mentioned "under certain assumptions." So, what are the assumptions here? Are they actually true?

It is not difficult to find counterexamples proving that even for two-dimensional functions, the integral mean value theorem does not hold universally. In other words, the integral mean value theorem only holds for one-dimensional functions. This means that if the integral mean value theorem holds for a high-dimensional function, the spatial trajectory described by that function must be a straight line. That is to say, all points $\boldsymbol{x}_{t_0}, \boldsymbol{x}_{t_1}, \cdots, \boldsymbol{x}_{t_N}$ in the sampling process form a straight line. This assumption is naturally very strong and virtually impossible to hold in reality. However, it also suggests that for the integral mean value theorem to hold as much as possible in high-dimensional space, the sampling trajectory should stay within a subspace of as low a dimension as possible.

To verify this, the paper's authors increased the number of sampling steps to obtain a relatively accurate sampling trajectory, and then performed Principal Component Analysis (PCA) on the trajectory. The results are shown below:

Principal Component Analysis of diffusion ODE sampling trajectories

Principal Component Analysis of diffusion ODE sampling trajectories

The PCA results show that keeping only the top 1 principal component preserves most of the precision of the trajectory, and if the first two components are kept, subsequent errors are almost negligible. This tells us that the sampling trajectories are almost entirely concentrated on a two-dimensional sub-plane, and are even very close to a straight line on that sub-plane. Consequently, when $t_{n+1}-t_n$ is not particularly large, the integral mean value theorem in high-dimensional space for diffusion models holds approximately.

This result might be surprising, but in hindsight, it can be explained: in "Generative Diffusion Model Talk (15): General Steps to Construct ODEs (Part 2)" and "(17): General Steps to Construct ODEs (Part 3)", we introduced the general steps of first specifying a "pseudo-trajectory" from $\boldsymbol{x}_T$ to $\boldsymbol{x}_0$, and then constructing the corresponding diffusion ODE. In practical applications, the "pseudo-trajectories" we construct are all linear interpolations between $\boldsymbol{x}_T$ and $\boldsymbol{x}_0$ (it may be non-linear with respect to $t$, but it is linear with respect to $\boldsymbol{x}_T$ and $\boldsymbol{x}_0$). Thus, the constructed "pseudo-trajectories" are all straight lines, which further encourages the real diffusion trajectory to be a straight line. This explains the PCA results.

Conclusion

This article briefly reviewed sampling acceleration methods for diffusion ODEs and highlighted a novel accelerated sampling scheme called "AMED," which was released just a couple of days ago. This Solver analogies the integral mean value theorem to construct iterative formats, improving the performance of the Solver at low NFEs with extremely low distillation costs.