Skip to content

[rollout, algo] feat: add binary_kl (KPop) bidirectional KL rejection sampling#6800

Open
yan-sun-x wants to merge 1 commit into
verl-project:mainfrom
yan-sun-x:feat/kpop-binary-kl-rejection
Open

[rollout, algo] feat: add binary_kl (KPop) bidirectional KL rejection sampling#6800
yan-sun-x wants to merge 1 commit into
verl-project:mainfrom
yan-sun-x:feat/kpop-binary-kl-rejection

Conversation

@yan-sun-x

Copy link
Copy Markdown

What does this PR do?

Adds KPop as a new rejection-sampling option (rollout_rs="binary_kl") for off-policy /
rollout correction. KPop applies a hard trust region using the bidirectional Bernoulli KL
divergence
between the training policy and the rollout (behavior) policy: a token is kept only
when max(KL(train‖rollout), KL(rollout‖train)) <= phi.

This complements the existing IcePop support (the lower_upper IS-weight threshold, e.g.
decoupled_token_icepop), which already covers ratio-based double-sided masking. KPop adds a
KL-based criterion that is symmetric and bounded per token.

Inspired by the IcePop/KPop masking strategies from AReaL
(areal-project/AReaL#1405). Since IcePop already exists in verl, this PR ports only the missing
half (KPop) and wires it into verl's existing rejection-sampling framework rather than adding new
top-level switches.

Checklist Before Starting

Test

Added tests/trainer/ppo/test_binary_kl_rejection.py (CPU-runnable, mirrors the style of
tests/trainer/ppo/test_rollout_corr.py). 10 cases, all passing locally:

$ python -m pytest tests/trainer/ppo/test_binary_kl_rejection.py -v
...
================= 10 passed in 18.83s =================

Coverage:

  • compute_binary_kl_divergence: self-KL is exactly 0, matches a hand-computed reference value,
    and preserves input dtype (math runs in float32 internally).
  • NaN regression at the p = 1 boundary: when log_q == 0 the term log((1 - p)/(1 - q))
    would produce NaN in float32 without the upcast + eps clamp. There is an explicit test that
    asserts no NaN/Inf and non-negative KL for this case.
  • binary_kl is registered as a token-level rejection option.
  • Direct compute_rollout_rejection_mask: high-divergence tokens are rejected, matched tokens are
    kept, a loose threshold keeps everything unchanged, and missing log-probs raise a clear error.
  • End-to-end through compute_rollout_correction_and_rejection_mask (correct metric keys emitted)
    and the decoupled_token_kpop preset wiring.

The masking applies through the existing, already-tested rejection-sampling pathway, so behavior
outside the new binary_kl branch is unchanged. Happy to add a training-curve comparison
(KPop vs. no-correction vs. IcePop) on a small task if reviewers would like end-to-end evidence.

API and Usage Example

New rejection-sampling option binary_kl with a single upper bound phi:

from verl.trainer.config.algorithm import RolloutCorrectionConfig
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask

# Option 1: convenience preset (decoupled mode, token-level KPop)
cfg = RolloutCorrectionConfig.decoupled_token_kpop(phi=2.0)

# Option 2: configure directly
_, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
    old_log_prob=old_log_prob,          # training policy log-probs
    rollout_log_prob=rollout_log_prob,  # rollout / behavior policy log-probs
    response_mask=response_mask,
    rollout_is=None,
    rollout_rs="binary_kl",
    rollout_rs_threshold=2.0,           # phi
)
# Emits rollout_corr/rollout_rs_binary_kl_{mean,max,masked_fraction}, etc.

binary_kl can also be chained with other options, e.g.
rollout_rs="binary_kl,seq_max_k3", rollout_rs_threshold="2.0,3.0".

Design & Code Changes

Rather than introduce new top-level config flags, KPop reuses verl's existing
rejection-sampling framework in rollout_corr_helper.py. The only conceptually new piece is a
bidirectional Bernoulli-KL metric; the threshold parsing (single upper bound), metric emission,
and mask application all reuse existing code paths.

Specific changes:

  • verl/trainer/ppo/rollout_corr_helper.py
    • Add compute_binary_kl_divergence(log_p, log_q, eps=1e-6) — Bernoulli KL parameterized by
      log-probs; upcasts to float32 and clamps with eps to avoid NaN at the probability boundary,
      then casts back to the input dtype.
    • Register binary_kl in SUPPORTED_ROLLOUT_RS_OPTIONS and TOKEN_LEVEL_ROLLOUT_RS_OPTIONS.
    • Add a binary_kl branch in compute_rollout_rejection_mask that keeps tokens where
      max(KL_fwd, KL_rev) <= phi. Because binary KL needs both policies' raw log-probs (not just
      their ratio), two optional parameters old_log_prob / rollout_log_prob were added (default
      None; a clear ValueError is raised if binary_kl is selected without them).
    • Pass the raw log-probs through from compute_rollout_correction_and_rejection_mask.
  • verl/trainer/config/algorithm.py
    • Add the RolloutCorrectionConfig.decoupled_token_kpop(phi=2.0) preset, mirroring
      decoupled_token_icepop.
  • tests/trainer/ppo/test_binary_kl_rejection.py
    • New unit tests (see Test section).

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to cover the code.
  • Once your PR is ready for CI, send a message in the ci-request channel.
  • N/A — not related to the recipe submodule.

References

@CLAassistant

CLAassistant commented Jun 21, 2026

Copy link
Copy Markdown

CLA assistant check
All committers have signed the CLA.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the binary_kl (KPop) rejection sampling option, which applies a hard trust region using the bidirectional Bernoulli KL divergence between the training and rollout policies. It includes the implementation of the divergence computation, configuration presets, integration into the rollout correction helper, and comprehensive unit tests. The feedback suggests optimizing the bidirectional KL computation in compute_rollout_rejection_mask to avoid redundant exponentiation and clamping of the log probabilities by computing them once and then calculating both forward and reverse KL divergences directly.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +360 to +362
kl_fwd = compute_binary_kl_divergence(old_log_prob, rollout_log_prob)
kl_rev = compute_binary_kl_divergence(rollout_log_prob, old_log_prob)
per_token_stat = torch.maximum(kl_fwd, kl_rev)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In bidirectional KL rejection sampling, calling compute_binary_kl_divergence twice results in redundant computations. Specifically, both old_log_prob and rollout_log_prob are exponentiated (torch.exp) and clamped (torch.clamp) twice. Since transcendental operations are relatively expensive on GPUs, we can optimize this by computing the clamped probabilities once and then calculating both forward and reverse KL divergences directly.

            eps = 1e-6
            orig_dtype = old_log_prob.dtype
            p = torch.clamp(torch.exp(old_log_prob.float()), eps, 1.0 - eps)
            q = torch.clamp(torch.exp(rollout_log_prob.float()), eps, 1.0 - eps)
            kl_fwd = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
            kl_rev = q * torch.log(q / p) + (1 - q) * torch.log((1 - q) / (1 - p))
            per_token_stat = torch.maximum(kl_fwd, kl_rev).to(orig_dtype)

@wuxibin89

Copy link
Copy Markdown
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants