By 苏剑林 | May 20, 2019
Generally speaking, neural networks deal with continuous floating-point numbers, and standard outputs are also continuous values. However, in practical problems, we often need discrete results. For example, in classification problems, we want to output a correct category—the "category" is discrete, while the "category probability" is continuous. Similarly, the evaluation metrics for many tasks are actually discrete, such as accuracy and F1 score for classification, BLEU for machine translation, and so on.
Taking classification as an example again: the common evaluation metric is accuracy, while the common loss function is cross-entropy. Although a decrease in cross-entropy is indeed correlated with an increase in accuracy, they do not have an absolutely monotonic relationship. In other words, if cross-entropy decreases, accuracy does not necessarily increase. Obviously, it would be ideal to use the negative of the accuracy as the loss function, but accuracy is non-differentiable (involving operations like $\text{argmax}$), so it cannot be used directly.
In such cases, there are generally two solutions: one is to use reinforcement learning, treating accuracy as the reward function—this is a "sledgehammer to crack a nut" approach. The other is to try to find a smooth, differentiable approximation formula for accuracy. This article explores common smooth approximations of non-differentiable functions, sometimes referred to as "smoothing" and sometimes as "softening."
The foundation for most of the content discussed later is the smooth approximation of the $\max$ operation. We have:
\begin{equation}\max(x_1,x_2,\dots,x_n) = \lim_{K\to +\infty}\frac{1}{K}\log\left(\sum_{i=1}^n e^{K x_i}\right)\end{equation}By choosing a constant $K$, we get the approximation:
\begin{equation}\max(x_1,x_2,\dots,x_n) \approx \frac{1}{K}\log\left(\sum_{i=1}^n e^{K x_i}\right)\end{equation}In many models, we can set $K=1$, which is equivalent to incorporating $K$ into the model itself. Thus, the simplest version is:
\begin{equation}\begin{aligned}\max(x_1,x_2,\dots,x_n) \approx&\, \log\left(\sum_{i=1}^n e^{x_i}\right) \\ \triangleq&\, \text{logsumexp}(x_1, x_2, \dots, x_n)\end{aligned}\label{eq:max-approx}\end{equation}Here, $\text{logsumexp}$ appears. This is a very common operator, and here it serves as a smooth approximation of the $\max$ function. Indeed, the smooth approximation of the $\max$ function is actually $\text{logsumexp}$, not $\text{softmax}$, which sounds similar. For related derivations, you can also refer to my previous post "Seeking a Smooth Maximum Function".
I just mentioned that $\text{softmax}$ is not a smooth approximation of $\max$. So what is it an approximation of? It is actually a smooth approximation of $\text{onehot}(\text{argmax}(\boldsymbol{x}))$, which first finds the position of the maximum value and then generates a vector of equal length where the maximum value's position is set to 1 and all other positions are set to 0. For example:
\begin{equation}[2, 1, 4, 5, 3]\quad \to \quad [0, 0, 0, 1, 0]\end{equation}We can easily provide a derivation from $\text{logsumexp}$ to $\text{softmax}$. Consider a vector $\boldsymbol{x}=[x_1, x_2, \dots, x_n]$, and then consider
\begin{equation}\boldsymbol{x}'=[x_1, x_2, \dots, x_n] - \max(x_1, x_2, \dots, x_n)\end{equation}That is, subtracting the overall maximum value from each element. The position of the maximum value in this new vector is the same as in the original vector, meaning $\text{onehot}(\text{argmax}(\boldsymbol{x}))=\text{onehot}(\text{argmax}(\boldsymbol{x}'))$. Without loss of generality, consider the case where $x_1, x_2, \dots, x_n$ are pairwise distinct. Then the maximum value of the new vector is clearly 0, and all other elements are negative. Given this, we can consider
\begin{equation}e^{\boldsymbol{x}'}=[e^{x_1 - \max(x_1, x_2, \dots, x_n)}, e^{x_2 - \max(x_1, x_2, \dots, x_n)}, \dots, e^{x_n - \max(x_1, x_2, \dots, x_n)}]\end{equation}as an approximation for $\text{onehot}(\text{argmax}(\boldsymbol{x}'))$. Since the maximum value is 0, the corresponding position becomes $e^0=1$, while the others are negative, getting close to 0 after exponentiation.
Finally, substituting the approximation $\eqref{eq:max-approx}$ into the above formula and simplifying, we get:
\begin{equation}\begin{aligned}\text{onehot}(\text{argmax}(\boldsymbol{x}))=&\, \text{onehot}(\text{argmax}(\boldsymbol{x}'))\\ \approx&\, \left(\frac{e^{x_1}}{\sum\limits_{i=1}^n e^{x_i}}, \frac{e^{x_2}}{\sum\limits_{i=1}^n e^{x_i}}, \dots, \frac{e^{x_n}}{\sum\limits_{i=1}^n e^{x_i}}\right)\\ \triangleq&\, \text{softmax}(x_1, x_2, \dots, x_n) \end{aligned}\end{equation}$\text{argmax}$ refers to directly giving the index (an integer) where the vector's maximum value is located. For example:
\begin{equation}[2, 1, 4, 5, 3]\quad \to \quad 4\end{equation}Here we follow general usage where indexing starts from 1, so the result is 4; however, in programming languages, it usually starts from 0, so the result would typically be 3.
If we want a smooth approximation of $\text{argmax}$, we naturally hope to output a floating-point number close to 4. To construct such an approximation, we first observe that $\text{argmax}$ is actually equal to:
\begin{equation}\text{sum}\Big(\underbrace{[1, 2, 3, 4, 5]}_{\text{Sequence vector } [1, 2, ..., n]}\, \otimes\, \underbrace{[0, 0, 0, 1, 0]}_{\text{onehot}(\text{argmax}(\boldsymbol{x}))}\Big)\end{equation}Which is the inner product of the array $[1, 2, \dots, n]$ and $\text{onehot}(\text{argmax}(\boldsymbol{x}))$. Constructing a softened version of $\text{argmax}$ is then simple: replace $\text{onehot}(\text{argmax}(\boldsymbol{x}))$ with $\text{softmax}(\boldsymbol{x})$:
\begin{equation}\text{argmax} (\boldsymbol{x}) \approx \sum_{i=1}^n i\times \text{softmax}(\boldsymbol{x})_i\end{equation}The various approximations discussed above are basically derived by finding the correct form based on one-hot vectors and then using softmax to approximate the one-hot vector. Using this logic, smooth approximations for many other operators can be derived, such as accuracy.
For simplicity, let's introduce the notation $\boldsymbol{1}_k$, which represents a one-hot vector with a 1 at the $k$-th position. Suppose in a classification problem, the target category is $i$ and the predicted category is $j$. We can consider the one-hot vectors $\boldsymbol{1}_i$ and $\boldsymbol{1}_j$ and then look at their inner product:
\begin{equation}\langle \boldsymbol{1}_i, \boldsymbol{1}_j\rangle = \left\{\begin{aligned}&1,\,\,(i=j)\\ &0,\,\,(i\neq j)\end{aligned}\right.\end{equation}In other words, when the two categories are the same, the inner product is exactly 1, and when they are different, the inner product is exactly 0. Thus, the inner product of the one-hot vectors corresponding to the target category and the predicted category defines exactly a "predicted correctly" counting function. With a counting function, we can calculate accuracy:
\begin{equation}\text{Accuracy}=\frac{1}{|\mathcal{B}|}\sum_{\boldsymbol{x}\in\mathcal{B}}\langle \boldsymbol{1}_i(\boldsymbol{x}), \boldsymbol{1}_j(\boldsymbol{x})\rangle\end{equation}where $\mathcal{B}$ denotes the current batch. The formula above is the function for calculating accuracy within a batch. In a neural network, to ensure differentiability, the final output can only be a probability distribution (the result after softmax). Thus, a smooth approximation of accuracy is obtained by replacing the predicted category's one-hot vector with the probability distribution:
\begin{equation}\text{Accuracy}\approx \frac{1}{|\mathcal{B}|}\sum_{\boldsymbol{x}\in\mathcal{B}}\langle \boldsymbol{1}_i(\boldsymbol{x}), p(\boldsymbol{x})\rangle\end{equation}Similarly, smooth approximations for recall, F1, and other metrics can be derived. Taking binary classification as an example, if $p(\boldsymbol{x})$ is the probability of the positive class and $t(\boldsymbol{x})$ is the label of sample $\boldsymbol{x}$ (0 or 1), then the smooth approximation for the positive class F1 is:
\begin{equation}\text{Positive F1}\approx\frac{2 \sum\limits_{\boldsymbol{x}\in\mathcal{B}}t(\boldsymbol{x}) p(\boldsymbol{x})}{\sum\limits_{\boldsymbol{x}\in\mathcal{B}}\big[t(\boldsymbol{x}) + p(\boldsymbol{x})\big]}\end{equation}The accuracy approximation formula derived this way is differentiable and can be used directly with its negative as the loss. However, in the sampling estimation process, it is a biased estimate of F1 (the denominator also contains a summation over the batch), which can sometimes affect the optimization trajectory or even lead to divergence. Therefore, in general, it is best not to use it directly from the start; instead, train with standard cross-entropy until it is nearly converged, and then fine-tune using the negative of F1 as the loss.
$\text{softmax}$ is a smooth approximation of "setting the maximum position to 1 and others to 0." What if we want a smooth approximation of "setting the positions of the top $k$ values to 1 and others to 0"? We might call this $\text{soft-}k\text{-max}$.
I haven't constructed a simple analytical form for $\text{soft-}k\text{-max}$, but it can be constructed recursively:
Input is $\boldsymbol{x}$, initialize $\boldsymbol{p}^{(0)}$ as an all-zero vector;
Execute $\boldsymbol{x} = \boldsymbol{x} - \min(\boldsymbol{x})$ (to ensure all elements are non-negative)
For $i=1,2,\dots,k$, execute:
$\boldsymbol{y} = (1 - \boldsymbol{p}^{(i-1)})\otimes\boldsymbol{x}$;
$\boldsymbol{p}^{(i)} = \boldsymbol{p}^{(i-1)} + \text{softmax}(\boldsymbol{y})$
Return $\boldsymbol{p}^{(k)}$.
As for the principle, if you replace $\text{softmax}(\boldsymbol{y})$ with $\text{onehot}(\text{argmax}(\boldsymbol{y}))$ and look at the recursion, it becomes clear: it essentially finds the max, then effectively removes that max so the second-largest value becomes the new max, then applies softmax again, and repeats $k$ times.
Function smoothing is an interesting mathematical topic that frequently appears in machine learning. On one hand, it is a technique to make certain operations differentiable, allowing models to be solved directly using backpropagation without "resorting" to reinforcement learning. On the other hand, in some cases, it can enhance the interpretability of a model because the corresponding non-differentiable functions often have better interpretability. After training with the smoothed version, it may be possible to revert to the non-differentiable version to explain the model's outputs.
Of course, appreciating it as a piece of pure mathematical beauty is also quite rewarding.