A Theoretical Analysis Attempt of the Repetition Problem in Seq2Seq

By 苏剑林 | January 26, 2021

Last year, the author wrote a blog post "How to Deal with the 'Never-Ending' Problem in Seq2Seq?", which introduced a paper's strategy for handling the phenomenon where Seq2Seq decoding fails to stop, and pointed out that the paper only provided strategies without providing a theoretical understanding of the problem. Recently, the author read a paper titled "A Theoretical Analysis of the Repetition Problem in Text Generation" from AAAI 2021 on arXiv, which analyzes the repetition decoding phenomenon in Seq2Seq from a theoretical perspective. In essence, repetition decoding and the "never-ending" decoding problem are cut from the same cloth, so this new paper can be seen as filling the gap left by the previous one.

Upon studying it, the author found that the paper indeed has many commendable points worth reading. The author has refined, corrected, and generalized the analysis process from the original paper and recorded the results in this article for reference. Furthermore, putting the problem background aside, readers can also treat this article as a matrix analysis exercise to review linear algebra.

Basic Idea

So-called repetition decoding refers to the appearance of repeated fragments in the decoding result. For example, if the decoding result is "A B C D B C D B C D E F", then "B C D" is a repeated fragment, and thus the result exhibits the repetition decoding phenomenon. For simplicity, if a subsequence $s=[w_1, w_2, \cdots, w_n]$ is followed by $t=[w_1, w_2, \cdots, w_n, w_1]$ during decoding, we call $[w_1, w_2, \cdots, w_n]$ a "repetition subsequence." What we want to do now is to analyze the probability of a repetition subsequence appearing during the decoding process.

Some readers might wonder why an extra $w_1$ is added at the end of $t$? From the subsequent process, we can see that this is essentially for analytical convenience and is not strictly necessary. What we hope to obtain is a representative quantitative indicator to measure this repetition decoding problem, preferably one that provides insights into improvement strategies. As for the specific details of this indicator, we need not worry too much. Quantifying the research objective is crucial; only after quantification can we better grasp the direction of improvement and compare the merits of different methods. Otherwise, even if we argue until we are red in the face, we will never reach a conclusion.

To obtain such an indicator, we will start from simple binary decoding to get some representative results and then see if they can be generalized to general autoregressive decoders.

Binary Decoding

A general autoregressive model takes the form: \begin{equation}p(\boldsymbol{y}|\boldsymbol{x}) = \prod_{t=1}^l p(y_t|\boldsymbol{y}_{< t}, \boldsymbol{x})\end{equation} That is to say, the decoding at position $t$ depends not only on the input $\boldsymbol{x}$ but also on all previously obtained decoding results before $t$. For simplicity, let's first consider a simple case where we assume each step of decoding only depends on the result of the previous moment: \begin{equation}p(\boldsymbol{y}|\boldsymbol{x}) = \prod_{t=1}^l p(y_t|y_{t-1}, \boldsymbol{x})\end{equation} In this way, for a fixed input $\boldsymbol{x}$, the decoder is effectively just an $n \times n$ transition matrix $\boldsymbol{P}=(P_{i,j})$, where $P_{i,j}$ represents the probability of $j$ following $i$, and $n$ represents the vocabulary size. Such a decoder is called a bigram model, a 2-gram model, a Markov model, and so on. We also need a termination token <eos>; decoding stops upon encountering <eos>. So actually, the transition matrix should be $(n+1) \times (n+1)$, but since we consider repetition decoding before termination, we only need to consider the $n \times n$ part excluding <eos>.

We want to calculate the probability of a repetition subsequence appearing. If $[i, j, k]$ is a trigram repetition subsequence, then its probability of occurrence is the probability of the sequence $[i, j, k, i, j, k, i]$ appearing: \begin{equation}P_{i,j}P_{j,k}P_{k,i}P_{i,j}P_{j,k}P_{k,i}=P_{i,j}^2 P_{j,k}^2 P_{k,i}^2\end{equation} Therefore, the probability of all trigram repetition subsequences is: \begin{equation}\sum_{i,j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2 = \text{Tr}\,(\boldsymbol{P}\otimes\boldsymbol{P})^3\end{equation} Here $\otimes$ denotes element-wise multiplication (Hadamard product), and $\text{Tr}$ is the trace of the matrix, the sum of the diagonal elements. Finally, we sum the probabilities of repetition subsequences of all lengths: \begin{equation}R = \sum_{k=1}^{\infty}\text{Tr}\,(\boldsymbol{P}\otimes\boldsymbol{P})^k = \text{Tr}\,\left(\sum_{k=1}^{\infty}(\boldsymbol{P}\otimes\boldsymbol{P})^k\right)\label{eq:r}\end{equation} This is the probability of repetition decoding occurring in a bigram decoder. While this is currently just a theoretical formula, it is our important starting point. We will derive its upper and lower bounds to obtain more inspired results.

A Lower Bound

It is difficult to see much directly from Eq. \eqref{eq:r}; we can first derive a more intuitive lower bound. Still taking the trigram repetition subsequence as an example, using the inequality of arithmetic and geometric means (or Power Mean inequality), we get: \begin{equation} \sum_{i,j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2 = n^3\times\frac{\sum_{i,j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2}{n^3}\geq n^3\times\left(\frac{\sum_{i,j,k} P_{i,j} P_{j,k} P_{k,i}}{n^3}\right)^2 = \frac{(\text{Tr}\, \boldsymbol{P}^3)^2}{n^3} \end{equation} In fact, we can be more precise. Suppose matrix $\boldsymbol{P}$ has some elements that are 0; then the number of non-zero elements in $P_{i,j}^2 P_{j,k}^2 P_{k,i}^2$ is not $n^3$. Let the number of non-zero elements be $N_3(\boldsymbol{P}) < n^3$. When applying the mean inequality, we can do so only for the non-zero elements, resulting in replacing $n^3$ with $N_3(\boldsymbol{P})$: \begin{equation} \sum_{i,j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2 \geq \frac{(\text{Tr}\, \boldsymbol{P}^3)^2}{N_3(\boldsymbol{P})} \end{equation} Calculation of $N_3(\boldsymbol{P})$ directly is difficult and lacks a general formula, but we can make a simple estimate: let the proportion of non-zero elements in $\boldsymbol{P}$ be $\zeta$, meaning the number of non-zero elements is $\zeta n^2$. Then we can assume the proportion of non-zero elements in $P_{i,j}^2 P_{j,k}^2 P_{k,i}^2$ is approximately $\zeta^3$, and the total number of permutations is $n^3$, so we can assume $N_3(\boldsymbol{P})\sim \zeta^3 n^3$, or generally $N_k(\boldsymbol{P})\sim \zeta^k n^k$. Note that examples can show this estimate is neither guaranteed to be an upper bound nor a lower bound, so after replacing $N_3(\boldsymbol{P})$ with $\zeta^3 n^3$, we cannot guarantee the inequality holds. However, if we are willing to believe $\zeta^3 n^3$ is a sufficiently good approximation, we can still (with a mix of trepidation and firm conviction) write down: \begin{equation} \sum_{i,j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2 \geq \frac{(\text{Tr}\, \boldsymbol{P}^3)^2}{\zeta^3 n^3} \end{equation} and \begin{equation}R = \sum_{k=1}^{\infty}\text{Tr}\,(\boldsymbol{P}\otimes\boldsymbol{P})^k \geq \sum_{k=1}^{\infty} \frac{(\text{Tr}\, \boldsymbol{P}^k)^2}{\zeta^k n^k}\label{eq:r-2}\end{equation} Or we can simply stop worrying about the inequality and treat the rightmost result as an estimate of $R$.

The Original Paper's Lower Bound

Readers hoping to follow along with the original paper might be a bit confused at this point because neither Eq. \eqref{eq:r} nor Eq. \eqref{eq:r-2} can be found in the original paper. In fact, the original paper does not provide an exact Eq. \eqref{eq:r} nor the estimate Eq. \eqref{eq:r-2}; instead, it provides another estimate, which can also be derived as a lower bound of Eq. \eqref{eq:r}.

Similarly, using the mean inequality, we have: \begin{equation}\begin{aligned} \sum_{i,j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2 =&\, \sum_{i} \sum_{j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2= \sum_{i} n^2\times\frac{\sum_{j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2}{n^2}\\ \geq&\, \sum_{i} n^2\times\left(\frac{\sum_{j,k} P_{i,j} P_{j,k} P_{k,i}}{n^2}\right)^2 = \frac{\text{Tr}\, (\boldsymbol{P}^3\otimes \boldsymbol{P}^3)}{n^2} \end{aligned}\end{equation} Similarly, introducing the non-zero element count trick to improve estimation accuracy, the non-zero rate is still $\zeta^3$, and the total sum count this time is $n^2$, so the non-zero permutations are about $\zeta^3 n^2$. Thus, we (still with trepidation and conviction) write: \begin{equation} \sum_{i,j,k} P_{i,j}^2 P_{j,k}^2 P_{k,i}^2 \geq \frac{\text{Tr}\, (\boldsymbol{P}^3\otimes \boldsymbol{P}^3)}{\zeta^3 n^2}\end{equation} and \begin{equation}R = \sum_{k=1}^{\infty}\text{Tr}\,(\boldsymbol{P}\otimes\boldsymbol{P})^k \geq \sum_{k=1}^{\infty}\frac{\text{Tr}\, (\boldsymbol{P}^k\otimes \boldsymbol{P}^k)}{\zeta^k n^{k-1}}\label{eq:r-3}\end{equation} This is basically "Definition 2.3" in the original paper, with the following differences:

1. The original paper calculates the probability averaged over each word, so it needs to divide by an extra $n$, hence its denominator is $n^k$;

2. The original paper takes the trace of $\boldsymbol{P}^{2k}$ instead of $\boldsymbol{P}^k\otimes \boldsymbol{P}^k$. This is actually a mistake in the original paper, which treated $(\boldsymbol{P}^k)_{i,i}^2$ as $(\boldsymbol{P}^{2k})_{i,i}$ during the derivation process; in fact, they are unequal. Eq. \eqref{eq:r-3} in this article is the correct result.

Initial Conclusion

Actually, whether it's Eq. \eqref{eq:r-2} or Eq. \eqref{eq:r-3}, the forms are similar, and we can use them to draw some conclusions. At this point, some readers might wonder: the probability distributions used by our models are usually Softmax outputs, and Softmax results are never equal to 0, so $\zeta$ should be identically 1. Does introducing $\zeta$ have any value?

Not necessarily. Indeed, the probability distribution from Softmax will not be strictly equal to 0, but our decoding algorithms often force them to be 0! In the article "How to Deal with the 'Never-Ending' Problem in Seq2Seq?", we listed common decoding algorithms for text generation, mainly including random sampling and deterministic decoding. Random sampling includes direct random sampling, Top-k sampling, and Top-p sampling; deterministic decoding includes Greedy Search and Beam Search. Among these five, except for the rarely used direct random sampling, the other four all forceably keep only several optimal candidates as potential values, which is equivalent to directly truncating the transition matrix and greatly reducing the non-zero probability $\zeta$.

Take the most extreme case, Greedy Search. It is easy to see that it corresponds to the smallest non-zero probability $\zeta=1/n$. Since $\zeta$ is in the denominator, a decrease in $\zeta$ means an increase in the repetition rate $R$. This tells us that the risk of repetition decoding in Greedy Search is quite high. Although this conclusion is only derived under the assumption of a bigram decoding model, repetition decoding in Greedy Search is indeed a phenomenon we frequently observe, so this conclusion and explanation are representative.

An Upper Bound

Since we have a lower bound, how can we not have an upper bound? While the lower bound helps us explain experimental phenomena, an upper bound can provide us with ideas for improvement.

To derive the upper bound, we utilize the following two conclusions:

1. The trace of a matrix is equal to the sum of all its eigenvalues;

2. If $\lambda_1(\boldsymbol{A})\geq\lambda_2(\boldsymbol{A})\geq\cdots\geq\lambda_n(\boldsymbol{A})$ are all eigenvalues of matrix $\boldsymbol{A}$, then $\lambda_1^k(\boldsymbol{A})\geq\lambda_2^k(\boldsymbol{A})\geq\cdots\geq\lambda_n^k(\boldsymbol{A})$ are all eigenvalues of matrix $\boldsymbol{A}^k$.

Therefore, we can derive: \begin{equation}\begin{aligned} R =&\, \sum_{k=1}^{\infty}\text{Tr}\,(\boldsymbol{P}\otimes\boldsymbol{P})^k = \sum_{k=1}^{\infty}\sum_{i=1}^n\lambda_i\left((\boldsymbol{P}\otimes\boldsymbol{P})^k\right)\\ =&\, \sum_{k=1}^{\infty}\sum_{i=1}^n\lambda_i^k\left(\boldsymbol{P}\otimes\boldsymbol{P}\right) = \sum_{i=1}^n \sum_{k=1}^{\infty}\lambda_i^k\left(\boldsymbol{P}\otimes\boldsymbol{P}\right) \\ =&\, \sum_{i=1}^n \frac{\lambda_i \left(\boldsymbol{P}\otimes\boldsymbol{P}\right)}{1 - \lambda_i \left(\boldsymbol{P}\otimes\boldsymbol{P}\right)} \end{aligned}\label{eq:r-4}\end{equation} The above process uses the series $\frac{x}{1-x}=\sum_{k=1}^{\infty} x^k$, which converges only when $|x| < 1$. Quite coincidentally, we can prove that the absolute value of the eigenvalues of $\boldsymbol{P}\otimes\boldsymbol{P}$ must not be greater than 1, and are usually less than 1. Since $\boldsymbol{P}$ is a transition matrix, each of its row sums is 1, thus each row sum of $\boldsymbol{P}\otimes\boldsymbol{P}$ is less than or equal to 1. Let $\lambda$ and $\boldsymbol{x}$ be its eigenvalue and eigenvector; then $(\boldsymbol{P}\otimes\boldsymbol{P})\boldsymbol{x}=\lambda \boldsymbol{x}$. Without loss of generality, let $x_1$ be the element with the largest absolute value in $\boldsymbol{x}$, and let the first row vector of $\boldsymbol{P}\otimes\boldsymbol{P}$ be $\boldsymbol{q}_1^{\top}$. Then we have $|\lambda| |x_1| = |\boldsymbol{q}_1^{\top}\boldsymbol{x}| \leq |x_1|$, hence $|\lambda| \leq 1$. Moreover, the condition for equality to hold is quite strict, so usually $|\lambda| < 1$.

Note that the function $\frac{x}{1-x}$ is monotonically increasing in the interval $[-1,1)$, so in Eq. \eqref{eq:r-4}, the dominant term is the first term $\frac{\lambda_1 \left(\boldsymbol{P}\otimes\boldsymbol{P}\right)}{1 - \lambda_1 \left(\boldsymbol{P}\otimes\boldsymbol{P}\right)}$. If we must set an overall upper bound, it could be $\frac{n \lambda_1 \left(\boldsymbol{P}\otimes\boldsymbol{P}\right)}{1 - \lambda_1 \left(\boldsymbol{P}\otimes\boldsymbol{P}\right)}$.

Further Conclusion

Therefore, to reduce the repetition rate $R$, we need to find ways to reduce the largest eigenvalue of the matrix $\boldsymbol{P}\otimes\boldsymbol{P}$. $\boldsymbol{P}\otimes\boldsymbol{P}$ is a non-negative matrix. According to the "Frobenius inclusion theorem" for non-negative matrices, we have: \begin{equation}\min_i \sum_j P_{i,j}^2 \leq \lambda_1 (\boldsymbol{P}\otimes\boldsymbol{P}) \leq \max_i \sum_j P_{i,j}^2\end{equation} Regarding the Frobenius inclusion theorem, it is introduced in almost any matrix analysis book; it states that the largest eigenvalue of a non-negative matrix lies between the minimum and maximum of its row sums. Now we know that to reduce the largest eigenvalue of $\boldsymbol{P}\otimes\boldsymbol{P}$, we need to find ways to reduce its row sums, i.e., $\sum_j P_{i,j}^2$. Furthermore, due to the mean inequality: \begin{equation}\sum_j P_{i,j}^2\geq n\left(\frac{\sum_j P_{i,j}}{n}\right)^2 = \frac{1}{n}\end{equation} the minimum value is $1/n$, achieved when $P_{i,1}=P_{i,2}=\cdots=P_{i,n}$. Therefore, we ultimately arrive at a conclusion: To reduce the largest eigenvalue, each row of matrix $\boldsymbol{P}$ should be as uniform as possible; in other words, the variance of each row of $\boldsymbol{P}$ should be reduced.

How to reduce variance? Simply put, excessively high probability values must not appear. For example, if a row is close to the one-hot form, then after squaring it is still close to one-hot, and the sum will be close to 1, far greater than the theoretical minimum $1/n$. Under what circumstances do excessively high probability values appear? It's not hard to understand: it's when there are very few candidate words that can follow a certain word, or even only one. For example, "忐" (tan) almost can only be followed by "忑" (te), so $P_{i=\text{忐},j=\text{忑}}$ is quite high; "矩" (ju) is probably followed mostly by "阵" (zhen) or "形" (xing), so the variance of the "矩" row is also not small. How can we avoid such excessively high probability values? Simply merge the high-probability transitions into a new word. For example, merge "忐忑" (tante) into one word; then the "忐" row no longer exists, and there's no issue with large variance. Similarly, "矩形" (rectangle) and "矩阵" (matrix) should also ideally be merged into single words.

So, in plain language, this tells us that for text generation tasks, using words as units is more reliable than using characters (less prone to repetition decoding). Appropriately merging words with a high degree of correlation into new words and adding them to the vocabulary to reduce the variance of the transition matrix helps reduce the risk of repetition decoding. The original paper even gave this operation a very high-end name called "Rebalanced Encoding Algorithm," which essentially means this. Our previous word-level WoBERT performing better on generation tasks than character-level BERT can be seen as a verification of this conclusion (refer to "Speed Up Without Dropping: Chinese WoBERT Based on Word-level Granularity").

General Decoding

Is this proof process easy to generalize to general autoregressive models? Unfortunately, it is not. For general autoregressive models, it's as if $\boldsymbol{P}$ is different at every step. Therefore, as long as the model performance is good enough, repetition decoding basically won't occur. In fact, pre-trained generative models that have been fully trained rarely show repetition decoding. However, we can still observe that even in general autoregressive decoding, the repetition decoding phenomenon occasionally appears, especially in models that haven't been pre-trained. How do we explain this?

The previous sections were based on the bigram decoding model, concluding that bigram decoding models are indeed prone to repetition. Then perhaps we can think in reverse: the reason general autoregressive models exhibit repetition decoding is because they have degenerated into bigram decoding models at that moment? For inputs with high difficulty, the model might not be able to precisely capture the transition probabilities at each step, thereby degenerating the transition matrix into bigram decoding; this is possible.

How did the original paper handle this part? It was essentially similar. The original paper assumed that the transition matrix of a general autoregressive model is just the bigram transition matrix $\boldsymbol{P}$ plus a time-specific perturbation $\tilde{\boldsymbol{P}}_t=\boldsymbol{P}+\boldsymbol{Q}_t$. It then pointed out that when $\boldsymbol{Q}_t$ is small enough, the gap with the bigram model is also small enough (which is almost a truism), thus the results of the bigram model can represent general autoregressive models. So, for general autoregressive models, we are indeed quite powerless and can only use this idea to establish a link.

Article Summary

This article is a theoretical analysis attempt of the repetition decoding phenomenon in Seq2Seq. The majority of the space is dedicated to obtaining quantitative results for bigram decoding models. It was found that these results can indeed explain some phenomena and bring some ideas for improvement. Finally, a somewhat "strained" connection was made between bigram decoding and general autoregressive models. This article was inspired by the paper "A Theoretical Analysis of the Repetition Problem in Text Generation", but the derivation process was done behind closed doors, and the formula definitions differ slightly from the original paper. However, the overall conclusions are consistent. Readers are invited to judge for themselves; if there are errors, corrections are welcome.