Skip to content

Feature request: extend the standard warmup & inference utilities to be more vmap-friendly (multi-chain) #927

Description

@junpenglao

Feature request

Make the standard outer user-facing utilities — window_adaptation.run (warmup) and blackjax.util.run_inference_algorithm (inference) — vmap-friendly for multi-chain, so users get correct multi-chain behavior and a working progress_bar without hand-wiring jax.vmap themselves.

Motivation / current friction

Multi-chain is ubiquitous (R-hat, etc.), but today the user has to reach for jax.vmap manually, and both natural ways are unsatisfying:

1. vmap the outer function — the intuitive move, jax.vmap(run_inference_algorithm) / jax.vmap(window_adaptation.run). JAX's UI invites this, but the docs note it's "not how SIMD vectorization works" — and concretely it breaks progress_bar: the bar's io_callback (via gen_scan_fn) lands inside the vmap, so JAX raises NotImplementedError: IO effect not supported in vmap-of-cond.

2. vmap the kernel inside a hand-rolled scan — the intended scan(vmap(kernel)) pattern from howto_sample_multiple_chains. Correct, but it means reimplementing the loop (foregoing the helper + its progress bar), or discovering the non-obvious trick of driving run_inference_algorithm with vmapped inputs (a batched initial_state + a SamplingAlgorithm whose .step is jax.vmap(kernel.step)). It's undocumented, and window_adaptation.run has no equivalent escape hatch at all (it owns its internal scan).

Net: getting multi-chain right requires manual vmap gymnastics, and the most natural attempt silently breaks the progress bar.

Example (the progress_bar break)

import jax, jax.numpy as jnp, blackjax
from blackjax.util import run_inference_algorithm
from blackjax.base import SamplingAlgorithm

def logdensity_fn(x): return -0.5 * jnp.sum(x ** 2)
n_chains, n_steps = 4, 200
keys = jax.random.split(jax.random.key(0), n_chains)
positions = jnp.zeros((n_chains, 2))
nuts = blackjax.nuts(logdensity_fn, step_size=0.5, inverse_mass_matrix=jnp.ones(2))
init_states = jax.vmap(nuts.init)(positions)

# (A) vmap the OUTER inference fn + progress_bar -> CRASH
jax.vmap(lambda k, s: run_inference_algorithm(k, nuts, n_steps, initial_state=s, progress_bar=True))(keys, init_states)
# NotImplementedError: IO effect not supported in vmap-of-cond.

# (B) vmap the OUTER warmup fn + progress_bar -> same CRASH
wa = blackjax.window_adaptation(blackjax.nuts, logdensity_fn, progress_bar=True)
jax.vmap(lambda k, p: wa.run(k, p, n_steps))(keys, positions)
# NotImplementedError: IO effect not supported in vmap-of-cond.

# (C) the non-obvious working pattern for sampling: vmapped INPUTS (scan(vmap(step)))
def vmapped_step(key, states):
    return jax.vmap(nuts.step)(jax.random.split(key, n_chains), states)
run_inference_algorithm(jax.random.key(1), SamplingAlgorithm(nuts.init, vmapped_step),
                        n_steps, initial_state=init_states, progress_bar=True)  # OK

(Both crashes are fine with progress_bar=False; single-chain is fine with progress_bar=True.)

Suggested fix

A blackjax-native, vmap-friendly story for all outer user-facing functions (run_inference_algorithm, window_adaptation.run, and the other adaptation utilities) that:

  1. handles vmap internally — e.g. accept num_chains (or batched inputs) and do scan(vmap(step)) under the hood, so the user never writes vmap themselves;
  2. dispatches correctly — single-chain vs multi-chain transparently, with the expected (num_chains, num_steps, ...) shapes in/out;
  3. doesn't break progress_bar — drive one shared progress bar from the non-vmapped outer scan.

Scope

Confirmed on window_adaptation.run (warmup) and run_inference_algorithm (sampling); both route through gen_scan_fn. window_adaptation is just one tested example — the other gen_scan_fn-based utilities (window_adaptation_low_rank, pathfinder / multipathfinder, meads, chees, mclmc / adjusted_mclmc tuning) are not yet tested but very likely want the same treatment.

Environment

  • jax 0.10.0
  • blackjax 1.6.dev81+ga8998469a
  • CPU

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions