Summary
ensemble_execute_fn in eca.py passes args to shard_map(F, ..., in_specs=(p, p)) via closure rather than as an explicit shard_map argument with a matching in_specs. On a single device this is benign (trivial single-shard mapping). On N>1 devices, args can be a sharded JAX array (see crash path below) and shard_map raises NotImplementedError because it encounters a closed-over array with NamedSharding that doesn't have explicit sharding declared in in_specs.
Crash path (LAPS burn-in, 4 devices)
# laps_burn_in.py:134
initial_state, equipartition = ensemble_execute_fn(
sequential_init, key1, num_chains, mesh, ...
)
# equipartition has out_specs=pscalar → NamedSharding(mesh, PartitionSpec())
# It is a *sharded JAX array*.
flat_equi, _ = ravel_pytree(equipartition)
signs = -2.0 * (flat_equi < 1.0) + 1.0
# signs inherits NamedSharding from equipartition — still sharded.
initial_state, _ = ensemble_execute_fn(
ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs, ...
)
# ^^^^^^^^ signs is now a sharded JAX array passed as args.
Inside ensemble_execute_fn (eca.py:301–306):
def F(x, keys):
y, summary_statistics = _F((x, args), (None, keys, None))[0]
return y, summary_statistics
parallel_execute = shard_map(F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar))
# ^^^^^^^^^^^^^^^^^^
# args is NOT in in_specs!
On N=4 devices, shard_map encounters signs (a JAX array with NamedSharding(mesh, PartitionSpec())) captured in the closure of F. This triggers:
NotImplementedError: shard_map does not support closed-over JAX arrays with NamedSharding
The traceback would point to ensemble_init (which is called inside F) because that's where the actual compute using args/signs happens.
Secondary gap: degenerate test_laps hides the multi-device bug
The only LAPS test uses jax.devices()[:1] (always a 1-device mesh):
mesh=jax.sharding.Mesh(jax.devices()[:1], "chains"),
On a 1-device mesh, shard_map with NamedSharding is trivially a no-op — no partitioning occurs and the closed-over sharding issue is never exercised. A real multi-device test would catch both the keys_sampling.T bug (fixed in #931) and this closure issue.
Proposed fix
Pass args as an explicit shard_map argument with a matching PartitionSpec. Since args (= signs) is a global-broadcast array (not per-chain), it should use pscalar:
def F(x, keys, args_local):
y, summary_statistics = _F((x, args_local), (None, keys, None))[0]
return y, summary_statistics
# Handle args=None by using a zero-size sentinel
args_for_shard = jnp.array(0.0) if args is None else args
args_spec = pscalar # signs is broadcast (same for all devices)
parallel_execute = shard_map(
F, mesh=mesh,
in_specs=(p, p, args_spec),
out_specs=(p, pscalar),
)
return parallel_execute(X, keys, args_for_shard)
Complications to address:
args type is Any — callers may pass None, a numpy array, a NamedTuple, or a JAX array. Need type narrowing or explicit args_partitionspec parameter.
- The
None sentinel requires special handling (cannot be a shard_map argument directly).
- When
args is per-chain data, args_spec=p would be appropriate instead of pscalar.
Multi-device test approach
# Add to test_laps or a new test_laps_multidevice:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax # reimport after setting XLA_FLAGS
mesh = jax.sharding.Mesh(jax.devices(), "chains") # all 4 simulated devices
# ... rest of test identical to test_laps ...
This simulates 4 devices on CPU and would catch both the keys_sampling.T shape corruption (Issue #931, now fixed) and this shard_map closure issue.
Related
Summary
ensemble_execute_fnineca.pypassesargstoshard_map(F, ..., in_specs=(p, p))via closure rather than as an explicitshard_mapargument with a matchingin_specs. On a single device this is benign (trivial single-shard mapping). On N>1 devices,argscan be a sharded JAX array (see crash path below) andshard_mapraisesNotImplementedErrorbecause it encounters a closed-over array withNamedShardingthat doesn't have explicit sharding declared inin_specs.Crash path (LAPS burn-in, 4 devices)
Inside
ensemble_execute_fn(eca.py:301–306):On N=4 devices,
shard_mapencounterssigns(a JAX array withNamedSharding(mesh, PartitionSpec())) captured in the closure ofF. This triggers:The traceback would point to
ensemble_init(which is called insideF) because that's where the actual compute usingargs/signshappens.Secondary gap: degenerate
test_lapshides the multi-device bugThe only LAPS test uses
jax.devices()[:1](always a 1-device mesh):On a 1-device mesh,
shard_mapwithNamedShardingis trivially a no-op — no partitioning occurs and the closed-over sharding issue is never exercised. A real multi-device test would catch both thekeys_sampling.Tbug (fixed in #931) and this closure issue.Proposed fix
Pass
argsas an explicitshard_mapargument with a matchingPartitionSpec. Sinceargs(=signs) is a global-broadcast array (not per-chain), it should usepscalar:Complications to address:
argstype isAny— callers may passNone, a numpy array, a NamedTuple, or a JAX array. Need type narrowing or explicitargs_partitionspecparameter.Nonesentinel requires special handling (cannot be ashard_mapargument directly).argsis per-chain data,args_spec=pwould be appropriate instead ofpscalar.Multi-device test approach
This simulates 4 devices on CPU and would catch both the
keys_sampling.Tshape corruption (Issue #931, now fixed) and thisshard_mapclosure issue.Related
keys_sampling.T→jnp.swapaxes(keys_sampling, 0, 1)(old-style PRNGKey array shape corruption)test_lapsbecause it uses a 1-device mesh that exercises no real multi-device sharding.