Skip to content

diagnostics: add rank-normalized + folded R-hat (Vehtari et al. 2021) #912

Description

@junpenglao

Summary

blackjax.diagnostics.potential_scale_reduction currently implements the classic Gelman-Rubin (1992) R-hat: a plain between-chain / within-chain variance ratio with no rank normalization and no folding. The docstring cites gelman1992inference directly. This is correct as far as it goes, but it has practical downsides that modern MCMC tooling (Stan, ArviZ, PyMC) has long since addressed.

Problem

The widely-used R̂ < 1.01 convergence threshold (Vehtari et al. 2021) was calibrated for the rank-normalized split-R-hat statistic, which is more sensitive than the classic G-R version in two specific ways:

  1. Rank normalization makes the statistic robust to heavy tails — it cannot be fooled by a chain that's drifted into a tail mode where the classic variance ratio looks stable but the rank distribution is still biased.
  2. Folding (z = |x - median(x)|) detects scale discrepancies between chains. Two chains with the same mean but different variances (e.g. one stuck in the neck of a funnel, one exploring the bulk) will look fine to classic G-R but folded R-hat will correctly flag them.

A user relying on blackjax.diagnostics.potential_scale_reduction and applying the 1.01 threshold from modern references is under-diagnosing scale discrepancies and heavy-tailed mixing failures. ArviZ's az.rhat correctly computes the Vehtari 2021 statistic; BlackJAX users who want the modern statistic currently have to round-trip through ArviZ.

Reference: Vehtari, Gelman, Simpson, Carpenter, Bürkner (2021), Rank-normalization, folding, and localization: An improved R-hat for assessing convergence of MCMC, Bayesian Analysis.

Proposed API

Add a new function in `blackjax/diagnostics.py`:

def rank_normalized_split_potential_scale_reduction(
    input_array: ArrayLike,
    chain_axis: int = 0,
    sample_axis: int = 1,
) -> Array:
    """Rank-normalized split-R̂ (Vehtari et al. 2021).

    Computes \`max(R̂_rank-normalized, R̂_folded-rank-normalized)\` on
    split chains (each chain halved). This is the recommended convergence
    statistic; the threshold \`R̂ < 1.01\` from the Vehtari 2021 paper
    applies to this statistic, not to the classic Gelman-Rubin version.

    [...]
    """

Implementation outline (validated against ArviZ `az.rhat`):

  1. Split each chain in half along `sample_axis` (so num_chains doubles, num_draws halves).
  2. Compute rank-normalized R-hat: rank-transform each draw across all chains+samples, then apply `scipy.stats.norm.ppf((r - 3/8) / (N - 1/4))` (the Blom transform), then classic G-R formula on the normalized ranks.
  3. Compute folded-rank-normalized R-hat: same as step 2 but on `|x - median(x))` instead of `x`.
  4. Return `max` of the two.

Plan

  1. Add the new function alongside the existing `potential_scale_reduction`.
  2. Update the existing function's docstring to note it returns the classic 1992 statistic, with a pointer to the new function for the modern recommendation.
  3. Do not deprecate `potential_scale_reduction` — the classic G-R has its place (it's cheap, intuitive, and matches many textbook references). But document the trade-off clearly.
  4. Add tests in `tests/test_diagnostics.py` covering:
    • Match with ArviZ `az.rhat` to <1e-6 on a synthetic chain.
    • Folded R-hat correctly catches a same-mean-different-variance failure that classic G-R misses.
    • Rank-normalized R-hat is robust to a heavy-tailed chain that classic G-R passes.

Why this matters operationally

This issue surfaced from a downstream usage gap: a diagnostics playbook elsewhere in the BlackJAX ecosystem was telling readers "BlackJAX's `potential_scale_reduction` computes the modern Vehtari 2021 version" — which is not true, and resulted in a 1.01 threshold being applied to the wrong statistic. The playbook has been corrected (BlackJAX's classic G-R + `az.rhat` for the modern statistic is the new recommendation), but the right long-term fix is for BlackJAX to expose the modern statistic natively.

Filing this so the gap is tracked upstream and so a future contributor (or me in a follow-up) can pick up the implementation. Happy to PR it if there's no objection to the API shape above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions