By 苏剑林 | May 18, 2020
Some time ago, we released the model weights for a project called SimBERT. It is based on Google's open-source BERT model and uses Microsoft's UniLM ideology to design a task that integrates retrieval and generation. After further fine-tuning, the resulting model possesses both the ability to generate similar questions and retrieve similar sentences. At the time of the release, we only provided a weight file and an example script without further explaining the model's principles and training process. In this article, we will supplement that information.
Open Source Address: https://github.com/ZhuiyiTechnology/simbert
UniLM (Unified Language Model) is a Transformer model that fuses NLU (Natural Language Understanding) and NLG (Natural Language Generation) capabilities. It was proposed by Microsoft in May last year, and in February this year, it was upgraded to v2. Our previous article "From Language Models to Seq2Seq: Transformer is All About the Mask" briefly introduced UniLM, and it has already been integrated into bert4keras.
The core of UniLM is the use of a special Attention Mask to grant the model Seq2Seq capabilities. For example, if the input is "What do you want to eat?" and the target sentence is "White-cut chicken," UniLM concatenates these two sentences into one: [CLS] What do you want to eat [SEP] White-cut chicken [SEP], and then applies the Attention Mask as shown in the figure:
In other words, the tokens in "[CLS] What do you want to eat [SEP]" have bidirectional Attention between them, while the tokens in "White-cut chicken [SEP]" have unidirectional Attention. This allows the model to recursively predict the "White-cut chicken [SEP]" tokens, giving it text generation capabilities.
Schematic of UniLM as a Seq2Seq model. Internal bidirectional Attention for the input part, and only unidirectional Attention for the output part.
Seq2Seq only explains UniLM's NLG capability. Why did we say it possesses both NLU and NLG capabilities? Because of UniLM's special Attention Mask, the first 6 tokens "[CLS] What do you want to eat [SEP]" only perform Attention among themselves and have nothing to do with "White-cut chicken [SEP]". This means that although the latter is appended, it does not affect the encoding vectors of the first 6 tokens. To be clearer, the first 6 encoding vectors are equivalent to the encoding results of just "[CLS] What do you want to eat [SEP]". If the [CLS] vector represents the sentence vector, then it is the sentence vector for "What do you want to eat," rather than the vector for the sentence combined with "White-cut chicken."
Due to this characteristic, random [MASK] tokens can also be added to the input part, allowing the input part to perform the MLM (Masked Language Model) task while the output part performs the Seq2Seq task. MLM enhances NLU capabilities, and Seq2Seq enhances NLG capabilities, achieving two goals at once.
Once you understand UniLM, it is not difficult to understand the training method of SimBERT. SimBERT belongs to supervised training. The training corpus consists of collected pairs of similar sentences. The Seq2Seq part is constructed through a similar sentence generation task where one sentence predicts the other. As mentioned before, the [CLS] vector effectively represents the input's sentence vector, so it can be simultaneously used to train a retrieval task, as shown below:
Schematic of SimBERT training method
Assume SENT_a and SENT_b are a pair of similar sentences. In the same batch, both "[CLS] SENT_a [SEP] SENT_b [SEP]" and "[CLS] SENT_b [SEP] SENT_a [SEP]" are added to the training to perform a similar sentence generation task; this is the Seq2Seq part.
On the other hand, the [CLS] vectors for the entire batch are extracted to obtain a sentence vector matrix \(\boldsymbol{V}\in\mathbb{R}^{b\times d}\) (where \(b\) is batch_size and \(d\) is hidden_size). Then, \(l_2\) normalization is applied to the \(d\) dimension to obtain \(\tilde{\boldsymbol{V}}\). Pairwise dot products are computed to obtain a \(b\times b\) similarity matrix \(\tilde{\boldsymbol{V}}\tilde{\boldsymbol{V}}^{\top}\). This is then multiplied by a scale (we used 30), and the diagonal is masked out. Finally, a softmax is performed on each row to train it as a classification task, where the target label for each sample is its corresponding similar sentence (as the self-similarity has been masked). Essentially, all non-similar samples within the batch are treated as negative samples, and softmax is used to increase the similarity of the similar samples while decreasing the similarity of the others.
Ultimately, the key is that "the [CLS] vector effectively represents the input sentence vector," so it can be used for NLU-related tasks. The final loss is the sum of the Seq2Seq loss and the similarity sentence classification loss.
Since the source code has been released, more training details can be found by reading the code. The model is implemented using keras + bert4keras. The code is very clear, so most doubts can be resolved by reading it.
Effect demonstration:
# Example usage:
# >>> gen_synonyms(u'Which is better to use, WeChat or Alipay?')
# [
# u'Which one is better, WeChat or Alipay?',
# u'Alipay and WeChat, which is easier to use?',
# u'Which one is more useful, WeChat or Alipay?',
# u'Between WeChat and Alipay, which one do you prefer?',
# ...
# ]
Many readers may be concerned about the training data. Here is a unified answer: Since the training data cannot be publicly disclosed and it is inconvenient to share privately, please do not ask about the data. The data source was scraped from similar questions recommended by Baidu Zhidao, which were then filtered by simple algorithms. If readers have a large number of question sentences, they can also use common retrieval algorithms to retrieve some similar sentences to serve as training data. In short, there are no strictly rigid requirements for training data; theoretically, any data with a degree of similarity can be used.
Regarding training hardware, the open-sourced model was trained on a single TITAN RTX (22G VRAM, batch_size=128) for about 4 days. VRAM and time are not strictly fixed requirements and depend on your actual situation. If VRAM is smaller, simply reduce the batch_size accordingly. If the corpus itself is not very large, the training time does not need to be that long (generally enough to iterate through the dataset a few times).
That's all I can think of for now. If there are any other questions, feel free to leave a comment for discussion.
This article introduced the training principles of the previously released SimBERT model and open-sourced the training code. SimBERT, trained based on UniLM concepts, possesses both retrieval and generation capabilities. Everyone is welcome to use and test it!