Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7666c99
Add prefix-tree MAGI/flex attention integration for verl SFT and RL
arvyanh May 20, 2026
b70cb3d
Fix CP>1 RoPE slicing bug in MAGI/flex prefix-tree attention
arvyanh May 20, 2026
9f3d937
restore accident changes
arvyanh May 20, 2026
27fb3e6
Fix pre-commit failures on meituan/verl_prefix_tree_full base
xiefan46 Jun 2, 2026
3d7f54b
Merge pull request #56 from xiefan46/pr/precommit-fix
arvyanh Jun 2, 2026
caef40b
Add dynamic-trie prefix-tree detection + symmetrize hash-based path
xiefan46 Jun 2, 2026
202d4bc
Merge pull request #57 from xiefan46/pr/dynamic-trie-rebased
arvyanh Jun 2, 2026
21cdda3
remove debug message
arvyanh Jun 4, 2026
cc1716b
code cleanup
arvyanh Jun 4, 2026
92c251c
grpo Megatron-LM
arvyanh Jun 5, 2026
40bbb8e
magi with grpo
arvyanh Jun 8, 2026
f9bf149
merge
arvyanh Jun 8, 2026
c5cc891
setup: move prefix_tree, revert unrelated changes
arvyanh Jun 8, 2026
4710d01
feat: dynamic trie, DFS pre-sort, trainer/balance helpers
arvyanh Jun 8, 2026
71ce316
fix: balance refactor, config scope, old log prob
arvyanh Jun 8, 2026
1e51c6f
fix: per-layer merge/spread with config and defaults
arvyanh Jun 9, 2026
876ffbf
refactor: trie-based everything, remove flat layout
arvyanh Jun 9, 2026
461c21c
test: mask spec for 3-layer tree, prune, zero-length leaves
arvyanh Jun 9, 2026
f57b497
fix: always dispatch/undispatch for CP correctness
arvyanh Jun 9, 2026
6212117
refactor: consolidate prefix_tree from 9→6 files
arvyanh Jun 10, 2026
bbf09e7
chore: remove dead backward-compat alias in seqlen_balancing
arvyanh Jun 10, 2026
dbf3be7
cleanup: remove dead no_expand_middle config + prefix_segments injection
arvyanh Jun 10, 2026
5d59b53
feat: eliminate trie rebuild in build_prefix_tree_micro_batch
arvyanh Jun 10, 2026
cc2254a
fix: convert_trie_to_tree_node loses samples at intermediate nodes
arvyanh Jun 10, 2026
0f414b4
add sorting for for prefix-tree micro-batches
arvyanh Jun 10, 2026
9281ee0
fix mask
arvyanh Jun 15, 2026
c11b848
fixes
arvyanh Jun 15, 2026
0870163
fix: prefix_tree_for_olp=None now inherits use_prefix_tree for on-pol…
arvyanh Jun 16, 2026
dec10e4
fix: MAGI fallbacks log ERROR; padding.py handles DP-duplicate sequen…
arvyanh Jun 16, 2026
02b831f
fix: per-sample roll in MAGI forward_prefix_tree; refactor try_forwar…
arvyanh Jun 16, 2026
9c3689a
fix: expand_flat_to_per_sample for autograd-safe logits expansion
arvyanh Jun 16, 2026
34d9c72
fix: per-sample RoPE positions for prefix tree multi-level support
arvyanh Jun 16, 2026
28de3dd
revert: disable MAGI-CMP OLP double-forward in production
arvyanh Jun 17, 2026
405e161
use layout builder flat_position_ids for RoPE
arvyanh Jun 17, 2026
2c63cea
add license
arvyanh Jun 17, 2026
ebb5a1a
remove debug FA3 call
arvyanh Jun 17, 2026
753bbb1
handle 3 edge cases causing silent FA3 fallback; add metrics; drop ol…
arvyanh Jun 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added prefix_script/data/coqa/coqa_grpo.parquet
Binary file not shown.
90 changes: 90 additions & 0 deletions prefix_script/data/coqa/coqa_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reward function for CoQA (Conversational Question Answering).

Score = F1 overlap between normalized model answer and ground truth.
Returns 1.0 for exact match, partial credit for partial overlap, 0.0 for no overlap.

Usage in verl config:
reward.custom_reward_function.path: ~/dataset/reward/coqa_reward.py
reward.custom_reward_function.name: compute_score
"""

import re
import string


def _normalize(text: str) -> str:
"""Lowercase, remove punctuation/articles, collapse whitespace."""
text = text.lower()
text = text.translate(str.maketrans("", "", string.punctuation))
text = re.sub(r"\b(a|an|the)\b", " ", text)
return " ".join(text.split())


def _f1(pred: str, gold: str) -> float:
pred_tokens = _normalize(pred).split()
gold_tokens = _normalize(gold).split()
if not pred_tokens or not gold_tokens:
return float(pred_tokens == gold_tokens)
common = set(pred_tokens) & set(gold_tokens)
if not common:
return 0.0
precision = len(common) / len(pred_tokens)
recall = len(common) / len(gold_tokens)
return 2 * precision * recall / (precision + recall)


def _extract_answer(response: str) -> str:
"""Extract the answer from model response.

Tries (in order):
1. Text after 'Answer:' or 'answer:'
2. Last non-empty line
3. Full response (trimmed)
"""
for prefix in ("Answer:", "answer:", "ANSWER:"):
if prefix in response:
return response.split(prefix, 1)[1].strip().split("\n")[0].strip()
lines = [line.strip() for line in response.strip().splitlines() if line.strip()]
return lines[-1] if lines else response.strip()


def compute_score(
solution_str: str,
ground_truth: str,
data_source: str = "coqa",
extra_info: dict | None = None,
**kwargs,
) -> float:
"""Compute F1-based reward for CoQA.

Args:
solution_str: model's raw response text.
ground_truth: correct answer string (from reward_model.ground_truth).
data_source: ignored (always "coqa").
extra_info: optional dict; if ground_truth is empty, falls back to
extra_info["answer"].

Returns:
float in [0, 1] — token-level F1 between extracted answer and gold.
"""
gold = ground_truth
if not gold and extra_info:
gold = extra_info.get("answer", "")
if not gold:
return 0.0

pred = _extract_answer(solution_str)
return _f1(pred, gold)
49 changes: 49 additions & 0 deletions prefix_script/data/coqa/prepare_coqa_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ~/dataset/coqa.parquet to verl GRPO format.

Input columns: prompt (JSON str), data_source, extra_info (JSON str)
Output adds: reward_model (dict), converts prompt/extra_info to objects
"""

import json
from pathlib import Path

import pandas as pd

SRC = Path.home() / "dataset" / "coqa.parquet"
DST = Path.home() / "dataset" / "coqa_grpo.parquet"

df = pd.read_parquet(SRC)
print(f"Loaded {len(df)} rows from {SRC}")

records = []
for _, row in df.iterrows():
prompt = json.loads(row["prompt"]) # list[dict]
extra = json.loads(row["extra_info"]) # dict
answer = extra.get("answer", "")
records.append(
{
"data_source": row["data_source"],
"prompt": prompt,
"extra_info": extra,
"reward_model": {"ground_truth": answer, "style": "rule"},
}
)

out = pd.DataFrame(records)
out.to_parquet(DST, index=False)
print(f"Saved {len(out)} rows → {DST}")
print("Sample reward_model:", out.iloc[0]["reward_model"])
Binary file added prefix_script/data/gsm8k_sft_10240/test.parquet
Binary file not shown.
Binary file added prefix_script/data/gsm8k_sft_10240/train.parquet
Binary file not shown.
89 changes: 89 additions & 0 deletions prefix_script/run_grpo_coqa_magi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env bash
# GRPO CoQA — Megatron actor + MAGI prefix tree, Qwen3-0.6B, 50 steps
# Actor training deduplicates shared prompt prefix across rollout.n=4 responses
# Expected prefix_sharing_ratio ~0.6-0.7 (prompt ~700 tok shared / total ~800 tok)

set -xeuo pipefail
export PATH=/usr/local/miniconda3/bin:$PATH
export OMP_NUM_THREADS=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export VLLM_ATTENTION_BACKEND=TORCH_SDPA
export HYDRA_FULL_ERROR=1

MODEL_BASE="/mnt/dolphinfs/ssd_pool/docker/user/hadoop-ai-search/deepsearch_files_ssd/LLMbasemodels/huggingface.co/Qwen"
MODEL_PATH="${MODEL_PATH:-$MODEL_BASE/Qwen3-4B-Base}"
# Dataset: coqa_grpo.parquet prepared from CoQA via prefix_script/data/coqa/prepare_coqa_grpo.py
TRAIN_FILES="${TRAIN_FILES:-${VERL_DIR:-/home/hadoop-djst-algoplat/prefix-tree/verl_prefix_tree}/prefix_script/data/coqa/coqa_grpo.parquet}"
REWARD_FN="${REWARD_FN:-${VERL_DIR:-/home/hadoop-djst-algoplat/prefix-tree/verl_prefix_tree}/prefix_script/data/coqa/coqa_reward.py}"

TS="${TS:-$(date +%Y%m%d_%H%M%S)}"
OUTDIR="${OUTDIR:-/tmp/claude/grpo_coqa_4b/magi/${TS}}"
mkdir -p "$OUTDIR"

VERL_DIR="${VERL_DIR:-/home/hadoop-djst-algoplat/prefix-tree/verl_prefix_tree}"
cd "$VERL_DIR"

echo "================================================================"
echo "GRPO CoQA — Megatron actor + MAGI prefix tree"
echo " model : $MODEL_PATH steps: 50"
echo " actor : Megatron TP=4, use_prefix_tree=True, attn=magi"
echo " rollout: vllm TP=4, n=4"
echo " Watch: train/prefix_sharing_ratio — expect >0.5"
echo "================================================================"

TENSORBOARD_DIR="$OUTDIR/tb" python3 -m verl.trainer.main_ppo \
model_engine=megatron \
algorithm.adv_estimator=grpo \
algorithm.use_kl_in_reward=False \
\
data.train_files="$TRAIN_FILES" \
data.val_files="$TRAIN_FILES" \
data.val_max_samples=32 \
data.train_batch_size=128 \
data.max_prompt_length=1024 \
data.max_response_length=128 \
data.filter_overlong_prompts=True \
data.truncation=left \
data.prompt_key=prompt \
\
actor_rollout_ref.model.path="$MODEL_PATH" \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.model.trust_remote_code=True \
actor_rollout_ref.model.use_prefix_tree=True \
actor_rollout_ref.model.prefix_tree_attention=magi \
\
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.ppo_epochs=1 \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \
actor_rollout_ref.actor.megatron.use_mbridge=True \
actor_rollout_ref.actor.megatron.vanilla_mbridge=True \
actor_rollout_ref.actor.megatron.use_megatron_fsdp=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=False \
\
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.max_model_len=2048 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.calculate_log_probs=True \
\
reward.custom_reward_function.path="$REWARD_FN" \
reward.num_workers=2 \
\
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.total_training_steps=50 \
trainer.logger='["console","tensorboard"]' \
trainer.project_name=grpo_coqa_4b \
trainer.experiment_name=coqa_magi_4b \
trainer.save_freq=-1 \
trainer.test_freq=-1 \
trainer.val_before_train=False \
trainer.balance_batch=True \
2>&1 | tee "$OUTDIR/run.log"
121 changes: 121 additions & 0 deletions prefix_script/run_grpo_coqa_magi_profile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env bash
# GRPO CoQA magi — actor-only profiling run (MAGI_TIMING + optional nsys)
# Steps: 30 only. Use PROFILE_MODE=timing (default) or PROFILE_MODE=nsys
#
# PROFILE_MODE=timing → MAGI_TIMING=1 prints dispatch/calc_attn/undispatch breakdown
# PROFILE_MODE=nsys → nsys wraps actor update_actor only at steps 20,25
#
# Usage:
# bash run_grpo_coqa_magi_profile.sh # timing mode
# PROFILE_MODE=nsys bash run_grpo_coqa_magi_profile.sh # nsys mode

set -xeuo pipefail
export PATH=/usr/local/miniconda3/bin:$PATH
export OMP_NUM_THREADS=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export VLLM_ATTENTION_BACKEND=TORCH_SDPA
export HYDRA_FULL_ERROR=1

# Timing: print dispatch/calc_attn/undispatch for layer 0 on every micro-batch
export MAGI_TIMING=1

# Diag: additional MAGI diagnostic output
export MAGI_DIAG=1

MODEL_PATH="${MODEL_PATH:-/mnt/dolphinfs/ssd_pool/docker/user/hadoop-ai-search/deepsearch_files_ssd/LLMbasemodels/huggingface.co/Qwen/Qwen3-4B-Base}"
# Dataset: coqa_grpo.parquet prepared from CoQA via prefix_script/data/coqa/prepare_coqa_grpo.py
TRAIN_FILES="${TRAIN_FILES:-/home/hadoop-djst-algoplat/prefix-tree/verl_prefix_tree/prefix_script/data/coqa/coqa_grpo.parquet}"
REWARD_FN="${REWARD_FN:-/home/hadoop-djst-algoplat/prefix-tree/verl_prefix_tree/prefix_script/data/coqa/coqa_reward.py}"
PROFILE_MODE="${PROFILE_MODE:-timing}"

TS=$(date +%Y%m%d_%H%M%S)
OUTDIR="/tmp/verl_submit/profiles/magi/${TS}"
mkdir -p "$OUTDIR"

VERL_DIR="${VERL_DIR:-/home/hadoop-djst-algoplat/prefix-tree/verl_prefix_tree}"
cd "$VERL_DIR"

echo "================================================================"
echo "MAGI profile run mode=$PROFILE_MODE TS=$TS"
echo " output: $OUTDIR"
echo "================================================================"

# Base args (same as magi run)
BASE_ARGS=(
model_engine=megatron
algorithm.adv_estimator=grpo
algorithm.use_kl_in_reward=False
data.train_files="$TRAIN_FILES"
data.val_files="$TRAIN_FILES"
data.val_max_samples=32
data.train_batch_size=128
data.max_prompt_length=1024
data.max_response_length=128
data.filter_overlong_prompts=True
data.truncation=left
data.prompt_key=prompt
actor_rollout_ref.model.path="$MODEL_PATH"
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.model.use_prefix_tree=True
actor_rollout_ref.model.prefix_tree_attention=magi
actor_rollout_ref.actor.use_kl_loss=False
actor_rollout_ref.actor.ppo_mini_batch_size=128
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8
actor_rollout_ref.actor.ppo_epochs=1
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1
actor_rollout_ref.actor.megatron.use_mbridge=True
actor_rollout_ref.actor.megatron.vanilla_mbridge=True
actor_rollout_ref.actor.megatron.use_megatron_fsdp=True
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=False
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.n=8
actor_rollout_ref.rollout.tensor_model_parallel_size=1
actor_rollout_ref.rollout.gpu_memory_utilization=0.5
actor_rollout_ref.rollout.max_model_len=2048
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8
reward.custom_reward_function.path="$REWARD_FN"
reward.num_workers=2
trainer.n_gpus_per_node=8
trainer.nnodes=1
trainer.total_training_steps=30
trainer.logger='["console","tensorboard"]'
trainer.project_name=grpo_coqa_4b_profile
trainer.experiment_name="magi_profile_${TS}"
trainer.save_freq=-1
trainer.test_freq=-1
trainer.val_before_train=False
trainer.balance_batch=True
)

if [ "$PROFILE_MODE" = "nsys" ]; then
# nsys: profile actor update_actor only at steps 20,25
NSYS_ARGS=(
+global_profiler.tool=nsys
"+global_profiler.steps=[20,25]"
++global_profiler.save_path="$OUTDIR/nsys"
++global_profiler.global_tool_config.nsys.discrete=True
++actor_rollout_ref.actor.profiler.enable=True
++actor_rollout_ref.actor.profiler.all_ranks=True
)
export TENSORBOARD_DIR=$HOME/profiles/magi_${TS}/tb
mkdir -p "$TENSORBOARD_DIR"
python3 -m verl.trainer.main_ppo \
"${BASE_ARGS[@]}" "${NSYS_ARGS[@]}" 2>&1 | tee "$OUTDIR/run.log"
else
# timing mode: MAGI_TIMING=1 already set above
export TENSORBOARD_DIR=$HOME/profiles/magi_${TS}/tb
mkdir -p "$TENSORBOARD_DIR"
python3 -m verl.trainer.main_ppo \
"${BASE_ARGS[@]}" 2>&1 | tee "$OUTDIR/run.log"
fi

echo "================================================================"
echo "Profile done → $OUTDIR"
if [ "$PROFILE_MODE" = "timing" ]; then
echo "Extract timing: grep 'MAGI-TIMING' $OUTDIR/run.log"
fi
echo "================================================================"
Loading