By 苏剑林 | July 21, 2025
In the previous article "Efficient Calculation of Matrix Square Root and Inverse Square Root", starting from the $\mcsgn$ operator, I proposed a very elegant calculation method for the matrix square root and inverse square root. Interestingly, after simplifying the scheme, the final formulas no longer resemble the original $\mcsgn$ form. This leads to deeper thinking: what is the more fundamental principle behind this scheme? Is there a possibility of generalizing it to an arbitrary $r$-th root?
Analyzing from this perspective, I surprisingly discovered that we can understand the previous iteration algorithm from a simpler angle, and from this new perspective, it can be easily generalized to the calculation of arbitrary $r$-th roots and inverse $r$-th roots. Next, we will share this process.
Previous Review #
Let $\boldsymbol{G}\in\mathbb{R}^{m\times n}$ be an arbitrary matrix, and $\boldsymbol{P}\in\mathbb{R}^{n\times n}$ be any matrix whose eigenvalues are all within $[0,1]$. The previous article provided:
\begin{gather}
\boldsymbol{G}_0 = \boldsymbol{G}, \quad \boldsymbol{P}_0 = \boldsymbol{P} \notag\\[6pt]
\boldsymbol{G}_{t+1} = \boldsymbol{G}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) \label{eq:r2-rsqrt}\\[6pt]
\boldsymbol{P}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^2\boldsymbol{P}_t \label{eq:r3-rsqrt}\\[6pt]
\lim_{t\to\infty} \boldsymbol{G}_t = \boldsymbol{G}\boldsymbol{P}^{-1/2}\notag
\end{gather}
Substituting $\boldsymbol{G}=\boldsymbol{P}$ allows solving for $\boldsymbol{P}^{1/2}$, and substituting $\boldsymbol{G}=\boldsymbol{I}$ allows solving for $\boldsymbol{P}^{-1/2}$. If we observe carefully, the above iteration is actually an embodiment of the following limit:
\begin{equation} \prod_{t=0}^{\infty}(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) = \boldsymbol{P}^{-1/2} \label{eq:prod-rsqrt} \end{equation}
Interestingly, directly proving this limit is not complicated. Applying the square root to both sides of Eq. $\eqref{eq:r3-rsqrt}$ and substituting it into the above expression yields:
\begin{equation} \prod_{t=0}^{\infty}(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) = \prod_{t=0}^{\infty} \boldsymbol{P}_{t+1}^{1/2}\boldsymbol{P}_t^{-1/2} = \lim_{t\to\infty} \boldsymbol{P}_t^{1/2}\boldsymbol{P}_0^{-1/2} = \lim_{t\to\infty} \boldsymbol{P}_t^{1/2}\boldsymbol{P}^{-1/2} \end{equation}
From this, it can be seen that as long as the sequence $\{\boldsymbol{P}_t\}$ remains invertible and its final limit is $\boldsymbol{I}$, the limit $\eqref{eq:prod-rsqrt}$ holds automatically. As for how the iteration $\eqref{eq:r3-rsqrt}$ ensures these two conditions for $\{\boldsymbol{P}_t\}$, we will discuss that in a moment.
General Form #
Let us consider a general iteration:
\begin{gather}
\boldsymbol{G}_0 = \boldsymbol{G}, \quad \boldsymbol{P}_0 = \boldsymbol{P} \notag\\[6pt]
\boldsymbol{G}_{t+1} = \boldsymbol{G}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^s\\[6pt]
\boldsymbol{P}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^r\boldsymbol{P}_t
\end{gather}
Similarly, if the sequence $\{\boldsymbol{P}_t\}$ remains invertible and its final limit is $\boldsymbol{I}$, then it can be proven that:
\begin{equation} \lim_{t\to\infty} \boldsymbol{G}_t = \boldsymbol{G}\boldsymbol{P}^{-s/r} \end{equation}
Thus, we have obtained a general iterative form for calculating any $-s/r$ power of a matrix. Building on this result, we only need to choose $\boldsymbol{G}=\boldsymbol{P}, s=r-1$ to obtain $\boldsymbol{P}^{1/r}$. Therefore, we only need to focus on solving for the inverse of $0 \sim 1$ powers.
Consequently, the problem becomes how to select appropriate $\{a_t, b_t, c_t\}$ so that the sequence $\{\boldsymbol{P}_t\}$ converges to $\boldsymbol{I}$ as quickly as possible. Faster convergence means we can reach the specified accuracy with fewer iteration steps.
Iteration Coefficients #
According to the assumption, $\boldsymbol{P}_0 = \boldsymbol{P}$ is a matrix with eigenvalues in $[0,1]$, while the target matrix $\boldsymbol{I}$ is a matrix with all eigenvalues equal to 1. Thus, the sequence $\{\boldsymbol{P}_t\}$ is essentially the process of transforming eigenvalues from any value in $[0,1]$ to $1$. This is exactly what $\mcsgn$ does!
If we let $\boldsymbol{X}_t = \boldsymbol{P}_t^{1/r}$, then $\boldsymbol{X}_0 = \boldsymbol{P}^{1/r}$ is also a matrix with eigenvalues in $[0,1]$, and the iteration equation becomes:
\begin{equation} \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t^{r+1} + c_{t+1}\boldsymbol{X}_t^{2r+1} \end{equation}
Now the problem is how to make $\boldsymbol{X}_0$ approach $\boldsymbol{I}$ as quickly as possible. This is essentially the same problem discussed in "Newton-Schulz Iteration for the msign Operator (Part 1)" and "Newton-Schulz Iteration for the msign Operator (Part 2)". Among them, "Part 2" gave the theoretical optimal solution for $r=2$, but its derivation process and conclusions can be generalized to any $r$.
Specifically, we first transform the problem into a scalar iteration:
\begin{equation} x_{t+1} = f_t(x_t) = a_{t+1}x_t + b_{t+1}x_t^{r+1} + c_{t+1}x_t^{2r+1} \end{equation}
Then we prove that the greedy solution is the optimal solution, and solving for the greedy solution becomes solving the system of equations:
\begin{equation}
\begin{gathered}
f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + \mathcal{E} \\
f_t(x_1) = 1 + \mathcal{E}, \quad f_t(x_2) = 1 - \mathcal{E} \\
f_t'(x_1) = 0, \quad f_t'(x_2) = 0
\end{gathered}
\end{equation}
For simplicity, parameterize $f_t'$ as:
\begin{equation} f_t'(x) = k(x^r-x_1^r)(x^r-x_2^r) \end{equation}
Then it can be solved using Mathematica, just like in "Part 2".
Initial Analysis #
Before formally solving, we also need to analyze the initialization. In the previous article "Efficient Calculation of Matrix Square Root and Inverse Square Root", we mentioned that under the assumption that the eigenvalues of $\boldsymbol{P}$ are non-negative, we can compress the eigenvalues into $[0,1]$ by dividing by $\tr(\boldsymbol{P})$. However, this compression ratio is often too large. In this article, we change it to:
\begin{equation} \boldsymbol{P}_0 = \frac{\boldsymbol{P}}{\sqrt{\tr(\boldsymbol{P}^2)}} \label{eq:trace-scaling} \end{equation}
We know that $\tr(\boldsymbol{P}^2)$ is equal to the sum of the squares of all eigenvalues, while $\tr(\boldsymbol{P})^2$ is equal to the square of the sum of all eigenvalues. For non-negative eigenvalues, $\tr(\boldsymbol{P}^2) \leq \tr(\boldsymbol{P})^2$ always holds, so the above formula provides a tighter initial value. Notably, calculating $\tr(\boldsymbol{P}^2)$ does not require explicitly calculating $\boldsymbol{P}^2$, because we have the identity:
\begin{equation} \tr(\boldsymbol{P}^2) = \langle \boldsymbol{P}, \boldsymbol{P}^{\top}\rangle_F \end{equation}
Next, we also need to analyze how small the eigenvalues we need to handle are, which is the same as the initial singular value analysis in "Newton-Schulz Iteration for the msign Operator (Part 1)". After dividing by $\sqrt{\tr(\boldsymbol{P}^2)}$, the eigenvalues of $\boldsymbol{P}_0$ form a unit vector. If all eigenvalues are equal, each eigenvalue is $1/\sqrt{n}$. By the Pigeonhole Principle, in general, there must exist eigenvalues smaller than $1/\sqrt{n}$. To be conservative, we support down to $0.01/\sqrt{n}$.
Considering large enough LLMs where $n$ has reached the $100^2$ level, we need compatibility down to $0.0001$. Note that this is for the eigenvalues of $\boldsymbol{P}_0$, and since $\boldsymbol{X}_0 = \boldsymbol{P}_0^{1/r}$, we only need to support $\boldsymbol{X}_0$ down to $0.0001^{1/r}$. This is more ideal than the $\mcsgn$ or $\text{msign}$ cases, because for $\mcsgn$ and $\text{msign}$, the input is $\boldsymbol{X}_0$ and we need to handle its small eigenvalues, but here the input is $\boldsymbol{P}_0$, and we only need to consider it starting from $\boldsymbol{P}_0$.
Calculation Results #
Integrating the above considerations, our final solver code is as follows:
r = 4;
df[x_] = k*(x^r - x1^r) (x^r - x2^r);
f[x_] = Integrate[df[x], {x, 0, x}];
sol[l_, u_] :=
NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e, f[u] == 1 + e,
l < x1 < x2 < u, e > 0, k > 0}, {k, x1, x2, e}]
ff[x_, l_, u_] = f[x]*2/(f[l] + f[u]) // Expand;
lt = 0.0001^(1/r); ut = 1; lambda = 0.1;
While[1 - lt > 0.0001,
fff[x_] = ff[x, lt, ut] /. sol[Max[lt, lambda*ut], ut][[1]];
Print[fff[x]];
lt = fff[lt]; ut = 2 - lt]
f[x] /. Solve[f[1] == 1, k][[1]] /. {x1 -> 1, x2 -> 1}
The calculation results for $r=1 \sim 5$ are as follows:
| $r$ |
$t$ |
$a$ |
$b$ |
$c$ |
| 1 |
1 |
14.2975 |
-31.2203 |
18.9214 |
| 2 |
7.12258 |
-7.78207 |
2.35989 |
| 3 |
6.9396 |
-7.61544 |
2.3195 |
| 4 |
5.98456 |
-6.77016 |
2.12571 |
| 5 |
3.79109 |
-4.18664 |
1.39555 |
| $\geq 6$ |
3 |
-3 |
1 |
| 2 |
1 |
7.42487 |
-18.3958 |
12.8967 |
| 2 |
3.48773 |
-2.33004 |
0.440469 |
| 3 |
2.77661 |
-2.07064 |
0.463023 |
| 4 |
1.99131 |
-1.37394 |
0.387593 |
| $\geq 5$ |
15/8 |
-5/4 |
3/8 |
| 3 |
1 |
5.05052 |
-13.5427 |
10.2579 |
| 2 |
2.31728 |
-1.06581 |
0.144441 |
| 3 |
1.79293 |
-0.913562 |
0.186699 |
| 4 |
1.56683 |
-0.786609 |
0.220008 |
| $\geq 5$ |
14/9 |
-7/9 |
2/9 |
| 4 |
1 |
3.85003 |
-10.8539 |
8.61893 |
| 2 |
1.80992 |
-0.587778 |
0.0647852 |
| 3 |
1.50394 |
-0.594516 |
0.121161 |
| $\geq 4$ |
45/32 |
-9/16 |
5/32 |
| 5 |
1 |
3.11194 |
-8.28217 |
6.67716 |
| 2 |
1.5752 |
-0.393327 |
0.0380364 |
| 3 |
1.3736 |
-0.44661 |
0.0911259 |
| $\geq 4$ |
33/25 |
-11/25 |
3/25 |
The convergence value for the last step is derived from $x_1=x_2=1$ and $f(1)=1$.
Let's Test It Out #
A simple test code is as follows:
import numpy as np
import jax.numpy as jnp
coefs = [
None,
[
(14.2975, -31.2203, 18.9214),
(7.12258, -7.78207, 2.35989),
(6.9396, -7.61544, 2.3195),
(5.98456, -6.77016, 2.12571),
(3.79109, -4.18664, 1.39555),
(3, -3, 1),
],
[
(7.42487, -18.3958, 12.8967),
(3.48773, -2.33004, 0.440469),
(2.77661, -2.07064, 0.463023),
(1.99131, -1.37394, 0.387593),
(15 / 8, -5 / 4, 3 / 8),
],
[
(5.05052, -13.5427, 10.2579),
(2.31728, -1.06581, 0.144441),
(1.79293, -0.913562, 0.186699),
(1.56683, -0.786609, 0.220008),
(14 / 9, -7 / 9, 2 / 9),
],
[
(3.85003, -10.8539, 8.61893),
(1.80992, -0.587778, 0.0647852),
(1.50394, -0.594516, 0.121161),
(45 / 32, -9 / 16, 5 / 32),
],
[
(3.11194, -8.28217, 6.67716),
(1.5752, -0.393327, 0.0380364),
(1.3736, -0.44661, 0.0911259),
(33 / 25, -11 / 25, 3 / 25),
],
]
def abc(r=1, steps=None, scale=1):
w, steps = coefs[r], steps or len(coefs[r])
for a, b, c in w[:steps] + w[-1:] * max(steps - len(w), 0):
yield a / scale, b / scale**(r + 1), c / scale**(2 * r + 1)
def matmul_invroot(G, P, r, s=1, steps=None, eps=1e-5):
"""return G @ P^(-s/r)
"""
I = jnp.eye(P.shape[0], dtype=P.dtype)
P = P / (t := (P * P.mT).sum()**0.5) + eps * I
for a, b, c in abc(r, steps, 1.001):
W = a * I + b * P + c * P @ P
W1, W2 = jnp.linalg.matrix_power(W, s), jnp.linalg.matrix_power(W, r)
G, P = G @ W1, P @ W2
return G * t**(-s / r)
def matmul_invroot_by_eigh(G, P, r, s=1):
"""return G @ P^(-s/r)
"""
S, Q = jnp.linalg.eigh(P)
return G @ Q @ jnp.diag(S**(-s / r)) @ jnp.linalg.inv(Q)
d = 1000
s, r = 1, 4
G = np.random.randn(2 * d, d) / d**0.5
P = (x := np.random.randn(d, d) / d**0.5) @ x.T + 0.001 * np.eye(d)
X1 = matmul_invroot_by_eigh(G, P, r, s)
X2 = matmul_invroot(G, P, r, s, eps=0)
print(jnp.abs(X1 - X2).mean()) # ~= 1e-3
X2 = matmul_invroot(jnp.array(G, dtype='bfloat16'), jnp.array(P, dtype='bfloat16'), r, s, eps=0)
print(jnp.abs(X1 - X2).mean()) # ~= 2e-3
There are a few key points to note. First, the minimum eigenvalue of the input $\boldsymbol{P}$ cannot be too small, otherwise the iteration process is extremely prone to exploding, even if we only want to calculate positive powers like $\boldsymbol{P}^{1/2}$. This is not difficult to understand, because $\sqrt{x}$ is quite ill-conditioned at $x=0$. Once error causes it to "accidentally" enter the negative half-axis, a (real) solution no longer exists, and the performance of the iteration becomes unpredictable.
How small is "not too small"? Roughly, the minimum eigenvalue of $\boldsymbol{P}/\sqrt{\tr(\boldsymbol{P}^2)}$ should not be significantly smaller than the minimum eigenvalue we considered, which is $0.0001$. If this cannot be guaranteed, it is recommended to directly set:
\begin{equation} \boldsymbol{P}_0 = \frac{\boldsymbol{P}}{\sqrt{\tr(\boldsymbol{P}^2)}} + \epsilon \cdot\boldsymbol{I} \end{equation}
where $\epsilon \sim 0.0001$. This will sacrifice a little precision but can significantly increase numerical stability.
Additionally, the number of iteration steps does not need to exceed the recommended value `len(coefs[r])` in most cases, especially in low-precision calculation scenarios, because more iteration steps make it easier for errors to accumulate and explode. In fact, as long as the eigenvalues are within the considered range, the recommended steps are sufficient to achieve ideal precision, unless we iterate with fp32 or higher precision, in which case we might consider setting $\epsilon=0$, `scale=1`, and using more iteration steps.
Article Summary #
This article generalizes the results of the previous article to the calculation of arbitrary $r$-th roots and inverse $r$-th roots, obtaining a general iterative format for calculation of matrix power $-1/r$.