From 422ffe42ab3c44cae13e0107f768f3ddf762d246 Mon Sep 17 00:00:00 2001 From: Harish Krishnamoorthy Murali Date: Fri, 5 Jun 2026 14:04:42 +0000 Subject: [PATCH 1/2] feat(recipe): add context-management agent loops Port the agentic context-management code from verl#5636 ([algo] feat: supporting agentic rl with context management; issue #5375) into a self-contained recipe, as the verl maintainers requested. Contents: the ContextManager abstraction with sliding-window and summarizer implementations, the naive_summarizer_agent and tool_sliding_window_agent loops, the design doc, a runnable GRPO example, and the original CPU unit tests (19 passing against verl main 9c38b8bb). Changes vs #5636: relocate intra-package imports to recipe.context_management.* and fix one drifted core import (verl.tools.utils.tool_registry -> verl.tools.tool_registry) for current verl. AI assistance was used. Co-Authored-By: Claude Opus 4.8 Co-authored-by: Zhentao Fan --- context_management/README.md | 62 ++ context_management/REQUIRED_VERL.txt | 11 + context_management/__init__.py | 13 + .../agent_loop_with_context_management.py | 483 ++++++++++++++++ context_management/context_manager.py | 288 ++++++++++ context_management/context_manager_plugin.md | 295 ++++++++++ context_management/example/README.md | 36 ++ context_management/example/agent.yaml | 16 + context_management/example/run_summarizer.sh | 36 ++ ...test_agent_loop_with_context_management.py | 528 ++++++++++++++++++ context_management/test_context_manager.py | 337 +++++++++++ 11 files changed, 2105 insertions(+) create mode 100644 context_management/README.md create mode 100644 context_management/REQUIRED_VERL.txt create mode 100644 context_management/__init__.py create mode 100644 context_management/agent_loop_with_context_management.py create mode 100644 context_management/context_manager.py create mode 100644 context_management/context_manager_plugin.md create mode 100644 context_management/example/README.md create mode 100644 context_management/example/agent.yaml create mode 100644 context_management/example/run_summarizer.sh create mode 100644 context_management/test_agent_loop_with_context_management.py create mode 100644 context_management/test_context_manager.py diff --git a/context_management/README.md b/context_management/README.md new file mode 100644 index 00000000..e51b0333 --- /dev/null +++ b/context_management/README.md @@ -0,0 +1,62 @@ +# Context-management agent loops + +Plug-in **context management** for verl agent loops: keep multi-turn / long-horizon rollouts within +the model's context window by compressing the trajectory on the fly, instead of truncating or +failing once the window is exceeded. + +This recipe provides two ready-to-use agent loops and the `ContextManager` abstraction they share: + +| Agent loop (`name`) | Class | Strategy | +|---|---|---| +| `naive_summarizer_agent` | `SummarizerAgentLoop` | When the model emits a `...` block, replace the history with `(initial prompt + summary)` and continue. | +| `tool_sliding_window_agent` | `ToolSlidingWindowAgentLoop` | Keep a sliding window over tool-calling turns, dropping the oldest turns when the window is exceeded. | + +Both subclass `AgentLoopWithContextManagement`, which drives a generic +`generate → check_and_compress → continue` loop around any `ContextManager` +(`SummarizerContextManager`, `SlidingWindowContextManager`, or your own). + +## Background + +This code was originally proposed for verl core in +[volcengine/verl#5636](https://github.com/verl-project/verl/pull/5636) +("[algo] feat: supporting agentic rl with context management", see issue +[#5375](https://github.com/verl-project/verl/issues/5375)). At the maintainers' request it now lives +here as a self-contained recipe rather than in `verl/experimental/agent_loop/`, so it can evolve +independently of the core library. The multi-trajectory / session-level GRPO training support that +complements it lands separately in core (see verl#5401, #5969). + +## Layout + +``` +context_management/ + context_manager.py # ContextManager + Sliding-window / Summarizer implementations + agent_loop_with_context_management.py # AgentLoopWithContextManagement + the two agent loops + context_manager_plugin.md # design notes / how to write a custom ContextManager + test_context_manager.py # CPU unit tests + test_agent_loop_with_context_management.py + example/ # runnable GRPO example wiring the summarizer loop +``` + +## Usage + +The loops register themselves under the `name`s above. Point verl at this recipe's agent-loop config +and select a loop: + +```bash +actor_rollout_ref.rollout.agent.agent_loop_config_path=recipe/context_management/example/agent.yaml +actor_rollout_ref.rollout.agent.default_agent_loop=naive_summarizer_agent +``` + +See [`example/`](example/) for a full run script, and +[`context_manager_plugin.md`](context_manager_plugin.md) for writing your own `ContextManager`. + +## Required verl version + +See [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt) for the upstream repo and the pinned core-library commit. + +## Tests + +```bash +pytest recipe/context_management/test_context_manager.py +pytest recipe/context_management/test_agent_loop_with_context_management.py +``` diff --git a/context_management/REQUIRED_VERL.txt b/context_management/REQUIRED_VERL.txt new file mode 100644 index 00000000..8ce8556d --- /dev/null +++ b/context_management/REQUIRED_VERL.txt @@ -0,0 +1,11 @@ +# context_management — rolling; refresh the commit against your verl checkout before publishing +UPSTREAM=https://github.com/verl-project/verl.git +MODE=rolling +BRANCH=main +# Core-library commit this recipe was developed/tested against. Refresh before opening the PR. +VERL_COMMIT=9c38b8bb1876a81273d76de3e79328b2dd2b7b32 +PIP_INSTALL=pip install verl@git+https://github.com/verl-project/verl.git@9c38b8bb1876a81273d76de3e79328b2dd2b7b32 +GIT_SETUP=git clone https://github.com/verl-project/verl.git && cd verl && git checkout 9c38b8bb1876a81273d76de3e79328b2dd2b7b32 && git submodule update --init --recursive recipe +RECIPE_FOLDER=context_management +NOTES=Depends only on stable verl core APIs: verl.experimental.agent_loop.agent_loop (AgentLoopBase, register, AgentLoopOutput, AgentLoopMetrics), verl.tools, verl.utils.chat_template, verl.utils.tokenizer, verl.workers.rollout.replica.TokenOutput. No core code changes are required to use this recipe. +REFRESH=Recompute VERL_COMMIT: (cd verl && git rev-parse HEAD). Re-run the tests under recipe/context_management/ after bumping. diff --git a/context_management/__init__.py b/context_management/__init__.py new file mode 100644 index 00000000..1ce90c5e --- /dev/null +++ b/context_management/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/context_management/agent_loop_with_context_management.py b/context_management/agent_loop_with_context_management.py new file mode 100644 index 00000000..41e881bb --- /dev/null +++ b/context_management/agent_loop_with_context_management.py @@ -0,0 +1,483 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import logging +import os +from abc import ABC, abstractmethod +from typing import Any, Optional +from uuid import uuid4 + +from recipe.context_management.context_manager import ( + ContextManager, + ContextState, + SlidingWindowContextManager, + SummarizerContextManager, +) + +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopMetrics, AgentLoopOutput, register +from verl.experimental.agent_loop.tool_parser import ToolParser +from verl.experimental.agent_loop.utils import build_gpt_oss_tool_response_text +from verl.tools.schemas import ToolResponse +from verl.tools.tool_registry import initialize_tools_from_config +from verl.utils.profiler import simple_timer +from verl.workers.rollout.replica import TokenOutput + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AgentLoopWithContextManagement(AgentLoopBase, ABC): + """Abstract base class for custom agent loops with pluggable context management.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.response_length = self.rollout_config.response_length + self.context_manager: Optional[ContextManager] = None + + def _build_output_from_state(self, state: ContextState) -> AgentLoopOutput: + response_length = len(state.response_mask) + prompt_ids = state.trajectory_ids[:-response_length] if response_length > 0 else list(state.trajectory_ids) + response_ids = state.trajectory_ids[-response_length:] if response_length > 0 else [] + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=state.response_mask[: self.response_length], + response_logprobs=state.response_logprobs[: self.response_length] if state.response_logprobs else None, + routed_experts=state.routed_experts, + multi_modal_data=state.multi_modal_data or None, + reward_score=state.reward_score, + num_turns=state.num_turns, + metrics=state.metrics, + extra_fields=dict(state.extra_fields), + ) + output.extra_fields.update({"turn_scores": [], "tool_rewards": []}) + return output + + async def _generate_next_state( + self, + *, + state: ContextState, + request_id: str, + sampling_params: dict[str, Any], + image_data=None, + video_data=None, + accumulate_metrics: bool = True, + preserve_extra_fields: bool = True, + preserve_routed_experts: bool = True, + ) -> tuple[ContextState, list[int]]: + """Call the LLM once and append assistant tokens to the context state. + + Returns the updated state and the raw assistant response ids. Tool loops pass + those ids to the tool parser, while non-tool loops can ignore them. + """ + metrics = state.metrics.model_dump() if accumulate_metrics else {} + prompt_ids = state.trajectory_ids + + with simple_timer("generate_sequences", metrics): + output: TokenOutput = await self.server_manager.generate( + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + image_data=image_data, + video_data=video_data, + ) + + num_preempted = output.num_preempted if output.num_preempted is not None else -1 + if metrics.get("num_preempted", -1) < 0: + metrics["num_preempted"] = num_preempted + elif num_preempted > 0: + metrics["num_preempted"] += num_preempted + + response_text = self.tokenizer.decode(output.token_ids, skip_special_tokens=True) + messages = [dict(message) for message in state.messages] + messages.append({"role": "assistant", "content": response_text}) + + response_mask = list(state.response_mask) + [1] * len(output.token_ids) + if state.response_logprobs or output.log_probs: + prefix_logprobs = ( + list(state.response_logprobs) if state.response_logprobs else [0.0] * len(state.response_mask) + ) + current_logprobs = output.log_probs if output.log_probs is not None else [0.0] * len(output.token_ids) + response_logprobs = prefix_logprobs + current_logprobs + else: + response_logprobs = [] + + if preserve_extra_fields: + extra_fields = dict(state.extra_fields) + if not extra_fields: + extra_fields.update(output.extra_fields) + else: + max_global_steps = output.extra_fields.get("max_global_steps") + if max_global_steps: + extra_fields["max_global_steps"] = max_global_steps + else: + extra_fields = dict(output.extra_fields) + + if output.routed_experts is not None: + routed_experts = output.routed_experts[: len(prompt_ids) + self.response_length] + elif preserve_routed_experts: + routed_experts = state.routed_experts + else: + routed_experts = None + + next_state = ContextState( + messages=messages, + trajectory_ids=list(state.trajectory_ids) + output.token_ids, + response_mask=response_mask, + response_logprobs=response_logprobs, + multi_modal_data=dict(state.multi_modal_data), + routed_experts=routed_experts, + reward_score=state.reward_score, + num_turns=sum(1 for message in messages if message.get("role") != "system"), + metrics=AgentLoopMetrics(**metrics), + extra_fields=extra_fields, + ) + return next_state, output.token_ids + + @abstractmethod + async def run(self, sampling_params: dict[str, Any], **kwargs) -> list[AgentLoopOutput]: + raise NotImplementedError + + +@register("naive_summarizer_agent") +class SummarizerAgentLoop(AgentLoopWithContextManagement): + """Naive agent loop of multi-trajectory that uses model-generated summaries for context compression.""" + + def __init__(self, *args, max_context_compressions: int = 4, **kwargs): + super().__init__(*args, **kwargs) + if max_context_compressions < 0: + raise ValueError("max_context_compressions must be non-negative.") + + self.max_context_compressions = max_context_compressions + self.context_manager = SummarizerContextManager( + tokenizer=self.tokenizer, + apply_chat_template_kwargs=self.apply_chat_template_kwargs, + ) + + async def run(self, sampling_params: dict[str, Any], **kwargs) -> list[AgentLoopOutput]: + messages = list(kwargs["raw_prompt"]) + + multi_modal_data = await self.process_vision_info(messages) + images = multi_modal_data.get("images") + videos = multi_modal_data.get("videos") + prompt_ids = await self.apply_chat_template(messages, images=images, videos=videos) + + state = ContextState( + messages=messages, + trajectory_ids=prompt_ids, + multi_modal_data=multi_modal_data, + num_turns=sum(1 for message in messages if message.get("role") != "system"), + metrics=AgentLoopMetrics(), + ) + + outputs = [] + request_id = uuid4().hex + compression_count = 0 + while True: + state, _ = await self._generate_next_state( + state=state, + request_id=request_id, + sampling_params=sampling_params, + image_data=images, + video_data=videos, + accumulate_metrics=False, + preserve_extra_fields=False, + preserve_routed_experts=False, + ) + outputs.append(self._build_output_from_state(state)) + + if compression_count >= self.max_context_compressions: + break + + next_state, compressed = await self.context_manager.check_and_compress(state) + if not compressed: + break + + state = next_state + compression_count += 1 + + return outputs + + +@register("tool_sliding_window_agent") +class ToolSlidingWindowAgentLoop(AgentLoopWithContextManagement): + """Text-only tool agent loop with sliding-window context compression. + + Targets coder-style text tools: no multi-modal tool returns and no user interaction loop. + """ + + def __init__( + self, + *args, + max_context_compressions: int = 4, + compress_when_m_observations: int = 16, + keep_last_n_observations: int = 2, + replacing_text: str = "[Compressed]", + tool_response_pattern: str = r"()(.*?)()", + **kwargs, + ): + super().__init__(*args, **kwargs) + if max_context_compressions < 0: + raise ValueError("max_context_compressions must be non-negative.") + + mt = self.rollout_config.multi_turn + if mt.interaction_config_path: + raise ValueError("ToolSlidingWindowAgentLoop does not support interaction_config_path.") + + self.max_context_compressions = max_context_compressions + self.max_user_turns = mt.max_user_turns + self.max_assistant_turns = mt.max_assistant_turns + self.max_parallel_calls = mt.max_parallel_calls + self.max_tool_response_length = mt.max_tool_response_length + self.tool_response_truncate_side = mt.tool_response_truncate_side + + # Tool infrastructure: parser, schemas, and tool instances (same config as ToolAgentLoop) + tool_config_path = mt.tool_config_path + tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] + self.tools = {tool.name: tool for tool in tool_list} + self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] + self.tool_parser_schemas = [tool.tool_schema for tool in tool_list] + self.tool_parser = ToolParser.get_tool_parser(mt.format, self.tokenizer) + self.tool_parser_name = mt.format + + # Sliding window compressor: replaces old blocks with placeholder text + self.context_manager = SlidingWindowContextManager( + compress_when_m_observations=compress_when_m_observations, + keep_last_n_observations=keep_last_n_observations, + replacing_text=replacing_text, + tool_response_pattern=tool_response_pattern, + tokenizer=self.tokenizer, + ) + + async def run(self, sampling_params: dict[str, Any], **kwargs) -> list[AgentLoopOutput]: + # Input validation: text-only, no multi-modal, no interaction + messages = [dict(message) for message in list(kwargs["raw_prompt"])] + self._validate_text_messages(messages) + + multi_modal_data = await self.process_vision_info(messages) + if multi_modal_data: + raise ValueError("ToolSlidingWindowAgentLoop only supports text prompts.") + + prompt_ids = await self.apply_chat_template(messages, tools=self.tool_schemas) + state = ContextState( + messages=messages, + trajectory_ids=prompt_ids, + num_turns=sum(1 for message in messages if message.get("role") != "system"), + metrics=AgentLoopMetrics(), + ) + + # Session-level bookkeeping + outputs: list[AgentLoopOutput] = [] + request_id = uuid4().hex + compression_count = 0 + assistant_turns = 0 + tool_turns = 0 + session_tool_rewards: list[float] = [] # accumulated across all trajectories + trajectory_tool_rewards: list[float] = [] # reset on each compression boundary + + def build_output(current_state: ContextState) -> AgentLoopOutput: + """Closure that captures the current reward/compression counters.""" + output = self._build_output_from_state(current_state) + output.extra_fields.update( + { + "tool_rewards": list(trajectory_tool_rewards), + "session_tool_rewards": list(session_tool_rewards), + "context_compression_count": compression_count, + "agent_loop_impl": "ToolSlidingWindowAgentLoop", + } + ) + return output + + # Main generate-tool-compress loop + while True: + # 1. LLM generation: append assistant tokens (mask=1) to state + state, assistant_response_ids = await self._generate_next_state( + state=state, + request_id=request_id, + sampling_params=sampling_params, + ) + assistant_turns += 1 + + # 2. Check termination (response budget / turn budget) + if self._should_terminate_after_generation(state, assistant_turns, tool_turns): + outputs.append(build_output(state)) + break + + # 3. Extract tool calls from the latest assistant response + _, tool_calls = await self.tool_parser.extract_tool_calls(assistant_response_ids, self.tool_parser_schemas) + if not tool_calls: + # No tool call → final answer, terminate + outputs.append(build_output(state)) + break + + # 4. Execute tools and append observation tokens (mask=0) to state + state, current_tool_rewards = await self._append_tool_responses( + state=state, + tool_calls=tool_calls, + tools_kwargs=kwargs.get("tools_kwargs", {}), + ) + tool_turns += 1 + session_tool_rewards.extend(current_tool_rewards) + trajectory_tool_rewards.extend(current_tool_rewards) + + # 5. Check response budget after tool response append + if len(state.response_mask) >= self.response_length: + outputs.append(build_output(state)) + break + + # 6. Sliding window compression: emit trajectory and start new one + if compression_count < self.max_context_compressions: + next_state, compressed = await self.context_manager.check_and_compress(state) + if compressed: + outputs.append(build_output(state)) + state = next_state + compression_count += 1 + trajectory_tool_rewards = [] # reset for the new trajectory + + return outputs + + async def _append_tool_responses( + self, + *, + state: ContextState, + tool_calls, + tools_kwargs: dict[str, Any], + ) -> tuple[ContextState, list[float]]: + """Execute tool calls in parallel and append observation tokens (mask=0) to the state. + + Returns (updated_state, tool_rewards) where tool_rewards is a list of per-tool + reward floats for this round of tool execution. + """ + metrics = state.metrics.model_dump() + tasks = [] + tool_call_names = [] + for tool_call in tool_calls[: self.max_parallel_calls]: + tasks.append(self._call_text_tool(tool_call, tools_kwargs)) + tool_call_names.append(tool_call.name) + + with simple_timer("tool_calls", metrics): + responses = await asyncio.gather(*tasks) + + # Collect text responses and rewards; reject any multi-modal returns + add_messages: list[dict[str, Any]] = [] + tool_rewards: list[float] = [] + for tool_response, tool_reward in responses: + if tool_response.image or tool_response.video: + raise ValueError("ToolSlidingWindowAgentLoop only supports text tool responses.") + add_messages.append({"role": "tool", "content": tool_response.text or ""}) + if tool_reward is not None: + tool_rewards.append(tool_reward) + + # Tokenize tool response messages into token ids + response_ids = await self._encode_tool_response_messages(add_messages, tool_call_names) + + # Tool response tokens are mask=0 (environment observation, no gradient) + response_mask = list(state.response_mask) + [0] * len(response_ids) + response_logprobs = list(state.response_logprobs) + [0.0] * len(response_ids) if state.response_logprobs else [] + messages = [dict(message) for message in state.messages] + add_messages + + next_state = ContextState( + messages=messages, + trajectory_ids=list(state.trajectory_ids) + response_ids, + response_mask=response_mask, + response_logprobs=response_logprobs, + routed_experts=state.routed_experts, + reward_score=state.reward_score, + num_turns=sum(1 for message in messages if message.get("role") != "system"), + metrics=AgentLoopMetrics(**metrics), + extra_fields=dict(state.extra_fields), + ) + return next_state, tool_rewards + + async def _call_text_tool(self, tool_call, tools_kwargs: dict[str, Any]) -> tuple[ToolResponse, Optional[float]]: + """Execute a single tool call. Returns (response, reward). + + On failure, returns an error message as the response text with reward=0. + The caller is responsible for rejecting multi-modal responses. + """ + tool, instance_id = None, None + try: + tool_name = tool_call.name + tool_args = json.loads(tool_call.arguments) + tool = self.tools[tool_name] + kwargs = tools_kwargs.get(tool_name, {}) + instance_id, _ = await tool.create(create_kwargs=kwargs.get("create_kwargs", {})) + tool_execution_response, tool_reward, _ = await tool.execute(instance_id, tool_args) + except Exception as e: + logger.warning(f"Error when executing tool: {e}") + return ToolResponse(text=f"Error when executing tool: {e}"), 0.0 + finally: + if tool and instance_id: + await tool.release(instance_id) + + tool_response_text = self._truncate_tool_response_text(tool_execution_response.text) + + # Propagate image/video attrs so the caller can detect and reject multi-modal returns + tool_response_kwargs = {"text": tool_response_text} + for attr_name in ["image", "video"]: + if hasattr(tool_execution_response, attr_name): + attr_value = getattr(tool_execution_response, attr_name) + if attr_value is not None: + tool_response_kwargs[attr_name] = attr_value + + return ToolResponse(**tool_response_kwargs), tool_reward + + async def _encode_tool_response_messages( + self, add_messages: list[dict[str, Any]], tool_call_names: list[str] + ) -> list[int]: + """Tokenize tool response messages into token ids, handling gpt-oss format specially.""" + if self.tool_parser_name == "gpt-oss": + tool_response_text = build_gpt_oss_tool_response_text(add_messages, tool_call_names) + return await self.loop.run_in_executor( + None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False) + ) + return await self.apply_chat_template(add_messages, remove_system_prompt=True) + + def _truncate_tool_response_text(self, text: Optional[str]) -> Optional[str]: + """Truncate tool response text to max_tool_response_length if exceeded.""" + if not text or len(text) <= self.max_tool_response_length: + return text + + if self.tool_response_truncate_side == "left": + return text[: self.max_tool_response_length] + "...(truncated)" + if self.tool_response_truncate_side == "right": + return "(truncated)..." + text[-self.max_tool_response_length :] + + length = self.max_tool_response_length // 2 + return text[:length] + "...(truncated)..." + text[-length:] + + def _should_terminate_after_generation( + self, + state: ContextState, + assistant_turns: int, + tool_turns: int, + ) -> bool: + """Check whether the loop should stop after an LLM generation step.""" + if len(state.response_mask) >= self.response_length: + return True + if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns: + return True + # tool_turns maps to max_user_turns (each tool round is one "user" turn in agent loop semantics) + return bool(self.max_user_turns and tool_turns >= self.max_user_turns) + + @staticmethod + def _validate_text_messages(messages: list[dict[str, Any]]) -> None: + """Reject messages with non-string content (e.g. multi-modal structured content).""" + for message in messages: + content = message.get("content", "") + if not isinstance(content, str): + raise ValueError("ToolSlidingWindowAgentLoop only supports string message content.") diff --git a/context_management/context_manager.py b/context_management/context_manager.py new file mode 100644 index 00000000..02083697 --- /dev/null +++ b/context_management/context_manager.py @@ -0,0 +1,288 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Optional + +from verl.experimental.agent_loop.agent_loop import AgentLoopMetrics +from verl.utils.chat_template import apply_chat_template, initialize_system_prompt +from verl.utils.tokenizer import normalize_token_ids + + +@dataclass +class ContextState: + """State boundary shared by agent loops and context managers.""" + + messages: list[dict[str, Any]] + trajectory_ids: list[int] = field(default_factory=list) + response_mask: list[int] = field(default_factory=list) + response_logprobs: list[float] = field(default_factory=list) + multi_modal_data: dict[str, Any] = field(default_factory=dict) + routed_experts: Optional[Any] = None + reward_score: Optional[float] = None + num_turns: int = 0 + metrics: AgentLoopMetrics = field(default_factory=AgentLoopMetrics) + extra_fields: dict[str, Any] = field(default_factory=dict) + + +class ContextManager(ABC): + """Plugin interface for context management.""" + + async def check_and_compress(self, state: ContextState) -> tuple[ContextState, bool]: + if not await self._should_compress(state): + return state, False + compressed_state = await self._compress_impl(state) + return compressed_state, compressed_state != state + + @abstractmethod + async def _should_compress(self, state: ContextState) -> bool: + raise NotImplementedError + + @abstractmethod + async def _compress_impl(self, state: ContextState) -> ContextState: + raise NotImplementedError + + +class SlidingWindowContextManager(ContextManager): + """Rule-based sliding-window compressor, following the structure of Figure 3 + in paper: https://arxiv.org/pdf/2510.08276 + + Keeps the last N tool responses/observations when M have accumulated, where + tool_window_size = M and slide_size = M - N in paper. + """ + + def __init__( + self, + compress_when_m_observations: int = 16, + keep_last_n_observations: int = 0, + replacing_text: str = "[Compressed]", + tool_response_pattern: str = r"()(.*?)()", + *, + tokenizer: Any, + ): + if compress_when_m_observations <= 0 or keep_last_n_observations < 0: + raise ValueError( + "compress_when_m_observations must be positive and keep_last_n_observations must be non-negative." + ) + if keep_last_n_observations >= compress_when_m_observations: + raise ValueError("keep_last_n_observations must be less than compress_when_m_observations.") + if tokenizer is None: + raise ValueError("tokenizer must be provided for SlidingWindowContextManager.") + + self.compress_when_m_observations = compress_when_m_observations + self.keep_last_n_observations = keep_last_n_observations + self.replacing_text = replacing_text + self.tokenizer = tokenizer + self.tool_response_pattern = re.compile(tool_response_pattern, re.DOTALL) + + async def _should_compress(self, state: ContextState) -> bool: + """Return True only when the number of remaining observations reaches the threshold M.""" + response_length = len(state.response_mask) + if response_length == 0: + return False + response_ids = state.trajectory_ids[-response_length:] + response_text = self.tokenizer.decode(response_ids, skip_special_tokens=False) + + observation_count = 0 + for _, body, _ in self.tool_response_pattern.findall(response_text): + # Only consider the observations that haven't been compressed or replaced + if body.strip() != self.replacing_text: + observation_count += 1 + return observation_count >= self.compress_when_m_observations + + async def _compress_impl(self, state: ContextState) -> ContextState: + """Remove earlier observations and keep only the last N observations.""" + response_length = len(state.response_mask) + + # 'response_length' won't be zero as it has been checked by _should_compress() + prompt_ids = state.trajectory_ids[:-response_length] + response_ids = state.trajectory_ids[-response_length:] + + # Compress both trajectory_ids and messages, and they should be aligned. + compressed_response_ids, removed_num_obs_from_ids = self._compress_token_ids(response_ids) + compressed_messages, removed_num_obs_from_messages = self._compress_messages(state.messages) + + if removed_num_obs_from_ids != removed_num_obs_from_messages: + raise ValueError("_compress_token_ids and _compress_messages must remove the same number of observations.") + if removed_num_obs_from_ids == 0: + raise ValueError("SlidingWindowContextManager.compress removed zero observations unexpectedly.") + + # Reconstruct the context state + compressed_trajectory_ids = prompt_ids + compressed_response_ids + response_mask = [0] * len(compressed_response_ids) + response_logprobs = [0.0] * len(compressed_response_ids) if state.response_logprobs else [] + + return ContextState( + messages=compressed_messages, + trajectory_ids=compressed_trajectory_ids, + response_mask=response_mask, + response_logprobs=response_logprobs, + multi_modal_data=dict(state.multi_modal_data), + routed_experts=None, + reward_score=state.reward_score, + num_turns=state.num_turns, + metrics=state.metrics.model_copy(deep=True), + extra_fields=dict(state.extra_fields), + ) + + def _compress_messages(self, messages: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], int]: + """Compress earlier observations in messages through message struct and keep the last N unchanged.""" + compressed_messages = [dict(message) for message in messages] + removed_num_obs = 0 + + tool_message_indices = [index for index, message in enumerate(messages) if message.get("role") == "tool"] + num_to_compress = len(tool_message_indices) - self.keep_last_n_observations + for message_index in tool_message_indices[:num_to_compress]: + content = messages[message_index].get("content") + already_compressed = False + + # For Multi-modal messages, we will replace them entirely. + if isinstance(content, str): + already_compressed = content.strip() == self.replacing_text + + if not already_compressed: + removed_num_obs += 1 + compressed_messages[message_index]["content"] = self.replacing_text + + return compressed_messages, removed_num_obs + + def _compress_token_ids(self, token_ids: list[int]) -> tuple[list[int], int]: + """Compress earlier observations in token ids through regex matching and keep the last N unchanged.""" + text = self.tokenizer.decode(token_ids, skip_special_tokens=False) + matches = list(self.tool_response_pattern.finditer(text)) + num_to_compress = len(matches) - self.keep_last_n_observations + + compressed_parts = [] + last_end = 0 + removed_num_obs = 0 + for index, match in enumerate(matches): + compressed_parts.append(text[last_end : match.start()]) + if index < num_to_compress: + start_tag, _, end_tag = match.groups() + # Previously compressed wouldn't be counted as 'removed_num_obs' this time + if match.group(2).strip() != self.replacing_text: + removed_num_obs += 1 + compressed_parts.append(f"{start_tag}{self.replacing_text}{end_tag}") + else: + compressed_parts.append(match.group(0)) + last_end = match.end() + compressed_parts.append(text[last_end:]) + compressed_text = "".join(compressed_parts) + return self.tokenizer.encode(compressed_text, add_special_tokens=False), removed_num_obs + + +class SummarizerContextManager(ContextManager): + """Model-based summarization compressor, following the structure of Figure 1 in + paper: https://arxiv.org/pdf/2510.06727 + + Models are doing the summarization by itself and start the next trajectory from + the initial token_ids of prompt_ids + summarization_ids. + """ + + def __init__( + self, + summary_pattern: str = r"()(.*?)()", + *, + tokenizer: Any, + apply_chat_template_kwargs: Optional[dict[str, Any]] = None, + ): + if tokenizer is None: + raise ValueError("tokenizer must be provided for SummarizerContextManager.") + + self.tokenizer = tokenizer + self.summary_pattern = re.compile(summary_pattern, re.DOTALL) + self.apply_chat_template_kwargs = apply_chat_template_kwargs or {} + self.system_prompt_length = len(initialize_system_prompt(self.tokenizer, **self.apply_chat_template_kwargs)) + + async def _should_compress(self, state: ContextState) -> bool: + """Return True only when a model-generated summary exists in the current generated + response of this trajectory. + + NOTE: Previous summarization from the preceding run shouldn't be counted as one in + current trajectory. + """ + response_length = len(state.response_mask) + if response_length == 0: + return False + response_ids = state.trajectory_ids[-response_length:] + # NOTE: Should only consider the summarization in generated tokens of current trajectory, otherwise we + # will get into a infinite loop as previous summarization may continuously trigger the compression. + generated_response_ids = [ + token_id for token_id, token_mask in zip(response_ids, state.response_mask, strict=False) if token_mask == 1 + ] + response_text = self.tokenizer.decode(generated_response_ids, skip_special_tokens=False) + return self.summary_pattern.search(response_text) is not None + + async def _compress_impl(self, state: ContextState) -> ContextState: + """Keep the last summarization only, prepended with original prompts.""" + response_length = len(state.response_mask) + + # 'response_length' won't be zero as it has been checked by _should_compress() + prompt_ids = state.trajectory_ids[:-response_length] + response_ids = state.trajectory_ids[-response_length:] + + # Take the llm-generated tokens of current trajectory and search for summarization + generated_response_ids = [ + token_id for token_id, token_mask in zip(response_ids, state.response_mask, strict=False) if token_mask == 1 + ] + response_text = self.tokenizer.decode(generated_response_ids, skip_special_tokens=False) + + summary_match = None + # Use the last summarization + for match in self.summary_pattern.finditer(response_text): + summary_match = match + if summary_match is None: + raise ValueError("SummarizerContextManager.compress expected a block but found none.") + + compressed_messages = [] + # Only keep the prompts from user and system, and append it with summarization + for message in state.messages: + if message.get("role") in {"assistant", "tool"}: + break + compressed_messages.append(dict(message)) + summary_text = summary_match.group(0) + compressed_messages.append({"role": "assistant", "content": summary_text}) + + # NOTE: We use chat_template to rebuild the trajectory_ids because we need 'add_generation_prompt' to encourage + # the model continuously infering after the tag. Otherwise, model may directly output EOS and stop. + # Meanwhile, we should keep the original prompt_ids unchanged, since it may have multi-modal data. + # The system prompt prefix is removed here to align with the incremental append behavior in ToolAgentLoop. + tokenized_summary = apply_chat_template( + self.tokenizer, + [{"role": "assistant", "content": summary_text}], + add_generation_prompt=True, + tokenize=True, + **self.apply_chat_template_kwargs, + ) + summary_ids = normalize_token_ids(tokenized_summary)[self.system_prompt_length :] + + # Reconstruct the context state + compressed_trajectory_ids = prompt_ids + summary_ids + response_mask = [0] * len(summary_ids) + response_logprobs = [0.0] * len(summary_ids) if state.response_logprobs else [] + + return ContextState( + messages=compressed_messages, + trajectory_ids=compressed_trajectory_ids, + response_mask=response_mask, + response_logprobs=response_logprobs, + multi_modal_data=dict(state.multi_modal_data), + routed_experts=None, + reward_score=state.reward_score, + num_turns=state.num_turns, + metrics=state.metrics.model_copy(deep=True), + extra_fields=dict(state.extra_fields), + ) diff --git a/context_management/context_manager_plugin.md b/context_management/context_manager_plugin.md new file mode 100644 index 00000000..4a3c0826 --- /dev/null +++ b/context_management/context_manager_plugin.md @@ -0,0 +1,295 @@ +# Context Manager Plugin + +Last updated: 04/16/2026. + +## Introduction + +In agentic RL, a model interacts with tools over many turns. As the conversation grows, the context may exceed the model's context window, leading to degraded performance or generation failures. The **Context Manager Plugin** provides a pluggable interface for compressing context during agent loop execution, enabling longer and more effective multi-turn interactions. + +This plugin extends the existing [Agent Loop](agent_loop.rst) framework with **pluggable context compression** via the `ContextManager` abstract class. Each compression boundary naturally produces a new trajectory from a single prompt, and the TQ Trainer infrastructure provides the necessary support for end-to-end training: + +- **Multi-trajectory Management**: `AgentLoopWorkerTQ` accepts `list[AgentLoopOutput]` per prompt, storing each trajectory as a separate sample in TransferQueue with key format `{uid}_{session_id}_{index}`. The TQ trainer (`main_ppo_sync.py`) then consumes these trajectories for training. +- **Session-level GRPO Advantage**: `compute_advantage_for_multi_trajectories` groups trajectories by `{uid}_{session_id}` and computes group-relative advantages only on each session's final output. Then, the advantage score will be broadcasted to preceding trajectories. You may customize this credit assignment inside each session. +- **Batch padding**: `upsample_batch_to_divisible_size` appends synthetic samples when dynamic batchsize being indivisible by `dp_size` or `mini_batch_size`, with the awareness of advantage correctness, gradient consistency, metrics independency, router replay support, multi-modal compatibility, and performance. + +## Architecture + +``` +AgentLoopBase (existing) + │ + └── AgentLoopWithContextManagement (new abstract base) + │ + ├── context_manager: ContextManager ← pluggable + │ ├── SlidingWindowContextManager + │ └── SummarizerContextManager + │ + ├── SummarizerAgentLoop (model-generated summary, no tool calling) + └── ToolSlidingWindowAgentLoop (sliding window + tool calling, text-only) +``` + +The `ContextManager` communicates with the agent loop through `ContextState`, a shared dataclass that carries the full state of the conversation: + +```python +@dataclass +class ContextState: + messages: list[dict[str, Any]] # Chat-format messages + trajectory_ids: list[int] # prompt_ids + response_ids + response_mask: list[int] # 1 = model-generated, 0 = observation/compressed + response_logprobs: list[float] # Log probabilities for response tokens + multi_modal_data: dict[str, Any] # Images, videos, etc. + routed_experts: Optional[Any] # MoE routing info + reward_score: Optional[float] # Reward for the trajectory + num_turns: int # Number of chat turns + metrics: AgentLoopMetrics # Performance metrics + extra_fields: dict[str, Any] # Extensible metadata +``` + +## Algorithm Principle + +After context compression, a trajectory is split into two training samples. Each sample carries its own `response_mask` (equivalently `loss_mask`) that determines which tokens contribute to the policy gradient. The key principle is that only the token **generated by model** in this trajectory **turn** will have gradient. For example: + +**Trajectory 0** (before compression): + +``` +| prompts | reasoning | tool_call | tool_response | ... | +| mask=0 | mask=1 | mask=1 | mask=0 | mask=1 | +``` + +**Trajectory 1** (after compression, starts a new generation): + +``` +| prompts | compressed_history | reasoning | tool_call | tool_response | final_output | +| mask=0 | mask=0 | mask=1 | mask=1 | mask=0 | mask=1 | +``` + +Both trajectories share the same `uid` and `session_id`, so they are grouped together for GRPO advantage computation. By default, the advantage is computed from the final trajectory's reward and broadcast back to all trajectories in the session. + +## ContextManager API + +`ContextManager` is the abstract plugin interface. It exposes a single public method: + +```python +class ContextManager(ABC): + async def check_and_compress(self, state: ContextState) -> tuple[ContextState, bool]: + """Check whether compression is needed and apply it if so. + + Returns: + A tuple of (new_state, compressed) where compressed is True + if the state was modified. + """ + if not await self._should_compress(state): + return state, False + compressed_state = await self._compress_impl(state) + return compressed_state, compressed_state != state +``` + +To implement a custom context manager, subclass `ContextManager` and override: + +- `_should_compress(state)` — return `True` when compression should trigger. +- `_compress_impl(state)` — return a new `ContextState` with compressed context. + +### Post-compression invariants + +After compression, the returned `ContextState` must satisfy: + +- `response_mask` is all zeros (compressed tokens are not model-generated). +- `response_logprobs` is all zeros (no valid log probabilities for compressed tokens). +- `routed_experts` is `None` (token positions have changed, so routing info is invalidated). +- `trajectory_ids` preserves the original prompt prefix and appends compressed content. + +### Supported Basic Strategies + +| Strategy | Trigger | Method | Reference | +|---|---|---|---| +| **Sliding Window** | Observation count reaches threshold | Replace earlier tool responses with placeholder text | [Figure 3, arXiv:2510.08276](https://arxiv.org/pdf/2510.08276) | +| **Summarizer** | Model generates `` tag | Keep original prompt + model-generated summary | [Figure 1, arXiv:2510.06727](https://arxiv.org/pdf/2510.06727) | + +## Built-in Manager: Sliding Window + +The `SlidingWindowContextManager` compresses context by replacing earlier tool observations with placeholder text. It is a rule-based strategy that does not require LLM calls. + +### Parameters + +| Parameter | Default | Description | +|---|---|---| +| `compress_when_m_observations` | 16 | Trigger compression when this many uncompressed observations accumulate | +| `keep_last_n_observations` | 0 | Number of recent observations to preserve | +| `replacing_text` | `"[Compressed]"` | Placeholder text for compressed observations | +| `tool_response_pattern` | `r"()(.*?)()"` | Regex pattern to match tool response blocks | +| `tokenizer` | *(required)* | Tokenizer for encode/decode | + +### Example + +```python +from recipe.context_management.context_manager import SlidingWindowContextManager + +manager = SlidingWindowContextManager( + compress_when_m_observations=8, + keep_last_n_observations=2, + replacing_text="[Compressed]", + tokenizer=tokenizer, +) +``` + +This triggers compression when 8 uncompressed tool responses accumulate, keeps the last 2, and replaces the rest with `[Compressed]`. + +### How it works + +1. **Trigger**: counts uncompressed `` blocks in the response tokens (already-compressed blocks are excluded from the count). +2. **Compress**: replaces the content of earlier observations with `[Compressed]` in both token IDs (via decode-regex-encode) and chat messages (via role-based matching). +3. **Align**: verifies that the same number of observations were removed from both representations of token_ids and messages. + +## Built-in Manager: Summarizer + +The `SummarizerContextManager` relies on the model itself to produce a summary. When the model generates a `...` block, the context is compressed to just the original prompt plus the summary. + +### Parameters + +| Parameter | Default | Description | +|---|---|---| +| `summary_pattern` | `r"()(.*?)()"` | Regex pattern to match summary blocks | +| `tokenizer` | *(required)* | Tokenizer for encode/decode | +| `apply_chat_template_kwargs` | `{}` | Additional kwargs passed to `apply_chat_template` | + +### How it works + +1. **Trigger**: checks only `mask=1` tokens (current generation) for `` tags. This prevents previous compressions from re-triggering an infinite loop. +2. **Compress**: keeps the original system/user messages, discards all assistant/tool turns, and appends the last `` block as a new assistant message. +3. **Rebuild token IDs**: uses `apply_chat_template(add_generation_prompt=True)` to reconstruct token IDs, which appends an assistant generation prefix to encourage the model to continue generating. + +## Loop: ToolSlidingWindowAgentLoop + +`ToolSlidingWindowAgentLoop` combines tool calling with sliding window context compression for text-only coder-style scenarios. It is registered as `"tool_sliding_window_agent"`. + +Unlike `SummarizerAgentLoop` which relies on the model to generate summaries, this loop uses rule-based sliding window compression: when tool response observations accumulate beyond a threshold, earlier observations are replaced with placeholder text. + +### Parameters + +| Parameter | Default | Description | +|---|---|---| +| `max_context_compressions` | 4 | Maximum number of compression cycles per prompt | +| `compress_when_m_observations` | 16 | Trigger compression when this many uncompressed observations accumulate | +| `keep_last_n_observations` | 2 | Number of recent observations to preserve after compression | +| `replacing_text` | `"[Compressed]"` | Placeholder text for compressed observations | +| `tool_response_pattern` | `r"()(.*?)()"` | Regex pattern to match tool response blocks in token stream | + +Tool calling parameters (`max_assistant_turns`, `max_parallel_calls`, `max_tool_response_length`, etc.) are read from `rollout_config.multi_turn`, same as `ToolAgentLoop`. + +### Multi-trajectory generation flow + +``` +Prompt → Generate → tool_call? ─No──→ Emit output[k], return [output[0..k]] + │ + Yes + ↓ + Execute tools → append tool_response (mask=0) + ↓ + Observations ≥ M? ─No──→ loop back to Generate + │ + Yes + ↓ + Emit output[k], compress context, k+=1, loop back to Generate + ↓ + ... (up to max_context_compressions) +``` + +### Usage + +Register `"tool_sliding_window_agent"` in your config: + +```yaml +actor_rollout_ref: + rollout: + agent: + agent_loop_cls: tool_sliding_window_agent +``` + +### Important: `tool_response_pattern` must match your chat template + +The sliding window compressor finds tool responses in the token stream via regex. The default pattern `...` works for Qwen-style chat templates. If your model uses a different format (e.g. `gpt-oss` uses `<|start|>functions.xxx...`), you must set `tool_response_pattern` accordingly, otherwise compression will silently never trigger. + +## Loop: SummarizerAgentLoop + +`SummarizerAgentLoop` is the built-in agent loop that uses `SummarizerContextManager` to produce multi-trajectory training data. It is registered as `"naive_summarizer_agent"`. + +### Parameters + +| Parameter | Default | Description | +|---|---|---| +| `max_context_compressions` | 4 | Maximum number of compression cycles per prompt | + +### Multi-trajectory generation flow + +``` +Prompt → Generate trajectory[0] → detected? + │ + No ──→ return [output[0]] + Yes ─→ Compress → Generate trajectory[1] → new detected? + │ + No ──→ return [output[0], output[1]] + Yes ─→ Compress → ... (up to max_context_compressions) +``` + +Each call to `run()` returns `list[AgentLoopOutput]`, one output per trajectory. Every output includes: + +- **output[0]**: `response_mask` is all 1s (fully generated). +- **output[k] (k > 0)**: `response_mask` starts with 0s (compressed summary prefix) followed by 1s (new generation). + +### Usage + +Register `"naive_summarizer_agent"` in your config: + +```yaml +actor_rollout_ref: + rollout: + agent: + agent_loop_cls: naive_summarizer_agent +``` + + +## Implementing a Custom Context Manager + +To build your own compression strategy: + +```python +from recipe.context_management.context_manager import ContextManager, ContextState + +class MyContextManager(ContextManager): + async def _should_compress(self, state: ContextState) -> bool: + # Example: compress when total tokens exceed a threshold + return len(state.trajectory_ids) > 8192 + + async def _compress_impl(self, state: ContextState) -> ContextState: + # Your compression logic here + compressed_trajectory_ids = ... + compressed_messages = ... + response_length = len(compressed_trajectory_ids) - len(original_prompt_ids) + + return ContextState( + messages=compressed_messages, + trajectory_ids=compressed_trajectory_ids, + response_mask=[0] * response_length, # must be all zeros + response_logprobs=[0.0] * response_length, # must be all zeros + multi_modal_data=dict(state.multi_modal_data), + routed_experts=None, # must be cleared + reward_score=state.reward_score, + num_turns=state.num_turns, + metrics=state.metrics.model_copy(deep=True), + extra_fields=dict(state.extra_fields), + ) +``` + +Then plug it into your custom `AgentLoopWithContextManagement` subclass: + +```python +from verl.experimental.agent_loop import AgentLoopWithContextManagement + +class MyAgentLoop(AgentLoopWithContextManagement): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.context_manager = MyContextManager() + + async def run(self, sampling_params, **kwargs): + # Your generation + compression loop + ... +``` diff --git a/context_management/example/README.md b/context_management/example/README.md new file mode 100644 index 00000000..45646388 --- /dev/null +++ b/context_management/example/README.md @@ -0,0 +1,36 @@ +# Context-management example + +Runs GRPO with the `naive_summarizer_agent` loop, which compresses the trajectory whenever the model +emits a `...` block and continues from `(initial prompt + summary)`. + +## Files + +- `agent.yaml` — registers `naive_summarizer_agent` and `tool_sliding_window_agent` (forwarded as + agent-loop `__init__` kwargs). Passed to verl via `agent_loop_config_path`. +- `run_summarizer.sh` — minimal GRPO launch. Set `MODEL_PATH`, `TRAIN_FILES`, `VAL_FILES`. + +## Run + +```bash +# from a verl checkout that includes this recipe (git submodule update --init --recursive recipe) +export MODEL_PATH=Qwen/Qwen2.5-3B-Instruct +export TRAIN_FILES=$HOME/data/gsm8k/train.parquet +export VAL_FILES=$HOME/data/gsm8k/test.parquet +bash recipe/context_management/example/run_summarizer.sh +``` + +## Switching strategy + +Select a different registered loop without code changes: + +```bash +bash run_summarizer.sh actor_rollout_ref.rollout.agent.default_agent_loop=tool_sliding_window_agent +``` + +## Notes + +- Summarization only triggers when a trajectory actually approaches the context window. On tasks whose + rollouts comfortably fit, the loop is effectively a no-op — to see (and train) compression, use a + long-horizon task or tighten `data.max_response_length` / `actor_rollout_ref.rollout.max_model_len`. +- The model must already know how to emit a `` block on demand (e.g. via SFT cold-start); + otherwise the summarizer never triggers. diff --git a/context_management/example/agent.yaml b/context_management/example/agent.yaml new file mode 100644 index 00000000..3d3c05c1 --- /dev/null +++ b/context_management/example/agent.yaml @@ -0,0 +1,16 @@ +# Agent-loop registry for the context-management recipe. +# Pass via: actor_rollout_ref.rollout.agent.agent_loop_config_path=recipe/context_management/example/agent.yaml +# then select one with actor_rollout_ref.rollout.agent.default_agent_loop=. +# +# Each entry's extra keys (beyond name/_target_) are forwarded as __init__ kwargs to the loop. + +- name: naive_summarizer_agent + _target_: recipe.context_management.agent_loop_with_context_management.SummarizerAgentLoop + # Max number of summarize-and-continue compressions per rollout. + max_context_compressions: 4 + +- name: tool_sliding_window_agent + _target_: recipe.context_management.agent_loop_with_context_management.ToolSlidingWindowAgentLoop + # Max number of sliding-window compressions per rollout (turn limits come from + # actor_rollout_ref.rollout.multi_turn.*). + max_context_compressions: 4 diff --git a/context_management/example/run_summarizer.sh b/context_management/example/run_summarizer.sh new file mode 100644 index 00000000..723d15e5 --- /dev/null +++ b/context_management/example/run_summarizer.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Minimal GRPO example wiring the `naive_summarizer_agent` context-management loop. +# Set MODEL_PATH / TRAIN_FILES / VAL_FILES for your task, then: bash run_summarizer.sh +set -xeuo pipefail + +MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2.5-3B-Instruct"} +TRAIN_FILES=${TRAIN_FILES:-"$HOME/data/gsm8k/train.parquet"} +VAL_FILES=${VAL_FILES:-"$HOME/data/gsm8k/test.parquet"} + +# Path to this recipe's agent-loop registry (relative to a verl checkout with the recipe submodule). +AGENT_LOOP_CONFIG=recipe/context_management/example/agent.yaml + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$TRAIN_FILES" \ + data.val_files="$VAL_FILES" \ + data.train_batch_size=128 \ + data.max_prompt_length=2048 \ + data.max_response_length=16384 \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path="$MODEL_PATH" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + `# --- context-management wiring (the point of this recipe) ---` \ + actor_rollout_ref.rollout.agent.agent_loop_config_path="$AGENT_LOOP_CONFIG" \ + actor_rollout_ref.rollout.agent.default_agent_loop=naive_summarizer_agent \ + trainer.logger='["console"]' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.total_epochs=15 \ + trainer.device=cuda "$@" diff --git a/context_management/test_agent_loop_with_context_management.py b/context_management/test_agent_loop_with_context_management.py new file mode 100644 index 00000000..2533ec28 --- /dev/null +++ b/context_management/test_agent_loop_with_context_management.py @@ -0,0 +1,528 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, Optional + +import pytest +from omegaconf import OmegaConf +from recipe.context_management.agent_loop_with_context_management import ( + SummarizerAgentLoop, + ToolSlidingWindowAgentLoop, +) +from recipe.context_management.context_manager import ContextState + +from verl.experimental.agent_loop.agent_loop import AgentLoopMetrics, DictConfigWrap +from verl.tools.schemas import ToolResponse +from verl.utils.chat_template import initialize_system_prompt +from verl.utils.dataset.rl_dataset import RLHFDataset +from verl.workers.rollout.replica import TokenOutput + + +class _FakeTokenizer: + """Char-level tokenizer mock for deterministic unit tests.""" + + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + del add_special_tokens + return [ord(ch) for ch in text] + + def decode(self, token_ids: list[int], skip_special_tokens: bool = False) -> str: + del skip_special_tokens + return "".join(chr(token_id) for token_id in token_ids) + + def apply_chat_template( + self, + messages: list[dict[str, Any]], + *, + tools: Optional[list[dict]] = None, + add_generation_prompt: bool = True, + tokenize: bool = True, + **kwargs, + ) -> list[int] | str: + del tools, kwargs + parts = [] + for message in messages: + if message["role"] == "tool": + parts.append(f"{message['content']}") + else: + parts.append(f"<{message['role']}>{message['content']}") + text = "".join(parts) + if add_generation_prompt: + text += "" + if not tokenize: + return text + return self.encode(text) + + +class _QueuedServerManager: + """Minimal fake server manager that returns pre-seeded responses in order. + + Pops one response string per generate() call and records each call in self.calls + so tests can inspect the prompt_ids passed to the model. + """ + + def __init__(self, tokenizer: _FakeTokenizer, responses: list[str]): + self._tokenizer = tokenizer + self._responses = list(responses) + self.calls: list[dict[str, Any]] = [] + + async def generate( + self, + request_id: str, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: Optional[list[Any]] = None, + video_data: Optional[list[Any]] = None, + ) -> TokenOutput: + del sampling_params, image_data, video_data + if not self._responses: + raise AssertionError("No fake response left for _QueuedServerManager.generate().") + + response_text = self._responses.pop(0) + response_ids = self._tokenizer.encode(response_text) + self.calls.append({"request_id": request_id, "prompt_ids": list(prompt_ids), "response_text": response_text}) + return TokenOutput( + token_ids=response_ids, + log_probs=[0.0] * len(response_ids), + num_preempted=0, + ) + + +def _build_summarizer_loop( + *, responses: list[str], max_context_compressions: int = 4 +) -> tuple[SummarizerAgentLoop, _FakeTokenizer]: + """Build a summarizer agent loop with deterministic fake dependencies for unit tests.""" + + config = OmegaConf.create( + { + "actor_rollout_ref": { + "rollout": {"prompt_length": 128, "response_length": 256}, + "model": {}, + }, + "data": {"apply_chat_template_kwargs": {}}, + } + ) + tokenizer = _FakeTokenizer() + loop = SummarizerAgentLoop( + trainer_config=DictConfigWrap(config), + server_manager=_QueuedServerManager(tokenizer, responses), + tokenizer=tokenizer, + processor=None, + dataset_cls=RLHFDataset, + data_config=DictConfigWrap(config.data), + max_context_compressions=max_context_compressions, + ) + return loop, tokenizer + + +class _FakeTextTool: + """Small async text tool used by ToolSlidingWindowAgentLoop tests.""" + + def __init__( + self, + *, + responses: Optional[dict[str, ToolResponse]] = None, + rewards: Optional[dict[str, float]] = None, + ): + self.responses = responses or {} + self.rewards = rewards or {} + self.created_with: list[dict[str, Any]] = [] + self.executed_with: list[dict[str, Any]] = [] + self.released: list[str] = [] + + async def create(self, create_kwargs: dict[str, Any]): + self.created_with.append(dict(create_kwargs)) + return f"instance-{len(self.created_with)}", {} + + async def execute(self, instance_id: str, parameters: dict[str, Any]): + del instance_id + self.executed_with.append(dict(parameters)) + value = str(parameters.get("value", "")) + response = self.responses.get(value, ToolResponse(text=f"obs:{value}")) + reward = self.rewards.get(value, 0.0) + return response, reward, {} + + async def release(self, instance_id: str): + self.released.append(instance_id) + + +def _build_tool_loop( + *, + responses: list[str], + tool: Optional[_FakeTextTool] = None, + response_length: int = 4096, + max_context_compressions: int = 4, + compress_when_m_observations: int = 16, + keep_last_n_observations: int = 2, + max_assistant_turns: Optional[int] = None, + max_user_turns: Optional[int] = None, + interaction_config_path: Optional[str] = None, +) -> tuple[ToolSlidingWindowAgentLoop, _FakeTokenizer, _FakeTextTool]: + """Build a ToolSlidingWindowAgentLoop with fake server/parser-compatible tool calls.""" + + config = OmegaConf.create( + { + "actor_rollout_ref": { + "rollout": { + "prompt_length": 512, + "response_length": response_length, + "multi_turn": { + "format": "hermes", + "tool_config_path": None, + "interaction_config_path": interaction_config_path, + "max_parallel_calls": 1, + "max_assistant_turns": max_assistant_turns, + "max_user_turns": max_user_turns, + "max_tool_response_length": 256, + "tool_response_truncate_side": "middle", + }, + }, + "model": {}, + }, + "data": {"apply_chat_template_kwargs": {}}, + } + ) + tokenizer = _FakeTokenizer() + loop = ToolSlidingWindowAgentLoop( + trainer_config=DictConfigWrap(config), + server_manager=_QueuedServerManager(tokenizer, responses), + tokenizer=tokenizer, + processor=None, + dataset_cls=RLHFDataset, + data_config=DictConfigWrap(config.data), + max_context_compressions=max_context_compressions, + compress_when_m_observations=compress_when_m_observations, + keep_last_n_observations=keep_last_n_observations, + ) + tool = tool or _FakeTextTool() + loop.tools = {"lookup": tool} + return loop, tokenizer, tool + + +def _tool_call(value: str, *, name: str = "lookup") -> str: + payload = {"name": name, "arguments": {"value": value}} + return f"{json.dumps(payload, separators=(',', ':'))}" + + +def _tool_response_ids(tokenizer: _FakeTokenizer, text: str) -> list[int]: + return tokenizer.apply_chat_template( + [{"role": "tool", "content": text}], + add_generation_prompt=True, + tokenize=True, + ) + + +def _build_expected_summary_ids(tokenizer: _FakeTokenizer, summary_text: str) -> list[int]: + """Return the token ids that the loop prepends after summarization compression. + + Mirrors SummarizerContextManager: apply_chat_template on the summary assistant + message with additional generation tokens, then strip the system-prompt prefix. + """ + system_prompt_ids = initialize_system_prompt(tokenizer) + summary_ids = tokenizer.apply_chat_template( + [{"role": "assistant", "content": summary_text}], + add_generation_prompt=True, + tokenize=True, + ) + return summary_ids[len(system_prompt_ids) :] + + +def test_summarizer_agent_loop_rejects_negative_max_context_compressions(): + # Passing a negative compression cap should raise ValueError at construction time. + with pytest.raises(ValueError, match="max_context_compressions must be non-negative"): + _build_summarizer_loop(responses=["hello"], max_context_compressions=-1) + + +@pytest.mark.asyncio +async def test_build_output_from_state_handles_empty_response(): + # When response_mask is empty, the entire trajectory should become prompt_ids + # and response_ids / response_mask should both be empty lists. + loop, _ = _build_summarizer_loop(responses=[]) + state = ContextState( + messages=[{"role": "user", "content": "hi"}], + trajectory_ids=[1, 2, 3], + response_mask=[], + response_logprobs=[], + metrics=AgentLoopMetrics(), + extra_fields={"source": "test"}, + ) + + output = loop._build_output_from_state(state) + + assert output.prompt_ids == [1, 2, 3] + assert output.response_ids == [] + assert output.response_mask == [] + assert output.extra_fields["source"] == "test" + assert output.extra_fields["turn_scores"] == [] + assert output.extra_fields["tool_rewards"] == [] + + +@pytest.mark.asyncio +async def test_summarizer_agent_loop_run_returns_multiple_outputs_after_summary_compression(): + # First generation contains a ; the loop compresses and generates again. + # Verifies: two outputs are returned, both calls share the same request_id, + # the second output starts with the summary token ids (mask=0) followed by the + # new generation (mask=1). + summary_text = "compressed summary" + first_response = f"thinking...{summary_text}" + second_response = "final answer" + raw_prompt = [{"role": "user", "content": "hello"}] + loop, tokenizer = _build_summarizer_loop( + responses=[first_response, second_response], + max_context_compressions=1, + ) + + outputs = await loop.run(sampling_params={}, raw_prompt=raw_prompt) + + assert len(outputs) == 2 + assert len(loop.server_manager.calls) == 2 + assert loop.server_manager.calls[0]["request_id"] == loop.server_manager.calls[1]["request_id"] + + first_output_text = tokenizer.decode(outputs[0].response_ids) + second_output_text = tokenizer.decode(outputs[1].response_ids) + summary_ids = _build_expected_summary_ids(tokenizer, summary_text) + + assert first_output_text == first_response + assert second_output_text == tokenizer.decode(summary_ids) + second_response + assert outputs[0].response_mask == [1] * len(outputs[0].response_ids) + assert outputs[1].response_mask[: len(summary_ids)] == [0] * len(summary_ids) + assert outputs[1].response_mask[len(summary_ids) :] == [1] * len(tokenizer.encode(second_response)) + + +@pytest.mark.asyncio +async def test_summarizer_agent_loop_run_returns_single_output_without_summary(): + # No in the response means no compression; run() returns exactly one output + # with all tokens marked as generated (response_mask all-ones). + loop, tokenizer = _build_summarizer_loop(responses=["plain final answer"], max_context_compressions=4) + + outputs = await loop.run(sampling_params={}, raw_prompt=[{"role": "user", "content": "hello"}]) + + assert len(outputs) == 1 + assert tokenizer.decode(outputs[0].response_ids) == "plain final answer" + assert outputs[0].response_mask == [1] * len(outputs[0].response_ids) + + +@pytest.mark.asyncio +async def test_summarizer_agent_loop_run_respects_zero_max_context_compressions(): + # max_context_compressions=0 means compression is never applied even if a + # is present; run() stops after the first generation with a single output. + summary_text = "compressed summary" + first_response = f"thinking...{summary_text}" + loop, tokenizer = _build_summarizer_loop(responses=[first_response], max_context_compressions=0) + + outputs = await loop.run(sampling_params={}, raw_prompt=[{"role": "user", "content": "hello"}]) + + assert len(outputs) == 1 + assert len(loop.server_manager.calls) == 1 + assert tokenizer.decode(outputs[0].response_ids) == first_response + + +@pytest.mark.asyncio +async def test_summarizer_agent_loop_run_supports_multiple_compressions_until_cap(): + # Two consecutive compressions (cap=2): each compressed output starts with the + # corresponding summary ids (mask=0) followed by the new generation (mask=1). + # Verifies that the loop chains compressions correctly up to the cap. + summary1 = "summary 1" + summary2 = "summary 2" + responses = [ + f"step1...{summary1}", + f"step2...{summary2}", + "final answer", + ] + raw_prompt = [{"role": "user", "content": "hello"}] + loop, tokenizer = _build_summarizer_loop(responses=responses, max_context_compressions=2) + + outputs = await loop.run(sampling_params={}, raw_prompt=raw_prompt) + + assert len(outputs) == 3 + assert len(loop.server_manager.calls) == 3 + summary1_ids = _build_expected_summary_ids(tokenizer, summary1) + summary2_ids = _build_expected_summary_ids(tokenizer, summary2) + assert tokenizer.decode(outputs[0].response_ids) == responses[0] + assert tokenizer.decode(outputs[1].response_ids) == tokenizer.decode(summary1_ids) + responses[1] + assert tokenizer.decode(outputs[2].response_ids) == tokenizer.decode(summary2_ids) + responses[2] + + +def test_tool_sliding_window_agent_loop_rejects_invalid_constructor_args(): + # The loop is intentionally text-only and does not support interaction callbacks. + with pytest.raises(ValueError, match="max_context_compressions must be non-negative"): + _build_tool_loop(responses=["unused"], max_context_compressions=-1) + + with pytest.raises(ValueError, match="does not support interaction_config_path"): + _build_tool_loop(responses=["unused"], interaction_config_path="/tmp/interaction.json") + + +@pytest.mark.asyncio +async def test_tool_sliding_window_agent_loop_rejects_non_text_prompt_messages(): + loop, _, _ = _build_tool_loop(responses=["unused"]) + + with pytest.raises(ValueError, match="only supports string message content"): + await loop.run( + sampling_params={}, + raw_prompt=[{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + ) + + assert loop.server_manager.calls == [] + + +@pytest.mark.asyncio +async def test_tool_sliding_window_agent_loop_runs_tool_round_without_reparsing_history(): + # The first assistant response contains a tool call; the second one is a final + # answer. If the parser accidentally scans the whole history, it would see the + # old tool call again and try to execute the tool a second time. + tool_call = _tool_call("first") + final_response = "final answer" + tool = _FakeTextTool(rewards={"first": 0.5}) + loop, tokenizer, tool = _build_tool_loop( + responses=[tool_call, final_response], + tool=tool, + compress_when_m_observations=2, + keep_last_n_observations=1, + ) + + outputs = await loop.run(sampling_params={}, raw_prompt=[{"role": "user", "content": "lookup first"}]) + + assert len(outputs) == 1 + assert len(loop.server_manager.calls) == 2 + assert loop.server_manager.calls[0]["request_id"] == loop.server_manager.calls[1]["request_id"] + assert tool.executed_with == [{"value": "first"}] + assert tool.released == ["instance-1"] + + tool_response_ids = _tool_response_ids(tokenizer, "obs:first") + expected_response_text = tool_call + tokenizer.decode(tool_response_ids) + final_response + expected_response_mask = ( + [1] * len(tokenizer.encode(tool_call)) + + [0] * len(tool_response_ids) + + [1] * len(tokenizer.encode(final_response)) + ) + + output = outputs[0] + assert tokenizer.decode(output.response_ids) == expected_response_text + assert output.response_mask == expected_response_mask + assert output.extra_fields["tool_rewards"] == [0.5] + assert output.extra_fields["session_tool_rewards"] == [0.5] + assert output.extra_fields["context_compression_count"] == 0 + assert output.extra_fields["agent_loop_impl"] == "ToolSlidingWindowAgentLoop" + + +@pytest.mark.asyncio +async def test_tool_sliding_window_agent_loop_recompresses_without_recounting_replaced_tool_responses(): + first_tool_call = _tool_call("first") + second_tool_call = _tool_call("second") + third_tool_call = _tool_call("third") + first_assistant = f"reasoning1: inspect first. {first_tool_call}" + second_assistant = f"reasoning2: inspect second. {second_tool_call}" + third_assistant = f"reasoning3: inspect third. {third_tool_call}" + final_response = "final output" + tool = _FakeTextTool(rewards={"first": 0.25, "second": 0.75, "third": 1.25}) + loop, tokenizer, tool = _build_tool_loop( + responses=[first_assistant, second_assistant, third_assistant, final_response], + tool=tool, + max_context_compressions=2, + compress_when_m_observations=2, + keep_last_n_observations=1, + ) + + outputs = await loop.run(sampling_params={}, raw_prompt=[{"role": "user", "content": "lookup three times"}]) + + assert len(outputs) == 3 + assert len(loop.server_manager.calls) == 4 + assert tool.executed_with == [{"value": "first"}, {"value": "second"}, {"value": "third"}] + + prompt_text = tokenizer.decode(outputs[0].prompt_ids) + first_tool_response_ids = _tool_response_ids(tokenizer, "obs:first") + second_tool_response_ids = _tool_response_ids(tokenizer, "obs:second") + third_tool_response_ids = _tool_response_ids(tokenizer, "obs:third") + compressed_tool_response_text = tokenizer.decode(_tool_response_ids(tokenizer, "[Compressed]")) + first_tool_response_text = tokenizer.decode(first_tool_response_ids) + second_tool_response_text = tokenizer.decode(second_tool_response_ids) + third_tool_response_text = tokenizer.decode(third_tool_response_ids) + + first_precompression_text = ( + first_assistant + first_tool_response_text + second_assistant + second_tool_response_text + ) + first_precompression_mask = ( + [1] * len(tokenizer.encode(first_assistant)) + + [0] * len(first_tool_response_ids) + + [1] * len(tokenizer.encode(second_assistant)) + + [0] * len(second_tool_response_ids) + ) + first_output = outputs[0] + first_full_text = tokenizer.decode(first_output.prompt_ids + first_output.response_ids) + assert first_full_text == prompt_text + first_precompression_text + assert first_output.response_mask == first_precompression_mask + assert first_output.extra_fields["tool_rewards"] == [0.25, 0.75] + assert first_output.extra_fields["session_tool_rewards"] == [0.25, 0.75] + assert first_output.extra_fields["context_compression_count"] == 0 + + once_compressed_prefix_text = ( + first_assistant + compressed_tool_response_text + second_assistant + second_tool_response_text + ) + second_precompression_text = once_compressed_prefix_text + third_assistant + third_tool_response_text + second_precompression_mask = ( + [0] * len(tokenizer.encode(once_compressed_prefix_text)) + + [1] * len(tokenizer.encode(third_assistant)) + + [0] * len(third_tool_response_ids) + ) + second_output = outputs[1] + assert tokenizer.decode(second_output.response_ids) == second_precompression_text + assert second_output.response_mask == second_precompression_mask + assert "obs:first" not in second_precompression_text + assert "obs:second" in second_precompression_text + assert "obs:third" in second_precompression_text + assert second_output.extra_fields["tool_rewards"] == [1.25] + assert second_output.extra_fields["session_tool_rewards"] == [0.25, 0.75, 1.25] + assert second_output.extra_fields["context_compression_count"] == 1 + + twice_compressed_prefix_text = ( + first_assistant + + compressed_tool_response_text + + second_assistant + + compressed_tool_response_text + + third_assistant + + third_tool_response_text + ) + final_output = outputs[2] + assert tokenizer.decode(final_output.response_ids) == twice_compressed_prefix_text + final_response + assert final_output.response_mask == [0] * len(tokenizer.encode(twice_compressed_prefix_text)) + [1] * len( + tokenizer.encode(final_response) + ) + assert final_output.extra_fields["tool_rewards"] == [] + assert final_output.extra_fields["session_tool_rewards"] == [0.25, 0.75, 1.25] + assert final_output.extra_fields["context_compression_count"] == 2 + + first_compressed_prompt = tokenizer.decode(loop.server_manager.calls[2]["prompt_ids"], skip_special_tokens=False) + second_compressed_prompt = tokenizer.decode(loop.server_manager.calls[3]["prompt_ids"], skip_special_tokens=False) + assert first_compressed_prompt == prompt_text + once_compressed_prefix_text + assert second_compressed_prompt == prompt_text + twice_compressed_prefix_text + assert first_compressed_prompt.count("[Compressed]") == 1 + assert second_compressed_prompt.count("[Compressed]") == 2 + assert "obs:first" not in second_compressed_prompt + assert "obs:second" not in second_compressed_prompt + assert "obs:third" in second_compressed_prompt + + +@pytest.mark.asyncio +async def test_tool_sliding_window_agent_loop_rejects_multimodal_tool_response(): + tool = _FakeTextTool(responses={"image": ToolResponse(text="obs:image", image=["fake-image"])}) + loop, _, tool = _build_tool_loop(responses=[_tool_call("image")], tool=tool) + + with pytest.raises(ValueError, match="only supports text tool responses"): + await loop.run(sampling_params={}, raw_prompt=[{"role": "user", "content": "call image tool"}]) + + assert tool.executed_with == [{"value": "image"}] + assert tool.released == ["instance-1"] diff --git a/context_management/test_context_manager.py b/context_management/test_context_manager.py new file mode 100644 index 00000000..7c06cb29 --- /dev/null +++ b/context_management/test_context_manager.py @@ -0,0 +1,337 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import pytest +from recipe.context_management.context_manager import ( + ContextState, + SlidingWindowContextManager, + SummarizerContextManager, +) + +from verl.utils.chat_template import initialize_system_prompt + + +class _FakeTokenizer: + """Char-level tokenizer mock for deterministic encode/decode in unit tests.""" + + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + del add_special_tokens + return [ord(ch) for ch in text] + + def decode(self, token_ids: list[int], skip_special_tokens: bool = False) -> str: + del skip_special_tokens + return "".join(chr(token_id) for token_id in token_ids) + + def apply_chat_template( + self, + messages: list[dict[str, Any]], + *, + tools=None, + add_generation_prompt: bool = True, + tokenize: bool = True, + **kwargs, + ) -> list[int] | str: + del tools, kwargs + text = "".join(f"<{message['role']}>{message['content']}" for message in messages) + if add_generation_prompt: + text += "" + if not tokenize: + return text + return self.encode(text) + + +def _build_state( + *, + prompt_text: str, + response_text: str, + messages: list[dict[str, Any]], + response_mask: list[int] | None = None, + response_logprobs: list[float] | None = None, + routed_experts=None, +) -> ContextState: + """Build a ContextState from raw text strings for use in tests. + + prompt_text / response_text are char-encoded by _FakeTokenizer. + response_mask defaults to all-ones (fully generated); response_logprobs defaults to empty. + """ + tokenizer = _FakeTokenizer() + prompt_ids = tokenizer.encode(prompt_text) + response_ids = tokenizer.encode(response_text) + if response_mask is None: + response_mask = [1] * len(response_ids) + if response_logprobs is None: + response_logprobs = [] + return ContextState( + messages=messages, + trajectory_ids=prompt_ids + response_ids, + response_mask=response_mask, + response_logprobs=response_logprobs, + routed_experts=routed_experts, + multi_modal_data={"images": ["keep-me"]}, + reward_score=1.0, + num_turns=3, + extra_fields={"source": "test"}, + ) + + +def _build_expected_summary_suffix_ids(tokenizer: _FakeTokenizer, summary_text: str) -> list[int]: + """Return the token ids that SummarizerContextManager appends after compression. + + Mirrors the manager's logic: apply_chat_template on the summary message, + then strip the system-prompt prefix. + """ + system_prompt_ids = initialize_system_prompt(tokenizer) + summary_ids = tokenizer.apply_chat_template( + [{"role": "assistant", "content": summary_text}], + add_generation_prompt=True, + tokenize=True, + ) + return summary_ids[len(system_prompt_ids) :] + + +@pytest.mark.asyncio +async def test_sliding_window_should_compress_ignores_already_compressed_observations(): + # One observation is already compressed, one is not. Uncompressed count < M, so compression should not trigger. + tokenizer = _FakeTokenizer() + manager = SlidingWindowContextManager( + compress_when_m_observations=2, + keep_last_n_observations=1, + tokenizer=tokenizer, + ) + obs1 = "[Compressed]" + obs2 = "obs2" + state = _build_state( + prompt_text="PROMPT", + response_text=obs1 + obs2, + messages=[ + {"role": "user", "content": "prompt"}, + {"role": "tool", "content": "[Compressed]"}, + {"role": "tool", "content": "obs2"}, + ], + ) + + next_state, compressed = await manager.check_and_compress(state) + + assert next_state == state + assert not compressed + + +@pytest.mark.asyncio +async def test_sliding_window_compress_rewrites_messages_and_response_segment(): + # 3 observations reach threshold M=3; keep last N=1, replace the first two with placeholders + # in both token ids and messages. After compression, response_mask is all-zero and routed_experts is cleared. + tokenizer = _FakeTokenizer() + manager = SlidingWindowContextManager( + compress_when_m_observations=3, + keep_last_n_observations=1, + tokenizer=tokenizer, + ) + obs1 = "obs1" + obs2 = "obs2" + obs3 = "obs3" + response_text = obs1 + obs2 + obs3 + state = _build_state( + prompt_text="PROMPT", + response_text=response_text, + messages=[ + {"role": "user", "content": "prompt"}, + {"role": "tool", "content": "obs1"}, + {"role": "tool", "content": "obs2"}, + {"role": "tool", "content": "obs3"}, + ], + response_logprobs=[0.1] * len(response_text), + routed_experts="stale-routes", + ) + + compressed_state, compressed = await manager.check_and_compress(state) + compressed_response_ids = compressed_state.trajectory_ids[-len(compressed_state.response_mask) :] + compressed_response_text = tokenizer.decode(compressed_response_ids) + compressed_obs = "[Compressed]" + + assert compressed + assert compressed_response_text == compressed_obs + compressed_obs + obs3 + assert compressed_state.messages[1]["content"] == "[Compressed]" + assert compressed_state.messages[2]["content"] == "[Compressed]" + assert compressed_state.messages[3]["content"] == "obs3" + assert compressed_state.response_mask == [0] * len(compressed_response_ids) + assert compressed_state.response_logprobs == [0.0] * len(compressed_response_ids) + assert compressed_state.routed_experts is None + + +@pytest.mark.asyncio +async def test_sliding_window_check_and_compress_returns_false_below_threshold(): + # Only 1 observation, below threshold M=2; check_and_compress should return the original state + # with compressed=False. + manager = SlidingWindowContextManager( + compress_when_m_observations=2, + keep_last_n_observations=1, + tokenizer=_FakeTokenizer(), + ) + state = _build_state( + prompt_text="PROMPT", + response_text="obs1", + messages=[ + {"role": "user", "content": "prompt"}, + {"role": "tool", "content": "obs1"}, + ], + ) + + next_state, compressed = await manager.check_and_compress(state) + + assert next_state == state + assert not compressed + + +@pytest.mark.asyncio +async def test_sliding_window_compress_raises_when_no_new_observation_is_removed(): + # Two observations but the first is already a placeholder, so _compress_impl removes zero new + # observations and should raise ValueError. + tokenizer = _FakeTokenizer() + manager = SlidingWindowContextManager( + compress_when_m_observations=2, + keep_last_n_observations=1, + tokenizer=tokenizer, + ) + obs1 = "[Compressed]" + obs2 = "obs2" + state = _build_state( + prompt_text="PROMPT", + response_text=obs1 + obs2, + messages=[ + {"role": "user", "content": "prompt"}, + {"role": "tool", "content": "[Compressed]"}, + {"role": "tool", "content": "obs2"}, + ], + ) + + with pytest.raises(ValueError, match="removed zero observations unexpectedly"): + await manager._compress_impl(state) + + +@pytest.mark.asyncio +async def test_summarizer_should_compress_only_checks_current_generated_tokens(): + # The response contains a prior (mask=0, treated as prompt), but the current generation has none. + # Ensures _should_compress only inspects mask=1 tokens, so compression should not trigger. + tokenizer = _FakeTokenizer() + old_summary = "previous summary" + current_generation = "new response without summary" + manager = SummarizerContextManager(tokenizer=tokenizer) + state = _build_state( + prompt_text="PROMPT", + response_text=old_summary + current_generation, + messages=[ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "prompt"}, + {"role": "assistant", "content": old_summary}, + ], + response_mask=[0] * len(old_summary) + [1] * len(current_generation), + ) + + next_state, compressed = await manager.check_and_compress(state) + + assert next_state == state + assert not compressed + + +@pytest.mark.asyncio +async def test_summarizer_compress_keeps_last_summary_when_multiple_exist(): + # Response contains two blocks; after compression only the last one should be kept. + tokenizer = _FakeTokenizer() + prefix = "thinking..." + summary_old = "old summary" + middle = "more thinking..." + summary_new = "new summary" + response_text = prefix + summary_old + middle + summary_new + prompt_messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "prompt"}, + ] + manager = SummarizerContextManager(tokenizer=tokenizer) + state = _build_state( + prompt_text=tokenizer.apply_chat_template(prompt_messages, tokenize=False), + response_text=response_text, + messages=prompt_messages, + ) + + compressed_state, compressed = await manager.check_and_compress(state) + assert compressed + assert compressed_state.messages[-1]["content"] == summary_new + + +@pytest.mark.asyncio +async def test_summarizer_compress_keeps_original_prompt_and_last_summary(): + # Multi-turn conversation (assistant + tool) after compression: + # - messages retains only system/user turns plus the final summary assistant message + # - trajectory_ids keeps the original prompt prefix; the tail is replaced with summary token ids + # - response_mask and response_logprobs are all-zero; routed_experts is cleared + tokenizer = _FakeTokenizer() + previous_assistant = "intermediate reasoning" + tool_observation = "tool observation" + thinking = "thinking..." + summary_text = "new summary" + final_assistant = thinking + summary_text + response_text = previous_assistant + tool_observation + final_assistant + response_mask = [1] * len(previous_assistant) + [0] * len(tool_observation) + [1] * len(final_assistant) + prompt_messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "prompt"}, + ] + manager = SummarizerContextManager(tokenizer=tokenizer) + state = _build_state( + prompt_text=tokenizer.apply_chat_template(prompt_messages, tokenize=False), + response_text=response_text, + messages=[ + *prompt_messages, + {"role": "assistant", "content": previous_assistant}, + {"role": "tool", "content": tool_observation}, + {"role": "assistant", "content": final_assistant}, + ], + response_mask=response_mask, + response_logprobs=[0.1] * len(response_text), + routed_experts="stale-routes", + ) + + compressed_state, compressed = await manager.check_and_compress(state) + compressed_messages = [*prompt_messages, {"role": "assistant", "content": summary_text}] + summary_ids = _build_expected_summary_suffix_ids(tokenizer, summary_text) + + assert compressed + assert compressed_state.messages == compressed_messages + assert ( + compressed_state.trajectory_ids[: len(state.trajectory_ids) - len(state.response_mask)] + == state.trajectory_ids[: len(state.trajectory_ids) - len(state.response_mask)] + ) + assert compressed_state.trajectory_ids[-len(summary_ids) :] == summary_ids + assert compressed_state.response_mask == [0] * len(summary_ids) + assert compressed_state.response_logprobs == [0.0] * len(summary_ids) + assert compressed_state.routed_experts is None + + +@pytest.mark.asyncio +async def test_summarizer_compress_raises_when_summary_is_missing(): + # Response has no block; calling _compress_impl directly should raise ValueError. + manager = SummarizerContextManager(tokenizer=_FakeTokenizer()) + state = _build_state( + prompt_text="PROMPT", + response_text="plain response without summary", + messages=[ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "prompt"}, + ], + ) + + with pytest.raises(ValueError, match="expected a block"): + await manager._compress_impl(state) From c07f42b49873a49fd259f6da655e4512e38e6d1f Mon Sep 17 00:00:00 2001 From: Harish Krishnamoorthy Murali Date: Fri, 5 Jun 2026 16:29:45 +0000 Subject: [PATCH 2/2] fix(recipe): address review feedback on #107 - Correct inverted tool_response_truncate_side in ToolSlidingWindowAgentLoop to match verl-core ToolAgentLoop (left => keep tail, right => keep head). Pre-existing in #5636. - Guard the text-only multi-modal check against empty-list dicts ({'images': [], 'videos': []}). - Remove a fragile backtick-comment idiom from the example run script. Co-Authored-By: Claude Opus 4.8 --- context_management/agent_loop_with_context_management.py | 6 +++--- context_management/example/run_summarizer.sh | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/context_management/agent_loop_with_context_management.py b/context_management/agent_loop_with_context_management.py index 41e881bb..e55c8914 100644 --- a/context_management/agent_loop_with_context_management.py +++ b/context_management/agent_loop_with_context_management.py @@ -269,7 +269,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> list[AgentLoop self._validate_text_messages(messages) multi_modal_data = await self.process_vision_info(messages) - if multi_modal_data: + if multi_modal_data and (multi_modal_data.get("images") or multi_modal_data.get("videos")): raise ValueError("ToolSlidingWindowAgentLoop only supports text prompts.") prompt_ids = await self.apply_chat_template(messages, tools=self.tool_schemas) @@ -453,9 +453,9 @@ def _truncate_tool_response_text(self, text: Optional[str]) -> Optional[str]: return text if self.tool_response_truncate_side == "left": - return text[: self.max_tool_response_length] + "...(truncated)" - if self.tool_response_truncate_side == "right": return "(truncated)..." + text[-self.max_tool_response_length :] + if self.tool_response_truncate_side == "right": + return text[: self.max_tool_response_length] + "...(truncated)" length = self.max_tool_response_length // 2 return text[:length] + "...(truncated)..." + text[-length:] diff --git a/context_management/example/run_summarizer.sh b/context_management/example/run_summarizer.sh index 723d15e5..dc74ae3c 100644 --- a/context_management/example/run_summarizer.sh +++ b/context_management/example/run_summarizer.sh @@ -26,7 +26,6 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.mode=async \ actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.rollout.multi_turn.enable=True \ - `# --- context-management wiring (the point of this recipe) ---` \ actor_rollout_ref.rollout.agent.agent_loop_config_path="$AGENT_LOOP_CONFIG" \ actor_rollout_ref.rollout.agent.default_agent_loop=naive_summarizer_agent \ trainer.logger='["console"]' \