NBCE: Using Naive Bayes to Extend LLM Context Handling Length

By 苏剑林 | May 23, 2023

Playing with Naive Bayes in the era of LLMs?

This might be the first thought for many readers upon seeing the title. Indeed, when ancient Naive Bayes meets cutting-edge LLMs, it produces surprising results—we can directly extend the context handling length of existing LLM models without fine-tuning, regardless of model architecture, with linear efficiency, and results that look quite good. This is the NBCE (Naive Bayes-based Context Extension) method proposed in this article.

Crossing the River by Feeling the Stones

Assume $T$ is the token sequence to be generated, and $S_1, S_2, \cdots, S_n$ are several given, relatively independent Context sets (e.g., $n$ different paragraphs, at least not a single sentence split into two fragments). Suppose their total length has exceeded the training length, while a single $S_k$ plus $T$ remains within the training length. We need to generate $T$ based on $S_1, S_2, \cdots, S_n$, which means estimating $p(T|S_1, S_2, \cdots, S_n)$.

Simply put, Naive Bayes is "Bayes' theorem + independence assumption." According to Bayes' theorem:

\begin{equation}p(T|S_1, S_2,\cdots,S_n) \propto p(S_1, S_2,\cdots,S_n|T)p(T)\end{equation}

Here $\propto$ indicates that constant factors unrelated to $T$ are omitted. According to the (conditional) independence assumption:

\begin{equation}p(S_1, S_2,\cdots,S_n|T) = \prod_{k=1}^n p(S_k|T)\end{equation}

So we have

\begin{equation}p(T|S_1, S_2,\cdots,S_n) \propto p(T)\prod_{k=1}^n p(S_k|T)\end{equation}

Applying Bayes' theorem again, $p(S_k|T) \propto \frac{p(T|S_k)}{p(T)}$, we get

\begin{equation}p(T|S_1, S_2,\cdots,S_n) \propto \frac{1}{p^{n-1}(T)}\prod_{k=1}^n p(T|S_k)\end{equation}

Or in log space:

\begin{equation}\log p(T|S_1, S_2,\cdots,S_n) = \color{red}{\sum_{k=1}^n \log p(T|S_k)} - \color{green}{(n-1)\log p(T)} + \color{skyblue}{\text{constant}}\label{eq:nbce-1}\end{equation}

Here $\color{red}{p(T|S_k)}$ and $\color{green}{p(T)}$ can both be directly calculated using existing LLMs; any language model will work, regardless of architecture, and without the need for long-text fine-tuning. Specifically, $\color{red}{p(T|S_k)}$ is the probability predicted with a single Context, while $\color{green}{p(T)}$ is the probability with no Context (or an empty Context). Multiple Contexts can be placed in the same batch for parallel calculation, and the computational complexity grows linearly with the number of Contexts.

Peeling the Onion

Of course, Naive Bayes relies on the independence assumption, which limits its actual performance. To "surpass the master," we can further "peel the onion" of equation $\eqref{eq:nbce-1}$ to achieve better results.

First, let's denote $\log p(T|S) = [\log p(T|S_1), \cdots, \log p(T|S_n)]$, and

\begin{equation}\overline{\log p(T|S)} = \frac{1}{n}\sum_{k=1}^n \log p(T|S_k)\end{equation}

Setting $\beta = n - 1$, then equation $\eqref{eq:nbce-1}$ can be rewritten as

\begin{equation}\log p(T|S_1, S_2,\cdots,S_n) = \color{red}{(\beta + 1)\overline{\log p(T|S)}} - \color{green}{\beta\log p(T)} + \color{skyblue}{\text{constant}}\label{eq:nbce-2}\end{equation}

Rewriting it in this form naturally leads to two questions:

1. If $\beta$ is treated as a hyperparameter to be tuned, is it possible to achieve better effects?

2. $\overline{\log p(T|S)}$ is just the Average Pooling of $\log p(T|S)$. Would changing it to other Pooling methods (denoted as $\mathcal{P}$) yield better results? i.e., \begin{equation}\log p(T|S_1, S_2,\cdots,S_n) = \color{red}{(\beta + 1)\mathcal{P}[\log p(T|S)]} - \color{green}{\beta\log p(T)} + \color{skyblue}{\text{constant}}\label{eq:nbce-3}\end{equation}

Consequently, I experimented with these two questions on a 7B model. The preliminary conclusion was: in reading comprehension scenarios, Max Pooling matched with $\beta=0.25$ generally performed well using Greedy Search; however, the results from Random Sampling were essentially unreadable.

Final Solution

Why does Greedy Search work well while Random Sampling fails? We know that Random Sampling follows the distribution; its poor performance indicates that the result of Max Pooling is not a reasonable distribution. Greedy Search only cares about the entity with the highest probability and not the rationality of the distribution; its success tells us that the token with the highest probability has high accuracy.

Higher probability indicates lower uncertainty. Therefore, to improve the performance of Random Sampling, we change the Pooling method to directly output the distribution with the lowest uncertainty:

\begin{equation}\begin{aligned} &\mathcal{P}[\log p(T|S)] = \log p(T|S_{\color{red}{k}}) \\[5pt] &\color{red}{k} = \mathop{\text{argmin}} \big\{H_1,H_2,\cdots,H_n\big\} \\[5pt] &H_i = -\sum_T p(T|S_i)\log p(T|S_i) \end{aligned}\end{equation}

Substituting this into equation $\eqref{eq:nbce-3}$ gives the final NBCE (Naive Bayes-based Context Extension).

It is worth pointing out that although our starting point was Naive Bayes, the generalized equation $\eqref{eq:nbce-3}$ already exceeds the scope of conventional Naive Bayes, while retaining its interpretability. It is not hard to see that the form of equation $\eqref{eq:nbce-3}$ is very intuitive:

1. Prediction results from different Contexts are aggregated (or voted) together through method $\mathcal{P}$ (with weight $\beta+1$), and the result without Context is subtracted (with weight $\beta$);

2. The subtraction of the no-Context prediction result is to make the model more inclined to combine the Context rather than answering purely based on its own knowledge base (Note: a paper appearing on Arxiv 3 days later, "Trusting Your Evidence: Hallucinate Less with Context-aware Decoding", proposed the same trick to reduce hallucinations);

3. Different $\beta$ can be chosen for different scenarios. For example, for reading comprehension requiring context integration, a larger $\beta$ can be considered; for creative writing, a smaller $\beta$ can be chosen. I believe $\beta \geq -1$ is reasonable.

Reference Implementation

Below is a reference implementation of NBCE:

Github: https://github.com/bojone/NBCE

From the demo code, it can be seen that the implementation of NBCE is very simple—it only requires modifying the logits construction method in the decoding function, which does not conflict with the choice of decoding algorithm.

Naive Bayes-based Context Extension (NBCE) Diagram

The provided Demo includes 12 different Context segments, totaling over 9,000 characters, inputted into the model along with 8 questions at once (model training length is 2048, parameter size is 7B, available for download at OpenBuddy). The model is able to correctly answer these 8 questions one by one based on the given Context. It is worth noting that all Contexts, questions, and answers combined exceed 10,000 characters! Additionally, some friends have briefly tried applications like resume matching and essay scoring, and the results were also acceptable. It is highly recommended that you try it out yourself.

Related Work

There are already several methods to extend the Context length of LLMs, but most shorten the sample's long context by combining retrieval or summarization, such as Unlimiformer. Since they do not directly process the long context, they usually cannot perform fine-grained reading comprehension, and these solutions often need to be considered during the training phase rather than being plug-and-play into existing LLM models afterwards.

Prior to NBCE, the solution capable of extending Context length without fine-tuning was Parallel Context Window (hereafter PCW), from the papers "Parallel Context Windows for Large Language Models" and "Structured Prompting: Scaling In-Context Learning to 1,000 Examples". Both papers were published around the same time by different authors, proposing methods with only subtle differences, so they are both referred to here as PCW.

PCW is applicable to Self-Attention models. The main modifications include Position Encoding and Attention Mask, as shown in the figure below:

Parallel Context Window

First, the maximum Context length $L$ is determined (6 in the figure). Then, the last position of each Context is encoded as $L-1$, the second to last as $L-2$, and so on. This encoding method is called "right-aligned" (or "left-indented"). On the other hand, for the Task Tokens part (Prompt + generated content), the position encodings are $L, L+1, L+2, \cdots$. Each Context is encoded separately, so the corresponding Attention Mask is a block-diagonal matrix; since it is a Language Model (LM), it is a block-diagonal lower-triangular matrix. As for the Task Tokens part, which needs to combine all Contexts, it needs to attend to all Contexts (and itself). In this way, if each Context is taken out individually and concatenated with the Task Tokens, its Attention pattern is consistent with the original LM.

Some readers might notice that NBCE and PCW share very similar characteristics, such as being unordered and egalitarian regarding Context. In fact, if NBCE is applied to a single-layer, single-head attention model, the result is roughly PCW. To demonstrate this, let's write a single-layer single-head attention language model as:

\begin{equation}p(x_t|x_{< t}) = softmax\left(\sum_{i=1}^t a_{t,i}v_i W\right)\end{equation}

So roughly $\log p(x_t|x_{< t}) \sim \sum\limits_{i=1}^t a_{t,i}v_i W$. Substituting this into equation $\eqref{eq:nbce-2}$ and setting $\beta=0$, we get:

\begin{equation}\log p(T|S_1, S_2,\cdots,S_n) \sim \frac{1}{n}\sum_{k=1}^n\left(\sum_{i\in S_k} a_{T,i}v_i\right) W = \left(\sum_{i\in S_1\oplus\cdots\oplus S_n} \frac{a_{T,i}}{n}v_i\right) W \end{equation}

Here it is assumed that $T$ is a single token, but this does not lose generality; $\oplus$ denotes concatenation. In the above formula, $S_k \oplus T$ is reasoned as a continuous segment (NBCE setup), so their position encodings are adjacent, and $a_{T,i}/n$ constitutes an overall Attention between $T$ and all $S_i$ (the sum is also 1). These characteristics are consistent with PCW. PCW is merely an elegant way to integrate this into every layer via Attention Masking.

Therefore, PCW is roughly the Average Pooling version of NBCE. Our practical tests also found that it shares the same disadvantage as the Average Pooling version of NBCE—as Context data increases, the output starts becoming inaccurate, specifically being topic-related but providing the wrong answer to the question.

Further Thoughts

A major disadvantage of NBCE is its lack of order; it cannot recognize the input order of the Context, which may lead to poor performance in scenarios like continuing a story. To alleviate this, one might consider adding a prefix that indicates order information before each Context, analogous to "Chapter 1" and "Chapter 2" in a novel.

Overall, my current tests on NBCE are limited to "reading comprehension" scenarios—"understanding" long texts. Whether this method can be used to "generate" long texts is still unknown, and I look forward to everyone's test results.

Additionally, an interesting question is:

Since Naive Bayes can be useful in the LLM field, can other traditional probabilistic models (like HMM) also find a place in the LLM field?

Summary

This article proposes NBCE (Naive Bayes-based Context Extension). It is based on the Naive Bayes idea to extend the Context handling length of LLMs. It has the advantages of being plug-and-play, model-agnostic, requiring no fine-tuning, linear efficiency, and simple implementation, and the results appear promising. Everyone is welcome to test it.


Reprint address: https://kexue.fm/archives/9617

If you find this article good, welcome to share or reward this article. Rewards are not intended for profit, but to know how much sincere attention Scientific Spaces has received from readers. Of course, if you ignore it, it will not affect your reading. Welcome and thank you again!

,
        author={苏剑林},
        year={2023},
        month={May},
        url={\url{https://kexue.fm/archives/9617}},
}