Revisit SSM (II): Remaining Issues of HiPPO

By 苏剑林 | June 5, 2024

Picking up from where we left off, in the previous article "Revisit SSM (I): Linear Systems and the HiPPO Matrix", we discussed in detail the derivation of the HiPPO matrix within the HiPPO approximation framework. Its principle is to dynamically approximate a real-time updated function using orthogonal function bases. The dynamics of the projection coefficients happen to be a linear system, and if orthogonal polynomials are used as the basis, the core matrix of the linear system can be solved analytically; this matrix is known as the HiPPO matrix.

Of course, the previous article focused on the derivation of the HiPPO matrix and did not further analyze its properties. Additionally, questions such as "how to discretize it for application to actual data" and "whether bases other than polynomial bases can be solved analytically" were not discussed in detail. Next, we will supplement the discussion on these related issues.

Discretization Formats

Assuming the reader has read and understood the content of the previous article, we will not provide excessive preamble here. In the previous article, we derived two types of linear ODE systems, which are:

\begin{align} &\text{HiPPO-LegT:}\quad x'(t) = Ax(t) + Bu(t) \label{eq:legt-ode}\\[5pt] &\text{HiPPO-LegS:}\quad x'(t) = \frac{A}{t}x(t) + \frac{B}{t}u(t) \label{eq:legs-ode}\end{align}

where $A, B$ are constant matrices independent of time $t$. The HiPPO matrix primarily refers to matrix $A$. In this section, we discuss the discretization of these two ODEs.

Input Transformation

In practical scenarios, the input data points are discrete sequences $u_0, u_1, u_2, \dots, u_k, \dots$, such as streaming audio signals or text vectors. We hope to use the above ODE systems to memorize these discrete points in real-time. To this end, we first define

\begin{equation}u(t) = u_k,\quad \text{if } t \in [k\epsilon, (k + 1)\epsilon)\end{equation}

where $\epsilon$ is the step size of discretization. This definition means that within the interval $[k\epsilon, (k + 1)\epsilon)$, $u(t)$ is a constant function equal to the value $u_k$. Obviously, defining $u(t)$ this way does not lose any information from the original $u_k$ sequence; therefore, memorizing $u(t)$ is equivalent to memorizing the $u_k$ sequence.

Transforming from $u_k$ to $u(t)$ allows the input signal to become a function on a continuous interval again, facilitating later operations such as integration. Moreover, remaining constant within the discretization interval simplifies the discretized format.

LegT Version

Let's take the LegT-type ODE \eqref{eq:legt-ode} as an example and integrate it on both sides:

\begin{equation}x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} x(s)ds + B\int_t^{t+\epsilon}u(s)ds\end{equation}

where $t=k\epsilon$. According to the definition of $u(t)$, it is constant $u_k$ in the interval $[t, t + \epsilon)$, so the integral of $u(s)$ can be calculated directly:

\begin{equation}x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} x(s)ds + \epsilon B u_k\end{equation}

The subsequent result depends on how we approximate the integral of $x(s)$. If we assume $x(s)$ is approximately constant at $x(t)$ in the interval $[t, t + \epsilon)$, we get the Forward Euler method:

\begin{equation}x(t+\epsilon) - x(t) = \epsilon A x(t) + \epsilon B u_k \quad\Rightarrow\quad x(t+\epsilon) = (I + \epsilon A)x(t) + \epsilon B u_k\end{equation}

If we assume $x(s)$ is approximately constant at $x(t+\epsilon)$ in the interval $[t, t + \epsilon)$, we get the Backward Euler method:

\begin{equation}x(t+\epsilon) - x(t) = \epsilon A x(t+\epsilon) + \epsilon B u_k \quad\Rightarrow\quad x(t+\epsilon) = (I - \epsilon A)^{-1}(x(t) + \epsilon B u_k)\end{equation}

Forward and Backward Euler have the same theoretical accuracy, but the backward method usually has better numerical stability. To be more accurate, if we assume $x(s)$ is approximately constant at $\frac{1}{2}[x(t) + x(t+\epsilon)]$ in the interval $[t, t + \epsilon)$, we obtain the bilinear form:

\begin{equation}\begin{gathered} x(t+\epsilon) - x(t) = \frac{1}{2}\epsilon A [x(t) + x(t+\epsilon)] + \epsilon B u_k \\ \Downarrow \\ x(t+\epsilon) = (I - \epsilon A/2)^{-1}[(I + \epsilon A/2) x(t) + \epsilon B u_k] \end{gathered}\end{equation}

This is also equivalent to first taking half a step with Forward Euler and then another half step with Backward Euler. More generally, we could assume $x(s)$ is approximately constant at $\alpha x(t) + (1 - \alpha) x(t+\epsilon)$, where $\alpha \in [0,1]$, which we will not expand upon here. In fact, we can avoid approximation entirely because combining Equation \eqref{eq:legt-ode} with the fact that $u(s)$ is a constant $u_k$ in the interval $[t, t + \epsilon)$, we can solve it exactly using the "variation of parameters" method. The result is:

\begin{equation}x(t+\epsilon) = e^{\epsilon A} x(t) + A^{-1} (e^{\epsilon A} - I) B u_k\label{eq:legt-ode-sol}\end{equation}

The matrix exponential here is defined by its power series; you can refer to "Appreciating the identity det(exp(A)) = exp(Tr(A))".

LegS Version

Now for the LegS-type ODE. The logic is basically consistent with LegT, and the results are quite similar. First, integrate both sides of Equation \eqref{eq:legs-ode}:

\begin{equation}x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} \frac{x(s)}{s}ds + B\int_t^{t+\epsilon}\frac{u(s)}{s}ds\end{equation}

By the definition of $u(t)$, the $u(s)$ in the second integral is constant $u_k$ in $[t,t+\epsilon)$, so it defaults to the integral of $1/s$, which integrates to $\ln\frac{t+\epsilon}{t}$. Of course, replacing it with the first-order approximation $\frac{\epsilon}{t}$ is also fine, as the transformation from $u_k$ to $u(t)$ has a lot of freedom, and this error is negligible. As for the first integral, we adopt a higher-precision midpoint approximation directly, obtaining:

\begin{equation}\begin{gathered} x(t+\epsilon) - x(t) = \frac{1}{2}\epsilon A\left(\frac{x(t)}{t}+\frac{x(t+\epsilon)}{t+\epsilon}\right) + \frac{\epsilon}{t} B u_k \\[5pt] \Downarrow \\[5pt] x(t+\epsilon) = \left(I - \frac{\epsilon A}{2(t+\epsilon)}\right)^{-1}\left[\left(I + \frac{\epsilon A}{2t}\right)x(t) + \frac{\epsilon}{t} B u_k\right] \end{gathered}\label{eq:legs-ode-bilinear}\end{equation}

In fact, Equation \eqref{eq:legs-ode} can also be solved exactly by noting it is equivalent to:

\begin{equation}Ax(t) + Bu(t) = t x'(t) = \frac{d}{d\ln t} x(t)\end{equation}

This means by performing a change of variable $\tau = \ln t$, the LegS-type ODE can be converted into a LegT-type ODE:

\begin{equation}\frac{d}{d\tau} x(e^{\tau}) = Ax(e^{\tau}) + Bu(e^{\tau})\end{equation}

Using Equation \eqref{eq:legt-ode-sol} we obtain (due to the variable substitution, the time interval changes from $\epsilon$ to $\ln(t+\epsilon) - \ln t$):

\begin{equation}x(t+\epsilon) = e^{(\ln(t+\epsilon) - \ln t) A} x(t) + A^{-1} \big(e^{(\ln(t+\epsilon) - \ln t) A} - I\big) B u_k\label{eq:legs-ode-sol}\end{equation}

However, although the above is an exact solution, it is not as practical as the exact solution in Equation \eqref{eq:legt-ode-sol}. In Equation \eqref{eq:legt-ode-sol}, the matrix exponential part is $e^{\epsilon A}$, which is independent of time $t$ and can be computed once. But in the equation above, $t$ is inside the matrix exponential, meaning the matrix exponential must be recalculated repeatedly during iteration, which is computationally unfriendly. Thus, for the LegS-type ODE, we generally only use Equation \eqref{eq:legs-ode-bilinear} for discretization.

Excellent Properties

Next, LegS is our main focus. The reason for focusing on LegS is easy to guess: based on the derivation assumptions, it is currently the only solved ODE system capable of memorizing the entire history, which is crucial for many scenarios like multi-turn dialogues. Additionally, it possesses other favorable and practical properties.

Timescale Equivariance

For instance, the discretization format of LegS \eqref{eq:legs-ode-bilinear} is step-size independent. We only need to substitute $t=k\epsilon$ and denote $x(k\epsilon)=x_k$ to find that:

\begin{equation} x_{k+1} = \left(I - \frac{A}{2(k + 1)}\right)^{-1}\left[\left(I + \frac{A}{2k}\right)x_k + \frac{1}{k} B u_k\right]\end{equation}

The step size $\epsilon$ is automatically canceled out, naturally reducing one hyperparameter to tune—clear good news for model practitioners. Note that step-size independence is an inherent property of LegS-type ODEs and is not directly related to the specific discretization method. For example, the exact solution \eqref{eq:legs-ode-sol} is also step-size independent:

\begin{equation}x_{k+1} = e^{(\ln(k+1) - \ln k) A} x_k + A^{-1} \big(e^{(\ln(k+1) - \ln k) A} - I\big) B u_k\label{eq:legs-ode-sol-2}\end{equation}

The underlying reason is that LegS-type ODEs satisfy "Timescale equivariance"—if we substitute $t=\lambda\tau$ into the LegS ODE, we get:

\begin{equation}Ax(\alpha\tau) + Bu(\alpha\tau) = (\alpha\tau)\times \frac{d}{d(\alpha\tau)} x(\alpha\tau) = \tau \frac{d}{d\tau}x(\alpha\tau)\end{equation}

This means when we replace $u(t)$ with $u(\alpha t)$, the ODE form of LegS remains unchanged, while the corresponding solution changes from $x(t)$ to $x(\alpha t)$. The direct consequence of this property is: when we choose a larger step size, the recursive format does not need to change because the step size of the result $x_k$ will automatically scale accordingly. This is the fundamental reason why the discretization of LegS-type ODEs is step-size independent.

Polynomial Decay

Another excellent property of the LegS-type ODE is that its memory regarding historical signals undergoes Polynomial decay. This is slower than the exponential decay of conventional RNNs, theoretically allowing for the memory of longer histories and making it less prone to vanishing gradients. To understand this, we can start from the exact solution \eqref{eq:legs-ode-sol-2}. As seen in Equation \eqref{eq:legs-ode-sol-2}, the decay effect per recursion step for historical information can be described by the matrix exponential $e^{(\ln(k+1) - \ln k) A}$. Thus, the total decay effect from step $m$ to step $n$ is:

\begin{equation}\prod_{k=m}^{n-1} e^{(\ln(k+1) - \ln k) A} = e^{(\ln n - \ln m) A}\end{equation}

Review the form of $A$ in HiPPO-LegS:

\begin{equation}A_{n,k} = -\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ n+1, &k = n \\ 0, &k > n\end{array}\right.\end{equation}

As seen from the definition, $A$ is a lower triangular matrix with diagonal elements $-1, -2, -3, \dots$. We know that the diagonal elements of a triangular matrix are exactly its eigenvalues (refer to Triangular matrix). From this, we can see that a $d\times d$ matrix $A$ has $d$ distinct eigenvalues $-1, -2, \dots, -d$, which implies $A$ is diagonalizable, i.e., there exists an invertible matrix $P$ such that $A = P^{-1}\Lambda P$, where $\Lambda = \text{diag}(-1, -2, \dots, -d)$. Thus, we have:

\begin{equation}\begin{aligned} e^{(\ln n - \ln m) A} =&\, e^{(\ln n - \ln m) P^{-1}\Lambda P} \\ =&\, P^{-1} e^{(\ln n - \ln m) \Lambda}P \\ =&\, P^{-1}\,\text{diag}(e^{-(\ln n - \ln m)}, e^{-2(\ln n - \ln m)}, \dots, e^{-d(\ln n - \ln m)})\,P \\ =&\, P^{-1}\,\text{diag}\Big(\frac{m}{n}, \frac{m^2}{n^2}, \dots, \frac{m^d}{n^d}\Big)\,P \\ \end{aligned}\end{equation}

Evidently, the final decay function is a linear combination of functions $1, 2, \dots, d$ of $1/n$. Therefore, the memory of LegS-type ODEs decays at most polynomially, which is more "long-tailed" than exponential decay, leading to theoretically better memory retention.

Computational Efficiency

Finally, we point out that the $A$ matrix of HiPPO-LegS is computationally efficient. Specifically, while a naive implementation of matrix multiplication for a $d\times d$ matrix multiplied by a $d\times 1$ column vector requires $d^2$ multiplications, multiplication of LegS's $A$ matrix with a vector can be reduced to $\mathcal{O}(d)$. Furthermore, we can prove that the discretized version \eqref{eq:legs-ode-bilinear} can also be completed in $\mathcal{O}(d)$.

To understand this, we first rewrite the $A$ matrix of HiPPO-LegS equivalently as:

\begin{equation}A_{n,k} = \left\{\begin{array}{l}n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}, &k \leq n \\ 0, &k > n\end{array}\right.\end{equation}

For a vector $v = [v_0, v_1, \dots, v_{d-1}]$, we have:

\begin{equation}\begin{aligned} (Av)_n = \sum_{k=0}^n A_{n,k}v_k =&\, \sum_{k=0}^n \left(n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}\right)v_k \\ =&\, n v_n -\sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}v_k \end{aligned}\end{equation}

This involves three types of operations: the first term $n v_n$ is an element-wise multiplication of the vector $[0, 1, 2, \dots, d-1]$ and $v$; the second term involves multiplying the vector $[1, \sqrt{3}, \sqrt{5}, \dots, \sqrt{2d-1}]$ element-wise with $v$, followed by a $\text{cumsum}$ operation, and finally element-wise multiplication by $\sqrt{2n+1}$ (i.e., vector $[1, \sqrt{3}, \sqrt{5}, \dots, \sqrt{2d-1}]$). Each step can be completed in $\mathcal{O}(d)$, so the total complexity is $\mathcal{O}(d)$.

Now let's look at Equation \eqref{eq:legs-ode-bilinear}. It contains two "matrix-vector" multiplications. One is $(I+\lambda A)v$, where $\lambda$ is an arbitrary real number; we just proved $Av$ is efficient, so naturally $(I+\lambda A)v$ is too. The second is $(I-\lambda A)^{-1}v$. Next, we will prove this is also efficient. This only requires noting that finding $z=(I-\lambda A)^{-1}v$ is equivalent to solving the equation $v = (I-\lambda A)z$. Using the expression for $Av$ given above, we get:

\begin{equation}v_n = z_n - \lambda \left(n z_n - \sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}z_k\right)\end{equation}

Let $S_n = \sum_{k=0}^n \sqrt{2k+1}z_k$, then $z_n = \frac{S_n - S_{n-1}}{\sqrt{2n+1}}$. Substituting this into the equation above yields:

\begin{equation}v_n = \frac{S_n - S_{n-1}}{\sqrt{2n+1}} - \lambda \left(n \frac{S_n - S_{n-1}}{\sqrt{2n+1}} - \sqrt{2n+1}S_n\right)\end{equation}

Rearranging terms gives:

\begin{equation}S_n = \frac{1 - \lambda n}{1+\lambda n + \lambda}S_{n-1} + \frac{\sqrt{2n+1}}{1+\lambda n + \lambda}v_n\end{equation}

This is a scalar recursion that can be computed completely serially, or parallelized using Prefix Sum algorithms (refer to here). The computational complexity is $\mathcal{O}(d)$ or $\mathcal{O}(d\log d)$, which is much more efficient than $\mathcal{O}(d^2)$.

Fourier Basis

Finally, we conclude with a derivation using the Fourier basis. In the previous article, we used Fourier series to introduce linear systems but only derived results for the sliding window form. For the Legendre polynomial basis, we derived both sliding window and full interval versions (LegT and LegS). So, can the Fourier basis derive a version equivalent to LegS? What difficulties would be involved? We explore this below.

Similarly, we will not repeat the preamble. Following the notation of the previous section, the coefficients for the Fourier basis are:

\begin{equation}c_n(T) = \int_0^1 u(t_{\leq T}(s)) e^{-2i\pi n s}ds\end{equation}

Like LegS, to memorize the signal over the entire $[0, T]$ interval, we need a mapping $[0, 1] \mapsto [0, T]$. Choosing the simplest $t_{\leq T}(s)=sT$ and taking the derivative of both sides with respect to $T$ gives:

\begin{equation}\frac{d}{dT}c_n(T) = \int_0^1 u'(sT) s e^{-2i\pi n s}ds\end{equation}

Integrating by parts yields:

\begin{equation}\begin{aligned} \frac{d}{dT}c_n(T) =&\, \frac{1}{T}\int_0^1 s e^{-2i\pi n s}d u(sT) \\ =&\, \frac{1}{T} u(sT) s e^{-2i\pi n s}\Big\|_{s=0}^{s=1} - \frac{1}{T}\int_0^1 u(sT) d(s e^{-2i\pi n s})\\ =&\, \frac{1}{T} u(T) - \frac{1}{T}\int_0^1 u(sT) e^{-2i\pi n s} ds + \frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds\\ =&\, \frac{1}{T} u(T) - \frac{1}{T}c_n(T) + \frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds\\ \end{aligned}\end{equation}

As mentioned in the previous article, one of the important reasons HiPPO selects Legendre polynomials as the basis is that $(s+1)p_n'(t)$ can be decomposed into a linear combination of $p_0(t), p_1(t), \dots, p_n(t)$, whereas $s e^{-2i\pi n s}$ in the Fourier basis cannot. However, if one allows for error, this assertion does not hold, as we can similarly expand $s$ as a Fourier series:

\begin{equation}s = \frac{1}{2} + \frac{i}{2\pi}\sum_{k\neq 0} \frac{1}{k} e^{2i\pi k s}\end{equation}

This sum has infinite terms; truncating it to finite terms will introduce error. But we can ignore this for now and substitute it directly into the equation:

\begin{equation}\begin{aligned} &\, \frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds \\ =&\, \frac{2i\pi n}{T}\int_0^1 u(sT) \left(\frac{1}{2} + \frac{i}{2\pi}\sum_{k\neq 0} \frac{1}{k} e^{2i\pi k s}\right) e^{-2i\pi n s} ds \\ =&\, \frac{i\pi n}{T}\int_0^1 u(sT) e^{-2i\pi n s} ds - \frac{1}{T}\sum_{k\neq 0} \frac{n}{k}\int_0^1 u(sT) e^{-2i\pi (n - k) s} ds \\ =&\, \frac{i\pi n}{T}c_n(T) - \frac{1}{T}\sum_{k\neq 0} \frac{n}{k}c_{n-k}(T) \\ =&\, \frac{i\pi n}{T}c_n(T) - \frac{1}{T}\sum_{k\neq n} \frac{n}{n - k}c_k(T) \\ \end{aligned}\end{equation}

Thus:

\begin{equation} \frac{d}{dT}c_n(T) = \frac{1}{T} u(T) + \frac{i\pi n - 1}{T}c_n(T) - \frac{1}{T}\sum_{k\neq n} \frac{n}{n - k}c_k(T)\end{equation}

So it can be written as:

\begin{equation}\begin{aligned} x'(t) =&\, \frac{A}{t}x(t) + \frac{B}{t}u(t)\\[8pt] \quad A_{n,k} =&\, \left\{\begin{array}{l}-\frac{n}{n-k}, &k \neq n \\[5pt] i\pi n - 1, &k = n\end{array}\right.\\[8pt] B_n =&\, 1 \end{aligned}\end{equation}

In practice, we only need to truncate $|n|, |k| \leq N$ to get a $(2N+1)\times (2N+1)$ matrix. The error from truncation is actually irrelevant because we introduced finite series approximations when deriving HiPPO-LegT as well, and we didn't consider the error then. Conversely, for a specific task, we choose an appropriate scale ($N$), and one meaning of "appropriate" is that the truncation error is negligible for that task.

For most people, this derivation for the Fourier basis might even be easier to understand because Legendre polynomials are unfamiliar to many readers, especially the identities used in the LegT and LegS derivations. Readers are usually more or less familiar with Fourier series. However, from the results, this Fourier basis version might not be as practical as LegS. Firstly, it introduces complex numbers, which increases implementation complexity. Secondly, the derived $A$ matrix is not a relatively simple lower triangular matrix like in LegS, making theoretical analysis more complex. Therefore, one can treat it as an exercise to deepen the understanding of HiPPO.

Article Summary

In this article, we supplemented the discussion on remaining issues of HiPPO introduced in the previous article. This included how to discretize ODEs, some excellent properties of the LegS-type ODE, and the derivation of results for memorizing the entire historical interval using the Fourier basis (the Fourier version of LegS), aiming to gain a more comprehensive understanding of HiPPO.