Talk on Generative Diffusion Models (8): Optimal Diffusion Variance Estimation (Part 2)

By 苏剑林 | August 18, 2022

In the previous article "Talk on Generative Diffusion Models (7): Optimal Diffusion Variance Estimation (Part 1)", we introduced and derived the optimal variance estimation results in Analytic-DPM. It provides an analytical estimate of the optimal variance for a pre-trained generative diffusion model, and experiments showed that this estimation effectively improves generation quality.

In this article, we continue to introduce the upgrade of Analytic-DPM, from the same author team's paper "Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models", referred to as "Extended-Analytic-DPM" in their official GitHub. Below, we will also use this name.

Review of Results

The previous article derived, based on DDIM, that the optimal variance for the DDIM generation process should be \[\sigma_t^2 + \gamma_t^2\bar{\sigma}_t^2\] where $\bar{\sigma}_t^2$ is the variance of the distribution $p(\boldsymbol{x}_0|\boldsymbol{x}_t)$, which has the following estimation result (using the result of "Variance Estimation 2"): \begin{equation}\bar{\sigma}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left(1 - \frac{1}{d}\mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\left[ \Vert\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\Vert^2\right]\right)\label{eq:basic}\end{equation}

In hindsight, the logic of the estimation isn't particularly difficult. Assuming \begin{equation}\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) = \frac{1}{\bar{\alpha}_t}\left(\boldsymbol{x}_t - \bar{\beta}_t \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)\label{eq:bar-mu}\end{equation} has accurately predicted the mean vector of the distribution $p(\boldsymbol{x}_0|\boldsymbol{x}_t)$, then according to the definition, the covariance is \begin{equation}\begin{aligned} \boldsymbol{\Sigma}(\boldsymbol{x}_t)=&\, \mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\right)\left(\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\right)^{\top}\right] \\ =&\, \frac{1}{\bar{\alpha}_t^2}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{x}_t - \bar{\alpha}_t\boldsymbol{x}_0\right)\left(\boldsymbol{x}_t - \bar{\alpha}_t\boldsymbol{x}_0\right)^{\top}\right] - \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2} \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)^{\top}\\ \end{aligned}\label{eq:full-cov}\end{equation} Averaging both sides over $\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)$ to eliminate the dependence on $\boldsymbol{x}_t$: \begin{equation} \boldsymbol{\Sigma}_t = \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}[\boldsymbol{\Sigma}(\boldsymbol{x}_t)] = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left(\boldsymbol{I} - \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\left[ \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)^{\top}\right]\right)\label{eq:uncond-var-2}\end{equation} Finally, taking the average of the diagonal elements to make it a scalar (or assuming the covariance is a multiple of the identity matrix), i.e., $\bar{\sigma}_t^2 = \text{Tr}(\boldsymbol{\Sigma}_t)/d$, yields the estimation formula $\eqref{eq:basic}$.

How to Improve

Before formally introducing Extended-Analytic-DPM, we can think about what room for improvement Analytic-DPM has.

Actually, several ideas quickly come to mind. For example, Analytic-DPM assumes the covariance matrix of the normal distribution used to approximate $p(\boldsymbol{x}_0|\boldsymbol{x}_t)$ is designed as $\bar{\sigma}_t^2\boldsymbol{I}$, which is a diagonal matrix with identical diagonal elements. A direct improvement would be to allow the diagonal elements to differ, i.e., $\text{diag}(\bar{\boldsymbol{\sigma}}_t^2)$, where vector multiplication is defined as the Hadamard product, e.g., $\boldsymbol{x}^2=\boldsymbol{x}\otimes \boldsymbol{x}$. The corresponding result only considers the diagonal part of $\boldsymbol{\Sigma}_t$. Starting from equation $\eqref{eq:uncond-var-2}$, the corresponding estimate is: \begin{equation}\bar{\boldsymbol{\sigma}}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left(\boldsymbol{1}_d - \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\left[ \boldsymbol{\epsilon}_{\boldsymbol{\theta}}^2(\boldsymbol{x}_t, t)\right]\right) \end{equation} where $\boldsymbol{1}_d$ is a $d$-dimensional vector of ones. A further improvement would be to preserve the dependence of $\bar{\boldsymbol{\sigma}}_t^2$ on $\boldsymbol{x}_t$, considering $\bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t)$. Similar to $\boldsymbol{\mu}(\boldsymbol{x}_t)$, this would require using a model with $\boldsymbol{x}_t$ as input to learn it.

So, could we consider the full $\boldsymbol{\Sigma}_t$? Theoretically yes, but practically it is largely infeasible because a full $\boldsymbol{\Sigma}_t$ is a $d\times d$ matrix. In the context of images, $d$ is the total number of pixels. Even for CIFAR-10, $d=32^2\times 3=3072$, let alone higher-resolution images. Thus, given experimental constraints, a $d\times d$ matrix is too costly in terms of both storage and computation.

Besides that, there is a problem many readers might not realize: the previous analytical derivations depend on $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) = \mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}[\boldsymbol{x}_0]$. In reality, $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$ is learned by a model and might not exactly equal the true mean $\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}[\boldsymbol{x}_0]$. This is the meaning of the "Imperfect Mean" mentioned in the Extended-Analytic-DPM paper title. Improving estimation results under an imperfect mean is of more practical significance.

Maximum Likelihood

Assuming the mean model $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$ has been pre-trained, the only remaining parameter for the distribution $\mathcal{N}(\boldsymbol{x}_0;\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\bar{\sigma}_t^2\boldsymbol{I})$ is $\bar{\sigma}_t^2$. The corresponding negative log-likelihood is \begin{equation}\begin{aligned} &\, \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[-\log \mathcal{N}(\boldsymbol{x}_0;\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\bar{\sigma}_t^2\boldsymbol{I})\right] \\ =&\, \frac{\mathbb{E}_{\boldsymbol{x}_t,\boldsymbol{x}_0\sim p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)}\left[\Vert\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2\right]}{2\bar{\sigma}_t^2} + \frac{d}{2}\log \bar{\sigma}_t^2 + \frac{d}{2}\log 2\pi \\ \end{aligned}\label{eq:neg-log}\end{equation} The minimum is found at \begin{equation}\bar{\sigma}_t^2 = \frac{1}{d}\mathbb{E}_{\boldsymbol{x}_t,\boldsymbol{x}_0\sim p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)}\left[\Vert\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2\right]\end{equation} The characteristic here is that $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$ is not necessarily the accurate mean, so the second equals sign in equation $\eqref{eq:full-cov}$ does not hold—only the first equals sign holds. Substituting equation $\eqref{eq:bar-mu}$ gives: \begin{equation}\bar{\sigma}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2 d}\mathbb{E}_{\boldsymbol{x}_0\sim \tilde{p}(\boldsymbol{x}_0),\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I})}\left[\left\Vert\boldsymbol{\varepsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}, t)\right\Vert^2\right]\end{equation} Of course, this only analyzes the simple case where the covariance matrix is $\bar{\sigma}_t^2\boldsymbol{I}$. we can also consider a more general diagonal covariance matrix, $\mathcal{N}(\boldsymbol{x}_0;\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\text{diag}(\bar{\boldsymbol{\sigma}}_t^2))$, with results: \begin{equation}\bar{\boldsymbol{\sigma}}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2 }\mathbb{E}_{\boldsymbol{x}_0\sim \tilde{p}(\boldsymbol{x}_0),\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I})}\left[\left(\boldsymbol{\varepsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}, t)\right)^2\right]\end{equation}

Conditional Variance

If we want to obtain the covariance $\text{diag}(\bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t))$ conditioned on $\boldsymbol{x}_t$, it is equivalent to calculating each component independently, which removes the $\mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}$ averaging step: \begin{equation}\bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t) = \mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[(\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t))^2\right] = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2\right] \end{equation} where $\boldsymbol{\epsilon}_t = \frac{\boldsymbol{x}_t - \bar{\alpha}_t \boldsymbol{x}_0}{\bar{\beta}_t}$. Similar to the previous article, using \begin{equation}\mathbb{E}_{\boldsymbol{x}}[\boldsymbol{x}] = \mathop{\text{argmin}}_{\boldsymbol{\mu}}\mathbb{E}_{\boldsymbol{x}}\left[\Vert \boldsymbol{x} - \boldsymbol{\mu}\Vert^2\right]\label{eq:mean-opt}\end{equation} we obtain \begin{equation}\begin{aligned} \bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t) =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2\right] \\ =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathop{\text{argmin}}_{\boldsymbol{g}}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2-\boldsymbol{g}\right\Vert^2\right] \\ =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathop{\text{argmin}}_{\boldsymbol{g}(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2-\boldsymbol{g}(\boldsymbol{x}_t)\right\Vert^2\right] \\ =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathop{\text{argmin}}_{\boldsymbol{g}(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_t,\boldsymbol{x}_0\sim p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)}\left[\left\Vert\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2-\boldsymbol{g}(\boldsymbol{x}_t)\right\Vert^2\right] \\ \end{aligned}\label{eq:npr-loss}\end{equation} This is the "NPR-DPM" scheme for learning conditional variance in Extended-Analytic-DPM. Additionally, the original paper proposed an "SN-DPM" scheme based on the Perfect Mean hypothesis rather than the Imperfect Mean. However, the paper's experimental results showed SN-DPM performing better than NPR-DPM. In other words, while the paper claims to solve the Imperfect Mean problem, its results suggest that the scheme based on the Perfect Mean hypothesis is better, which conversely implies that the Imperfect Mean problem can practically be considered non-existent.

Two Stages

Readers might wonder: didn't we say at the beginning that in "Improved Denoising Diffusion Probabilistic Models", learnable variance increased training difficulty? Why did Extended-Analytic-DPM return to doing a trainable variance model?

We know that DDPM provides two schemes for variance: $\sigma_t = \frac{\bar{\beta}_{t-1}}{\bar{\beta}_t}\beta_t$ and $\sigma_t = \beta_t$. These two simple schemes actually work quite well. This indirectly suggests that finer adjustments to the variance don't have a huge impact on the final result (at least for a full $T$-step diffusion); the primary factor is the learning of $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$, while variance is just "icing on the cake." If variance is treated as a learnable parameter or model and learned alongside the mean model $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$, the variance changing during the training process will seriously interfere with the learning of the mean model $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$, violating the principle that "$\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$ is primary, and variance is auxiliary."

The cleverness of Extended-Analytic-DPM lies in its proposal of a two-stage training scheme: first, train the mean model $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$ using a original fixed variance, then fix this model and reuse most of its parameters to learn a variance model. This accomplishes "three things at once":

1. Reduces the number of parameters and training costs;
2. Allows the reuse of already trained mean models;
3. Makes the training process more stable.

Personal Reflections

At this point, the introduction of Extended-Analytic-DPM is essentially complete. Attentive readers might feel that if the results of the previous Analytic-DPM were "stunning," then the results of Extended-Analytic-DPM seem relatively mediocre, with nothing particularly heart-stirring. One could say Extended-Analytic-DPM is just a mundane generalization of Analytic-DPM. Although experiments show it still brings a decent improvement, overall it leaves a flat impression. This is probably because Analytic-DPM was a "hard act to follow," making this work look dimmer by comparison, though it is a quite solid piece of work.

Furthermore, as mentioned earlier, experimental results indicated that SN-DPM (based on the Perfect Mean hypothesis) performed better than NPR-DPM (based on the Imperfect Mean hypothesis). This outcome makes the original paper's title feel somewhat "unfitting"—if the Perfect Mean hypothesis scheme is better, it implies the Imperfect Mean problem doesn't really matter. The original paper did not provide further analysis or evaluation of this result. I wonder if it has something to do with the bias in variance estimation? As we know, directly using the formula "dividing by $n$" to estimate variance is biased, and NPR-DPM operates on this basis, while SN-DPM directly estimates the second moment, and second moment estimation is unbiased. It feels like there's some logic there, but it doesn't fully explain everything—it's a bit of a mystery~

Finally, I wonder if readers share a question I have: given $\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)$, why not just use the negative log-likelihood (like equation $\eqref{eq:neg-log}$) as the loss function to learn the variance, instead of redesigning NPR-DPM or SN-DPM as MSE-form losses? Is there a special benefit to the MSE form? For now, I haven't thought of an answer either.

Summary

This article introduced the upgrade of the Analytic-DPM paper—optimal variance estimation results in "Extended-Analytic-DPM." It primarily derived results for the case of an imperfect mean and proposed a scheme for learning conditional variance.