The Annotated S4

Efficiently Modeling Long Sequences with Structured State Spaces

Albert Gu, Karan Goel, and Christopher Ré.

Blog Post and Library by Sasha Rush and Sidd Karamcheti

The Structured State Space for Sequence Modeling (S4) architecture is a new approach to very long-range sequence modeling tasks for vision, language, and audio, showing a capacity to capture dependencies over tens of thousands of steps. Especially impressive are the model’s results on the challenging Long Range Arena benchmark, showing an ability to reason over sequences of up to 16,000+ elements with high accuracy.

The paper is also a refreshing departure from Transformers, taking a very different approach to an important problem-space. However, several of our colleagues have also noted privately (and on twitter!) the difficulty of gaining intuition for the model. This blog post is a first step towards this goal of gaining intuition, linking concrete code implementations with explanations from the S4 paper – very much in the style of the annotated transformer. Hopefully this combination of code and literate explanations helps you follow the details of the model.

Table of Contents

Note that this project uses JAX with the Flax NN library. While we personally mainly use Torch, the functional nature of JAX is a good fit for some of the complexities of S4. We make heavy use of vmap, scan, their NN cousins, and most importantly jax.jit to compile fast and efficient S4 layers.

from functools import partial
import jax
import jax.numpy as np
from flax import linen as nn
from jax.nn.initializers import lecun_normal
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve
rng = jax.random.PRNGKey(1)

Part 1: State Space Models

Let’s get started! Our goal is the efficient modeling of long sequences. To do this, we are going to build a new neural network layer based on State Space Models. By the end of this section we will be able to build and run a model with this layer. However, we are going to need some technical background. Let’s work our way through the background of the paper.

The state space model is defined by this simple equation. It maps a 1-D input signal u(t)u(t) to an NN-D latent state x(t)x(t) before projecting to a 1-D output signal y(t)y(t). x(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t) \begin{aligned} x'(t) &= \boldsymbol{A}x(t) + \boldsymbol{B}u(t) \\ y(t) &= \boldsymbol{C}x(t) + \boldsymbol{D}u(t) \end{aligned} Our goal is to simply use the SSM as a black-box representation in a deep sequence model, where A,B,C,D\boldsymbol{A}, \boldsymbol{B}, \boldsymbol{C}, \boldsymbol{D} are parameters learned by gradient descent. For the remainder, we will omit the parameter D\boldsymbol{D} for exposition (or equivalently, assume D=0\boldsymbol{D} = 0 because the term Du\boldsymbol{D}u can be viewed as a skip connection and is easy to compute).

An SSM maps a input u(t)u(t) to a state representation vector x(t)x(t) and an output y(t)y(t). For simplicity, we assume the input and output are one-dimensional, and the state representation is NN-dimensional. The first equation defines the change in x(t)x(t) over time.

Our SSMs will be defined by three matrices – A,B,C\boldsymbol{A}, \boldsymbol{B}, \boldsymbol{C} – which we will learn. For now we begin with a random SSM, to define sizes,

def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

Discrete-time SSM: The Recurrent Representation

To be applied on a discrete input sequence (u0,u1,)(u_0, u_1, \dots ) instead of continuous function u(t)u(t), the SSM must be discretized by a step size Δ\Delta that represents the resolution of the input. Conceptually, the inputs uku_k can be viewed as sampling an implicit underlying continuous signal u(t)u(t), where uk=u(kΔ)u_k = u(k \Delta).

To discretize the continuous-time SSM, we use the bilinear method, which converts the state matrix A\boldsymbol{A} into an approximation A\boldsymbol{\overline{A}}. The discrete SSM is: A=(IΔ/2A)1(I+Δ/2A)B=(IΔ/2A)1ΔBC=C \begin{aligned} \boldsymbol{\overline{A}} &= (\boldsymbol{I} - \Delta/2 \cdot \boldsymbol{A})^{-1}(\boldsymbol{I} + \Delta/2 \cdot \boldsymbol{A}) \\ \boldsymbol{\overline{B}} &= (\boldsymbol{I} - \Delta/2 \cdot \boldsymbol{A})^{-1} \Delta \boldsymbol{B} \\ \boldsymbol{\overline{C}} &= \boldsymbol{C}\\ \end{aligned}

def discretize(A, B, C, step):
    I = np.eye(A.shape[0])
    BL = inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

This equation is now a sequence-to-sequence map ukyku_k \mapsto y_k instead of function-to-function. Moreover the state equation is now a recurrence in xkx_k, allowing the discrete SSM to be computed like an RNN. Concretely, xkRNx_k \in \mathbb{R}^N can be viewed as a hidden state with transition matrix A\boldsymbol{\overline{A}}. xk=Axk1+Bukyk=Cxk \begin{aligned} x_{k} &= \boldsymbol{\overline{A}} x_{k-1} + \boldsymbol{\overline{B}} u_k\\ y_k &= \boldsymbol{\overline{C}} x_k \\ \end{aligned}

As the paper says, this “step” function does look superficially like that of an RNN. We can implement this with a scan in JAX,

def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)[1]

Putting everything together, we can run the SSM by first discretizing, then iterating step by step,

def run_SSM(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)

    # Run recurrence
    return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))

Tangent: A Mechanics Example

To gain some more intuition and test our SSM implementation, we pause from machine learning to implement a classic example from mechanics.

In this example, we consider the forward position y(t)y(t) of a mass attached to a wall with a spring. Over time, varying force u(t)u(t) is applied to this mass. The system is parameterized by mass (mm), spring constant (kk), friction constant (bb). We can relate these with the following differential equation:

my(t)=u(t)by(t)ky(t)\begin{aligned} my''(t) = u(t) - by'(t) - ky(t) \end{aligned}

Rewriting this in matrix form yields an SSM in the following form:

A=[01k/mb/m]B=[01/m]C=[10] \begin{aligned} \boldsymbol{A} &= \begin{bmatrix} 0 & 1 \\ -k/m & -b/m \end{bmatrix} \\ \boldsymbol{B} & = \begin{bmatrix} 0 \\ 1/m \end{bmatrix} & \boldsymbol{C} = \begin{bmatrix} 1 & 0 \end{bmatrix} \\ \end{aligned}

def example_mass(k, b, m):
    A = np.array([[0, 1], [-k / m, -b / m]])
    B = np.array([[0], [1.0 / m]])
    C = np.array([[1.0, 0]])
    return A, B, C

Looking at the C\boldsymbol{C}, we should be able to convince ourselves that the first dimension of the hidden state is the position (since that becomes y(t)y(t)). The second dimension is the velocity, as it is impacted by u(t)u(t) through B\boldsymbol{B}. The transition A\boldsymbol{A} relates these terms.

We’ll set uu to be a continuous function of tt,

@partial(np.vectorize, signature="()->()")
def example_force(t):
    x = np.sin(10 * t)
    return x * (x > 0.5)

Let’s run this SSM through our code.

def example_ssm():
    # SSM
    ssm = example_mass(k=40, b=5, m=1)

    # L samples of u(t).
    L = 100
    step = 1.0 / L
    ks = np.arange(L)
    u = example_force(ks * step)

    # Approximation of y(t).
    y = run_SSM(*ssm, u)

    # Plotting ---
    import matplotlib.pyplot as plt
    import seaborn
    from celluloid import Camera

    seaborn.set_context("paper")
    fig, (ax1, ax2, ax3) = plt.subplots(3)
    camera = Camera(fig)
    ax1.set_title("Force $u_k$")
    ax2.set_title("Position $y_k$")
    ax3.set_title("Object")
    ax1.set_xticks([], [])
    ax2.set_xticks([], [])

    # Animate plot over time
    for k in range(0, L, 2):
        ax1.plot(ks[:k], u[:k], color="red")
        ax2.plot(ks[:k], y[:k], color="blue")
        ax3.boxplot(
            [[y[k, 0] - 0.04, y[k, 0], y[k, 0] + 0.04]],
            showcaps=False,
            whis=False,
            vert=False,
            widths=10,
        )
        camera.snap()
    anim = camera.animate()
    anim.save("line.gif", dpi=150, writer="imagemagick")
example_ssm()

Neat! And that it was just 1 SSM, with 2 hidden states over 100 steps. The final model will have had 100s of stacked SSMs over thousands of steps. But first – we need to make these models practical to train.

Training SSMs: The Convolutional Representation

The punchline of this section is that we can turn the “RNN” above into a “CNN” by unrolling. Let’s go through the derivation.

The recurrent SSM is not practical for training on modern hardware due to its sequential nature. Instead, there is a well-known connection between linear time-invariant (LTI) SSMs and continuous convolutions. Correspondingly, the recurrent SSM can actually be written as a discrete convolution.

For simplicity let the initial state be x1=0x_{-1} = 0. Then unrolling explicitly yields:

x0=Bu0x1=ABu0+Bu1x2=A2Bu0+ABu1+Bu2y0=CBu0y1=CABu0+CBu1y2=CA2Bu0+CABu1+CBu2 \begin{aligned} x_0 &= \boldsymbol{\overline{B}} u_0 & x_1 &= \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{B}} u_1 & x_2 &= \boldsymbol{\overline{A}}^2 \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_1 + \boldsymbol{\overline{B}} u_2 & \dots \\ y_0 &= \boldsymbol{\overline{C}} \boldsymbol{\overline{B}} u_0 & y_1 &= \boldsymbol{\overline{C}} \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{C}} \boldsymbol{\overline{B}} u_1 & y_2 &= \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^2 \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{C}} \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_1 + \boldsymbol{\overline{C}} \boldsymbol{\overline{B}} u_2 & \dots \end{aligned}

This can be vectorized into a convolution with an explicit formula for the convolution kernel.

yk=CAkBu0+CAk1Bu1++CABuk1+CBuky=Ku \begin{aligned} y_k &= \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^k \boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^{k-1} \boldsymbol{\overline{B}} u_1 + \dots + \boldsymbol{\overline{C}} \boldsymbol{\overline{A}} \boldsymbol{\overline{B}} u_{k-1} + \boldsymbol{\overline{C}}\boldsymbol{\overline{B}} u_k \\ y &= \boldsymbol{\overline{K}} \ast u \end{aligned}

KRL=(CB,CAB,,CAL1B) \begin{aligned} \boldsymbol{\overline{K}} \in \mathbb{R}^L = (\boldsymbol{\overline{C}}\boldsymbol{\overline{B}}, \boldsymbol{\overline{C}}\boldsymbol{\overline{A}}\boldsymbol{\overline{B}}, \dots, \boldsymbol{\overline{C}}\boldsymbol{\overline{A}}^{L-1}\boldsymbol{\overline{B}}) \end{aligned} We call K\boldsymbol{\overline{K}} the SSM convolution kernel or filter.

Note that this is a giant filter. It is the size of the entire sequence!

def K_conv(Ab, Bb, Cb, L):
    return np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )

Warning: this implementation is naive and unstable. In practice it will fail to work for more than very small lengths. However, we are going to replace it with S4 in Part 2, so for now we just keep it around as a placeholder.

We can compute the result of applying this filter either with a standard direct convolution or with a padded (non-circular) Fast Fourier Transform (FFT). As the length gets longer the second method will be more efficient,

def non_circular_convolution(u, K, nofft=False):
    if nofft:
        return convolve(u, K, mode="full")[: u.shape[0]]
    else:
        assert K.shape[0] == u.shape[0]
        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return np.fft.irfft(out)[: u.shape[0]]

The CNN method and the RNN method yield (roughly) the same result,

def test_cnn_is_rnn(N=4, L=16, step=1.0 / 16):
    ssm = random_SSM(rng, N)
    u = jax.random.uniform(rng, (L,))

    # "RNN"
    rec = run_SSM(*ssm, u)

    # "CNN"
    ssmb = discretize(*ssm, step=step)
    conv = non_circular_convolution(u, K_conv(*ssmb, L))

    # Check
    assert np.isclose(rec.ravel(), conv.ravel(), rtol=1e-2, atol=1e-4).all()

At this point we have all of the machinery used for SSM training. The next steps are about 1) making these models stable to train, and 2) making them fast.

Addressing Long-Range Dependencies with HiPPO

Prior work found that the basic SSM actually performs very poorly in practice. Intuitively, one explanation is that they suffer from gradients scaling exponentially in the sequence length (i.e., the vanishing/exploding gradients problem). To address this problem, previous work developed the HiPPO theory of continuous-time memorization.

HiPPO specifies a class of certain matrices ARN×N\boldsymbol{A} \in \mathbb{R}^{N \times N} that when incorporated, allow the state x(t)x(t) to memorize the history of the input u(t)u(t). The most important matrix in this class is defined by the HiPPO matrix.

(HiPPO Matrix)Ank={(2n+1)1/2(2k+1)1/2if n>kn+1if n=k0if n<k \begin{aligned} (\text{\textbf{HiPPO Matrix}}) \qquad \boldsymbol{A}_{nk} = \begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \end{aligned}

Previous work found that simply modifying an SSM from a random matrix A\boldsymbol{A} to HiPPO improved its performance on the sequential MNIST classification benchmark from 50%50\% to 98%98\%.

This matrix is going to be really important, but it is a bit of magic. For our purposes we mainly need to know that: 1) we only need to calculate it once, and 2) it has a nice, simple structure (which we will exploit in part 2). Without going into the ODE math, the main takeaway is that this matrix aims to remember the past history in the state a timescale invariant manner,

def make_HiPPO(N):
    def v(n, k):
        if n > k:
            return np.sqrt(2 * n + 1) * np.sqrt(2 * k + 1)
        elif n == k:
            return n + 1
        else:
            return 0

    # Do it slow so we don't mess it up :)
    mat = [[v(n, k) for k in range(1, N + 1)] for n in range(1, N + 1)]
    return np.array(mat)

Diving a bit deeper, the intuitive explanation of this matrix is that it produces a hidden state that memorizes its history. It does this by keeping track of the coefficients of a Legendre polynomial. These coefficients let it approximate all of the previous history. Let us look at an example,

def example_legendre(N=8):
    # Random hidden state as coefficients
    import numpy as np
    import numpy.polynomial.legendre

    x = (np.random.rand(N) - 0.5) * 2
    t = np.linspace(-1, 1, 100)
    f = numpy.polynomial.legendre.Legendre(x)(t)

    # Plot
    import matplotlib.pyplot as plt
    import seaborn

    seaborn.set_context("talk")
    fig = plt.figure(figsize=(20, 10))
    ax = fig.gca(projection="3d")
    ax.plot(
        np.linspace(-25, (N - 1) * 100 + 25, 100),
        [0] * 100,
        zs=-1,
        zdir="x",
        color="black",
    )
    ax.plot(t, f, zs=N * 100, zdir="y", c="r")
    for i in range(N):
        coef = [0] * N
        coef[N - i - 1] = 1
        ax.set_zlim(-4, 4)
        ax.set_yticks([])
        ax.set_zticks([])
        # Plot basis function.
        f = numpy.polynomial.legendre.Legendre(coef)(t)
        ax.bar(
            [100 * i],
            [x[i]],
            zs=-1,
            zdir="x",
            label="x%d" % i,
            color="brown",
            fill=False,
            width=50,
        )
        ax.plot(t, f, zs=100 * i, zdir="y", c="b", alpha=0.5)
    ax.view_init(elev=40.0, azim=-45)
    fig.savefig("images/leg.png")

The red line represents that curve we are approximating, while the black bars represent the values of our hidden state. Each is a coefficient for one element of the Legendre series shown as blue functions. The intuition is that the HiPPO matrix updates these coefficients each step.

An SSM Neural Network.

We now have everything we need to build an SSM neural network layer. As defined above, the discrete SSM defines a map from RLRL\mathbb{R}^L \to \mathbb{R}^L, i.e. a 1-D sequence map. We assume that we are going to be learning the parameters BB and CC, as well as a step size Δ\Delta and a scalar DD parameter. The HiPPO matrix is used for the transition AA. We learn the step size in log space.

def log_step_initializer(dt_min=0.001, dt_max=0.1):
    def init(key, shape):
        return jax.random.uniform(key, shape) * (
            np.log(dt_max) - np.log(dt_min)
        ) + np.log(dt_min)

    return init

For the SMM layer most of the work is to build the filter. The actual call to the network is just the (huge) convolution we specified above.

Note for Torch users: setup in Flax is called each time the parameters are updated. This is similar to the Torch parameterizations.

class SSMLayer(nn.Module):
    A: np.DeviceArray  # HiPPO
    N: int
    l_max: int

    def setup(self):
        # SSM parameters
        self.B = self.param("B", lecun_normal(), (self.N, 1))
        self.C = self.param("C", lecun_normal(), (1, self.N))
        self.D = self.param("D", nn.initializers.ones, (1,))

        # Step parameter
        self.log_step = self.param("log_step", log_step_initializer(), (1,))

        step = np.exp(self.log_step)
        ssm = discretize(self.A, self.B, self.C, step=step)
        self.K = K_conv(*ssm, self.l_max)

    def __call__(self, u):
        return non_circular_convolution(u, self.K) + self.D * u

Since our SSMs operate on scalars, we make HH different, stacked copies (HH different SSMs!) with different parameters. Here we use the Flax vmap method to easily define these copies,

def cloneLayer(layer):
    return nn.vmap(
        layer,
        in_axes=1,
        out_axes=1,
        variable_axes={"params": 1},
        split_rngs={"params": True},
    )

We then initialize AA with the HiPPO matrix, and pass it into the stack of modules above,

def SSMInit(N):
    return partial(cloneLayer(SSMLayer), A=make_HiPPO(N), N=N)

This SSM Layer can then be put into a standard NN. For instance, here we have a Transformer-style stack of residual blocks, each containing the HH stacked SSMs.

class SeqInternal(nn.Module):
    layer: nn.Module
    l_max: int
    dropout: float
    d_model: int
    training: bool = True

    def setup(self):
        self.seq = self.layer(l_max=self.l_max)
        self.norm = nn.LayerNorm()
        self.out = nn.Dense(self.d_model)
        self.drop = nn.Dropout(
            self.dropout,
            broadcast_dims=[0],
            deterministic=not self.training,
        )

    def __call__(self, x, blank):
        x2 = self.seq(x)
        z = self.drop(self.out(self.drop(nn.gelu(x2))))
        return self.norm(z + x)
class SeqModel(nn.Module):
    layer: nn.Module
    d_output: int
    d_model: int
    l_max: int
    n_layers: int
    dropout: float = 0.2
    training: bool = True
    classification: bool = False

    def setup(self):
        self.encoder = nn.Dense(self.d_model)
        self.decoder = nn.Dense(self.d_output)
        self.layers = [
            SeqInternal(
                layer=self.layer,
                d_model=self.d_model,
                dropout=self.dropout,
                training=self.training,
                l_max=self.l_max,
            )
            for _ in range(self.n_layers)
        ]

    def __call__(self, x):
        x = self.encoder(x)
        for layer in self.layers:
            x = layer(x, None)
        if self.classification:
            x = np.mean(x, axis=0)
        x = self.decoder(x)
        return nn.log_softmax(x, axis=-1)
BatchSeqModel = nn.vmap(
    SeqModel,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None, "dropout": None},
    split_rngs={"params": False, "dropout": True},
)

Overall, this defines a sequence-to-sequence map of shape (batch size, sequence length, hidden dimension), exactly the signature exposed by related sequence models such as Transformers, RNNs, and CNNs.

Full code for training is defined in training.py. While we now have our main model, it is not fast enough to actually use. The next section is all about making this SSM Layer faster – a lot faster!

Part 2: Implementing S4

Warning: this section has a lot of math. Roughly it boils down to finding a way to compute the filter from Part 1 with a “HiPPO-like” matrix really fast. If you are interested, the details are really neat. If not, skip to Part 3 for some cool applications like MNIST completion.

Skip Button

The fundamental bottleneck in computing the discrete-time SSM is that it involves repeated matrix multiplication by A\boldsymbol{\overline{A}}. For example, computing naively involves LL successive multiplications by A\boldsymbol{\overline{A}}, requiring O(N2L)O(N^2 L) operations and O(NL)O(NL) space.

Specifically, recall this function here:

def K_conv_(Ab, Bb, Cb, L):
    return np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )

The contribution of S4 is a stable method for speeding up this particular operation. To do this we are going to focus on the case where the SSM has special structure. Specifically, Diagonal Plus Low-Rank (DPLR) in complex space.

DPLR: SSM is (Λpq,B,C)(\boldsymbol{\Lambda} - \boldsymbol{p}\boldsymbol{q}^*, \boldsymbol{B}, \boldsymbol{C}) for some diagonal Λ\boldsymbol{\Lambda} and vectors p,q,B,CCN×1\boldsymbol{p}, \boldsymbol{q}, \boldsymbol{B}, \boldsymbol{C} \in \mathbb{C}^{N \times 1}.

Under this DPLR assumption, S4 overcomes the speed bottleneck in three steps

  1. Instead of computing K\boldsymbol{\overline{K}} directly, we compute its spectrum by evaluating its truncated generating function . This now involves a matrix inverse instead of power.
  2. We show that the diagonal matrix case is equivalent to the computation of a Cauchy kernel 1ωjζk\frac{1}{\omega_j - \zeta_k}.
  3. We show the low-rank term can now be corrected by applying the Woodbury identity which reduces (Λ+pq)1(\boldsymbol{\Lambda} + \boldsymbol{p}\boldsymbol{q}^*)^{-1} in terms of Λ1\boldsymbol{\Lambda}^{-1}, truly reducing to the diagonal case.

Step 1. SSM Generating Functions

The main step will be switching from computing the sequence to computing its generating function. From the paper’s appendix:

To address the problem of computing powers of A\boldsymbol{\overline{A}}, we introduce another technique. Instead of computing the SSM convolution filter K\boldsymbol{\overline{K}} directly, we introduce a generating function on its coefficients and compute evaluations of it.

The truncated SSM generating function at node zz with truncation LL is K^L(z;A,B,C)C:=i=0L1CAiBzi \hat{\mathcal{K}}_L(z; \boldsymbol{\overline{A}}, \boldsymbol{\overline{B}}, \boldsymbol{\overline{C}}) \in \mathbb{C} := \sum_{i=0}^{L-1} \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^i \boldsymbol{\overline{B}} z^i

def K_gen_simple(Ab, Bb, Cb, L):
    K = K_conv(Ab, Bb, Cb, L)

    def gen(z):
        return np.sum(K * (z ** np.arange(L)))

    return gen

The generating function essentially converts the SSM convolution filter from the time domain to frequency domain. Importantly, it preserves the same information, and the desired SSM convolution filter can be recovered from evaluations of its generating function at the roots of unity Ω={exp(2πkL:k[L]}\Omega = \{ \exp(2\pi \frac{k}{L} : k \in [L] \} stably in O(LlogL)O(L \log L) operations by applying an FFT,

def conv_from_gen(gen, L):
    # Evaluate at roots of unity
    Omega_L = np.exp((2j * np.pi / L) * np.arange(L))
    atRoots = jax.vmap(gen)(Omega_L)
    # Inverse FFT
    out = np.fft.ifft(atRoots, L).reshape(L)
    # Numpy returns the values out of order.
    order = np.array([i if i == 0 else L - i for i in range(L)])
    return out[order].real

More importantly, in the generating function we can replace the matrix power with an inverse! K^L(z)=i=0L1CAiBzi=C(IALzL)(IAz)1B=C~(IAz)1B \hat{\mathcal{K}}_L(z) = \sum_{i=0}^{L-1} \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^i \boldsymbol{\overline{B}} z^i = \boldsymbol{\overline{C}} (\boldsymbol{I} - \boldsymbol{\overline{A}}^L z^L) (\boldsymbol{I} - \boldsymbol{\overline{A}} z)^{-1} \boldsymbol{\overline{B}} = \boldsymbol{\tilde{C}} (\boldsymbol{I} - \boldsymbol{\overline{A}} z)^{-1} \boldsymbol{\overline{B}}

And for all zΩLz \in \Omega_L, we have zL=1z^L = 1 so that term is removed. We then pull this constant term into a new C~\boldsymbol{\tilde{C}}. Critically, this function does not call K_conv,

def K_gen_inverse(Ab, Bb, Cb, L):
    I = np.eye(Ab.shape[0])
    Ab_L = matrix_power(Ab, L)
    Ct = Cb @ (I - Ab_L)
    return lambda z: (Ct @ inv(I - Ab * z) @ Bb).reshape()

But it does output the same values,

def test_gen_inverse(L=16, N=4):
    ssm = random_SSM(rng, N)
    b = K_conv(*ssm, L=L)

    a = conv_from_gen(K_gen_inverse(*ssm, L=L), L)
    assert np.isclose(a, b, rtol=1e-2, atol=1e-4).all()

In summary, Step 1 allows us to replace the matrix power with an inverse by utilizing a truncated generating function. However this inverse still needs to be calculated LL times (for each of the roots of unity).

Step 2: Diagonal Case

The next step to assume special structure on the matrix A\boldsymbol{A} to avoid the inverse. To begin, let us first convert the equation above to use the original SSM matrices. With some algebra you can expand the discretization and show:

C~(IA)1B=2Δ1+zC~[21z1+zΔA]1B \begin{aligned} \boldsymbol{\tilde{C}}\left(\boldsymbol{I} - \boldsymbol{\overline{A}} \right)^{-1} \boldsymbol{\overline{B}} = \frac{2\Delta}{1+z} \boldsymbol{\tilde{C}} \left[ {2 \frac{1-z}{1+z}} - \Delta \boldsymbol{A} \right]^{-1} \boldsymbol{B} \end{aligned}

Now imagine A=ΛA=\boldsymbol{\Lambda} for a diagonal Λ\boldsymbol{\Lambda}. Substituting in the discretization formula the authors show that the generating function can be written in the following manner:

K^Λ(z)=c(z)iC~iBi(g(z)Λi)=c(z)kz,Λ(C~,B) \begin{aligned} \boldsymbol{\hat{K}}_{\boldsymbol{\Lambda}}(z) & = c(z) \sum_i \cdot \frac{\tilde{C}_i B_i} {(g(z) - \Lambda_{i})} = c(z) \cdot k_{z, \boldsymbol{\Lambda}}(\boldsymbol{\tilde{C}}, \boldsymbol{B}) \\ \end{aligned} where cc is a constant, and gg is a function of zz.

We have effectively replaced an inverse with a weighted dot product. Let’s make a small helper function to compute this weight dot product for use. Here vectorize is a decorator that let’s us broadcast this function automatically,

@partial(np.vectorize, signature="(c),(),(c)->()")
def cauchy_dot(v, omega, lambd):
    return (v / (omega - lambd)).sum()

While not important for our implementation, it is worth noting that this is a Cauchy kernel and is the subject of many fast implementations. On a GPU though, it is efficient enough just to compute it directly.

Step 3: Diagonal Plus Low-Rank

The final step is to relax the diagonal assumption. In addition to the diagonal term we allow a low-rank component with p,qCN×1\boldsymbol{p}, \boldsymbol{q} \in \mathbb{C}^{N\times 1} such that:

A=Λ+pq \boldsymbol{A} = \boldsymbol{\Lambda} + \boldsymbol{p} \boldsymbol{q}^*

The Woodbury identity tells us that the inverse of a diagonal plus rank-1 term is equal to the inverse of the diagonal plus a rank-1 term. Or in math:

(Λ+pq)1=Λ1Λ1p(1+qp)1qΛ1 \begin{aligned} (\boldsymbol{\Lambda} + \boldsymbol{p} \boldsymbol{q}^*)^{-1} &= \boldsymbol{\Lambda}^{-1} - \boldsymbol{\Lambda}^{-1} \boldsymbol{p} (1 + \boldsymbol{q}^* \boldsymbol{p})^{-1} \boldsymbol{q}^* \boldsymbol{\Lambda}^{-1} \end{aligned}

There is a bunch of algebra not shown. But it mostly consists of substituting this component in for A, applying the Woodbury identity and distributing terms. We end up with 4 terms that all look like Step 2 above:

K^DPLR(z)=c(z)[kz,Λ(C~,B)kz,Λ(C~,p)(1kz,Λ(q,p))1kz,Λ(q,B)] \begin{aligned} \boldsymbol{\hat{K}}_{DPLR}(z) & = c(z) [k_{z, \Lambda}(\boldsymbol{\tilde{C}}, \boldsymbol{\boldsymbol{B}}) - k_{z, \Lambda}(\boldsymbol{\tilde{C}}, \boldsymbol{\boldsymbol{p}}) (1 - k_{z, \Lambda}(\boldsymbol{q^*}, \boldsymbol{\boldsymbol{p}}) )^{-1} k_{z, \Lambda}(\boldsymbol{q^*}, \boldsymbol{\boldsymbol{B}}) ] \end{aligned}

The code consists of collecting up the terms and applying 4 weighted dot products,

def K_gen_DPLR(Lambda, p, q, B, Ct, step):
    aterm = (Ct.conj().ravel(), q.conj().ravel())
    bterm = (B.ravel(), p.ravel())

    def gen(o):
        g = (2.0 / step) * ((1.0 - o) / (1.0 + o))
        c = 2.0 / (1.0 + o)

        def k(a):
            return cauchy_dot(a, g, Lambda)

        k00 = k(aterm[0] * bterm[0])
        k01 = k(aterm[0] * bterm[1])
        k10 = k(aterm[1] * bterm[0])
        k11 = k(aterm[1] * bterm[1])
        return c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)

    return gen

This is our final version of the KK function. Now we can check whether it worked. First, let’s generate a random Diagonal Plus Low Rank (DPLR) matrix,

def random_DPLR(rng, N):
    l_r, p_r, q_r, b_r, c_r = jax.random.split(rng, 5)
    Lambda = jax.random.uniform(l_r, (N,))
    p = jax.random.uniform(p_r, (N,))
    q = jax.random.uniform(q_r, (N,))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return Lambda, p, q, B, C

We can check that the DPLR method yields the same filter as computing A\boldsymbol{A} directly,

def test_gen_dplr(L=16, N=4):
    I = np.eye(4)

    # Create a DPLR A matrix and discretize
    Lambda, p, q, B, C = random_DPLR(rng, N)
    A = np.diag(Lambda) - p[:, np.newaxis] * q[np.newaxis, :]
    Ab, Bb, Cb = discretize(A, B, C, 1.0 / L)
    a = K_conv(Ab, Bb, Cb, L=L)

    # Compare to the DPLR generating function approach.
    Ct = (I - matrix_power(Ab, L)).conj().T @ Cb.ravel()
    b = conv_from_gen(K_gen_DPLR(Lambda, p, q, B, Ct, step=1.0 / L), L)
    assert np.isclose(a, b, rtol=1e-2, atol=1e-4).all()

Turning HiPPO to DPLR

This approach applies to DPLR matrices, but remember we would like it to also apply to the HiPPO matrix. While not DPLR in its current form, the HiPPO matrix does have special structure. It is Normal Plus Low-Rank (NPLR). The paper argues that this is just as good as DPLR for the purposes of learning an SSM network.

The S4 techniques can apply to any matrix A\boldsymbol{A} that can be decomposed as Normal Plus Low-Rank (NPLR). A=VΛVpq=V(ΛVp(Vq))V \boldsymbol{A} = \boldsymbol{V} \boldsymbol{\Lambda} \boldsymbol{V}^* - \boldsymbol{p} \boldsymbol{q}^\top = \boldsymbol{V} \left( \boldsymbol{\Lambda} - \boldsymbol{V}^* \boldsymbol{p} (\boldsymbol{V}^*\boldsymbol{q})^* \right) \boldsymbol{V}^* for unitary VCN×N\boldsymbol{V} \in \mathbb{C}^{N \times N}, diagonal Λ\boldsymbol{\Lambda}, and low-rank factorization p,qRN×r\boldsymbol{p}, \boldsymbol{q} \in \mathbb{R}^{N \times r}. An NPLR SSM is therefore unitarily equivalent to some DPLR matrix.

For S4, we need to work with a HiPPO matrix for A\boldsymbol{A}. This requires extracting Λ\boldsymbol{\Lambda} from this decomposition. The appendix of the paper shows this by getting it into a skew-symmetric (normal) + low-rank form. We can use this math to get out the DPLR terms,

def make_NPLR_HiPPO(N):
    # Make -HiPPO
    nhippo = -make_HiPPO(N)

    # Add in a rank 1 term. Makes it Normal.
    p = 0.5 * np.sqrt(2 * np.arange(1, N + 1) + 1.0)
    q = 2 * p
    S = nhippo + p[:, np.newaxis] * q[np.newaxis, :]

    # Diagonalize to S to V \Lambda V^*
    Lambda, V = jax.jit(eig, backend="cpu")(S)
    return nhippo, Lambda, p, q, V

Final sanity check just to make sure those identities hold,

def test_nplr(N=8):
    A2, Lambda, p, q, V = make_NPLR_HiPPO(N)
    p, q = p[:, np.newaxis], q[:, np.newaxis]
    Lambda = np.diag(Lambda)
    Vc = V.conj().T
    A3 = V @ (Lambda - (Vc @ p) @ (Vc @ q.conj()).conj().T) @ Vc
    A4 = V @ Lambda @ Vc - (p @ q.T)
    assert np.allclose(A2, A3, atol=1e-2, rtol=1e-2)
    assert np.allclose(A2, A4, atol=1e-2, rtol=1e-2)

Part 3: S4 in Practice

That was a lot of work, but now the actual model is concise. In fact we are only using four functions:

  1. discretize → Convert SSM to discrete form.
  2. K_gen_DPLR → Truncated generating function when A\boldsymbol{A} is DPLR (S4-part)
  3. conv_from_gen → Convert generating function to filter
  4. non_circular_convolution → Run convolution

A full S4 Layer is very similar to the simple SSM layer above. The only difference is in the the computation of K\boldsymbol{K}. Additionally instead of learning C\boldsymbol{C}, we learn C~\boldsymbol{\tilde{C}} so we avoid computing powers of A\boldsymbol{A}. Note as well that in the original paper Λ,p,q\Lambda, p, q are also learned. However, in this post, we leave them fixed for simplicity.

class S4Layer(nn.Module):
    A: np.DeviceArray
    p: np.DeviceArray
    q: np.DeviceArray
    Lambda: np.DeviceArray

    N: int
    l_max: int

    def setup(self):
        self.B = self.param("B", lecun_normal(), (self.N, 1))
        self.D = self.param("D", nn.initializers.ones, (1,))
        self.Ct = self.param(
            "Ct", lecun_normal(dtype=jax.numpy.complex64), (1, self.N)
        )
        self.log_step = self.param("log_step", log_step_initializer(), (1,))
        step = np.exp(self.log_step)

        K_gen = K_gen_DPLR(
            self.Lambda, self.p, self.q, self.B, self.Ct, step[0]
        )
        self.K = conv_from_gen(K_gen, self.l_max)

    def __call__(self, u):
        return non_circular_convolution(u, self.K) + self.D * u
S4Layer = cloneLayer(S4Layer)

We initialize the model by computing a DPLR initializer similar to HiPPO,

def S4LayerInit(N):
    _, Lambda, p, q, V = make_NPLR_HiPPO(N)
    Vc = V.conj().T
    p = Vc @ p
    q = Vc @ q.conj()
    A = np.diag(Lambda) - p[:, np.newaxis] @ q[:, np.newaxis].conj().T
    return partial(S4Layer, N=N, A=A, p=p, q=q, Lambda=Lambda)

Experiments

Now that we have the model, we can try it out on some MNIST experiments. For these experiments we linearize MNIST and just treat each image as a sequence of pixels.

The first experiments we ran were on MNIST classification. While not in theory a hard problem, treating MNIST as a linear sequence classification task is a bit strange. However in practice, the model with H=256H=256 and four layers seems to get up near 99% right away.

A more visually interesting task is generating MNIST digits, by predicting entire sequences of pixels! Here, we simply feed in a sequence of pixels into the model and have it predict the next one like language modeling. With a little tweaking, we are able to get the model to an NLL of 0.52 on this task with size 512 and 6 layers (~2m parameters).

The metric usually used for this task is bits per dimension which is NLL in base 2 for MNIST. A score of 0.52 is ~0.76 BPD which is near PixelCNN++.

We can sample from the model using the CNN implementation. Ideally we would use the RNN form, but that would require a bit more plumbing,

def sample_mnist():
    import matplotlib.pyplot as plt
    from flax.training import checkpoints

    model = S4LayerInit(N=64)
    model = partial(
        BatchSeqModel,
        layer=model,
        d_output=256,
        d_model=256,
        n_layers=6,
        l_max=783,
    )
    rng = jax.random.PRNGKey(0)
    state = checkpoints.restore_checkpoint("models/best_84", None)
    model = model(training=False)
    start = np.zeros((1, 784, 1))

    def loop(i, cur):
        cur, rng = cur
        r, rng = jax.random.split(rng)
        out = model.apply({"params": state["params"]}, cur[:, :-1])
        p = jax.random.categorical(rng, out[0, i])
        cur = jax.ops.index_update(cur, (0, i + 1, 0), p)
        return cur, rng

    out = jax.lax.fori_loop(0, 783, jax.jit(loop), (start, rng))[0]
    plt.imshow(out.reshape(28, 28))
    plt.savefig("sample.png")
sample_mnist()

We can also do prefix-samples – given the first 300 pixels, try to complete the image. S4 is on the left, true on the right.

Next we tried training a model to generate drawings. For this we used the QuickDraw dataset. The dataset includes a version of the dataset downsampled to MNIST size so we can use roughly the same model as above. The dataset is much larger though (5M images) and more complex. We only trained for 1 epoch with a H=256H=256, 4 layer model. Still, the approach was able to generate relatively coherent completions. These are prefix samples with 500 pixels given.

Our full code base contains more examples and infrastructure for training models for generations and classification.

Conclusion

Putting together this post inspired lots of thoughts about future work in this area. One obvious conclusion is that long-range models have all sorts of future applications from acoustic modeling to genomic sequences to trajectories (not to mention our shared area of NLP). Another is some surprise that linear models can be so effective here, while also opening up a range of efficient techniques. Finally from a practical level, the transformations in JAX make it really nice to implement complex models like this in a very concise way (~200 LoC), with similar efficiency and performance!

We end by thanking the authors Albert Gu and Karan Goel, who were super helpful in putting this together, and pointing you again to their paper and codebase. We’re also grateful for Conner Vercellino and Laurel Orr for providing helpful feedback on this post.

/ Cheers – Sasha & Sidd