diff --git a/routing_aware_replay/README.md b/routing_aware_replay/README.md new file mode 100644 index 00000000..08068c44 --- /dev/null +++ b/routing_aware_replay/README.md @@ -0,0 +1,103 @@ +# Routing-Aware Replay + +`routing_aware_replay` is a lightweight `verl-recipe` utility for studying +routing-aware replay policies in MoE RL post-training. + +The recipe focuses on a narrow question: when router replay is used to stabilize +MoE RL, how can we compare which experts/routes should be preserved under the +same replay budget? + +## Motivation + +Router replay can reduce routing drift between training and inference, but a +uniform replay policy may preserve all selected routes equally. For debugging +and ablation, it is useful to separate: + +- the effect of using replay at all; +- the effect of replay budget; +- the effect of choosing which experts to preserve; +- whether the replay mask is too restrictive or too weak. + +This recipe provides CPU-testable utilities for Fisher-weighted replay masks, +budget-matched controls, and compact routing replay diagnostics. + +## Contents + +```text +routing_aware_replay/ +├── README.md +├── REQUIRED_VERL.txt +├── routing_aware_replay/ +│ ├── fisher_mask.py +│ ├── replay_policy.py +│ ├── diagnostics.py +│ └── schema.py +├── examples/ +│ └── synthetic_router_replay_demo.py +└── tests/ + ├── test_fisher_mask.py + ├── test_budget_matched_replay.py + └── test_diagnostics_schema.py +``` + +## Required `verl` version + +See [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt). This recipe is intended to be +self-contained and does not require changes to `verl` core. + +## Quick start + +From this recipe directory: + +```bash +python examples/synthetic_router_replay_demo.py +python -m unittest discover -s tests +``` + +If `pytest` is available, the same tests can be collected with: + +```bash +pytest tests +``` + +## Example + +```python +from routing_aware_replay import ( + FisherMaskConfig, + compute_fisher_weighted_replay_mask, + make_budget_matched_mask, + summarize_replay_mask, +) + +fisher_scores = [0.05, 0.12, 0.81, 0.33, 1.20, 0.09, 0.74, 0.44] +config = FisherMaskConfig(target_budget=3) + +fisher_mask = compute_fisher_weighted_replay_mask(fisher_scores, config) +uniform_mask = make_budget_matched_mask( + num_experts=len(fisher_scores), + budget=fisher_mask.effective_budget, + policy="uniform", +) + +print(summarize_replay_mask(fisher_mask).as_dict()) +print(summarize_replay_mask(uniform_mask).as_dict()) +``` + +## Non-goals + +The initial version does not: + +- modify `verl` trainer behavior; +- require GPU training to validate correctness; +- claim benchmark-leading results; +- provide a full reproduction package for any unpublished paper. + +## Future work + +If this recipe is useful to the community, later PRs can add: + +- alignment with an upstream router replay output schema; +- small MoE training configs; +- richer diagnostic plots; +- a generic replay policy interface in `verl` core. diff --git a/routing_aware_replay/REQUIRED_VERL.txt b/routing_aware_replay/REQUIRED_VERL.txt new file mode 100644 index 00000000..4a2fc982 --- /dev/null +++ b/routing_aware_replay/REQUIRED_VERL.txt @@ -0,0 +1,11 @@ +# routing_aware_replay — rolling; exact commits refreshed from this workspace +UPSTREAM=https://github.com/verl-project/verl.git +MODE=rolling +BRANCH=main +VERL_COMMIT=8a694930275061f52ebd538c906ef8819af56dbd +PIP_INSTALL=pip install verl@git+https://github.com/verl-project/verl.git@8a694930275061f52ebd538c906ef8819af56dbd +RECIPE_SUBMODULE_COMMIT=bab0fa6ca865097f1ae3dc3a517672e788464a0a +RECIPE_FOLDER_LAST_COMMIT=initial-public-contribution +NOTES=This recipe is self-contained and CPU-testable; no verl core changes are required for the initial utilities. +REFRESH=Recompute: git ls-remote https://github.com/verl-project/verl.git HEAD; git ls-remote https://github.com/verl-project/verl-recipe.git HEAD + diff --git a/routing_aware_replay/examples/synthetic_router_replay_demo.py b/routing_aware_replay/examples/synthetic_router_replay_demo.py new file mode 100644 index 00000000..cf0abf8f --- /dev/null +++ b/routing_aware_replay/examples/synthetic_router_replay_demo.py @@ -0,0 +1,48 @@ +"""Synthetic demo for routing-aware replay masks.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from routing_aware_replay import ( + FisherMaskConfig, + compute_fisher_weighted_replay_mask, + make_budget_matched_mask, + summarize_replay_mask, +) + + +def main() -> None: + """Run a small CPU-only routing replay comparison.""" + + fisher_scores = [0.05, 0.12, 0.81, 0.33, 1.20, 0.09, 0.74, 0.44] + fisher_result = compute_fisher_weighted_replay_mask( + fisher_scores, + FisherMaskConfig(target_budget=3), + ) + uniform_result = make_budget_matched_mask( + num_experts=len(fisher_scores), + budget=fisher_result.effective_budget, + policy="uniform", + ) + random_result = make_budget_matched_mask( + num_experts=len(fisher_scores), + budget=fisher_result.effective_budget, + policy="random", + seed=7, + ) + + payload = { + "fisher_weighted": summarize_replay_mask(fisher_result).as_dict(), + "uniform_budget_matched": summarize_replay_mask(uniform_result).as_dict(), + "random_budget_matched": summarize_replay_mask(random_result).as_dict(), + } + print(json.dumps(payload, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/routing_aware_replay/routing_aware_replay/__init__.py b/routing_aware_replay/routing_aware_replay/__init__.py new file mode 100644 index 00000000..86dda09a --- /dev/null +++ b/routing_aware_replay/routing_aware_replay/__init__.py @@ -0,0 +1,19 @@ +"""Routing-aware replay utilities for MoE RL recipes.""" + +from routing_aware_replay.diagnostics import summarize_replay_mask +from routing_aware_replay.fisher_mask import compute_fisher_weighted_replay_mask +from routing_aware_replay.replay_policy import make_budget_matched_mask +from routing_aware_replay.schema import ( + FisherMaskConfig, + ReplayMaskResult, + RoutingReplayDiagnostics, +) + +__all__ = [ + "FisherMaskConfig", + "ReplayMaskResult", + "RoutingReplayDiagnostics", + "compute_fisher_weighted_replay_mask", + "make_budget_matched_mask", + "summarize_replay_mask", +] diff --git a/routing_aware_replay/routing_aware_replay/diagnostics.py b/routing_aware_replay/routing_aware_replay/diagnostics.py new file mode 100644 index 00000000..1f624d7d --- /dev/null +++ b/routing_aware_replay/routing_aware_replay/diagnostics.py @@ -0,0 +1,49 @@ +"""Diagnostics for routing-aware replay masks.""" + +from __future__ import annotations + +from routing_aware_replay.schema import ReplayMaskResult, RoutingReplayDiagnostics + + +def summarize_replay_mask(result: ReplayMaskResult) -> RoutingReplayDiagnostics: + """Summarize replay mask behavior in a compact diagnostics object. + + Args: + result: Replay mask result produced by a replay policy. + + Returns: + RoutingReplayDiagnostics with budget and score statistics. + """ + + mask = result.mask + if not mask: + raise ValueError("result.mask must not be empty") + + hard_count = sum(1 for value in mask if value >= 1.0) + soft_count = sum(1 for value in mask if 0.0 < value < 1.0) + preserved_count = sum(1 for value in mask if value > 0.0) + released_count = sum(1 for value in mask if value == 0.0) + mask_mean = sum(mask) / len(mask) + + score_min = None + score_max = None + score_mean = None + if result.scores: + score_min = min(result.scores) + score_max = max(result.scores) + score_mean = sum(result.scores) / len(result.scores) + + return RoutingReplayDiagnostics( + policy_name=result.policy_name, + num_experts=len(mask), + effective_replay_budget=preserved_count, + hard_expert_count=hard_count, + soft_expert_count=soft_count, + preserved_count=preserved_count, + released_count=released_count, + mask_mean=mask_mean, + score_min=score_min, + score_max=score_max, + score_mean=score_mean, + metadata=dict(result.metadata), + ) diff --git a/routing_aware_replay/routing_aware_replay/fisher_mask.py b/routing_aware_replay/routing_aware_replay/fisher_mask.py new file mode 100644 index 00000000..870d1220 --- /dev/null +++ b/routing_aware_replay/routing_aware_replay/fisher_mask.py @@ -0,0 +1,97 @@ +"""Fisher-weighted replay mask utilities.""" + +from __future__ import annotations + +import math +from collections.abc import Iterable + +from routing_aware_replay.schema import FisherMaskConfig, ReplayMaskResult + + +def _as_float_tuple(values: Iterable[float], name: str) -> tuple[float, ...]: + result = tuple(float(value) for value in values) + if not result: + raise ValueError(f"{name} must not be empty") + if any(not math.isfinite(value) for value in result): + raise ValueError(f"{name} must contain only finite values") + return result + + +def _minmax_normalize(scores: tuple[float, ...]) -> tuple[float, ...]: + min_score = min(scores) + max_score = max(scores) + width = max_score - min_score + if width == 0.0: + return tuple(1.0 if max_score > 0.0 else 0.0 for _ in scores) + return tuple((score - min_score) / width for score in scores) + + +def _sigmoid(value: float) -> float: + if value >= 0: + z = math.exp(-value) + return 1.0 / (1.0 + z) + z = math.exp(value) + return z / (1.0 + z) + + +def _top_budget_mask(scores: tuple[float, ...], budget: int) -> tuple[float, ...]: + num_experts = len(scores) + clamped_budget = min(max(budget, 0), num_experts) + ranked_indices = sorted(range(num_experts), key=lambda index: (-scores[index], index)) + selected = set(ranked_indices[:clamped_budget]) + return tuple(1.0 if index in selected else 0.0 for index in range(num_experts)) + + +def compute_fisher_weighted_replay_mask( + fisher_scores: Iterable[float], + config: FisherMaskConfig | None = None, +) -> ReplayMaskResult: + """Build an expert replay mask from Fisher/proxy importance scores. + + Args: + fisher_scores: Per-expert Fisher trace or Fisher-like importance proxy. + Larger values indicate experts whose routing behavior should be + preserved more strongly. + config: Optional mask configuration. + + Returns: + ReplayMaskResult with one mask value per expert. + + Raises: + ValueError: If scores are empty or contain non-finite values. + """ + + cfg = config or FisherMaskConfig() + raw_scores = _as_float_tuple(fisher_scores, "fisher_scores") + normalized_scores = _minmax_normalize(raw_scores) + + if cfg.target_budget is not None: + mask = _top_budget_mask(normalized_scores, cfg.target_budget) + selection_mode = "top_budget" + else: + mask_values: list[float] = [] + for score in normalized_scores: + if score >= cfg.theta_high: + mask_values.append(1.0) + elif score <= cfg.theta_low: + mask_values.append(0.0) + elif cfg.soft_mask_temperature > 0.0: + mask_values.append(_sigmoid((score - cfg.tau) / cfg.soft_mask_temperature)) + else: + mask_values.append(1.0 if score >= cfg.tau else 0.0) + mask = tuple(mask_values) + selection_mode = "threshold" + + return ReplayMaskResult( + policy_name=cfg.policy_name, + mask=mask, + scores=normalized_scores, + metadata={ + "selection_mode": selection_mode, + "target_budget": cfg.target_budget, + "tau": cfg.tau, + "theta_high": cfg.theta_high, + "theta_low": cfg.theta_low, + "soft_mask_temperature": cfg.soft_mask_temperature, + }, + ) diff --git a/routing_aware_replay/routing_aware_replay/replay_policy.py b/routing_aware_replay/routing_aware_replay/replay_policy.py new file mode 100644 index 00000000..de4af23e --- /dev/null +++ b/routing_aware_replay/routing_aware_replay/replay_policy.py @@ -0,0 +1,63 @@ +"""Budget-matched replay policy baselines.""" + +from __future__ import annotations + +import random + +from routing_aware_replay.schema import ReplayMaskResult + + +def _validate_budget(num_experts: int, budget: int) -> int: + if num_experts <= 0: + raise ValueError("num_experts must be positive") + if budget < 0: + raise ValueError("budget must be non-negative") + return min(budget, num_experts) + + +def _mask_from_indices(num_experts: int, selected: set[int]) -> tuple[float, ...]: + return tuple(1.0 if index in selected else 0.0 for index in range(num_experts)) + + +def make_budget_matched_mask( + num_experts: int, + budget: int, + policy: str, + seed: int = 0, +) -> ReplayMaskResult: + """Create a replay mask with the same budget under a control policy. + + Args: + num_experts: Number of experts covered by the mask. + budget: Number of experts to preserve. + policy: ``"uniform"`` or ``"random"``. + seed: Random seed used only for the random control. + + Returns: + ReplayMaskResult with exactly ``min(budget, num_experts)`` preserved + experts. + + Raises: + ValueError: If the policy or dimensions are invalid. + """ + + clamped_budget = _validate_budget(num_experts, budget) + if policy == "uniform": + selected = ( + {int(index * num_experts / clamped_budget) for index in range(clamped_budget)} if clamped_budget else set() + ) + elif policy == "random": + rng = random.Random(seed) + selected = set(rng.sample(range(num_experts), clamped_budget)) + else: + raise ValueError("policy must be 'uniform' or 'random'") + + return ReplayMaskResult( + policy_name=f"{policy}_budget_matched", + mask=_mask_from_indices(num_experts, selected), + metadata={ + "budget": clamped_budget, + "requested_budget": budget, + "seed": seed if policy == "random" else None, + }, + ) diff --git a/routing_aware_replay/routing_aware_replay/schema.py b/routing_aware_replay/routing_aware_replay/schema.py new file mode 100644 index 00000000..1d452401 --- /dev/null +++ b/routing_aware_replay/routing_aware_replay/schema.py @@ -0,0 +1,114 @@ +"""Shared schema for routing-aware replay utilities.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class FisherMaskConfig: + """Configuration for Fisher-weighted replay mask construction. + + Args: + target_budget: Optional number of experts to preserve. When set, the + highest-scoring experts are selected and threshold fields are not + used for hard selection. + tau: Center point used for threshold or soft-mask interpolation. + theta_high: Normalized score above which an expert is always preserved. + theta_low: Normalized score below which an expert is always released. + soft_mask_temperature: If positive, experts between ``theta_low`` and + ``theta_high`` receive a sigmoid soft mask around ``tau``. + policy_name: Human-readable policy identifier stored in results. + """ + + target_budget: int | None = None + tau: float = 0.5 + theta_high: float = 0.7 + theta_low: float = 0.2 + soft_mask_temperature: float = 0.0 + policy_name: str = "fisher_weighted" + + def __post_init__(self) -> None: + if self.target_budget is not None and self.target_budget < 0: + raise ValueError("target_budget must be non-negative or None") + if self.theta_low > self.theta_high: + raise ValueError("theta_low must be <= theta_high") + if not (self.theta_low <= self.tau <= self.theta_high): + raise ValueError("tau must be between theta_low and theta_high") + for field_name in ("tau", "theta_high", "theta_low"): + value = getattr(self, field_name) + if value < 0.0 or value > 1.0: + raise ValueError(f"{field_name} must be in [0, 1]") + if self.soft_mask_temperature < 0.0: + raise ValueError("soft_mask_temperature must be non-negative") + + +@dataclass(frozen=True) +class ReplayMaskResult: + """Result returned by replay mask builders.""" + + policy_name: str + mask: tuple[float, ...] + scores: tuple[float, ...] = () + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def num_experts(self) -> int: + """Number of experts covered by the mask.""" + + return len(self.mask) + + @property + def effective_budget(self) -> int: + """Number of experts with a non-zero replay mask.""" + + return sum(1 for value in self.mask if value > 0.0) + + def as_dict(self) -> dict[str, Any]: + """Return a JSON-serializable representation.""" + + return { + "policy_name": self.policy_name, + "mask": list(self.mask), + "scores": list(self.scores), + "num_experts": self.num_experts, + "effective_budget": self.effective_budget, + "metadata": dict(self.metadata), + } + + +@dataclass(frozen=True) +class RoutingReplayDiagnostics: + """Compact diagnostics for a replay mask.""" + + policy_name: str + num_experts: int + effective_replay_budget: int + hard_expert_count: int + soft_expert_count: int + preserved_count: int + released_count: int + mask_mean: float + score_min: float | None = None + score_max: float | None = None + score_mean: float | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def as_dict(self) -> dict[str, Any]: + """Return a JSON-serializable representation.""" + + return { + "policy_name": self.policy_name, + "num_experts": self.num_experts, + "effective_replay_budget": self.effective_replay_budget, + "hard_expert_count": self.hard_expert_count, + "soft_expert_count": self.soft_expert_count, + "preserved_count": self.preserved_count, + "released_count": self.released_count, + "mask_mean": self.mask_mean, + "score_min": self.score_min, + "score_max": self.score_max, + "score_mean": self.score_mean, + "metadata": dict(self.metadata), + } diff --git a/routing_aware_replay/tests/conftest.py b/routing_aware_replay/tests/conftest.py new file mode 100644 index 00000000..b4c34994 --- /dev/null +++ b/routing_aware_replay/tests/conftest.py @@ -0,0 +1,8 @@ +"""Test path setup for running this recipe inside verl-recipe.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) diff --git a/routing_aware_replay/tests/test_budget_matched_replay.py b/routing_aware_replay/tests/test_budget_matched_replay.py new file mode 100644 index 00000000..6ebae74e --- /dev/null +++ b/routing_aware_replay/tests/test_budget_matched_replay.py @@ -0,0 +1,46 @@ +"""Tests for budget-matched replay controls.""" + +from __future__ import annotations + +import unittest + +from routing_aware_replay import make_budget_matched_mask + + +class BudgetMatchedReplayTest(unittest.TestCase): + def test_uniform_mask_matches_budget(self) -> None: + result = make_budget_matched_mask(num_experts=8, budget=3, policy="uniform") + + self.assertEqual(result.effective_budget, 3) + self.assertEqual(result.mask, (1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0)) + + def test_random_mask_is_deterministic_with_seed(self) -> None: + first = make_budget_matched_mask( + num_experts=8, + budget=3, + policy="random", + seed=11, + ) + second = make_budget_matched_mask( + num_experts=8, + budget=3, + policy="random", + seed=11, + ) + + self.assertEqual(first.mask, second.mask) + self.assertEqual(first.effective_budget, 3) + + def test_budget_is_clamped_to_num_experts(self) -> None: + result = make_budget_matched_mask(num_experts=4, budget=10, policy="uniform") + + self.assertEqual(result.effective_budget, 4) + self.assertEqual(result.mask, (1.0, 1.0, 1.0, 1.0)) + + def test_invalid_policy_raises(self) -> None: + with self.assertRaises(ValueError): + make_budget_matched_mask(num_experts=4, budget=1, policy="unknown") + + +if __name__ == "__main__": + unittest.main() diff --git a/routing_aware_replay/tests/test_diagnostics_schema.py b/routing_aware_replay/tests/test_diagnostics_schema.py new file mode 100644 index 00000000..735768ee --- /dev/null +++ b/routing_aware_replay/tests/test_diagnostics_schema.py @@ -0,0 +1,32 @@ +"""Tests for routing replay diagnostics.""" + +from __future__ import annotations + +import json +import unittest + +from routing_aware_replay import ( + FisherMaskConfig, + compute_fisher_weighted_replay_mask, + summarize_replay_mask, +) + + +class DiagnosticsSchemaTest(unittest.TestCase): + def test_diagnostics_are_json_serializable(self) -> None: + result = compute_fisher_weighted_replay_mask( + [0.1, 0.3, 0.9, 0.2], + FisherMaskConfig(target_budget=2), + ) + diagnostics = summarize_replay_mask(result).as_dict() + + encoded = json.dumps(diagnostics, sort_keys=True) + self.assertIn("effective_replay_budget", encoded) + self.assertEqual(diagnostics["effective_replay_budget"], 2) + self.assertEqual(diagnostics["preserved_count"], 2) + self.assertEqual(diagnostics["released_count"], 2) + self.assertEqual(diagnostics["policy_name"], "fisher_weighted") + + +if __name__ == "__main__": + unittest.main() diff --git a/routing_aware_replay/tests/test_fisher_mask.py b/routing_aware_replay/tests/test_fisher_mask.py new file mode 100644 index 00000000..5a5f6fee --- /dev/null +++ b/routing_aware_replay/tests/test_fisher_mask.py @@ -0,0 +1,72 @@ +"""Tests for Fisher-weighted replay masks.""" + +from __future__ import annotations + +import unittest + +from routing_aware_replay import ( + FisherMaskConfig, + compute_fisher_weighted_replay_mask, +) + + +class FisherMaskTest(unittest.TestCase): + def test_target_budget_selects_highest_scores(self) -> None: + result = compute_fisher_weighted_replay_mask( + [0.1, 0.9, 0.4, 1.2], + FisherMaskConfig(target_budget=2), + ) + + self.assertEqual(result.mask, (0.0, 1.0, 0.0, 1.0)) + self.assertEqual(result.effective_budget, 2) + self.assertEqual(result.num_experts, 4) + + def test_fixed_scores_are_deterministic(self) -> None: + scores = [0.2, 0.5, 0.5, 0.1] + config = FisherMaskConfig(target_budget=2) + + first = compute_fisher_weighted_replay_mask(scores, config) + second = compute_fisher_weighted_replay_mask(scores, config) + + self.assertEqual(first.mask, second.mask) + self.assertEqual(first.mask, (0.0, 1.0, 1.0, 0.0)) + + def test_threshold_mode_produces_soft_values(self) -> None: + result = compute_fisher_weighted_replay_mask( + [0.0, 0.4, 0.8, 1.0], + FisherMaskConfig( + theta_low=0.1, + theta_high=0.9, + tau=0.5, + soft_mask_temperature=0.2, + ), + ) + + self.assertEqual(result.mask[0], 0.0) + self.assertEqual(result.mask[-1], 1.0) + self.assertGreater(result.mask[1], 0.0) + self.assertLess(result.mask[1], 1.0) + + def test_invalid_scores_raise(self) -> None: + with self.assertRaises(ValueError): + compute_fisher_weighted_replay_mask([]) + + def test_identical_positive_scores_are_preserved(self) -> None: + result = compute_fisher_weighted_replay_mask([1.5, 1.5, 1.5]) + + self.assertEqual(result.mask, (1.0, 1.0, 1.0)) + self.assertEqual(result.scores, (1.0, 1.0, 1.0)) + + def test_identical_zero_scores_are_released(self) -> None: + result = compute_fisher_weighted_replay_mask([0.0, 0.0, 0.0]) + + self.assertEqual(result.mask, (0.0, 0.0, 0.0)) + self.assertEqual(result.scores, (0.0, 0.0, 0.0)) + + def test_tau_must_be_between_thresholds(self) -> None: + with self.assertRaises(ValueError): + FisherMaskConfig(theta_low=0.2, theta_high=0.7, tau=0.8) + + +if __name__ == "__main__": + unittest.main()