GPLinker: Entity-Relation Joint Extraction based on GlobalPointer

By 苏剑林 | January 30, 2022

Nearly three years ago, during Baidu's "2019 Language and Intelligence Technology Competition" (hereafter referred to as LIC2019), I proposed a new relation extraction model (refer to "A Lightweight Information Extraction Model Based on DGCNN and Probabilistic Graphs"), which was later further published and named "CasRel," and was considered the SOTA for relation extraction at the time. However, when CasRel was proposed, I was actually new to the field, so in hindsight, CasRel still had many imperfections. I had thought about further improving it later, but hadn't come up with a particularly good design.

Later, I proposed GlobalPointer and the recent Efficient GlobalPointer, feeling that I now had sufficient "materials" to construct a new relation extraction model. Starting from the probabilistic graph perspective and referencing some SOTA designs following CasRel, I finally arrived at a model version similar to TPLinker.

Basic Idea

At first glance, relation extraction is the extraction of triplets $(s, p, o)$ (i.e., subject, predicate, object). However, in specific implementation, it is actually the extraction of "quintuplets" $(s_h, s_t, p, o_h, o_t)$, where $s_h, s_t$ are the start and end positions of $s$, and $o_h, o_t$ are the start and end positions of $o$.

From the perspective of a probabilistic graph, we can construct the model as follows:

1. Design a scoring function for quintuplets $S(s_h, s_t, p, o_h, o_t)$;
2. During training, ensure that labeled quintuplets $S(s_h, s_t, p, o_h, o_t) > 0$, and other quintuplets $S(s_h, s_t, p, o_h, o_t) < 0$;
3. During prediction, enumerate all possible quintuplets and output the parts where $S(s_h, s_t, p, o_h, o_t) > 0$.

However, directly enumerating all quintuplets yields too many. Assuming the sentence length is $l$ and the total number of $p$ is $n$, even with the constraints $s_h \leq s_t$ and $o_h \leq o_t$, the total number of quintuplets is:

\begin{equation}n\times \frac{l(l+1)}{2}\times \frac{l(l+1)}{2}=\frac{1}{4}nl^2(l+1)^2\end{equation}

This is a computational complexity on the order of the fourth power of the length, which is difficult to implement in practice, so some simplification is necessary.

Simplified Decomposition

Given current computing power, we can generally only accept computational complexity on the order of the square of the length. Therefore, we can at most identify "a pair" of start or end positions at a time. To this end, we can use the following decomposition:

\begin{equation}S(s_h,s_t,p,o_h,o_t) = S(s_h,s_t) + S(o_h,o_t) + S(s_h,o_h| p) + S(s_t, o_t| p)\label{eq:factor}\end{equation}

Note that this equation is a model assumption based on our understanding of the task and hardware constraints, rather than a theoretical derivation. Each term has an intuitive meaning: for example, $S(s_h, s_t)$ and $S(o_h, o_t)$ are the start-end scores for the subject and object, respectively, used to extract all subjects and objects by $S(s_h, s_t) > 0$ and $S(o_h, o_t) > 0$. As for the latter two terms, they are for predicate matching. $S(s_h, o_h|p)$ represents a match using the start features of the subject and object as their own representations. If we can ensure there are no nested entities within the subject or object, then theoretically $S(s_h, o_h|p) > 0$ is sufficient to extract all predicates. However, considering the possibility of nested entities, we also perform a match on the end positions of the entities, i.e., the $S(s_t, o_t|p)$ term.

At this point, the training and prediction process becomes:

1. During training, let the labeled quintuplets satisfy $S(s_h,s_t) > 0, S(o_h,o_t) > 0, S(s_h,o_h| p) > 0, S(s_t, o_t| p) > 0$, and other quintuplets satisfy $S(s_h,s_t) < 0, S(o_h,o_t) < 0, S(s_h,o_h| p) < 0, S(s_t, o_t| p) < 0$;
2. During prediction, enumerate all possible quintuplets, sequentially outputting parts where $S(s_h,s_t) > 0, S(o_h,o_t) > 0, S(s_h,o_h| p) > 0, S(s_t, o_t| p) > 0$, and then take their intersection as the final output (i.e., satisfying all 4 conditions simultaneously).

In implementation, since $S(s_h, s_t)$ and $S(o_h, o_t)$ are used to identify entities corresponding to subject and object, it is equivalent to an NER task with two entity types, so we can use one GlobalPointer to complete it. As for $S(s_h, o_h|p)$, it is used to identify $(s_h, o_h)$ pairs with predicate $p$. Unlike NER, there is no $s_h \leq o_h$ constraint here; we similarly use GlobalPointer to complete this, but to identify the $s_h > o_h$ part, we need to remove GlobalPointer's default lower-triangular mask. Finally, $S(s_t, o_t|p)$ is handle similarly to $S(s_h, o_h|p)$.

A quick recap: we know that as an NER module, GlobalPointer can uniformly identify nested and non-nested entities, which it achieves based on token-pair recognition. Therefore, we should further understand GlobalPointer as a token-pair recognition model rather than limiting it to the scope of NER. Once this is recognized, we realize that $S(s_h,s_t), S(o_h,o_t), S(s_h,o_h| p), S(s_t, o_t|p)$ can all be implemented with GlobalPointer. Whether to add a lower-triangular mask can be set according to the specific task background.

Loss Function

Now that we have designed the scoring functions, the only thing left for training is the loss function. Here, we continue to use the multi-label cross-entropy proposed in "Generalizing 'Softmax + Cross-Entropy' to Multi-Label Classification Problems", which is the default for GlobalPointer. Its general form is:

\begin{equation}\log \left(1 + \sum\limits_{i\in \mathcal{P}} e^{-S_i}\right) + \log \left(1 + \sum\limits_{i\in \mathcal{N}} e^{S_i}\right)\label{eq:loss-1}\end{equation}

where $\mathcal{P}, \mathcal{N}$ are the sets of positive and negative categories, respectively. In previous articles, we used "multi-hot" vectors to mark positive and negative categories; if the total number of categories is $K$, we use a $K$-dimensional vector where the positions of positive classes are 1 and negative classes are 0. However, in the scenarios of $S(s_h,o_h| p)$ and $S(s_t, o_t|p)$, we each need an $n \times l \times l$ matrix for labeling. Adding them together and including the batch_size, the total dimension is $2bnl^2$. Taking $b=64, n=50, l=128$ as an example, then $2bnl^2 \approx 100$ million. This means that if we insist on using the "multi-hot" form to represent labels, we would have to create a matrix with 100 million parameters at each training step and transfer it to the GPU, making both creation and transmission costs very high.

Therefore, to improve training speed, we need to implement a "sparse version" of multi-label cross-entropy, meaning we only transmit the indices corresponding to positive classes. Since positive classes are far fewer than negative classes, the size of the label matrix is greatly reduced. Implementing the "sparse version" means we need to implement Equation \eqref{eq:loss-1} knowing only $\mathcal{P}$ and $\mathcal{A}=\mathcal{P}\cup\mathcal{N}$. For this, the implementation we use is:

\begin{equation}\begin{aligned} &\,\log \left(1 + \sum\limits_{i\in \mathcal{N}} e^{S_i}\right) = \log \left(1 + \sum\limits_{i\in \mathcal{A}} e^{S_i} - \sum\limits_{i\in \mathcal{P}} e^{S_i}\right) \\ =&\, \log \left(1 + \sum\limits_{i\in \mathcal{A}} e^{S_i}\right) + \log \left(1 - \left(\sum\limits_{i\in \mathcal{P}} e^{S_i}\right)\Bigg/\left(1 + \sum\limits_{i\in \mathcal{A}} e^{S_i}\right)\right) \end{aligned}\end{equation}

If we let $a = \log \left(1 + \sum\limits_{i\in \mathcal{A}} e^{S_i}\right)$ and $b = \log \left(\sum\limits_{i\in \mathcal{P}} e^{S_i}\right)$, it can be written as:

\begin{equation}\log \left(1 + \sum\limits_{i\in \mathcal{N}} e^{S_i}\right) = a + \log\left(1 - e^{b - a}\right)\end{equation}

In this way, the loss for the negative classes is calculated using $\mathcal{P}$ and $\mathcal{A}$, while the loss for the positive class part remains unchanged.

Finally, in general multi-label classification tasks, the number of positive classes is indefinite. In this case, we can start the class indices from 1 and use 0 as a padding label to keep the label matrix size consistent for each sample, and finally apply masking for class 0 in the loss implementation. The corresponding implementation is built into bert4keras; for details, refer to "sparse_multilabel_categorical_crossentropy".

Experimental Results

For ease of reference, we temporarily refer to the above model as GPLinker (GlobalPointer-based Linking). A reference implementation based on bert4keras is as follows:

Script Link: task_relation_extraction_gplinker.py

The experimental results on LIC2019 are as follows (the code for CasRel is task_relation_extraction.py):

\begin{array}{c|c} \hline \text{Model} & \text{F1} \\ \hline \text{CasRel} & 0.8220 \\ \text{GPLinker (Standard)} & 0.8272\\ \text{GPLinker (Efficient)} & 0.8268\\ \hline \end{array}

The pre-trained model is BERT base. The difference between Standard and Efficient is the use of the Standard GlobalPointer and the Efficient GlobalPointer, respectively. This experimental result demonstrates two things: first, GPLinker is indeed more effective than CasRel; second, the design of Efficient GlobalPointer can indeed rival the effect of the standard GlobalPointer with fewer parameters. Note that in the LIC2019 task, if the standard GlobalPointer is used, GPLinker's parameter count is nearly 10 million, whereas using Efficient GlobalPointer it is only about 300,000.

In addition, on a 3090 GPU, compared to the "multi-hot" version of the multi-label cross-entropy, the model using the sparse version of multi-label cross-entropy can increase training speed by 1.5 times without losing accuracy. Compared with CasRel, GPLinker using the sparse multi-label cross-entropy is only 15% slower in training, but nearly twice as fast in decoding, making it both fast and effective.

Related Work

For students who follow the progress of SOTA models in relation extraction over the past two years, once they understand the above model, they will find it very similar to TPLinker. Indeed, the model was designed with significant reference to TPLinker, and the final results are also very similar to TPLinker.

Broadly speaking, the differences between TPLinker and GPLinker are as follows:

1. TPLinker's token-pair classification features are obtained by concatenating start and end features followed by a Dense transformation, which originates from Additive Attention. GPLinker is implemented with GlobalPointer, originating from Scaled Dot-Product Attention. On average, the latter has lower memory occupancy and faster computation speed.
2. GPLinker identifies subject and object entities separately, while TPLinker mixes subject and object for unified identification. I also tried mixed identification in GPLinker and found no significant difference in final performance compared to separate identification.
3. In $S(s_h,o_h|p)$ and $S(s_t,o_t|p)$, TPLinker converts them into $l(l+1)/2$ 3-classification problems, which has obvious class imbalance issues; GPLinker uses the multi-label cross-entropy I proposed, so there is no imbalance issue, making it easier to train. In fact, TPLinker later realized this problem and proposed TPLinker-plus, which also utilizes this multi-label cross-entropy.

Of course, in my view, the main contribution of this article is not proposing these changes in GPLinker, but rather providing a "top-down" understanding of relation joint extraction models: starting from the initial quintuplet scoring $S(s_h,s_t,p,o_h,o_t)$, analyzing its difficulties, and then "dividing and conquering" by simplifying Equation \eqref{eq:factor}. I hope this top-down understanding process can provide some ideas for readers when designing models for more complex tasks.

Summary

In this post, I shared an entity-relation joint extraction model based on GlobalPointer—"GPLinker"—and provided a "top-down" derivation for reference.