[recipe][gkd/fsdp] feat: add TCAD recipe with Ray+FSDP distributed training#51
[recipe][gkd/fsdp] feat: add TCAD recipe with Ray+FSDP distributed training#51aslyc wants to merge 2 commits into
Conversation
This commit introduces the official implementation of Confidence-Aligned On-Policy Distillation (TCAD), a novel method for improving on-policy knowledge distillation through teacher-confidence reweighting. Key components included: - Distributed training pipeline with Ray and FSDP support - Configurable architecture for flexible experimentation - Implementation of TCAD core algorithms: * Teacher-confidence reweighting for trajectory selection - Support-Symmetric Truncation for token selection - Coverage-Aware Regularization for distribution stabilization - Comprehensive configuration system for hyperparameter tuning The implementation supports large-scale distributed training and provides robust foundations for reproducing results from the associated research paper 'Learning from What the Teacher Trusts: Confidence-Aligned On-Policy Distillation'. Testing: Basic functionality verified with script execution Co-authored-by: process-cxr <fantastique0910@gmail.com> Co-authored-by: monster119120 <285514559@qq.com> Co-authored-by: jiaminchen-1031 <jiaminchen.sophie@gmail.con>
Summary of ChangesHello @aslyc, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the distributed training capabilities by integrating the Confidence-Aligned On-Policy Distillation (TCAD) algorithm. It introduces a new, modular recipe that leverages Ray for distributed orchestration and FSDP for efficient model parallelism. The changes enable more robust and scalable training by incorporating advanced distillation techniques and providing a flexible configuration framework for various components of the training pipeline. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive new recipe for Confidence-Aligned On-Policy Distillation (TCAD) using Ray and FSDP. The changes are well-structured, adding new configuration files, Python modules for the training logic, and a launch script. The overall implementation follows the existing patterns in the repository.
My review has identified a few high-severity issues that need to be addressed. These include a hardcoded, user-specific PYTHONPATH in runtime_env.yaml, a critical configuration mismatch in run_opkd.sh that violates an assumption in the distillation logic, and a typo in main_opkd.py that could prevent a configuration from being applied. I have also noted several medium-severity issues, such as leftover debug code and comments, which should be cleaned up to improve code clarity and maintainability. The core algorithm implementation appears sound.
| NVTE_FLASH_ATTN: "1" | ||
| RAY_DEBUG: "legacy" | ||
| NCCL_DEBUG: "WARN" | ||
| PYTHONPATH: "/home/work/cxr/verl-opkd" |
There was a problem hiding this comment.
|
|
||
| # [Optional] get the path of the timeline trace file from the configuration, default to None | ||
| # This file is used for performance analysis | ||
| timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) |
There was a problem hiding this comment.
The configuration key ray_kwargs seems to be incorrect. Based on the opkd_trainer.yaml config structure, it should be ray_init. Using the wrong key will prevent the timeline_json_file setting from being applied.
| timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) | |
| timeline_json_file = config.ray_init.get("timeline_json_file", None) |
| actor_rollout_ref.actor.ppo_mini_batch_size=8 \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size=128 \ |
There was a problem hiding this comment.
There's an inconsistency in batch size configuration that violates an assumption in the dp_actor.py code. The actor's update_policy method assumes that ppo_mini_batch_size is equal to rollout.n to correctly process all rollouts for a single prompt together.
Here, actor_rollout_ref.actor.ppo_mini_batch_size is set to 8, while actor_rollout_ref.rollout.n is 16. This mismatch could lead to incorrect behavior in the power-weighting logic.
Additionally, ppo_micro_batch_size is deprecated and should be removed. You should use ppo_micro_batch_size_per_gpu instead if needed.
To fix this, please align ppo_mini_batch_size with rollout.n and remove the deprecated ppo_micro_batch_size.
| actor_rollout_ref.actor.ppo_mini_batch_size=8 \ | |
| actor_rollout_ref.actor.ppo_micro_batch_size=128 \ | |
| actor_rollout_ref.actor.ppo_mini_batch_size=16 \ |
| title = {Learning from What the Teacher Trusts: Confidence-Aligned On-Policy Distillation}, | ||
| author = {Chen, Xinran and Chen, Jiamin and Kong, Rui and Li, Yuchen and Wang, Yu and Li, Lei and Wu, Hui and Xu, Han and Cai, Hengyi and Wang, Shuaiqiang and Yin, Dawei}, | ||
| journal ={arXiv preprint}, | ||
| year = {2026}, |
There was a problem hiding this comment.
The year in the bibtex reference appears to be a typo. It's set to 2026, but the corresponding paper "Learning from What the Teacher Trusts: Confidence-Aligned On-Policy Distillation" was published on arXiv in 2024. Please correct the year to 2024 for accuracy.
| year = {2026}, | |
| year = {2024}, |
| # model dtype of fsdp | ||
| model_dtype: fp32 | ||
|
|
||
| # Whether to use original parameters in fsdp. Only avaiable in fsdp1 |
|
|
||
| shuffle: False | ||
|
|
||
| # profile the actor model in `update_policy` |
| if torch.cuda.is_available(): | ||
| pass | ||
| else: | ||
| pass |
| # import hashlib | ||
|
|
||
| # def _rank_info(): | ||
| # if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
| # rank = torch.distributed.get_rank() | ||
| # world = torch.distributed.get_world_size() | ||
| # else: | ||
| # rank, world = 0, 1 | ||
| # dev = torch.cuda.current_device() if torch.cuda.is_available() else -1 | ||
| # return rank, world, dev | ||
|
|
||
| # def _hash_prompt(input_ids, attention_mask, response_mask): | ||
| # """ | ||
| # input_ids: [B, S] (prompt + response padded) | ||
| # attention_mask: [B, S] | ||
| # response_mask: [B, R] (mask over response span; 1=valid response token, 0=response padding) | ||
|
|
||
| # return: | ||
| # hashes: list[str] length B, short hash for prompt tokens | ||
| # prompt_len: list[int] length B, estimated prompt lengths | ||
| # resp_len: list[int] length B, effective response lengths (non-pad) from response_mask | ||
| # """ | ||
| # # Estimate prompt effective length: total_effective_len - response_effective_len. | ||
| # total_len = attention_mask.sum(dim=-1).to(torch.long) # [B] | ||
| # resp_len = response_mask.sum(dim=-1).to(torch.long) # [B] | ||
| # prompt_len = (total_len - resp_len).clamp(min=0) # [B] | ||
|
|
||
| # hashes = [] | ||
| # for b in range(input_ids.size(0)): | ||
| # plen = int(prompt_len[b].item()) | ||
| # toks = input_ids[b, :plen].detach().cpu().to(torch.int32).numpy().tobytes() | ||
| # h = hashlib.sha1(toks).hexdigest()[:10] | ||
| # hashes.append(h) | ||
| # return hashes, prompt_len.detach().cpu().tolist(), resp_len.detach().cpu().tolist() | ||
|
|
||
| # rank, world, dev = _rank_info() | ||
| # print( | ||
| # f"[DBG][rank {rank}/{world}][cuda:{dev}] micro_batches={len(micro_batches)} " | ||
| # f"mini_batch_size={len(mini_batch.batch)} micro_bsz={self.config.ppo_micro_batch_size_per_gpu}" | ||
| # ) | ||
| # uids = None | ||
| # try: | ||
| # uids = mini_batch.non_tensor_batch.get("uid", None) | ||
| # except Exception: | ||
| # uids = None | ||
|
|
||
| # if uids is not None: | ||
| # uid_list = uids.tolist() | ||
| # total = len(uid_list) | ||
|
|
||
| # from collections import Counter | ||
|
|
||
| # cnt = Counter(uid_list) | ||
| # uniq = list(cnt.keys()) | ||
| # n_unique = len(uniq) | ||
|
|
||
| # print(f"[DBG][rank {rank}][cuda:{dev}] mini_batch uid summary: total={total}, unique={n_unique}") | ||
| # uid_cnt_preview = {k: cnt[k] for k in uniq[:5]} | ||
| # print(f"[DBG][rank {rank}][cuda:{dev}] mini_batch uid counts (preview): {uid_cnt_preview}") | ||
|
|
||
| # # Check whether prompts are identical inside a micro-batch (should be, if grouped by rollouts). | ||
| # for i, mb in enumerate(micro_batches[: min(4, len(micro_batches))]): # show up to 4 | ||
| # try: | ||
| # inp = mb.batch["input_ids"] | ||
| # am = mb.batch["attention_mask"] | ||
| # rm = mb.batch["response_mask"] | ||
| # hs, plen, rlen = _hash_prompt(inp, am, rm) | ||
| # uniq = len(set(hs)) | ||
| # print( | ||
| # f"[DBG][rank {rank}][cuda:{dev}] micro[{i}] B={inp.size(0)} " | ||
| # f"prompt_hash_unique={uniq} prompt_len={plen} resp_len={rlen} hashes={hs}" | ||
| # ) | ||
| # except Exception as e: | ||
| # print(f"[DBG][rank {rank}][cuda:{dev}] micro[{i}] prompt check failed: {repr(e)}") | ||
|
|
| one_attention_mask = batch.batch["attention_mask"][0].to(torch.bool) | ||
| one_sentence = batch.batch["input_ids"][0] | ||
| print("INFO:", "generate text done.") | ||
| print("DEBUG:", self.tokenizer.decode(one_sentence[one_attention_mask].tolist())) |
| # huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json | ||
|
|
||
| # 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main | ||
| # change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path |
There was a problem hiding this comment.
This commit adds an FSDP-based GKD training recipe with the official implementation of
Confidence-Aligned On-Policy Distillation (TCAD), enabling large-scale distributed training.
Specifically, it:
gkd/fsdprecipe with a modular config system (actor/critic/ref/reward/rollout/engine/optim)main_opkd.py,ray_trainer.py,run_opkd.sh) and example configs for experimentsWhat does this PR do?
This PR brings an end-to-end FSDP distributed training backend for GKD and integrates TCAD into the on-policy
distillation pipeline. The new recipe is isolated under
gkd/fsdpto minimize coupling with existing backendsand to support flexible experimentation via configuration.
Checklist Before Starting
[{modules}] {type}: {description}Test
This PR introduces training-level functionality which is not fully covered by existing CI tests.
Validation performed:
API and Usage Example
No public API is changed. This PR adds an internal training recipe.
Example:
bash gkd/fsdp/run_opkd.sh # or python gkd/fsdp/main_opkd.py --config gkd/fsdp/config/opkd_trainer.yamlDesign & Code Changes
Added
gkd/fsdp/recipe directory with:gkd/fsdp/config/**.yamlfsdp_workers.py,ray_trainer.pymain_opkd.pyImplemented TCAD algorithm components and integrated them into the training loop
Checklist Before Submitting
gkd/fsdp/README.md)