By 苏剑林 | March 23, 2017
PS: This article summarizes the relationship between gradient descent and the EM algorithm. Using a unified perspective, it derives ordinary gradient descent, the EM algorithm in pLSA, and the EM algorithm in K-Means. The goal is to demonstrate that they are essentially different aspects of the same concept—much like "viewing a mountain range from different angles, where it looks different from near, far, high, or low."
In machine learning, we usually represent the problem we want to solve as a loss function with unknown parameters, such as Mean Squared Error (MSE). We then find ways to minimize this function to obtain the optimal parameter values and complete the model. Since multiplying a function by -1 turns a maximum into a minimum, we generally speak in terms of minimization. In the field of machine learning, two major directions are commonly taught for finding the minimum of a function: 1. Gradient Descent; 2. The EM algorithm (Expectation-Maximization), which is typically used for solving complex maximum likelihood problems.
Standard tutorials often describe these two methods as vastly different, as if they were two rival systems, and the EM algorithm is often portrayed as being particularly mysterious. In reality, however, both methods are simply different examples of the same underlying concept. As the saying goes, "they are born from the same root"—they share a common lineage.
Let us begin with the ancient Newton's method.
Newton's Iterative Method
Given a complex non-linear function $f(x)$ for which we want to find the minimum, we can generally proceed as follows. Assuming it is sufficiently smooth, its minimum point is also a local minimum point, which satisfies $f'(x_0)=0$. Thus, the problem is transformed into finding the root of the equation $f'(x)=0$. For the roots of non-linear equations, we have Newton's method:
\begin{equation}x_{n+1} = x_{n} - \frac{f'(x_n)}{f''(x_n)}\end{equation}
However, this approach lacks geometric intuition and doesn't reveal deeper secrets. We prefer the following thought process: at the point $x=x_n$ on the curve $y=f(x)$, we can approximate the original function with a simpler curve. If the approximate curve is easy to minimize, we can use the minimum of that approximate curve to substitute for the minimum of the original curve:
Approximation-Iteration
Obviously, the requirements for the approximate curve are:
1. It must approximate the real curve to a certain extent; generally, at least first-order accuracy is required.
2. It must have a minimum point, and that minimum point must be easy to solve.
Naturally, we can choose a "tangent parabola" for the approximation:
\begin{equation}f(x)\approx g(x) = f(x_n)+f'(x_n)(x-x_n)+\frac{1}{2}f''(x_n)(x-x_n)^2\end{equation}
This parabola has second-order accuracy. For this parabola, the extremum point is:
\begin{equation}x_n - \frac{f'(x_n)}{f''(x_n)}\end{equation}
Thus, we arrive again at the iterative formula for Newton's method:
\begin{equation}x_{n+1} = x_n - \frac{f'(x_n)}{f''(x_n)}\end{equation}
If $f(x)$ is sufficiently smooth and has only one global extremum, Newton's method converges rapidly (at an exponential rate). However, real-world functions are rarely so ideal, so its weaknesses become apparent:
1. It requires calculating the second derivative, which can be extremely complex for some functions.
2. Since the magnitude of $f''(x_n)$ is uncertain, the direction of the opening of $g(x)$ is uncertain, making it impossible to guarantee whether the result is a maximum or a minimum.
Gradient Descent
In many problems, these two weaknesses are fatal. Therefore, to solve these issues, we abandon second-order accuracy. That is, we remove $f''(x_n)$ and replace it with a fixed positive constant $1/h$:
\begin{equation}g(x) = f(x_n)+f'(x_n)(x-x_n)+\frac{1}{2h}(x-x_n)^2\end{equation}
This approximate curve only has first-order accuracy, but it eliminates the need for second-order derivatives and ensures that the parabola always opens upwards. Iterating through it guarantees convergence to a minimum (at least a local one). The minimum point of $g(x)$ above is:
\begin{equation}x_n - h f'(x_n)\end{equation}
So we get the iterative formula:
\begin{equation}x_{n+1} = x_n - h f'(x_n)\end{equation}
In high-dimensional spaces, this is:
\begin{equation}\boldsymbol{x}_{n+1} = \boldsymbol{x}_n - h \nabla(\boldsymbol{x}_n)\end{equation}
This is the famous Gradient Descent method. Of course, it has its own set of problems, but many improved algorithms revolve around it, such as Stochastic Gradient Descent.
Here we understand gradient descent as the result obtained by approximating using a parabola. Given this view, readers might wonder: why must I use a parabola? Is it not possible to use other curves for approximation? Of course it is. For many problems, gradient descent might even complicate the issue—meaning the parabola is no longer effective. In such cases, we consider other forms of approximation. In fact, other approximation schemes are essentially what is called the EM algorithm. It seems strange to exclude gradient descent from this, given they share the same origin.
Maximum Likelihood
When estimating probabilities, our optimization target is usually the maximum likelihood function rather than MSE. For example, if we are building a language model, we need to estimate the co-occurrence probability $p(x,y)$ of any two words $x, y$. Suppose in a corpus of size $N$, the number of co-occurrences of $x, y$ is $\#(x,y)$. We can then obtain the statistical result:
\begin{equation}\tilde{p}(x,y)=\frac{\#(x,y)}{N}\end{equation}
While this is a basic result, it suffers from sparsity issues, and storing results for every possible word pair would consume prohibitive amounts of memory.
A better solution is to assume that $p(x,y)$ can be represented by a function $p(x,y;\theta)$ with unknown parameters $\theta$ (where $\theta$ might be a vector), such as a neural network, and then find ways to optimize the parameters. So the question arises: what is the optimization objective? Note that if MSE is used, the biggest problem is that we cannot guarantee the resulting probability is non-negative, whereas probabilities must be non-negative.
For probability problems, statisticians proposed a more natural scheme—the Maximum Likelihood Function. Philosophers often say "existence is reasonable." The idea of the maximum likelihood function is even more thorough: it says "existence is most reasonable." If the number of co-occurrences of $x, y$ is $\#(x,y)$, then since this event has occurred, it must be the most reasonable outcome. Therefore, the probability function
\begin{equation}\prod_{x,y} p(x,y;\theta)^{\#(x,y)}\end{equation}
should be maximized. Taking the logarithm gives:
\begin{equation}\sum_{x,y} \#(x,y)\log p(x,y;\theta)\end{equation}
We should maximize this function, which is called the Log-Likelihood. Clearly, $\#(x,y)$ can be replaced by the statistical frequency $\tilde{p}(x,y)$, and the result is equivalent:
\begin{equation}\sum_{x,y} \tilde{p}(x,y)\log p(x,y;\theta)\end{equation}
In fact, multiplying by -1 yields:
\begin{equation}S=-\sum_{x,y} \tilde{p}(x,y)\log p(x,y;\theta)\end{equation}
We give this a special name—Cross-Entropy. It is a common loss function in machine learning. In other words, maximizing the likelihood is equivalent to minimizing the cross-entropy. If we don't pre-specify the form of $p(x,y;\theta)$ and directly estimate it, it is easy to calculate that $p(x,y;\theta)=\tilde{p}(x,y)$, which is exactly what we expect and proves the rationality of using maximum likelihood as an optimization target.
EM Algorithm
For cross-entropy optimization, we usually try gradient descent first. However, in many cases, gradient descent is not effective, and we prefer to use the EM algorithm, often referred to as "The Algorithm of God."
Taking the previous language model as an example, to make the estimated results generalize better, we transform $p(x,y)$ into $p(x|y)p(y)$ and then decompose $p(x|y)$ as $p(x|y)=\sum_z p(x|z)p(z|y)$. The significance of this decomposition, as mentioned in the article "SVD Decomposition (II): Why SVD implies Clustering?", is that $z$ can be understood as a category or a topic. $p(x|y)$ is the probability of $x$ following $y$, while $p(z|y)$ is the probability that $y$ belongs to topic $z$, and $p(x|z)$ can be understood as the probability that $x$ appears within topic $z$. Generally, the number of $z$ topics is much smaller than the number of $x,y$ tokens, thereby reducing the total number of parameters.
In this case, the cross-entropy is:
\begin{equation}S=-\sum_{x,y} \tilde{p}(x,y)\log \sum_z p(x|z)p(z|y)p(y)\end{equation}
Assuming $p(y)$ can be accurately estimated by $\tilde{p}(y)$, this term only contributes a constant. Therefore, the equivalent optimization target is:
\begin{equation}S=-\sum_{x,y} \tilde{p}(x,y)\log \sum_z p(x|z)p(z|y)\end{equation}
Here $p(x|z)$ and $p(z|y)$ are the parameters to be determined (traversing all possible combinations of $x,y,z$).
Its gradients are:
\begin{equation}\begin{aligned}&\frac{\partial S}{\partial p(x|z)} = -\sum_{y} \frac{\tilde{p}(x,y)}{\sum_z p(x|z)p(z|y)}p(z|y)\\
&\frac{\partial S}{\partial p(z|y)} = -\sum_{x} \frac{\tilde{p}(x,y)}{\sum_z p(x|z)p(z|y)}p(x|z)\end{aligned}\end{equation}
Direct gradient descent is unfeasible because $p(x|z)$ and $p(z|y)$ are non-negative and subject to constraints:
\begin{equation}\sum_x p(x|z) = 1,\quad \sum_z p(z|y)=1\end{equation}
Gradient descent cannot guarantee non-negative constraints, which dooms it for such problems. Recalling our derivation of gradient descent at the beginning, we use an approximate curve to iterate. While gradient descent uses a parabolic approximation, we won't use it here (because it cannot guarantee positive-definite constraints); instead, we find a way to construct a new approximation. Suppose we have performed $n$ iterations and obtained estimates $p_n(x|z)$ and $p_n(z|y)$. According to the gradient formula, the gradient for this round is:
\begin{equation}\begin{aligned}&\frac{\partial S}{\partial p_n(x|z)} = -\sum_{y} \frac{\tilde{p}(x,y)}{\sum_z p_n(x|z)p_n(z|y)}p_n(z|y)\\
&\frac{\partial S}{\partial p_n(z|y)} = -\sum_{x} \frac{\tilde{p}(x,y)}{\sum_z p_n(x|z)p_n(z|y)}p_n(x|z)\end{aligned}\end{equation}
The difficulty of the original problem lies in the fact that the sum is inside the $\log$. If we could pull the sum outside, it would be simpler. Therefore, we might consider an approximate function like this:
\begin{equation}S_n'=-\sum_{x,y} \tilde{p}(x,y)\sum_z C_{x,y,z,n} \log p(x|z)p(z|y)\end{equation}
Where $C$ is a constant (in the current iteration). Thus $S_n'$ also has a minimum value that can be exactly solved. Naturally, we hope that the gradient of $S'$ is the same as the gradient of the original $S$ (possessing first-order accuracy). The gradient of $S'$ is:
\begin{equation}\begin{aligned}&\frac{\partial S'}{\partial p_n(x|z)} = -\sum_{y} \frac{\tilde{p}(x,y)C_{x,y,z,n}}{p_n(x|z)}\\
&\frac{\partial S'}{\partial p_n(z|y)} = -\sum_{x} \frac{\tilde{p}(x,y)C_{x,y,z,n}}{p_n(z|y)}\end{aligned}\end{equation}
Comparing the two sets of gradients, we get:
\begin{equation}C_{x,y,z,n}=\frac{p_n(x|z)p_n(z|y)}{\sum_z p_n(x|z)p_n(z|y)}\end{equation}
That is, once we have a set of initial parameters, we can substitute them into the formula above to solve for $C_{x,y,z,n}$. Then, we find the parameters that minimize $S_n'$ and use them as $p_{n+1}(x|z)$ and $p_{n+1}(z|y)$, iterating in this manner.
This part basically describes the solving process of the pLSA model; further details can be found in "Natural Language Processing: PLSA". For the purposes of this article, the point is to show that the EM algorithm, like gradient descent, comes from the same source: both are based on approximation curves. It is not some mysterious method, and the search for this approximate function is well-founded. Nearly all tutorials online directly present the expression for $S'$ (usually called the Q-function in standard tutorials) in a way that feels like "metaphysics" to me, which I find quite off-putting.
K-Means
K-Means clustering is easy to understand: given the coordinates of $N$ points $\boldsymbol{x}_i, i=1,\dots,N$, find a way to divide these points into $K$ categories. Each category has a centroid $\boldsymbol{c}_j, j=1,\dots,K$. Naturally, the category a point belongs to is the one whose centroid $\boldsymbol{c}_j$ is closest to it, with distance defined as Euclidean distance.
Therefore, the main task of K-Means clustering is to find the centroids $\boldsymbol{c}_j$. We naturally hope that each centroid is exactly in the "center" of its category. Expressing this as a function, we want to minimize the following function $L$:
\begin{equation}L=\sum_{i=1}^N \min\bigg\{\|\boldsymbol{x}_i-\boldsymbol{c}_1\|^2,\|\boldsymbol{x}_i-\boldsymbol{c}_2\|^2,\dots,\|\boldsymbol{x}_i-\boldsymbol{c}_K\|^2\bigg\}\end{equation}
Where the $\min$ operation ensures that each point only belongs to the category nearest to it.
If we optimized $L$ directly using gradient descent, we would face great difficulties. However, this is not because the $\min$ operation is hard to differentiate, but because this is an NP problem where theoretical convergence time grows exponentially with $N$. In this case, we also use the EM algorithm, which manifests as:
1. Randomly select $K$ points as the initial centroids;
2. Given $K$ centroids, determine which category each point belongs to, and then use the average coordinates of all points in the same category as the new centroid.
This method generally converges within a few iterations. But what is the reason for doing this?
We still follow the idea of approximation curves. But the problem is, what do we do about the non-differentiable $\min$? We can consider a smooth approximation and then take the limit. The answer is found here: "Seeking a Smooth Maximum Function." Taking a sufficiently large $M$, we can assume (since the minimum is the negative of the maximum of the negatives):
\begin{equation}\begin{aligned}&\min\bigg\{\|\boldsymbol{x}_i-\boldsymbol{c}_1\|^2,\|\boldsymbol{x}_i-\boldsymbol{c}_2\|^2,\dots,\|\boldsymbol{x}_i-\boldsymbol{c}_K\|^2\bigg\}\\
=&-\frac{1}{M}\ln\bigg(e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_1\|^2}+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_2\|^2}+\dots+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_K\|^2}\bigg)\end{aligned}\end{equation}
At this point,
\begin{equation}L=-\sum_{i=1}^N \frac{1}{M}\ln\bigg(e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_1\|^2}+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_2\|^2}+\dots+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_K\|^2}\bigg)\end{equation}
gradients can be calculated:
\begin{equation}\frac{\partial L}{\boldsymbol{c}_j}=\sum_{i=1}^N \frac{2e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_j\|^2 } }{e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_1\|^2}+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_2\|^2}+\dots+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_K\|^2}}(\boldsymbol{c}_j-\boldsymbol{x}_i)\end{equation}
Let the result of the $n$-th iteration be $\boldsymbol{c}^{(n)}_j$. The gradient for this round is:
\begin{equation}\frac{\partial L}{\boldsymbol{c}^{(n)}_j}=\sum_{i=1}^N \frac{2e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}^{(n)}_j\|^2 } }{e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}^{(n)}_1\|^2}+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}^{(n)}_2\|^2}+\dots+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_K^{(n)}\|^2}}(\boldsymbol{c}^{(n)}_j-\boldsymbol{x}_i)\end{equation}
Based on the characteristics of this formula, we can look for such an approximate curve (it should be a hypersurface):
\begin{equation}L'=\sum_{i=1}^N \sum_{j=1}^K C^{(n)} _{i,j} \|\boldsymbol{x}_i-\boldsymbol{c}_j\|^2 \end{equation}
Where $C^{(n)} _{i,j}$ is to be determined. It is a constant in each round, so this is just a quadratic function, and its minimum point is easily found:
\begin{equation}\boldsymbol{c}_j = \frac{\sum_{i=1}^N C^{(n)} _{i,j}\boldsymbol{x}_i}{\sum_{i=1}^N C^{(n)} _{i,j}}\end{equation}
which is the weighted average of $\boldsymbol{x}_i$.
Similarly, we want this approximate curve to have at least first-order accuracy compared with the original function. Thus, we take its derivative:
\begin{equation}\frac{\partial L'}{\boldsymbol{c}_j}=\sum_{i=1}^N 2C^{(n)} _{i,j} (\boldsymbol{c}_j-\boldsymbol{x}_i)\end{equation}
Comparing with the derivative of the original function, it is easy to see that:
\begin{equation}C^{(n)} _{i,j} = \frac{e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}^{(n)}_j\|^2 } }{e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}^{(n)}_1\|^2}+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}^{(n)}_2\|^2}+\dots+e^{-M\|\boldsymbol{x}_i-\boldsymbol{c}_K^{(n)}\|^2}}\end{equation}
At this point, we get the iteration formula:
\begin{equation}\boldsymbol{c}^{(n+1)}_j = \frac{\sum_{i=1}^N C^{(n)} _{i,j}\boldsymbol{x}_i}{\sum_{i=1}^N C^{(n)} _{i,j}}\end{equation}
So far, we have derived all the steps. However, we used a continuous approximation, and finally, we need to take the limit as $M\to\infty$. After taking the limit, the problem simplifies. From the above formula, it can be inferred that:
\begin{equation}\lim_{M\to\infty} C^{(n)} _{i,j} = \Delta^{(n)} _{i,j} = \left\{\begin{aligned}&1,\text{for a fixed i, the distance from j to i is minimal}\\
&0,\text{otherwise}\end{aligned}\right.\end{equation}
In short, if we view $\Delta^{(n)} _{i,j}$ as an $N \times K$ matrix, each row of the matrix can only have one 1, and the rest are 0. If the $j$-th element in the $i$-th row is 1, it means that the nearest centroid to $\boldsymbol{x}_i$ is $\boldsymbol{c}_j$. At this point, the iteration formula becomes:
\begin{equation}\boldsymbol{c}^{(n+1)}_j = \frac{\sum_{i=1}^N \Delta^{(n)} _{i,j}\boldsymbol{x}_i}{\sum_{i=1}^N \Delta^{(n)} _{i,j}}\end{equation}
According to the meaning of $\Delta^{(n)} _{i,j}$, this simply says:
$\boldsymbol{c}^{(n+1)}_j$ is the average of all the points whose nearest centroid is $\boldsymbol{c}^{(n)}_j$.
This yields the iterative algorithm we typically use to solve K-Means problems, which is also called the EM algorithm.
Summary
As we can see, the so-called EM algorithm is not a specific method, but a class of methods or a strategy. Gradient descent is originally a special case within this class of methods; they are fundamentally the same thing and, strictly speaking, it shouldn't be excluded. The so-called "Algorithm of God" is ultimately an iterative method. By updating itself through iteration, it can approach perfection (the global optimum). This is as perfect as biological evolution, almost as if carefully designed by a creator, which is no doubt why it's called the Algorithm of God.