By 苏剑林 | July 08, 2019
Now in Keras, you can also achieve the effect of a large batch size using a small batch size—as long as you are willing to spend $n$ times the time, you can achieve the effect of an $n$ times larger batch size without increasing VRAM.
Github Address: https://github.com/bojone/accum_optimizer_for_keras
Digression
A year or two ago, you didn't really have to worry about OOM (Out of Memory) issues when doing NLP tasks because, compared to models in the CV field, most NLP models were quite shallow and rarely ran out of video memory. Fortunately or unfortunately, Bert was released and then became famous. Bert and its successors (GPT-2, XLNET, etc.) are all based on sufficiently massive Transformer models, pre-trained on large enough corpora, and then completed for specific NLP tasks through fine-tuning.
Even if you really don't want to use Bert, the reality today is: the complex model you meticulously designed might not perform as well as simply fine-tuning Bert. So no matter what, to keep up with the times, you need to learn Bert's fine-tuning. The problem is "you don't know until you try, and it's a shock when you do"—as long as the task is slightly complex or the sentence length is slightly longer, the VRAM isn't enough, and the batch size drops sharply—32? 16? 8? It's possible for it to keep falling.
This isn't hard to understand. Transformers are based on Attention, and theoretically, the space and time complexity of Attention is $\mathcal{O}(n^2)$. Although Attention can still perform fast enough due to its parallelism when computing power is strong, the VRAM consumption cannot be saved. $\mathcal{O}(n^2)$ means that when your sentence length doubles, the VRAM consumption basically requires 4 times the original, and this growth rate certainly makes it easy to OOM~
And the even more unfortunate news is that when everyone is fine-tuning pre-trained Bert, your batch_size=8 might be several tenths of a percent or even several percentage points lower than someone else's batch_size=80. This is clearly difficult for readers who want to climb the leaderboards. Is there really no way other than adding more graphics cards?
The Real Business
Yes! By using gradient caching and accumulation, you can trade time for space. The final training effect is equivalent to using a larger batch size. Therefore, as long as you can run batch_size=1, and you are willing to spend $n$ times the time, you can run $n$ times the batch size.
The idea of gradient accumulation was introduced in the previous article "Making Keras Cooooler!": Niche Custom Optimizers, where it was called "soft batch." In this article, I will follow the mainstream terminology and call it "accumulate gradients." The so-called gradient accumulation is actually very simple: the gradient we use for gradient descent is actually the average value of gradients calculated from multiple samples. Taking batch_size=128 as an example, you can calculate the gradients for 128 samples at once and then average them; I can also calculate the average gradient for 16 samples each time, cache it and accumulate it, and after doing this 8 times, divide the total gradient by 8 and then execute the parameter update. Of course, the parameters must not be updated until 8 accumulations have occurred using the 8-time average gradient; you cannot update every time you calculate 16, otherwise, it would just be batch_size=16.
As mentioned earlier, the coding method in that previous article was incorrect because it used:
to control the update, but in fact, this writing method cannot control the update because K.switch only guarantees the selectivity of the result, not the selectivity of the execution. In fact, it is equivalent to:
That is to say, regardless of the cond, both branches are executed. In fact, Keras or Tensorflow "almost" has no conditional writing that only executes one branch (I say "almost" because it can be done under some very stringent conditions), so this path is blocked.
If it cannot be written this way, then we can only work on the "update amount." As mentioned before, we calculate the gradient of 16 samples each time and update the parameters each time, but the update amount is 0 for 7 out of 8 times, and only 1 time is the real gradient descent update. Fortunately, this method can be seamlessly integrated into existing Keras optimizers, so we don't need to rewrite the optimizer! For detailed code, please see:
https://github.com/bojone/accum_optimizer_for_keras/blob/master/accum_optimizer.py
The specific coding is nothing more than some programming tricks of "grafting." There isn't much high-tech content. I won't go into detail about the code itself; if you have questions, feel free to discuss them in the discussion section.
(Note: This optimizer modification allows a small batch size to function as a large batch size, provided the model does not contain Batch Normalization, because Batch Normalization must use the mean and variance of the entire batch during gradient descent. So if your network uses Batch Normalization and you want to accurately achieve the effect of a large batch size, the only current method is to add VRAM/graphics cards.)
Experiment
The usage is very simple:
from accum_optimizer import AccumOptimizer
from keras.optimizers import Adam
opt = AccumOptimizer(Adam(lr=1e-3), 10) # 10 accumulations
model.compile(loss='mse', optimizer=opt)
model.fit(x_train, y_train, batch_size=10) # Original batch size is 10
In this way, it is equivalent to an Adam optimizer with batch_size=100. The cost is that the speed of each epoch will be slower (because the batch size is smaller), but the advantage is that you only need to use the VRAM amount of batch_size=10.
One question readers might want to ask is: how do you prove your code worked? That is, how do you prove your result is indeed batch_size=100 instead of batch_size=10? For this, I did a somewhat extreme experiment, the code is here:
https://github.com/bojone/accum_optimizer_for_keras/blob/master/mnist_mlp_example.py
The code is simple: use a multi-layer MLP for MNIST classification, use the Adam optimizer, and when calling fit, batch_size=1. There are two choices for the optimizer: the first is direct Adam(), and the second is AccumOptimizer(Adam(), 100):
- If it's direct
Adam(), the loss hovers around 0.4, and the loss gets larger later (even on the training set), and the val accuracy never exceeds 97%;
- If it's
AccumOptimizer(Adam(), 100), then the training set loss gets lower and lower, eventually dropping to about 0.02, and the highest val accuracy is 98%+;
- Finally, I compared the result of direct
Adam() but with batch_size=100 and found that it performed similarly to AccumOptimizer(Adam(), 100) with batch_size=1.
This result is enough to show that the code is effective and has achieved the expected purpose. If this is not convincing enough, I can provide another training result for reference: in a certain Bert fine-tuning experiment, using direct Adam() plus batch_size=12, I reached 70.33% accuracy; using AccumOptimizer(Adam(), 10) plus batch_size=12 (expected equivalent batch size is 120), I reached 71.00% accuracy, an increase of 0.7%. If you are climbing a leaderboard, then this 0.7% might be decisive.
Conclusion
I have finally formally implemented gradient accumulation (soft batch). In the future when using Bert, we can also consider using large batch sizes~