By 苏剑林 | September 17, 2018
Since VAEs contain both an encoder and a decoder (generator), and the distribution of latent variables is approximated as a standard normal distribution, a VAE is both a generative model and a feature extractor. In the field of images, because VAE-generated images tend to be blurry, people are usually more interested in the role of the VAE as an image feature extractor. Feature extraction is done to prepare for subsequent tasks, which can include many types, such as classification or clustering. This article focuses on the "clustering" task.
Generally speaking, using AE or VAE for clustering is done in steps: first, train an ordinary VAE, then obtain the latent variables of the original data, and then apply K-Means or GMM to these latent variables. However, this approach clearly lacks a sense of overall integration, and the choice of clustering method can be perplexing. This article introduces a "one-stop" clustering idea based on VAE, which allows us to complete clustering and conditional generation simultaneously in an unsupervised manner.
Theory
General Framework
Recalling the loss of VAE (if you don't remember, please refer to "Variational Autoencoders (II): Starting from the Bayesian Perspective"):
$$KL\Big(p(x,z)\Big\Vert q(x,z)\Big) = \iint p(z|x)\tilde{p}(x)\ln \frac{p(z|x)\tilde{p}(x)}{q(x|z)q(z)} dzdx\tag{1}$$
Usually, we assume that $q(z)$ is the standard normal distribution, and $p(z|x), q(x|z)$ are conditional normal distributions. By substituting these and calculating, we obtain the loss for an ordinary VAE.
However, no one stipulates that latent variables must be continuous, right? Here, we define the latent variables as $(z, y)$, where $z$ is a continuous variable representing the encoding vector, and $y$ is a discrete variable representing the category. Directly replacing $z$ in $(1)$ with $(z, y)$, we get:
$$KL\Big(p(x,z,y)\Big\Vert q(x,z,y)\Big) = \sum_y \iint p(z,y|x)\tilde{p}(x)\ln \frac{p(z,y|x)\tilde{p}(x)}{q(x|z,y)q(z,y)} dzdx\tag{2}$$
This is the loss for a VAE used for clustering.
Step-by-step Hypotheses
Wait, is that it? Well, yes, if we only consider the generalized framework, $(2)$ indeed covers it.
However, when it comes to practice, $(2)$ can have many different implementation schemes. Here, we introduce a simpler one. First, we must clarify that in $(2)$, we only know $\tilde{p}(x)$ (the empirical distribution given by a batch of data); everything else is not explicitly defined. To solve $(2)$, we need to set some forms. One selection scheme is:
$$p(z,y|x)=p(y|z)p(z|x),\quad q(x|z,y)=q(x|z),\quad q(z,y)=q(z|y)q(y)\tag{3}$$
Substituting into $(2)$, we get:
$$KL\Big(p(x,z,y)\Big\Vert q(x,z,y)\Big) = \sum_y \iint p(y|z)p(z|x)\tilde{p}(x)\ln \frac{p(y|z)p(z|x)\tilde{p}(x)}{q(x|z)q(z|y)q(y)} dzdx\tag{4}$$
Equation $(4)$ is actually quite intuitive; it describes the encoding and generation processes separately:
1. Sample $x$ from the original data, then obtain the encoding feature $z$ through $p(z|x)$, and then classify the encoding feature using the classifier $p(y|z)$ to obtain the category;
2. Select a category $y$ from the distribution $q(y)$, then select a random latent variable $z$ from the distribution $q(z|y)$, and finally decode it into an original sample through the generator $q(x|z)$.
Concrete Model
Equation $(4)$ is already quite specific. We only need to follow the previous VAE approach: $p(z|x)$ is generally assumed to be a normal distribution with mean $\mu(x)$ and variance $\sigma^2(x)$; $q(x|z)$ is generally assumed to be a normal distribution with mean $G(z)$ and constant variance (equivalent to using MSE as the loss); $q(z|y)$ can be assumed to be a normal distribution with mean $\mu_y$ and variance 1. As for the remaining $q(y)$ and $p(y|z)$, $q(y)$ can be assumed to be a uniform distribution (it's just a constant), which is hoping that each class is roughly balanced, and $p(y|z)$ is a classifier for the latent variables, which can be fitted using any softmax network.
Finally, $(4)$ can be vividly rewritten as:
$$\mathbb{E}_{x\sim\tilde{p}(x)}\Big[-\log q(x|z) + \sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)} + KL\big(p(y|z)\big\Vert q(y)\big)\Big],\quad z\sim p(z|x) \tag{5}$$
where $z\sim p(z|x)$ is the reparameterization operation, and the three loss terms in the brackets each have their own meaning:
1. $-\log q(x|z)$ hopes that the reconstruction error is as small as possible, meaning $z$ preserves complete information as much as possible;
2. $\sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)}$ hopes that $z$ can align as much as possible with the "exclusive" normal distribution of a certain category; this is the step that performs clustering;
3. $KL\big(p(y|z)\big\Vert q(y)\big)$ hopes that the distribution of each class is as balanced as possible, avoiding cases where two classes almost overlap (collapsing into one class). Of course, sometimes this prior requirement may not be necessary, in which case this term can be removed.
Experiments
The experimental code was naturally completed using Keras (^_^). Experiments were conducted on MNIST and Fashion-MNIST, and the performance was quite good. Experimental environment: Keras 2.2 + TensorFlow 1.8 + Python 2.7.
Code Implementation
Code is located at: https://github.com/bojone/vae/blob/master/vae_keras_cluster.py
Actually, the comments should be quite clear, and it doesn't deviate much from a standard VAE. What might be slightly difficult is how to implement $\sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)}$. First, we substitute:
\begin{equation}
\begin{aligned}
p(z|x)&=\frac{1}{\prod\limits_{i=1}^d\sqrt{2\pi\sigma_i^2(x)}}\exp\left\{-\frac{1}{2}\left\Vert\frac{z - \mu(x)}{\sigma(x)}\right\Vert^2\right\}\\
q(z|y)&=\frac{1}{(2\pi)^{d/2}}\exp\left\{-\frac{1}{2}\left\Vert z - \mu_y\right\Vert^2\right\}
\end{aligned}\tag{6}
\end{equation}
to get
$$\log \frac{p(z|x)}{q(z|y)}=-\frac{1}{2}\sum_{i=1}^d \log \sigma_i^2(x)-\frac{1}{2}\left\Vert\frac{z - \mu(x)}{\sigma(x)}\right\Vert^2 + \frac{1}{2}\left\Vert z - \mu_y\right\Vert^2 \tag{7}$$
Note that the second term is actually redundant because the reparameterization operation tells us $z = \varepsilon\otimes \sigma(x) + \mu(x),\, \varepsilon\sim \mathcal{N}(0,1)$, so the second term is actually just $-\Vert \varepsilon\Vert^2/2$, which is independent of the parameters. Therefore:
$$\log \frac{p(z|x)}{q(z|y)}\sim -\frac{1}{2}\sum_{i=1}^d \log \sigma_i^2(x) + \frac{1}{2}\left\Vert z - \mu_y\right\Vert^2 \tag{8}$$
Since $y$ is discrete, $\sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)}$ is essentially a matrix multiplication (multiplying and then summing over a common variable is the general form of matrix multiplication), which is implemented using K.batch_dot.
As for the rest, readers should be familiar with the implementation process of standard VAEs before reading the content and code of this article; otherwise, they might be quite confused.
MNIST
Here are the experimental results for MNIST, including intra-class samples and per-class sampling. Finally, I made a simple estimation: if the most frequent ground-truth label in each class is taken as the class label, the final test accuracy is about 83%. Compared to the results in the paper "Unsupervised Deep Embedding for Clustering Analysis" (which also peaks at around 84%), this feels very good.
Clustering Visualization
Clustering Category_0, Clustering Category_1, ..., Clustering Category_9.
Per-Class Sampling
Category Sampling_0, Category Sampling_1, ..., Category Sampling_9.
Fashion-MNIST
Here are the experimental results for Fashion-MNIST, including intra-class samples and per-class sampling. The final test accuracy is approximately 58.5%.
Clustering Visualization
Clustering Category_0, Clustering Category_1, ..., Clustering Category_9.
Per-Class Sampling
Category Sampling_0, Category Sampling_1, ..., Category Sampling_9.
Conclusion
This article briefly implemented a clustering algorithm based on VAE. The characteristic of the algorithm is its "one-stop" nature, completing the three tasks of "encoding," "clustering," and "generation" simultaneously. The core idea is the generalization of the VAE loss.
There is still room for improvement. For example, Equation $(4)$ is just one example of Equation $(2)$, and more general cases could be considered. The encoder and decoder in the code have not been carefully tuned; they were merely used to verify the idea.