Exploring multi-input einsums in JAX

The JAX documentation says:

einsum is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays.

In this blog post, we are going to explore the einsum API and utilize it solve marginal inference-type problems. 

Basic Examples

Suppose we have two $ n \times n $ arrays $A$ and $B$ and we want to compute

$$ Y[i, k] = \sum_{j=1}^n A[i, j] B[j, k] $$

This is the basic definition of matrix multiplication, and we can compute it in JAX via jnp.einsum('ij,jk->ik', A, B).  Here the formula 'ij,jk->ik' defines the computation.  Each letter corresponds to a name for an axis of an array.  In this case we have only 2D arrays so each part of the formula contains two axes.  The letters on the RHS of the expression determine which axes will be defined for the output.  Any axes names not appearing in the RHS of the expression will be "marginalized out" via the sum operator.  This is a very generic API, by modifying the formula we can easily express a wide variety of interesting computations. 

A simple one that was of interest to me recently is the formula 'ij,ik,jk->ij', which defines the computation:

$$ Y[i, j] = \sum_{k} A[i,j] B[i,k] C[j,k] $$

These types of expressions tend to appear when doing marginal inference in discrete graphical models, and I recently became interested in exploring this space more to improve the efficiency of an open source repository I maintain.  One way to do this computation is by constructing a 3D array $Z$ defined by $Z[i,j,k] = A[i,j] B[i,k] C[j,k]$ and then marginalize out $k$ by computing $Y[i,j] = \sum_{k} Z[i,j,k]$.  When $Z$ is not too large, this approach is fine, and on CPUs it doesn't matter so much whether $Z$ is materialized or not.  When doing the computation on GPUs, the difference can matter substantially though.  For example, if $n=4096$ then it would require $256$ GB of RAM to store $Z$, assuming each element was represented using a 4 byte float.  This is infeasible on most machines, including GPUs.  However, by using  jnp.einsum (which does not construct the large intermediate in this case), on a nvidia A100 GPU this computation only takes $0.02$ seconds.  This is because the computation can be reduced to a matrix multiplication and an element-wise multiplication, both of which are very efficient on GPUs.  Specifically, we can write the expression as $ Y[i,j] = A[i, j] \sum_{k} B[i, k] C[j, k] $, which we can implement in JAX using A * (B @ C.T).  This is more or less what jnp.einsum does internally in this case.  By expressing the einsum in this manner, we enjoy all the benefits of vectorization and highly optimized matrix multiplication kernels, with the low memory footprint of the naive approach (nested for loops).

A harder example

Let's generalize the example above by adding an additional axis and additional arrays defined over the new axis.  Specifically, let's consider the einsum formula 'ij,ik,il,jk,jl,kl->ij', which defines the computation:

$$ Y[i, j] = \sum_{k, l} A[i, j] B[i, k] C[i, l] D[j, k] E[j, l] F[k, l] $$

When done naively this summation should require $O(n^2)$ space and $O(n^4)$ time.  However, when we plug this formula into JAX and give it $6$ arrays of size $ 256 \times 256 $ it gives back an error:

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 17196646400 bytes.

This error indicates JAX is trying to materialize an object that is roughly 16 GB, which corresponds exactly to $n^4$ 4-byte floats.  This is different from the 3D case where JAX did not attempt to materialize the full intermediate.  This can be partially alleviated by modifying the ``optimize`` keyword arg on jnp.einsum.  In this case it brings the memory consumption down to $O(n^3)$ as demonstrated by the fact that it works up to $n=512$, but fails at $n=1024$ with error: 

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4311744512 bytes.

This amount of memory corresponds to exactly $1024^3$.

Limitations of JAX's einsum implementation

It turns out that JAX implements einsum as a sequence of "contractions", which it uses the dot_general function for.  This function is specialized to take two inputs at a time, and while arbitrary einsums can be broken down in this way, it does not appear to be possible to get the memory-efficient behavior of the naive implementation defined above as a sequence of pairwise contractions.  So where does that leave us?  Two options are given below.  For the sake of simplicity, the remainder of the blog post will focus on and specialize to the difficult benchmark formula 'ij,ik,il,jk,jl,kl->ij', but the key ideas are generalizable to any formula.  

Naive einsum implementation via jax.vmap and jax.lax.scan

One option is to just write the custom for loop ourselves.  Doing it in pure python is obviously a bad idea, so we have to use JAX's higher order functions jax.vmap and jax.lax.scan instead.  The code below shows what this would look like, specialed to the einsum formula 'ij,ik,il,jk,jl,kl->ij' considered above.Essentially, it uses vmap over the axes that should be included in the output and scan over the axes that should be marginalized out.  Thus, the computation is vectorized only across axes in the output, and along other axes the computation is sequential. 
def inner_einsum(*arrays):
  # computes einsum for ,k,l,k,l,kl->
  # Does not create any intermediate arrays

  A, B, C, D, E, F = arrays
  K, L = B.size, C.size

  def loop_body_k(partial1, k):
    def loop_body_l(partial2, l):
      return partial2 + C[l] * E[l] * F[k, l], ()
    return partial1 + B[k] * D[k] * jax.lax.scan(loop_body_l, 0, jnp.arange(L))[0], ()
  return A * jax.lax.scan(loop_body_k, 0, jnp.arange(K))[0]


@jax.jit
def vmap_einsum(*arrays):
  # computes einsum for ij,ik,il,jk,jl,kl->ij naively
  # No memory overhead.  Vectorized across output cells.

  return jax.vmap(
      # The in_axes specifies which arrays contain axis j (here) and i (below)
      jax.vmap(inner_einsum, in_axes=(0, None, None, 0, 0, None)),
      in_axes=(0, 0, 0, None, None, None)
  )(*arrays)
Benchmark experiments shows how the performance characteristics of this approach relate to jnp.einsum.  It scales better (handles n=1024) without OOM, but is significantly slower than jnp.einsum in regimes where it scales.  

 
n vmap_einsum jnp.einsum
- 0.14 0.0022
256 0.76 0.0180
512 4.29 0.2500
1024 35.70 XLARuntimeError

Combining jnp.einsum with jax.lax.scan for improved memory-efficiency

One reason why vmap_einsum is so much slower than jnp.einsum is because it does not utilize optimized XLA kernels for matrix multiplication.  In this section, we propose an approach that gives the best of both worlds.  It leverages jnp.einsum to do the "heavy lifting" but resorts to sequential computation when necessary to scale.  The basic idea is to do a scan (i.e., a for loop) over one axis, and compute a bunch of smaller einsums sequentially which we will add up.  If we sequentially loop over axis k for instance,   then the smaller einsum formula would be 'ij,i,il,j,jl,l->ij',  which would be applied to arrays A, B[:, k], C, D[:, k], E, F[k, :]. The benefit of this approach is that jnp.einsum can easily handle the smaller einsum formula efficiently with much smaller memory overhead. In JAX, this implementation would look like:
@jax.jit
def scan_einsum(*arrays):
  # we will scan over k and build up a running sum

  A, B, C, D, E, F = arrays
  K = B.shape[1]
  zeros = jnp.zeros(A.shape)

  def add_small_einsum(partial, k):
    # einsum with k stripped out
    # i,j,i,il,j,jl,l->ij
    return partial + jnp.einsum('ij,i,il,j,jl,l->ij', A, B[:,k], C, D[:,k], E, F[k,:]), ()

  return jax.lax.scan(add_small_einsum, zeros, jnp.arange(K))[0]
Benchmarking this approach shows that it scales much more favorably than jnp.einsum on this problem, handling $n=4096$ while jnp.einsum only scaled up to $n=512$.  Somewhat surprisingly this approach is even more efficient than jnp.einsum for $n=256$ and $n=512$, perhaps due to it's lower memory usage.

 
n scan_einsum jnp.einsum
128 0.00346 0.00142
256 0.0115 0.0185
512 0.0468 0.238
1024 0.473 XlaRuntimeError
2048 6.00 XlaRuntimeError
4096 89.7 XlaRuntimeError

Summary

In summary, we figured out a simple method to greatly improve the efficiency of einsum over multiple array inputs.  It was not trivial at first to implement these algorithms in pure jax, and required thinking at a very high level of abstraction.   However, I think the end result is a pretty nice and elegant solution that leverages JAX's higher order functions.  Programming in JAX reminds me of programming in Haskell, which is a pleasure to program in once you understand it, although getting to the point of sufficient understanding can be difficult.   

Comments

Popular posts from this blog

Optimal Strategy for Farkle Dice

Markov Chains and Expected Value

Automatically Finding Recurrence Relations from Integer Sequences