Algorithm category: Adaptation / tuning (MCMC)
Paper reference
"A Gradient Based Strategy for Hamiltonian Monte Carlo Hyperparameter Optimization" — Andrew Campbell, Wenlong Chen, Vincent Stimper, José Miguel Hernández-Lobato, Yichuan Zhang. Proceedings of the 38th International Conference on Machine Learning (ICML), PMLR 139:1238–1248, 2021. https://proceedings.mlr.press/v139/campbell21a.html
Existing implementations
Benefit and motivation
Existing HMC adaptation in BlackJAX (window adaptation, dual averaging) tunes step size and mass matrix by optimising proxies such as acceptance rate, which can be a poor surrogate for actual mixing quality. Campbell et al. propose directly optimising the Sliced Kernelized Stein Discrepancy (SKSD) between the HMC chain's empirical distribution and the target, differentiating through the HMC transition operator using stochastic gradient descent. This gives a principled, metric-grounded adaptation scheme that is not tied to Gaussian geometry assumptions and can tune step size, number of leapfrog steps, and mass matrix jointly.
Comparison to existing BlackJAX algorithms
- Closest existing:
blackjax.window_adaptation (Stan-style dual averaging; optimises acceptance rate proxy)
- Advantage: Directly minimises a convergence-to-target metric (SKSD) rather than a proxy; differentiates through the HMC kernel to get gradient information about hyperparameters; can jointly tune all hyperparameters simultaneously.
- Limitation: Requires differentiating through multiple HMC steps, which can be memory-intensive for long trajectories; SKSD evaluation adds per-iteration overhead proportional to sample size; higher implementation complexity than dual averaging.
Estimated JAX implementation effort
L — Requires differentiating through the leapfrog integrator and accept/reject step with respect to hyperparameters. The custom_vjp machinery in JAX is needed to handle the implicit differentiation through the MH accept/reject step. SKSD computation is a kernel V-statistic that can be vectorised with jax.vmap, but the full adaptation loop is non-trivial to JIT-compile.
JAX-specific implementation notes
Differentiating through the leapfrog integrator is straightforward via jax.grad. The accept/reject step requires custom_vjp or a straight-through estimator for gradients to flow through the indicator function. lax.scan handles the trajectory of L leapfrog steps within a single HMC proposal. SKSD is evaluated using jax.vmap over a batch of chain samples. Stopping gradients through ∇_x log p*(x) inside SKSD (as done in the paper) is implemented with jax.lax.stop_gradient.
Willing to open a PR?
No — filing for community interest (re-filed from #395)
Re-filed from #395 which was closed as stale. Using the new structured proposal format.
Algorithm category: Adaptation / tuning (MCMC)
Paper reference
"A Gradient Based Strategy for Hamiltonian Monte Carlo Hyperparameter Optimization" — Andrew Campbell, Wenlong Chen, Vincent Stimper, José Miguel Hernández-Lobato, Yichuan Zhang. Proceedings of the 38th International Conference on Machine Learning (ICML), PMLR 139:1238–1248, 2021. https://proceedings.mlr.press/v139/campbell21a.html
Existing implementations
Benefit and motivation
Existing HMC adaptation in BlackJAX (window adaptation, dual averaging) tunes step size and mass matrix by optimising proxies such as acceptance rate, which can be a poor surrogate for actual mixing quality. Campbell et al. propose directly optimising the Sliced Kernelized Stein Discrepancy (SKSD) between the HMC chain's empirical distribution and the target, differentiating through the HMC transition operator using stochastic gradient descent. This gives a principled, metric-grounded adaptation scheme that is not tied to Gaussian geometry assumptions and can tune step size, number of leapfrog steps, and mass matrix jointly.
Comparison to existing BlackJAX algorithms
blackjax.window_adaptation(Stan-style dual averaging; optimises acceptance rate proxy)Estimated JAX implementation effort
L — Requires differentiating through the leapfrog integrator and accept/reject step with respect to hyperparameters. The
custom_vjpmachinery in JAX is needed to handle the implicit differentiation through the MH accept/reject step. SKSD computation is a kernel V-statistic that can be vectorised withjax.vmap, but the full adaptation loop is non-trivial to JIT-compile.JAX-specific implementation notes
Differentiating through the leapfrog integrator is straightforward via
jax.grad. The accept/reject step requirescustom_vjpor a straight-through estimator for gradients to flow through the indicator function.lax.scanhandles the trajectory ofLleapfrog steps within a single HMC proposal. SKSD is evaluated usingjax.vmapover a batch of chain samples. Stopping gradients through∇_x log p*(x)inside SKSD (as done in the paper) is implemented withjax.lax.stop_gradient.Willing to open a PR?
No — filing for community interest (re-filed from #395)
Re-filed from #395 which was closed as stale. Using the new structured proposal format.