When BERT Meets Keras: This Might Be the Simplest Way to Open BERT

By 苏剑林 | June 18, 2019

What BERT is likely doesn't need much introduction by now. Although I am not particularly fond of BERT, I must say that it has truly caused a significant stir in the NLP world. Currently, whether in Chinese or English, explanations and interpretations of BERT are everywhere, and it seems its popularity has even surpassed the initial momentum when Word2Vec first came out. Interestingly, BERT was developed by Google, and the original Word2Vec was also developed by Google. Regardless of which one you use, you are essentially following in the footsteps of the big boss Google~

Shortly after BERT was released, some readers suggested I write an interpretation, but I ultimately didn't. Firstly, there are already many interpretations of BERT out there. Secondly, BERT is essentially a large-scale pre-trained model based on Attention. It isn't particularly innovative in terms of technology, and since I've already written an interpretation of Google's Attention, I couldn't quite find the motivation.

BERT pre-training/fine-tuning
BERT's pre-training and fine-tuning (Image from the original BERT paper)

Overall, I personally had little interest in BERT until the end of last month during an information extraction competition when I first tried it. Later, I thought that even if I wasn't interested, I should eventually learn it—after all, choosing to use it is one thing, but knowing how to use it is another. Furthermore, there didn't seem to be many articles introducing how to use (fine-tune) BERT within Keras, so I decided to share my usage experience.

When BERT meets Keras

Fortunately, some experts have already encapsulated a Keras version of BERT, allowing for the direct use of official pre-trained weights. For readers who already have some foundation in Keras, this is likely the simplest way to invoke BERT. The phrase "standing on the shoulders of giants" perfectly describes the feelings of Keras enthusiasts like us right now.

keras-bert

In my opinion, the best encapsulation of BERT under Keras currently is:

keras-bert: https://github.com/CyberZHG/keras-bert

This article is built upon this foundation.

Incidentally, in addition to keras-bert, CyberZHG has also encapsulated many other valuable Keras modules, such as keras-gpt-2 (so you can use GPT-2 just like BERT), keras-lr-multiplier (for layer-wise learning rates), keras-ordered-neurons (the ON-LSTM introduced recently), and more. A summary can be found here. He seems to be a die-hard Keras fan~ Salutations to the expert.

In fact, with keras-bert, a little bit of basic Keras knowledge, and the sufficiently complete demos provided by keras-bert, calling and fine-tuning BERT has become a task with almost no technical hurdle. Therefore, I will simply provide a few Chinese examples to help readers get started with the basic usage of keras-bert.

Tokenizer

Before moving on to the examples, it is necessary to discuss the Tokenizer. We import BERT's Tokenizer and reconstruct it:

from keras_bert import load_trained_model_from_checkpoint, Tokenizer
import codecs

config_path = '../bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '../bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '../bert/chinese_L-12_H-768_A-12/vocab.txt'

token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
    for line in reader:
        token = line.strip()
        token_dict[token] = len(token_dict)

class OurTokenizer(Tokenizer):
    def _tokenize(self, text):
        R = []
        for c in text:
            if c in self._token_dict:
                R.append(c)
            elif self._is_space(c):
                R.append('[unused1]') # use untrained [unused1] for space characters
            else:
                R.append('[UNK]') # use [UNK] for remaining characters
        return R

tokenizer = OurTokenizer(token_dict)
tokenizer.tokenize(u'今天天气不错')
# Output: ['[CLS]', u'今', u'天', u'天', u'气', u'不', u'错', '[SEP]']

A brief explanation of the Tokenizer output: First, by default, [CLS] and [SEP] markers are added to the beginning and end of the tokenized sentence, respectively. The output vector corresponding to the [CLS] position is designed to represent the whole sentence vector (at least that's how BERT was designed), while [SEP] is a separator between sentences. The rest consists of single-character outputs (for Chinese).

Originally, Tokenizer has its own _tokenize method, but I've overridden it here to ensure that the tokenized result remains the same length as the original string (or length plus 2, if counting the markers). The built-in _tokenize in Tokenizer automatically removes spaces and outputs some characters stuck together, which means the length of the tokenized list won't equal the original string length. This creates a lot of trouble for sequence labeling tasks. To avoid this, it's better to rewrite it yourself—mainly using [unused1] to represent space-like characters and [UNK] for characters not in the vocabulary list. The [unused*] markers are untrained (randomly initialized) and are reserved by BERT for incrementally adding vocabulary, so we can use them to represent any new characters.

Three Examples

This includes three examples using keras-bert: Text Classification, Relation Extraction, and Subject Extraction. All are done by fine-tuning on the basis of the officially released pre-trained weights.

Official BERT Github: https://github.com/google-research/bert
Official Chinese Pre-trained Weights: chinese_L-12_H-768_A-12.zip
Github for Examples: https://github.com/bojone/bert_in_keras/

According to the official introduction, this weight set was trained using the Chinese Wikipedia corpus.

(Update June 20, 2019: The HIT-iFLYTEK Joint Laboratory has released a new set of weights, which can also be loaded using keras-bert. For details, please see here.)

Text Classification

As our first example, we will perform a basic text classification task. Once you are familiar with this basic task, the remaining ones will become quite simple. This time, we'll use the Sentiment Classification task we've discussed many times before, using the annotated data I organized previously.

Let's look at the full model section (complete code can be found here):

# Note: although seq_len=None can be set, ensure the sequence length does not exceed 512
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)

for l in bert_model.layers:
    l.trainable = True

x1_in = Input(shape=(None,))
x2_in = Input(shape=(None,))

x = bert_model([x1_in, x2_in])
x = Lambda(lambda x: x[:, 0])(x) # Extract the vector corresponding to [CLS] for classification
p = Dense(1, activation='sigmoid')(x)

model = Model([x1_in, x2_in], p)
model.compile(
    loss='binary_crossentropy',
    optimizer=Adam(1e-5), # Use a sufficiently small learning rate
    metrics=['accuracy']
)
model.summary()

That’s it! Calling BERT in Keras for a sentiment classification task is finished just like that.

Feeling like it ended before it even started? Invoking BERT in Keras is just that short. In truth, the only line truly calling BERT is load_trained_model_from_checkpoint; the rest are standard Keras operations (thanks again to CyberZHG). So if you have already started with Keras, calling BERT will be effortless.

Can such a simple call achieve high accuracy? After 5 epochs of fine-tuning, the best accuracy on the validation set exceeded 95.5%! Previously, in "Text Sentiment Classification (3): To Segment or Not to Segment", we tweaked and tweaked only to get around 90% accuracy; with BERT, in just a few rows, we've increased the accuracy by over 5 percentage points! No wonder BERT has created such a wave in the NLP community...

Here, based on my personal experience, I'll answer two questions readers might care about.

The first question is likely: "How much VRAM is enough?". In fact, there is no standard answer. VRAM usage depends on three factors: sequence length, batch size, and model complexity. For the sentiment analysis example above, it can run on my GTX 1060 6GB VRAM; you just need to adjust the batch size to 24. So, if your VRAM is not large enough, try reducing the maxlen and batch size. Of course, if your task is too complex, even the smallest maxlen and batch size might trigger an OOM (Out of Memory) error, then your only option is to upgrade your GPU.

The second question is: "What principles guide which layers should be added after BERT?". The answer is: use as few layers as possible to complete your task. For instance, the sentiment analysis mentioned is just a binary classification task—you simply take the first vector and add a Dense(1). Don't think about adding multiple Dense layers, and definitely don't think about adding an LSTM before a Dense layer. If you're doing sequence labeling (like NER), then just add a Dense+CRF. In short, minimize extra additions. Firstly, BERT is complex enough and has the power to handle many tasks; secondly, the layers you add are randomly initialized, and adding too many will cause severe perturbations to BERT's pre-trained weights, which easily degrades performance or even prevents the model from converging.

Relation Extraction

If readers already have a foundation in Keras, then after the first example, you should have fully mastered fine-tuning BERT, as it is so simple there's hardly anything left to say. Therefore, the next two examples mainly provide reference patterns to help you feel how to "use as few layers as possible to complete your task."

In the second example, we introduce a minimalist relation extraction model based on BERT. Its labeling principle is the same as described in "A Lightweight Information Extraction Model Based on DGCNN and Probabilistic Graphs", but thanks to BERT's powerful encoding capability, our custom part can be greatly simplified. In one of the reference implementations I provided, the model part is as follows (complete model found here):

t = bert_model([t1, t2])
ps1 = Dense(1, activation='sigmoid')(t)
ps2 = Dense(1, activation='sigmoid')(t)

subject_model = Model([t1_in, t2_in], [ps1, ps2]) # Model for predicting subject

k1v = Lambda(seq_gather)([t, k1])
k2v = Lambda(seq_gather)([t, k2])
kv = Average()([k1v, k2v])
t = Add()([t, kv])
po1 = Dense(num_classes, activation='sigmoid')(t)
po2 = Dense(num_classes, activation='sigmoid')(t)

object_model = Model([t1_in, t2_in, k1_in, k2_in], [po1, po2]) # Input text and subject, predict object and its relation

train_model = Model([t1_in, t2_in, s1_in, s2_in, k1_in, k2_in, o1_in, o2_in],
                    [ps1, ps2, po1, po2])

If readers have read "A Lightweight Information Extraction Model Based on DGCNN and Probabilistic Graphs" and understand the architecture without BERT, they will see how concise and clear the above implementation is.

As you can see, we've introduced BERT as the encoder to get the encoded sequence $t$, then simply connected two Dense(1) layers to complete the subject labeling model. Next, we take the encoding vectors corresponding to the beginning and end of the passed $s$, add them directly to the encoding sequence $t$, and then connect two Dense(num_classes) layers to complete the object labeling model (which also labels relations simultaneously).

How high can the F1 score go with such a simple design? The answer: it reached nearly 82% on the offline dev set, and once I submitted it, the result was over 85% (both single models)! By comparison, the model in "A Lightweight Information Extraction Model Based on DGCNN and Probabilistic Graphs" needed CNNs, global features, passing $s$ to an LSTM for encoding, and relative position vectors—all these ad-hoc modules fused together—and the single model was only slightly better (about 82.5%). Bear in mind, I wrote this BERT-based model in just an hour, while the DGCNN model with its various tricks and fusions took me nearly two months to tune! BERT's power is evident from this.

(Note: fine-tuning this model is best done with over 8GB of VRAM. Also, because I only encountered BERT and wrote this model a few days before the competition ended, I didn't spend much time tuning it, so the final submission didn't include BERT.)

A notable difference between this relation extraction example and the previous sentiment analysis example is the change in learning rate.

In the sentiment analysis example, we only used a constant learning rate ($10^{-5}$) for a few epochs, and the results were decent. In this relation extraction example, during the first epoch, the learning rate gradually increases from $0$ to $5 \times 10^{-5}$ (this is called warmup), and during the second epoch, it decreases from $5 \times 10^{-5}$ back to $10^{-5}$. Overall, it increases then decreases. BERT itself was trained with a similar learning rate curve; this training method is more stable, less likely to collapse, and generally yields better results.

Event Subject Extraction

The final example comes from the CCKS 2019 Financial Event Subject Extraction competition. This competition is still ongoing, but I no longer have much motivation or interest in continuing, so I'm releasing my current model (accuracy 89%+) for reference. I wish the remaining participants the best of luck.

To briefly introduce the competition data, it looks something like this:

Input: "Company A's product contained additives; its subsidiaries Company B and Company C were investigated", "Product problem occurs"
Output: "Company A"

In other words, it is a dual-input, single-output model. The input is a query and an event type, and the output is an entity (exactly one, and it is a fragment of the query). In fact, this task can be seen as a simplified version of SQuAD 1.0. Based on the output characteristics, using a pointer structure (two softmax layers predicting the start and end) is best for the output. The remaining question is: how to handle the dual input?

While the previous two examples varied in complexity, they were both single-input models. What to do with dual inputs? Of course, the entity types are finite, so embedding them directly is an option. However, I use a solution that better demonstrates BERT's simple, brute-force, and robust nature: just use a connector to join the two inputs into one sentence, turning it into a single-input problem! For example, the sample above is processed into:

Input: "___Product problem occurs___Company A's product contained additives; its subsidiaries Company B and Company C were investigated"
Output: "Company A"

Then it becomes an ordinary single-input extraction problem. Speaking of which, there is not much else to say about this model code; it's just a few simple lines (complete code here):

x = bert_model([x1, x2])
ps1 = Dense(1, use_bias=False)(x)
ps1 = Lambda(lambda x: x[0][..., 0] - (1 - x[1][..., 0]) * 1e10)([ps1, x_mask])
ps2 = Dense(1, use_bias=False)(x)
ps2 = Lambda(lambda x: x[0][..., 0] - (1 - x[1][..., 0]) * 1e10)([ps2, x_mask])

model = Model([x1_in, x2_in], [ps1, ps2])

Additionally, by adding some decoding tricks and model fusion, submitting this could achieve 89%+. Looking at the current leaderboard, the best result is just over 90%, so I suspect everyone is doing something similar...

This example mainly teaches us that when implementing your own tasks with BERT, it's best if you can organize them into a single-input mode. This is simpler and more efficient.

For instance, for a sentence similarity model—where you input two sentences and output a similarity score—there are two identifiable methods. The first is passing both sentences through the same BERT separately and taking their respective [CLS] features for classification. The second is as described above: using a marker to concatenate the two sentences into one, passing it through one BERT, and then classifying based on the output features. The latter is clearly faster and allows for more comprehensive interaction between the features.

Article Summary

This article introduced the basic methods for calling BERT under Keras, primarily providing three reference examples to help everyone gradually become familiar with the fine-tuning steps and principles of BERT. Many of these points are my own experience from working in isolation; if there are any biases, I hope readers will point them out.

In fact, with the keras-bert implementation by CyberZHG, using BERT in Keras is a piece of cake; once you've tinkered with it for half a day, you'll have the hang of it. Finally, I wish everyone a happy experience using it~