By 苏剑林 | February 8, 2022
In "Multitask Learning Chat (Part 1): In the Name of Loss", we initially explored the multitask learning problem from the perspective of loss functions. We ultimately found that if we want the results to possess both scaling invariance and translation invariance, choosing the reciprocal of the gradient norm as the task weight is a relatively simple choice. We further analyzed that this design is equivalent to normalizing each task's gradient individually and then summing them. This means the "battlefield" of multitask learning has shifted from the loss function to the gradient: what appears to be designing the loss function is, in fact, designing better gradients—what we call "In the name of loss, acting via gradients."
So, what are the standards for "better" gradients? How can we design them? In this article, we will understand multitask learning from the perspective of gradients, attempting to build multitask learning algorithms directly from the idea of designing gradients.
Overall Idea
We know that for single-task learning, the commonly used optimization method is gradient descent. How is it derived? Can the same logic be directly applied to multitask learning? This section aims to answer these questions.
Descent Direction
We actually answered the first question in "Optimization Algorithms from a Dynamical Perspective (3): A More Holistic View". Suppose the loss function is $\mathcal{L}$ and the current parameter is $\boldsymbol{\theta}$. We want to design a parameter increment $\Delta\boldsymbol{\theta}$ that makes the loss function smaller, i.e., $\mathcal{L}(\boldsymbol{\theta}+\Delta\boldsymbol{\theta}) < \mathcal{L}(\boldsymbol{\theta})$. To this end, we consider the first-order Taylor expansion:
\begin{equation}\mathcal{L}(\boldsymbol{\theta}+\Delta\boldsymbol{\theta})\approx \mathcal{L}(\boldsymbol{\theta}) + \langle \nabla_{\boldsymbol{\theta}}\mathcal{L}, \Delta\boldsymbol{\theta}\rangle \label{eq:approx-1}\end{equation}
Assuming this approximation is accurate enough, $\mathcal{L}(\boldsymbol{\theta}+\Delta\boldsymbol{\theta}) < \mathcal{L}(\boldsymbol{\theta})$ implies $\langle \nabla_{\boldsymbol{\theta}}\mathcal{L}, \Delta\boldsymbol{\theta}\rangle < 0$. This means the angle between the update amount and the gradient must be greater than 90 degrees. The most natural choice among these is:
\begin{equation}\Delta\boldsymbol{\theta} = -\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}\end{equation}
This is gradient descent, where the update is taken in the direction opposite to the gradient, and $\eta > 0$ is the learning rate.
No Exceptions
Returning to multitask learning, if we assume every task is equally important, we can interpret this assumption as requiring $\mathcal{L}_1, \mathcal{L}_2, \cdots, \mathcal{L}_n$ to all decrease or remain unchanged at every update step. If after reaching $\boldsymbol{\theta}^*$, any further change would lead to an increase in at least one $\mathcal{L}_i$, then $\boldsymbol{\theta}^*$ is said to be a Pareto optimal solution. Simply put, Pareto optimality means we cannot improve one task at the expense of another; it means there is no "involution" (internal conflict) between tasks.
Assuming the approximation $\eqref{eq:approx-1}$ still holds, searching for Pareto optimality means we need to find $\Delta\boldsymbol{\theta}$ satisfying:
\begin{equation}\left\{\begin{aligned}
&\langle \nabla_{\boldsymbol{\theta}}\mathcal{L}_1, \Delta\boldsymbol{\theta}\rangle \leq 0\\
&\langle \nabla_{\boldsymbol{\theta}}\mathcal{L}_2, \Delta\boldsymbol{\theta}\rangle \leq 0\\
&\quad \vdots \\
&\langle \nabla_{\boldsymbol{\theta}}\mathcal{L}_n, \Delta\boldsymbol{\theta}\rangle \leq 0\\
\end{aligned}\right.\end{equation}
Note that this has a trivial solution $\Delta\boldsymbol{\theta}=\boldsymbol{0}$, so the feasible region of the above inequality system is definitely non-empty. We are primarily interested in whether non-zero solutions exist in the feasible region: if so, we pick one as the update direction; if not, it is possible that Pareto optimality has already been reached (necessary but not sufficient condition), a state we call a Pareto Stationary point.
Solution Algorithms
For convenience, we denote $\boldsymbol{g}_i=\nabla_{\boldsymbol{\theta}}\mathcal{L}_i$. We seek a vector $\boldsymbol{u}$ such that $\langle \boldsymbol{g}_i, \boldsymbol{u}\rangle \geq 0$ for all $i$. Then we can take $\Delta\boldsymbol{\theta}=-\eta\boldsymbol{u}$ as the update amount. When there are only two tasks, it can be verified that $\boldsymbol{u}=\boldsymbol{g}_1/\|\boldsymbol{g}_1\| + \boldsymbol{g}_2/\|\boldsymbol{g}_2\|$ automatically satisfies $\langle \boldsymbol{g}_1, \boldsymbol{u}\rangle \geq 0$ and $\langle \boldsymbol{g}_2, \boldsymbol{u}\rangle \geq 0$. That is to say, in dual-task learning, the previously mentioned gradient normalization can reach a Pareto stationary point.
When the number of tasks is greater than 2, the problem becomes more complex. Here we introduce two solving methods. The first approach is the derivation the author produced themselves, while the second approach is the "standard answer" provided by "Multi-Task Learning as Multi-Objective Optimization".
Problem Transformation
First, we further transform the problem. Note that:
\begin{equation}\forall i, \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle \geq 0\quad\Leftrightarrow\quad \min_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle \geq 0\label{eq:q-0}\end{equation}
So we only need to maximize the minimum $\langle \boldsymbol{g}_i, \boldsymbol{u}\rangle$ as much as possible to find the ideal $\boldsymbol{u}$, transforming the problem into:
\begin{equation}\max_{\boldsymbol{u}}\min_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle \end{equation}
However, this is a bit dangerous because if a non-zero $\boldsymbol{u}$ exists such that $\min_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle > 0$, letting the norm of $\boldsymbol{u}$ go to infinity would make the maximum value go to infinity. To ensure stability, we need a regularization term. Consider:
\begin{equation}\max_{\boldsymbol{u}}\min_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle - \frac{1}{2}\Vert \boldsymbol{u}\Vert^2\label{eq:q-1}\end{equation}
In this way, a $\boldsymbol{u}$ with an infinite norm cannot be the optimal solution. Note that substituting $\boldsymbol{u}=0$ gives $\min_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle - \frac{1}{2}\Vert \boldsymbol{u}\Vert^2=0$. Thus, assuming the optimal solution for $\boldsymbol{u}$ is $\boldsymbol{u}^*$, we must have:
\begin{equation}\min_i \langle \boldsymbol{g}_i, \boldsymbol{u}^*\rangle - \frac{1}{2}\Vert \boldsymbol{u}^*\Vert^2\geq 0\quad\Leftrightarrow\quad \min_i \langle \boldsymbol{g}_i, \boldsymbol{u}^*\rangle \geq \frac{1}{2}\Vert \boldsymbol{u}^*\Vert^2\geq 0\end{equation}
Therefore, the solution to problem $\eqref{eq:q-1}$ is necessarily a solution satisfying condition $\eqref{eq:q-0}$. If it is a non-zero solution, then its reverse direction is necessarily a direction that decreases the loss of all tasks.
Smooth Approximation
Now we introduce the first solution scheme for problem $\eqref{eq:q-1}$. It assumes the reader, like the author, is unfamiliar with min-max problems. We can replace the initial $\min$ with a smooth approximation (refer to "Seeking a Smooth Maximum Function"):
\begin{equation}\min_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle \approx -\frac{1}{\lambda}\log\sum_i e^{-\lambda \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle}\,\,\big(\text{for sufficiently large }\lambda\big)\end{equation}
Thus, we can first solve:
\begin{equation}\max_{\boldsymbol{u}}-\frac{1}{\lambda}\log\sum_i e^{-\lambda \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle} - \frac{1}{2}\Vert \boldsymbol{u}\Vert^2\end{equation}
And then let $\lambda\to\infty$. We have converted the problem into a single unconstrained maximization problem. Taking the gradient and setting it to zero gives:
\begin{equation}\frac{\sum\limits_i e^{-\lambda \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle} \boldsymbol{g}_i}{\sum\limits_i e^{-\lambda \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle}} = \boldsymbol{u}\end{equation}
Assuming the differences between various $\langle \boldsymbol{g}_i, \boldsymbol{u}\rangle$ are greater than the $\mathcal{O}(1/\lambda)$ magnitude, when $\lambda\to\infty$, the above expression actually becomes:
\begin{equation}\boldsymbol{u} = \boldsymbol{g}_{\tau},\quad \tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle\end{equation}
However, if one iterates directly according to $\boldsymbol{u}^{(k+1)} = \boldsymbol{g}_{\tau},\tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \boldsymbol{u}^{(k)}\rangle$, it will likely oscillate. This is because it asks us to find the $\boldsymbol{g}_i$ that minimizes $\langle \boldsymbol{g}_i, \boldsymbol{u}^{(k)}\rangle$ as $\boldsymbol{u}^{(k+1)}$. If we set $\boldsymbol{u}^{(k+1)}=\boldsymbol{g}_{i^*}$, then in the next step, the $\boldsymbol{g}_i$ that minimizes $\langle \boldsymbol{g}_i, \boldsymbol{u}^{(k+1)}\rangle=\langle \boldsymbol{g}_i, \boldsymbol{g}_{i^*}\rangle$ will very likely no longer be $\boldsymbol{g}_{i^*}$; in fact, $\boldsymbol{g}_{i^*}$ might be the maximum one.
Intuitively, although the algorithm oscillates, it should oscillate around the optimal point $\boldsymbol{u}^*$. Therefore, if we average all the results during the oscillation process, we should obtain the optimal point. This means the iteration format that converges to the optimal point is:
\begin{equation}\boldsymbol{u}^{(k+1)} = \frac{k \boldsymbol{u}^{(k)} + \boldsymbol{g}_{\tau}}{k + 1},\quad \tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \boldsymbol{u}^{(k)}\rangle\label{eq:sol-1}\end{equation}
Note that since the added term each time is some $\boldsymbol{g}_i$, the final $\boldsymbol{u}^*$ must be a weighted average of various $\boldsymbol{g}_i$, i.e., there exist $\alpha_1,\alpha_2,\cdots,\alpha_n\geq 0$ and $\alpha_1 + \alpha_2 + \cdots + \alpha_n =1$, such that:
\begin{equation}\boldsymbol{u}^* = \sum_i \alpha_i \boldsymbol{g}_i\end{equation}
We can also understand $\alpha_1,\alpha_2,\cdots,\alpha_n$ as the current optimal weight distribution for each $\mathcal{L}_i$.
Dual Problem
The benefit of the smooth approximation technique is that it is simple and intuitive, requiring little foundation in optimization algorithms. However, it is ultimately a "non-mainstream" approach with many informalities (though the result is correct). Next, let's introduce the "standard answer" based on dual theory.
First, define $\mathbb{P}^n$ as the set of all $n$-element discrete distributions:
\begin{equation}\mathbb{P}^n = \left\{(\alpha_1,\alpha_2,\cdots,\alpha_n)\left\|\alpha_1,\alpha_2,\cdots,\alpha_n\geq 0, \sum_i \alpha_i = 1\right.\right\}\end{equation}
Then it is easy to verify that:
\begin{equation}\min_i \langle \boldsymbol{g}_i, \boldsymbol{u}\rangle = \min_{\alpha\in\mathbb{P}^n}\left\langle \tilde{\boldsymbol{g}}(\alpha), \boldsymbol{u}\right\rangle,\quad \tilde{\boldsymbol{g}}(\alpha) = \sum_i \alpha_i \boldsymbol{g}_i\end{equation}
Thus, problem $\eqref{eq:q-1}$ is equivalent to:
\begin{equation}\max_{\boldsymbol{u}}\min_{\alpha\in\mathbb{P}^n}\left\langle \tilde{\boldsymbol{g}}(\alpha), \boldsymbol{u}\right\rangle - \frac{1}{2}\Vert \boldsymbol{u}\Vert^2\label{eq:q-2}\end{equation}
The function above is concave relative to $\boldsymbol{u}$ and convex relative to $\alpha$, and the feasible regions for $\boldsymbol{u}$ and $\alpha$ are convex sets. Therefore, according to Von Neumann's Minimax Theorem, the $\min$ and $\max$ in $\eqref{eq:q-2}$ can be interchanged, making it equivalent to:
\begin{equation}\min_{\alpha\in\mathbb{P}^n}\max_{\boldsymbol{u}}\left\langle \tilde{\boldsymbol{g}}(\alpha), \boldsymbol{u}\right\rangle - \frac{1}{2}\Vert \boldsymbol{u}\Vert^2 = \min_{\alpha\in\mathbb{P}^n}\frac{1}{2}\left\Vert\tilde{\boldsymbol{g}}(\alpha)\right\Vert^2\label{eq:q-3}\end{equation}
The right side of the equals sign is because the $\max$ part is just an unconstrained maximum problem of a quadratic function, which can be solved directly to find $\boldsymbol{u}^* = \tilde{\boldsymbol{g}}(\alpha)$. Consequently, only the $\min$ remains, and the problem becomes finding a weighted average of $\boldsymbol{g}_1,\boldsymbol{g}_2,\cdots,\boldsymbol{g}_n$ that minimizes its norm.
When $n=2$, the solution to the problem is relatively simple, equivalent to constructing the altitude of a triangle, as shown below:
(Geometric interpretation: Solving algorithm when n = 2)
When $n > 2$, we can use the Frank-Wolfe algorithm to transform it into multiple $n=2$ scenarios for iteration. The Frank-Wolfe algorithm can be understood as a constrained gradient descent algorithm suitable for cases where the feasible region of the parameters is a convex set. However, describing it thoroughly would take too much space, so I will not detail it here; readers are encouraged to find materials for self-study. Simply put, the Frank-Wolfe algorithm first linearizes the objective to find the next update direction as $e_{\tau}$, where $\tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \tilde{\boldsymbol{g}}(\alpha)\rangle$ and $e_{\tau}$ is a one-hot vector with 1 at position $\tau$, and then solves for an interpolation search between $\alpha$ and $e_{\tau}$ to find the optimal result as the iteration outcome. Thus, its iteration process is:
\begin{equation}\left\{\begin{aligned}
&\tau = \mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \tilde{\boldsymbol{g}}(\alpha^{(k)})\rangle\\
&\gamma = \mathop{\text{argmin}}_{\gamma} \left\Vert\tilde{\boldsymbol{g}}((1-\gamma)\alpha^{(k)} + \gamma e_{\tau})\right\Vert^2 = \mathop{\text{argmin}}_{\gamma} \left\Vert(1-\gamma)\tilde{\boldsymbol{g}}(\alpha^{(k)}) + \gamma \boldsymbol{g}_{\tau}\right\Vert^2\\
&\alpha^{(k+1)} = (1-\gamma)\alpha^{(k)} + \gamma e_{\tau}
\end{aligned}\right.\end{equation}
The solution for $\gamma$ is precisely the $n=2$ special case, which can be solved with the algorithm mentioned earlier. If $\gamma$ is not found through search but is instead fixed at $1/(k+1)$, the result is equivalent to $\eqref{eq:sol-1}$, which is a simplified version of the Frank-Wolfe algorithm. This means the result obtained through the smooth approximation technique is equivalent to the result of the simplified Frank-Wolfe algorithm.
De-constraint
In theory, for the solution of problem $\eqref{eq:q-3}$, we could also use unconstrained gradient descent directly by removing the constraints. For example, we could set parameters $\beta_1,\beta_2,\cdots,\beta_n\in\mathbb{R}$ and define:
\begin{equation}\alpha_i = \frac{e^{\beta_i}}{Z},\quad Z = \sum_i e^{\beta_i}\end{equation}
Then it translates to:
\begin{equation}\min_{\beta} \frac{1}{2Z^2}\left\Vert \sum_i e^{\beta_i} \boldsymbol{g}_i\right\Vert^2\end{equation}
This is an unconstrained optimization problem that can be solved with conventional gradient descent. However, for some reason, the author has not seen this approach being used (perhaps to avoid tuning the learning rate?).
Some Techniques
In the previous section, we provided two schemes for finding the update direction of a Pareto stationary point. Both require that during each step of training, one must first determine the weight of each task through multiple additional iterations before the model parameters can be updated. Consequently, the computational load during actual implementation is quite high, so we need some techniques to reduce it.
Gradient Inner Product
As can be seen, regardless of the scheme, the key step involves $\mathop{\text{argmin}}_i \langle \boldsymbol{g}_i, \tilde{\boldsymbol{g}}(\alpha)\rangle$, which requires traversing gradients to calculate inner products. However, in deep learning scenarios, the number of model parameters is often very large, making the gradient a vector of extremely high dimensionality. If an inner product is calculated at every iteration step, the computational cost is huge. At this point, we can utilize the expansion:
\begin{equation}\langle \boldsymbol{g}_i, \tilde{\boldsymbol{g}}(\alpha)\rangle = \left\langle \boldsymbol{g}_i, \sum_j \alpha_j \boldsymbol{g}_j \right\rangle = \sum_j \alpha_j \langle \boldsymbol{g}_i, \boldsymbol{g}_j \rangle\end{equation}
In each iteration, only $\alpha$ changes. Therefore, at each step of training, $\langle \boldsymbol{g}_i, \boldsymbol{g}_j \rangle$ only needs to be calculated once and stored; there is no need to repeat these high-dimensional vector inner product calculations.
Shared Encoder
However, when the model reach a certain size, it becomes difficult to calculate the gradient of each task separately and then perform iterative calculations. If we assume that each model in multitask learning shares the same encoder, we can further simplify the algorithm with an approximation.
Specifically, assume the batch size is $b$ and the encoder output for the $j$-th sample is $\boldsymbol{h}_j$. From the chain rule, we know:
\begin{equation}\boldsymbol{g}_i = \nabla_{\boldsymbol{\theta}}\mathcal{L}_i = \sum_j (\nabla_{\boldsymbol{h}_j}\mathcal{L}_i)(\nabla_{\boldsymbol{\theta}}\boldsymbol{h}_j) = \underbrace{\big(\nabla_{\boldsymbol{h}_1}\mathcal{L}_i , \cdots , \nabla_{\boldsymbol{h}_b}\mathcal{L}_i\big)}_{\nabla_{\boldsymbol{H}}\mathcal{L}_i}\underbrace{\begin{pmatrix}\nabla_{\boldsymbol{\theta}}\boldsymbol{h}_b \\ \vdots \\ \nabla_{\boldsymbol{\theta}}\boldsymbol{h}_b\end{pmatrix}}_{\nabla_{\boldsymbol{\theta}}\boldsymbol{H}}\end{equation}
Denoting $\boldsymbol{H} = (\boldsymbol{h}_1,\cdots,\boldsymbol{h}_b)$, we get $\boldsymbol{g}_i = (\nabla_{\boldsymbol{H}}\mathcal{L}_i) (\nabla_{\boldsymbol{\theta}}\boldsymbol{H})$. Using the matrix norm inequality, we get:
\begin{equation}\left\Vert\sum_i \alpha_i \boldsymbol{g}_i\right\Vert^2 = \left\Vert\sum_i \alpha_i (\nabla_{\boldsymbol{H}}\mathcal{L}_i) (\nabla_{\boldsymbol{\theta}}\boldsymbol{H})\right\Vert^2 \leq \left\Vert\sum_i \alpha_i \nabla_{\boldsymbol{H}}\mathcal{L}_i\right\Vert^2 \big\Vert \nabla_{\boldsymbol{\theta}}\boldsymbol{H}\big\Vert^2 \end{equation}
It is easy to imagine that if we minimize $\left\Vert\sum_i\alpha_i \nabla_{\boldsymbol{H}}\mathcal{L}_i\right\Vert^2$, the computational load will be significantly reduced because this only requires the gradient of the final output encoder vector, rather than all parameters. The formula above tells us that minimizing $\left\Vert\sum_i\alpha_i \nabla_{\boldsymbol{H}}\mathcal{L}_i\right\Vert^2$ is effectively minimizing the upper bound of Equation $\eqref{eq:q-3}$. Like many problems that are difficult to optimize directly, we hope that minimizing the upper bound will lead to similar results.
However, while this upper bound is more efficient, it also has limitations. It is generally only suitable for multitask learning where each sample has multiple types of label information. it is not suitable for scenarios where there is no overlap in the training data of various tasks (i.e., different tasks are annotated on different samples, and a single sample has only one type of annotation). For the latter, the various $\nabla_{\boldsymbol{H}}\mathcal{L}_i$ are mutually orthogonal. In this case, tasks do not interact, and the upper bound is too loose and loses meaning by failing to reflect task correlation.
A Flawed Proof
The "standard answer" and the result regarding optimization of the upper bound when using a shared encoder both come from the paper "Multi-Task Learning as Multi-Objective Optimization". The original paper then attempts to prove that when $\nabla_{\boldsymbol{\theta}}\boldsymbol{H}$ is full rank, optimizing the upper bound can also find a Pareto stationary point. However, unfortunately, the proof in the original paper is incorrect.
The proof is located in Appendix A of the original paper, where an incorrect conclusion is used:
If $\boldsymbol{M}$ is a symmetric positive-definite matrix, then $\boldsymbol{x}^{\top}\boldsymbol{y}\geq 0$ if and only if $\boldsymbol{x}^{\top}\boldsymbol{M}\boldsymbol{y}\geq 0$.
It is easy to find a counterexample proving this conclusion is wrong. For instance, $\boldsymbol{x}=\begin{pmatrix}1 \\ -2\end{pmatrix},\boldsymbol{y}=\begin{pmatrix}1 \\ 1\end{pmatrix},\boldsymbol{M}=\begin{pmatrix}3 & 0\\ 0 & 1\end{pmatrix}$. Here, $\boldsymbol{x}^{\top}\boldsymbol{y} < 0$ but $\boldsymbol{x}^{\top}\boldsymbol{M}\boldsymbol{y} > 0$.
Upon reflection, the author believes the proof in the original paper is difficult to salvage. That is, the hypothesis in the original paper does not hold; in other words, even if $\nabla_{\boldsymbol{\theta}}\boldsymbol{H}$ is full rank, the update direction derived from optimizing the upper bound might not be a direction along which all task losses do not rise, and thus might not find a Pareto stationary point. As for why the optimization of the upper bound in the original paper's experiments worked well, we can only say that the parameter space of deep learning models is so large, and the "maneuvering" space is so vast, that the upper bound approximation can still yield decent results.
Summary
In this article, we understood multitask learning from the perspective of gradients. From this perspective, the main task of multitask learning is to find an update direction that is, as much as possible, in the opposite direction of the gradient of each task. This ensures that the loss of each task can decrease as much as possible, without sacrificing one task for another. This represents an ideal state of "zero involution" between tasks.
Original Address: https://kexue.fm/archives/8896