By 苏剑林 | July 1, 2025
From the article "A Brief History of Linear Attention: From Imitation and Innovation to Feedback", we can find that DeltaNet and subsequent linear attention models basically all involve the inverse matrix $(\boldsymbol{I} + \boldsymbol{K}\boldsymbol{K}^{\top}\odot\boldsymbol{M}^-)^{-1}$. This post is specifically dedicated to exploring the calculation of the inverse of such triangular matrices characterized by a "diagonal + low-rank" structure.
Basic Results
We define the problem generally as follows:
Given matrices $\boldsymbol{Q}, \boldsymbol{K} \in \mathbb{R}^{n \times d}$ and a diagonal matrix $\boldsymbol{\Lambda} \in \mathbb{R}^{n \times n}$, satisfying $n \gg d$, define
\begin{equation}\boldsymbol{T} = \boldsymbol{\Lambda} + \boldsymbol{Q}\boldsymbol{K}^{\top}\odot\boldsymbol{M}^-\end{equation}
where $\boldsymbol{M}^- = \boldsymbol{M} - \boldsymbol{I}$, and the matrix $\boldsymbol{M}$ is defined as
\begin{equation}M_{i,j} = \left\{\begin{aligned} &1, &i \geq j \\ &0, &i < j\end{aligned}\right.\end{equation}
The goal is to find the inverse matrix $\boldsymbol{T}^{-1}$ and prove that its complexity is $\mathcal{O}(n^2)$.
First, if there were no lower triangular constraint imposed by $\odot\boldsymbol{M}^-$, it could be solved directly using the "Woodbury Matrix Identity":
\begin{equation}(\boldsymbol{\Lambda} + \boldsymbol{Q}\boldsymbol{K}^{\top})^{-1} = \boldsymbol{\Lambda}^{-1} - \boldsymbol{\Lambda}^{-1} \boldsymbol{Q}(\boldsymbol{I} + \boldsymbol{K}^{\top}\boldsymbol{\Lambda}^{-1}\boldsymbol{Q})^{-1}\boldsymbol{K}^{\top}\boldsymbol{\Lambda}^{-1}\end{equation}
It is easy to verify that the computational complexity of the right-hand side is $\mathcal{O}(n^2)$. However, with the addition of $\odot\boldsymbol{M}^-$, $\boldsymbol{T}$ itself no longer possesses the "diagonal + low-rank" structure, so it cannot be solved directly by this identity. Focusing on the characteristic of it being a lower triangular matrix, a basic strategy is recursion, as we have the block matrix identity:
\begin{equation}\begin{bmatrix}\boldsymbol{A} & \boldsymbol{0} \\ \boldsymbol{C} & \boldsymbol{B}\end{bmatrix}^{-1} = \begin{bmatrix}\boldsymbol{A}^{-1} & \boldsymbol{0} \\ -\boldsymbol{B}^{-1}\boldsymbol{C}\boldsymbol{A}^{-1} & \boldsymbol{B}^{-1}\end{bmatrix}\end{equation}
This allows us to transform $\boldsymbol{T}^{-1}$ into a recursive form (convention: in the absence of parentheses, slicing has the highest priority):
\begin{equation}\boldsymbol{T}_{[:l+1,:l+1]}^{-1} = \begin{bmatrix}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{0} \\ -\boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\boldsymbol{T}_{[l:l+1,:l]}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\end{bmatrix}\end{equation}
The main calculation here is $\boldsymbol{T}_{[l:l+1,:l]}\boldsymbol{T}_{[:l,:l]}^{-1}$, which is a product of a $1 \times l$ and an $l \times l$ matrix with a complexity of $\mathcal{O}(l^2)$. This means the complexity of each iteration grows quadratically, resulting in a total complexity of $\mathcal{O}(n^3)$.
Low-Rank Structure
Of course, this is because we haven't yet utilized the low-rank structure of $\boldsymbol{T}$ (before the $\odot\boldsymbol{M}^-$ mask). By leveraging it, we get $\boldsymbol{T}_{[l:l+1,:l]} = \boldsymbol{Q}_{[l:l+1]}\boldsymbol{K}_{[:l]}^{\top}$. Substituting this into the equation above yields:
\begin{equation}\boldsymbol{T}_{[:l+1,:l+1]}^{-1} = \begin{bmatrix}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{0} \\ -\boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\boldsymbol{Q}_{[l:l+1]}\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\end{bmatrix}\end{equation}
Note that $\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1} \in \mathbb{R}^{d \times l}$. If we can use this as our recursive variable, the complexity of each step would be only $\mathcal{O}(l)$, successfully reducing the total complexity to $\mathcal{O}(n^2)$. Based on this idea, we have:
\begin{equation}\begin{aligned}
\boldsymbol{K}_{[:l+1]}^{\top}\boldsymbol{T}_{[:l+1,:l+1]}^{-1} =&\, \begin{bmatrix}\boldsymbol{K}_{[:l]}^{\top} & \boldsymbol{K}_{[l:l+1]}^{\top}\end{bmatrix}\begin{bmatrix}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{0} \\ -\boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\boldsymbol{Q}_{[l:l+1]}\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\end{bmatrix} \\[6pt]
=&\, \begin{bmatrix}\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{0}\end{bmatrix} + \boldsymbol{K}_{[l:l+1]}^{\top}\underbrace{\begin{bmatrix}-\boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\boldsymbol{Q}_{[l:l+1]}\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\end{bmatrix}}_{\text{which is simply } (\boldsymbol{T}^{-1})_{[l:l+1,:l+1]}}\end{aligned}\end{equation}
As we can see, this recursive process does not involve $\mathcal{O}(l^2)$ operations, so the approach is feasible. One just needs to introduce a new variable to cache $\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1}$. If we replace $l+1$ with $l+c$, we can obtain a chunked recursive format.
The test code is as follows:
import numpy as np
n, d, c = 1000, 100, 200
Q = np.random.randn(n, d) / d**0.5
K = np.random.randn(n, d) / d**0.5
T = np.tril(Q @ K.T, -1) + np.eye(n)
Y, Z = np.zeros((n, n)), np.zeros((d, n))
for l in range(0, n, c):
Y[l:l + c, l:l + c] = np.linalg.inv(T[l:l + c, l:l + c])
Y[l:l + c, :l] = - Y[l:l + c, l:l + c] @ Q[l:l + c] @ Z[:, :l]
Z[:, :l + c] += K[l:l + c].T @ Y[l:l + c, :l + c]
print(np.allclose(Y @ T, np.eye(n)))
Multiplication Calculation
Based on the same logic, we can also prove:
For any matrix $\boldsymbol{V} \in \mathbb{R}^{n \times d}$, calculating $\boldsymbol{T}^{-1}\boldsymbol{V}$ requires only $\mathcal{O}(n)$ complexity.
The proof only requires a slight modification of the previous process. First, we have:
\begin{equation}\begin{aligned}
(\boldsymbol{T}^{-1}\boldsymbol{V})_{[:l+1]} =&\, \boldsymbol{T}_{[:l+1,:l+1]}^{-1}\boldsymbol{V}_{[:l+1]} \\[6pt]
=&\, \begin{bmatrix}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{0} \\ -\boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\boldsymbol{Q}_{[l:l+1]}\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1} & \boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\end{bmatrix}\begin{bmatrix}\boldsymbol{V}_{[:l]} \\ \boldsymbol{V}_{[l:l+1]}\end{bmatrix} \\[6pt]
=&\, \begin{bmatrix}\boldsymbol{T}_{[:l,:l]}^{-1}\boldsymbol{V}_{[:l]} \\ -\boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\boldsymbol{Q}_{[l:l+1]}\boldsymbol{K}_{[:l]}^{\top}\boldsymbol{T}_{[:l,:l]}^{-1}\boldsymbol{V}_{[:l]} + \boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}\boldsymbol{V}_{[l:l+1]}\end{bmatrix} \\[6pt]
=&\, \begin{bmatrix}(\boldsymbol{T}^{-1}\boldsymbol{V})_{[:l]} \\ \boldsymbol{T}_{[l:l+1,l:l+1]}^{-1}(\boldsymbol{V}_{[l:l+1]} - \boldsymbol{Q}_{[l:l+1]}\boldsymbol{K}_{[:l]}^{\top}(\boldsymbol{T}^{-1}\boldsymbol{V})_{[:l]})\end{bmatrix}
\end{aligned}\end{equation}
Then
\begin{equation}\begin{aligned}
\boldsymbol{K}_{[:l+1]}^{\top}(\boldsymbol{T}^{-1}\boldsymbol{V})_{[:l+1]} =&\, \begin{bmatrix}\boldsymbol{K}_{[:l]}^{\top} & \boldsymbol{K}_{[l:l+1]}^{\top}\end{bmatrix}\begin{bmatrix}(\boldsymbol{T}^{-1}\boldsymbol{V})_{[:l]} \\ (\boldsymbol{T}^{-1}\boldsymbol{V})_{[l:l+1]} \end{bmatrix} \\[8pt]
=&\, \boldsymbol{K}_{[:l]}^{\top}(\boldsymbol{T}^{-1}\boldsymbol{V})_{[:l]} + \boldsymbol{K}_{[l:l+1]}^{\top}(\boldsymbol{T}^{-1}\boldsymbol{V})_{[l:l+1]}
\end{aligned}\end{equation}
Thus, by simply caching $\boldsymbol{K}_{[:l]}^{\top}(\boldsymbol{T}^{-1}\boldsymbol{V})_{[:l]} \in \mathbb{R}^{d \times d}$, the computational complexity of each step becomes independent of $l$, and the total complexity is therefore $\mathcal{O}(n)$. Likewise, replacing $l+1$ with $l+c$ provides the chunked format.
The test code is as follows:
import numpy as np
n, d, c = 1000, 100, 200
Q = np.random.randn(n, d) / d**0.5
K = np.random.randn(n, d) / d**0.5
V = np.random.randn(n, d) / d**0.5
T = np.tril(Q @ K.T, -1) + np.eye(n)
Y, Z = np.zeros((n, d)), np.zeros((d, d))
for l in range(0, n, c):
X = np.linalg.inv(T[l:l + c, l:l + c])
Y[l:l + c] = X @ (V[l:l + c] - Q[l:l + c] @ Z)
Z += K[l:l + c].T @ Y[l:l + c]
print(np.allclose(T @ Y, V))
Summary
This article discussed the inversion problem for triangular matrices with "diagonal + low-rank" characteristics, which commonly appear in modern linear attention models.