Modifying Transformer Architecture to Design a Faster and Better MLM Model

By 苏剑林 | August 07, 2020

As is well-known, MLM (Masked Language Model) is the pre-training method for BERT and RoBERTa. As the name suggests, it involves masking some tokens from the original sequence and then letting the model predict these masked tokens. As research has deepened, it has been discovered that MLM is not only valuable as a pre-training method but also has a wealth of practical applications. For instance, I previously found that directly loading the MLM weights of BERT allows it to be used as a UniLM for Seq2Seq tasks (refer here). Another example is the paper published in ACL 2020, "Spelling Error Correction with Soft-Masked BERT", which applies the MLM model to text error correction.

However, anyone who has read the BERT paper carefully or tried it personally knows that the training efficiency of the original MLM is relatively low. This is because only a small fraction of tokens are masked in each pass. The ACL 2020 paper "Fast and Accurate Deep Bidirectional Language Representations for Unsupervised Learning" also addresses this issue and proposes a new MLM model design that offers higher training efficiency and better performance.

MLM Model

Suppose the original sequence is $\boldsymbol{x}=[x_1,x_2,\dots,x_T]$, and $\boldsymbol{x}\backslash \{x_i\}$ represents the sequence where the $i$-th token has been replaced by $\text{[MASK]}$. Then the MLM model is modeling: \begin{equation}p\big(x_i, x_j, x_k, \dots\big\|\,\boldsymbol{x}\backslash \{x_i,x_j,x_k,\dots\}\big)\end{equation} We say its efficiency is low because only a small portion of tokens can be chosen for masking each time—for example, 15%. This means only 15% of the tokens in each sample are used for training, requiring the same sample to be trained multiple times. In BERT, each sample is masked multiple times and saved as a tfrecord, which is inefficient and increases disk space usage.

MLM Task Illustration
MLM Task Illustration

If every token in a sample could serve as a prediction target during training, efficiency would naturally improve. While unidirectional language models like GPT can achieve this, MLM is a bidirectional model and cannot do so directly. To reach this goal, we need to simplify the above formula. Suppose we mask only one token at a time; the distribution we want to construct is: \begin{equation}p\big(x_i\big\|\,\boldsymbol{x}\backslash \{x_i\}\big),\,i=1,2,\dots,T\end{equation} We then hope to obtain $p(x_1\|\,\boldsymbol{x}\backslash \{x_1\}),p(x_2\|\,\boldsymbol{x}\backslash \{x_2\}),\dots,p(x_T\|\,\boldsymbol{x}\backslash \{x_T\})$ simultaneously through a single model pass. How can this be achieved? This brings us to the results of the paper introduced in this article, which proposes a design called T-TA (Transformer-based Text Autoencoder) that allows us to predict the distributions of all tokens at once.

T-TA Introduction

T-TA Attention Mask Pattern
T-TA Attention Mask Pattern

First, we know that the core operation of the Transformer is $Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})$. In BERT, $\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}$ are all the same, i.e., Self-Attention. In MLM, since we are modeling $p(x_i\|\,\boldsymbol{x}\backslash \{x_i\})$, the $i$-th output must not contain information from the $i$-th token. Therefore, the first modification is to remove the token input from $\boldsymbol{Q}$. That is to say, the $\boldsymbol{Q}$ in the first layer of Attention cannot contain token information; it can only contain position vectors. This is because we aggregate information from $\boldsymbol{K}$ and $\boldsymbol{V}$ through $\boldsymbol{Q}$; if $\boldsymbol{Q}$ itself contains token information, it causes data leakage. Furthermore, we must prevent information leakage from $\boldsymbol{K}$ and $\boldsymbol{V}$. This requires modifying the Attention Mask to mask out the diagonal parts (the token's own attention), as shown in the figure.

If this is still unclear, we can understand it from the general form of Attention. The general definition of Attention is: \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\label{eq:gen-att}\end{equation} Clearly, $Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i$ must be associated with $\boldsymbol{q}_i$, so $\boldsymbol{q}_i$ absolutely cannot contain information about the $i$-th token. However, it does not necessarily have to be associated with $\boldsymbol{k}_i,\boldsymbol{v}_i$, because if $\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_i)=0$, then $\boldsymbol{k}_i,\boldsymbol{v}_i$ effectively cease to exist. Thus, we need to mask the diagonal portion of the Attention.

However, this leak-proof Attention Mask can only be maintained for one layer! This means even if you do this, once $Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_j$ has integrated information from the $i$-th token, leakage will occur starting from the second layer if you continue to use the previous layer's output as $\boldsymbol{K}$ and $\boldsymbol{V}$, even with the aforementioned Attention Mask.

The original paper's solution is crude but seems to be the only way: every Attention layer shares the original input as $\boldsymbol{K}$ and $\boldsymbol{V}$! Thus, let $\boldsymbol{E}$ be the token embedding sequence and $\boldsymbol{P}$ be the corresponding position vectors. The calculation processes for T-TA and BERT can be simplified as: \begin{equation} \begin{array}{c}\bbox[border: 1px dashed red; padding: 5px]{\begin{aligned}&\boldsymbol{Q}_0 = \boldsymbol{E}+\boldsymbol{P}\\ &\boldsymbol{Q}_1 = Attention(\boldsymbol{Q}_0,\boldsymbol{Q}_0,\boldsymbol{Q}_0) \\ &\boldsymbol{Q}_2 = Attention(\boldsymbol{Q}_1,\boldsymbol{Q}_1,\boldsymbol{Q}_1) \\ &\qquad\vdots\\ &\boldsymbol{Q}_n = Attention(\boldsymbol{Q}_{n-1},\boldsymbol{Q}_{n-1},\boldsymbol{Q}_{n-1}) \end{aligned}} \\ \text{BERT Schematic}\quad\end{array}\qquad \begin{array}{c}\bbox[border: 1px dashed red; padding: 5px]{\begin{aligned}&\boldsymbol{Q}_0 = \boldsymbol{P}\\ &\boldsymbol{Q}_1 = Attention(\boldsymbol{Q}_0,\boldsymbol{E}+\boldsymbol{P},\boldsymbol{E}+\boldsymbol{P}) \\ &\boldsymbol{Q}_2 = Attention(\boldsymbol{Q}_1,\boldsymbol{E}+\boldsymbol{P},\boldsymbol{E}+\boldsymbol{P}) \\ &\qquad\vdots\\ &\boldsymbol{Q}_n = Attention(\boldsymbol{Q}_{n-1},\boldsymbol{E}+\boldsymbol{P},\boldsymbol{E}+\boldsymbol{P}) \end{aligned}} \\ \text{T-TA Schematic}\quad\end{array}\end{equation} Of course, details like residuals and FFN are omitted, retaining only the core operations. During the pre-training phase, T-TA's Attention employs a diagonal Attention Mask. For downstream task fine-tuning, this can be removed.

Experimental Results

One of the experimental tables from the original paper
One of the experimental tables from the original paper. It shows that T-TA has unique advantages in semantic representation.

Based on this design, T-TA can predict all tokens at once, making training efficient. Moreover, it does not require an additional $\text{[MASK]}$ symbol, thus achieving consistency between pre-training and fine-tuning. However, it is not hard to see that T-TA is essentially a simplification of the standard Transformer, so theoretically, its fitting capacity is weakened. Between this "give and take," is there still an improvement in performance? Naturally, the paper's experimental results say yes. The original paper conducted several experiments, and the results show that the T-TA design can generally match or even exceed the performance of models trained with standard MLM with the same number of parameters. The authors also generously open-sourced their code to allow for duplication of results (Link).

When it comes to modifying the Transformer structure, one might imagine massive amounts of GPUs and TPUs running in parallel. But in fact, although the authors did not list their experimental equipment in detail, the paper suggests the setup was not "luxurious." For this reason, the authors only trained a 3-layer T-TA and reproduced a 3-layer MLM and GPT (unidirectional language model) in the same pattern for comparison. That's right—all T-TA results in the paper are from only 3-layer models, some of which even outperformed the Base version of BERT. Thus, the authors vividly taught us a lesson: you don't need "tycoon" equipment to work on modifying the Transformer structure or to publish in ACL; the key is having a truly effective idea.

Personal Analysis

Finally, let's briefly discuss why T-TA is effective. Readers might question how effectiveness can be guaranteed at deeper layers if the authors only performed 3-layer experiments. Well, let's look at this model from another perspective.

From the design, for T-TA, once the input is given, $\boldsymbol{K}$ and $\boldsymbol{V}$ remain constant across all Attention layers, and only $\boldsymbol{Q}$ changes. It is unsurprising that readers might doubt its effectiveness. However, don't forget that Google recently proposed Synthesizer (refer to "Google's New Synthesizer: We Don't Understand Self-Attention Well Enough"), which explored several Attention variants. One variant, abbreviated as "R", equivalent to having $\boldsymbol{Q}$ and $\boldsymbol{K}$ fixed as constants, actually worked quite well! Note that in "R", $\boldsymbol{Q}$ and $\boldsymbol{K}$ are absolute constants with no connection to the input.

Therefore, since results are decent when $\boldsymbol{Q}$ and $\boldsymbol{K}$ are constants, why can't $\boldsymbol{K}$ and $\boldsymbol{V}$ be constants? Furthermore, the $\boldsymbol{K}$ and $\boldsymbol{V}$ in T-TA are dynamically dependent on the input (they are only constant once the input is determined). Thus, theoretically, T-TA's fitting capacity is stronger than the Synthesizer "R" model. Since "R" can perform well, it is not surprising that T-TA can too.

Of course, one hopes for results from deeper experiments in the future.