Early stopping with PyMC3’s sampling callback

When working with PyMC3 you often find yourself looking at a trace that you know isn’t going to converge anyway. Wouldn’t it be great if there was an automatic way to detect that and stop the sampling process?

Matthijs Brouns true
01-26-2020

I was recently discussing probabilistic programming with PyMC3 in my training. We were testing out some different models, one of which sampled very slowly and had a lot of divergences. While waiting, we discussed if PyMC3 could detect these divergences, and stop if they reached a certain threshold.

Unfortunately, PyMC3 didn’t offer any such thing built-in, so I started hacking away. This resulted in this pull request request which was recently merged. The pull request adds a callback parameter to pm.sample. in this post, we’ll explore why I chose a callback and how you can use it.

The implementation

There are a few options to add early stopping to a package like PyMC3. The obvious way might be to add arguments to the pm.sample function. Using these arguments, we could threshold maximum divergences, or minimum effective samples. PyMC3 would then check these thresholds for each sample and stop sampling if it reaches them. The signature of pm.sample would look something like this:


def sample(
  draws=500,
  tune=500,
  ...
  max_divergences=None,
  min_effective_samples=None
)

There’s a bit of an issue though with the above implementation. For every possible early stopping criterium, we need to add an extra argument to the function. This might not seem like a big problem, but consider tha the set of all possible early stopping criteria is likely to be a lot bigger than we know, and it grows over time.

Ines Montani (one of the developers behind spacy.io,) summarized this problem during her terrific PyCon India Keynote. The gist of the talk is:

Make your tools programmable and let your users write code.

By allowing users of your tools to interact with your library programatically, you suddenly give them all the power that they have available in the programming language, rather than just the power that exists in your library A practical tip that she gives to realize this is to use callbacks. A callback is a function that gets passed to another function. The callback is then called at a later time with a (partial) result. A pseudo-implementation of a sample callback in PyMC3 looks something like this:


def sample(draws=1000, on_draw):
  trace = Trace()
  
  for i in range(draws)
    draw = new_draw(draw)
    trace.add(draw)
    on_draw(trace, draw)  # Execute the callback
  
  return draw
  
def log_result(trace, draw):
  logger.info(f'current_draw is {draw}')

sample(on_raw=log_result)

Of course, there’s a lot more subtlety involved, but the essential parts are all there. Something that is missing though, is the ability to stop sampling from our callback. It turns out PyMC3 already catches KeyBoardInterrupts to wrap up the sampling process. We can use that exception as a cancelling signal inside our callback functions.

This callback allows every PyMC3 user to build whatever stopping criteria they need, instead of relying on what is available out of the box.

Usage

So how do we use this callback in practice?s Some of the use-cases that feel obvious include:

We’ll start with defining a simple model

A callback that stops sampling when we reach 100 samples is then added like this:


def my_callback(trace, draw):
    if len(trace) >= 100:
        raise KeyboardInterrupt()
    
with model:
    trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)
    
assert len(trace) == 100

Something subtle about this callback is that it gets called with the trace of a single chain. That means that if we sample multiple chains, we need some machinery to keep track of the other chains as well. We can build a callback that stops whenever we reach n total samples like this:


class StopNSamples:
    def __init__(self, n_samples):
        self.lengths = Counter()
        self.n_samples = n_samples
    
    def __call__(self, trace, draw):
        self.lengths[draw.chain] += 1
        if min(self.lengths.values()) > self.n_samples:
            raise KeyboardInterrupt()

Here we use the fact that a draw has a chain attribute that tells which chain the draw belongs to. We store the lengths of each chain and only stop sampling if the total length exceeds n_samples. A more useful callback that stops whenever we reach a predefined r_hat value can look like this:


class MaxRHatCallback:
    def __init__(self, every=1000, max_rhat=1.05):
        self.every = every
        self.max_rhat = max_rhat
        self.traces = {}
    
    def __call__(self, trace, draw):
        if draw.tuning:
            return

        self.traces[draw.chain] = trace
        if len(trace) % self.every == 0:    
            multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
            if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
                raise KeyboardInterrupt

I’m looking forward to exploring more use-cases for this callback. Let me know how you can use this on Twitter!