Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions gkd_ascend/README.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions gkd_ascend/REQUIRED_VERL.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
VERL_COMMIT=56d18ec631dcb125c7818ad381886aefd761ae23
PIP_INSTALL=pip install verl@git+https://github.com/verl-project/verl.git@56d18ec631dcb125c7818ad381886aefd761ae23
GIT_SETUP=git clone https://github.com/verl-project/verl.git && cd verl && git checkout 56d18ec631dcb125c7818ad381886aefd761ae23
64 changes: 64 additions & 0 deletions gkd_ascend/agent_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os

import ray

from verl.experimental.agent_loop.agent_loop import AgentLoopManager
from verl.protocol import DataProto

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class GKDAgentLoopManager(AgentLoopManager):
async def generate_sequences_async(self, prompts: DataProto) -> DataProto:
"""Split input batch and dispatch to agent loop workers (async version).

Args:
prompts (DataProto): Input batch.

Returns:
DataProto: Output batch.
"""

chunkes = prompts.chunk(len(self.agent_loop_workers))
# Use asyncio.gather with ray.get wrapped in asyncio.to_thread to avoid blocking
import asyncio

outputs = await asyncio.gather(
*[
asyncio.to_thread(ray.get, worker.generate_sequences.remote(chunk))
for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
]
)
Comment on lines +38 to +47

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo chunkes which should be chunks. Also, asyncio is imported again inside the method, but it's already imported at the top of the file. This can be cleaned up.

        chunks = prompts.chunk(len(self.agent_loop_workers))
        # Use asyncio.gather with ray.get wrapped in asyncio.to_thread to avoid blocking
        outputs = await asyncio.gather(
            *[
                asyncio.to_thread(ray.get, worker.generate_sequences.remote(chunk))
                for worker, chunk in zip(self.agent_loop_workers, chunks, strict=True)
            ]
        )

output = DataProto.concat(outputs)

# calculate performance metrics
metrics = [output.meta_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]]
timing = self._performance_metrics(metrics, output)

output.meta_info = {"timing": timing, **outputs[0].meta_info}
return output

async def wake_up(self):
await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas])

async def sleep(self):
await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])

async def clear_kv_cache(self):
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
26 changes: 26 additions & 0 deletions gkd_ascend/config/on_policy_distill_megatron_trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
hydra:
searchpath:
- file://verl/trainer/config

defaults:
- ppo_megatron_trainer
- _self_

# config for the rollout (only for resource isolation)
rollout:
# Number of nodes used in the rollout
nnodes: 1
# Number of GPUs per node
n_gpus_per_node: 8

actor_rollout_ref:
hybrid_engine: False

teacher:
server_ip: localhost
server_port: 15555
overlap_rollout: False
n_server_workers: 1

trainer:
scheduler: one_step_off
26 changes: 26 additions & 0 deletions gkd_ascend/config/on_policy_distill_trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
hydra:
searchpath:
- file://verl/trainer/config

defaults:
- ppo_trainer
- _self_

# config for the rollout (only for resource isolation)
rollout:
# Number of nodes used in the rollout
nnodes: 1
# Number of GPUs per node
n_gpus_per_node: 8

actor_rollout_ref:
hybrid_engine: False

teacher:
server_ip: localhost
server_port: 15555
overlap_rollout: False
n_server_workers: 1

trainer:
scheduler: one_step_off
41 changes: 41 additions & 0 deletions gkd_ascend/distributed_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from verl.utils.device import is_npu_available


def vllm_stateless_init_process_group(master_address, master_port, rank, world_size, device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
# NOTE: If it is necessary to support weight synchronization with the sglang backend in the future,
# the following can be used:
# from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
# from sglang.srt.distributed.utils import statelessprocessgroup
if is_npu_available:
from vllm_ascend.distributed.device_communicators.pyhccl import (
PyHcclCommunicator as PyNcclCommunicator,
)
else:
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

pg = StatelessProcessGroup.create(host=master_address, port=master_port, rank=rank, world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
77 changes: 77 additions & 0 deletions gkd_ascend/fsdp_kl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2025 Individual Contributor: furunding
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FSDP-adapted version of the KL distillation loss.

Key difference between FSDP and Megatron: FSDP does not shard the vocab
dimension across tensor-parallel ranks, so the logits tensor on every rank
already contains the full vocab dimension. We can therefore run the standard
softmax / KL directly. The semantics are kept consistent with
``recipe.gkd.megatron_kl_loss.vocab_parallel_kl_divergence``:

* Use KL(P||Q), where P is the teacher (target) and Q is the student (source).
* Only compute on the top-k indices provided by the teacher (top-k distillation).
* The output is a per-token loss with shape equal to ``logits.shape[:-1]``.
"""

from __future__ import annotations

import torch


def topk_kl_divergence(
logits: torch.Tensor,
teacher_topk_logps: torch.Tensor,
teacher_topk_indices: torch.Tensor,
) -> torch.Tensor:
"""Compute the per-token KL(P||Q) loss restricted to the teacher's top-k.

Args:
logits: Student model logits with shape ``(..., vocab_size)``.
teacher_topk_logps: Teacher model log-probabilities on the top-k
indices, with shape ``(..., top_k)``.
teacher_topk_indices: Vocab indices corresponding to the teacher's
top-k entries, with shape ``(..., top_k)`` and dtype ``long``.

Returns:
Per-token KL loss with shape ``logits.shape[:-1]``.
"""
assert logits.shape[:-1] == teacher_topk_logps.shape[:-1], (
f"logits/teacher_topk_logps leading dims mismatch: {logits.shape} vs {teacher_topk_logps.shape}"
)
assert teacher_topk_logps.shape == teacher_topk_indices.shape, (
f"teacher_topk_logps/teacher_topk_indices shape mismatch: "
f"{teacher_topk_logps.shape} vs {teacher_topk_indices.shape}"
)

# Compute the student log-softmax once over the full vocab to avoid
# repeated logsumexp evaluations.
student_logps = torch.nn.functional.log_softmax(logits.float(), dim=-1)

# Gather the student log-probs at the teacher's top-k indices; the
# resulting shape equals ``teacher_topk_logps``.
student_topk_logps = torch.gather(
student_logps,
dim=-1,
index=teacher_topk_indices.long(),
)

teacher_topk_logps = teacher_topk_logps.to(student_topk_logps.dtype)
teacher_topk_probs = torch.exp(teacher_topk_logps)

# KL(P||Q) = sum_k P_k * (log P_k - log Q_k)
per_token_kl = torch.sum(
teacher_topk_probs * (teacher_topk_logps - student_topk_logps),
dim=-1,
)
return per_token_kl
Loading
Loading