An Appreciation of the Muon Optimizer: A Fundamental Leap from Vectors to Matrices

By 苏剑林 | December 10, 2024

With the arrival of the LLM era, research enthusiasm for optimizers in academia seems to have waned. This is primarily because the current mainstream AdamW is already sufficient to meet most needs, and making major changes to an optimizer involves massive verification costs. Therefore, current developments in optimizers are mostly "minor tweaks" to AdamW made by the industry based on their own training experiences.

However, an optimizer named "Muon" has recently gained considerable attention on Twitter. It claims to be more efficient than AdamW and is not just a "minor adjustment" on top of Adam; rather, it reflects some profound principles regarding the differences between vectors and matrices. In this article, let us appreciate it together.

Muon vs AdamW Comparison
Comparison of Muon and AdamW performance (Source: Twitter @Yuchenj_UW)

Initial Exploration of the Algorithm

Muon stands for "MomentUm Orthogonalized by Newton-schulz." It is applicable to matrix parameters $\boldsymbol{W} \in \mathbb{R}^{n \times m}$, and its update rule is:

\begin{equation} \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t [\text{msign}(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}] \\ \end{aligned} \end{equation}

Here, $\text{msign}$ is the matrix sign function. It is not simply a component-wise $\text{sign}$ operation on the matrix, but a matrix generalization of the $\text{sign}$ function. Its relationship with SVD is:

\begin{equation}\boldsymbol{U},\boldsymbol{\Sigma},\boldsymbol{V}^{\top} = \text{SVD}(\boldsymbol{M}) \quad\Rightarrow\quad \text{msign}(\boldsymbol{M}) = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top}\end{equation}

where $\boldsymbol{U} \in \mathbb{R}^{n \times n}, \boldsymbol{\Sigma} \in \mathbb{R}^{n \times m}, \boldsymbol{V} \in \mathbb{R}^{m \times m}$, and $r$ is the rank of $\boldsymbol{M}$. We will dive into more theoretical details later, but for now, let’s try to intuitively perceive the following fact:

Muon is an adaptive learning rate optimizer similar to Adam.

The characteristic of adaptive learning rate optimizers like Adagrad, RMSprop, and Adam is that they adjust the update amount of each parameter by dividing by the square root of the moving average of the gradient squared. This achieves two effects: 1. Constant scaling of the loss function does not affect the optimization trajectory; 2. The update magnitude of each parameter component is as consistent as possible. Muon satisfies exactly these two properties:

1. If the loss function is multiplied by $\lambda$, $\boldsymbol{M}$ will also be multiplied by $\lambda$. Consequently, $\boldsymbol{\Sigma}$ is multiplied by $\lambda$, but since Muon's final update converts $\boldsymbol{\Sigma}$ into an identity matrix, it does not affect the optimization results;

2. When $\boldsymbol{M}$ is decomposed via SVD as $\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$, the different singular values in $\boldsymbol{\Sigma}$ reflect the "anisotropy" of $\boldsymbol{M}$. Setting them all to one makes the update more "isotropic," which serves to synchronize the update magnitudes.

By the way, regarding the second point, does it remind any readers of BERT-whitening? It should also be noted that Muon has a Nesterov version, which simply replaces $\text{msign}(\boldsymbol{M}_t)$ with $\text{msign}(\beta\boldsymbol{M}_t + \boldsymbol{G}_t)$ in the update rule, with everything else remaining consistent. For simplicity, we won't expand on that here.

(Archival note: It was later discovered that the 2015 paper "Stochastic Spectral Descent for Restricted Boltzmann Machines" already proposed an optimization algorithm nearly identical to Muon, then called "Stochastic Spectral Descent.")

The Sign Function

Using SVD, we can also prove the identity:

\begin{equation}\text{msign}(\boldsymbol{M}) = (\boldsymbol{M}\boldsymbol{M}^{\top})^{-1/2}\boldsymbol{M}= \boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}\label{eq:msign-id}\end{equation}

where ${}^{-1/2}$ is the inverse square root of the matrix; if it is not invertible, the pseudoinverse is taken. This identity helps us better understand why $\text{msign}$ is a matrix generalization of $\text{sign}$: for a scalar $x$, we have $\text{sign}(x)=x(x^2)^{-1/2}$, which is a special case of the above equation (when $\boldsymbol{M}$ is a $1 \times 1$ matrix). This special case can also be generalized to a diagonal matrix $\boldsymbol{M}=\text{diag}(\boldsymbol{m})$:

\begin{equation}\text{msign}(\boldsymbol{M}) = \text{diag}(\boldsymbol{m})[\text{diag}(\boldsymbol{m})^2]^{-1/2} = \text{diag}(\text{sign}(\boldsymbol{m}))=\text{sign}(\boldsymbol{M})\end{equation}

where $\text{sign}(\boldsymbol{m})$ and $\text{sign}(\boldsymbol{M})$ refer to taking the component-wise $\text{sign}$ of the vector/matrix. This implies that when $\boldsymbol{M}$ is a diagonal matrix, Muon degenerates into a momentum-based SignSGD (Signum) or the Tiger optimizer I proposed, both of which are classic approximations of Adam. Conversely, the difference between Muon and Signum/Tiger is that the element-wise $\text{sign}(\boldsymbol{M})$ is replaced by the matrix version $\text{msign}(\boldsymbol{M})$.

For an $n$-dimensional vector, we can also view it as an $n \times 1$ matrix, in which case $\text{msign}(\boldsymbol{m}) = \boldsymbol{m}/\Vert\boldsymbol{m}\Vert_2$ is exactly $l_2$ normalization. Thus, within the Muon framework, we have two perspectives for vectors: one is as a diagonal matrix (e.g., the gamma parameter in LayerNorm), resulting in taking the $\text{sign}$ of momentum; the other is as an $n \times 1$ matrix, resulting in $l_2$ normalization of momentum. Furthermore, although input and output Embeddings are matrices, they are used sparsely, so the more reasonable approach is to treat them as multiple independent vectors.

When $m=n=r$, $\text{msign}(\boldsymbol{M})$ also has the meaning of the "optimal orthogonal approximation":

\begin{equation}\text{msign}(\boldsymbol{M}) = \mathop{\text{argmin}}_{\boldsymbol{O}^{\top}\boldsymbol{O} = \boldsymbol{I}}\Vert \boldsymbol{M} - \boldsymbol{O}\Vert_F^2 \label{eq:nearest-orth}\end{equation}

Similarly, for $\text{sign}(\boldsymbol{M})$ we can write (assuming $\boldsymbol{M}$ has no zero elements):

\begin{equation}\text{sign}(\boldsymbol{M}) = \mathop{\text{argmin}}_{\boldsymbol{O}\in\{-1,1\}^{n\times m}}\Vert \boldsymbol{M} - \boldsymbol{O}\Vert_F^{2}\end{equation}

Whether $\boldsymbol{O}^{\top}\boldsymbol{O} = \boldsymbol{I}$ or $\boldsymbol{O}\in\{-1,1\}^{n\times m}$, we can view this as a regularization constraint on the update amount. Thus, Muon, Signum, and Tiger can be seen as optimizers following the same logic, building update amounts starting from momentum $\boldsymbol{M}$, just choosing different regularization methods for the update amount.

Proof of Equation $\eqref{eq:nearest-orth}$: For an orthogonal matrix $\boldsymbol{O}$, we have

\begin{equation} \begin{aligned} \Vert \boldsymbol{M} - \boldsymbol{O}\Vert_F^2 =&\, \Vert \boldsymbol{M}\Vert_F^2 + \Vert \boldsymbol{O}\Vert_F^2 - 2\langle\boldsymbol{M},\boldsymbol{O}\rangle_F \\[5pt] =&\, \Vert \boldsymbol{M}\Vert_F^2 + n - 2\text{Tr}(\boldsymbol{M}\boldsymbol{O}^{\top})\\[5pt] =&\, \Vert \boldsymbol{M}\Vert_F^2 + n - 2\text{Tr}(\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}\boldsymbol{O}^{\top})\\[5pt] =&\, \Vert \boldsymbol{M}\Vert_F^2 + n - 2\text{Tr}(\boldsymbol{\Sigma}\boldsymbol{V}^{\top}\boldsymbol{O}^{\top}\boldsymbol{U})\\ =&\, \Vert \boldsymbol{M}\Vert_F^2 + n - 2\sum_{i=1}^n \boldsymbol{\Sigma}_{i,i}(\boldsymbol{V}^{\top}\boldsymbol{O}^{\top}\boldsymbol{U})_{i,i} \end{aligned} \end{equation}

The operational rules involved here were introduced in the pseudoinverse post. Since $\boldsymbol{U},\boldsymbol{V},\boldsymbol{O}$ are all orthogonal matrices, $\boldsymbol{V}^{\top}\boldsymbol{O}^{\top}\boldsymbol{U}$ is also an orthogonal matrix. Any component of an orthogonal matrix cannot exceed 1. Since $\boldsymbol{\Sigma}_{i,i} > 0$, the minimum of the above expression corresponds to maximizing each $(\boldsymbol{V}^{\top}\boldsymbol{O}^{\top}\boldsymbol{U})_{i,i}$, i.e., $(\boldsymbol{V}^{\top}\boldsymbol{O}^{\top}\boldsymbol{U})_{i,i}=1$, which means $\boldsymbol{V}^{\top}\boldsymbol{O}^{\top}\boldsymbol{U}=\boldsymbol{I}$, or $\boldsymbol{O}=\boldsymbol{U}\boldsymbol{V}^{\top}$.

This conclusion can be carefully generalized to cases where $m, n, r$ are not equal, but we won't expand further here.

Iterative Solution

In practice, performing an SVD at every step to solve for $\text{msign}(\boldsymbol{M})$ would be computationally expensive. Therefore, the authors proposed using Newton-Schulz iteration to approximate $\text{msign}(\boldsymbol{M})$.

The starting point for the iteration is identity $\eqref{eq:msign-id}$. Without loss of generality, we assume $n \geq m$ and consider the Taylor expansion of $(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}$ at $\boldsymbol{M}^{\top}\boldsymbol{M}=\boldsymbol{I}$. The expansion method directly applies the result of the scalar function $t^{-1/2}$ to the matrix:

\begin{equation}t^{-1/2} = 1 - \frac{1}{2}(t-1) + \frac{3}{8}(t-1)^2 - \frac{5}{16}(t-1)^3 + \cdots\end{equation}

Retaining up to the second order, the result is $(15 - 10t + 3t^2)/8$, so we have

\begin{equation}\text{msign}(\boldsymbol{M}) = \boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}\approx \frac{15}{8}\boldsymbol{M} - \frac{5}{4}\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M}) + \frac{3}{8}\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^2\end{equation}

If $\boldsymbol{X}_t$ is some approximation of $\text{msign}(\boldsymbol{M})$, we assume that substituting it into the above expression yields a better approximation, thus obtaining a usable iterative format:

\begin{equation}\boldsymbol{X}_{t+1} = \frac{15}{8}\boldsymbol{X}_t - \frac{5}{4}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + \frac{3}{8}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2\end{equation}

However, checking Muon's official code, we find that while the Newton-Schulz iteration follows this form, the three coefficients are actually $(3.4445, -4.7750, 2.0315)$. The author did not provide a mathematical derivation, only a vague comment:

Muon Newton-Schulz Code
Newton-Schulz iteration in the Muon optimizer

Convergence Acceleration

To guess the origin of the official iteration algorithm, we consider a general iterative process:

\begin{equation}\boldsymbol{X}_{t+1} = a\boldsymbol{X}_t + b\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2\label{eq:iteration}\end{equation}

where $a, b, c$ are three coefficients to be solved. If a higher-order iterative algorithm is desired, we could successively add terms like $\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^3$ or $\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^4$. The following analysis is general.

We choose the initial value $\boldsymbol{X}_0=\boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$, where $\Vert\cdot\Vert_F$ is the Frobenius norm. The basis for this choice is that dividing by $\Vert\boldsymbol{M}\Vert_F$ does not change the $\boldsymbol{U}$ and $\boldsymbol{V}$ of the SVD, but ensures all singular values of $\boldsymbol{X}_0$ lie between $[0, 1]$, making the initial singular values more "standard." Now assume $\boldsymbol{X}_t$ can be decomposed via SVD as $\boldsymbol{U}\boldsymbol{\Sigma}_t\boldsymbol{V}^{\top}$. Substituting into the above equation, we get:

\begin{equation}\boldsymbol{X}_{t+1} = \boldsymbol{U}_{[:,:r]}(a \boldsymbol{\Sigma}_{t,[:r,:r]} + b \boldsymbol{\Sigma}_{t,[:r,:r]}^3 + c \boldsymbol{\Sigma}_{t,[:r,:r]}^5)\boldsymbol{V}_{[:,:r]}^{\top}\end{equation}

Therefore, Equation $\eqref{eq:iteration}$ is essentially iterating the diagonal matrix $\boldsymbol{\Sigma}_{[:r,:r]}$ of singular values. If we denote $\boldsymbol{X}_t=\boldsymbol{U}_{[:,:r]}\boldsymbol{\Sigma}_{t,[:r,:r]}\boldsymbol{V}_{[:,:r]}^{\top}$, then $\boldsymbol{\Sigma}_{t+1,[:r,:r]} = g(\boldsymbol{\Sigma}_{t,[:r,:r]})$, where $g(x) = ax + bx^3 + cx^5$. Since the power of a diagonal matrix equals the power of its diagonal elements, the problem simplifies to the iteration of a single singular value $\sigma$. Our target is to calculate $\boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top}$; in other words, we hope the singular values in $\boldsymbol{\Sigma}_{[:r,:r]}$ become identity via iteration, which simplifies to the iteration $\sigma_{t+1} = g(\sigma_t)$ mapping values to 1.

Inspired by @leloykun, we treat the choice of $a, b, c$ as an optimization problem, aiming to make the iteration converge as quickly as possible for any initial singular value. First, we reparameterize $g(x)$ as:

\begin{equation}g(x) = x + \kappa x(x^2 - x_1^2)(x^2 - x_2^2)\end{equation}

where $x_1 \leq x_2$. This parameterization intuitively shows the 5 fixed points of the iteration: $0, \pm x_1, \pm x_2$. As our target is to converge to 1, we choose initial values $x_1 < 1, x_2 > 1$. The idea is that whether the iteration moves toward $x_1$ or $x_2$, the result remains near 1.

Next, we fix the number of iterations $T$, making the iterative process a deterministic function. By specifying the matrix shape (i.e., $n, m$), we can sample a batch of matrices and calculate their singular values via SVD. Finally, we treat these singular values as inputs and the target output as 1, with the loss function being the mean squared error. This entire model is differentiable and can be solved using gradient descent (@leloykun assumed $x_1 + x_2 = 2$ and used grid search).

Some calculation results:

n m T $\kappa$ $x_1$ $x_2$ a b c mse $\text{mse}_o$
1024 1024 3 7.020 0.830 0.830 4.328 -9.666 7.020 0.10257 0.18278
1024 1024 5 1.724 0.935 1.235 3.297 -4.136 1.724 0.02733 0.04431
2048 1024 3 7.028 0.815 0.815 4.095 -9.327 7.028 0.01628 0.06171
2048 1024 5 1.476 0.983 1.074 2.644 -3.128 1.476 0.00038 0.02954
4096 1024 3 6.948 0.802 0.804 3.886 -8.956 6.948 0.00371 0.02574
4096 1024 5 1.214 1.047 1.048 2.461 -2.663 1.214 0.00008 0.02563
2048 2048 3 11.130 0.767 0.767 4.857 -13.103 11.130 0.10739 0.24410
2048 2048 5 1.779 0.921 1.243 3.333 -4.259 1.779 0.03516 0.04491
4096 4096 3 18.017 0.705 0.705 5.460 -17.929 18.017 0.11303 0.33404
4096 4096 5 2.057 0.894 1.201 3.373 -4.613 2.057 0.04700 0.06372
8192 8192 3 30.147 0.643 0.643 6.139 -24.893 30.147 0.11944 0.44843
8192 8192 5 2.310 0.871 1.168 3.389 -4.902 2.310 0.05869 0.07606

Here $\text{mse}_o$ is the result calculated using the $a, b, c$ from the Muon author. As seen from the table, results depend significantly on matrix size and the number of iterations; the loss function suggests that non-square matrices converge more easily than square ones. The $a, b, c$ provided by the Muon author seem to be the optimal solution for a square matrix when the number of iterations is 5. When the number of iterations is fixed, the result depends on the matrix size, which essentially depends on the distribution of singular values. A noteworthy result regarding this distribution as $n, m \to \infty$ is the Marchenko–Pastur distribution.

Reference code:

import jax
import jax.numpy as jnp
from tqdm import tqdm

n, m, T = 1024, 1024, 5
key, data = jax.random.key(42), jnp.array([])
for _ in tqdm(range(1000), ncols=0, desc='SVD'):
    key, subkey = jax.random.split(key)
    M = jax.random.normal(subkey, shape=(n, m))
    S = jnp.linalg.svd(M, full_matrices=False)[1]
    data = jnp.concatenate([data, S / (S**2).sum()**0.5])

@jax.jit
def f(w, x):
    k, x1, x2 = w
    for _ in range(T):
        x = x + k * x * (x**2 - x1**2) * (x**2 - x2**2)
    return ((x - 1)**2).mean()

f_grad = jax.grad(f)
w, u = jnp.array([1, 0.9, 1.1]), jnp.zeros(3)
for _ in tqdm(range(100000), ncols=0, desc='SGD'):
    u = 0.9 * u + f_grad(w, data)  # Momentum acceleration
    w = w - 0.01 * u

k, x1, x2 = w
a, b, c = 1 + k * x1**2 * x2**2, -k * (x1**2 + x2**2), k
print(f'{n} & {m} & {T} & {k:.3f} & {x1:.3f} & {x2:.3f} & {a:.3f} & {b:.3f} & {c:.3f} & {f(w, data):.5f}')

Some Reflections

If we follow the default choice of $T=5$, then for an $n \times n$ matrix parameter, Muon's each update step requires at least 15 matrix multiplications of $n \times n$ by $n \times n$. This computational overhead is undoubtedly significantly larger than Adam's, leading some readers to worry about whether Muon is practical.

In fact, this concern is unnecessary. Although Muon's calculation is more complex than Adam's, the additional time per step is small; my conclusion is within 5%, and the Muon author claims it can be as low as 2%. This is because the matrix multiplications in Muon happen after the current gradient computation is finished and before the next gradient computation starts. During this period, almost all computational power is idle, and these matrix multiplications have static sizes and can be parallelized, hence they do not significantly increase time costs. Moreover, Muon requires one fewer set of cache variables than Adam, leading to lower memory costs.

The most thought-provoking aspect of Muon is actually the intrinsic difference between vectors and matrices and its impact on optimization. Common optimizers like SGD, Adam, and Tiger have element-wise update rules; whether parameters are vectors or matrices, they are treated as one large vector, with components updated independently following the same rules. Optimizers with this property are often easier to analyze theoretically and facilitate tensor parallelism, as splitting a large matrix into two small matrices to be handled independently doesn't change the optimization trajectory.

But Muon is different. It takes the matrix as the basic unit, considering some unique properties of matrices. Some readers might find it strange: aren't matrices and vectors just arrangements of numbers, what difference could there be? Take the concept of "trace" as an example: it is the sum of diagonal elements. This concept wasn't chosen randomly; it has an important property of being invariant under similarity transformations and is equal to the sum of all eigenvalues. From this, we can see that diagonal elements and non-diagonal elements of a matrix do not have perfectly equal status. Muon's superior performance stems precisely from considering this inequality.

Of course, this also leads to some negative impacts. If a matrix is partitioned across different devices, then using Muon would require gathering their gradients before calculating the update amount, rather than having each device update independently, which increases communication costs. Even ignoring parallelism, this problem exists. For example, Multi-Head Attention is generally projected into $Q$ (and $K, V$) via a single large matrix, then reshaped into multiple heads. While there is only one matrix in the model parameters, it is essentially multiple small matrices; thus, theoretically, we should split the large matrix into multiple small matrices to update them independently.

In short, Muon's non-element-wise update rule effectively captures the fundamental differences between vectors and matrices but also introduces some minor issues, which might not suit the aesthetic tastes of some readers.

(Supplement: Almost concurrently with the publication of this blog, Keller Jordan, the author of Muon, also released his own blog post: "Muon: An optimizer for hidden layers in neural networks".)

The Norm Perspective

Theoretically, what key characteristic of matrices does Muon capture? Perhaps the norm perspective can provide the answer.

The discussion in this section is primarily based on the papers "Stochastic Spectral Descent for Discrete Graphical Models" and "Old Optimizer, New Norm: An Anthology," particularly the latter. However, the starting point is not new; we briefly touched upon it in "Gradient Flow: Exploring the Path to the Minimum": for a vector parameter $\boldsymbol{w} \in \mathbb{R}^n$, we define the next update rule as:

\begin{equation}\boldsymbol{w}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{w}} \frac{\Vert\boldsymbol{w} - \boldsymbol{w}_t\Vert^2}{2\eta_t} + \mathcal{L}(\boldsymbol{w})\end{equation}

where $\Vert\cdot\Vert$ is some vector norm; this is known as "steepest gradient descent" under a specific norm constraint. Assuming $\eta_t$ is sufficiently small, the first term dominates, meaning $\boldsymbol{w}_{t+1}$ will be very close to $\boldsymbol{w}_t$. If we assume a first-order approximation of $\mathcal{L}(\boldsymbol{w})$ is sufficient, the problem simplifies to:

\begin{equation}\boldsymbol{w}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{w}} \frac{\Vert\boldsymbol{w} - \boldsymbol{w}_t\Vert^2}{2\eta_t} + \mathcal{L}(\boldsymbol{w}_t) + \nabla_{\boldsymbol{w}_t}\mathcal{L}(\boldsymbol{w}_t)^{\top}(\boldsymbol{w}-\boldsymbol{w}_t)\end{equation}

Let $\Delta\boldsymbol{w}_{t+1} = \boldsymbol{w}_{t+1}-\boldsymbol{w}_t$ and $\boldsymbol{g}_t = \nabla_{\boldsymbol{w}_t}\mathcal{L}(\boldsymbol{w}_t)$. This can be rewritten as:

\begin{equation}\Delta\boldsymbol{w}_{t+1} = \mathop{\text{argmin}}_{\Delta\boldsymbol{w}} \frac{\Vert\Delta\boldsymbol{w}\Vert^2}{2\eta_t} + \boldsymbol{g}_t^{\top}\Delta\boldsymbol{w}\end{equation}

A general way to calculate $\Delta\boldsymbol{w}_{t+1}$ is through differentiation, but "Old Optimizer, New Norm: An Anthology" provides a unified solution without differentiation: decompose $\Delta\boldsymbol{w}$ into norm $\gamma = \Vert\Delta\boldsymbol{w}\Vert$ and a direction vector $\boldsymbol{\varphi} = -\Delta\boldsymbol{w}/\Vert\Delta\boldsymbol{w}\Vert$. Thus:

\begin{equation}\min_{\Delta\boldsymbol{w}} \frac{\Vert\Delta\boldsymbol{w}\Vert^2}{2\eta_t} + \boldsymbol{g}_t^{\top}\Delta\boldsymbol{w} = \min_{\gamma\geq 0, \Vert\boldsymbol{\varphi}\Vert=1} \frac{\gamma^2}{2\eta_t} - \gamma\boldsymbol{g}_t^{\top}\boldsymbol{\varphi} = \min_{\gamma\geq 0} \frac{\gamma^2}{2\eta_t} - \gamma\bigg(\underbrace{\max_{\Vert\boldsymbol{\varphi}\Vert=1}\boldsymbol{g}_t^{\top}\boldsymbol{\varphi}}_{\text{denoted as }\Vert \boldsymbol{g}_t\Vert^{\dagger}}\bigg)\end{equation}

Since $\gamma$ is just a scalar similar to the learning rate, it's easy to find the optimal value as $\eta_t\Vert \boldsymbol{g}_t\Vert^{\dagger}$, while the update direction $\boldsymbol{\varphi}^*$ maximizes $\boldsymbol{g}_t^{\top}\boldsymbol{\varphi}$ subject to $\Vert\boldsymbol{\varphi}\Vert=1$. Substituting the Euclidean norm $\Vert\boldsymbol{\varphi}\Vert_2 = \sqrt{\boldsymbol{\varphi}^{\top}\boldsymbol{\varphi}}$, we have $\Vert \boldsymbol{g}_t\Vert^{\dagger}=\Vert \boldsymbol{g}_t\Vert_2$ and $\boldsymbol{\varphi}^* = \boldsymbol{g}_t/\Vert \boldsymbol{g}_t\Vert_2$. In this case, $\Delta\boldsymbol{w}_{t+1}=-\eta_t \boldsymbol{g}_t$, which is Gradient Descent (SGD). Generally, for a $p$-norm:

\begin{equation}\Vert\boldsymbol{\varphi}\Vert_p = \sqrt[p]{\sum_{i=1}^n |\varphi_i|^p}\end{equation}

Hölder's inequality gives $\boldsymbol{g}^{\top}\boldsymbol{\varphi} \leq \Vert \boldsymbol{g}\Vert_q \Vert \boldsymbol{\varphi}\Vert_p$, where $1/p + 1/q = 1$. Utilizing this, we obtain:

\begin{equation}\max_{\Vert\boldsymbol{\varphi}\Vert_p=1}\boldsymbol{g}^{\top}\boldsymbol{\varphi} = \Vert \boldsymbol{g}\Vert_q\end{equation}

The equality holds when:

\begin{equation}\boldsymbol{\varphi}^* = \frac{1}{\Vert\boldsymbol{g}\Vert_q^{q/p}}\Big[\text{sign}(g_1) |g_1|^{q/p},\text{sign}(g_2) |g_2|^{q/p},\cdots,\text{sign}(g_n) |g_n|^{q/p}\Big]\end{equation}

The optimizer using this direction is called pbSGD, see "pbSGD: Powered Stochastic Gradient Descent Methods for Accelerated Non-Convex Optimization." Notably, as $p \to \infty$, we have $q \to 1$ and $|g_i|^{q/p} \to 1$, which degenerates into SignSGD. This means SignSGD is actually steepest gradient descent under the $\Vert\cdot\Vert_{\infty}$ norm.

Matrix Norms

Now let's switch our focus to matrix parameters $\boldsymbol{W} \in \mathbb{R}^{n \times m}$. Similarly, we define its update rule as:

\begin{equation}\boldsymbol{W}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{W}} \frac{\Vert\boldsymbol{W} - \boldsymbol{W}_t\Vert^2}{2\eta_t} + \mathcal{L}(\boldsymbol{W})\end{equation}

where $\Vert\cdot\Vert$ is some matrix norm. Using a first-order approximation again, we get:

\begin{equation}\Delta\boldsymbol{W}_{t+1} = \mathop{\text{argmin}}_{\Delta\boldsymbol{W}} \frac{\Vert\Delta\boldsymbol{W}\Vert^2}{2\eta_t} + \text{Tr}(\boldsymbol{G}_t^{\top}\Delta\boldsymbol{W})\end{equation}

where $\Delta\boldsymbol{W}_{t+1} = \boldsymbol{W}_{t+1}-\boldsymbol{W}_t$ and $\boldsymbol{G}_t = \nabla_{\boldsymbol{W}_t}\mathcal{L}(\boldsymbol{W}_t)$. Again using the "norm-direction" decoupling, with $\gamma = \Vert\Delta\boldsymbol{W}\Vert$ and $\boldsymbol{\Phi} = -\Delta\boldsymbol{W}/\Vert\Delta\boldsymbol{W}\Vert$, we get:

\begin{equation}\min_{\Delta\boldsymbol{W}} \frac{\Vert\Delta\boldsymbol{W}\Vert^2}{2\eta_t} + \text{Tr}(\boldsymbol{G}_t^{\top}\Delta\boldsymbol{W}) = \min_{\gamma\geq 0} \frac{\gamma^2}{2\eta_t} - \gamma\bigg(\underbrace{\max_{\Vert\boldsymbol{\Phi}\Vert=1}\text{Tr}(\boldsymbol{G}_t^{\top}\boldsymbol{\Phi})}_{\text{denoted as }\Vert \boldsymbol{G}_t\Vert^{\dagger}}\bigg)\end{equation}

Then we analyze specific norms. There are two commonly used matrix norms: one is the Frobenius norm (F-norm), which is essentially the Euclidean norm of the flattened matrix. In this case, the conclusion is the same as for vectors: the answer is SGD. The other is the 2-norm induced by vector norms, also known as the spectral norm:

\begin{equation}\Vert \boldsymbol{\Phi}\Vert_2 = \max_{\Vert \boldsymbol{x}\Vert_2 = 1} \Vert \boldsymbol{\Phi}\boldsymbol{x}\Vert_2\end{equation}

Note that the $\Vert\cdot\Vert_2$ on the right side applies to vectors, so the definition is clear. More discussion on the 2-norm can be found in "Lipschitz Continuity in Deep Learning" and "The Path to Low Rank Approximation (2): SVD." Since the 2-norm is induced by "matrix-vector" multiplication, it more appropriately fits matrix multiplication, and it always holds that $\Vert\boldsymbol{\Phi}\Vert_2\leq \Vert\boldsymbol{\Phi}\Vert_F$, meaning the 2-norm provides a tighter measure than the F-norm.

So, we proceed with calculations for the 2-norm. Let the SVD of $\boldsymbol{G}$ be $\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} = \sum_{i=1}^r \sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top}$. we have:

\begin{equation}\text{Tr}(\boldsymbol{G}^{\top}\boldsymbol{\Phi})=\text{Tr}\Big(\sum_{i=1}^r \sigma_i \boldsymbol{v}_i \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\Big) = \sum_{i=1}^r \sigma_i \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i\end{equation}

By definition, when $\Vert\boldsymbol{\Phi}\Vert_2=1$, $\Vert\boldsymbol{\Phi}\boldsymbol{v}_i\Vert_2\leq \Vert\boldsymbol{v}_i\Vert_2=1$. Thus $\boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i\leq 1$, and therefore:

\begin{equation}\text{Tr}(\boldsymbol{G}^{\top}\boldsymbol{\Phi})\leq \sum_{i=1}^r \sigma_i\end{equation}

The equality is reached when all $\boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i$ are equal to 1, in which case:

\begin{equation}\boldsymbol{\Phi} = \sum_{i=1}^r \boldsymbol{u}_i \boldsymbol{v}_i^{\top} = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top} = \text{msign}(\boldsymbol{G})\end{equation}

Thus, we have proved that gradient descent under the 2-norm penalty is exactly the Muon optimizer when $\beta=0$! When $\beta > 0$, the moving average takes effect, which we can treat as a more accurate estimate of the gradient, thus applying $\text{msign}$ to the momentum instead. Overall, Muon is equivalent to gradient descent under the 2-norm constraint. The 2-norm better captures the intrinsic differences between matrices, allowing each step to be more precise and fundamental.

Tracing the Origins

There is a much earlier related work to Muon called "Shampoo: Preconditioned Stochastic Tensor Optimization." This 2018 paper proposed an optimizer named Shampoo, which shares similarities with Muon.

The strategy of adaptive learning rates via the average of squared gradients, first proposed in the Adagrad paper "Adaptive Subgradient Methods for Online Learning and Stochastic Optimization," initially suggested directly accumulating the squared gradients, which corresponds to a global uniform average. Subsequent RMSProp and Adam adapted this with a 设计 similar to momentum, using a moving average, which was found to perform better in practice.

Furthermore, Adagrad originally proposed accumulating the outer product $\boldsymbol{g}\boldsymbol{g}^{\top}$. However, because caching the outer product consumes too much spatial cost, the Hadamard product $\boldsymbol{g}\odot\boldsymbol{g}$ was used in practice. What is the theoretical basis for accumulating outer products? As derived in "Adaptive Learning Rate Optimizers from the Perspective of Hessian Approximation," the answer is that the long-term average of the gradient outer product $\mathbb{E}[\boldsymbol{g}\boldsymbol{g}^{\top}]$ approximates the square of the Hessian matrix $\sigma^2\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}^2$, so this is actually an approximation of the second-order Newton method.

Shampoo inherited the idea of caching outer products from Adagrad but struck a compromise considering cost. Like Muon, it optimizes matrices (and higher-order tensors) specifically. Its strategy is to cache matrix products $\boldsymbol{G}\boldsymbol{G}^{\top}$ and $\boldsymbol{G}^{\top}\boldsymbol{G}$ instead of outer products, bringing spatial cost to $\mathcal{O}(n^2 + m^2)$ rather than $\mathcal{O}(n^2 m^2)$:

\begin{equation} \begin{aligned} \boldsymbol{L}_t =&\, \beta\boldsymbol{L}_{t-1} + \boldsymbol{G}_t\boldsymbol{G}_t^{\top} \\[5pt] \boldsymbol{R}_t =&\, \beta\boldsymbol{R}_{t-1} + \boldsymbol{G}_t^{\top}\boldsymbol{G}_t \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t \boldsymbol{L}_t^{-1/4}\boldsymbol{G}_t\boldsymbol{R}_t^{-1/4} \\ \end{aligned} \end{equation}

The $\beta$ here was added by me for illustration; Shampoo defaults $\beta=1$. The ${}^{-1/4}$ exponent refers to matrix power operations, which can be completed via SVD. Since Shampoo did not propose an approximate scheme like Newton-Schulz iteration and directly used SVD, to save computing costs, it does not calculate $\boldsymbol{L}_t^{-1/4}$ and $\boldsymbol{R}_t^{-1/4}$ at every step, but only updates them at fixed step intervals.

Specifically, when $\beta=0$, the update vector in Shampoo is $(\boldsymbol{G}\boldsymbol{G}^{\top})^{-1/4}\boldsymbol{G}(\boldsymbol{G}^{\top}\boldsymbol{G})^{-1/4}$. By performing SVD on $\boldsymbol{G}$, we can prove that:

\begin{equation}(\boldsymbol{G}\boldsymbol{G}^{\top})^{-1/4}\boldsymbol{G}(\boldsymbol{G}^{\top}\boldsymbol{G})^{-1/4} = (\boldsymbol{G}\boldsymbol{G}^{\top})^{-1/2}\boldsymbol{G}= \boldsymbol{G}(\boldsymbol{G}^{\top}\boldsymbol{G})^{-1/2}=\text{msign}(\boldsymbol{G})\end{equation}

This shows that when $\beta=0$, Shampoo and Muon are theoretically equivalent! Therefore, Shampoo and Muon share a common ground in terms of update amount design.

Summary

This article introduced the Muon optimizer, which has recently become a hot topic on Twitter. It is specifically tailored for matrix parameters and currently appears to be more efficient than AdamW. Moreover, it seems to embody some fundamental differences between vectorization and matrixization, making it worthy of study and reflection.