By 苏剑林 | March 26, 2026
In the first article "Muon Implementation Based on Streaming Power Iteration: 1. Initial Contact", I separately abstracted Streaming Power Iteration as a new way to implement Muon. Since the new scheme directly approximates SVD, it has richer expansion space compared to the standard implementation based on Newton-Schulz iteration, making it worthy of continued study.
From a computational perspective, the main change in the new scheme is replacing Newton-Schulz iteration with \(\newcommand{QR}{\mathop{\text{QR}}}\QR\) decomposition, which introduces some slowdown. In the previous post, we discussed some basic acceleration methods, but they have not yet caught up to the standard implementation. In this article, we continue to research the acceleration of \(\QR\) in an effort to close the gap as much as possible.
Streaming Iteration
We will follow all concepts and notations from the first article; readers with related doubts should refer back to it first. First, the Muon update formula is:
\begin{equation}
\newcommand{msign}{\mathop{\text{msign}}}
\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t [\msign(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}] \\
\end{aligned}
\end{equation}
where the standard implementation of \(\msign\) is Newton-Schulz iteration, which is also the most expensive calculation in the Muon optimizer. In contrast, the update formulas for the Streaming Power Iteration scheme are:
\begin{equation}
\newcommand{ColNorm}{\mathop{\text{ColNorm}}}
\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{V}_t =&\, \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) \\[5pt]
\boldsymbol{U}_t =&\, \ColNorm(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\
\end{aligned}
\end{equation}
If we repeatedly execute $\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1})$, it becomes the standard power iteration, and the result will converge to the right singular matrix of $\boldsymbol{M}_t$, thereby achieving the SVD of $\boldsymbol{M}_t$ and subsequently calculating $\msign$. However, performing full power iteration at every step is too costly, so we instead cache the result from the previous step $\boldsymbol{V}_{t-1}$ and perform only one iteration of $\QR$ as an approximation per step; this is the meaning of "streaming."
Now the most expensive operation becomes the $\QR$ decomposition. Use of the built-in QR function is the most naive implementation; its underlying principle is the Householder transform, which has good stability but is quite slow.
First Speedup
To accelerate, in the previous post, we introduced Cholesky QR, which divides the QR decomposition of matrix $\boldsymbol{A}$ into two steps: 1. Perform Cholesky decomposition on $\boldsymbol{A}^{\top}\boldsymbol{A}$ to obtain an upper triangular matrix $\boldsymbol{R}$; 2. Solve the equation $\boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A}$ to obtain the orthogonal matrix $\boldsymbol{Q}$. Theoretically, both steps are very efficient, but in practice, the calculation will fail if the condition number is too large. For this, we introduced the Shift technique, which adds $\lambda \boldsymbol{I}$ ($\lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F$) to $\boldsymbol{A}^{\top}\boldsymbol{A}$ to reduce the condition number.
The combination of the two is simply referred to as "SCQR (Shifted Cholesky QR)." A reference implementation based on Jax is as follows:
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax import lax
def scqr(A, eps=1e-9):
"""先按Shifted Cholesky QR算,失败则回退到默认QR
"""
B, I = A.mT @ A, jnp.eye(A.shape[-1])
B += eps * jnp.linalg.matrix_norm(B, keepdims=True) * I
R = jnp.linalg.cholesky(B, upper=True)
Q = solve_triangular(R.mT, A.mT, lower=True).mT
return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])
Note that the smaller $\lambda$ is, the more likely SCQR is to fail, while the larger $\lambda$ is, the more the result deviates from orthogonality, leading to worse effects. Therefore, $\lambda$ must be "just right," which means this scheme still has a high probability of falling back to standard QR. Additionally, the previous post mentioned that adding $\ColNorm$ to the power iteration (i.e., changing it to $\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\ColNorm(\boldsymbol{M}_t\boldsymbol{V}_{t-1}))$) can stabilize training effects, an effect which is even more pronounced under SCQR.
Full Precision Enabled
The above basically covers the entirety of the first article; it proved the feasibility of running the full process, and SCQR provided some acceleration compared to calling the framework's built-in QR decomposition directly. However, the speed was still significantly slower than the $\msign$ implemented via Newton-Schulz iteration, so we must find ways to speed it up.
This section introduces the first acceleration trick—enabling "full" FP32 precision matrix multiplication. First, it should be noted that the new steps in streaming power iteration are calculated in FP32 precision. However, starting from the introduction of the TF32 format with the A100, some frameworks (like Jax, which the author uses for small experiments, or certain versions of Torch) will convert FP32 arrays to TF32 format for acceleration by default during matrix multiplication. One must manually enable it to achieve true FP32 precision multiplication.
Readers might wonder: shouldn't increasing multiplication precision slow things down? This is indeed counter-intuitive, but it's not hard to understand. Reducing matrix precision often increases its condition number, thereby increasing the probability of SCQR failure and the chance of falling back to standard QR, which increases time consumption; conversely, increasing precision increases the success rate of SCQR, and standard QR is precisely the most time-consuming part. Thus, the total time consumption actually becomes shorter.
Data shows that Jax has always defaulted to TF32 for FP32 multiplication, so it must be manually enabled via jax.config.update('jax_default_matmul_precision', 'highest'). Torch is more complex; versions 1.7 to 1.11 defaulted to TF32 multiplication, but starting from 1.12, they defaulted back to FP32 multiplication. Considering Torch is now version 2.11, it is estimated that most users don't need to enable it manually.
Double Orthogonalization
The second acceleration trick I thought of is changing the power iteration step to:
\begin{equation}
\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1}))
\end{equation}
This step is also very counter-intuitive. By adding an extra $\QR$, the speed actually becomes faster in the end. The reason is similar to the previous section: it reduces the condition number of the matrix to be decomposed, thereby increasing the success rate of SCQR. Since SCQR itself is very fast, executing it twice doesn't add much time and significantly increases speed by greatly reducing the number of fallbacks to standard QR.
Understanding this trick involves two parts: first, that adding this $\QR$ step theoretically does not change the power iteration; second, that it indeed reduces the condition number. The first part is easy to understand: if $\boldsymbol{A}=\boldsymbol{Q}\boldsymbol{R}$, then $\boldsymbol{Q}=\boldsymbol{A}\boldsymbol{R}^{-1}$, where $\boldsymbol{R}^{-1}$ is also an upper triangular matrix. That is, QR decomposition can be written as right-multiplying by an upper triangular matrix, then:
\begin{equation}
\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1}) = \boldsymbol{M}_t^{\top}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}\times \text{some upper triangular matrix}) = \boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}\times \text{some upper triangular matrix}
\end{equation}
Due to the uniqueness of QR decomposition, right-multiplying by an upper triangular matrix does not change the result of \(\QR\), so it is theoretically equivalent to $\QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1})$.
As for the condition number, it is the ratio of the maximum singular value to the minimum singular value. With a single $\QR$, the Cholesky decomposition matrix is $\boldsymbol{V}_{t-1}^{\top}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)^2\boldsymbol{V}_{t-1}$. Note that orthogonal transforms do not change singular values and thus do not change the condition number, so the condition number of the matrix to be decomposed reaches the 4th power of the condition number of $\boldsymbol{M}_t$! If two $\QR$s are performed, then the matrix for Cholesky decomposition will be $\boldsymbol{Q}_t^{\top}(\boldsymbol{M}_t\boldsymbol{M}_t^{\top})\boldsymbol{Q}_t$, where $\boldsymbol{Q}_t$ is the orthogonal matrix from the first $\QR$. At this point, the condition number is only the square of that of $\boldsymbol{M}_t$, which is significantly lower.
Translation Invariance
The third acceleration trick was derived during a discussion with @YouJiacheng; it utilizes the translation invariance of the characteristic matrix. We know that the power iteration $\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1})$ can also be understood as seeking the characteristic matrix of the positive definite matrix $\boldsymbol{M}_t^{\top}\boldsymbol{M}_t$, and positive definite matrices have a property—adding a multiple of the identity matrix to them does not change the characteristic matrix.
In other words, $\boldsymbol{M}_t^{\top}\boldsymbol{M}_t$ and $\boldsymbol{M}_t^{\top}\boldsymbol{M}_t + \lambda \boldsymbol{I}$ have the same characteristic matrix, so we can change the power iteration to:
\begin{equation}
\boldsymbol{V}_t = \QR((\boldsymbol{M}_t^{\top}\boldsymbol{M}_t + \lambda \boldsymbol{I})\boldsymbol{V}_{t-1}) = \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1} + \lambda \boldsymbol{V}_{t-1})
\end{equation}
without changing the convergence result of the power iteration. What is the benefit of adding $\lambda \boldsymbol{I}$ to $\boldsymbol{M}_t^{\top}\boldsymbol{M}_t$? The answer is again to reduce the condition number, since $(\sigma_{\max} + \lambda)/(\sigma_{\min} + \lambda) < \sigma_{\max}/\sigma_{\min}$, so it can also improve the success rate of Cholesky QR. Note that we are talking about Cholesky QR here rather than SCQR, because setting an appropriate $\lambda$ externally ensures the condition number, making the Shift unnecessary, and the resulting output will definitely be orthogonal, which is also a good property.
But don't celebrate too early. While a larger $\lambda$ naturally makes Cholesky QR more likely to succeed, it also reduces the convergence speed of the power iteration! This is because the convergence speed of the power iteration depends on the ratio of adjacent singular values; the smaller $\sigma_{i+1}/\sigma_i$ is, the faster the convergence (with singular values sorted from largest to smallest). Since $(\sigma_{i+1} + \lambda)/(\sigma_i + \lambda) > \sigma_{i+1}/\sigma_i$, the larger $\lambda$ is, the slower the power iteration converges, and the final effect will worsen.
Therefore, we must carefully adjust the value of $\lambda$ to balance the success rate of Cholesky QR and the convergence speed of the power iteration. My tests found that taking $\lambda = \epsilon\Vert\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\Vert_F$ with $\epsilon=10^{-4}$ achieves relatively good results. Another approach is to use a larger $\lambda$ to guarantee the Cholesky QR success rate and then perform two iterations to improve convergence speed, i.e.:
\begin{equation}
\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\tilde{\boldsymbol{V}}_t + \lambda \tilde{\boldsymbol{V}}_t),\qquad \tilde{\boldsymbol{V}}_t = \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1} + \lambda \boldsymbol{V}_{t-1})
\end{equation}
This simultaneously accounts for both Cholesky QR and the power iteration, with the cost naturally being two $\QR$s per step.
Multi-step Correction
The fourth acceleration trick is called "SCQR2," which is a general correction technique for SCQR. Let's review the two steps of SCQR (given a matrix $\boldsymbol{A}$ to be decomposed):
\begin{align}
1) &\, \boldsymbol{R}^{\top}\boldsymbol{R}= \boldsymbol{A}^{\top}\boldsymbol{A} + \lambda \boldsymbol{I} \quad (\text{Cholesky decomposition of } \boldsymbol{A}^{\top}\boldsymbol{A}+\lambda\boldsymbol{I}) \\[5pt]
2) &\, \boldsymbol{Q} = \boldsymbol{A}\boldsymbol{R}^{-1} \quad (\text{Solve the triangular linear equation } \boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A})
\end{align}
The problem with SCQR is that the larger $\lambda$ is, the easier Cholesky decomposition becomes, but the less orthogonal $\boldsymbol{Q} = \boldsymbol{A}\boldsymbol{R}^{-1}$ becomes. The idea of SCQR2 is to first perform SCQR with a large $\lambda$; although the result is not orthogonal, it is closer to orthogonal compared to the original $\boldsymbol{A}$, implying the condition number has been reduced. At this point, one can perform SCQR again on the result using a smaller $\lambda$ to correct the orthogonality. A rough implementation is as follows:
def scqr(A, eps=1e-9):
"""Shifted Cholesky QR
"""
B, I = A.mT @ A, jnp.eye(A.shape[-1])
B += eps * jnp.linalg.matrix_norm(B, keepdims=True) * I
R = jnp.linalg.cholesky(B, upper=True)
return solve_triangular(R.mT, A.mT, lower=True).mT
def scqr2(A, eps1=1e-4, eps2=1e-8):
"""SCQR两次,失败则回退到默认QR
"""
Q = scqr(scqr(A, eps1), eps2)
return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])
In principle, we need to understand why dual correction is feasible. Suppose the first SCQR results in $\boldsymbol{Q}_1 = \boldsymbol{A}\boldsymbol{R}_1^{-1}$. Although it deviates from orthogonality, it maintains the form of "$\boldsymbol{A}\times \text{upper triangular matrix}$." As mentioned earlier, right-multiplying by an upper triangular matrix does not change the result of QR, allowing us to perform another SCQR on top of the first one. Of course, in principle, we could perform even more correction steps.
Method Summary
We have discussed four acceleration tricks; here is a brief summary of their characteristics.
The first trick—increasing FP32 matrix multiplication precision—is universal; Jax requires manual enabling while newer versions of Torch have it enabled by default. The second, third, and fourth tricks are isolated and cannot be overlaid with each other. Intuitively, the upper bound of the second trick is higher because the third and fourth tricks take $\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}$ as input, meaning the condition number has already been amplified and they are attempting to remedy it, whereas the second method modifies the input to $\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})$, reducing the condition number from the source.
Interestingly, tricks two, three, and four all seem to point toward two $\QR$ iterations. Except for trick three, which can use a single $\QR$ if $\lambda$ is carefully adjusted, the rest require at least two $\QR$s. It seems this is indeed the most reliable choice. In terms of speed, trick three is the fastest if it can be adjusted to use only one $\QR$; otherwise, it is as fast as trick two. Trick four is somewhat unstable; performing SCQR2 on $\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}$ is fast but ineffective, while changing it to $\boldsymbol{M}_t^{\top}\ColNorm(\boldsymbol{M}_t\boldsymbol{V}_{t-1})$ guarantees effectiveness but reduces speed.
I recommend the combination of tricks one and two, as it provides guarantees in both effectiveness and efficiency. Tested separately, its speed is about half that of the Newton-Schulz iteration for $\msign$. Some readers might think, "all that effort for only half the speed?" This is actually quite ideal, considering we are calculating in FP32 and performing two \(\QR\)s. Furthermore, the computation time for the $\msign$ step accounts for only about 1% of the end-to-end time; doubling it only adds another 1% to the total time, which is acceptable.
Additionally, the efficiency of Newton-Schulz iteration depends on the number of iteration steps. If coefficients from Polar Express are used to further increase the number of steps to improve accuracy, then the speed gap between it and our method will further narrow. In short, streaming power iteration is indeed slower, but it results in richer and more accurate information (SVD), allowing for more possibilities.
Conclusion
This article introduced further tricks for streaming power iteration, the essence of which is attempting to reduce the matrix condition number to improve the success rate of Cholesky QR.