Skip to content

Add Kernelized Stein Discrepancy (KSD) as a convergence diagnostic #839

Description

@junpenglao

Algorithm category: Diagnostic

Paper reference

Two concurrent 2016 ICML papers establishing the KSD goodness-of-fit test:

  1. "A Kernelized Stein Discrepancy for Goodness-of-fit Tests and Model Evaluation" — Qiang Liu, Jason D. Lee, Michael I. Jordan. ICML, PMLR 48, 2016. https://arxiv.org/abs/1602.03253

  2. "A Kernel Test of Goodness of Fit" — Kacper Chwialkowski, Heiko Strathmann, Arthur Gretton. ICML, PMLR 48:2606–2615, 2016. https://arxiv.org/abs/1602.02964

Existing implementations

Benefit and motivation

KSD is a sample-quality measure that quantifies how well a set of MCMC samples approximates the target distribution, without requiring the normalising constant. Unlike R-hat or effective sample size, KSD directly measures distributional discrepancy and is applicable to both MCMC and variational inference outputs. It is already used internally in BlackJAX-related work (TESS, MAMBA adaptation) and is the natural companion diagnostic for SVGD, which is already in BlackJAX. Adding KSD as an official diagnostic would give users a principled convergence check and enable downstream work such as automated stopping criteria.

Comparison to existing BlackJAX algorithms

  • Closest existing: No diagnostic utilities currently in BlackJAX (only blackjax.diagnostics with R-hat, ESS, and MCSE).
  • Advantage: Distribution-level convergence measure; does not require running multiple chains (unlike R-hat); applicable to VI and MCMC outputs alike; differentiable, enabling gradient-based adaptation (as in MAMBA and Campbell et al. 2021).
  • Limitation: Quadratic cost O(n²) in sample size (though random Fourier feature and sliced approximations exist); requires choosing a kernel and bandwidth, which can affect sensitivity.

Estimated JAX implementation effort

S — KSD is a V-statistic (double sum over pairs of samples) of a closed-form Stein kernel. The Stein kernel involves ∇_x log p(x), which JAX can compute automatically. The entire computation is a single jax.vmap-over-jax.vmap kernel matrix evaluation plus a sum, fully JIT-compatible.

JAX-specific implementation notes

jax.vmap over sample pairs computes the Stein kernel matrix efficiently. jax.grad provides ∇_x log p(x) without manual derivation. No lax.while_loop or custom_vjp is needed. A sliced variant (SKSD) can be added later to address the O(n²) scaling. The existing implementation in albcab/TESS/mcmc_utils.py is a ready-to-adapt JAX reference.

Willing to open a PR?

No — filing for community interest (re-filed from #384)


Re-filed from #384 which was closed as stale. Using the new structured proposal format.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions