Variational Autoencoder (VII): VAE on the Sphere (vMF-VAE)

By 苏剑林 | May 17, 2021

In "Variational Autoencoder (V): VAE + BN = Better VAE", we discussed the common phenomenon of KL divergence vanishing when training VAEs in NLP and mentioned using Batch Normalization (BN) to give the KL divergence term a positive lower bound, thereby ensuring it does not vanish. In fact, back in 2018, work based on similar ideas was proposed. These approaches involve using new prior and posterior distributions in the VAE to ensure the KL divergence term has a fixed positive lower bound.

This idea appeared in two similar papers in 2018, namely "Hyperspherical Variational Auto-Encoders" and "Spherical Latent Spaces for Stable Variational Autoencoders". Both use the von Mises–Fisher (vMF) distribution, defined on a hypersphere, to construct the prior and posterior distributions. In some ways, this distribution is even simpler and more interesting than the Gaussian distribution we commonly use!

KL Divergence Vanishing

We know that the training objective of a VAE is:

\begin{equation}\mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\log q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]\end{equation}

where the first term is the reconstruction term and the second is the KL divergence term. As we noted in "Variational Autoencoder (I): What it's All About", these two terms are "adversarial" in some sense. The existence of the KL divergence term increases the difficulty for the decoder to utilize the encoded information. If the KL divergence term becomes zero, it means the decoder completely ignores the information from the encoder.

In NLP, the input and the objects of reconstruction are sentences. To ensure performance, decoders generally use autoregressive models. However, autoregressive models are extremely powerful—powerful enough that they can complete training even without any input (degenerating into an unconditional language model). As we mentioned, since the KL divergence term makes it harder for the decoder to use encoded information, the decoder simply stops using it altogether, which results in the phenomenon of KL divergence vanishing.

An early common solution was to gradually increase the weight of the KL term (KL annealing) to guide the decoder to use encoded information. A more popular modern solution is to introduce certain modifications so that the KL divergence term inherently has a positive lower bound. Replacing the prior and posterior distributions with the vMF distribution is a classic example of this approach.

vMF Distribution

The vMF distribution is defined on a $(d-1)$-dimensional hypersphere. Its sample space is $S^{d-1}=\{x|x\in\mathbb{R}^d, \Vert x\Vert=1\}$, and its probability density function is:

\begin{equation}p(x) = \frac{e^{\langle\xi,x\rangle}}{Z_{d, \Vert\xi\Vert}},\quad Z_{d, \Vert\xi\Vert}=\int_{S^{d-1}}e^{\langle\xi,x\rangle} dS^{d-1}\end{equation}

where $\xi\in\mathbb{R}^d$ is a pre-specified parameter vector. As you might imagine, this is a distribution on $S^{d-1}$ centered at $\xi$. Writing the normalization factor as $Z_{d, \Vert\xi\Vert}$ implies it only depends on the magnitude of $\xi$ due to isotropy. Because of this property, a more common notation for the vMF distribution is to set $\mu=\xi/\Vert\xi\Vert, \kappa=\Vert\xi\Vert, C_{d,\kappa}=1/Z_{d, \Vert\xi\Vert}$, resulting in:

\begin{equation}p(x) = C_{d,\kappa} e^{\kappa\langle\mu,x\rangle}\end{equation}

In this case, $\langle\mu,x\rangle$ is the cosine of the angle between $\mu$ and $x$. Thus, the vMF distribution is essentially a distribution based on cosine similarity as a metric. Since we frequently use cosine values to measure the similarity between two vectors, models based on the vMF distribution usually satisfy this requirement better. When $\kappa=0$, the vMF distribution becomes a uniform distribution on the sphere.

From the integral form of the normalization factor $Z_{d, \Vert\xi\Vert}$, it serves as the generating function for vMF; thus, the moments of vMF can be expressed through $Z_{d, \Vert\xi\Vert}$. For instance, the first moment is:

\begin{equation}\mathbb{E}_{x\sim p(x)} [x] = \nabla_{\xi} \log Z_{d, \Vert\xi\Vert}=\frac{d \log Z_{d,\Vert\xi\Vert}}{d\Vert\xi\Vert}\frac{\xi}{\Vert\xi\Vert}\end{equation}

It can be seen that the direction of $\mathbb{E}_{x\sim p(x)} [x]$ is consistent with $\xi$. The exact form of $Z_{d, \Vert\xi\Vert}$ can be calculated, but it is quite complex, and in many cases, we do not need to know the exact normalization factor, so we will skip it here.

As for the meaning of the parameter $\kappa$, it might be easier to understand by setting $\tau=1/\kappa$. Then $p(x)\sim e^{\langle\mu,x\rangle/\tau}$. Students familiar with energy-based models will recognize that $\tau$ is the temperature parameter. If $\tau$ is small ($\kappa$ is large), the distribution is highly concentrated around $\mu$; conversely, it becomes more dispersed (closer to a uniform distribution on the sphere). Therefore, $\kappa$ is vividly called the "concentration" parameter.

Sampling from vMF

For the vMF distribution, the first major problem to solve is how to sample concrete instances from it. This step is crucial if we want to apply it to a VAE.

Uniform Distribution

The simplest case is when $\kappa=0$, which is the uniform distribution on a $(d-1)$-dimensional sphere. Since the standard normal distribution is isotropic and its probability density is proportional to $e^{-\Vert x\Vert^2/2}$ (depending only on the magnitude), we only need to sample a vector $z$ from a $d$-dimensional standard normal distribution and then let $x=z/\Vert z\Vert$ to obtain a uniform sampling result on the sphere.

Special Direction

Next, for the case where $\kappa > 0$, we denote $x=[x_1,x_2,\cdots,x_d]$. First, consider the special case where $\mu = [1, 0, \cdots, 0]$. Due to isotropy, we often only need to consider this special case and then generalize it to the general case in parallel.

In this case, the probability density is proportional to $e^{\kappa x_1}$. We transform to spherical coordinates:

\begin{equation} \left\{\begin{aligned} x_1 &= \cos\varphi_1\\ x_2 &= \sin\varphi_1 \cos\varphi_2 \\ x_3 &= \sin\varphi_1 \sin\varphi_2 \cos\varphi_3 \\ &\,\,\vdots \\ x_{d-1} &= \sin\varphi_1 \cdots \sin\varphi_{d-2} \cos\varphi_{d-1}\\ x_d &= \sin\varphi_1 \cdots \sin\varphi_{d-2} \sin\varphi_{d-1} \end{aligned}\right. \end{equation}

Then (refer to Wikipedia for the integral transformation of hyperspherical coordinates):

\begin{equation}\begin{aligned} e^{\kappa x_1}dS^{d-1} =& e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 \sin^{d-3}\varphi_2 \cdots \sin\varphi_{d-2} d\varphi_1 d\varphi_2 \cdots d\varphi_{d-1} \\ =& \left(e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 d\varphi_1\right)\left(\sin^{d-3}\varphi_2 \cdots \sin\varphi_{d-2} d\varphi_2 \cdots d\varphi_{d-1}\right) \\ =& \left(e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 d\varphi_1\right)dS^{d-2} \\ \end{aligned}\end{equation}

This decomposition indicates that sampling from this vMF distribution is equivalent to first sampling $\varphi_1$ from a distribution whose probability density is proportional to $e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1$, and then sampling a $(d-1)$-dimensional vector $\varepsilon = [\varepsilon_2,\varepsilon_3,\cdots,\varepsilon_d]$ uniformly from a $(d-2)$-dimensional hypersphere, and combining them as follows:

\begin{equation}x = [\cos\varphi_1, \varepsilon_2\sin\varphi_1, \varepsilon_3\sin\varphi_1, \cdots, \varepsilon_d\sin\varphi_1]\end{equation}

Let $w=\cos\varphi_1\in[-1,1]$, then:

\begin{equation}\left|e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 d\varphi_1\right| = \left|e^{\kappa w} (1-w^2)^{(d-3)/2}dw\right|\end{equation}

Thus, we primarily focus on sampling from a distribution with probability density proportional to $e^{\kappa w} (1-w^2)^{(d-3)/2}$.

However, what the author finds puzzling is that most papers involving the vMF distribution adopt the rejection sampling scheme based on the Beta distribution proposed in the 1994 paper "Simulation of the von Mises Fisher distribution". The entire sampling flow is quite complex. But it is now 2021; for a one-dimensional distribution, why do we still need an inefficient scheme like rejection sampling?

In fact, for any one-dimensional distribution $p(w)$, given its cumulative distribution function (CDF) $\Phi(w)$, the transformation $w=\Phi^{-1}(\varepsilon), \varepsilon\sim U[0,1]$ is the most convenient and universal sampling scheme. Some readers might protest that "the CDF is hard to compute" or "its inverse is even harder to compute." However, when implementing sampling in code, we don't need to know what $\Phi(w)$ looks like analytically; we can just compute it numerically. A reference implementation is as follows:

import numpy as np

def sample_from_pw(size, kappa, dims, epsilon=1e-7):
    x = np.arange(-1 + epsilon, 1, epsilon)
    y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
    y = np.cumsum(np.exp(y - y.max()))
    y = y / y[-1]
    return np.interp(np.random.random(size), y, x)

In this implementation, the most computationally intensive part is calculating the variable y. Once calculated, it can be cached, and subsequent sampling only requires the last step, which is very fast. This is arguably much simpler and more convenient than rejection sampling from a Beta distribution. Incidentally, a trick is used here: first calculate the log values, subtract the maximum, and then take the exponent to prevent overflow, allowing successful computation even when $\kappa$ is in the thousands.

General Case

Now that we have implemented sampling from a vMF distribution with $\mu=[1,0,\cdots,0]$, we can decompose the sampling result as:

\begin{equation}x = w\times\underbrace{[1,0,\cdots,0]}_{\text{parameter vector } \mu} + \sqrt{1-w^2}\times\underbrace{[0,\varepsilon_2,\cdots,\varepsilon_d]}_{\begin{array}{c}\text{uniform sampling on a } d-2 \text{ dimensional}\\ \text{hypersphere orthogonal to } \mu\end{array}}\end{equation}

Similarly, due to isotropy, for a general $\mu$, the sampling result still has the same form:

\begin{equation}\begin{aligned} &x = w\mu + \sqrt{1-w^2}\nu\\ &w\sim e^{\kappa w} (1-w^2)^{(d-3)/2}\\ &\nu\sim \text{uniform distribution on a } d-2 \text{ dimensional hypersphere orthogonal to } \mu \end{aligned}\end{equation}

The key to sampling $\nu$ is ensuring it is orthogonal to $\mu$. This is not hard to achieve: first sample a $d$-dimensional vector $\varepsilon$ from a standard normal distribution, then keep only the component orthogonal to $\mu$ and normalize it:

\begin{equation}\nu = \frac{\varepsilon - \langle \varepsilon,\mu\rangle \mu}{\Vert \varepsilon - \langle \varepsilon,\mu\rangle \mu\Vert},\quad \varepsilon\sim\mathcal{N}(0,1_d)\end{equation}

vMF-VAE

At this point, we have completed the most difficult part of the article. Constructing the vMF-VAE follows naturally. vMF-VAE chooses a uniform distribution on the sphere ($\kappa=0$) as the prior distribution $q(z)$ and chooses the vMF distribution as the posterior distribution:

\begin{equation}p(z|x) = C_{d, \kappa} e^{\kappa\langle\mu(x),z\rangle}\end{equation}

For simplicity, we treat $\kappa$ as a hyperparameter (which can be understood as updating this parameter manually rather than via gradient descent). Thus, the only source of parameters for $p(z|x)$ is $\mu(x)$. Now we can calculate the KL divergence term:

\begin{equation}\begin{aligned} \int p(z|x) \log\frac{p(z|x)}{q(z)} dz =&\, \int C_{d,\kappa} e^{\kappa\langle\mu(x),z\rangle}\left(\kappa\langle\mu(x),z\rangle + \log C_{d,\kappa} - \log C_{d,0}\right)dz\\ =&\, \kappa\left\langle\mu(x),\mathbb{E}_{z\sim p(z|x)}[z]\right\rangle + \log C_{d,\kappa} - \log C_{d,0} \end{aligned}\end{equation}

As we discussed earlier, the mean direction of the vMF distribution is aligned with $\mu(x)$, and the magnitude depends only on $d$ and $\kappa$. By substituting this into the equation, we find that the KL divergence term depends only on $d$ and $\kappa$. Once these two parameters are fixed, it becomes a constant (by the properties of KL divergence, when $\kappa \neq 0$, it must be greater than 0), and the phenomenon of KL divergence vanishing will absolutely not occur.

Now, only the reconstruction term remains. We need to use "reparameterization" to perform sampling while preserving gradients. Since we have already studied the sampling process for vMF, implementation is straightforward. The comprehensive process is:

\begin{equation}\begin{aligned} &\mathcal{L} = \Vert x - g(z)\Vert^2\\ &z = w\mu(x) + \sqrt{1-w^2}\nu\\ &w\sim e^{\kappa w} (1-w^2)^{(d-3)/2}\\ &\nu=\frac{\varepsilon - \langle \varepsilon,\mu\rangle \mu}{\Vert \varepsilon - \langle \varepsilon,\mu\rangle \mu\Vert}\\ &\varepsilon\sim\mathcal{N}(0,1_d) \end{aligned}\end{equation}

Here, the reconstruction loss uses MSE as an example; for sentence reconstruction, use cross-entropy. In this flow, $\mu(x)$ is the encoder and $g(z)$ is the decoder. Since the KL divergence is constant and does not affect optimization, vMF-VAE is simply an autoencoder with a slightly more complex reparameterization operation (and manual adjustment of $\kappa$), significantly simpler than a standard Gaussian-based VAE.

Furthermore, from this process, we can see that beyond "simplicity," another reason for not making $\kappa$ trainable is that $\kappa$ is involved in the sampling of $w$, and preserving the gradient of $\kappa$ during the sampling of $w$ is relatively difficult.

Reference Implementation

The difficulty of implementing vMF-VAE lies in the reparameterization part, which is the sampling from the vMF distribution, specifically the sampling of $w$. We previously provided a NumPy implementation for sampling $w$, but TensorFlow lacks a direct equivalent to np.interp, making it difficult to convert to a pure TF implementation. Of course, with dynamic graph frameworks like PyTorch or TF2, mixing NumPy code is fine, but we want a more universal solution here.

Actually, it's not hard. Since $w$ is just a one-dimensional variable and each training step only requires batch_size samples, we can pre-calculate and store a large number (say, hundreds of thousands) of $w$ values using NumPy. During training, we can simply sample randomly from these values. A reference implementation is as follows:

def sampling(mu):
    """vMF distribution reparameterization
    """
    dims = K.int_shape(mu)[-1]
    # Pre-calculate a batch of w
    epsilon = 1e-7
    x = np.arange(-1 + epsilon, 1, epsilon)
    y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
    y = np.cumsum(np.exp(y - y.max()))
    y = y / y[-1]
    W = K.constant(np.interp(np.random.random(10**6), y, x))
    # Sample w in real-time
    idxs = K.random_uniform(K.shape(mu[:, :1]), 0, 10**6, dtype='int32')
    w = K.gather(W, idxs)
    # Sample z in real-time
    eps = K.random_normal(K.shape(mu))
    nu = eps - K.sum(eps * mu, axis=1, keepdims=True) * mu
    nu = K.l2_normalize(nu, axis=-1)
    return w * mu + (1 - w**2)**0.5 * nu

A complete example based on MNIST can be found at:

https://github.com/bojone/vae/blob/master/vae_vmf_keras.py

As for examples of using vMF-VAE in NLP, we will share those when we have the opportunity. This article mainly serves as a theoretical introduction and simple demonstration.

Summary

This article introduced a VAE implementation based on the vMF distribution, with the primary difficulty being sampling from the vMF distribution. Overall, the vMF distribution is built on the metric of cosine similarity, and its properties in certain respects align more closely with our intuitive understanding. Using it in a VAE allows the KL divergence term to be a constant, thereby preventing KL divergence vanishing and simplifying the VAE structure.