Skip to content

[recipe, algo] feat: add DMPO algorithm#105

Open
OliverLeeXZ wants to merge 2 commits into
verl-project:mainfrom
OliverLeeXZ:DMPO
Open

[recipe, algo] feat: add DMPO algorithm#105
OliverLeeXZ wants to merge 2 commits into
verl-project:mainfrom
OliverLeeXZ:DMPO

Conversation

@OliverLeeXZ

Copy link
Copy Markdown

What does this PR do?

This PR adds the official DMPO recipe implementation for training LLMs on diverse reasoning tasks, which hosts the community implementation for the paper Beyond Mode Collapse: Distribution Matching for Diverse Reasoning.

DMPO introduces a group-wise distribution-matching objective over rollouts that share the same prompt uid. This recipe registers multiple DMPO policy loss variants and provides a Qwen2.5-Math GRPO+DMPO training example.

Test

Validated locally with:

ruff check dmpo
ruff format --check dmpo
python3 -m compileall -q dmpo

Also smoke-tested all registered DMPO policy loss modes with forward/backward:

  • grpo_dmpo
  • grpo_dmpo_zero
  • grpo_dmpo_js
  • pure_dmpo

Verified that dmpo_ppo_loss correctly passes prompt uid groups into the DMPO policy loss.


API and Usage Example

Run from a verl checkout with this repository mounted as the recipe submodule:

bash recipe/dmpo/run_qwen2.5-7b_math_grpo_dmpo_zero.sh

Default config:

actor_rollout_ref:
  actor:
    policy_loss:
      loss_mode: grpo_dmpo
      dmpo_beta: 1.0
      dmpo_temperature: 0.06666666666666667

Design & Code Changes

  • Add recipe/dmpo/dmpo_core_algos.py with DMPO policy loss registration.
  • Add recipe/dmpo/dmpo_losses.py to pass prompt uid groups into the policy loss.
  • Add recipe/dmpo/dmpo_patch.py, dmpo_worker.py, and main_dmpo.py for recipe integration.
  • Add recipe/dmpo/config/dmpo_trainer.yaml.
  • Add a Qwen2.5-Math GRPO+DMPO example script.
  • Add unit tests for DMPO loss and uid dispatch.

Loss Variants

  • grpo_dmpo: default GRPO + DMPO objective.
  • grpo_dmpo_zero: skips zero-advantage groups during training.
  • grpo_dmpo_js: uses Jensen-Shannon divergence between current and target distributions.
  • pure_dmpo: updates with the DMPO objective without the GRPO policy loss.
  • dmpo: DMPO-only variant.
  • dmpo_zero: DMPO-only variant that skips zero-advantage groups.
  • dmpo_js: DMPO-only variant with Jensen-Shannon divergence.

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply formatting/lint checks.
  • Add / update documentation.
  • Add unit tests.
  • Request CI when ready.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the DMPO (Distribution Matching for Diverse Reasoning) recipe, which adds group-wise distribution-matching objectives over rollouts sharing the same prompt. Key feedback points out critical issues: potential AttributeErrors when calling .get() on policy_loss if it is instantiated as a standard Python dataclass, an incorrect sequence-level probability calculation caused by dividing sequence log probabilities by sequence length, and a biased MSE loss calculation that should be averaged group-wise rather than across all samples to prevent biasing towards larger groups.

Comment thread dmpo/dmpo_core_algos.py
Comment on lines +38 to +48
def _get_dmpo_params(config: Optional[DictConfig | ActorConfig]) -> tuple[float, float]:
beta = 1.0
temperature = 1.0 / 15.0
if config is None:
return beta, temperature

policy_loss_cfg = getattr(config, "policy_loss", None)
if policy_loss_cfg is None:
return beta, temperature

return float(policy_loss_cfg.get("dmpo_beta", beta)), float(policy_loss_cfg.get("dmpo_temperature", temperature))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The policy_loss_cfg object can be a pure Python dataclass instance (such as DMPOPolicyLossConfig) rather than a Hydra DictConfig or dict. Standard Python dataclasses do not have a .get() method, which will cause an AttributeError when attempting to retrieve dmpo_beta or dmpo_temperature. Use a robust helper that checks for the existence of .get() or falls back to getattr.

Suggested change
def _get_dmpo_params(config: Optional[DictConfig | ActorConfig]) -> tuple[float, float]:
beta = 1.0
temperature = 1.0 / 15.0
if config is None:
return beta, temperature
policy_loss_cfg = getattr(config, "policy_loss", None)
if policy_loss_cfg is None:
return beta, temperature
return float(policy_loss_cfg.get("dmpo_beta", beta)), float(policy_loss_cfg.get("dmpo_temperature", temperature))
def _get_dmpo_params(config: Optional[DictConfig | ActorConfig]) -> tuple[float, float]:
beta = 1.0
temperature = 1.0 / 15.0
if config is None:
return beta, temperature
policy_loss_cfg = getattr(config, "policy_loss", None)
if policy_loss_cfg is None:
return beta, temperature
if hasattr(policy_loss_cfg, "get"):
beta_val = policy_loss_cfg.get("dmpo_beta", beta)
temp_val = policy_loss_cfg.get("dmpo_temperature", temperature)
else:
beta_val = getattr(policy_loss_cfg, "dmpo_beta", beta)
temp_val = getattr(policy_loss_cfg, "dmpo_temperature", temperature)
return float(beta_val), float(temp_val)

Comment thread dmpo/dmpo_core_algos.py
num_groups = int(group_index.max().item()) + 1
response_mask = response_mask.to(log_prob.dtype)
seq_len = response_mask.sum(dim=-1).clamp(min=1)
seq_log_prob = (log_prob * response_mask).sum(dim=-1) / seq_len

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Dividing the sequence log probability by seq_len computes the token-average log probability, which corresponds to the geometric mean of token probabilities ($P(y|x)^{1/T}$). This severely distorts the probability distribution over sequences by artificially boosting the probabilities of longer sequences closer to 1 (since $P^{1/T}$ is much larger than $P$ for large $T$). To correctly match the sequence-level probability distribution as defined in DMPO, seq_log_prob should be the sum of token log probabilities, not the average.

Suggested change
seq_log_prob = (log_prob * response_mask).sum(dim=-1) / seq_len
seq_log_prob = (log_prob * response_mask).sum(dim=-1)

Comment thread dmpo/dmpo_losses.py

metrics = {}

loss_mode = config.policy_loss.get("loss_mode", "vanilla")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the issue in dmpo_core_algos.py, config.policy_loss can be a pure Python dataclass instance which does not have a .get() method. Calling .get() directly on it will raise an AttributeError. Use a robust check to safely retrieve loss_mode.

    policy_loss_cfg = config.policy_loss
    if hasattr(policy_loss_cfg, "get"):
        loss_mode = policy_loss_cfg.get("loss_mode", "vanilla")
    else:
        loss_mode = getattr(policy_loss_cfg, "loss_mode", "vanilla")

Comment thread dmpo/dmpo_core_algos.py
Comment on lines +111 to +114
model_dist = _group_softmax(seq_log_prob, group_index, num_groups, eps=eps)
if divergence == "mse":
loss_per_sample = (target_dist - model_dist).square()
loss = (loss_per_sample * sample_weight).sum() / valid_samples.clamp(min=1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Computing the MSE loss by averaging over all valid samples across the entire batch biases the loss towards larger groups (prompts with more valid rollouts). To ensure that each prompt group contributes equally to the distribution matching objective (consistent with the JS divergence implementation below), the MSE loss should be averaged group-wise.

Suggested change
model_dist = _group_softmax(seq_log_prob, group_index, num_groups, eps=eps)
if divergence == "mse":
loss_per_sample = (target_dist - model_dist).square()
loss = (loss_per_sample * sample_weight).sum() / valid_samples.clamp(min=1)
model_dist = _group_softmax(seq_log_prob, group_index, num_groups, eps=eps)
if divergence == "mse":
loss_per_sample = (target_dist - model_dist).square()
mse_per_group = torch.zeros(num_groups, device=log_prob.device, dtype=log_prob.dtype)
mse_per_group = mse_per_group.index_add(0, group_index, loss_per_sample * sample_weight)
group_weight = torch.zeros(num_groups, device=log_prob.device, dtype=log_prob.dtype)
group_weight = group_weight.index_add(0, group_index, sample_weight)
loss = mse_per_group[group_weight > 0].mean()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant