Steepest Descent on Manifolds: 2. Muon + Orthogonality

By 苏剑林 | August 06, 2025

This article continues our series on constrained optimization. In the previous post, "Steepest Descent on Manifolds: 1. SGD + Hypersphere", we revisited the "least action principle" for optimizers, proposing that the core difference between various optimizers lies in the different constraints imposed on the update magnitude. If this constraint is the Euclidean norm, then the corresponding steepest descent is SGD. Furthermore, we discussed the result of adding a magnitude constraint to the parameters, which constitutes steepest descent on a hypersphere manifold.

However, the previous article was merely a "warm-up" because it dealt with relatively simple vector parameter optimization. This article formally enters the more challenging part—where optimization parameters transition from vectors to matrices, and the increment constraint is changed to the spectral norm, giving rise to the Muon optimizer. Next, we add an orthogonality constraint to the parameters, which leads to the Muon optimizer under an orthogonal manifold.

Proposition Description

Let the parameters to be optimized have a matrix form $\boldsymbol{W}\in\mathbb{R}^{n\times m}$; without loss of generality, let $n\geq m$. According to the "least action principle" from the previous article, we conclude that the steepest descent increment $\Delta\boldsymbol{W}$ should satisfy

\begin{equation}\min_{\Delta \boldsymbol{W}} \mathcal{L}(\boldsymbol{W} +\Delta\boldsymbol{W}) \qquad \text{s.t.}\qquad \rho(\Delta\boldsymbol{W})\leq \eta\end{equation}

If $\rho$ is taken as the $F$-norm (Frobenius Norm), we obtain the same result as in the previous section because the $F$-norm treats the matrix as a vector and computes its L2 norm, so the result is equivalent to SGD treating the matrix as a vector. To obtain a result that more deeply reveals and fits the nature of matrices, the norm we choose here is the Spectral Norm, also known as the "2-norm," denoted as $\Vert\cdot\Vert_2$.

As for why we choose the spectral norm, readers can refer to "An Appreciation of the Muon Optimizer: A Qualitative Leap from Vector to Matrix", "Muon Sequel: Why We Chose to Try Muon?", and "Higher-Order muP: Simple but Sophisticated Spectral Condition Scaling"; I will not repeat the introduction here. Simply put, the spectral norm is the tightest norm that reveals the variation of a linear layer, making it more suitable as a measure of "stability" for matrices.

Following the previous steps, applying a first-order approximation to $\mathcal{L}(\boldsymbol{W} +\Delta\boldsymbol{W})$ yields $\mathcal{L}(\boldsymbol{W}) + \langle \boldsymbol{G}, \Delta\boldsymbol{W}\rangle_F$, where $\boldsymbol{G}=\nabla_{\boldsymbol{W}}\mathcal{L}(\boldsymbol{W})$. Here $\langle\cdot,\cdot\rangle_F$ is the inner product of the two matrices after flattening them into vectors, which is equal to $\mathop{\text{tr}}(\boldsymbol{G}^{\top}\Delta\boldsymbol{W})$. Letting $\Delta\boldsymbol{W} = -\eta \boldsymbol{\Phi}$, the original proposition can be simplified to

\begin{equation}\max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1\label{eq:muon-obj}\end{equation}

Up to this point, these transformation steps are general. If you have forgotten the details, please refer to the previous article.

Basic Result

The solution process for the objective $\eqref{eq:muon-obj}$ was already given in the "Matrix Norm" section of "An Appreciation of the Muon Optimizer: A Qualitative Leap from Vector to Matrix", but for the sake of completeness, I will repeat it here. Let the SVD of $\boldsymbol{G}$ be $\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} = \sum\limits_{i=1}^r \sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top}$, where $r$ is the rank of $\boldsymbol{G}$. We have

\begin{equation}\tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi})=\tr\left(\sum_{i=1}^r \sigma_i \boldsymbol{v}_i \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\right) = \sum_{i=1}^r \sigma_i \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i\end{equation}

By definition, when $\Vert\boldsymbol{\Phi}\Vert_2=1$, $\Vert\boldsymbol{\Phi}\boldsymbol{v}_i\Vert_2\leq \Vert\boldsymbol{v}_i\Vert_2=1$, hence $\boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i\leq 1$. Therefore

\begin{equation}\tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi})\leq \sum_{i=1}^r \sigma_i = \Vert \boldsymbol{G}\Vert_*\end{equation}

Here $\Vert\cdot\Vert_*$ is called the Nuclear Norm. Equality holds when all $\boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i$ are equal to 1, at which point

\begin{equation}\boldsymbol{\Phi} = \sum_{i=1}^r \boldsymbol{u}_i \boldsymbol{v}_i^{\top} = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top} = \mathop{\text{msign}}(\boldsymbol{G})\end{equation}

Note that if $r < m$, then superimposing $\boldsymbol{u}_{r+1} \boldsymbol{v}_{r+1}^{\top}, \boldsymbol{u}_{r+2} \boldsymbol{v}_{r+2}^{\top}, \dots$ will also make the equality hold, meaning the solution is not unique. However, because the terms for indices greater than $r$ cannot be uniquely determined, the formula above provides a deterministic and minimal solution. Additionally, interested readers can try using the "big gun"—the von Neumann trace inequality—to find the general solution under the Schatten-$p$ norm, where the spectral norm corresponds to the $p\to\infty$ case.

Orthogonal Manifold

Thus far, we have proved that for matrix parameters, the direction of steepest descent under the spectral norm constraint is not the negative gradient direction $-\boldsymbol{G}$, but rather requires an additional $\text{msign}$ operator, i.e., $-\mathop{\text{msign}}(\boldsymbol{G})$. This is exactly the Muon optimizer used to train Kimi K2, which is currently one of the most competitive optimizers. This, in turn, suggests that the spectral norm is a very appropriate stability constraint for matrices.

Of course, the results so far are already known. Now let's start something new—adding an orthogonality constraint $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$ to the parameters $\boldsymbol{W}$ (Source: "Orthogonal manifold"). This falls into two cases: first, $n=m$, where $\boldsymbol{W}$ is a proper orthogonal matrix satisfying $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{W}\boldsymbol{W}^{\top}=\boldsymbol{I}$; second, $n > m$, where $\boldsymbol{W}\boldsymbol{W}^{\top}=\boldsymbol{I}$ cannot be satisfied. This is typically called a semi-orthogonal matrix, and the corresponding space is the Stiefel manifold.

Specifically, the problem we want to solve is:

\begin{equation}\max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,(\boldsymbol{W} - \eta \boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta \boldsymbol{\Phi})=\boldsymbol{I}\end{equation}

Still adhering to the principle that "first-order approximation is sufficient," the last condition can be simplified to $\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}$, meaning $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ is a skew-symmetric matrix:

\begin{equation}\max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}\label{eq:muon-obj-orth}\end{equation}

When is an orthogonality constraint used? In fact, scenarios are not uncommon. For example, in classification problems, if it is known that various categories are uncorrelated, one can consider imposing orthogonality constraints on the category matrix. However, usually, we achieve approximate orthogonality by adding a regularization term $\Vert\boldsymbol{W}^{\top}\boldsymbol{W}-\boldsymbol{I}\Vert_F^2$. Another example is in LoRA scenarios, where the $\boldsymbol{A}\boldsymbol{B}$ parametrization is redundant, and redundancy can be reduced through orthogonality constraints, etc.

Solution Process

To solve the objective $\eqref{eq:muon-obj-orth}$, similar to the previous article, we introduce a matrix of Lagrange multipliers $\boldsymbol{\Lambda}\in\mathbb{R}^{m\times m}$, yielding

\begin{equation}\begin{aligned} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) =&\, \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) + \tr(\boldsymbol{\Lambda}^{\top}(\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W})) \\ =&\, \tr((\boldsymbol{G} + \boldsymbol{W}(\boldsymbol{\Lambda} + \boldsymbol{\Lambda}^{\top}))^{\top}\boldsymbol{\Phi}) \\ \leq &\,\Vert\boldsymbol{G} + \boldsymbol{W}(\boldsymbol{\Lambda} + \boldsymbol{\Lambda}^{\top})\Vert_* \end{aligned}\end{equation}

The second equality uses the trace identity $\tr(\boldsymbol{A}\boldsymbol{B}) = \tr(\boldsymbol{B}\boldsymbol{A}) = \tr(\boldsymbol{A}^{\top}\boldsymbol{B}^{\top}) = \tr(\boldsymbol{B}^{\top}\boldsymbol{A}^{\top})$. According to the previous Muon result, the condition for equality is

\begin{equation}\boldsymbol{\Phi} = \mathop{\text{msign}}(\boldsymbol{G} + \boldsymbol{W}(\boldsymbol{\Lambda} + \boldsymbol{\Lambda}^{\top}))\end{equation}

The remaining problem is to find a symmetric matrix $\boldsymbol{X} = \boldsymbol{\Lambda} + \boldsymbol{\Lambda}^{\top}$ such that $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ is skew-symmetric. This is easy to solve for $n=m$ because $\boldsymbol{W}^{\top}$ can then be absorbed into the $\mathop{\text{msign}}$ operator:

\begin{equation}\boldsymbol{W}^{\top}\boldsymbol{\Phi} = \boldsymbol{W}^{\top}\mathop{\text{msign}}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) = \mathop{\text{msign}}(\boldsymbol{W}^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})) = \mathop{\text{msign}}(\boldsymbol{W}^{\top}\boldsymbol{G} +\boldsymbol{X})\end{equation}

Note that $\mathop{\text{msign}}$ has another property: it preserves skew-symmetry. That is, if a square matrix $\boldsymbol{M}$ is skew-symmetric, then $\mathop{\text{msign}}(\boldsymbol{M})$ is also skew-symmetric (please prove this). Thus, to make $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ skew-symmetric, we just need to make $\boldsymbol{W}^{\top}\boldsymbol{G} +\boldsymbol{X}$ skew-symmetric. Since $\boldsymbol{X}$ is symmetric, this is equivalent to decomposing $\boldsymbol{W}^{\top}\boldsymbol{G}$ into the sum of a symmetric matrix and a skew-symmetric matrix, which has a standard solution:

\begin{equation}\boldsymbol{W}^{\top}\boldsymbol{G} = \underbrace{\frac{1}{2}(\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{G}^{\top}\boldsymbol{W})}_{[\boldsymbol{W}^{\top}\boldsymbol{G}] _{\text{sym}}} + \underbrace{\frac{1}{2}(\boldsymbol{W}^{\top}\boldsymbol{G} - \boldsymbol{G}^{\top}\boldsymbol{W})}_{[\boldsymbol{W}^{\top}\boldsymbol{G}] _{\text{skew}}} \end{equation}

where $[\boldsymbol{M}]_{\text{sym}} = (\boldsymbol{M}+\boldsymbol{M}^{\top})/2$ and $[\boldsymbol{M}]_{\text{skew}} = (\boldsymbol{M}-\boldsymbol{M}^{\top})/2$. Based on the above identity, we can directly conclude $\boldsymbol{X} = -[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$. The solution for $n > m$ is more complex and will be discussed in detail in the next article; this post aims to fully solve the $n=m$ case.

Retraction Operation

In summary, for $n=m$, the final result we obtained is

\begin{equation}\begin{aligned} \boldsymbol{\Phi} =&\, \mathop{\text{msign}}(\boldsymbol{G} - \boldsymbol{W}[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}) \\ =&\, \boldsymbol{W}\boldsymbol{W}^{\top}\mathop{\text{msign}}(\boldsymbol{G} - \boldsymbol{W}[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}) \\ =&\, \boldsymbol{W}\mathop{\text{msign}}(\boldsymbol{W}^{\top}\boldsymbol{G} - [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}) \\ =&\, \boldsymbol{W}\mathop{\text{msign}}([\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}) \\ \end{aligned}\end{equation}

Therefore, the new variable is

\begin{equation}\boldsymbol{W} - \eta \boldsymbol{\Phi} = \boldsymbol{W}(\boldsymbol{I} - \eta\,\underbrace{\mathop{\text{msign}}([\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}})}_{\text{denoted as }\boldsymbol{O}})\label{eq:updated-W}\end{equation}

This is not exactly an orthogonal matrix, but it is accurate to $\mathcal{O}(\eta^2)$, which is consistent with our "first-order approximation is sufficient" principle. To observe this, we only need to verify

\begin{equation}\begin{aligned} (\boldsymbol{I} - \eta\boldsymbol{O})^{\top}\boldsymbol{W}^{\top}\boldsymbol{W}(\boldsymbol{I} - \eta\boldsymbol{O}) =&\,(\boldsymbol{I} - \eta\boldsymbol{O})^{\top}(\boldsymbol{I} - \eta\boldsymbol{O}) \\ =&\,\boldsymbol{I} - \eta(\boldsymbol{O}^{\top} + \boldsymbol{O}) + \eta^2\boldsymbol{O}^{\top}\boldsymbol{O} \\ =&\,\boldsymbol{I} + \eta^2\boldsymbol{O}^{\top}\boldsymbol{O} \\ \end{aligned}\label{eq:orth-check}\end{equation}

If $[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}$ is full rank, then $\boldsymbol{O}$ is an orthogonal matrix, i.e., $\boldsymbol{O}^{\top}\boldsymbol{O}=\boldsymbol{I}$. In this case, simply dividing by $\sqrt{1+\eta^2}$ would satisfy the orthogonality of $\eqref{eq:updated-W}$. However, when $[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}$ is not full rank, there is no simple transformation to satisfy orthogonality. In such cases, a general approach is to find the nearest orthogonal matrix, which is exactly what $\mathop{\text{msign}}$ does! Therefore, the complete update rule is

\begin{equation}\boldsymbol{W} \quad \leftarrow\quad \mathop{\text{msign}}(\boldsymbol{W} - \eta \boldsymbol{\Phi}) = \mathop{\text{msign}}(\boldsymbol{W}(\boldsymbol{I} - \eta\boldsymbol{O})) = \boldsymbol{W}\mathop{\text{msign}}(\boldsymbol{I} - \eta\boldsymbol{O})\end{equation}

But this requires calculating $\mathop{\text{msign}}$ twice; let's try to simplify it. According to the definition and Eq. $\eqref{eq:orth-check}$, we have

\begin{equation}\mathop{\text{msign}}(\boldsymbol{I} - \eta\boldsymbol{O}) = (\boldsymbol{I} - \eta\boldsymbol{O})(\boldsymbol{I} + \eta^2\boldsymbol{O}^{\top}\boldsymbol{O})^{-1/2}\end{equation}

Note that regardless of whether it is full rank, $(\boldsymbol{O}^{\top}\boldsymbol{O})^2 = \boldsymbol{O}^{\top}\boldsymbol{O}$. Let $(1+\eta^2 x)^{-1/2}=1 + a_1 x + a_2 x^2 + a_2 x^3 + \dots$, then

\begin{equation}\begin{aligned} (\boldsymbol{I} + \eta^2\boldsymbol{O}^{\top}\boldsymbol{O})^{-1/2} =&\, \boldsymbol{I} + a_1 (\boldsymbol{O}^{\top}\boldsymbol{O}) + a_2 (\boldsymbol{O}^{\top}\boldsymbol{O})^2 + a_3 (\boldsymbol{O}^{\top}\boldsymbol{O})^3 + \dots \\ =&\, \boldsymbol{I} + a_1 (\boldsymbol{O}^{\top}\boldsymbol{O}) + a_2 (\boldsymbol{O}^{\top}\boldsymbol{O}) + a_3 (\boldsymbol{O}^{\top}\boldsymbol{O}) + \dots \\ =&\, \boldsymbol{I} - \boldsymbol{O}^{\top}\boldsymbol{O} + \underbrace{(1 + a_1 + a_2 + a_3 + \dots)}_{(1+\eta^2 x)^{-1/2}\text{ at }x=1}\boldsymbol{O}^{\top}\boldsymbol{O} \\ \end{aligned}\end{equation}

This eliminates one $\mathop{\text{msign}}$ calculation. The simplified final result is

\begin{equation}\boldsymbol{W} \quad \leftarrow\quad \boldsymbol{W}(\boldsymbol{I} - \eta\boldsymbol{O})\left(\boldsymbol{I} - \boldsymbol{O}^{\top}\boldsymbol{O} + \frac{\boldsymbol{O}^{\top}\boldsymbol{O}}{\sqrt{1+\eta^2}}\right)\end{equation}

Summary

In this article, we revisited the conclusion that adding a spectral norm constraint to matrix parameter updates yields the Muon optimizer. We then explored the form of the Muon optimizer when an orthogonality constraint is added. If you want your parameters to always remain as orthogonal matrices during updates, this article may be of some reference value.