By 苏剑林 | February 21, 2025
In the previous article "MoE Grand Tour: 1. Starting from Geometric Significance", we introduced a geometric interpretation of MoE, aiming to derive and understand MoE starting from the best approximation of a Dense model. At the same time, we mentioned at the end of the article that giving the MoE calculation formula is only the beginning. Training a practical and effective MoE model requiring many details to be filled in, such as the load balance problem to be discussed in this article.
Load balancing, or "not worrying about scarcity but rather about inequality," simply means making sure every Expert is working, and that they are all doing as equal an amount of work as possible to avoid wasting computing power on certain Experts. Load balancing is both a requirement for fully utilizing training computing power and a requirement for maximizing the potential of MoE's large parameter count.
We know that the basic form of MoE is: \begin{equation}\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}} \rho_i \boldsymbol{e}_i\end{equation} For traditional MoE, $\boldsymbol{\rho}$ is a probability distribution (Router), $\boldsymbol{e}_i=\boldsymbol{v}_i$, where $\boldsymbol{v}_i$ is the output of a small FFN (Expert). For the geometric MoE we derived in the previous post, $\boldsymbol{\rho}$ has no normalization requirement; it predicts the magnitude of the Expert, while $\boldsymbol{e}_i=\boldsymbol{v}_i/\Vert\boldsymbol{v}_i\Vert$ predicts the direction of the Expert.
Regardless of the format of MoE, the actual performance is similar; it is just a difference in interpretative perspective. However, it should be noted that although the MoE formula gives the impression of "finding the corresponding Expert to calculate every time a Token is encountered," the actual training is the reverse: first, assign the corresponding computing power to each Expert, and then dispatch (Route) Tokens to their respective Experts for parallel calculation. This is why the $\boldsymbol{\rho}$ responsible for scoring is called the Router.
Consequently, if the Expert distribution is uneven, the following situation may occur: certain Experts (Dead Experts) are almost always idle, wasting computing power; while other Experts have too many Tokens to process and cannot keep up, resulting in Token Drop (i.e., giving up on processing some Tokens). Theoretically, the appearance of Dead Experts means that MoE has not reached its expected parameter capacity; that is, it consumes the VRAM of a large parameter model but yields the training effect of a small parameter model.
Therefore, whether from the perspective of training or performance, we want to ensure Expert load balance.
A conventional approach to promoting load balance is to add a related loss function, which we usually call "Aux Loss (Auxiliary Loss)." The current mainstream Aux Loss can be traced back to the 2020 paper "GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding".
Before introducing Aux Loss, we need to introduce some new concepts. First, as we mentioned, for a general MoE, $\boldsymbol{\rho}$ may not necessarily be a probability distribution. We denote the normalized version of $\boldsymbol{\rho}$ as $\boldsymbol{p}=[p_1,p_2,\dots,p_n]$, and its Top-$k$ version as $\boldsymbol{f}=[f_1,f_2,\dots,f_n]$, where: \begin{equation}p_i = \frac{\rho_i}{\sum_{i=1}^n \rho_i},\qquad f_i = \left\{\begin{aligned}1/k, \quad i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} \\ 0, \quad i\not\in \mathop{\text{argtop}}_k \boldsymbol{\rho}\end{aligned}\right.\end{equation} Next, we define $\boldsymbol{P}=\mathbb{E}[\boldsymbol{p}]$ and $\boldsymbol{F}=\mathbb{E}[\boldsymbol{f}]$, where $\mathbb{E}$ refers to the average over all Tokens of all samples. It is not difficult to see that $\boldsymbol{F}$ is the current load distribution of the Experts, while $\boldsymbol{P}$ serves as a smooth approximation of $\boldsymbol{F}$.
With these notations, we can write the Aux Loss as: \begin{equation}\mathcal{L}_{\text{aux}} = \boldsymbol{F}\cdot \boldsymbol{P} = \sum_{i=1}^n F_i P_i\label{eq:aux-loss}\end{equation} General literature may define Aux Loss by multiplying it by an additional $n$, meaning their Aux Loss is equal to $n \mathcal{L}_{\text{aux}}$ here. Additionally, some large-scale MoEs may calculate Aux Loss per device to achieve intra-device balance and reduce inter-device communication; these are adaptations. However, some recent experiments suggest that forcing local balance might very likely affect the model's final performance.
I wonder if anyone has noticed a strange phenomenon: whether in the earliest source, subsequent literature, or popular science articles, basically every resource I have read cites judicial Aux Loss without proof, as if everyone agrees that the above Aux Loss promoting balance is a self-evident fact. But is it really that obvious?
At least I couldn't see it immediately. Therefore, in the following sections, I will provide a derivation for Eq. $\eqref{eq:aux-loss}$. From this thought process, we can also customize other forms of Aux Loss. First, define a uniform distribution $\boldsymbol{Q}=(1/n, 1/n, \dots, 1/n)$. As we said, $\boldsymbol{F}$ is the current load distribution; therefore, load balancing is equivalent to $\boldsymbol{F}=\boldsymbol{Q}$. Thus, the following formula is a relatively intuitive Aux Loss: \begin{equation}\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (F_i - 1/n)^2\label{eq:aux-loss-2}\end{equation} The problem is that $\boldsymbol{F}$ is produced by $\mathop{\text{argtop}}_k$, which means that the above formula is not a directly usable differentiable objective. How to solve this? The answer is the STE (Straight-Through Estimator) trick, which involves designing different functions for forward and backward propagation. Specifically, $\boldsymbol{F}$ is non-differentiable, but $\boldsymbol{P}$ as its smooth approximation is differentiable. Thus, we can replace $\boldsymbol{F}$ with $\boldsymbol{P}$ during backward propagation: \begin{equation}\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert \boldsymbol{P} + \text{sg}[\boldsymbol{F}-\boldsymbol{P}] - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n)^2\label{eq:aux-loss-3}\end{equation} where $\text{sg}[]$ is the stop gradient operator, which keeps the forward output unchanged but forces the gradient to be zero. After this modification, $\mathcal{L}_{\text{aux}}$ becomes a feasible Aux Loss. Let's try to find its gradient: \begin{equation}\begin{aligned} \nabla_{\boldsymbol{\theta}}\mathcal{L}_{\text{aux}} =&\, \frac{1}{2}\nabla_{\boldsymbol{\theta}}\sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n)^2 \\ =&\, \sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n) \nabla_{\boldsymbol{\theta}}(P_i + \text{sg}[F_i - P_i] - 1/n)\\ =&\, \sum_{i=1}^n (F_i - 1/n) \nabla_{\boldsymbol{\theta}}P_i = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n (F_i - 1/n) P_i\\ =&\, \nabla_{\boldsymbol{\theta}}\left(\sum_{i=1}^n F_i P_i\right) \end{aligned}\end{equation} where $\boldsymbol{\theta}$ represents the model parameters. The final result shows that the gradient of Eq. $\eqref{eq:aux-loss-3}$ is equal to the gradient of Eq. $\eqref{eq:aux-loss}$, which means using Eq. $\eqref{eq:aux-loss}$ as Aux Loss is equivalent to Eq. $\eqref{eq:aux-loss-3}$ in terms of gradients. This is why the Aux Loss in Eq. $\eqref{eq:aux-loss}$ appears.
However, Eq. $\eqref{eq:aux-loss}$ only has the meaning of an equivalent gradient but does not have the meaning of a "loss." It is not a true Loss. For example, when $\boldsymbol{F} = \boldsymbol{P}$, we can calculate that Eq. $\eqref{eq:aux-loss}$ equals $1/n$, but in fact, we can construct an $\boldsymbol{F}$ not equal to $\boldsymbol{P}$ to make it smaller than $1/n$. Thus, Eq. $\eqref{eq:aux-loss}$ is not like a normal Loss that is "the smaller, the better," nor is its minimum value achieved at $\boldsymbol{F} = \boldsymbol{P}$.
The derivation above actually provides a general idea for constructing Aux Loss: First, construct a loss based on $\boldsymbol{F}$ that meets the requirements, and then replace $\boldsymbol{F}$ with $\boldsymbol{P} + \text{sg}[\boldsymbol{F}-\boldsymbol{P}]$ in the implementation. For example, we know that maximum entropy can also push a distribution toward balance, so we can also use the negative sum of entropy to construct an Aux Loss: \begin{equation}\mathcal{L}_{\text{aux}} = \sum_{i=1}^n (P_i + \text{sg}[F_i - P_i])\log(P_i + \text{sg}[F_i - P_i])\end{equation} The above formula can be directly used for code implementation. Of course, if we seek simplification, we can similarly calculate the gradient, and the result will be: \begin{equation}\nabla_{\boldsymbol{\theta}}\mathcal{L}_{\text{aux}} = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n(P_i + \text{sg}[F_i - P_i]) \log(P_i + \text{sg}[F_i - P_i]) = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n P_i \log F_i\end{equation} In the process of twice simplifying the gradient, we used the following identity: \begin{equation}\sum_{i=1}^n \nabla_{\boldsymbol{\theta}}P_i = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n P_i = \nabla_{\boldsymbol{\theta}}1 = \boldsymbol{0}\end{equation} This relies on the fact that $\boldsymbol{P}$ is a probability distribution and the target distribution $\boldsymbol{Q}$ is uniform. If we do not seek a simplified equivalent result but directly use the Aux Loss in the $\boldsymbol{F} \to \boldsymbol{P} + \text{sg}[\boldsymbol{F}-\boldsymbol{P}]$ form, then we are not restricted by these two constraints.
For instance, regarding $\boldsymbol{P}$ as a smooth approximation of $\boldsymbol{F}$, we only used the property that "if $P_i$ is large, $F_i$ is usually large." Therefore, using a non-normalized $\mathbb{E}[\boldsymbol{\rho}]$ as $\boldsymbol{P}$ is usually fine. This point might be critical in some special scenarios (e.g., when $\boldsymbol{\rho}$ has both positive and negative values), because normalization into a probability distribution might be impossible. Furthermore, the goal $\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2$ can obviously push $\boldsymbol{F}$ toward any target distribution $\boldsymbol{Q}$ we want, not necessarily a uniform one.
This article introduced the load balancing problem in MoE and presented a general approach for constructing Aux Loss. Besides Aux Loss, there are other schemes for promoting load balance, which we will discuss next time.