By 苏剑林 | June 20, 2024
In the previous two articles, "Revisiting SSM (I): Linear Systems and the HiPPO Matrix" and "Revisiting SSM (II): Some Remaining Issues of HiPPO", we introduced the core ideas and derivation of HiPPO—using orthogonal function bases to approximate continuously updating functions in real-time. The dynamics of the fitting coefficients can be expressed as a linear ODE system, and for specific bases and approximation methods, we can precisely calculate the key matrices of the linear system. Additionally, we discussed the discretization and properties of HiPPO. These contents laid the theoretical foundation for subsequent SSM (State Space Model) work.
Next, we will introduce the follow-up application paper, "Efficiently Modeling Long Sequences with Structured State Spaces" (abbreviated as S4). It utilizes the derivation results of HiPPO as a basic tool for sequence modeling and explores efficient computation and training methods from a new perspective. Finally, it validates its effectiveness on many long-sequence modeling tasks, becoming one of the representative works in the revival of SSMs and RNNs.
The sequence modeling framework used by S4 is the following linear ODE system:
\begin{equation}\begin{aligned} x'(t) =&\, A x(t) + B u(t) \\ y(t) =&\, C^* x(t) + D u(t) \end{aligned}\end{equation}Here $u, y, D \in \mathbb{R}$; $x \in \mathbb{R}^d$; $A \in \mathbb{R}^{d \times d}$; $B, C \in \mathbb{R}^{d \times 1}$, and ${}^*$ denotes the conjugate transpose operation (if it is a real matrix, it is simply the transpose). Since the complete model usually includes residual structures, the last term $D u(t)$ can be integrated into the residual. Therefore, we can directly assume $D=0$ to slightly simplify the form without reducing the model's capability.
This system possesses Similarity Invariance. If $\tilde{A}$ is a matrix similar to $A$, i.e., $A = P^{-1}\tilde{A}P$, then by substituting and rearranging, we get:
\begin{equation}\begin{aligned} Px'(t) =&\, \tilde{A} Px(t) + PB u(t) \\ y(t) =&\, ((P^{-1})^* C)^* P x(t) \end{aligned}\end{equation}By treating $Px(t)$ as a whole and replacing the original $x(t)$, the changes in the new system are $(A, B, C) \to (\tilde{A}, PB, (P^{-1})^*C)$, but the output remains completely unchanged. This means that if there exists a matrix $\tilde{A}$ similar to $A$ that makes computation simpler, we can analyze everything within the $\tilde{A}$ framework without changing the results. This is the core strategy for the series of analyses that follow.
Specifically, S4 selects the matrix $A$ as the HiPPO-LegS matrix, namely:
\begin{equation}A_{n,k} = -\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ n+1, &k = n \\ 0, &k > n\end{array}\right.\end{equation}The peculiarity of this choice is that the ODE we previously derived for LegS was in the form $x'(t) = \frac{A}{t} x(t) + \frac{B}{t} u(t)$, while the ODE for LegT was $x'(t) = A x(t) + B u(t)$. Now, we are using the LegT-style ODE paired with the LegS $A$ matrix. This leads to the question: what effect does this combination have? For example, is its memory of history still complete and "equitable" like LegS?
The answer is no—the ODE system selected for S4 has a memory of history that decays exponentially. We can understand this from two angles.
The first angle starts from the transformation discussed in "Revisiting SSM (II): Some Remaining Issues of HiPPO". The LegS-type ODE can be equivalently written as:
\begin{equation}Ax(t) + Bu(t) = t x'(t) = \frac{d}{d\ln t} x(t)\end{equation}So, setting $\tau = \ln t$ transforms the LegS-type ODE into a LegT-type ODE with time variable $\tau$, which is exactly what S4 uses. We know that LegS treats every historical point equally, but this is under the premise that the input is $u(t) = u(e^{\tau})$. However, S4's ODE uses $u(\tau)$ directly as input. If we perform uniform discretization on $\tau$, the result is that the weights for each point are not equal—assuming $t \in [0, T]$, writing it in terms of probability density gives $dt/T = \rho(\tau)d\tau$, meaning $\rho(\tau) = e^{\tau}/T$. Thus, the weight is an exponential function of $\tau$: the more recent the history, the greater its weight.
The second angle requires a bit more linear algebra. As we mentioned in "Revisiting SSM (II): Some Remaining Issues of HiPPO", the HiPPO-LegS matrix $A$ can theoretically be diagonalized, and its eigenvalues are $[-1, -2, -3, \cdots]$. Thus, there exists an invertible matrix $P$ such that $A = P^{-1}\Lambda P$, where $\Lambda = \text{diag}(-1, -2, \cdots, -d)$. By similarity invariance, the original system is equivalent to the new system:
\begin{equation}\begin{aligned} x'(t) =&\, \Lambda x(t) + PB u(t) \\ y(t) =&\, C^* P^{-1} x(t) \end{aligned}\end{equation}After discretization (taking Forward Euler as an example):
\begin{equation}x(t+\epsilon) = (I + \epsilon\Lambda) Px(t) + \epsilon P B u(t)\end{equation}Here, $I + \epsilon\Lambda$ is a diagonal matrix where each component is less than 1. This means that with each iteration, the historical information is multiplied by a number less than 1. After multiple steps of superimposition, this results in an exponential decay effect.
Although exponential decay might seem less elegant than LegS's equitable treatment of history, there is no free lunch. For a state $x(t)$ of fixed size, as the memory interval grows larger, the LegS approach of treating history equally can cause the representation of history to become increasingly blurred. In scenarios following the "recency effect" (near is clear, far is fuzzy), this can be counterproductive. Furthermore, the right-hand side of the S4-style ODE does not explicitly contain time $t$, which helps improve training efficiency.
Once we have an understanding of the memory properties of the S4 ODE, we can move to the next step. To process discrete sequences in practice, we first need to discretize the system. In the previous article, we provided two discretization formats with higher precision. One is the bilinear form (Tustin transform):
\begin{equation}x_{k+1} = (I - \epsilon A/2)^{-1}[(I + \epsilon A/2) x_k + \epsilon B u_k] \end{equation}It has second-order accuracy. S4 adopts this discretization format, and it is the format we will focus on in this article. The other format is based on the exact solution for constant-input ODEs:
\begin{equation}x_{k+1} = e^{\epsilon A} x_k + A^{-1} (e^{\epsilon A} - I) B u_k\end{equation}The author's subsequent works, including Mamba, use this format. In that case, $A$ is generally assumed to be a diagonal matrix, because for the LegS matrix $A$, the matrix exponential is not friendly to compute.
Now we define:
\begin{equation}\bar{A}=(I - \epsilon A/2)^{-1}(I + \epsilon A/2),\quad\bar{B}=\epsilon(I - \epsilon A/2)^{-1}B,\quad\bar{C}=C\end{equation}Then we obtain a linear RNN:
\begin{equation}\begin{aligned} x_{k+1} =&\, \bar{A} x_k + \bar{B} u_k \\ y_{k+1} =&\, \bar{C}^* x_{k+1} \\ \end{aligned}\label{eq:s4-r}\end{equation}Where $\epsilon > 0$ is the discretization step size, which is a manually selected hyperparameter.
In the previous article, we also mentioned that the HiPPO-LegS matrix $A$ has efficient computational properties. Specifically, multiplying $A$ or $\bar{A}$ by a vector $x$ allows for an algorithm with $O(d)$ complexity instead of the usual $O(d^2)$. However, this only means that the recursive calculation of equation \eqref{eq:s4-r} is more efficient than a typical RNN. For efficient training, simple recursion is not enough; we need to explore parallel computation methods.
There are two paths for parallelizing linear RNNs: one is the Prefix Sum approach (as mentioned in "Recurrent Neural Networks Again? Google's New Work Revival"), which utilizes Associative Scan algorithms like Upper/Lower, Odd/Even, or Ladner-Fischer (refer to "Prefix Sums and Their Applications"). The other is transforming it into a convolution between matrix and vector sequences, using the Fast Fourier Transform (FFT) for acceleration—this is S4's approach. Regardless of the method, they face the same bottleneck: the calculation of the power matrix $\bar{A}^k$.
Specifically, we usually set the initial state $x_0$ to 0, allowing us to write:
\begin{equation}\begin{aligned} y_1 =&\, \bar{C}^*\bar{B}u_0\\ y_2 =&\, \bar{C}^*(\bar{A}x_0 + \bar{B}u_1) = \bar{C}^*\bar{A}\bar{B}u_0 + \bar{C}^*\bar{B}u_1\\ y_3 =&\, \bar{C}^*(\bar{A}x_1 + \bar{B}u_2) = \bar{C}^*\bar{A}^2 Bu_0 + \bar{C}^*\bar{A}Bu_1 + \bar{C}^*\bar{B}u_2\\[5pt] \vdots \\ y_L =&\, \bar{C}^*(\bar{A} x_{L-1}+\bar{B}u_{L-1}) = \sum_{k=0}^{L-1} \bar{C}^*\bar{A}^k \bar{B}u_{L-k} = \bar{K}_{< L} * u_{< L} \end{aligned}\end{equation}Where $*$ represents the convolution operation, and
\begin{equation}\bar{K}_k = \bar{C}^*\bar{A}^k\bar{B},\quad \bar{K}_{< L} = \big(\bar{K}_0,\bar{K}_1,\cdots,\bar{K}_{L-1}\big),\quad u_{< L} = (u_0,u_1,\cdots,u_{L-1})\end{equation}Note that according to the current convention, $\bar{C}^*\bar{A}^k \bar{B}$ and $u_k$ are scalars, so $\bar{K}_{< L}, u_{< L} \in \mathbb{R}^L$. We know that convolution can be converted into frequency domain multiplication via the (Discrete) Fourier Transform and then transformed back via the inverse. Its complexity is $O(L \log L)$, where $L$ is the sequence length. Although this complexity seems higher than the $O(L)$ of direct recursion, the Fourier Transform is highly parallelizable, making it much faster in practice.
Thus, the problem now is how to efficiently calculate the convolution kernel $\bar{K}_{< L}$, which requires computing the power matrix $\bar{A}^k$. Calculating it by definition involves significant complexity. Of course, if we were only calculating $\bar{A}^k$, it wouldn't be a problem because $A$ is a constant matrix; once $\epsilon$ is chosen, $\bar{A}$ is also constant. No matter how hard the power is to compute, it could be pre-calculated and stored. However, $\bar{A}^k$ is just an intermediate step; we need $\bar{C}^*\bar{A}^k\bar{B}$, and S4 treats $\bar{C}, \bar{B}$ as trainable parameters. Therefore, $\bar{C}^*\bar{A}^k\bar{B}$ cannot be pre-calculated, and even pre-calculating $\bar{A}^k$ is not efficient enough.
Before proceeding further, let's insert the concept of generating functions. This is one of the foundational steps for subsequent efficient computation. For readers who are not very familiar with convolution and DFT, this can also serve as a conceptual guide to understanding how Fourier transforms accelerate convolution.
For a given sequence $a = (a_0, a_1, a_2, \cdots)$, its generating function works by treating each component as a coefficient of a power series:
\begin{equation}\mathcal{G}(z|a) = \sum_{k=0}^{\infty} a_k z^k\end{equation}If there are two sequences $a = (a_0, a_1, a_2, \cdots)$ and $b = (b_0, b_1, b_2, \cdots)$, then the product of their generating functions is:
\begin{equation}\mathcal{G}(z|a)\mathcal{G}(z|b) = \left(\sum_{k=0}^{\infty} a_k z^k\right)\left(\sum_{l=0}^{\infty} b_l z^l\right) = \sum_{k=0}^{\infty}\sum_{l=0}^{\infty}a_k b_l z^{k+l} = \sum_{l=0}^{\infty}\left(\sum_{k=0}^l a_k b_{l-k}\right) z^l \end{equation}Do you notice? The coefficient of the $l$-th term of $\mathcal{G}(z|a)\mathcal{G}(z|b)$ (i.e., the coefficient of $z^{l}$) is exactly the convolution of $a_{< l} = (a_0, \cdots, a_l)$ and $b_{< l} = (b_0, \cdots, b_l)$. If we have a way to quickly compute the values of the generating function and quickly extract its coefficients, we can transform convolution into generating functions, perform simple multiplication, and then extract the corresponding coefficients.
The Discrete Fourier Transform (DFT) is exactly such an approach to building generating functions. First, note that if we only need to perform convolution on the first $L$ terms of $a$ and $b$, the summation in the generating function doesn't need to go to infinity; capping it at $L-1$ is sufficient. For this requirement, DFT does not compute the generating function for all $z$, but selects specific $z = e^{-2i\pi l/L}, l = 0, 1, 2, \dots, L-1$:
\begin{equation}\hat{a}_l = \sum_{k=0}^{L-1} a_k \left(e^{-2i\pi l/L}\right)^k = \sum_{k=0}^{L-1} a_k e^{-2i\pi kl/L}\end{equation}The Inverse DFT (IDFT) to extract coefficients is:
\begin{equation}a_k = \frac{1}{L}\sum_{l=0}^{L-1} \hat{a}_l e^{2i\pi kl/L}\end{equation}Both DFT and IDFT can be efficiently computed using the Fast Fourier Transform (FFT), which is built into most numerical computing frameworks. Thus, there is no efficiency issue. However, note that if you use DFT to calculate convolution directly, a minor tweak is needed. Since $e^{-2i\pi l/L}$ is a periodic function, we cannot distinguish between $e^{-2i\pi l/L}$ and $e^{-2i\pi (l+L)/L}$. When we multiply two $L$-term DFTs, terms with $l \geq L$ appear, which "wrap around" and mix with $e^{-2i\pi k(l-L)/L}$ terms. This results in the sum of two coefficients when doing the IDFT, which is incorrect for a standard linear convolution.
The solution is to change the $L$ in $e^{-2i\pi l/L}$ to $2L$ (while still summing over $L$ terms). This increases the period such that the multiplication results remain within a single period. The definition of DFT becomes:
\begin{equation}\hat{a}_l = \sum_{k=0}^{L-1} a_k e^{-i\pi kl/L}\end{equation}However, since most standard FFT functions do not support adjusting the period independently of the array length, the equivalent approach is to pad $(a_0, a_1, \dots, a_{L-1})$ with $L$ zeros to make it length $2L$, perform a standard DFT, multiply the results, do an IDFT, and take the first $L$ results.
For the convolution kernel $\bar{K}$, we have:
\begin{equation}\mathcal{G}(z|\bar{K}) = \sum_{k=0}^{\infty} \bar{C}^*\bar{A}^k \bar{B}z^k = \bar{C}^*\left(I - z\bar{A}\right)^{-1}\bar{B}\label{eq:k-gen}\end{equation}We discover that the generating function not only accelerates the computation of the convolution, but also transforms the originally complex calculation of power matrices $\bar{A}^k$ into the calculation of matrix inverses $\left(I - z\bar{A}\right)^{-1}$.
What kind of matrix $\bar{A}$ makes $\left(I - z\bar{A}\right)^{-1}$ easy to compute? First, a diagonal matrix is certainly fine. If $\bar{A}$ is diagonal, then $I - z\bar{A}$ is also diagonal, and its inverse is simply the inverse of each diagonal element. Second, if $\bar{A}$ can be diagonalized as $\bar{\Lambda}$, i.e., $\bar{A} = P^{-1}\bar{\Lambda} P$, then $\left(I - z\bar{A}\right)^{-1}$ is also easy to compute because:
\begin{equation}\left(I - z\bar{A}\right)^{-1} = \left(P^{-1}(I - z\bar{\Lambda})P\right)^{-1} = P^{-1}\left(I - z\bar{\Lambda}\right)^{-1} P\end{equation}Can $\bar{A}$ be diagonalized? This depends on whether $A$ can be diagonalized. If $A = P^{-1}\Lambda P$, according to similarity invariance, we can switch entirely to a new system where $A = \Lambda$. By definition, the new $\bar{A}$ would be:
\begin{equation}\begin{aligned} \bar{A}=&\,(I - \epsilon A/2)^{-1}(I + \epsilon A/2) \\ =&\,(I - \epsilon\Lambda/2)^{-1}(I + \epsilon\Lambda/2) \end{aligned}\end{equation}which is clearly a diagonal matrix.
Can $A$ be diagonalized? The answer is: theoretically yes, but practically no. Theoretically, almost all matrices can be diagonalized in the complex field. We even gave the eigenvalues of the LegS $A$ matrix as $[-1, -2, -3, \cdots]$ in the previous article, so we even know what the diagonal matrix should look like. Practically no means it is difficult for numerical computation. Numerical computation must consider precision, memory, and time. If any of these exceed the tolerance, a theoretically feasible algorithm fails in practice.
For matrix $A$, the practical difficulty is that the matrix $P$ required to diagonalize $A$ suffers from numerical instability issues—specifically, issues caused by finite floating-point precision. In the original paper, the authors provide the analytical solution for $P$ without much explanation and then verify it, which is not conducive to reader understanding. Below, let's look at this from the perspective of eigenvector calculation.
Diagonalizing $A$ is equivalent to diagonalizing $-A$. Since the eigenvalues of $A$ are all negative, for simplicity, we consider the diagonalization of $-A$, which has $d$ distinct eigenvalues $\lambda = 1, 2, \dots, d$. The matrix $P$ required for diagonalization is the stack of its eigenvectors, so finding $P$ is essentially finding the eigenvectors. For a matrix with known eigenvalues, the direct method is to solve the equation $-Av = \lambda v$.
In the "Efficient Computation" section of the previous article, we already provided the result for the $n$-th component of $Av$:
\begin{equation}(Av)_n = n v_n -\sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}v_k \end{equation}So $-Av = \lambda v$ means:
\begin{equation}\sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}v_k - n v_n = \lambda v_n\end{equation}Let $S_n = \sum\limits_{k=0}^n \sqrt{2k+1}v_k$, then $\sqrt{2n+1}v_n = S_n - S_{n-1}$. Rearranging slightly gives:
\begin{equation}S_{n-1} = \frac{\lambda - n - 1}{\lambda + n}S_n\end{equation}Note that $-Av = \lambda v$ is an indeterminate equation, giving us some flexibility (eigenvectors are not unique). Since the maximum value of $n$ is $d-1$, we can set $S_{d-1} = 1$ and recurse backwards. When $\lambda - n - 1 = 0$, we get $S_{\lambda - 1} = 0$, and for all $n < \lambda - 1$, $S_n = 0$. For $n > \lambda - 1$, we have:
\begin{equation}S_n = (-1)^{d-n-1}\frac{(d-\lambda)! (n+\lambda)!}{(d+\lambda-1)! (n-\lambda + 1)!}\end{equation}Since we want to prove the numerical instability of $P$, observing one eigenvector is enough. Let's take $n = \lambda = d/3$ (if $d$ is not a multiple of 3, just take the integer part; the conclusion holds). Then:
\begin{equation}|S_{d/3}| = \frac{\left(\frac{2d}{3}\right)! \left(\frac{2d}{3}\right)!}{\left(\frac{4d}{3}-1\right)!} \sim \mathcal{O}(\sqrt{d}\,2^{-4d/3})\end{equation}The final $\sim$ is obtained using Stirling's formula. From this result, we see that for the eigenvalue $d/3$, there is an exponential decay from $S_{d-1}$ down to $S_{d/3}$ (or an explosion in the other direction). Likewise, the components of the eigenvector $v_{d-1}$ down to $v_{d/3}$ exhibit similar decay. Within the finite precision of floating-point numbers, it is extremely difficult to accurately process such eigenvectors. Consequently, directly diagonalizing $A$ via matrix $P$ is numerically unstable.
Besides diagonal matrices, if $\bar{A}$ can be decomposed as a low-rank update, we can also simplify the calculation of $\left(I - z\bar{A}\right)^{-1}$. This is because we have the following Woodbury Identity:
\begin{equation}(I - UV^*)^{-1} = \sum_{k=0}^{\infty} (UV^*)^k = I + U\left(\sum_{k=0}^{\infty}(V^* U)^k\right)V^* = I + U(I - V^* U)^{-1} V^*\end{equation}Here $U, V \in \mathbb{R}^{d \times r}$. The derivation uses the identity $(UV^*)^k = U(V^* U)^{k-1}V^*$. If $d \gg r$, then calculating $(I - V^* U)^{-1}$ is theoretically much cheaper than $(I - UV^*)^{-1}$, thus accelerating computation. In particular, if $r=1$, then $(I - V^* U)^{-1}$ is just the inverse of a scalar.
However, we know that $A$ is a lower triangular matrix with non-zero diagonal elements, so it must be full rank. Combined with the conclusion of the previous section, $A$ is neither low-rank nor practically diagonalizable. So these don't apply, right? Wrong! Using the Woodbury Identity, we can derive its more general version:
\begin{equation}\begin{aligned} (M - UV^*)^{-1} =&\, (M(I - (M^{-1}U)V^*))^{-1} = (I - (M^{-1}U)V^*)^{-1}M^{-1} \\ =&\, (I + M^{-1}U(I - V^*M^{-1}U)^{-1} V^*)M^{-1} \\ =&\, M^{-1} + M^{-1}U(I - V^*M^{-1}U)^{-1} V^*M^{-1} \\ \end{aligned}\end{equation}This result tells us that if the inverse of $M$ is easy to compute, then adding or subtracting a low-rank matrix from $M$ still yields an easily computable inverse. And what kind of matrix has an easy-to-compute inverse? We go back to the previous answer: diagonal matrices. Thus, we can look for a way to express $A$ or $\bar{A}$ in the form "Diagonal + Low-Rank".
In fact, if you look closely, the $A$ matrix itself has the shadows of "Diagonal + Low-Rank". In the previous article, we reformatted the definition of $A$ as:
\begin{equation}A_{n,k} = \left\{\begin{array}{l}n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}, &k \leq n \\ 0, &k > n\end{array}\right.\end{equation}Here $n\delta_{n,k}$ is essentially the diagonal matrix $\text{diag}(0, 1, 2, \dots)$, while $\sqrt{2n+1}\sqrt{2k+1}$ can be rewritten as a low-rank matrix $vv^*$, where $v = [1, \sqrt{3}, \sqrt{5}, \dots]^* \in \mathbb{R}^{d \times 1}$. In other words, if not for the constraint $k > n \Rightarrow A_{n,k} = 0$, $A$ would already be in "Diagonal minus Low-Rank" form.
Although this pattern doesn't hold once the lower-triangular constraint is included, we can cleverlly use the $vv^*$ structure to help construct a new diagonalizable matrix. It must be said that this trick is ingenious—truly a "finishing touch" that makes one marvel at the original authors. Specifically, let's consider $A + \frac{1}{2}vv^*$:
\begin{equation}\left(A + \frac{1}{2}v v^*\right)_{n,k} = \left\{\begin{array}{l}n\delta_{n,k} - \frac{1}{2}\sqrt{2n+1}\sqrt{2k+1}, &k \leq n \\ \frac{1}{2}\sqrt{2n+1}\sqrt{2k+1}, &k > n\end{array}\right.\end{equation}The diagonal elements of this new matrix are exactly $n - \frac{1}{2}(2n+1) = -\frac{1}{2}$. If we add $\frac{1}{2}I$ to this, we get:
\begin{equation}\left(A + \frac{1}{2}v v^*+\frac{1}{2}I\right)_{n,k} = \left\{\begin{array}{} - \frac{1}{2}\sqrt{2n+1}\sqrt{2k+1}, &k < n \\ 0, &k=n \\ \frac{1}{2}\sqrt{2n+1}\sqrt{2k+1}, &k > n\end{array}\right.\end{equation}Here is the key: this is a skew-symmetric matrix! Therefore, it is guaranteed to be diagonalizable in the complex field. Thus, we have decomposed $A$ into a diagonalizable matrix and a low-rank matrix. Some might ask: if $A$ was already diagonalizable but had numerical issues, why is diagonalizing this skew-symmetric matrix fine? Because skew-symmetric matrices can always be diagonalized by an orthogonal matrix (or a unitary matrix in the complex field)! Unitary matrices are generally very numerically stable, so the issues we faced earlier disappear. This is why we don't diagonalize $A$ directly, but instead take this detour to construct a skew-symmetric matrix.
Now we have that there exists a diagonal matrix $\Lambda$ and a unitary matrix $U$ such that $A + \frac{1}{2}vv^* + \frac{1}{2}I = U^*\Lambda U$, which leads to:
\begin{equation}A = U^*\Lambda U - \frac{1}{2}I - \frac{1}{2}v v^* = U^*\left(\Lambda - \frac{1}{2}I - \frac{1}{2}(Uv)(Uv)^*\right) U\end{equation}Stripping away the scaffolding, we find the conclusion simplifies to "$A$ is isomorphic to a diagonal matrix minus a rank-1 matrix": there existing a unitary matrix $U$, diagonal matrix $\Lambda$, and column vectors $u, v$ such that:
\begin{equation}A = U^*\left(\Lambda - uv^*\right) U\end{equation}Note that "Diagonal + Low-Rank" matrices are efficient to multiply by vectors. For instance:
\begin{equation}\left(\Lambda - uv^*\right)x = \Lambda x - u(v^*x)\end{equation}$\Lambda x$ is just a component-wise multiplication, and $u(v^*x)$ is a dot product between $v$ and $x$ yielding a scalar multiplied by $u$. All of these can be done in $O(d)$.
With $A = U^*(\Lambda - uv^*)U$, using similarity invariance again, all our subsequent calculations can be transferred to $A = \Lambda - uv^*$. Let's assume $A = \Lambda - uv^*$ from now on. First, regarding $\bar{A}$:
\begin{equation}\bar{A}=\big(I - \epsilon (\Lambda - uv^*)/2\big)^{-1}\big(I + \epsilon (\Lambda - uv^*)/2\big)\end{equation}Notice that $I - \epsilon (\Lambda - uv^*)/2 = \frac{\epsilon}{2}(D + uv^*)$, where $D = \frac{2}{\epsilon}I - \Lambda$ is a diagonal matrix. Using the Woodbury Identity, we get:
\begin{equation}\big(I - \epsilon (\Lambda - uv^*)/2\big)^{-1} = \frac{2}{\epsilon}(D + uv^*)^{-1} = \frac{2}{\epsilon}\left[D^{-1} - D^{-1}u(I + v^*D^{-1}u)^{-1} v^*D^{-1}\right]\end{equation}This is also in "Diagonal + Low-Rank" form. Multiplying this by $\big(I + \epsilon (\Lambda - uv^*)/2\big)$ completes the calculation for $\bar{A}$, resulting in the product of two "Diagonal + Low-Rank" matrices, which preserves computational efficiency for recursive inference.
Finally, for the convolution kernel required for parallel training, we have already converted it into the generating function in equation \eqref{eq:k-gen}. Now let's calculate it. Using a "common denominator" algebraic trick, we can prove:
\begin{equation}\begin{aligned} \mathcal{G}(z|\bar{K}) = \bar{C}^* \left(I - \bar{A}z\right)^{-1}\bar{B} =&\, \bar{C}^* \left(I - (I - \epsilon A/2)^{-1}(I + \epsilon A/2)z\right)^{-1}\bar{B} \\ =&\, \bar{C}^* \left[(I - \epsilon A/2)^{-1}\big((I - \epsilon A/2)-(I + \epsilon A/2)z\big)\right]^{-1}\bar{B} \\ =&\, \bar{C}^* \big[(I - \epsilon A/2)-(I + \epsilon A/2)z\big]^{-1}(I - \epsilon A/2)\bar{B} \\ =&\, \bar{C}^* \big[(I - \epsilon A/2)-(I + \epsilon A/2)z\big]^{-1}B\epsilon \\ =&\, \bar{C}^* \big[(1-z)I - (1+z)\epsilon A / 2\big]^{-1}B\epsilon \\ =&\, \frac{2}{1+z}\bar{C}^* \left[\frac{2}{\epsilon}\frac{1-z}{1+z}I - A\right]^{-1}B \\ \end{aligned}\end{equation}Substituting $A = \Lambda - uv^*$, we get:
\begin{equation}\mathcal{G}(z|\bar{K}) = \frac{2}{1+z}\bar{C}^* \left[\frac{2}{\epsilon}\frac{1-z}{1+z}I - (\Lambda - uv^*)\right]^{-1}B = \frac{2}{1+z}\bar{C}^* (R_z + uv^*)^{-1}B\end{equation}Here $R_z = \frac{2}{\epsilon}\frac{1-z}{1+z}I - \Lambda$ is a diagonal matrix. Using the Woodbury Identity again:
\begin{equation}\mathcal{G}(z|\bar{K}) = \frac{2}{1+z}\bar{C}^* \left[R_z^{-1} - R_z^{-1}u(I + v^*R_z^{-1}u)^{-1} v^*R_z^{-1}\right]B\end{equation}This is a scalar function of $z$. One detail: what FFT needs is actually the "Truncated Generating Function":
\begin{equation}\mathcal{G}_L(z|\bar{K}) = \sum_{k=0}^{L-1} \bar{C}^*\bar{A}^k \bar{B}z^k = \bar{C}^*(I - z^L\bar{A}^L)\left(I - z\bar{A}\right)^{-1}\bar{B}\end{equation}This means $\bar{C}^*$ in $\mathcal{G}(z|\bar{K})$ is replaced by $\bar{C}^*(I - z^L\bar{A}^L)$, where $L$ is the maximum training length. By substituting $z=e^{-2i\pi l/L}, l=0, 1, 2, \dots, L-1$, we get the DFT of $\bar{K}$, and applying IDFT gives the kernel $\bar{K}$. Furthermore, for $z=e^{-2i\pi l/L}$, $z^L=1$. In this case, we just replace $\bar{C}^*$ with $\bar{C}^*(I - \bar{A}^L)$. Since S4 treats $\bar{C}$ as a trainable parameter, we can treat $\bar{C}^*(I - \bar{A}^L)$ as a single trainable parameter during training, solving for $\bar{C}$ afterward for inference. This avoids calculating $\bar{A}^L$ during training.
After a long and arduous mathematical trek, we have finally combed through the key mathematical details of S4. We hope this is helpful to readers interested in understanding S4. As we can see, S4 is a further refinement and completion of HiPPO. Its crucial contribution is the "Diagonal + Low-Rank" matrix form, which provides the foundation for efficient parallel computation. Before this, $A$ was defined piecewise rather than in a matrix operator form, which made it difficult to apply general linear algebra tools for analysis.
Since the original HiPPO derivation was for a 1D function $u(t)$, the $u_k$ in S4 so far are also scalars. How does S4 handle vector sequence inputs? Quite brute-force: it applies the aforementioned linear RNN to each component independently, with different $\epsilon, B, C$ parameters per component, then concatenates the results. This approach persists even in the author's latest work, Mamba. Of course, there are simplified versions, like S5 (not by Albert Gu), which process vector inputs in a single RNN by changing $B, C$ to matrices. S5 essentially borrows the linear RNN form and HiPPO matrix $A$ from S4 while shedding some of the more complex mathematical details, also achieving good results.
It's somewhat ironic that S4 proposed numerous exquisite mathematical tricks to simplify and accelerate the calculation of $A$, yet subsequent work starting from "Diagonal State Spaces are as Effective as Structured State Spaces", and including Mamba, basically abandoned these parts. They directly assume $A$ to be a diagonal matrix, which makes the RNN part almost the same as the LRU introduced in "Recurrent Neural Networks Again? Google's New Work Revival". Therefore, from today's perspective on SSMs and linear RNNs, the HiPPO and S4 line of work is somewhat "outdated." Many articles explaining Mamba start from HiPPO and S4; in retrospect, that might be "unnecessary."
Of course, for me, spending so much time learning HiPPO and S4 isn't just about understanding or using the latest SSM and RNN models. It's about learning the assumptions and derivations behind HiPPO, understanding the memory mechanisms and bottlenecks of linear systems, and accumulating ideas for building new models and methods in the future. Moreover, the many clever mathematical techniques in HiPPO and S4 are aesthetically pleasing and serve as excellent exercises for improving mathematical skills.
This article introduced S4, the successor to HiPPO. Its key contribution is the "Diagonal Matrix + Low-Rank Matrix" decomposition, which enables efficient parallel computation of the HiPPO matrix. This article primarily focused on the introduction and derivation of the more challenging mathematical details.