Skip to content

Add gradient-based HMC hyperparameter adaptation via Stein discrepancy (Campbell et al. 2021) #838

Description

@junpenglao

Algorithm category: Adaptation / tuning (MCMC)

Paper reference

"A Gradient Based Strategy for Hamiltonian Monte Carlo Hyperparameter Optimization" — Andrew Campbell, Wenlong Chen, Vincent Stimper, José Miguel Hernández-Lobato, Yichuan Zhang. Proceedings of the 38th International Conference on Machine Learning (ICML), PMLR 139:1238–1248, 2021. https://proceedings.mlr.press/v139/campbell21a.html

Existing implementations

Benefit and motivation

Existing HMC adaptation in BlackJAX (window adaptation, dual averaging) tunes step size and mass matrix by optimising proxies such as acceptance rate, which can be a poor surrogate for actual mixing quality. Campbell et al. propose directly optimising the Sliced Kernelized Stein Discrepancy (SKSD) between the HMC chain's empirical distribution and the target, differentiating through the HMC transition operator using stochastic gradient descent. This gives a principled, metric-grounded adaptation scheme that is not tied to Gaussian geometry assumptions and can tune step size, number of leapfrog steps, and mass matrix jointly.

Comparison to existing BlackJAX algorithms

  • Closest existing: blackjax.window_adaptation (Stan-style dual averaging; optimises acceptance rate proxy)
  • Advantage: Directly minimises a convergence-to-target metric (SKSD) rather than a proxy; differentiates through the HMC kernel to get gradient information about hyperparameters; can jointly tune all hyperparameters simultaneously.
  • Limitation: Requires differentiating through multiple HMC steps, which can be memory-intensive for long trajectories; SKSD evaluation adds per-iteration overhead proportional to sample size; higher implementation complexity than dual averaging.

Estimated JAX implementation effort

L — Requires differentiating through the leapfrog integrator and accept/reject step with respect to hyperparameters. The custom_vjp machinery in JAX is needed to handle the implicit differentiation through the MH accept/reject step. SKSD computation is a kernel V-statistic that can be vectorised with jax.vmap, but the full adaptation loop is non-trivial to JIT-compile.

JAX-specific implementation notes

Differentiating through the leapfrog integrator is straightforward via jax.grad. The accept/reject step requires custom_vjp or a straight-through estimator for gradients to flow through the indicator function. lax.scan handles the trajectory of L leapfrog steps within a single HMC proposal. SKSD is evaluated using jax.vmap over a batch of chain samples. Stopping gradients through ∇_x log p*(x) inside SKSD (as done in the paper) is implemented with jax.lax.stop_gradient.

Willing to open a PR?

No — filing for community interest (re-filed from #395)


Re-filed from #395 which was closed as stale. Using the new structured proposal format.

Metadata

Metadata

Assignees

No one assigned

    Labels

    samplerIssue related to samplers

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions