Revisiting SSM (Part 1): Linear Systems and HiPPO Matrices

By 苏剑林 | May 24, 2024

A few days ago, I read several articles introducing SSM (State Space Models) and realized that I had never seriously understood the core of SSM. Therefore, I decided to study the relevant content of SSM in depth and, in passing, started this new series to record what I have learned.

The concept of SSM has a long history, but here we specifically refer to SSM in the context of deep learning. Generally, the seminal work is considered to be S4 from 2021, which is not too old. The most recent and popular variant of SSM is likely last year's Mamba. Of course, when we talk about SSM, we might also refer broadly to all linear RNN models; in this sense, RWKV, RetNet, and the LRU we introduced in "Google's New Work Attempts to 'Revive' RNN: Can RNN Shine Again?" can all be categorized as such. Many SSM variants strive to become competitors to the Transformer. Although I do not believe there is a possibility of complete replacement, the elegant mathematical properties of SSM themselves are well worth studying.

Although we say that SSM originated with S4, before S4, there was a very powerful foundational work for SSM called "HiPPO: Recurrent Memory with Optimal Polynomial Projections" (referred to as HiPPO for short). So, this article begins with HiPPO.

Basic Form

As a side note, the first author of the representative SSM works mentioned above—HiPPO, S4, and Mamba—is Albert Gu. He has many other SSM-related works. It is no exaggeration to say that these efforts built the foundation of the SSM edifice. Regardless of the prospects for SSM, the spirit of unremittingly delving into the same subject is truly admirable.

Let's get back to the point. For readers who already have some understanding of SSM, it is likely known that SSM modeling uses a linear ODE (Ordinary Differential Equation) system: \begin{equation}\begin{aligned} x'(t) =&\, A x(t) + B u(t) \\ y(t) =&\, C x(t) + D u(t) \end{aligned}\label{eq:ode}\end{equation} where $u(t)\in\mathbb{R}^{d_i}, x(t)\in\mathbb{R}^{d}, y(t)\in\mathbb{R}^{d_o}, A\in\mathbb{R}^{d\times d}, B\in\mathbb{R}^{d\times d_i}, C\in\mathbb{R}^{d_o\times d}, D\in\mathbb{R}^{d_o\times d_i}$. Of course, we can discretize it, in which case it becomes a linear RNN model, a part we will expand upon in later articles. Regardless of whether it is discretized or not, the keyword is "linear." Then, a very natural question arises: Why a linear system? Is a linear system enough?

We can answer this question from two perspectives: linear systems are both simple enough and complex enough. Simple means that theoretically, linearization is often the most basic approximation of a complex system, so linear systems are usually an unavoidable fundamental point. Complex means that even such a simple system can fit exceptionally complex functions. To understand this, we only need to consider a simple example in $\mathbb{R}^4$: \[ x'(t) =\begin{pmatrix} 1 & 0 & 0 & 0 \\ 0 & -1 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & -1 & 0 \end{pmatrix}x(t) \] The basic solution for this example is $x(t) = (e^t, e^{-t}, \sin t, \cos t)$. What does this imply? It means that as long as $d$ is large enough, the linear system can fit sufficiently complex functions through combinations of exponential and trigonometric functions. We know that the powerful Fourier series is just a combination of trigonometric functions, and adding exponential functions clearly makes it even stronger. Thus, one can imagine that linear systems have sufficiently complex fitting capabilities.

Of course, these explanations are somewhat "post-hoc." The results provided by HiPPO are more fundamental: when we attempt to use orthogonal bases to approximate a dynamically updated function, the result is precisely the linear system above. This means that HiPPO not only tells us that linear systems can approximate sufficiently complex functions but also tells us how to approximate them and even what the degree of approximation is.

Finite Compression

Next, we only consider the special case where $d_i=1$; the $d_i > 1$ case is simply a parallel generalization of $d_i=1$. At this point, the output of $u(t)$ is a scalar. Furthermore, as a starting point, let's assume $t\in[0, 1]$. The goal of HiPPO is: to use a finite-dimensional vector to store the information of $u(t)$ over this segment.

This seems like an impossible requirement because $t\in[0,1]$ means $u(t)$ may be equivalent to a vector composed of an infinite number of points, and compressing it into a finite-dimensional vector might lead to severe distortion. However, if we make some assumptions about $u(t)$ and allow for some loss, this compression is possible, and most readers have already tried it. For example, when $u(t)$ is $(n+1)$-th order differentiable at a certain point, its $n$-th order Taylor expansion is often a good approximation of $u(t)$. Thus, we can store only the $n+1$ coefficients of the expansion as an approximate representation of $u(t)$, successfully compressing $u(t)$ into an $(n+1)$-dimensional vector.

Of course, for actual data encountered, the condition of being "$(n+1)$-th order differentiable" is extremely harsh. We usually prefer using orthogonal function basis expansions under square-integrable conditions, such as Fourier series. The formula for calculating its coefficients is: \begin{equation}c_n = \int_0^1 u(t) e^{-2i\pi n t}dt \label{eq:fourier-coef-1}\end{equation} By choosing a sufficiently large integer $N$ and keeping only the coefficients where $|n|\leq N$, we compress $u(t)$ into a $(2N + 1)$-dimensional vector.

Next, the difficulty level increases. Earlier we said $t\in[0,1]$, which is a static interval. In practice, $u(t)$ represents a continuously collected signal, so new data is constantly entering. For example, if we have approximated the data for the interval $[0,1]$, data for $[1,2]$ will arrive immediately. You need to update the approximation result to try to remember the entire $[0,2]$ interval, followed by $[0,3], [0,4]$, and so on. We call this "online function approximation." The Fourier coefficient formula $\eqref{eq:fourier-coef-1}$ only applies to the interval $[0,1]$, so it needs to be generalized.

To this end, let $t\in[0,T]$, and let $s\mapsto t_{\leq T}(s)$ be a mapping from $[0,1]$ to $[0,T]$. Then, when $u(t_{\leq T}(s))$ is treated as a function of $s$, its domain of definition is $[0,1]$, so we can reuse equation $\eqref{eq:fourier-coef-1}$: \begin{equation}c_n(T) = \int_0^1 u(t_{\leq T}(s)) e^{-2i\pi n s}ds \label{eq:fourier-coef-2}\end{equation} Here, we have added the marker $(T)$ to the coefficients to indicate that the coefficients will change as $T$ changes.

Linearity Emerges

There are infinitely many functions that can map $[0,1]$ to $[0,T]$, and the final result varies with $t_{\leq T}(s)$. Some relatively intuitive and simple choices are as follows:

1. $t_{\leq T}(s) = sT$, which maps $[0,1]$ uniformly to $[0,T]$;

2. Note that $t_{\leq T}(s)$ does not necessarily have to be surjective. So, something like $t_{\leq T}(s)=s + T - 1$ is allowed. This means only the information from the most recent window $[T-1,T]$ is retained, and earlier parts are discarded. More generally, we could have $t_{\leq T}(s)=sw + T - w$, where $w$ is a constant, meaning information before $T-w$ is discarded;

3. One can also choose a non-uniform mapping, such as $t_{\leq T}(s) = T\sqrt{s}$. This is also a surjective mapping from $[0,1]$ to $[0,T]$, but when $s=1/4$, it maps to $T/2$. This means that while we pay attention to the global history, we simultaneously place more emphasis on information near time $T$.

Now, let's take $t_{\leq T}(s)=sw + T - w$ as an example and substitute it into equation $\eqref{eq:fourier-coef-2}$: \[c_n(T) = \int_0^1 u(sw + T - w) e^{-2i\pi n s}ds\] Now we take the derivative with respect to $T$ on both sides: \begin{equation}\begin{aligned} \frac{d}{dT}c_n(T) =&\, \int_0^1 u'(sw + T - w) e^{-2i\pi n s}ds \\ =&\, \left.\frac{1}{w} u(sw + T - w) e^{-2i\pi n s}\right|_{s=0}^{s=1} + \frac{2i\pi n}{w}\int_0^1 u(sw + T - w) e^{-2i\pi n s}ds \\ =&\, \frac{1}{w} u(T) - \frac{1}{w} u(T-w) + \frac{2i\pi n}{w} c_n(T) \\ \end{aligned}\label{eq:fourier-dc}\end{equation} In the second equality, we used integration by parts. Since we keep only the coefficients $|n|\leq N$, according to the Fourier series formula, we can consider the following to be a good approximation of $u(sw + T - w)$: \[u(sw + T - w) \approx \sum_{k=-N}^{k=N} c_k(T) e^{2i\pi k s}\] Then $u(T - w) = u(sw + T - w)|_{s=0}\approx \sum\limits_{k=-N}^{k=N} c_k(T)$. Substituting this into equation $\eqref{eq:fourier-dc}$ gives: \[\frac{d}{dT}c_n(T) \approx \frac{1}{w} u(T) - \frac{1}{w} \sum_{k=-N}^{k=N} c_k(T) + \frac{2i\pi n}{w} c_n(T)\] Replacing $T$ with $t$, and grouping all $c_n(t)$ together as $x(t) = (c_{-N},c_{-(N-1)},\dots,c_0,\dots,c_{N-1},c_N)$, and not distinguishing between $\approx$ and $=$, we can write: \[x'(t) = Ax(t) + Bu(t),\quad A_{n,k} = \left\{\begin{array}{l}(2i\pi n - 1)/w, &k=n \\ -1/w,&k\neq n\end{array}\right.,\quad B_n = 1/w\] This results in a linear ODE system as shown in equation $\eqref{eq:ode}$. That is, when we attempt to use a Fourier series to remember the state within the most recent window of a real-time function, the result naturally leads to a linear ODE system.

General Framework

Of course, this was choosing a specific $t_{\leq T}(s)$. Choosing a different $t_{\leq T}(s)$ might not lead to such a simple result. Furthermore, the conclusion for Fourier series is in the complex domain. Although it can be further realified, the form becomes more complicated. Therefore, we should generalize the process of the previous section into a general framework to obtain more general and simpler pure real-number conclusions.

Let $t\in[a,b]$, and given a target function $u(t)$ and a function basis $\{g_n(t)\}_{n=0}^N$, we wish to approximate the former with a linear combination of the latter. The goal is to minimize the $L_2$ distance: \begin{equation}\mathop{\text{argmin}}_{c_1,\dots,c_N}\int_a^b \left[u(t) - \sum_{n=0}^N c_n g_n(t)\right]^2 dt\end{equation} We mainly consider this in the real domain, so the square in the brackets is sufficient without taking the absolute value. A more generalized target function could also include a weight function $\rho(t)$, but we won't consider that here, as the main conclusions of HiPPO don't really rely on the weight function.

Expanding the target function, we get: \[\int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{m=0}^N\sum_{n=0}^N c_m c_n \int_a^b g_m(t) g_n(t) dt\] Here we only consider orthonormal function bases, defined as $\int_a^b g_m(t) g_n(t) dt = \delta_{m,n}$, where $\delta_{m,n}$ is the Kronecker delta function. In this case, the above equation simplifies to: \[\int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{n=0}^N c_n^2 \] This is just a quadratic function with respect to $c_n$, and its minimum has an analytical solution: \begin{equation}c^*_n = \int_a^b u(t) g_n(t)dt\end{equation} This is also called the inner product of $u(t)$ and $g_n(t)$, which is a parallel generalization of the inner product of finite-dimensional vector spaces to function spaces. For simplicity, when it does not cause confusion, we assume $c_n$ is $c^*_n$.

Next, the processing is exactly the same as in the previous section. We want to consider the approximation of $u(t)$ for a general $t\in[0, T]$. So, we find a mapping $s\mapsto t_{\leq T}(s)$ from $[a,b]$ to $[0,T]$, and then calculate the coefficients: \begin{equation}c_n(T) = \int_a^b u(t_{\leq T}(s)) g_n(s) ds\end{equation} Similarly, we take the derivative with respect to $T$ on both sides and use integration by parts: \begin{equation}\scriptsize\begin{aligned} \frac{d}{dT}c_n(T) =&\, \int_a^b u'(t_{\leq T}(s)) \frac{\partial t_{\leq T}(s)}{\partial T} g_n(s) ds = \int_a^b \left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s) d u(t_{\leq T}(s)) \\ =&\, \left.u(t_{\leq T}(s))\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right|_{s=a}^{s=b} - \int_a^b u(t_{\leq T}(s)) \,d\left[\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right] \end{aligned}\label{eq:hippo-base}\end{equation}

Invoking Legendre

The subsequent calculations depend on the specific forms of $g_n(t)$ and $t_{\leq T}(s)$. The full name of HiPPO is High-order Polynomial Projection Operators. The first P stands for Polynomial, so the key to HiPPO is selecting polynomials as the basis. Now, we invite another great master after Fourier—Legendre. The function basis we will select next is precisely the "Legendre polynomials" named after him.

The Legendre polynomial $p_n(t)$ is an $n$-th degree function of $t$, defined on the interval $[-1,1]$, satisfying: \begin{equation}\int_{-1}^1 p_m(t) p_n(t) dt = \frac{2}{2n+1}\delta_{m,n}\end{equation} So the $p_n(t)$ are only orthogonal, not yet standard (normalized so the integral of the square is 1). $g_n(t)=\sqrt{\frac{2n+1}{2}} p_n(t)$ is the orthonormal basis.

When we perform the Gram-Schmidt process on the function basis $\{1,t,t^2,\dots, t^n\}$, the results are precisely the Legendre polynomials. Compared to the Fourier basis, the advantage of Legendre polynomials is that they are purely defined in the real space, and the format of polynomials helps simplify parts of the derivation of $t_{\leq T}(s)$, as we will see later. Legendre polynomials have many different definitions and properties; we won't expand on them all here. Interested readers can refer to the Wikipedia link provided.

Next, we use two recurrence relations to derive an identity. These two recurrence relations are: \begin{align} p_{n+1}'(t) - p_{n-1}'(t) = (2n+1)p_n(t) \label{eq:leg-r1}\\[5pt] p_{n+1}'(t) = (n + 1)p_n(t) + t p_n'(t) \label{eq:leg-r2} \end{align} Iterating from the first formula $\eqref{eq:leg-r1}$, we get: \begin{equation}\begin{aligned} p_{n+1}'(t) =&\, (2n+1)p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \dots \\ =&\, \sum_{k=0}^n (2k+1) \chi_{n-k} p_k(t) \end{aligned}\label{eq:leg-dot}\end{equation} where $\chi_k=1$ if $k$ is even and $\chi_k=0$ otherwise. Substituting this into the second formula $\eqref{eq:leg-r2}$ gives: \[t p_n'(t) = n p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \dots\] Hence: \begin{equation}\begin{aligned} (t+1) p_n'(t) =&\, n p_n(t) + (2n-1)p_{n-1}(t) + (2n-3)p_{n-2}(t) + \dots\\ =&\,-(n+1) p_n(t) + \sum_{k=0}^n (2k + 1) p_k(t) \end{aligned}\label{eq:leg-dot-t1}\end{equation} These are the identities we will use shortly. Additionally, Legendre polynomials satisfy $p_n(1)=1$ and $p_n(-1)=(-1)^n$, boundary values which will also be used later.

Just as there isn't only one set of orthogonal bases in an $n$-dimensional space, there isn't only one kind of orthogonal polynomial. For example, there are Chebyshev polynomials. If we take into account the weighted target function (i.e., $\rho(t)\not\equiv 1$), there are also Laguerre polynomials, etc. These are mentioned in the original paper, but the main conclusions of HiPPO are based on the Legendre polynomial expansion, so the rest won't be discussed further here.

Sliding Window

After completing the preparations, we can substitute the specific $t_{\leq T}(s)$ to perform calculations. The calculation process is largely similar to the Fourier series example, except the basis functions are replaced by the orthonormal basis $g_n(t)=\sqrt{\frac{2n+1}{2}} p_n(t)$ constructed from Legendre polynomials. As our first example, we similarly consider only retaining information from the most recent window. In this case, $t_{\leq T}(s) = (s + 1)w / 2 + T - w$ maps $[-1,1]$ to $[T-w,T]$. The original paper calls this case "LegT (Translated Legendre)."

Substituting directly into equation $\eqref{eq:hippo-base}$, we immediately obtain: \[\small\frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{w}\left[u(T) - (-1)^n u(T-w)\right] - \frac{2}{w}\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds\] We first handle the $u(T-w)$ term. Following the same idea as the Fourier series, we truncate at $n \leq N$ as an approximation of $u((s + 1)w / 2 + T - w)$: \begin{equation}u((s + 1)w / 2 + T - w)\approx \sum_{k=0}^N c_k(T)g_k(s)\end{equation} Thus $u(T-w)\approx \sum\limits_{k=0}^N c_k(T)g_k(-1) = \sum\limits_{k=0}^N (-1)^k c_k(T) \sqrt{\frac{2k+1}{2}}$. Next, using equation $\eqref{eq:leg-dot}$, we get: \begin{equation}\begin{aligned} &\,\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds \\ =&\,\int_{-1}^1 u((s + 1)w / 2 + T - w) \sqrt{\frac{2n+1}{2}} p_n'(s) ds \\ =&\, \int_{-1}^1 u((s + 1)w / 2 + T - w)\sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} (2k+1) \chi_{n-1-k} p_k(s)\right]ds \\ =&\, \int_{-1}^1 u((s + 1)w / 2 + T - w)\sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} \sqrt{2(2k+1)} \chi_{n-1-k} g_k(s)\right]ds \\ =&\, \sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T) \end{aligned}\right. \end{equation} Integrating these results, we have: \begin{equation}\begin{aligned} \frac{d}{dT}c_n(T) \approx &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2(2n+1)}}{w} (-1)^n \overbrace{\sum\limits_{k=0}^N (-1)^k c_k(T) \sqrt{\frac{2k+1}{2}}}^{u(T-w)} \\ &\quad- \frac{2}{w}\overbrace{\sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T)}^{\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds} \\[12pt] = &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=0}^N (-1)^{n-k} c_k(T) \sqrt{2k+1} \\ &\quad- \frac{2}{w}\sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T) \\[12pt] = &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=n}^N (-1)^{n-k} c_k(T) \sqrt{2k+1} \\ &\quad- \frac{\sqrt{2n+1}}{w}\sum_{k=0}^{n-1} \sqrt{2k+1} \underbrace{\left(2\chi_{n-1-k} + (-1)^{n-k}\right)}_{\equiv 1}c_k(T) \\ \end{aligned}\label{eq:leg-t}\end{equation} Again, replacing $T$ with $t$ and grouping all $c_n(t)$ together as $x(t) = (c_0,c_1,\dots,c_N)$, we can write according to the above equation: \begin{equation}\begin{aligned} x'(t) =&\, Ax(t) + Bu(t)\\[8pt] \quad A_{n,k} =&\, -\frac{1}{w}\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ (-1)^{n-k}\sqrt{(2n+1)(2k+1)}, &k \geq n\end{array}\right.\\[8pt] B_n =&\, \frac{1}{w}\sqrt{2(2n+1)} \end{aligned}\label{eq:leg-t-hippo-1}\end{equation} We can also introduce a scaling factor for each $c_n(T)$ to generalize the results. For example, if we let $c_n(T) = \lambda_n \tilde{c}_n(T)$, substituting into equation $\eqref{eq:leg-t}$ and rearranging yields: \begin{equation}\begin{aligned} \frac{d}{dt}\tilde{c}_n(T) \approx &\, \frac{\sqrt{2(2n+1)}}{w\lambda_n}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=n}^N (-1)^{n-k} \tilde{c}_k(T) \frac{\lambda_k\sqrt{2k+1}}{\lambda_n} \\ &\quad- \frac{\sqrt{2n+1}}{w}\sum_{k=0}^{n-1} \frac{\lambda_k\sqrt{2k+1}}{\lambda_n} \tilde{c}_k(T) \\ \end{aligned}\end{equation} If we take $\lambda_n = \sqrt{2}$, then $A$ remains unchanged, and $B_n = \frac{1}{w}\sqrt{2n+1}$, which aligns with the results in the original paper. If we take $\lambda_n = \frac{2}{\sqrt{2n+1}}$, we get the results found in Legendre Memory Units: \begin{equation}\begin{aligned} x'(t) =&\, Ax(t) + Bu(t)\\[8pt] \quad A_{n,k} =&\, -\frac{1}{w}\left\{\begin{array}{l}2n+1, &k < n \\ (-1)^{n-k}(2n+1), &k \geq n\end{array}\right.\\[8pt] B_n =&\, \frac{1}{w}(2n+1) \end{aligned}\label{eq:leg-t-hippo-2}\end{equation} These forms are theoretically equivalent but may have different numerical stability properties. For example, generally, when the behavior of $u(t)$ is not particularly poor, we can expect that the larger $n$ is, the relatively smaller the value of $|c_n|$ will be. Thus, using $c_n$ directly might result in the scales of each component of the $x(t)$ vector not being equal. Such a system is prone to numerical stability issues during actual calculation. Using $\lambda_n = \frac{2}{\sqrt{2n+1}}$ to switch to $\tilde{c}_n$ means that components with smaller numerical values are appropriately amplified, which may help mitigate multi-scale issues and make numerical calculations more stable.

Entire Interval

Now let's continue with another example: $t_{\leq T}(s) = (s + 1)T / 2$. This maps $[-1,1]$ uniformly to $[0,T]$, which means we have not discarded any historical information and treat all history equally. The original paper calls this case "LegS (Scaled Legendre)."

Similarly, substituting into equation $\eqref{eq:hippo-base}$ yields: \[\frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{T}u(T) - \frac{1}{T}\int_{-1}^1 u((s + 1)T / 2) \left[g_n(s) + (s+1) g_n'(s)\right] ds\] Using formula $\eqref{eq:leg-dot-t1}$, we get: \begin{equation}\begin{aligned} &\,\int_{-1}^1 u((s + 1)T / 2) \left[g_n(s) + (s+1) g_n'(s)\right] ds \\ =&\,c_n(T) + \int_{-1}^1 u((s + 1)T / 2) (s+1) g_n'(s) ds \\ =&\, c_n(T) + \int_{-1}^1 u((s + 1)T / 2)(s+1) \sqrt{\frac{2n+1}{2}} p_n'(s) \\ =&\, c_n(T) + \int_{-1}^1 u((s + 1)T / 2)\sqrt{\frac{2n+1}{2}}\left[-(n+1) p_n(s) + \sum_{k=0}^n (2k + 1) p_k(s)\right] ds \\ =&\, c_n(T) + \int_{-1}^1 u((s + 1)T / 2)\left[-(n+1) g_n(s) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} g_k(s)\right] ds \\ =&\, -n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} c_k(T) \\ \end{aligned}\end{equation} Therefore, we have: \begin{equation}\frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{T}u(T) - \frac{1}{T}\left(-n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} c_k(T)\right)\label{eq:leg-s}\end{equation} Replacing $T$ with $t$ and grouping all $c_n(t)$ together as $x(t) = (c_0,c_1,\dots,c_N)$, we can write according to the above equation: \begin{equation}\begin{aligned} x'(t) =&\, \frac{A}{t}x(t) + \frac{B}{t}u(t)\\[8pt] \quad A_{n,k} =&\, -\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ n+1, &k = n \\ 0, &k > n\end{array}\right.\\[8pt] B_n =&\, \sqrt{2(2n+1)} \end{aligned}\label{eq:leg-s-hippo}\end{equation} Introducing scaling factors to generalize the result is also feasible: let $c_n(T) = \lambda_n \tilde{c}_n(T)$. Substituting into equation $\eqref{eq:leg-t}$ and rearranging: \[\frac{d}{dT}\tilde{c}_n(T) = \frac{\sqrt{2(2n+1)}}{T\lambda_n}u(T) - \frac{1}{T}\left(-n \tilde{c}_n(T) + \sum_{k=0}^n \frac{\sqrt{(2n+1)(2k + 1)}\lambda_k}{\lambda_n} \tilde{c}_k(T)\right)\] Taking $\lambda_n=\sqrt{2}$ makes $A$ constant and $B$ becomes $B_n = \sqrt{2n+1}$, aligning with the paper. Taken $\lambda_n=\sqrt{\frac{2}{2n+1}}$, we can remove the square roots as in the previous LegT result: \begin{equation}\begin{aligned} x'(t) =&\, \frac{A}{t}x(t) + \frac{B}{t}u(t)\\[8pt] \quad A_{n,k} =&\, -\left\{\begin{array}{l}2n+1, &k < n \\ n+1, &k = n \\ 0, &k > n\end{array}\right.\\[8pt] B_n =&\, 2n+1 \end{aligned}\label{eq:leg-s-hippo-2}\end{equation} The original paper did not consider this situation for unknown reasons.

Further Reflections

Looking back at the entire derivation of Leg-S, we can see that a key step was decomposing $(s+1) g_n'(s)$ into a linear combination of $g_0(s), g_1(s), \dots, g_n(s)$. For orthogonal polynomials, $(s+1) g_n'(s)$ is an $n$-th degree polynomial, so this decomposition must hold exactly. However, for Fourier series, $g_n(s)$ are exponential functions, and a similar decomposition cannot be done, or at least not exactly. Thus, we can say that choosing orthogonal polynomials as the basis primarily serves the purpose of simplifying the subsequent derivations.

It is worth noting that HiPPO is a bottom-up framework. It did not assume from the beginning that the system must be linear. Instead, it starts from the perspective of orthogonal basis approximation and inversely derives that the dynamics of its coefficients satisfy a linear ODE system. In this way, we can be certain that as long as we accept the assumptions made, the capacity of the linear ODE system is sufficient, and there is no need to worry that the capacity of the linear system limits your performance.

Of course, for each solution, the assumptions HiPPO makes and their physical meanings are very clear. Thus, for anyone reusing the HiPPO matrix in SSM, how it stores history and how much it can store is clear from the underlying HiPPO assumptions. For example, LegT only retains the information of the most recent window of size $w$. If you use the HiPPO matrix of LegT, it is similar to a Sliding Window Attention. While LegS can theoretically capture the entire history, there is a resolution issue. Because the dimension of $x(t)$ represents the order of the fit, which is a fixed value, using the same order function basis to fit a function is certainly more accurate for smaller intervals and less accurate for larger intervals. This is just like trying to see a large picture at once: we must stand further back, thus seeing fewer details.

Models such as RWKV and LRU have not reused the HiPPO matrix but have switched to trainable matrices. In principle, this has more potential to break through bottlenecks, but from the previous analysis, one can roughly perceive that different matrices of linear ODEs just use different function bases, but essentially they might all just be the coefficient dynamics of finite-order function basis approximation. Given this, resolution and memory length still cannot be achieved simultaneously. If you want to remember longer inputs and maintain effectiveness, you can only increase the volume of the entire model (i.e., equivalent to increasing hidden_size), which is likely a characteristic of all linear systems.

Article Summary

This article aims to repeat the main derivations of "HiPPO: Recurrent Memory with Optimal Polynomial Projections" (referred to as HiPPO) as simply as possible. HiPPO derives a linear ODE system bottom-up via appropriate memory assumptions and finds the corresponding analytical solutions (HiPPO matrices) for Legendre polynomials. Its results have been used by many subsequent SSMs (State Space Models) and can be considered an important foundational work for SSM.