# [model, trainer, engine] feat: prefix-tree MAGI attention for verl SFT and RL (draft)#6689
# [model, trainer, engine] feat: prefix-tree MAGI attention for verl SFT and RL (draft)#6689arvyanh wants to merge 37 commits into
Conversation
Resolves four classes of issues blocking pre-commit on this branch:
1. Device-API usage (check_device_api_usage):
- Replace torch.cuda.* with verl/utils/device.py's get_torch_device()
across model_forward.py (14 occurrences) and prefix_tree_merge.py
(4 occurrences).
- Required by verl rule that device handles go through the unified
API so NPU/CUDA stay portable.
2. License headers (check_license):
- prefix_tree_params.py and prefix_tree_utils.py had a single-line
"Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES" header which
doesn't match any of the recognised license blocks.
- magi_patch.py had no header at all.
- All three now carry the standard Bytedance Apache 2.0 header.
3. Ruff lint:
- F841: drop unused local `node_start_in_sample` in prefix_tree_magi.py
- B905: zip-without-explicit-strict (auto-fixed by ruff)
- UP037: quoted-annotation (auto-fixed by ruff)
- E501: split an over-long log line in model_forward.py
4. Ruff format:
- Auto-formatted yumingxuan-added files: model_forward.py,
prefix_tree_merge.py, prefix_tree_magi.py, prefix_tree_params.py,
prefix_tree_utils.py, magi_patch.py, ray_trainer.py, sft_trainer.py,
sft_trainer_ray.py, multiturn_sft_dataset.py, test_prefix_tree_magi.py.
- No semantic change; whitespace / quote style / import order only.
After this commit `pre-commit run --all-files` is clean on the branch.
Fix pre-commit failures on meituan/verl_prefix_tree_full base
Refactors prefix-tree detection into a symmetric two-path dispatcher and
adds a new token-by-token trie path supporting arbitrary tree depth.
Before: build_prefix_tree_micro_batch inlined the hash-based detection
and used a custom _build_multilevel_prefix_tree_params that produced a
PrefixTreeParams directly. The dynamic-trie path (if added inline) would
need different glue code, leaving the two paths asymmetric and harder to
extend.
After: both paths produce the same (TreeNode, leaf_to_sample) contract
and share a single layout helper:
samples = _unpack_nested_to_list(input_ids)
if dynamic_trie:
result = build_tree_dynamic(samples)
else:
result = build_tree_hash_based(samples, prefix_segments_batch=...)
if result is None:
return None
tree_root, leaf_to_sample = result
params = build_layout_from_tree_node(samples, tree_root, leaf_to_sample, ...)
return _finalize_prefix_tree_batch(params, ...)
File layout:
- verl/utils/prefix_tree_dynamic.py (new): trie algorithm (TrieNode,
greedy_build_tries, convert_trie_to_tree_node) + build_tree_dynamic
returning (TreeNode, leaf_to_sample).
- verl/utils/prefix_tree_hash_based.py (new): hash detection extracted
from prefix_tree_magi.py. Exposes build_tree_hash_based with the same
(TreeNode, leaf_to_sample) return type, plus _hash_prefix and
build_prefix_segments_single_turn for dataset / trainer callers.
- verl/utils/prefix_tree_magi.py: now holds dispatcher
(build_prefix_tree_micro_batch), restore_flat_to_nested,
_finalize_prefix_tree_batch, and the MAGI/flex key builders.
Re-exports _hash_prefix and build_prefix_segments_single_turn for
backwards compatibility with existing callers (multiturn_sft_dataset,
ray_trainer, tests).
- verl/utils/prefix_tree_utils.py: add build_layout_from_tree_node helper
used by both paths to realise a TreeNode into a PrefixTreeParams; add
__all__ entry.
- verl/models/mcore/model_forward.py: thread prefix_tree_dynamic flag
from logits_processor_args through to build_prefix_tree_micro_batch.
Trade-offs of the trie path: 10-300x slower at detection but
<250 ms absolute on typical RL workloads, and supports arbitrary depth
(required for MCTS-based RL where the hash path's depth-2 cap loses
sharing structure). No rollout-side metadata required.
Tests:
- tests/utils/test_prefix_tree_dynamic.py (new, 13 tests): trie
construction, TreeNode conversion, end-to-end token-conservation
checks at depth-2 and depth-3.
Algorithm originally derived from AReaL (inclusionAI/AReaL).
Add dynamic-trie prefix-tree detection + symmetrize hash-based path
- Remove 8 dead flat-layout functions from utils.py - Rename build_multilevel_flex_spec → build_prefix_tree_attention_spec - Remove multilevel field from PrefixTreeParams - Add trie_dfs_leaf_order, trie_to_leaf_ids, mbs_groups_from_trie, prune_trie - Build trie once in prepare_micro_batches, DFS sort, store leaf_id metadata - Reuse trie in prepare_prefix_tree_micro_batches via mbs_groups_from_trie - Update test_magi.py for trie-based API - Add test_trie.py with 20 unit tests
- 1/2/4 3-layer tree: 7 leaves, 25 attention rects verified - prune_trie to 4 leaves: reduced rect count, child_b has 1 leaf - A/AB/ABC nested prefixes: zero-length leaves skipped correctly
- dispatch/undispatch are CP scatter/gather, not token merge - Skipping them in middle layers breaks CP>1 - Remove no_expand_middle from magi_attn_forward and _tl_forward - Fix TensorDict.non_tensor_batch (does not exist) → set_non_tensor_data
- Delete hash_based.py: move _hash_prefix + build_prefix_segments_single_turn → magi.py - Delete params.py: merge PrefixTreeParams + RangeSpec → utils.py - Delete balancing.py: merge load balancing functions → dynamic.py - Remove _hash_prefix_ids usage from multiturn_sft_dataset (trie handles detection) - Update all external imports - Keep trainer.py as-is
- Remove no_expand_middle from config (hf_model.yaml, model.py), PrefixTreeMagiBatch, get_prefix_tree_kwargs, build_prefix_tree_batch, engine_workers.py - Remove inject_prefix_segments from trainer.py + ray_trainer.py - Remove _hash_prefix, build_prefix_segments_single_turn from magi.py - Remove hash_prefix, build_prefix_segments_single_turn from __init__.py - Remove TestPrefixSegmentsPrior class from test_magi.py (hash-based tests) - Regenerate _generated_*.yaml configs
- prepare_prefix_tree_micro_batches prunes global trie per micro-batch and stores subtree via non_tensor_data - forward_step threads subtree through logits_processor_args - build_prefix_tree_batch passes cached_result to build_prefix_tree_micro_batch - build_prefix_tree_micro_batch accepts cached_result, skips build_tree_dynamic when provided (backward-compatible)
When a sample's tokens are entirely covered by ancestor nodes (e.g. [1,2] as a prefix of [1,2,3,4]), the sample terminated mid-tree but was not included in leaf_to_sample. Now adds zero-length leaf nodes for such samples.
There was a problem hiding this comment.
Code Review
This pull request introduces prefix-tree shared-prefix deduplication for actor and SFT training, adding dynamic-trie prefix-tree building, MAGI/flex attention integration, and trainer load-balancing utilities. Feedback on the changes highlights a critical logic modification in the non-prefix-tree fallback path within model_forward.py where the condition for passing position_ids was altered, as well as the high maintainability risks associated with the extensive monkey-patching of Megatron-LM core components in prefix_tree_merge.py.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| output_orig = model( | ||
| input_ids=input_ids_rmpad, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids_rmpad if not vision_model else None, |
There was a problem hiding this comment.
The condition for passing position_ids to the model has been changed from mtp_enable_train to not vision_model. This is a significant logic change in the non-prefix-tree fallback path and appears unrelated to the main feature of this PR. This could have unintended consequences on model correctness or performance for non-vision models that are not using MTP, as they will now receive position_ids where they previously did not. Furthermore, the original comment # position_ids is only needed for MTP is now outdated and misleading. This change needs to be justified, and if it is an intentional bug fix, it should be explained and ideally handled in a separate PR.
| position_ids=position_ids_rmpad if not vision_model else None, | |
| position_ids=position_ids_rmpad if mtp_enable_train else None, # position_ids is only needed for MTP |
| def apply_prefix_tree_patch() -> None: | ||
| """Monkey-patch upstream Megatron-LM classes to support prefix-tree attention (flex and MAGI). | ||
|
|
||
| Safe to call multiple times — subsequent calls are no-ops (checks for the | ||
| ``_magi_patched`` sentinel attribute). | ||
| """ |
There was a problem hiding this comment.
The extensive monkey-patching of multiple Megatron-LM core components (GPTModel, TransformerBlock, TransformerLayer, SelfAttention, TEDotProductAttention, RotaryEmbedding) introduces a significant maintainability risk. This approach is fragile and highly susceptible to breaking with future updates to the upstream Megatron-LM library. While the implementation is carefully written with sentinels to prevent re-patching and try...finally blocks to restore methods, a breaking change in the upstream API could lead to silent failures or hard-to-debug issues.
Consider exploring less intrusive integration methods, such as subclassing or dependency injection where possible.
The MAGI kernel computes RoPE from RotaryEmbedding.forward() which creates sequential cos/sin table for positions 0..T-1. But per-sample positions should be prefix:0..P-1, each leaf:P..P+L_i (starting from prefix_end, matching restore_flat_to_nested order).
The MAGI-CMP diagnostic runs a second full FA3 forward after MAGI OLP, causing ~2x OLP time. Keep the code as an unstaged patch for debugging. Co-authored-by: Claude To re-apply: git apply /tmp/magi_cmp.patch
…p fallback params
What does this PR do?
The RFC: #6401
Adds prefix-tree shared-prefix deduplication for verl SFT and GRPO training using MAGI attention. Token-by-token trie detection discovers shared prefixes across rollout samples without requiring rollout-side metadata. The flat token layout + attention rectangle spec is dispatched to MAGI
calc_attn, which computes correct prefix-tree attention patterns while internally deduplicating shared KV tokens.Key features
build_tree_dynamicbuilds a compressed trie from input tokens — no prior knowledge of turn boundaries needed.TEDotProductAttentionthrough theSelfAttention → TransformerLayer → GPTModelchain to inject MAGI/flex attention.Design
This PR implements the design described in RFC #6401. All
nGRPO rollout samples are packed into a single flat[prefix | leaf_0 | ... | leaf_{n-1}]sequence, run through one transformer forward pass, with cross-leaf attention blocked. The result is mathematically equivalent tonindependent forwards.Attention rectangle spec
The attention pattern is encoded as rectangle specs:
MAGI interprets these rectangles natively.
Prefix detection (dynamic trie)
Prefix detection runs entirely at train time — no rollout-side metadata required.
build_tree_dynamicperforms token-by-token compressed trie insertion over the micro-batch's input sequences. The trie is converted to aTreeNodetree, then fed tobuild_prefix_tree_attention_spec, which emits(q_ranges, k_ranges, mask_types)for every node.but in case we could retrieve the tree at runtime, we could provide the trie data to the traininer
The tree generalizes to arbitrary-depth multi-level trees. For example, in multi-turn agent RL where responses branch into sub-groups sharing a turn-2 prefix:
This produces a depth-3 tree with
turn0as root,turn1_A/turn1_Cas intermediate nodes, and four leaves. Zero-length leaf nodes are inserted when a sample terminates at an intermediate node.Megatron integration
The patch chain injects
magi_attention_keythrough upstream Megatron-LM:File summary
verl/utils/prefix_tree/dynamic.pyverl/utils/prefix_tree/utils.pyverl/utils/prefix_tree/magi.pyverl/utils/prefix_tree/trainer.pyverl/models/mcore/prefix_tree_merge.pyResult
[TODO]
Test
CPU unit tests (59 tests, all pass)
python3 prefix_script/test/test_prefix_tree_full.py # temporary locationGPU training
Validated on Megatron-LM training with GRPO on CoQA — prefix-tree MAGI attention runs correctly with CP=2, SP=1, PP=1.
Checklist Before Starting
[{modules}] {type}: {description}