Skip to content

Add discrete bouncy particle sampler (Sherlock & Thiery) #837

Description

@junpenglao

Algorithm category: MCMC (non-reversible, gradient-based)

Paper reference

"A Discrete Bouncy Particle Sampler" — Chris Sherlock, Alexandre H. Thiery. Biometrika, Volume 109, Issue 2, June 2022, Pages 335–349. arXiv preprint: https://arxiv.org/abs/1707.05200

Existing implementations

  • No public JAX implementation found.
  • The paper's supplementary material contains R code.

Benefit and motivation

The Discrete Bouncy Particle Sampler (DBPS) is a non-reversible discrete-time MCMC algorithm inspired by the continuous-time Bouncy Particle Sampler, but requiring only pointwise evaluations of the target density and its gradient — no Poisson process simulation or local gradient upper bounds are needed. It combines a gradient-guided random walk, partial momentum refreshment, and a delayed-rejection bounce step to achieve non-diffusive, directionally persistent exploration of the state space. Adding DBPS would make BlackJAX the first major JAX inference library to offer a non-reversible discrete-time sampler, complementing the existing HMC/NUTS family and expanding the set of geometries BlackJAX can efficiently explore.

Comparison to existing BlackJAX algorithms

  • Closest existing: blackjax.hmc (gradient-based, reversible) or the continuous bouncy particle sampler (continuous-time, non-reversible, requires Poisson process simulation)
  • Advantage: Non-reversible dynamics give non-diffusive, persistent exploration; simpler to implement than continuous-time PDMPs (no event-time simulation or upper bounds on the gradient); theoretically grounded scaling limit in Gaussian settings enables principled tuning of the partial-refreshment parameter.
  • Limitation: Delayed-rejection step requires two gradient evaluations per proposal in the bounce case; performance advantage over HMC is problem-dependent and most pronounced for near-Gaussian targets.

Estimated JAX implementation effort

M — The algorithm consists of a guided random walk step, a partial refreshment of the velocity, and a delayed-rejection bounce step. All three components are straightforward to implement in JAX. The main challenge is the delayed-rejection logic, which can be handled with lax.cond.

JAX-specific implementation notes

The guided random walk and velocity refreshment are simple arithmetic operations on position and velocity arrays — fully jax.jit-compatible. The delayed-rejection bounce uses lax.cond to conditionally negate the velocity, avoiding Python-level control flow. State carries (position, velocity); no lax.while_loop or custom VJP is needed. The partial-refreshment parameter can be tuned online using the scaling-limit result from the paper.

Willing to open a PR?

No — filing for community interest (re-filed from #386)


Re-filed from #386 which was closed as stale. Using the new structured proposal format.

Metadata

Metadata

Assignees

No one assigned

    Labels

    samplerIssue related to samplers

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions