Writing Efficient Numpy Code

In this blog post, I am going to talk about writing efficient numpy code in python. I do a good amount of numerical linear algebra for my research and personal projects, and I typically code in python with numpy (and spicy) because it is very easy to use and it is also very efficient when used correctly.

Consider the following problem, which will serve as a concrete task to use as an example throughout this post. Suppose we have a (column) vector $x$ of length $n = 2^k$ and we want to compute $H_k x$ where $H_k$ is a “hierarchical” matrix with branching factor $2$, defined by $$ \begin{align*} H_0 &= \begin{bmatrix} 1 \end{bmatrix} \\ H_{k+1} &= \begin{bmatrix} 1 & 1 \\ H_k & 0 \\ 0 & H_k \end{bmatrix} \end{align*} $$ Where the top row is a vector of all ones and 0 denotes a matrix of zeros having the same size as $H_k$. For example, $H_2$ is a $7 \times 4$ matrix that looks like this: $$ \begin{bmatrix} 1 & 1 & 1 & 1 \\ 1 & 1 & 0 & 0 \\ 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ \end{bmatrix} $$ And in general $H_k$ has size $(2n-1) \times n$. It’s easy enough to program this in python by instantiating the matrix $H_k$ then computing the matrix-vector product. This approach has a theoretical time and space complexity of $O(n^2)$. On my laptop, I am able to run this up to $k=13$, or equivalently $n=8192$, before it fails due to memory limitations. I will remark that when you write good numpy code, memory typically becomes an issue before time. Even though the program fails to run after $k=13$, it only took $0.46$ seconds in that setting.

In the rest of this blog post, I will show how to solve this program much faster than this baseline implementation. The final algorithm will only be limited by the size of $x$, and as long as that fits in memory it will run in a reasonable amount of time. Along the way, we will learn valuable lessons about how to write efficient numpy code.

The rest of the post will be broken up into two main parts. In the first part, I will discuss how to exploit problem structure to improve the theoretical time complexity of solving this problem. When writing code you should always be looking for these types of opportunities. The famous Donald Knuth quote "premature optimization is the root of all evil in programming" is relevant here -- it is a common source for bugs, so it’s important to first find an algorithm with near-optimal theoretical time complexity and implement it correctly before micro-optimizing it. In the second part, I will discuss how to micro-optimize this code, which is also an important thing to do. For example, in this problem my micro-optimized code was over 25x faster than my initial implementation of the same algorithm, which is nothing to sneeze at.


Exploiting Problem Structure


One simple and natural way to improve the theoretical time complexity is to simply use a sparse matrix representation of $H_k$. It’s not difficult to show that $H_k$ has $O(n \log n)$ nonzero entries — and the rest of the entries don’t need to be stored. Using scipy.sparse to represent this matrix, we can get the theoretical time and space complexity of this problem down to $O(n \log n)$ — a substantial improvement. On my laptop, I am able to run this up to $k=21$, or equivalently $n=2$ million, at which point it uses about $1$ GB of memory and takes about $6$ seconds to run.

There is additional structure to exploit in this problem however, and we can get the theoretical time and space complexity down to $O(n)$. To understand the idea, it’s important to think about the structure of H (beyond its sparsity). We basically need to compute $x_a + … + x_b$, for a collection of hierarchically structured intervals, exemplified below:
Note that he number of nodes in this tree is $2n-1$, and each node corresponds to one of the rows of $H_k$. This tree helps illuminate the structure of $H_k$, which leads to the following simple bottom-up algorithm:

  1. Compute the answers to the leaves directly (these are just entries of $x$) 
  2. For level $k-1, \dots, 0$:
          Compute answer to node by summing answers from children
The time complexity of this algorithm is $O(n)$ because each of the $2n-1$ answers requires summing at most two items. There is no space overhead in this algorithm either, so it’s space complexity is also O(n), which is what we need to store $x$. Theoretically, this is the best algorithm I am aware of to solve this problem.

Micro-Optimizing Numpy Code


While the algorithm I just described is theoretically optimal, a poor implementation can leave a lot on the table in terms of performance, as I will demonstrate now. A first-crack implementation of this algorithm, which uses indexing and for loops (things you should always avoid when writing numpy code!) is given below:
import numpy as np
def hier1(x):
    ans = [x]
    m = x.shape[0]
    while m > 1:
        m = m // 2
        y = np.zeros(m)
        for i in range(m):
            y[i] = x[2*i] + x[2*i+1]
        x = y
        ans.append(x)
    return np.concatenate(ans)
This algorithm scales up to $k=26$, or equivalently $n=67$ million in about $30$ seconds. With this method we could scale to larger $k$ if we are willing to wait longer, at least while $x$ fits into memory. This is a big improvement over the matrix-based approaches, but bumpy-wise it is not very a very well written piece of code. Someone who works with numpy a lot could look at it and easily improve it without having any understanding of the problem that it is solving.

The key to making an efficient numpy implementation is to avoid for loops and indexing, by replacing these things with built-in numpy functions. This is easier said than done, and it is often challenging to find the right numpy function for the job. But the more familiarity you get with python and numpy, the easier this becomes. And for the example of this blog post, it is quite straightforward: we only need slicing and addition of two arrays. The resulting algorithm looks the same, with the inner for loop replaced with an efficient vectorized implementation.
import numpy as np
def hier2(x):
    ans = [x]
    m = x.shape[0]
    while m > 1:
        m = m // 2
        x = x[0::2] + x[1::2]
        ans.append(x)
    return np.concatenate(ans)
The resulting algorithm is $35$x faster than the original, and scales up to $k=28$, or $n = 268$ million, in less than $4$ seconds. It was also able to run for $k=29$, but it took over $25$ seconds. This is in contrast with the $8$ seconds predicted by extrapolating the result from $k=28$ and using the fact that the algorithm is linear time, and is likely due to the fact that I am pushing the limit of what I can do in main memory, and there is a large overhead to I/O bound operations. Nevertheless, it is a big improvement to our original implementation, and highlights the practical importance of writing vectorized numpy code. To avoid loops and indexing in more complicated situations, you’ll need to dig into the numpy api and tinker around with things a bit. Anyway, that’s all I have to say on this topic. Hopefully you found this blog post useful, and if you come up with an even faster implementation leave a comment below!

Comments

Popular posts from this blog

Efficiently Remove Duplicate Rows from a 2D Numpy Array

Multi-Core Programming with Java

Beat the Streak: Day Three