Algorithm category: Diagnostic
Paper reference
Two concurrent 2016 ICML papers establishing the KSD goodness-of-fit test:
-
"A Kernelized Stein Discrepancy for Goodness-of-fit Tests and Model Evaluation" — Qiang Liu, Jason D. Lee, Michael I. Jordan. ICML, PMLR 48, 2016. https://arxiv.org/abs/1602.03253
-
"A Kernel Test of Goodness of Fit" — Kacper Chwialkowski, Heiko Strathmann, Arthur Gretton. ICML, PMLR 48:2606–2615, 2016. https://arxiv.org/abs/1602.02964
Existing implementations
Benefit and motivation
KSD is a sample-quality measure that quantifies how well a set of MCMC samples approximates the target distribution, without requiring the normalising constant. Unlike R-hat or effective sample size, KSD directly measures distributional discrepancy and is applicable to both MCMC and variational inference outputs. It is already used internally in BlackJAX-related work (TESS, MAMBA adaptation) and is the natural companion diagnostic for SVGD, which is already in BlackJAX. Adding KSD as an official diagnostic would give users a principled convergence check and enable downstream work such as automated stopping criteria.
Comparison to existing BlackJAX algorithms
- Closest existing: No diagnostic utilities currently in BlackJAX (only
blackjax.diagnostics with R-hat, ESS, and MCSE).
- Advantage: Distribution-level convergence measure; does not require running multiple chains (unlike R-hat); applicable to VI and MCMC outputs alike; differentiable, enabling gradient-based adaptation (as in MAMBA and Campbell et al. 2021).
- Limitation: Quadratic cost O(n²) in sample size (though random Fourier feature and sliced approximations exist); requires choosing a kernel and bandwidth, which can affect sensitivity.
Estimated JAX implementation effort
S — KSD is a V-statistic (double sum over pairs of samples) of a closed-form Stein kernel. The Stein kernel involves ∇_x log p(x), which JAX can compute automatically. The entire computation is a single jax.vmap-over-jax.vmap kernel matrix evaluation plus a sum, fully JIT-compatible.
JAX-specific implementation notes
jax.vmap over sample pairs computes the Stein kernel matrix efficiently. jax.grad provides ∇_x log p(x) without manual derivation. No lax.while_loop or custom_vjp is needed. A sliced variant (SKSD) can be added later to address the O(n²) scaling. The existing implementation in albcab/TESS/mcmc_utils.py is a ready-to-adapt JAX reference.
Willing to open a PR?
No — filing for community interest (re-filed from #384)
Re-filed from #384 which was closed as stale. Using the new structured proposal format.
Algorithm category: Diagnostic
Paper reference
Two concurrent 2016 ICML papers establishing the KSD goodness-of-fit test:
"A Kernelized Stein Discrepancy for Goodness-of-fit Tests and Model Evaluation" — Qiang Liu, Jason D. Lee, Michael I. Jordan. ICML, PMLR 48, 2016. https://arxiv.org/abs/1602.03253
"A Kernel Test of Goodness of Fit" — Kacper Chwialkowski, Heiko Strathmann, Arthur Gretton. ICML, PMLR 48:2606–2615, 2016. https://arxiv.org/abs/1602.02964
Existing implementations
ksddescent(PyTorch/JAX hybrid): https://github.com/pierreablin/ksddescentBenefit and motivation
KSD is a sample-quality measure that quantifies how well a set of MCMC samples approximates the target distribution, without requiring the normalising constant. Unlike R-hat or effective sample size, KSD directly measures distributional discrepancy and is applicable to both MCMC and variational inference outputs. It is already used internally in BlackJAX-related work (TESS, MAMBA adaptation) and is the natural companion diagnostic for SVGD, which is already in BlackJAX. Adding KSD as an official diagnostic would give users a principled convergence check and enable downstream work such as automated stopping criteria.
Comparison to existing BlackJAX algorithms
blackjax.diagnosticswith R-hat, ESS, and MCSE).Estimated JAX implementation effort
S — KSD is a V-statistic (double sum over pairs of samples) of a closed-form Stein kernel. The Stein kernel involves
∇_x log p(x), which JAX can compute automatically. The entire computation is a singlejax.vmap-over-jax.vmapkernel matrix evaluation plus a sum, fully JIT-compatible.JAX-specific implementation notes
jax.vmapover sample pairs computes the Stein kernel matrix efficiently.jax.gradprovides∇_x log p(x)without manual derivation. Nolax.while_looporcustom_vjpis needed. A sliced variant (SKSD) can be added later to address the O(n²) scaling. The existing implementation inalbcab/TESS/mcmc_utils.pyis a ready-to-adapt JAX reference.Willing to open a PR?
No — filing for community interest (re-filed from #384)
Re-filed from #384 which was closed as stale. Using the new structured proposal format.