Reflections from Spectral Norm Gradients to a New Type of Weight Decay

By 苏剑林 | December 25, 2024

In the article "Appreciating the Muon Optimizer: A Fundamental Leap from Vectors to Matrices," we introduced a new optimizer named "Muon." One perspective to understand it is as steepest gradient descent under spectral norm regularization, which seems to reveal a more fundamental optimization direction for matrix parameters. As we all know, we frequently apply weight decay to matrix parameters, which can be understood as the gradient of the squared $F$-norm. From the perspective of Muon, would constructing a new weight decay using the gradient of the squared spectral norm yield better results?

The question then arises: what does the gradient—or rather, the derivative—of the spectral norm look like? And what would a new weight decay designed with it look like? Next, we will explore these questions.

Basic Review

The spectral norm (Spectral Norm), also known as the "$2$-norm," is one of the most commonly used matrix norms. Compared to the simpler $F$-norm (Frobenius Norm), it often reveals more fundamental signals related to matrix multiplication. This is because its definition is inherently tied to matrix multiplication: for a matrix parameter $\boldsymbol{W} \in \mathbb{R}^{n \times m}$, its spectral norm is defined as

\begin{equation}\Vert\boldsymbol{W}\Vert_2 \triangleq \max_{\Vert\boldsymbol{x}\Vert=1} \Vert\boldsymbol{W}\boldsymbol{x}\Vert\end{equation}

Here $\boldsymbol{x} \in \mathbb{R}^m$ is a column vector, and $\Vert\Vert$ on the right side is the vector norm (Euclidean norm). From another perspective, the spectral norm is the smallest constant $C$ such that the following inequality holds for all $\forall \boldsymbol{x} \in \mathbb{R}^m$:

\begin{equation}\Vert\boldsymbol{W}\boldsymbol{x}\Vert \leq C\Vert\boldsymbol{x}\Vert\end{equation}

It is not difficult to prove that when $C$ takes the $F$-norm $\Vert W \Vert_F$, the above inequality also holds. Therefore, we can write $\Vert \boldsymbol{W}\Vert_2 \leq \Vert \boldsymbol{W}\Vert_F$ (because $\Vert \boldsymbol{W}\Vert_F$ is merely one such $C$ that makes the inequality hold, while $\Vert \boldsymbol{W}\Vert_2$ is the smallest such $C$). This conclusion also indicates that if we want to control the magnitude of the output, using the spectral norm as a regularization term is more precise than using the $F$-norm.

As early as six years ago, in "Lipschitz Constraints in Deep Learning: Generalization and Generative Models," we discussed the spectral norm. At that time, there were two application scenarios: first, WGAN explicitly proposed a Lipschitz constraint for the discriminator, and one implementation method was normalization based on the spectral norm; second, some work indicated that spectral norm as a regularization term performs better than $F$-norm regularization.

Gradient Derivation

Now let us get to the main point and attempt to derive the gradient of the spectral norm $\nabla_{\boldsymbol{W}} \Vert\boldsymbol{W}\Vert_2$. We know that the spectral norm is numerically equal to its largest singular value, as we proved in the "Matrix Norms" section of "The Road to Low-Rank Approximation (II): SVD." This means that if $\boldsymbol{W}$ can be decomposed via SVD as $\sum\limits_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}$, then

\begin{equation}\Vert\boldsymbol{W}\Vert_2 = \sigma_1 = \boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1\end{equation}

where $\sigma_1 \geq \sigma_2 \geq \cdots \geq \sigma_{\min(n,m)} \geq 0$ are the singular values of $\boldsymbol{W}$. Taking the differential of both sides, we get

\begin{equation}d\Vert\boldsymbol{W}\Vert_2 = d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 + \boldsymbol{u}_1^{\top}d \boldsymbol{W}\boldsymbol{v}_1 + \boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1\end{equation}

Note that

\begin{equation}d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 = d\boldsymbol{u}_1^{\top}\sum_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}\boldsymbol{v}_1 = d\boldsymbol{u}_1^{\top}\sigma_1 \boldsymbol{u}_1 = \frac{1}{2}\sigma_1 d(\Vert\boldsymbol{u}_1\Vert^2)=0\end{equation}

Similarly, $\boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1=0$, so

\begin{equation}d\Vert\boldsymbol{W}\Vert_2 = \boldsymbol{u}_1^{\top}d\boldsymbol{W}\boldsymbol{v}_1 = \text{Tr}((\boldsymbol{u}_1 \boldsymbol{v}_1^{\top})^{\top} d\boldsymbol{W}) \quad\Rightarrow\quad \nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2 = \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}\end{equation}

Note that this proof process has a key condition: $\sigma_1 > \sigma_2$. If $\sigma_1 = \sigma_2$, $\Vert\boldsymbol{W}\Vert_2$ can be represented as both $\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1$ and $\boldsymbol{u}_2^{\top}\boldsymbol{W}\boldsymbol{v}_2$. The gradients calculated using the same method would be $\boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$ and $\boldsymbol{u}_2 \boldsymbol{v}_2^{\top}$, respectively; the non-uniqueness of the result means the gradient does not exist. However, from a practical standpoint, the probability of two numbers being exactly equal is very small, so this point can be ignored.

(Note: The proof process here refers to an answer on Stack Exchange, but that answer did not prove $d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1=0$ and $\boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1=0$; these parts were completed by the author.)

Weight Decay

Based on this result and the chain rule, we have

\begin{equation}\nabla_{\boldsymbol{W}}\left(\frac{1}{2}\Vert\boldsymbol{W}\Vert_2^2\right) = \Vert\boldsymbol{W}\Vert_2\nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2 = \sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}\label{eq:grad-2-2}\end{equation}

Contrast this with the result under the $F$-norm:

\begin{equation}\nabla_{\boldsymbol{W}}\left(\frac{1}{2}\Vert\boldsymbol{W}\Vert_F^2\right) = \boldsymbol{W} = \sum_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top}\end{equation}

Looking at them side-by-side makes it very clear: weight decay resulting from the squared $F$-norm as a regularization term penalizes all singular values simultaneously; however, the weight decay corresponding to the squared spectral norm only penalizes the largest singular value. If our goal is to compress the size of the output, then compressing the maximum singular value is a "just right" approach. While compressing all singular values might achieve a similar goal, it might also simultaneously compress the expressive capacity of the parameters.

According to the "Eckart-Young-Mirsky Theorem," the result on the far right of equation $\eqref{eq:grad-2-2}$ also has another meaning: it is the "optimal rank-1 approximation" of matrix $\boldsymbol{W}$. In other words, spectral norm weight decay changes the operation of subtracting a portion of itself at each step to subtracting a portion of its optimal rank-1 approximation at each step, weakening the penalty intensity and, to some extent, making the penalty "hit the essence" more directly.

Numerical Calculation

For practice, the most critical question arises: how do we calculate $\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$? SVD is certainly the most straightforward scheme, but its computational complexity is undoubtedly the highest. We must find a more efficient calculation path.

Without loss of generality, let $n \geq m$. First note that

\begin{equation}\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} = \sum_{i=1}^m\sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top} \boldsymbol{v}_1 \boldsymbol{v}_1^{\top} = \boldsymbol{W}\boldsymbol{v}_1 \boldsymbol{v}_1^{\top}\end{equation}

It can be seen that calculating $\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$ only requires knowing $\boldsymbol{v}_1$. According to our discussion in "Singular Value Decomposition" of "The Road to Low-Rank Approximation (II): SVD," $\boldsymbol{v}_1$ is actually the eigenvector corresponding to the largest eigenvalue of the matrix $\boldsymbol{W}^{\top}\boldsymbol{W}$. In this way, we have transformed the problem from the SVD of a general matrix $\boldsymbol{W}$ to the eigenvalue decomposition of a real symmetric matrix $\boldsymbol{W}^{\top}\boldsymbol{W}$, which already reduces complexity as eigenvalue decomposition is usually significantly faster than SVD.

If that is still too slow, then we need to call upon the principle behind many eigenvalue decomposition algorithms—"Power Iteration":

When $\sigma_1 > \sigma_2$, the iteration

\begin{equation}\boldsymbol{x}_{t+1} = \frac{\boldsymbol{W}^{\top}\boldsymbol{W}\boldsymbol{x}_t}{\Vert\boldsymbol{W}^{\top}\boldsymbol{W}\boldsymbol{x}_t\Vert}\end{equation}

converges to $\boldsymbol{v}_1$ at a rate of $(\sigma_2/\sigma_1)^{2t}$.

Each step of power iteration only requires two "matrix-vector" multiplications, with a complexity of $\mathcal{O}(nm)$. The total complexity for $t$ iterations is $\mathcal{O}(tnm)$, which is very ideal. The drawback is that convergence can be slow when $\sigma_1$ and $\sigma_2$ are close. However, the actual performance of power iteration is often better than theory would suggest. Early works even used only one iteration to get decent results, because the proximity of $\sigma_1$ and $\sigma_2$ implies that they and their eigenvectors are replaceable to some extent, and even if power iteration does not fully converge, it yields an average of the two eigenvectors, which is entirely sufficient.

Iteration Proof

In this section, we complete the proof for power iteration. It is not hard to see that power iteration can be equivalently written as

\begin{equation}\lim_{t\to\infty} \frac{(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0}{\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert} = \boldsymbol{v}_1\end{equation}

To prove this limit, we start from $\boldsymbol{W}=\sum\limits_{i=1}^m\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}$. Substituting it into the calculation, we get

\begin{equation}\boldsymbol{W}^{\top}\boldsymbol{W} = \sum_{i=1}^m\sigma_i^2 \boldsymbol{v}_i\boldsymbol{v}_i^{\top},\qquad(\boldsymbol{W}^{\top}\boldsymbol{W})^t = \sum_{i=1}^m\sigma_i^{2t} \boldsymbol{v}_i\boldsymbol{v}_i^{\top}\end{equation}

Since $\boldsymbol{v}_1, \boldsymbol{v}_2, \cdots, \boldsymbol{v}_m$ form an orthonormal basis for $\mathbb{R}^m$, $\boldsymbol{x}_0$ can be written as $\sum\limits_{j=1}^m c_j \boldsymbol{v}_j$. Thus we have

\begin{equation}(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0 = \sum_{i=1}^m\sigma_i^{2t} \boldsymbol{v}_i\boldsymbol{v}_i^{\top}\sum_{j=1}^m c_j \boldsymbol{v}_j = \sum_{i=1}^m\sum_{j=1}^m c_j\sigma_i^{2t} \boldsymbol{v}_i\underbrace{\boldsymbol{v}_i^{\top} \boldsymbol{v}_j}_{=\delta_{i,j}} = \sum_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i\end{equation}

and

\begin{equation}\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert = \left\Vert \sum_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i\right\Vert = \sqrt{\sum_{i=1}^m c_i^2\sigma_i^{4t}}\end{equation}

Due to the random initialization, the probability of $c_1=0$ is very small, so we can assume $c_1 \neq 0$. Then

\begin{equation}\frac{(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0}{\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert} = \frac{\sum\limits_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i}{\sqrt{\sum\limits_{i=1}^m c_i^2\sigma_i^{4t}}} = \frac{\boldsymbol{v}_1 + \sum\limits_{i=2}^m (c_i/c_1)(\sigma_i/\sigma_1)^{2t} \boldsymbol{v}_i}{\sqrt{1 + \sum\limits_{i=2}^m (c_i/c_1)^2(\sigma_i/\sigma_1)^{4t}}}\end{equation}

When $\sigma_1 > \sigma_2$, all $\sigma_i/\sigma_1 (i \geq 2)$ are less than 1. Therefore, as $t \to \infty$, the corresponding terms become 0, and the final limit is $\boldsymbol{v}_1$.

Related Work

The first paper to propose spectral norm regularization should be the 2017 paper "Spectral Norm Regularization for Improving the Generalizability of Deep Learning." It compared methods such as weight decay, adversarial training, and spectral norm regularization, finding that spectral norm regularization performed best in terms of generalization performance.

The method used in the paper at that time was not to derive $\nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2^2 = 2\sigma_1\boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$ as we did here. Instead, it directly estimated $\Vert\boldsymbol{W}\Vert_2$ through power iteration and then added the weighted $\Vert\boldsymbol{W}\Vert_2^2$ to the loss function, letting the optimizer calculate the gradient itself. This approach is slightly less efficient and is not as easy to decouple from the optimizer in the form of weight decay. The approach in this article is relatively more flexible, allowing us to decouple weight decay from the main loss function optimization, much like AdamW.

Of course, from the current perspective of LLMs, the biggest problem with those early experiments was that their scale was too small to be sufficiently convincing. However, given that the Muon optimizer based on spectral norm is "ahead of its time," I believe it is worth re-thinking and trying spectral norm weight decay. Of course, whether it is $F$-norm or spectral norm weight decay, these "generalization-oriented" techniques often involve a certain element of luck; everyone should maintain a steady expectation.

Preliminary experimental results on individual language models show that there might be a slight improvement at the Loss level (hopefully not an illusion; at least it hasn't shown any degradation). The experimental process involved using power iteration to find an approximation of $\boldsymbol{v}_1$ (initialized as an all-ones vector, iterating 10 times), and then changing the original weight decay $-\lambda \boldsymbol{W}$ to $-\lambda \boldsymbol{W}\boldsymbol{v}_1\boldsymbol{v}_1^{\top}$, with the value of $\lambda$ remaining unchanged.

Article Summary

This article derived the gradient of the spectral norm, leading to a new type of weight decay, and shared the author's reflections on it.