diff --git a/dapo_predictor/README.md b/dapo_predictor/README.md new file mode 100644 index 00000000..003063e5 --- /dev/null +++ b/dapo_predictor/README.md @@ -0,0 +1,176 @@ +# DAPO Predictor Reorder + +This directory is a portable copy of `recipe/dapo_predictor`. You can copy it into a local `recipe/` tree, for example when adapting the predictor reorder flow around a local verl `0.7.1` environment. + +The feature adds predictor-driven prompt reordering to DAPO. Before rollout generation, prompts are scored by a lightweight predictor head and then reordered with serpentine packing so prompts with similar predicted response length are spread across data-parallel ranks. After actor update, the predictor head is trained from the observed rollout response lengths and is used again in the next step. + +## What It Changes + +- Uses `PredictorAsyncActorRolloutRefWorker` instead of the default actor-rollout worker for FSDP/FSDP2 actor rollout. +- Adds a linear predictor head on the actor worker: `nn.Linear(hidden_size, 1, bias=False)`. +- Scores one sample per prompt group before generation, expands the score to all `rollout.n` samples, and applies `snake_sort_indices`. +- Restores order after DP batch balancing before training the predictor, so labels still correspond to the original prompt groups. +- Trains only the predictor head during `update_predictor`; the actor update still follows the normal DAPO/PPO path. + +Prompt-reorder patch examples were removed from this branch; this package documents predictor-driven reorder only. + +## Entry Points + +- `main_dapo_predictor_reorder.py` + - Main DAPO entrypoint with predictor score + snake-sort reorder enabled. +- `main_dapo_reorder.py` + - Backward-compatible alias to the predictor-driven reorder entrypoint. + +## Implementation Modules + +- `predictor_dapo_trainer.py` + - Injects predictor scoring before rollout generation. + - Builds and applies predictor reorder indices. + - Reverses DP balancing order before predictor training. + - Calls `actor_rollout_wg.update_predictor(prompt_batch, batch)` after actor update. +- `predictor_worker.py` + - Adds `PredictorDataParallelPPOActor` with the linear predictor head. + - Implements `compute_predictor_score` and `update_predictor` worker RPCs. + - Extracts last-token hidden states from the actor model for scoring and training. +- `predictor_utils.py` + - Provides `snake_sort_indices` for prompt-level serpentine DP packing. + +## Runtime Flow + +1. Build `gen_batch` from the training batch and repeat each prompt `rollout.n` times. +2. Hydrate the predictor input with `input_ids`, `attention_mask`, and `position_ids` when needed. +3. Run `compute_predictor_score` on actor workers: + - sample one item from each prompt group, + - extract the last-token hidden state, + - score it with the predictor head, + - broadcast that score back to all samples from the same prompt. +4. Sort prompt groups by predictor score and apply serpentine DP packing through `snake_sort_indices`. +5. Generate rollouts with the reordered batch. +6. Continue normal reward, KL, advantage, critic, and actor update logic. +7. If DP batch balancing changed row order, restore the pre-balance order. +8. Train the predictor head using the latest prompt hidden states and observed response lengths. + +## Predictor Head Training + +The predictor head is trained online after each actor update. The training data comes from the same rollout step: + +- Inputs: prompt-side last-token hidden states extracted from `prompt_batch`. +- Labels: observed generated response lengths from `response_batch.batch["responses"]`. +- Prompt grouping: response lengths are reshaped by `rollout.n`, and the max response length in each prompt group is used as the label. +- Label scaling: response lengths are bucketed by `max(1, rollout.response_length // 40)` to keep label values in a stable range. +- Loss: ListMLE ranking loss, so the head learns the relative ordering of prompts by response length rather than an exact length regression target. +- Optimizer: AdamW over the linear predictor head only. +- Determinism: the predictor dataloader and ListMLE shuffle use `trainer.predictor_reorder.seed`. + +The update path gathers hidden states and labels across distributed ranks. When sequence parallelism is enabled, only SP rank 0 data from each DP group is used to avoid duplicated prompt samples. + +Metrics emitted by the predictor update include: + +- `predictor/epoch_0_loss` +- `predictor/epoch_0_kendall_tau` +- `predictor/epoch_{last}_loss` +- `predictor/epoch_{last}_kendall_tau` +- `predictor/final_loss` +- `predictor/epochs` +- `predictor/update_time_s` +- `predictor/total_samples` + +`scipy` is optional. If it is unavailable, Kendall tau metrics fall back to `0.0` instead of failing the worker. + +## Configuration + +Enable predictor reorder with Hydra overrides under `trainer.predictor_reorder`. The entrypoint mirrors this config to `actor_rollout_ref.predictor_reorder` so the worker can read it. + +Common options: + +| Option | Default | Description | +| ------ | ------- | ----------- | +| `enable` | `False` | Enables predictor scoring, reorder, and predictor head training. | +| `epochs` | `10` | Number of predictor-head training epochs per actor update. | +| `batch_size` | `32` | Batch size for predictor-head training. | +| `lr` | `3e-5` | AdamW learning rate for the predictor head. | +| `weight_decay` | `1e-4` | AdamW weight decay for the predictor head. | +| `seed` | `1` | Local seed used by predictor dataloader/ListMLE shuffling. | +| `predictor_keep_actor_loaded` | `False` | Keeps actor parameters on GPU across actor update when predictor training immediately follows. Useful when offload overhead is high. | + +## Launch Example + +```bash +PYTHONPATH=/workspace/verl python recipe/dapo_predictor/main_dapo_predictor_reorder.py \ + +trainer.predictor_reorder.enable=True \ + +trainer.predictor_reorder.epochs=10 \ + +trainer.predictor_reorder.batch_size=32 \ + +trainer.predictor_reorder.lr=3e-5 \ + +trainer.predictor_reorder.weight_decay=1e-4 \ +``` + +Use the same DAPO data, model, rollout, critic, and trainer overrides as the normal `recipe.dapo` entrypoint. This package only adds predictor reorder-specific overrides. + +## Experimental Setup and Effects + +The PR experiment used a long-response DAPO workload where generation time can become unbalanced across DP ranks: + +| Parameter | Value | +| --------- | ----- | +| Model | Qwen3-30B-A3B-Instruct-2507 | +| DataLoader seed | 1 | +| Global batch size | 32 | +| Samples per prompt | 8 | +| Max num sequences | 16 | +| Generation TP | 4 | +| Sequence parallel | 4, ulysses | +| Max model length | 22528 | +| Prompt length | about 2k | +| Response length | about 20k | +| NPU count | 32 | +| Training steps | 57 | + +### Critic Score + +| Metric | Reorder | Baseline | +| ------ | ------- | -------- | +| Average | 0.6179 | 0.6137 | +| First 10 steps avg | 0.4383 | 0.4391 | +| Last 10 steps avg | 0.6680 | 0.6680 | + +The critic score is essentially unchanged, so predictor reorder did not degrade training quality in this run. + +### Step Time + +| Metric | Reorder | Baseline | +| ------ | ------- | -------- | +| Average | 638.98 s/it | 668.40 s/it | +| First 10 steps avg | 616.14 s/it | 621.21 s/it | +| Last 10 steps avg | 616.23 s/it | 711.55 s/it | + +The reorder run stayed around 616 s/it, while the baseline degraded from about 621 s/it to 711 s/it. The step-time gap grew from 5.08s to 95.33s. + +### Generation Time + +| Metric | Reorder | Baseline | +| ------ | ------- | -------- | +| Average | 471.66s | 504.67s | +| First 10 steps avg | 439.39s | 461.45s | +| Last 5 steps avg | 421.14s | 522.81s | +| Trend | -18.25s | +61.37s | + +Generation time decreased during the reorder run but increased in the baseline. The generation-time advantage grew from about 22s to 101.67s as training progressed. + +### Actor Entropy + +| Metric | Reorder | Baseline | +| ------ | ------- | -------- | +| Average | 0.2664 | 0.2626 | +| First 10 steps avg | 0.2571 | 0.2577 | +| Last 5 steps avg | 0.2619 | 0.2611 | +| Trend | +0.0048 | +0.0034 | + +Actor entropy stayed comparable between the reorder and baseline runs. + +### Summary + +- No quality loss was observed: critic score was unchanged. +- Step time stayed stable with predictor reorder, while baseline step time increased late in training. +- Generation became faster and more stable in the reorder run. +- Actor entropy remained similar, suggesting the reorder did not materially change policy entropy. +- The benefit widened over time, especially for generation latency. diff --git a/dapo_predictor/__init__.py b/dapo_predictor/__init__.py new file mode 100644 index 00000000..3339e5ee --- /dev/null +++ b/dapo_predictor/__init__.py @@ -0,0 +1 @@ +"""Portable copy of recipe/dapo_predictor for local transfer.""" diff --git a/dapo_predictor/main_dapo_predictor_reorder.py b/dapo_predictor/main_dapo_predictor_reorder.py new file mode 100644 index 00000000..1e9d0577 --- /dev/null +++ b/dapo_predictor/main_dapo_predictor_reorder.py @@ -0,0 +1,110 @@ +"""DAPO entrypoint with legacy predictor-driven reorder support.""" + +import os +import socket + +import hydra +import ray +from recipe.dapo.main_dapo import DAPOTaskRunner +from recipe.dapo_predictor.predictor_dapo_trainer import PredictorRayDAPOTrainer +from recipe.dapo_predictor.predictor_worker import PredictorAsyncActorRolloutRefWorker + +from verl.experimental.reward_loop import migrate_legacy_reward_impl +from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler, run_ppo +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import auto_set_device + + +class PredictorDAPOTaskRunner(DAPOTaskRunner): + def add_actor_rollout_worker(self, config): + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role + + actor_rollout_cls = PredictorAsyncActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + self.mapping[Role.ActorRollout] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + return super().add_actor_rollout_worker(config) + + def run(self, config): + from pprint import pprint + + from omegaconf import OmegaConf, open_dict + + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.dataset.rl_dataset import collate_fn + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + trainer_predictor_cfg = OmegaConf.select(config, "trainer.predictor_reorder", default=None) + if trainer_predictor_cfg is not None: + with open_dict(config.actor_rollout_ref): + config.actor_rollout_ref.predictor_reorder = OmegaConf.create( + OmegaConf.to_container(trainer_predictor_cfg, resolve=True) + ) + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + self.add_reward_model_resource_pool(config) + self.add_ref_policy_worker(config, actor_rollout_cls) + + validate_config( + config=config, + use_reference_policy=need_reference_policy(config), + use_critic=need_critic(config), + ) + + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + resource_pool_manager = self.init_resource_pool_mgr(config) + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + trainer = PredictorRayDAPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + trainer.init_workers() + trainer.fit() + + +@hydra.main(config_path="../dapo/config", config_name="dapo_trainer", version_base=None) +def main(config): + auto_set_device(config) + config = migrate_legacy_reward_impl(config) + run_ppo(config, task_runner_class=ray.remote(num_cpus=1)(PredictorDAPOTaskRunner)) + + +if __name__ == "__main__": + main() diff --git a/dapo_predictor/predictor_dapo_trainer.py b/dapo_predictor/predictor_dapo_trainer.py new file mode 100644 index 00000000..c2dde9c9 --- /dev/null +++ b/dapo_predictor/predictor_dapo_trainer.py @@ -0,0 +1,692 @@ +"""Recipe-side DAPO trainer with predictor-driven rollout reordering.""" + +from __future__ import annotations + +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import torch +from recipe.dapo.dapo_ray_trainer import RayDAPOTrainer +from tqdm import tqdm + +from verl import DataProto +from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + apply_kl_penalty, + compute_advantage, +) +from verl.trainer.ppo.reward import extract_reward +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi +from verl.utils.metric import reduce_metrics +from verl.utils.profiler import marked_timer +from verl.utils.rollout_skip import RolloutSkip + +from .predictor_utils import snake_sort_indices + + +class PredictorRayDAPOTrainer(RayDAPOTrainer): + """DAPO trainer that only injects the predictor-specific reorder steps. + + Most heavy lifting is still reused from the current `RayDAPOTrainer` / `RayPPOTrainer` stack: + reward computation, KL/ref/value computation, actor/critic updates, checkpointing, metrics, + and rollout manager orchestration all stay on the upstream path. + """ + + def _predictor_cfg(self): + return self.config.trainer.get("predictor_reorder", {}) + + def _predictor_enabled(self) -> bool: + return self._predictor_cfg().get("enable", False) + + def _build_predictor_order(self, gen_batch: DataProto) -> torch.Tensor: + """Compute predictor scores and build snake-sort reorder indices.""" + predictor_scores = self.actor_rollout_wg.compute_predictor_score(gen_batch) + gen_batch = gen_batch.union(predictor_scores) + dp_world_size = self._get_dp_size(self.actor_rollout_wg, "actor") + return torch.tensor( + snake_sort_indices( + gen_batch.batch["predictor_scores"].tolist(), + n_samples_per_prompt=self.config.actor_rollout_ref.rollout.n, + dp_world_size=dp_world_size, + ), + dtype=torch.long, + ) + + def _apply_predictor_order(self, batch: DataProto, predictor_order: torch.Tensor | None) -> DataProto: + """Apply predictor-derived reorder indices to a DataProto batch.""" + if predictor_order is not None: + if batch.batch is not None: + batch.reorder(predictor_order) + else: + indices_np = predictor_order.detach().cpu().numpy() + batch.non_tensor_batch = {k: v[indices_np] for k, v in batch.non_tensor_batch.items()} + return batch + + def _repeat_and_tag_uid(self, batch: DataProto) -> DataProto: + """Tag each row with a UID and repeat the batch n times per prompt.""" + batch.non_tensor_batch["uid"] = np.array([str(i) for i in range(len(batch.batch))], dtype=object) + return batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + @staticmethod + def _ensure_gen_batch_has_tensors(gen_batch: DataProto, source_batch: DataProto) -> DataProto: + if gen_batch.batch is None and source_batch.batch is not None: + gen_batch.batch = source_batch.batch + return gen_batch + + def _hydrate_gen_batch_model_inputs(self, gen_batch: DataProto) -> DataProto: + """Ensure gen_batch has input_ids, attention_mask, and position_ids tensors. + + Tokenizes from raw prompts/messages if the model inputs are missing. + """ + import uuid + + from tensordict import TensorDict + + from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + TokenizationSanityCheckModeEnum, + ) + + if gen_batch.batch is None: + gen_batch.batch = {} + + batch_keys = set(gen_batch.batch.keys()) + if {"input_ids", "attention_mask", "position_ids"}.issubset(batch_keys): + return gen_batch + + if "input_ids" not in batch_keys and "prompts" in batch_keys: + gen_batch.batch["input_ids"] = gen_batch.batch["prompts"] + batch_keys.add("input_ids") + + if "input_ids" not in batch_keys: + seqs = None + messages = None + if "raw_prompt_ids" in gen_batch.non_tensor_batch: + raw_prompt_ids = gen_batch.non_tensor_batch["raw_prompt_ids"] + seqs = raw_prompt_ids.tolist() if isinstance(raw_prompt_ids, np.ndarray) else raw_prompt_ids + if "messages" in gen_batch.non_tensor_batch: + messages = gen_batch.non_tensor_batch["messages"] + elif "raw_prompt" in gen_batch.non_tensor_batch: + messages = gen_batch.non_tensor_batch["raw_prompt"] + if messages is not None and len(messages) == 0: + messages = None + if messages is not None: + multi_modal_batch = gen_batch.non_tensor_batch.get("multi_modal_data", None) + tool_schema_batch = gen_batch.non_tensor_batch.get("tool_schemas", None) + input_id_list = [] + attn_mask_list = [] + pos_id_list = [] + max_prompt_len = int(self.config.data.get("max_prompt_length", 32768)) + max_response_len = int(self.config.data.get("max_response_length", 8192)) + max_model_len = int(self.config.actor_rollout_ref.rollout.get("max_model_len") or 32768) + for i, msg in enumerate(messages): + multi_modal_data = {"image": [], "video": []} + if multi_modal_batch is not None: + mm_val = ( + multi_modal_batch[i] if isinstance(multi_modal_batch, np.ndarray) else multi_modal_batch + ) + if isinstance(mm_val, dict): + multi_modal_data.update(mm_val) + tools = None + if tool_schema_batch is not None: + tool_schema_val = ( + tool_schema_batch[i] if isinstance(tool_schema_batch, np.ndarray) else tool_schema_batch + ) + if tool_schema_val: + tools = [ + tool.model_dump() if hasattr(tool, "model_dump") else tool for tool in tool_schema_val + ] + request = AsyncRolloutRequest.model_validate( + { + "request_id": str(uuid.uuid4()), + "state": AsyncRolloutRequestStateEnum.PENDING, + "messages": msg, + "multi_modal_data": multi_modal_data, + "tool_schemas": tools, + "reward_scores": {}, + "max_prompt_len": max_prompt_len, + "max_response_len": max_response_len, + "max_model_len": max_model_len, + "use_inference_chat_template": False, + "tokenization_sanity_check_mode": TokenizationSanityCheckModeEnum.DISABLE, + "processing_class": self.tokenizer, + } + ) + input_ids = request.input_ids.squeeze(0) + attention_mask = request.attention_mask.squeeze(0) + position_ids = request.position_ids + if position_ids.dim() == 2 and position_ids.shape[0] == 1: + position_ids = position_ids.squeeze(0) + input_id_list.append(input_ids) + attn_mask_list.append(attention_mask) + pos_id_list.append(position_ids) + + if input_id_list: + max_len = max(x.shape[-1] for x in input_id_list) + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + input_ids = torch.full((len(input_id_list), max_len), fill_value=pad_token_id, dtype=torch.long) + attention_mask = torch.zeros((len(attn_mask_list), max_len), dtype=torch.long) + is_3d_pos = pos_id_list[0].dim() == 2 + if is_3d_pos: + pos_channels = pos_id_list[0].shape[0] + position_ids = torch.zeros((len(pos_id_list), pos_channels, max_len), dtype=torch.long) + else: + position_ids = torch.zeros((len(pos_id_list), max_len), dtype=torch.long) + + for i, (iid, am, pid) in enumerate(zip(input_id_list, attn_mask_list, pos_id_list, strict=True)): + input_ids[i, : iid.shape[-1]] = iid + attention_mask[i, : am.shape[-1]] = am + if is_3d_pos: + position_ids[i, :, : pid.shape[-1]] = pid + else: + position_ids[i, : pid.shape[-1]] = pid + + gen_batch.batch = TensorDict( + source={ + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=(len(input_id_list),), + ) + batch_keys.update({"input_ids", "attention_mask", "position_ids"}) + + if seqs is not None: + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + max_len = max((len(s) for s in seqs), default=0) + input_ids = torch.full((len(seqs), max_len), fill_value=pad_token_id, dtype=torch.long) + for i, seq in enumerate(seqs): + if len(seq) > 0: + input_ids[i, : len(seq)] = torch.as_tensor(seq, dtype=torch.long) + gen_batch.batch["input_ids"] = input_ids + batch_keys.add("input_ids") + + if "attention_mask" not in batch_keys and "input_ids" in batch_keys: + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + gen_batch.batch["attention_mask"] = (gen_batch.batch["input_ids"] != pad_token_id).long() + batch_keys.add("attention_mask") + + if "position_ids" not in batch_keys and "attention_mask" in batch_keys: + gen_batch.batch["position_ids"] = ( + (torch.cumsum(gen_batch.batch["attention_mask"], dim=-1) - 1).clamp_min(0).long() + ) + + return gen_batch + + @staticmethod + def _build_reverse_idx_from_uid(before_uid: np.ndarray, after_uid: np.ndarray) -> torch.Tensor: + """Build a reverse index mapping original positions to post-reorder positions. + + Used to restore data order after DP balancing. Tracks duplicate UIDs + by counting occurrences so that repeated prompts can be matched correctly. + """ + before_counts = defaultdict(int) + before_slots: dict[tuple[str, int], int] = {} + for idx, uid in enumerate(before_uid.tolist()): + key = (uid, before_counts[uid]) + before_slots[key] = idx + before_counts[uid] += 1 + + after_counts = defaultdict(int) + orig_pos_of_after = [] + for uid in after_uid.tolist(): + key = (uid, after_counts[uid]) + if key not in before_slots: + raise ValueError(f"Cannot restore predictor order: missing uid key {key}") + orig_pos_of_after.append(before_slots[key]) + after_counts[uid] += 1 + + reverse_idx = torch.empty(len(orig_pos_of_after), dtype=torch.long) + for after_pos, orig_pos in enumerate(orig_pos_of_after): + reverse_idx[orig_pos] = after_pos + return reverse_idx + + @staticmethod + def _prepare_predictor_gen_batch(source_batch: DataProto) -> DataProto: + """Pop a lightweight gen batch containing only predictor-required keys.""" + batch_keys = [] + if source_batch.batch is not None: + preferred_batch_keys = ["input_ids", "attention_mask", "position_ids", "prompts"] + source_batch_keys = set(source_batch.batch.keys()) + batch_keys = [k for k in preferred_batch_keys if k in source_batch_keys] + if not batch_keys: + batch_keys = list(source_batch_keys) + + preferred_non_tensor_keys = ["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"] + non_tensor_batch_keys = [k for k in preferred_non_tensor_keys if k in source_batch.non_tensor_batch] + + gen_batch = source_batch.pop(batch_keys=batch_keys, non_tensor_batch_keys=non_tensor_batch_keys) + return gen_batch + + @staticmethod + def _extract_restore_keys(batch: DataProto) -> np.ndarray: + """Extract unique identifier keys from non_tensor_batch for order restoration.""" + if "uid" in batch.non_tensor_batch: + return np.asarray(batch.non_tensor_batch["uid"], dtype=object) + if "extra_info" in batch.non_tensor_batch: + extra_info = batch.non_tensor_batch["extra_info"] + keys = [] + for item in extra_info: + if isinstance(item, dict) and "index" in item: + keys.append(item["index"]) + else: + keys.append(str(item)) + return np.asarray(keys, dtype=object) + raise ValueError("Cannot restore order: neither `uid` nor `extra_info.index` found in non_tensor_batch") + + def _maybe_update_predictor(self, gen_batch: DataProto, batch: DataProto, metrics, timing_raw): + with marked_timer("update_predictor", timing_raw, "orange"): + prompt_length = batch.batch["prompts"].shape[-1] + prompt_input_ids = batch.batch["prompts"] + prompt_attention_mask = batch.batch["attention_mask"][:, :prompt_length] + prompt_position_ids = batch.batch["position_ids"][:, :prompt_length] + + prompt_batch = DataProto.from_dict( + { + "input_ids": prompt_input_ids, + "attention_mask": prompt_attention_mask, + "position_ids": prompt_position_ids, + }, + meta_info=batch.meta_info, + ) + + predictor_output = self.actor_rollout_wg.update_predictor(prompt_batch, batch) + metrics.update(reduce_metrics(predictor_output.meta_info.get("metrics", {}))) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + if not self._predictor_enabled(): + return super().fit() + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + self.gen_steps = 0 + self.max_steps_duration = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + self.checkpoint_manager.update_weights() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.async_rollout_manager) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + self.gen_steps += 1 + last_val_metrics = None + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + current_epoch = self.global_steps // len(self.train_dataloader) + + for epoch in range(current_epoch, self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) + metrics = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + new_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + num_gen_batches += 1 + # print(f"new_batch{new_batch}") + gen_batch = self._get_gen_batch(new_batch) + # print(f"gen_batch{gen_batch}") + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + with marked_timer("predictor_score", timing_raw, "purple"): + with marked_timer("predictor_hydrate", timing_raw, "purple"): + predictor_input_batch = gen_batch_output.select(deepcopy=True) + predictor_input_batch = self._hydrate_gen_batch_model_inputs(predictor_input_batch) + predictor_order = self._build_predictor_order(predictor_input_batch) + self._apply_predictor_order(gen_batch_output, predictor_order) + # print(f'predictor_scores{predictor_scores}') + # generate a batch + with marked_timer("gen", timing_raw, "red"): + # print(f'gen_batch_output{gen_batch_output}') + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with marked_timer("gen_max", timing_raw, "red"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + + new_batch = new_batch.union(gen_baseline_output) + # compute reward model score on new_batch + rm_scores = None + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + rm_scores = self._compute_reward_colocate(new_batch) + new_batch = new_batch.union(rm_scores) + reward_baseline_tensor, _ = extract_reward(new_batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + new_batch.pop(batch_keys=list(keys_to_pop)) + + new_batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.union(gen_batch_output) + + if self.config.algorithm.use_kl_in_reward: + # We need these metrics for apply_kl_penalty if using kl in reward + new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw) + # otherwise, we will compute those after dynamic sampling + + with marked_timer("reward", timing_raw, "yellow"): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + # we first compute reward model score + batch_reward = self._compute_reward_colocate(new_batch) + new_batch = new_batch.union(batch_reward) + + # we combine with rule-based rm + reward_tensor, reward_extra_infos_dict = extract_reward(new_batch) + + new_batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update( + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches + else: + new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: # NOTE: When prompts after filtering is less than train batch size, + # we skip to the next generation batch + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + # Turn to numpy for easier filtering + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) + + # Collect the sequence reward for each trajectory + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + ): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f"{num_prompt_in_batch=} < {prompt_bsz=}") + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f"{num_gen_batches=}. Keep generating...") + self.gen_steps += 1 + is_last_step = self.global_steps >= self.total_training_steps + continue + else: + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) + else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + + self.checkpoint_manager.sleep_replicas() + + # === Updating === + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + # print(f'self.config.trainer.balance_batch{self.config.trainer.balance_batch}') + # if self.config.trainer.balance_batch: + # self._balance_batch(batch, metrics=metrics) + + if self.config.trainer.balance_batch: + uid_before_balance = batch.non_tensor_batch["uid"].copy() + self._balance_batch(batch, metrics=metrics) + reverse_idx = self._build_reverse_idx_from_uid( + uid_before_balance, + batch.non_tensor_batch["uid"], + ) + else: + reverse_idx = None + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + if not self.config.algorithm.use_kl_in_reward: + batch = self.compute_kl_related_metrics(batch, metrics, timing_raw) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, "cyan"): + values = self._compute_values(batch) + batch = batch.union(values) + + # Compute rollout correction weights and off-policy metrics (inherited from RayPPOTrainer) + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + with marked_timer("adv", timing_raw, "brown"): + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, "pink"): + critic_output = self._update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, "red"): + actor_output = self._update_actor(batch) + + # Update predictor after actor update + if reverse_idx is not None: + batch.reorder(reverse_idx) + self._maybe_update_predictor(gen_batch, batch, metrics, timing_raw) + + # Check if ESI/training plan is close to expiration + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + if self.config.trainer.save_freq > 0 and ( + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 + or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + + with marked_timer("update_weights", timing_raw, "red"): + self.checkpoint_manager.update_weights() + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if self.config.trainer.test_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.test_freq == 0 + ): + with marked_timer("testing", timing_raw, "green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw.get("step", 0) + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + timing_raw = defaultdict(float) # clear timing + + metrics["train/num_gen_batches"] = num_gen_batches + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 + self.gen_steps += 1 + # check if last step checkpint exists + checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + if not os.path.exists(checkpoint_dir): + # save last step checkpoint + timing_raw = defaultdict(float) + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + metrics = {f"timing/{k}": v for k, v in timing_raw.items()} + logger.log(data=metrics, step=self.global_steps) diff --git a/dapo_predictor/predictor_utils.py b/dapo_predictor/predictor_utils.py new file mode 100644 index 00000000..a83a3834 --- /dev/null +++ b/dapo_predictor/predictor_utils.py @@ -0,0 +1,36 @@ +"""Utilities for predictor-driven prompt reordering in DAPO.""" + +from __future__ import annotations + + +def snake_sort_indices(scores: list[float], n_samples_per_prompt: int, dp_world_size: int) -> list[int]: + """Return sample indices after prompt-level score sort with serpentine DP packing. + + The input scores are expected to be repeated per prompt group. We first collapse them to + prompt-level scores, sort prompts descending by score, assign prompt groups to DP ranks in + snake/serpentine order, then expand back to sample indices. + """ + if n_samples_per_prompt <= 0: + raise ValueError("n_samples_per_prompt must be positive") + if dp_world_size <= 0: + raise ValueError("dp_world_size must be positive") + if len(scores) % n_samples_per_prompt != 0: + raise ValueError("scores length must be divisible by n_samples_per_prompt") + + num_prompts = len(scores) // n_samples_per_prompt + prompt_scores = [scores[i * n_samples_per_prompt] for i in range(num_prompts)] + sorted_prompt_indices = sorted(range(num_prompts), key=lambda idx: prompt_scores[idx], reverse=True) + + dp_buckets: list[list[int]] = [[] for _ in range(dp_world_size)] + for sorted_pos, prompt_idx in enumerate(sorted_prompt_indices): + block = sorted_pos // dp_world_size + offset = sorted_pos % dp_world_size + dp_rank = offset if block % 2 == 0 else dp_world_size - 1 - offset + dp_buckets[dp_rank].append(prompt_idx) + + ordered_prompt_indices = [prompt_idx for bucket in dp_buckets for prompt_idx in bucket] + sample_indices: list[int] = [] + for prompt_idx in ordered_prompt_indices: + start = prompt_idx * n_samples_per_prompt + sample_indices.extend(range(start, start + n_samples_per_prompt)) + return sample_indices diff --git a/dapo_predictor/predictor_worker.py b/dapo_predictor/predictor_worker.py new file mode 100644 index 00000000..da46d3ef --- /dev/null +++ b/dapo_predictor/predictor_worker.py @@ -0,0 +1,469 @@ +"""Recipe-side worker extensions for predictor-driven prompt reordering.""" + +from __future__ import annotations + +import numpy as np +import torch +from codetiming import Timer +from omegaconf import OmegaConf +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +from verl import DataProto +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input +from verl.utils.device import get_device_id +from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.profiler import DistProfiler +from verl.utils.seqlen_balancing import prepare_dynamic_batch +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.workers.actor.dp_actor import DataParallelPPOActor +from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker + + +class PredictorDataParallelPPOActor(DataParallelPPOActor): + """PPO actor with an attached linear predictor scorer for prompt reordering. + + The predictor scorer maps last-token hidden states to a scalar score. + It is initialized from rank 0 and broadcast across all processes. + """ + + def __init__(self, config, actor_module, actor_optimizer=None): + super().__init__(config=config, actor_module=actor_module, actor_optimizer=actor_optimizer) + hidden_size = getattr(getattr(actor_module, "config", None), "hidden_size", None) + if hidden_size is None and hasattr(actor_module, "module"): + hidden_size = getattr(getattr(actor_module.module, "config", None), "hidden_size", None) + hidden_size = hidden_size or 4096 + self.predictor_scorer = nn.Linear(hidden_size, 1, bias=False).to(next(actor_module.parameters()).device) + if torch.distributed.is_initialized(): + for param in self.predictor_scorer.parameters(): + torch.distributed.broadcast(param.data, src=0) + + def extract_hidden_states(self, data: DataProto): + """Extract last-token hidden states from the actor model for predictor scoring.""" + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + select_keys = ["input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch( + data, + max_token_len=max_token_len, + dp_group=torch.distributed.group.WORLD, + ) + else: + micro_batches = data.split(micro_batch_size) + + hidden_states_list = [] + for micro_batch in micro_batches: + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + hidden_states = self._forward_predictor_micro_batch(model_inputs, temperature) + # hidden_states = self. _forward_micro_batch(model_inputs, temperature) + hidden_states_list.append(hidden_states) + + # Concatenate hidden states from all micro batches + all_hidden_states = torch.concat(hidden_states_list, dim=0) # [total_batch_size, hidden_size] + + return all_hidden_states + + def _forward_predictor_micro_batch(self, micro_batch, temperature): + """Process a single micro batch following the full dp_actor forward pass.""" + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + + # Handle multi-modal inputs + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic + for key in micro_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]] + else: + for key in micro_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat( + [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 + ) + + # Handle position_ids dimensionality for Qwen2VL mrope + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + + if self.use_remove_padding: + # Apply remove-padding optimization + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad position_ids + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + # Unpad multi-modal inputs + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) + + # Ulysses sequence parallel processing + if self.use_ulysses_sp: + is_vlm_model = "multi_modal_inputs" in micro_batch.keys() + if is_vlm_model: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + output_hidden_states=True, + use_cache=False, + **extra_args, + ) + full_hidden_states = output.hidden_states[-1] + + if hasattr(output, "hidden_states") and output.hidden_states is not None: + last_hidden_states = output.hidden_states[-1].squeeze(0) # (total_nnz, hidden_size) + + if self.use_ulysses_sp: + last_hidden_states = gather_outputs_and_unpad( + last_hidden_states, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + + full_hidden_states = pad_input( + hidden_states=last_hidden_states.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_hidden_states = full_hidden_states.squeeze(-1) # [batch_size, seq_len, hidden_size] + else: + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + output_hidden_states=True, + use_cache=False, + ) + full_hidden_states = output.hidden_states[-1] + if hasattr(output, "hidden_states") and output.hidden_states is not None: + full_hidden_states = output.hidden_states[-1] # [batch_size, seq_len, hidden_size] + # extract the hidden states of last token + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (batch_size,) + last_token_hidden = full_hidden_states[torch.arange(batch_size), eos_mask_idx] + + return last_token_hidden + + @staticmethod + def listmle_loss( + y_pred: torch.Tensor, + y_true: torch.Tensor, + eps: float = 1e-10, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + """Compute ListMLE loss with random feature shuffling for regularization. + + Shuffles hidden dimensions, sorts by true labels, then computes the + ListMLE loss to train the predictor scorer to predict response length. + """ + random_indices = torch.randperm(y_pred.shape[-1], generator=generator).to(y_pred.device) + y_pred_shuffled = y_pred[:, random_indices].float() + y_true_shuffled = y_true[:, random_indices].float() + _, indices = y_true_shuffled.sort(descending=True, dim=-1) + preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices) + max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True) + preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values + cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1]) + observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max + return torch.mean(torch.sum(observation_loss, dim=1)) + + +class PredictorAsyncActorRolloutRefWorker(AsyncActorRolloutRefWorker): + """Async worker that adds predictor score computation and update RPCs. + + Extends the standard worker with three new entry points: + - compute_predictor_score: forward pass to score prompts + - update_predictor: train the predictor with ListMLE loss + - update_actor: overridden to keep actor loaded for predictor when needed + """ + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + super().init_model() + self._pending_offload_param_restore = None + if self._is_actor: + self.actor = PredictorDataParallelPPOActor( + config=self.actor.config, + actor_module=self.actor_module_fsdp, + actor_optimizer=self.actor_optimizer, + ) + + def _predictor_cfg(self): + """Resolve predictor config from worker or trainer level.""" + # Worker-side config is usually rooted at `actor_rollout_ref`, so `trainer.*` + # is not always available here. + cfg = OmegaConf.select(self.config, "predictor_reorder", default=None) + if cfg is None: + cfg = OmegaConf.select(self.config, "trainer.predictor_reorder", default=None) + return cfg or {} + + def _actor_params_are_offloaded(self) -> bool: + """Check whether FSDP actor parameters are currently on CPU.""" + return next(self.actor_module_fsdp.parameters()).device.type == "cpu" + + def _sync_predictor_scorer_device(self): + """Ensure predictor scorer lives on the same device as the actor.""" + actor_device = next(self.actor_module_fsdp.parameters()).device + if next(self.actor.predictor_scorer.parameters()).device != actor_device: + self.actor.predictor_scorer = self.actor.predictor_scorer.to(actor_device) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="red", role="actor_update") + def update_actor(self, data: DataProto): + """Override to keep actor loaded on GPU when predictor needs it afterward.""" + cfg = self._predictor_cfg() + keep_actor_loaded = bool(cfg.get("predictor_keep_actor_loaded", False)) + + if not (keep_actor_loaded and self._is_offload_param): + return super().update_actor(data) + load_fsdp_model_to_gpu(self.actor_module_fsdp) + self._sync_predictor_scorer_device() + + original_is_offload_param = self._is_offload_param + self._is_offload_param = False + self._pending_offload_param_restore = original_is_offload_param + + return super().update_actor(data) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="purple", role="predictor_compute_score") + def compute_predictor_score(self, data: DataProto): + """Score each prompt by running predictor scorer on sampled hidden states. + + Takes one sample per prompt group (stride=n), extracts hidden states, + scores them, and broadcasts the score back to all samples of that prompt. + """ + assert self._is_actor + loaded_actor_for_predictor = self._is_offload_param and self._actor_params_are_offloaded() + if loaded_actor_for_predictor: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + self._sync_predictor_scorer_device() + + data = data.to(get_device_id()) + data.meta_info["micro_batch_size"] = self.config.ref.log_prob_micro_batch_size_per_gpu + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz + + n = self.config.rollout.n + batch_size = data.batch["input_ids"].shape[0] + sample_indices = list(range(0, batch_size, n)) + sampled_non_tensors = {} + for key, val in data.non_tensor_batch.items(): + sampled_non_tensors[key] = val[sample_indices] if isinstance(val, np.ndarray) else val + # Create sampled data proto for one sample per prompt group + sampled_data = DataProto.from_dict( + { + "input_ids": data.batch["input_ids"][sample_indices], + "attention_mask": data.batch["attention_mask"][sample_indices], + "position_ids": data.batch["position_ids"][sample_indices], + }, + non_tensors=sampled_non_tensors, + ) + sampled_data.meta_info = data.meta_info.copy() + with self.ulysses_sharding_manager: + sampled_hidden_states = self.actor.extract_hidden_states(data=sampled_data) + + scores = self.actor.predictor_scorer(sampled_hidden_states).squeeze(-1) + + predictor_scores = torch.zeros(batch_size, device=scores.device, dtype=scores.dtype) + for i, sample_idx in enumerate(sample_indices): + predictor_scores[sample_idx : min(sample_idx + n, batch_size)] = scores[i] + + output = DataProto.from_dict(tensors={"predictor_scores": predictor_scores}).to("cpu") + if self._pending_offload_param_restore is not None: + self._is_offload_param = self._pending_offload_param_restore + self._pending_offload_param_restore = None + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="orange", role="predictor_update") + def update_predictor(self, prompt_batch: DataProto, response_batch: DataProto): + """Train the predictor scorer to predict response length via ListMLE loss. + + Gathers hidden states and response lengths across all GPUs, + trains the predictor for `epochs` steps, and returns metrics. + """ + assert self._is_actor + cfg = self._predictor_cfg() + if not cfg.get("enable", False): + return DataProto(meta_info={"metrics": {}}) + + loaded_actor_for_predictor = self._is_offload_param and self._actor_params_are_offloaded() + if loaded_actor_for_predictor: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + self._sync_predictor_scorer_device() + + prompt_batch = prompt_batch.to(get_device_id()) + prompt_batch.meta_info["micro_batch_size"] = self.config.ref.log_prob_micro_batch_size_per_gpu + prompt_batch.meta_info["temperature"] = self.config.rollout.temperature + prompt_batch.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + prompt_batch.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz + + n = self.config.rollout.n + batch_size = prompt_batch.batch["input_ids"].shape[0] + sp_size = self.config.actor.ulysses_sequence_parallel_size + sample_indices = list(range(0, batch_size, n)) + sampled_non_tensors = {} + for key, val in prompt_batch.non_tensor_batch.items(): + sampled_non_tensors[key] = val[sample_indices] if isinstance(val, np.ndarray) else val + sampled_prompt = DataProto( + batch=prompt_batch.batch[sample_indices], + non_tensor_batch=sampled_non_tensors, + meta_info=prompt_batch.meta_info.copy(), + ) + + with self.ulysses_sharding_manager: + hidden_states = self.actor.extract_hidden_states(sampled_prompt) + + response_batch = response_batch.to(get_device_id()) + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + ) + response_lengths = (response_batch.batch["responses"] != pad_token_id).sum(dim=1) + response_lengths = response_lengths.view(-1, n).max(dim=1).values.float() + + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + + gathered_hidden = [torch.empty_like(hidden_states) for _ in range(world_size)] + gathered_lengths = [torch.empty_like(response_lengths) for _ in range(world_size)] + torch.distributed.all_gather(gathered_hidden, hidden_states) + torch.distributed.all_gather(gathered_lengths, response_lengths) + hidden_states = torch.cat(gathered_hidden, dim=0) + response_lengths = torch.cat(gathered_lengths, dim=0) + + if sp_size > 1: + rank_data_len = len(response_lengths) // world_size + dp_world_size = world_size // sp_size + + # Reshape and extract SP rank 0 data + reshaped_response = response_lengths.view(dp_world_size, sp_size, rank_data_len) + reshaped_hidden = hidden_states.view(dp_world_size, sp_size, rank_data_len, -1) + + response_lengths = reshaped_response[:, 0, :].flatten() + hidden_states = reshaped_hidden[:, 0, :, :].flatten(0, 1) + + label_group_size = max(1, self.config.rollout.response_length // 40) + response_lengths = response_lengths // label_group_size + + predictor = self.actor.predictor_scorer.float() + optimizer = torch.optim.AdamW( + predictor.parameters(), + lr=cfg.get("lr", 3e-5), + weight_decay=cfg.get("weight_decay", 1e-4), + ) + dataset = TensorDataset(hidden_states.float(), response_lengths.float()) + predictor_seed = int(cfg.get("seed", 1)) + dataloader_generator = torch.Generator() + dataloader_generator.manual_seed(predictor_seed) + dataloader = DataLoader( + dataset, + batch_size=cfg.get("batch_size", 32), + shuffle=True, + drop_last=False, + generator=dataloader_generator, + ) + listmle_generator = torch.Generator() + listmle_generator.manual_seed(predictor_seed) + epochs = cfg.get("epochs", 10) + + kendalltau = None + try: + from scipy.stats import kendalltau + except ImportError: + pass + + metrics = {} + with Timer(name="predictor_update", logger=None) as timer: + for epoch in range(epochs): + epoch_loss = 0.0 + num_batches = 0 + epoch_kendall_taus = [] + for batch_hidden, batch_lengths in dataloader: + # batch_hidden = batch_hidden.float().requires_grad_(True) + # batch_lengths = batch_lengths.float() + preds = predictor(batch_hidden).squeeze(-1).unsqueeze(0) + labels = batch_lengths.unsqueeze(0) + loss = self.actor.listmle_loss(preds, labels, generator=listmle_generator) + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(predictor.parameters(), max_norm=1.0) + optimizer.step() + epoch_loss += loss.item() + num_batches += 1 + + if epoch == 0 or epoch == epochs - 1: + pred_numpy = preds.squeeze().detach().cpu().float().numpy() + true_numpy = labels.squeeze().detach().cpu().float().numpy() + + if len(pred_numpy) > 1 and kendalltau is not None: + kendall_tau, _ = kendalltau(pred_numpy, true_numpy) + if not np.isnan(kendall_tau): + epoch_kendall_taus.append(kendall_tau) + + if epoch == 0 or epoch == epochs - 1: + avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0.0 + avg_kendall_tau = np.mean(epoch_kendall_taus) if epoch_kendall_taus else 0.0 + metrics[f"predictor/epoch_{epoch}_loss"] = avg_epoch_loss + metrics[f"predictor/epoch_{epoch}_kendall_tau"] = avg_kendall_tau + + metrics["predictor/final_loss"] = epoch_loss / max(num_batches, 1) + metrics["predictor/epochs"] = epochs + metrics["predictor/update_time_s"] = timer.last + metrics["predictor/total_samples"] = len(dataset) + + output = DataProto(meta_info={"metrics": metrics}).to("cpu") + if loaded_actor_for_predictor: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + return output diff --git a/dapo_predictor/tests/test_review_regressions.py b/dapo_predictor/tests/test_review_regressions.py new file mode 100644 index 00000000..c1398c1c --- /dev/null +++ b/dapo_predictor/tests/test_review_regressions.py @@ -0,0 +1,63 @@ +import ast +import unittest +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +WORKER = REPO_ROOT / "dapo_predictor" / "predictor_worker.py" + + +def _source(path: Path) -> str: + return path.read_text(encoding="utf-8") + + +def _function_source(path: Path, name: str) -> str: + source = _source(path) + module = ast.parse(source) + for node in ast.walk(module): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == name: + return ast.get_source_segment(source, node) + raise AssertionError(f"{name} not found in {path}") + + +class PredictorReviewRegressionTests(unittest.TestCase): + def test_predictor_worker_training_guards_small_response_length_and_uses_stable_shuffle(self): + update_source = _function_source(WORKER, "update_predictor") + + self.assertIn("max(1,", update_source) + self.assertIn("torch.Generator()", update_source) + self.assertIn("generator=", update_source) + self.assertIn("listmle_generator", update_source) + self.assertIn('cfg.get("epochs", 10)', update_source) + + def test_predictor_worker_avoids_global_seed_and_debug_prints(self): + source = _source(WORKER) + tree = ast.parse(source) + + manual_seed_calls = [ + node + for node in ast.walk(tree) + if isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "manual_seed" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "torch" + ] + print_calls = [ + node + for node in ast.walk(tree) + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "print" + ] + + self.assertEqual(manual_seed_calls, []) + self.assertEqual(print_calls, []) + + def test_kendall_tau_is_optional_when_scipy_is_missing(self): + update_source = _function_source(WORKER, "update_predictor") + + self.assertIn("try:", update_source) + self.assertIn("ImportError", update_source) + self.assertIn("kendalltau", update_source) + + +if __name__ == "__main__": + unittest.main()