"Behind Closed Doors" Thoughts on Multimodal Approaches (II): Autoregression

By 苏剑林 | July 08, 2024

In this article, we continue to "build wheels behind closed doors" and share some of the author's recent new understandings of multimodal learning. In the previous post "Thoughts on Multimodal Approaches (I): Lossless Input", we emphasized the importance of lossless input for an ideal multimodal model. If this viewpoint holds, then the current mainstream approaches of discretizing images based on VQ-VAE, VQ-GAN, etc., present a performance bottleneck. This is because a simple entropy calculation shows that discretization inevitably leads to severe information loss. Therefore, a more promising or long-term solution should be using continuous features as input, such as directly "patchifying" the original pixel features of an image before feeding them into the model.

However, while continuous input is naturally simple for image understanding, it introduces additional difficulties for image generation. This is because non-discretized data cannot directly use the autoregressive framework applied to text; it requires incorporating new elements like diffusion. This leads us to the subject of this article—how to perform multimodal autoregressive learning and generation. Of course, non-discretization is only a surface-level difficulty; the more demanding parts lie ahead...

The Meaning of "Lossless"

First, let us clarify the meaning of "lossless." Lossless does not mean that there can be no loss whatsoever throughout the entire computation process. That is unrealistic and contradicts our understanding of the essence of deep learning—as mentioned in the 2015 article "Chatting: Neural Networks and Deep Learning", the key to deep learning's success is information loss. Thus, the meaning of lossless here is simple: we hope that the input to the model, as an initial stage, is as lossless as possible.

The mainstream architecture of current multimodal models is still the Transformer. Many works "preprocess" images before feeding them into the Transformer, such as simply splitting images into pixel patches, extracting features through a VAE, or discretizing them via Vector Quantization (VQ). Their common characteristic is transforming an image from a $w \times h \times 3$ array into an $s \times t \times d$ array (where $s < w, t < h$), which can be broadly termed "Patchify." Different Patchify methods result in different degrees of information loss. Among them, VQ often results in the most severe and explicit loss. For example, ByteDance's recent TiTok compresses a 256*256 image into 32 tokens. One doesn't even need to calculate entropy to understand its information loss; its codebook size is only 4096, meaning it can represent at most $4096^{32}$ images. We know there are more than 4096 Chinese characters; in other words, if an image contained 32 Chinese characters, the total permutations would exceed the upper limit of what this encoding can express.

If an image has significant information loss before entering the model, it will inevitably limit the model's image understanding capabilities. For instance, if TiTok's 32 tokens are fed into a model, it becomes essentially impossible to perform OCR tasks. Furthermore, this bottleneck in VQ is fundamental; even increasing it to 32*32 tokens rarely yields significant improvements unless the token count reaches the scale of the original RGB pixels, at which point the purpose of VQ disappears. Therefore, to better adapt to various image understanding tasks, the ideal image input method for multimodal models should be continuous features that are as lossless as possible, allowing the model itself to decide what to discard during computation based on context.

The Autoregressive Form

As mentioned at the beginning of this article, continuous features for image understanding are actually a very reasonable and natural approach; they simply introduce extra difficulty for autoregressive (AR) image generation. At this point, readers might ask: why must images be generated using autoregression? Don't we already have better generation methods like diffusion models?

First, we know that the "autoregressive model + Teacher Forcing training" itself is a very universal learning pathway, a classic manifestation of "hand-holding instruction," so its potential is sufficient. Secondly, the example of diffusion models further demonstrates the necessity of autoregression in image generation. Taking DDPM as an example, it is essentially an autoregressive model. In "Talk on Generative Diffusion Models (II): DDPM = Autoregressive VAE", we crowned it with the name of autoregression. It deconstructs a single image into a sequence $x_T, x_{T-1}, \cdots, x_1, x_0$, and then models $p(x_{t-1}|x_t)$. The training method is essentially Teacher Forcing (thus it also has the Exposure Bias problem). One could say DDPM is not only autoregressive but also the simplest 2-gram version of it.

In fact, from early works like PixelRNN, PixelCNN, and NVAE to today's popular diffusion models and the trend of training language models on VQ'ed image tokens, the signal being sent is increasingly clear: for images, the question is not whether to use autoregression, but in what way to do it better.

Its role is not only to grant multimodal models image generation capabilities but also to serve as an important unsupervised learning pathway. As my idol Feynman famously said, "What I cannot create, I do not understand." This holds true for large models as well; "if you cannot generate, you cannot understand." Of course, this statement might seem a bit bold, as supervised learning on various image-text pairs seems capable of providing sufficient understanding. However, learning image understanding purely through supervised methods might have limited coverage on one hand, and is restricted by human levels of understanding on the other. Therefore, we need unsupervised generative pre-training to obtain more comprehensive image understanding capabilities, which is consistent with the "Pretrain + SFT" pipeline for text.

Squared Error

Some readers might think: after patchifying an image and ordering it, can't we just predict the next patch like we do with text? Even if the input is continuous features rather than discrete tokens, doesn't it just require swapping cross-entropy loss for squared error? It seems there shouldn't be much difficulty in autoregressive learning for images? While the logic seems sound, the two key issues mentioned—"Patchification Ordering" and the "Loss Function"—are both difficult problems to solve.

In this section, let's look at the loss function problem. Assuming the image has been patchified and ordered in some way, the image becomes a one-dimensional sequence of patches, and autoregressive learning is indeed the prediction of the next patch, as shown below:

The simplest idea for image AR learning: use squared error to predict the next patch

However, the loss function here cannot simply be Mean Squared Error (MSE, or equivalently, Euclidean distance, L2 distance). This is because the distribution assumption behind squared error is the Gaussian distribution:

\begin{equation}\frac{1}{2\sigma^2}\Vert x_t - f(x_{< t})\Vert^2 = -\log \mathcal{N}(x_t;f(x_{< t}),\sigma^2) + \text{constant depending only on } \sigma\end{equation}

That is, the negative log-likelihood of the Gaussian distribution $\mathcal{N}(x_t;f(x_{< t}),\sigma^2)$ is exactly the squared error (where $\sigma$ is a constant). This means using squared error assumes $p(x_t|x_{< t})$ is $\mathcal{N}(x_t;f(x_{< t}),\sigma^2)$. But if we think carefully, we realize this assumption is quite far from reality. If it were true, then sampling $x_t$ could be done via $x_t = f(x_{< t}) + \sigma\varepsilon$, where $\varepsilon$ is noise from a standard Gaussian distribution. This would mean $x_t$ would inevitably have many noisy spots, which is clearly not always the case in reality.

Another reader might counter: why must we understand it from the perspective of probabilistic likelihood? Can't I just view it purely as a regression fitting problem? Likely not. Understanding it from a probabilistic perspective serves two main purposes: first, generative modeling ultimately faces sampling, and writing out the probability distribution is necessary to construct a sampling method; second, from a pure regression standpoint, we still need to justify the rationality of squared error, as there are many other losses we could use, such as L1 distance (MAE), Hinge Loss, etc. These losses are not equivalent to each other, nor are they necessarily reasonable (in fact, none of these losses are entirely reasonable, as they are defined from a purely mathematical metric distance perspective and do not completely align with human visual perception).

The Wonder of Noise

Since the nature of the image feature input essentially determines the irrationality of squared error, the only way to solve this is to modify the input format of the image so that its corresponding conditional distribution becomes more Gaussian-like. Currently, there are two specific schemes to consider.

The first scheme is to encode the image using a pre-trained Encoder. When training the Encoder, regularization terms like the VAE's KL divergence are usually added to reduce variance. To put it more intuitively, the features are compressed near a sphere (refer to "An Attempt to Understand VAE from a Geometric Perspective"). Using these features as image input makes the assumption that $p(x_t|x_{< t})$ is a Gaussian distribution more reasonable, allowing us to use squared error for autoregressive training. After training, we also need to train a separate Decoder to decode the sampled image features back into an image. This is roughly the approach adopted by Emu2. The downside is that the pipeline seems too long and not sufficiently end-to-end.

The second scheme might surprise many: adding noise. This is the author's "building wheels behind closed doors" idea. We just said that if $p(x_t|x_{< t})$ were truly Gaussian, $x_t$ should straight-up have noise, but it doesn't. So, to satisfy this condition, why don't we just add some noise ourselves? Adding noise might not make $p(x_t|x_{< t})$ perfectly Gaussian, but it can make it closer, especially when we add noise progressively, as shown below:

Expanding each Patch by adding noise to make squared error a viable loss function

Readers familiar with diffusion models will naturally think: constructing a progressive sequence by adding noise and then training a recursive denoising model with squared error as the loss—isn't that just a diffusion model? Exactly. The core idea of diffusion models is "using progressive noise addition to make squared error a reasonable loss function," and the above scheme borrows this idea. Of course, the differences from conventional diffusion models are obvious: for instance, diffusion models add noise to the entire image, while here it is applied patch-wise; diffusion models model $p(x_t|x_{t-1})$, while here we model $p(x_t|x_{< t})$, and so on. In its final form, what is proposed here is a scheme that combines diffusion models for autoregressive image learning.

Efficiency Issues

Using noise to extend the sequence and make naive squared error usable—aligning the image's AR learning with our initial design (just one extra noise step in the input)—is undoubtedly a very comforting result. However, things are not that optimistic. This scheme has at least two major problems, both of which can be summarized under one word: efficiency.

First is the issue of learning efficiency. We discussed this in the first article introducing diffusion models, "Talk on Generative Diffusion Models (I): DDPM = Demolition + Construction". The gist is that predicting a noisy image at step $t$ from a noisy image at step $t-1$ requires dual sampling of noise, which leads to higher training variance. Consequently, more training steps are needed to reduce this variance. After applying a series of variance reduction techniques, we find that a more efficient way is to predict the original image directly (or equivalently, predict the noise added to the original image):

Directly predicting the original image is more efficient than predicting the next step's noise map

Some readers might be confused: didn't we just say the original image has no noise and doesn't fit a Gaussian distribution, so squared error can't be used? This is indeed not easy to explain intuitively. We can understand it as a coincidence of the Gaussian distribution where squared error remains usable. A more formal explanation can be found in "Talk on Generative Diffusion Models (III): DDPM = Bayes + Denoising" and "Talk on Generative Diffusion Models (IV): DDIM = DDPM from a High Perspective".

Second is the issue of computational efficiency. This is easy to understand: if each patch becomes $T$ patches through noise addition, the sequence length becomes $T$ times longer. This significantly increases both training and inference costs. Furthermore, theoretically, the noisy patches do not substantially help with image understanding; in principle, keeping only the clean, non-noisy patch can achieve the same effect. In other words, this scheme contains a large amount of redundant input and computation for image understanding.

There are two ways to solve this problem, which we will introduce one by one.

Decoupled Diffusion

If we must solve this within a single Transformer, we can consider adding an Attention Mask. This includes two parts: 1. Diffusion theory and practice tell us that to predict $x_t$, $x_{t-1}$ is sufficient and earlier inputs can be ignored, meaning different noisy versions of the same patch don't need to attend to each other; 2. To reduce redundancy, for predictions between different patches and subsequent text tokens, we only need to attend to the clean patches. This results in an Attention Mask roughly like this:

An Attention Mask designed for model simplification and redundancy reduction

Since this Attention Mask has a fixed sparse pattern, there is significant room for speedup. Moreover, because the attention for noisy patches is independent, we don't have to calculate all $T-1$ noisy patches at once during training; we can just sample a portion for calculation. Of course, this is only a prototype; in practice, some details need careful consideration, such as the fact that noisy patches have almost no correlation, so their position encodings might need separate designs, etc. We won't expand on that here.

If we allow two different models to be concatenated (but still trainable end-to-end), we can separate the diffusion model. The Transformer is only responsible for processing patches without noise, and the Transformer's output serves as the condition for the diffusion model, as shown below:

Separating the diffusion model, using the Transformer output as the conditional input for diffusion

This is roughly the scheme proposed in Kaiming’s new work "Autoregressive Image Generation without Vector Quantization", but it was proposed earlier in "Denoising Autoregressive Representation Learning". The benefit is that it keeps the Transformer part pure and elegant while saving computation because: 1) The separated diffusion model can be made smaller; 2) For the diffusion part, we can follow the standard training strategy of sampling only one noise step per calculation. From the perspective of the loss function, it uses an extra diffusion model as the loss for predicting the next patch, thereby overcoming the disadvantages of squared error.

Generation Direction

Earlier, we mentioned two key aspects of autoregressive learning for images: "Patchification Ordering" and the "Loss Function." We have just spent four sections barely smoothing out the loss function part, which is only crossing the threshold. However, next we will encounter an even more pessimistic result—regarding the problem of "Patchification Ordering," we can hardly even find the threshold.

From the ultimate goal, "Patchification Ordering" is to define a generation sequence and direction for AR learning. It involves two steps: "Patchifying" and "Ordering." Narrowly defined, Patchify is a simple reshaping and transposition of the pixel array—transforming a $w \times h \times 3$ array into $s \times (w/s) \times t \times (h/t) \times 3$, ثم transposing to $s \times t \times (w/s) \times (h/t) \times 3$, and finally reshaping to $s \times t \times (3wh/st)$. But broadly speaking, Patchify refers to any scheme that turns a $w \times h \times 3$ array into an $s \times t \times d$ array, such as Stable Diffusion's Encoder encoding images into Latents, or various VQ-Tokenizers turning images into discrete IDs. All of these are broad forms of Patchify.

"Ordering" is easier to understand. We know images have two dimensions: length and width. Most Patchify methods produce features that retain this 2D property, while AR generation is unidirectional, so a generation order must be specified. Common orders include: 1) Left-to-right then top-to-bottom; 2) Spiral from center to edges; 3) "Z-shape" starting from the top-left corner, etc. These ordering designs have been around for a long time, tracing back to the first generation of image AR models that operated directly on pixels (e.g., PixelRNN/PixelCNN).

Naive Patchification and two different ordering methods

In total, "Patchification Ordering" is the process of deconstructing an image into a 1D sequence for AR learning. Simply put, it converts a 2D sequence into a 1D sequence. From the broadest perspective, the image sequences with different noise intensities constructed by diffusion models, as well as the multi-scale sequences in "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction", can all be included in this category. So, now that we have mentioned multiple ways to deconstruct images, the natural question is: which scheme is better? What is the basis for judgment?

World Models

To answer this, we must first understand what is the fundamental difficulty of visual generation. In "Thoughts on Multimodal Approaches (I): Lossless Input", we briefly mentioned that the difficulty of image generation lies in modeling continuous probabilities. But in fact, this is a very superficial judgment. If it were only that, the situation would be much more optimistic, as we have already developed many continuous generative models like diffusion. In reality, the difficulty is much deeper than we imagined...

The images we discuss can generally be divided into two types: those created by humans and those captured by cameras. Since the proliferation of cameras and phones, images on the internet are predominantly photos. Thus, image generation is essentially equivalent to photo generation. What is a photo? It is a record of light, a projection of the 3D world onto a 2D plane. And what is light? Light is an electromagnetic wave, and electromagnetic waves are solutions to Maxwell's equations! From this reflection, we discover an undeniable fact: a real natural photo is essentially a solution to Maxwell's equations. This means that perfect image generation inevitably touches upon the laws of physics—the origin of the world that countless theoretical physicists tirelessly pursue!

Coincidentally, since the appearance of Sora, we often evaluate the quality of model-generated videos based on "whether they conform to real-world physical laws." In fact, even seemingly simpler image generation can be evaluated on the dimension of "conforming to physical laws," such as the distribution of light and shadow, except that video adds dynamics on top of optics (electromagnetism). Following this chain of thought, it becomes increasingly shocking, because it implies that a perfect visual generation model is actually numerically simulating various physical laws. Or more exaggeratedly, it is simulating the evolution of the entire world, the entire universe. It is essentially a World Model. This isn't just "hell-level" difficulty; it's "Creation-level" difficulty.

Some readers might object: Maxwell's equations are hard, but weren't they discovered by humans? We have discovered even harder physical laws, such as quantum mechanics and general relativity, and are constantly approaching a final theory (the Theory of Everything). So the difficulty doesn't seem that high? No, let's not get confused. Even if we can discover perfectly correct physical laws, being able to "numerically simulate using those laws" is a completely different matter. For example, we can write an equation by hand but might not be able to solve it manually. Thus, our ability to discover physical laws does not mean we can use them to deduce or simulate the real world. To be more specific, we can stand here and say the essence of a photo is a solution to Maxwell's equations, but no one can hand-draw a photo from scratch.

(Note: The above series of reflections originated from the idea that "an image is essentially a solution to Maxwell's equations," which was shared with me by my leader, Zhou Xinyu, during a technical exchange. When I first heard this seemingly absurd but undeniably true viewpoint, I was both astonished and shocked. In an instant, I felt I had gained clarity on the fundamental difficulty of multimodal models.)

Human Values

In summary, the point this brainstorm wants to convey is that a perfect visual generation model is a true "World Model," and its difficulty is at the level of creation. So, do we want to create a world? Do we have the ability to create a world? I believe the answer is negative for the foreseeable future; after all, that would feel like using human power to fight the entire universe. Therefore, the key is to abandon the concept of "perfection," much like a human cannot hand-draw a photo but can still paint and convey information through drawing. Or, for instance, when we evaluate whether a model-generated video follows physical laws, we aren't actually measuring the trajectory and plugging it into physical formulas; we are simply observing with our eyes and using our intuitive perception of physical laws.

To put it bluntly, loss is acceptable, as long as it is lossless relative to human values. What does this have to do with the "Patchification Ordering" mentioned earlier? We stated from the beginning that AR learning is intended not only to grant the model generation capabilities but also to serve as an unsupervised learning pathway to improve understanding (if you can't generate, you can't understand). If we had a model with truly infinite fitting capability (Creation capability), then all "Patchification Orderings" would be equivalent, because the precise joint distribution does not depend on the decomposition method of random variables. But unfortunately, we do not, so we must make trade-offs.

Note that we hope to promote understanding through learning generation; this "understanding" must align with human visual understanding. This is the purpose of training AI. However, of the "Patchification Ordering" methods listed earlier, none conform to how humans visually understand things. More directly, human understanding of an image doesn't go from left to right or top to bottom, nor from center outward or in a "Z-shape." Humans don't even understand images in units of patches, nor in a way like progressive denoising in Diffusion. If we use existing "Patchification Ordering" schemes for AR learning, we might grant the model some level of visual generation capability, but since these methods of deconstructing images do not align with human visual understanding modes, it's hard to believe this AR learning can promote the model's visual understanding—more accurately, its ability to mimic human visual understanding.

The main reason for this difficulty is that an image is a "result" and does not contain a "process." Take human-created images for example: whether hand-drawn or photoshopped, the process is a series of operations, but what is finally presented is an image that masks its creation process. This is different from text. Although we don't know how an author conceived a passage, we know most people write from left to right, so the text itself already contains its creation (writing) process. But what about a painting? Looking at a finished painting, do we know which part or stroke the painter drew first? Not at all. This is also why most people can write but cannot mimic a painting. Thinking deeper, it seems humans are better at mimicking along the temporal dimension than along the spatial dimension, because there is only one temporal dimension, while there are three spatial dimensions, giving the latter too much degree of freedom.

It seems there is only one very "compromised" method to solve this: use as many image-text pairs as possible that align with human values (and other valuable supervisory signals for images) to train a "Patchification Ordering" model in a supervised manner. Note that unlike common Vision Encoders, this model should directly output a 1D sequence instead of retaining the 2D nature of the image, thus eliminating the ordering step. The model design could refer to TiTok (this is the third time we've mentioned TiTok); it essentially uses Cross Attention to transform 2D sequences into 1D. Alternatively, a Q-Former could achieve a similar effect. In short, there isn't much difficulty in model design; the core work becomes the data engineering of image-text pairs.

But then, it's hard to say how far the model can go. We originally hoped that AR learning would promote the model's understanding ability, but now AR learning itself depends on an Encoder trained through understanding tasks. Ideally, these two models would mutually reinforce and evolve together; but in a non-ideal scenario, the model's capability would be limited by the quantity and quality of supervised training data for the Encoder, failing to achieve true unsupervised learning.

Article Summary

This article continued to "build wheels behind closed doors" with some thoughts on multimodal learning, centered around the autoregressive learning of vision. The main points are:

  1. Autoregressive learning grants models generation capability and serves as an unsupervised pathway to promote understanding through generation.
  2. For images, the question is not whether to do AR, but in what way to do it better.
  3. When images are input as continuous features, AR learning faces two major challenges: Patchification Ordering and the Loss Function.
  4. The loss function cannot simply be squared error; instead, one can consider appending a small diffusion model to predict the next patch.
  5. "Patchification Ordering" is the fundamental challenge of image AR learning; its choice determines whether AR learning can truly promote understanding.
  6. Perfect image/visual generation inevitably involves physical laws, forming a "World Model."
  7. However, World Models are difficult to achieve, so choosing a "Patchification Ordering" that aligns with human values is especially important.
  8. Finally, it seems we must "compromise" by using supervised learning to obtain a "Patchification Ordering" model aligned with human values.

There may be many "bold statements" and "fallacies" here; readers are encouraged to discern and be patient. The main purpose of writing down these thoughts is so that one day in the future, I can look back and see which parts of my original ideas were feasible and which were laughable.