Chatting about Multi-Task Learning (I): In the Name of Loss

By 苏剑林 | January 18, 2022

There are many methods to improve model performance, and Multi-Task Learning (MTL) is one of them. Simply put, MTL aims to train multiple related tasks together, hoping that the tasks can complement and promote each other to achieve better results (accuracy, robustness, etc.) than a single task. However, MTL is not as simple as stacking all tasks together; how to balance the training of each task so that each task obtains a beneficial improvement remains a subject worthy of research.

Recently, by coincidence, I have also made some attempts at multi-task learning and took the opportunity to learn about it. I have selected some results to exchange and discuss with everyone here.

Weighted Sum

From the perspective of the loss function, multi-task learning involves multiple loss functions $\mathcal{L}_1,\mathcal{L}_2,\cdots,\mathcal{L}_n$. Generally, they have a large number of shared parameters and a small number of independent parameters. Our goal is to make each loss function as small as possible. To this end, we introduce weights $\alpha_1,\alpha_2,\cdots,\alpha_n\geq 0$ and convert it into a single-task learning problem using a weighted sum as follows:

\begin{equation}\mathcal{L} = \sum_{i=1}^n \alpha_i \mathcal{L}_i\label{eq:w-loss}\end{equation}

From this perspective, the main difficulty of multi-task learning is how to determine the weights $\alpha_i$ for each task.

Initial State

In theory, without task priors or biases, the most natural choice is to treat each task equally, i.e., $a_i=1/n$. However, in reality, tasks can differ significantly. For example, a mixture of classification tasks with different numbers of categories, classification mixed with regression, classification mixed with generation, and so on. From a physical perspective, the dimensions and magnitudes of each loss function are different, making a direct addition meaningless.

If we treat each loss function as a physical quantity with different dimensions, then starting from the idea of "nondimensionalization," we can use the inverse of the initial value of the loss function as the weight, i.e.,

\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{init})}}\label{eq:init}\end{equation}

where $\mathcal{L}_i^{(\text{init})}$ represents the initial loss value of task $i$. This formula is "homogeneous" with respect to each $\mathcal{L}_i$, so one obvious advantage is scale invariance. That is, if you multiply the loss of task $i$ by a constant, the result will not change. Furthermore, since each loss is divided by its own initial value, larger losses are scaled down and smaller losses are scaled up, allowing each loss to be roughly balanced.

So, how do we estimate $\mathcal{L}_i^{(\text{init})}$? The most direct method is, of course, to estimate it using a few batches of data. Alternatively, we can derive a theoretical value based on certain assumptions. For instance, under mainstream initialization, we can assume the initial model output (before the activation function) is a zero vector. If softmax is added, it results in a uniform distribution. Thus, for a "$K$-category classification + cross-entropy" problem, the initial loss is $\log K$. For a "regression + L2 loss" problem, the initial loss can be estimated using the zero vector: $\mathbb{E}_{y\sim \mathcal{D}}[\Vert y-0\Vert^2] = \mathbb{E}_{y\sim \mathcal{D}}[\Vert y\Vert^2]$, where $\mathcal{D}$ denotes the labels in the training set.

Prior State

One issue with using the initial loss is that the initial state may not accurately reflect the current task's learning difficulty. A better approach is to change the "initial state" to a "prior state":

\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{prior})}}\label{eq:prior}\end{equation}

For example, if the frequencies of each class in a $K$-classification task are $[p_1,p_2,\dots,p_K]$ (the prior distribution), then while the initial state's predicted distribution is uniform, we can reasonably assume the model can easily learn to predict the result for every sample as $[p_1,p_2,\dots,p_K]$. In this case, the model's loss is the entropy:

\begin{equation}\mathcal{L}_i^{(\text{prior})}=\mathcal{H} = -\sum_{i=1}^K p_i\log p_i\end{equation}

In a sense, the "prior distribution" reflects the essence of "initial" better than the "initial distribution." it represents what "the model knows even if it learns nothing else—it knows to output results according to the prior distribution." Thus, the loss value at this point better represents the initial difficulty of the task. Therefore, replacing $\mathcal{L}_i^{(\text{init})}$ with $\mathcal{L}_i^{(\text{prior})}$ should be more reasonable. Similarly, for a "regression + L2 loss" problem, the prior result should be the expectation of all labels $\mu = \mathbb{E}_{y\sim \mathcal{D}}[y]$, so we use $\mathcal{L}_i^{(\text{prior})}=\mathbb{E}_{y\sim \mathcal{D}}[\Vert y-\mu\Vert^2]$ to replace $\mathcal{L}_i^{(\text{init})}=\mathbb{E}_{y\sim \mathcal{D}}[\Vert y\Vert^2]$, which is expected to yield more reasonable results.

Dynamic Adjustment

Regardless of whether formula $\eqref{eq:init}$ or formula $\eqref{eq:prior}$ is used, the task weights remain fixed once determined, and the method for determining them does not depend on the learning process. However, although we can roughly perceive task difficulty through prior distributions, the true difficulty can only be known during actual learning. Therefore, a more reasonable approach should dynamically adjust weights according to the training process.

Real-time State

Reviewing the previous content, the core idea of formulas $\eqref{eq:init}$ and $\eqref{eq:prior}$ is to use the reciprocal of the loss value as the task weight. Can we simply use the "real-time" inverse of the loss value to achieve dynamic weight adjustment? That is,

\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}}\label{eq:sg}\end{equation}

Here, $\mathcal{L}_i^{(\text{sg})}$ is shorthand for $\text{stop\_gradient}(\mathcal{L}_i)$. In this scheme, the loss function for each task is adjusted to be consistently 1, making them consistent in both scale and magnitude. Due to the presence of the $\text{stop\_gradient}$ operator, although the loss is constant at 1, the gradient is not zero:

\begin{equation}\nabla_{\theta}\left(\frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}}\right) = \frac{\nabla_{\theta}\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}} = \frac{\nabla_{\theta}\mathcal{L}_i}{\mathcal{L}_i}\label{eq:sg-grad}\end{equation}

Simply put, when a function is wrapped by the $\text{stop\_gradient}$ operator, it becomes a new function whose value is identical to the original function, but its derivative is forced to zero. The final result is that the gradient proportions are adjusted in real-time using the dynamic weight $1/\mathcal{L}_i$. Many "informal experiments" indicate that formula $\eqref{eq:sg}$ serves as a very good baseline in most cases.

Equivalent Gradient

We can look at this scheme from another angle. From formula $\eqref{eq:sg-grad}$, we get:

\begin{equation}\nabla_{\theta}\left(\frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}}\right) = \frac{\nabla_{\theta}\mathcal{L}_i}{\mathcal{L}_i} = \nabla_{\theta} \log \mathcal{L}_i\end{equation}

Therefore, in terms of gradients, formula $\eqref{eq:sg}$ is essentially no different from $\mathcal{L} = \sum\limits_{i=1}^n \log \mathcal{L}_i$. Furthermore, we have:

\begin{equation}\mathcal{L} = \sum_{i=1}^n \log \mathcal{L}_i = n\log \sqrt[n]{\prod_{i=1}^n\mathcal{L}_i}\end{equation}

Since $\log$ is monotonically increasing, formula $\eqref{eq:sg}$ is consistent in gradient direction with the following:

\begin{equation}\mathcal{L} = \sqrt[n]{\prod_{i=1}^n\mathcal{L}_i}\end{equation}

Generalized Mean

Clearly, the formula above is the "geometric mean" of $\mathcal{L}_1,\mathcal{L}_2,\cdots,\mathcal{L}_n$. If we fix $a_i$ as $1/n$, the original formula $\eqref{eq:w-loss}$ is the "arithmetic mean" of $\mathcal{L}_1,\mathcal{L}_2,\cdots,\mathcal{L}_n$. In other words, we find that this series of derivations actually hides a transition from the arithmetic mean to the geometric mean. This inspires us to consider the "generalized mean":

\begin{equation}\mathcal{L}(\gamma) = \sqrt[\gamma]{\frac{1}{n}\sum_{i=1}^n\mathcal{L}_i^{\gamma}}\end{equation}

This involves raising each loss function to the power of $\gamma$, averaging them, and then taking the $\gamma$-th root. $\gamma$ can be any real number. The arithmetic mean corresponds to $\gamma=1$, and the geometric mean corresponds to $\gamma=0$ (taking the limit). It can be proven that $\mathcal{L}(\gamma)$ is a monotonically increasing function of $\gamma$, and we have:

\begin{equation}\min(\mathcal{L}_1,\cdots,\mathcal{L}_n)=\lim_{\gamma\to-\infty} \mathcal{L}(\gamma) \leq\cdots\leq \mathcal{L}(\gamma) \leq\cdots\leq \lim_{\gamma\to+\infty} \mathcal{L}(\gamma)=\max(\mathcal{L}_1,\cdots,\mathcal{L}_n)\end{equation}

This means that as $\gamma$ increases, the model becomes increasingly concerned with the maximum value among the losses, and conversely, it becomes more concerned with the minimum value. In this way, although there is still a hyperparameter $\gamma$ to adjust, compared to the original formula $\eqref{eq:w-loss}$, the number of hyperparameters has decreased from $n$ to just 1, simplifying the tuning process.

Translation Invariance

Reviewing formulas $\eqref{eq:init}$, $\eqref{eq:prior}$, and $\eqref{eq:sg}$, they all adjust weights by dividing each task loss by some state of itself, achieving scale invariance. However, while they possess scale invariance, they lose the more fundamental "translation invariance." That is, if a constant is added to each loss, the gradient directions of $\eqref{eq:init}$, $\eqref{eq:prior}$, and $\eqref{eq:sg}$ might change. This is not good news for optimization because, in principle, constants do not bring any meaningful information, and the optimization result should not change as a result.

Ideal Goal

On one hand, we use the reciprocal of the loss function (or some state of it) as the current task weight, but the derivative of the loss function does not have translation invariance. On the other hand, the loss function can be understood as the distance between the current model and the target state, while gradient descent essentially looks for points where the gradient is 0. Thus, the magnitude of the gradient can actually play a similar role. Therefore, we can replace the loss function with the gradient magnitude, transforming formula $\eqref{eq:sg}$ into:

\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\Vert\nabla_{\theta}\mathcal{L}_i\Vert^{(\text{sg})}}\label{eq:grad}\end{equation}

An obvious difference from the loss function is that the gradient magnitude clearly possesses translation invariance. Furthermore, the numerator and denominator are still homogeneous with respect to $\mathcal{L}_i$, so the above formula also retains scale invariance. Thus, this is an ideal goal that simultaneously possesses translation and scale invariance.

Gradient Normalization

Taking the gradient of formula $\eqref{eq:grad}$, we get:

\begin{equation}\nabla_{\theta}\mathcal{L} = \sum_{i=1}^n \frac{\nabla_{\theta}\mathcal{L}_i}{\Vert\nabla_{\theta}\mathcal{L}_i\Vert}\label{eq:grad-norm}\end{equation}

As can be seen, formula $\eqref{eq:grad}$ essentially normalizes the gradient of each task loss and then accumulates the gradients. It also provides an implementation scheme: we can train each task sequentially, training only one task each time, then accumulate the normalized gradient of each task before updating. This eliminates the trouble of having to calculate gradients while defining the loss function.

Regarding gradient normalization, the related work I found is "GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks". It is essentially a hybrid of formula $\eqref{eq:init}$ and formula $\eqref{eq:grad-norm}$, also including the idea of re-scaling gradient magnitudes, but it requires additional optimization to determine task weights, which I personally find cumbersome and redundant.

Summary

From the perspective of loss functions, the key issue in multi-task learning is how to adjust the weight of each task to balance their respective losses. This article introduced some reference practices from the viewpoints of scale invariance and translation invariance, and supplemented the concept of "generalized mean," transforming the weight adjustment of multiple tasks into a single-parameter adjustment problem, which can simplify the difficulty of hyperparameter tuning.