When working with PyMC3 you often find yourself looking at a trace that you know isn’t going to converge anyway. Wouldn’t it be great if there was an automatic way to detect that and stop the sampling process?
I was recently discussing probabilistic programming with PyMC3 in my training. We were testing out some different models, one of which sampled very slowly and had a lot of divergences. While waiting, we discussed if PyMC3 could detect these divergences, and stop if they reached a certain threshold.
Unfortunately, PyMC3 didn’t offer any such thing built-in, so I started hacking away. This resulted in this pull request request which was recently merged. The pull request adds a callback parameter to pm.sample
. in this post, we’ll explore why I chose a callback and how you can use it.
There are a few options to add early stopping to a package like PyMC3. The obvious way might be to add arguments to the pm.sample function. Using these arguments, we could threshold maximum divergences, or minimum effective samples. PyMC3 would then check these thresholds for each sample and stop sampling if it reaches them. The signature of pm.sample
would look something like this:
def sample(
draws=500,
tune=500,
...
max_divergences=None,
min_effective_samples=None
)
There’s a bit of an issue though with the above implementation. For every possible early stopping criterium, we need to add an extra argument to the function. This might not seem like a big problem, but consider tha the set of all possible early stopping criteria is likely to be a lot bigger than we know, and it grows over time.
Ines Montani (one of the developers behind spacy.io,) summarized this problem during her terrific PyCon India Keynote. The gist of the talk is:
Make your tools programmable and let your users write code.
By allowing users of your tools to interact with your library programatically, you suddenly give them all the power that they have available in the programming language, rather than just the power that exists in your library A practical tip that she gives to realize this is to use callbacks. A callback is a function that gets passed to another function. The callback is then called at a later time with a (partial) result. A pseudo-implementation of a sample callback in PyMC3 looks something like this:
def sample(draws=1000, on_draw):
trace = Trace()
for i in range(draws)
draw = new_draw(draw)
trace.add(draw)
on_draw(trace, draw) # Execute the callback
return draw
def log_result(trace, draw):
logger.info(f'current_draw is {draw}')
sample(on_raw=log_result)
Of course, there’s a lot more subtlety involved, but the essential parts are all there. Something that is missing though, is the ability to stop sampling from our callback. It turns out PyMC3 already catches KeyBoardInterrupt
s to wrap up the sampling process. We can use that exception as a cancelling signal inside our callback functions.
This callback allows every PyMC3 user to build whatever stopping criteria they need, instead of relying on what is available out of the box.
So how do we use this callback in practice?s Some of the use-cases that feel obvious include:
We’ll start with defining a simple model
A callback that stops sampling when we reach 100 samples is then added like this:
def my_callback(trace, draw):
if len(trace) >= 100:
raise KeyboardInterrupt()
with model:
trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)
assert len(trace) == 100
Something subtle about this callback is that it gets called with the trace of a single chain. That means that if we sample multiple chains, we need some machinery to keep track of the other chains as well. We can build a callback that stops whenever we reach n total samples like this:
class StopNSamples:
def __init__(self, n_samples):
self.lengths = Counter()
self.n_samples = n_samples
def __call__(self, trace, draw):
self.lengths[draw.chain] += 1
if min(self.lengths.values()) > self.n_samples:
raise KeyboardInterrupt()
Here we use the fact that a draw
has a chain
attribute that tells which chain the draw belongs to. We store the lengths of each chain and only stop sampling if the total length exceeds n_samples
. A more useful callback that stops whenever we reach a predefined r_hat
value can look like this:
class MaxRHatCallback:
def __init__(self, every=1000, max_rhat=1.05):
self.every = every
self.max_rhat = max_rhat
self.traces = {}
def __call__(self, trace, draw):
if draw.tuning:
return
self.traces[draw.chain] = trace
if len(trace) % self.every == 0:
multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
raise KeyboardInterrupt
I’m looking forward to exploring more use-cases for this callback. Let me know how you can use this on Twitter!