Steepest Descent on Manifolds: 3. Muon + Stiefel

By 苏剑林 | August 08, 2025

As mentioned last time, when we transition the optimization target from vector parameters to matrix parameters and choose the spectral norm constraint more suitable for matrices, the Muon optimizer emerges naturally. Furthermore, we considered the steepest descent direction under orthogonal constraints on parameters, which was divided into two parts: square matrices and non-square matrices. The solution for square matrices was completed in the previous article, but the non-square matrix part remained unresolved.

The goal of this article is to fill in the solution for the non-square matrix part, allowing optimization under orthogonal constraints to be fully resolved.

Task Information

Let's briefly review the results from the previous article, "Steepest Descent on Manifolds: 2. Muon + Orthogonality". The objective we want to solve is:

\begin{equation} \max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,\,(\boldsymbol{W} - \eta \boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta \boldsymbol{\Phi})=\boldsymbol{I} \end{equation}

where $\boldsymbol{W},\boldsymbol{\Phi}\in\mathbb{R}^{n\times m}(n \geq m)$, and $\Vert\cdot\Vert_2$ is the spectral norm. Based on the principle that "first-order approximation is sufficient," it can be simplified to:

\begin{equation} \max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,\,\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0} \label{eq:ori-obj} \end{equation}

The set of all $\boldsymbol{\Phi}$ satisfying $\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}$ is also called the "tangent space" of $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$. In the previous article, we already found the general form of the solution:

\begin{equation} \boldsymbol{\Phi} = \msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) \end{equation}

where $\boldsymbol{X}\in\mathbb{R}^{m\times m}$ is a symmetric matrix to be determined.

The remaining difficulty is providing a calculation method for the symmetric matrix $\boldsymbol{X}$ such that $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ is an anti-symmetric matrix. Once the solution is found, the corresponding $\boldsymbol{\Phi}$ is naturally the optimal solution. For $n=m$, we already obtained the closed-form solution $\boldsymbol{X}=-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$; the truly difficult case is when $n > m$, which is also known as the "Stiefel manifold." This is the open problem left by "Orthogonal manifold".

Equation Transformation

To put it simply, our current task is to solve the system of equations:

\begin{equation} \boldsymbol{W}^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})+\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\boldsymbol{W} = \boldsymbol{0} \label{eq:start} \end{equation}

When $n=m$, $\boldsymbol{W}^{\top}$ can be absorbed directly into the $\msign$ function. Thus, the solution is simplified. However, for $n > m$, such absorption is not possible, which is the difficulty in solving it. I tend to think that there is no simple explicit solution when $n > m$, so let's look for numerical algorithms.

According to the definition $\msign(\boldsymbol{M})=\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}$, we can write:

\begin{equation} \boldsymbol{W}^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) = \boldsymbol{W}^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})\boldsymbol{Q}^{-1} = (\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1} \end{equation}

where $\boldsymbol{Q} = ((\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}))^{1/2}$. Under this new notation, the system becomes:

\begin{equation} (\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1} + \boldsymbol{Q}^{-1}(\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X}) = \boldsymbol{0} \end{equation}

Multiplying by $\boldsymbol{Q}$ on both the left and the right, we get:

\begin{equation} \boldsymbol{Q}(\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X}) + (\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X})\boldsymbol{Q} = \boldsymbol{0} \label{eq:r-x} \end{equation}

where $\boldsymbol{Q}$ also satisfies:

\begin{equation} \boldsymbol{Q} = (\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) \label{eq:r-q} \end{equation}

Iterative Solution

My idea now is starting from some initial value of $\boldsymbol{X}$, substitute it into equation $\eqref{eq:r-q}$ to get $\boldsymbol{Q}$, and then substitute $\boldsymbol{Q}$ into the system $\eqref{eq:r-x}$ to solve for a new $\boldsymbol{X}$, iterating repeatedly until convergence. Given that $\msign$ is known, equation $\eqref{eq:r-q}$ can be explicitly calculated, so the only difficulty is solving the system $\eqref{eq:r-x}$.

We can rearrange equation $\eqref{eq:r-x}$:

\begin{equation} \boldsymbol{Q}\boldsymbol{X} + \boldsymbol{X}\boldsymbol{Q} = -2[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} \label{eq:r-xx} \end{equation}

Given $\boldsymbol{Q}$, this is actually a system of linear equations for $\boldsymbol{X}$, known as the "continuous Lyapunov equation," which can also be seen as a special case of the "Sylvester equation." If we only use the CPU for calculation, Scipy already has a built-in solver function scipy.linalg.solve_continuous_lyapunov, which can be called directly.

As for the choice of initial value, we can consider the solution for square matrices $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$, as this is a natural transition from square matrices to non-square matrices. We can also observe the rationality of the initial value $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ from another equivalent form of equation $\eqref{eq:r-xx}$:

\begin{equation} \boldsymbol{Q}(\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}) + (\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}})\boldsymbol{Q} =[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}\boldsymbol{Q} -\boldsymbol{Q}[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}} \end{equation}

Therefore, the accuracy of $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ depends on the degree to which $[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}$ commutes with $\boldsymbol{Q}$. The closer they are to commuting matrices, the more accurate $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ becomes. However, subsequent experimental results show that our iterative algorithm is not particularly sensitive to the initial value; even using an all-zero matrix as an initial value works fine.

Do It Yourself

Just now, we mentioned that Scipy has a function to solve the Lyapunov equation, so it can be called directly without worrying about the solving process. However, this is limited to the CPU version of Scipy. I checked, and neither Torch nor Jax has a similar function. So, if you want to use GPU for calculation, you have to "stand on your own feet."

There are two ways to program the solution to equation $\eqref{eq:r-xx}$ yourself. One is to follow the approach in "What can the matrix sign function mcsgn calculate?" and solve using $\mcsgn$ (not $\msign$):

\begin{equation} \boldsymbol{X} = \mcsgn \left( \begin{bmatrix} -\boldsymbol{Q} & -[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} \\ \boldsymbol{0} & \boldsymbol{Q} \end{bmatrix} \right)_{[:m,m:]} \end{equation}

The second is based on SVD, a method we already used in "The Derivative of msign" when calculating the gradient of $\msign$. Let's introduce it again in combination with equation $\eqref{eq:r-xx}$. Given the definition of $\boldsymbol{Q}$, it is positive definite and symmetric. Thus, it can be decomposed into eigenvalues as $\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$, where $\boldsymbol{V}$ is an orthogonal matrix and $\boldsymbol{\Sigma}=\mathop{\text{diag}}(\sigma_1,\cdots,\sigma_m)$ is a diagonal matrix. Substituting this back into equation $\eqref{eq:r-xx}$ and rearranging gives:

\begin{equation} \boldsymbol{\Sigma}(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V}) + (\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\boldsymbol{\Sigma} = -2\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V} \end{equation}

The left side can be expressed as $(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V}) \otimes \boldsymbol{S}$, where $\otimes$ is the Hadamard product and $\boldsymbol{S}_{i,j} = \sigma_i + \sigma_j$. From this, we can solve for:

\begin{equation} \boldsymbol{X} = -2\boldsymbol{V}((\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V}) \oslash \boldsymbol{S})\boldsymbol{V}^{\top} \end{equation}

where $\oslash$ is the Hadamard quotient. The interesting part here is that performing an eigenvalue decomposition on $\boldsymbol{Q}$ is essentially equivalent to performing an SVD on $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$, and performing an SVD on $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ can also be used to find $\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})$. Thus, only one SVD is needed to calculate both $\msign$ and the solution to equation $\eqref{eq:r-xx}$.

Each of the two approaches has its characteristics. The first approach requires calculating $\msign$ for an $m \times m$ matrix and then $\mcsgn$ for a $2m \times 2m$ matrix. Although both can be calculated efficiently using Newton-Schulz iteration, the cost is not insignificant. In addition, we must choose coefficients that converge and provide high precision (refer to the results in "Newton-Schulz iteration for the msign operator (II)"). Otherwise, the calculation of $\mcsgn$ and $\msign$ will not converge, let alone $\boldsymbol{X}$.

The second approach requires SVD. Although SVD is computationally complex and often requires FP32 precision, in this problem, only one SVD per iteration is needed to simultaneously compute $\msign$ and $\boldsymbol{X}$. The overall efficiency is not too bad. If there are not many matrix parameters requiring orthogonal constraints, SVD might be the easiest choice.

Related Results

Prior to this article, @leloy, in his blog post "Heuristic Solutions for Steepest Descent on the Stiefel Manifold," proposed two heuristic solutions for the original objective $\eqref{eq:ori-obj}$. "Heuristic" here means that in most cases, it can yield a decent solution, but it doesn't guarantee the optimal one. Let's learn about them here as well.

The first method can be described as purely geometric. First, we define a projection operation:

\begin{equation} \proj_{\boldsymbol{W}}(\boldsymbol{M}) = \boldsymbol{M} - \boldsymbol{W}[\boldsymbol{W}^{\top}\boldsymbol{M}]_{\text{sym}} \end{equation}

It can be verified that $\boldsymbol{W}^{\top}\proj_{\boldsymbol{W}}(\boldsymbol{M})$ is always an anti-symmetric matrix, meaning $\proj_{\boldsymbol{W}}(\boldsymbol{M})$ is always in the tangent space. Thus, we treat it as the projection of any matrix $\boldsymbol{M}$ onto the tangent space of $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$.

Starting from the gradient $\boldsymbol{G}$, $\proj_{\boldsymbol{W}}(\boldsymbol{M})$ is certainly in the tangent space, but we know the update amount for Muon must be an orthogonal matrix (when full rank). Since $\proj_{\boldsymbol{W}}(\boldsymbol{M})$ is not necessarily orthogonal, we can use $\msign$ to find the nearest orthogonal matrix, which is $\msign(\proj_{\boldsymbol{W}}(\boldsymbol{M}))$. However, after $\msign$, it might no longer be in the tangent space, so we project it back into the tangent space and then look for the nearest orthogonal matrix again, iterating repeatedly:

\begin{equation} \boldsymbol{\Phi} = (\msign \circ \proj_{\boldsymbol{W}} \circ \cdots \circ \msign \circ \proj_{\boldsymbol{W}})(\boldsymbol{M}) \end{equation}

This is @leloy's first approach: alternately projecting onto the tangent space and the orthogonal space until convergence. It's quite intuitive. Under relatively random conditions, it yields a result very close to the optimal solution, sometimes accurate to four decimal places, leading me initially to think it was the exact solution. However, after further search, I found cases where the deviation from the optimal solution was significant enough to confirm that this was just a coincidence, not the optimal solution.

The second method can be called a line search. Specifically, when $n > m$, we can consider padding $\boldsymbol{W}$ into a standard $n \times n$ orthogonal matrix $[\boldsymbol{W}, \overline{\boldsymbol{W}}]$ and then decomposing the desired $\boldsymbol{\Phi}$ into two parts: $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ and $\overline{\boldsymbol{W}}^{\top}\boldsymbol{\Phi}$. Then @leloy makes a greedy approximation, first solving for the optimal $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ and then for the optimal $\overline{\boldsymbol{W}}^{\top}\boldsymbol{\Phi}$, with a line search between the two to improve accuracy.

This sequence of operations indeed yields a reasonably good approximation that is guaranteed to be in the tangent space and satisfy orthogonality. The solution process involves calculating the spectral norm, $\msign$, and Cholesky decomposition. For details, please refer to the author's article. Furthermore, when $m=2$, it can theoretically find the optimal solution because a $2 \times 2$ anti-symmetric matrix has only one free parameter, which corresponds exactly to one degree of freedom in the line search.

Let's Test It

Below, we test several methods in Numpy. The main goal is to verify the correctness of the methods themselves, so we use SVD and eigenvalue decomposition to implement $\msign$ and $\mcsgn$.

import numpy as np
import scipy as sp

def mcsgn(x):
    """Accurate calculation of mcsgn using eigenvalue decomposition"""
    s, v = np.linalg.eig(x)
    return v @ np.diag(np.sign(s)) @ np.linalg.inv(v)

def msign(g):
    """Accurate calculation of msign using SVD"""
    u, s, vh = np.linalg.svd(g, full_matrices=False)
    return u @ np.diag(np.sign(s)) @ vh

def sym(x):
    """Symmetrization"""
    return (x + x.T) * 0.5

def skew(x):
    """Anti-symmetrization"""
    return (x - x.T) * 0.5

def proj(g, w):
    """Projection onto the orthogonal tangent space"""
    return g - w @ sym(w.T @ g)

def jianlin_by_mcsgn(g, w, steps=20):
    """Iteration using mcsgn as constructed in this article"""
    n, m = g.shape
    x = -sym(w.T @ g)
    for i in range(1, steps + 1):
        phi = msign(z := g + w @ x)
        print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
        if i == steps:
            return phi
        q = z.T @ phi
        x = mcsgn(np.block([[-q, -sym(q @ w.T @ g)], [np.zeros_like(q), q]]))[:m, m:]
        # x = -2 * sp.linalg.solve_continuous_lyapunov(q, sym(q @ w.T @ g))

def jianlin_by_svd(g, w, steps=20):
    """Iteration using SVD as constructed in this article"""
    x = -sym(w.T @ g)
    for i in range(1, steps + 1):
        u, s, vh = np.linalg.svd(z := g + w @ x, full_matrices=False)
        phi = (u * np.sign(s)) @ vh
        print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
        if i == steps:
            return phi
        x = -2 * vh.T @ (vh @ sym(z.T @ phi @ w.T @ g) @ vh.T / (s + s[:, None])) @ vh

def leloy_v1(g, w, steps=20):
    """Alternating projection onto the tangent space and orthogonal space"""
    phi = g
    for i in range(1, steps + 1):
        phi = msign(proj(phi, w))
        print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
    return phi

def leloy_v2(g, w, steps=20):
    """Partial greedy solver + line search (simplified form)"""
    n, m = g.shape
    taus = np.linspace(0, 1, steps + 2)[1:-1]
    p_max, tau_opt, phi_opt = 0, 0, None
    for tau in taus:
        b = (b := skew(w.T @ g)) * tau / max(np.linalg.norm(b, ord=2), 1e-8)
        r = np.linalg.cholesky(np.eye(m) - b.T @ b)
        c = msign((np.eye(n) - w @ w.T) @ g @ r) @ r
        phi = w @ b + c
        print('tau:', tau, ', inner product:', p := (phi * g).sum())
        if p > p_max:
            p_max, tau_opt, phi_opt = p, tau, phi
    print('best inner product:', p_max, ', tau:', tau_opt)
    return phi_opt

# Test Case 1
w = np.array([[ 0.69453734, -0.26590866, -0.44721806, 0.2753041 ],
              [-0.11738148, -0.5588003 , -0.17580748, 0.3218624 ],
              [-0.4515288 , -0.23489913, -0.26683152, -0.25739142],
              [ 0.02392521, 0.02664689, 0.48423648, 0.6193399 ],
              [ 0.45194831, -0.25206333, 0.27654836, -0.60242337],
              [ 0.21197332, -0.09174792, 0.24521762, -0.08484317],
              [-0.15496767, -0.26446804, -0.34942415, -0.01877318],
              [-0.16181251, -0.6474956 , 0.45243263, -0.01776086]])

g = np.array([[-17.85745  , -10.758921 ,  -2.9583392,   6.245008 ],
              [-28.883093 ,  19.772121 ,   8.086545 , -21.564013 ],
              [ -1.6274693, -14.96859  ,   3.4465332,   3.1070817],
              [ -7.8890743,   1.5304767,  -8.949573 ,   9.579629 ],
              [  2.246596 ,  14.46572  ,  12.8451   ,  -2.7370298],
              [ -0.9496974,   6.9879804,   2.849277 ,   1.1148484],
              [ -8.115278 , -18.054405 ,  -0.19287404,  7.0389237],
              [-15.062008 , -15.02901  ,   2.9083247,  21.706533 ]])

phi1 = jianlin_by_mcsgn(g, w, steps=100)
phi2 = jianlin_by_svd(g, w, steps=100)
phi3 = leloy_v1(g, w, steps=100)
phi4 = leloy_v2(g, w, steps=100)
assert np.allclose(phi1, phi2)

# Random Case
w = np.linalg.qr(np.random.randn(100, 50))[0]
g = np.random.randn(100, 50)

phi1 = jianlin_by_mcsgn(g, w, steps=10)
phi2 = jianlin_by_svd(g, w, steps=10)
phi3 = leloy_v1(g, w, steps=10)
phi4 = leloy_v2(g, w, steps=10)
assert np.allclose(phi1, phi2)

For the first set of $\boldsymbol{W}, \boldsymbol{G}$ in the code, the optimal $\tr(\boldsymbol{G}^{\top} \boldsymbol{\Phi})$ found by my method is roughly $90$, and the results from $\mcsgn$ and SVD are identical. @leloy's first method yields about $70$, and the second about $80$, both falling short of the optimal solution.

However, the first set of $\boldsymbol{W}, \boldsymbol{G}$ was an extreme example specifically found to highlight the differences. With relatively random values, my method and @leloy's first method are quite close, and fewer iterations (5–10 steps) are needed. In these cases, @leloy's second method deviates further from the optimal solution. Readers can test this by constructing their own examples.

Further Thoughts

That concludes the solution for the original problem $\eqref{eq:ori-obj}$. Here are a few more points worth discussing.

First, for simplicity, the iterative process I described assumes that $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ remains full rank (rank $m$) throughout. Otherwise, matrix $\boldsymbol{S}$ would have zero components, making $\oslash\boldsymbol{S}$ problematic. However, this difficulty is not fundamental because equation $\eqref{eq:start}$ must have a solution. When the denominator is zero, the numerator must also be zero. We can simply replace any zeros in $\boldsymbol{S}$ with a small positive number to obtain the correct result.

From a numerical computation perspective, we rarely encounter singular values exactly equal to zero, so this issue shouldn't be a major concern—we can usually assume $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ is full rank. Under this assumption, the retraction operation becomes simple because:

\begin{equation} (\boldsymbol{W} - \eta\boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta\boldsymbol{\Phi}) = \boldsymbol{W}^{\top} \boldsymbol{W} - \eta(\boldsymbol{W}^{\top} \boldsymbol{\Phi} + \boldsymbol{\Phi}^{\top}\boldsymbol{W}) + \eta^2 \boldsymbol{\Phi}^{\top}\boldsymbol{\Phi} \end{equation}

According to the definition of the Stiefel manifold, the first term on the right is $\boldsymbol{I}$. Based on the tangent space condition, the second term is $\boldsymbol{0}$. Finally, if it's full rank, the result of $\msign$ is also a Stiefel manifold matrix, so the third term is $\eta^2 \boldsymbol{I}$. The total result is $(1+\eta^2)\boldsymbol{I}$. Dividing by $\sqrt{1+\eta^2}$ performs the retraction:

\begin{equation} \boldsymbol{W}\quad\leftarrow\quad\frac{\boldsymbol{W} - \eta\boldsymbol{\Phi}}{\sqrt{1+\eta^2}} \end{equation}

This raises a more profound question: whether with a simple orthogonal manifold or a more complex Stiefel manifold, what precision should we use for calculations? "Orthogonality" is a precise quantitative constraint; $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$ includes $m(m+1)/2$ equality constraints. Using the above formula for iterations in low precision will inevitably lead to significant deviation from orthogonality over time, not to mention errors in solving for $\boldsymbol{\Phi}$.

Therefore, I believe that unless we periodically apply orthogonalization (i.e., $\boldsymbol{W}\leftarrow\msign(\boldsymbol{W})$) to pull the parameters back onto the orthogonal manifold, the calculation precision should be at least FP32. Since the number of parameters requiring orthogonal constraints is usually small, this is generally not too large a cost.

Summary

This article extends the "Muon + Orthogonal manifold" from the previous post to the more general "Muon + Stiefel manifold," with the main finding being an iterative algorithm for solving for the corresponding update amount.