Newton-Schulz Iteration for the msign Operator (Part 1)

By 苏剑林 | May 11, 2025

In previous articles such as "Appreciating the Muon Optimizer: A Fundamental Leap from Vectors to Matrices" and "Muon Sequel: Why We Choose to Try Muon?", we introduced a highly promising emerging optimizer named "Muon," which has the potential to replace Adam. As research continues to deepen, the attention surrounding the Muon optimizer is increasing day by day.

Readers familiar with Muon know that its core operation is the $\msign$ operator. Finding more efficient calculation methods for it is an ongoing goal of the academic community. This article summarizes its latest progress.

Preliminaries

The definition of $\msign$ is closely related to singular value decomposition (SVD). Suppose a matrix $\boldsymbol{M}\in\mathbb{R}^{n\times m}$, then \begin{equation}\boldsymbol{U},\boldsymbol{\Sigma},\boldsymbol{V}^{\top} = \text{SVD}(\boldsymbol{M}) \quad\Rightarrow\quad \msign(\boldsymbol{M}) = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top}\end{equation} where $\boldsymbol{U}\in\mathbb{R}^{n\times n}, \boldsymbol{\Sigma}\in\mathbb{R}^{n\times m}, \boldsymbol{V}\in\mathbb{R}^{m\times m}$, and $r$ is the rank of $\boldsymbol{M}$. Simply put, $\msign$ is the new matrix obtained by changing all non-zero singular values of the matrix to 1. Based on SVD, we can also prove: \begin{equation}\text{msign}(\boldsymbol{M}) = (\boldsymbol{M}\boldsymbol{M}^{\top})^{-1/2}\boldsymbol{M} = \boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}\end{equation} where $^{-1/2}$ represents the inverse square root power of the matrix. This form is very similar to the scalar $\mathop{\text{sign}}(x) = x / \sqrt{x^2}$, which is why the author uses the name $\msign$. However, it should be noted that this is not entirely identical to the "Matrix Sign" concept in Wikipedia; Wikipedia's concept only applies to square matrices, but when $\boldsymbol{M}$ is a symmetric matrix, the two are consistent.

When $m=n=r$, $\text{msign}(\boldsymbol{M})$ also has a meaning as the "optimal orthogonal approximation": \begin{equation}\text{msign}(\boldsymbol{M}) = \mathop{\text{argmin}}_{\boldsymbol{O}^{\top}\boldsymbol{O} = \boldsymbol{I}}\Vert \boldsymbol{M} - \boldsymbol{O}\Vert_F^2\end{equation} The proof process can be found in "Appreciating the Muon Optimizer: A Fundamental Leap from Vectors to Matrices". Because of this property, $\msign$ is also known as "symmetric orthogonalization," a name that first appeared in "On the Nonorthogonality Problem" (refer to the Wikipedia entry for "Orthogonalization").

Finally, in "Higher-order muP: Simpler but Smarter Spectral Condition Scaling", $\msign$ was viewed by the author as the limit version of "singular value clipping."

Iterative Computation

Since $\msign$ is defined by SVD, it can naturally be calculated precisely using SVD. However, precise SVD computation is computationally expensive, so in practice, "Newton-Schulz iteration" is used for approximate calculation.

Newton-Schulz iteration is a commonly used iterative algorithm for calculating matrix functions. For $\msign$, its iterative format is: \begin{equation}\boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\Vert\boldsymbol{M}\Vert_F},\qquad \boldsymbol{X}_{t+1} = a\boldsymbol{X}_t + b\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2+\cdots\end{equation} where $\Vert\boldsymbol{M}\Vert_F$ is the Frobenius norm of $\boldsymbol{M}$ (the square root of the sum of squares of all elements), and $(a,b,c,\cdots)$ are coefficients to be determined. In actual computation, we truncate to a finite number of terms, commonly 2 or 3. That is, choosing one of the following: \begin{gather}\boldsymbol{X}_{t+1} = a\boldsymbol{X}_t + b\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) \\[8pt] \boldsymbol{X}_{t+1} = a\boldsymbol{X}_t + b\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2\label{eq:power-5}\end{gather} Finally, $\boldsymbol{X}_T$ after $T$ iterations is returned as an approximation of $\msign(\boldsymbol{M})$. Consequently, the coefficients $(a,b,c)$ and the number of iterations $T$ constitute all the hyperparameters of the Newton-Schulz iteration. The reference selection provided by Muon author KellerJordan is: \begin{equation}(a,b,c)=(3.4445, -4.7750, 2.0315),\qquad T = 5\end{equation} Our theme next is to understand it and then attempt to improve it.

Reference Implementation

A minimalist reference implementation is provided here:

def msign(x, steps=5, eps=1e-20):
    a, b, c, y = 3.4445, -4.7750, 2.0315, x.astype('bfloat16')
    y = y.mT if x.shape[-2] > x.shape[-1] else y
    y /= ((y**2).sum(axis=(-2, -1), keepdims=True) + eps)**0.5
    for _ in range(steps):
        y = a * y + (b * (y2 := y @ y.mT) + c * y2 @ y2) @ y
    return y.mT if x.shape[-2] > x.shape[-1] else y

This implementation already includes batch processing capability (only performing $\msign$ on the last two dimensions) and can run in Jax; if you change x.astype('bfloat16') to x.to(torch.bfloat16), it can run in Torch; simply changing x.astype('bfloat16') to x allows it to run in Numpy.

Theoretical Analysis

To understand the principle of Newton-Schulz iteration, we need to analyze its steps one by one. First is $\boldsymbol{X}_0 = \boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$. Substituting the SVD of $\boldsymbol{M}$: \begin{equation}\boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\Vert\boldsymbol{M}\Vert_F} = \boldsymbol{U}_{[:,:r]}\left(\frac{\boldsymbol{\Sigma}_{[:r,:r]}}{\Vert\boldsymbol{M}\Vert_F}\right)\boldsymbol{V}_{[:,:r]}^{\top} = \boldsymbol{U}_{[:,:r]}\underbrace{\left(\frac{\boldsymbol{\Sigma}_{[:r,:r]}}{\Vert\boldsymbol{\Sigma}_{[:r,:r]}\Vert_F}\right)}_{\boldsymbol{S}_0}\boldsymbol{V}_{[:,:r]}^{\top}\end{equation} The last equal sign holds because the square of the $F$-norm equals the sum of squares of all components as well as the sum of squares of all singular values. The final result indicates that $\boldsymbol{S}_0$ is a diagonal matrix with components in $[0,1]$. In other words, all singular values of $\boldsymbol{X}_0=\boldsymbol{U}_{[:,:r]}\boldsymbol{S}_0\boldsymbol{V}_{[:,:r]}^{\top}$ do not exceed 1. This is the purpose of the first step $\boldsymbol{X}_0 = \boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$.

Next, substituting $\boldsymbol{U}_{[:,:r]}\boldsymbol{S}_t\boldsymbol{V}_{[:,:r]}^{\top}$ into equation \eqref{eq:power-5}, we obtain: \begin{equation}\boldsymbol{X}_{t+1} = \boldsymbol{U}_{[:,:r]}\left(a\boldsymbol{S}_t + b\boldsymbol{S}_t^3 + c\boldsymbol{S}_t^5\right)\boldsymbol{V}_{[:,:r]}^{\top}\end{equation} That is to say, the iteration does not change the left and right $\boldsymbol{U}_{[:,:r]}$ and $\boldsymbol{V}_{[:,:r]}^{\top}$. It is essentially an iteration of the diagonal matrix: \begin{equation}\boldsymbol{S}_{t+1} = a\boldsymbol{S}_t + b\boldsymbol{S}_t^3 + c\boldsymbol{S}_t^5\end{equation} Furthermore, the power of a diagonal matrix is equivalent to the power of each individual diagonal element. Thus, this is equivalent to the iteration of a scalar $x_t$: \begin{equation}x_{t+1} = a x_t + b x_t^3 + c x_t^5\end{equation} Since $\boldsymbol{X}_0 = \boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$ has already compressed the singular values into $(0,1]$, we hope that starting from any $x_0\in(0,1]$, after $T$ iterations, $x_T$ will be as close to 1 as possible. Then the iteration \eqref{eq:power-5} can sufficiently approximate $\msign$. In this way, we have simplified the analysis of matrix iteration into scalar iteration analysis, greatly reducing the difficulty of analysis.

Optimization Solution

In fact, the solution for $a,b,c$ was briefly discussed when we first introduced Muon in "Appreciating the Muon Optimizer: A Fundamental Leap from Vectors to Matrices". The basic idea is to treat $a,b,c$ as optimization parameters, construct a Loss using the difference between $x_T$ and 1, and then optimize using SGD.

The approach in this article is largely the same but with slight adjustments. Obviously, the optimization result will depend on the distribution of singular values. Previously, the author's idea was to use SVD of random matrices to simulate the real singular value distribution, but SVD is time-consuming and labor-intensive, and the results also depend on the shape of the matrix. Now it seems unnecessary. Instead, we take uniform points within $[0,1]$ and select the $k$ points with the largest $|x_T-1|$ to construct the Loss. This transforms it into a $\min\text{-}\max$ problem to minimize the impact of the singular value distribution as much as possible:

import jax
import jax.numpy as jnp
from tqdm import tqdm

def loss(w, x, k=50):
    for a, b, c in [w] * iters:
        x = a * x + b * x**3 + c * x**5
    return jnp.abs(x - 1).sort()[-k:].mean()

@jax.jit
def grad(w, x, tol=0.1):
    G = lambda w, x: (g := jax.grad(loss)(w, x)) / jnp.fmax(jnp.linalg.norm(g), 1)
    return 0.6 * G(w, x) + 0.2 * (G(w + tol / 2, x) + G(w - tol / 2, x))

iters = 5
x = jnp.linspace(0, 1, 10001)[1:]
w = jnp.array([1.5, -0.5, 0])
m, v = jnp.zeros_like(w), jnp.zeros_like(w)
lr = 1e-3
pbar = tqdm(range(20000), ncols=0, desc='Adam')

for i in pbar:
    l, g = loss(w, x), grad(w, x)
    m = 0.9 * m + 0.1 * g
    v = 0.999 * v + 0.001 * g**2
    w = w - lr * m / jnp.sqrt(v + 1e-20)
    pbar.set_description(f'Loss: {l:.6f}, LR: {lr:.6f}')
    if i in [10000]:
        lr *= 0.1

Additionally, the optimizer has been changed from SGD to Adam, which is easier for controlling the magnitude of parameter updates. At the same time, to enhance the noise resistance of the solution, we add a certain perturbation to $a,b,c$ and mix in the gradient after the perturbation. The optimization result of the above script is: \begin{equation}(a,b,c)=(3.3748, -4.6969, 2.1433)\end{equation} It can be seen that this is not far from KellerJordan's solution. Comparing the two through images:

Approximation effect on [0, 1]

Approximation effect on [0, 0.01]

Globally, the solution we obtained here has a slightly smaller average error, while the benefit of KellerJordan's solution is a larger slope in the $[0, 0.01]$ range, which means it is more favorable for smaller singular values.

Initial Value Distribution

Before further discussion, we need to clarify one question: how small of a singular value do we actually care about? This goes back to the distribution of $\boldsymbol{S}_0$. Since $\boldsymbol{S}_0$ is normalized by the $F$-norm, $\mathop{\text{diag}}(\boldsymbol{S}_0)$ is actually an $r$-dimensional unit vector. If all singular values are equal, then each singular value is $1/\sqrt{r}$.

Consequently, according to the Pigeonhole Principle, in a non-uniform case, there must exist singular values smaller than $1/\sqrt{r}$. To be safe, we can consider a multiple, such as 10 times, which means we must at least account for singular values of size $0.1/\sqrt{r}$. In practice, the probability of a matrix being strictly low-rank (i.e., singular values strictly equal to 0) is very small, so we generally assume the matrix is full-rank, i.e., $r = \min(n,m)$. Therefore, we should at least account for $0.1/\sqrt{\min(n,m)}$ singular values.

Considering that the largest LLMs today have hidden_sizes in the range of $8192 \sim 100^2$, based on this numerical estimation, a general Muon optimizer's $\msign$ algorithm should at least account for singular values as small as $0.001$. That is, it should be able to map $0.001$ to a value close to 1. From this perspective, both KellerJordan's solution and our newly found solution fall short.

Note: For a discussion on the initial value distribution, one can also refer to "Iterative Orthogonalization Scaling Laws".

Unlocking Constraints

At this point, @YouJiacheng (one of Muon's primary advocates) on Twitter proposed a very clever idea: we can use different coefficients at each iteration step! That is, changing the iteration to: \begin{equation}\boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c_{t+1}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2\end{equation} The advantage of this change is that when $T$ is fixed, the total computational load does not change at all. However, from a fitting perspective, instead of having only 3 training parameters, there are now $3T$, and the fitting capacity will be greatly enhanced. He personally provided a reference result for a 6-step iteration:

$t$ $a$ $b$ $c$
1 3955/1024 -8306/1024 5008/1024
2 3735/1024 -6681/1024 3463/1024
3 3799/1024 -6499/1024 3211/1024
4 4019/1024 -6385/1024 2906/1024
5 2677/1024 -3029/1024 1162/1024
6 2172/1024 -1833/1024 682/1024

We can plot and compare them:

Approximation effect on [0, 1]

Approximation effect on [0, 0.01]

For fairness, the KellerJordan and Ours solutions have also been changed to $T=6$. It can be seen that whether in terms of smoothness or overall approximation degree, YouJiacheng's solution shows a very significant improvement. This fully demonstrates the "full potential" power released by removing parameter sharing.

Try it Yourself

How was YouJiacheng's solution obtained? The author shared his code here. The idea is also to solve using Adam, but it contains many different Losses, making it a bit tedious to understand. In fact, using our previous script along with the initialization he provided can yield equally good results:

$t$ $a$ $b$ $c$
1 4140/1024 -7553/1024 3571/1024
2 3892/1024 -6637/1024 2973/1024
3 3668/1024 -6456/1024 3021/1024
4 3248/1024 -6211/1024 3292/1024
5 2792/1024 -5759/1024 3796/1024
6 3176/1024 -5507/1024 4048/1024

Reference code:

import jax
import jax.numpy as jnp
from tqdm import tqdm

def loss(w, x, k=50):
    for a, b, c in w:
        x = a * x + b * x**3 + c * x**5
    return jnp.abs(x - 1).sort()[-k:].mean()

@jax.jit
def grad(w, x, tol=0.1):
    G = lambda w, x: (g := jax.grad(loss)(w, x)) / jnp.fmax(jnp.linalg.norm(g), 1)
    return 0.6 * G(w, x) + 0.2 * (G(w + tol / 2, x) + G(w - tol / 2, x))

iters = 6
x = jnp.linspace(0, 1, 10001)[1:]
w = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * iters)
m, v = jnp.zeros_like(w), jnp.zeros_like(w)
lr = 1e-3
pbar = tqdm(range(20000), ncols=0, desc='Adam')

for i in pbar:
    l, g = loss(w, x), grad(w, x)
    m = 0.9 * m + 0.1 * g
    v = 0.999 * v + 0.001 * g**2
    w = w - lr * m / jnp.sqrt(v + 1e-20)
    pbar.set_description(f'Loss: {l:.6f}, LR: {lr:.6f}')
    if i in [10000]:
        lr *= 0.1

Comparison below (labeled "Ours-X"):

Approximation effect on [0, 1]

Approximation effect on [0, 0.01]

As shown in the figures, compared to YouJiacheng's solution, our results oscillate more but achieve a larger slope in the $[0,0.001]$ region.

Other Solution Sets

If the reader wants a solution with fewer oscillations, simply increasing the $k$ value will do. For example, the result for $k=200$ is:

$t$ $a$ $b$ $c$
1 4059/1024 -7178/1024 3279/1024
2 3809/1024 -6501/1024 2925/1024
3 3488/1024 -6308/1024 3063/1024
4 2924/1024 -5982/1024 3514/1024
5 2439/1024 -5439/1024 4261/1024
6 3148/1024 -5464/1024 4095/1024

At this point, it is almost identical to YouJiacheng's solution (Ours-X2):

Approximation effect on [0, 1]

Approximation effect on [0, 0.01]

Additionally, here is a 5-step solution for easy comparison with the original solution:

$t$ $a$ $b$ $c$
1 4.6182 -12.9582 9.3299
2 3.8496 -7.9585 4.3052
3 3.5204 -7.2918 4.0606
4 3.2067 -6.8243 4.2802
5 3.2978 -5.7848 3.8917

Effect diagram (Ours-X3):

Approximation effect on [0, 1]

Approximation effect on [0, 0.01]

Improving the Initial Value

So far, our quest to solve for $a,b,c$ has concluded. Overall, using different $a,b,c$ at each step can significantly improve the convergence properties of Newton-Schulz iteration without increasing any computational cost, making it a free lunch.

Besides optimizing the coefficients of Newton-Schulz iteration, are there other ways to improve the convergence properties of the iteration? There are indeed. @johanwind, @YouJiacheng, @ZhangRuichong, and others discovered that we can leverage the characteristics of Newton-Schulz iteration to improve initial value quality almost for free, thereby increasing convergence speed. @leloykun provided a reference implementation here.

Specifically, the main efforts to improve Newton-Schulz iteration can be summarized as "maximising the convergence speed of singular values close to zero while ensuring convergence." If we can slightly amplify these near-zero singular values beforehand, the convergence speed can be increased without changing the iteration algorithm. To compress singular values into $[0,1]$, we currently use $F$-norm normalization $\boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$, which compresses singular values as: \begin{equation}\sigma_i \quad\to\quad \frac{\sigma_i}{\Vert\boldsymbol{M}\Vert_F} = \frac{\sigma_i}{\sqrt{\sum\limits_{j=1}^r \sigma_j^2}} \in [0, 1]\end{equation} While this does achieve the goal, it suffers from over-compression. The most compact compression method would be $\sigma_i\to \sigma_i/\sigma_1$, i.e., spectral normalization. The problem is that spectral norm is not as easy to compute as $F$-norm, so we were forced to choose $F$-norm. However, we have: \begin{equation}\sigma_1 \quad\leq\quad \underbrace{\sqrt[\uproot{10}8]{\sum_{j=1}^r \sigma_j^8}}_{\sqrt[4]{\Vert(\boldsymbol{M}^{\top}\boldsymbol{M})^2\Vert_F}} \quad\leq\quad \underbrace{\sqrt[\uproot{10}4]{\sum_{j=1}^r \sigma_j^4}}_{\sqrt{\Vert\boldsymbol{M}^{\top}\boldsymbol{M}\Vert_F}} \quad\leq\quad \underbrace{\sqrt{\sum_{j=1}^r \sigma_j^2}}_{\Vert\boldsymbol{M}\Vert_F} \end{equation} This means that using $\sqrt[4]{\Vert(\boldsymbol{M}^{\top}\boldsymbol{M})^2\Vert_F}$ or $\sqrt{\Vert\boldsymbol{M}^{\top}\boldsymbol{M}\Vert_F}$ as a normalization factor is theoretically better than $\Vert\boldsymbol{M}\Vert_F$. Very cleverly, under Newton-Schulz iteration, their computation is almost free! To understand this, let's write out the first step of the iteration: \begin{equation}\boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\Vert\boldsymbol{M}\Vert_F},\qquad \boldsymbol{X}_1 = a\boldsymbol{X}_0 + b\boldsymbol{X}_0(\boldsymbol{X}_0^{\top}\boldsymbol{X}_0) + c\boldsymbol{X}_0(\boldsymbol{X}_0^{\top}\boldsymbol{X}_0)^2\end{equation} We can see that $\boldsymbol{X}_0^{\top}\boldsymbol{X}_0$ and $(\boldsymbol{X}_0^{\top}\boldsymbol{X}_0)^2$ must be calculated anyway, so we can directly take their $F$-norms and re-normalize. Reference code:

def msign(x, steps=5, eps=1e-20):
    a, b, c, y = 3.4445, -4.7750, 2.0315, x.astype('bfloat16')
    y = y.mT if x.shape[-2] > x.shape[-1] else y
    y /= ((y**2).sum(axis=[-2, -1], keepdims=True) + eps)**0.5
    for i in range(steps):
        y4 = (y2 := y @ y.mT) @ y2
        if i == 0:
            n = ((y4**2).sum(axis=[-2, -1], keepdims=True) + eps)**0.125
            y, y2, y4 = y / n, y2 / n**2, y4 / n**4
        y = a * y + (b * y2 + c * y4) @ y
    return y.mT if x.shape[-2] > x.shape[-1] else y

Empirical results show that for a $100\times 100$ random Gaussian matrix, the improved minimum singular values are mostly more than twice those before the improvement, and the average singular values are also closer to 1. However, the author of Muon also noted that it might introduce additional instability, so it has not yet been adopted into the official code.

Summary

This article introduced optimization ideas for calculating $\msign$ via Newton-Schulz iteration. The results obtained can significantly improve the iteration's convergence speed and effect compared to Muon's official solution.

Finally, it should be noted that for Muon, small-scale experimental results show that the calculation accuracy of $\msign$ does not seem to have a necessary connection with the final model performance. Improving the precision of $\msign$ in small models only seems to accelerate convergence slightly in the early stages, but the final outcome remains unchanged. It is currently unclear whether this conclusion holds at a larger scale.