diff --git a/reward_queue/README.md b/reward_queue/README.md new file mode 100644 index 00000000..1acea798 --- /dev/null +++ b/reward_queue/README.md @@ -0,0 +1,167 @@ +# Reward Queue: Decoupled Inference and Reward Computation + +## Required `verl` version + +See [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt) for the upstream repository, install mode (rolling `main`, pinned release tag, or pinned git commit), and copy-pastable `pip` / `git` instructions where they exist. + +## Overview + +Reward Queue decouples inference (generation) from reward computation in VERL's fully asynchronous training pipeline. It introduces an intermediate queue between the two stages, enabling concurrent execution and maximizing GPU utilization. + +When reward computation involves slow external LLM judges or complex scoring functions, the traditional tightly-coupled pipeline wastes GPU cycles waiting for scores. This recipe solves that bottleneck. + +## Architecture + +![Reward Queue Architecture](./images/reward_queue_architecture.png) + +**Core pipeline:** + +``` +┌─────────────────┐ ┌──────────────┐ ┌─────────────────┐ +│ Generation │────▶│ RewardQueue │────▶│ Reward Compute │ +│ (async) │ │ │ │ (concurrent) │ +└─────────────────┘ └──────────────┘ └─────────────────┘ +``` + +**Key components:** + +| Component | File | Role | +| ------------------- | ----------------- | ------------------------------------------------------------ | +| `RewardQueue` | `reward_queue.py` | Ray actor-based async queue with producer-consumer semantics | +| `SubRewardDataItem` | `utils.py` | Data item passed through the queue | +| `SampleAggregator` | `utils.py` | Accumulates scored sub-items per sample | +| `Rollouter` | `rollouter.py` | Extended FullyAsyncRollouter with reward queue support | +| `Trainer` | `trainer.py` | Extended FullyAsyncTrainer with timing metadata | + +**Processing flow:** + +1. `_processor_worker` launches async generation for each sub-item +2. Generated outputs are immediately buffered into `RewardQueue` (no waiting for scores) +3. `_reward_consumer_worker` pulls from queue and distributes scoring across workers +4. `SampleAggregator` accumulates scored sub-items per sample +5. `_finalize_sample` assembles complete batch and publishes to `MessageQueue` for trainer + +## Quick Start + +### Enable the Feature + +Set `async_training.enable_reward_queue: true` in your config: + +```yaml +async_training: + enable_reward_queue: true + reward_queue_size: null # Uses default: max_required_samples * rollout_n +``` + +Or via command line: + +```bash +python -m recipe.reward_queue.main \ + --config-path=config \ + --config-name='fully_async' \ + async_training.enable_reward_queue=true \ + # ... other config +``` + +### Run Training + +```bash +# Single node (8 GPUs) +NNODES=1 NGPUS_PER_NODE=8 \ +MODEL_PATH=Qwen3.5-9B \ +TRAIN_FILE=./gsm8k/train/gsm8k_tra.jsonl \ +VAL_FILE=./gsm8k/eval/gsm8k_ev.jsonl \ +bash recipe/reward_queue/train_async.sh +``` + +### Run with Custom Settings + +```bash +NNODES=2 \ +NGPUS_PER_NODE=8 \ +MODEL_PATH=Qwen3.5-9B \ +TRAIN_BATCH_SIZE=8 \ +N_SAMPLE=8 \ +TOTAL_TRAINING_STEPS=500 \ +ASYNC_STALENESS=0.3 \ +ASYNC_SYNC_STEP=2 \ +ASYNC_REQUIRE_BATCHES=4 \ +bash recipe/reward_queue/train_async.sh +``` + +## Configuration + +### Async Training Config + +| Parameter | Default | Description | +| ------------------------------------ | ------- | ------------------------------------------------------------ | +| `async_training.enable_reward_queue` | `false` | Enable/disable reward queue decoupling | +| `async_training.reward_queue_size` | `null` | Max queue size. `null` means `max_required_samples * rollout_n` | + +### Environment Variables + +| Variable | Default | Description | +| ----------------------- | ------------------------------- | -------------------------------- | +| `NNODES` | `1` | Number of nodes | +| `NGPUS_PER_NODE` | `8` | GPUs per node | +| `MODEL_PATH` | `Qwen3.5-9B` | Model path | +| `TRAIN_FILE` | `./gsm8k/train/gsm8k_tra.jsonl` | Training data | +| `VAL_FILE` | `./gsm8k/eval/gsm8k_ev.jsonl` | Validation data | +| `TRAIN_BATCH_SIZE` | `8` | Training batch size | +| `N_SAMPLE` | `8` | Responses per prompt (rollout_n) | +| `TOTAL_TRAINING_STEPS` | `500` | Total training steps | +| `ASYNC_STALENESS` | `0.3` | Staleness threshold | +| `ASYNC_SYNC_STEP` | `2` | Parameter sync trigger step | +| `ASYNC_REQUIRE_BATCHES` | `4` | Required batches | + +## Monitoring Metrics + +The reward queue exports the following metrics to W&B: + +| Metric | Description | +| --------------------------------- | --------------------------------------- | +| `monitor/queue/reward_queue_size` | Current reward queue size | +| `reward_queue/total_produced` | Total items produced to queue | +| `reward_queue/total_consumed` | Total items consumed from queue | +| `reward_queue/dropped_samples` | Samples dropped due to queue overflow | +| `static/max_reward_queue_size` | Maximum configured queue size | +| `timing_s/reward_compute/mean` | Mean reward computation time | +| `timing_s/reward_compute/max` | Max reward computation time | +| `timing_s/reward_compute/tp95` | 95th percentile reward computation time | +| `aggregator/pending_groups_count` | Number of samples awaiting completion | +| `aggregator/total_pending` | Total sub-items awaiting scoring | + +## Use Cases + +1. **External LLM Judges**: When reward computation calls external LLM APIs (e.g., LLM-as-a-Judge), network latency is overlapped with generation. + +2. **Complex Scoring Functions**: Multi-step reward pipelines with multiple model calls benefit from overlapping generation with scoring. + +3. **Variable Reward Latency**: When computation time varies significantly across samples, the queue buffers fast results while waiting for slow ones. + +4. **Throughput Optimization**: Maximizes GPU utilization by keeping either generation or scoring always active. + +## File Layout + +``` +reward_queue/ +├── REQUIRED_VERL.txt +├── README.md +├── reward_queue_architecture.drawio +├── main.py # Hydra entry point with TaskRunner +├── rollouter.py # Extended FullyAsyncRollouter +├── trainer.py # Extended FullyAsyncTrainer +├── reward_queue.py # RewardQueue and RewardQueueClient +├── utils.py # SubRewardDataItem, SampleAggregator, etc. +├── train_async.sh # Launch script +├── agent_loop/ +│ └── agent_loop.py # AgentLoopWorkerForRewardQueue +└── config/ + └── fully_async.yaml # Base config +``` + +## Design Notes + +- **Backpressure control**: When `MessageQueue` is full, scoring pauses automatically to prevent resource exhaustion. +- **Concurrent scoring**: Multiple reward workers score sub-items in parallel, throttled by `max_concurrent_samples * rollout_n`. +- **Temporal decoupling**: Inference output and reward computation run at their own pace via the queue buffer. \ No newline at end of file diff --git a/reward_queue/REQUIRED_VERL.txt b/reward_queue/REQUIRED_VERL.txt new file mode 100644 index 00000000..868f1100 --- /dev/null +++ b/reward_queue/REQUIRED_VERL.txt @@ -0,0 +1,6 @@ +UPSTREAM=https://github.com/verl-project/verl.git +MODE=pinned_tag +TAG=v0.8.0 +COMMIT=bbca85c9b8131bacf6efcb4892af882e8ade4b5d +PIP_INSTALL=pip install verl==0.8.0 +NOTES=reward_queue recipe is developed and tested against verl v0.8.0. diff --git a/reward_queue/__init__.py b/reward_queue/__init__.py new file mode 100644 index 00000000..6dcb1ac5 --- /dev/null +++ b/reward_queue/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/reward_queue/agent_loop/agent_loop.py b/reward_queue/agent_loop/agent_loop.py new file mode 100644 index 00000000..7cde8d5c --- /dev/null +++ b/reward_queue/agent_loop/agent_loop.py @@ -0,0 +1,299 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +AgentLoopWorkerForRewardQueue extends AgentLoopWorker for reward queue workflow. +""" + +import logging +import os +from typing import Any + +import hydra +import numpy as np +import torch +from tensordict import TensorDict + +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopOutput, + AgentLoopWorker, + DictConfigWrap, + ToolListWrap, + _InternalAgentLoopOutput, + _agent_loop_registry, + get_trajectory_info, + rollout_trace_attr, +) +from verl.protocol import DataProto + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AgentLoopWorkerForRewardQueue(AgentLoopWorker): + async def _run_reward_agent_loop( + self, + sampling_params: dict[str, Any], + trajectory: dict[str, Any], + *, + agent_name: str, + trace: bool = True, + **kwargs, + ) -> _InternalAgentLoopOutput: + with rollout_trace_attr( + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], + validate=trajectory["validate"], + name="agent_loop", + trace=trace, + ): + assert agent_name in _agent_loop_registry, ( + f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" + ) + + agent_loop_config = _agent_loop_registry[agent_name] + agent_loop = hydra.utils.instantiate( + config=agent_loop_config, + trainer_config=DictConfigWrap(config=self.config), + server_manager=self.llm_client, + tokenizer=self.tokenizer, + processor=self.processor, + dataset_cls=self.dataset_cls, + data_config=DictConfigWrap(self.config.data), + tools=ToolListWrap(self.tools), + ) + + import time as _time + inference_start_timestamp = _time.time() + output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) + inference_end_timestamp = _time.time() + + output.extra_fields["inference_start_timestamp"] = inference_start_timestamp + output.extra_fields["inference_end_timestamp"] = inference_end_timestamp + output.extra_fields["inference_duration"] = inference_end_timestamp - inference_start_timestamp + + return output + + def _padding_postprocess(self, output, validate, **kwargs) -> _InternalAgentLoopOutput: + output.extra_fields["raw_prompt"] = kwargs.get("raw_prompt", None) + + self.tokenizer.padding_side = "left" + prompt_output = self.tokenizer.pad( + {"input_ids": output.prompt_ids}, + padding="max_length", + max_length=self.rollout_config.prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + if prompt_output["input_ids"].dim() == 1: + prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0) + prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0) + + self.tokenizer.padding_side = "right" + response_output = self.tokenizer.pad( + {"input_ids": output.response_ids}, + padding="max_length", + max_length=self.rollout_config.response_length, + return_tensors="pt", + return_attention_mask=True, + ) + if response_output["input_ids"].dim() == 1: + response_output["input_ids"] = response_output["input_ids"].unsqueeze(0) + response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0) + + response_mask_output = self.tokenizer.pad( + {"input_ids": output.response_mask}, + padding="max_length", + max_length=self.rollout_config.response_length, + return_tensors="pt", + return_attention_mask=False, + ) + if response_mask_output["input_ids"].dim() == 1: + response_mask_output["input_ids"] = response_mask_output["input_ids"].unsqueeze(0) + + response_logprobs = None + if output.response_logprobs is not None: + pad_size = self.rollout_config.response_length - len(output.response_logprobs) + response_logprobs = torch.tensor(output.response_logprobs + [0.0] * pad_size).unsqueeze(0) + + response_mask = response_mask_output["input_ids"] * response_output["attention_mask"] + attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1) + input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1) + + routed_experts = None + if output.routed_experts is not None: + total_length = input_ids.shape[1] + length, layer_num, topk_num = output.routed_experts.shape + if isinstance(output.routed_experts, np.ndarray): + routed_experts_array = output.routed_experts + if not routed_experts_array.flags.writeable: + routed_experts_array = routed_experts_array.copy() + experts_tensor = torch.from_numpy(routed_experts_array) + elif isinstance(output.routed_experts, torch.Tensor): + experts_tensor = output.routed_experts + else: + raise TypeError(f"Unsupported type for routed_experts: {type(output.routed_experts)}") + routed_experts = torch.zeros(1, total_length, layer_num, topk_num, dtype=experts_tensor.dtype) + start_pos = prompt_output["input_ids"].shape[1] - len(output.prompt_ids) + end_pos = min(start_pos + length, total_length) + routed_experts[:, start_pos:end_pos] = experts_tensor.unsqueeze(0) + + multi_modal_inputs = self._compute_multi_modal_inputs(output, input_ids) + position_ids = self._compute_position_ids(input_ids, attention_mask, multi_modal_inputs) + + return _InternalAgentLoopOutput( + prompt_ids=prompt_output["input_ids"], + response_ids=response_output["input_ids"], + input_ids=input_ids, + position_ids=position_ids, + response_mask=response_mask, + attention_mask=attention_mask, + response_logprobs=response_logprobs, + routed_experts=routed_experts, + multi_modal_inputs=multi_modal_inputs, + multi_modal_data=output.multi_modal_data, + teacher_logprobs=None, + teacher_ids=None, + reward_score=None, + num_turns=output.num_turns, + metrics=output.metrics, + extra_fields=output.extra_fields, + ) + + def _build_reward_input( + self, padded_output: _InternalAgentLoopOutput, kwargs: dict, validate: bool + ) -> DataProto: + prompts = padded_output.prompt_ids + responses = padded_output.response_ids + attention_mask = padded_output.attention_mask + input_ids = padded_output.input_ids + position_ids = padded_output.position_ids + + batch = TensorDict( + { + "prompts": prompts, + "responses": responses, + "attention_mask": attention_mask, + "input_ids": input_ids, + "position_ids": position_ids, + }, + batch_size=1, + ) + + meta_info = {"validate": validate} + non_tensor_dict = {k: np.array([v]) for k, v in kwargs.items()} + non_tensor_dict["__num_turns__"] = np.array([padded_output.num_turns]) + non_tensor_dict["tool_extra_fields"] = np.array([padded_output.extra_fields], dtype=object) + + return DataProto(batch=batch, non_tensor_batch=non_tensor_dict, meta_info=meta_info) + + async def _impl_generate_single_for_reward_queue( + self, batch: DataProto + ) -> tuple[DataProto, DataProto, float, float]: + config = self.rollout_config + sampling_params = dict( + temperature=config.temperature, + top_p=config.top_p, + top_k=config.top_k, + repetition_penalty=1.0, + logprobs=config.calculate_log_probs, + ) + + if batch.meta_info.get("validate", False): + sampling_params["top_p"] = config.val_kwargs.top_p + sampling_params["top_k"] = config.val_kwargs.top_k + sampling_params["temperature"] = config.val_kwargs.temperature + + if "agent_name" not in batch.non_tensor_batch: + default_agent_loop = config.agent.default_agent_loop + batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object) + + if "index" in batch.non_tensor_batch: + index_val = batch.non_tensor_batch["index"] + index = [index_val[0] if isinstance(index_val, (list, np.ndarray)) else index_val] + else: + index = [0] + + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) + ) + + kwargs = {k: v[0] for k, v in batch.non_tensor_batch.items()} + validate = batch.meta_info.get("validate", False) + + output = await self._run_reward_agent_loop( + sampling_params, trajectory_info[0], **kwargs + ) + + padded_output = self._padding_postprocess(output, validate=validate, **kwargs) + reward_input = self._build_reward_input(padded_output, kwargs, validate=validate) + + batch_td = TensorDict( + { + "prompts": padded_output.prompt_ids, + "responses": padded_output.response_ids, + "response_mask": padded_output.response_mask, + "input_ids": padded_output.input_ids, + "attention_mask": padded_output.attention_mask, + "position_ids": padded_output.position_ids, + }, + batch_size=1, + ) + if padded_output.response_logprobs is not None: + batch_td["rollout_log_probs"] = padded_output.response_logprobs + if padded_output.routed_experts is not None: + batch_td["routed_experts"] = padded_output.routed_experts + + non_tensor_batch = {} + non_tensor_batch["__num_turns__"] = np.array([padded_output.num_turns], dtype=np.int32) + + default_extra_keys = { + "turn_scores", "tool_rewards", "min_global_steps", "max_global_steps", + "extras", "inference_start_timestamp", "inference_end_timestamp", "inference_duration", + } + all_keys = set(padded_output.extra_fields.keys()) | default_extra_keys + for key in all_keys: + val = padded_output.extra_fields.get(key) + if isinstance(val, (list, tuple, dict)): + arr = np.empty(1, dtype=object) + arr[0] = val + else: + arr = np.array([val], dtype=object) + non_tensor_batch[key] = arr + + for k, v in batch.non_tensor_batch.items(): + if isinstance(v, (list, np.ndarray)): + non_tensor_batch[k] = np.array([v[0]], dtype=object) if len(v) == 1 else v[:1] + else: + non_tensor_batch[k] = np.array([v], dtype=object) + + if padded_output.multi_modal_inputs is not None: + non_tensor_batch["multi_modal_inputs"] = np.array([padded_output.multi_modal_inputs], dtype=object) + + metrics = [padded_output.metrics.model_dump()] + meta_info = {"metrics": metrics} + + padded_dp = DataProto(batch=batch_td, non_tensor_batch=non_tensor_batch, meta_info=meta_info) + + inf_start = padded_output.extra_fields.get("inference_start_timestamp", 0.0) + inf_end = padded_output.extra_fields.get("inference_end_timestamp", 0.0) + + return padded_dp, reward_input, inf_start, inf_end + + async def generate_sequences(self, batch: DataProto) -> tuple[DataProto, DataProto, float, float] | DataProto: + if batch.meta_info.get("validate", False): + return await super().generate_sequences(batch) + return await self._impl_generate_single_for_reward_queue(batch) + + diff --git a/reward_queue/config/fully_async.yaml b/reward_queue/config/fully_async.yaml new file mode 100644 index 00000000..47ef6aa2 --- /dev/null +++ b/reward_queue/config/fully_async.yaml @@ -0,0 +1,75 @@ +hydra: + searchpath: + - pkg://verl.trainer.config + +defaults: + - ppo_trainer + - _self_ + +async_training: + + # Maximum samples staleness threshold + staleness_threshold: 0.1 + + # Frequency of parameter synchronization between rollouter and trainer, + # One step means trainer obtains a batch of required samples + trigger_parameter_sync_step: 4 + + # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once + require_batches: 1 + + # When synchronizing parameters, Whether to resume generation when rollout is interrupted. + # If True, LLMServerClient auto resume generation, making rollout interruption invisible to the AgentLoop. + partial_rollout: True + + # whether to use trainer do_validate + use_trainer_do_validate: False + + # 推理-打分解耦异步队列 + enable_reward_queue: false + + # RewardQueue 最大队列大小(仅enable_reward_queue=true时生效,null表示使用默认值) + reward_queue_size: null + +# Rollout config +rollout: + + # Number of nodes used in the rollout + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # number of responses (i.e. num sample times). > 1 for grpo + n: 4 + + # total rollout samples # TODO rename to total_rollout_samples + total_rollout_steps: 100 + +data: + # Number of samples generated, currently only support 1 + gen_batch_size: 1 + +actor_rollout_ref: + + rollout: + # Must be enabled! Otherwise, log_probs cannot be calculated. + calculate_log_probs: True + + checkpoint_engine: + backend: "nccl" + + actor: + # Must use rollout log probs for training + use_rollout_log_probs: True + + model: + # To use remove padding (thd) + use_remove_padding: True + + +# Only then will the use of log probs be correct. +# And it can be used in conjunction with other rollout_correction algorithms. +algorithm: + rollout_correction: + bypass_mode: True diff --git a/reward_queue/images/reward_queue_architecture.png b/reward_queue/images/reward_queue_architecture.png new file mode 100644 index 00000000..898f18b0 Binary files /dev/null and b/reward_queue/images/reward_queue_architecture.png differ diff --git a/reward_queue/main.py b/reward_queue/main.py new file mode 100644 index 00000000..032fae4c --- /dev/null +++ b/reward_queue/main.py @@ -0,0 +1,142 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback + +import hydra +import ray +from omegaconf import OmegaConf +from verl.experimental.fully_async_policy.fully_async_main import FullyAsyncTaskRunner +from verl.experimental.reward_loop import migrate_legacy_reward_impl +from verl.utils.device import auto_set_device +from .reward_queue import create_reward_queue, RewardQueueClient +from .rollouter import Rollouter +from .trainer import Trainer +from verl.trainer.ppo.utils import Role +from verl.experimental.separation.utils import create_resource_pool_manager + +ray_metadata = getattr(FullyAsyncTaskRunner, "__ray_metadata__") +OriginalTaskRunner = ray_metadata.modified_class + +@ray.remote(num_cpus=1) +class TaskRunner(OriginalTaskRunner): + """ + Ray remote class for executing distributed PPO training tasks. + """ + def run(self, config): + print("[ASYNC MAIN] Starting fully async PPO training...") + self._initialize_components(config) + enable_reward_queue = config.async_training.get("enable_reward_queue", False) + if enable_reward_queue: + self._init_reward_queue_components(config) + self._run_training_loop() + + def _init_reward_queue_components(self, config): + max_reward_queue_size = ray.get(self.components["rollouter"].get_max_reward_queue_size.remote()) + print(f"[ASYNC MAIN] Creating RewardQueue... max_reward_queue_size {max_reward_queue_size}") + reward_queue = create_reward_queue(config, max_reward_queue_size) + reward_queue_client = RewardQueueClient(reward_queue) + self.components["reward_queue"] = reward_queue + self.components["reward_queue_client"] = reward_queue_client + + ray.get(self.components["rollouter"].set_reward_queue_client.remote(self.components["reward_queue_client"])) + + reward_loop_worker_handles = ray.get(self.components["rollouter"].get_reward_loop_worker_handles.remote()) + if reward_loop_worker_handles: + ray.get(self.components["rollouter"].set_reward_loop_worker_handles.remote(reward_loop_worker_handles)) + print(f"[ASYNC MAIN] Set reward_loop_worker_handles: {len(reward_loop_worker_handles)} workers") + else: + print("[ASYNC MAIN] WARNING: No reward_loop_worker_handles available, scoring will not work") + + def _create_rollouter(self, config) -> None: + print("[ASYNC MAIN] Starting create rollouter...") + rollouter = Rollouter.remote( + config=config, + tokenizer=self.components["tokenizer"], + processor=self.components["processor"], + device_name=config.trainer.device, + ) + + # set_hybrid_worker_group must be called BEFORE init_workers() so that + # _init_async_rollout_manager can pass the hybrid WG to ALM.create(). + if "hybrid_worker_group" in self.components: + ray.get(rollouter.set_hybrid_worker_group.remote(self.components["hybrid_worker_group"])) + print("[ASYNC MAIN] Hybrid worker group injected into rollouter") + + ray.get(rollouter.init_workers.remote()) + ray.get(rollouter.set_max_required_samples.remote()) + + self.components["rollouter"] = rollouter + print("[ASYNC MAIN] Rollouter created and initialized successfully") + + def _create_trainer(self, config) -> None: + print("[ASYNC MAIN] Starting create trainer...") + trainer_role_mapping = { + role: worker_cls + for role, worker_cls in self.components["role_worker_mapping"].items() + if role != Role.Rollout + } + + trainer = Trainer.remote( + config=config, + tokenizer=self.components["tokenizer"], + role_worker_mapping=trainer_role_mapping, + resource_pool_manager=create_resource_pool_manager(config, roles=list(trainer_role_mapping.keys())), + ray_worker_group_cls=self.components["ray_worker_group_cls"], + device_name=config.trainer.device, + ) + + ray.get(trainer.init_workers.remote()) + self.components["trainer"] = trainer + print("[ASYNC MAIN] FullyAsyncTrainer created and initialized successfully") + + +@hydra.main(config_path="config", config_name="fully_async", version_base=None) +def main(config): + from verl.trainer.main_ppo import run_ppo + + # Ensure async training config exists + if not hasattr(config, "async_training"): + raise RuntimeError("must set async_training config") + + from time import time + + start_time = time() + auto_set_device(config) + # TODO: unify rollout config with actor_rollout_ref + config.actor_rollout_ref.rollout.nnodes = config.rollout.nnodes + config.actor_rollout_ref.rollout.n_gpus_per_node = config.rollout.n_gpus_per_node + config = migrate_legacy_reward_impl(config) + + is_train_over = False + auto_resume_on_error = config.trainer.get("auto_resume_on_error", False) + while not is_train_over: + try: + run_ppo(config, task_runner_class=TaskRunner) + is_train_over = True + except (ray.exceptions.RayTaskError, Exception) as e: + print(e, str(traceback.format_exc())) + # raise e + finally: + ray.shutdown() + if not auto_resume_on_error: + break + + print("training process successfully!") + print(f"total time: {time() - start_time:.2f} seconds") + + +if __name__ == "__main__": + main() + diff --git a/reward_queue/reward_queue.py b/reward_queue/reward_queue.py new file mode 100644 index 00000000..458ef7a7 --- /dev/null +++ b/reward_queue/reward_queue.py @@ -0,0 +1,29 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from omegaconf import DictConfig + +from verl.experimental.fully_async_policy.message_queue import MessageQueue, MessageQueueClient + +logger = logging.getLogger(__name__) + + +def create_reward_queue(config: DictConfig, max_queue_size: int = 1000): + # return MessageQueue.remote(config, max_queue_size, name="RewardQueue") + return MessageQueue.options(name="RewardQueue").remote(config, max_queue_size) + + +class RewardQueueClient(MessageQueueClient): + pass diff --git a/reward_queue/rollouter.py b/reward_queue/rollouter.py new file mode 100644 index 00000000..ff6ce9c5 --- /dev/null +++ b/reward_queue/rollouter.py @@ -0,0 +1,799 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os +import random +import time +from pprint import pformat + +import numpy as np +import ray +import torch +from omegaconf import DictConfig, OmegaConf +from .agent_loop.agent_loop import AgentLoopWorkerForRewardQueue +from .utils import ( + SampleAggregator, + SubRewardDataItem, + _ScoredSubItem, +) +from verl.experimental.fully_async_policy.detach_utils import ( + RolloutSample, + safe_create_task, +) +from verl.experimental.fully_async_policy.fully_async_rollouter import ( + FullyAsyncAgentLoopManager, + FullyAsyncRollouter, + FullyAsyncLLMServerManager, + FullyAsyncLLMServerClient +) +from .reward_queue import RewardQueueClient +from verl.protocol import DataProto +from verl.utils.profiler import marked_timer +from verl.utils.ray_utils import auto_await +from verl.utils.tracking import Tracking + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AsyncAgentLoopManager(FullyAsyncAgentLoopManager): + def __init__(self, *args, **kwargs): + config = kwargs.get('config') + self.enable_reward_queue = config.async_training.get("enable_reward_queue", False) if config else False + if self.enable_reward_queue: + self.agent_loop_workers_class = ray.remote(AgentLoopWorkerForRewardQueue) + super().__init__(*args, **kwargs) + + @auto_await + async def generate_single_for_reward_queue( + self, batch: DataProto + ) -> tuple[DataProto, DataProto, float, float]: + worker = self._select_best_worker() + output_future = worker.generate_sequences.remote(batch) + return await asyncio.wrap_future(output_future.future()) + + +ray_metadata = getattr(FullyAsyncRollouter, "__ray_metadata__") +OriginalRollouter = ray_metadata.modified_class + +@ray.remote(num_cpus=10, max_concurrency=100) +class Rollouter(OriginalRollouter): + def __init__( + self, + config, + tokenizer, + processor=None, + device_name=None, + ): + super().__init__(config=config, tokenizer=tokenizer, processor=processor, device_name=device_name) + + self.reward_queue_client = None + self.reward_loop_worker_handles = None + + self.sample_aggregator = SampleAggregator() + self.rollout_n: int | None = None + self.scoring_paused = False + + self.enable_reward_queue = config.async_training.get("enable_reward_queue", False) + + def _init_async_objects(self): + self.lock = asyncio.Lock() + self._resume_event = asyncio.Event() + self._resume_event.set() + # `_scoring_resume_event` signals that the scoring is currently running (scoring_paused == False). + self._scoring_resume_event = asyncio.Event() + self._scoring_resume_event.set() + + async def set_reward_queue_client(self, reward_queue_client: RewardQueueClient): + async with self.lock: + self.reward_queue_client = reward_queue_client + + async def set_max_required_samples(self): + async with self.lock: + self.max_required_samples = int( + self.required_samples + * (self.staleness_threshold + 1) + * self.config.async_training.trigger_parameter_sync_step + ) + self.total_train_steps = int( + self.total_rollout_steps + / (self.required_samples * self.config.async_training.trigger_parameter_sync_step) + ) + + self.max_concurrent_samples = len(self.llm_server_manager.get_replicas()) * 16 + self.max_concurrent_samples = min(self.max_concurrent_samples, self.max_required_samples) + self.max_queue_size = self.max_required_samples + + self.rollout_n = self.config.actor_rollout_ref.rollout.n + if self.enable_reward_queue: + self._init_reward_queue_size() + + print( + f"[FullyAsyncRollouter] required_samples : {self.required_samples} " + f"max_required_samples: {self.max_required_samples} " + f"max_queue_size: {self.max_queue_size} " + f"total_train_steps: {self.total_train_steps} " + f"total_rollout_steps: {self.total_rollout_steps} " + f"max_concurrent_samples: {self.max_concurrent_samples} " + ) + + def _init_reward_queue_size(self): + reward_queue_size = self.config.async_training.get("reward_queue_size", None) + if reward_queue_size is not None: + self.max_reward_queue_size = reward_queue_size * self.rollout_n + else: + self.max_reward_queue_size = self.max_required_samples * self.rollout_n + + async def _init_async_rollout_manager(self): + enable_agent_reward_loop = not self.use_rm or self.config.reward.reward_model.enable_resource_pool + reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None + assert self.config.actor_rollout_ref.rollout.mode == "async" + + self.async_rollout_mode = True + self.llm_server_manager = await FullyAsyncLLMServerManager.create( + config=self.config, + worker_group=self.get_hybrid_worker_group(), + ) + self.async_rollout_manager = await AsyncAgentLoopManager.create( + config=self.config, + llm_client=self.llm_server_manager.get_client(client_cls=FullyAsyncLLMServerClient), + reward_loop_worker_handles=reward_loop_worker_handles, + teacher_client=self.teacher_model_manager.get_client() if self.teacher_model_manager else None, + ) + + def get_max_reward_queue_size(self): + return getattr(self, "max_reward_queue_size", None) + + def get_reward_loop_worker_handles(self): + if hasattr(self, "reward_loop_manager") and self.reward_loop_manager is not None: + if hasattr(self.reward_loop_manager, "reward_loop_workers"): + return self.reward_loop_manager.reward_loop_workers + return None + + def set_reward_loop_worker_handles(self, handles): + self.reward_loop_worker_handles = handles + + async def reset_staleness(self): + """ + Reset staleness samples after parameter update. + Returns timing_raw dictionary for metrics. + """ + async with self.lock: + self.paused = False + self.scoring_paused = False + # Wake the drain loop in _processor_worker so it can exit early and resume submitting + # new samples to idle replicas instead of waiting for long-tail in-flight tasks. + self._resume_event.set() + self._scoring_resume_event.set() + # every time param change, reset staleness_samples + self.staleness_samples = len(self.active_tasks) + await self.message_queue_client.get_queue_size() + timing_raw = {} + rollout_version_time = max(time.time() - self.step_start_time, 1e-6) + if self.idle_start_time > self.step_start_time: + rollout_active_time = self.idle_start_time - self.step_start_time + idle_ratio = 1 - rollout_active_time / rollout_version_time + else: + rollout_active_time = rollout_version_time + idle_ratio = 0 + timing_raw["fully_async/rollouter/active_time"] = rollout_active_time + timing_raw["fully_async/rollouter/version_time"] = rollout_version_time + timing_raw["fully_async/rollouter/idle_ratio"] = idle_ratio + + print( + f"[FullyAsyncRollouter][Public][reset_staleness] " + f"reset staleness_samples to: {self.staleness_samples} " + f"idle_ratio: {timing_raw['fully_async/rollouter/idle_ratio']:.4f}" + ) + self.step_start_time = time.time() + + return timing_raw + + def do_validate(self): + """Run validation and return metrics""" + timing_raw = {} + with marked_timer("rollouter/validate_time", timing_raw, color="green"): + val_metrics: dict = self._validate() + return timing_raw | val_metrics + + async def _processor_worker(self): + """ + Streaming worker coroutines, a sample is submitted for processing without waiting for batches + """ + while True: + if self.paused or await self._should_pause_generation(): + print( + "[FullyAsyncRollouter][Processor] Received pause signal, waiting for remaining tasks to return..." + ) + async with self.lock: + self.paused = True + self._resume_event.clear() + + resume_future = asyncio.ensure_future(self._resume_event.wait()) + try: + # Drain: wait for either (a) at least one active task to finish, or + # (b) a resume signal (reset_staleness / monitor flipping paused=False) to + # break the drain early so new samples can be submitted to free replicas. + # We do NOT hold the lock during the wait, so publishers can acquire it to + # update paused / staleness_samples concurrently. + while self.active_tasks and not resume_future.done(): + wait_set = set(self.active_tasks) | {resume_future} + done, _pending = await asyncio.wait(wait_set, return_when=asyncio.FIRST_COMPLETED) + actual_done = done - {resume_future} + if actual_done: + async with self.lock: + for task in actual_done: + self.active_tasks.discard(task) + await task + if resume_future in done: + print( + "[FullyAsyncRollouter][Processor] " + "Drain interrupted by resume signal, resuming generation early " + f"(active tasks remaining: {len(self.active_tasks)})" + ) + break + + # block until resuming + if not resume_future.done(): + self.idle_start_time = time.time() + await resume_future + finally: + if not resume_future.done(): + resume_future.cancel() + await asyncio.gather(resume_future, return_exceptions=True) + continue + # Get sample from appropriate queue and immediately mark task as done + rollout_sample = await self.pending_queue.get() + self.pending_queue.task_done() + self.staleness_samples += 1 + + if rollout_sample is None: + print( + "[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete..." + ) + while self.active_tasks: + async with self.lock: + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + # Signal RewardQueue that no more items will be produced + if self.enable_reward_queue: + await self.reward_queue_client.shutdown() + break + + # Check whether the number of concurrent tasks exceeds the limit + while len(self.active_tasks) >= self.max_concurrent_samples: + async with self.lock: + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + + # Submit single sample processing + if self.paused: + await self._resume_event.wait() + async with self.lock: + if self.enable_reward_queue: + task = safe_create_task( + self._process_sample_with_reward_queue(rollout_sample), + name=rollout_sample.sample_id, + task_set=self.active_tasks, + ) + else: + task = safe_create_task( + self._process_single_sample_streaming(rollout_sample), + name=rollout_sample.sample_id, + task_set=self.active_tasks, + ) + + async def _process_sample_with_reward_queue(self, rollout_sample: RolloutSample): + batch = rollout_sample.full_batch + n = len(batch) + sample_id = rollout_sample.sample_id + generate_start = time.time() + + inference_tasks = {} + for i in range(n): + sub_batch = batch[i: i + 1] + task = asyncio.create_task( + self.async_rollout_manager.generate_single_for_reward_queue(sub_batch) + ) + inference_tasks[task] = i + + pending = set(inference_tasks.keys()) + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + i = inference_tasks[task] + try: + padded_dp, reward_input_dp, inf_start, inf_end = task.result() + except Exception as e: + print( + f"[FullyAsyncRollouter][Processor] Inference FAILED for {sample_id}[{i}]: {e}" + ) + self.dropped_stale_samples += 1 + continue + + sub_item = SubRewardDataItem( + reward_input=reward_input_dp, + padded_output=padded_dp, + sample_id=sample_id, + epoch=rollout_sample.epoch, + sub_index=i, + total_count=n, + inference_start_timestamp=inf_start, + inference_end_timestamp=inf_end, + ) + + success = await self.reward_queue_client.put_sample(sub_item) + if not success: + self.dropped_stale_samples += 1 + + generate_end = time.time() + + self.processed_sample_count += 1 + if self.processed_sample_count % 10 == 0: + rq_stats = await self.reward_queue_client.get_statistics() + print( + f"[FullyAsyncRollouter][Processor] Inference progress: " + f"processed={self.processed_sample_count} " + f"rq_size={rq_stats['queue_size']} sample_id={sample_id} " + f"generate_time={generate_end - generate_start:.2f}s" + ) + + async def _reward_consumer_worker(self): + active_reward_tasks = set() + max_concurrent_rewards = getattr(self, "max_concurrent_samples", 16) * (self.rollout_n or 1) + self.scoring_paused = False + diagnostic_counter = 0 + + while True: + diagnostic_counter += 1 + if diagnostic_counter % 50 == 0: + try: + rq_stats = await self.reward_queue_client.get_statistics() + mq_stats = await self.message_queue_client.get_statistics() + print( + f"[FullyAsyncRollouter][RewardConsumer] diagnostic: " + f"rq_size={rq_stats['queue_size']} rq_produced={rq_stats['total_produced']} " + f"rq_consumed={rq_stats['total_consumed']} " + f"mq_size={mq_stats['queue_size']} mq_produced={mq_stats['total_produced']} " + f"active_reward_tasks={len(active_reward_tasks)} " + f"aggregator_pending={self.sample_aggregator.pending_groups_count} " + f"scoring_paused={self.scoring_paused} " + f"staleness_samples={self.staleness_samples}" + ) + except Exception as e: + print(f"[FullyAsyncRollouter][RewardConsumer] diagnostic failed: {e}") + + if self.scoring_paused: + await self._scoring_resume_event.wait() + if self.scoring_paused or await self._should_pause_scoring(): + mq_stats = await self.message_queue_client.get_statistics() + print( + f"[FullyAsyncRollouter][RewardConsumer] Pausing scoring: " + f"mq_size={mq_stats['queue_size']} >= max_queue_size={self.max_queue_size}" + ) + async with self.lock: + self.scoring_paused = True + self._scoring_resume_event.clear() + + scoring_resume_future = asyncio.ensure_future(self._scoring_resume_event.wait()) + try: + while active_reward_tasks and not scoring_resume_future.done(): + wait_set = set(active_reward_tasks) | {scoring_resume_future} + done, _pending = await asyncio.wait(wait_set, return_when=asyncio.FIRST_COMPLETED) + actual_done = done - {scoring_resume_future} + if actual_done: + for task in actual_done: + active_reward_tasks.discard(task) + await task + if scoring_resume_future in done: + print( + "[FullyAsyncRollouter][RewardConsumer] " + "Drain interrupted by resume signal, resuming scoring early " + f"(active_reward_tasks remaining: {len(active_reward_tasks)})" + ) + break + + if not scoring_resume_future.done(): + await scoring_resume_future + finally: + if not scoring_resume_future.done(): + scoring_resume_future.cancel() + await asyncio.gather(scoring_resume_future, return_exceptions=True) + continue + + while len(active_reward_tasks) >= max_concurrent_rewards: + if not active_reward_tasks: + break + done_tasks, active_reward_tasks = await asyncio.wait( + active_reward_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + + result = await self.reward_queue_client.get_sample() + if result is None: + while active_reward_tasks: + done_tasks, active_reward_tasks = await asyncio.wait( + active_reward_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + break + + item, _ = result + + task = safe_create_task( + self._process_single_sub_reward_item(item), + name=f"reward_{item.sample_id}_{item.sub_index}", + task_set=active_reward_tasks, + ) + + async def _process_single_sub_reward_item(self, item: SubRewardDataItem): + sample_id = item.sample_id + + if not self.reward_loop_worker_handles: + print(f"[FullyAsyncRollouter][RewardConsumer] No reward_loop_worker_handles, discard: {sample_id}[{item.sub_index}]") + self.dropped_stale_samples += 1 + return + + reward_start = time.time() + try: + selected_worker = random.choice(self.reward_loop_worker_handles) + score_result = await selected_worker.compute_score.remote(item.reward_input) + score = score_result["reward_score"] + reward_extra_info = score_result.get("reward_extra_info", {}) + reward_compute_time = score_result.get("reward_compute_time", 0.0) + except Exception as e: + reward_compute_time = 0.0 + score = -1.0 + reward_extra_info = {} + print(f"[FullyAsyncRollouter][RewardConsumer] Scoring FAILED for {sample_id}[{item.sub_index}]: {e}") + reward_end = time.time() + + scored_item = _ScoredSubItem( + sub_index=item.sub_index, + padded_output=item.padded_output, + score=score, + reward_extra_info=reward_extra_info, + reward_start=reward_start, + reward_end=reward_end, + reward_compute_time=reward_compute_time, + inference_start=item.inference_start_timestamp, + inference_end=item.inference_end_timestamp, + ) + is_complete = self.sample_aggregator.add_scored_item( + sample_id=sample_id, + total_count=item.total_count, + epoch=item.epoch, + scored_item=scored_item, + ) + + if is_complete: + await self._finalize_sample(sample_id) + + async def _finalize_sample(self, sample_id: str): + group = self.sample_aggregator.get_and_remove(sample_id) + sorted_items = [group.items[i] for i in sorted(group.items.keys())] + total_count = group.total_count + + scores = [si.score for si in sorted_items] + all_reward_extra_info = [si.reward_extra_info for si in sorted_items] + reward_start_ts_list = [si.reward_start for si in sorted_items] + reward_end_ts_list = [si.reward_end for si in sorted_items] + reward_compute_time_list = [si.reward_compute_time for si in sorted_items] + inference_start_list = [si.inference_start for si in sorted_items] + inference_end_list = [si.inference_end for si in sorted_items] + + failed_count = sum(1 for s in scores if s == -1.0) + print( + f"[FullyAsyncRollouter][RewardConsumer] Score summary for {sample_id}: " + f"bsz={total_count} scores={scores} failed={failed_count} " + f"abnormal_score_threshold={self.config.trainer.get('dapo_threshold', 0)}" + ) + + final_batch = DataProto.concat([si.padded_output for si in sorted_items]) + + prompt_length = final_batch.batch["prompts"].shape[1] + response_mask = final_batch.batch["response_mask"] + attention_mask = final_batch.batch["attention_mask"] + valid_response_length = attention_mask[:, prompt_length:].sum(dim=1) - 1 + rm_scores = torch.zeros_like(response_mask, dtype=torch.float32) + rm_scores[torch.arange(response_mask.size(0)), valid_response_length] = torch.tensor( + scores, dtype=torch.float32 + ) + final_batch.batch["rm_scores"] = rm_scores + + final_batch.non_tensor_batch["score"] = np.array(scores) + reward_extra_keys = [] + if all_reward_extra_info and isinstance(all_reward_extra_info[0], dict): + reward_extra_keys = list(all_reward_extra_info[0].keys()) + for key in reward_extra_keys: + final_batch.non_tensor_batch[key] = np.array( + [info.get(key) for info in all_reward_extra_info], dtype=object + ) + final_batch.non_tensor_batch["reward_compute_time"] = np.array(reward_compute_time_list) + final_batch.non_tensor_batch["reward_start_timestamp"] = np.array(reward_start_ts_list) + final_batch.non_tensor_batch["reward_end_timestamp"] = np.array(reward_end_ts_list) + final_batch.non_tensor_batch["inference_start_timestamp"] = np.array(inference_start_list) + final_batch.non_tensor_batch["inference_end_timestamp"] = np.array(inference_end_list) + + if reward_extra_keys: + final_batch.meta_info["reward_extra_keys"] = reward_extra_keys + + final_batch.non_tensor_batch["uid"] = np.array([f"uid_{sample_id}"] * len(final_batch), dtype=object) + rollout_status = await self.get_statistics() + + rollout_sample = RolloutSample( + full_batch=final_batch, + sample_id=sample_id, + epoch=group.epoch, + rollout_status=rollout_status, + ) + + success = await self.message_queue_client.put_sample( + sample=ray.cloudpickle.dumps(rollout_sample), + ) + + if success: + self.total_generated_samples += 1 + if self.total_generated_samples % 10 == 0: + mq_stats = await self.message_queue_client.get_statistics() + print( + f"[FullyAsyncRollouter][RewardConsumer] Put to MQ success: " + f"total_generated={self.total_generated_samples} " + f"mq_size={mq_stats['queue_size']} sample_id={sample_id}" + ) + else: + self.dropped_stale_samples += 1 + print( + f"[FullyAsyncRollouter][RewardConsumer] Put to MQ DROPPED: " + f"sample_id={sample_id} total_dropped={self.dropped_stale_samples}" + ) + + async def _should_pause_scoring(self) -> bool: + mq_stats = await self.message_queue_client.get_statistics() + mq_size = mq_stats["queue_size"] + if mq_size >= self.max_queue_size: + return True + return False + + def _maybe_create_reward_consumer_task(self): + if self.enable_reward_queue: + return safe_create_task(self._reward_consumer_worker(), name="reward_consumer_task") + return None + + async def _maybe_wait_reward_consumer(self): + if self.reward_consumer_task: + await self.reward_consumer_task + print("[FullyAsyncRollouter] Reward consumer completed") + + async def _maybe_cancel_reward_consumer(self): + if self.reward_consumer_task and not self.reward_consumer_task.done(): + self.reward_consumer_task.cancel() + await asyncio.gather(self.reward_consumer_task, return_exceptions=True) + + async def _streaming_generation_main(self): + """The main entry method for stream processing""" + + if self.async_rollout_manager is None: + await self._init_async_rollout_manager() + + # Start the streaming loop + print(f"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}" + + (" (with RewardQueue)" if self.enable_reward_queue else "")) + + # Start sample feed coroutine, streaming process coroutine + self.feed_task = safe_create_task(self._feed_samples(), name="feed_task") + self.processor_task = safe_create_task(self._processor_worker(), name="processor_task") + self.reward_consumer_task = self._maybe_create_reward_consumer_task() + + try: + # Wait for sample feed to complete + # Use asyncio.wait to monitor all tasks. If processor exits early, + # detect it instead of blocking on feed_task (it might be stuck on a full queue). + tasks_to_wait = [self.feed_task, self.processor_task] + if self.reward_consumer_task: + tasks_to_wait.append(self.reward_consumer_task) + + done, pending = await asyncio.wait( + tasks_to_wait, return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + if task.exception(): + raise task.exception() + + if self.feed_task not in done: + raise RuntimeError("Processor task exited prematurely") + + print("[FullyAsyncRollouter] Sample feed completed") + + # Wait for streaming to complete + await self.processor_task + print("[FullyAsyncRollouter] Streaming process completed") + + await self.pending_queue.join() + print("[FullyAsyncRollouter] pending_queue joined") + + await self._maybe_wait_reward_consumer() + + except Exception as e: + print(f"[FullyAsyncRollouter] Streaming process exception: {e}") + raise e + + finally: + if self.feed_task and not self.feed_task.done(): + self.feed_task.cancel() + await asyncio.gather(self.feed_task, return_exceptions=True) + + if self.processor_task and not self.processor_task.done(): + self.processor_task.cancel() + await asyncio.gather(self.processor_task, return_exceptions=True) + + await self._maybe_cancel_reward_consumer() + + self.feed_task = None + self.processor_task = None + self.reward_consumer_task = None + + # Send a finish signal + await self.message_queue_client.put_sample(sample=None) + + async with self.lock: + self.running = False + + async def fit(self): + """ + Start the async rollouter - entry point that sets up and runs async tasks + Main async fit method that coordinates all coroutines + """ + + print("[FullyAsyncRollouter] Starting FullyAsyncRollouter...") + + if self.message_queue_client is None: + raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.") + + # Set the running status flag + async with self.lock: + self.paused = False + self.scoring_paused = False + self.running = True + self._resume_event.set() + self._scoring_resume_event.set() + + # Create the main asynchronous task + generation_task = safe_create_task(self._streaming_generation_main(), name="generation_task") + monitor_task = safe_create_task(self._async_monitor_loop(), name="monitor_task") + + try: + # Run build and monitoring tasks concurrently + await asyncio.gather(generation_task, monitor_task, return_exceptions=True) + except Exception as e: + print(f"[FullyAsyncRollouter] Asynchronous task execution error: {e}") + finally: + if not generation_task.done(): + generation_task.cancel() + if not monitor_task.done(): + monitor_task.cancel() + + # Wait for the task to complete + await asyncio.gather(generation_task, monitor_task, return_exceptions=True) + + print("[FullyAsyncRollouter] Rollouter fit completed") + + async def _async_monitor_loop(self): + """ + Async coroutine for monitoring: + Function 1: Log information output + Function 2: Trigger rollout recovery + """ + last_stats_time = time.time() + stats_interval = 60.0 + check_interval = 10.0 + + while True: + async with self.lock: + if not self.running: + break + await asyncio.sleep(check_interval) + # Print statistics periodically + current_time = time.time() + if current_time - last_stats_time >= stats_interval: + stats = await self.get_statistics() + print(f"[FullyAsyncRollouter][MonitorLoop][Statistics] {pformat(stats)}") + last_stats_time = current_time + + # Trigger rollout recovery + if self.paused and not await self._should_pause_generation(): + async with self.lock: + self.paused = False + print("[FullyAsyncRollouter][ShouldPause] resume rollouter.") + self._resume_event.set() + + # Trigger scoring recovery + if self.enable_reward_queue: + await self._maybe_resume_scoring() + + async def _maybe_resume_scoring(self): + if self.scoring_paused and not await self._should_pause_scoring(): + async with self.lock: + self.scoring_paused = False + print("[FullyAsyncRollouter][Monitor] Resuming scoring, calling _scoring_resume_event.set()") + self._scoring_resume_event.set() + + async def _should_pause_reward_queue(self) -> bool: + reward_queue_stats = await self.reward_queue_client.get_statistics() + reward_queue_size = reward_queue_stats["queue_size"] + + if reward_queue_size >= self.max_reward_queue_size: + if not self.paused: + print( + f"[FullyAsyncRollouter][ShouldPause] " + f"due to RewardQueue full: size={reward_queue_size}, max={self.max_reward_queue_size}" + ) + return True + return False + + async def _should_pause_generation(self) -> bool: + """Determine whether the build should be paused""" + if self.enable_reward_queue: + return await self._should_pause_reward_queue() + + return super()._should_pause_generation() + + async def get_statistics(self) -> dict: + queue_stats = await self.message_queue_client.get_statistics() + reward_queue_stats = await self.reward_queue_client.get_statistics() if self.reward_queue_client else None + + stats = { + # monitor stats + "monitor/active_tasks_size": len(self.active_tasks), + "monitor/queue/pending_queue_size": self.pending_queue.qsize(), + "monitor/queue/mq_queue_size": queue_stats["queue_size"], + # counting stats + "count/total_generated_samples": self.total_generated_samples, + "count/staleness_samples": self.staleness_samples, + "count/dropped_stale_samples": self.dropped_stale_samples, + # static stats + "static/max_required_samples": self.max_required_samples, + "static/required_samples": self.required_samples, + "static/staleness_threshold": self.staleness_threshold, + "static/max_queue_size": self.max_queue_size, + "static/max_concurrent_samples": self.max_concurrent_samples, + } + + if reward_queue_stats: + stats["monitor/queue/reward_queue_size"] = reward_queue_stats["queue_size"] + stats["reward_queue/total_produced"] = reward_queue_stats["total_produced"] + stats["reward_queue/total_consumed"] = reward_queue_stats["total_consumed"] + stats["reward_queue/dropped_samples"] = reward_queue_stats["dropped_samples"] + + if hasattr(self, "max_reward_queue_size"): + stats["static/max_reward_queue_size"] = self.max_reward_queue_size + if self.rollout_n is not None: + stats["static/rollout_n"] = self.rollout_n + if self.sample_aggregator is not None: + stats["aggregator/pending_groups_count"] = self.sample_aggregator.pending_groups_count + stats["aggregator/total_pending"] = self.sample_aggregator.total_pending + + return stats diff --git a/reward_queue/train_async.sh b/reward_queue/train_async.sh new file mode 100644 index 00000000..e82eabb5 --- /dev/null +++ b/reward_queue/train_async.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash + +set -xeuo pipefail + +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +PROJECT_NAME=${PROJECT_NAME:-'reward_queue'} +EXP_NAME=${EXP_NAME:-'verl_async_train'} +MODEL_PATH=${MODEL_PATH:-'Qwen3.5-9B'} +TRAIN_FILE=${TRAIN_FILE:-'./gsm8k/train/gsm8k_tra.jsonl'} +VAL_FILE=${VAL_FILE:-'./gsm8k/eval/gsm8k_ev.jsonl'} +CKPTS_DIR=${CKPTS_DIR:-"./ckpts/${project_name}/${exp_name}"} + +TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-8} +N_SAMPLE=${N_SAMPLE:-8} +VAL_N_SAMPLE=${VAL_N_SAMPLE:-5} + +MAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-1000} +MAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-2000} + +LR=${LR:-1e-6} +TOTAL_TRAINING_STEPS=${TOTAL_TRAINING_STEPS:-500} +TEST_FREQ=${TEST_FREQ:-5} + +ASYNC_STALENESS=${ASYNC_STALENESS:-0.3} +ASYNC_SYNC_STEP=${ASYNC_SYNC_STEP:-2} +ASYNC_REQUIRE_BATCHES=${ASYNC_REQUIRE_BATCHES:-4} + +TENSOR_MODEL_PARALLEL_SIZE=${TENSOR_MODEL_PARALLEL_SIZE:-8} + +python -m recipe.reward_queue.main \ + --config-path=config \ + --config-name='fully_async' \ + algorithm.adv_estimator=grpo \ + data.max_prompt_length=${MAX_PROMPT_LENGTH} \ + data.max_response_length=${MAX_RESPONSE_LENGTH} \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${VAL_FILE}" \ + rollout.nnodes=${NNODES} \ + rollout.n_gpus_per_node=${NGPUS_PER_NODE} \ + rollout.total_rollout_steps=${TOTAL_TRAINING_STEPS} \ + rollout.test_freq=${TEST_FREQ} \ + async_training.staleness_threshold=${ASYNC_STALENESS} \ + async_training.trigger_parameter_sync_step=${ASYNC_SYNC_STEP} \ + async_training.require_batches=${ASYNC_REQUIRE_BATCHES} \ + async_training.partial_rollout=true \ + async_training.use_trainer_do_validate=false \ + async_training.enable_reward_queue=true \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${TENSOR_MODEL_PARALLEL_SIZE} \ + actor_rollout_ref.rollout.n=${N_SAMPLE} \ + actor_rollout_ref.rollout.val_kwargs.n=${VAL_N_SAMPLE} \ + actor_rollout_ref.actor.optim.lr=${LR} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${TRAIN_BATCH_SIZE} \ + trainer.nnodes=${NNODES} \ + trainer.n_gpus_per_node=${NGPUS_PER_NODE} \ + trainer.test_freq=${TEST_FREQ} \ + trainer.save_freq=${TEST_FREQ} \ + trainer.project_name="${PROJECT_NAME}" \ + trainer.experiment_name="${EXP_NAME}" \ + trainer.default_local_dir="${CKPTS_DIR}" \ + "$@" + diff --git a/reward_queue/trainer.py b/reward_queue/trainer.py new file mode 100644 index 00000000..c1135abe --- /dev/null +++ b/reward_queue/trainer.py @@ -0,0 +1,111 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +import ray +from omegaconf import OmegaConf, open_dict + +from .utils import assemble_batch_from_rollout_samples +from verl.experimental.fully_async_policy.fully_async_trainer import FullyAsyncTrainer +from verl.single_controller.ray import RayWorkerGroup +from verl.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.trainer.ppo.utils import Role, WorkerType + +logger = logging.getLogger(__name__) + +ray_metadata = getattr(FullyAsyncTrainer, "__ray_metadata__") +OriginalTrainer = ray_metadata.modified_class + +@ray.remote(num_cpus=10) +class Trainer(OriginalTrainer): + """ + A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training. + Based on an improved implementation of OneStepOffRayTrainer + """ + + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + device_name=None, + ): + super().__init__(config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, device_name) + + self.enable_reward_queue = config.async_training.get("enable_reward_queue", False) + + async def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]: + """ + Get samples from message queue and compose gen_batch_output + Uses a loop to continuously collect samples until enough are gathered + + Returns: + tuple: (epoch, batch_dict, gen_batch_output) + """ + print( + f"[FullyAsyncTrainer] Requesting {self.required_samples} samples from queue", + flush=True, + ) + + # Collect samples using a simple loop calling get_sample + consumer_start = time.time() + queue_samples = [] + queue_len = 0 + while len(queue_samples) < self.required_samples: + # Get a single sample and wait until there is a sample or None is received + sample, queue_len = await self.message_queue_client.get_sample() + + if sample is None: + print( + f"[FullyAsyncTrainer] Detected termination signal (None), stopping sample collection. " + f"Collected {len(queue_samples)}/{self.required_samples} samples" + ) + break + + queue_samples.append(sample) + + if len(queue_samples) % 64 == 0: + print( + f"[FullyAsyncTrainer] Collected {len(queue_samples)}/{self.required_samples} samples. " + f"mq_len: {queue_len}" + ) + + consumer_end = time.time() + + if not queue_samples or len(queue_samples) < self.required_samples: + print("[FullyAsyncTrainer] not enough samples collected after loop") + return None, None + total_wait_time = consumer_end - consumer_start + + print( + f"[FullyAsyncTrainer] Loop collection completed: {len(queue_samples)}/{self.required_samples} samples, " + f"total wait time: {total_wait_time:.2f} seconds. " + f"mq_len: {queue_len}" + ) + + queue_samples = [ray.cloudpickle.loads(x) for x in queue_samples] + # Assemble batch - now working directly with RolloutSample objects + if self.config.trainer.balance_batch: + batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, self._balance_batch, enable_reward_queue=self.enable_reward_queue) + else: + batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, None, + enable_reward_queue=self.enable_reward_queue) + + batch.meta_info["fully_async/total_wait_time"] = total_wait_time + return 0, batch diff --git a/reward_queue/utils.py b/reward_queue/utils.py new file mode 100644 index 00000000..e691be2a --- /dev/null +++ b/reward_queue/utils.py @@ -0,0 +1,295 @@ +# Copyright 2026 Huawei Technologies Co., Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import torch + +from verl import DataProto +from verl.trainer.ppo.ray_trainer import compute_response_mask +from verl.experimental.fully_async_policy.detach_utils import RolloutSample + + +@dataclass +class SubRewardDataItem: + reward_input: DataProto + padded_output: DataProto + sample_id: str + epoch: int + sub_index: int + total_count: int + inference_start_timestamp: float + inference_end_timestamp: float + generate_data: Any = None + enqueue_data: Any = None + + +@dataclass +class _ScoredSubItem: + sub_index: int + padded_output: DataProto + score: float + reward_extra_info: dict + reward_start: float + reward_end: float + reward_compute_time: float + inference_start: float + inference_end: float + + +@dataclass +class _AggregationGroup: + total_count: int + epoch: int + items: dict[int, _ScoredSubItem] = field(default_factory=dict) + + def add(self, item: _ScoredSubItem): + self.items[item.sub_index] = item + + @property + def is_complete(self) -> bool: + return len(self.items) == self.total_count + + +class SampleAggregator: + + def __init__(self): + self._groups: dict[str, _AggregationGroup] = {} + + def add_scored_item( + self, + sample_id: str, + total_count: int, + epoch: int, + scored_item: _ScoredSubItem, + ) -> bool: + if sample_id not in self._groups: + self._groups[sample_id] = _AggregationGroup(total_count=total_count, epoch=epoch) + group = self._groups[sample_id] + group.add(scored_item) + return group.is_complete + + def get_and_remove(self, sample_id: str) -> _AggregationGroup: + return self._groups.pop(sample_id) + + @property + def total_pending(self) -> int: + return sum(len(g.items) for g in self._groups.values()) + + @property + def pending_sample_ids(self) -> list[str]: + return list(self._groups.keys()) + + @property + def pending_groups_count(self) -> int: + return len(self._groups) + + +def addition_process(output: DataProto, enable_reward_queue: bool = False): + """collect metirics""" + metrics = output.meta_info.pop("metrics") # List[Dict[str, str]] + if enable_reward_queue: + return _addition_process_with_reward_queue(output, metrics) + + processing_times_list = [item["generate_sequences"] for item in metrics] + tool_calls_times_list = [item["tool_calls"] for item in metrics] + + # Collect reward_compute_time if available + reward_compute_times_list = [] + for item in metrics: + if "reward_compute_time" in item: + reward_compute_times_list.append(item["reward_compute_time"]) + + output.non_tensor_batch["processing_times"] = processing_times_list + output.non_tensor_batch["tool_calls_times"] = tool_calls_times_list + + # Store reward_compute_time if any sample has it + if reward_compute_times_list: + output.non_tensor_batch["reward_compute_times"] = reward_compute_times_list + + return output + + +def _addition_process_with_reward_queue(output, metrics): + metrics = _normalize_metrics(metrics) + processing_times_list = [ + item.generate_sequences if hasattr(item, "generate_sequences") else item.get("generate_sequences", 0.0) + for item in metrics + ] + tool_calls_times_list = [ + item.tool_calls if hasattr(item, "tool_calls") else item.get("tool_calls", 0.0) + for item in metrics + ] + reward_compute_times_list = [] + for item in metrics: + rct = item.reward_compute_time if hasattr(item, "reward_compute_time") else item.get("reward_compute_time") + if rct is not None: + reward_compute_times_list.append(rct) + output.non_tensor_batch["processing_times"] = np.array(processing_times_list, dtype=np.float64) + output.non_tensor_batch["tool_calls_times"] = np.array(tool_calls_times_list, dtype=np.float64) + if reward_compute_times_list: + output.non_tensor_batch["reward_compute_times"] = reward_compute_times_list + + _flatten_non_tensor_batch(output) + return output + + +def _normalize_metrics(metrics): + if not metrics: + return [] + if isinstance(metrics, dict): + keys = list(metrics.keys()) + n = len(metrics[keys[0]]) if keys else 0 + return [{k: metrics[k][i] for k in keys} for i in range(n)] + if isinstance(metrics, list): + flat = [] + for item in metrics: + if isinstance(item, list): + flat.extend(item) + else: + flat.append(item) + return flat + return [metrics] + + +def _flatten_non_tensor_batch(data_proto: DataProto) -> None: + batch_size = len(data_proto) + for key in list(data_proto.non_tensor_batch.keys()): + arr = data_proto.non_tensor_batch[key] + if not isinstance(arr, np.ndarray): + arr = np.array(arr, dtype=object) + if arr.dtype == np.object_ and arr.ndim == 1: + continue + new_arr = np.empty(batch_size, dtype=object) + if arr.ndim == 1: + new_arr[:] = list(arr) + elif arr.ndim == 2: + for i in range(batch_size): + new_arr[i] = arr[i] + else: + for i in range(batch_size): + new_arr[i] = arr[i] + data_proto.non_tensor_batch[key] = new_arr + + + +def assemble_batch_from_rollout_samples( + rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None, enable_reward_queue: bool = False +) -> DataProto: + """ + Assemble gen_batch_output from RolloutSample objects + Assembles batches from RolloutSample objects, similar to the _post_generate_batch logic in ray_trainer. + + Args: + rollout_samples: List of RolloutSample objects + tokenizer: Tokenizer instance + config: Configuration object containing trainer settings + balance_batch: Whether to balance the batch (simplified version) + + Returns: + DataProto: Assembled gen_batch_output + + Raises: + ValueError: If rollout_samples is empty + """ + start_time = time.time() + + if not rollout_samples: + raise ValueError("Empty rollout_samples provided for batch assembly") + + print(f"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects") + + rollout_samples_batch = [] + rollout_status = rollout_samples[0].rollout_status + # Add a prefix to all rollout_status keys + rollout_status = {f"fully_async/{key}": value for key, value in rollout_status.items()} + + for rs in rollout_samples: + batch = addition_process(rs.full_batch, enable_reward_queue=enable_reward_queue) + rollout_samples_batch.append(batch) + final_batch = DataProto.concat(rollout_samples_batch) + + # Calculate response_mask (if not present) + if "response_mask" not in final_batch.batch.keys(): + final_batch.batch["response_mask"] = compute_response_mask(final_batch) + + if balance_batch: + balance_batch(final_batch, metrics={}) + + # Calculate the global valid token number + if "attention_mask" in final_batch.batch: + final_batch.meta_info["global_token_num"] = torch.sum(final_batch.batch["attention_mask"], dim=-1).tolist() + + processing_times = final_batch.non_tensor_batch["processing_times"] + tool_calls = final_batch.non_tensor_batch["tool_calls_times"] + # Collect statistics + processing_time_stats = { + "processing_time/avg": np.mean(processing_times), + "processing_time/max": np.max(processing_times), + "processing_time/min": np.min(processing_times), + "processing_time/tp50": np.percentile(processing_times, 50), + "processing_time/tp99": np.percentile(processing_times, 99), + "processing_time/tp95": np.percentile(processing_times, 95), + } + tool_calls_stats = {} + if len(tool_calls) > 0: + tool_calls_stats = { + "timing_s/agent_loop/tool_calls/max": np.max(tool_calls), + "timing_s/agent_loop/tool_calls/min": np.min(tool_calls), + "timing_s/agent_loop/tool_calls/mean": np.mean(tool_calls), + } + + # Collect reward_compute_time statistics if available + reward_compute_stats = {} + if "reward_compute_times" in final_batch.non_tensor_batch: + reward_compute_times = final_batch.non_tensor_batch["reward_compute_times"] + if len(reward_compute_times) > 0: + reward_compute_stats = { + "timing_s/reward_compute/max": np.max(reward_compute_times), + "timing_s/reward_compute/min": np.min(reward_compute_times), + "timing_s/reward_compute/mean": np.mean(reward_compute_times), + "timing_s/reward_compute/tp50": np.percentile(reward_compute_times, 50), + "timing_s/reward_compute/tp95": np.percentile(reward_compute_times, 95), + } + processing_time_stats = {f"fully_async/{key}": value for key, value in processing_time_stats.items()} + + param_version_start = final_batch.non_tensor_batch["min_global_steps"] + param_version_end = final_batch.non_tensor_batch["max_global_steps"] + param_version_diff = [abs(a - b) for a, b in zip(param_version_end, param_version_start, strict=False)] + num_diff0 = param_version_diff.count(0) + partial_stats = { + "fully_async/partial/total_partial_num": len(param_version_diff) - num_diff0, + "fully_async/partial/partial_ratio": (len(param_version_diff) - num_diff0) / len(param_version_diff), + "fully_async/partial/max_partial_span": max(param_version_diff), + } + # add meta_info + trajectory_param_versions = final_batch.non_tensor_batch["max_global_steps"] + + final_batch.meta_info.update( + { + "param_version_diversity": len(set(trajectory_param_versions)), + "trajectory_param_versions": trajectory_param_versions, + **processing_time_stats, + **rollout_status, + **partial_stats, + **tool_calls_stats, + **reward_compute_stats, + } + ) + + print(f"[BatchUtils] Batch assembly completed in {time.time() - start_time:.2f}s") + + return final_batch