Newton-Schulz Iteration for the msign Operator (Part 2)

By 苏剑林 | June 05, 2025

In the previous article "Newton-Schulz Iteration for the msign Operator (Part 1)", we attempted to find better Newton-Schulz iterations for the $\mathop{\text{msign}}$ operator, aiming to achieve the highest possible approximation within a limited number of iteration steps. This process can be transformed into finding polynomial iterations of the same form for the scalar function $\mathop{\text{sign}}(x)$. At that time, our approach was to use the Adam optimizer to find a local optimal solution end-to-end, which, while effective, was somewhat crude.

A few days ago, a new paper appeared on arXiv titled "The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm". The authors used a series of exquisite mathematical conclusions to provide a more elegant and hardcore answer. In this article, let's appreciate and learn from this brilliant paper.

Problem Description

We will not repeat the relevant background and transformation process. The problem we want to solve directly is:

\begin{equation}\mathop{\text{argmin}}_f d(f(x),1)\end{equation}

where $f = f_T \circ \dots \circ f_2 \circ f_1$, $\circ$ represents function composition, $f_t(x)$ is an odd polynomial in $x$ (containing only odd powers of $x$), and $d(f(x),1)$ is a metric measuring the distance between the function $f(x)$ and $1$. In the previous article, we uniformly selected a finite number of points in $[0,1]$ and took the average of the largest $k$ values of $|f(x)-1|$ as the metric. In this paper, the maximum value of $|f(x)-1|$ in the interval is taken as the metric, i.e.,

\begin{equation}\mathop{\text{argmin}}_f \max_{x\in[l,u]} |f(x) - 1| \label{eq:opt}\end{equation}

where $[l,u]\subset [0,1]$. Note that here $u$ can be taken as 1, but $l$ cannot be 0, because $f(0)$ is always 0, which means the above expression is always greater than or equal to 1 and cannot converge. Therefore, $l$ can only be chosen as a number very close to 0. According to the analysis in the previous article, for universality, we should account for singular values as small as $0.001$, so we can consider $l=0.001$.

Before starting the analysis, let's briefly explain the meaning of the word "Polar" in the paper's title. It actually represents the "Polar Decomposition" of a matrix:

Polar Decomposition: For a square matrix $\boldsymbol{M}\in\mathbb{R}^{n\times n}$, its polar decomposition is $\boldsymbol{M}=\boldsymbol{Q}\boldsymbol{S}$, where $\boldsymbol{Q}$ is an orthogonal matrix and $\boldsymbol{S}$ is a positive semi-definite matrix.

If the SVD of $\boldsymbol{M}$ is $\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$, then we exactly have:

\begin{equation}\boldsymbol{M} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} = (\boldsymbol{U}\boldsymbol{V}^{\top})(\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top})\end{equation}

where $\boldsymbol{Q}=\boldsymbol{U}\boldsymbol{V}^{\top}$ and $\boldsymbol{S}=\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$ is one answer to the polar decomposition. We know that when $\boldsymbol{M}$ is a full-rank matrix, $\boldsymbol{U}\boldsymbol{V}^{\top}$ is exactly $\mathop{\text{msign}}(\boldsymbol{M})$. This is why $\mathop{\text{msign}}$ is linked with "Polar," because solving for it leads to the "Polar Decomposition" of the matrix. In other words, the core difficulty of polar decomposition is calculating $\mathop{\text{msign}}$, which shares the same goal as Muon.

Greedy is Sufficient

Back to the point. Regarding problem $\eqref{eq:opt}$, the first conclusion drawn in the original paper, which is also the core conclusion of the entire paper, is: Its greedy solution is exactly its global optimal solution! In formulaic terms, it means the solution to problem $\eqref{eq:opt}$ can be transformed into:

\begin{equation}\begin{gathered} f^* = f_T^* \circ \dots \circ f_2^* \circ f_1^* \\[12pt] f_1^* = \mathop{\text{argmin}}_{f_1} \max_{x\in[l_1,u_1]} |f_1(x) - 1| \\ f_2^* = \mathop{\text{argmin}}_{f_2} \max_{x\in[l_2,u_2]} |f_2(x) - 1| \\ \vdots \\ f_T^* = \mathop{\text{argmin}}_{f_T} \max_{x\in[l_T,u_T]} |f_T(x) - 1| \\[24pt] l_1 = l,\quad u_1 = u, \\[8pt] l_{t+1} = \min_{x\in[l_t,u_t]} f_t^*(x),\quad u_{t+1} = \max_{x\in[l_t,u_t]} f_t^*(x) \end{gathered}\end{equation}

I believe this conclusion will surprise many readers; I was also quite amazed when I first saw it and found it extraordinary. It not only greatly reduces the difficulty of the solution—transforming the original $T$-step composite function problem into solving single polynomials step-by-step—but also allows us to push the solution forward step by step while maintaining optimality (i.e., the optimal solution for $T+1$ steps only requires calculating one more step based on the optimal solution for $T$ steps, rather than starting from scratch).

It is worth noting that this conclusion allows each $f_t$ to have a different degree (here "degree" refers to the highest power of the polynomial). For example, $f_1$ could be degree 3, $f_2$ could be degree 5, and so on, yet the conclusion that the "greedy solution is the global optimal solution" still holds. For simplicity, however, we will keep all $f_t$ at the same degree and primarily consider the results for degrees 3 and 5.

The complete proof of this conclusion is slightly complex. We will put it at the end and first complete the subsequent operations based on this conclusion.

Equioscillation

Since we have transformed the original problem into finding greedy solutions, we now only need to focus on solving:

\begin{equation}\mathop{\text{argmin}}_{f_t} \max_{x\in[l_t,u_t]} |f_t(x) - 1| \label{eq:local}\end{equation}

To solve the above equation, we first need to understand the "Equioscillation Theorem" regarding odd polynomials introduced in "Equioscillation Theorem: Necessary and Sufficient Conditions for Optimal Polynomial Approximation":

Equioscillation Theorem - Odd: Let $f(x)$ be an odd polynomial of degree at most $2n+1$, and $g(x)$ be a continuous function on the interval $[a,b]\subset (0,\infty)$. Then \begin{equation}f^* = \mathop{\text{argmin}}_f \max_{x\in[a,b]} |f(x) - g(x)|\end{equation} holds if and only if there exist $a\leq x_0 < x_1 < \dots < x_{n+1} \leq b$ and $\sigma\in\{0,1\}$ such that \begin{equation}f^*(x_k) - g(x_k) = (-1)^{k+\sigma} \max_{x\in[a,b]} |f^*(x) - g(x)|\end{equation}

Now, our target function is $f_t$, and the target $g$ is constant 1. The equioscillation theorem tells us that $|f_t^*(x)-1|$ reaches the maximum error (denoted as $\mathcal{E}$) at least $n+2$ times in $[l_t,u_t]$. It is easy to find that the maximum points of $|f_t^*(x)-1|$ can only be boundary points or extreme points of $f_t^*(x)$. An odd polynomial of degree $2n+1$ has at most $n$ extreme points in $(0,\infty)$. Therefore, to "make up" $n+2$ points, we "have to" include the boundary points. This determines $x_0 = l_t, x_{n+1}=u_t$, and $x_1,\dots,x_n$ are the zeros of $\frac{d}{dx}f_t^*(x)$.

Furthermore, since the target function is $1$, and $f_t^*(0)=0$, the slope of $f_t^*(x)$ at $x=0$ must be greater than zero. Thus $l_t$ can only be a minimum point of $f_t^*(x)$, meaning $\sigma=1$. Combining these results, we are actually solving the following system of equations:

\begin{equation}f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + (-1)^n \mathcal{E},\quad f_t(x_i) = 1 + (-1)^{i+1}\mathcal{E}, \quad f_t'(x_i) = 0\end{equation}

where $i=1,2,3,\dots,n$. It can be seen that there are $2n+2$ equations and $2n+2$ unknowns. Adding the constraints $l_t < x_1 < \dots < x_n < u_t$ and $\mathcal{E} > 0$, the solution can theoretically be determined.

Solving the System of Equations

For a degree 3 odd polynomial ($n=1$), the original paper provides an analytical solution. For a degree 5 odd polynomial ($n=2$), the paper provides an iterative algorithm: first fix $x_1, x_2$ to solve for $a, b, c$, then fix $a, b, c$ of $f_t(x)$ to solve for $x_1, x_2$, iterating repeatedly. This is essentially a simplified version of the Remez algorithm.

However, the iteration in the original paper relies on the root formula to find $x_1, x_2$, which is not easy for larger $n$. Therefore, here I change the solution approach: first parameterize $f_t'(x_i)$ using $x_1, x_2, \dots, x_n$, i.e., define:

\begin{equation}f_t'(x) = k(x^2-x_1^2)(x^2-x_2^2)\dots (x^2-x_n^2)\end{equation}

Then we have $f_t(x) = \int_0^x f_t'(x) dx$. In this way, we express $f_t(x)$ using $k$ and $x_1, x_2, \dots, x_n$. Then we only need to solve the system:

\begin{equation}f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + (-1)^n \mathcal{E},\quad f_t(x_i) = 1 + (-1)^{i+1}\mathcal{E}\end{equation}

This avoids solving the equation $f_t'(x) = 0$. When $n=1$, we can solve for:

\begin{equation}x_1 = \sqrt{\frac{l_t^2 + l_t u_t + u_t^2}{3}}, \quad k = -\frac{6}{l_t^2 u_t + l_t u_t^2 + 2x_1^3}\end{equation}

When $n > 1$, we can hand it over to Mathematica. For example, when $n=2$:

df[x_] = k*(x^2 - x1^2) (x^2 - x2^2);
f[x_] = Integrate[df[x], {x, 0, x}];
sol = NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e,
 f[u] == 1 + e, l < x1 < x2 < u, e > 0} /. {l -> 0.001,
 u -> 1}, {k, x1, x2, e}, Reals]
f[x] /. sol

Finite Precision

So far, have we completed the solution to the original problem? Theoretically, yes, but only for infinite precision. In practice, computation uses finite precision—especially as the Muon optimizer uses bfloat16, where precision loss is more severe—leading to several problems.

The first problem is that each $f_t^*$ is theoretically only responsible for the interval $[l_t, u_t]$, but under finite precision, singular values might deviate from this interval. When $n$ is even (i.e., $f_t^*$ is degree 5, 9, ... odd polynomial), there is a risk of divergence if it exceeds $u_t$, because $f_t^*(x)$ at $x > u_t$ increases monotonically to infinity, and could diverge with iterations. There are two solutions: one is to leave a slightly wider margin for $[l_t, u_t]$ when solving for $f_t^*$; the other is to keep the interval unchanged but divide the input $f_t^*$ by a number slightly greater than 1 after solving it. The original paper uses the latter, changing $f_t^*(x)$ to $f_t^*(x / 1.01)$.

The second problem is more subtle. Let's introduce it with a specific example. Suppose $n=2, l_1=0.001, u_1=1$. We can solve for $f_1^*$:

\begin{equation}f_1^*(x) = 8.4703 x - 25.1081 x^3 + 18.6293 x^5\end{equation}

where $x_1 = 0.3674, x_2 = 0.8208, \mathcal{E}=0.9915$. What is the problem with this solution? According to the equioscillation theorem, we know $f_1^*(x_2) = 1-\mathcal{E} = 0.0085$, meaning it maps $0.8208$ to $0.0085$. However, our ultimate goal is to turn all numbers in $(0,1]$ into 1. Thus, $f_1^*$ maps a value already close to the target, $0.8208$, to $0.0085$, which is very far from the target. Although $f_2^*, f_3^*, \dots$ will theoretically pull it back eventually, in finite precision, repeatedly shrinking and expanding a number can lead to significant accumulated error.

Of course, from the equioscillation theorem, we know this oscillating behavior is unavoidable. We can only hope that the maximum error $\mathcal{E}$ is not too close to 1 to slow down the accumulation of error. It is easy to see that the larger the interval $[l_t, u_t]$, the harder it is to fit theoretically, and the maximum error $\mathcal{E}$ will be closer to 1. Therefore, the paper introduces a hyperparameter $\lambda \in (0, 1)$ to change the optimization interval from $[l_t, u_t]$ to $[\max(l_t, \lambda u_t), u_t]$, ensuring that $\mathcal{E}$ is not too large by restricting the interval size. (Note: the paper uses $\lambda=0.1$ in the main text explanation, but the appendix code uses $\lambda=0.024$.)

However, doesn't this mean the original $l_t$, especially our initial $l$, is easily neglected? To solve this, the paper introduces the "Recenter" trick: if the optimization interval is $[l_t, u_t]$, then $f_t^*(l_t) + f_t^*(u_t) = 2$ will be satisfied. This might not hold after changing the interval to $[\max(l_t, \lambda u_t), u_t]$. In this case, we multiply $f_t^*$ by $\gamma$ so that it satisfies this equation:

\begin{equation}\gamma f_t^*(l_t) + \gamma f_t^*(u_t) = 2 \quad \Rightarrow \quad \gamma = \frac{2}{f_t^*(l_t) + f_t^*(u_t)}\end{equation}

This takes the original $l_t$ back into consideration.

Reference Code

This is the complete Mathematica code for $n=2$:

df[x_] = k*(x^2 - x1^2) (x^2 - x2^2);
f[x_] = Integrate[df[x], {x, 0, x}];
sol[l_, u_] :=
 NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e, f[u] == 1 + e,
 l < x1 < x2 < u, e > 0, k > 0}, {k, x1, x2, e}]
ff[x_, l_, u_] = f[x]*2/(f[l] + f[u]) // Expand;
lt = 0.001; ut = 1; lambda = 0.02407327424182761;
While[1 - lt > 0.0001,
 fff[x_] = ff[x, lt, ut] /. sol[Max[lt, lambda*ut], ut][[1]];
 Print[fff[x]];
 lt = fff[lt]; ut = 2 - lt]

Results are as follows ($f_t(x) = a_t x + b_t x^3 + c_t x^5$):

\begin{array}{c|ccc} \hline t & a \times 1.01 & b \times 1.01^3 & c \times 1.01^5 \\ \hline \quad 1 \quad & 8.28721 & -23.5959 & 17.3004 \\ 2 & 4.10706 & -2.94785 & 0.544843 \\ 3 & 3.94869 & -2.9089 & 0.551819 \\ 4 & 3.31842 & -2.48849 & 0.510049 \\ 5 & 2.30065 & -1.6689 & 0.418807 \\ 6 & 1.8913 & -1.268 & 0.376804 \\ 7 & 1.875 & -1.25 & 0.375 \\ 8 & 1.875 & -1.25 & 0.375 \\ \hline \end{array}

Note that the results given here are before the $f_t^*(x / 1.01)$ processing, so the actual $a, b, c$ should be divided by powers of $1.01$ (orders 1, 3, 5). The reason for not directly giving the results after dividing by $1.01$ is that the convergence values $1.875, -1.25, 0.375$ (for $t \geq 7$) are cleaner and more aesthetically pleasing to observe.

The code from the appendix is summarized as follows:

import numpy as np

def optimal_quintic(l, u):
    assert 0 <= l <= u
    if 1 - 5e-6 <= l / u:
        # Above this threshold, the equoscillating polynomials
        # is numerically equal to...
        return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
    
    # This initialization becomes exact as l -> u
    q = (3 * l + 1) / 4
    r = (l + 3) / 4
    E, old_E = np.inf, None
    while not old_E or abs(old_E - E) > 1e-15:
        old_E = E
        LHS = np.array([
            [l, l**3, l**5, 1],
            [q, q**3, q**5, -1],
            [r, r**3, r**5, 1],
            [u, u**3, u**5, -1],
        ])
        a, b, c, E = np.linalg.solve(LHS, np.ones(4))
        q, r = np.sqrt(
            (-3 * b + np.array([-1, 1]) * np.sqrt(9 * b**2 - 20 * a * c)) /
            (10 * c)
        )
    return float(a), float(b), float(c)

def optimal_composition(l, num_iters, cushion=0.02407327424182761):
    u = 1
    coefficients = []
    for _ in range(num_iters):
        a, b, c = optimal_quintic(max(l, cushion * u), u)
        # Due to cushioning, this may be centered around 1 with
        # respect to 0.024*u, u. Recenter it around 1 with respect
        # to l, u, meaning find c so that 1 - c*p(l) = c*p(u) - 1:
        pl = a * l + b * l**3 + c * l**5
        pu = a * u + b * u**3 + c * u**5
        rescalar = 2 / (pl + pu)
        a *= rescalar
        b *= rescalar
        c *= rescalar
        # Optionally incorporate safety factor here:
        # a /= 1.01; b /= 1.01**3; c /= 1.01**5
        coefficients.append((a, b, c))
        l = a * l + b * l**3 + c * l**5
        u = 2 - l
    return coefficients

print(*optimal_composition(1e-3, 10), sep="\n")

Completing the Proof

In the final section, we complete the proof that "the greedy solution is exactly the global optimal solution."

According to the equioscillation theorem, we know the range of $f_t^*$ is $[l_{t+1}, u_{t+1}]$, where $l_{t+1}=f_t^*(l_t)$ and $u_{t+1}=2-l_{t+1}$. From this, we know the maximum error of the $T$-step greedy solution is $\mathcal{E}_T = 1 - l_{T+1} = 1 - f_T^*(l_T)$. We only need to prove that the maximum error of any $T$-step global optimal solution also can only be reduced to $1 - f_T^*(l_T)$ to conclude that "the greedy solution is the global optimal solution."

The proof strategy is mathematical induction. Suppose the conclusion holds for $t=1, 2, \dots, T-1$. Then $\hat{f} = f_{T-1}^* \circ \dots \circ f_2^* \circ f_1^*$ is the global optimal solution for $T-1$ steps, with range $[l_T, u_T]$ and maximum error $\mathcal{E}_{T-1} = 1 - l_T = u_T - 1$. On the other hand, let $\tilde{f} = \tilde{f}_{T-1} \circ \dots \circ \tilde{f}_2 \circ \tilde{f}_1$ be any $T-1$ step solution with range $[a,b]$. Let $c = \frac{2}{a+b}$, then the range of $c\tilde{f}$ is $[ca, cb]$. Clearly $ca \leq 1$ and $cb \geq 1$. According to the inductive hypothesis, we have:

\begin{equation}\begin{aligned} 1 - ca \geq \mathcal{E}_{T-1} \\ cb - 1 \geq \mathcal{E}_{T-1} \end{aligned} \quad \Rightarrow \quad \frac{a}{b} \leq \frac{1 - \mathcal{E}_{T-1}}{1 + \mathcal{E}_{T-1}} = \frac{l_T}{u_T} \end{equation}

That is, the relative size of the range of any $T-1$ step solution is no smaller than the relative size of the range $[l_T, u_T]$ of the optimal $T-1$ step solution. Then we have:

\begin{equation}\begin{aligned} \min_{f_T} \max_{x\in[l,u]} |f_T(\tilde{f}(x)) - 1| =& \min_{f_T} \max_{x\in[a,b]} |f_T(x) - 1| \\ =& \min_{f_T} \max_{x\in[a/b,1]} |f_T(x) - 1| \\ \geq & \min_{f_T} \max_{x\in[l_T/u_T,1]} |f_T(x) - 1| \\ =& \min_{f_T} \max_{x\in[l_T,u_T]} |f_T(x) - 1| \\ =& \mathcal{E}_T \end{aligned}\end{equation}

In other words, no matter what other $T-1$ step solution you use, the maximum error can at best be as small as that of the greedy solution. Therefore, the maximum error of the greedy solution is already globally optimal, completing the recursive proof. The key step in the above equation is:

\begin{equation}\min_{f_T} \max_{x\in[a,b]} |f_T(x) - 1| = \min_{f_T} \max_{x\in[a/b,1]} |f_T(x) - 1|\end{equation}

This is because we can always set $g_T(y) = f_T(b y)$. $g_T$ still represents any odd polynomial of the same degree. Thus, $g_T$ and $f_T$ are in the same function space, and the notation can be substituted, i.e.:

\begin{equation}\min_{f_T} \max_{x\in[a,b]} |f_T(x) - 1| = \min_{g_T} \max_{y\in[a/b,1]} |g_T(y) - 1| = \min_{f_T} \max_{x\in[a/b,1]} |f_T(x) - 1|\end{equation}

Summary

This article introduced the latest progress in finding better Newton-Schulz iterations for the $\mathop{\text{msign}}$ operator. By using the equioscillation theorem and greedy transformation, it directly derives the theoretically optimal solution. The entire process is quite hardcore and well worth learning.