Skip to content

eca.py: ensemble_execute_fn closes over sharded JAX arrays in shard_map — NotImplementedError on N>1 devices #932

Description

@junpenglao

Summary

ensemble_execute_fn in eca.py passes args to shard_map(F, ..., in_specs=(p, p)) via closure rather than as an explicit shard_map argument with a matching in_specs. On a single device this is benign (trivial single-shard mapping). On N>1 devices, args can be a sharded JAX array (see crash path below) and shard_map raises NotImplementedError because it encounters a closed-over array with NamedSharding that doesn't have explicit sharding declared in in_specs.

Crash path (LAPS burn-in, 4 devices)

# laps_burn_in.py:134
initial_state, equipartition = ensemble_execute_fn(
    sequential_init, key1, num_chains, mesh, ...
)
# equipartition has out_specs=pscalar → NamedSharding(mesh, PartitionSpec())
# It is a *sharded JAX array*.

flat_equi, _ = ravel_pytree(equipartition)
signs = -2.0 * (flat_equi < 1.0) + 1.0
# signs inherits NamedSharding from equipartition — still sharded.

initial_state, _ = ensemble_execute_fn(
    ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs, ...
)
# ^^^^^^^^ signs is now a sharded JAX array passed as args.

Inside ensemble_execute_fn (eca.py:301–306):

def F(x, keys):
    y, summary_statistics = _F((x, args), (None, keys, None))[0]
    return y, summary_statistics

parallel_execute = shard_map(F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar))
#                                         ^^^^^^^^^^^^^^^^^^
#                                         args is NOT in in_specs!

On N=4 devices, shard_map encounters signs (a JAX array with NamedSharding(mesh, PartitionSpec())) captured in the closure of F. This triggers:

NotImplementedError: shard_map does not support closed-over JAX arrays with NamedSharding

The traceback would point to ensemble_init (which is called inside F) because that's where the actual compute using args/signs happens.

Secondary gap: degenerate test_laps hides the multi-device bug

The only LAPS test uses jax.devices()[:1] (always a 1-device mesh):

mesh=jax.sharding.Mesh(jax.devices()[:1], "chains"),

On a 1-device mesh, shard_map with NamedSharding is trivially a no-op — no partitioning occurs and the closed-over sharding issue is never exercised. A real multi-device test would catch both the keys_sampling.T bug (fixed in #931) and this closure issue.

Proposed fix

Pass args as an explicit shard_map argument with a matching PartitionSpec. Since args (= signs) is a global-broadcast array (not per-chain), it should use pscalar:

def F(x, keys, args_local):
    y, summary_statistics = _F((x, args_local), (None, keys, None))[0]
    return y, summary_statistics

# Handle args=None by using a zero-size sentinel
args_for_shard = jnp.array(0.0) if args is None else args
args_spec = pscalar  # signs is broadcast (same for all devices)

parallel_execute = shard_map(
    F, mesh=mesh,
    in_specs=(p, p, args_spec),
    out_specs=(p, pscalar),
)
return parallel_execute(X, keys, args_for_shard)

Complications to address:

  1. args type is Any — callers may pass None, a numpy array, a NamedTuple, or a JAX array. Need type narrowing or explicit args_partitionspec parameter.
  2. The None sentinel requires special handling (cannot be a shard_map argument directly).
  3. When args is per-chain data, args_spec=p would be appropriate instead of pscalar.

Multi-device test approach

# Add to test_laps or a new test_laps_multidevice:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax  # reimport after setting XLA_FLAGS

mesh = jax.sharding.Mesh(jax.devices(), "chains")  # all 4 simulated devices
# ... rest of test identical to test_laps ...

This simulates 4 devices on CPU and would catch both the keys_sampling.T shape corruption (Issue #931, now fixed) and this shard_map closure issue.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions