Skip to content

Commit 928b6ee

Browse files
committed
Add request preprocessing and response postprocessing to MainLoop
Add two overridable methods for transforming requests before agent processing and responses before returning to callers: - preprocess_request(request) - override to transform incoming requests - postprocess_response(response) - override to transform outgoing responses Both methods return their input unchanged by default. Subclass TriviaAgentLoop and override these methods to implement custom preprocessing/postprocessing logic.
1 parent 625be17 commit 928b6ee

2 files changed

Lines changed: 269 additions & 2 deletions

File tree

src/trivia_agent/worker.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,21 @@
1414

1515
import os
1616
import sys
17-
from collections.abc import Sequence
17+
from collections.abc import Mapping, Sequence
1818
from dataclasses import field
1919
from datetime import UTC, datetime, timedelta
2020
from pathlib import Path
2121
from typing import TYPE_CHECKING, TextIO
2222

23-
from weakincentives import FrozenDataclass, Prompt
23+
from weakincentives import Budget, FrozenDataclass, Prompt, PromptResponse
2424
from weakincentives.adapters import ProviderAdapter
2525
from weakincentives.adapters.claude_agent_sdk import ClaudeAgentWorkspaceSection, HostMount
2626
from weakincentives.deadlines import Deadline
2727
from weakincentives.debug.bundle import BundleConfig
2828
from weakincentives.prompt import PromptTemplate
2929
from weakincentives.prompt.overrides import LocalPromptOverridesStore, PromptOverridesStore
3030
from weakincentives.runtime import (
31+
Heartbeat,
3132
LoopGroup,
3233
MainLoop,
3334
MainLoopConfig,
@@ -204,6 +205,10 @@ class TriviaAgentLoop(MainLoop[TriviaRequest, TriviaResponse]):
204205
customization of prompt content without code changes
205206
- **Evaluation integration**: Compatible with EvalLoop for running
206207
evaluations with session-aware scoring
208+
- **Request preprocessing**: Override preprocess_request() to transform
209+
requests before agent processing
210+
- **Response postprocessing**: Override postprocess_response() to transform
211+
responses before returning to callers
207212
208213
To use this loop, instantiate it with an adapter and mailbox, then either:
209214
1. Call run() directly for single-threaded processing
@@ -265,6 +270,34 @@ def __init__(
265270
self._base_template = build_prompt_template()
266271
self._overrides_store = overrides_store
267272

273+
def preprocess_request(self, request: TriviaRequest) -> TriviaRequest:
274+
"""Transform request before agent processing.
275+
276+
Override this method in subclasses to implement custom preprocessing
277+
logic such as validation, normalization, or enrichment.
278+
279+
Args:
280+
request: The incoming TriviaRequest.
281+
282+
Returns:
283+
TriviaRequest: The preprocessed request.
284+
"""
285+
return request
286+
287+
def postprocess_response(self, response: TriviaResponse) -> TriviaResponse:
288+
"""Transform response before returning to caller.
289+
290+
Override this method in subclasses to implement custom postprocessing
291+
logic such as formatting, cleanup, or validation.
292+
293+
Args:
294+
response: The TriviaResponse from the agent.
295+
296+
Returns:
297+
TriviaResponse: The postprocessed response.
298+
"""
299+
return response
300+
268301
def prepare(
269302
self,
270303
request: TriviaRequest,
@@ -277,6 +310,10 @@ def prepare(
277310
for isolation, builds the complete PromptTemplate with workspace section,
278311
binds request parameters, and optionally applies experiment overrides.
279312
313+
Note: Request preprocessing is applied in the execute() method before
314+
prepare() is called. This ensures the preprocessed request is used
315+
consistently throughout the execution flow.
316+
280317
This method demonstrates key WINK patterns:
281318
282319
- **Session per request**: Each request gets its own Session for proper
@@ -348,6 +385,53 @@ def prepare(
348385

349386
return prompt, session
350387

388+
def execute(
389+
self,
390+
request: TriviaRequest,
391+
*,
392+
budget: Budget | None = None,
393+
deadline: Deadline | None = None,
394+
resources: Mapping[type[object], object] | None = None,
395+
heartbeat: Heartbeat | None = None,
396+
experiment: Experiment | None = None,
397+
) -> tuple[PromptResponse[TriviaResponse], Session]:
398+
"""Execute a trivia request with preprocessing and postprocessing.
399+
400+
Overrides the parent MainLoop.execute() to apply preprocess_request()
401+
before execution and postprocess_response() after execution.
402+
403+
Args:
404+
request: TriviaRequest containing the question to process.
405+
budget: Optional Budget for token/cost limits.
406+
deadline: Optional Deadline for time limits.
407+
resources: Optional mapping of resource types to instances.
408+
heartbeat: Optional Heartbeat for progress reporting.
409+
experiment: Optional Experiment for evaluation runs.
410+
411+
Returns:
412+
tuple[PromptResponse[TriviaResponse], Session]: Response and session.
413+
"""
414+
# Apply preprocessing to the request
415+
preprocessed_request = self.preprocess_request(request)
416+
417+
# Execute with the preprocessed request
418+
prompt_response, session = super().execute(
419+
preprocessed_request,
420+
budget=budget,
421+
deadline=deadline,
422+
resources=resources,
423+
heartbeat=heartbeat,
424+
experiment=experiment,
425+
)
426+
427+
# Apply postprocessing to the response output (if present)
428+
output = prompt_response.output
429+
if output is not None:
430+
postprocessed_output = self.postprocess_response(output)
431+
return prompt_response.update(output=postprocessed_output), session # type: ignore[return-value]
432+
433+
return prompt_response, session
434+
351435

352436
@FrozenDataclass()
353437
class TriviaRuntime:

tests/trivia_agent/test_worker.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,189 @@ def test_prepare_seeds_overrides_store(
238238
assert call_args.kwargs.get("tag") == "latest"
239239

240240

241+
class TestTriviaAgentLoopPreprocessing:
242+
"""Tests for TriviaAgentLoop.preprocess_request() method."""
243+
244+
def test_preprocess_request_returns_unchanged_by_default(
245+
self,
246+
fake_mailboxes: TriviaMailboxes,
247+
) -> None:
248+
"""Test that preprocess_request returns request unchanged by default."""
249+
mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock()
250+
251+
loop = TriviaAgentLoop(
252+
adapter=mock_adapter,
253+
requests=fake_mailboxes.requests,
254+
)
255+
256+
request = TriviaRequest(question="test question")
257+
result = loop.preprocess_request(request)
258+
259+
assert result is request
260+
261+
def test_preprocess_request_can_be_overridden(
262+
self,
263+
fake_mailboxes: TriviaMailboxes,
264+
) -> None:
265+
"""Test that preprocess_request can be overridden in subclass."""
266+
267+
class CustomLoop(TriviaAgentLoop):
268+
def preprocess_request(self, request: TriviaRequest) -> TriviaRequest:
269+
return TriviaRequest(question=request.question.upper())
270+
271+
mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock()
272+
273+
loop = CustomLoop(
274+
adapter=mock_adapter,
275+
requests=fake_mailboxes.requests,
276+
)
277+
278+
request = TriviaRequest(question="hello world")
279+
result = loop.preprocess_request(request)
280+
281+
assert result.question == "HELLO WORLD"
282+
283+
284+
class TestTriviaAgentLoopPostprocessing:
285+
"""Tests for TriviaAgentLoop.postprocess_response() method."""
286+
287+
def test_postprocess_response_returns_unchanged_by_default(
288+
self,
289+
fake_mailboxes: TriviaMailboxes,
290+
) -> None:
291+
"""Test that postprocess_response returns response unchanged by default."""
292+
mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock()
293+
294+
loop = TriviaAgentLoop(
295+
adapter=mock_adapter,
296+
requests=fake_mailboxes.requests,
297+
)
298+
299+
response = TriviaResponse(answer="42")
300+
result = loop.postprocess_response(response)
301+
302+
assert result is response
303+
304+
def test_postprocess_response_can_be_overridden(
305+
self,
306+
fake_mailboxes: TriviaMailboxes,
307+
) -> None:
308+
"""Test that postprocess_response can be overridden in subclass."""
309+
310+
class CustomLoop(TriviaAgentLoop):
311+
def postprocess_response(self, response: TriviaResponse) -> TriviaResponse:
312+
return TriviaResponse(answer=f"Answer: {response.answer}")
313+
314+
mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock()
315+
316+
loop = CustomLoop(
317+
adapter=mock_adapter,
318+
requests=fake_mailboxes.requests,
319+
)
320+
321+
response = TriviaResponse(answer="42")
322+
result = loop.postprocess_response(response)
323+
324+
assert result.answer == "Answer: 42"
325+
326+
327+
class TestTriviaAgentLoopExecute:
328+
"""Tests for TriviaAgentLoop.execute() with preprocessing/postprocessing."""
329+
330+
def test_execute_calls_preprocess_request(
331+
self,
332+
fake_mailboxes: TriviaMailboxes,
333+
) -> None:
334+
"""Test that execute() calls preprocess_request."""
335+
from weakincentives.runtime import MainLoop
336+
337+
class CustomLoop(TriviaAgentLoop):
338+
def preprocess_request(self, request: TriviaRequest) -> TriviaRequest:
339+
return TriviaRequest(question=request.question.strip())
340+
341+
mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock()
342+
343+
loop = CustomLoop(
344+
adapter=mock_adapter,
345+
requests=fake_mailboxes.requests,
346+
)
347+
348+
mock_response = MagicMock()
349+
mock_response.output = TriviaResponse(answer="42")
350+
mock_session = MagicMock()
351+
352+
captured_requests: list[TriviaRequest] = []
353+
354+
def capture_execute(self_arg, request, **kwargs):
355+
captured_requests.append(request)
356+
return (mock_response, mock_session)
357+
358+
with patch.object(MainLoop, "execute", capture_execute):
359+
request = TriviaRequest(question=" What is the answer? ")
360+
loop.execute(request)
361+
362+
assert len(captured_requests) == 1
363+
assert captured_requests[0].question == "What is the answer?"
364+
365+
def test_execute_calls_postprocess_response(
366+
self,
367+
fake_mailboxes: TriviaMailboxes,
368+
) -> None:
369+
"""Test that execute() calls postprocess_response."""
370+
from weakincentives.runtime import MainLoop
371+
372+
class CustomLoop(TriviaAgentLoop):
373+
def postprocess_response(self, response: TriviaResponse) -> TriviaResponse:
374+
return TriviaResponse(answer=response.answer.strip())
375+
376+
mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock()
377+
378+
loop = CustomLoop(
379+
adapter=mock_adapter,
380+
requests=fake_mailboxes.requests,
381+
)
382+
383+
mock_response = MagicMock()
384+
mock_response.output = TriviaResponse(answer=" 42 ")
385+
mock_response.update = MagicMock(return_value=mock_response)
386+
mock_session = MagicMock()
387+
388+
with patch.object(MainLoop, "execute", return_value=(mock_response, mock_session)):
389+
request = TriviaRequest(question="What is the answer?")
390+
loop.execute(request)
391+
392+
mock_response.update.assert_called_once()
393+
call_kwargs = mock_response.update.call_args.kwargs
394+
assert call_kwargs["output"].answer == "42"
395+
396+
def test_execute_skips_postprocessing_when_output_is_none(
397+
self,
398+
fake_mailboxes: TriviaMailboxes,
399+
) -> None:
400+
"""Test that execute() skips postprocessing when output is None."""
401+
from weakincentives.runtime import MainLoop
402+
403+
mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock()
404+
405+
loop = TriviaAgentLoop(
406+
adapter=mock_adapter,
407+
requests=fake_mailboxes.requests,
408+
)
409+
410+
mock_response = MagicMock()
411+
mock_response.output = None
412+
mock_response.update = MagicMock()
413+
mock_session = MagicMock()
414+
415+
with patch.object(MainLoop, "execute", return_value=(mock_response, mock_session)):
416+
request = TriviaRequest(question="What is the answer?")
417+
result_response, result_session = loop.execute(request)
418+
419+
mock_response.update.assert_not_called()
420+
assert result_response is mock_response
421+
assert result_session is mock_session
422+
423+
241424
class TestTriviaRuntime:
242425
"""Tests for TriviaRuntime dataclass."""
243426

0 commit comments

Comments
 (0)