Revisiting Random Tokenization: From Viterbi Sampling to Perfect Sampling Algorithms

By 苏剑林 | October 16, 2023

In the article "An Exploration of Random Tokenization: From Viterbi Decoding to Viterbi Sampling," I proposed a random tokenization algorithm called "Viterbi Sampling." It is a minor modification of the Viterbi Decoding algorithm, which seeks the optimal solution. It retains the simple and fast characteristics of the Viterbi algorithm and is significantly more efficient than existing methods like Subword Regularization. However, a reader on Zhihu, @鶴舞, pointed out that the current sampling algorithm might "dilute" the occurrence probability of certain schemes through multiple rounds of 1-on-1 "challenges." The direct consequence is that the segmentation with the highest original score might not appear with the highest probability.

After careful consideration, I realized that this issue does indeed exist. At the time, in my eagerness to derive a new sampling algorithm, my thinking and handling of the details were indeed somewhat coarse. Therefore, this article will further refine the Viterbi Sampling algorithm and prove that the perfected algorithm is equivalent in effect to Subword Regularization.

Problem Analysis

First, let's look at the original comment:

In subword regularization, it can be guaranteed that the data follows a specific probability (possessing a temperature hyperparameter). In the proposed method, for each $e$, the first calculated route is subjected to multiple 1v1 "challenges." Will the final probability distribution end up being quite different from existing algorithms? For example, if there are three ways to segment "watching": "watch ing", "wat ching", and "w atching", and their probabilities are all one-third, in the proposed scheme, the sampling probabilities for the first two would become one-fourth, and the third would become one-half. Is that correct?

The comment is already very clear. If the reader still doesn't understand, let me expand on it slightly. Suppose there are three segmentation schemes, and each scheme has the same score. We naturally expect each scheme to have a $1/3$ probability of appearing during sampling. However, Viterbi Sampling transforms the multi-choice sampling process into multiple steps of binary choices:

\begin{equation} r_i = \left\{\begin{aligned}&\,1\,, \,\, s_i > s_{i-1} \\ &\,0\,, \,\, \text{else}\end{aligned}\right.\qquad\longrightarrow\qquad r_i = \left\{\begin{aligned}&\,1\,, \,\, \varepsilon < \sigma(\alpha(s_i - s_{i-1})) \\ &\,0\,, \,\, \text{else}\end{aligned}\right. \end{equation}

In this way, the first two segmentation schemes undergo a binary choice first, each with a probability of $\frac{1/3}{1/3+1/3}=1/2$. After one is selected, it is put together with the third scheme for another binary choice. Since the probability is calculated according to their respective scores, the probability for each at this stage is still $1/2$. Consequently, in the complete sampling process, the probability of the first two schemes appearing is $1/4$, and the probability of the last scheme appearing is $1/2$. Schemes that appear later "benefit" relatively more, while the probabilities of earlier schemes are diluted more severely. Unfortunately, according to the return order of BytePiece's AC automaton, longer words (which usually have higher scores) tend to appear earlier. Therefore, in the old Viterbi Sampling, schemes with higher scores are actually more likely to have their probabilities diluted.

The Solution

It now appears that the solution is actually quite simple: after each binary choice, we should also cache the cumulative probability. Starting from the second step, each binary choice should not be between the new candidate and the current winner's score, but between the new candidate and the cumulative probability score. This is commonly known as the "Reservoir sampling" algorithm.

Using the previous example: two segmentation schemes come in first. One is chosen with a probability of $\frac{1/3}{1/3+1/3}=1/2$, and their total cumulative probability is $2/3$. Next, the winner is compared with the new scheme. The probability of the new scheme being selected should be $\frac{1/3}{2/3+1/3}=1/3$. That is, it must be compared against the cumulative probability, not just the individual probability of the previous winner. In this way, throughout the complete sampling flow, each segmentation scheme has an equal probability of $1/3$ of appearing.

For Viterbi Sampling, each end position has multiple segmentation schemes, and we need to perform a multi-choice sampling among them. The probability of being selected is constructed from their respective scores as $p_i = e^{\alpha s_i}/Z$, where $Z$ is the normalization factor. Because we process recursively, we don't know exactly how many choices there are in "multi-choice," nor can we calculate the final $Z$. However, this is not important; knowing $e^{\alpha s_i}$ is sufficient, because calculating the conditional sampling probability at each step does not require the full $Z$, but rather the recursive $Z_i$:

Viterbi Decoding Old Viterbi Sampling New Viterbi Sampling
$r_i = \left\{\begin{aligned}&\,1\,, \,\, s_i > s_{i-1} \\ &\,0\,, \,\, \text{else}\end{aligned}\right.$ $r_i = \left\{\begin{aligned}&\,1\,, \,\, \varepsilon < \sigma(\alpha(s_i - s_{i-1})) \\ &\,0\,, \,\, \text{else}\end{aligned}\right.$ $\begin{aligned}Z_i =&\, Z_{i - 1} + e^{\alpha s_i} \\[1pt] r_i =&\, \left\{\begin{aligned}&\,1\,, \,\, \varepsilon < e^{\alpha s_i} / Z_i \\ &\,0\,, \,\, \text{else}\end{aligned}\right.\end{aligned}$

In actual calculation, due to the risk of exponential explosion, directly caching $Z_i$ is highly likely to cause overflow. Therefore, we generally cache its logarithm $Z^{\log}_i$ and use the $\text{logsumexp}$ function to avoid overflow:

\begin{equation} \begin{aligned}&\,Z^{\log}_i = \text{logsumexp}(Z^{\log}_{i-1}, \alpha s_i) \\ &\qquad e^{\alpha s_i} / Z_i \to e^{\alpha s_i - Z^{\log}_i} \end{aligned},\qquad \text{logsumexp}(x,y) = \left\{\begin{aligned}&\,x + \log(1+e^{y-x}),\,\, x \geq y \\ &\,y + \log(1 + e^{x-y}),\,\,x < y \end{aligned}\right. \end{equation}

The corresponding implementation is already built into bytepiece>=0.5.0.

Perfect Sampling

In summary, the flaws in the old Viterbi Sampling were simply due to acting too hastily. So now, I am seriously providing a mathematical proof for the new version of Viterbi Sampling. Interestingly, it can be proven that the updated Viterbi Sampling is a "perfect sampling" algorithm, just like Subword Regularization.

Previously, we introduced the approach of Subword Regularization: it is very "brute-force," directly finding the top $k$ segmentation schemes and then calculating the probability of being selected via $p_i = e^{\alpha s_i}/Z$, where $s_i$ is the score of the $i$-th scheme. Aside from high complexity, there is nothing wrong with this method. When $k$ is unrestricted (i.e., finding all possible segmentation schemes), we obtain a random sample of all segmentations, where the probability of sampling any scheme is proportional to $e^{\alpha s_i}$—a monotonically increasing function of the score $s_i$. This means the sampling probability and the ranking of scores are the same. Algorithms satisfying these two conditions are what I call "perfect sampling."

Decoding

To prove that the new Viterbi Sampling is also "perfect sampling," let's first review Viterbi Decoding. Let there be a byte string of length $l$, $c_1, c_2, \dots, c_l$. Let $S^*(c_1, c_2, \dots, c_l)$ denote the score of the optimal segmentation scheme. If we know that a split must occur between $c_k$ and $c_{k+1}$, then it necessarily follows that:

\begin{equation}S^*(c_1, c_2, \dots, c_l) = S^*(c_1, c_2, \dots, c_k) + S^*(c_{k+1}, c_{k+2}, \dots, c_l)\end{equation}

That is to say, the sub-segmentation of an optimal segmentation scheme must also be the optimal segmentation scheme for the corresponding sub-byte string. This is the fundamental basis of dynamic programming. Of course, in reality, we cannot predict where the cuts will occur, so we must use enumeration:

\begin{equation}S^*(c_1, c_2, \dots, c_l) = \max\left\{\begin{aligned} &\,{\color{green}s\left(\overline{c_1,\dots,c_l}\right)} \\ {\color{red}S^*(c_1)} \,+&\, {\color{green}s\left(\overline{c_2,\dots,c_l}\right)} \\ {\color{red}S^*(c_1,c_2)} \,+&\, {\color{green}s\left(\overline{c_3,\dots,c_l}\right)} \\ \vdots \\ {\color{red}S^*(c_1,\dots,c_{l-2})} \,+&\, {\color{green}s\left(\overline{c_{l-1},c_l}\right)} \\ {\color{red}S^*(c_1,\dots,c_{l-1})} \,+&\, {\color{green}s\left(\overline{c_l}\right)} \end{aligned}\right\}\label{eq:core}\end{equation}

Where $s\left(\overline{c_1,\dots,c_l}\right)$ refers to the score of the byte string $c_1, \dots, c_l$ as a single token (if it is not in the vocabulary, it is recorded as $-\infty$). Consequently, the calculation of $S^*(c_1, c_2, \dots, c_l)$ is transformed into the calculation of $S^*(c_1), S^*(c_1, c_2), \dots, S^*(c_1, \dots, c_{l-1})$. By induction, the calculation of $S^*(c_1, c_2, \dots, c_{l-1})$ is transformed into $S^*(c_1), S^*(c_1, c_2), \dots, S^*(c_1, \dots, c_{l-2})$, and so on. This means the results of $S^*$ can be reused. Therefore, the entire process can be summarized in one sentence:

When scanning to each position, record the optimal segmentation scheme and its score up to that position.

Of course, directly following the recursion in formula $\eqref{eq:core}$ would theoretically have a complexity of $\mathcal{O}(l^2)$. However, in reality, it is impossible for every sub-byte string to be a token in the vocabulary. Therefore, we can use methods like Trie trees or AC automata to pre-scan all possible tokens in the vocabulary. The complexity then becomes proportional to the number of candidate tokens searched, which is linear with respect to $l$. If forced to estimate a value, assuming the maximum length of a token in the vocabulary is $m$, the number of tokens scanned for a byte string of length $l \geq m$ would not exceed:

\begin{equation}l + (l - 1) + \dots + (l - m + 1) = lm - \frac{1}{2}m(m-1) = \mathcal{O}(lm)\end{equation}

Sampling

With the Decoding part as a foundation, understanding Sampling is relatively easier. The key is still in formula $\eqref{eq:core}$. We use $Z(c_1, c_2, \dots, c_l)$ to represent the normalization factor (perfect sampling) for all segmentation schemes of the byte string $c_1, c_2, \dots, c_l$. Then we have:

\begin{equation}Z(c_1, c_2, \dots, c_l) = \sum\left\{\begin{aligned} &\,{\color{green}e^{\alpha\cdot s\left(\overline{c_1,\dots,c_l}\right)}} \\ {\color{red}Z(c_1)} &\, {\color{green}e^{\alpha\cdot s\left(\overline{c_2,\dots,c_l}\right)}} \\ {\color{red}Z(c_1,c_2)} &\, {\color{green}e^{\alpha\cdot s\left(\overline{c_3,\dots,c_l}\right)}} \\ \vdots \\ {\color{red}Z(c_1,\dots,c_{l-2})} &\, {\color{green}e^{\alpha\cdot s\left(\overline{c_{l-1},c_l}\right)}} \\ {\color{red}Z(c_1,\dots,c_{l-1})} &\, {\color{green}e^{\alpha\cdot s\left(\overline{c_l}\right)}} \end{aligned}\right\}\label{eq:core-2} \end{equation}

This equality also indicates that to implement sampling according to the weight $e^{\alpha s}$ from all segmentation schemes of $c_1, c_2, \dots, c_l$, we can: randomly pick one from all segmentations of $c_1, \dots, c_{l-1}$ and append token $\overline{c_l}$; randomly pick one from all segmentations of $c_1, \dots, c_{l-2}$ and append token $\overline{c_{l-1}, c_l}$; randomly pick one from all segmentations of $c_1, \dots, c_{l-3}$ and append token $\overline{c_{l-2}, c_{l-1}, c_l}$, and so on. After obtaining these $l$ candidate sampling results, we select one from them with weights $Z(c_1, \dots, c_{l-1}) e^{\alpha\cdot s\left(\overline{c_l}\right)}$, $Z(c_1, \dots, c_{l-2}) e^{\alpha\cdot s\left(\overline{c_{l-1}, c_l}\right)}$, $Z(c_1, \dots, c_{l-3}) e^{\alpha\cdot s\left(\overline{c_{l-2}, c_{l-1}, c_l}\right)}$, etc.

Similar to the Decoding scenario, the calculation of $Z(c_1, \dots, c_{l-1})$ can reuse the results of $Z(c_1), Z(c_1, c_2), \dots, Z(c_1, \dots, c_{l-2})$, and the calculation of $Z(c_1, \dots, c_{l-2})$ can reuse $Z(c_1), Z(c_1, c_2), \dots, Z(c_1, \dots, c_{l-3})$, and so forth. The sampling results themselves are also reusable. Thus, the entire Sampling algorithm can be summarized in one sentence:

When scanning to each position, perform sampling for all segmentation schemes ending at the current position according to the $e^{\alpha s}$ weights, and record the sampling result along with the cumulative weight $Z$.

If we take the logarithm on both sides, formula $\eqref{eq:core-2}$ can be equivalently rewritten as:

\begin{equation}Z^{\log}(c_1, c_2, \dots, c_l) = \text{logsumexp}\left\{\begin{aligned} &\,{\color{green}\alpha\cdot s\left(\overline{c_1,\dots,c_l}\right)} \\ {\color{red}Z^{\log}(c_1)} \,+&\, {\color{green}\alpha\cdot s\left(\overline{c_2,\dots,c_l}\right)} \\ {\color{red}Z^{\log}(c_1,c_2)} \,+&\, {\color{green}\alpha\cdot s\left(\overline{c_3,\dots,c_l}\right)} \\ \vdots \\ {\color{red}Z^{\log}(c_1,\dots,c_{l-2})} \,+&\, {\color{green}alpha\cdot s\left(\overline{c_{l-1},c_l}\right)} \\ {\color{red}Z^{\log}(c_1,\dots,c_{l-1})} \,+&\, {\color{green}alpha\cdot s\left(\overline{c_l}\right)} \end{aligned}\right\} \end{equation}

The difference from the Viterbi Decoding formula $\eqref{eq:core}$ is that $Z^{\log}$ replaces $S^*$ and $\text{logsumexp}$ replaces $\max$. Since $\text{logsumexp}$ is precisely a smooth approximation of $\max$, it degenerates into Viterbi Decoding when $\alpha \to \infty$. On the other hand, in actual computation, multiple segmentation schemes for the same ending point arrive one by one rather than all at once. Therefore, the single-step "multi-choice" needs to be converted into multiple steps of "binary choices," which is exactly what was discussed in "The Solution" section. To this point, we have proven (or rather, re-derived from Viterbi Decoding) that the modified Viterbi Sampling is indeed a perfect sampling algorithm, identical in principle to Subword Regularization.

Summary

This article refines the previously proposed random tokenization algorithm, Viterbi Sampling, and mathematically proves that it is a "perfect sampling" algorithm, just like Subword Regularization, while being significantly more efficient in practical use.