Representing Graphical Model Factors in Numpy

I've been working a bit with graphical models lately for my research.  For a while I was using a library called pgmpy for their implementations of factor arithmetic and inference algorithms.  However, as my research is progressing I am needing more control than what pgmpy offers, so I decided to re-implement and extend the algorithms that I needed.  If you are in a similar situation and are finding yourself implementing your own graphical models algorithms, this post is for you.

In the setting I am working in, I have a Markov random field model where there are only a small number of variables (no more than a few dozen, often much less).  However, each variable can take on a possibly large number of categorical values and the graph may contain relatively large cliques that make exact inference computationally challenging (although still possible).

Suppose the model has $d$ variables and variable $i$ can take on $n_i$ possible values.  Usually a factor defined over $k$ variables can be represented as a $k$ dimensional array.   However, if we instead represent the factor as a $d$ dimensional array, we can exploit built-in functionality of numpy array arithmetic to simplify the factor math.  The new factor representation has the size of each dimension being $1$ (if it is not in the factor) or $n_i$ (if it is in the factor).

When we want to multiply two factors, say, $AB$ and $AC$ we can just write $ABC = AB * AC$ and numpy will automatically broadcast these arrays to the desired compatible shape.  If the factors are in log space, we can just write $ ABC = AB+AC $ and the same thing happens with addition.  The marginalize a factor, we can just call the function np.sum(ABC, axis=ax, keepdims=True) and this will keep the dimensionality of the factor, so that it can be used for downstream factor computations. We can replace np.sum with scipy.misc.logsumexp if we are working in log space. 

By using this idea I was able to avoid solving the annoying problem of aligning factors before adding or multiplying them.  Thus, the code is cleaner and there is no overhead for aligning so the code is more efficient.


Post a Comment

Popular posts from this blog

Efficiently Remove Duplicate Rows from a 2D Numpy Array

Multi-Core Programming with Java

Beat the Streak: Day Three