Skip to content

Add AMAGOLD: amortized Metropolis-adjusted stochastic gradient MCMC #836

Description

@junpenglao

Algorithm category: SGMCMC

Paper reference

"AMAGOLD: Amortized Metropolis Adjustment for Efficient Stochastic Gradient MCMC" — Ruqi Zhang, A. Feder Cooper, Christopher De Sa. Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics (AISTATS), PMLR 108:2142–2152, 2020. https://arxiv.org/abs/2003.00193

Existing implementations

Benefit and motivation

Standard SGHMC and SGLD use mini-batch gradients and converge to biased stationary distributions, requiring a diminishing step size that slows mixing. AMAGOLD corrects this bias by applying infrequent Metropolis-Hastings correction steps — amortizing the full-dataset cost of the MH step across many cheap stochastic gradient updates — and provably converges to the exact target distribution with a fixed step size. BlackJAX already implements SGHMC and SGLD; adding AMAGOLD would give users a bias-free drop-in alternative with only a marginal increase in implementation complexity.

Comparison to existing BlackJAX algorithms

  • Closest existing: blackjax.sghmc (biased, no MH correction)
  • Advantage: Exact (asymptotically unbiased) posterior samples with fixed step size; convergence rate is at most a constant factor slower than a full-batch HMC baseline; well-suited for Bayesian deep learning where exact posteriors matter.
  • Limitation: Requires periodic full-dataset evaluations for the MH step, which can be expensive for very large datasets; the amortization frequency is an additional hyperparameter to tune.

Estimated JAX implementation effort

M — The core algorithm is SGHMC with an additional periodic MH correction. The main JAX complexity is tracking the accumulated log-acceptance ratio over multiple mini-batch steps and branching at the correction interval, which can be handled with lax.cond and a step counter in the kernel state.

JAX-specific implementation notes

The MH correction accumulates a running log-probability sum over K mini-batch steps, then applies lax.cond to accept/reject. The step counter and accumulated ratio are carried in the State NamedTuple. lax.scan naturally handles the inner loop of K SGHMC steps before each correction. The JaxSGMC implementation provides a working JAX reference.

Willing to open a PR?

No — filing for community interest (re-filed from #375)


Re-filed from #375 which was closed as stale. Using the new structured proposal format.

Metadata

Metadata

Assignees

No one assigned

    Labels

    samplerIssue related to samplers

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions