Skip to content

feat: Dapo predictor#100

Open
ml1019-lmx wants to merge 11 commits into
verl-project:mainfrom
ml1019-lmx:dapo_predictor
Open

feat: Dapo predictor#100
ml1019-lmx wants to merge 11 commits into
verl-project:mainfrom
ml1019-lmx:dapo_predictor

Conversation

@ml1019-lmx

@ml1019-lmx ml1019-lmx commented May 13, 2026

Copy link
Copy Markdown

Predictor-Driven Prompt Reordering for DAPO

Summary

Add predictor-driven prompt reordering to the DAPO training pipeline for the release/v0.7.1 branch. Before each rollout generation, prompts are scored by a lightweight linear predictor (trained via ListMLE loss to predict response length), then reordered with serpentine (snake) DP packing so that high-scoring prompts are distributed evenly across DP ranks.

Key Files

  • predictor_worker.py — Predictor actor worker: hidden state extraction, score computation, and ListMLE training
  • predictor_dapo_trainer.py — Trainer that injects predictor scoring + snake-sort reorder into the training loop
  • predictor_utils.py — Serpentine sort utility for balanced DP packing

Experimental Setup

Parameter Value
Model Qwen3-30B-A3B-Instruct-2507
DataLoader Seed 1
Global Batch Size (GBS) 32
N Samples per Prompt 8
Max Num Sequences 16
Generation TP 4
Sequence Parallel (SP) 4 (ulysses)
Max Model Len 22528
Prompt Length ~2k
Response Length ~20k
NPU Count 32
Training Steps 57

Experimental Results

Critic Score

Metric Reorder (this PR) Baseline
Average 0.6179 0.6137
First 10 steps avg 0.4383 0.4391
Last 10 steps avg 0.6680 0.6680

Conclusion: Critic Score is essentially identical — reordering does not degrade training quality.

Step Time (s/it)

Metric Reorder Baseline
Average 638.98 668.40
First 10 steps avg 616.14 621.21
Last 10 steps avg 616.23 711.55

Conclusion: Reorder version stays stable at ~616 s/it throughout training. Baseline degrades from 621 s/it to 711 s/it — the gap widens from 5.08s to 95.33s.

Generation Time (s)

Metric Reorder Baseline
Average 471.66 504.67
First 10 steps avg 439.39 461.45
Last 5 steps avg 421.14 522.81
Trend -18.25s (decreasing) +61.37s (increasing)

Conclusion: Reorder version gen time continuously decreases during training; baseline gen time significantly increases. The gap widens from 22s to 101.67s.

Actor Entropy

Metric Reorder Baseline
Average 0.2664 0.2626
First 10 steps avg 0.2571 0.2577
Last 5 steps avg 0.2619 0.2611
Trend +0.0048 (slight increase) +0.0034 (slight increase)

Conclusion: Both versions have basically the same entropy — no significant difference.

Key Findings

  1. No quality loss — Critic Score unchanged (0.6179 vs 0.6137), confirming reordering preserves training quality.
  2. Stable step time — Reorder stays at ~616 s/it; baseline degrades from 621s to 711s.
  3. Faster generation — Reorder gen time drops from 439s to 421s; baseline rises from 461s to 523s.
  4. Entropy basically consistent — Both versions show similar entropy (0.2664 vs 0.2626), no significant difference.
  5. Widening gap — The gen time advantage grows from 22s to 102s as training progresses.

How to Enable

python3 -m recipe.dapo_predictor.main_dapo_predictor_reorder \
...
...
+trainer.predictor_reorder.enable=True \
+trainer.predictor_reorder.epochs=10 \
+trainer.predictor_reorder.batch_size=32 \
+trainer.predictor_reorder.lr=3e-5

Co-authored-by @Mind-s

@ml1019-lmx ml1019-lmx marked this pull request as draft May 13, 2026 01:23

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a predictor-driven rollout reordering mechanism for DAPO training, including a new trainer, worker extensions, and actor modifications to score prompts and reorder them using a serpentine packing strategy. The feedback identifies several critical issues, including potential KeyError and ZeroDivisionError risks, and a weight divergence risk caused by unsynchronized shuffling across ranks during predictor training. Additionally, the reviewer pointed out significant code duplication in the fit method, performance bottlenecks in driver-side hydration logic, and various code quality issues such as improper indentation, global seed usage, and leftover debug prints.

Comment thread dapo_predictor/predictor_dapo_trainer.py Outdated
Comment thread dapo_predictor/predictor_worker.py Outdated
Comment on lines +413 to +414
label_group_size= self.config.rollout.response_length // 40
response_lengths = response_lengths // label_group_size

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Potential ZeroDivisionError. If self.config.rollout.response_length is less than 40, label_group_size will be 0, causing a crash in the next line. It's safer to ensure a minimum value of 1.

Suggested change
label_group_size= self.config.rollout.response_length // 40
response_lengths = response_lengths // label_group_size
label_group_size = max(1, self.config.rollout.response_length // 40)
response_lengths = response_lengths // label_group_size

Comment thread dapo_predictor/predictor_worker.py Outdated
Comment on lines +423 to +442
dataloader = DataLoader(dataset, batch_size=cfg.get("batch_size", 32), shuffle=True, drop_last=False)
epochs = cfg.get("epochs", 100)

metrics = {}
with Timer(name="predictor_update", logger=None) as timer:
for epoch in range(epochs):
epoch_loss = 0.0
num_batches = 0
epoch_kendall_taus = []
for batch_hidden, batch_lengths in dataloader:
# batch_hidden = batch_hidden.float().requires_grad_(True)
# batch_lengths = batch_lengths.float()
print('Training')
preds = predictor(batch_hidden).squeeze(-1).unsqueeze(0)
labels = batch_lengths.unsqueeze(0)
loss = self.actor.listmle_loss(preds, labels)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(predictor.parameters(), max_norm=1.0)
optimizer.step()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Weight divergence risk. The predictor training loop runs on all ranks using a DataLoader with shuffle=True. Unless the random seed is perfectly synchronized across all ranks before this loop, the ranks will process batches in different orders, leading to divergent weights in the predictor_scorer. This will cause inconsistent scoring across workers during the next rollout. Consider setting a fixed seed before the loop or performing the update on rank 0 only and broadcasting the updated weights.

Comment on lines +142 to +168
for i, msg in enumerate(messages):
multi_modal_data = {"image": [], "video": []}
if multi_modal_batch is not None:
mm_val = multi_modal_batch[i] if isinstance(multi_modal_batch, np.ndarray) else multi_modal_batch
if isinstance(mm_val, dict):
multi_modal_data.update(mm_val)
tools = None
if tool_schema_batch is not None:
tool_schema_val = tool_schema_batch[i] if isinstance(tool_schema_batch, np.ndarray) else tool_schema_batch
if tool_schema_val:
tools = [tool.model_dump() if hasattr(tool, "model_dump") else tool for tool in tool_schema_val]
request = AsyncRolloutRequest.model_validate(
{
"request_id": str(uuid.uuid4()),
"state": AsyncRolloutRequestStateEnum.PENDING,
"messages": msg,
"multi_modal_data": multi_modal_data,
"tool_schemas": tools,
"reward_scores": {},
"max_prompt_len": max_prompt_len,
"max_response_len": max_response_len,
"max_model_len": max_model_len,
"use_inference_chat_template": False,
"tokenization_sanity_check_mode": TokenizationSanityCheckModeEnum.DISABLE,
"processing_class": self.tokenizer,
}
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performing tokenization and AsyncRolloutRequest.model_validate in a loop on the driver process for every prompt in the batch can become a performance bottleneck as the batch size or model length increases. Consider offloading this processing to the workers or optimizing the hydration logic to avoid repeated driver-side overhead.

Comment on lines +295 to +308
# Dynamically extract prompt length
prompt_length = batch.batch["prompts"].shape[-1]
prompt_input_ids = batch.batch["prompts"] # [batch_size, prompt_length]
prompt_attention_mask = batch.batch["attention_mask"][:, :prompt_length] # dynamic slice
prompt_position_ids = batch.batch["position_ids"][:, :prompt_length] # dynamic slice

# Build a standalone prompt batch for predictor training
prompt_batch = DataProto.from_dict({
"input_ids": prompt_input_ids,
"attention_mask": prompt_attention_mask,
"position_ids": prompt_position_ids
}, meta_info=batch.meta_info )

predictor_output = self.actor_rollout_wg.update_predictor(prompt_batch, batch)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Incorrect indentation. These lines appear to be indented relative to a commented-out else block. They should be aligned with the with statement at line 291 to follow standard Python style guidelines (PEP 8).

Comment on lines +311 to +704
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
if not self._predictor_enabled():
return super().fit()
from omegaconf import OmegaConf

from verl.utils.tracking import Tracking

logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)

self.global_steps = 0
self.gen_steps = 0
self.max_steps_duration = 0

# load checkpoint before doing anything
self._load_checkpoint()
self.checkpoint_manager.update_weights()

# perform validation before training
# currently, we only support validation using the reward_function.
if self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return

if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
rollout_skip = RolloutSkip(self.config, self.async_rollout_manager)
rollout_skip.wrap_generate_sequences()

# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

# we start from step 1
self.global_steps += 1
self.gen_steps += 1
last_val_metrics = None

prev_step_profile = False
curr_step_profile = (
self.global_steps in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
next_step_profile = False

timing_raw = defaultdict(float)
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
current_epoch = self.global_steps // len(self.train_dataloader)

for epoch in range(current_epoch, self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"):
self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False)
metrics = {}

with marked_timer("start_profile", timing_raw):
self._start_profiling(
not prev_step_profile and curr_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)

new_batch: DataProto = DataProto.from_single_dict(batch_dict)
new_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
num_gen_batches += 1
# print(f"new_batch{new_batch}")
gen_batch = self._get_gen_batch(new_batch)
# print(f"gen_batch{gen_batch}")
gen_batch_output = gen_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)

is_last_step = self.global_steps >= self.total_training_steps

with marked_timer("step", timing_raw):

with marked_timer("predictor_score", timing_raw, "purple"):
with marked_timer("predictor_hydrate", timing_raw, "purple"):
predictor_input_batch = gen_batch_output.select(deepcopy=True)
predictor_input_batch = self._hydrate_gen_batch_model_inputs(predictor_input_batch)
predictor_order = self._build_predictor_order(predictor_input_batch)
# predictor_scores = self.actor_rollout_wg.compute_predictor_score(predictor_input_batch)
self._apply_predictor_order(gen_batch_output, predictor_order)
# print(f'predictor_scores{predictor_scores}')
# generate a batch
with marked_timer("gen", timing_raw, "red"):
# print(f'gen_batch_output{gen_batch_output}')
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, "red"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)

new_batch = new_batch.union(gen_baseline_output)
# compute reward model score on new_batch
rm_scores = None
if self.use_rm and "rm_scores" not in new_batch.batch.keys():
rm_scores = self._compute_reward_colocate(new_batch)
new_batch = new_batch.union(rm_scores)
reward_baseline_tensor, _ = extract_reward(new_batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

keys_to_pop = set(gen_baseline_output.batch.keys())
if rm_scores is not None:
keys_to_pop.update(rm_scores.batch.keys())
new_batch.pop(batch_keys=list(keys_to_pop))

new_batch.batch["reward_baselines"] = reward_baseline_tensor

del rm_scores, gen_baseline_batch, gen_baseline_output

new_batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)

if self.config.algorithm.use_kl_in_reward:
# We need these metrics for apply_kl_penalty if using kl in reward
new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw)
# otherwise, we will compute those after dynamic sampling

with marked_timer("reward", timing_raw, "yellow"):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm and "rm_scores" not in new_batch.batch.keys():
# we first compute reward model score
batch_reward = self._compute_reward_colocate(new_batch)
new_batch = new_batch.union(batch_reward)

# we combine with rule-based rm
reward_tensor, reward_extra_infos_dict = extract_reward(new_batch)

new_batch.batch["token_level_scores"] = reward_tensor

if reward_extra_infos_dict:
new_batch.non_tensor_batch.update(
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
)

# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(
new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(
kl_metrics
) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]

if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
# we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = (
new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
)
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = (
new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
)

# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
):
prompt_uid2metric_vals[uid].append(metric_val)

prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)

kept_prompt_uids = [
uid
for uid, std in prompt_uid2metric_std.items()
if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
]
num_prompt_in_batch += len(kept_prompt_uids)

kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)

new_batch = new_batch[kept_traj_idxs]
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])

prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
self.gen_steps += 1
is_last_step = self.global_steps >= self.total_training_steps
continue
else:
raise ValueError(
f"{num_gen_batches=} >= {max_num_gen_batches=}."
+ " Generated too many. Please check if your data are too difficult."
+ " You could also try set max_num_gen_batches=0 to enable endless trials."
)
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
batch = batch[:traj_bsz]

self.checkpoint_manager.sleep_replicas()

# === Updating ===
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
# print(f'self.config.trainer.balance_batch{self.config.trainer.balance_batch}')
# if self.config.trainer.balance_batch:
# self._balance_batch(batch, metrics=metrics)


if self.config.trainer.balance_batch:
uid_before_balance = batch.non_tensor_batch["uid"].copy()
self._balance_batch(batch, metrics=metrics)
reverse_idx = self._build_reverse_idx_from_uid(
uid_before_balance,
batch.non_tensor_batch["uid"],
)
else:
reverse_idx = None
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

if not self.config.algorithm.use_kl_in_reward:
batch = self.compute_kl_related_metrics(batch, metrics, timing_raw)

# compute values
if self.use_critic:
with marked_timer("values", timing_raw, "cyan"):
values = self._compute_values(batch)
batch = batch.union(values)

# Compute rollout correction weights and off-policy metrics (inherited from RayPPOTrainer)
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch

rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
# IS and off-policy metrics already have rollout_corr/ prefix
metrics.update(is_metrics)

with marked_timer("adv", timing_raw, "brown"):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)

# update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, "pink"):
critic_output = self._update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, "red"):
actor_output = self._update_actor(batch)

# Update predictor after actor update
if reverse_idx is not None:
batch.reorder(reverse_idx)
self._maybe_update_predictor(gen_batch, batch, metrics, timing_raw)

# Check if ESI/training plan is close to expiration
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
redundant_time=self.config.trainer.esi_redundant_time,
)
if self.config.trainer.save_freq > 0 and (
is_last_step
or self.global_steps % self.config.trainer.save_freq == 0
or esi_close_to_expiration
):
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
with marked_timer("save_checkpoint", timing_raw, "green"):
self._save_checkpoint()

with marked_timer("update_weights", timing_raw, "red"):
self.checkpoint_manager.update_weights()
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)

# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)

# validate
if self.config.trainer.test_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.test_freq == 0
):
with marked_timer("testing", timing_raw, "green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)

with marked_timer("stop_profile", timing_raw):
next_step_profile = (
self.global_steps + 1 in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
self._stop_profiling(
curr_step_profile and not next_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
prev_step_profile = curr_step_profile
curr_step_profile = next_step_profile

steps_duration = timing_raw.get("step", 0)
self.max_steps_duration = max(self.max_steps_duration, steps_duration)

# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing

metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0

# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)

if is_last_step:
if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"):
self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True)
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return

progress_bar.update(1)
self.global_steps += 1
self.gen_steps += 1
# check if last step checkpint exists
checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}")
if not os.path.exists(checkpoint_dir):
# save last step checkpoint
timing_raw = defaultdict(float)
with marked_timer("save_checkpoint", timing_raw, "green"):
self._save_checkpoint()
metrics = {f"timing/{k}": v for k, v in timing_raw.items()}
logger.log(data=metrics, step=self.global_steps)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fit method is a near-complete duplication of the base RayDAPOTrainer.fit loop. This creates a significant maintenance burden, as any bug fixes or improvements in the upstream trainer will not be reflected here. Consider refactoring the base class to provide hooks (e.g., pre_rollout_hook, post_actor_update_hook) that can be overridden in this subclass to inject the predictor logic.

Comment thread dapo_predictor/predictor_worker.py Outdated
self.actor_module.eval()

# Deterministic behavior for reproducible hidden states
torch.manual_seed(42)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting a global seed using torch.manual_seed(42) inside a worker method is discouraged as it can have unintended side effects on other parts of the system that rely on randomness. If determinism is required for hidden state extraction, consider using a local torch.Generator if applicable, although eval() mode should already ensure deterministic behavior for most models.

Comment thread dapo_predictor/predictor_worker.py Outdated
)
dataset = TensorDataset(hidden_states.float(), response_lengths.float())
dataloader = DataLoader(dataset, batch_size=cfg.get("batch_size", 32), shuffle=True, drop_last=False)
epochs = cfg.get("epochs", 100)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value of 100 epochs for training the linear predictor on every training step might be excessive and could slow down the overall training process. Consider a more conservative default (e.g., 10) as used in the experimental setup described in the PR.

Comment thread dapo_predictor/predictor_worker.py Outdated
for batch_hidden, batch_lengths in dataloader:
# batch_hidden = batch_hidden.float().requires_grad_(True)
# batch_lengths = batch_lengths.float()
print('Training')

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Leftover debug print statement. Please remove it or replace it with proper logging to avoid cluttering the console output during training.

Comment thread dapo_predictor/predictor_worker.py Outdated
Comment on lines +451 to +452
from scipy.stats import kendalltau
kendall_tau, p_value = kendalltau(pred_numpy, true_numpy)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scipy dependency is used here for kendalltau. Ensure that scipy is included in the project's dependencies, or wrap this in a try-except block to avoid crashing the worker if the library is missing in the environment.

@ml1019-lmx ml1019-lmx marked this pull request as ready for review May 13, 2026 03:34
@ml1019-lmx ml1019-lmx marked this pull request as draft May 14, 2026 07:02
@ml1019-lmx ml1019-lmx marked this pull request as ready for review May 14, 2026 07:03
@ml1019-lmx ml1019-lmx changed the title Dapo predictor feat: Dapo predictor May 14, 2026
Comment thread dapo_predictor/README.md
@@ -0,0 +1,28 @@
# DAPO Predictor Reorder (Portable Copy)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

readme说明过于简略,需要加上一些说明文档和使用文档

predictor_input_batch = gen_batch_output.select(deepcopy=True)
predictor_input_batch = self._hydrate_gen_batch_model_inputs(predictor_input_batch)
predictor_order = self._build_predictor_order(predictor_input_batch)
# predictor_scores = self.actor_rollout_wg.compute_predictor_score(predictor_input_batch)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无关注释可以删除

Removed the seed parameter from the predictor reorder configuration.
Removed commented-out print statements and unnecessary checks.
Removed commented-out print statements for cleaner code.
Add check for predictor configuration enablement before training.
Removed unnecessary blank lines in predictor_worker.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants