By 苏剑林 | July 15, 2022
Some readers might have noticed that this update has been delayed for quite a while. In fact, I started preparing this article last weekend. However, I underestimated the difficulty of this problem; I spent nearly a whole week on derivations and still haven't reached a perfect result. What I am posting now is still a failed attempt, and I hope experienced readers can provide some guidance.
In the article "Generalizing 'Softmax + Cross-Entropy' to Multi-label Classification", we proposed a multi-label classification loss function that automatically adjusts for the imbalance between positive and negative classes. Later, in "The Soft Label Version of Multi-label 'Softmax + Cross-Entropy'", we further derived its "soft label" version. Essentially, multi-label classification is an "n-way 2-class" problem. Consequently, what would the loss function for an "n-way m-class" problem look like?
This is the question that this article intends to explore.
In the soft label generalization article "The Soft Label Version of Multi-label 'Softmax + Cross-Entropy'", we obtained the final result by directly applying a first-order truncation within the $\log$ of the "n-way 2-class" sigmoid cross-entropy loss. The same process can indeed be generalized to the "n-way m-class" softmax cross-entropy loss. This was my first attempt.
Let $\text{softmax}(s_{i,j}) = \frac{e^{s_{i,j}}}{\sum\limits_j e^{s_{i,j}}}$, where $s_{i,j}$ are the predicted scores and $t_{i,j}$ are the labels. Then
\begin{equation}\begin{aligned}-\sum_i\sum_j t_{i,j}\log \text{softmax}(s_{i,j}) =& \,\sum_i\sum_j t_{i,j}\log \left(1 + \sum_{k\neq j} e^{s_{i,k} - s_{i,j}}\right)\\ =& \,\sum_j \log \prod_i\left(1 + \sum_{k\neq j} e^{s_{i,k} - s_{i,j}}\right)^{t_{i,j}}\\ =& \,\sum_j \log \left(1 + \sum_i t_{i,j}\sum_{k\neq j} e^{s_{i,k} - s_{i,j}}+\cdots\right)\\ \end{aligned}\end{equation}The summation over $i$ defaults to $1 \sim n$, and over $j$ defaults to $1 \sim m$. Truncating the higher-order terms $\cdots$, we get
\begin{equation}l = \sum_j \log \left(1 + \sum_{i,k\neq j} t_{i,j}e^{- s_{i,j} + s_{i,k}}\right)\label{eq:loss-1}\end{equation}This is the loss I initially obtained, and it is a natural generalization of the previous result to "n-way m-class". In fact, if $t_{i,j}$ are hard labels, this loss is essentially fine. However, I hoped that, like in "The Soft Label Version of Multi-label 'Softmax + Cross-Entropy'", I could derive a corresponding analytical solution for soft labels. To this end, I took its derivative:
\begin{equation}\frac{\partial l}{\partial s_{i,j}} = \frac{- t_{i,j}e^{- s_{i,j}}\sum\limits_{k\neq j} e^{s_{i,k}}}{1 + \sum\limits_{i,k\neq j} t_{i,j}e^{- s_{i,j} + s_{i,k}}} + \sum_{h\neq j} \frac{t_{i,h}e^{- s_{i,h}}e^{s_{i,j}}}{1 + \sum\limits_{i,k\neq h} t_{i,h}e^{- s_{i,h} + s_{i,k}}}\end{equation}A so-called analytical solution would be found by solving the equation $\frac{\partial l}{\partial s_{i,j}}=0$. However, I tried for several days and could not solve the equation. I suspect there is no simple explicit solution. Therefore, the first attempt failed.
After trying for several days without success, I thought from the opposite perspective: since the results derived directly by analogy cannot be solved, I might as well work backward from the result—that is, first determine what the solution should be and then reverse-engineer what the equation should look like. Thus, I began my second attempt.
First, I observed that the original multi-label loss, or the loss obtained earlier in Eq. $\eqref{eq:loss-1}$, both take the following form:
\begin{equation}l = \sum_j \log \left(1 + \sum_i t_{i,j}e^{- f(s_{i,j})}\right)\label{eq:loss-2}\end{equation}We take this form as our starting point and compute the derivative:
\begin{equation}\frac{\partial l}{\partial s_{i,k}} = \sum_j \frac{- t_{i,j}e^{- f(s_{i,j})}\frac{\partial f(s_{i,j})}{\partial s_{i,k}}}{1 + \sum\limits_i t_{i,j}e^{- f(s_{i,j})}}\end{equation}We hope that $t_{i,j}=\text{softmax}(f(s_{i,j}))=e^{f(s_{i,j})}/Z_i$ is the analytical solution for $\frac{\partial l}{\partial s_{i,k}}=0$, where $Z_i=\sum\limits_j e^{f(s_{i,j})}$. Substituting this in, we get
\begin{equation}0=\frac{\partial l}{\partial s_{i,k}} = \sum_j \frac{- (1/Z_i)\frac{\partial f(s_{i,j})}{\partial s_{i,k}}}{1 + \sum\limits_i 1/Z_i} = \frac{- (1/Z_i)\frac{\partial \left(\sum\limits_j f(s_{i,j})\right)}{\partial s_{i,k}}}{1 + \sum\limits_i 1/Z_i}\end{equation}So, for the above equation to hold naturally, we find we only need to make $\sum\limits_j f(s_{i,j})$ equal to a constant independent of $i$ and $j$. For simplicity, let
\begin{equation}f(s_{i,j})=s_{i,j}- \bar{s}_i,\qquad \bar{s}_i=\frac{1}{m}\sum_j s_{i,j}\end{equation}This naturally gives $\sum\limits_j f(s_{i,j})=0$. The corresponding optimization target is
\begin{equation}l = \sum_j \log \left(1 + \sum_i t_{i,j}e^{- s_{i,j} + \bar{s}_i}\right)\label{eq:loss-3}\end{equation}Since $\bar{s}_i$ does not affect the normalization result, its theoretical optimal solution is $t_{i,j}=\text{softmax}(s_{i,j})$.
However, while it looks promising, its actual performance is quite poor. Although $t_{i,j}=\text{softmax}(s_{i,j})$ is indeed the theoretical optimal solution, in practice, the performance gets worse as the labels approach hard labels. This is because we know that for the loss in Eq. $\eqref{eq:loss-3}$, as long as $s_{i,j} \gg \bar{s}_i$, the loss will be very close to 0. To satisfy $s_{i,j} \gg \bar{s}_i$, $s_{i,j}$ does not necessarily have to be the maximum among $s_{i,1},s_{i,2},\cdots,s_{i,m}$, which fails to achieve the classification goal.
We now have two results. Eq. $\eqref{eq:loss-1}$ is an analogical generalization of the original multi-label cross-entropy. It performs reasonably well in the case of hard labels; however, because the analytical solution for soft labels cannot be found, the soft label case cannot be theoretically evaluated. Eq. $\eqref{eq:loss-3}$ is theoretically reverse-engineered from the result. Theoretically, its analytical solution is a simple softmax, but due to the limitations of practical optimization algorithms, its performance with hard labels is usually very poor, and it cannot even guarantee that the target logits are the maximum values. Notably, when $m=2$, both Eq. $\eqref{eq:loss-1}$ and Eq. $\eqref{eq:loss-3}$ can degenerate back into the multi-label cross-entropy.
We know that multi-label cross-entropy automatically regulates the problem of positive and negative sample imbalance. Similarly, although we have not yet obtained a perfect generalization, theoretically, generalizing to "n-way m-class" should still automatically regulate the imbalance between the $m$ classes. How does this balancing mechanism work? It is not difficult to understand. Whether in the analogical generalization of Eq. $\eqref{eq:loss-1}$ or the general assumption of Eq. $\eqref{eq:loss-2}$, the summation over $i$ is placed inside the $\log$. Originally, the loss contribution of each class was roughly proportional to "the number of samples in that class". By moving the summation inside the $\log$, the loss contribution of each class becomes roughly equal to "the logarithm of the number of samples in that class", thereby narrowing the loss gap between classes and automatically alleviating the imbalance problem.
Regrettably, this article has not yet reached a perfect generalization for "n-way m-class"—one that should possess two characteristics: 1. the ability to automatically regulate class imbalance through the $\log$ method; 2. the ability to derive an analytical solution for soft labels. For hard labels, using Eq. $\eqref{eq:loss-1}$ directly should be sufficient; but for soft labels, I am truly at a loss. I welcome interested readers to think about and discuss this together.
This article attempted to generalize previous multi-label cross-entropy to "n-way m-class" classification. Unfortunately, this generalization was not entirely successful. I am sharing the results here for the time being, hoping that interested readers can participate in improving them.