Solve it all and solve it fast: using numba for optimization in Python

6 minute read

Published:

When implementing optimization algorithms, we typically have to balance the following goals:

  • Speed: low execution time
  • Generality: the code should work for a class of optimization problems
  • Readability: easy to understand code

It is very hard to reach all goals at the same time. Typically, compiled languages such as C++ offer speed but lack the other two (at least for my taste). Some concrete examples of this tradeoff:

  • in scikit-learn many classical solvers are written in C++ or cython, for example the logistic regression solvers.
  • In Julia, it is (for my limited understanding) easier to achieve all goals: for a nice example, see the design of the ProximalOperators.jl and ProximalAlgorithms.jl packages.

However, the goal of this article is to present one approach of reaching all three goals in Python.

A simple example: proximal gradient descent

Consider problems of the form

\[\min_x f(x) + r(x)\]

where we assume $f$ to be continously differentiable and $r$ is a (closed, convex) regularizer. For a step size $\alpha>0$, the proximal gradient algorithm for such problems is given by the iterates

\[x^{k+1} = \mathrm{prox}_{\alpha r}(x^k- \alpha \nabla f(x^k)),\]

where $\mathrm{prox}$ is the proximal operator of a closed, convex function.

General Python implementation

If we implement an algorithm for problems of the above type, it would be favourable to have code that works for any functions f and r fulfilling the respective assumptions. Moreover, as we have a composite objective, we would prefer to have a solver which we can call for any combination of f and r we would like - without adapting the code of the solver.

An obvious approach to achieve this, is handling both f and r as instances of classes, having the following methods:

  • f needs the method grad which computes a gradient at a specific point,
  • r needs the method prox which computes the proximal operator of $\alpha\cdot r$ at a specific point.

Let us show the implementation for f being a quadratic function and r being the 1-norm.

class Quadratic:
	def __init__(self, A, b):
    	self.A = A
    	self.b = b
    def grad(self, x):
    	g = self.A @ x + self.b
    return g

The below formula for the proximal operator is well-known but for the understanding it is not so important here.

class L1Norm:
  def __init__(self, l):
    self.l = l
  def prox(self, x, alpha):
    return np.sign(x) * np.maximum(np.abs(x) - alpha*self.l, 0.)

Now, proximal gradient descent can be implemented generally with the following simple function:

def prox_gd(f, r, x0, alpha=0.1, max_iter=50):
  x = x0.copy()
  for i in range(max_iter):
    y = x-alpha*f.grad(x)
    x = r.prox(y,alpha)
  return x

This is general and very simple to read. If one would like to apply the algorithm to a different objective, he would only need to write the respective f and/or r. With this, a library of functions can be built and used modularly. However, all of the above is pure Python code and will therefore be pretty slow. Our goal is to use Numba in order to accelerate the implementation while keeping generality and readability.

Numba implementation

What is Numba?

Numba is a package for just-in-time (JIT) compilation. It is designed to speed up pure Python code using the decorator @njit. Numba support many functions built in numpy. A detailled list is here.

The speedup comes typically from for-loops - which naturally appear in optimization algorithms. Thus, we want to write our solver as a JIT-compiled numba function. The problem: everything that happens inside a JIT-compiled function must itself be compiled. Thus, if we want to make use of class methods inside the solver, the class must be such that every method is JIT-compiled. Luckily, numba offers this possibility using @jitclass.

The Jitclass decorator

When using @jitclass, it is important to specify the type of every attribute of the class. See the example below or the docs for all details. Our quadratic function class can be implemented as follows:

from numba.experimental import jitclass
from numba import float64, njit

spec = [
    ('b', float64[:]),               
    ('A', float64[:,:])
    ]
        
@jitclass(spec)
class Quadratic:
  def __init__(self, A, b):
    self.A = A
    self.b = b
  
  def grad(self, x):
    g = self.A @ x + self.b
    return g

Same with the 1-norm:

spec_l1 = [('l', float64)]

@jitclass(spec_l1)
class L1Norm:
  def __init__(self, l):
    self.l = l
  def prox(self, x, alpha):
    return np.sign(x) * np.maximum(np.abs(x) - alpha*self.l, 0.)

Remark: @jitclass alone does not necessarily speed up the code. The main speedup will come from for-loops, typically appearing in the solver.

Final steps

After implementing Quadratic and L1Norm as specific examples for f and r, we can now implement a numba-version of proximal gradient descent. We can pretty muchy copy the code and simply add the @njit decorator.

@njit()
def fast_prox_gd(f, r, x0, alpha=0.1, max_iter=50):
  x = x0.copy()

  for i in range(max_iter):
    y = x-alpha*f.grad(x)
    x = r.prox(y,alpha)

  return x

Some remarks on the @njit decorator (mainly a reminder to myself):

  • When calling a JIT-function for the first time, it will take longer as the code is compiled. For any subsequent call, the runtime should be much smaller.
  • If no types are specified then the code is compiled according to the types that are given when first called. If afterwards a different type is handed over (e.g. float instead of int), numba will recompile the function (which takes longer).
  • Useful ressources when being uncertain whether some method/data type is supported in numba: Python related and Numpy related

Runtime comparison

I implemented the pure Python and the Numba version of proximal gradient descent in this notebook.

For a simple 50-dimensional example with f being quadratic function and r the 1-norm, we get the following result:

# Python version
%timeit prox_gd(f, r, x0, alpha=0.001, max_iter=20000)
10 loops, best of 5: 164 ms per loop
# Numba version
%timeit fast_prox_gd(f, r, x0, alpha=0.001, max_iter=20000)
10 loops, best of 5: 54.2 ms per loop

Even for this simple example, we already get a speedup factor over 3. Of course, how much speedup is possible depends on how much of the computation is due to the loop or rather due to numerical heaviness (e.g. matrix vector multiplitication in high dimensions).

If the gradient or prox computation involves for-loops (e.g. Condat’s algorithm for total variation regularization), using numba will result in significant speedups in my experience. The outlined approach can also be applied to stochastic algorithms where the number of iterations and thus the speedup is typically large. You can find some standard algorithms such as SGD, SAGA or SVRG in this repository.

Thanks for reading!

  • A package with a similar approach as described is copt.
  • numba and cython: http://gouthamanbalaraman.com/blog/optimizing-python-numba-vs-cython.html