Skip to content

Add VerlBB recipe#99

Draft
ISEEKYAN wants to merge 1 commit into
verl-project:mainfrom
ISEEKYAN:verlbb
Draft

Add VerlBB recipe#99
ISEEKYAN wants to merge 1 commit into
verl-project:mainfrom
ISEEKYAN:verlbb

Conversation

@ISEEKYAN

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces VerlBB, a recipe integrating the verl framework with the Bumblebee training engine for SFT and GRPO. It includes model-specific scripts, Hydra configurations, and a bridge engine implementation. Feedback highlights critical issues in the BumblebeeEngine, including incorrect token count reduction across parallel groups, flawed loss scaling for variable-length micro-batches in THD mode, and logic errors in loss mask alignment. A compatibility concern regarding Python 3.10 syntax was also identified.

Comment on lines +209 to +215
token_mask = data["loss_mask"] if "loss_mask" in data.keys() else data["response_mask"]
batch_num_tokens = token_mask.sum().to(get_device_id())
torch.distributed.all_reduce(
batch_num_tokens,
op=torch.distributed.ReduceOp.SUM,
group=self.get_data_parallel_group(),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The all_reduce operation on batch_num_tokens is performed without checking if the data parallel group is valid or if the data parallel size is greater than 1. If self.get_data_parallel_group() returns None (which can happen if DP size is 1), torch.distributed.all_reduce defaults to group.WORLD. In a multi-parallel setup (e.g., TP > 1), this will incorrectly sum the token counts across all ranks in the world, leading to an inflated batch_num_tokens and subsequently incorrect loss scaling.

        token_mask = data.get("loss_mask", data.get("response_mask"))
        if token_mask is None:
            raise KeyError("Data must contain 'loss_mask' or 'response_mask' for token counting.")
        batch_num_tokens = token_mask.sum().to(get_device_id())
        dp_group = self.get_data_parallel_group()
        if self.get_data_parallel_size() > 1 and dp_group is not None:
            torch.distributed.all_reduce(
                batch_num_tokens,
                op=torch.distributed.ReduceOp.SUM,
                group=dp_group,
            )

raise ValueError("BumblebeeEngine PP/CP SFT requires batch_num_tokens for VERL-compatible loss scaling.")
if batch_num_tokens <= 0:
raise ValueError(f"batch_num_tokens must be positive, got {batch_num_tokens}.")
loss_scale = self.get_data_parallel_size() * num_micro_batches / float(batch_num_tokens)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loss_scale is calculated as a constant for all micro-batches in the batch. This calculation assumes that every micro-batch contains the same number of tokens (i.e., batch_num_tokens / (DP_SIZE * num_micro_batches)). However, since this engine explicitly supports no_padding mode (THD), micro-batches will likely have varying numbers of tokens.

If the loss_function returns a mean loss (which is standard in verl), using a constant scale will overweight micro-batches with fewer tokens and underweight those with more, leading to mathematically incorrect gradients. The scale for each micro-batch should ideally be tokens_in_micro_batch / total_tokens_in_global_batch.

Comment on lines +658 to +659
if response_tokens:
full_mask[-response_tokens:] = row_mask[:response_tokens]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic full_mask[-response_tokens:] = row_mask[:response_tokens] is likely incorrect for standard verl datasets. In verl, the loss_mask is typically aligned with the input_ids. If row_mask is a padded tensor, the response tokens (where the mask is 1) are usually at the end of the valid sequence length, not at the beginning of the padded row.

For example, if row_mask is [0, 0, 1, 1, 1] (3 response tokens), row_mask[:3] would be [0, 0, 1], resulting in a full_mask of [0, 0, 0, 0, 1] instead of the expected [0, 0, 1, 1, 1]. A more robust approach is to extract the non-zero elements of the mask.

            active_mask = row_mask[row_mask != 0]
            if active_mask.numel() > seq_len:
                raise ValueError(
                    f"response loss mask has {active_mask.numel()} tokens but packed input sequence has {seq_len} tokens"
                )
            full_mask = torch.zeros(seq_len, dtype=row_mask.dtype, device=row_mask.device)
            if active_mask.numel():
                full_mask[-active_mask.numel():] = active_mask

return loss_mask

rows = []
for seq_ids, row_mask in zip(input_ids.unbind(0), loss_mask, strict=True):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of strict=True in zip() is only supported in Python 3.10 and later. While many modern LLM training environments use 3.10+, verl and its dependencies sometimes target Python 3.8 or 3.9. Please ensure that the minimum supported Python version for this recipe is consistent with the project's requirements, or remove strict=True for broader compatibility.

Suggested change
for seq_ids, row_mask in zip(input_ids.unbind(0), loss_mask, strict=True):
for seq_ids, row_mask in zip(input_ids.unbind(0), loss_mask):

@ISEEKYAN ISEEKYAN marked this pull request as draft May 11, 2026 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant