From Loss Hard Truncation and Softening to Focal Loss

By 苏剑林 | December 25, 2017

Preface

In a discussion in a QQ group today, I saw "Focal Loss." After some searching, I found it is a loss function proposed by Kaiming He's team in their paper "Focal Loss for Dense Object Detection," which they used to improve the results of image object detection. However, I rarely work on image-related tasks and don't pay much attention to image applications. Essentially, Focal Loss is a loss function designed to solve the problems of class imbalance and differences in classification difficulty in classification tasks. In short, this work has been widely acclaimed. You can also refer to the discussion on Zhihu: "How to evaluate Kaiming's Focal Loss for Dense Object Detection?"

When I first saw this loss, it felt quite magical and seemed very useful. This is because in Natural Language Processing (NLP), there are also many tasks with significant class imbalance. The most classic example is sequence labeling, where categories are highly unbalanced. For instance, in Named Entity Recognition (NER), it's obvious that entities are much scarcer than non-entities within a sentence. I tried applying it to my sequence-labeling-based QA model and saw a slight improvement. Indeed, it is a good loss.

As I compared it more carefully, I realized that this loss shares a striking similarity with a loss I was conceptualizing last night! This prompted me to write this blog post. I will start from my own perspective to analyze the issue, ultimately arriving at Focal Loss, while also presenting the similar loss I derived last night.

Hard Truncation

The entire article is based on binary classification; the same ideas can be applied to multi-class problems. The standard loss for binary classification is cross-entropy: \[L_{ce} = -y\log \hat{y} - (1-y)\log(1-\hat{y})=\left\{\begin{aligned}&-\log(\hat{y}),\,\text{when }y=1\\ &-\log(1-\hat{y}),\,\text{when }y=0\end{aligned}\right.\] where $y\in\{0,1\}$ is the true label and $\hat{y}$ is the predicted value. Of course, for binary classification, we almost always use the sigmoid function for activation $\hat{y}=\sigma(x)$, so it is equivalent to: \[L_{ce} = -y\log \sigma(x) - (1-y)\log\sigma(-x)=\left\{\begin{aligned}&-\log \sigma(x),\,\text{when }y=1\\ &-\log\sigma(-x),\,\text{when }y=0\end{aligned}\right.\] (Note that $1-\sigma(x)=\sigma(-x)$.)

In a blog post from the first half of the year, "Text Sentiment Classification (IV): Better Loss Functions," I proposed a "hard truncation" loss aimed at "focusing energy on hard-to-classify samples," taking the form: \[L^\cdot = \lambda(y,\hat{y})\cdot L_{ce}\] where \[\lambda(y,\hat{y})=\left\{\begin{aligned}&0,\,(y=1\text{ and }\hat{y} > 0.5)\text{ or }(y=0\text{ and }\hat{y} < 0.5)\\ &1,\,\text{other cases}\end{aligned}\right.\] The approach here is: if the predicted value for a positive sample is greater than 0.5, or for a negative sample less than 0.5, I stop updating. I focus attention on the samples that are incorrectly predicted. Of course, this threshold can be adjusted. This method can partially achieve the goal, but the number of required iterations increases significantly.

The reason is as follows: taking positive samples as an example, I only tell the model not to update if the prediction is greater than 0.5, but I don't tell it to "maintain" it above 0.5. Consequently, in the next stage, the prediction value might easily fall back below 0.5. If that happens, it gets updated again in the next round. While this iterative process can theoretically achieve the goal, it significantly increases the number of iterations. Therefore, the key for improvement is: "don't just tell the model to stop updating when the positive prediction is above 0.5; tell it that once it is above 0.5, it just needs to maintain that state." (It's like a teacher ignoring a student once they pass; this doesn't work. If a student has passed, the goal should be to help them maintain or even improve that state, rather than ignoring them.)

Softening the Loss

Hard truncation fails because the factor $\lambda(y,\hat{y})$ is non-differentiable, or rather, we treat its derivative as 0. Thus, this term provides no help for the gradient, and we cannot get reasonable feedback from it (meaning the model doesn't know what "maintaining" means).

One way to solve this is to "soften" the loss. "Softening" involves approximating non-differentiable functions with differentiable ones—mathematically known as "smoothing." After this treatment, things that were previously non-differentiable become differentiable. A similar example can be found in the K-means section of "Gradient Descent and EM Algorithm: From the Same Origin." First, let's rewrite $L^\cdot$: \[L^\cdot =\left\{\begin{aligned}&-\theta(0.5-\hat{y})\log(\hat{y}),\,\text{when }y=1\\ &-\theta(\hat{y}-0.5)\log(1-\hat{y}),\,\text{when }y=0\end{aligned}\right.\] Here, $\theta$ is the unit step function: \[\theta(x) = \left\{\begin{aligned}&1, x > 0\\ &\frac{1}{2}, x = 0\\ &0, x < 0\end{aligned}\right.\] This $L^\cdot$ is completely equivalent to the original. It is also equivalent to (since $\sigma(0)=0.5$): \[L^\cdot =\left\{\begin{aligned}&-\theta(-x)\log \sigma(x),\,\text{when }y=1\\ &-\theta(x)\log\sigma(-x),\,\text{when }y=0\end{aligned}\right.\] The logic becomes clear: to "soften" this loss, we must "soften" $\theta(x)$. Softening it is easy; it's the sigmoid function! We have: \[\theta(x) = \lim_{K\to +\infty} \sigma(Kx)\] So, it is obvious that we can replace $\theta(x)$ with $\sigma(Kx)$: \[L^{\cdot \cdot }=\left\{\begin{aligned}&-\sigma(-Kx)\log \sigma(x),\,\text{when }y=1\\ &-\sigma(Kx)\log\sigma(-x),\,\text{when }y=0\end{aligned}\right.\] This is the loss I arrived at last night, and it is obviously easy to implement.

Now, let's compare it with Focal Loss.

Focal Loss

The form of Focal Loss proposed by Kaiming He is: \[L_{fl}=\left\{\begin{aligned}&-(1-\hat{y})^{\gamma}\log \hat{y},\,\text{when }y=1\\ &-\hat{y}^{\gamma}\log (1-\hat{y}),\,\text{when }y=0\end{aligned}\right.\] If we use the prediction $\hat{y}=\sigma(x)$, then: \[L_{fl}=\left\{\begin{aligned}&-\sigma^{\gamma}(-x)\log \sigma(x),\,\text{when }y=1\\ &-\sigma^{\gamma}(x)\log\sigma(-x),\,\text{when }y=0\end{aligned}\right.\] In particular, if both $K$ and $\gamma$ are set to 1, then $L^{\cdot \cdot}=L_{fl}$!

In fact, the roles of $K$ and $\gamma$ are identical: they both adjust the steepness of the weight curve, just through different mechanisms. Note that $L^{\cdot \cdot}$ or $L_{fl}$ already contains the solution for unbalanced samples; or rather, class imbalance is essentially a manifestation of differences in classification difficulty. For example, if negative samples far outnumber positive samples, the model will certainly lean toward the majority negative class (one can imagine predicting all samples as negative). In this case, the $\hat{y}^{\gamma}$ or $\sigma(Kx)$ for the negative class will be very small, while the $(1-\hat{y})^{\gamma}$ or $\sigma(-Kx)$ for the positive class will be very large. At this point, the model will start to focus its energy on the positive samples.

Of course, Kaiming He also found that applying a weight adjustment to $L_{fl}$ yields a slight improvement: \[L_{fl}=\left\{\begin{aligned}&-\alpha(1-\hat{y})^{\gamma}\log \hat{y},\,\text{when }y=1\\ &-(1-\alpha)\hat{y}^{\gamma}\log (1-\hat{y}),\,\text{when }y=0\end{aligned}\right.\] Through a series of parameter tuning, he found that $\alpha=0.25, \gamma=2$ worked best (on his model). Note that in his task, positive samples are the minority samples; meaning, while positive samples originally couldn't "compete" with negative samples, after being "manipulated" by $(1-\hat{y})^{\gamma}$ and $\hat{y}^{\gamma}$, the situation might have even reversed, requiring a weight reduction for positive samples. However, I believe such adjustments are empirical results. Theoretically, it is difficult to have a guided strategy to determine the value of $\alpha$. Without massive computing power for parameter tuning, it might be better to simply set $\alpha=0.5$ (equal weighting).

Multi-class

The form of Focal Loss for multi-class classification is also easy to obtain: \[L_{fl}=-(1-\hat{y}_t)^{\gamma}\log \hat{y}_t\] where $\hat{y}_t$ is the predicted value for the target class, usually the result after softmax. How can my conceptualized $L^{\cdot \cdot}$ be generalized to multi-class? It's also simple: \[L^{\cdot \cdot }=-\text{softmax}(-Kx_t)\log \text{softmax}(x_t)\] Here, $x_t$ is also the predicted value for the target, but it is the value before softmax (the logit).

Conclusion

What? You came up with something identical to Kaiming He's idea? No, no, no. This article is merely an introduction to Kaiming He's Focal Loss. More accurately, it is an introduction to some schemes for dealing with classification imbalance and difficulty differences, while providing my own perspective as much as possible. Of course, writing it this way might seem a bit pretentious or like a poor imitation; I ask for the readers' understanding.