[rollout, algo] feat: add binary_kl (KPop) bidirectional KL rejection sampling#6800
[rollout, algo] feat: add binary_kl (KPop) bidirectional KL rejection sampling#6800yan-sun-x wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the binary_kl (KPop) rejection sampling option, which applies a hard trust region using the bidirectional Bernoulli KL divergence between the training and rollout policies. It includes the implementation of the divergence computation, configuration presets, integration into the rollout correction helper, and comprehensive unit tests. The feedback suggests optimizing the bidirectional KL computation in compute_rollout_rejection_mask to avoid redundant exponentiation and clamping of the log probabilities by computing them once and then calculating both forward and reverse KL divergences directly.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| kl_fwd = compute_binary_kl_divergence(old_log_prob, rollout_log_prob) | ||
| kl_rev = compute_binary_kl_divergence(rollout_log_prob, old_log_prob) | ||
| per_token_stat = torch.maximum(kl_fwd, kl_rev) |
There was a problem hiding this comment.
In bidirectional KL rejection sampling, calling compute_binary_kl_divergence twice results in redundant computations. Specifically, both old_log_prob and rollout_log_prob are exponentiated (torch.exp) and clamped (torch.clamp) twice. Since transcendental operations are relatively expensive on GPUs, we can optimize this by computing the clamped probabilities once and then calculating both forward and reverse KL divergences directly.
eps = 1e-6
orig_dtype = old_log_prob.dtype
p = torch.clamp(torch.exp(old_log_prob.float()), eps, 1.0 - eps)
q = torch.clamp(torch.exp(rollout_log_prob.float()), eps, 1.0 - eps)
kl_fwd = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
kl_rev = q * torch.log(q / p) + (1 - q) * torch.log((1 - q) / (1 - p))
per_token_stat = torch.maximum(kl_fwd, kl_rev).to(orig_dtype)
What does this PR do?
Adds KPop as a new rejection-sampling option (
rollout_rs="binary_kl") for off-policy /rollout correction. KPop applies a hard trust region using the bidirectional Bernoulli KL
divergence between the training policy and the rollout (behavior) policy: a token is kept only
when
max(KL(train‖rollout), KL(rollout‖train)) <= phi.This complements the existing IcePop support (the
lower_upperIS-weight threshold, e.g.decoupled_token_icepop), which already covers ratio-based double-sided masking. KPop adds aKL-based criterion that is symmetric and bounded per token.
Inspired by the IcePop/KPop masking strategies from AReaL
(areal-project/AReaL#1405). Since IcePop already exists in verl, this PR ports only the missing
half (KPop) and wires it into verl's existing rejection-sampling framework rather than adding new
top-level switches.
Checklist Before Starting
https://github.com/verl-project/verl/pulls?q=is%3Apr+rejection+sampling+icepop
[{modules}] {type}: {description}.Test
Added
tests/trainer/ppo/test_binary_kl_rejection.py(CPU-runnable, mirrors the style oftests/trainer/ppo/test_rollout_corr.py). 10 cases, all passing locally:Coverage:
compute_binary_kl_divergence: self-KL is exactly 0, matches a hand-computed reference value,and preserves input dtype (math runs in float32 internally).
p = 1boundary: whenlog_q == 0the termlog((1 - p)/(1 - q))would produce NaN in float32 without the upcast +
epsclamp. There is an explicit test thatasserts no NaN/Inf and non-negative KL for this case.
binary_klis registered as a token-level rejection option.compute_rollout_rejection_mask: high-divergence tokens are rejected, matched tokens arekept, a loose threshold keeps everything unchanged, and missing log-probs raise a clear error.
compute_rollout_correction_and_rejection_mask(correct metric keys emitted)and the
decoupled_token_kpoppreset wiring.The masking applies through the existing, already-tested rejection-sampling pathway, so behavior
outside the new
binary_klbranch is unchanged. Happy to add a training-curve comparison(KPop vs. no-correction vs. IcePop) on a small task if reviewers would like end-to-end evidence.
API and Usage Example
New rejection-sampling option
binary_klwith a single upper boundphi:binary_klcan also be chained with other options, e.g.rollout_rs="binary_kl,seq_max_k3",rollout_rs_threshold="2.0,3.0".Design & Code Changes
Rather than introduce new top-level config flags, KPop reuses verl's existing
rejection-sampling framework in
rollout_corr_helper.py. The only conceptually new piece is abidirectional Bernoulli-KL metric; the threshold parsing (single upper bound), metric emission,
and mask application all reuse existing code paths.
Specific changes:
verl/trainer/ppo/rollout_corr_helper.pycompute_binary_kl_divergence(log_p, log_q, eps=1e-6)— Bernoulli KL parameterized bylog-probs; upcasts to float32 and clamps with
epsto avoid NaN at the probability boundary,then casts back to the input dtype.
binary_klinSUPPORTED_ROLLOUT_RS_OPTIONSandTOKEN_LEVEL_ROLLOUT_RS_OPTIONS.binary_klbranch incompute_rollout_rejection_maskthat keeps tokens wheremax(KL_fwd, KL_rev) <= phi. Because binary KL needs both policies' raw log-probs (not justtheir ratio), two optional parameters
old_log_prob/rollout_log_probwere added (defaultNone; a clearValueErroris raised ifbinary_klis selected without them).compute_rollout_correction_and_rejection_mask.verl/trainer/config/algorithm.pyRolloutCorrectionConfig.decoupled_token_kpop(phi=2.0)preset, mirroringdecoupled_token_icepop.tests/trainer/ppo/test_binary_kl_rejection.pyChecklist Before Submitting
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel.recipesubmodule.References
Every Step Evolves: Scaling Reinforcement Learning for Trillion-Scale Thinking Model,
arXiv:2510.18855. Technical blog: https://ringtech.notion.site/icepop
probability ratio. Technical blog: https://ringtech.notion.site/kpop
rejection-sampling option).