Algorithm category: SGMCMC
Paper reference
"AMAGOLD: Amortized Metropolis Adjustment for Efficient Stochastic Gradient MCMC" — Ruqi Zhang, A. Feder Cooper, Christopher De Sa. Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics (AISTATS), PMLR 108:2142–2152, 2020. https://arxiv.org/abs/2003.00193
Existing implementations
Benefit and motivation
Standard SGHMC and SGLD use mini-batch gradients and converge to biased stationary distributions, requiring a diminishing step size that slows mixing. AMAGOLD corrects this bias by applying infrequent Metropolis-Hastings correction steps — amortizing the full-dataset cost of the MH step across many cheap stochastic gradient updates — and provably converges to the exact target distribution with a fixed step size. BlackJAX already implements SGHMC and SGLD; adding AMAGOLD would give users a bias-free drop-in alternative with only a marginal increase in implementation complexity.
Comparison to existing BlackJAX algorithms
- Closest existing:
blackjax.sghmc (biased, no MH correction)
- Advantage: Exact (asymptotically unbiased) posterior samples with fixed step size; convergence rate is at most a constant factor slower than a full-batch HMC baseline; well-suited for Bayesian deep learning where exact posteriors matter.
- Limitation: Requires periodic full-dataset evaluations for the MH step, which can be expensive for very large datasets; the amortization frequency is an additional hyperparameter to tune.
Estimated JAX implementation effort
M — The core algorithm is SGHMC with an additional periodic MH correction. The main JAX complexity is tracking the accumulated log-acceptance ratio over multiple mini-batch steps and branching at the correction interval, which can be handled with lax.cond and a step counter in the kernel state.
JAX-specific implementation notes
The MH correction accumulates a running log-probability sum over K mini-batch steps, then applies lax.cond to accept/reject. The step counter and accumulated ratio are carried in the State NamedTuple. lax.scan naturally handles the inner loop of K SGHMC steps before each correction. The JaxSGMC implementation provides a working JAX reference.
Willing to open a PR?
No — filing for community interest (re-filed from #375)
Re-filed from #375 which was closed as stale. Using the new structured proposal format.
Algorithm category: SGMCMC
Paper reference
"AMAGOLD: Amortized Metropolis Adjustment for Efficient Stochastic Gradient MCMC" — Ruqi Zhang, A. Feder Cooper, Christopher De Sa. Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics (AISTATS), PMLR 108:2142–2152, 2020. https://arxiv.org/abs/2003.00193
Existing implementations
Benefit and motivation
Standard SGHMC and SGLD use mini-batch gradients and converge to biased stationary distributions, requiring a diminishing step size that slows mixing. AMAGOLD corrects this bias by applying infrequent Metropolis-Hastings correction steps — amortizing the full-dataset cost of the MH step across many cheap stochastic gradient updates — and provably converges to the exact target distribution with a fixed step size. BlackJAX already implements SGHMC and SGLD; adding AMAGOLD would give users a bias-free drop-in alternative with only a marginal increase in implementation complexity.
Comparison to existing BlackJAX algorithms
blackjax.sghmc(biased, no MH correction)Estimated JAX implementation effort
M — The core algorithm is SGHMC with an additional periodic MH correction. The main JAX complexity is tracking the accumulated log-acceptance ratio over multiple mini-batch steps and branching at the correction interval, which can be handled with
lax.condand a step counter in the kernel state.JAX-specific implementation notes
The MH correction accumulates a running log-probability sum over
Kmini-batch steps, then applieslax.condto accept/reject. The step counter and accumulated ratio are carried in theStateNamedTuple.lax.scannaturally handles the inner loop ofKSGHMC steps before each correction. The JaxSGMC implementation provides a working JAX reference.Willing to open a PR?
No — filing for community interest (re-filed from #375)
Re-filed from #375 which was closed as stale. Using the new structured proposal format.