Monarch Matrix: Computationally Efficient Sparse Matrix Decomposition

By 苏剑林 | July 24, 2024

In the problem of matrix compression, we usually have two strategies to choose from: low-rankness and sparsity. Low-rankness reduces matrix dimensions by finding a low-rank approximation, while sparsity reduces the complexity of the matrix by decreasing the number of its non-zero elements. If SVD is aimed at the low-rank approximation of a matrix, then what is the corresponding algorithm for finding a sparse approximation of a matrix?

Next, we are going to learn from the paper "Monarch: Expressive Structured Matrices for Efficient and Accurate Training". It provides an answer to the above question—the "Monarch matrix." This is a family of matrices that can be decomposed into the product of several permutation matrices and sparse matrices, characterized by being both computationally efficient and expressive. The paper also discusses how to find the Monarch approximation of a general matrix and how to use Monarch matrices to parameterize Large Language Models (LLMs) to improve their speed.

It is worth noting that the author of this paper is also the author of the famous Flash Attention, Tri Dao. Almost all of his work is dedicated to improving the performance of LLMs. This Monarch paper is also one of the few featured papers on his homepage. From this point alone, it is a topic very much worth studying.

SVD Review

First, let's briefly review SVD (Singular Value Decomposition). For an $n \times m$ matrix $A$, SVD decomposes it into: \begin{equation}A = U\Sigma V\end{equation} where $U$ and $V$ are orthogonal matrices of shapes $n \times n$ and $m \times m$ respectively, and $\Sigma$ is an $n \times m$ diagonal matrix with non-negative diagonal elements arranged from largest to smallest. When we retain only the first $r$ diagonal elements of $\Sigma$, we obtain an approximate decomposition of $A$ with a rank not exceeding $r$: \begin{equation}A \approx U_{[:,:r]}\Sigma_{[:r,:r]} V_{[:r,:]}\end{equation} Here the subscripts are executed according to Python slicing, so $U_{[:,:r]}$ has shape $n \times r$, $\Sigma_{[:r,:r]}$ has shape $r \times r$, and $V_{[:r,:]}$ has shape $r \times m$. This means the rank of $U_{[:,:r]}\Sigma_{[:r,:r]} V_{[:r,:]}$ is at most $r$.

Particularly, the low-rank approximation obtained by SVD is precisely the exact solution to the following optimization problem: \begin{equation}U_{[:,:r]}\Sigma_{[:r,:r]} V_{[:r,:]} = \mathop{\text{argmin}}_{rank(B)\leq r} \Vert A - B\Vert_F^2\end{equation} where $\Vert\cdot\Vert_F^2$ is the square of the Frobenius norm of the matrix, i.e., the sum of the squares of all elements in the matrix. That is to say, under the Frobenius norm, the optimal rank-$r$ approximation of matrix $A$ is $U_{[:,:r]}\Sigma_{[:r,:r]} V_{[:r,:]}$. This conclusion is known as the "Eckart-Young-Mirsky Theorem." It is because of this conclusion that we said at the beginning of the article that "SVD is aimed at the low-rank approximation of matrices."

There is a vast amount of content that can be expanded on regarding SVD—it could fill an entire book. We won't delve deeper here. Finally, it should be noted that the computational complexity of SVD is $\mathcal{O}(nm \cdot \min(m,n))$, as we must perform an eigenvalue decomposition on at least one of $A^{\top} A$ or $A A^{\top}$. If we know for certain that we are performing SVD to find a rank-$r$ approximation, the complexity can be reduced, which leads to Truncated SVD.

Monarch Matrix

Low-rank decomposition is widely used, but it does not always meet our needs. For example, the low-rank approximation of an invertible square matrix is necessarily non-invertible, which means low-rank approximation is unsuitable for scenarios requiring matrix inversion. In such cases, another choice is sparse approximation; sparse matrices can usually ensure that the rank does not degenerate.

Note that sparsity and low-rankness are not necessarily linked. For instance, the identity matrix is a very sparse matrix, but it is invertible (full rank). Finding a sparse approximation for a matrix isn't difficult; for example, setting all elements to zero except for the $k$ elements with the largest absolute values is a very simple sparse approximation. The problem is that it's usually not practical. Thus, the challenge lies in finding practical sparse approximations. By "practical," we mean maintaining enough expressive power or degree of approximation while achieving some degree of sparsification, and this sparsification must have an appropriate structure that helps speed up matrix operations (such as multiplication and inversion).

The Monarch matrix was created exactly for this purpose. Suppose $n=m^2$ is a square number. Then the Monarch matrix is a subset of the set of all $n$-order matrices, denoted as $\mathcal{M}^{(n)}$. It is defined as the set of matrices of the following form: \begin{equation}M = PLPR\end{equation} where $P$ is an $n \times n$ permutation matrix (orthogonal matrix), and $L, R$ are block-diagonal matrices. Let's introduce them one by one.

Permutation Matrix

The effect achieved by the permutation matrix $P$ is to permute the vector $[x_1, x_2, \cdots, x_n]$ into a new vector: \begin{equation}[x_1, x_{1+m}, \cdots , x_{1+(m−1)m}, x_2, x_{2+m}, \cdots , x_{2+(m−1)m}, \cdots , x_m, x_{2m}, \cdots , x_n]\end{equation} Of course, writing it this way might still be confusing, but in fact, the implementation in code is very simple:

Px = x.reshape(m, m).transpose().reshape(n)

As shown in the figure below:

Schematic of the permutation matrix P

Readers who have previously done CV work might find this operation familiar; it is actually the "Shuffle" operation in ShuffleNet. This combined operation of first reshaping a vector, then transposing it, and finally reshaping it back, creates a "pseudo-Shuffle" effect. It can also be viewed as an $m$-ary "bit-reversal permutation." Obviously, doing this operation twice will restore the vector to its original state, so we have $P^2=I$, hence $P^{-1}=P^{\top}=P$.

Block Diagonal

After discussing $P$, let's talk about $L$ and $R$. They are also $n \times n$ matrices, but they are $m \times m$ block-diagonal matrices, where each block is of size $m \times m$, as shown below:

Monarch matrix form M=PLPR

When $n$ is large enough, the number of zeros in $L$ and $R$ dominates, so $L$ and $R$ are sparse matrices. This means the Monarch matrix is a matrix decomposition form with sparse characteristics. Since $P$ is fixed, the variable elements in $PLPR$ come from the non-zero elements of $L$ and $R$. Therefore, although $M$ is an $n \times n$ matrix, its actual free parameters do not exceed $2m^3 = 2n^{1.5}$. From the number $1.5$, we can glimpse the intention of the Monarch matrix: it hopes to reduce operations that original required square complexity to $1.5$ power complexity via Monarch matrix approximation.

Efficiency Analysis

Can the Monarch matrix achieve this goal? In other words, can the Monarch matrix meet the "practical" standard mentioned earlier? We will discuss expressive power later; first, let's look at computational efficiency.

For example, in "matrix-vector" multiplication, the standard complexity is $\mathcal{O}(n^2)$. But for a Monarch matrix, we have $Mx = P(L(P(Rx)))$. Since multiplying by $P$ is just a simple reshape and transpose, it consumes almost no computation. The main computational load comes from $L$ or $R$ multiplying a vector. Due to the block-diagonal characteristics of $L$ and $R$, we can divide the vector into $m$ groups, thereby transforming the problem into $m$ instances of $m \times m$ matrices multiplying $m$-dimensional vectors. The total complexity is $2m \times \mathcal{O}(m^2) = \mathcal{O}(2n^{1.5})$, which is lower than $\mathcal{O}(n^2)$.

Another example is inversion. Consider $M^{-1}x$. The standard complexity for inverting an $n$-order matrix is $\mathcal{O}(n^3)$. But for a Monarch matrix, we have $M^{-1} x = R^{-1}PL^{-1}P x$. The main computational load comes from $L^{-1}$, $R^{-1}$ and the corresponding "matrix-vector" multiplications. Since both $L$ and $R$ are block-diagonal matrices, we only need to invert each block matrix on the diagonal separately. This means there are $2m$ inversions of $m \times m$ matrices. The complexity is $2m \times \mathcal{O}(m^3) = \mathcal{O}(2n^2)$, which is also lower than the standard $\mathcal{O}(n^3)$. It is also possible to write out $M^{-1}$ individually, but that requires utilizing the identity in Eq. \eqref{eq:high-m-lr} which we will discuss later.

So the conclusion is that, because the $P$ multiplication takes almost no computation and $L, R$ are block-diagonal, operations related to an $n$-order Monarch matrix can basically be transformed into $2m$ independent operations on $m \times m$ matrices, thereby reducing the total computational complexity. At least regarding computational efficiency, there is no problem with Monarch matrices. Moreover, since the non-zero elements of $L$ and $R$ themselves have a square structure, they are very convenient to implement and can fully utilize the GPU for calculation without causing unnecessary waste.

Monarch Decomposition

After confirming the validity of Monarch matrices, a key application question is: given any $n=m^2$ order matrix $A$, how do we find its Monarch approximation? Similar to SVD, we define the following optimization problem: \begin{equation}\mathop{\text{argmin}}_{M\in\mathcal{M}^{(n)}} \Vert A - M\Vert_F^2\end{equation} Very fortunately, there is a solving algorithm for this problem with a complexity not exceeding $\mathcal{O}(n^{2.5})$, which is even more efficient than SVD's $\mathcal{O}(n^3)$.

High-Dimensional Arrays

A key step in understanding this algorithm is to transform the matrices and vectors related to Monarch into a higher-dimensional array form. Specifically, the Monarch matrix $M$ is originally a two-dimensional array, where each element is denoted as $M_{i,j}$, representing the element at the $i$-th row and $j$-th column. Now, based on the characteristics of block matrices, we will equivalently represent it as a four-dimensional array, where each element is labeled $M_{i,j,k,l}$, representing the element at the $i$-th large row, $j$-th small row, $k$-th large column, and $l$-th small column, as shown in the figure below:

Viewing Monarch-related matrices/vectors as high-dimensional arrays

Although it sounds tedious to explain, the code is just one line:

M.reshape(m, m, m, m)

Similarly, an $n$-dimensional (column) vector $x$ is converted into a two-dimensional array of $m \times m$. The code is also one line: x.reshape(m, m). Naturally, $L$ and $R$ are represented as three-dimensional arrays of $m \times m \times m$. For example, $L_{i,j,k}$ represents the element in the $i$-th block, $j$-th small row, and $k$-th small column. This is already the most efficient way to store $L$ and $R$, but for unified processing, we can also use the Kronecker delta symbol to expand them to four dimensions, such as $L_{i,j,k,l} = \delta_{i,k}L_{i,j,l}$ and $R_{i,j,k,l} = \delta_{i,k}R_{i,j,l}$.

New Identity

Next, we will derive a new relation between $M$ and $L, R$. First, it can be proven that in the two-dimensional representation, the multiplication of matrix $P$ and vector $x$ becomes simpler; the result is just the transpose of $x$, i.e., $(Px)_{i,j} = x_{j,i}$. Therefore, we have $(PR)_{i,j,k,l} = R_{j,i,k,l} = \delta_{j,k}R_{j,i,l}$. Then, when multiplying two matrices, under the four-dimensional representation, there are two summation indices, so: \begin{equation}(L P R)_{\alpha,\beta,k,l} = \sum_{i,j} L_{\alpha,\beta,i,j}(PR)_{i,j,k,l} = \sum_{i,j} \delta_{\alpha, i} L_{\alpha,\beta,j}\delta_{j,k}R_{j,i,l} = L_{\alpha,\beta,k}R_{k,\alpha,l}\end{equation} Finally, we have $(P L P R)_{\alpha,\beta,k,l} = L_{\beta,\alpha,k}R_{k,\beta,l}$. Replacing $\alpha, \beta$ back with $i, j$, we get $(P L P R)_{i,j,k,l} = L_{j,i,k}R_{k,j,l}$. Since $M=PLPR$, we have: \begin{equation}M_{i,j,k,l} = L_{j,i,k}R_{k,j,l}\label{eq:high-m-lr}\end{equation} From this identity, it can be seen that when we fix a pair $(j,k)$, the left side is a submatrix and the right side is the outer product of two vectors. This means that if we want to find a Monarch approximation for matrix $A$, we only need to convert $A$ into a four-dimensional array in the same way and fix a pair $(j,k)$. Then the problem turns into finding a "rank-1 approximation" for the corresponding submatrix! In other words, with this identity, finding a Monarch approximation for matrix $A$ can be transformed into finding "rank-1 approximations" for $m^2$ submatrices. This can be completed using SVD, each with a complexity of no more than $\mathcal{O}(m^3)$, so the total complexity does not exceed $m^2 \times \mathcal{O}(m^3) = \mathcal{O}(n^{2.5})$.

Reference Implementation

A reference implementation written by the author using Numpy is as follows:

import numpy as np

def monarch_factorize(A):
    # n = m**2
    m = int(np.sqrt(A.shape[0]))
    M = A.reshape(m, m, m, m).transpose(1, 2, 0, 3)
    U, S, V = np.linalg.svd(M)
    # The rank-1 approximation gives the components for L and R
    L = (U[:, :, :, 0] * S[:, :, :1]**0.5).transpose(0, 2, 1)
    R = (V[:, :, 0] * S[..., :1]**0.5).transpose(1, 0, 2)
    return L, R

def convert_3D_to_2D(LR):
    m = LR.shape[0]
    n = m**2
    X = np.zeros((m, m, m, m))
    for i in range(m):
        X[i, i] += LR[i]
    return X.transpose(0, 2, 1, 3).reshape(n, n)

m = 8
n = m**2
A = np.where(np.random.rand(n, n) > 0.8, np.random.randn(n, n), 0)

L, R = monarch_factorize(A)
L_2d = convert_3D_to_2D(L)
R_2d = convert_3D_to_2D(R)

# P matrix operation implementation via reshape and transpose
# Here we construct the effect of PL and PR for verification
PL = L_2d.reshape(m, m, n).transpose(1, 0, 2).reshape(n, n)
PR = R_2d.reshape(m, m, n).transpose(1, 0, 2).reshape(n, n)

U, S, V = np.linalg.svd(A)

print('Monarch error:', np.square(A - PL.dot(PR)).mean())
print('Low-Rank error:', np.square(A - (U[:, :m] * S[:m]).dot(V[:m])).mean())

I briefly compared the rank-$m$ approximation obtained via SVD (where the parameter count of the low-rank approximation is equivalent to that of the Monarch approximation) and found that for completely dense matrices, the mean squared error of the rank-$m$ approximation is often better than the Monarch approximation (but not by much). This is expected because the Monarch approximation algorithm essentially acts as a customized version of SVD. However, if the matrix to be approximated is sparse, the error of the Monarch approximation is often superior, and the sparser the matrix, the more superior it becomes.

Monarch Extension

So far, we have assumed that the matrices discussed are $n$-th order square matrices and $n=m^2$ is a square number. While the square matrix condition might be acceptable, the $n=m^2$ condition is ultimately too restrictive. Therefore, it is necessary to extend the concept of the Monarch matrix at least to non-square numbers $n$.

Non-perfect Square Order

To this end, let's first introduce some notation. Suppose $b$ is a divisor of $n$. $\mathcal{BD}^{(b,n)}$ denotes the set of all block-diagonal matrices of size $n \times n$, where each block is a $b \times b$ submatrix. Clearly, this is a generalization of the previous $L$ and $R$; according to this notation, we can write $L, R \in \mathcal{BD}^{\sqrt{n},n}$. In addition, we must generalize the permutation matrix $P$. As previously stated, the implementation of $P$ is Px = x.reshape(m, m).transpose().reshape(n). Now we generalize it to Px = x.reshape(n // b, b).transpose().reshape(n), denoted as $P_{(n/b, b)}$.

With these notations, we can define a general Monarch matrix (referencing the appendix of the original paper): \begin{equation}\mathcal{M}^{(b,n)} = \Bigg\{M = P_{(b, n/b)} L P_{(n/b, b)} R\,\Bigg\|\, L\in\mathcal{BD}^{(n/b, n)}, R\in\mathcal{BD}^{(b,n)} \Bigg\}\end{equation} The diagram is as follows:

Extending Monarch matrices to non-perfect square square matrices

The Monarch matrix defined previously can be simply denoted here as $\mathcal{M}^{(n)} = \mathcal{M}^{(\sqrt{n},n)}$. It is not difficult to calculate that $L$ has at most $n^2/b$ non-zero elements and $R$ has at most $nb$. Added together, they are $n^2/b + nb$, which reaches its minimum at $b=\sqrt{n}$. Thus $b=\sqrt{n}$ belongs to one of the sparsest examples.

Forms Only

Readers might be confused: why distinguish between $L \in \mathcal{BD}^{(n/b, n)}$ and $R \in \mathcal{BD}^{(b,n)}$? Wouldn't using one for both work? In fact, this design is to ensure that Eq. \eqref{eq:high-m-lr} still holds under the high-dimensional representation, thus allowing for a similar decomposition algorithm to be derived (readers are encouraged to complete this themselves), and to theoretically guarantee its expressive power.

If we don't care about these theoretical details and only wish to construct a matrix parameterization method with sparse characteristics, then the Monarch matrix can be generalized even more flexibly, for example: \begin{equation}M = \left(\prod_{i=1}^k P_i B_i\right)P_0\end{equation} where $B_1, B_2, \cdots, B_k \in \mathcal{BD}^{(b,n)}$, and $P_0, P_1, \cdots, P_k$ are all permutation matrices. Multiplying by $P_0$ at the end is for the sake of symmetry and is not mandatory. If you feel it's necessary, you can even choose different $b$ values for each $B_i$, i.e., $B_i \in \mathcal{BD}^{(b_i, n)}$.

Furthermore, you can combine low-rank decomposition forms and generalize to non-square block matrices, as shown below:

A Monarch-like matrix parameterization combining low-rankness and sparsity

Based on this analogy, we can further extend the concept of the Monarch matrix to non-square matrices. In short, if we only need a sparsified structural matrix similar to the Monarch matrix and don't care about theoretical details, the results are limited only by our imagination.

Application Examples

Currently, the most significant feature of Monarch matrices is that they are friendly to matrix multiplication. Thus, their primary use is nothing more than replacing parameter matrices in fully connected layers to improve the efficiency of those layers. This is also the main content of the experimental section in the original paper.

We can divide this into two categories: "pre-processing" and "post-processing." "Pre-processing" involves changing the parameter matrices of fully connected layers to Monarch matrices before training the model. This speeds up both training and inference, and the trained model fits the Monarch matrix structure best. "Post-processing" involves having an already trained model; in this case, we use Monarch decomposition to find a Monarch approximation for the parameter matrices of the fully connected layers and replace the original matrices. If necessary, simple fine-tuning can be performed to improve the fine-tuning efficiency or inference efficiency of the original model.

Besides replacing fully connected layers, "Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture" discusses an even more extreme approach—directly replacing the Attention layer as a Token-Mixer module. However, in my view, Monarch-Mixer is not very elegant because, like MLP-Mixer, it replaces the Attention matrix with a learnable matrix. In Monarch-Mixer, it is replaced by a Monarch matrix. This pattern learns static attention, and I am personally skeptical of its generalizability.

Finally, for today’s LLMs, Monarch matrices can also be used to construct Parameter-Efficient Fine-Tuning (PEFT) schemes. We know that LoRA was designed starting from low-rank decomposition. Since low-rankness and sparsity are two parallel routes, shouldn't Monarch matrices, as a representative of sparsity, also be usable to construct a PEFT scheme? I Googled it, and someone has indeed done this; the paper title is "MoRe Fine-Tuning with 10x Fewer Parameters." It is quite fresh, being one of the ICML 2024 Workshop papers.

The King of Butterflies

Finally, let's briefly discuss the fitting capability of the Monarch matrix. "Monarch" refers to a king or sovereign, taken from the "Monarch Butterfly." It was named this way because it is aimed at the earlier "Butterfly Matrix."

What is a Butterfly matrix? Explaining this is actually quite difficult. A Butterfly matrix is a product of a series of ($\log_2 n$) Butterfly factor matrices. A Butterfly factor matrix is a block-diagonal matrix where the matrices on the diagonal are called Butterfly factors (without the word "matrix"). A Butterfly factor is a $2 \times 2$ block matrix where each block is a diagonal matrix (end of nested explanation). As shown below:

Schematic of a Butterfly matrix

For the accurate definition of a Butterfly matrix, please refer to the paper; I won't expand on it here. The name "Butterfly" comes from the author thinking the shape of each Butterfly factor looks like a butterfly—whether it does or not is up to one's own judgment, but the author thought so. From a literal standpoint, a "Monarch Butterfly" is superior to a "Butterfly" (after all, it is the "Emperor"). This implies that the Monarch matrix is stronger than the Butterfly matrix. Indeed, the Monarch paper appendix proves that regardless of what $b$ is, $\mathcal{M}^{(b,n)}$ can cover all $n$-order Butterfly matrices. Moreover, when $n > 512$, $\mathcal{M}^{(b,n)}$ is strictly larger than the set of all $n$-order Butterfly matrices. In other words, whatever a Butterfly matrix can do, a Monarch matrix can also do, but the reverse is not necessarily true.

We can also intuitively perceive the expressive power of the Monarch matrix from "matrix-vector" multiplication complexity. We know that the standard complexity for multiplying an $n \times n$ matrix with an $n$-dimensional vector is $\mathcal{O}(n^2)$. However, for certain structured matrices, it can be lower; for example, a Fourier transform can achieve $\mathcal{O}(n \log n)$, a Butterfly matrix is also $\mathcal{O}(n \log n)$, and a Monarch matrix is $\mathcal{O}(n^{1.5})$. Therefore, Monarch matrices "should" be no weaker than Butterfly matrices. Of course, Butterfly matrices have their advantages too; for example, their inverse and determinant are easier to calculate, which is more convenient for scenarios like Flow models that require inversion and determinants.

Summary

This article introduced the Monarch matrix, a family of matrices proposed by Tri Dao a few years ago that can be decomposed into products of permutation matrices and sparse matrices. It possesses the characteristic of high computational efficiency (as we all know, Tri Dao is synonymous with high performance). It can be used to speed up fully connected layers, construct parameter-efficient fine-tuning methods, and more.