Variational Autoencoders (Part 5): VAE + BN = Better VAE

By 苏剑林 | May 06, 2020

In this article, we continue our previous Variational Autoencoder series and analyze how to prevent the "KL Vanishing" phenomenon in NLP VAE models. This article is inspired by the ACL 2020 paper "A Batch Normalized Inference Network Keeps the KL Vanishing Away", with further refinements added by the author.

It is worth mentioning that the final solution derived in this article is quite concise—simply adding BN (Batch Normalization) to the encoder output followed by a simple scale—but it is indeed very effective and worth a try for readers researching related issues. At the same time, the conclusions are also applicable to general VAE models (including those for CV); in my view, it could even be considered a "standard configuration" for VAE models.

Finally, a reminder to readers that this is an advanced VAE paper, so please ensure you have a certain understanding of VAEs before reading further.

A Simple Review of VAE

Here we briefly review the VAE model and discuss the difficulties VAE faces in NLP. For a more detailed introduction to VAE, please refer to the author's previous works: "Variational Autoencoders (1): So That's What It Is" and "Variational Autoencoders (2): Starting from a Bayesian Perspective".

VAE Training Process

The training process of a VAE can be roughly illustrated as:

(VAE Training Flow Diagram)

Written as a formula, it is:

\begin{equation}\mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\log q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]\end{equation}

where the first term is the reconstruction term, implemented via the reparameterization trick $\mathbb{E}_{z\sim p(z|x)}$; the second term is known as the KL divergence term. This is the explicit difference between a VAE and a standard autoencoder; without this term, it basically degrades into a conventional AE. For more detailed symbol definitions, please refer to "Variational Autoencoders (2): Starting from a Bayesian Perspective".

VAE in NLP

In NLP, sentences are encoded as discrete integer IDs, so $q(x|z)$ is a discrete distribution that can be implemented using the versatile "Conditional Language Model." Therefore, theoretically, $q(x|z)$ can precisely fit the generative distribution. The problem arises because $q(x|z)$ is too strong. During training, the reparameterization operation introduces noise. If the noise is significant, utilizing $z$ becomes difficult. Consequently, the model simply ignores $z$, degenerating into an unconditional language model (which is still very strong), and $KL(p(z|x)\Vert q(z))$ subsequently drops to 0. This is the KL vanishing phenomenon.

A VAE model in this state has little value: a KL divergence of 0 means the encoder outputs a constant vector, and the decoder is merely a standard language model. We use VAEs typically for their ability to construct encoded vectors unsupervised, so to apply VAEs effectively, we must solve the KL vanishing problem. In fact, since 2016, a considerable amount of work has addressed this issue, proposing various solutions such as annealing strategies and changing the prior distribution. Readers can find many publications by Googling "KL Vanishing."

The Ingenuity and Subtlety of BN

The solution in this article directly addresses the KL divergence term, being simple, effective, and free of hyper-parameters. The core idea is simple:

Isn't KL vanishing just the KL divergence term becoming 0? If I adjust the encoder output such that the KL divergence has a lower bound greater than zero, won't it definitely not vanish?

The direct result of this simple idea is: add a BN layer after $\mu$, as shown in the diagram:

(Adding BN to the VAE)

Brief Derivation

Why is there a connection with BN? Let's look at the form of the KL divergence term:

\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] = \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\Big(\mu_{i,j}^2 + \sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1\Big)\end{equation}

The above equation is the result of calculating sampled results from $b$ samples, where the dimension of the encoding vector is $d$. Since we always have $e^x \geq x + 1$, it follows that $\sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1 \geq 0$. Therefore:

\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] \geq \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\mu_{i,j}^2 = \frac{1}{2}\sum_{j=1}^d \left(\frac{1}{b} \sum_{i=1}^b \mu_{i,j}^2\right)\label{eq:kl}\end{equation}

Notice the term in the parentheses; it is effectively the second moment of $\mu$ within the batch. If we add a BN layer to $\mu$, we can generally ensure that the mean of $\mu$ is $\beta$ and the variance is $\gamma^2$ (where $\beta, \gamma$ are trainable parameters in BN). At this point:

\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] \geq \frac{d}{2}\left(\beta^2 + \gamma^2\right)\label{eq:kl-lb}\end{equation}

Thus, as long as we control $\beta, \gamma$ (primarily by fixing $\gamma$ as a certain constant), we can ensure the KL divergence term has a positive lower bound, thus preventing the KL vanishing phenomenon. In this way, KL vanishing and BN are ingeniously linked, with BN "eliminating" the possibility of KL vanishing.

Why not LN?

Readers skilled in derivation might think that, according to the logic above, Layer Normalization (LN) could also provide a positive lower bound for the KL divergence term—specifically by normalizing along the $j$ dimension in Equation $\eqref{eq:kl}$.

So why use BN instead of LN?

The answer to this question lies in the subtlety of BN. Intuitively, KL vanishing occurs because the noise in $z \sim p(z|x)$ is too large for the decoder to distinguish non-noise components, leading it to ignore $z$ entirely. When BN is added to $\mu(x)$, it effectively stretches the distances between the $z$ vectors of different samples. This makes it easier to distinguish even noisy components of $z$. Consequently, the decoder is more willing to use the information in $z$, alleviating the problem. In contrast, LN performs normalization within a single sample and does not have the effect of increasing the gap between different samples, so LN is not as effective as BN.

Further Results

The derivation of the original paper basically ends here, followed by experimental sections, including determining the value of $\gamma$. However, I believe the current conclusions still have some room for improvement. For instance, they don't provide a deeper understanding of why adding BN works, appearing more like an engineering trick. Furthermore, only $\mu(x)$ has BN applied while $\sigma(x)$ does not, which feels somewhat asymmetrical.

Through my own derivation, I found that the above conclusions can be further refined.

Relating to the Prior Distribution

For a VAE, it is desired that the distribution of latent variables in the trained model matches the prior $q(z)=\mathcal{N}(z;0,1)$, while the posterior is $p(z|x)=\mathcal{N}(z; \mu(x),\sigma^2(x))$. Therefore, a VAE aims for the following to hold:

\begin{equation}q(z) = \int \tilde{p}(x)p(z|x)dx = \int \tilde{p}(x)\mathcal{N}(z; \mu(x),\sigma^2(x))dx\end{equation}

Multiplying both sides by $z$ and integrating with respect to $z$, we get:

\begin{equation}0 = \int \tilde{p}(x)\mu(x)dx = \mathbb{E}_{x\sim \tilde{p}(x)}[\mu(x)]\end{equation}

Multiplying both sides by $z^2$ and integrating with respect to $z$, we get:

\begin{equation}1 = \int \tilde{p}(x)\left[\mu^2(x) + \sigma^2(x)\right]dx = \mathbb{E}_{x\sim \tilde{p}(x)}\left[\mu^2(x)\right] + \mathbb{E}_{x\sim \tilde{p}(x)}\left[\sigma^2(x)\right]\end{equation}

If we add BN to both $\mu(x)$ and $\sigma(x)$, we have:

\begin{equation}\begin{aligned} &0 = \mathbb{E}_{x\sim \tilde{p}(x)}[\mu(x)] = \beta_{\mu}\\ &1 = \mathbb{E}_{x\sim \tilde{p}(x)}\left[\mu^2(x)\right] + \mathbb{E}_{x\sim \tilde{p}(x)}\left[\sigma^2(x)\right] = \beta_{\mu}^2 + \gamma_{\mu}^2 + \beta_{\sigma}^2 + \gamma_{\sigma}^2 \end{aligned}\end{equation}

Thus, we know that $\beta_{\mu}$ must be 0. If we also fix $\beta_{\sigma}=0$, we have the constraint:

\begin{equation}1 = \gamma_{\mu}^2 + \gamma_{\sigma}^2\label{eq:gamma2}\end{equation}

Reference Implementation

Through this derivation, we find that BN can be added to both $\mu(x)$ and $\sigma(x)$, and we can fix $\beta_{\mu}=\beta_{\sigma}=0$, provided that the constraint in $\eqref{eq:gamma2}$ is satisfied. It should be noted that this discussion is a general analysis of VAEs and does not yet address the KL vanishing problem. Even if these conditions are met, there is no guarantee that the KL term won't tend toward 0. Combined with Eq. $\eqref{eq:kl-lb}$, we know the key to preventing KL vanishing is ensuring $\gamma_{\mu} > 0$. Therefore, the final strategy I propose is:

\begin{equation}\begin{aligned} &\beta_{\mu}=\beta_{\sigma}=0\\ &\gamma_{\mu} = \sqrt{\tau + (1-\tau)\cdot\text{sigmoid}(\theta)}\\ &\gamma_{\sigma} = \sqrt{(1-\tau)\cdot\text{sigmoid}(-\theta)} \end{aligned}\end{equation}

Where $\tau\in(0,1)$ is a constant (in my experiments, I took $\tau=0.5$), and $\theta$ is a trainable parameter. The formula utilizes the identity $\text{sigmoid}(-\theta) = 1-\text{sigmoid}(\theta)$.

Key implementation code (Keras):

class Scaler(Layer):
    """Special scale layer
    """
    def __init__(self, tau=0.5, **kwargs):
        super(Scaler, self).__init__(**kwargs)
        self.tau = tau

    def build(self, input_shape):
        super(Scaler, self).build(input_shape)
        self.scale = self.add_weight(
            name='scale', shape=(input_shape[-1],), initializer='zeros'
        )

    def call(self, inputs, mode='positive'):
        if mode == 'positive':
            scale = self.tau + (1 - self.tau) * K.sigmoid(self.scale)
        else:
            scale = (1 - self.tau) * K.sigmoid(-self.scale)
        return inputs * K.sqrt(scale)

    def get_config(self):
        config = {'tau': self.tau}
        base_config = super(Scaler, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

def sampling(inputs):
    """Reparameterization trick
    """
    z_mean, z_std = inputs
    noise = K.random_normal(shape=K.shape(z_mean))
    return z_mean + z_std * noise

# Assume e_outputs is the output vector of the encoder
e_outputs 
scaler = Scaler()
z_mean = Dense(hidden_dims)(e_outputs)
z_mean = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_mean)
z_mean = scaler(z_mean, mode='positive')
z_std = Dense(hidden_dims)(e_outputs)
z_std = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_std)
z_std = scaler(z_std, mode='negative')
z = Lambda(sampling, name='Sampling')([z_mean, z_std])

Summary

In this article, we briefly analyzed the KL vanishing phenomenon of VAE in NLP and introduced a method to prevent it and stabilize the training process using BN layers. This is a concise and effective solution. Beyond the original paper, my private experiments have also confirmed its effectiveness, making it well worth trying for readers. Because the derivation is general, it can even be attempted for VAE models in any scenario (such as CV).