Feature request
Make the standard outer user-facing utilities — window_adaptation.run (warmup) and blackjax.util.run_inference_algorithm (inference) — vmap-friendly for multi-chain, so users get correct multi-chain behavior and a working progress_bar without hand-wiring jax.vmap themselves.
Motivation / current friction
Multi-chain is ubiquitous (R-hat, etc.), but today the user has to reach for jax.vmap manually, and both natural ways are unsatisfying:
1. vmap the outer function — the intuitive move, jax.vmap(run_inference_algorithm) / jax.vmap(window_adaptation.run). JAX's UI invites this, but the docs note it's "not how SIMD vectorization works" — and concretely it breaks progress_bar: the bar's io_callback (via gen_scan_fn) lands inside the vmap, so JAX raises NotImplementedError: IO effect not supported in vmap-of-cond.
2. vmap the kernel inside a hand-rolled scan — the intended scan(vmap(kernel)) pattern from howto_sample_multiple_chains. Correct, but it means reimplementing the loop (foregoing the helper + its progress bar), or discovering the non-obvious trick of driving run_inference_algorithm with vmapped inputs (a batched initial_state + a SamplingAlgorithm whose .step is jax.vmap(kernel.step)). It's undocumented, and window_adaptation.run has no equivalent escape hatch at all (it owns its internal scan).
Net: getting multi-chain right requires manual vmap gymnastics, and the most natural attempt silently breaks the progress bar.
Example (the progress_bar break)
import jax, jax.numpy as jnp, blackjax
from blackjax.util import run_inference_algorithm
from blackjax.base import SamplingAlgorithm
def logdensity_fn(x): return -0.5 * jnp.sum(x ** 2)
n_chains, n_steps = 4, 200
keys = jax.random.split(jax.random.key(0), n_chains)
positions = jnp.zeros((n_chains, 2))
nuts = blackjax.nuts(logdensity_fn, step_size=0.5, inverse_mass_matrix=jnp.ones(2))
init_states = jax.vmap(nuts.init)(positions)
# (A) vmap the OUTER inference fn + progress_bar -> CRASH
jax.vmap(lambda k, s: run_inference_algorithm(k, nuts, n_steps, initial_state=s, progress_bar=True))(keys, init_states)
# NotImplementedError: IO effect not supported in vmap-of-cond.
# (B) vmap the OUTER warmup fn + progress_bar -> same CRASH
wa = blackjax.window_adaptation(blackjax.nuts, logdensity_fn, progress_bar=True)
jax.vmap(lambda k, p: wa.run(k, p, n_steps))(keys, positions)
# NotImplementedError: IO effect not supported in vmap-of-cond.
# (C) the non-obvious working pattern for sampling: vmapped INPUTS (scan(vmap(step)))
def vmapped_step(key, states):
return jax.vmap(nuts.step)(jax.random.split(key, n_chains), states)
run_inference_algorithm(jax.random.key(1), SamplingAlgorithm(nuts.init, vmapped_step),
n_steps, initial_state=init_states, progress_bar=True) # OK
(Both crashes are fine with progress_bar=False; single-chain is fine with progress_bar=True.)
Suggested fix
A blackjax-native, vmap-friendly story for all outer user-facing functions (run_inference_algorithm, window_adaptation.run, and the other adaptation utilities) that:
- handles
vmap internally — e.g. accept num_chains (or batched inputs) and do scan(vmap(step)) under the hood, so the user never writes vmap themselves;
- dispatches correctly — single-chain vs multi-chain transparently, with the expected
(num_chains, num_steps, ...) shapes in/out;
- doesn't break
progress_bar — drive one shared progress bar from the non-vmapped outer scan.
Scope
Confirmed on window_adaptation.run (warmup) and run_inference_algorithm (sampling); both route through gen_scan_fn. window_adaptation is just one tested example — the other gen_scan_fn-based utilities (window_adaptation_low_rank, pathfinder / multipathfinder, meads, chees, mclmc / adjusted_mclmc tuning) are not yet tested but very likely want the same treatment.
Environment
- jax 0.10.0
- blackjax 1.6.dev81+ga8998469a
- CPU
Feature request
Make the standard outer user-facing utilities —
window_adaptation.run(warmup) andblackjax.util.run_inference_algorithm(inference) — vmap-friendly for multi-chain, so users get correct multi-chain behavior and a workingprogress_barwithout hand-wiringjax.vmapthemselves.Motivation / current friction
Multi-chain is ubiquitous (R-hat, etc.), but today the user has to reach for
jax.vmapmanually, and both natural ways are unsatisfying:1. vmap the outer function — the intuitive move,
jax.vmap(run_inference_algorithm)/jax.vmap(window_adaptation.run). JAX's UI invites this, but the docs note it's "not how SIMD vectorization works" — and concretely it breaksprogress_bar: the bar'sio_callback(viagen_scan_fn) lands inside thevmap, so JAX raisesNotImplementedError: IO effect not supported in vmap-of-cond.2. vmap the kernel inside a hand-rolled scan — the intended
scan(vmap(kernel))pattern fromhowto_sample_multiple_chains. Correct, but it means reimplementing the loop (foregoing the helper + its progress bar), or discovering the non-obvious trick of drivingrun_inference_algorithmwith vmapped inputs (a batchedinitial_state+ aSamplingAlgorithmwhose.stepisjax.vmap(kernel.step)). It's undocumented, andwindow_adaptation.runhas no equivalent escape hatch at all (it owns its internal scan).Net: getting multi-chain right requires manual
vmapgymnastics, and the most natural attempt silently breaks the progress bar.Example (the
progress_barbreak)(Both crashes are fine with
progress_bar=False; single-chain is fine withprogress_bar=True.)Suggested fix
A blackjax-native, vmap-friendly story for all outer user-facing functions (
run_inference_algorithm,window_adaptation.run, and the other adaptation utilities) that:vmapinternally — e.g. acceptnum_chains(or batched inputs) and doscan(vmap(step))under the hood, so the user never writesvmapthemselves;(num_chains, num_steps, ...)shapes in/out;progress_bar— drive one shared progress bar from the non-vmapped outer scan.Scope
Confirmed on
window_adaptation.run(warmup) andrun_inference_algorithm(sampling); both route throughgen_scan_fn.window_adaptationis just one tested example — the othergen_scan_fn-based utilities (window_adaptation_low_rank,pathfinder/multipathfinder,meads,chees,mclmc/adjusted_mclmctuning) are not yet tested but very likely want the same treatment.Environment