Using Mixed Precision and XLA to Accelerate Training in bert4keras

By 苏剑林 | April 28, 2022

Previously, I have always focused on model conception and implementation, rarely paying attention to model training acceleration. Although I had heard of technologies like mixed precision and XLA, I had never truly put them into practice. Over the past two days, after some tinkering, I successfully used mixed precision and XLA to accelerate training in bert4keras. Here is a brief summary for your reference.

Most of the empirical conclusions in this article are not limited to use within bert4keras. The reason bert4keras is emphasized in the title is simply that the model implementations in bert4keras are relatively structured, making the modifications required to enable these acceleration techniques relatively minimal.

Experimental Environment

The graphics card used for the experiments in this article is an RTX 3090, and the Docker image used is nvcr.io/nvidia/tensorflow:21.09-tf1-py3, which comes with TensorFlow version 1.15.5. Additionally, the version of bert4keras used in the experiments is 0.11.3. Other environments can also be set up by referring to this, but be sure to maintain a spirit of experimentation and don't expect brainless calls to work perfectly.

As a side note, cards like the 3090 and A100 can only use CUDA 11, and the official version 1.15 of TensorFlow from Google does not support CUDA 11. If you still want to use TensorFlow 1.x, your only options are to use nvidia-tensorflow maintained by NVIDIA itself, or use the Docker images they build. Using TensorFlow maintained by NVIDIA instead of Google not only allows you to use version 1.x on the latest graphics cards but also includes specific additional optimizations by NVIDIA. For specific documentation, refer here.

Don't say things like "TensorFlow is already at 2.8, why are you still using 1.15?". Your graphics card is produced by NVIDIA, so who says which version of TensorFlow is best to use? Not you or me, and even Google doesn't have the final word—NVIDIA does. Since NVIDIA is still maintaining 1.15, that indicates 1.15 is still the GOAT (always the best).

Mixed Precision

First, let's look at mixed precision training. Simply put, model calculations use FP16, while parameter updates and storage use FP32. The representation range of FP16 is roughly $6 \times 10^{-8} \sim 65504$. Both its upper and lower bounds are limits we might hit when implementing a model. Therefore, the biggest problems introduced by FP16 are overflow and precision loss. For a more detailed introduction to the principles, you can search for them yourself; this article focuses on how to use it.

NVIDIA-TensorFlow's help documentation provides an introduction to mixed precision training here. The simplest way to start mixed precision training is to add an environment variable at the beginning of the script:

os.environ['TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE'] = '1'

Readers might notice that most tutorials introduce TF_ENABLE_AUTO_MIXED_PRECISION, while I use TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE. The difference is that the former automatically adds "Dynamic Loss Scaling," while the latter does not. However, my tests found that "Dynamic Loss Scaling" cannot replace manual loss adjustment, so I decided to skip that function entirely.

After adding the environment variable, you can restart the training script to see the situation. If NaN occurs as soon as training begins, you can adjust infinity and epsilon:

# For example:
K.set_epsilon(1e-4)

After adjustment, it usually won't NaN immediately (if it still does, check if other parts of the model use infinity or epsilon that aren't controlled by these functions and modify them). However, what might happen is that the loss drops first, then rises, and finally NaNs. This is because of poor initialization, or it is intentional as in DeepNet, resulting in some parameters having extremely small gradients (less than $10^{-8}$). Within FP16 precision, these gradients effectively become 0, so those parameters won't be updated, or equivalently, the gradients are inaccurate. Updating with inaccurate gradients over a long period easily leads to non-convergence.

At this point, the solution is "Loss Scaling." We can directly multiply the loss function by a biological amplification factor (e.g., 1000; you can debug this yourself, the larger the better as long as NaN doesn't occur), so that originally tiny gradients are amplified into the FP16 range instead of being zeroed out, thus avoiding gradient precision loss. For the optimizers we usually use, such as Adam or LAMB, multiplying the loss function by a constant does not change the training process of these optimizers, meaning they are perfectly compatible with "Loss Scaling."

In fact, I've found that the "Loss Scaling" trick is effective not just in mixed precision training scenarios, but even in full FP32 training: in full FP32 training, if loss scaling is not performed, the model might stay at a certain loss value for a while before starting to drop; if loss scaling is applied, the model maintains a slow downward trend from the beginning, resulting in relatively faster convergence.

Algebraic Acceleration

Now let's look at XLA, which stands for "Accelerated Linear Algebra," specifically designed to speed up linear algebra operations. Simply put, XLA performs ahead-of-time compilation optimization on the computation graph, merging operators that can be merged (reducing intermediate variables to save memory) and parallelizing operators that can be parallelized (increasing calculation speed).

In NVIDIA-TensorFlow, the simplest way to enable XLA is still by adding an environment variable:

os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'

However, note that XLA is not guaranteed to give an improvement. As mentioned, XLA tries to parallelize as many operators as possible. This is clearly a strategy of trading space for time. Therefore, enabling XLA might consume more VRAM, leading to OOM (Out of Memory), and if the parallel clusters are too large, it might even lead to performance degradation. The official documentation provides a detailed analysis of possible anomalies and corresponding suggestions. My recommended solution is to supplement with the --tf_xla_enable_lazy_compilation=false parameter:

os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit --tf_xla_enable_lazy_compilation=false'

If that doesn't solve it, switch to XLA Lite:

os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=fusible'

If even XLA Lite doesn't solve the issue, then it basically means XLA is not suitable for your model.

Performance Comparison

On a 3090, the speedup brought by enabling mixed precision training is a bit over 10%. This improvement might not be as fast as everyone imagines. I suspect this is because, on newer cards like the 3090 and A100, the default FP32 format actually uses a format called TF32 (refer here). TF32 is, in a sense, a "half-precision format" itself and is faster than FP32. In other words, FP32 on the 3090 is already equivalent to having some half-precision optimization, making it naturally faster, so the boost after switching to mixed precision is relatively smaller.

As for the boost from XLA, it is roughly around 15%. In my training script, directly setting the environment variable TF_XLA_FLAGS to --tf_xla_auto_jit=1 resulted in OOM; supplementing with --tf_xla_enable_lazy_compilation=false did as well. However, changing it to --tf_xla_auto_jit=fusible allowed for normal training.

Finally, the most crucial point is that mixed precision and XLA can be used together! Using both together brings a total speedup of about 30%, and the addition of mixed precision training basically offsets the increased memory consumption brought by XLA. The two truly complement each other perfectly.

Summary

This article introduced an attempt to use mixed precision and XLA to accelerate training in bert4keras. Enabling both simultaneously can achieve a speedup of about 30% on an RTX 3090.