Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,14 +1102,19 @@ def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict, lo
# same model_output keys as the eager logit-processor path.
if distillation_use_topk:
aux_outputs = getattr(output, "fused_linear_aux", None)
if aux_outputs is not None and aux_outputs.distillation_losses is not None:
cu_seqlens = input_ids.offsets()
for field_name in ("distillation_losses", "student_mass", "teacher_mass"):
v = getattr(aux_outputs, field_name).squeeze(0)
if self.use_ulysses_sp:
pad_size = output_args["pad_size"]
v = gather_outputs_and_unpad(v, gather_dim=0, unpad_dim=0, padding_size=pad_size)
model_output[field_name] = torch.nested.nested_tensor_from_jagged(v, cu_seqlens)
assert aux_outputs is not None and aux_outputs.distillation_losses is not None, (
"distillation_use_topk=True requires the model.forward to be invoked with "
"teacher_topk_ids/teacher_topk_log_probs (see VeOmniEngineWithLMHead."
"prepare_model_inputs) so VeOmni's chunk_topk_distill kernel populates "
"fused_linear_aux.distillation_losses / student_mass / teacher_mass."
)
cu_seqlens = input_ids.offsets()
for field_name in ("distillation_losses", "student_mass", "teacher_mass"):
v = getattr(aux_outputs, field_name).squeeze(0)
if self.use_ulysses_sp:
pad_size = output_args["pad_size"]
v = gather_outputs_and_unpad(v, gather_dim=0, unpad_dim=0, padding_size=pad_size)
model_output[field_name] = torch.nested.nested_tensor_from_jagged(v, cu_seqlens)
else:
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad.div_(temperature_rmpad.clamp(min=1e-8).unsqueeze(-1).to(logits_rmpad.dtype))
Expand Down
17 changes: 15 additions & 2 deletions verl/workers/engine/veomni/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,8 +876,18 @@ def prepare_model_inputs(self, micro_batch: TensorDict):
"both must be provided together for fused top-K distillation."
)
# Kernel kwarg names follow veomni's chunk_topk_distill_function API.
teacher_topk_ids = micro_batch["teacher_ids"].values().unsqueeze(0)
teacher_topk_log_probs = micro_batch["teacher_logprobs"].values().unsqueeze(0)
teacher_ids = micro_batch["teacher_ids"]
teacher_logprobs = micro_batch["teacher_logprobs"]
if teacher_ids.is_nested:
teacher_topk_ids = teacher_ids.values().unsqueeze(0)
teacher_topk_log_probs = teacher_logprobs.values().unsqueeze(0)
else:
# Tensors may already be in the rmpad [1, total_nnz, K]
# layout expected by VeOmni (for example when the caller has
# preprocessed the distillation batch). Avoid assuming a
# NestedTensor-only representation.
teacher_topk_ids = teacher_ids
teacher_topk_log_probs = teacher_logprobs
Comment on lines +881 to +890

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If teacher_ids is a 2D tensor of shape [total_nnz, K] (instead of a 3D tensor of shape [1, total_nnz, K]), assigning it directly to teacher_topk_ids without unsqueezing will cause sequence parallel slicing (slice_input_tensor(..., dim=1)) to slice along the K (top-k) dimension instead of the sequence dimension (total_nnz). To prevent this critical correctness bug, ensure that 2D tensors are unsqueezed to 3D.

Suggested change
if teacher_ids.is_nested:
teacher_topk_ids = teacher_ids.values().unsqueeze(0)
teacher_topk_log_probs = teacher_logprobs.values().unsqueeze(0)
else:
# Tensors may already be in the rmpad [1, total_nnz, K]
# layout expected by VeOmni (for example when the caller has
# preprocessed the distillation batch). Avoid assuming a
# NestedTensor-only representation.
teacher_topk_ids = teacher_ids
teacher_topk_log_probs = teacher_logprobs
if teacher_ids.is_nested:
teacher_topk_ids = teacher_ids.values().unsqueeze(0)
teacher_topk_log_probs = teacher_logprobs.values().unsqueeze(0)
else:
# Tensors may already be in the rmpad [1, total_nnz, K]
# layout expected by VeOmni (for example when the caller has
# preprocessed the distillation batch). Avoid assuming a
# NestedTensor-only representation.
teacher_topk_ids = teacher_ids.unsqueeze(0) if teacher_ids.dim() == 2 else teacher_ids
teacher_topk_log_probs = teacher_logprobs.unsqueeze(0) if teacher_logprobs.dim() == 2 else teacher_logprobs

# SP-slice along seqlen (dim=1); teacher tensors are 3D
# (1, total_nnz, K) so use slice_input_tensor directly —
# ulysses_pad_and_slice_inputs hardcodes 2D.
Expand All @@ -888,6 +898,9 @@ def prepare_model_inputs(self, micro_batch: TensorDict):
teacher_topk_log_probs = slice_input_tensor(teacher_topk_log_probs, dim=1, padding=True)
model_inputs["teacher_topk_ids"] = teacher_topk_ids
model_inputs["teacher_topk_log_probs"] = teacher_topk_log_probs
clamp = tu.get_non_tensor_data(data=micro_batch, key="log_prob_min_clamp", default=None)
if clamp is not None:
model_inputs["log_prob_min_clamp"] = clamp

# Router replay plumbing. Two responsibilities:
# (1) snapshot the ulysses pad_size for this micro-batch so
Expand Down