By 苏剑林 | April 29, 2020
Many readers recently might have noticed the official account article "BERT Recomputation: Saving 5x memory overhead with 22.5% training time (including code)". It introduced a technique called "recomputation" (gradient checkpointing), which is essentially a method to save memory. It allows for a several-fold increase in batch_size at the cost of a slightly slower average training speed. This technique was first proposed in the paper "Training Deep Nets with Sublinear Memory Cost" back in 2016, though it doesn't seem to have become particularly popular until now.
Exploration
The aforementioned article mentioned that this technique has native implementations in PyTorch and PaddlePaddle, but not yet in TensorFlow. However, in reality, since TensorFlow 1.8, TensorFlow has included this functionality. At that time, it was placed in the tf.contrib sub-library. Starting from TensorFlow 1.15, it became a built-in core function of TensorFlow: tf.recompute_grad.
After finding tf.recompute_grad, I looked into its usage. After some tinkering, I actually succeeded in using it, successfully increasing the batch_size from 48 to 144! However, while organizing and testing further, I discovered that this thing is actually ineffective in TensorFlow 2.x... So I spent another two days searching through various materials and debugging repeatedly. Finally, I managed to fill this gap.
Here is my open-source implementation:
GitHub Address: https://github.com/bojone/keras_recompute
This implementation is already built into bert4keras. Readers using bert4keras can upgrade to the latest version (0.7.5+) to test this feature.
Usage
My implementation is also named recompute_grad. It is a decorator used for customizing the call function of a Keras layer, for example:
from recompute import recompute_grad
class MyLayer(Layer):
@recompute_grad
def call(self, inputs):
return inputs * 2
For existing layers, you can decorate them through inheritance:
from recompute import recompute_grad
class MyDense(Dense):
@recompute_grad
def call(self, inputs):
return super(MyDense, self).call(inputs)
After customizing the layers, embed them into your code and set the environment variable RECOMPUTE=1 before running the code to enable recomputation.
Note: Simply inserting @recompute_grad into the overall model will not achieve the goal of saving memory. Instead, you need to insert @recompute_grad into each individual layer to better save memory. Put simply, the more @recompute_grad decorations you insert, the more memory you save. Please carefully understand the principles of recomputation for the specific reasons.
Effect
bert4keras 0.7.5+ has built-in recomputation, which is enabled by passing the environment variable RECOMPUTE=1. Readers can try it themselves; the approximate effects are:
- With the BERT Base version, the
batch_size can be increased to about 3 times the original;
- With the BERT Large version, the
batch_size can be increased to about 4 times the original;
- The average training time per sample increases by approximately 25%;
- Theoretically, the more layers there are, the larger the multiplier for increasing
batch_size.
Environment
Tested and passed in the following environments:
- tensorflow 1.14 + keras 2.3.1
- tensorflow 1.15 + keras 2.3.1
- tensorflow 2.0 + keras 2.3.1
- tensorflow 2.1 + keras 2.3.1
- tensorflow 2.0 + built-in tf.keras
- tensorflow 2.1 + built-in tf.keras
Confirmed unsupported environments:
- tensorflow 1.x + built-in tf.keras
More test results are welcome.
By the way, it is strongly recommended to use Keras 2.3.1 in conjunction with TensorFlow 1.x/2.x; it is strongly discouraged to use the tf.keras bundled with TensorFlow 2.x.
References
Finally, my implementation mainly refers to the following two source codes. I would like to express my gratitude.
Reprinting notice: Please include the source address of this article: https://kexue.fm/archives/7367