Efficient Calculation of Matrix Square Root and Inverse Square Root

By 苏剑林 | July 19, 2025

Let \(\boldsymbol{P}\in\mathbb{R}^{n\times n}\) be a square matrix of order \(n\) whose eigenvalues are all non-negative real numbers. In this article, we discuss the calculation of its square root \(\boldsymbol{P}^{1/2}\) and its inverse square root \(\boldsymbol{P}^{-1/2}\).

Basic Concepts

The square root of a matrix \(\boldsymbol{P}\) refers to a matrix \(\boldsymbol{X}\) that satisfies \(\boldsymbol{X}^2=\boldsymbol{P}\). We know that positive numbers have two square roots, so it is not difficult to imagine that matrix square roots are generally not unique. However, the "arithmetic square root" is unique. The arithmetic square root of a positive number is the positive square root; similarly, we define the arithmetic square root of \(\boldsymbol{P}\) as the square root whose eigenvalues are all non-negative. The matrix square roots sought in this article default to the arithmetic square root.

The calculations in this article rely on the matrix sign function \(\text{mcsgn}\) discussed in "What can the matrix sign function mcsgn calculate?":

\begin{equation}\newcommand{mcsgn}{\mathop{\text{mcsgn}}}\mcsgn(\boldsymbol{M}) = (\boldsymbol{M}^2)^{-1/2}\boldsymbol{M}= \boldsymbol{M}(\boldsymbol{M}^2)^{-1/2}\end{equation}

Simply put, it is a new matrix obtained by transforming the eigenvalues of any matrix \(\boldsymbol{M}\in\mathbb{R}^{n\times n}\) into their corresponding sign function values. Assuming the eigenvalues of \(\boldsymbol{M}\) are all real numbers, \(\mcsgn\) can be efficiently calculated via the Newton-Schulz iteration:

\begin{equation}\newcommand{tr}{\mathop{\text{tr}}}\boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\sqrt{\tr(\boldsymbol{M}^2)}},\qquad \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t^3 + c_{t+1}\boldsymbol{X}_t^5\end{equation}

where \(\frac{\boldsymbol{M}}{\sqrt{\tr(\boldsymbol{M}^2)}}\) is used to scale all eigenvalues of \(\boldsymbol{X}_0\) into the range \([-1,1]\), and \(a_t, b_t, c_t\) are the coefficients derived in "Newton-Schulz Iteration for the msign Operator (Part II)":

\begin{array}{c|ccc} \hline t & a\times 1.01 & b\times 1.01^3 & c\times 1.01^5 \\ \hline \quad 1\quad & 8.28721 & -23.5959 & 17.3004 \\ 2 & 4.10706 & -2.94785 & 0.544843 \\ 3 & 3.94869 & -2.9089 & 0.551819 \\ 4 & 3.31842 & -2.48849 & 0.510049 \\ 5 & 2.30065 & -1.6689 & 0.418807 \\ 6 & 1.8913 & -1.268 & 0.376804 \\ 7 & 1.875 & -1.25 & 0.375 \\ 8 & 1.875 & -1.25 & 0.375 \\ \hline \end{array}

In fact, when the eigenvalues of \(\boldsymbol{M}\) are all real, the calculation principle of \(\mcsgn\) is consistent with another matrix sign function \(\newcommand{msign}{\mathop{\text{msign}}}\msign\).

Computational Principle

The starting point for the following calculation is the identity:

\begin{equation}\mcsgn\left(\begin{bmatrix}\boldsymbol{0} & \boldsymbol{A} \\ \boldsymbol{B} & \boldsymbol{0}\end{bmatrix}\right)=\begin{bmatrix}\boldsymbol{0} & \boldsymbol{A}(\boldsymbol{B}\boldsymbol{A})^{-1/2} \\ \boldsymbol{B}(\boldsymbol{A}\boldsymbol{B})^{-1/2} & \boldsymbol{0}\end{bmatrix}\label{eq:core}\end{equation}

This can be verified by substituting the definition of \(\mcsgn\) directly (Note: \(\boldsymbol{A}, \boldsymbol{B}\) need not be square matrices). Next, we need to determine under what conditions the eigenvalues of the matrix on the left side are all real numbers. Let \(\lambda\) be a non-zero eigenvalue; then:

\begin{equation}0=\det\left(\lambda\boldsymbol{I} - \begin{bmatrix}\boldsymbol{0} & \boldsymbol{A} \\ \boldsymbol{B} & \boldsymbol{0} \end{bmatrix}\right) = \det\left(\begin{bmatrix}\lambda\boldsymbol{I} & -\boldsymbol{A} \\ -\boldsymbol{B} & \lambda\boldsymbol{I} \end{bmatrix}\right) = \det(\lambda^2 \boldsymbol{I} - \boldsymbol{A}\boldsymbol{B})\end{equation}

That is, \(\lambda^2\) is an eigenvalue of the matrix \(\boldsymbol{A}\boldsymbol{B}\). This means that all eigenvalues of the above block matrix are real if and only if all eigenvalues of \(\boldsymbol{A}\boldsymbol{B}\) are non-negative.

Iterating directly on the original matrix is possible, but it is computationally wasteful. We can exploit its anti-diagonal structure to reduce complexity. Since:

\begin{equation} \begin{bmatrix}\boldsymbol{0} & \boldsymbol{Y} \\ \boldsymbol{Z} & \boldsymbol{0}\end{bmatrix}^3 = \begin{bmatrix}\boldsymbol{0} & (\boldsymbol{Y}\boldsymbol{Z})\boldsymbol{Y} \\ \boldsymbol{Z}(\boldsymbol{Y}\boldsymbol{Z}) & \boldsymbol{0}\end{bmatrix},\quad \begin{bmatrix}\boldsymbol{0} & \boldsymbol{Y} \\ \boldsymbol{Z} & \boldsymbol{0}\end{bmatrix}^5 = \begin{bmatrix}\boldsymbol{0} & (\boldsymbol{Y}\boldsymbol{Z})^2\boldsymbol{Y} \\ \boldsymbol{Z}(\boldsymbol{Y}\boldsymbol{Z})^2 & \boldsymbol{0}\end{bmatrix} \\ \end{equation}

We can obtain the iteration:

\begin{gather} \boldsymbol{Y}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\boldsymbol{Y}_t \label{eq:r1} \\[6pt] \boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2) \label{eq:r2} \end{gather}

Then \(\boldsymbol{Y}_t\to \boldsymbol{A}(\boldsymbol{B}\boldsymbol{A})^{-1/2}\) and \(\boldsymbol{Z}_t\to \boldsymbol{B}(\boldsymbol{A}\boldsymbol{B})^{-1/2}\). Specifically, multiplying the above two equations yields the recursion for \(\boldsymbol{Y}_t\boldsymbol{Z}_t\):

\begin{equation}\boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t\label{eq:r3}\end{equation}

Calculating Square Root

Now we proceed to the calculation of the square root. Since we assume the eigenvalues of \(\boldsymbol{P}\) are non-negative, we can always compress its eigenvalues to between \(0\sim 1\) by dividing by \(\tr(\boldsymbol{P})\). Therefore, without loss of generality, we assume the eigenvalues of \(\boldsymbol{P}\) are in \([0,1]\), allowing for the direct use of the Newton-Schulz iteration to calculate \(\mcsgn\).

Substituting \(\boldsymbol{A}=\boldsymbol{P}, \boldsymbol{B}=\boldsymbol{I}\) into equation \eqref{eq:core}, we get:

\begin{equation}\mcsgn\left(\begin{bmatrix}\boldsymbol{0} & \boldsymbol{P} \\ \boldsymbol{I} & \boldsymbol{0}\end{bmatrix}\right)=\begin{bmatrix}\boldsymbol{0} & \boldsymbol{P}^{1/2} \\ \boldsymbol{P}^{-1/2} & \boldsymbol{0}\end{bmatrix}\end{equation}

This is quite remarkable: in theory, just one \(\mcsgn\) operation allows both the square root and the inverse square root to be calculated simultaneously following iterations \eqref{eq:r1} and \eqref{eq:r2}!

However, in practice, this is not always ideal. If \(\boldsymbol{P}\) has singular values very close to 0, then \(\boldsymbol{P}^{-1/2}\) will blow up numerically (equivalent to a \(1/\sqrt{0}\) scenario), whereas \(\boldsymbol{P}^{1/2}\) will not. Thus, if we only care about the value of \(\boldsymbol{P}^{1/2}\), calculating both \(\boldsymbol{P}^{1/2}\) and \(\boldsymbol{P}^{-1/2}\) simultaneously may increase numerical instability. A better approach is to use equations \eqref{eq:r1} and \eqref{eq:r3} to iterate, calculating only \(\boldsymbol{P}^{1/2}\):

\begin{gather} \boldsymbol{Y}_0 = \boldsymbol{P}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{Y}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\boldsymbol{Y}_t \\[6pt] \boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \\[6pt] \lim_{t\to\infty} \boldsymbol{Y}_t = \boldsymbol{P}^{1/2}\notag \end{gather}

Since the limit of \(\boldsymbol{Z}_t\) is \(\boldsymbol{P}^{-1/2}\), the limit of \(\boldsymbol{Y}_t\boldsymbol{Z}_t\) is \(\boldsymbol{I}\), making the iteration for \(\boldsymbol{Y}_t\boldsymbol{Z}_t\) less prone to numerical risks. Reference code is as follows:

import numpy as np

def abc(steps):
    coefs = [
        (8.287212018145622, -23.59588651909882, 17.300387312530923),
        (4.107059111542197, -2.9478499167379084, 0.54484310829266),
        (3.9486908534822938, -2.908902115962947, 0.5518191394370131),
        (3.3184196573706055, -2.488488024314878, 0.5100489401237208),
        (2.3006520199548186, -1.6689039845747518, 0.4188073119525678),
        (1.8913014077874002, -1.2679958271945908, 0.37680408948524996),
        (1.875, -1.25, 0.375)
    ]
    for a, b, c in coefs[:steps] + max(steps - 7, 0) * [coefs[-1]]:
        yield a / 1.01, b / 1.01**3, c / 1.01**5

def msqrt(P, steps=6):
    Y = YZ = P / (t := np.trace(P))
    I = np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Y, YZ = W @ Y, W @ W @ YZ
    return Y * t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
np.abs(msqrt(P) @ msqrt(P) - P).mean() # ~= 2e-4

Inverse Square Root

If we must explicitly calculate the inverse square root \(\boldsymbol{P}^{-1/2}\), there is no magic solution; the numerical explosion will happen where it must. In this case, whether we use the combination of equations \eqref{eq:r2} and \eqref{eq:r1} or the combination of \eqref{eq:r2} and \eqref{eq:r3}, the outcome should be similar, though the latter might be slightly more stable:

\begin{gather} \boldsymbol{Z}_0 = \boldsymbol{I}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\label{eq:r2-rsqrt} \\[6pt] \boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t\label{eq:r3-rsqrt} \\[6pt] \lim_{t\to\infty} \boldsymbol{Z}_t = \boldsymbol{P}^{-1/2}\notag \end{gather}

Reference code is as follows:

def mrsqrt(P, steps=6):
    YZ = P / (t := np.trace(P))
    Z = I = np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Z, YZ = Z @ W, W @ W @ YZ
    return Z / t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
np.abs(mrsqrt(P) @ mrsqrt(P) @ P - np.eye(d)).mean() # ~= 5e-4

Matrix Multiplication

However, most of the time, seeking \(\boldsymbol{P}^{-1/2}\) is just an intermediate step, usually followed by multiplication with another matrix. Let the matrix be \(\boldsymbol{G}\in\mathbb{R}^{m\times n}\); we need to calculate \(\boldsymbol{G}\boldsymbol{P}^{-1/2}\). If we can treat \(\boldsymbol{G}\boldsymbol{P}^{-1/2}\) as a single iterative object, it often yields better numerical stability compared to calculating \(\boldsymbol{P}^{-1/2}\) separately and then performing the multiplication.

Observing equations \eqref{eq:r2-rsqrt} and \eqref{eq:r3-rsqrt} closely, it is not difficult to see that when \(\boldsymbol{Y}_t\boldsymbol{Z}_t\) is viewed as a whole, its iteration in \eqref{eq:r3-rsqrt} is independent of \(\boldsymbol{Z}_t\). Thus, equation \eqref{eq:r2-rsqrt} for \(\boldsymbol{Z}_t\) is essentially just a linear recursion! Multiplying it by a matrix on the left does not change the form of the iteration; we only need to modify the initial value:

\begin{gather} \boldsymbol{Z}_0 = \boldsymbol{G}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2) \label{eq:r2-final} \\[6pt] \boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \label{eq:r3-final}\\[6pt] \lim_{t\to\infty} \boldsymbol{Z}_t = \boldsymbol{G}\boldsymbol{P}^{-1/2}\notag \end{gather}

Reference code:

import scipy as sp

def matmul_mrsqrt(G, P, steps=6):
    YZ = P / (t := np.trace(P))
    Z, I = G, np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Z, YZ = Z @ W, W @ W @ YZ
    return Z / t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
G = np.random.randn(2 * d, d) / d**0.5
X = matmul_mrsqrt(G, P)
np.abs(X @ sp.linalg.sqrtm(P) - G).mean() # ~= 1e-4

Looking back at the square root algorithm, it is easy to see that it is simply another equivalent expression of the iteration in this section when \(\boldsymbol{G}=\boldsymbol{P}\), i.e., \(\boldsymbol{P}^{1/2}=\boldsymbol{P}\boldsymbol{P}^{-1/2}\). Thus, although we have discussed three iterations in separate sections, they are all special cases of this final iteration!

Ultimate Generalization

Finally, we can generalize this to the calculation of \(\boldsymbol{Q}^{-1/2}\boldsymbol{G}\boldsymbol{P}^{-1/2}\), where \(\boldsymbol{Q}\in\mathbb{R}^{m\times m}\) is another matrix with non-negative eigenvalues. The result is as follows:

\begin{gather} \boldsymbol{G}_0 = \boldsymbol{G}, \quad \boldsymbol{Q}_0 = \boldsymbol{Q},\quad \boldsymbol{P}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{G}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Q}_t + c_{t+1}\boldsymbol{Q}_t^2)\boldsymbol{G}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) \\[6pt] \boldsymbol{Q}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Q}_t + c_{t+1}\boldsymbol{Q}_t^2)^2\boldsymbol{Q}_t \\[6pt] \boldsymbol{P}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^2\boldsymbol{P}_t \\[6pt] \lim_{t\to\infty} \boldsymbol{G}_t = \boldsymbol{Q}^{-1/2}\boldsymbol{G}\boldsymbol{P}^{-1/2}\notag \end{gather}

Reference code:

def mrsqrt_matmul_mrsqrt(Q, G, P, steps=6):
    Q = Q / (t1 := np.trace(Q))
    P = P / (t2 := np.trace(P))
    I1, I2 = np.eye(Q.shape[0]), np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W1 = a * I1 + b * Q + c * Q @ Q
        W2 = a * I2 + b * P + c * P @ P
        G, Q, P = W1 @ G @ W2, W1 @ W1 @ Q, W2 @ W2 @ P
    return G / (t1 * t2)**0.5

d = 100
Q = (x := np.random.randn(2 * d, 2 * d) / (2 * d)**0.5) @ x.T
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
G = np.random.randn(2 * d, d) / d**0.5
X = mrsqrt_matmul_mrsqrt(Q, G, P)
np.abs(sp.linalg.sqrtm(Q) @ X @ sp.linalg.sqrtm(P) - G).mean() # ~= 2e-3

Readers are invited to complete the proof using the previous sections' results.

For the Shampoo optimizer, we need to calculate \(\boldsymbol{Q}^{-1/4}\boldsymbol{G}\boldsymbol{P}^{-1/4}\). Currently, a feasible plan is to calculate \(\boldsymbol{Q}^{1/2}\) and \(\boldsymbol{P}^{1/2}\) first, then substitute them into the above iteration to find \((\boldsymbol{Q}^{1/2})^{-1/2}\boldsymbol{G}(\boldsymbol{P}^{1/2})^{-1/2}\). While the computational cost looks high, in the optimizer's update phase, computing power is often not the bottleneck as long as the algorithm is fully parallelizable. Conveniently, the calculations for \(\boldsymbol{Q}^{1/2}\) and \(\boldsymbol{P}^{1/2}\) can be computed in parallel, as can the two matrices W1 and W2 during the iteration process, so this should be acceptable.

Of course, it is certainly slower than Muon, as the complexity of Shampoo is significantly higher; such a cost is inevitable (see "Efficient calculation of matrix r-th roots and inverse r-th roots" for follow-up).

Summary

This article proposes translating the matrix square root and inverse square root into the \(\mcsgn\) form, utilizing its Newton-Schulz iteration to achieve efficient calculation.