Abnormalities and Countermeasures for Embeddings under Lion/Tiger Optimizer Training

By 苏剑林 | August 28, 2023

Ever since I proposed the Tiger optimizer in "Tiger: An Extremely Frugal Optimizer", Tiger has become my "standard" optimizer for training models. Recently, I attempted to apply Tiger to the pre-training of a model with 7 billion parameters. The initial results looked promising, preliminary suggesting that Tiger is capable of scaling up. However, upon inspecting the weights of the trained model, I discovered some abnormalities in the Embeddings; certain components of the Embedding reached the magnitude of \(\pm 100\).

After analysis, I found that similar phenomena do not occur with Adam. This is a specific issue for optimizers like Tiger or Lion that use the \(\text{sign}\) function. This article provides two reference solutions for this issue at the end. I will record the analysis process here for everyone's reference.

Phenomenon

Next, we will use the Tiger optimizer as an example for our analysis, but the process and conclusions apply equally to Lion.

First, the phenomena I observed were as follows:

1. Some Embedding components for certain tokens became \(\pm 100\);
2. A small portion of other tokens' Embedding components were trending towards \(\pm 100\);
3. These tokens appeared to be very low-frequency tokens;
4. The maximum value of the entire Embedding matrix was 100, and the minimum was -100;
5. Except for the Embedding, other weights did not exhibit this problem;
6. The overall performance of the model (such as training loss and generation tests) was normal.

Some readers might wonder: since the model performance is normal, why bother? In my view, there are at least two reasons. First, if you want to perform fine-tuning later, some low-frequency tokens might become high-frequency again; if the Embeddings for these tokens are too poor, fine-tuning might not be able to save them. Second, some capabilities are not reflected in the loss. For instance, in Chinese-English pre-trained models, because the training data contains a small amount of other language data, the model usually exhibits some multilingual capability. Clearly, this capability relies on the quality of low-frequency token Embeddings. If this is compromised by the optimizer, it would be a "great loss."

Of course, regardless of the optimizer, it is always possible for the training to collapse, which is not surprising and often difficult to investigate deeply. However, the most intriguing thing here is how "regularly" it collapsed—it reached exactly \(\pm 100\). This compelled me to further investigate the underlying cause.

Thinking

Based on the observations above, we can initially conclude that these outliers only appear in the "Embeddings of low-frequency tokens." This reminded me of the problem discussed in "Implementing Two Optimizers in Keras: Lookahead and LazyOptimizer", where optimizers with momentum can lead to the over-optimization of the Embedding layer.

Specifically, as long as a token appears once, the momentum corresponding to that token's Embedding is updated to a non-zero value (assuming the gradient is not exactly zero). Consequently, in subsequent updates, even if the token does not appear in the current sample (gradient is zero), its Embedding is still updated (because momentum is non-zero). This is the over-optimization problem for low-frequency tokens. This problem occurs in all optimizers with momentum, including Adam and Tiger. However, in Adam, this might not be noticeably felt because the update amount is proportional to the momentum. If a token does not reappear for a long time, the momentum decays exponentially, so it quickly approaches zero. In other words, the update amount also quickly tends to zero, meaning over-updating soon disappears.

However, the situation is different in Tiger. The update amount in Tiger is proportional to the sign of the momentum, \(\text{sign}(m_t)\). Although the momentum \(m_t\) decays exponentially, the sign function does not. Until \(m_t\) becomes zero due to rounding errors, \(\text{sign}(m_t)\) maintains a value of \(\pm 1\), meaning the update amount remains constant. Thus, the over-updating problem for Embeddings is much more severe in Tiger. To make matters worse, after a token's Embedding has biased in a certain direction due to over-updating, its gradient might adapt to and reinforce this change. That is, the gradient the next time the token appears might be in the same direction rather than the opposite. This leads to long-term over-updating in the same direction, eventually resulting in the outliers.

Calculation

So, why is the outlier value exactly \(\pm 100\)? This is where weight decay comes into play. The total update formula for Tiger is:

\begin{equation} \theta_t = \theta_{t-1} - \eta_t [\text{sign}(m_t) + \lambda \theta_{t-1}] \label{eq:tiger_update} \end{equation}

In other words, in addition to the sign of the momentum, there is a weight decay term. In the abnormal experiment mentioned at the beginning, the decay rate \(\lambda\) was set to 0.01.

It is not hard to see that if \(\text{sign}(m_t)\) remains constant for a long time, the iteration formula above will have an equilibrium point. It occurs when \(\text{sign}(m_t) + \lambda \theta^* = 0\), which is:

\begin{equation} \theta^* = -\frac{\text{sign}(m_t)}{\lambda} \label{eq:equilibrium} \end{equation}

This corresponds exactly to a vector with elements of \(\pm 100\), which explains why the outliers are \(\pm 100\). If interested, readers can also assume \(\eta_t\) is constant and directly solve for the analytical form of \(\theta_t\) to further analyze the convergence speed, etc. I will not expand further on that here.

Countermeasures

Since the problem arises from the over-updating of low-frequency token Embeddings, a natural solution is to "Lazy-ify" the Embedding updates, as suggested in "Implementing Two Optimizers in Keras: Lookahead and LazyOptimizer". That is, only update the corresponding Embedding when the Token actually appears. If we can access the set of all input Token IDs, we can update only those Embeddings. If not, we can determine whether an Embedding needs to be updated by checking if the gradient norm of the Embedding is non-zero.

On the other hand, from a more general perspective, this problem is a common defect of Lion/Tiger optimizers for parameters with sparse gradients, including but not limited to the Embedding layer. Therefore, another approach to solve the problem is to make the Embedding gradients non-sparse. To this end, we can consider Tied Embeddings—sharing input and output Embeddings. Since the output end reuses the entire Embedding matrix, the entire Embedding matrix will have non-zero gradients, preventing \(m_t\) from remaining constant for long periods. Of course, Tied Embeddings might bring other problems; corresponding solutions can be found in "Re-exploration of Shared Embeddings at the Output End of Language Models". In my experiments, using Tied Embeddings that swap half of the model's feature channels solved the problem above, and the effect seemed even better than Untied Embeddings.

Finally, I also consulted the authors of the Lion optimizer regarding this issue. Their reply was that they had noticed this problem before. Their solution was a hybrid optimizer; for example, using Adam for the Embedding layer and Lion/Tiger for the other layers. Uh, this solution was one I hadn't thought of. It feels not particularly elegant, but it certainly solves the problem. Readers can choose for themselves.

Summary

This article introduced the phenomenon of Embedding abnormalities during training with the Lion/Tiger optimizers, analyzed the underlying causes, and provided reference solutions.