Skip to content

【draft】add new dapo trainer with TransferQueue#82

Draft
ji-huazhong wants to merge 1 commit into
verl-project:mainfrom
ji-huazhong:feat/dapo-tq
Draft

【draft】add new dapo trainer with TransferQueue#82
ji-huazhong wants to merge 1 commit into
verl-project:mainfrom
ji-huazhong:feat/dapo-tq

Conversation

@ji-huazhong

Copy link
Copy Markdown

No description provided.

@ji-huazhong ji-huazhong marked this pull request as draft April 12, 2026 11:41

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the DAPOSyncPPOTrainer, which integrates DAPO dynamic sampling with TransferQueue and ReplayBuffer for synchronized training. The implementation includes a new Hydra configuration and a trainer class that supports multi-batch generation and conditional KL computation. Review feedback highlights a potential dimension mismatch when processing reward scores, suggests the removal of unused variables, and recommends replacing standard print statements with proper logging. Additionally, there is a suggestion to handle excessive generation retries more gracefully to prevent the training process from crashing due to a ValueError.

Comment thread dapo/main_dapo_sync.py
if metric_name == "seq_final_reward" and "token_level_rewards" in data:
metric_values = data["token_level_rewards"].to_padded_tensor().sum(dim=-1).numpy()
else:
metric_values = data["rm_scores"].to_padded_tensor().sum(dim=-1).numpy()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The sum(dim=-1) operation on rm_scores might be incorrect if the reward model provides sequence-level scores (which are already 1D). If rm_scores is 1D, sum(dim=-1) will reduce the entire batch to a single scalar, causing the subsequent zip at line 133 to fail with a TypeError. It is safer to check the dimensions before summing.

            metric_values = data["rm_scores"].to_padded_tensor()
            if metric_values.ndim > 1:
                metric_values = metric_values.sum(dim=-1)
            metric_values = metric_values.numpy()

Comment thread dapo/main_dapo_sync.py
uids = list(uid_data)

prompt_uid2metric_vals = defaultdict(list)
prompt_uid2key_indices = defaultdict(list)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable prompt_uid2key_indices is initialized but never used in the function. It should be removed to keep the code clean.

Comment thread dapo/main_dapo_sync.py

prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print for logging inside the training loop can be problematic, especially if it spams the console during multiple generation retries. It is better to use the configured logger.

Suggested change
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
logger.info(f"{num_prompt_in_batch=} < {prompt_bsz=}")

Comment thread dapo/main_dapo_sync.py
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print for logging inside the training loop can be problematic. It is better to use the configured logger.

Suggested change
print(f"{num_gen_batches=}. Keep generating...")
logger.info(f"{num_gen_batches=}. Keep generating...")

Comment thread dapo/main_dapo_sync.py
Comment on lines +335 to +339
raise ValueError(
f"{num_gen_batches=} >= {max_num_gen_batches=}. "
"Generated too many. Please check if your data are too difficult. "
"You could also try set max_num_gen_batches=0 to enable endless trials."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Raising a ValueError here will crash the entire training process. While this is a safety check against infinite loops, in a production or long-running environment, it might be preferable to log a critical error and either skip the batch or stop the training gracefully (e.g., by returning from fit) to allow for final checkpointing or cleaner exits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant