Algorithm category: MCMC (non-reversible, gradient-based)
Paper reference
"A Discrete Bouncy Particle Sampler" — Chris Sherlock, Alexandre H. Thiery. Biometrika, Volume 109, Issue 2, June 2022, Pages 335–349. arXiv preprint: https://arxiv.org/abs/1707.05200
Existing implementations
- No public JAX implementation found.
- The paper's supplementary material contains R code.
Benefit and motivation
The Discrete Bouncy Particle Sampler (DBPS) is a non-reversible discrete-time MCMC algorithm inspired by the continuous-time Bouncy Particle Sampler, but requiring only pointwise evaluations of the target density and its gradient — no Poisson process simulation or local gradient upper bounds are needed. It combines a gradient-guided random walk, partial momentum refreshment, and a delayed-rejection bounce step to achieve non-diffusive, directionally persistent exploration of the state space. Adding DBPS would make BlackJAX the first major JAX inference library to offer a non-reversible discrete-time sampler, complementing the existing HMC/NUTS family and expanding the set of geometries BlackJAX can efficiently explore.
Comparison to existing BlackJAX algorithms
- Closest existing:
blackjax.hmc (gradient-based, reversible) or the continuous bouncy particle sampler (continuous-time, non-reversible, requires Poisson process simulation)
- Advantage: Non-reversible dynamics give non-diffusive, persistent exploration; simpler to implement than continuous-time PDMPs (no event-time simulation or upper bounds on the gradient); theoretically grounded scaling limit in Gaussian settings enables principled tuning of the partial-refreshment parameter.
- Limitation: Delayed-rejection step requires two gradient evaluations per proposal in the bounce case; performance advantage over HMC is problem-dependent and most pronounced for near-Gaussian targets.
Estimated JAX implementation effort
M — The algorithm consists of a guided random walk step, a partial refreshment of the velocity, and a delayed-rejection bounce step. All three components are straightforward to implement in JAX. The main challenge is the delayed-rejection logic, which can be handled with lax.cond.
JAX-specific implementation notes
The guided random walk and velocity refreshment are simple arithmetic operations on position and velocity arrays — fully jax.jit-compatible. The delayed-rejection bounce uses lax.cond to conditionally negate the velocity, avoiding Python-level control flow. State carries (position, velocity); no lax.while_loop or custom VJP is needed. The partial-refreshment parameter can be tuned online using the scaling-limit result from the paper.
Willing to open a PR?
No — filing for community interest (re-filed from #386)
Re-filed from #386 which was closed as stale. Using the new structured proposal format.
Algorithm category: MCMC (non-reversible, gradient-based)
Paper reference
"A Discrete Bouncy Particle Sampler" — Chris Sherlock, Alexandre H. Thiery. Biometrika, Volume 109, Issue 2, June 2022, Pages 335–349. arXiv preprint: https://arxiv.org/abs/1707.05200
Existing implementations
Benefit and motivation
The Discrete Bouncy Particle Sampler (DBPS) is a non-reversible discrete-time MCMC algorithm inspired by the continuous-time Bouncy Particle Sampler, but requiring only pointwise evaluations of the target density and its gradient — no Poisson process simulation or local gradient upper bounds are needed. It combines a gradient-guided random walk, partial momentum refreshment, and a delayed-rejection bounce step to achieve non-diffusive, directionally persistent exploration of the state space. Adding DBPS would make BlackJAX the first major JAX inference library to offer a non-reversible discrete-time sampler, complementing the existing HMC/NUTS family and expanding the set of geometries BlackJAX can efficiently explore.
Comparison to existing BlackJAX algorithms
blackjax.hmc(gradient-based, reversible) or the continuous bouncy particle sampler (continuous-time, non-reversible, requires Poisson process simulation)Estimated JAX implementation effort
M — The algorithm consists of a guided random walk step, a partial refreshment of the velocity, and a delayed-rejection bounce step. All three components are straightforward to implement in JAX. The main challenge is the delayed-rejection logic, which can be handled with
lax.cond.JAX-specific implementation notes
The guided random walk and velocity refreshment are simple arithmetic operations on position and velocity arrays — fully
jax.jit-compatible. The delayed-rejection bounce useslax.condto conditionally negate the velocity, avoiding Python-level control flow. State carries(position, velocity); nolax.while_loopor custom VJP is needed. The partial-refreshment parameter can be tuned online using the scaling-limit result from the paper.Willing to open a PR?
No — filing for community interest (re-filed from #386)
Re-filed from #386 which was closed as stale. Using the new structured proposal format.