【draft】add new dapo trainer with TransferQueue#82
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the DAPOSyncPPOTrainer, which integrates DAPO dynamic sampling with TransferQueue and ReplayBuffer for synchronized training. The implementation includes a new Hydra configuration and a trainer class that supports multi-batch generation and conditional KL computation. Review feedback highlights a potential dimension mismatch when processing reward scores, suggests the removal of unused variables, and recommends replacing standard print statements with proper logging. Additionally, there is a suggestion to handle excessive generation retries more gracefully to prevent the training process from crashing due to a ValueError.
| if metric_name == "seq_final_reward" and "token_level_rewards" in data: | ||
| metric_values = data["token_level_rewards"].to_padded_tensor().sum(dim=-1).numpy() | ||
| else: | ||
| metric_values = data["rm_scores"].to_padded_tensor().sum(dim=-1).numpy() |
There was a problem hiding this comment.
The sum(dim=-1) operation on rm_scores might be incorrect if the reward model provides sequence-level scores (which are already 1D). If rm_scores is 1D, sum(dim=-1) will reduce the entire batch to a single scalar, causing the subsequent zip at line 133 to fail with a TypeError. It is safer to check the dimensions before summing.
metric_values = data["rm_scores"].to_padded_tensor()
if metric_values.ndim > 1:
metric_values = metric_values.sum(dim=-1)
metric_values = metric_values.numpy()| uids = list(uid_data) | ||
|
|
||
| prompt_uid2metric_vals = defaultdict(list) | ||
| prompt_uid2key_indices = defaultdict(list) |
|
|
||
| prompt_bsz = self.config.data.train_batch_size | ||
| if num_prompt_in_batch < prompt_bsz: | ||
| print(f"{num_prompt_in_batch=} < {prompt_bsz=}") |
There was a problem hiding this comment.
Using print for logging inside the training loop can be problematic, especially if it spams the console during multiple generation retries. It is better to use the configured logger.
| print(f"{num_prompt_in_batch=} < {prompt_bsz=}") | |
| logger.info(f"{num_prompt_in_batch=} < {prompt_bsz=}") |
| print(f"{num_prompt_in_batch=} < {prompt_bsz=}") | ||
| max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches | ||
| if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: | ||
| print(f"{num_gen_batches=}. Keep generating...") |
| raise ValueError( | ||
| f"{num_gen_batches=} >= {max_num_gen_batches=}. " | ||
| "Generated too many. Please check if your data are too difficult. " | ||
| "You could also try set max_num_gen_batches=0 to enable endless trials." | ||
| ) |
There was a problem hiding this comment.
Raising a ValueError here will crash the entire training process. While this is a safety check against infinite loops, in a production or long-running environment, it might be preferable to log a critical error and either skip the batch or stop the training gracefully (e.g., by returning from fit) to allow for final checkpointing or cleaner exits.
No description provided.