Skip to content

# [model, trainer, engine] feat: prefix-tree MAGI attention for verl SFT and RL (draft)#6689

Draft
arvyanh wants to merge 37 commits into
verl-project:mainfrom
meituan-search:verl_prefix_tree_full
Draft

# [model, trainer, engine] feat: prefix-tree MAGI attention for verl SFT and RL (draft)#6689
arvyanh wants to merge 37 commits into
verl-project:mainfrom
meituan-search:verl_prefix_tree_full

Conversation

@arvyanh

@arvyanh arvyanh commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

The RFC: #6401

Adds prefix-tree shared-prefix deduplication for verl SFT and GRPO training using MAGI attention. Token-by-token trie detection discovers shared prefixes across rollout samples without requiring rollout-side metadata. The flat token layout + attention rectangle spec is dispatched to MAGI calc_attn, which computes correct prefix-tree attention patterns while internally deduplicating shared KV tokens.

Key features

  • Dynamic trie detection: build_tree_dynamic builds a compressed trie from input tokens — no prior knowledge of turn boundaries needed.
  • Multi-level tree support: handles multi-turn rollout sharing with zero-length leaf nodes for samples that terminate at intermediate levels.
  • MAGI integration: patches Megatron-LM's TEDotProductAttention through the SelfAttention → TransformerLayer → GPTModel chain to inject MAGI/flex attention.
  • prefix-length aware dynamic bsz grouping and mbs reordering
  • Configurable old-log-prob backend: Currently the old_log_prob's log_prob seems to diverge

Design

This PR implements the design described in RFC #6401. All n GRPO rollout samples are packed into a single flat [prefix | leaf_0 | ... | leaf_{n-1}] sequence, run through one transformer forward pass, with cross-leaf attention blocked. The result is mathematically equivalent to n independent forwards.

Attention rectangle spec

The attention pattern is encoded as rectangle specs:

          k: prefix    k: leaf0    k: leaf1
q: prefix   causal       ✗           ✗
q: leaf0     full      causal         ✗
q: leaf1     full        ✗         causal

MAGI interprets these rectangles natively.

Prefix detection (dynamic trie)

Prefix detection runs entirely at train time — no rollout-side metadata required. build_tree_dynamic performs token-by-token compressed trie insertion over the micro-batch's input sequences. The trie is converted to a TreeNode tree, then fed to build_prefix_tree_attention_spec, which emits (q_ranges, k_ranges, mask_types) for every node.

but in case we could retrieve the tree at runtime, we could provide the trie data to the traininer

The tree generalizes to arbitrary-depth multi-level trees. For example, in multi-turn agent RL where responses branch into sub-groups sharing a turn-2 prefix:

S0: turn0 + turn1_A + turn2_B1
S1: turn0 + turn1_A + turn2_B2
S2: turn0 + turn1_C + turn2_D1
S3: turn0 + turn1_C + turn2_D2

This produces a depth-3 tree with turn0 as root, turn1_A/turn1_C as intermediate nodes, and four leaves. Zero-length leaf nodes are inserted when a sample terminates at an intermediate node.

Megatron integration

The patch chain injects magi_attention_key through upstream Megatron-LM:

GPTModel.forward → TransformerBlock → TransformerLayer
    → SelfAttention → TEDotProductAttention (magi_attn_forward: dispatch → calc_attn → undispatch)

File summary

Area Path
Trie + load balancing verl/utils/prefix_tree/dynamic.py
Layout + attention spec verl/utils/prefix_tree/utils.py
MAGI key + batch verl/utils/prefix_tree/magi.py
Trainer helpers verl/utils/prefix_tree/trainer.py
Megatron patch verl/models/mcore/prefix_tree_merge.py

Result

[TODO]

Test

CPU unit tests (59 tests, all pass)

python3 prefix_script/test/test_prefix_tree_full.py  # temporary location

GPU training

Validated on Megatron-LM training with GRPO on CoQA — prefix-tree MAGI attention runs correctly with CP=2, SP=1, PP=1.

Checklist Before Starting

arvyanh and others added 25 commits May 20, 2026 10:56
Resolves four classes of issues blocking pre-commit on this branch:

1. Device-API usage (check_device_api_usage):
   - Replace torch.cuda.* with verl/utils/device.py's get_torch_device()
     across model_forward.py (14 occurrences) and prefix_tree_merge.py
     (4 occurrences).
   - Required by verl rule that device handles go through the unified
     API so NPU/CUDA stay portable.

2. License headers (check_license):
   - prefix_tree_params.py and prefix_tree_utils.py had a single-line
     "Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES" header which
     doesn't match any of the recognised license blocks.
   - magi_patch.py had no header at all.
   - All three now carry the standard Bytedance Apache 2.0 header.

3. Ruff lint:
   - F841: drop unused local `node_start_in_sample` in prefix_tree_magi.py
   - B905: zip-without-explicit-strict (auto-fixed by ruff)
   - UP037: quoted-annotation (auto-fixed by ruff)
   - E501: split an over-long log line in model_forward.py

4. Ruff format:
   - Auto-formatted yumingxuan-added files: model_forward.py,
     prefix_tree_merge.py, prefix_tree_magi.py, prefix_tree_params.py,
     prefix_tree_utils.py, magi_patch.py, ray_trainer.py, sft_trainer.py,
     sft_trainer_ray.py, multiturn_sft_dataset.py, test_prefix_tree_magi.py.
   - No semantic change; whitespace / quote style / import order only.

After this commit `pre-commit run --all-files` is clean on the branch.
Fix pre-commit failures on meituan/verl_prefix_tree_full base
Refactors prefix-tree detection into a symmetric two-path dispatcher and
adds a new token-by-token trie path supporting arbitrary tree depth.

Before: build_prefix_tree_micro_batch inlined the hash-based detection
and used a custom _build_multilevel_prefix_tree_params that produced a
PrefixTreeParams directly. The dynamic-trie path (if added inline) would
need different glue code, leaving the two paths asymmetric and harder to
extend.

After: both paths produce the same (TreeNode, leaf_to_sample) contract
and share a single layout helper:

    samples = _unpack_nested_to_list(input_ids)
    if dynamic_trie:
        result = build_tree_dynamic(samples)
    else:
        result = build_tree_hash_based(samples, prefix_segments_batch=...)
    if result is None:
        return None
    tree_root, leaf_to_sample = result
    params = build_layout_from_tree_node(samples, tree_root, leaf_to_sample, ...)
    return _finalize_prefix_tree_batch(params, ...)

File layout:
  - verl/utils/prefix_tree_dynamic.py (new): trie algorithm (TrieNode,
    greedy_build_tries, convert_trie_to_tree_node) + build_tree_dynamic
    returning (TreeNode, leaf_to_sample).
  - verl/utils/prefix_tree_hash_based.py (new): hash detection extracted
    from prefix_tree_magi.py. Exposes build_tree_hash_based with the same
    (TreeNode, leaf_to_sample) return type, plus _hash_prefix and
    build_prefix_segments_single_turn for dataset / trainer callers.
  - verl/utils/prefix_tree_magi.py: now holds dispatcher
    (build_prefix_tree_micro_batch), restore_flat_to_nested,
    _finalize_prefix_tree_batch, and the MAGI/flex key builders.
    Re-exports _hash_prefix and build_prefix_segments_single_turn for
    backwards compatibility with existing callers (multiturn_sft_dataset,
    ray_trainer, tests).
  - verl/utils/prefix_tree_utils.py: add build_layout_from_tree_node helper
    used by both paths to realise a TreeNode into a PrefixTreeParams; add
    __all__ entry.
  - verl/models/mcore/model_forward.py: thread prefix_tree_dynamic flag
    from logits_processor_args through to build_prefix_tree_micro_batch.

Trade-offs of the trie path: 10-300x slower at detection but
<250 ms absolute on typical RL workloads, and supports arbitrary depth
(required for MCTS-based RL where the hash path's depth-2 cap loses
sharing structure). No rollout-side metadata required.

Tests:
  - tests/utils/test_prefix_tree_dynamic.py (new, 13 tests): trie
    construction, TreeNode conversion, end-to-end token-conservation
    checks at depth-2 and depth-3.

Algorithm originally derived from AReaL (inclusionAI/AReaL).
Add dynamic-trie prefix-tree detection + symmetrize hash-based path
- Remove 8 dead flat-layout functions from utils.py
- Rename build_multilevel_flex_spec → build_prefix_tree_attention_spec
- Remove multilevel field from PrefixTreeParams
- Add trie_dfs_leaf_order, trie_to_leaf_ids, mbs_groups_from_trie, prune_trie
- Build trie once in prepare_micro_batches, DFS sort, store leaf_id metadata
- Reuse trie in prepare_prefix_tree_micro_batches via mbs_groups_from_trie
- Update test_magi.py for trie-based API
- Add test_trie.py with 20 unit tests
- 1/2/4 3-layer tree: 7 leaves, 25 attention rects verified
- prune_trie to 4 leaves: reduced rect count, child_b has 1 leaf
- A/AB/ABC nested prefixes: zero-length leaves skipped correctly
- dispatch/undispatch are CP scatter/gather, not token merge
- Skipping them in middle layers breaks CP>1
- Remove no_expand_middle from magi_attn_forward and _tl_forward
- Fix TensorDict.non_tensor_batch (does not exist) → set_non_tensor_data
- Delete hash_based.py: move _hash_prefix + build_prefix_segments_single_turn → magi.py
- Delete params.py: merge PrefixTreeParams + RangeSpec → utils.py
- Delete balancing.py: merge load balancing functions → dynamic.py
- Remove _hash_prefix_ids usage from multiturn_sft_dataset (trie handles detection)
- Update all external imports
- Keep trainer.py as-is
- Remove no_expand_middle from config (hf_model.yaml, model.py),
  PrefixTreeMagiBatch, get_prefix_tree_kwargs, build_prefix_tree_batch,
  engine_workers.py
- Remove inject_prefix_segments from trainer.py + ray_trainer.py
- Remove _hash_prefix, build_prefix_segments_single_turn from magi.py
- Remove hash_prefix, build_prefix_segments_single_turn from __init__.py
- Remove TestPrefixSegmentsPrior class from test_magi.py (hash-based tests)
- Regenerate _generated_*.yaml configs
- prepare_prefix_tree_micro_batches prunes global trie per micro-batch
  and stores subtree via non_tensor_data
- forward_step threads subtree through logits_processor_args
- build_prefix_tree_batch passes cached_result to build_prefix_tree_micro_batch
- build_prefix_tree_micro_batch accepts cached_result, skips build_tree_dynamic
  when provided (backward-compatible)
When a sample's tokens are entirely covered by ancestor nodes (e.g.
[1,2] as a prefix of [1,2,3,4]), the sample terminated mid-tree but
was not included in leaf_to_sample. Now adds zero-length leaf nodes
for such samples.
@arvyanh arvyanh requested a review from PeterSH6 as a code owner June 11, 2026 03:04

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces prefix-tree shared-prefix deduplication for actor and SFT training, adding dynamic-trie prefix-tree building, MAGI/flex attention integration, and trainer load-balancing utilities. Feedback on the changes highlights a critical logic modification in the non-prefix-tree fallback path within model_forward.py where the condition for passing position_ids was altered, as well as the high maintainability risks associated with the extensive monkey-patching of Megatron-LM core components in prefix_tree_merge.py.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread verl/models/mcore/model_forward.py Outdated
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=attention_mask,
position_ids=position_ids_rmpad if not vision_model else None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition for passing position_ids to the model has been changed from mtp_enable_train to not vision_model. This is a significant logic change in the non-prefix-tree fallback path and appears unrelated to the main feature of this PR. This could have unintended consequences on model correctness or performance for non-vision models that are not using MTP, as they will now receive position_ids where they previously did not. Furthermore, the original comment # position_ids is only needed for MTP is now outdated and misleading. This change needs to be justified, and if it is an intentional bug fix, it should be explained and ideally handled in a separate PR.

Suggested change
position_ids=position_ids_rmpad if not vision_model else None,
position_ids=position_ids_rmpad if mtp_enable_train else None, # position_ids is only needed for MTP

Comment on lines +116 to +121
def apply_prefix_tree_patch() -> None:
"""Monkey-patch upstream Megatron-LM classes to support prefix-tree attention (flex and MAGI).

Safe to call multiple times — subsequent calls are no-ops (checks for the
``_magi_patched`` sentinel attribute).
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The extensive monkey-patching of multiple Megatron-LM core components (GPTModel, TransformerBlock, TransformerLayer, SelfAttention, TEDotProductAttention, RotaryEmbedding) introduces a significant maintainability risk. This approach is fragile and highly susceptible to breaking with future updates to the upstream Megatron-LM library. While the implementation is carefully written with sentinels to prevent re-patching and try...finally blocks to restore methods, a breaking change in the upstream API could lead to silent failures or hard-to-debug issues.

Consider exploring less intrusive integration methods, such as subclassing or dependency injection where possible.

@wuxibin89 wuxibin89 marked this pull request as draft June 12, 2026 06:30
arvyanh added 12 commits June 15, 2026 10:35
The MAGI kernel computes RoPE from RotaryEmbedding.forward() which
creates
sequential cos/sin table for positions 0..T-1. But per-sample positions
should
be prefix:0..P-1, each leaf:P..P+L_i (starting from prefix_end, matching
restore_flat_to_nested order).
The MAGI-CMP diagnostic runs a second full FA3 forward after MAGI OLP,
causing ~2x OLP time. Keep the code as an unstaged patch for debugging.

Co-authored-by: Claude

To re-apply: git apply /tmp/magi_cmp.patch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants