L2 projection onto the probability simplex with bound constraints

Projecting onto the probability simplex is a common problem that arises in frequency estimation and related tasks.  This problem can be solved efficiently with an $O(n \log{(n)})$ algorithm.  In this blog post I consider the related problem of projecting onto the probability simplex with bound constraints.  Given a vector $ \mathbf{r} \in \mathbb{R}^n $, our goal is to find a vector $\mathbf{p}^*$ that solves the following optimization problem. $$ \begin{equation*} \begin{aligned} & \underset{\mathbf{p}}{\text{minimize}} & & \frac{1}{2} \lvert \lvert \mathbf{p} - \mathbf{r} \rvert \rvert_2^2 \\ & \text{subject to} & & \mathbf{1}^T \mathbf{p} = 1 \\ & & & \mathbf{a} \leq \mathbf{p} \leq \mathbf{b} \\ \end{aligned} \end{equation*} $$ This problem generalizes the standard probability simplex projection problem by introducing arbitrary bound constraints $ \mathbf{a} \leq \mathbf{p} \leq \mathbf{b}$. Here, $\mathbf{a}, \mathbf{b} \in \mathbb{R}^n $. In order for the above problem to be feasible, we need $ \mathbf{1}^T \mathbf{a} \leq 1 \leq \mathbf{1}^T \mathbf{b} $. I will now state an efficient $O(n \log{(n)})$ algorithm that solves the above problem exactly, then will argue that the algorithm is correct.
import numpy as np

def project(r, a, b):
    n = r.size
    total = np.sum(a)
    lambdas = np.append(a-r, b-r)
    idx = np.argsort(lambdas)
    lambdas = lambdas[idx]
    active = 1
    for i in range(1, 2*n):
        total += active*(lambdas[i] - lambdas[i-1])
        if total >= 1:
            lam = (1-total) / active + lambdas[i]
            return np.clip(r + lam, a, b)
        elif idx[i] < n:
            active += 1
        else:
            active -= 1


I will first argue that the solution to the projection problem is: $$ \mathbf{p}^* = \text{clip}(\mathbf{r} + \lambda, \mathbf{a}, \mathbf{b}) $$ where "clip" is the function that truncates values outside the range $[\mathbf{a}, \mathbf{b}]$ to the range $[\mathbf{a}, \mathbf{b}]$ and $\lambda$ is chosen so that $\mathbf{1}^T \mathbf{p}^* = 1$. First note that our problem is clearly convex because the objective is a quadratic function and the constraints are linear. Thus, first order optimality implies global optimality, and we can study the KKT conditions of the problem. To do so, let's form the Lagrangian function: $$ \mathcal{L}(\mathbf{p}, \lambda, \mathbf{\alpha}, \mathbf{\beta}) = \frac{1}{2} || \mathbf{p} - \mathbf{r} ||_2^2 + \lambda (1 - \mathbf{1}^T \mathbf{p}) + \mathbf{\alpha}^T(\mathbf{a} - \mathbf{p}) + \mathbf{\beta}^T (\mathbf{p} - \mathbf{b}) $$ The KKT conditions are listed below.
Stationarity: $$ \nabla_{p_i} \mathcal{L}(\mathbf{p}, \lambda, \alpha, \beta) = p_i - r_i - \lambda - \alpha_i + \beta_i = 0$$ $$ \rightarrow p_i = r_i + \lambda + \alpha_i - \beta_i $$ Primal Feasibility: $$ a_i \leq p_i \leq b_i $$ $$ \sum_i p_i = 1 $$ Dual Feasibility: $$ \alpha_i \geq 0 $$ $$ \beta_i \geq 0 $$ Complementary Slackness: $$\alpha_i (a_i - p_i) = 0 $$ $$\beta_i (p_i - b_i) = 0 $$
Studying the above conditions we see that the solution should be of the form $p_i = \text{clip}(r_i + \lambda, a_i, b_i)$. This is because for any $\lambda$, the $\alpha_i$ and $\beta_i$ will be chosen so that the other KKT conditions will be satisfied (dual feasibility and complementary slackness). When $ a_i < r_i + \lambda < b_i $, then the $\alpha_i = \beta_i = 0$ because both constraints are inactive. However, when $a_i > r_i + \lambda$ or $ b_i < r_i + \lambda $ then $\alpha_i$ and $\beta_i$ will be chosen so that the primal feasibility conditions will be satisfied (i.e., by clipping). Thus, we can solve the problem by finding the value of $\lambda$ such that $ \mathbf{1}^T \mathbf{p} = 1$, where $ \mathbf{p} = \text{clip}(\mathbf{r} + \lambda, \mathbf{a}, \mathbf{b}) $.
We can accomplish this using a generic root finding algorithm such as binary search to find the root of the function $f(\lambda) = 1 - \mathbf{1}^T \text{clip}(\mathbf{r} + \lambda, \mathbf{a}, \mathbf{b}) $. This works because $f$ is a monotonic function of $\lambda$. This approach would require $O(k n)$ time where $k$ is the number of required iterations by the binary search to find (or approximate) the solution.

Alternatively, we can use the python algorithm above, which computes $\lambda$ exactly in $O(n \log{(n)})$ time. Intuitively, the algorithm above begins with an initial estimate of $\lambda$ which would make $\mathbf{p} = \mathbf{a}$. This is leads to the smallest possible value of $\mathbf{1}^T \mathbf{p}$. This initial value of $\lambda$ is obtained by finding $\min_{i} (a_i - r_i)$. Then we want to gradually increase $\lambda$ until $ \mathbf{1}^T \mathbf{p} \geq 1 $. We increase $\lambda$ in discrete amounts based on the values needed to reach the lower and upper bounds (i.e., $ \mathbf{a} - \mathbf{r} $ and $\mathbf{b} - \mathbf{r}$). The amount that a small increase in $\lambda$ will affect the total $ \mathbf{1}^T \mathbf{p}$ depends on the active counter in the algorithm. This counter represents the number of entries in the solution that will increase with a small increase in $\lambda$. This corresponds to exactly those entries where $a_i \leq r_i + \lambda < b_i$. Once the total exceeds $1$, we know that the true value of $\lambda$ lies within the range lambdas[i-1] and lambdas[i], and we can compute the value exactly by solving $ 1 = \text{total} + \text{active}*\delta $ for $\delta$ and setting $\lambda = $ lambdas[i] + $\delta$.

A final note is that while the above algorithm has good asymptotics, it uses for loops and is not particularly efficient in practice. Thus, a more efficient (but equivalent) vectorized implementation is provided below.
import numpy as np

def project_fast(r, a, b):
    assert a.sum() <= 1 and b.sum() >= 1 and np.all(a <= b), 'not feasible'
    lambdas = np.append(a - r, b - r)
    idx = np.argsort(lambdas)
    lambdas = lambdas[idx]
    active = np.cumsum((idx < r.size)*2 - 1)[:-1]
    diffs = np.diff(lambdas, n=1)
    totals = a.sum() + np.cumsum(active*diffs)
    i = np.searchsorted(totals, 1.0)
    lam = (1 - totals[i]) / active[i] + lambdas[i+1]
    return np.clip(r + lam, a, b)

Comments

Popular posts from this blog

Multi-Core Programming with Java

Beat the Streak: Day Three

Efficiently Remove Duplicate Rows from a 2D Numpy Array