Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions tests/trainer/ppo/test_binary_kl_rejection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#!/usr/bin/env python3
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for the ``binary_kl`` (KPop) rejection sampling option.

KPop applies a hard trust region using the bidirectional Bernoulli KL divergence
between the training policy and the rollout policy: a token is kept only when
``max(KL(train||rollout), KL(rollout||train)) <= phi``.

This covers:
- ``compute_binary_kl_divergence`` numerics (self-KL = 0, no NaN at the p=1 boundary,
agreement with a manual reference value).
- ``binary_kl`` registration as a token-level rejection option.
- ``compute_rollout_rejection_mask`` directly: high-divergence tokens are rejected,
matched tokens are kept, and missing log-probs raise a clear error.
- The unified ``compute_rollout_correction_and_rejection_mask`` entry point and the
``RolloutCorrectionConfig.decoupled_token_kpop`` preset.

Usage:
python test_binary_kl_rejection.py
"""

import math

import pytest
import torch

from verl.trainer.config.algorithm import RolloutCorrectionConfig
from verl.trainer.ppo.rollout_corr_helper import (
SUPPORTED_ROLLOUT_RS_OPTIONS,
TOKEN_LEVEL_ROLLOUT_RS_OPTIONS,
compute_binary_kl_divergence,
compute_rollout_correction_and_rejection_mask,
compute_rollout_rejection_mask,
)


def test_binary_kl_self_divergence_is_zero():
"""KL(P||P) must be exactly zero for identical distributions."""
device = "cuda" if torch.cuda.is_available() else "cpu"
log_p = torch.log(torch.tensor([0.1, 0.5, 0.9, 0.99], device=device))
kl = compute_binary_kl_divergence(log_p, log_p)
assert torch.allclose(kl, torch.zeros_like(kl), atol=1e-6)


def test_binary_kl_no_nan_at_probability_one_boundary():
"""log_q == 0 (q == 1.0) must not produce NaN/Inf.

Without upcasting to float32 and clamping with eps, ``1 - q`` rounds to exactly
0.0 in float32 and the KL term ``log((1 - p) / (1 - q))`` becomes NaN. This is the
exact failure mode flagged in review for the original implementation.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
log_p = torch.tensor([-0.1, -2.0, 0.0, -5.0], device=device)
log_q = torch.tensor([0.0, -0.1, 0.0, -0.2], device=device) # includes q == 1.0
kl_fwd = compute_binary_kl_divergence(log_p, log_q)
kl_rev = compute_binary_kl_divergence(log_q, log_p)
for kl in (kl_fwd, kl_rev):
assert not torch.isnan(kl).any(), "binary KL produced NaN at the p=1 boundary"
assert not torch.isinf(kl).any(), "binary KL produced Inf at the p=1 boundary"
assert (kl >= 0).all(), "Bernoulli KL must be non-negative"


def test_binary_kl_matches_reference_value():
"""Spot-check against a hand-computed Bernoulli KL value."""
device = "cuda" if torch.cuda.is_available() else "cpu"
p, q = 0.8, 0.5
expected = p * math.log(p / q) + (1 - p) * math.log((1 - p) / (1 - q))
log_p = torch.log(torch.tensor([p], device=device))
log_q = torch.log(torch.tensor([q], device=device))
kl = compute_binary_kl_divergence(log_p, log_q)
assert kl.item() == pytest.approx(expected, abs=1e-5)


def test_binary_kl_dtype_preserved():
"""Output dtype matches the input dtype even though math runs in float32."""
log_p = torch.log(torch.tensor([0.6, 0.4], dtype=torch.bfloat16))
log_q = torch.log(torch.tensor([0.5, 0.5], dtype=torch.bfloat16))
kl = compute_binary_kl_divergence(log_p, log_q)
assert kl.dtype == torch.bfloat16


def test_binary_kl_registered_as_token_level_option():
"""binary_kl must be a recognized, token-level rejection option."""
assert "binary_kl" in SUPPORTED_ROLLOUT_RS_OPTIONS
assert "binary_kl" in TOKEN_LEVEL_ROLLOUT_RS_OPTIONS


def test_binary_kl_rejects_high_divergence_tokens():
"""A token whose bidirectional KL exceeds phi is masked; a matched token is kept."""
device = "cuda" if torch.cuda.is_available() else "cpu"

# Token 0: identical policies -> KL = 0 -> kept.
# Token 1: p=0.99 vs q=0.5 -> max(KL_fwd, KL_rev) ~= 1.61 -> rejected at phi=1.0.
old_log_prob = torch.log(torch.tensor([[0.5, 0.99]], device=device))
rollout_log_prob = torch.log(torch.tensor([[0.5, 0.50]], device=device))
response_mask = torch.ones_like(old_log_prob)
log_ratio = old_log_prob - rollout_log_prob

modified_mask, metrics = compute_rollout_rejection_mask(
log_ratio=log_ratio,
response_mask=response_mask,
rollout_rs="binary_kl",
rollout_rs_threshold=1.0,
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
)

assert modified_mask[0, 0].item() == 1, "matched token should be kept"
assert modified_mask[0, 1].item() == 0, "high-divergence token should be rejected"
assert metrics["rollout_rs_binary_kl_masked_fraction"] == pytest.approx(0.5, abs=1e-6)


def test_binary_kl_keeps_everything_under_loose_threshold():
"""With a large phi every token survives and the mask is unchanged."""
device = "cuda" if torch.cuda.is_available() else "cpu"
old_log_prob = torch.randn(3, 7, device=device)
rollout_log_prob = old_log_prob + torch.randn(3, 7, device=device) * 0.1
response_mask = torch.ones_like(old_log_prob)

modified_mask, _ = compute_rollout_rejection_mask(
log_ratio=old_log_prob - rollout_log_prob,
response_mask=response_mask,
rollout_rs="binary_kl",
rollout_rs_threshold=1e6,
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
)
assert torch.equal(modified_mask, response_mask)


def test_binary_kl_requires_logprobs():
"""binary_kl needs the raw log-probs; omitting them must raise a clear error."""
device = "cuda" if torch.cuda.is_available() else "cpu"
log_ratio = torch.randn(2, 4, device=device)
response_mask = torch.ones_like(log_ratio)

with pytest.raises(ValueError, match="binary_kl"):
compute_rollout_rejection_mask(
log_ratio=log_ratio,
response_mask=response_mask,
rollout_rs="binary_kl",
rollout_rs_threshold=2.0,
)


def test_binary_kl_through_unified_entrypoint():
"""End-to-end through compute_rollout_correction_and_rejection_mask."""
device = "cuda" if torch.cuda.is_available() else "cpu"
old_log_prob = torch.randn(4, 8, device=device)
rollout_log_prob = old_log_prob + torch.randn(4, 8, device=device) * 0.15
response_mask = torch.ones_like(old_log_prob)

_, modified_mask, metrics = compute_rollout_correction_and_rejection_mask(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
rollout_is=None,
rollout_rs="binary_kl",
rollout_rs_threshold=2.0,
)

assert modified_mask.shape == response_mask.shape
assert "rollout_corr/rollout_rs_binary_kl_mean" in metrics
assert "rollout_corr/rollout_rs_binary_kl_masked_fraction" in metrics


def test_decoupled_token_kpop_preset():
"""The convenience preset wires binary_kl with phi and disables IS weights."""
cfg = RolloutCorrectionConfig.decoupled_token_kpop(phi=2.5)
assert cfg.rollout_rs == "binary_kl"
assert cfg.rollout_rs_threshold == 2.5
assert cfg.rollout_is is None


if __name__ == "__main__":
print("=" * 60)
print("Binary KL (KPop) Rejection Sampling Test Suite")
print("=" * 60)

try:
test_binary_kl_self_divergence_is_zero()
test_binary_kl_no_nan_at_probability_one_boundary()
test_binary_kl_matches_reference_value()
test_binary_kl_dtype_preserved()
test_binary_kl_registered_as_token_level_option()
test_binary_kl_rejects_high_divergence_tokens()
test_binary_kl_keeps_everything_under_loose_threshold()
test_binary_kl_requires_logprobs()
test_binary_kl_through_unified_entrypoint()
test_decoupled_token_kpop_preset()
print("\n" + "=" * 60)
print("ALL TESTS PASSED ✓")
print("=" * 60)
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback

traceback.print_exc()
exit(1)
18 changes: 18 additions & 0 deletions verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,24 @@ def decoupled_token_icepop(
"""
return cls(rollout_is="token", rollout_is_threshold=f"{threshold_lower}_{threshold}", rollout_rs=None)

@classmethod
def decoupled_token_kpop(
cls,
phi: float = 2.0,
) -> "RolloutCorrectionConfig":
"""Decoupled Mode with token-level KPop.

Bidirectional Bernoulli-KL rejection sampling: keeps response_mask only for
tokens where max(KL(train||rollout), KL(rollout||train)) <= phi.

Args:
phi (float): Upper bound on the bidirectional binary KL. Default: 2.0

Returns:
RolloutCorrectionConfig configured for decoupled mode with token-level KPop
"""
return cls(rollout_is=None, rollout_rs="binary_kl", rollout_rs_threshold=phi)

@classmethod
def decoupled_seq_is_rs(
cls,
Expand Down
47 changes: 46 additions & 1 deletion verl/trainer/ppo/rollout_corr_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@
"seq_mean_k3",
"seq_max_k2",
"seq_max_k3",
"binary_kl",
}
TOKEN_LEVEL_ROLLOUT_RS_OPTIONS: set[str] = {"token_k1", "token_k2", "token_k3"}
TOKEN_LEVEL_ROLLOUT_RS_OPTIONS: set[str] = {"token_k1", "token_k2", "token_k3", "binary_kl"}


def _parse_rollout_is_threshold(threshold_spec: str | float) -> tuple[float, Optional[float]]:
Expand Down Expand Up @@ -192,11 +193,42 @@ def _parse_rollout_rs_thresholds(
return thresholds


def compute_binary_kl_divergence(
log_p: torch.Tensor,
log_q: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""Bernoulli KL divergence KL(P||Q) parameterized by log-probabilities.

Treats each token as a Bernoulli distribution P = [p, 1-p], Q = [q, 1-q]:
KL(P||Q) = p * log(p / q) + (1 - p) * log((1 - p) / (1 - q))

The inputs are upcast to float32 and clamped with ``eps = 1e-6`` so that
``1 - q`` cannot round to exactly 0 in float32 (machine epsilon ~1.19e-7),
which would otherwise produce NaNs during training.

Args:
log_p: Log-probabilities of distribution P, any shape.
log_q: Log-probabilities of distribution Q, broadcastable to ``log_p``.
eps: Clamp bound keeping probabilities strictly inside (0, 1).

Returns:
Token-level Bernoulli KL divergence, cast back to the input dtype.
"""
orig_dtype = log_p.dtype
p = torch.clamp(torch.exp(log_p.float()), eps, 1.0 - eps)
q = torch.clamp(torch.exp(log_q.float()), eps, 1.0 - eps)
kl = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
return kl.to(orig_dtype)


def compute_rollout_rejection_mask(
log_ratio: torch.Tensor,
response_mask: torch.Tensor,
rollout_rs: str = "token_k1",
rollout_rs_threshold: Optional[str | float] = None,
old_log_prob: Optional[torch.Tensor] = None,
rollout_log_prob: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, dict[str, float]]:
"""Compute hard trust region mask using divergence estimators.

Expand All @@ -212,6 +244,9 @@ def compute_rollout_rejection_mask(
- "seq_sum_k{1,2,3}": Sum of token divergences per sequence.
- "seq_mean_k{1,2,3}": Mean of token divergences per sequence.
- "seq_max_k{2,3}": Maximum token divergence per sequence.
- "binary_kl": KPop. Token-level bidirectional Bernoulli KL between the training and
rollout policies; keeps tokens where max(KL_fwd, KL_rev) <= upper bound (phi).
Requires ``old_log_prob`` and ``rollout_log_prob`` to be passed in.

Args:
log_ratio: Log ratio of training policy probability to rollout policy probability,
Expand Down Expand Up @@ -318,6 +353,14 @@ def _sequence_max(values: torch.Tensor) -> torch.Tensor:
elif option_name == "token_k3":
per_token_stat = token_k3
token_keep_bool = per_token_stat <= upper_value
elif option_name == "binary_kl":
# KPop: bidirectional Bernoulli KL between training (old) and rollout policies.
if old_log_prob is None or rollout_log_prob is None:
raise ValueError("rollout_rs option 'binary_kl' requires both old_log_prob and rollout_log_prob.")
kl_fwd = compute_binary_kl_divergence(old_log_prob, rollout_log_prob)
kl_rev = compute_binary_kl_divergence(rollout_log_prob, old_log_prob)
per_token_stat = torch.maximum(kl_fwd, kl_rev)
Comment on lines +360 to +362

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In bidirectional KL rejection sampling, calling compute_binary_kl_divergence twice results in redundant computations. Specifically, both old_log_prob and rollout_log_prob are exponentiated (torch.exp) and clamped (torch.clamp) twice. Since transcendental operations are relatively expensive on GPUs, we can optimize this by computing the clamped probabilities once and then calculating both forward and reverse KL divergences directly.

            eps = 1e-6
            orig_dtype = old_log_prob.dtype
            p = torch.clamp(torch.exp(old_log_prob.float()), eps, 1.0 - eps)
            q = torch.clamp(torch.exp(rollout_log_prob.float()), eps, 1.0 - eps)
            kl_fwd = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
            kl_rev = q * torch.log(q / p) + (1 - q) * torch.log((1 - q) / (1 - p))
            per_token_stat = torch.maximum(kl_fwd, kl_rev).to(orig_dtype)

token_keep_bool = per_token_stat <= upper_value
elif option_name.startswith("seq_sum"):
if option_name.endswith("k1"):
if lower_log is None:
Expand Down Expand Up @@ -867,6 +910,8 @@ def compute_rollout_correction_and_rejection_mask(
response_mask=response_mask,
rollout_rs=rollout_rs,
rollout_rs_threshold=rollout_rs_threshold,
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
)
metrics.update(rs_metrics)

Expand Down
Loading