Algorithm category
MCMC (gradient-based)
Paper reference
Title: Divide, Interact, Sample: The Two-System Paradigm
Authors: James Chok, Myung Won Lee, Daniel Paulin, Geoffrey M. Vasil
Year: 2026
Link: https://arxiv.org/abs/2509.09162
Existing implementations
Official repo: https://github.com/paulindani/MAKLA_JAX
Implementation is already in JAX.
Benefit and motivation
In the paper, there is
(1) a new sampler MAKLA-BCSS2, which is a variant of GHMC, using deterministic 2 gradient and 1 likelihood evaluation per step.
(2) There are new ensemble adaptation methods for finding good mass matrix to get better preconditioning.
https://arxiv.org/abs/2509.09162 has shown significant improvements in efficiency over BlackJAX NUTS on 48 different posteriors from PosteriorDB.
Comparison to existing BlackJAX algorithms
Closest existing method: NUTS
Advantage over it: deterministic computation cost per iteration (more efficient VMAP, better implementation efficiency), better statistical efficiency in terms of grad/ESS.
Disadvantage over NUTS: none that I am aware of.
Estimated JAX implementation effort
S — standard ops, fits existing pattern
JAX-specific implementation notes
No response
Are you willing to open a PR?
Yes — I can implement this
Algorithm category
MCMC (gradient-based)
Paper reference
Title: Divide, Interact, Sample: The Two-System Paradigm
Authors: James Chok, Myung Won Lee, Daniel Paulin, Geoffrey M. Vasil
Year: 2026
Link: https://arxiv.org/abs/2509.09162
Existing implementations
Official repo: https://github.com/paulindani/MAKLA_JAX
Implementation is already in JAX.
Benefit and motivation
In the paper, there is
(1) a new sampler MAKLA-BCSS2, which is a variant of GHMC, using deterministic 2 gradient and 1 likelihood evaluation per step.
(2) There are new ensemble adaptation methods for finding good mass matrix to get better preconditioning.
https://arxiv.org/abs/2509.09162 has shown significant improvements in efficiency over BlackJAX NUTS on 48 different posteriors from PosteriorDB.
Comparison to existing BlackJAX algorithms
Closest existing method: NUTS
Advantage over it: deterministic computation cost per iteration (more efficient VMAP, better implementation efficiency), better statistical efficiency in terms of grad/ESS.
Disadvantage over NUTS: none that I am aware of.
Estimated JAX implementation effort
S — standard ops, fits existing pattern
JAX-specific implementation notes
No response
Are you willing to open a PR?
Yes — I can implement this