Exploring multi-input einsums in JAX
einsum is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays.
In this blog post, we are going to explore the einsum API and utilize it solve marginal inference-type problems.
Basic Examples
Suppose we have two $ n \times n $ arrays $A$ and $B$ and we want to compute
$$ Y[i, k] = \sum_{j=1}^n A[i, j] B[j, k] $$
This is the basic definition of matrix multiplication, and we can compute it in JAX via jnp.einsum('ij,jk->ik', A, B). Here the formula 'ij,jk->ik' defines the computation. Each letter corresponds to a name for an axis of an array. In this case we have only 2D arrays so each part of the formula contains two axes. The letters on the RHS of the expression determine which axes will be defined for the output. Any axes names not appearing in the RHS of the expression will be "marginalized out" via the sum operator. This is a very generic API, by modifying the formula we can easily express a wide variety of interesting computations.
A simple one that was of interest to me recently is the formula 'ij,ik,jk->ij', which defines the computation:
$$ Y[i, j] = \sum_{k} A[i,j] B[i,k] C[j,k] $$
These types of expressions tend to appear when doing marginal inference in discrete graphical models, and I recently became interested in exploring this space more to improve the efficiency of an open source repository I maintain. One way to do this computation is by constructing a 3D array $Z$ defined by $Z[i,j,k] = A[i,j] B[i,k] C[j,k]$ and then marginalize out $k$ by computing $Y[i,j] = \sum_{k} Z[i,j,k]$. When $Z$ is not too large, this approach is fine, and on CPUs it doesn't matter so much whether $Z$ is materialized or not. When doing the computation on GPUs, the difference can matter substantially though. For example, if $n=4096$ then it would require $256$ GB of RAM to store $Z$, assuming each element was represented using a 4 byte float. This is infeasible on most machines, including GPUs. However, by using jnp.einsum (which does not construct the large intermediate in this case), on a nvidia A100 GPU this computation only takes $0.02$ seconds. This is because the computation can be reduced to a matrix multiplication and an element-wise multiplication, both of which are very efficient on GPUs. Specifically, we can write the expression as $ Y[i,j] = A[i, j] \sum_{k} B[i, k] C[j, k] $, which we can implement in JAX using A * (B @ C.T). This is more or less what jnp.einsum does internally in this case. By expressing the einsum in this manner, we enjoy all the benefits of vectorization and highly optimized matrix multiplication kernels, with the low memory footprint of the naive approach (nested for loops).
A harder example
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 17196646400 bytes.
Limitations of JAX's einsum implementation
Naive einsum implementation via jax.vmap and jax.lax.scan
def inner_einsum(*arrays):
# computes einsum for ,k,l,k,l,kl->
# Does not create any intermediate arrays
A, B, C, D, E, F = arrays
K, L = B.size, C.size
def loop_body_k(partial1, k):
def loop_body_l(partial2, l):
return partial2 + C[l] * E[l] * F[k, l], ()
return partial1 + B[k] * D[k] * jax.lax.scan(loop_body_l, 0, jnp.arange(L))[0], ()
return A * jax.lax.scan(loop_body_k, 0, jnp.arange(K))[0]
@jax.jit
def vmap_einsum(*arrays):
# computes einsum for ij,ik,il,jk,jl,kl->ij naively
# No memory overhead. Vectorized across output cells.
return jax.vmap(
# The in_axes specifies which arrays contain axis j (here) and i (below)
jax.vmap(inner_einsum, in_axes=(0, None, None, 0, 0, None)),
in_axes=(0, 0, 0, None, None, None)
)(*arrays)
Benchmark experiments shows how the performance characteristics of this approach relate to jnp.einsum. It scales better (handles n=1024) without OOM, but is significantly slower than jnp.einsum in regimes where it scales. n | vmap_einsum | jnp.einsum |
---|---|---|
- | 0.14 | 0.0022 |
256 | 0.76 | 0.0180 |
512 | 4.29 | 0.2500 |
1024 | 35.70 | XLARuntimeError |
Combining jnp.einsum with jax.lax.scan for improved memory-efficiency
@jax.jit
def scan_einsum(*arrays):
# we will scan over k and build up a running sum
A, B, C, D, E, F = arrays
K = B.shape[1]
zeros = jnp.zeros(A.shape)
def add_small_einsum(partial, k):
# einsum with k stripped out
# i,j,i,il,j,jl,l->ij
return partial + jnp.einsum('ij,i,il,j,jl,l->ij', A, B[:,k], C, D[:,k], E, F[k,:]), ()
return jax.lax.scan(add_small_einsum, zeros, jnp.arange(K))[0]
Benchmarking this approach shows that it scales much more favorably than jnp.einsum on this problem, handling $n=4096$ while jnp.einsum only scaled up to $n=512$. Somewhat surprisingly this approach is even more efficient than jnp.einsum for $n=256$ and $n=512$, perhaps due to it's lower memory usage.n | scan_einsum | jnp.einsum |
---|---|---|
128 | 0.00346 | 0.00142 |
256 | 0.0115 | 0.0185 |
512 | 0.0468 | 0.238 |
1024 | 0.473 | XlaRuntimeError |
2048 | 6.00 | XlaRuntimeError |
4096 | 89.7 | XlaRuntimeError |
Comments
Post a Comment