diff --git a/dmpo/README.md b/dmpo/README.md new file mode 100644 index 00000000..f3f2478f --- /dev/null +++ b/dmpo/README.md @@ -0,0 +1,43 @@ +# Beyond Mode Collapse: Distribution Matching for Diverse Reasoning + +See [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt) for the upstream repository, install mode (rolling `main`, pinned release tag, or pinned git commit), and copy-pastable `pip` / `git` instructions where they exist. + + +This repository hosts the community implementation for the paper [Beyond Mode Collapse: Distribution Matching for Diverse Reasoning](https://arxiv.org/pdf/2605.19461). + +DMPO adds a group-wise distribution-matching objective over rollouts that share the same prompt `uid`. + +The default implementation is `grpo_dmpo`, which combines the standard GRPO policy loss with the DMPO +distribution-matching loss. The recipe also provides these variants: + +- `grpo_dmpo_zero`: skips zero-advantage groups during training, so groups without a useful advantage signal do not + contribute to the DMPO term. +- `grpo_dmpo_js`: computes the gap between the current distribution and the target distribution with + Jensen-Shannon divergence instead of the default MSE objective. +- `pure_dmpo`: updates only with the DMPO objective and does not include the GRPO policy loss. + +## Usage + +Run from a verl checkout that has this repository mounted as the `recipe` submodule: + +```bash +bash recipe/dmpo/run_qwen2.5-7b_math_grpo_dmpo_zero.sh +``` + + +## ๐Ÿ–Š๏ธ Citation + +If you find this work helpful, please consider to **star๐ŸŒŸ** this repo and cite this paper. Thanks for your support! + +```bib +@misc{li2026modecollapsedistributionmatching, + title={Beyond Mode Collapse: Distribution Matching for Diverse Reasoning}, + author={Xiaozhe Li and Yang Li and Xinyu Fang and Shengyuan Ding and Peiji Li and Yongkang Chen and Yichuan Ma and Tianyi Lyu and Linyang Li and Dahua Lin and Qipeng Guo and Qingwen Liu and Kai Chen}, + year={2026}, + eprint={2605.19461}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2605.19461}, +} +``` + diff --git a/dmpo/REQUIRED_VERL.txt b/dmpo/REQUIRED_VERL.txt new file mode 100644 index 00000000..3c81a9c2 --- /dev/null +++ b/dmpo/REQUIRED_VERL.txt @@ -0,0 +1,13 @@ +# dmpo โ€” rolling; exact commits refreshed from this workspace +UPSTREAM=https://github.com/verl-project/verl.git +MODE=rolling +BRANCH=main +# Exact upstream verl commit this file was refreshed against (core library). +VERL_COMMIT=bcb638649a50e58494a8ddd92085ad1174f674b8 +PIP_INSTALL=pip install verl@git+https://github.com/verl-project/verl.git@bcb638649a50e58494a8ddd92085ad1174f674b8 +GIT_SETUP=git clone https://github.com/verl-project/verl.git && cd verl && git checkout bcb638649a50e58494a8ddd92085ad1174f674b8 && git submodule update --init --recursive recipe +# Recipe submodule snapshot at the same verl checkout (see `git ls-tree HEAD recipe` in verl). +RECIPE_SUBMODULE_COMMIT=ba246418f4de12b845a09bba975f1a5242adc898 +RECIPE_FOLDER=dmpo +NOTES=DMPO relies on the model-engine PPO path and patches the actor loss wrapper to pass prompt uid groups to the registered policy loss. +REFRESH=Recompute: (cd verl && git rev-parse HEAD); (cd verl/recipe && git rev-parse HEAD); (cd verl/recipe && git log -1 --format=%H -- dmpo) diff --git a/dmpo/__init__.py b/dmpo/__init__.py new file mode 100644 index 00000000..b1c60949 --- /dev/null +++ b/dmpo/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMPO recipe.""" diff --git a/dmpo/config.py b/dmpo/config.py new file mode 100644 index 00000000..79341465 --- /dev/null +++ b/dmpo/config.py @@ -0,0 +1,32 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from verl.workers.config import FSDPActorConfig, PolicyLossConfig + + +@dataclass +class DMPOPolicyLossConfig(PolicyLossConfig): + """Policy loss config with DMPO-specific hyperparameters.""" + + dmpo_beta: float = 1.0 + dmpo_temperature: float = 1.0 / 15.0 + + +@dataclass +class DMPOActorConfig(FSDPActorConfig): + """Actor config that accepts DMPO policy loss fields.""" + + policy_loss: DMPOPolicyLossConfig = field(default_factory=DMPOPolicyLossConfig) diff --git a/dmpo/config/dmpo_trainer.yaml b/dmpo/config/dmpo_trainer.yaml new file mode 100644 index 00000000..e7d18340 --- /dev/null +++ b/dmpo/config/dmpo_trainer.yaml @@ -0,0 +1,25 @@ +# DMPO config overrides for verl PPO trainer. + +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +actor_rollout_ref: + model: + external_lib: recipe.dmpo.dmpo_patch + + actor: + _target_: recipe.dmpo.config.DMPOActorConfig + + policy_loss: + _target_: recipe.dmpo.config.DMPOPolicyLossConfig + loss_mode: grpo_dmpo_zero + dmpo_beta: 1.0 + dmpo_temperature: 0.06666666666666667 + +algorithm: + adv_estimator: grpo diff --git a/dmpo/dmpo_core_algos.py b/dmpo/dmpo_core_algos.py new file mode 100644 index 00000000..8dd05643 --- /dev/null +++ b/dmpo/dmpo_core_algos.py @@ -0,0 +1,347 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import numpy as np +import torch +from omegaconf import DictConfig + +from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla, register_policy_loss +from verl.utils import as_torch_index +from verl.workers.config import ActorConfig + +DMPO_POLICY_LOSS_MODES = frozenset( + { + "dmpo_zero", + "dmpo", + "dmpo_js", + "grpo_dmpo_zero", + "grpo_dmpo", + "grpo_dmpo_js", + "pure_dmpo", + } +) + + +def _get_dmpo_params(config: Optional[DictConfig | ActorConfig]) -> tuple[float, float]: + beta = 1.0 + temperature = 1.0 / 15.0 + if config is None: + return beta, temperature + + policy_loss_cfg = getattr(config, "policy_loss", None) + if policy_loss_cfg is None: + return beta, temperature + + return float(policy_loss_cfg.get("dmpo_beta", beta)), float(policy_loss_cfg.get("dmpo_temperature", temperature)) + + +def _group_softmax( + logits: torch.Tensor, group_index: torch.Tensor, num_groups: int, eps: float = 1e-10 +) -> torch.Tensor: + group_max = torch.full((num_groups,), float("-inf"), device=logits.device, dtype=logits.dtype) + group_max = group_max.scatter_reduce(0, group_index, logits, reduce="amax", include_self=False) + + exp_val = torch.exp(logits - group_max[group_index]) + group_sum = torch.zeros(num_groups, device=logits.device, dtype=logits.dtype) + group_sum = group_sum.index_add(0, group_index, exp_val) + return exp_val / (group_sum[group_index] + eps) + + +def _compute_dmpo_distribution_loss( + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + index: torch.Tensor | np.ndarray | None, + temperature: float, + divergence: str = "mse", + filter_zero_signal: bool = False, + eps: float = 1e-10, +) -> tuple[torch.Tensor, dict[str, Any]]: + if index is None: + raise ValueError("DMPO policy loss requires uid group index in the training batch.") + + group_index = as_torch_index(index, device=log_prob.device) + if group_index.numel() != log_prob.shape[0]: + raise ValueError( + f"DMPO group index length must match batch size: got {group_index.numel()} and {log_prob.shape[0]}." + ) + if group_index.numel() == 0: + return log_prob.sum() * 0.0, {"actor/dmpo_valid_samples": 0.0, "actor/dmpo_group_count": 0.0} + + num_groups = int(group_index.max().item()) + 1 + response_mask = response_mask.to(log_prob.dtype) + seq_len = response_mask.sum(dim=-1).clamp(min=1) + seq_log_prob = (log_prob * response_mask).sum(dim=-1) / seq_len + seq_advantages = (advantages * response_mask).sum(dim=-1) / seq_len + + with torch.no_grad(): + target_dist = _group_softmax(seq_advantages / temperature, group_index, num_groups, eps=eps) + active_groups = torch.zeros(num_groups, device=log_prob.device, dtype=torch.bool) + active_groups[group_index] = True + + sample_weight = torch.ones_like(seq_advantages) + if filter_zero_signal: + group_min = torch.full((num_groups,), float("inf"), device=log_prob.device, dtype=seq_advantages.dtype) + group_max = torch.full((num_groups,), float("-inf"), device=log_prob.device, dtype=seq_advantages.dtype) + group_min = group_min.scatter_reduce(0, group_index, seq_advantages, reduce="amin", include_self=False) + group_max = group_max.scatter_reduce(0, group_index, seq_advantages, reduce="amax", include_self=False) + has_signal = (group_max - group_min) > 1e-5 + sample_weight = has_signal[group_index].to(seq_advantages.dtype) + + valid_samples = sample_weight.sum() + if filter_zero_signal and valid_samples.item() <= 0: + return seq_log_prob.sum() * 0.0, { + "actor/dmpo_valid_samples": 0.0, + "actor/dmpo_group_count": active_groups.sum().detach().item(), + } + + model_dist = _group_softmax(seq_log_prob, group_index, num_groups, eps=eps) + if divergence == "mse": + loss_per_sample = (target_dist - model_dist).square() + loss = (loss_per_sample * sample_weight).sum() / valid_samples.clamp(min=1) + elif divergence == "js": + target_safe = torch.clamp(target_dist, min=eps) + model_safe = torch.clamp(model_dist, min=eps) + mixture = 0.5 * (target_safe + model_safe) + js_per_sample = 0.5 * ( + target_safe * (torch.log(target_safe) - torch.log(mixture)) + + model_safe * (torch.log(model_safe) - torch.log(mixture)) + ) + js_per_group = torch.zeros(num_groups, device=log_prob.device, dtype=log_prob.dtype) + js_per_group = js_per_group.index_add(0, group_index, js_per_sample * sample_weight) + group_weight = torch.zeros(num_groups, device=log_prob.device, dtype=log_prob.dtype) + group_weight = group_weight.index_add(0, group_index, sample_weight) + loss = js_per_group[group_weight > 0].mean() + else: + raise ValueError(f"Unsupported DMPO divergence: {divergence}.") + + metrics = { + "actor/dmpo_valid_samples": valid_samples.detach().item(), + "actor/dmpo_group_count": active_groups.sum().detach().item(), + } + return loss, metrics + + +def _dmpo_metrics(loss: torch.Tensor, beta: float, temperature: float, metrics: dict[str, Any]) -> dict[str, Any]: + return { + **metrics, + "actor/dmpo_loss": loss.detach().item(), + "actor/dmpo_beta": beta, + "actor/dmpo_temperature": temperature, + } + + +@register_policy_loss("dmpo_zero") # type: ignore[arg-type] +def compute_policy_loss_dmpo_zero( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, + index: torch.Tensor | np.ndarray | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute the DMPO distribution-matching loss while ignoring zero-signal groups.""" + del old_log_prob, loss_agg_mode, rollout_is_weights + beta, temperature = _get_dmpo_params(config) + dmpo_loss, metrics = _compute_dmpo_distribution_loss( + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + index=index, + temperature=temperature, + filter_zero_signal=True, + ) + return dmpo_loss, _dmpo_metrics(dmpo_loss, beta, temperature, metrics) + + +@register_policy_loss("dmpo") # type: ignore[arg-type] +def compute_policy_loss_dmpo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, + index: torch.Tensor | np.ndarray | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute the DMPO distribution-matching loss with MSE divergence.""" + del old_log_prob, loss_agg_mode, rollout_is_weights + beta, temperature = _get_dmpo_params(config) + dmpo_loss, metrics = _compute_dmpo_distribution_loss( + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + index=index, + temperature=temperature, + ) + return dmpo_loss, _dmpo_metrics(dmpo_loss, beta, temperature, metrics) + + +@register_policy_loss("dmpo_js") # type: ignore[arg-type] +def compute_policy_loss_dmpo_js( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, + index: torch.Tensor | np.ndarray | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute the DMPO distribution-matching loss with Jensen-Shannon divergence.""" + del old_log_prob, loss_agg_mode, rollout_is_weights + beta, temperature = _get_dmpo_params(config) + dmpo_loss, metrics = _compute_dmpo_distribution_loss( + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + index=index, + temperature=temperature, + divergence="js", + ) + return dmpo_loss, _dmpo_metrics(dmpo_loss, beta, temperature, metrics) + + +def _compute_grpo_dmpo_loss( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str, + config: Optional[DictConfig | ActorConfig], + rollout_is_weights: torch.Tensor | None, + index: torch.Tensor | np.ndarray | None, + divergence: str = "mse", + filter_zero_signal: bool = False, + pure_dmpo: bool = False, +) -> tuple[torch.Tensor, dict[str, Any]]: + pg_loss, pg_metrics = compute_policy_loss_vanilla( # type: ignore[call-arg] + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=config, + rollout_is_weights=rollout_is_weights, + ) + beta, temperature = _get_dmpo_params(config) + dmpo_loss, dmpo_metrics = _compute_dmpo_distribution_loss( + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + index=index, + temperature=temperature, + divergence=divergence, + filter_zero_signal=filter_zero_signal, + ) + total_loss = beta * dmpo_loss if pure_dmpo else pg_loss + beta * dmpo_loss + pg_metrics.update(_dmpo_metrics(dmpo_loss, beta, temperature, dmpo_metrics)) + return total_loss, pg_metrics + + +@register_policy_loss("grpo_dmpo_zero") # type: ignore[arg-type] +def compute_policy_loss_grpo_dmpo_zero( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, + index: torch.Tensor | np.ndarray | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute clipped GRPO plus DMPO over non-degenerate uid groups.""" + return _compute_grpo_dmpo_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + loss_agg_mode, + config, + rollout_is_weights, + index, + filter_zero_signal=True, + ) + + +@register_policy_loss("grpo_dmpo") # type: ignore[arg-type] +def compute_policy_loss_grpo_dmpo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, + index: torch.Tensor | np.ndarray | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute clipped GRPO plus DMPO with MSE divergence.""" + return _compute_grpo_dmpo_loss( + old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, config, rollout_is_weights, index + ) + + +@register_policy_loss("grpo_dmpo_js") # type: ignore[arg-type] +def compute_policy_loss_grpo_dmpo_js( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, + index: torch.Tensor | np.ndarray | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute clipped GRPO plus DMPO with Jensen-Shannon divergence.""" + return _compute_grpo_dmpo_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + loss_agg_mode, + config, + rollout_is_weights, + index, + divergence="js", + ) + + +@register_policy_loss("pure_dmpo") # type: ignore[arg-type] +def compute_policy_loss_pure_dmpo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, + index: torch.Tensor | np.ndarray | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute beta-scaled DMPO while still reporting vanilla PPO diagnostics.""" + return _compute_grpo_dmpo_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + loss_agg_mode, + config, + rollout_is_weights, + index, + pure_dmpo=True, + ) diff --git a/dmpo/dmpo_losses.py b/dmpo/dmpo_losses.py new file mode 100644 index 00000000..e6baaf79 --- /dev/null +++ b/dmpo/dmpo_losses.py @@ -0,0 +1,112 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensordict import TensorDict + +from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty +from verl.utils import tensordict_utils as tu +from verl.utils.metric import AggregationType, Metric +from verl.workers.config import ActorConfig +from verl.workers.utils.padding import no_padding_2_padding + +from .dmpo_core_algos import DMPO_POLICY_LOSS_MODES + + +def dmpo_ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None): + """PPO loss wrapper that passes the prompt uid group index to DMPO policy losses.""" + del dp_group + log_prob = no_padding_2_padding(model_output["log_probs"], data) + entropy = model_output.get("entropy", None) + if entropy is not None: + entropy = no_padding_2_padding(entropy, data) + + config.global_batch_info["dp_size"] = data["dp_size"] + config.global_batch_info["batch_num_tokens"] = data["batch_num_tokens"] + config.global_batch_info["global_batch_size"] = data["global_batch_size"] + config.global_batch_info["loss_scale_factor"] = config.loss_scale_factor + + if ( + data["dp_size"] > 1 + or data["batch_num_tokens"] is not None + or data["global_batch_size"] is not None + or config.loss_scale_factor is not None + ): + metric_aggregation = AggregationType.SUM + else: + metric_aggregation = AggregationType.MEAN + + metrics = {} + + loss_mode = config.policy_loss.get("loss_mode", "vanilla") + group_index = None + if loss_mode in DMPO_POLICY_LOSS_MODES: + group_index = tu.get(data, "uid", None) + if group_index is None: + raise ValueError("DMPO policy losses require uid in the training batch.") + + fields = ["response_mask", "old_log_probs", "advantages"] + if "rollout_is_weights" in data: + fields.append("rollout_is_weights") + if "ref_log_prob" in data: + fields.append("ref_log_prob") + data = data.select(*fields).to_padded_tensor() + + response_mask = data["response_mask"].to(bool) + old_log_prob = data["old_log_probs"] + advantages = data["advantages"] + rollout_is_weights = data.get("rollout_is_weights", None) + + policy_loss_fn = get_policy_loss_fn(loss_mode) + policy_loss_kwargs = { + "old_log_prob": old_log_prob, + "log_prob": log_prob, + "advantages": advantages, + "response_mask": response_mask, + "loss_agg_mode": config.loss_agg_mode, + "config": config, + "rollout_is_weights": rollout_is_weights, + } + if loss_mode in DMPO_POLICY_LOSS_MODES: + policy_loss_kwargs["index"] = group_index + pg_loss, pg_metrics = policy_loss_fn(**policy_loss_kwargs) + + metrics.update(Metric.from_dict(pg_metrics, aggregation=AggregationType.MEAN)) + metrics["actor/pg_loss"] = Metric(value=pg_loss, aggregation=metric_aggregation) + policy_loss = pg_loss + + if entropy is not None: + entropy_loss = agg_loss( + loss_mat=entropy, + loss_mask=response_mask, + loss_agg_mode=config.loss_agg_mode, + **config.global_batch_info, + ) + policy_loss -= config.entropy_coeff * entropy_loss + metrics["actor/entropy_loss"] = Metric(value=entropy_loss, aggregation=metric_aggregation) + + if config.use_kl_loss: + ref_log_prob = data["ref_log_prob"] + kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=config.kl_loss_type) + kl_loss = agg_loss( + loss_mat=kld, + loss_mask=response_mask, + loss_agg_mode=config.loss_agg_mode, + **config.global_batch_info, + ) + + policy_loss += kl_loss * config.kl_loss_coef + metrics["kl_loss"] = Metric(value=kl_loss, aggregation=metric_aggregation) + metrics["kl_coef"] = config.kl_loss_coef + + return policy_loss, metrics diff --git a/dmpo/dmpo_patch.py b/dmpo/dmpo_patch.py new file mode 100644 index 00000000..fb246400 --- /dev/null +++ b/dmpo/dmpo_patch.py @@ -0,0 +1,34 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Register DMPO losses and patch actor PPO loss dispatch.""" + +from . import dmpo_core_algos # noqa: F401 +from .dmpo_losses import dmpo_ppo_loss + + +def apply_dmpo_patch() -> None: + import verl.workers.utils.losses as losses + + losses.ppo_loss = dmpo_ppo_loss + + try: + import verl.workers.engine_workers as engine_workers + + engine_workers.ppo_loss = dmpo_ppo_loss + except ImportError: + pass + + +apply_dmpo_patch() diff --git a/dmpo/dmpo_worker.py b/dmpo/dmpo_worker.py new file mode 100644 index 00000000..6ba83b33 --- /dev/null +++ b/dmpo/dmpo_worker.py @@ -0,0 +1,27 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.single_controller.base.decorator import Dispatch, register +from verl.workers.engine_workers import ActorRolloutRefWorker + +from .dmpo_patch import apply_dmpo_patch + + +class DMPOActorRolloutRefWorker(ActorRolloutRefWorker): + """Actor rollout worker with DMPO policy losses registered before model initialization.""" + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + apply_dmpo_patch() + return super().init_model() diff --git a/dmpo/main_dmpo.py b/dmpo/main_dmpo.py new file mode 100644 index 00000000..1fe669b2 --- /dev/null +++ b/dmpo/main_dmpo.py @@ -0,0 +1,61 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +import ray + +from verl.experimental.reward_loop import migrate_legacy_reward_impl +from verl.trainer.main_ppo import TaskRunner, run_ppo +from verl.trainer.ppo.utils import need_reference_policy +from verl.utils.device import auto_set_device + +from .dmpo_patch import apply_dmpo_patch + + +class DMPOTaskRunner(TaskRunner): + """Task runner that uses the DMPO actor rollout worker.""" + + def add_actor_rollout_worker(self, config): + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role + + from .dmpo_worker import DMPOActorRolloutRefWorker + + actor_rollout_cls = DMPOActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + if need_reference_policy(config) and not ref_in_actor: + role = Role.ActorRolloutRef + else: + role = Role.ActorRollout + + self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) + self.mapping[role] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + +@hydra.main(config_path="config", config_name="dmpo_trainer", version_base=None) +def main(config): + apply_dmpo_patch() + auto_set_device(config) + config = migrate_legacy_reward_impl(config) + run_ppo(config, task_runner_class=ray.remote(num_cpus=1)(DMPOTaskRunner)) + + +if __name__ == "__main__": + main() diff --git a/dmpo/run_qwen2.5-7b_math_grpo_dmpo_zero.sh b/dmpo/run_qwen2.5-7b_math_grpo_dmpo_zero.sh new file mode 100644 index 00000000..2693a2fc --- /dev/null +++ b/dmpo/run_qwen2.5-7b_math_grpo_dmpo_zero.sh @@ -0,0 +1,76 @@ +# Example DMPO training script for Qwen2.5-Math on OpenR1-Math. +set -x + +export HYDRA_FULL_ERROR=${HYDRA_FULL_ERROR:-1} + +train_files=${TRAIN_FILES:-"['$HOME/data/openr1_math/train.parquet']"} +test_files=${VAL_FILES:-"['$HOME/data/aime2024/test.parquet']"} +model_path=${MODEL_PATH:-"$HOME/models/Qwen2.5-Math-7B-16k-think"} +output_dir=${OUTPUT_DIR:-"$PWD/outputs/qwen2_5_math_grpo_dmpo_zero_openR1_46k_beta_2.0"} + +python3 -m recipe.dmpo.main_dmpo \ + actor_rollout_ref.actor.policy_loss.loss_mode=grpo_dmpo_zero \ + actor_rollout_ref.actor.policy_loss.dmpo_beta=2.0 \ + actor_rollout_ref.actor.policy_loss.dmpo_temperature=0.06666666666666667 \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=128 \ + data.val_batch_size=128 \ + data.truncation=error \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.filter_overlong_prompts_workers=8 \ + data.max_prompt_length=2048 \ + data.max_response_length=8192 \ + actor_rollout_ref.model.path="$model_path" \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=10240 \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.28 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=10240 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.response_length=8192 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=10240 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.max_num_batched_tokens=10240 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_ctrl.kl_coef=0.0 \ + trainer.critic_warmup=0 \ + trainer.default_hdfs_dir=null \ + trainer.default_local_dir="$output_dir" \ + trainer.rollout_data_dir="$output_dir/rollout" \ + trainer.logger='["console","tensorboard"]' \ + trainer.project_name=qwen2_5_math_openR1_46k \ + trainer.experiment_name=qwen2_5_math_grpo_dmpo_zero_openR1_46k_beta_2.0 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=50 \ + trainer.test_freq=-1 \ + trainer.total_training_steps=300 \ + trainer.total_epochs=1 \ + "$@" diff --git a/dmpo/tests/test_dmpo_loss.py b/dmpo/tests/test_dmpo_loss.py new file mode 100644 index 00000000..be2d8096 --- /dev/null +++ b/dmpo/tests/test_dmpo_loss.py @@ -0,0 +1,82 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorData, NonTensorStack + +from dmpo.config import DMPOActorConfig +from dmpo.dmpo_core_algos import DMPO_POLICY_LOSS_MODES +from dmpo.dmpo_losses import dmpo_ppo_loss +from verl.trainer.ppo.core_algos import get_policy_loss_fn + + +def _config(): + config = DMPOActorConfig(strategy="fsdp", rollout_n=2, ppo_micro_batch_size=2) + object.__setattr__(config.policy_loss, "dmpo_beta", 2.0) + object.__setattr__(config.policy_loss, "dmpo_temperature", 1.0) + return config + + +def test_dmpo_policy_losses_backward(): + config = _config() + old_log_prob = torch.zeros(4, 2) + log_prob = torch.tensor([[-0.1, -0.2], [-1.0, -1.1], [-0.3, -0.4], [-0.8, -0.9]], requires_grad=True) + advantages = torch.tensor([[1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [1.0, 1.0]]) + response_mask = torch.ones(4, 2, dtype=torch.bool) + index = np.array(["prompt-a", "prompt-a", "prompt-b", "prompt-b"], dtype=object) + + for name in DMPO_POLICY_LOSS_MODES: + loss_fn = get_policy_loss_fn(name) + loss, metrics = loss_fn( + old_log_prob, log_prob, advantages, response_mask, "token-mean", config, None, index=index + ) + assert loss.ndim == 0 + assert torch.isfinite(loss) + assert "actor/dmpo_loss" in metrics + + loss.backward() + assert log_prob.grad is not None + assert torch.isfinite(log_prob.grad).all() + + +def test_dmpo_ppo_loss_passes_uid(): + config = _config() + object.__setattr__(config.policy_loss, "loss_mode", "grpo_dmpo") + bsz, prompt_len, response_len = 4, 1, 2 + uid = NonTensorStack.from_list([NonTensorData(item) for item in ["a", "a", "b", "b"]]) + data = TensorDict( + { + "prompts": torch.ones(bsz, prompt_len, dtype=torch.long), + "responses": torch.ones(bsz, response_len, dtype=torch.long), + "attention_mask": torch.ones(bsz, prompt_len + response_len, dtype=torch.long), + "response_mask": torch.ones(bsz, response_len, dtype=torch.bool), + "old_log_probs": torch.zeros(bsz, response_len), + "advantages": torch.tensor([[1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [1.0, 1.0]]), + "uid": uid, + "dp_size": NonTensorData(1), + "batch_num_tokens": NonTensorData(None), + "global_batch_size": NonTensorData(None), + }, + batch_size=[bsz], + ) + flat_log_probs = torch.linspace(-1.0, -0.1, bsz * (prompt_len + response_len), requires_grad=True) + + loss, metrics = dmpo_ppo_loss(config, {"log_probs": flat_log_probs}, data) + + assert torch.isfinite(loss) + assert "actor/dmpo_loss" in metrics + loss.backward() + assert flat_log_probs.grad is not None