By 苏剑林 | March 12, 2026
The core operation of Muon is msign, and the current standard implementation is the Newton-Schulz iteration. It must be said that this is indeed a very efficient and GPU-friendly algorithm; a large part of Muon's popularity is arguably due to this algorithm. However, this algorithm also gives a sense of being "exclusive with no substitutes," as it seems limited strictly to calculating msign. Once we want to modify Muon (for example, replacing msign with mclip as discussed here), the corresponding calculations become cumbersome.
This article proposes a new implementation approach—approximating SVD through Streaming Power Iteration. This is not a entirely new idea, as it has appeared in some previous optimizer works, but here we extract it as a standalone algorithm for use.
Review of Content #
We won't go into the specifics of Muon here; readers can refer to previous articles such as "Appreciating the Muon Optimizer: The Essential Leap from Vectors to Matrices", "Muon Sequel: Why Do We Choose to Try Muon?", and "Muon Optimizer Guide: Quick Start and Key Details". Here, we directly provide its formulas:
\begin{equation}
\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t [\text{msign}(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}] \\
\end{aligned}
\end{equation}
where $\text{msign}$ is defined as:
\begin{equation}
\text{msign}(\boldsymbol{M})=\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}=\boldsymbol{U}_{[:, :r]}\boldsymbol{V}_{[:, :r]}^{\top}
\end{equation}
Here $\boldsymbol{M}\in\mathbb{R}^{n\times m}$. Without loss of generality, we assume $n\geq m$, and for simplicity, we mostly assume $r=m$ (full rank), discussing rank-deficiency only when strictly necessary.
Since SVD is expensive, in most cases we use Newton-Schulz iteration to calculate msign, which we have discussed in detail in "Newton-Schulz Iteration for the msign Operator (Part 1)" and "Newton-Schulz Iteration for the msign Operator (Part 2)". Overall, Newton-Schulz iteration is very clever and is the main contributor to Muon's success, but its extensibility is relatively weak.
To expand the application scenarios of Newton-Schulz iteration, I previously did some work, such as "Calculating Singular Value Clipping mclip via msign (Part 1)", "Calculating Singular Value Clipping mclip via msign (Part 2)", "Efficient Calculation of Matrix Square Root and Inverse Square Root", and "Efficient Calculation of Matrix r-th Root and Inverse r-th Root". However, what can be achieved remains overall quite limited.
Obviously, the once-and-for-all method is to directly solve for the SVD, which is the approach we will focus on next.
Power Iteration #
In articles like "Lipschitz Constraint in Deep Learning: Generalization and Generative Models" and "Reflections from Spectral Norm Gradients to New Forms of Weight Decay", we have briefly encountered Power Iteration. We use it to find the dominant eigenvector of $\boldsymbol{M}^{\top}\boldsymbol{M}$, or the dominant right singular vector of $\boldsymbol{M}$, with the following iteration format:
\begin{equation}
\boldsymbol{v}_1^{(t)} = \frac{\boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_1^{(t-1)}}{\Vert\boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_1^{(t-1)}\Vert_2}
\end{equation}
Assuming we have obtained the dominant eigenvector $\boldsymbol{v}_1$, we can add orthogonalization to the power iteration to find the next eigenvector:
\begin{equation}
\boldsymbol{v}_2^{(t)} = \frac{\tilde{\boldsymbol{v}}_2^{(t)} - \langle\tilde{\boldsymbol{v}}_2^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1}{\Vert\tilde{\boldsymbol{v}}_2^{(t)} - \langle\tilde{\boldsymbol{v}}_2^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1\Vert_2},\qquad \tilde{\boldsymbol{v}}_2^{(t)} = \boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_2^{(t-1)}
\end{equation}
By ensuring orthogonality with $\boldsymbol{v}_1$, this will converge to the second eigenvector $\boldsymbol{v}_2$. Similarly, assuming $\boldsymbol{v}_1, \boldsymbol{v}_2, \dots, \boldsymbol{v}_{k-1}$ are known, we can use Gram-Schmidt orthogonalization to find the $(k)$-th eigenvector:
\begin{equation}
\boldsymbol{v}_k^{(t)} = \frac{\tilde{\boldsymbol{v}}_k^{(t)} - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1 - \cdots - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_{k-1}\rangle\boldsymbol{v}_{k-1}}{\Vert\tilde{\boldsymbol{v}}_k^{(t)} - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1 - \cdots - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_{k-1}\rangle\boldsymbol{v}_{k-1}\Vert_2},\qquad \tilde{\boldsymbol{v}}_k^{(t)} = \boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_k^{(t-1)}
\label{eq:vk-pi}
\end{equation}
In practice, we don't need to wait for $\boldsymbol{v}_1, \boldsymbol{v}_2, \dots, \boldsymbol{v}_{k-1}$ to be calculated before calculating $\boldsymbol{v}_k$; the entire $\boldsymbol{V}=[\boldsymbol{v}_1, \boldsymbol{v}_2, \dots, \boldsymbol{v}_m]$ can be iterated in parallel. Specifically, starting from an existing approximation $\boldsymbol{V}_{t-1}$, we batch calculate $\boldsymbol{M}^{\top} \boldsymbol{M} \boldsymbol{V}_{t-1}$, and then re-orthogonalize the column vectors (e.g., using QR decomposition). This yields a better approximation, which we denote as $\boldsymbol{V}_t$. Repeated iteration will ultimately converge to our target $\boldsymbol{V}$:
\begin{equation}
\newcommand{QR}{\mathop{\text{QR}}}\boldsymbol{V}_t = \QR(\boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{V}_{t-1})
\end{equation}
Here $\text{QR}$ refers to the orthogonal matrix obtained after QR decomposition. Once we have $\boldsymbol{V}$, it is clear that $\newcommand{ColNorm}{\mathop{\text{ColNorm}}}\boldsymbol{U} = \ColNorm(\boldsymbol{M}\boldsymbol{V})$, where $\text{ColNorm}$ means performing L2 normalization on each column (axis=0), and $\newcommand{diag}{\mathop{\text{diag}}}\boldsymbol{\Sigma}=\diag(\boldsymbol{U}^{\top}\boldsymbol{M}\boldsymbol{V})$. This gives us an approximate SVD scheme based on power iteration and QR decomposition. Of course, when $n > m$, it only yields an incomplete decomposition where $\boldsymbol{U}\in\mathbb{R}^{n\times m}$ and $\boldsymbol{\Sigma}, \boldsymbol{V}\in\mathbb{R}^{m \times m}$, but it is sufficient.
Streaming Update #
However, the actual efficiency of using power iteration to calculate SVD is extremely low—far lower than directly calling the SVD function provided by the framework—making it impractical. However, considering that training itself is a long-term iterative process, we can assume that $\boldsymbol{V}$ does not change significantly at each step. Thus, we can store $\boldsymbol{V}$ from the previous step as the initialization for the current step and perform only one power iteration per step, i.e.,
\begin{equation}
\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{V}_t =&\, \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) \\[5pt]
\boldsymbol{U}_t =&\, \ColNorm(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\
\end{aligned}
\label{eq:muon-qr}
\end{equation}
where $\boldsymbol{V}_0=\boldsymbol{I}$. Empirical tests show that Muon implemented via this streaming power iteration produces an LM Loss convergence curve nearly overlapping with the Newton-Schulz version. This indicates it is indeed a viable scheme. This is largely due to the momentum mechanism and small learning rates, allowing the assumption that "$\boldsymbol{V}$ change is small at each step" to approximately hold, thereby permitting the cost of power iteration to be "amortized" across steps.
Thanks to the direct approximate calculation of SVD, we can also perform operations on singular values and incorporate them into the optimizer, for example:
\begin{equation}
\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{V}_t =&\, \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1}) \\[5pt]
\boldsymbol{U}_t =&\, \ColNorm(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt]
\boldsymbol{\Sigma}_t =&\, \diag(\boldsymbol{U}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t f(\boldsymbol{\Sigma}_t)\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\
\end{aligned}
\end{equation}
In this way, implementing mclip or Muon variants based on general Schatten norms becomes much easier. In short, having explicit results for $\boldsymbol{U}_t, \boldsymbol{\Sigma}_t, \boldsymbol{V}_t$ (even if only approximate) allows us to easily experiment with small modifications, significantly enhancing extensibility and playability.
Accelerating Decomposition #
The pressure now shifts to the QR decomposition. The most time-consuming step in Eq. $\eqref{eq:muon-qr}$ is the QR decomposition. The standard implementation is Householder QR, which, while already much faster than SVD, is still slower than the msign calculated via Newton-Schulz iteration (which uses polynomial iterations and allows BF16 multiplication—practically "cheating"). Therefore, to enhance the competitiveness of this new scheme, we need to speed up the QR decomposition.
For a given matrix $\boldsymbol{A}\in\mathbb{R}^{n\times m}$ ($n\geq m$), QR decomposition aims to find an orthogonal matrix $\boldsymbol{Q}\in\mathbb{R}^{n\times m}$ and an upper triangular matrix $\boldsymbol{R}\in\mathbb{R}^{m\times m}$ such that $\boldsymbol{A}=\boldsymbol{Q}\boldsymbol{R}$ (here the orthogonal matrix only needs to satisfy $\boldsymbol{Q}^{\top}\boldsymbol{Q}=\boldsymbol{I}$, more accurately called a Stiefel matrix). Note that $\boldsymbol{A}^{\top}\boldsymbol{A}=\boldsymbol{R}^{\top}\boldsymbol{R}$, which means we only need to decompose $\boldsymbol{A}^{\top}\boldsymbol{A}$ into a product of a lower triangular matrix and its transpose to obtain $\boldsymbol{R}$—and this is exactly what Cholesky decomposition does!
Cholesky decomposition is extremely efficient. So, as a first step, we can use it to get $\boldsymbol{R}$, and then solve the equation $\boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A}$ to get $\boldsymbol{Q}$. The equation can also be written as $\boldsymbol{R}^{\top}\boldsymbol{Q}^{\top}=\boldsymbol{A}^{\top}$, which can be solved using solve_triangular, which is also very efficient. Combining these two steps constitutes the QR decomposition algorithm known as "Cholesky QR." Disregarding numerical stability, it is perhaps the fastest QR decomposition method.
Unfortunately, compared to standard QR decomposition, Cholesky QR is very unstable; it is extremely sensitive to the condition number of $\boldsymbol{A}^{\top}\boldsymbol{A}$. To address this, "Shifted CholeskyQR for computing the QR factorization of ill-conditioned matrices" (abbreviated as "SCQR") proposes adding $\lambda \boldsymbol{I}$ ($\lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F$) to $\boldsymbol{A}^{\top}\boldsymbol{A}$ to mitigate this issue. However, this is a double-edged sword: the larger $\epsilon$ is, the more stable SCQR becomes, but the less orthogonal the resulting $\boldsymbol{Q}$ will be, leading to worse final performance.
Furthermore, even with $\epsilon$, there is no guarantee that SCQR will always succeed, so we need an extra detection step to fall back to standard QR if it fails.
Reference Implementation #
A simple reference implementation based on Jax is as follows:
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax import lax
def scqr(A, eps=1e-9):
"""先按Shifted Cholesky QR算,失败则回退到默认QR
"""
B, I = A.mT @ A, jnp.eye(A.shape[-1])
B += eps * jnp.linalg.matrix_norm(B, keepdims=True) * I
R = jnp.linalg.cholesky(B, upper=True)
Q = solve_triangular(R.mT, A.mT, lower=True).mT
return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])
Simple tests show that if executed successfully, the efficiency of SCQR is comparable to the Newton-Schulz version of msign. However, to ensure approximation quality, the value of $\epsilon$ cannot be too small, or performance drops significantly. Empirical tests usually require $\epsilon=10^{-9}$ for guaranteed performance, at which point SCQR still has a relatively high probability of falling back to standard QR, causing the final speed to still lag behind Newton-Schulz iteration.
Besides improving the QR decomposition algorithm itself, there are other acceleration techniques, such as keeping only the top $k$ eigenvectors. In that case, $\boldsymbol{V}$ only needs to be initialized as $m \times k$ instead of $m \times m$, which also reduces computation. Further acceleration of QR decomposition is left for everyone to explore; I won't expand on it further here.
Other Details #
Additionally, there are some details to pay special attention to, as they are closely related to training stability and final performance.
First, according to the convention $\boldsymbol{M}_t\in\mathbb{R}^{n\times m}$, we need to ensure $n\geq m$, otherwise transpose it. If $n < m$, the matrix $\boldsymbol{M}_t^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1}$ is inherently rank-deficient. Performing QR decomposition on a rank-deficient matrix is ill-posed, and SCQR in particular is prone to various pathological phenomena, reducing effectiveness. Therefore, ensuring $n\geq m$ guarantees numerical stability, improves performance, and accelerates computation—achieving multiple benefits at once.
Second, empirical tests found that adding an extra ColNorm to the $\text{QR}$ step is very helpful for performance:
\begin{equation}
\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) \qquad\to\qquad \boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\ColNorm(\boldsymbol{M}_t\boldsymbol{V}_{t-1}))
\end{equation}
This is merely equivalent to changing $\tilde{\boldsymbol{v}}_k^{(t)} = \boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_k^{(t-1)}$ in Eq. $\eqref{eq:vk-pi}$ to $\tilde{\boldsymbol{v}}_k^{(t)} = \boldsymbol{M}^{\top}(\boldsymbol{M}\boldsymbol{v}_k^{(t-1)} / \Vert\boldsymbol{M}\boldsymbol{v}_k^{(t-1)}\Vert_2)$, which does not change the convergence of the power iteration itself. However, experiments show that this extra ColNorm step significantly helps training performance, especially under SCQR, narrowing the gap between it and standard QR.
It can be proven that this operation theoretically does not change the power iteration and only affects the numerical calculation of QR. According to experimental observation, it actually makes SCQR fall back to standard QR more frequently (but not by much, so it won't be significantly slower), and slightly helps standard QR as well, so it seems to improve the properties of the matrix to be decomposed, resulting in better QR quality.
Related Work #
As mentioned at the beginning, the concept of streaming power iteration has appeared several times in various optimizer works, such as "4-bit Shampoo for Memory-Efficient Network Training", "SOAP: Improving and Stabilizing Shampoo using Adam", "COSMOS: A Hybrid Adaptive Optimizer for Memory-Efficient Training of LLMs", and "Dion: Distributed Orthonormalized Updates".
In fact, taking an algorithm that requires multiple iterations to converge and, combining it with the long-term update nature of model training, turning it into a streaming version with one iteration per step to amortize costs is not a difficult idea to conceive. We previously tried this in "Steepest Descent on Manifolds: 5. Dual Gradient Descent". Therefore, the existence of so much related work is not surprising.
This article mainly references the paper "ARO: A New Lens On Matrix Optimization For Large Models" released last month. This paper essentially covers most of the content of this article and generalizes it. The generalization idea is also worth considering; note that the current Muon update can be written as:
\begin{equation}
\ColNorm(\boldsymbol{M}_t\boldsymbol{V}_t)\boldsymbol{V}_t^{\top}
\end{equation}
where $\ColNorm(\boldsymbol{M}_t)$ can be considered a base optimizer that simply does column normalization on momentum, lacking competitiveness. Thus, we find a new set of orthogonal bases $\boldsymbol{V}_t$ for it, apply the base optimizer under the new basis, and then transform back. Now we know this is Muon, and it is significantly stronger than simple $\ColNorm$. So the next logical thought is: can $\ColNorm$ be replaced with other base optimizers? To this, ARO proposes a general optimizer framework (Rotation Steepest Descent):
\begin{equation}
\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{R}_t =&\, \QR(\boldsymbol{M}_t^{\top}f(\boldsymbol{M}_t\boldsymbol{R}_{t-1})) \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (f(\boldsymbol{M}_t\boldsymbol{R}_t)\boldsymbol{R}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\
\end{aligned}
\end{equation}
where $f$ represents any matrix function, and the original notation $\boldsymbol{V}$ is replaced by $\boldsymbol{R}$ (Rotation). We will discuss more about Rotation Steepest Descent when the opportunity arises.
Summary #
This article mainly introduced the idea of implementing Muon by calculating SVD via Streaming Power Iteration. It only requires one QR decomposition per step. Compared to the standard Newton-Schulz iteration implementation, this approach offers more flexible extensibility.
If you found this article worthwhile, feel free to share it or leave a tip. The tip is not meant to generate income; it is simply a way for me to know how much genuine attention Scientific Spaces has received. Of course, ignoring it will not affect your reading. Thank you again for reading and for your support.