From a853c33b4f02d01452330c3e93f0710446ec35d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Tue, 9 Dec 2025 20:31:14 +0800 Subject: [PATCH 1/6] add async comm --- mplang/v1/core/async_comm.py | 365 +++++++++++++ mplang/v1/core/expr/ast.py | 39 +- mplang/v1/core/expr/async_evaluator.py | 514 ++++++++++++++++++ mplang/v1/core/expr/visitor.py | 48 ++ mplang/v1/runtime/communicator.py | 83 +++ mplang/v1/runtime/simulation.py | 170 +++--- tests/v1/core/test_async_comm.py | 88 +++ tests/v1/core/test_async_simulation.py | 233 ++++++++ .../v1/device/02_simulation_and_driver.py | 6 +- 9 files changed, 1471 insertions(+), 75 deletions(-) create mode 100644 mplang/v1/core/async_comm.py create mode 100644 mplang/v1/core/expr/async_evaluator.py create mode 100644 tests/v1/core/test_async_comm.py create mode 100644 tests/v1/core/test_async_simulation.py diff --git a/mplang/v1/core/async_comm.py b/mplang/v1/core/async_comm.py new file mode 100644 index 00000000..c5dc6eef --- /dev/null +++ b/mplang/v1/core/async_comm.py @@ -0,0 +1,365 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC, abstractmethod +from typing import Any + +from mplang.v1.core.mask import Mask + + +class IAsyncCommunicator(ABC): + """Base class for asynchronous communicators.""" + + @property + @abstractmethod + def rank(self) -> int: + """Get the rank of this process""" + + @property + @abstractmethod + def world_size(self) -> int: + """Get the world size of this process""" + + @abstractmethod + def new_id(self) -> str: + """Must be implemented by mixing class""" + raise NotImplementedError + + @abstractmethod + async def send(self, to: int, key: str, data: Any) -> None: + """Send data to peer with the given key asynchronously""" + + @abstractmethod + async def recv(self, frm: int, key: str) -> Any: + """Receive data from peer with the given key asynchronously""" + + @abstractmethod + def onSent(self, frm: int, key: str, data: Any) -> None: + """Called when a key is sent to self. + + This is typically called by the underlying transport layer (possibly from another thread). + It should be non-blocking and thread-safe. + """ + + +class IAsyncCollective(ABC): + """Interface for asynchronous collective communication""" + + @abstractmethod + async def p2p(self, frm: int, to: int, data: Any) -> Any: + """Perform point-to-point communication""" + + @abstractmethod + async def gather(self, root: int, data: Any) -> list[Any]: + """Gather data from all processes to root""" + + @abstractmethod + async def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]: + """Gather data from parties in pmask to root""" + + @abstractmethod + async def scatter(self, root: int, args: list[Any]) -> Any: + """Scatter data from root to all processes""" + + @abstractmethod + async def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any: + """Scatter data from root to parties in pmask""" + + @abstractmethod + async def allgather(self, arg: Any) -> list[Any]: + """Gather data from all processes to all processes""" + + @abstractmethod + async def allgather_m(self, pmask: int, arg: Any) -> list[Any]: + """Gather data from parties in pmask to all processes""" + + @abstractmethod + async def bcast(self, root: int, arg: Any) -> Any: + """Broadcast data from root to all processes""" + + @abstractmethod + async def bcast_m(self, pmask: int, root: int, arg: Any) -> Any: + """Broadcast data from root to parties in pmask""" + + +class AsyncCollectiveMixin(IAsyncCommunicator, IAsyncCollective): + """Mixin class providing default implementations of asynchronous collective communication algorithms""" + + # Note: These will be provided by mixing classes as properties + @property + def rank(self) -> int: + raise NotImplementedError + + @property + def world_size(self) -> int: + raise NotImplementedError + + async def send(self, to: int, key: str, data: Any) -> None: + raise NotImplementedError + + async def recv(self, frm: int, key: str) -> Any: + raise NotImplementedError + + def new_id(self) -> str: + raise NotImplementedError + + async def p2p(self, frm: int, to: int, data: Any) -> Any: + assert 0 <= frm < self.world_size + assert 0 <= to < self.world_size + + cid = self.new_id() + + send_coro = None + if self.rank == frm: + send_coro = self.send(to, cid, data) + + recv_coro = None + if self.rank == to: + recv_coro = self.recv(frm, cid) + + if send_coro and recv_coro: + _, res = await asyncio.gather(send_coro, recv_coro) + return res + elif send_coro: + await send_coro + return None + elif recv_coro: + return await recv_coro + else: + return None + + async def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]: + assert 0 <= root < self.world_size + cid = self.new_id() + mask = Mask(pmask) + + # 1. Send if we are in mask + if self.rank in mask: + await self.send(root, cid, data) + + # 2. Recv if we are root + if self.rank == root: + # Create futures for all expected receives + futures = [] + for idx in mask: + futures.append(self.recv(idx, cid)) + + # Wait for all concurrently + results = await asyncio.gather(*futures) + return results + else: + return [None] * mask.num_parties() + + async def gather(self, root: int, data: Any) -> list[Any]: + pmask = Mask.all(self.world_size) + return await self.gather_m(pmask.value, root, data) + + async def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any: + logging.debug( + f"[{self.rank}]: scatter_m: pmask={pmask}, root={root}, args={args}" + ) + assert 0 <= root < self.world_size + mask = Mask(pmask) + assert len(args) == mask.num_parties(), f"{len(args)} != {mask.num_parties()}" + + cid = self.new_id() + + if self.rank == root: + # Send to all targets concurrently + send_futures = [] + for idx, arg in zip(mask, args, strict=True): + send_futures.append(self.send(idx, cid, arg)) + await asyncio.gather(*send_futures) + + if self.rank in mask: + data = await self.recv(root, cid) + else: + data = None + + return data + + async def scatter(self, root: int, args: list[Any]) -> Any: + pmask = Mask.all(self.world_size) + return await self.scatter_m(pmask.value, root, args) + + async def allgather_m(self, pmask: int, arg: Any) -> list[Any]: + logging.debug(f"allgather_m: pmask={pmask}, arg={arg}") + cid = self.new_id() + mask = Mask(pmask) + + # 1. Send to all other parties in mask + if self.rank in mask: + send_futures = [] + for idx in mask: + send_futures.append(self.send(idx, cid, arg)) + await asyncio.gather(*send_futures) + + # 2. Recv from all parties in mask + recv_futures = [] + for idx in mask: + recv_futures.append(self.recv(idx, cid)) + + res = await asyncio.gather(*recv_futures) + return res + else: + return [None] * mask.num_parties() + + async def allgather(self, arg: Any) -> list[Any]: + pmask = Mask.all(self.world_size) + return await self.allgather_m(pmask.value, arg) + + async def bcast_m(self, pmask: int, root: int, arg: Any) -> Any: + logging.debug(f"bcast_m: pmask={pmask}, root={root}, arg={arg}") + assert 0 <= root < self.world_size + mask = Mask(pmask) + cid = self.new_id() + + if self.rank == root: + send_futures = [] + for idx in mask: + send_futures.append(self.send(idx, cid, arg)) + await asyncio.gather(*send_futures) + + if self.rank in mask: + return await self.recv(root, cid) + else: + return None + + async def bcast(self, root: int, arg: Any) -> Any: + pmask = Mask.all(self.world_size) + return await self.bcast_m(pmask.value, root, arg) + + +class AsyncCommunicatorBase(IAsyncCommunicator): + """Base implementation providing message box functionality for local communication using asyncio""" + + def __init__( + self, rank: int, world_size: int, loop: asyncio.AbstractEventLoop | None = None + ): + self._rank = rank + self._world_size = world_size + # Map (frm, key) -> Future or Data + self._msgboxes: dict[tuple[int, str], Any | asyncio.Future] = {} + self._counter = 0 + self._loop = loop + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + def _get_loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError as e: + raise RuntimeError( + "AsyncCommunicatorBase must be used within an asyncio event loop or loop must be provided in init" + ) from e + return self._loop + + def new_id(self) -> str: + # Simple counter, assuming single-threaded access to this method within the loop + res = self._counter + self._counter += 1 + return str(res) + + async def recv(self, frm: int, key: str) -> Any: + """Wait until the key is set, returns the value""" + mkey = (frm, key) + + # Check if data is already there + if mkey in self._msgboxes: + val = self._msgboxes[mkey] + if isinstance(val, asyncio.Future): + # Already waiting? This shouldn't happen in normal logic unless multiple recvs for same key + return await val + else: + # Data arrived before recv + del self._msgboxes[mkey] + return val + + # Not there, create a future + loop = self._get_loop() + fut = loop.create_future() + self._msgboxes[mkey] = fut + try: + return await fut + finally: + if mkey in self._msgboxes and self._msgboxes[mkey] is fut: + del self._msgboxes[mkey] + + def onSent(self, frm: int, key: str, data: Any) -> None: + """Called when a key is sent to self. + + This method must be thread-safe as it might be called from network threads. + """ + loop = self._get_loop() + # Use call_soon_threadsafe to handle calls from other threads (e.g. network callbacks) + # If called from the same loop, it just schedules it for next iteration. + loop.call_soon_threadsafe(self._on_sent_internal, frm, key, data) + + def _on_sent_internal(self, frm: int, key: str, data: Any) -> None: + mkey = (frm, key) + if mkey in self._msgboxes: + val = self._msgboxes[mkey] + if isinstance(val, asyncio.Future): + if not val.done(): + val.set_result(data) + # Future is done, we can remove it from msgboxes? + # No, recv needs to await it. But recv will remove it after await. + # Wait, if we remove it here, recv might fail if it hasn't awaited yet? + # Actually, once set_result is called, the future holds the value. + # We should remove it from _msgboxes so it doesn't grow forever? + # But recv uses mkey to find the future. + # So we leave it there. recv will remove it. + else: + raise RuntimeError(f"Duplicate message for {mkey}") + else: + self._msgboxes[mkey] = data + + async def send(self, to: int, key: str, data: Any) -> None: + # Base implementation for local simulation: directly call peer's onSent + # In a real distributed setting, this would put data on wire. + raise NotImplementedError( + "Must be implemented by subclass or mixin with peer awareness" + ) + + +class AsyncThreadCommunicator(AsyncCommunicatorBase, AsyncCollectiveMixin): + """Thread-based async communicator for in-memory communication (simulation)""" + + def __init__( + self, rank: int, world_size: int, loop: asyncio.AbstractEventLoop | None = None + ): + super().__init__(rank, world_size, loop) + self.peers: list[AsyncThreadCommunicator] = [] + + def set_peers(self, peers: list[AsyncThreadCommunicator]) -> None: + assert self.world_size == len(peers) + self.peers = peers + + async def send(self, to: int, key: str, data: Any) -> None: + assert 0 <= to < self.world_size + # In local simulation, we can directly call peer's onSent. + # Since we are all in the same process (and likely same loop for simulation), + # we can just call it. + self.peers[to].onSent(self.rank, key, data) diff --git a/mplang/v1/core/expr/ast.py b/mplang/v1/core/expr/ast.py index 39413af3..09548b99 100644 --- a/mplang/v1/core/expr/ast.py +++ b/mplang/v1/core/expr/ast.py @@ -34,7 +34,7 @@ from mplang.v1.core.tensor import TensorType if TYPE_CHECKING: - from mplang.v1.core.expr.visitor import ExprVisitor + from mplang.v1.core.expr.visitor import AsyncExprVisitor, ExprVisitor class Expr(ABC): @@ -84,6 +84,10 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: """Accept a visitor for the visitor pattern.""" + @abstractmethod + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + """Accept an async visitor with environment.""" + # ============================================================================ # Concrete Expression Classes @@ -161,6 +165,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_eval(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_eval(self, env) + class TupleExpr(Expr): """Expression for creating a tuple from multiple single-output expressions. @@ -204,6 +211,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_tuple(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_tuple(self, env) + class CondExpr(Expr): """Expression for conditional execution. @@ -240,6 +250,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_cond(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_cond(self, env) + class WhileExpr(Expr): """Expression for while loop.""" @@ -266,6 +279,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_while(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_while(self, env) + class ConvExpr(Expr): """Expression for convergence of multiple variables.""" @@ -321,6 +337,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_conv(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_conv(self, env) + class ShflSExpr(Expr): """Expression for static shuffle operation. @@ -403,6 +422,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_shfl_s(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_shfl_s(self, env) + class ShflExpr(Expr): """Expression for dynamic shuffle operation.""" @@ -427,6 +449,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_shfl(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_shfl(self, env) + class AccessExpr(Expr): """Expression for accessing a specific output of a multi-output expression. @@ -457,6 +482,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_access(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_access(self, env) + class VariableExpr(Expr): """Expression for variable reference/lookup.""" @@ -473,6 +501,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_variable(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_variable(self, env) + class FuncDefExpr(Expr): """Expression representing a function definition with parameters and body. @@ -522,6 +553,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_func_def(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_func_def(self, env) + class CallExpr(Expr): """Expression for function call.""" @@ -540,3 +574,6 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_call(self) + + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_call(self, env) diff --git a/mplang/v1/core/expr/async_evaluator.py b/mplang/v1/core/expr/async_evaluator.py new file mode 100644 index 00000000..a9e03256 --- /dev/null +++ b/mplang/v1/core/expr/async_evaluator.py @@ -0,0 +1,514 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from concurrent.futures import Executor +from dataclasses import dataclass +from typing import Any + +from mplang.v1.core.async_comm import IAsyncCommunicator +from mplang.v1.core.expr.ast import ( + AccessExpr, + CallExpr, + CondExpr, + ConvExpr, + EvalExpr, + Expr, + FuncDefExpr, + ShflExpr, + ShflSExpr, + TupleExpr, + VariableExpr, + WhileExpr, +) +from mplang.v1.core.expr.evaluator import EvalSemantic +from mplang.v1.core.expr.visitor import AsyncExprVisitor +from mplang.v1.core.expr.walk import walk_dataflow +from mplang.v1.core.mask import Mask +from mplang.v1.core.pfunc import PFunction +from mplang.v1.kernels.value import Value + + +@dataclass +class AsyncEvalSemantic(EvalSemantic): + """Async version of EvalSemantic. + + Reuses pure computation logic from EvalSemantic but overrides I/O bound methods + to use IAsyncCommunicator. + """ + + comm: IAsyncCommunicator # Override type hint + executor: Executor | None = None + + async def _exec_pfunc_async(self, pfunc: PFunction, args: list[Any]) -> list[Any]: + # Check if any args are None - if so, this rank shouldn't participate + # This prevents None values from reaching kernel validation + if any(arg is None for arg in args): + return [None] * len(pfunc.outs_info) + + if self.executor: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self.executor, self._exec_pfunc, pfunc, args + ) + else: + return self._exec_pfunc(pfunc, args) + + async def _eval_eval_node_async( + self, expr: EvalExpr, arg_vals: list[Any] + ) -> list[Any]: + assert isinstance(expr.pfunc, PFunction) + if not self._should_run(expr.rmask, arg_vals): + return [None] * len(expr.mptypes) + return await self._exec_pfunc_async(expr.pfunc, arg_vals) + + async def _eval_shfl_s_node_async( + self, expr: ShflSExpr, src_value: Any + ) -> list[Any]: + pmask = expr.pmask + src_ranks = expr.src_ranks + dst_ranks = list(Mask(pmask)) + assert len(src_ranks) == len(dst_ranks) + cid = self.comm.new_id() + + # Prepare send and recv operations separately + send_tasks = [] + recv_futures = [] + + # Send phase + for src, dst in zip(src_ranks, dst_ranks, strict=True): + if self.comm.rank == src: + send_tasks.append(self.comm.send(dst, cid, src_value)) + + # Recv phase + for src, dst in zip(src_ranks, dst_ranks, strict=True): + if self.comm.rank == dst: + recv_futures.append(self.comm.recv(src, cid)) + + # Execute all operations concurrently to avoid deadlock + all_tasks = send_tasks + recv_futures + if all_tasks: + results = await asyncio.gather(*all_tasks) + # Return only the recv results + recv_results = results[len(send_tasks) :] + if self.comm.rank in dst_ranks: + assert len(recv_results) == 1 + return recv_results + else: + # Should not happen, but handle gracefully + return [None] + else: + # This party is neither sending nor receiving + if self.comm.rank in dst_ranks: + # Destination rank but no src_ranks match? + return [None] + else: + # Not involved in this shuffle + return [None] + + async def _eval_shfl_node_async( + self, expr: ShflExpr, data: Any, idx: Any + ) -> list[Any]: + # Async version of shuffle implementation + # allgather index via send/recv + indices = [None] * self.comm.world_size + cid = self.comm.new_id() + + # Send index to all other ranks + send_tasks = [] + for dst_rank in range(self.comm.world_size): + if dst_rank != self.comm.rank: + send_tasks.append(self.comm.send(dst_rank, cid, idx)) + + # Receive index from all ranks + recv_tasks = [] + for src_rank in range(self.comm.world_size): + if src_rank != self.comm.rank: + recv_tasks.append(self.comm.recv(src_rank, cid)) + + # Wait for all operations + if send_tasks: + await asyncio.gather(*send_tasks) + if recv_tasks: + recv_results = await asyncio.gather(*recv_tasks) + for i, src_rank in enumerate([ + r for r in range(self.comm.world_size) if r != self.comm.rank + ]): + indices[src_rank] = recv_results[i] + + # Set own index + indices[self.comm.rank] = idx + + # Process indices + indices_int: list[int | None] = [self._as_optional_int(val) for val in indices] + send_pairs: list[tuple[int, int]] = [] + for dst_idx, src_idx in enumerate(indices_int): + if src_idx is not None: + send_pairs.append((src_idx, dst_idx)) + send_pairs.sort() + + # Second phase: send data according to pairs + cid = self.comm.new_id() + received_data = None + + # Send data + data_send_tasks = [] + for src_rank, dst_rank in send_pairs: + if self.comm.rank == src_rank: + data_send_tasks.append(self.comm.send(dst_rank, cid, data)) + + # Receive data + data_recv_tasks = [] + for src_rank, dst_rank in send_pairs: + if self.comm.rank == dst_rank: + data_recv_tasks.append(self.comm.recv(src_rank, cid)) + + # Wait for data operations + if data_send_tasks: + await asyncio.gather(*data_send_tasks) + if data_recv_tasks: + recv_data = await asyncio.gather(*data_recv_tasks) + # Should receive exactly one data item + received_data = recv_data[0] + + return [received_data] + + def _as_optional_int(self, val: Any) -> int | None: + """Convert a value to int if possible, preserving None.""" + val = EvalSemantic._unwrap_value(val) + if val is None: + return None + return int(val) + + async def _verify_uniform_predicate_async(self, pred: Any) -> None: + # For now, just pass + # Would need proper async implementation for uniform verification + pass + + @staticmethod + def _as_optional_int(val: Any) -> int | None: + if isinstance(val, int): + return val + if isinstance(val, Value): + if hasattr(val, "value"): + return int(val.value) + # Try to convert TensorValue using to_numpy + to_numpy = getattr(val, "to_numpy", None) + if callable(to_numpy): + arr = to_numpy() + import numpy as np + + if isinstance(arr, np.ndarray) and arr.size == 1: + return int(arr.item()) + return None + + +class AsyncRecursiveEvaluator(AsyncExprVisitor): + """Original async evaluator using recursive visitor pattern. + + This evaluator can cause stack overflow with deeply nested control flow. + Kept for reference and fallback. + """ + + def __init__(self, semantic: AsyncEvalSemantic): + self.semantic = semantic + + def _first(self, vals: list[Any]) -> Any: + if not isinstance(vals, list): + return vals + if len(vals) == 0: + return None + return vals[0] + + async def evaluate(self, expr: Expr, env: dict[str, Any] | None = None) -> Any: + evaluation_env = env if env is not None else self.semantic.env + return await expr.accept_async(self, evaluation_env) + + async def visit_cond(self, expr: CondExpr, env: dict[str, Any]) -> Any: + pred_res = await expr.pred.accept_async(self, env) + pred = self._first(pred_res) + + args_results = await self._spawn_and_gather(expr.args, env) + flat_args = [self._first(res) for res in args_results] + + if expr.verify_uniform: + await self.semantic._verify_uniform_predicate_async(pred) + + if isinstance(pred, Value): + pred_bool = pred.to_bool() + else: + pred_bool = bool(self.semantic._unwrap_value(pred)) + + if pred_bool: + new_env = {**env, **dict(zip(expr.then_fn.params, flat_args, strict=True))} + res = await expr.then_fn.body.accept_async(self, new_env) + else: + new_env = {**env, **dict(zip(expr.else_fn.params, flat_args, strict=True))} + res = await expr.else_fn.body.accept_async(self, new_env) + return res + + async def visit_call(self, expr: CallExpr, env: dict[str, Any]) -> Any: + args_results = await self._spawn_and_gather(expr.args, env) + flat_args = [self._first(res) for res in args_results] + # Bind arguments + new_env = {**env, **dict(zip(expr.fn.params, flat_args, strict=True))} + res = await expr.fn.body.accept_async(self, new_env) + return res + + async def visit_while(self, expr: WhileExpr, env: dict[str, Any]) -> Any: + curr_vals_results = await self._spawn_and_gather(expr.args, env) + curr_vals = [self._first(res) for res in curr_vals_results] + + # Determine split between state and captures + num_state = expr.body_fn.num_outputs + + # Initial state and captures + curr_state = curr_vals[:num_state] + captures = curr_vals[num_state:] + + while True: + # Reconstruct full arguments: state + captures + full_args = curr_state + captures + + # Check condition + cond_env = {**env, **dict(zip(expr.cond_fn.params, full_args, strict=True))} + cond_res = await expr.cond_fn.body.accept_async(self, cond_env) + + # Validate condition + cond_val = self.semantic._check_while_predicate(cond_res) + + if not cond_val: + break + + # Execute body + body_env = {**env, **dict(zip(expr.body_fn.params, full_args, strict=True))} + body_res = await expr.body_fn.body.accept_async(self, body_env) + + # Update state - body_res is already a list + curr_state = body_res + + return curr_state + + async def _spawn_and_gather( + self, exprs: list[Expr], env: dict[str, Any] + ) -> list[Any]: + """Spawn async tasks for multiple expressions and gather results.""" + tasks = [expr.accept_async(self, env) for expr in exprs] + return await asyncio.gather(*tasks) + + +class AsyncIterativeEvaluator(AsyncExprVisitor): + """Async evaluator using iterative traversal to avoid stack overflow. + + This evaluator follows the same pattern as the synchronous IterativeEvaluator: + 1. Uses local symbols dictionary instead of instance state + 2. Directly recurses via method calls (not Python call stack) + 3. Processes nodes in dependency order + """ + + def __init__(self, semantic: AsyncEvalSemantic): + self.semantic = semantic + + async def evaluate(self, expr: Expr, env: dict[str, Any] | None = None) -> Any: + """Entry point for evaluation.""" + evaluation_env = env if env is not None else self.semantic.env + result = await self._iter_eval_graph(expr, evaluation_env) + return result + + async def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]: + """Main evaluation loop using iterative traversal (async version of sync pattern).""" + symbols: dict[int, list[Any]] = {} + + # Process all nodes in dependency order + for node in walk_dataflow(root, traversal="dfs_post_iter"): + if isinstance(node, VariableExpr): + if node.name not in env: + raise ValueError( + f"Variable '{node.name}' not found in evaluator environment" + ) + symbols[id(node)] = [env[node.name]] + + elif isinstance(node, TupleExpr): + vals = [self._first(symbols[id(a)]) for a in node.args] + symbols[id(node)] = vals + + elif isinstance(node, AccessExpr): + src_vals = symbols[id(node.src)] + symbols[id(node)] = [src_vals[node.index]] + + elif isinstance(node, CallExpr): + arg_vals = [self._first(symbols[id(a)]) for a in node.args] + assert isinstance(node.fn, FuncDefExpr) + sub_env = dict(zip(node.fn.params, arg_vals, strict=True)) + # Recursive method call - not Python call stack recursion! + res = await self._iter_eval_graph(node.fn.body, {**env, **sub_env}) + symbols[id(node)] = res + + elif isinstance(node, CondExpr): + pred_val = self._first(symbols[id(node.pred)]) + arg_vals = [self._first(symbols[id(a)]) for a in node.args] + + if pred_val is None: + symbols[id(node)] = [None] * len(node.mptypes) + else: + # Optional uniform verification + if node.verify_uniform: + await self.semantic._verify_uniform_predicate_async(pred_val) + + # Convert to bool + if isinstance(pred_val, Value): + pred = pred_val.to_bool() + else: + pred = bool(self.semantic._unwrap_value(pred_val)) + + if pred: + sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True)) + # Recursive method call + res = await self._iter_eval_graph( + node.then_fn.body, {**env, **sub_env} + ) + symbols[id(node)] = res + else: + sub_env = dict(zip(node.else_fn.params, arg_vals, strict=True)) + # Recursive method call + res = await self._iter_eval_graph( + node.else_fn.body, {**env, **sub_env} + ) + symbols[id(node)] = res + + elif isinstance(node, WhileExpr): + state = [self._first(symbols[id(a)]) for a in node.args] + while True: + cond_env = dict(zip(node.cond_fn.params, state, strict=True)) + # Recursive method call for condition + cond_vals = await self._iter_eval_graph( + node.cond_fn.body, {**env, **cond_env} + ) + cond_val = self.semantic._check_while_predicate(cond_vals) + if not bool(cond_val): + break + + body_env = dict(zip(node.body_fn.params, state, strict=True)) + # Recursive method call for body + new_state = await self._iter_eval_graph( + node.body_fn.body, {**env, **body_env} + ) + state = self._merge_state(state, new_state) + symbols[id(node)] = state[0 : len(node.body_fn.mptypes)] + + elif isinstance(node, EvalExpr): + arg_vals = [self._first(symbols[id(a)]) for a in node.args] + symbols[id(node)] = await self.semantic._eval_eval_node_async( + node, arg_vals + ) + + elif isinstance(node, ConvExpr): + vars_vals = [self._first(symbols[id(v)]) for v in node.vars] + # ConvExpr needs async implementation + symbols[id(node)] = await self._eval_conv_node_async(node, vars_vals) + + elif isinstance(node, ShflSExpr): + value = self._first(symbols[id(node.src_val)]) + symbols[id(node)] = await self.semantic._eval_shfl_s_node_async( + node, value + ) + + elif isinstance(node, ShflExpr): + data = self._first(symbols[id(node.src)]) + index = self._first(symbols[id(node.index)]) + symbols[id(node)] = await self.semantic._eval_shfl_node_async( + node, data, index + ) + + elif isinstance(node, FuncDefExpr): + # FuncDefExpr should not be directly evaluated + raise RuntimeError("FuncDefExpr should not be directly evaluated") + else: + raise NotImplementedError(f"Unsupported expression type: {type(node)}") + + return symbols[id(root)] + + @staticmethod + def _first(vals: list[Any]) -> Any: + """Get first value from list (matches sync evaluator).""" + if not isinstance(vals, list): + return vals + if len(vals) == 0: + return None + return vals[0] + + def _merge_state(self, old: list[Any], new: list[Any]) -> list[Any]: + """Merge state for while loops (matches sync evaluator).""" + assert len(new) <= len(old) + return new + old[len(new) :] + + async def _eval_conv_node_async( + self, expr: ConvExpr, vars_vals: list[Any] + ) -> list[Any]: + """Async version of conv node evaluation.""" + # Implement the same logic as sync _eval_conv_node + assert len(vars_vals) > 0, "pconv called with empty vars list." + filtered = [v for v in vars_vals if v is not None] + if len(filtered) == 0: + return [None] + if len(filtered) == 1: + return [filtered[0]] + raise ValueError(f"pconv called with multiple vars={filtered}.") + + # Implement all required AsyncExprVisitor methods + async def visit_variable(self, expr: VariableExpr, env: dict[str, Any]) -> Any: + """Visit VariableExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_eval(self, expr: EvalExpr, env: dict[str, Any]) -> Any: + """Visit EvalExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_tuple(self, expr: TupleExpr, env: dict[str, Any]) -> Any: + """Visit TupleExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_cond(self, expr: CondExpr, env: dict[str, Any]) -> Any: + """Visit CondExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_while(self, expr: WhileExpr, env: dict[str, Any]) -> Any: + """Visit WhileExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_call(self, expr: CallExpr, env: dict[str, Any]) -> Any: + """Visit CallExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_conv(self, expr: ConvExpr, env: dict[str, Any]) -> Any: + """Visit ConvExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_shfl_s(self, expr: ShflSExpr, env: dict[str, Any]) -> Any: + """Visit ShflSExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_shfl(self, expr: ShflExpr, env: dict[str, Any]) -> Any: + """Visit ShflExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_access(self, expr: AccessExpr, env: dict[str, Any]) -> Any: + """Visit AccessExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") + + async def visit_func_def(self, expr: FuncDefExpr, env: dict[str, Any]) -> Any: + """Visit FuncDefExpr - not used in new implementation.""" + raise NotImplementedError("Use _iter_eval_graph instead") diff --git a/mplang/v1/core/expr/visitor.py b/mplang/v1/core/expr/visitor.py index c63e8055..a601a00f 100644 --- a/mplang/v1/core/expr/visitor.py +++ b/mplang/v1/core/expr/visitor.py @@ -83,3 +83,51 @@ def visit_access(self, expr: AccessExpr) -> Any: @abstractmethod def visit_func_def(self, expr: FuncDefExpr) -> Any: pass + + +class AsyncExprVisitor(ABC): + """Async visitor interface that supports environment passing.""" + + @abstractmethod + async def visit_eval(self, expr: EvalExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_variable(self, expr: VariableExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_tuple(self, expr: TupleExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_cond(self, expr: CondExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_call(self, expr: CallExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_while(self, expr: WhileExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_conv(self, expr: ConvExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_shfl_s(self, expr: ShflSExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_shfl(self, expr: ShflExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_access(self, expr: AccessExpr, env: dict[str, Any]) -> Any: + pass + + @abstractmethod + async def visit_func_def(self, expr: FuncDefExpr, env: dict[str, Any]) -> Any: + pass diff --git a/mplang/v1/runtime/communicator.py b/mplang/v1/runtime/communicator.py index bc51a3b6..9f66f5ec 100644 --- a/mplang/v1/runtime/communicator.py +++ b/mplang/v1/runtime/communicator.py @@ -23,6 +23,7 @@ import httpx +from mplang.v1.core.async_comm import AsyncCommunicatorBase from mplang.v1.core.comm import CommunicatorBase from mplang.v1.kernels.value import Value, decode_value, encode_value @@ -105,3 +106,85 @@ def recv(self, frm: int, key: str) -> Any: f"Received data: from_rank={frm}, to_rank={self._rank}, key={key}" ) return result + + +class AsyncHttpCommunicator(AsyncCommunicatorBase): + """Async version of HttpCommunicator.""" + + def __init__( + self, + session_name: str, + rank: int, + endpoints: list[str], + loop=None, + ): + # Validate endpoints + if not endpoints: + raise ValueError("endpoints cannot be empty") + + if not all(endpoint for endpoint in endpoints): + raise ValueError("endpoints cannot contain empty elements") + + super().__init__(rank, len(endpoints), loop) + self.session_name = session_name + # Ensure all endpoints have protocol prefix + self.endpoints = [ + endpoint + if endpoint.startswith(("http://", "https://")) + else f"http://{endpoint}" + for endpoint in endpoints + ] + logging.info( + f"AsyncHttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={self.endpoints}" + ) + + async def send(self, to: int, key: str, data: Any) -> None: + """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.""" + target_endpoint = self.endpoints[to] + url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}" + logging.debug( + f"Async sending data: from_rank={self._rank}, to_rank={to}, key={key}, target_url={url}" + ) + + try: + # Serialize data using Value envelope. + if not isinstance(data, Value): + raise TypeError( + f"Communicator requires Value instance, got {type(data).__name__}. " + "Wrap data in TensorValue or custom Value subclass." + ) + data_bytes = encode_value(data) + data_b64 = base64.b64encode(data_bytes).decode("utf-8") + + request_data = { + "data": data_b64, + } + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.put(url, json=request_data) + logging.debug(f"Async send response: status={response.status_code}") + if response.status_code != 200: + logging.error(f"Async send failed: {response.text}") + response.raise_for_status() + + except httpx.RequestError as e: + logging.error( + f"Async send failed with exception: from_rank={self._rank}, to_rank={to}, key={key}, error={e}" + ) + raise OSError(f"Failed to send data to rank {to}") from e + + async def recv(self, frm: int, key: str) -> Any: + """Wait until the key is set, returns the value.""" + logging.debug( + f"Async waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}" + ) + data_b64 = await super().recv(frm, key) + + data_bytes = base64.b64decode(data_b64) + # Deserialize using Value envelope + result = decode_value(data_bytes) + + logging.debug( + f"Async received data: from_rank={frm}, to_rank={self._rank}, key={key}" + ) + return result diff --git a/mplang/v1/runtime/simulation.py b/mplang/v1/runtime/simulation.py index 56ed912e..3cf44d46 100644 --- a/mplang/v1/runtime/simulation.py +++ b/mplang/v1/runtime/simulation.py @@ -14,12 +14,11 @@ from __future__ import annotations -import concurrent.futures -import faulthandler +import asyncio import logging -import sys -import traceback +import os from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor from typing import Any, cast import spu.libspu as libspu @@ -38,8 +37,14 @@ PFunction, # for spu.seed_env kernel seeding TensorLike, ) +from mplang.v1.core.async_comm import AsyncThreadCommunicator from mplang.v1.core.expr.ast import Expr -from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator +from mplang.v1.core.expr.async_evaluator import ( + AsyncEvalSemantic, + AsyncRecursiveEvaluator, + AsyncIterativeEvaluator, +) +from mplang.v1.core.expr.evaluator import IEvaluator from mplang.v1.kernels.context import RuntimeContext from mplang.v1.runtime.link_comm import LinkCommunicator from mplang.v1.utils.spu_utils import parse_field, parse_protocol @@ -146,6 +151,9 @@ def __init__( self._spu_world = spu_mask.num_parties() self._spu_mask = spu_mask + # Executor for CPU-bound tasks + self._executor = ThreadPoolExecutor(max_workers=os.cpu_count()) + # Persistent per-rank RuntimeContext instances (reused across evaluates). # We no longer pre-create evaluators since each evaluate has different env bindings. # Build per-rank runtime contexts. @@ -210,90 +218,110 @@ def fetch(self, obj: MPObject) -> list[TensorLike]: raise ValueError(f"Expected SimVar, got {type(obj)}") return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in obj._values] + def _ensure_spu_init(self, rank: int) -> None: + """Ensure SPU environment is initialized for the given rank.""" + runtime = self._runtimes[rank] + spu_meta = runtime.state.setdefault("_spu", {}) + if not spu_meta.get("inited", False): + link_ctx = self._spu_link_ctxs[rank] + seed_fn = PFunction( + fn_type="spu.seed_env", + ins_info=(), + outs_info=(), + config=self._spu_runtime_cfg, + world=self._spu_world, + link=link_ctx, + ) + runtime.run_kernel(seed_fn, []) # type: ignore[arg-type] + spu_meta["inited"] = True + # override def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]: - # sanity check for bindings. + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + # Case A: Inside an existing loop (e.g., Jupyter) + try: + import nest_asyncio + + nest_asyncio.apply() + return loop.run_until_complete(self._evaluate_async(expr, bindings)) + except ImportError as e: + raise RuntimeError( + "Running in an active event loop (e.g. Jupyter). " + "Please install 'nest_asyncio' or use 'await simulator.evaluate_async(...)'." + ) from e + else: + # Case B: Standard script + return asyncio.run(self._evaluate_async(expr, bindings)) + + async def _evaluate_async( + self, expr: Expr, bindings: dict[str, MPObject] + ) -> Sequence[MPObject]: + """Async evaluation entry point.""" + # 1. Setup Async Communicators + world_size = self.world_size() + async_comms = [ + AsyncThreadCommunicator(rank, world_size) for rank in range(world_size) + ] + for comm in async_comms: + comm.set_peers(async_comms) + + # 2. Prepare Environment + # Validate that all variables belong to this simulator context for name, var in bindings.items(): + if not isinstance(var, SimVar): + raise ValueError( + f"Expected SimVar for variable '{name}', got {type(var)}" + ) if var.ctx is not self: - raise ValueError(f"Variable {name} not in this context, got {var.ctx}.") + raise ValueError(f"Variable '{name}' not in this context") pts_env = [ {name: cast(SimVar, var)._values[rank] for name, var in bindings.items()} - for rank in range(self.world_size()) + for rank in range(world_size) ] - # Build per-rank evaluators with the per-party environment (runtime reused) - pts_evaluators: list[IEvaluator] = [] - for rank in range(self.world_size()): + # 3. Create Evaluators + evaluators = [] + for rank in range(world_size): runtime = self._runtimes[rank] - ev = create_evaluator( - rank, - pts_env[rank], - self._comms[rank], - runtime, - None, + # Initialize SPU if needed (same logic as sync) + self._ensure_spu_init(rank) + + semantic = AsyncEvalSemantic( + rank=rank, + env=pts_env[rank], + comm=async_comms[rank], + runtime=runtime, + executor=self._executor, ) - # Seed SPU once per runtime (idempotent logical requirement) - # Use setdefault to both retrieve and create metadata dict in one step. - spu_meta = runtime.state.setdefault("_spu", {}) - if not spu_meta.get("inited", False): - link_ctx = self._spu_link_ctxs[rank] - seed_fn = PFunction( - fn_type="spu.seed_env", - ins_info=(), - outs_info=(), - config=self._spu_runtime_cfg, - world=self._spu_world, - link=link_ctx, - ) - ev.runtime.run_kernel(seed_fn, []) # type: ignore[arg-type] - spu_meta["inited"] = True - pts_evaluators.append(ev) - - # Collect evaluation results from all parties - pts_results: list[Any] = [] - - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self._do_evaluate, expr, evaluator) - for evaluator in pts_evaluators - ] - - # Collect results with proper exception handling - for i, future in enumerate(futures): - try: - result = future.result(100) # 100 second timeout - pts_results.append(result) - except concurrent.futures.TimeoutError: - faulthandler.dump_traceback(file=sys.stderr, all_threads=True) - raise - except Exception as e: - print( - f"Exception in party {i}: {type(e).__name__}: {e}", - file=sys.stderr, - ) - traceback.print_exc(file=sys.stderr) - executor.shutdown(wait=False, cancel_futures=True) - raise - - # Convert results to SimVar objects - # pts_results is a list of party results, where each party result is a list of values - # We need to transpose this to get (n_outputs, n_parties) structure - assert len(pts_results) == self.world_size() - - # Ensure all parties returned the same number of outputs (matrix validation) + ev = AsyncIterativeEvaluator(semantic) + evaluators.append(ev) + + # 4. Run Evaluation concurrently + # We need to run all evaluators.evaluate(expr) concurrently. + tasks = [ev.evaluate(expr) for ev in evaluators] + pts_results = await asyncio.gather(*tasks) + + # Ensure results are lists if expr has single output + if expr.num_outputs == 1: + # If each evaluator already returns a list (as async evaluators do), don't wrap again + if pts_results and not isinstance(pts_results[0], list): + pts_results = [[res] for res in pts_results] + + # 5. Process Results (Transpose and Wrap) + assert len(pts_results) == world_size if pts_results and not all( len(row) == len(pts_results[0]) for row in pts_results ): raise ValueError("Inconsistent number of outputs across parties") - # Transpose: (n_parties, n_outputs) -> (n_outputs, n_parties) output_values = list(zip(*pts_results, strict=False)) - - # Get the output types from the expression output_types = expr.mptypes - - # Create SimVar objects for each output sim_vars = [] for values, mptype in zip(output_values, output_types, strict=False): sim_var = SimVar(self, mptype, list(values)) diff --git a/tests/v1/core/test_async_comm.py b/tests/v1/core/test_async_comm.py new file mode 100644 index 00000000..be5cc4ab --- /dev/null +++ b/tests/v1/core/test_async_comm.py @@ -0,0 +1,88 @@ +import asyncio + +import pytest + +from mplang.v1.core.async_comm import AsyncThreadCommunicator + + +@pytest.mark.asyncio +async def test_async_p2p(): + world_size = 2 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + # P0 sends to P1 + async def p0_task(): + await comms[0].p2p(0, 1, "hello") + return "done" + + async def p1_task(): + data = await comms[1].p2p(0, 1, None) + return data + + results = await asyncio.gather(p0_task(), p1_task()) + assert results[1] == "hello" + + +@pytest.mark.asyncio +async def test_async_gather(): + world_size = 3 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + async def task(rank): + data = f"data-{rank}" + return await comms[rank].gather(0, data) + + results = await asyncio.gather(*[task(i) for i in range(world_size)]) + + # Rank 0 should get all data + assert results[0] == ["data-0", "data-1", "data-2"] + # Others get None list + assert results[1] == [None, None, None] + assert results[2] == [None, None, None] + + +@pytest.mark.asyncio +async def test_async_scatter(): + world_size = 3 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + data_to_scatter = ["d0", "d1", "d2"] + + async def task(rank): + if rank == 0: + return await comms[rank].scatter(0, data_to_scatter) + else: + return await comms[rank].scatter(0, [None] * 3) # args ignored for non-root + + results = await asyncio.gather(*[task(i) for i in range(world_size)]) + + assert results[0] == "d0" + assert results[1] == "d1" + assert results[2] == "d2" + + +@pytest.mark.asyncio +async def test_async_bcast(): + world_size = 3 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + async def task(rank): + if rank == 0: + return await comms[rank].bcast(0, "broadcast_data") + else: + return await comms[rank].bcast(0, None) + + results = await asyncio.gather(*[task(i) for i in range(world_size)]) + + # bcast returns the data for everyone in the mask, including the root + assert results[0] == "broadcast_data" + assert results[1] == "broadcast_data" + assert results[2] == "broadcast_data" diff --git a/tests/v1/core/test_async_simulation.py b/tests/v1/core/test_async_simulation.py new file mode 100644 index 00000000..6bd8db32 --- /dev/null +++ b/tests/v1/core/test_async_simulation.py @@ -0,0 +1,233 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp +import numpy as np +import pytest + +from mplang.v1.core import ClusterSpec, Mask +from mplang.v1.core.context_mgr import with_ctx +from mplang.v1.core.primitive import function +from mplang.v1.core.tracer import TraceContext, trace +from mplang.v1.ops import jax_cc +from mplang.v1.runtime.simulation import Simulator, SimVar +from mplang.v1.simp.api import constant, run + + +def add(x, y): + return run(None, jax_cc.run_jax, lambda a, b: jnp.add(a, b), x, y) + + +@pytest.mark.asyncio +async def test_async_simulation_basic(): + """Test basic async simulation.""" + + # 1. Define a function + @function + def add_func(x, y): + return add(x, y) + + # 2. Trace it + cluster = ClusterSpec.simple(2) + ctx = TraceContext(cluster, mask=Mask(3)) + + # Let's trace a function that adds two constants. + @function + def simple_add(): + a = constant(10) + b = constant(20) + return add(a, b) + + traced = trace(ctx, simple_add) + expr = traced.make_expr() + + # 3. Create Simulator + sim = Simulator(cluster) + + # 4. Evaluate Async + # simple_add takes no args, so bindings is empty + # make_expr() returns a FuncDefExpr. We want to evaluate its body. + results = await sim._evaluate_async(expr.body, {}) + + # 5. Verify + assert len(results) == 1 + sim_var = results[0] + assert isinstance(sim_var, SimVar) + values = sim_var.values + assert len(values) == 2 + assert values[0] == 30 + assert values[1] == 30 + + +@pytest.mark.asyncio +async def test_async_simulation_args(): + """Test async simulation with arguments.""" + + @function + def add_func(x, y): + return add(x, y) + + cluster = ClusterSpec.simple(2) + ctx = TraceContext(cluster, mask=Mask(3)) + + # Let's manually create TraceVars for inputs. + from mplang.v1.core.dtypes import INT32 + from mplang.v1.core.expr.ast import VariableExpr + from mplang.v1.core.mptype import MPType + from mplang.v1.core.tracer import TraceVar + from mplang.v1.kernels.value import TensorValue + + x_type = MPType.tensor(INT32, (), Mask(3)) + y_type = MPType.tensor(INT32, (), Mask(3)) + + with with_ctx(ctx): + x = TraceVar(ctx, VariableExpr("x", x_type)) + y = TraceVar(ctx, VariableExpr("y", y_type)) + + traced = trace(ctx, add_func, x, y) + expr = traced.make_expr() + + # Now evaluate with bindings + sim = Simulator(cluster) + + # Create input values + x_val = SimVar( + sim, + x_type, + [ + TensorValue(np.array(10, dtype=np.int32)), + TensorValue(np.array(10, dtype=np.int32)), + ], + ) + y_val = SimVar( + sim, + y_type, + [ + TensorValue(np.array(20, dtype=np.int32)), + TensorValue(np.array(20, dtype=np.int32)), + ], + ) + + bindings = {"x": x_val, "y": y_val} + # expr is FuncDefExpr. We need to evaluate its body. + # But wait, the body refers to parameters. + # FuncDefExpr params are generated names usually? + # Or they match the names we gave? + + # When we use `trace(ctx, func, x, y)`, `x` and `y` are passed as arguments. + # `trace` captures them. + + # If `x` and `y` are TraceVars with VariableExpr, they are treated as inputs? + # `trace` logic: + # It calls the function with arguments. + # If arguments are TraceVars, they are used. + + # The resulting FuncDefExpr will have parameters corresponding to the inputs. + # But `x` and `y` are captured from the outer scope if we pass them? + # No, they are passed as arguments. + + # Let's check `traced.make_expr()` logic. + # It creates a FuncDefExpr. + # The parameters of FuncDefExpr correspond to the arguments of the traced function. + + # But we need to know the parameter names to bind them. + # `traced.make_expr()` might generate parameter names. + + # Let's inspect `expr.params`. + + # For now, let's assume we can bind by position if we construct a CallExpr. + # But `evaluate` takes `bindings` which is a dict. + + # If we evaluate `expr.body`, it contains `VariableExpr`s. + # These `VariableExpr`s refer to the parameter names. + + # So we need to map our input values to these parameter names. + + # `traced.make_expr()` returns `FuncDefExpr`. + # `expr.params` gives the list of parameter names. + + # So we should map `expr.params` to our values. + + param_names = expr.params + assert len(param_names) == 2 + + bindings = {param_names[0]: x_val, param_names[1]: y_val} + + results = await sim._evaluate_async(expr.body, bindings) + + assert len(results) == 1 + assert results[0].values == [30, 30] + + +def test_sync_evaluate(): + """Test synchronous evaluate.""" + + @function + def simple_add(): + a = constant(10) + b = constant(20) + return add(a, b) + + cluster = ClusterSpec.simple(2) + ctx = TraceContext(cluster, mask=Mask(3)) + traced = trace(ctx, simple_add) + expr = traced.make_expr() + + sim = Simulator(cluster) + # This calls evaluate (sync) which calls asyncio.run(_evaluate_async) + results = sim.evaluate(expr.body, {}) + + assert len(results) == 1 + values = results[0].values + assert values[0] == 30 + assert values[1] == 30 + + +@pytest.mark.asyncio +async def test_evaluate_in_loop(): + """Test evaluate inside an async loop (should fail without nest_asyncio or work with it).""" + + @function + def simple_add(): + a = constant(10) + b = constant(20) + return add(a, b) + + cluster = ClusterSpec.simple(2) + ctx = TraceContext(cluster, mask=Mask(3)) + traced = trace(ctx, simple_add) + expr = traced.make_expr() + + sim = Simulator(cluster) + + # We are in an async test, so there is a running loop. + # Calling sim.evaluate() should raise RuntimeError or work if nest_asyncio is present. + import asyncio + + await asyncio.sleep(0) # Silence RUF029 and ensure loop is running + + import importlib.util + + nest_asyncio_installed = importlib.util.find_spec("nest_asyncio") is not None + + if nest_asyncio_installed: + # If nest_asyncio is installed, it should work (assuming apply() is called inside evaluate) + results = sim.evaluate(expr.body, {}) + assert len(results) == 1 + values = results[0].values + assert values[0] == 30 + else: + # If not installed, it should raise RuntimeError + with pytest.raises(RuntimeError, match="nest_asyncio"): + sim.evaluate(expr.body, {}) diff --git a/tutorials/v1/device/02_simulation_and_driver.py b/tutorials/v1/device/02_simulation_and_driver.py index b04466a7..df118405 100644 --- a/tutorials/v1/device/02_simulation_and_driver.py +++ b/tutorials/v1/device/02_simulation_and_driver.py @@ -127,13 +127,13 @@ def cmd_main(): Usage: 1. Simulator (local, no setup needed): - uv run tutorials/device/02_simulation_and_driver.py sim + uv run tutorials/v1/device/02_simulation_and_driver.py sim 2. Driver (distributed): Step 1: Start cluster in separate terminal: - uv run python -m mplang.runtime.cli up -c examples/v1/conf/3pc.yaml + uv run python -m mplang.v1.runtime.cli up -c examples/v1/conf/3pc.yaml Step 2: Run computation: - uv run tutorials/device/02_simulation_and_driver.py run + uv run tutorials/v1/device/02_simulation_and_driver.py run """ cmd_main() From d2ee6f6d38c02f78ead0d50443e15c8c0b1190c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Wed, 10 Dec 2025 15:39:37 +0800 Subject: [PATCH 2/6] update --- mplang/v1/core/async_comm.py | 76 ++++++++++----------- mplang/v1/core/expr/async_evaluator.py | 91 +++++++------------------- mplang/v1/runtime/communicator.py | 6 +- mplang/v1/runtime/server.py | 15 +++-- mplang/v1/runtime/session.py | 68 +++++++++++++++++-- mplang/v1/runtime/simulation.py | 3 +- 6 files changed, 135 insertions(+), 124 deletions(-) diff --git a/mplang/v1/core/async_comm.py b/mplang/v1/core/async_comm.py index c5dc6eef..f6a6ced6 100644 --- a/mplang/v1/core/async_comm.py +++ b/mplang/v1/core/async_comm.py @@ -19,43 +19,21 @@ from abc import ABC, abstractmethod from typing import Any +from mplang.v1.core.comm import ICommunicator from mplang.v1.core.mask import Mask -class IAsyncCommunicator(ABC): +class IAsyncCommunicator(ICommunicator): """Base class for asynchronous communicators.""" - @property - @abstractmethod - def rank(self) -> int: - """Get the rank of this process""" - - @property - @abstractmethod - def world_size(self) -> int: - """Get the world size of this process""" - @abstractmethod - def new_id(self) -> str: - """Must be implemented by mixing class""" - raise NotImplementedError - - @abstractmethod - async def send(self, to: int, key: str, data: Any) -> None: + async def async_send(self, to: int, key: str, data: Any) -> None: """Send data to peer with the given key asynchronously""" @abstractmethod - async def recv(self, frm: int, key: str) -> Any: + async def async_recv(self, frm: int, key: str) -> Any: """Receive data from peer with the given key asynchronously""" - @abstractmethod - def onSent(self, frm: int, key: str, data: Any) -> None: - """Called when a key is sent to self. - - This is typically called by the underlying transport layer (possibly from another thread). - It should be non-blocking and thread-safe. - """ - class IAsyncCollective(ABC): """Interface for asynchronous collective communication""" @@ -109,10 +87,16 @@ def rank(self) -> int: def world_size(self) -> int: raise NotImplementedError - async def send(self, to: int, key: str, data: Any) -> None: + def send(self, to: int, key: str, data: Any) -> None: + raise NotImplementedError + + def recv(self, frm: int, key: str) -> Any: + raise NotImplementedError + + async def async_send(self, to: int, key: str, data: Any) -> None: raise NotImplementedError - async def recv(self, frm: int, key: str) -> Any: + async def async_recv(self, frm: int, key: str) -> Any: raise NotImplementedError def new_id(self) -> str: @@ -126,11 +110,11 @@ async def p2p(self, frm: int, to: int, data: Any) -> Any: send_coro = None if self.rank == frm: - send_coro = self.send(to, cid, data) + send_coro = self.async_send(to, cid, data) recv_coro = None if self.rank == to: - recv_coro = self.recv(frm, cid) + recv_coro = self.async_recv(frm, cid) if send_coro and recv_coro: _, res = await asyncio.gather(send_coro, recv_coro) @@ -150,14 +134,14 @@ async def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]: # 1. Send if we are in mask if self.rank in mask: - await self.send(root, cid, data) + await self.async_send(root, cid, data) # 2. Recv if we are root if self.rank == root: # Create futures for all expected receives futures = [] for idx in mask: - futures.append(self.recv(idx, cid)) + futures.append(self.async_recv(idx, cid)) # Wait for all concurrently results = await asyncio.gather(*futures) @@ -183,11 +167,11 @@ async def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any: # Send to all targets concurrently send_futures = [] for idx, arg in zip(mask, args, strict=True): - send_futures.append(self.send(idx, cid, arg)) + send_futures.append(self.async_send(idx, cid, arg)) await asyncio.gather(*send_futures) if self.rank in mask: - data = await self.recv(root, cid) + data = await self.async_recv(root, cid) else: data = None @@ -206,13 +190,13 @@ async def allgather_m(self, pmask: int, arg: Any) -> list[Any]: if self.rank in mask: send_futures = [] for idx in mask: - send_futures.append(self.send(idx, cid, arg)) + send_futures.append(self.async_send(idx, cid, arg)) await asyncio.gather(*send_futures) # 2. Recv from all parties in mask recv_futures = [] for idx in mask: - recv_futures.append(self.recv(idx, cid)) + recv_futures.append(self.async_recv(idx, cid)) res = await asyncio.gather(*recv_futures) return res @@ -232,11 +216,11 @@ async def bcast_m(self, pmask: int, root: int, arg: Any) -> Any: if self.rank == root: send_futures = [] for idx in mask: - send_futures.append(self.send(idx, cid, arg)) + send_futures.append(self.async_send(idx, cid, arg)) await asyncio.gather(*send_futures) if self.rank in mask: - return await self.recv(root, cid) + return await self.async_recv(root, cid) else: return None @@ -282,7 +266,7 @@ def new_id(self) -> str: self._counter += 1 return str(res) - async def recv(self, frm: int, key: str) -> Any: + async def async_recv(self, frm: int, key: str) -> Any: """Wait until the key is set, returns the value""" mkey = (frm, key) @@ -336,13 +320,23 @@ def _on_sent_internal(self, frm: int, key: str, data: Any) -> None: else: self._msgboxes[mkey] = data - async def send(self, to: int, key: str, data: Any) -> None: + async def async_send(self, to: int, key: str, data: Any) -> None: # Base implementation for local simulation: directly call peer's onSent # In a real distributed setting, this would put data on wire. raise NotImplementedError( "Must be implemented by subclass or mixin with peer awareness" ) + def send(sefl, to: int, key: str, data: Any) -> None: + raise NotImplementedError( + "Synchronous send not supported in AsyncCommunicatorBase" + ) + + def recv(self, frm: int, key: str) -> Any: + raise NotImplementedError( + "Synchronous recv not supported in AsyncCommunicatorBase" + ) + class AsyncThreadCommunicator(AsyncCommunicatorBase, AsyncCollectiveMixin): """Thread-based async communicator for in-memory communication (simulation)""" @@ -357,7 +351,7 @@ def set_peers(self, peers: list[AsyncThreadCommunicator]) -> None: assert self.world_size == len(peers) self.peers = peers - async def send(self, to: int, key: str, data: Any) -> None: + async def async_send(self, to: int, key: str, data: Any) -> None: assert 0 <= to < self.world_size # In local simulation, we can directly call peer's onSent. # Since we are all in the same process (and likely same loop for simulation), diff --git a/mplang/v1/core/expr/async_evaluator.py b/mplang/v1/core/expr/async_evaluator.py index a9e03256..3926dee1 100644 --- a/mplang/v1/core/expr/async_evaluator.py +++ b/mplang/v1/core/expr/async_evaluator.py @@ -39,6 +39,7 @@ from mplang.v1.core.expr.walk import walk_dataflow from mplang.v1.core.mask import Mask from mplang.v1.core.pfunc import PFunction +from mplang.v1.kernels.context import RuntimeContext from mplang.v1.kernels.value import Value @@ -91,12 +92,12 @@ async def _eval_shfl_s_node_async( # Send phase for src, dst in zip(src_ranks, dst_ranks, strict=True): if self.comm.rank == src: - send_tasks.append(self.comm.send(dst, cid, src_value)) + send_tasks.append(self.comm.async_send(dst, cid, src_value)) # Recv phase for src, dst in zip(src_ranks, dst_ranks, strict=True): if self.comm.rank == dst: - recv_futures.append(self.comm.recv(src, cid)) + recv_futures.append(self.comm.async_recv(src, cid)) # Execute all operations concurrently to avoid deadlock all_tasks = send_tasks + recv_futures @@ -131,13 +132,13 @@ async def _eval_shfl_node_async( send_tasks = [] for dst_rank in range(self.comm.world_size): if dst_rank != self.comm.rank: - send_tasks.append(self.comm.send(dst_rank, cid, idx)) + send_tasks.append(self.comm.async_send(dst_rank, cid, idx)) # Receive index from all ranks recv_tasks = [] for src_rank in range(self.comm.world_size): if src_rank != self.comm.rank: - recv_tasks.append(self.comm.recv(src_rank, cid)) + recv_tasks.append(self.comm.async_recv(src_rank, cid)) # Wait for all operations if send_tasks: @@ -168,13 +169,13 @@ async def _eval_shfl_node_async( data_send_tasks = [] for src_rank, dst_rank in send_pairs: if self.comm.rank == src_rank: - data_send_tasks.append(self.comm.send(dst_rank, cid, data)) + data_send_tasks.append(self.comm.async_send(dst_rank, cid, data)) # Receive data data_recv_tasks = [] for src_rank, dst_rank in send_pairs: if self.comm.rank == dst_rank: - data_recv_tasks.append(self.comm.recv(src_rank, cid)) + data_recv_tasks.append(self.comm.async_recv(src_rank, cid)) # Wait for data operations if data_send_tasks: @@ -310,7 +311,7 @@ async def _spawn_and_gather( return await asyncio.gather(*tasks) -class AsyncIterativeEvaluator(AsyncExprVisitor): +class AsyncIterativeEvaluator(AsyncEvalSemantic): """Async evaluator using iterative traversal to avoid stack overflow. This evaluator follows the same pattern as the synchronous IterativeEvaluator: @@ -319,12 +320,19 @@ class AsyncIterativeEvaluator(AsyncExprVisitor): 3. Processes nodes in dependency order """ - def __init__(self, semantic: AsyncEvalSemantic): - self.semantic = semantic + def __init__( + self, + rank: int, + env: dict[str, Any], + comm: IAsyncCommunicator, + runtime: RuntimeContext, + executor: Executor, + ): + super().__init__(rank, env, comm, runtime, executor) async def evaluate(self, expr: Expr, env: dict[str, Any] | None = None) -> Any: """Entry point for evaluation.""" - evaluation_env = env if env is not None else self.semantic.env + evaluation_env = env if env is not None else self.env result = await self._iter_eval_graph(expr, evaluation_env) return result @@ -366,13 +374,13 @@ async def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]: else: # Optional uniform verification if node.verify_uniform: - await self.semantic._verify_uniform_predicate_async(pred_val) + await self._verify_uniform_predicate_async(pred_val) # Convert to bool if isinstance(pred_val, Value): pred = pred_val.to_bool() else: - pred = bool(self.semantic._unwrap_value(pred_val)) + pred = bool(self._unwrap_value(pred_val)) if pred: sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True)) @@ -397,7 +405,7 @@ async def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]: cond_vals = await self._iter_eval_graph( node.cond_fn.body, {**env, **cond_env} ) - cond_val = self.semantic._check_while_predicate(cond_vals) + cond_val = self._check_while_predicate(cond_vals) if not bool(cond_val): break @@ -411,9 +419,7 @@ async def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]: elif isinstance(node, EvalExpr): arg_vals = [self._first(symbols[id(a)]) for a in node.args] - symbols[id(node)] = await self.semantic._eval_eval_node_async( - node, arg_vals - ) + symbols[id(node)] = await self._eval_eval_node_async(node, arg_vals) elif isinstance(node, ConvExpr): vars_vals = [self._first(symbols[id(v)]) for v in node.vars] @@ -422,16 +428,12 @@ async def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]: elif isinstance(node, ShflSExpr): value = self._first(symbols[id(node.src_val)]) - symbols[id(node)] = await self.semantic._eval_shfl_s_node_async( - node, value - ) + symbols[id(node)] = await self._eval_shfl_s_node_async(node, value) elif isinstance(node, ShflExpr): data = self._first(symbols[id(node.src)]) index = self._first(symbols[id(node.index)]) - symbols[id(node)] = await self.semantic._eval_shfl_node_async( - node, data, index - ) + symbols[id(node)] = await self._eval_shfl_node_async(node, data, index) elif isinstance(node, FuncDefExpr): # FuncDefExpr should not be directly evaluated @@ -467,48 +469,3 @@ async def _eval_conv_node_async( if len(filtered) == 1: return [filtered[0]] raise ValueError(f"pconv called with multiple vars={filtered}.") - - # Implement all required AsyncExprVisitor methods - async def visit_variable(self, expr: VariableExpr, env: dict[str, Any]) -> Any: - """Visit VariableExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_eval(self, expr: EvalExpr, env: dict[str, Any]) -> Any: - """Visit EvalExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_tuple(self, expr: TupleExpr, env: dict[str, Any]) -> Any: - """Visit TupleExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_cond(self, expr: CondExpr, env: dict[str, Any]) -> Any: - """Visit CondExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_while(self, expr: WhileExpr, env: dict[str, Any]) -> Any: - """Visit WhileExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_call(self, expr: CallExpr, env: dict[str, Any]) -> Any: - """Visit CallExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_conv(self, expr: ConvExpr, env: dict[str, Any]) -> Any: - """Visit ConvExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_shfl_s(self, expr: ShflSExpr, env: dict[str, Any]) -> Any: - """Visit ShflSExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_shfl(self, expr: ShflExpr, env: dict[str, Any]) -> Any: - """Visit ShflExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_access(self, expr: AccessExpr, env: dict[str, Any]) -> Any: - """Visit AccessExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") - - async def visit_func_def(self, expr: FuncDefExpr, env: dict[str, Any]) -> Any: - """Visit FuncDefExpr - not used in new implementation.""" - raise NotImplementedError("Use _iter_eval_graph instead") diff --git a/mplang/v1/runtime/communicator.py b/mplang/v1/runtime/communicator.py index 9f66f5ec..1bad4d7e 100644 --- a/mplang/v1/runtime/communicator.py +++ b/mplang/v1/runtime/communicator.py @@ -138,7 +138,7 @@ def __init__( f"AsyncHttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={self.endpoints}" ) - async def send(self, to: int, key: str, data: Any) -> None: + async def async_send(self, to: int, key: str, data: Any) -> None: """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.""" target_endpoint = self.endpoints[to] url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}" @@ -173,12 +173,12 @@ async def send(self, to: int, key: str, data: Any) -> None: ) raise OSError(f"Failed to send data to rank {to}") from e - async def recv(self, frm: int, key: str) -> Any: + async def async_recv(self, frm: int, key: str) -> Any: """Wait until the key is set, returns the value.""" logging.debug( f"Async waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}" ) - data_b64 = await super().recv(frm, key) + data_b64 = await super().async_recv(frm, key) data_bytes = base64.b64decode(data_b64) # Deserialize using Value envelope diff --git a/mplang/v1/runtime/server.py b/mplang/v1/runtime/server.py index 0062aedb..48fd5299 100644 --- a/mplang/v1/runtime/server.py +++ b/mplang/v1/runtime/server.py @@ -18,7 +18,9 @@ """ import base64 +from concurrent.futures import Executor, ThreadPoolExecutor import logging +import os import re from typing import Any @@ -55,6 +57,7 @@ # per-server global state _sessions: dict[str, Session] = {} _global_symbols: dict[str, Symbol] = {} +_executor: Executor = ThreadPoolExecutor(max_workers=os.cpu_count()) def register_session(session: Session) -> Session: # pragma: no cover - test helper @@ -271,7 +274,9 @@ def create_session(session_name: str, request: CreateSessionRequest) -> SessionR sess = _sessions[session_name] else: spec = ClusterSpec.from_dict(request.cluster_spec) - sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec) + sess = create_session_from_spec( + name=session_name, rank=request.rank, spec=spec, async_mode=True + ) _sessions[session_name] = sess return SessionResponse(name=sess.name) @@ -300,7 +305,7 @@ def delete_session(session_name: str) -> dict[str, str]: "/sessions/{session_name}/computations/{computation_id}", response_model=ComputationResponse, ) -def create_and_execute_computation( +async def create_and_execute_computation( session_name: str, computation_id: str, request: CreateComputationRequest ) -> ComputationResponse: graph_proto = mpir_pb2.GraphProto() @@ -325,12 +330,14 @@ def create_and_execute_computation( if not comp: comp = Computation(name=computation_id, expr=expr) sess.add_computation(comp) - sess.execute(comp, request.input_names, request.output_names) + await sess.async_execute( + comp, request.input_names, request.output_names, executor=_executor + ) return ComputationResponse(name=computation_id) @app.delete("/sessions/{session_name}/computations/{computation_id}") -def delete_computation(session_name: str, computation_id: str) -> dict[str, str]: +async def delete_computation(session_name: str, computation_id: str) -> dict[str, str]: """Delete a specific computation.""" sess = _sessions.get(session_name) if sess and sess.delete_computation(computation_id): diff --git a/mplang/v1/runtime/session.py b/mplang/v1/runtime/session.py index 27e3a369..bc35f6cd 100644 --- a/mplang/v1/runtime/session.py +++ b/mplang/v1/runtime/session.py @@ -27,6 +27,7 @@ import logging import time +from concurrent.futures import Executor from dataclasses import dataclass, field from functools import cached_property from typing import TYPE_CHECKING, Any, cast @@ -34,15 +35,17 @@ import spu.libspu as libspu +from mplang.v1.core.async_comm import IAsyncCommunicator from mplang.v1.core.cluster import ClusterSpec from mplang.v1.core.comm import ICommunicator from mplang.v1.core.expr.ast import Expr +from mplang.v1.core.expr.async_evaluator import AsyncIterativeEvaluator from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator from mplang.v1.core.mask import Mask from mplang.v1.kernels.context import RuntimeContext from mplang.v1.kernels.spu import PFunction # type: ignore from mplang.v1.kernels.value import Value -from mplang.v1.runtime.communicator import HttpCommunicator +from mplang.v1.runtime.communicator import AsyncHttpCommunicator, HttpCommunicator from mplang.v1.runtime.exceptions import ResourceNotFound from mplang.v1.runtime.link_comm import LinkCommunicator from mplang.v1.utils.spu_utils import parse_field, parse_protocol @@ -281,17 +284,68 @@ def execute( ) self.add_symbol(Symbol(name=name, mptype={}, data=val)) + async def async_execute( + self, + computation: Computation, + input_names: list[str], + output_names: list[str], + executor: Executor, + ) -> None: + if not isinstance(self.communicator, IAsyncCommunicator): + raise RuntimeError("Session.async_execute requires an async communicator") + + env: dict[str, Any] = {} + for in_name in input_names: + sym = self.get_symbol(in_name) + if sym is None: + raise ResourceNotFound( + f"Input symbol '{in_name}' not found in session '{self.name}'" + ) + env[in_name] = sym.data + rt = self.ensure_runtime() + self.ensure_spu_env() + evaluator = AsyncIterativeEvaluator( + rank=self.rank, + env=env, + comm=self.communicator, + runtime=rt, + executor=executor, + ) + results = await evaluator.evaluate(computation.expr, env) + if results and len(results) != len(output_names): + raise RuntimeError( + f"Expected {len(output_names)} results, got {len(results)}" + ) + for name, val in zip(output_names, results, strict=True): + # In pure SIMP model, all nodes should have the same symbol table. + # Non-participating nodes get None values. + if val is not None and not isinstance(val, Value): + raise TypeError( + "Session executions must produce kernel Value outputs; " + f"got {type(val).__name__} for symbol '{name}'" + ) + self.add_symbol(Symbol(name=name, mptype={}, data=val)) + # --- Convenience constructor use HttpCommunicator--- -def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session: +def create_session_from_spec( + name: str, rank: int, spec: ClusterSpec, async_mode: bool = False +) -> Session: if len(spec.get_devices_by_kind("SPU")) == 0: raise RuntimeError("No SPU device found in cluster_spec") # Create HttpCommunicator for the session - communicator = HttpCommunicator( - session_name=name, - rank=rank, - endpoints=spec.endpoints, - ) + if async_mode: + communicator: ICommunicator = AsyncHttpCommunicator( + session_name=name, + rank=rank, + endpoints=spec.endpoints, + ) + else: + communicator = HttpCommunicator( + session_name=name, + rank=rank, + endpoints=spec.endpoints, + ) return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator) diff --git a/mplang/v1/runtime/simulation.py b/mplang/v1/runtime/simulation.py index 3cf44d46..1bc469e7 100644 --- a/mplang/v1/runtime/simulation.py +++ b/mplang/v1/runtime/simulation.py @@ -292,14 +292,13 @@ async def _evaluate_async( # Initialize SPU if needed (same logic as sync) self._ensure_spu_init(rank) - semantic = AsyncEvalSemantic( + ev = AsyncIterativeEvaluator( rank=rank, env=pts_env[rank], comm=async_comms[rank], runtime=runtime, executor=self._executor, ) - ev = AsyncIterativeEvaluator(semantic) evaluators.append(ev) # 4. Run Evaluation concurrently From bf10554042709c59c3c37d0e63af3d5fff289062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Wed, 10 Dec 2025 15:49:09 +0800 Subject: [PATCH 3/6] cleanup --- mplang/v1/core/expr/async_evaluator.py | 103 ++----------------------- mplang/v1/core/expr/visitor.py | 48 ------------ mplang/v1/runtime/session.py | 1 - mplang/v1/runtime/simulation.py | 2 - 4 files changed, 5 insertions(+), 149 deletions(-) diff --git a/mplang/v1/core/expr/async_evaluator.py b/mplang/v1/core/expr/async_evaluator.py index 3926dee1..2086f57c 100644 --- a/mplang/v1/core/expr/async_evaluator.py +++ b/mplang/v1/core/expr/async_evaluator.py @@ -35,7 +35,6 @@ WhileExpr, ) from mplang.v1.core.expr.evaluator import EvalSemantic -from mplang.v1.core.expr.visitor import AsyncExprVisitor from mplang.v1.core.expr.walk import walk_dataflow from mplang.v1.core.mask import Mask from mplang.v1.core.pfunc import PFunction @@ -47,13 +46,15 @@ class AsyncEvalSemantic(EvalSemantic): """Async version of EvalSemantic. - Reuses pure computation logic from EvalSemantic but overrides I/O bound methods - to use IAsyncCommunicator. + Reuses pure computation logic from EvalSemantic """ - comm: IAsyncCommunicator # Override type hint executor: Executor | None = None + def __post_init__(self) -> None: + if not isinstance(self.comm, IAsyncCommunicator): + raise TypeError("AsyncEvalSemantic requires an IAsyncCommunicator instance") + async def _exec_pfunc_async(self, pfunc: PFunction, args: list[Any]) -> list[Any]: # Check if any args are None - if so, this rank shouldn't participate # This prevents None values from reaching kernel validation @@ -217,100 +218,6 @@ def _as_optional_int(val: Any) -> int | None: return None -class AsyncRecursiveEvaluator(AsyncExprVisitor): - """Original async evaluator using recursive visitor pattern. - - This evaluator can cause stack overflow with deeply nested control flow. - Kept for reference and fallback. - """ - - def __init__(self, semantic: AsyncEvalSemantic): - self.semantic = semantic - - def _first(self, vals: list[Any]) -> Any: - if not isinstance(vals, list): - return vals - if len(vals) == 0: - return None - return vals[0] - - async def evaluate(self, expr: Expr, env: dict[str, Any] | None = None) -> Any: - evaluation_env = env if env is not None else self.semantic.env - return await expr.accept_async(self, evaluation_env) - - async def visit_cond(self, expr: CondExpr, env: dict[str, Any]) -> Any: - pred_res = await expr.pred.accept_async(self, env) - pred = self._first(pred_res) - - args_results = await self._spawn_and_gather(expr.args, env) - flat_args = [self._first(res) for res in args_results] - - if expr.verify_uniform: - await self.semantic._verify_uniform_predicate_async(pred) - - if isinstance(pred, Value): - pred_bool = pred.to_bool() - else: - pred_bool = bool(self.semantic._unwrap_value(pred)) - - if pred_bool: - new_env = {**env, **dict(zip(expr.then_fn.params, flat_args, strict=True))} - res = await expr.then_fn.body.accept_async(self, new_env) - else: - new_env = {**env, **dict(zip(expr.else_fn.params, flat_args, strict=True))} - res = await expr.else_fn.body.accept_async(self, new_env) - return res - - async def visit_call(self, expr: CallExpr, env: dict[str, Any]) -> Any: - args_results = await self._spawn_and_gather(expr.args, env) - flat_args = [self._first(res) for res in args_results] - # Bind arguments - new_env = {**env, **dict(zip(expr.fn.params, flat_args, strict=True))} - res = await expr.fn.body.accept_async(self, new_env) - return res - - async def visit_while(self, expr: WhileExpr, env: dict[str, Any]) -> Any: - curr_vals_results = await self._spawn_and_gather(expr.args, env) - curr_vals = [self._first(res) for res in curr_vals_results] - - # Determine split between state and captures - num_state = expr.body_fn.num_outputs - - # Initial state and captures - curr_state = curr_vals[:num_state] - captures = curr_vals[num_state:] - - while True: - # Reconstruct full arguments: state + captures - full_args = curr_state + captures - - # Check condition - cond_env = {**env, **dict(zip(expr.cond_fn.params, full_args, strict=True))} - cond_res = await expr.cond_fn.body.accept_async(self, cond_env) - - # Validate condition - cond_val = self.semantic._check_while_predicate(cond_res) - - if not cond_val: - break - - # Execute body - body_env = {**env, **dict(zip(expr.body_fn.params, full_args, strict=True))} - body_res = await expr.body_fn.body.accept_async(self, body_env) - - # Update state - body_res is already a list - curr_state = body_res - - return curr_state - - async def _spawn_and_gather( - self, exprs: list[Expr], env: dict[str, Any] - ) -> list[Any]: - """Spawn async tasks for multiple expressions and gather results.""" - tasks = [expr.accept_async(self, env) for expr in exprs] - return await asyncio.gather(*tasks) - - class AsyncIterativeEvaluator(AsyncEvalSemantic): """Async evaluator using iterative traversal to avoid stack overflow. diff --git a/mplang/v1/core/expr/visitor.py b/mplang/v1/core/expr/visitor.py index a601a00f..c63e8055 100644 --- a/mplang/v1/core/expr/visitor.py +++ b/mplang/v1/core/expr/visitor.py @@ -83,51 +83,3 @@ def visit_access(self, expr: AccessExpr) -> Any: @abstractmethod def visit_func_def(self, expr: FuncDefExpr) -> Any: pass - - -class AsyncExprVisitor(ABC): - """Async visitor interface that supports environment passing.""" - - @abstractmethod - async def visit_eval(self, expr: EvalExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_variable(self, expr: VariableExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_tuple(self, expr: TupleExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_cond(self, expr: CondExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_call(self, expr: CallExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_while(self, expr: WhileExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_conv(self, expr: ConvExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_shfl_s(self, expr: ShflSExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_shfl(self, expr: ShflExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_access(self, expr: AccessExpr, env: dict[str, Any]) -> Any: - pass - - @abstractmethod - async def visit_func_def(self, expr: FuncDefExpr, env: dict[str, Any]) -> Any: - pass diff --git a/mplang/v1/runtime/session.py b/mplang/v1/runtime/session.py index bc35f6cd..45fa3247 100644 --- a/mplang/v1/runtime/session.py +++ b/mplang/v1/runtime/session.py @@ -195,7 +195,6 @@ def ensure_spu_env(self) -> None: spu_addrs: list[str] = [] for r, addr in enumerate(self.cluster_spec.endpoints): if r in self.spu_mask: - # TODO(oeqqwq): addr may contain other schema like grpc:// if not addr.startswith(("http://", "https://")): addr = f"http://{addr}" parsed = urlparse(addr) diff --git a/mplang/v1/runtime/simulation.py b/mplang/v1/runtime/simulation.py index 1bc469e7..ef55ab63 100644 --- a/mplang/v1/runtime/simulation.py +++ b/mplang/v1/runtime/simulation.py @@ -40,8 +40,6 @@ from mplang.v1.core.async_comm import AsyncThreadCommunicator from mplang.v1.core.expr.ast import Expr from mplang.v1.core.expr.async_evaluator import ( - AsyncEvalSemantic, - AsyncRecursiveEvaluator, AsyncIterativeEvaluator, ) from mplang.v1.core.expr.evaluator import IEvaluator From 9b5d338381970154aed2d833b17cd07b099cb51f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Wed, 10 Dec 2025 15:49:40 +0800 Subject: [PATCH 4/6] cleanup --- tests/v1/core/test_async_simulation.py | 233 ------------------------- 1 file changed, 233 deletions(-) delete mode 100644 tests/v1/core/test_async_simulation.py diff --git a/tests/v1/core/test_async_simulation.py b/tests/v1/core/test_async_simulation.py deleted file mode 100644 index 6bd8db32..00000000 --- a/tests/v1/core/test_async_simulation.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright 2025 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax.numpy as jnp -import numpy as np -import pytest - -from mplang.v1.core import ClusterSpec, Mask -from mplang.v1.core.context_mgr import with_ctx -from mplang.v1.core.primitive import function -from mplang.v1.core.tracer import TraceContext, trace -from mplang.v1.ops import jax_cc -from mplang.v1.runtime.simulation import Simulator, SimVar -from mplang.v1.simp.api import constant, run - - -def add(x, y): - return run(None, jax_cc.run_jax, lambda a, b: jnp.add(a, b), x, y) - - -@pytest.mark.asyncio -async def test_async_simulation_basic(): - """Test basic async simulation.""" - - # 1. Define a function - @function - def add_func(x, y): - return add(x, y) - - # 2. Trace it - cluster = ClusterSpec.simple(2) - ctx = TraceContext(cluster, mask=Mask(3)) - - # Let's trace a function that adds two constants. - @function - def simple_add(): - a = constant(10) - b = constant(20) - return add(a, b) - - traced = trace(ctx, simple_add) - expr = traced.make_expr() - - # 3. Create Simulator - sim = Simulator(cluster) - - # 4. Evaluate Async - # simple_add takes no args, so bindings is empty - # make_expr() returns a FuncDefExpr. We want to evaluate its body. - results = await sim._evaluate_async(expr.body, {}) - - # 5. Verify - assert len(results) == 1 - sim_var = results[0] - assert isinstance(sim_var, SimVar) - values = sim_var.values - assert len(values) == 2 - assert values[0] == 30 - assert values[1] == 30 - - -@pytest.mark.asyncio -async def test_async_simulation_args(): - """Test async simulation with arguments.""" - - @function - def add_func(x, y): - return add(x, y) - - cluster = ClusterSpec.simple(2) - ctx = TraceContext(cluster, mask=Mask(3)) - - # Let's manually create TraceVars for inputs. - from mplang.v1.core.dtypes import INT32 - from mplang.v1.core.expr.ast import VariableExpr - from mplang.v1.core.mptype import MPType - from mplang.v1.core.tracer import TraceVar - from mplang.v1.kernels.value import TensorValue - - x_type = MPType.tensor(INT32, (), Mask(3)) - y_type = MPType.tensor(INT32, (), Mask(3)) - - with with_ctx(ctx): - x = TraceVar(ctx, VariableExpr("x", x_type)) - y = TraceVar(ctx, VariableExpr("y", y_type)) - - traced = trace(ctx, add_func, x, y) - expr = traced.make_expr() - - # Now evaluate with bindings - sim = Simulator(cluster) - - # Create input values - x_val = SimVar( - sim, - x_type, - [ - TensorValue(np.array(10, dtype=np.int32)), - TensorValue(np.array(10, dtype=np.int32)), - ], - ) - y_val = SimVar( - sim, - y_type, - [ - TensorValue(np.array(20, dtype=np.int32)), - TensorValue(np.array(20, dtype=np.int32)), - ], - ) - - bindings = {"x": x_val, "y": y_val} - # expr is FuncDefExpr. We need to evaluate its body. - # But wait, the body refers to parameters. - # FuncDefExpr params are generated names usually? - # Or they match the names we gave? - - # When we use `trace(ctx, func, x, y)`, `x` and `y` are passed as arguments. - # `trace` captures them. - - # If `x` and `y` are TraceVars with VariableExpr, they are treated as inputs? - # `trace` logic: - # It calls the function with arguments. - # If arguments are TraceVars, they are used. - - # The resulting FuncDefExpr will have parameters corresponding to the inputs. - # But `x` and `y` are captured from the outer scope if we pass them? - # No, they are passed as arguments. - - # Let's check `traced.make_expr()` logic. - # It creates a FuncDefExpr. - # The parameters of FuncDefExpr correspond to the arguments of the traced function. - - # But we need to know the parameter names to bind them. - # `traced.make_expr()` might generate parameter names. - - # Let's inspect `expr.params`. - - # For now, let's assume we can bind by position if we construct a CallExpr. - # But `evaluate` takes `bindings` which is a dict. - - # If we evaluate `expr.body`, it contains `VariableExpr`s. - # These `VariableExpr`s refer to the parameter names. - - # So we need to map our input values to these parameter names. - - # `traced.make_expr()` returns `FuncDefExpr`. - # `expr.params` gives the list of parameter names. - - # So we should map `expr.params` to our values. - - param_names = expr.params - assert len(param_names) == 2 - - bindings = {param_names[0]: x_val, param_names[1]: y_val} - - results = await sim._evaluate_async(expr.body, bindings) - - assert len(results) == 1 - assert results[0].values == [30, 30] - - -def test_sync_evaluate(): - """Test synchronous evaluate.""" - - @function - def simple_add(): - a = constant(10) - b = constant(20) - return add(a, b) - - cluster = ClusterSpec.simple(2) - ctx = TraceContext(cluster, mask=Mask(3)) - traced = trace(ctx, simple_add) - expr = traced.make_expr() - - sim = Simulator(cluster) - # This calls evaluate (sync) which calls asyncio.run(_evaluate_async) - results = sim.evaluate(expr.body, {}) - - assert len(results) == 1 - values = results[0].values - assert values[0] == 30 - assert values[1] == 30 - - -@pytest.mark.asyncio -async def test_evaluate_in_loop(): - """Test evaluate inside an async loop (should fail without nest_asyncio or work with it).""" - - @function - def simple_add(): - a = constant(10) - b = constant(20) - return add(a, b) - - cluster = ClusterSpec.simple(2) - ctx = TraceContext(cluster, mask=Mask(3)) - traced = trace(ctx, simple_add) - expr = traced.make_expr() - - sim = Simulator(cluster) - - # We are in an async test, so there is a running loop. - # Calling sim.evaluate() should raise RuntimeError or work if nest_asyncio is present. - import asyncio - - await asyncio.sleep(0) # Silence RUF029 and ensure loop is running - - import importlib.util - - nest_asyncio_installed = importlib.util.find_spec("nest_asyncio") is not None - - if nest_asyncio_installed: - # If nest_asyncio is installed, it should work (assuming apply() is called inside evaluate) - results = sim.evaluate(expr.body, {}) - assert len(results) == 1 - values = results[0].values - assert values[0] == 30 - else: - # If not installed, it should raise RuntimeError - with pytest.raises(RuntimeError, match="nest_asyncio"): - sim.evaluate(expr.body, {}) From 481e42d074d68baa1aff3290fe84e4e59e700c40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Wed, 10 Dec 2025 15:51:51 +0800 Subject: [PATCH 5/6] cleanup --- mplang/v1/runtime/simulation.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/mplang/v1/runtime/simulation.py b/mplang/v1/runtime/simulation.py index ef55ab63..1b212717 100644 --- a/mplang/v1/runtime/simulation.py +++ b/mplang/v1/runtime/simulation.py @@ -235,26 +235,7 @@ def _ensure_spu_init(self, rank: int) -> None: # override def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - # Case A: Inside an existing loop (e.g., Jupyter) - try: - import nest_asyncio - - nest_asyncio.apply() - return loop.run_until_complete(self._evaluate_async(expr, bindings)) - except ImportError as e: - raise RuntimeError( - "Running in an active event loop (e.g. Jupyter). " - "Please install 'nest_asyncio' or use 'await simulator.evaluate_async(...)'." - ) from e - else: - # Case B: Standard script - return asyncio.run(self._evaluate_async(expr, bindings)) + return asyncio.run(self._evaluate_async(expr, bindings)) async def _evaluate_async( self, expr: Expr, bindings: dict[str, MPObject] From b0dc33a0aa3cd488af353d76a0d95b835227664b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Wed, 10 Dec 2025 16:29:27 +0800 Subject: [PATCH 6/6] update --- mplang/v1/core/async_comm.py | 2 +- mplang/v1/core/expr/async_evaluator.py | 87 ++++++++++++++++++-------- tests/v1/core/test_async_comm.py | 14 +++++ 3 files changed, 76 insertions(+), 27 deletions(-) diff --git a/mplang/v1/core/async_comm.py b/mplang/v1/core/async_comm.py index f6a6ced6..1f829107 100644 --- a/mplang/v1/core/async_comm.py +++ b/mplang/v1/core/async_comm.py @@ -327,7 +327,7 @@ async def async_send(self, to: int, key: str, data: Any) -> None: "Must be implemented by subclass or mixin with peer awareness" ) - def send(sefl, to: int, key: str, data: Any) -> None: + def send(self, to: int, key: str, data: Any) -> None: raise NotImplementedError( "Synchronous send not supported in AsyncCommunicatorBase" ) diff --git a/mplang/v1/core/expr/async_evaluator.py b/mplang/v1/core/expr/async_evaluator.py index 2086f57c..da1c7fbc 100644 --- a/mplang/v1/core/expr/async_evaluator.py +++ b/mplang/v1/core/expr/async_evaluator.py @@ -188,34 +188,69 @@ async def _eval_shfl_node_async( return [received_data] - def _as_optional_int(self, val: Any) -> int | None: - """Convert a value to int if possible, preserving None.""" - val = EvalSemantic._unwrap_value(val) - if val is None: - return None - return int(val) + async def _simple_allgather_async(self, value: Any) -> list[Any]: + """Async all-gather emulation using async send/recv. + + This implements an O(P^2) pairwise exchange (each rank sends its value to all + other ranks) and collects values in rank order. Suitable for small P (typical + controller / simulation sizes) and control metadata like a single bool. + + Returns a list of length world_size with entries ordered by rank. + """ + ws = self.comm.world_size + value = self._unwrap_value(value) + # Trivial fast-path + if ws == 1: + return [value] + cid = self.comm.new_id() + gathered: list[Any] = [None] * ws # type: ignore + gathered[self.comm.rank] = value + + # Create async tasks for all send and receive operations + tasks = [] + # Fan-out: send to all other ranks + for dst in range(ws): + if dst != self.comm.rank: + tasks.append(self.comm.async_send(dst, cid, value)) + # Fan-in: receive from all other ranks + for src in range(ws): + if src != self.comm.rank: + tasks.append(self.comm.async_recv(src, cid)) + + # Wait for all operations to complete + results = await asyncio.gather(*tasks) + + # Process results: first half are sends (which return None), second half are receives + recv_results = results[len(results) // 2:] + for i, src in enumerate([r for r in range(ws) if r != self.comm.rank]): + gathered[src] = recv_results[i] + + return gathered async def _verify_uniform_predicate_async(self, pred: Any) -> None: - # For now, just pass - # Would need proper async implementation for uniform verification - pass + """Async version of uniform predicate verification using async collective communication. + + Verifies that the predicate value is uniform across all parties by performing + an async all-gather operation and checking that all values are identical. + """ + # Use Value.to_bool() if available, otherwise unwrap and convert + if isinstance(pred, Value): + pred_bool = pred.to_bool() + else: + pred_bool = bool(self._unwrap_value(pred)) - @staticmethod - def _as_optional_int(val: Any) -> int | None: - if isinstance(val, int): - return val - if isinstance(val, Value): - if hasattr(val, "value"): - return int(val.value) - # Try to convert TensorValue using to_numpy - to_numpy = getattr(val, "to_numpy", None) - if callable(to_numpy): - arr = to_numpy() - import numpy as np - - if isinstance(arr, np.ndarray) and arr.size == 1: - return int(arr.item()) - return None + # Use async allgather to collect predicate values from all parties + vals = await self._simple_allgather_async(pred_bool) + + if not vals: + raise ValueError("uniform_cond: empty gather for predicate") + + first = vals[0] + for v in vals[1:]: + if v != first: + raise ValueError( + "uniform_cond: predicate is not uniform across parties" + ) class AsyncIterativeEvaluator(AsyncEvalSemantic): @@ -365,7 +400,7 @@ def _merge_state(self, old: list[Any], new: list[Any]) -> list[Any]: return new + old[len(new) :] async def _eval_conv_node_async( - self, expr: ConvExpr, vars_vals: list[Any] + self, _expr: ConvExpr, vars_vals: list[Any] ) -> list[Any]: """Async version of conv node evaluation.""" # Implement the same logic as sync _eval_conv_node diff --git a/tests/v1/core/test_async_comm.py b/tests/v1/core/test_async_comm.py index be5cc4ab..71fecad1 100644 --- a/tests/v1/core/test_async_comm.py +++ b/tests/v1/core/test_async_comm.py @@ -1,3 +1,17 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio import pytest