Deriving Model Scaling Laws Based on the Quantization Hypothesis

By 苏剑林 | May 18, 2023

Scaling Law refers to the asymptotic relationship between a model's performance and its scale. Specifically, model performance can be simplified as the model's loss function, while model scale can refer to the number of parameters, the amount of training data, or the number of training steps. Research on Scaling Laws investigates the general relationship between the loss function and variables such as parameters, data volume, and training steps. Experimental results from works like "Scaling Laws for Neural Language Models" and "Training Compute-Optimal Large Language Models" show that the Scaling Laws for neural networks mostly take the form of a "Power Law."

Why do they follow a power law? Can it be explained theoretically? The paper "The Quantization Model of Neural Scaling" provides a very interesting derivation based on a "Quantization" hypothesis. Let's explore it in this article.

Derivation Hypotheses

First, we assume that for a specific task, there exists a "perfect model," and the models we train are approximations of this "perfect model." Furthermore, we assume that the "perfect model" is composed of "Quanta," where each quantum represents a specific capability (note that "quanta" here refers mainly to virtual units of capability, not necessarily specific skills we can name).

To complete a task, multiple capabilities are usually required. Therefore, without loss of generality, we assume the "perfect model" contains infinitely many capability quanta. Different quanta are responsible for solving samples of varying difficulty. Generally, simple samples account for the majority, while difficult samples are in the minority. Thus, these capability quanta can be sorted from highest to lowest frequency of appearance, labeled as $1, 2, \dots, k, \dots$, with corresponding frequencies $p_1, p_2, \dots, p_k, \dots$.

Finally, we assume the frequencies of these capability quanta follow "Zipf's Law", namely: \begin{equation}p_k = \frac{k^{-\gamma - 1}}{Z_{\gamma}}\end{equation} where $\gamma > 0$ and $Z_{\gamma}$ is the normalization factor $\sum_{k=1}^{\infty} k^{-\gamma - 1}$.

Zipf's Law

The reader might ask: why Zipf's Law? Zipf's Law is an empirical law published by Zipf in 1949. His original finding was that the frequency of a word is approximately inversely proportional to its rank in a frequency table. Later, people generalized it to be inversely proportional to the "power of the rank," and it has been observed in many different fields.

Zipf himself and some successors attempted to derive Zipf's Law based on assumptions closer to the essence of things; related work can be found on Wikipedia and will not be expanded here. For me, the most important reason for choosing Zipf's Law is actually—there weren't many other choices.

Don't forget that $p_k$ is already sorted from high to lowest, so $p_k$ is a monotonically decreasing function. What non-negative, monotonically decreasing functions can we think of? Basically just exponential functions and power-law functions. Exponential functions decay very quickly and thus lack a long-tail phenomenon, whereas power laws decay slower and are relatively more long-tailed. Choosing one depends on our prior knowledge of the tail's importance. For the capability quanta hypothesis, we believe every capability is critical, so we must choose a power function, which results in Zipf's Law.

Basic Results

Back to the topic. Previously, we assumed the ideal model has infinitely many capability quanta. For a realistic model with limited capacity, it can only learn $n$ quanta. To cover as many samples as possible, the model should learn the first $n$ quanta. Assuming each quantum reduces the loss of its corresponding samples from $b$ to $a$, the average loss of the model can be estimated as: \begin{equation}L = a \sum_{k=1}^n p_k + b \sum_{k=n+1}^{\infty} p_k\end{equation} The first $n$ quanta have been learned, so the loss for that part of the samples is $a$; the remaining quanta have not been learned, so their loss remains $b$. This assumption seems a bit strong—setting $a$ and $b$ as functions of $k$ might be more reasonable—but the current result is already representative (refer to the appendix of the original paper). For the above equation, we can complete an asymptotic estimation: \begin{equation}\begin{aligned} L =&\, a \sum_{k=1}^{\infty} p_k + (b - a) \sum_{k=n+1}^{\infty} p_k \\ =&\, a + (b - a) \sum_{k=n+1}^{\infty} \frac{k^{-\gamma-1}}{Z_{\gamma}} \\ \sim&\, a + (b - a) \int_n^{\infty} \frac{k^{-\gamma-1}}{Z_{\gamma}} dk \\ =&\, a + \frac{b - a}{\gamma Z_{\gamma}} n^{-\gamma} \\ \end{aligned}\end{equation} It shows that the model capability (loss function) relates to the number of capability quanta $n$ in the form of a power law $n^{-\gamma}$. Clearly, here $a$ represents the minimum value of the loss function; if $a=0$, then $L \sim \mathcal{O}(n^{-\gamma})$. In the following, we assume $a=0$.

Scaling Laws

In the basic result, $n$ is the number of capability quanta learned by the model, which so far is still a virtual concept. Next, we will link it to common variables in models.

Parameter Count: Assuming the model's parameter count is $N$, and assuming that on average it takes $C$ parameters to learn one capability quantum (where $C$ is a constant), then obviously $n \propto N$, and: \begin{equation}L \sim \mathcal{O}(N^{-\gamma})\end{equation}

Data Volume: Assuming the total number of samples in the training set is $D$. Since we assume different quanta solve samples of different difficulties, we can naturally believe that the number of samples solved by quantum $1$ is $D p_1$, by quantum $2$ is $D p_2$, by quantum $3$ is $D p_3$, and so on. If we assume that learning one quantum requires at least $\tau$ samples, then quanta where $D p_k < \tau$ cannot be learned. Thus, from $\tau = D p_n$, we can solve for $n \propto D^{1/(\gamma + 1)}$. Substituting this back gives: \begin{equation}L \sim \mathcal{O}(D^{-\gamma/(\gamma + 1)})\end{equation}

Training Amount: Assuming the model parameters and training set samples are unlimited, then the number of quanta $n$ learned by the model depends on the training steps $S$. Assuming the batch size is $B$, then on average, the number of samples used to learn quantum $1$ is $B p_1$, quantum $2$ is $B p_2$, and so on. Similarly, assuming learning a quantum requires at least $\tau$ samples, then after $S$ steps of training, quantum $n$ has been trained on $S B p_n$ samples. From $\tau = S B p_n$, we can solve $n \propto S^{1/(\gamma + 1)}$. Substituting this back gives: \begin{equation}L \sim \mathcal{O}(S^{-\gamma/(\gamma + 1)})\end{equation}

As we can see, although the results are all power laws, because $\gamma > \gamma/(\gamma + 1) \in (0, 1)$, it is clear that the parameter count has a larger impact on model capability.

Emergence Phenomena

Some readers might ask: can the capability quantization hypothesis be used to explain the "Emergence" phenomenon in large models?

To some extent, yes. Previously, we assumed that the perfect model should have infinitely many capability quanta. If we change this infinity to a finite number, then by increasing the parameter count, the model will eventually have a chance to cover all capability quanta and reach the theoretically optimal perfect model—this is emergence. Alternatively, if the perfect model still has infinitely many quanta, but human "resolution" of intelligence only spans a finite number of quanta (humans themselves might not be perfect), then once a large model learns a certain number of capability quanta, it appears as a perfect "emergence" from the human perspective.

Article Summary

This article introduced the process of deriving model Scaling Laws from the quantization hypothesis. Specifically, it looked at the asymptotic relationship between the model's loss function and its parameters, data volume, and training amount, and briefly analyzed its possible connection with emergence phenomena.


More detailed reprinting matters: