-
Notifications
You must be signed in to change notification settings - Fork 139
[recipe] add routing-aware replay utilities for MoE RL #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kaining-never-stop
wants to merge
2
commits into
verl-project:main
Choose a base branch
from
kaining-never-stop:recipe/routing-aware-replay
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Routing-Aware Replay | ||
|
|
||
| `routing_aware_replay` is a lightweight `verl-recipe` utility for studying | ||
| routing-aware replay policies in MoE RL post-training. | ||
|
|
||
| The recipe focuses on a narrow question: when router replay is used to stabilize | ||
| MoE RL, how can we compare which experts/routes should be preserved under the | ||
| same replay budget? | ||
|
|
||
| ## Motivation | ||
|
|
||
| Router replay can reduce routing drift between training and inference, but a | ||
| uniform replay policy may preserve all selected routes equally. For debugging | ||
| and ablation, it is useful to separate: | ||
|
|
||
| - the effect of using replay at all; | ||
| - the effect of replay budget; | ||
| - the effect of choosing which experts to preserve; | ||
| - whether the replay mask is too restrictive or too weak. | ||
|
|
||
| This recipe provides CPU-testable utilities for Fisher-weighted replay masks, | ||
| budget-matched controls, and compact routing replay diagnostics. | ||
|
|
||
| ## Contents | ||
|
|
||
| ```text | ||
| routing_aware_replay/ | ||
| ├── README.md | ||
| ├── REQUIRED_VERL.txt | ||
| ├── routing_aware_replay/ | ||
| │ ├── fisher_mask.py | ||
| │ ├── replay_policy.py | ||
| │ ├── diagnostics.py | ||
| │ └── schema.py | ||
| ├── examples/ | ||
| │ └── synthetic_router_replay_demo.py | ||
| └── tests/ | ||
| ├── test_fisher_mask.py | ||
| ├── test_budget_matched_replay.py | ||
| └── test_diagnostics_schema.py | ||
| ``` | ||
|
|
||
| ## Required `verl` version | ||
|
|
||
| See [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt). This recipe is intended to be | ||
| self-contained and does not require changes to `verl` core. | ||
|
|
||
| ## Quick start | ||
|
|
||
| From this recipe directory: | ||
|
|
||
| ```bash | ||
| python examples/synthetic_router_replay_demo.py | ||
| python -m unittest discover -s tests | ||
| ``` | ||
|
|
||
| If `pytest` is available, the same tests can be collected with: | ||
|
|
||
| ```bash | ||
| pytest tests | ||
| ``` | ||
|
|
||
| ## Example | ||
|
|
||
| ```python | ||
| from routing_aware_replay import ( | ||
| FisherMaskConfig, | ||
| compute_fisher_weighted_replay_mask, | ||
| make_budget_matched_mask, | ||
| summarize_replay_mask, | ||
| ) | ||
|
|
||
| fisher_scores = [0.05, 0.12, 0.81, 0.33, 1.20, 0.09, 0.74, 0.44] | ||
| config = FisherMaskConfig(target_budget=3) | ||
|
|
||
| fisher_mask = compute_fisher_weighted_replay_mask(fisher_scores, config) | ||
| uniform_mask = make_budget_matched_mask( | ||
| num_experts=len(fisher_scores), | ||
| budget=fisher_mask.effective_budget, | ||
| policy="uniform", | ||
| ) | ||
|
|
||
| print(summarize_replay_mask(fisher_mask).as_dict()) | ||
| print(summarize_replay_mask(uniform_mask).as_dict()) | ||
| ``` | ||
|
|
||
| ## Non-goals | ||
|
|
||
| The initial version does not: | ||
|
|
||
| - modify `verl` trainer behavior; | ||
| - require GPU training to validate correctness; | ||
| - claim benchmark-leading results; | ||
| - provide a full reproduction package for any unpublished paper. | ||
|
|
||
| ## Future work | ||
|
|
||
| If this recipe is useful to the community, later PRs can add: | ||
|
|
||
| - alignment with an upstream router replay output schema; | ||
| - small MoE training configs; | ||
| - richer diagnostic plots; | ||
| - a generic replay policy interface in `verl` core. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # routing_aware_replay — rolling; exact commits refreshed from this workspace | ||
| UPSTREAM=https://github.com/verl-project/verl.git | ||
| MODE=rolling | ||
| BRANCH=main | ||
| VERL_COMMIT=8a694930275061f52ebd538c906ef8819af56dbd | ||
| PIP_INSTALL=pip install verl@git+https://github.com/verl-project/verl.git@8a694930275061f52ebd538c906ef8819af56dbd | ||
| RECIPE_SUBMODULE_COMMIT=bab0fa6ca865097f1ae3dc3a517672e788464a0a | ||
| RECIPE_FOLDER_LAST_COMMIT=initial-public-contribution | ||
| NOTES=This recipe is self-contained and CPU-testable; no verl core changes are required for the initial utilities. | ||
| REFRESH=Recompute: git ls-remote https://github.com/verl-project/verl.git HEAD; git ls-remote https://github.com/verl-project/verl-recipe.git HEAD | ||
|
|
48 changes: 48 additions & 0 deletions
48
routing_aware_replay/examples/synthetic_router_replay_demo.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| """Synthetic demo for routing-aware replay masks.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | ||
|
|
||
| from routing_aware_replay import ( | ||
| FisherMaskConfig, | ||
| compute_fisher_weighted_replay_mask, | ||
| make_budget_matched_mask, | ||
| summarize_replay_mask, | ||
| ) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| """Run a small CPU-only routing replay comparison.""" | ||
|
|
||
| fisher_scores = [0.05, 0.12, 0.81, 0.33, 1.20, 0.09, 0.74, 0.44] | ||
| fisher_result = compute_fisher_weighted_replay_mask( | ||
| fisher_scores, | ||
| FisherMaskConfig(target_budget=3), | ||
| ) | ||
| uniform_result = make_budget_matched_mask( | ||
| num_experts=len(fisher_scores), | ||
| budget=fisher_result.effective_budget, | ||
| policy="uniform", | ||
| ) | ||
| random_result = make_budget_matched_mask( | ||
| num_experts=len(fisher_scores), | ||
| budget=fisher_result.effective_budget, | ||
| policy="random", | ||
| seed=7, | ||
| ) | ||
|
|
||
| payload = { | ||
| "fisher_weighted": summarize_replay_mask(fisher_result).as_dict(), | ||
| "uniform_budget_matched": summarize_replay_mask(uniform_result).as_dict(), | ||
| "random_budget_matched": summarize_replay_mask(random_result).as_dict(), | ||
| } | ||
| print(json.dumps(payload, indent=2, sort_keys=True)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| """Routing-aware replay utilities for MoE RL recipes.""" | ||
|
|
||
| from routing_aware_replay.diagnostics import summarize_replay_mask | ||
| from routing_aware_replay.fisher_mask import compute_fisher_weighted_replay_mask | ||
| from routing_aware_replay.replay_policy import make_budget_matched_mask | ||
| from routing_aware_replay.schema import ( | ||
| FisherMaskConfig, | ||
| ReplayMaskResult, | ||
| RoutingReplayDiagnostics, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "FisherMaskConfig", | ||
| "ReplayMaskResult", | ||
| "RoutingReplayDiagnostics", | ||
| "compute_fisher_weighted_replay_mask", | ||
| "make_budget_matched_mask", | ||
| "summarize_replay_mask", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| """Diagnostics for routing-aware replay masks.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from routing_aware_replay.schema import ReplayMaskResult, RoutingReplayDiagnostics | ||
|
|
||
|
|
||
| def summarize_replay_mask(result: ReplayMaskResult) -> RoutingReplayDiagnostics: | ||
| """Summarize replay mask behavior in a compact diagnostics object. | ||
|
|
||
| Args: | ||
| result: Replay mask result produced by a replay policy. | ||
|
|
||
| Returns: | ||
| RoutingReplayDiagnostics with budget and score statistics. | ||
| """ | ||
|
|
||
| mask = result.mask | ||
| if not mask: | ||
| raise ValueError("result.mask must not be empty") | ||
|
|
||
| hard_count = sum(1 for value in mask if value >= 1.0) | ||
| soft_count = sum(1 for value in mask if 0.0 < value < 1.0) | ||
| preserved_count = sum(1 for value in mask if value > 0.0) | ||
| released_count = sum(1 for value in mask if value == 0.0) | ||
| mask_mean = sum(mask) / len(mask) | ||
|
|
||
| score_min = None | ||
| score_max = None | ||
| score_mean = None | ||
| if result.scores: | ||
| score_min = min(result.scores) | ||
| score_max = max(result.scores) | ||
| score_mean = sum(result.scores) / len(result.scores) | ||
|
|
||
| return RoutingReplayDiagnostics( | ||
| policy_name=result.policy_name, | ||
| num_experts=len(mask), | ||
| effective_replay_budget=preserved_count, | ||
| hard_expert_count=hard_count, | ||
| soft_expert_count=soft_count, | ||
| preserved_count=preserved_count, | ||
| released_count=released_count, | ||
| mask_mean=mask_mean, | ||
| score_min=score_min, | ||
| score_max=score_max, | ||
| score_mean=score_mean, | ||
| metadata=dict(result.metadata), | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| """Fisher-weighted replay mask utilities.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from collections.abc import Iterable | ||
|
|
||
| from routing_aware_replay.schema import FisherMaskConfig, ReplayMaskResult | ||
|
|
||
|
|
||
| def _as_float_tuple(values: Iterable[float], name: str) -> tuple[float, ...]: | ||
| result = tuple(float(value) for value in values) | ||
| if not result: | ||
| raise ValueError(f"{name} must not be empty") | ||
| if any(not math.isfinite(value) for value in result): | ||
| raise ValueError(f"{name} must contain only finite values") | ||
| return result | ||
|
|
||
|
|
||
| def _minmax_normalize(scores: tuple[float, ...]) -> tuple[float, ...]: | ||
| min_score = min(scores) | ||
| max_score = max(scores) | ||
| width = max_score - min_score | ||
| if width == 0.0: | ||
| return tuple(0.0 for _ in scores) | ||
| return tuple((score - min_score) / width for score in scores) | ||
|
|
||
|
|
||
| def _sigmoid(value: float) -> float: | ||
| if value >= 0: | ||
| z = math.exp(-value) | ||
| return 1.0 / (1.0 + z) | ||
| z = math.exp(value) | ||
| return z / (1.0 + z) | ||
|
|
||
|
|
||
| def _top_budget_mask(scores: tuple[float, ...], budget: int) -> tuple[float, ...]: | ||
| num_experts = len(scores) | ||
| clamped_budget = min(max(budget, 0), num_experts) | ||
| ranked_indices = sorted(range(num_experts), key=lambda index: (-scores[index], index)) | ||
| selected = set(ranked_indices[:clamped_budget]) | ||
| return tuple(1.0 if index in selected else 0.0 for index in range(num_experts)) | ||
|
|
||
|
|
||
| def compute_fisher_weighted_replay_mask( | ||
| fisher_scores: Iterable[float], | ||
| config: FisherMaskConfig | None = None, | ||
| ) -> ReplayMaskResult: | ||
| """Build an expert replay mask from Fisher/proxy importance scores. | ||
|
|
||
| Args: | ||
| fisher_scores: Per-expert Fisher trace or Fisher-like importance proxy. | ||
| Larger values indicate experts whose routing behavior should be | ||
| preserved more strongly. | ||
| config: Optional mask configuration. | ||
|
|
||
| Returns: | ||
| ReplayMaskResult with one mask value per expert. | ||
|
|
||
| Raises: | ||
| ValueError: If scores are empty or contain non-finite values. | ||
| """ | ||
|
|
||
| cfg = config or FisherMaskConfig() | ||
| raw_scores = _as_float_tuple(fisher_scores, "fisher_scores") | ||
| normalized_scores = _minmax_normalize(raw_scores) | ||
|
|
||
| if cfg.target_budget is not None: | ||
| mask = _top_budget_mask(normalized_scores, cfg.target_budget) | ||
| selection_mode = "top_budget" | ||
| else: | ||
| mask_values: list[float] = [] | ||
| for score in normalized_scores: | ||
| if score >= cfg.theta_high: | ||
| mask_values.append(1.0) | ||
| elif score <= cfg.theta_low: | ||
| mask_values.append(0.0) | ||
| elif cfg.soft_mask_temperature > 0.0: | ||
| mask_values.append(_sigmoid((score - cfg.tau) / cfg.soft_mask_temperature)) | ||
| else: | ||
| mask_values.append(1.0 if score >= cfg.tau else 0.0) | ||
| mask = tuple(mask_values) | ||
| selection_mode = "threshold" | ||
|
|
||
| return ReplayMaskResult( | ||
| policy_name=cfg.policy_name, | ||
| mask=mask, | ||
| scores=normalized_scores, | ||
| metadata={ | ||
| "selection_mode": selection_mode, | ||
| "target_budget": cfg.target_budget, | ||
| "tau": cfg.tau, | ||
| "theta_high": cfg.theta_high, | ||
| "theta_low": cfg.theta_low, | ||
| "soft_mask_temperature": cfg.soft_mask_temperature, | ||
| }, | ||
| ) | ||
63 changes: 63 additions & 0 deletions
63
routing_aware_replay/routing_aware_replay/replay_policy.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| """Budget-matched replay policy baselines.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import random | ||
|
|
||
| from routing_aware_replay.schema import ReplayMaskResult | ||
|
|
||
|
|
||
| def _validate_budget(num_experts: int, budget: int) -> int: | ||
| if num_experts <= 0: | ||
| raise ValueError("num_experts must be positive") | ||
| if budget < 0: | ||
| raise ValueError("budget must be non-negative") | ||
| return min(budget, num_experts) | ||
|
|
||
|
|
||
| def _mask_from_indices(num_experts: int, selected: set[int]) -> tuple[float, ...]: | ||
| return tuple(1.0 if index in selected else 0.0 for index in range(num_experts)) | ||
|
|
||
|
|
||
| def make_budget_matched_mask( | ||
| num_experts: int, | ||
| budget: int, | ||
| policy: str, | ||
| seed: int = 0, | ||
| ) -> ReplayMaskResult: | ||
| """Create a replay mask with the same budget under a control policy. | ||
|
|
||
| Args: | ||
| num_experts: Number of experts covered by the mask. | ||
| budget: Number of experts to preserve. | ||
| policy: ``"uniform"`` or ``"random"``. | ||
| seed: Random seed used only for the random control. | ||
|
|
||
| Returns: | ||
| ReplayMaskResult with exactly ``min(budget, num_experts)`` preserved | ||
| experts. | ||
|
|
||
| Raises: | ||
| ValueError: If the policy or dimensions are invalid. | ||
| """ | ||
|
|
||
| clamped_budget = _validate_budget(num_experts, budget) | ||
| if policy == "uniform": | ||
| selected = ( | ||
| {int(index * num_experts / clamped_budget) for index in range(clamped_budget)} if clamped_budget else set() | ||
| ) | ||
| elif policy == "random": | ||
| rng = random.Random(seed) | ||
| selected = set(rng.sample(range(num_experts), clamped_budget)) | ||
| else: | ||
| raise ValueError("policy must be 'uniform' or 'random'") | ||
|
|
||
| return ReplayMaskResult( | ||
| policy_name=f"{policy}_budget_matched", | ||
| mask=_mask_from_indices(num_experts, selected), | ||
| metadata={ | ||
| "budget": clamped_budget, | ||
| "requested_budget": budget, | ||
| "seed": seed if policy == "random" else None, | ||
| }, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.