Summary
blackjax.diagnostics.potential_scale_reduction currently implements the classic Gelman-Rubin (1992) R-hat: a plain between-chain / within-chain variance ratio with no rank normalization and no folding. The docstring cites gelman1992inference directly. This is correct as far as it goes, but it has practical downsides that modern MCMC tooling (Stan, ArviZ, PyMC) has long since addressed.
Problem
The widely-used R̂ < 1.01 convergence threshold (Vehtari et al. 2021) was calibrated for the rank-normalized split-R-hat statistic, which is more sensitive than the classic G-R version in two specific ways:
- Rank normalization makes the statistic robust to heavy tails — it cannot be fooled by a chain that's drifted into a tail mode where the classic variance ratio looks stable but the rank distribution is still biased.
- Folding (
z = |x - median(x)|) detects scale discrepancies between chains. Two chains with the same mean but different variances (e.g. one stuck in the neck of a funnel, one exploring the bulk) will look fine to classic G-R but folded R-hat will correctly flag them.
A user relying on blackjax.diagnostics.potential_scale_reduction and applying the 1.01 threshold from modern references is under-diagnosing scale discrepancies and heavy-tailed mixing failures. ArviZ's az.rhat correctly computes the Vehtari 2021 statistic; BlackJAX users who want the modern statistic currently have to round-trip through ArviZ.
Reference: Vehtari, Gelman, Simpson, Carpenter, Bürkner (2021), Rank-normalization, folding, and localization: An improved R-hat for assessing convergence of MCMC, Bayesian Analysis.
Proposed API
Add a new function in `blackjax/diagnostics.py`:
def rank_normalized_split_potential_scale_reduction(
input_array: ArrayLike,
chain_axis: int = 0,
sample_axis: int = 1,
) -> Array:
"""Rank-normalized split-R̂ (Vehtari et al. 2021).
Computes \`max(R̂_rank-normalized, R̂_folded-rank-normalized)\` on
split chains (each chain halved). This is the recommended convergence
statistic; the threshold \`R̂ < 1.01\` from the Vehtari 2021 paper
applies to this statistic, not to the classic Gelman-Rubin version.
[...]
"""
Implementation outline (validated against ArviZ `az.rhat`):
- Split each chain in half along `sample_axis` (so num_chains doubles, num_draws halves).
- Compute rank-normalized R-hat: rank-transform each draw across all chains+samples, then apply `scipy.stats.norm.ppf((r - 3/8) / (N - 1/4))` (the Blom transform), then classic G-R formula on the normalized ranks.
- Compute folded-rank-normalized R-hat: same as step 2 but on `|x - median(x))` instead of `x`.
- Return `max` of the two.
Plan
- Add the new function alongside the existing `potential_scale_reduction`.
- Update the existing function's docstring to note it returns the classic 1992 statistic, with a pointer to the new function for the modern recommendation.
- Do not deprecate `potential_scale_reduction` — the classic G-R has its place (it's cheap, intuitive, and matches many textbook references). But document the trade-off clearly.
- Add tests in `tests/test_diagnostics.py` covering:
- Match with ArviZ `az.rhat` to <1e-6 on a synthetic chain.
- Folded R-hat correctly catches a same-mean-different-variance failure that classic G-R misses.
- Rank-normalized R-hat is robust to a heavy-tailed chain that classic G-R passes.
Why this matters operationally
This issue surfaced from a downstream usage gap: a diagnostics playbook elsewhere in the BlackJAX ecosystem was telling readers "BlackJAX's `potential_scale_reduction` computes the modern Vehtari 2021 version" — which is not true, and resulted in a 1.01 threshold being applied to the wrong statistic. The playbook has been corrected (BlackJAX's classic G-R + `az.rhat` for the modern statistic is the new recommendation), but the right long-term fix is for BlackJAX to expose the modern statistic natively.
Filing this so the gap is tracked upstream and so a future contributor (or me in a follow-up) can pick up the implementation. Happy to PR it if there's no objection to the API shape above.
Summary
blackjax.diagnostics.potential_scale_reductioncurrently implements the classic Gelman-Rubin (1992) R-hat: a plain between-chain / within-chain variance ratio with no rank normalization and no folding. The docstring citesgelman1992inferencedirectly. This is correct as far as it goes, but it has practical downsides that modern MCMC tooling (Stan, ArviZ, PyMC) has long since addressed.Problem
The widely-used R̂ < 1.01 convergence threshold (Vehtari et al. 2021) was calibrated for the rank-normalized split-R-hat statistic, which is more sensitive than the classic G-R version in two specific ways:
z = |x - median(x)|) detects scale discrepancies between chains. Two chains with the same mean but different variances (e.g. one stuck in the neck of a funnel, one exploring the bulk) will look fine to classic G-R but folded R-hat will correctly flag them.A user relying on
blackjax.diagnostics.potential_scale_reductionand applying the 1.01 threshold from modern references is under-diagnosing scale discrepancies and heavy-tailed mixing failures. ArviZ'saz.rhatcorrectly computes the Vehtari 2021 statistic; BlackJAX users who want the modern statistic currently have to round-trip through ArviZ.Reference: Vehtari, Gelman, Simpson, Carpenter, Bürkner (2021), Rank-normalization, folding, and localization: An improved R-hat for assessing convergence of MCMC, Bayesian Analysis.
Proposed API
Add a new function in `blackjax/diagnostics.py`:
Implementation outline (validated against ArviZ `az.rhat`):
Plan
Why this matters operationally
This issue surfaced from a downstream usage gap: a diagnostics playbook elsewhere in the BlackJAX ecosystem was telling readers "BlackJAX's `potential_scale_reduction` computes the modern Vehtari 2021 version" — which is not true, and resulted in a 1.01 threshold being applied to the wrong statistic. The playbook has been corrected (BlackJAX's classic G-R + `az.rhat` for the modern statistic is the new recommendation), but the right long-term fix is for BlackJAX to expose the modern statistic natively.
Filing this so the gap is tracked upstream and so a future contributor (or me in a follow-up) can pick up the implementation. Happy to PR it if there's no objection to the API shape above.