[fsdp, veomni] fix: wire fused top-k distillation outputs#6737
[fsdp, veomni] fix: wire fused top-k distillation outputs#6737zhangxin81 wants to merge 2 commits into
Conversation
Forward top-k distillation teacher tensors and clamp settings into VeOmni fused log-prob forwards, and fail closed when fused auxiliary distillation outputs are missing. Tests: PYTHONPYCACHEPREFIX=/private/tmp/verl-pycache python3 -m py_compile verl/workers/engine/fsdp/transformer_impl.py verl/workers/engine/veomni/transformer_impl.py Tests: /Users/bytedance/Documents/VeOmni/.venv/bin/python /private/tmp/test_veomni_topk_integration.py Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: zhangxin81 <115389973+zhangxin81@users.noreply.github.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for non-nested teacher tensors in the VeOmni transformer implementation, forwards the 'log_prob_min_clamp' configuration from the micro-batch to the model inputs, and introduces an assertion in the FSDP transformer implementation to ensure distillation outputs are populated when 'distillation_use_topk' is enabled. Feedback is provided regarding a potential bug where 2D teacher tensors are not unsqueezed to 3D, which would cause sequence parallel slicing to slice along the wrong dimension. A code suggestion is provided to safely handle 2D tensors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if teacher_ids.is_nested: | ||
| teacher_topk_ids = teacher_ids.values().unsqueeze(0) | ||
| teacher_topk_log_probs = teacher_logprobs.values().unsqueeze(0) | ||
| else: | ||
| # Tensors may already be in the rmpad [1, total_nnz, K] | ||
| # layout expected by VeOmni (for example when the caller has | ||
| # preprocessed the distillation batch). Avoid assuming a | ||
| # NestedTensor-only representation. | ||
| teacher_topk_ids = teacher_ids | ||
| teacher_topk_log_probs = teacher_logprobs |
There was a problem hiding this comment.
If teacher_ids is a 2D tensor of shape [total_nnz, K] (instead of a 3D tensor of shape [1, total_nnz, K]), assigning it directly to teacher_topk_ids without unsqueezing will cause sequence parallel slicing (slice_input_tensor(..., dim=1)) to slice along the K (top-k) dimension instead of the sequence dimension (total_nnz). To prevent this critical correctness bug, ensure that 2D tensors are unsqueezed to 3D.
| if teacher_ids.is_nested: | |
| teacher_topk_ids = teacher_ids.values().unsqueeze(0) | |
| teacher_topk_log_probs = teacher_logprobs.values().unsqueeze(0) | |
| else: | |
| # Tensors may already be in the rmpad [1, total_nnz, K] | |
| # layout expected by VeOmni (for example when the caller has | |
| # preprocessed the distillation batch). Avoid assuming a | |
| # NestedTensor-only representation. | |
| teacher_topk_ids = teacher_ids | |
| teacher_topk_log_probs = teacher_logprobs | |
| if teacher_ids.is_nested: | |
| teacher_topk_ids = teacher_ids.values().unsqueeze(0) | |
| teacher_topk_log_probs = teacher_logprobs.values().unsqueeze(0) | |
| else: | |
| # Tensors may already be in the rmpad [1, total_nnz, K] | |
| # layout expected by VeOmni (for example when the caller has | |
| # preprocessed the distillation batch). Avoid assuming a | |
| # NestedTensor-only representation. | |
| teacher_topk_ids = teacher_ids.unsqueeze(0) if teacher_ids.dim() == 2 else teacher_ids | |
| teacher_topk_log_probs = teacher_logprobs.unsqueeze(0) if teacher_logprobs.dim() == 2 else teacher_logprobs |
Luosuu
left a comment
There was a problem hiding this comment.
I am not sure whether this is necessary as if the user wants to use this feature then the model engine backend should be veomni?
Cover fused distillation auxiliary outputs and VeOmni teacher top-k passthrough for both nested and pre-rmpad tensor layouts. Tests: /private/tmp/uv-cache/archive-v0/XkW9QpETjgABHVawsIv1F/ruff-0.13.2.data/scripts/ruff check tests/workers/test_distillation_topk_symmetry_on_cpu.py tests/workers/test_router_replay_engine_helpers_on_cpu.py verl/workers/engine/fsdp/transformer_impl.py verl/workers/engine/veomni/transformer_impl.py Tests: /private/tmp/uv-cache/archive-v0/XkW9QpETjgABHVawsIv1F/ruff-0.13.2.data/scripts/ruff format --check tests/workers/test_distillation_topk_symmetry_on_cpu.py tests/workers/test_router_replay_engine_helpers_on_cpu.py verl/workers/engine/fsdp/transformer_impl.py verl/workers/engine/veomni/transformer_impl.py Tests: PYTHONPYCACHEPREFIX=/private/tmp/verl-pycache python3 -m py_compile tests/workers/test_distillation_topk_symmetry_on_cpu.py tests/workers/test_router_replay_engine_helpers_on_cpu.py verl/workers/engine/fsdp/transformer_impl.py verl/workers/engine/veomni/transformer_impl.py Tests: /Users/bytedance/Documents/VeOmni/.venv/bin/python /private/tmp/test_veomni_topk_integration.py Note: git commit hook was bypassed because pre-commit hook environment installation repeatedly failed while downloading ruff==0.12.2 with ConnectionResetError. Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: zhangxin81 <115389973+zhangxin81@users.noreply.github.com>
What does this PR do?
This PR fixes the VeOmni fused-kernel path for top-k forward-KL distillation.
When
distillation_use_topk=Truewithuse_fused_kernels=True, VeOmni's patched causal LM loss can compute per-token top-k distillation auxiliary outputs throughchunk_topk_distill_function. This PR wires the missing pieces so those outputs are correctly produced and consumed by verl:teacher_ids/teacher_logprobsinto VeOmni model forward asteacher_topk_ids/teacher_topk_log_probs;log_prob_min_clamp;fused_linear_auxdoes not contain the expected distillation outputs;Duplicate-work check:
gh pr list --repo verl-project/verl --state open --search "veomni topk distillation"returns only this PR (#6737).gh pr list --repo verl-project/verl --state open --search "VeOmni fused top-k distillation"returns only this PR (#6737).gh pr list --repo verl-project/verl --state open --search "teacher_topk_log_probs"returns only this PR (#6737).AI assistance was used to prepare this change. I reviewed the changed lines and ran the checks listed below.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Passed:
The targeted stub test passed 4 cases:
Attempted but blocked by local environment/network:
This failed during hook environment installation while downloading
ruff==0.12.2with repeatedConnectionResetError. To partially cover this, I ran cachedruff 0.13.2on the changed files as shown above.This failed during collection because the ad-hoc local
.venvlacksray. Installing the full dependency set was also interrupted by networkConnectionResetError. The PR adds CPU regression tests that should run in the normal verl CI environment.API and Usage Example
No public API change.
This change affects the existing internal distillation path when the batch metadata enables:
and the micro-batch contains:
Design & Code Changes
verl/workers/engine/veomni/transformer_impl.pyteacher_ids/teacher_logprobsinto VeOmni kernel argument names:teacher_topk_idsteacher_topk_log_probslog_prob_min_clamp.verl/workers/engine/fsdp/transformer_impl.pyoutput.fused_linear_aux.distillation_lossesto exist.distillation_lossesstudent_massteacher_massTests
tests/workers/test_distillation_topk_symmetry_on_cpu.py.tests/workers/test_router_replay_engine_helpers_on_cpu.py.Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysruff==0.12.2withConnectionResetError. Changed files were checked with cachedruff 0.13.2andpy_compile.ci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.recipe.