Multi-Task Learning (Part 3): The Order of Primary and Secondary

By 苏剑林 | February 14, 2022

Multi-task learning (MTL) is a very broad topic, and the objectives of MTL vary in different scenarios. In "Multi-Task Learning (Part 1): In the Name of Loss" and "Multi-Task Learning (Part 2): By Way of Gradient", we understood the goal of MTL as "doing every task well," which translates to "treating every task as equally as possible." We can call this "Parallel Multi-Task Learning." However, not all MTL objectives are like this. In many scenarios, our primary goal is to learn a specific main task well, while the other tasks serve only as auxiliary tasks, added in the hope of improving the performance of the main task. We can call this type of scenario "Primary-Secondary Multi-Task Learning."

In this context, if we continue to use the learning scheme of Parallel MTL aimed at "doing every task well," it may significantly degrade the performance of the main task. Therefore, this article continues the idea of "acting by gradient" to explore training schemes for Primary-Secondary MTL.

Objective Form

In this article, we assume the reader has already read and basically understood the ideas and methods in "Multi-Task Learning (Part 2): By Way of Gradient". From a gradient perspective, a necessary condition for a loss function to keep decreasing is that the angle between the update and its gradient is at least greater than 90 degrees. This is the design philosophy running through this text.

Constraint Optimization

Now assume that in a Primary-Secondary MTL scenario, we have $n+1$ task loss functions, denoted as $\mathcal{L}_0, \mathcal{L}_1, \dots, \mathcal{L}_n$, where $\mathcal{L}_0$ is the primary task loss that we want to minimize as much as possible. $\mathcal{L}_1, \dots, \mathcal{L}_n$ are auxiliary losses, acting like regularization terms; we only hope they do not increase during the training process, but they don't necessarily have to "strive to become smaller."

Following the notation from the previous article, we denote the update at each step as $\Delta\boldsymbol{\theta}=-\eta\boldsymbol{u}$. Since we treat $\mathcal{L}_0$ as the primary task, we naturally want to maximize the inner product between $\boldsymbol{u}$ and $\boldsymbol{g}_0$. We can design the optimization objective as: \begin{equation}\max_{\boldsymbol{u}} \langle\boldsymbol{u},\boldsymbol{g}_0\rangle - \frac{1}{2}\Vert\boldsymbol{u}\Vert^2\end{equation} where $\boldsymbol{g}_i = \nabla_{\boldsymbol{\theta}}\mathcal{L}_i$ is the gradient of the corresponding loss. Without other constraints, the solution would be $\boldsymbol{u} = \boldsymbol{g}_0$, which is ordinary gradient descent. However, we also have auxiliary tasks $\mathcal{L}_1, \dots, \mathcal{L}_n$. We hope they do not move in an increasing direction, so at the very least, we must ensure $\langle\boldsymbol{u},\boldsymbol{g}_1\rangle\geq 0, \dots, \langle\boldsymbol{u},\boldsymbol{g}_n\rangle\geq 0$. These are the optimization constraints, leading to the total objective: \begin{equation}\max_{\boldsymbol{u}} \langle\boldsymbol{u},\boldsymbol{g}_0\rangle - \frac{1}{2}\Vert\boldsymbol{u}\Vert^2\quad\text{s.t.}\,\, \langle\boldsymbol{u},\boldsymbol{g}_1\rangle\geq 0,\dots,\langle\boldsymbol{u},\boldsymbol{g}_n\rangle\geq 0\end{equation} By solving this constrained optimization problem, we can obtain an update amount that satisfies the conditions.

Lagrange Multipliers

The standard way to solve such constrained optimization problems is the method of Lagrange multipliers. It integrates the constraints into the objective function, transforming it into a min-max problem: \begin{equation}\max_{\boldsymbol{u}} \min_{\lambda_i\geq 0}\langle\boldsymbol{u},\boldsymbol{g}_0\rangle - \frac{1}{2}\Vert\boldsymbol{u}\Vert^2 + \sum_i \lambda_i \langle\boldsymbol{u},\boldsymbol{g}_i\rangle\label{eq:q-1}\end{equation}

Here, the summation over $i$ is from $1$ to $n$. How do we understand this transformation? If $\langle\boldsymbol{u},\boldsymbol{g}_i\rangle > 0$, then the $\min_{\lambda_i \geq 0}$ step will force $\lambda_i = 0$ because only $\lambda_i = 0$ minimizes the expression, making $\lambda_i \langle\boldsymbol{u},\boldsymbol{g}_i\rangle = 0$. If $\langle\boldsymbol{u},\boldsymbol{g}_i\rangle = 0$, then naturally $\lambda_i \langle\boldsymbol{u},\boldsymbol{g}_i\rangle = 0$. If $\langle\boldsymbol{u},\boldsymbol{g}_i\rangle < 0$, the $\min_{\lambda_i \geq 0}$ step would yield $\lambda_i \to \infty$, making $\lambda_i \langle\boldsymbol{u},\boldsymbol{g}_i\rangle \to -\infty$. But remember, the optimization of $\boldsymbol{u}$ is a $\max$ operation, so between $0$ and $-\infty$, it will naturally choose $0$. That is, after completing this min-max optimization, we will automatically have $\langle\boldsymbol{u},\boldsymbol{g}_i\rangle \geq 0$ and $\lambda_i \langle\boldsymbol{u},\boldsymbol{g}_i\rangle = 0$. This means the result of the min-max optimization is exactly equivalent to the original constrained max optimization.

To facilitate subsequent derivations, we introduce notation similar to the previous article: \begin{equation}\mathbb{Q}^n=\left\{(\lambda_1,\dots,\lambda_n)\left\|\lambda_1,\dots,\lambda_n\geq 0\right.\right\},\quad\tilde{\boldsymbol{g}}(\lambda) = \sum_i \lambda_i \boldsymbol{g}_i\end{equation} Then Eq. $\eqref{eq:q-1}$ can be written as: \begin{equation}\max_{\boldsymbol{u}} \min_{\lambda\in\mathbb{Q}^n}\langle\boldsymbol{u},\boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda)\rangle - \frac{1}{2}\Vert\boldsymbol{u}\Vert^2\label{eq:q-2}\end{equation}

Solving Algorithm

At this point, we have converted the problem of finding the update direction for Primary-Secondary MTL into a min-max problem $\eqref{eq:q-2}$. Next, similar to the previous article, we will use the Minimax theorem to swap the order of $\max$ and $\min$, then use the Frank-Wolfe algorithm to provide a solution method, and finally compare it with the results of the previous article.

Swapping Order

Note that the $\max$ and $\min$ in problem $\eqref{eq:q-2}$ are ordered. Normally, the $\min$ step must be completed before the $\max$ step; rashly swapping them might lead to incorrect results. However, the $\min$ step is a constrained optimization, while the $\max$ step is unconstrained, making the $\max$ step relatively simpler. If we can swap the order and perform the $\max$ step first, the problem will be simplified.

We first need to determine if the order can be swapped. Fortunately, von Neumann proposed the beautiful Minimax Theorem, which tells us that if the parameter domains of both $\min$ and $\max$ are convex sets, and the objective function is convex with respect to the $\min$ parameters and concave with respect to the $\max$ parameters, then the order can be swapped. Even better, it's easy to see that problem $\eqref{eq:q-2}$ satisfies the conditions of the Minimax theorem, so it is equivalent to: \begin{equation}\min_{\lambda\in\mathbb{Q}^n}\max_{\boldsymbol{u}} \langle\boldsymbol{u},\boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda)\rangle - \frac{1}{2}\Vert\boldsymbol{u}\Vert^2 =\min_{\lambda\in\mathbb{Q}^n}\frac{1}{2}\Vert\boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda)\Vert^2\label{eq:q-3}\end{equation} Thus, we simplify the problem to a single $\min$ operation. The equality on the right holds because the objective function is just a quadratic function of $\boldsymbol{u}$, and its maximum is reached at $\boldsymbol{u}^* = \boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda)$. Substituting this gives the result on the right side.

Simple Case

The problem has now become finding the minimal magnitude of the weighted superposition of $\boldsymbol{g}_0$ and $\boldsymbol{g}_1, \dots, \boldsymbol{g}_n$. As per custom, let's first solve the simplest case, $n=1$, which is $\min_{\gamma \geq 0} \Vert\boldsymbol{g}_0 + \gamma\boldsymbol{g}_1\Vert^2$. This has a clear geometric meaning and a simple analytical solution.

Exact solution of simple example
Exact solution of a simple example

As shown in the figure above, there are two cases: The first case is $\langle\boldsymbol{g}_0,\boldsymbol{g}_1\rangle \geq 0$, which means $\boldsymbol{g}_0$ and $\boldsymbol{g}_1$ are not in conflict, so we can just set $\gamma=0$. The second case is $\langle\boldsymbol{g}_0,\boldsymbol{g}_1\rangle < 0$. As seen in the right figure, the minimum value of $\Vert\boldsymbol{g}_0 + \gamma\boldsymbol{g}_1\Vert^2$ is reached when $\boldsymbol{g}_0 + \gamma\boldsymbol{g}_1$ is perpendicular to $\boldsymbol{g}_1$. Thus, solving $\langle \boldsymbol{g}_0 + \gamma\boldsymbol{g}_1,\boldsymbol{g}_1\rangle=0$ gives $\gamma = -\frac{\langle \boldsymbol{g}_0,\boldsymbol{g}_1\rangle}{\Vert\boldsymbol{g}_1\Vert^2}$. Finally, when $\Vert\boldsymbol{g}_1\Vert\neq 0$, this can be written uniformly as: \begin{equation}\gamma = \frac{\text{relu}(-\langle \boldsymbol{g}_0,\boldsymbol{g}_1\rangle)}{\Vert\boldsymbol{g}_1\Vert^2}\label{eq:gamma}\end{equation}

Iterative Solution

Next, we handle the general case. The idea is still derived from the Frank-Wolfe algorithm.

First, we find the feasible direction for the next update $e_{\tau}$ by finding $\tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda^{(k)})\rangle$. Next, we perform a one-dimensional search, but differently from before, this time we do not search by interpolating between $\lambda^{(k)}$ and $e_{\tau}$. Instead, we directly re-determine the coefficient corresponding to $\boldsymbol{g}_{\tau}$. That is, we remove the $\boldsymbol{g}_{\tau}$ part from $\tilde{\boldsymbol{g}}(\lambda^{(k)})$ and re-calculate the coefficient for $\boldsymbol{g}_{\tau}$ using the $n=1$ case algorithm.

From this, we obtain the following iterative process: \begin{equation}\left\{\begin{aligned} &\tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \boldsymbol{g}_0+\tilde{\boldsymbol{g}}(\lambda^{(k)})\rangle\\ &\gamma = \mathop{\text{argmin}}_{\gamma} \left\Vert\boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda^{(k)} - \lambda^{(k)}_{\tau} e_{\tau} + \gamma e_{\tau})\right\Vert^2 = \mathop{\text{argmin}}_{\gamma} \left\Vert\boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda^{(k)}) - \lambda^{(k)}_{\tau}\boldsymbol{g}_{\tau} + \gamma \boldsymbol{g}_{\tau}\right\Vert^2\\ &\lambda^{(k+1)} = \lambda^{(k)} - \lambda^{(k)}_{\tau} e_{\tau} + \gamma e_{\tau} \end{aligned}\right.\end{equation}

Comparison

At this point, we have completed the solution for the Primary-Secondary MTL explored in this article. For students who have carefully derived the mathematical results of both articles, the methods and results of Parallel MTL and Primary-Secondary MTL will certainly seem very similar. Indeed, they share many similarities, though they differ in subtle ways.

To deepen everyone's understanding, we can compare the similarities and differences between these two types of MTL as follows:

$$\small \begin{array}{c|c|c} \hline & \text{Parallel MTL (Previous Part)} & \text{Primary-Secondary MTL (Current Part)} \\ \hline \text{Objective Overview} & \text{Do every task well} & \text{Do primary task well, don't let auxiliaries worsen} \\ \hline \text{Incremental Format} & \Delta\boldsymbol{\theta} = -\eta\boldsymbol{u} & \Delta\boldsymbol{\theta} = -\eta\boldsymbol{u} \\ \hline \text{Mathematical Definition} & \max\limits_{\boldsymbol{u}}\min\limits_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle - \frac{1}{2}\Vert \boldsymbol{u}\Vert^2 & {\begin{array}{l}\max\limits_{\boldsymbol{u}} \langle\boldsymbol{u},\boldsymbol{g}_0\rangle - \frac{1}{2}\Vert\boldsymbol{u}\Vert^2 \\ \text{s.t.}\,\, \langle\boldsymbol{u},\boldsymbol{g}_1\rangle\geq 0,\dots,\langle\boldsymbol{u},\boldsymbol{g}_n\rangle\geq 0\end{array}} \\ \hline \text{Dual Result} & \min\limits_{\alpha\in\mathbb{P}^n}^{\,^\,}\Vert\tilde{\boldsymbol{g}}(\alpha)\Vert^2 & \min\limits_{\lambda\in\mathbb{Q}^n}\Vert\boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda)\Vert^2 \\ \hline \text{Direction Vector} & \boldsymbol{u}=\tilde{\boldsymbol{g}}(\alpha)=\sum\limits_i^{\\,^\,} \alpha_i \boldsymbol{g}_i & \boldsymbol{u}=\boldsymbol{g}_0+\tilde{\boldsymbol{g}}(\lambda)=\boldsymbol{g}_0 + \sum\limits_i \lambda_i \boldsymbol{g}_i \\ \hline \text{Feasible Space} & \mathbb{P}^n = \left\{(\alpha_1,\dots,\alpha_n)\left\|\forall\alpha_i\geq 0, \sum\limits_i \alpha_i = 1\right.\right\} & \mathbb{Q}^n=\left\{(\lambda_1,\dots,\lambda_n)\left\|\forall\lambda_i\geq 0\right.\right\} \\ \hline \text{Iterative Steps} & \left\{\begin{aligned} &\tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \tilde{\boldsymbol{g}}(\alpha^{(k)})\rangle\\ &\gamma = \mathop{\text{argmin}}_{\gamma} \left\Vert(1-\gamma)\tilde{\boldsymbol{g}}(\alpha^{(k)}) + \gamma \boldsymbol{g}_{\tau}\right\Vert^2\\ &\alpha^{(k+1)} = (1-\gamma)\alpha^{(k)} + \gamma e_{\tau} \end{aligned}\right. & \left\{\begin{aligned} &\tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \boldsymbol{g}_0+\tilde{\boldsymbol{g}}(\lambda^{(k)})\rangle\\ &\gamma = \mathop{\text{argmin}}_{\gamma} \left\Vert\boldsymbol{g}_0 + \tilde{\boldsymbol{g}}(\lambda^{(k)}) - \lambda^{(k)}_{\tau}\boldsymbol{g}_{\tau} + \gamma \boldsymbol{g}_{\tau}\right\Vert^2\\ &\lambda^{(k+1)} = \lambda^{(k)} - \lambda^{(k)}_{\tau} e_{\tau} + \gamma e_{\tau} \end{aligned}\right. \\ \hline \end{array}$$

Through this comparison, it is not difficult to generalize the results to a "hybrid" MTL with $n$ primary tasks and $m$ auxiliary tasks. The dual result would be: \begin{equation}\min_{\alpha\in\mathbb{P}^n,\lambda\in\mathbb{Q}^m}\Vert\tilde{\boldsymbol{g}}(\alpha) + \tilde{\boldsymbol{g}}(\lambda)\Vert^2\end{equation} As for the specific iterative algorithm, please think about it yourself~

Application Thoughts

In this section, we use several examples to show that many common problems can be mapped to MTL with a distinction between primary and secondary. In some sense, Primary-Secondary MTL may be even more common than Parallel MTL.

Regularization Loss

The most common example might be regularization terms added to the task loss function, such as L2 regularization: \begin{equation}\mathcal{L}(\boldsymbol{\theta}) + \frac{\lambda}{2}\Vert\boldsymbol{\theta}\Vert^2\end{equation} If we treat $\mathcal{L}(\boldsymbol{\theta})$ and $\frac{1}{2}\Vert\boldsymbol{\theta}\Vert^2$ as losses for two tasks, this can also be viewed as an MTL problem. Clearly, we don't necessarily want $\frac{1}{2}\Vert\boldsymbol{\theta}\Vert^2$ to be as small as possible; we only hope it improves the generalization performance of $\mathcal{L}(\boldsymbol{\theta})$. Thus, it doesn't fit Parallel MTL but is much closer to Primary-Secondary MTL.

The gradient of the L2 regularization term $\frac{1}{2}\Vert\boldsymbol{\theta}\Vert^2$ is simple: it's just $\boldsymbol{\theta}$. Then, applying the result of this article from Eq. $\eqref{eq:gamma}$, we can modify the optimizer to change the gradient term to: \begin{equation}\boldsymbol{g} + \frac{\text{relu}(-\langle \boldsymbol{g},\boldsymbol{\theta}\rangle)}{\Vert\boldsymbol{\theta}\Vert^2}\end{equation} In this way, we can add L2 regularization to the model without having to tune the regularization coefficient $\lambda$. Of course, one could also decouple weight decay like in AdamW by processing the original update amount.

Besides direct parameter regularization, there are many other forms of auxiliary losses, such as adding contrastive learning loss to classification models, or adding length penalties to generative models, and so on. These practices can more or less be mapped to Primary-Secondary MTL, so one can try to apply the results of this article. If calculating the full gradient is computationally expensive, one can use the approximation for the "Shared Encoding" case mentioned in the previous article to reduce the computational cost.

Learning with Noise

In addition, there is a common training scenario that people might not realize is an MTL problem, but it essentially can be understood as such: "Learning with Noise."

Suppose for the same task, we only have a small amount of accurately labeled clean data, but also a large amount of noisy data. Because there is more noisy data, we tend to learn primarily from it, assuming the corresponding loss is $\mathcal{L}_0$. However, since the data contains noise, purely minimizing $\mathcal{L}_0$ might not result in an ideal model; it might memorize the incorrect labels. This is where the clean data comes in handy—we can calculate a loss $\mathcal{L}_1$ using the clean data. Since the clean data has less noise, we can assume $\mathcal{L}_1$ better reflects the model's true performance. We can then add a restriction:

No matter how you minimize $\mathcal{L}_0$, you cannot let $\mathcal{L}_1$ increase. In other words, you can train with noisy data, but you cannot let the performance on clean data get worse.

This is exactly a Primary-Secondary MTL problem with $\mathcal{L}_0$ as primary and $\mathcal{L}_1$ as secondary!

Coincidentally, a Google paper from last year, "Gradient-guided Loss Masking for Neural Machine Translation", also presented a similar approach, though the details are slightly different. It calculates the gradient of each noisy sample with respect to the parameters and only keeps samples whose gradient has an angle less than 90 degrees (inner product greater than 0) with the clean data gradient $\nabla_{\boldsymbol{\theta}} \mathcal{L}_1$. That is, everyone uses the inner product with the clean data gradient as a criterion. The difference is that in Primary-Secondary MTL, if the inner product is less than 0, a correction is made to the update amount, whereas in Google's article, the corresponding sample is directly discarded.

Summary

This article generalizes the results of Parallel MTL from the previous post to "Primary-Secondary" MTL. In this case, the goal of MTL is no longer to do all tasks well, but to focus on one main task while using others as auxiliaries. The results have many similarities with the original Parallel MTL but differ in subtle ways. Finally, some classic examples of Primary-Secondary MTL, such as regularization terms and learning with noise, were introduced.