By 苏剑林 | January 03, 2020
While developing bert4keras, I promised to gradually migrate the examples previously implemented with keras-bert over. One of those examples was triple extraction. Currently, the examples in bert4keras are becoming quite rich, but they were still missing tasks related to sequence labeling and information extraction. Since triple extraction fits this category perfectly, I have added it now.
Schematic Diagram of BERT-based Triple Extraction Model Structure
Model Introduction
Regarding the data format and the basic logic of the model, these were detailed in the article "A Lightweight Information Extraction Model Based on DGCNN and Probabilistic Graphs", so I will not repeat them here. The dataset has been made public by Baidu and can be downloaded here.
Following the same strategy as before, the model is still based on the "semi-pointer, semi-annotation" method. The sequence is to first extract the subject ($s$), then pass $s$ into the model to extract the object ($o$) and predicate ($p$). The only difference is that the overall architecture of the model has been replaced with BERT:
- The original sequence is converted into IDs and passed into the BERT encoder to obtain encoded sequences;
- The encoded sequence is connected to two binary classifiers to predict $s$;
- Based on the passed $s$, the encoding vectors corresponding to the start and end of $s$ are extracted from the encoded sequence;
- Using the encoding vector of $s$ as a condition, a Conditional Layer Norm is applied to the encoded sequence;
- The sequence after Conditional Layer Norm is used to predict the $o$ and $p$ corresponding to that $s$.
Class Imbalance
It is easy to imagine that when using a "semi-pointer, semi-annotation" structure for entity extraction, one faces an issue of class imbalance. This is because target entity words are usually much rarer than non-target words, so label 1 will be much scarcer than label 0. Conventional methods for handling imbalance, such as Focal Loss or manual label weighting, can be used, but after applying these methods, thresholds become difficult to set. Here, I used a method I find quite appropriate: raising the probability values to the $n$-th power.
Specifically, if the original output is a probability value $p$, representing the probability of class 1 as $p$, I now change it to $p^n$. That is, I treat the probability of class 1 as $p^n$. Everything else remains unchanged, and the loss is still the standard binary cross-entropy loss. Since we already have $0 \leq p \leq 1$, $p^n$ will be overall closer to 0. Thus, the initial state conforms to the target distribution, accelerating convergence.
The difference between the two can also be seen from the perspective of the loss function. Assuming the label is $t \in \{0, 1\}$, the original loss is:
\begin{equation}- t \log p - (1 - t) \log (1 - p)\end{equation}
The loss after the $n$-th power modification becomes:
\begin{equation}- t \log p^n - (1 - t) \log (1 - p^n)\end{equation}
Notice that $- t \log p^n = -nt \log p$. Therefore, when the label is 1, it is equivalent to amplifying the weight of the loss. When the label is 0, $(1 - p^n)$ is closer to 1, making the corresponding loss $\log(1 - p^n)$ smaller (and the gradient smaller). Thus, this can be considered an adaptive logic for adjusting loss weights (gradient weights).
Compared to Focal Loss or manual weighting, the advantage of this method is that it makes the distribution closer to the target without changing the distribution of the original inner product (where $p$ is usually obtained via an inner product plus a sigmoid). Maintaining the inner product distribution is generally more optimization-friendly.
Source Code and Performance
Github: task_relation_extraction.py
Without any pre-processing or post-processing, the final F1 on the validation set is 0.822, which is generally better than the previous DGCNN models. Note that this is without any pre- or post-processing; if some were added, the F1 would likely reach 0.83.
At the same time, we find many errors and omissions in the labels of the training and validation sets. When we originally participated in the competition, the labeling quality of the online test set was higher (more standardized and complete) than that of the training and validation sets. Back then, the F1 on the submitted test set was generally 4%–5% higher than the offline validation set F1. In other words, with some rule-based corrections, if this result were submitted to the original leaderboard, a single model would likely have an F1 of around 0.87.
Worth Noting
As mentioned at the beginning, I previously wrote an example of using BERT for triple extraction using keras-bert. Here, I will discuss the differences between the current model and the previous one, as well as some points worth noting.
The first difference is that the previous one was a simple attempt, where the vector of $s$ was simply added to the encoding sequence to predict $o$ and $p$, rather than using Conditional Layer Norm as in this article. The Conditional Layer Norm approach has better representational power, and the effect is slightly improved.
The second difference, and a point worth being aware of, is that the model in this article uses BERT's standard tokenizer, whereas the previous example used direct character splitting. Sequences produced by the standard tokenizer are not simply split by character, especially in cases involving English and numbers. The output tokenization results do not align perfectly with the characters of the original sequence. Therefore, one must be very careful when constructing training samples and outputting results.
Readers might ask: Why not go back to original character splitting? I believe that since BERT is used, one should follow BERT's tokenizer; even if they don't align, there are ways to handle it. The previous character splitting was due to my lack of familiarity with BERT at the time, leading to non-standard usage that should not be encouraged. Following BERT's tokenizer also has the potential to achieve better fine-tuning results than forced character splitting.
Furthermore, I discovered a somewhat surprising fact: the vocabulary (vocab.txt) provided with the Chinese BERT is incomplete. For example, the character "箓" in "符箓" (fulu/talisman) is not in BERT's vocab.txt. Therefore, when outputting final results, it is best not to use the tokenizer's own decode method, but rather map back to the original sequence and output slices from the original string.
Finally, the training of this version includes a weight moving average, which stabilizes model training and may even slightly improve performance. For an introduction to weight moving averages, please refer here.
Summary
This article has provided an example of using bert4keras to perform triple extraction and pointed out several things worth noting. Everyone is welcome to refer to and try it out.