Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions routing_aware_replay/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Routing-Aware Replay

`routing_aware_replay` is a lightweight `verl-recipe` utility for studying
routing-aware replay policies in MoE RL post-training.

The recipe focuses on a narrow question: when router replay is used to stabilize
MoE RL, how can we compare which experts/routes should be preserved under the
same replay budget?

## Motivation

Router replay can reduce routing drift between training and inference, but a
uniform replay policy may preserve all selected routes equally. For debugging
and ablation, it is useful to separate:

- the effect of using replay at all;
- the effect of replay budget;
- the effect of choosing which experts to preserve;
- whether the replay mask is too restrictive or too weak.

This recipe provides CPU-testable utilities for Fisher-weighted replay masks,
budget-matched controls, and compact routing replay diagnostics.

## Contents

```text
routing_aware_replay/
├── README.md
├── REQUIRED_VERL.txt
├── routing_aware_replay/
│ ├── fisher_mask.py
│ ├── replay_policy.py
│ ├── diagnostics.py
│ └── schema.py
├── examples/
│ └── synthetic_router_replay_demo.py
└── tests/
├── test_fisher_mask.py
├── test_budget_matched_replay.py
└── test_diagnostics_schema.py
```

## Required `verl` version

See [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt). This recipe is intended to be
self-contained and does not require changes to `verl` core.

## Quick start

From this recipe directory:

```bash
python examples/synthetic_router_replay_demo.py
python -m unittest discover -s tests
```

If `pytest` is available, the same tests can be collected with:

```bash
pytest tests
```

## Example

```python
from routing_aware_replay import (
FisherMaskConfig,
compute_fisher_weighted_replay_mask,
make_budget_matched_mask,
summarize_replay_mask,
)

fisher_scores = [0.05, 0.12, 0.81, 0.33, 1.20, 0.09, 0.74, 0.44]
config = FisherMaskConfig(target_budget=3)

fisher_mask = compute_fisher_weighted_replay_mask(fisher_scores, config)
uniform_mask = make_budget_matched_mask(
num_experts=len(fisher_scores),
budget=fisher_mask.effective_budget,
policy="uniform",
)

print(summarize_replay_mask(fisher_mask).as_dict())
print(summarize_replay_mask(uniform_mask).as_dict())
```

## Non-goals

The initial version does not:

- modify `verl` trainer behavior;
- require GPU training to validate correctness;
- claim benchmark-leading results;
- provide a full reproduction package for any unpublished paper.

## Future work

If this recipe is useful to the community, later PRs can add:

- alignment with an upstream router replay output schema;
- small MoE training configs;
- richer diagnostic plots;
- a generic replay policy interface in `verl` core.
11 changes: 11 additions & 0 deletions routing_aware_replay/REQUIRED_VERL.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# routing_aware_replay — rolling; exact commits refreshed from this workspace
UPSTREAM=https://github.com/verl-project/verl.git
MODE=rolling
BRANCH=main
VERL_COMMIT=8a694930275061f52ebd538c906ef8819af56dbd
PIP_INSTALL=pip install verl@git+https://github.com/verl-project/verl.git@8a694930275061f52ebd538c906ef8819af56dbd
RECIPE_SUBMODULE_COMMIT=bab0fa6ca865097f1ae3dc3a517672e788464a0a
RECIPE_FOLDER_LAST_COMMIT=initial-public-contribution
NOTES=This recipe is self-contained and CPU-testable; no verl core changes are required for the initial utilities.
REFRESH=Recompute: git ls-remote https://github.com/verl-project/verl.git HEAD; git ls-remote https://github.com/verl-project/verl-recipe.git HEAD

48 changes: 48 additions & 0 deletions routing_aware_replay/examples/synthetic_router_replay_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Synthetic demo for routing-aware replay masks."""

from __future__ import annotations

import json
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from routing_aware_replay import (
FisherMaskConfig,
compute_fisher_weighted_replay_mask,
make_budget_matched_mask,
summarize_replay_mask,
)


def main() -> None:
"""Run a small CPU-only routing replay comparison."""

fisher_scores = [0.05, 0.12, 0.81, 0.33, 1.20, 0.09, 0.74, 0.44]
fisher_result = compute_fisher_weighted_replay_mask(
fisher_scores,
FisherMaskConfig(target_budget=3),
)
uniform_result = make_budget_matched_mask(
num_experts=len(fisher_scores),
budget=fisher_result.effective_budget,
policy="uniform",
)
random_result = make_budget_matched_mask(
num_experts=len(fisher_scores),
budget=fisher_result.effective_budget,
policy="random",
seed=7,
)

payload = {
"fisher_weighted": summarize_replay_mask(fisher_result).as_dict(),
"uniform_budget_matched": summarize_replay_mask(uniform_result).as_dict(),
"random_budget_matched": summarize_replay_mask(random_result).as_dict(),
}
print(json.dumps(payload, indent=2, sort_keys=True))


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions routing_aware_replay/routing_aware_replay/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Routing-aware replay utilities for MoE RL recipes."""

from routing_aware_replay.diagnostics import summarize_replay_mask
from routing_aware_replay.fisher_mask import compute_fisher_weighted_replay_mask
from routing_aware_replay.replay_policy import make_budget_matched_mask
from routing_aware_replay.schema import (
FisherMaskConfig,
ReplayMaskResult,
RoutingReplayDiagnostics,
)

__all__ = [
"FisherMaskConfig",
"ReplayMaskResult",
"RoutingReplayDiagnostics",
"compute_fisher_weighted_replay_mask",
"make_budget_matched_mask",
"summarize_replay_mask",
]
49 changes: 49 additions & 0 deletions routing_aware_replay/routing_aware_replay/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Diagnostics for routing-aware replay masks."""

from __future__ import annotations

from routing_aware_replay.schema import ReplayMaskResult, RoutingReplayDiagnostics


def summarize_replay_mask(result: ReplayMaskResult) -> RoutingReplayDiagnostics:
"""Summarize replay mask behavior in a compact diagnostics object.

Args:
result: Replay mask result produced by a replay policy.

Returns:
RoutingReplayDiagnostics with budget and score statistics.
"""

mask = result.mask
if not mask:
raise ValueError("result.mask must not be empty")

hard_count = sum(1 for value in mask if value >= 1.0)
soft_count = sum(1 for value in mask if 0.0 < value < 1.0)
preserved_count = sum(1 for value in mask if value > 0.0)
released_count = sum(1 for value in mask if value == 0.0)
mask_mean = sum(mask) / len(mask)

score_min = None
score_max = None
score_mean = None
if result.scores:
score_min = min(result.scores)
score_max = max(result.scores)
score_mean = sum(result.scores) / len(result.scores)

return RoutingReplayDiagnostics(
policy_name=result.policy_name,
num_experts=len(mask),
effective_replay_budget=preserved_count,
hard_expert_count=hard_count,
soft_expert_count=soft_count,
preserved_count=preserved_count,
released_count=released_count,
mask_mean=mask_mean,
score_min=score_min,
score_max=score_max,
score_mean=score_mean,
metadata=dict(result.metadata),
)
97 changes: 97 additions & 0 deletions routing_aware_replay/routing_aware_replay/fisher_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Fisher-weighted replay mask utilities."""

from __future__ import annotations

import math
from collections.abc import Iterable

from routing_aware_replay.schema import FisherMaskConfig, ReplayMaskResult


def _as_float_tuple(values: Iterable[float], name: str) -> tuple[float, ...]:
result = tuple(float(value) for value in values)
if not result:
raise ValueError(f"{name} must not be empty")
if any(not math.isfinite(value) for value in result):
raise ValueError(f"{name} must contain only finite values")
return result


def _minmax_normalize(scores: tuple[float, ...]) -> tuple[float, ...]:
min_score = min(scores)
max_score = max(scores)
width = max_score - min_score
if width == 0.0:
return tuple(0.0 for _ in scores)
Comment thread
kaining-never-stop marked this conversation as resolved.
Outdated
return tuple((score - min_score) / width for score in scores)


def _sigmoid(value: float) -> float:
if value >= 0:
z = math.exp(-value)
return 1.0 / (1.0 + z)
z = math.exp(value)
return z / (1.0 + z)


def _top_budget_mask(scores: tuple[float, ...], budget: int) -> tuple[float, ...]:
num_experts = len(scores)
clamped_budget = min(max(budget, 0), num_experts)
ranked_indices = sorted(range(num_experts), key=lambda index: (-scores[index], index))
selected = set(ranked_indices[:clamped_budget])
return tuple(1.0 if index in selected else 0.0 for index in range(num_experts))


def compute_fisher_weighted_replay_mask(
fisher_scores: Iterable[float],
config: FisherMaskConfig | None = None,
) -> ReplayMaskResult:
"""Build an expert replay mask from Fisher/proxy importance scores.

Args:
fisher_scores: Per-expert Fisher trace or Fisher-like importance proxy.
Larger values indicate experts whose routing behavior should be
preserved more strongly.
config: Optional mask configuration.

Returns:
ReplayMaskResult with one mask value per expert.

Raises:
ValueError: If scores are empty or contain non-finite values.
"""

cfg = config or FisherMaskConfig()
raw_scores = _as_float_tuple(fisher_scores, "fisher_scores")
normalized_scores = _minmax_normalize(raw_scores)

if cfg.target_budget is not None:
mask = _top_budget_mask(normalized_scores, cfg.target_budget)
selection_mode = "top_budget"
else:
mask_values: list[float] = []
for score in normalized_scores:
if score >= cfg.theta_high:
mask_values.append(1.0)
elif score <= cfg.theta_low:
mask_values.append(0.0)
elif cfg.soft_mask_temperature > 0.0:
mask_values.append(_sigmoid((score - cfg.tau) / cfg.soft_mask_temperature))
else:
mask_values.append(1.0 if score >= cfg.tau else 0.0)
mask = tuple(mask_values)
selection_mode = "threshold"

return ReplayMaskResult(
policy_name=cfg.policy_name,
mask=mask,
scores=normalized_scores,
metadata={
"selection_mode": selection_mode,
"target_budget": cfg.target_budget,
"tau": cfg.tau,
"theta_high": cfg.theta_high,
"theta_low": cfg.theta_low,
"soft_mask_temperature": cfg.soft_mask_temperature,
},
)
63 changes: 63 additions & 0 deletions routing_aware_replay/routing_aware_replay/replay_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Budget-matched replay policy baselines."""

from __future__ import annotations

import random

from routing_aware_replay.schema import ReplayMaskResult


def _validate_budget(num_experts: int, budget: int) -> int:
if num_experts <= 0:
raise ValueError("num_experts must be positive")
if budget < 0:
raise ValueError("budget must be non-negative")
return min(budget, num_experts)


def _mask_from_indices(num_experts: int, selected: set[int]) -> tuple[float, ...]:
return tuple(1.0 if index in selected else 0.0 for index in range(num_experts))


def make_budget_matched_mask(
num_experts: int,
budget: int,
policy: str,
seed: int = 0,
) -> ReplayMaskResult:
"""Create a replay mask with the same budget under a control policy.

Args:
num_experts: Number of experts covered by the mask.
budget: Number of experts to preserve.
policy: ``"uniform"`` or ``"random"``.
seed: Random seed used only for the random control.

Returns:
ReplayMaskResult with exactly ``min(budget, num_experts)`` preserved
experts.

Raises:
ValueError: If the policy or dimensions are invalid.
"""

clamped_budget = _validate_budget(num_experts, budget)
if policy == "uniform":
selected = (
{int(index * num_experts / clamped_budget) for index in range(clamped_budget)} if clamped_budget else set()
)
elif policy == "random":
rng = random.Random(seed)
selected = set(rng.sample(range(num_experts), clamped_budget))
else:
raise ValueError("policy must be 'uniform' or 'random'")

return ReplayMaskResult(
policy_name=f"{policy}_budget_matched",
mask=_mask_from_indices(num_experts, selected),
metadata={
"budget": clamped_budget,
"requested_budget": budget,
"seed": seed if policy == "random" else None,
},
)
Loading