UniVAE: A Transformer-based Single-Model Multi-Scale VAE Model

By 苏剑林 | June 29, 2021

As everyone knows, the $\mathcal{O}(n^2)$ complexity of the Transformer is one of its "fatal flaws." However, every coin has two sides. The $\mathcal{O}(n^2)$ complexity also provides a lot of room for maneuvering. We can flexibly customize different attention masks to design Transformer models for various purposes, such as UniLM and K-BERT.

This article introduces a "UniVAE" model for text conceived by the author. It follows an idea similar to UniLM, integrating a VAE into a single Transformer model while also incorporating multi-scale characteristics.

UniAE

Variational Autoencoder (VAE) will not be introduced here from scratch; this site already has several articles on it, and you can search for them yourself. A VAE can be understood as an AE (Autoencoder) with a regularization term. In general, the Encoder is responsible for encoding the input into a vector that satisfies a certain distribution, while the Decoder is responsible for reconstructing the input from that encoding vector. Therefore, to implement UniVAE, we must first implement the corresponding UniAE.

In "From Language Models to Seq2Seq: Transformer is all about the Mask," we introduced UniLM (Unified Language Model). It uses the Attention Mask shown in the left figure below to enable the Transformer to perform Seq2Seq tasks. However, UniLM is not the UniAE we are looking for, because the Decoder part of UniLM relates to the entire encoded sequence of the input, rather than a single vector.

UniLM-style Attention Mask

UniLM-style Attention Mask (Left) vs. UniAE-style Attention Mask (Right)

However, we can further adjust the Attention Mask based on UniLM to the mode shown on the right in the figure above. In this way, during decoding, it can only rely on the [CLS] vector of the encoding part and the decoding results completed so far. This is the UniAE-style Attention Mask we are looking for. Because for the input, it only depends on the [CLS] vector, and the size of the [CLS] vector is fixed, it is equivalent to saying that the source information during the generation process is just a fixed-size vector. The input is also encoded into this fixed-size vector, which fulfills the AE function.

UniAE Attention Relationship Schematic

UniAE-style Attention Relationship Schematic

Multi-scale

In other words, through the UniAE-style Attention Mask, we can implement a Seq2Seq model similar to UniLM, which is equivalent to the Encoder encoding the input into a fixed-length vector and then the Decoder decoding that vector. If this is still not clear enough, we can analyze it by decomposing it into an Encoder-Decoder architecture, as shown in the figure below:

Decomposed Encoder-Decoder structure

Decomposed into an Encoder-Decoder structure for understanding

The difference from a conventional Seq2Seq architecture is that the weights of the Encoder and Decoder here are shared. It can also be seen from the figure above that if we add this kind of Mask to every layer of Attention, the Decoder will depend on the [CLS] vector of the input sequence at every layer. This means that if there are $L$ layers of Attention, then the concatenation of all [CLS] vectors from the input sequences of these $L$ layers constitutes the complete encoding vector of the input text (of course, the first layer can be removed because the [CLS] of the first layer is its Embedding vector, which is a constant vector for every input). The [CLS] vector of a single layer is not the complete encoding vector.

For the Decoder, every layer of Attention has a [CLS] vector passed in. This actually forms a multi-scale structure. In Computer Vision (CV), the most advanced generative models are basically multi-scale structures, such as StyleGAN, Glow, NVAE, etc., but this seems rare in NLP. It is not difficult to imagine that in a multi-scale structure, different levels of input variables have different degrees of control over the generation results. Variables closer to the input layer control "inconsequential" parts, while variables closer to the output layer control key information of the generation results. Therefore, ideally, after training a multi-scale model, we can achieve control over different levels of generation results by editing input variables at different levels.

Reducing Dimensionality

Some readers might think: if each layer has a dimension of $d$ and there are $L$ layers, then all [CLS] vectors concatenated together would be $Ld$ dimensions. For BERT-base, that's $12 \times 768 = 9216$ dimensions. Isn't this encoding vector dimension too large? Indeed, for an ordinary AE or VAE, an encoding vector of nearly ten thousand dimensions is too large.

Dimensionality reduction process schematic

Dimensionality reduction process schematic

Actually, the solution is very simple. We only need to use a fully connected layer to reduce the dimension of each layer's [CLS] vector first, and then use another fully connected layer to increase the dimension back before concatenating it with the remaining $(L-1)$ $d$-dimension vectors, as shown in the figure above. In this way, although the input sequence is still $L \times d$ in size, the [CLS] vector can actually be expressed by a much lower-dimensional vector. We only need to concatenate these lower-dimensional vectors from each layer as the total encoding vector.

Encoder-Decoder schematic after dimensionality reduction

Encoder-Decoder schematic after dimensionality reduction

Disentanglement Ability

The previous design and discussion were targeted at ordinary AEs. For VAEs, it involves adding a reparameterization operation to the AE encoding vector and adding a KL divergence term to the loss function. Therefore, technically, once the UniAE is designed, the UniVAE is already designed.

However, in practical operation, there is room for improvement. Theoretically, a well-trained VAE has a certain degree of disentanglement ability, meaning that each dimension of the latent variable is independent and controls a certain aspect of the generation result, which can be adjusted randomly. It's not hard to understand that disentanglement is a very challenging task. If the VAE's Encoder can encode a disentangled vector, its fitting ability must be relatively strong; in other words, its structure needs a certain level of complexity.

Looking back at the Encoder of UniAE, its encoding vector is the concatenation of the [CLS] vectors (or the corresponding low-dimensional vectors) from each layer. For the earlier layers, their [CLS] vectors are only the outputs of a few Transformer layers. Their encoding capacity is very weak and insufficient to encode disentangled vectors. Therefore, using them as latent variables for a VAE is inappropriate.

So, when designing UniVAE in practice, we shouldn't use all [CLS] vectors of the UniAE as the encoding vector. We should set a starting layer number; the Decoder only uses [CLS] vectors from layers greater than this number, while [CLS] vectors from layers less than or equal to this number are not used. This corresponds to using the Attention Mask on the right in the figure below:

Mask combination

UniAE-style Mask near the output; Independent Attention Mask near the input

At this point, it is equivalent to the following Encoder-Decoder structure:

Effect of independent mask in first two layers

Schematic of the effect using Independent Masks in the first two Attention layers

Other Details

So far, the key parts of UniVAE have been introduced. Next, I will share some important details encountered during the implementation process.

First is the issue of length leakage. Whether it's UniLM or UniVAE, because the Encoder and Decoder are integrated into a single model, we concatenate the input and output as a single sample for training. In this way, the starting position of the Decoder part for each sample is different, depending on the length of the input text. This means the input length is also passed as an input condition to the Decoder, which is length leakage.

There are two solutions to this problem: the first is to make all inputs the same length through truncation or padding, so length leakage won't occur; the second is simpler—do nothing. That is, accept that length is input as a condition, and during decoding, control the starting position to control the generation length. However, the problem this might bring is that the length information might not be completely disentangled from the encoding vector, so using the same encoding vector with different lengths might lead to unreasonable results.

Then there is the issue of choosing the number of layers and dimensions. As mentioned earlier, to give latent variables better disentanglement ability, we add Independent Attention Masks to the first $k$ layers of Attention, and UniAE-style Attention Masks to the remaining $L-k$ layers. How should we choose $k$? This is a hyperparameter that requires careful adjustment. A smaller $k$ preserves more information, which is conducive to reconstruction but detrimental to disentanglement; conversely, a larger $k$ is better for disentanglement but detrimental to reconstruction. In the author's experiments, $k=8$ was used.

A similar issue occurs in choosing dimensions for dimensionality reduction. Larger dimensions are naturally conducive to reconstruction but also detrimental to disentanglement, and vice versa. This parameter needs to be specifically adjusted based on the complexity of the task itself. The general direction for adjustment is to observe random sampling effects and reconstruction effects. If most randomly sampled samples are readable and the reconstruction of natural sentences is good, then the dimension is appropriate; otherwise, it needs to be adjusted accordingly.

Finally, it is worth mentioning that the UniAE design can be used not only for VAEs but also for building VQ-VAE. You only need to perform quantization on each [CLS] vector, and it becomes a VQ-VAE model that encodes variable-length sentences into fixed-length discrete sequences.

Reference Implementation

Here is a reference implementation of UniVAE:

Github: https://github.com/bojone/univae

The code uses a vMF-VAE variant, implemented based on bert4keras, with RoFormer as the basic architecture (which can also be replaced with BERT). Below is a demonstration of the effects of UniVAE trained on question sentences.

Random sampling effects:

我在steam下载的游戏,怎样能在电脑上玩啊??? (The game I downloaded on Steam, how can I play it on computer???)

呼市男科医院哪家比较好实惠 (Which andrology hospital in Hohhot is better and affordable?)

我血压高,我妈妈手脚麻木,是怎么回事呀 (I have high blood pressure, and my mother's hands and feet are numb, what is going on?)

怎样查询交通违章记录和处罚 (How to check traffic violation records and penalties?)

为什么我提问的问题有点卡顿 (Why are the questions I asked a bit leggy?)

小米2s用的是移动卡还是联通卡 (Does Xiaomi 2s use Mobile SIM or Unicom SIM?)

幼儿园怎么发展幼儿教育 (How should kindergartens develop early childhood education?)

英国读研学校排名对于英国留学生来说重要吗 (Is the UK graduate school ranking important for students studying in the UK?)

有专业的关于excel表格数据库的培训机构吗? (Are there professional training institutions for Excel sheet databases?)

为什么一到晚上就容易咳嗽,不睡觉就不咳 (Why do I cough easily at night, but don't cough if I don't sleep?)

Reconstruction results:

Original: 数字电视机顶盒坏了,可以免费维修吗 (The digital TV set-top box is broken, can it be repaired for free?)

Reconstruction: 数字电视机顶盒坏了可以换吗? (The digital TV set-top box is broken, can it be replaced?)


Original: 青椒跟什么炒好吃 (What is good to stir-fry with green peppers?)

Reconstruction: 青椒跟什么炒好吃 (What is good to stir-fry with green peppers?)


Original: 王者荣耀carryyou什么意思 (What does carryyou mean in Honor of Kings?)

Reconstruction: 王者荣耀carry芈月什么意思 (What does carry Mi Yue mean in Honor of Kings?)


Original: 没感冒老是咳嗽要吃什么药好 (What medicine is good for a persistent cough without a cold?)

Reconstruction: 没感冒老是咳嗽要吃什么药好 (What medicine is good for a persistent cough without a cold?)


Original: 沁园(金科西城大院店)怎么样,好不好的默认点评 (How is Qinyuan (Jinke Xicheng Dayuan Store), good or bad default review)

Reconstruction: 沁园(金源店)怎么样,好不好的默认点评 (How is Qinyuan (Jinyuan Store), good or bad default review)

Randomly replacing the first 32 dimensions of the latent variables:

Original: 牙龈出血要吃什么药? (What medicine to take for bleeding gums?)

Results: 牙龈出血还出血吃什么消炎药好 (Gums are still bleeding, what anti-inflammatory medicine is good to take?)

牙龈出血吃阿莫西林有效吗 (Is taking Amoxicillin effective for bleeding gums?)

牙龈出血是肝火旺吗? (Is bleeding gums due to excessive "liver fire"?)

牙龈出血去医院检查大概要多少钱? (About how much does it cost to go to the hospital for a check-up for bleeding gums?)

牙龈出血去牙科看什么科室 (What department should I visit at the dentist for bleeding gums?)

牙龈出血去深圳哪里看牙科好 (Where in Shenzhen is a good place for dental services regarding bleeding gums?)


Original: 广州和深圳哪个更好玩? (Which is more fun, Guangzhou or Shenzhen?)

Results: 广州和深圳哪个城市发展得好? 薪资高? (Which city, Guangzhou or Shenzhen, is developed better? Higher salary?)

广州和深圳,哪个发达?深圳到广州的飞机票贵吗? (Guangzhou or Shenzhen, which is more developed? Is the flight ticket from Shenzhen to Guangzhou expensive?)

广州和深圳比哪个好 (Which is better compared to Guangzhou and Shenzhen?)

广州和深圳哪个人均gdp高 (Which has higher GDP per capita, Guangzhou or Shenzhen?)

广州和深圳房价涨幅 (The house price increase in Guangzhou and Shenzhen)

广州和深圳自考一样吗 (Are self-study exams the same in Guangzhou and Shenzhen?)

Randomly replacing the last 16 dimensions of the latent variables:

Original: 牙龈出血要吃什么药? (What medicine to take for bleeding gums?)

Results: 未来21年做什么生意好? (What business is good for the next 21 years?)

湿疹给身体有什么伤害? (What harm does eczema do to the body?)

朗逸现在要买什么配置? (What configuration should I buy for Lavida now?)

马来西亚签证要多少钱? (How much does a Malaysian visa cost?)

早上给孩子吃什么水果好? (What fruit is good for children in the morning?)

头晕发热去医院看什么科? (What department should I visit at the hospital for dizziness and fever?)


Original: 广州和深圳哪个更好玩? (Which is more fun, Guangzhou or Shenzhen?)

Results: 99和98相差多少呢? (How much is the difference between 99 and 98?)

微信和支付宝怎么更换手机号 (How to change the phone number for WeChat and Alipay?)

我的指甲和肉很不一样怎么回事? (My nails and flesh are very different, what is going on?)

吃了甲硝唑多久才能喝酒? (How long after taking Metronidazole can I drink alcohol?)

桂圆和红枣可以一起泡茶吗? (Can longan and red dates be steeped together for tea?)

小米和华为哪个更好点? (Which is better, Xiaomi or Huawei?)

As can be seen, the effects of random sampling and reconstruction are both good. By randomly replacing latent variables of different dimensions, we can roughly observe the effect of the multi-scale structure: replacing dimensions in the earlier part of the latent variables largely keeps the subject words constant; replacing dimensions in the later part largely keeps the sentence structure constant. Of course, the structural nature of natural language itself is weak, so there are usually some exceptions in the examples.

Article Summary

This article introduced the UniVAE design conceived by the author. It follows an idea similar to UniLM, integrating a VAE into a single Transformer model through specific Attention Masks, while also featuring multi-scale characteristics. Besides regular VAE models, this design can also be applied to models like VQ-VAE.