feat: Dapo predictor#100
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a predictor-driven rollout reordering mechanism for DAPO training, including a new trainer, worker extensions, and actor modifications to score prompts and reorder them using a serpentine packing strategy. The feedback identifies several critical issues, including potential KeyError and ZeroDivisionError risks, and a weight divergence risk caused by unsynchronized shuffling across ranks during predictor training. Additionally, the reviewer pointed out significant code duplication in the fit method, performance bottlenecks in driver-side hydration logic, and various code quality issues such as improper indentation, global seed usage, and leftover debug prints.
| label_group_size= self.config.rollout.response_length // 40 | ||
| response_lengths = response_lengths // label_group_size |
There was a problem hiding this comment.
Potential ZeroDivisionError. If self.config.rollout.response_length is less than 40, label_group_size will be 0, causing a crash in the next line. It's safer to ensure a minimum value of 1.
| label_group_size= self.config.rollout.response_length // 40 | |
| response_lengths = response_lengths // label_group_size | |
| label_group_size = max(1, self.config.rollout.response_length // 40) | |
| response_lengths = response_lengths // label_group_size |
| dataloader = DataLoader(dataset, batch_size=cfg.get("batch_size", 32), shuffle=True, drop_last=False) | ||
| epochs = cfg.get("epochs", 100) | ||
|
|
||
| metrics = {} | ||
| with Timer(name="predictor_update", logger=None) as timer: | ||
| for epoch in range(epochs): | ||
| epoch_loss = 0.0 | ||
| num_batches = 0 | ||
| epoch_kendall_taus = [] | ||
| for batch_hidden, batch_lengths in dataloader: | ||
| # batch_hidden = batch_hidden.float().requires_grad_(True) | ||
| # batch_lengths = batch_lengths.float() | ||
| print('Training') | ||
| preds = predictor(batch_hidden).squeeze(-1).unsqueeze(0) | ||
| labels = batch_lengths.unsqueeze(0) | ||
| loss = self.actor.listmle_loss(preds, labels) | ||
| optimizer.zero_grad() | ||
| loss.backward() | ||
| torch.nn.utils.clip_grad_norm_(predictor.parameters(), max_norm=1.0) | ||
| optimizer.step() |
There was a problem hiding this comment.
Weight divergence risk. The predictor training loop runs on all ranks using a DataLoader with shuffle=True. Unless the random seed is perfectly synchronized across all ranks before this loop, the ranks will process batches in different orders, leading to divergent weights in the predictor_scorer. This will cause inconsistent scoring across workers during the next rollout. Consider setting a fixed seed before the loop or performing the update on rank 0 only and broadcasting the updated weights.
| for i, msg in enumerate(messages): | ||
| multi_modal_data = {"image": [], "video": []} | ||
| if multi_modal_batch is not None: | ||
| mm_val = multi_modal_batch[i] if isinstance(multi_modal_batch, np.ndarray) else multi_modal_batch | ||
| if isinstance(mm_val, dict): | ||
| multi_modal_data.update(mm_val) | ||
| tools = None | ||
| if tool_schema_batch is not None: | ||
| tool_schema_val = tool_schema_batch[i] if isinstance(tool_schema_batch, np.ndarray) else tool_schema_batch | ||
| if tool_schema_val: | ||
| tools = [tool.model_dump() if hasattr(tool, "model_dump") else tool for tool in tool_schema_val] | ||
| request = AsyncRolloutRequest.model_validate( | ||
| { | ||
| "request_id": str(uuid.uuid4()), | ||
| "state": AsyncRolloutRequestStateEnum.PENDING, | ||
| "messages": msg, | ||
| "multi_modal_data": multi_modal_data, | ||
| "tool_schemas": tools, | ||
| "reward_scores": {}, | ||
| "max_prompt_len": max_prompt_len, | ||
| "max_response_len": max_response_len, | ||
| "max_model_len": max_model_len, | ||
| "use_inference_chat_template": False, | ||
| "tokenization_sanity_check_mode": TokenizationSanityCheckModeEnum.DISABLE, | ||
| "processing_class": self.tokenizer, | ||
| } | ||
| ) |
There was a problem hiding this comment.
Performing tokenization and AsyncRolloutRequest.model_validate in a loop on the driver process for every prompt in the batch can become a performance bottleneck as the batch size or model length increases. Consider offloading this processing to the workers or optimizing the hydration logic to avoid repeated driver-side overhead.
| # Dynamically extract prompt length | ||
| prompt_length = batch.batch["prompts"].shape[-1] | ||
| prompt_input_ids = batch.batch["prompts"] # [batch_size, prompt_length] | ||
| prompt_attention_mask = batch.batch["attention_mask"][:, :prompt_length] # dynamic slice | ||
| prompt_position_ids = batch.batch["position_ids"][:, :prompt_length] # dynamic slice | ||
|
|
||
| # Build a standalone prompt batch for predictor training | ||
| prompt_batch = DataProto.from_dict({ | ||
| "input_ids": prompt_input_ids, | ||
| "attention_mask": prompt_attention_mask, | ||
| "position_ids": prompt_position_ids | ||
| }, meta_info=batch.meta_info ) | ||
|
|
||
| predictor_output = self.actor_rollout_wg.update_predictor(prompt_batch, batch) |
| def fit(self): | ||
| """ | ||
| The training loop of PPO. | ||
| The driver process only need to call the compute functions of the worker group through RPC | ||
| to construct the PPO dataflow. | ||
| The light-weight advantage computation is done on the driver process. | ||
| """ | ||
| if not self._predictor_enabled(): | ||
| return super().fit() | ||
| from omegaconf import OmegaConf | ||
|
|
||
| from verl.utils.tracking import Tracking | ||
|
|
||
| logger = Tracking( | ||
| project_name=self.config.trainer.project_name, | ||
| experiment_name=self.config.trainer.experiment_name, | ||
| default_backend=self.config.trainer.logger, | ||
| config=OmegaConf.to_container(self.config, resolve=True), | ||
| ) | ||
|
|
||
| self.global_steps = 0 | ||
| self.gen_steps = 0 | ||
| self.max_steps_duration = 0 | ||
|
|
||
| # load checkpoint before doing anything | ||
| self._load_checkpoint() | ||
| self.checkpoint_manager.update_weights() | ||
|
|
||
| # perform validation before training | ||
| # currently, we only support validation using the reward_function. | ||
| if self.config.trainer.get("val_before_train", True): | ||
| val_metrics = self._validate() | ||
| assert val_metrics, f"{val_metrics=}" | ||
| pprint(f"Initial validation metrics: {val_metrics}") | ||
| logger.log(data=val_metrics, step=self.global_steps) | ||
| if self.config.trainer.get("val_only", False): | ||
| return | ||
|
|
||
| if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): | ||
| rollout_skip = RolloutSkip(self.config, self.async_rollout_manager) | ||
| rollout_skip.wrap_generate_sequences() | ||
|
|
||
| # add tqdm | ||
| progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") | ||
|
|
||
| # we start from step 1 | ||
| self.global_steps += 1 | ||
| self.gen_steps += 1 | ||
| last_val_metrics = None | ||
|
|
||
| prev_step_profile = False | ||
| curr_step_profile = ( | ||
| self.global_steps in self.config.global_profiler.steps | ||
| if self.config.global_profiler.steps is not None | ||
| else False | ||
| ) | ||
| next_step_profile = False | ||
|
|
||
| timing_raw = defaultdict(float) | ||
| batch = None | ||
| num_prompt_in_batch = 0 | ||
| num_gen_batches = 0 | ||
| current_epoch = self.global_steps // len(self.train_dataloader) | ||
|
|
||
| for epoch in range(current_epoch, self.config.trainer.total_epochs): | ||
| for batch_dict in self.train_dataloader: | ||
| if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): | ||
| self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) | ||
| metrics = {} | ||
|
|
||
| with marked_timer("start_profile", timing_raw): | ||
| self._start_profiling( | ||
| not prev_step_profile and curr_step_profile | ||
| if self.config.global_profiler.profile_continuous_steps | ||
| else curr_step_profile | ||
| ) | ||
|
|
||
| new_batch: DataProto = DataProto.from_single_dict(batch_dict) | ||
| new_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature | ||
| num_gen_batches += 1 | ||
| # print(f"new_batch{new_batch}") | ||
| gen_batch = self._get_gen_batch(new_batch) | ||
| # print(f"gen_batch{gen_batch}") | ||
| gen_batch_output = gen_batch.repeat( | ||
| repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True | ||
| ) | ||
|
|
||
| is_last_step = self.global_steps >= self.total_training_steps | ||
|
|
||
| with marked_timer("step", timing_raw): | ||
|
|
||
| with marked_timer("predictor_score", timing_raw, "purple"): | ||
| with marked_timer("predictor_hydrate", timing_raw, "purple"): | ||
| predictor_input_batch = gen_batch_output.select(deepcopy=True) | ||
| predictor_input_batch = self._hydrate_gen_batch_model_inputs(predictor_input_batch) | ||
| predictor_order = self._build_predictor_order(predictor_input_batch) | ||
| # predictor_scores = self.actor_rollout_wg.compute_predictor_score(predictor_input_batch) | ||
| self._apply_predictor_order(gen_batch_output, predictor_order) | ||
| # print(f'predictor_scores{predictor_scores}') | ||
| # generate a batch | ||
| with marked_timer("gen", timing_raw, "red"): | ||
| # print(f'gen_batch_output{gen_batch_output}') | ||
| gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) | ||
| timing_raw.update(gen_batch_output.meta_info["timing"]) | ||
| gen_batch_output.meta_info.pop("timing", None) | ||
|
|
||
| if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: | ||
| with marked_timer("gen_max", timing_raw, "red"): | ||
| gen_baseline_batch = deepcopy(gen_batch) | ||
| gen_baseline_batch.meta_info["do_sample"] = False | ||
| gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) | ||
|
|
||
| new_batch = new_batch.union(gen_baseline_output) | ||
| # compute reward model score on new_batch | ||
| rm_scores = None | ||
| if self.use_rm and "rm_scores" not in new_batch.batch.keys(): | ||
| rm_scores = self._compute_reward_colocate(new_batch) | ||
| new_batch = new_batch.union(rm_scores) | ||
| reward_baseline_tensor, _ = extract_reward(new_batch) | ||
| reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) | ||
|
|
||
| keys_to_pop = set(gen_baseline_output.batch.keys()) | ||
| if rm_scores is not None: | ||
| keys_to_pop.update(rm_scores.batch.keys()) | ||
| new_batch.pop(batch_keys=list(keys_to_pop)) | ||
|
|
||
| new_batch.batch["reward_baselines"] = reward_baseline_tensor | ||
|
|
||
| del rm_scores, gen_baseline_batch, gen_baseline_output | ||
|
|
||
| new_batch.non_tensor_batch["uid"] = np.array( | ||
| [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object | ||
| ) | ||
| # repeat to align with repeated responses in rollout | ||
| new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) | ||
| new_batch = new_batch.union(gen_batch_output) | ||
|
|
||
| if self.config.algorithm.use_kl_in_reward: | ||
| # We need these metrics for apply_kl_penalty if using kl in reward | ||
| new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw) | ||
| # otherwise, we will compute those after dynamic sampling | ||
|
|
||
| with marked_timer("reward", timing_raw, "yellow"): | ||
| # compute scores. Support both model and function-based. | ||
| # We first compute the scores using reward model. Then, we call reward_fn to combine | ||
| # the results from reward model and rule-based results. | ||
| if self.use_rm and "rm_scores" not in new_batch.batch.keys(): | ||
| # we first compute reward model score | ||
| batch_reward = self._compute_reward_colocate(new_batch) | ||
| new_batch = new_batch.union(batch_reward) | ||
|
|
||
| # we combine with rule-based rm | ||
| reward_tensor, reward_extra_infos_dict = extract_reward(new_batch) | ||
|
|
||
| new_batch.batch["token_level_scores"] = reward_tensor | ||
|
|
||
| if reward_extra_infos_dict: | ||
| new_batch.non_tensor_batch.update( | ||
| {k: np.array(v) for k, v in reward_extra_infos_dict.items()} | ||
| ) | ||
|
|
||
| # compute rewards. apply_kl_penalty if available | ||
| if self.config.algorithm.use_kl_in_reward: | ||
| new_batch, kl_metrics = apply_kl_penalty( | ||
| new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty | ||
| ) | ||
| metrics.update( | ||
| kl_metrics | ||
| ) # TODO: This will be cleared if we use multiple genenration batches | ||
| else: | ||
| new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] | ||
|
|
||
| if not self.config.algorithm.filter_groups.enable: | ||
| batch = new_batch | ||
| else: # NOTE: When prompts after filtering is less than train batch size, | ||
| # we skip to the next generation batch | ||
| metric_name = self.config.algorithm.filter_groups.metric | ||
| if metric_name == "seq_final_reward": | ||
| # Turn to numpy for easier filtering | ||
| new_batch.non_tensor_batch["seq_final_reward"] = ( | ||
| new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() | ||
| ) | ||
| elif metric_name == "seq_reward": | ||
| new_batch.non_tensor_batch["seq_reward"] = ( | ||
| new_batch.batch["token_level_scores"].sum(dim=-1).numpy() | ||
| ) | ||
|
|
||
| # Collect the sequence reward for each trajectory | ||
| prompt_uid2metric_vals = defaultdict(list) | ||
| for uid, metric_val in zip( | ||
| new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True | ||
| ): | ||
| prompt_uid2metric_vals[uid].append(metric_val) | ||
|
|
||
| prompt_uid2metric_std = {} | ||
| for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): | ||
| prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) | ||
|
|
||
| kept_prompt_uids = [ | ||
| uid | ||
| for uid, std in prompt_uid2metric_std.items() | ||
| if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 | ||
| ] | ||
| num_prompt_in_batch += len(kept_prompt_uids) | ||
|
|
||
| kept_traj_idxs = [] | ||
| for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): | ||
| if traj_from_prompt_uid in kept_prompt_uids: | ||
| kept_traj_idxs.append(idx) | ||
|
|
||
| new_batch = new_batch[kept_traj_idxs] | ||
| batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) | ||
|
|
||
| prompt_bsz = self.config.data.train_batch_size | ||
| if num_prompt_in_batch < prompt_bsz: | ||
| print(f"{num_prompt_in_batch=} < {prompt_bsz=}") | ||
| max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches | ||
| if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: | ||
| print(f"{num_gen_batches=}. Keep generating...") | ||
| self.gen_steps += 1 | ||
| is_last_step = self.global_steps >= self.total_training_steps | ||
| continue | ||
| else: | ||
| raise ValueError( | ||
| f"{num_gen_batches=} >= {max_num_gen_batches=}." | ||
| + " Generated too many. Please check if your data are too difficult." | ||
| + " You could also try set max_num_gen_batches=0 to enable endless trials." | ||
| ) | ||
| else: | ||
| # Align the batch | ||
| traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n | ||
| batch = batch[:traj_bsz] | ||
|
|
||
| self.checkpoint_manager.sleep_replicas() | ||
|
|
||
| # === Updating === | ||
| # Balance the number of valid tokens across DP ranks. | ||
| # NOTE: This usually changes the order of data in the `batch`, | ||
| # which won't affect the advantage calculation (since it's based on uid), | ||
| # but might affect the loss calculation (due to the change of mini-batching). | ||
| # TODO: Decouple the DP balancing and mini-batching. | ||
| # print(f'self.config.trainer.balance_batch{self.config.trainer.balance_batch}') | ||
| # if self.config.trainer.balance_batch: | ||
| # self._balance_batch(batch, metrics=metrics) | ||
|
|
||
|
|
||
| if self.config.trainer.balance_batch: | ||
| uid_before_balance = batch.non_tensor_batch["uid"].copy() | ||
| self._balance_batch(batch, metrics=metrics) | ||
| reverse_idx = self._build_reverse_idx_from_uid( | ||
| uid_before_balance, | ||
| batch.non_tensor_batch["uid"], | ||
| ) | ||
| else: | ||
| reverse_idx = None | ||
| # compute global_valid tokens | ||
| batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() | ||
|
|
||
| if not self.config.algorithm.use_kl_in_reward: | ||
| batch = self.compute_kl_related_metrics(batch, metrics, timing_raw) | ||
|
|
||
| # compute values | ||
| if self.use_critic: | ||
| with marked_timer("values", timing_raw, "cyan"): | ||
| values = self._compute_values(batch) | ||
| batch = batch.union(values) | ||
|
|
||
| # Compute rollout correction weights and off-policy metrics (inherited from RayPPOTrainer) | ||
| from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch | ||
|
|
||
| rollout_corr_config = self.config.algorithm.get("rollout_correction", None) | ||
| if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: | ||
| batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) | ||
| # IS and off-policy metrics already have rollout_corr/ prefix | ||
| metrics.update(is_metrics) | ||
|
|
||
| with marked_timer("adv", timing_raw, "brown"): | ||
| # compute advantages, executed on the driver process | ||
| norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) | ||
| batch = compute_advantage( | ||
| batch, | ||
| adv_estimator=self.config.algorithm.adv_estimator, | ||
| gamma=self.config.algorithm.gamma, | ||
| lam=self.config.algorithm.lam, | ||
| num_repeat=self.config.actor_rollout_ref.rollout.n, | ||
| norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, | ||
| config=self.config.algorithm, | ||
| ) | ||
|
|
||
| # update critic | ||
| if self.use_critic: | ||
| with marked_timer("update_critic", timing_raw, "pink"): | ||
| critic_output = self._update_critic(batch) | ||
| critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) | ||
| metrics.update(critic_output_metrics) | ||
|
|
||
| # implement critic warmup | ||
| if self.config.trainer.critic_warmup <= self.global_steps: | ||
| # update actor | ||
| with marked_timer("update_actor", timing_raw, "red"): | ||
| actor_output = self._update_actor(batch) | ||
|
|
||
| # Update predictor after actor update | ||
| if reverse_idx is not None: | ||
| batch.reorder(reverse_idx) | ||
| self._maybe_update_predictor(gen_batch, batch, metrics, timing_raw) | ||
|
|
||
| # Check if ESI/training plan is close to expiration | ||
| esi_close_to_expiration = should_save_ckpt_esi( | ||
| max_steps_duration=self.max_steps_duration, | ||
| redundant_time=self.config.trainer.esi_redundant_time, | ||
| ) | ||
| if self.config.trainer.save_freq > 0 and ( | ||
| is_last_step | ||
| or self.global_steps % self.config.trainer.save_freq == 0 | ||
| or esi_close_to_expiration | ||
| ): | ||
| if esi_close_to_expiration: | ||
| print("Force saving checkpoint: ESI instance expiration approaching.") | ||
| with marked_timer("save_checkpoint", timing_raw, "green"): | ||
| self._save_checkpoint() | ||
|
|
||
| with marked_timer("update_weights", timing_raw, "red"): | ||
| self.checkpoint_manager.update_weights() | ||
| actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) | ||
| metrics.update(actor_output_metrics) | ||
|
|
||
| # Log rollout generations if enabled | ||
| rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) | ||
| if rollout_data_dir: | ||
| self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) | ||
|
|
||
| # validate | ||
| if self.config.trainer.test_freq > 0 and ( | ||
| is_last_step or self.global_steps % self.config.trainer.test_freq == 0 | ||
| ): | ||
| with marked_timer("testing", timing_raw, "green"): | ||
| val_metrics: dict = self._validate() | ||
| if is_last_step: | ||
| last_val_metrics = val_metrics | ||
| metrics.update(val_metrics) | ||
|
|
||
| with marked_timer("stop_profile", timing_raw): | ||
| next_step_profile = ( | ||
| self.global_steps + 1 in self.config.global_profiler.steps | ||
| if self.config.global_profiler.steps is not None | ||
| else False | ||
| ) | ||
| self._stop_profiling( | ||
| curr_step_profile and not next_step_profile | ||
| if self.config.global_profiler.profile_continuous_steps | ||
| else curr_step_profile | ||
| ) | ||
| prev_step_profile = curr_step_profile | ||
| curr_step_profile = next_step_profile | ||
|
|
||
| steps_duration = timing_raw.get("step", 0) | ||
| self.max_steps_duration = max(self.max_steps_duration, steps_duration) | ||
|
|
||
| # collect metrics | ||
| metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) | ||
| metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) | ||
| # TODO: implement actual tflpo and theoretical tflpo | ||
| n_gpus = self.resource_pool_manager.get_n_gpus() | ||
| metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) | ||
| timing_raw = defaultdict(float) # clear timing | ||
|
|
||
| metrics["train/num_gen_batches"] = num_gen_batches | ||
| batch = None | ||
| num_prompt_in_batch = 0 | ||
| num_gen_batches = 0 | ||
|
|
||
| # TODO: make a canonical logger that supports various backend | ||
| logger.log(data=metrics, step=self.global_steps) | ||
|
|
||
| if is_last_step: | ||
| if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): | ||
| self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) | ||
| pprint(f"Final validation metrics: {last_val_metrics}") | ||
| progress_bar.close() | ||
| return | ||
|
|
||
| progress_bar.update(1) | ||
| self.global_steps += 1 | ||
| self.gen_steps += 1 | ||
| # check if last step checkpint exists | ||
| checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") | ||
| if not os.path.exists(checkpoint_dir): | ||
| # save last step checkpoint | ||
| timing_raw = defaultdict(float) | ||
| with marked_timer("save_checkpoint", timing_raw, "green"): | ||
| self._save_checkpoint() | ||
| metrics = {f"timing/{k}": v for k, v in timing_raw.items()} | ||
| logger.log(data=metrics, step=self.global_steps) |
There was a problem hiding this comment.
The fit method is a near-complete duplication of the base RayDAPOTrainer.fit loop. This creates a significant maintenance burden, as any bug fixes or improvements in the upstream trainer will not be reflected here. Consider refactoring the base class to provide hooks (e.g., pre_rollout_hook, post_actor_update_hook) that can be overridden in this subclass to inject the predictor logic.
| self.actor_module.eval() | ||
|
|
||
| # Deterministic behavior for reproducible hidden states | ||
| torch.manual_seed(42) |
There was a problem hiding this comment.
Setting a global seed using torch.manual_seed(42) inside a worker method is discouraged as it can have unintended side effects on other parts of the system that rely on randomness. If determinism is required for hidden state extraction, consider using a local torch.Generator if applicable, although eval() mode should already ensure deterministic behavior for most models.
| ) | ||
| dataset = TensorDataset(hidden_states.float(), response_lengths.float()) | ||
| dataloader = DataLoader(dataset, batch_size=cfg.get("batch_size", 32), shuffle=True, drop_last=False) | ||
| epochs = cfg.get("epochs", 100) |
There was a problem hiding this comment.
| for batch_hidden, batch_lengths in dataloader: | ||
| # batch_hidden = batch_hidden.float().requires_grad_(True) | ||
| # batch_lengths = batch_lengths.float() | ||
| print('Training') |
| from scipy.stats import kendalltau | ||
| kendall_tau, p_value = kendalltau(pred_numpy, true_numpy) |
| @@ -0,0 +1,28 @@ | |||
| # DAPO Predictor Reorder (Portable Copy) | |||
|
|
|||
There was a problem hiding this comment.
readme说明过于简略,需要加上一些说明文档和使用文档
| predictor_input_batch = gen_batch_output.select(deepcopy=True) | ||
| predictor_input_batch = self._hydrate_gen_batch_model_inputs(predictor_input_batch) | ||
| predictor_order = self._build_predictor_order(predictor_input_batch) | ||
| # predictor_scores = self.actor_rollout_wg.compute_predictor_score(predictor_input_batch) |
Removed the seed parameter from the predictor reorder configuration.
Removed commented-out print statements and unnecessary checks.
Removed commented-out print statements for cleaner code.
Add check for predictor configuration enablement before training.
Removed unnecessary blank lines in predictor_worker.py
Predictor-Driven Prompt Reordering for DAPO
Summary
Add predictor-driven prompt reordering to the DAPO training pipeline for the release/v0.7.1 branch. Before each rollout generation, prompts are scored by a lightweight linear predictor (trained via ListMLE loss to predict response length), then reordered with serpentine (snake) DP packing so that high-scoring prompts are distributed evenly across DP ranks.
Key Files
predictor_worker.py— Predictor actor worker: hidden state extraction, score computation, and ListMLE trainingpredictor_dapo_trainer.py— Trainer that injects predictor scoring + snake-sort reorder into the training looppredictor_utils.py— Serpentine sort utility for balanced DP packingExperimental Setup
Experimental Results
Critic Score
Conclusion: Critic Score is essentially identical — reordering does not degrade training quality.
Step Time (s/it)
Conclusion: Reorder version stays stable at ~616 s/it throughout training. Baseline degrades from 621 s/it to 711 s/it — the gap widens from 5.08s to 95.33s.
Generation Time (s)
Conclusion: Reorder version gen time continuously decreases during training; baseline gen time significantly increases. The gap widens from 22s to 101.67s.
Actor Entropy
Conclusion: Both versions have basically the same entropy — no significant difference.
Key Findings
How to Enable
Co-authored-by @Mind-s