Skip to content

[recipe][gkd/fsdp] feat: add TCAD recipe with Ray+FSDP distributed training#51

Open
aslyc wants to merge 2 commits into
verl-project:mainfrom
aslyc:main
Open

[recipe][gkd/fsdp] feat: add TCAD recipe with Ray+FSDP distributed training#51
aslyc wants to merge 2 commits into
verl-project:mainfrom
aslyc:main

Conversation

@aslyc

@aslyc aslyc commented Feb 26, 2026

Copy link
Copy Markdown

This commit adds an FSDP-based GKD training recipe with the official implementation of
Confidence-Aligned On-Policy Distillation (TCAD), enabling large-scale distributed training.

Specifically, it:

  • Introduces a new gkd/fsdp recipe with a modular config system (actor/critic/ref/reward/rollout/engine/optim)
  • Adds a Ray-based distributed training driver and FSDP workers for scalable rollout + training
  • Implements TCAD core components:
    • Teacher-confidence reweighting for on-policy trajectory selection
    • Support-Symmetric Truncation for stable token support construction
    • Coverage-Aware Regularization to reduce probability mass leakage and stabilize optimization
  • Provides runnable entrypoints (main_opkd.py, ray_trainer.py, run_opkd.sh) and example configs for experiments

What does this PR do?

This PR brings an end-to-end FSDP distributed training backend for GKD and integrates TCAD into the on-policy
distillation pipeline. The new recipe is isolated under gkd/fsdp to minimize coupling with existing backends
and to support flexible experimentation via configuration.

Checklist Before Starting

  • Format the PR title as [{modules}] {type}: {description}

Test

This PR introduces training-level functionality which is not fully covered by existing CI tests.

Validation performed:

  • Ran a smoke training run with the FSDP recipe to verify end-to-end execution (rollout -> loss -> optimizer step)
  • Verified distributed initialization / worker launch under Ray
  • Sanity-checked TCAD components (trajectory reweighting / truncation support / regularization terms) for numerical stability

API and Usage Example

No public API is changed. This PR adds an internal training recipe.

Example:

bash gkd/fsdp/run_opkd.sh
# or
python gkd/fsdp/main_opkd.py --config gkd/fsdp/config/opkd_trainer.yaml

Design & Code Changes

  • Added gkd/fsdp/ recipe directory with:

    • Configs: gkd/fsdp/config/**.yaml
    • Distributed runtime: fsdp_workers.py, ray_trainer.py
    • Entrypoint: main_opkd.py
  • Implemented TCAD algorithm components and integrated them into the training loop

Checklist Before Submitting

  • Read the Contribute Guide
  • Apply pre-commit checks (basic formatting and lint checks applied)
  • Add / Update documentation (added gkd/fsdp/README.md)
  • Add unit or end-to-end tests (training-level validation only)

This commit introduces the official implementation of Confidence-Aligned On-Policy Distillation (TCAD),
a novel method for improving on-policy knowledge distillation through teacher-confidence reweighting.

Key components included:
- Distributed training pipeline with Ray and FSDP support
- Configurable architecture for flexible experimentation
- Implementation of TCAD core algorithms:
  * Teacher-confidence reweighting for trajectory selection
  - Support-Symmetric Truncation for token selection
  - Coverage-Aware Regularization for distribution stabilization
- Comprehensive configuration system for hyperparameter tuning

The implementation supports large-scale distributed training and provides robust foundations for
reproducing results from the associated research paper 'Learning from What the Teacher Trusts:
Confidence-Aligned On-Policy Distillation'.

Testing: Basic functionality verified with script execution

Co-authored-by: process-cxr <fantastique0910@gmail.com>
Co-authored-by: monster119120 <285514559@qq.com>
Co-authored-by: jiaminchen-1031 <jiaminchen.sophie@gmail.con>
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @aslyc, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the distributed training capabilities by integrating the Confidence-Aligned On-Policy Distillation (TCAD) algorithm. It introduces a new, modular recipe that leverages Ray for distributed orchestration and FSDP for efficient model parallelism. The changes enable more robust and scalable training by incorporating advanced distillation techniques and providing a flexible configuration framework for various components of the training pipeline.

Highlights

  • New TCAD Recipe Integration: Introduced a new gkd/fsdp recipe for Confidence-Aligned On-Policy Distillation (TCAD), featuring a modular configuration system for actor, critic, ref, reward, rollout, engine, and optimizer components.
  • Distributed Training with Ray and FSDP: Implemented a Ray-based distributed training driver and FSDP workers to enable scalable rollout and training, supporting large-scale distributed training for GKD.
  • Core TCAD Algorithm Components: Added key TCAD components including teacher-confidence reweighting for on-policy trajectory selection, Support-Symmetric Truncation for stable token support construction, and Coverage-Aware Regularization to prevent probability mass leakage.
  • Runnable Entrypoints and Configurations: Provided runnable entrypoints (main_opkd.py, ray_trainer.py, run_opkd.sh) and example configurations to facilitate experimentation and quick starts.
  • Modular and Isolated Design: Designed the new recipe to be isolated under gkd/fsdp to minimize coupling with existing backends and ensure flexible experimentation through configuration.
Changelog
  • gkd/fsdp/README.md
    • Added a new README file detailing the TCAD method, its features, project structure, and quick start instructions.
  • gkd/fsdp/config/actor/actor.yaml
    • Added a new configuration file for the base actor, defining parameters for PPO, policy loss, checkpointing, optimizer, and profiler settings.
  • gkd/fsdp/config/actor/dp_actor.yaml
    • Added a new configuration file for the FSDP actor, inheriting from the base actor config and specifying FSDP-specific settings like strategy, gradient clipping, and entropy calculation.
  • gkd/fsdp/config/actor/megatron_actor.yaml
    • Added a new configuration file for the Megatron actor, inheriting from the base actor config and defining Megatron-specific settings.
  • gkd/fsdp/config/algorithm/rollout_correction.yaml
    • Added a new configuration file for rollout correction, including settings for importance sampling (IS) and rejection sampling (RS) aggregation levels and thresholds.
  • gkd/fsdp/config/critic/critic.yaml
    • Added a new configuration file for the base critic, defining parameters for optimizer, model, PPO, loss aggregation, checkpointing, and profiler settings.
  • gkd/fsdp/config/critic/dp_critic.yaml
    • Added a new configuration file for the FSDP critic, inheriting from the base critic config and specifying FSDP-specific model and parallelism settings.
  • gkd/fsdp/config/critic/megatron_critic.yaml
    • Added a new configuration file for the Megatron critic, inheriting from the base critic config and defining Megatron-specific model and parallelism settings, including LoRA configuration.
  • gkd/fsdp/config/data/legacy_data.yaml
    • Added a new configuration file for data loading, including tokenizer settings, file paths, prompt/response lengths, batch sizes, and data augmentation options.
  • gkd/fsdp/config/engine/fsdp.yaml
    • Added a new configuration file for the FSDP engine, defining wrapping policies, offloading options, FSDP size, prefetching, and mixed precision settings.
  • gkd/fsdp/config/engine/megatron.yaml
    • Added a new configuration file for the Megatron engine, defining offloading, tensor/pipeline/expert parallelism sizes, sequence parallelism, and recompute configurations.
  • gkd/fsdp/config/engine/veomni.yaml
    • Added a new configuration file for the VeOmni engine, specifying data/tensor/expert/pipeline parallelism modes and sizes, mixed precision, and FSDP offload options.
  • gkd/fsdp/config/model/hf_model.yaml
    • Added a new configuration file for Hugging Face models, including paths, remote code trust, gradient checkpointing, activation offload, LoRA, and fused kernel settings.
  • gkd/fsdp/config/npu_profile/npu_profile.yaml
    • Added a new configuration file for NPU profiler options, including save path, roles, collection level, and memory/shape/stack recording.
  • gkd/fsdp/config/opkd_trainer.yaml
    • Added the main OPKD trainer configuration file, orchestrating defaults for actor, data, ref, rollout, critic, and reward models, along with trainer and profiler settings.
  • gkd/fsdp/config/optim/fsdp.yaml
    • Added a new configuration file for FSDP optimizer settings, including learning rate, warmup, weight decay, betas, gradient clipping, and LR scheduler type.
  • gkd/fsdp/config/optim/megatron.yaml
    • Added a new configuration file for Megatron optimizer settings, including learning rate, warmup, weight decay, betas, gradient clipping, and LR decay styles.
  • gkd/fsdp/config/optim/veomni.yaml
    • Added a new configuration file for VeOmni optimizer settings, including learning rate, warmup, weight decay, betas, gradient clipping, and LR scheduler type.
  • gkd/fsdp/config/ref/dp_ref.yaml
    • Added a new configuration file for the FSDP reference model, inheriting from the base ref config and specifying FSDP engine settings for forward-only mode.
  • gkd/fsdp/config/ref/megatron_ref.yaml
    • Added a new configuration file for the Megatron reference model, inheriting from the base ref config and defining Megatron-specific parallelism and forward-only settings.
  • gkd/fsdp/config/ref/ref.yaml
    • Added a new base configuration file for the reference model, including strategy, torch.compile settings, log-prob batch sizes, and profiler configurations.
  • gkd/fsdp/config/reward_model/dp_reward_loop.yaml
    • Added a new configuration file for the FSDP reward loop, enabling reward loop usage and defining rollout parameters for the reward model.
  • gkd/fsdp/config/reward_model/dp_reward_model.yaml
    • Added a new configuration file for the FSDP reward model, inheriting from the base reward model config and specifying FSDP-specific model and engine settings.
  • gkd/fsdp/config/reward_model/megatron_reward_loop.yaml
    • Added a new configuration file for the Megatron reward loop, enabling reward loop usage and defining rollout parameters for the reward model.
  • gkd/fsdp/config/reward_model/megatron_reward_model.yaml
    • Added a new configuration file for the Megatron reward model, inheriting from the base reward model config and defining Megatron-specific parallelism and engine settings.
  • gkd/fsdp/config/reward_model/reward_model.yaml
    • Added a new base configuration file for the reward model, including enablement, resource pooling, model paths, batch sizes, and profiler settings.
  • gkd/fsdp/config/rollout/rollout.yaml
    • Added a new configuration file for rollout settings, including generation parameters (temperature, top-k/p), prompt/response lengths, GPU memory utilization, and multi-turn interaction options.
  • gkd/fsdp/config/runtime_env.yaml
    • Added a new configuration file for the Ray runtime environment, specifying working directory, excludes, and various environment variables for Torch, CUDA, NVTE, Ray, NCCL, Python, and VLLM.
  • gkd/fsdp/dp_actor.py
    • Implemented DataParallelOPKDActor, extending DataParallelPPOActor with methods for merging teacher and student indices, computing union logits, and updating policy using token-level KD loss and per-prompt power-weighting.
  • gkd/fsdp/fsdp_workers.py
    • Implemented OPKDWorker and AsyncOPKDWorker, extending ActorRolloutRefWorker to handle FSDP-based distributed training for OPKD, including custom model initialization and RPCs for teacher log-prob and student top-k index computation.
  • gkd/fsdp/main_opkd.py
    • Implemented the main entry point for OPKD training, including Ray initialization, TaskRunner setup, dataset creation, and trainer instantiation and execution.
  • gkd/fsdp/ray_trainer.py
    • Implemented RayOPKDTrainer, extending RayPPOTrainer to manage the distributed OPKD training loop, including dataloader creation, rollout generation, student/teacher statistics computation, and actor updates.
  • gkd/fsdp/run_opkd.sh
    • Added a new shell script to launch the OPKD training process using ray job submit, configuring various data, model, actor, rollout, ref, and trainer parameters.
Activity
  • The author (aslyc) has introduced a new FSDP-based GKD training recipe with the official implementation of Confidence-Aligned On-Policy Distillation (TCAD).
  • The author verified the end-to-end execution of the FSDP recipe through a smoke training run.
  • Distributed initialization and worker launch under Ray were successfully verified.
  • Sanity checks were performed on TCAD components, including trajectory reweighting, truncation support, and regularization terms, confirming their numerical stability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive new recipe for Confidence-Aligned On-Policy Distillation (TCAD) using Ray and FSDP. The changes are well-structured, adding new configuration files, Python modules for the training logic, and a launch script. The overall implementation follows the existing patterns in the repository.

My review has identified a few high-severity issues that need to be addressed. These include a hardcoded, user-specific PYTHONPATH in runtime_env.yaml, a critical configuration mismatch in run_opkd.sh that violates an assumption in the distillation logic, and a typo in main_opkd.py that could prevent a configuration from being applied. I have also noted several medium-severity issues, such as leftover debug code and comments, which should be cleaned up to improve code clarity and maintainability. The core algorithm implementation appears sound.

Comment thread gkd/fsdp/config/runtime_env.yaml Outdated
NVTE_FLASH_ATTN: "1"
RAY_DEBUG: "legacy"
NCCL_DEBUG: "WARN"
PYTHONPATH: "/home/work/cxr/verl-opkd"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PYTHONPATH is hardcoded to a user-specific local directory. This will cause the script to fail for other users or in different environments. This line should be removed from the version-controlled configuration. Users can set this environment variable locally if needed.

Comment thread gkd/fsdp/main_opkd.py Outdated

# [Optional] get the path of the timeline trace file from the configuration, default to None
# This file is used for performance analysis
timeline_json_file = config.ray_kwargs.get("timeline_json_file", None)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The configuration key ray_kwargs seems to be incorrect. Based on the opkd_trainer.yaml config structure, it should be ray_init. Using the wrong key will prevent the timeline_json_file setting from being applied.

Suggested change
timeline_json_file = config.ray_kwargs.get("timeline_json_file", None)
timeline_json_file = config.ray_init.get("timeline_json_file", None)

Comment thread gkd/fsdp/run_opkd.sh Outdated
Comment on lines +54 to +55
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
actor_rollout_ref.actor.ppo_micro_batch_size=128 \

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's an inconsistency in batch size configuration that violates an assumption in the dp_actor.py code. The actor's update_policy method assumes that ppo_mini_batch_size is equal to rollout.n to correctly process all rollouts for a single prompt together.

Here, actor_rollout_ref.actor.ppo_mini_batch_size is set to 8, while actor_rollout_ref.rollout.n is 16. This mismatch could lead to incorrect behavior in the power-weighting logic.

Additionally, ppo_micro_batch_size is deprecated and should be removed. You should use ppo_micro_batch_size_per_gpu instead if needed.

To fix this, please align ppo_mini_batch_size with rollout.n and remove the deprecated ppo_micro_batch_size.

Suggested change
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
actor_rollout_ref.actor.ppo_micro_batch_size=128 \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \

Comment thread gkd/fsdp/README.md
title = {Learning from What the Teacher Trusts: Confidence-Aligned On-Policy Distillation},
author = {Chen, Xinran and Chen, Jiamin and Kong, Rui and Li, Yuchen and Wang, Yu and Li, Lei and Wu, Hui and Xu, Han and Cai, Hengyi and Wang, Shuaiqiang and Yin, Dawei},
journal ={arXiv preprint},
year = {2026},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The year in the bibtex reference appears to be a typo. It's set to 2026, but the corresponding paper "Learning from What the Teacher Trusts: Confidence-Aligned On-Policy Distillation" was published on arXiv in 2024. Please correct the year to 2024 for accuracy.

Suggested change
year = {2026},
year = {2024},

# model dtype of fsdp
model_dtype: fp32

# Whether to use original parameters in fsdp. Only avaiable in fsdp1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the comment. avaiable should be available.

# Whether to use original parameters in fsdp. Only available in fsdp1


shuffle: False

# profile the actor model in `update_policy`

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment appears to be a leftover from development and is not associated with any configuration key. It should be removed to avoid confusion.

Comment thread gkd/fsdp/dp_actor.py
Comment on lines +15 to +18
if torch.cuda.is_available():
pass
else:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block is empty and serves no purpose. It should be removed to improve code clarity.

Comment thread gkd/fsdp/dp_actor.py Outdated
Comment on lines +322 to +396
# import hashlib

# def _rank_info():
# if torch.distributed.is_available() and torch.distributed.is_initialized():
# rank = torch.distributed.get_rank()
# world = torch.distributed.get_world_size()
# else:
# rank, world = 0, 1
# dev = torch.cuda.current_device() if torch.cuda.is_available() else -1
# return rank, world, dev

# def _hash_prompt(input_ids, attention_mask, response_mask):
# """
# input_ids: [B, S] (prompt + response padded)
# attention_mask: [B, S]
# response_mask: [B, R] (mask over response span; 1=valid response token, 0=response padding)

# return:
# hashes: list[str] length B, short hash for prompt tokens
# prompt_len: list[int] length B, estimated prompt lengths
# resp_len: list[int] length B, effective response lengths (non-pad) from response_mask
# """
# # Estimate prompt effective length: total_effective_len - response_effective_len.
# total_len = attention_mask.sum(dim=-1).to(torch.long) # [B]
# resp_len = response_mask.sum(dim=-1).to(torch.long) # [B]
# prompt_len = (total_len - resp_len).clamp(min=0) # [B]

# hashes = []
# for b in range(input_ids.size(0)):
# plen = int(prompt_len[b].item())
# toks = input_ids[b, :plen].detach().cpu().to(torch.int32).numpy().tobytes()
# h = hashlib.sha1(toks).hexdigest()[:10]
# hashes.append(h)
# return hashes, prompt_len.detach().cpu().tolist(), resp_len.detach().cpu().tolist()

# rank, world, dev = _rank_info()
# print(
# f"[DBG][rank {rank}/{world}][cuda:{dev}] micro_batches={len(micro_batches)} "
# f"mini_batch_size={len(mini_batch.batch)} micro_bsz={self.config.ppo_micro_batch_size_per_gpu}"
# )
# uids = None
# try:
# uids = mini_batch.non_tensor_batch.get("uid", None)
# except Exception:
# uids = None

# if uids is not None:
# uid_list = uids.tolist()
# total = len(uid_list)

# from collections import Counter

# cnt = Counter(uid_list)
# uniq = list(cnt.keys())
# n_unique = len(uniq)

# print(f"[DBG][rank {rank}][cuda:{dev}] mini_batch uid summary: total={total}, unique={n_unique}")
# uid_cnt_preview = {k: cnt[k] for k in uniq[:5]}
# print(f"[DBG][rank {rank}][cuda:{dev}] mini_batch uid counts (preview): {uid_cnt_preview}")

# # Check whether prompts are identical inside a micro-batch (should be, if grouped by rollouts).
# for i, mb in enumerate(micro_batches[: min(4, len(micro_batches))]): # show up to 4
# try:
# inp = mb.batch["input_ids"]
# am = mb.batch["attention_mask"]
# rm = mb.batch["response_mask"]
# hs, plen, rlen = _hash_prompt(inp, am, rm)
# uniq = len(set(hs))
# print(
# f"[DBG][rank {rank}][cuda:{dev}] micro[{i}] B={inp.size(0)} "
# f"prompt_hash_unique={uniq} prompt_len={plen} resp_len={rlen} hashes={hs}"
# )
# except Exception as e:
# print(f"[DBG][rank {rank}][cuda:{dev}] micro[{i}] prompt check failed: {repr(e)}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large block of commented-out code appears to be for debugging. It should be removed before merging to keep the codebase clean and maintainable.

Comment thread gkd/fsdp/ray_trainer.py
one_attention_mask = batch.batch["attention_mask"][0].to(torch.bool)
one_sentence = batch.batch["input_ids"][0]
print("INFO:", "generate text done.")
print("DEBUG:", self.tokenizer.decode(one_sentence[one_attention_mask].tolist()))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. It can generate a lot of verbose output and should be removed or placed behind a debug flag.

Comment thread gkd/fsdp/run_opkd.sh
# huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json

# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main
# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment mentions DIST_CKPT_PATH, but this variable is not used anywhere in the script. To avoid confusion, please remove the mention of DIST_CKPT_PATH.

Suggested change
# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path
# change the HF_MODEL_PATH to your own path

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant