Exploring multi-input einsums in JAX
The JAX documentation says : einsum is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. In this blog post, we are going to explore the einsum API and utilize it solve marginal inference-type problems. Basic Examples Suppose we have two $ n \times n $ arrays $A$ and $B$ and we want to compute $$ Y[i, k] = \sum_{j=1}^n A[i, j] B[j, k] $$ This is the basic definition of matrix multiplication, and we can compute it in JAX via jnp.einsum('ij,jk->ik', A, B). Here the formula 'ij,jk->ik' defines the computation. Each letter corresponds to a name for an axis of an array. In this case we have only 2D arrays so each part of the formula contains two axes. The letters on the RHS of the expression determine which axes will be defined for the output. Any axes names not appearing in the RHS of the expression will be "marginalized out" via th...