[recipe, algo] feat: add DMPO algorithm#105
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the DMPO (Distribution Matching for Diverse Reasoning) recipe, which adds group-wise distribution-matching objectives over rollouts sharing the same prompt. Key feedback points out critical issues: potential AttributeErrors when calling .get() on policy_loss if it is instantiated as a standard Python dataclass, an incorrect sequence-level probability calculation caused by dividing sequence log probabilities by sequence length, and a biased MSE loss calculation that should be averaged group-wise rather than across all samples to prevent biasing towards larger groups.
| def _get_dmpo_params(config: Optional[DictConfig | ActorConfig]) -> tuple[float, float]: | ||
| beta = 1.0 | ||
| temperature = 1.0 / 15.0 | ||
| if config is None: | ||
| return beta, temperature | ||
|
|
||
| policy_loss_cfg = getattr(config, "policy_loss", None) | ||
| if policy_loss_cfg is None: | ||
| return beta, temperature | ||
|
|
||
| return float(policy_loss_cfg.get("dmpo_beta", beta)), float(policy_loss_cfg.get("dmpo_temperature", temperature)) |
There was a problem hiding this comment.
The policy_loss_cfg object can be a pure Python dataclass instance (such as DMPOPolicyLossConfig) rather than a Hydra DictConfig or dict. Standard Python dataclasses do not have a .get() method, which will cause an AttributeError when attempting to retrieve dmpo_beta or dmpo_temperature. Use a robust helper that checks for the existence of .get() or falls back to getattr.
| def _get_dmpo_params(config: Optional[DictConfig | ActorConfig]) -> tuple[float, float]: | |
| beta = 1.0 | |
| temperature = 1.0 / 15.0 | |
| if config is None: | |
| return beta, temperature | |
| policy_loss_cfg = getattr(config, "policy_loss", None) | |
| if policy_loss_cfg is None: | |
| return beta, temperature | |
| return float(policy_loss_cfg.get("dmpo_beta", beta)), float(policy_loss_cfg.get("dmpo_temperature", temperature)) | |
| def _get_dmpo_params(config: Optional[DictConfig | ActorConfig]) -> tuple[float, float]: | |
| beta = 1.0 | |
| temperature = 1.0 / 15.0 | |
| if config is None: | |
| return beta, temperature | |
| policy_loss_cfg = getattr(config, "policy_loss", None) | |
| if policy_loss_cfg is None: | |
| return beta, temperature | |
| if hasattr(policy_loss_cfg, "get"): | |
| beta_val = policy_loss_cfg.get("dmpo_beta", beta) | |
| temp_val = policy_loss_cfg.get("dmpo_temperature", temperature) | |
| else: | |
| beta_val = getattr(policy_loss_cfg, "dmpo_beta", beta) | |
| temp_val = getattr(policy_loss_cfg, "dmpo_temperature", temperature) | |
| return float(beta_val), float(temp_val) |
| num_groups = int(group_index.max().item()) + 1 | ||
| response_mask = response_mask.to(log_prob.dtype) | ||
| seq_len = response_mask.sum(dim=-1).clamp(min=1) | ||
| seq_log_prob = (log_prob * response_mask).sum(dim=-1) / seq_len |
There was a problem hiding this comment.
Dividing the sequence log probability by seq_len computes the token-average log probability, which corresponds to the geometric mean of token probabilities (seq_log_prob should be the sum of token log probabilities, not the average.
| seq_log_prob = (log_prob * response_mask).sum(dim=-1) / seq_len | |
| seq_log_prob = (log_prob * response_mask).sum(dim=-1) |
|
|
||
| metrics = {} | ||
|
|
||
| loss_mode = config.policy_loss.get("loss_mode", "vanilla") |
There was a problem hiding this comment.
Similar to the issue in dmpo_core_algos.py, config.policy_loss can be a pure Python dataclass instance which does not have a .get() method. Calling .get() directly on it will raise an AttributeError. Use a robust check to safely retrieve loss_mode.
policy_loss_cfg = config.policy_loss
if hasattr(policy_loss_cfg, "get"):
loss_mode = policy_loss_cfg.get("loss_mode", "vanilla")
else:
loss_mode = getattr(policy_loss_cfg, "loss_mode", "vanilla")| model_dist = _group_softmax(seq_log_prob, group_index, num_groups, eps=eps) | ||
| if divergence == "mse": | ||
| loss_per_sample = (target_dist - model_dist).square() | ||
| loss = (loss_per_sample * sample_weight).sum() / valid_samples.clamp(min=1) |
There was a problem hiding this comment.
Computing the MSE loss by averaging over all valid samples across the entire batch biases the loss towards larger groups (prompts with more valid rollouts). To ensure that each prompt group contributes equally to the distribution matching objective (consistent with the JS divergence implementation below), the MSE loss should be averaged group-wise.
| model_dist = _group_softmax(seq_log_prob, group_index, num_groups, eps=eps) | |
| if divergence == "mse": | |
| loss_per_sample = (target_dist - model_dist).square() | |
| loss = (loss_per_sample * sample_weight).sum() / valid_samples.clamp(min=1) | |
| model_dist = _group_softmax(seq_log_prob, group_index, num_groups, eps=eps) | |
| if divergence == "mse": | |
| loss_per_sample = (target_dist - model_dist).square() | |
| mse_per_group = torch.zeros(num_groups, device=log_prob.device, dtype=log_prob.dtype) | |
| mse_per_group = mse_per_group.index_add(0, group_index, loss_per_sample * sample_weight) | |
| group_weight = torch.zeros(num_groups, device=log_prob.device, dtype=log_prob.dtype) | |
| group_weight = group_weight.index_add(0, group_index, sample_weight) | |
| loss = mse_per_group[group_weight > 0].mean() |
What does this PR do?
This PR adds the official DMPO recipe implementation for training LLMs on diverse reasoning tasks, which hosts the community implementation for the paper Beyond Mode Collapse: Distribution Matching for Diverse Reasoning.
DMPO introduces a group-wise distribution-matching objective over rollouts that share the same prompt
uid. This recipe registers multiple DMPO policy loss variants and provides a Qwen2.5-Math GRPO+DMPO training example.Test
Validated locally with:
Also smoke-tested all registered DMPO policy loss modes with forward/backward:
grpo_dmpogrpo_dmpo_zerogrpo_dmpo_jspure_dmpoVerified that
dmpo_ppo_losscorrectly passes promptuidgroups into the DMPO policy loss.API and Usage Example
Run from a
verlcheckout with this repository mounted as the recipe submodule:Default config:
Design & Code Changes
recipe/dmpo/dmpo_core_algos.pywith DMPO policy loss registration.recipe/dmpo/dmpo_losses.pyto pass promptuidgroups into the policy loss.recipe/dmpo/dmpo_patch.py,dmpo_worker.py, andmain_dmpo.pyfor recipe integration.recipe/dmpo/config/dmpo_trainer.yaml.uiddispatch.Loss Variants
grpo_dmpo: default GRPO + DMPO objective.grpo_dmpo_zero: skips zero-advantage groups during training.grpo_dmpo_js: uses Jensen-Shannon divergence between current and target distributions.pure_dmpo: updates with the DMPO objective without the GRPO policy loss.dmpo: DMPO-only variant.dmpo_zero: DMPO-only variant that skips zero-advantage groups.dmpo_js: DMPO-only variant with Jensen-Shannon divergence.Checklist Before Submitting