By 苏剑林 | July 19, 2020
Class imbalance, also known as the "long-tail problem," is one of the common challenges faced in machine learning, especially with datasets derived from real-world scenarios, which are almost always imbalanced. About two years ago, I also thought about this problem. At the time, I happened to have some insights into "mutual information," so I conceived a solution based on that idea. However, upon further reflection, the logic seemed too commonplace, so I didn't pursue it further. Yet, a few days ago, I came across an article from Google on arXiv titled "Long-tail learning via logit adjustment" and was surprised to find that it contained almost exactly the same method I had originally conceived. This made me realize that my initial idea could actually achieve SOTA (state-of-the-art) performance. Therefore, combining that paper with my original thought process, I have organized the following content, hoping readers won't dismiss it as "Monday morning quarterbacking."
Problem Description
The primary concern here is the single-label multi-classification problem. Suppose there are $K$ candidate categories $1, 2, \dots, K$. The training data is $(x, y) \sim \mathcal{D}$, and the modeled distribution is $p_{\theta}(y|x)$. Our optimization goal is maximum likelihood, or minimizing cross-entropy:
\begin{equation}\mathop{\text{argmin}}_{\theta}\,\mathbb{E}_{(x,y)\sim\mathcal{D}}[-\log p_{\theta}(y|x)]\end{equation}
Typically, the final step of our probability model is a softmax. Assuming the result before softmax is $f(x;\theta)$ (i.e., logits), then:
\begin{equation}-\log p_{\theta}(y|x)=-\log \frac{e^{f_y(x;\theta)}}{\sum\limits_{i=1}^K e^{f_i(x;\theta)}}=\log\left[1 + \sum_{i\neq y}e^{f_i(x;\theta) - f_y(x;\theta)}\right]\label{eq:loss-1}\end{equation}
The so-called class imbalance refers to certain categories having an exceptionally large number of samples, much like "20% of the people hold 80% of the wealth." The remaining categories are numerous, but their total sample count is small. If sorted from high to low, it looks like a long "tail," hence the name long-tail phenomenon. In this situation, when we sample a batch during training, there is rarely an opportunity to sample low-frequency categories; thus, the model easily ignores these classes. However, during evaluation, we usually care more about the recognition performance of these low-frequency categories. This is the root of the contradiction.
Common Approaches
You may have heard of common approaches, which generally fall into three directions:
1. Starting from data: Using techniques like over-sampling or down-sampling to make each batch more balanced.
2. Starting from loss: A classic approach is to divide the loss of sample of category $y$ by the frequency of that category $p(y)$.
3. Starting from results: Modifying the prediction phase of a normally trained model to favor low-frequency categories. For example, if there are far fewer positive samples than negative samples, we might treat any prediction result greater than 0.2 (instead of 0.5) as positive.
Google's original paper lists many references for these three directions. Interested readers can read the original paper directly. Additionally, the Zhihu article "Long-Tailed Classification (2): Latest Research on Classification under Long-Tailed Distribution" also introduces this problem, and readers may refer to it as well.
Learning Mutual Information
Recall how we determine that a classification problem is imbalanced. Obviously, the general approach is to calculate the frequency $p(y)$ of each category from the entire training set and discover that $p(y)$ is concentrated in a few categories. Therefore, the key to solving the class imbalance problem is how to integrate this prior knowledge $p(y)$ into the model.
When I was previously conceiving word vector models (as in the article "A More Elegant Word Vector Model (2): Modeling Language"), I emphasized that compared to fitting conditional probability, if a model can directly fit mutual information, it will learn more essential knowledge, as mutual information is the indicator that reveals core associations. However, fitting mutual information is not easy to train; what is easy to train is conditional probability, using cross-entropy $-\log p_{\theta}(y|x)$ directly. Thus, a more ideal idea is: how to make the model still use cross-entropy as the loss, while essentially fitting mutual information?
In Equation $\eqref{eq:loss-1}$, we modeled
\begin{equation}p_{\theta}(y|x)=\frac{e^{f_y(x;\theta)}}{\sum\limits_{i=1}^K e^{f_i(x;\theta)}}\end{equation}
Now we change to modeling mutual information, which means we hope
\begin{equation}\log \frac{p_{\theta}(y|x)}{p(y)}\sim f_y(x;\theta)\quad \Leftrightarrow\quad \log p_{\theta}(y|x)\sim f_y(x;\theta) + \log p(y)\end{equation}
According to the form on the right, normalizing with softmax again, we have $p_{\theta}(y|x)=\frac{e^{f_y(x;\theta)+\log p(y)}}{\sum\limits_{i=1}^K e^{f_i(x;\theta)+\log p(i)}}$, or written in loss form:
\begin{equation}-\log p_{\theta}(y|x)=-\log \frac{e^{f_y(x;\theta)+\log p(y)}}{\sum\limits_{i=1}^K e^{f_i(x;\theta)+\log p(i)}}=\log\left[1 + \sum_{i\neq y}\frac{p(i)}{p(y)}e^{f_i(x;\theta) - f_y(x;\theta)}\right]\label{eq:loss-2}\end{equation}
The original paper calls this "logit adjustment loss." More generally, a adjustment factor $\tau$ can be added:
\begin{equation}-\log p_{\theta}(y|x)=-\log \frac{e^{f_y(x;\theta)+\tau\log p(y)}}{\sum\limits_{i=1}^K e^{f_i(x;\theta)+\tau\log p(i)}}=\log\left[1 + \sum_{i\neq y}\left(\frac{p(i)}{p(y)}\right)^{\tau}e^{f_i(x;\theta) - f_y(x;\theta)}\right]\label{eq:loss-3}\end{equation}
Generally, $\tau=1$ already achieves near-optimal results. If the final layer of $f_y(x;\theta)$ has a bias term, the simplest implementation is to initialize the bias term as $\tau\log p(y)$. It can also be written within the loss function:
import numpy as np
import keras.backend as K
def categorical_crossentropy_with_prior(y_true, y_pred, tau=1.0):
"""Cross-entropy with prior distribution
Note: y_pred should not have softmax applied
"""
prior = xxxxxx # Define your own prior, shape [num_classes]
log_prior = K.constant(np.log(prior + 1e-8))
for _ in range(K.ndim(y_pred) - 1):
log_prior = K.expand_dims(log_prior, 0)
y_pred = y_pred + tau * log_prior
return K.categorical_crossentropy(y_true, y_pred, from_logits=True)
def sparse_categorical_crossentropy_with_prior(y_true, y_pred, tau=1.0):
"""Sparse cross-entropy with prior distribution
Note: y_pred should not have softmax applied
"""
prior = xxxxxx # Define your own prior, shape [num_classes]
log_prior = K.constant(np.log(prior + 1e-8))
for _ in range(K.ndim(y_pred) - 1):
log_prior = K.expand_dims(log_prior, 0)
y_pred = y_pred + tau * log_prior
return K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
Result Analysis
Clearly, logit adjustment loss belongs to the category of loss adjustment schemes. The difference is that it adjusts weights inside the $\log$, whereas conventional ideas adjust outside the $\log$. As for its advantages, they are the advantages of mutual information: mutual information reveals truly important associations. By supplementing the logits with the bias of the prior distribution, the model is encouraged to "let the prior resolve what it can, and let the model resolve the essential parts that the prior cannot."
In the prediction phase, we can formulate different prediction schemes according to different evaluation metrics. From "Talk on Function Smoothing: Differentiable Approximation of Non-differentiable Functions", we know that for overall accuracy, we have the approximation:
\begin{equation}\text{Overall Accuracy} \approx \frac{1}{N}\sum_{i=1}^N p_{\theta}(y_i|x_i)\end{equation}
where $\{(x_i, y_i)\}_{i=1}^N$ is the validation set. So, if we do not consider class imbalance and pursue higher overall accuracy, for each $x$, we can simply output the category with the largest $p_{\theta}(y|x)$. But if we hope that the accuracy of each class is as high as possible, we rewrite the above formula as:
\begin{equation}\text{Overall Accuracy} \approx \frac{1}{N}\sum_{i=1}^N \frac{p_{\theta}(y_i|x_i)}{p(y_i)}\times p(y_i)=\sum_{y=1}^K p(y)\left(\frac{1}{N}\sum_{x_i\in\Omega_y} \frac{p_{\theta}(y|x_i)}{p(y)}\right)\end{equation}
where $\Omega_y=\{x_i|y_i=y, i=1,2,\dots,N\}$, the set of $x$ labeled $y$. The right side of the equals sign is essentially the result of merging terms with the same $y$. We know that "Overall Accuracy = weighted average of per-class accuracy," and the above formula has exactly the same form. Thus, the term in parentheses $\frac{1}{N}\sum\limits_{x_i\in\Omega_y} \frac{p_{\theta}(y|x_i)}{p(y)}$ is an approximation of "per-class accuracy." Therefore, if we want the accuracy of each class to be as high as possible, we should output the category that maximizes $\frac{p_{\theta}(y|x)}{p(y)}$ (unweighted). Combined with the form of $p_{\theta}(y|x)$, we have the conclusion:
\begin{equation}y^{*}=\left\{\begin{aligned}&\mathop{\text{argmax}}\limits_y\, f_y(x;\theta)+\tau\log p(y),\quad(\text{Pursuing overall accuracy})\\
&\mathop{\text{argmax}}\limits_y\, f_y(x;\theta),\quad(\text{Desired balanced class accuracy})
\end{aligned}\right.\end{equation}
The first strategy actually outputs the one with maximum conditional probability, while the second outputs the one with maximum mutual information, to be chosen based on specific needs.
As for detailed experimental results, readers can look at the paper. In short, it is surprisingly good:
Experimental results from the original paper
Article Summary
This article briefly introduced a method for handling class imbalance based on mutual information. I had previously conceived this scheme but didn't pursue it. A recent paper from Google provided the same method, so I have simply recorded and analyzed it here. Finally, the experimental results provided by Google show that this method can reach SOTA levels.