Hierarchical Decomposition Position Encoding: Enabling BERT to Handle Ultra-Long Text

By 苏剑林 | December 04, 2020

As we all know, current mainstream BERT models can handle a maximum of 512 tokens. The root cause of this bottleneck is that BERT uses absolute position encodings trained from random initialization. Usually, the maximum position is set to 512, thus it can only handle 512 tokens at most—any additional parts have no position encodings available. Of course, another important reason is the $\mathcal{O}(n^2)$ complexity of Attention, which significantly increases VRAM usage for long sequences, making fine-tuning impossible on typical GPUs.

This article primarily addresses the former cause. Assuming one has sufficient VRAM, how can we simply modify a BERT model with an existing maximum length of 512 so that it can directly process longer texts? The main idea is to hierarchically decompose the pre-trained absolute position encodings, allowing them to extend to longer positions.

Position Encoding

BERT uses learned absolute position encodings. This encoding method is simple and effective, but since each position vector is learned by the model itself, we cannot infer encoding vectors for other positions, hence the length limit.

A mainstream approach to solving this is switching to relative position encoding. This is a viable method; for instance, Huawei's NEZHA model is a BERT variant that uses relative position encoding. Relative position encoding generally truncates position differences to keep them within a finite range, thus it is not restricted by sequence length. However, relative position encoding is not a perfect solution. First, methods like NEZHA's increase computational overhead (unlike T5's approach). Second, Linear Attention cannot use relative position encoding, making it less universal.

Readers might recall that "Attention is All You Need" proposed a Sinusoidal absolute position encoding using $\sin$ and $\cos$. Wouldn't using that directly remove the length limit? Theoretically, yes, but the problem is that there aren't many open-sourced models using Sinusoidal position encoding. Do we really have to train a model from scratch? That is clearly unrealistic.

Hierarchical Decomposition

Therefore, under limited resources, the most ideal solution is to find a way to extend the pre-trained BERT position encodings without retraining the model. Below, I present a hierarchical decomposition scheme I devised.

Schematic of Hierarchical Decomposition of Position Encoding
Schematic diagram of hierarchical decomposition of position encoding

Specifically, let the pre-trained absolute position encoding vectors be $\boldsymbol{p}_1, \boldsymbol{p}_2, \cdots, \boldsymbol{p}_n$. We hope to construct a new set of encoding vectors $\boldsymbol{q}_1, \boldsymbol{q}_2, \cdots, \boldsymbol{q}_m$ based on these, where $m > n$. To do this, we set:

\begin{equation}\boldsymbol{q}_{(i-1)\times n + j} = \alpha \boldsymbol{u}_i + (1 - \alpha) \boldsymbol{u}_j\label{eq:fenjie}\end{equation}

where $\alpha\in (0, 1)$ and $\alpha\neq 0.5$ is a hyperparameter, and $\boldsymbol{u}_1, \boldsymbol{u}_2, \cdots, \boldsymbol{u}_n$ are the "bases" of this set of position encodings. The meaning of this representation is very clear: the position $(i - 1)\times n + j$ is hierarchically represented as $(i, j)$, and the corresponding position encodings for $i$ and $j$ are $\alpha \boldsymbol{u}_i$ and $(1 - \alpha) \boldsymbol{u}_j$ respectively. The final encoding vector for $(i - 1)\times n + j$ is the superposition of the two. We require $\alpha\neq 0.5$ to distinguish between $(i, j)$ and $(j, i)$.

We want the position vectors to remain the same as the original ones when they do not exceed $n$, ensuring compatibility with the pre-trained model. In other words, we want $\boldsymbol{q}_1=\boldsymbol{p}_1, \boldsymbol{q}_2=\boldsymbol{p}_2, \cdots, \boldsymbol{q}_n=\boldsymbol{p}_n$. From this, we can derive each $\boldsymbol{u}_i$:

\begin{equation}\boldsymbol{u}_i = \frac{\boldsymbol{p}_i - \alpha\boldsymbol{p}_1}{1 - \alpha},\quad i = 1,2,\cdots,n\end{equation}

In this way, our parameters are still $\boldsymbol{p}_1, \boldsymbol{p}_2, \cdots, \boldsymbol{p}_n$, but we can represent $n^2$ position encodings, and the first $n$ position encodings are compatible with the original model.

Self-Analysis

In fact, after understanding it, readers might feel that this decomposition has little technical depth and is just the result of a "brainstorm"? That is indeed true.

As for why this might be effective? First, due to the strong interpretability of hierarchical decomposition, we can estimate that our results have certain extrapolation capabilities; at the very least, they provide a good initialization for positions greater than $n$. Second, the experiment in the next section validates it, as experimentation is the only standard for proving a trick effective. Essentially, what we have done is simple: we've constructed a position encoding extension scheme that is compatible with the first $n$ encodings and can extrapolate to more positions, leaving the rest for the model to adapt to. There are infinite ways to do this; I chose this one because I found its explainability relatively strong, providing one possibility—not necessarily the optimal or guaranteed solution.

Additionally, let's discuss the selection of $\alpha$. My default choice is $\alpha=0.4$. Theoretically, any $\alpha\in (0, 1)$ where $\alpha\neq 0.5$ holds, but from a practical standpoint, it is recommended to choose a value $0 < \alpha < 0.5$. Because we rarely encounter sequences tens of thousands of tokens long, being able to handle 2048 is already quite "luxurious" for personal GPUs. If $n=512$, this means $i = 1, 2, 3, 4$ and $j=1,2,\cdots,512$. If $\alpha > 0.5$, then according to equation $\eqref{eq:fenjie}$, $\alpha \boldsymbol{u}_i$ would dominate, making the differences between position encodings smaller (since there are only 4 candidates for $i$). This makes it harder for the model to distinguish positions and slows down convergence. If $\alpha < 0.5$, then $(1-\alpha) \boldsymbol{u}_j$ dominates, providing better position discrimination (512 candidates for $j$), which helps the model converge faster.

Practical Testing

In summary, we can extend BERT's absolute position encoding with almost zero cost, allowing its maximum length to reach $n^2=512^2=262,144 \approx 260,000$ tokens! This should definitely meet our needs. This modification has been integrated into bert4keras >= 0.9.5. Users only need to pass the parameter hierarchical_position=True in build_transformer_model to enable it. True can also be replaced with a floating-point number between 0 and 1, representing the value of $\alpha$ mentioned above (defaults to 0.4 when `True`).

Regarding effectiveness, I first tested the MLM task. I directly set the maximum length to 1536 and loaded the pre-trained RoBERTa weights. I found that the MLM accuracy was around 38% (if truncated to 512, it's about 55%). After fine-tuning, the accuracy quickly recovered (around 3000 steps) to over 55%. This result indicates that position encodings extended this way are effective for the MLM task. If you have spare computing power, it's better to continue pre-training on MLM for a while before doing other tasks. At the same time, we experimented with different $\alpha$ values, showing that $\alpha=0.4$ is indeed a good default value, as shown in the figure below.

MLM Training Accuracy under Different Alphas
MLM training accuracy under different alpha values

Then I tested two long-text classification problems, setting the length to 512 and 1024 respectively while keeping other parameters constant for fine-tuning (direct fine-tuning, without initial MLM pre-training). In one dataset, there was no obvious change; in the other, the 1024-length model performed about 0.5% better on the validation set than the 512-length model. This again indicates that the hierarchical decomposition position encoding proposed in this article works. So, if you have a GPU with enough VRAM, give it a try—especially for long-text sequence labeling tasks, it seems quite suitable. In bert4keras, it's just one extra line of code; if it improves, you win; if not, you haven't wasted much energy. Everyone is welcome to report their test results.

Finally, I provide a reference table for maximum sequence length and maximum batch_size during the training stage (RoBERTa Base version, 24G TITAN RTX):

\[\begin{array}{c|c} \hline \text{Sequence Length} & \text{batch\_size}\\ \hline 512 & 22\\ 1024 & 9\\ 1536 & 5\\ \hline \end{array}\]

From this table, we can see that when the sequence length doubles, the VRAM consumption also doubles (slightly more); it doesn't seem to follow the legendary $\mathcal{O}(n^2)$ complexity. In fact, $\mathcal{O}(n^2)$ is relative to "long enough" sequences—meaning thousands or tens of thousands. For sequences not exceeding 2048, BERT's complexity is still nearly linear. Therefore, in these scenarios, using "BERT + extended position encoding" is much more convenient than designs like "Split sentences + BERT + LSTM".

Summary

In this article, I shared a hierarchical decomposition scheme for extending position encodings. Through this extension, BERT can theoretically handle text lengths up to 260,000. As long as VRAM is sufficient, there is no long text that BERT cannot handle.

So, are you ready with your VRAM?