By 苏剑林 | September 29, 2018
Today, I will introduce a classic piece of work titled f-GAN. In this paper, the authors provide a scheme for constructing general GANs using general $f$-divergences. It is no exaggeration to say that this paper is a "production workshop" for GAN models. It generalizes many GAN variants and can inspire us to quickly construct new ones (whether they are valuable is another matter, but theoretically, it is possible).
Local Variation
The entire article's treatment of $f$-divergence is actually based on what is known in machine learning as the "local variational method." This is a very classic and useful estimation technique. In fact, most of this article will be spent introducing the results of applying this estimation technique to $f$-divergence. As for GANs, they are simply a basic application of these results.
f-Divergence
First, let's provide a basic introduction to $f$-divergence. The so-called $f$-divergence is a generalization of the KL divergence:
\begin{equation}\mathcal{D}_f(P\Vert Q) = \int q(x) f\left(\frac{p(x)}{q(x)}\right)dx\label{eq:f-div}\end{equation}
Note that according to general convention, the term inside the parentheses is $p/q$ rather than $q/p$. Do not naturally assume it is $q/p$ based on the form of the KL divergence.
It turns out that this form can cover many measures between probability distributions that we have seen. Here, I will directy bring in the table from the paper (partially):
| Name of Divergence |
Formula |
Corresponding $f$ |
| Total Variation |
$\frac{1}{2}\int | p(x) - q(x)| dx$ |
$\frac{1}{2}|u - 1|$ |
| KL Divergence |
$\int p(x)\log \frac{p(x)}{q(x)} dx$ |
$u \log u$ |
| Reverse KL |
$\int q(x)\log \frac{q(x)}{p(x)} dx$ |
$- \log u$ |
| Pearson $\chi^2$ |
$\int \frac{(q(x) - p(x))^{2}}{p(x)} dx$ |
$\frac{(1 - u)^{2}}{u}$ |
| Neyman $\chi^2$ |
$\int \frac{(p(x) - q(x))^{2}}{q(x)} dx$ |
$(u - 1)^{2}$ |
| Hellinger Distance |
$\int \left(\sqrt{p(x)} - \sqrt{q(x)}\right)^{2} dx$ |
$(\sqrt{u} - 1)^{2}$ |
| Jeffrey Distance |
$\int (p(x) - q(x))\log \left(\frac{p(x)}{q(x)}\right) dx$ |
$(u - 1)\log u$ |
| JS Divergence |
$\frac{1}{2}\int p(x)\log \frac{2 p(x)}{p(x) + q(x)} + q(x)\log \frac{2 q(x)}{p(x) + q(x)} dx$ |
$-\frac{u + 1}{2}\log \frac{1 + u}{2} + \frac{u}{2} \log u$ |
Convex Functions
Having listed a bunch of distribution measures and their corresponding $f$, a natural question arises: what are the common characteristics of these $f$?
The answer is:
1. They are all mappings from non-negative real numbers to real numbers ($\mathbb{R}^* \to \mathbb{R}$);
2. $f(1)=0$;
3. They are all convex functions.
The first point is standard. The second point, $f(1)=0$, ensures that $\mathcal{D}_f(P\Vert P)=0$. How do we understand the third point about convexity? It is actually one of the most fundamental applications of convex function properties, specifically Jensen's Inequality:
\begin{equation}\mathbb{E}\big[f(x)\big]\geq f\big(\mathbb{E}[x]\big)\label{eq:tuhanshu-xingzhi}\end{equation}
That is, "the average of the function is greater than or equal to the function of the average." Some textbooks directly use this property as the definition of a convex function. If $f(u)$ is a smooth function, we generally determine if it is convex by checking if the second derivative $f''(u)$ is consistently greater than or equal to 0.
Using $\eqref{eq:tuhanshu-xingzhi}$, we have:
\begin{equation}\begin{aligned}\int q(x) f\left(\frac{p(x)}{q(x)}\right)dx =& \mathbb{E}_{x\sim q(x)} \left[f\left(\frac{p(x)}{q(x)}\right)\right]\\
\geq& f\left(\mathbb{E}_{x\sim q(x)} \left[\frac{p(x)}{q(x)}\right]\right)\\
=& f\left(\int q(x) \frac{p(x)}{q(x)}dx\right)\\
=& f\left(\int p(x)dx\right)\\
=& f(1) = 0
\end{aligned}\end{equation}
In other words, these three conditions guarantee that the $f$-divergence is non-negative, and it equals 0 when the two distributions are identical, making $\mathcal{D}_f$ a simple way to measure the difference between distributions. Of course, $f$-divergence does not strictly guarantee that $\mathcal{D}_f(P\Vert Q) \gt 0$ when $P \neq Q$. However, we typically choose strictly convex $f$ (i.e., $f''(u) \gt 0$ everywhere), which ensures that $\mathcal{D}_f(P\Vert Q) \gt 0$ when $P \neq Q$, meaning $\mathcal{D}_f(P\Vert Q)=0 \iff P=Q$. (Note: Even so, in general, $\mathcal{D}_f(P\Vert Q)$ is still not a "distance" satisfying the axiomatic definition, but that is irrelevant to the main topic here.)
Convex Conjugate
Now let's discuss convex functions from a more mathematical perspective. Generally, let the domain of a convex function be $\mathbb{D}$ (for this article, $\mathbb{D}=\mathbb{R}_+$). Choosing any point $\xi$, we find the tangent line of $y=f(u)$ at $u=\xi$, which results in:
\begin{equation}y = f(\xi) + f'(\xi)(u - \xi)\end{equation}
Consider the difference function:
\begin{equation}h(u) = f(u) - f(\xi) - f'(\xi)(u - \xi)\end{equation}
A convex function can be understood intuitively as its graph always lying above (any of) its tangent lines. Therefore, for a convex function, the following inequality always holds:
\begin{equation}f(u) - f(\xi) - f'(\xi)(u - \xi)\geq 0\end{equation}
Rearranging this:
\begin{equation}f(u) \geq f(\xi) - f'(\xi) \xi + f'(\xi)u\end{equation}
Since the inequality holds universally and the equality can be achieved, we can derive:
\begin{equation}f(u) = \max_{\xi\in\mathbb{D}}\big\{f(\xi) - f'(\xi) \xi + f'(\xi)u\big\}\end{equation}
Changing notation, let $t=f'(\xi)$, and solve for $\xi$ in terms of $t$ (for a convex function, this is always possible; the reader can try to prove this), then denote:
\begin{equation}g(t) = - f(\xi) + f'(\xi) \xi\end{equation}
Then we have:
\begin{equation}f(u) = \max_{t\in f'(\mathbb{D})}\big\{t u - g(t)\big\}\end{equation}
Here $g(t)$ is known as the conjugate function of $f(u)$. Note that the expression inside the braces is linear with respect to $u$, once $f$ is given, $g$ is also determined. So overall, we have done the following:
Provided a linear approximation for a convex function, whereby the original value can be reached by maximizing with respect to the internal parameters.
Note that for a given $u$, we must maximize over $t$ once to get as close as possible to $f(u)$; otherwise, substituting a random $t$ only guarantees a lower bound without ensuring the magnitude of the error. Hence, it is called a "local variational method" because maximization (variation) must be performed at every point (locally). In this way, we can understand $t$ as effectively being a function of $u$, i.e.:
\begin{equation}f(u) = \max_{T\text{ is a function with range } f'(\mathbb{D})}\big\{T(u) u - g(T(u))\big\}\label{eq:max-conj}\end{equation}
The above discussion actually provides a method for calculating the convex conjugate. Here, we directly provide the conjugate functions corresponding to the convex functions in the previous table.
| $f(u)$ |
Corresponding Conjugate $g(t)$ |
$f'(\mathbb{D})$ |
Activation Function |
| $\frac{1}{2}|u - 1|$ |
$t$ |
$\left[-\frac{1}{2},\frac{1}{2}\right]$ |
$\frac{1}{2}\tanh(x)$ |
| $u \log u$ |
$e^{t-1}$ |
$\mathbb{R}$ |
$x$ |
| $- \log u$ |
$-1 - \log(-t)$ |
$\mathbb{R}_-$ |
$-e^{x}$ |
| $\frac{(1 - u)^{2}}{u}$ |
$2 - 2\sqrt{1-t}$ |
$(-\infty, 1)$ |
$1-e^x$ |
| $(u - 1)^{2}$ |
$\frac{1}{4}t^2+t$ |
$(-2,+\infty)$ |
$e^x-2$ |
| $(\sqrt{u} - 1)^{2}$ |
$\frac{t}{1-t}$ |
$(-\infty, 1)$ |
$1-e^x$ |
| $(u - 1)\log u$ |
$W(e^{1-t})+\frac{1}{W(e^{1-t})}+t-2$ |
$\mathbb{R}$ |
$x$ |
| $-\frac{u + 1}{2}\log \frac{1 + u}{2} + \frac{u}{2} \log u$ |
$-\frac{1}{2}\log(2-e^{2t})$ |
$\left(-\infty,\frac{\log 2}{2}\right)$ |
$\frac{\log 2}{2}-\frac{1}{2}\log(1+e^{-x})$ |
(Note: $W$ here is the Lambert W function.)
f-GAN
From the above derivation, we can provide the estimation formula for $f$-divergence and further provide a general framework for $f$-GAN.
f-Divergence Estimation
What is the difficulty in calculating $f$-divergence? According to the definition $\eqref{eq:f-div}$, we simultaneously need to know both probability distributions $P$ and $Q$ to calculate the $f$-divergence between them. But in reality, this is hard to achieve in machine learning. Sometimes we at most know the analytical form of one probability distribution, while for the other distribution, we only have sampled data. In many cases, we don't know either distribution and only have samples (meaning we want to compare the similarity between two sets of samples). Therefore, we cannot directly calculate $f$-divergence based on $\eqref{eq:f-div}$.
Combining $\eqref{eq:f-div}$ and $\eqref{eq:max-conj}$, we get:
\begin{equation}\begin{aligned}\mathcal{D}_f(P\Vert Q) =& \max_{T}\int q(x) \left[\frac{p(x)}{q(x)}T\left(\frac{p(x)}{q(x)}\right)-g\left(T\left(\frac{p(x)}{q(x)}\right)\right)\right]dx\\
=& \max_{T}\int\left[p(x)\cdot T\left(\frac{p(x)}{q(x)}\right)-q(x)\cdot g\left(T\left(\frac{p(x)}{q(x)}\right)\right)\right]dx\end{aligned}\end{equation}
Denoting $T\left(\frac{p(x)}{q(x)}\right)$ as the unified term $T(x)$, we have:
\begin{equation}\mathcal{D}_f(P\Vert Q) = \max_{T}\Big(\mathbb{E}_{x\sim p(x)}[T(x)]-\mathbb{E}_{x\sim q(x)}[g(T(x))]\Big)\label{eq:f-div-e}\end{equation}
Equation $\eqref{eq:f-div-e}$ is the basic formula for estimating $f$-divergence. It means that by sampling from two distributions, calculating the average values of $T(x)$ and $g(T(x))$, and optimizing $T$ to make their difference as large as possible, we arrive at an approximate value for the $f$-divergence. Clearly, $T(x)$ can be fitted by a sufficiently complex neural network; we just need to optimize its parameters.
Note that in our discussion of convex functions, while maximizing the objective, there are constraints on the range of $T$. Thus, in the final layer of $T$, we must design an appropriate activation function to ensure $T$ satisfies the required range. Of course, the choice of activation function is not unique. Reference activation functions have been listed in the previous table. Although theoretically any activation function of the right range will do, for ease of optimization, certain principles should be followed:
1. The domain should be $\mathbb{R}$, and the range should be the required range (boundary points can be ignored);
2. It is best to choose a globally smooth function, rather than simple truncation. For example, if the required range is $\mathbb{R}_+$, do not use $relu(x)$ directly; consider $e^x$ or $\log(1+e^x)$ instead;
3. Note that the second term of formula $\eqref{eq:f-div-e}$ contains $g(T(x))$, which is the composite calculation of $g$ and $T$. Thus, the activation function should ideally make this composite calculation simple.
GAN Wholesale
Well, having said all this, we are almost at the end of the article, and yet we haven't formally mentioned GANs. In fact, GANs can be considered a mere byproduct of this entire process.
A GAN aims to train a generator to map a Gaussian distribution to the distribution of our required dataset. This requires comparing the difference between the two distributions. After the previous process, it becomes quite simple: just pick any $f$-divergence. Then use formula $\eqref{eq:f-div-e}$ to estimate the $f$-divergence. Once estimated, you have a model for the $f$-divergence. Since the generator wants to minimize the difference in distributions, you just minimize the $f$-divergence. So writing it as a single expression:
\begin{equation}\min_G\max_{T}\Big(\mathbb{E}_{x\sim p(x)}[T(x)]-\mathbb{E}_{x=G(z),z\sim q(z)}[g(T(x))]\Big)\label{eq:f-div-gan}\end{equation}
Or vice versa:
\begin{equation}\min_G\max_{T}\Big(\mathbb{E}_{x=G(z),z\sim q(z)}[T(x)]-\mathbb{E}_{x\sim p(x)}[g(T(x))]\Big)\label{eq:f-div-gan-2}\end{equation}
And that's it~
Need some examples? Okay, let's look at JS divergence first. If you substitute all the formulas step by step, you will find the final result (omitting the $\log 2$ constant term) is:
\begin{equation}\min_G\max_{D}\Big(\mathbb{E}_{x\sim p(x)}[\log D(x)] + \mathbb{E}_{x=G(z),z\sim q(z)}[\log(1-D(x))]\Big)\end{equation}
Where $D$ is activated using $\sigma(x)=1/(1+e^{-x})$. This is the original version of GAN.
How about using Hellinger distance? The result is:
\begin{equation}\min_G\max_{D}\Big(-\mathbb{E}_{x\sim p(x)}[e^{D(x)}] - \mathbb{E}_{x=G(z),z\sim q(z)}[e^{-D(x)}]\Big)\end{equation}
Here $D(x)$ uses linear activation. This seemingly hasn't been named yet? But experiments on it have already been conducted in the paper.
What about KL divergence? Since KL divergence is asymmetric, there are two results:
\begin{equation}\min_G\max_{D}\Big(\mathbb{E}_{x\sim p(x)}[D(x)] - \mathbb{E}_{x=G(z),z\sim q(z)}[e^{D(x)-1}]\Big)\end{equation}
Or:
\begin{equation}\min_G\max_{D}\Big(\mathbb{E}_{x=G(z),z\sim q(z)}[D(x)] - \mathbb{E}_{x\sim p(x)}[e^{D(x)-1}]\Big)\end{equation}
Here $D(x)$ also uses linear activation.
Alright, no more examples. Actually, these $f$-divergences are essentially similar, and it is hard to see significant differences in performance. However, one can note that JS divergence and Hellinger distance are symmetric and bounded, which are very good properties that we will use later.
Conclusion
To put it bluntly, the main purpose of this article is to introduce $f$-divergence and its estimation via local variation. Therefore, most of it consists of theoretical text, with GANs occupying only a small part.
Of course, after this effort, we indeed arrive at the result of a "GAN production workshop" (depending on how many $f$-divergences you have). These newly devised GANs may not look exactly like the GANs we imagine, but they are indeed optimizing $f$-divergence. However, the problems inherent in standard GANs (corresponding to JS divergence) will still exist in $f$-divergences. Therefore, the greater value of the f-GAN work lies in "unification"; from the perspective of generative models, it doesn't represent a major breakthrough.