From 0952f09e419dc3c871896574818fb27a549d576b Mon Sep 17 00:00:00 2001 From: Alberto Gimeno Date: Wed, 17 Jun 2026 14:30:19 +0200 Subject: [PATCH 1/3] go: plumb context cancellation through ToolInvocation for session.Abort() Add a Context field to ToolInvocation that is cancelled when session.Abort is called. This lets tool handlers cooperatively stop in-flight work (HTTP requests, DB queries, sleeps) when the session is aborted, without requiring OS-level process kills. Changes: - Add Context context.Context to ToolInvocation; it carries OTel trace propagation AND is cancelled on Abort() - Deprecate TraceContext in favour of Context (same value, new name that better communicates its dual role) - Track per-tool-call cancel funcs in Session.toolCallCancels; clean up after each call completes - In Abort(), cancel all in-flight cancel funcs after the RPC call - RPC response calls (HandlePendingToolCall) use traceCtx (non-cancellable) so they succeed even when the handler context was aborted - Add TestToolInvocation_ContextCancelledOnAbort and TestToolInvocation_ContextPopulated unit tests - Document the cancellation pattern in go/README.md Fixes #1433 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- go/README.md | 25 ++++++ go/session.go | 44 +++++++++- go/session_test.go | 201 +++++++++++++++++++++++++++++++++++++++++++++ go/types.go | 14 +++- 4 files changed, 279 insertions(+), 5 deletions(-) diff --git a/go/README.md b/go/README.md index b89a76318..bc75faebb 100644 --- a/go/README.md +++ b/go/README.md @@ -355,6 +355,31 @@ session, _ := client.CreateSession(context.Background(), &copilot.SessionConfig{ When the model selects a tool, the SDK automatically runs your handler (in parallel with other calls) and responds to the CLI's `tool.call` with the handler's result. +#### Cooperative Cancellation via session.Abort + +`ToolInvocation.Context` is a `context.Context` that is cancelled when `session.Abort` is called. Pass it to any cancellable operation (HTTP requests, DB queries, sleeps) so the handler stops promptly when the session is aborted: + +```go +lookupIssue := copilot.DefineTool("lookup_issue", "Fetch issue details from our tracker", + func(params LookupIssueParams, inv copilot.ToolInvocation) (any, error) { + // Pass inv.Context so the HTTP request is cancelled on session.Abort. + req, err := http.NewRequestWithContext(inv.Context, "GET", + "https://api.example.com/issues/"+params.ID, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err // returns context.Canceled when aborted + } + defer resp.Body.Close() + // ... + return summary, nil + }) +``` + +Handlers that don't use `inv.Context` are unaffected; they run to completion as before. + #### Overriding Built-in Tools If you register a tool with the same name as a built-in CLI tool (e.g. `edit_file`, `read_file`), the SDK will throw an error unless you explicitly opt in by setting `OverridesBuiltInTool = true`. This flag signals that you intend to replace the built-in tool with your custom implementation. diff --git a/go/session.go b/go/session.go index acd698677..5f42114af 100644 --- a/go/session.go +++ b/go/session.go @@ -82,6 +82,11 @@ type Session struct { capabilities SessionCapabilities capabilitiesMu sync.RWMutex + // toolCallCancels tracks cancel functions for in-flight tool calls so that + // Abort can propagate cancellation into handler contexts. + toolCallCancels map[string]context.CancelFunc + toolCallCancelsMu sync.Mutex + // eventCh serializes user event handler dispatch. dispatchEvent enqueues; // a single goroutine (processEvents) dequeues and invokes handlers in FIFO order. eventCh chan SessionEvent @@ -1337,11 +1342,35 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { // executeToolAndRespond executes a tool handler and sends the result back via RPC. func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, arguments any, handler ToolHandler, traceparent, tracestate string) { - ctx := contextWithTraceParent(context.Background(), traceparent, tracestate) + // traceCtx carries OTel trace propagation but is not subject to abort cancellation. + // It is used for administrative RPC calls that must complete regardless of abort. + traceCtx := contextWithTraceParent(context.Background(), traceparent, tracestate) + // ctx is passed to the tool handler and is cancelled when session.Abort is called, + // giving handlers a cooperative cancellation signal. + ctx, cancel := context.WithCancel(traceCtx) + + s.toolCallCancelsMu.Lock() + if s.toolCallCancels == nil { + s.toolCallCancels = make(map[string]context.CancelFunc) + } + s.toolCallCancels[toolCallID] = cancel + s.toolCallCancelsMu.Unlock() + + // Cleanup runs last (registered first). Removes the cancel from the in-flight map + // and releases context resources. + defer func() { + s.toolCallCancelsMu.Lock() + delete(s.toolCallCancels, toolCallID) + s.toolCallCancelsMu.Unlock() + cancel() + }() + + // Panic recovery runs first (registered second, LIFO). Uses traceCtx to ensure + // the error response is sent even if ctx was already cancelled by Abort. defer func() { if r := recover(); r != nil { errMsg := fmt.Sprintf("tool panic: %v", r) - s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{ + s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{ RequestID: requestID, Error: &errMsg, }) @@ -1353,13 +1382,14 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, ToolCallID: toolCallID, ToolName: toolName, Arguments: arguments, + Context: ctx, TraceContext: ctx, } result, err := handler(invocation) if err != nil { errMsg := err.Error() - s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{ + s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{ RequestID: requestID, Error: &errMsg, }) @@ -1389,7 +1419,7 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, if result.Error != "" { rpcResult.Error = &result.Error } - s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{ + s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{ RequestID: requestID, Result: rpcResult, }) @@ -1555,6 +1585,12 @@ func (s *Session) Abort(ctx context.Context) error { return fmt.Errorf("failed to abort session: %w", err) } + s.toolCallCancelsMu.Lock() + for _, cancel := range s.toolCallCancels { + cancel() + } + s.toolCallCancelsMu.Unlock() + return nil } diff --git a/go/session_test.go b/go/session_test.go index 15cfbcf57..2dfe22252 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -1031,3 +1031,204 @@ func TestSession_ElicitationRequestSchema(t *testing.T) { } }) } + +// TestToolInvocation_ContextCancelledOnAbort verifies that the context passed to a +// tool handler is cancelled when the in-flight cancel func (as used by Abort) fires. +func TestToolInvocation_ContextCancelledOnAbort(t *testing.T) { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + defer stdinR.Close() + defer stdinW.Close() + defer stdoutR.Close() + defer stdoutW.Close() + + client := jsonrpc2.NewClient(stdinW, stdoutR) + client.Start() + defer client.Stop() + + session := &Session{ + SessionID: "session-abort-test", + client: client, + RPC: rpc.NewSessionRPC(client, "session-abort-test"), + } + + // Drain the RPC responses from the mock server side. + go func() { + scanner := bufio.NewScanner(stdinR) + for scanner.Scan() { + // read Content-Length header + line := scanner.Text() + if !strings.HasPrefix(line, "Content-Length:") { + continue + } + var contentLen int + fmt.Sscanf(line, "Content-Length: %d", &contentLen) + // skip blank separator + scanner.Scan() + body := make([]byte, contentLen) + io.ReadFull(stdinR, body) + + var req struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + } + if err := json.Unmarshal(body, &req); err != nil || req.ID == nil { + continue + } + resp, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(req.ID), + "result": map[string]any{}, + }) + fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp) + } + }() + + // Channel to receive the invocation context from the handler. + ctxCh := make(chan context.Context, 1) + + // The handler blocks until its context is cancelled, then reports. + handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { + ctxCh <- inv.Context + <-inv.Context.Done() + return ToolResult{TextResultForLLM: "cancelled"}, nil + }) + + done := make(chan struct{}) + go func() { + defer close(done) + session.executeToolAndRespond("req-1", "my_tool", "tc-1", nil, handler, "", "") + }() + + // Wait for the handler to start and capture its context. + var handlerCtx context.Context + select { + case handlerCtx = <-ctxCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler to start") + } + + // Verify the context is not yet cancelled. + if handlerCtx.Err() != nil { + t.Fatalf("expected context to be active, got %v", handlerCtx.Err()) + } + + // Simulate what Abort() does: cancel all in-flight tool call contexts. + session.toolCallCancelsMu.Lock() + for _, cancel := range session.toolCallCancels { + cancel() + } + session.toolCallCancelsMu.Unlock() + + // Wait for the handler to finish. + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler to finish after cancellation") + } + + // The handler context must be cancelled. + if handlerCtx.Err() == nil { + t.Fatal("expected handler context to be cancelled after abort") + } + + // The cancel func must have been removed from the map. + session.toolCallCancelsMu.Lock() + remaining := len(session.toolCallCancels) + session.toolCallCancelsMu.Unlock() + if remaining != 0 { + t.Fatalf("expected toolCallCancels to be empty after execution, got %d entries", remaining) + } +} + +// TestToolInvocation_ContextPopulated verifies that executeToolAndRespond sets +// both Context and TraceContext on the ToolInvocation passed to the handler. +func TestToolInvocation_ContextPopulated(t *testing.T) { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + defer stdinR.Close() + defer stdinW.Close() + defer stdoutR.Close() + defer stdoutW.Close() + + client := jsonrpc2.NewClient(stdinW, stdoutR) + client.Start() + defer client.Stop() + + session := &Session{ + SessionID: "session-ctx-test", + client: client, + RPC: rpc.NewSessionRPC(client, "session-ctx-test"), + } + + // Drain RPC responses. + go func() { + scanner := bufio.NewScanner(stdinR) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "Content-Length:") { + continue + } + var contentLen int + fmt.Sscanf(line, "Content-Length: %d", &contentLen) + scanner.Scan() + body := make([]byte, contentLen) + io.ReadFull(stdinR, body) + + var req struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + } + if err := json.Unmarshal(body, &req); err != nil || req.ID == nil { + continue + } + resp, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(req.ID), + "result": map[string]any{}, + }) + fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp) + } + }() + + invCh := make(chan ToolInvocation, 1) + handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { + invCh <- inv + return ToolResult{TextResultForLLM: "ok"}, nil + }) + + done := make(chan struct{}) + go func() { + defer close(done) + session.executeToolAndRespond("req-2", "check_tool", "tc-2", map[string]any{"x": 1}, handler, "", "") + }() + + var inv ToolInvocation + select { + case inv = <-invCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler invocation") + } + + if inv.Context == nil { + t.Fatal("expected ToolInvocation.Context to be set") + } + if inv.TraceContext == nil { + t.Fatal("expected ToolInvocation.TraceContext to be set") + } + if inv.Context != inv.TraceContext { + t.Error("expected Context and TraceContext to be the same value") + } + if inv.SessionID != "session-ctx-test" { + t.Errorf("expected SessionID session-ctx-test, got %q", inv.SessionID) + } + if inv.ToolCallID != "tc-2" { + t.Errorf("expected ToolCallID tc-2, got %q", inv.ToolCallID) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for executeToolAndRespond to complete") + } +} diff --git a/go/types.go b/go/types.go index 5ed0b6931..9650de59f 100644 --- a/go/types.go +++ b/go/types.go @@ -1154,10 +1154,22 @@ type ToolInvocation struct { ToolName string Arguments any + // Context is the primary context for this tool invocation. It carries + // W3C Trace Context propagation (for OpenTelemetry) and is cancelled + // when session.Abort is called, allowing handlers to cooperatively stop + // in-flight work (e.g. pass to http.NewRequestWithContext, sql.QueryContext). + // + // Handlers that do not inspect the context continue to work unchanged. + Context context.Context + + // TraceContext is deprecated: use Context instead. // TraceContext carries the W3C Trace Context propagated from the CLI's - // execute_tool span. Pass this to OpenTelemetry-aware code so that + // execute_tool span. Pass this to OpenTelemetry-aware code so that // child spans created inside the handler are parented to the CLI span. // When no trace context is available this will be context.Background(). + // + // Deprecated: Use Context, which carries the same trace information and + // is additionally cancelled when session.Abort is called. TraceContext context.Context } From 036bdac2692b8c34952d8ad028c89b2a979bfe42 Mon Sep 17 00:00:00 2001 From: Alberto Gimeno Date: Wed, 17 Jun 2026 14:34:03 +0200 Subject: [PATCH 2/3] go: add CancelToolCall for single-handler granular cancellation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Session.CancelToolCall(toolCallID string) bool — cancels a single in-flight tool handler by its tool call ID without affecting other concurrent handlers or the agentic loop. - CancelToolCall looks up the cancel func in toolCallCancels, calls it, removes the entry, and returns true; returns false if not in flight - Abort() now also removes entries after cancelling them, so a subsequent CancelToolCall on the same ID correctly returns false - Extract newRPCDrainSession test helper to eliminate boilerplate across the three tool-context tests - Add TestCancelToolCall_cancelsTargetedHandlerOnly covering: unknown ID returns false, targeted handler context is cancelled, sibling handler context stays live, idempotent false on second call - Document CancelToolCall in go/README.md alongside Abort cancellation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- go/README.md | 34 +++++++ go/session.go | 31 +++++- go/session_test.go | 228 ++++++++++++++++++++++++++++----------------- 3 files changed, 206 insertions(+), 87 deletions(-) diff --git a/go/README.md b/go/README.md index bc75faebb..ac91e2944 100644 --- a/go/README.md +++ b/go/README.md @@ -380,6 +380,40 @@ lookupIssue := copilot.DefineTool("lookup_issue", "Fetch issue details from our Handlers that don't use `inv.Context` are unaffected; they run to completion as before. +#### Cancelling a single tool call + +Use `session.CancelToolCall(toolCallID)` to cancel one specific in-flight handler without aborting the session or any other concurrent handlers. It returns `true` if the tool call was found and cancelled, `false` if it was not in flight. + +```go +// Cancel a specific tool call by its ID. +if cancelled := session.CancelToolCall(toolCallID); !cancelled { + log.Println("tool call was not in flight") +} +``` + +`toolCallID` is available inside the handler as `inv.ToolCallID`. You can capture it to enable external cancellation of a specific operation: + +```go +var mu sync.Mutex +activeCalls := map[string]string{} // label → toolCallID + +slowTool := copilot.DefineTool("slow_op", "A long-running operation", + func(params SlowOpParams, inv copilot.ToolInvocation) (any, error) { + mu.Lock() + activeCalls[params.Label] = inv.ToolCallID + mu.Unlock() + + // ... do work, checking inv.Context.Done() ... + return result, nil + }) + +// Elsewhere, cancel by label: +mu.Lock() +id := activeCalls["my-label"] +mu.Unlock() +session.CancelToolCall(id) +``` + #### Overriding Built-in Tools If you register a tool with the same name as a built-in CLI tool (e.g. `edit_file`, `read_file`), the SDK will throw an error unless you explicitly opt in by setting `OverridesBuiltInTool = true`. This flag signals that you intend to replace the built-in tool with your custom implementation. diff --git a/go/session.go b/go/session.go index 5f42114af..3a5255c49 100644 --- a/go/session.go +++ b/go/session.go @@ -1586,14 +1586,43 @@ func (s *Session) Abort(ctx context.Context) error { } s.toolCallCancelsMu.Lock() - for _, cancel := range s.toolCallCancels { + for id, cancel := range s.toolCallCancels { cancel() + delete(s.toolCallCancels, id) } s.toolCallCancelsMu.Unlock() return nil } +// CancelToolCall cancels a single in-flight tool handler identified by toolCallID +// without aborting the agentic loop or any other concurrent tool handlers. +// +// It looks up the cancel func registered when the handler was dispatched, calls it +// (cancelling the context passed to that handler via ToolInvocation.Context), removes +// the entry from the in-flight map, and returns true. If no handler with the given +// toolCallID is currently executing, CancelToolCall is a no-op and returns false. +// +// Example: +// +// // Start a session with a long-running tool registered. +// // Later, cancel only a specific tool call without aborting the session: +// if cancelled := session.CancelToolCall("tool-call-id-123"); !cancelled { +// log.Println("tool call was not in flight") +// } +func (s *Session) CancelToolCall(toolCallID string) bool { + s.toolCallCancelsMu.Lock() + defer s.toolCallCancelsMu.Unlock() + + cancel, ok := s.toolCallCancels[toolCallID] + if !ok { + return false + } + cancel() + delete(s.toolCallCancels, toolCallID) + return true +} + // SetModelOptions configures optional parameters for SetModel. type SetModelOptions struct { // ReasoningEffort sets the reasoning effort level for the new model (e.g., "low", "medium", "high", "xhigh"). diff --git a/go/session_test.go b/go/session_test.go index 2dfe22252..e6898d65a 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -1035,54 +1035,8 @@ func TestSession_ElicitationRequestSchema(t *testing.T) { // TestToolInvocation_ContextCancelledOnAbort verifies that the context passed to a // tool handler is cancelled when the in-flight cancel func (as used by Abort) fires. func TestToolInvocation_ContextCancelledOnAbort(t *testing.T) { - stdinR, stdinW := io.Pipe() - stdoutR, stdoutW := io.Pipe() - defer stdinR.Close() - defer stdinW.Close() - defer stdoutR.Close() - defer stdoutW.Close() - - client := jsonrpc2.NewClient(stdinW, stdoutR) - client.Start() - defer client.Stop() - - session := &Session{ - SessionID: "session-abort-test", - client: client, - RPC: rpc.NewSessionRPC(client, "session-abort-test"), - } - - // Drain the RPC responses from the mock server side. - go func() { - scanner := bufio.NewScanner(stdinR) - for scanner.Scan() { - // read Content-Length header - line := scanner.Text() - if !strings.HasPrefix(line, "Content-Length:") { - continue - } - var contentLen int - fmt.Sscanf(line, "Content-Length: %d", &contentLen) - // skip blank separator - scanner.Scan() - body := make([]byte, contentLen) - io.ReadFull(stdinR, body) - - var req struct { - ID json.RawMessage `json:"id"` - Method string `json:"method"` - } - if err := json.Unmarshal(body, &req); err != nil || req.ID == nil { - continue - } - resp, _ := json.Marshal(map[string]any{ - "jsonrpc": "2.0", - "id": json.RawMessage(req.ID), - "result": map[string]any{}, - }) - fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp) - } - }() + session, cleanup := newRPCDrainSession(t, "session-abort-test") + defer cleanup() // Channel to receive the invocation context from the handler. ctxCh := make(chan context.Context, 1) @@ -1144,24 +1098,69 @@ func TestToolInvocation_ContextCancelledOnAbort(t *testing.T) { // TestToolInvocation_ContextPopulated verifies that executeToolAndRespond sets // both Context and TraceContext on the ToolInvocation passed to the handler. func TestToolInvocation_ContextPopulated(t *testing.T) { + session, cleanup := newRPCDrainSession(t, "session-ctx-test") + defer cleanup() + + invCh := make(chan ToolInvocation, 1) + handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { + invCh <- inv + return ToolResult{TextResultForLLM: "ok"}, nil + }) + + done := make(chan struct{}) + go func() { + defer close(done) + session.executeToolAndRespond("req-2", "check_tool", "tc-2", map[string]any{"x": 1}, handler, "", "") + }() + + var inv ToolInvocation + select { + case inv = <-invCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler invocation") + } + + if inv.Context == nil { + t.Fatal("expected ToolInvocation.Context to be set") + } + if inv.TraceContext == nil { + t.Fatal("expected ToolInvocation.TraceContext to be set") + } + if inv.Context != inv.TraceContext { + t.Error("expected Context and TraceContext to be the same value") + } + if inv.SessionID != "session-ctx-test" { + t.Errorf("expected SessionID session-ctx-test, got %q", inv.SessionID) + } + if inv.ToolCallID != "tc-2" { + t.Errorf("expected ToolCallID tc-2, got %q", inv.ToolCallID) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for executeToolAndRespond to complete") + } +} + +// newRPCDrainSession creates a Session backed by a pipe-based JSON-RPC client +// whose server side drains requests and returns empty success responses. +// The caller must close stdinW and stdoutW when done. +func newRPCDrainSession(t *testing.T, sessionID string) (*Session, func()) { + t.Helper() stdinR, stdinW := io.Pipe() stdoutR, stdoutW := io.Pipe() - defer stdinR.Close() - defer stdinW.Close() - defer stdoutR.Close() - defer stdoutW.Close() client := jsonrpc2.NewClient(stdinW, stdoutR) client.Start() - defer client.Stop() session := &Session{ - SessionID: "session-ctx-test", + SessionID: sessionID, client: client, - RPC: rpc.NewSessionRPC(client, "session-ctx-test"), + RPC: rpc.NewSessionRPC(client, sessionID), } - // Drain RPC responses. + // Drain goroutine: read every RPC request and send an empty success response. go func() { scanner := bufio.NewScanner(stdinR) for scanner.Scan() { @@ -1171,13 +1170,12 @@ func TestToolInvocation_ContextPopulated(t *testing.T) { } var contentLen int fmt.Sscanf(line, "Content-Length: %d", &contentLen) - scanner.Scan() + scanner.Scan() // blank separator body := make([]byte, contentLen) - io.ReadFull(stdinR, body) + io.ReadFull(stdinR, body) //nolint:errcheck var req struct { - ID json.RawMessage `json:"id"` - Method string `json:"method"` + ID json.RawMessage `json:"id"` } if err := json.Unmarshal(body, &req); err != nil || req.ID == nil { continue @@ -1187,48 +1185,106 @@ func TestToolInvocation_ContextPopulated(t *testing.T) { "id": json.RawMessage(req.ID), "result": map[string]any{}, }) - fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp) + fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp) //nolint:errcheck } }() - invCh := make(chan ToolInvocation, 1) - handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { - invCh <- inv - return ToolResult{TextResultForLLM: "ok"}, nil - }) + cleanup := func() { + client.Stop() + stdinR.Close() + stdinW.Close() + stdoutR.Close() + stdoutW.Close() + } + return session, cleanup +} - done := make(chan struct{}) +// TestCancelToolCall_cancelsTargetedHandlerOnly verifies that CancelToolCall +// cancels only the specified handler's context while leaving concurrent +// handlers unaffected, and returns false for an unknown tool call ID. +func TestCancelToolCall_cancelsTargetedHandlerOnly(t *testing.T) { + session, cleanup := newRPCDrainSession(t, "session-cancel-test") + defer cleanup() + + type handlerState struct { + ctx context.Context + done chan struct{} + } + + makeBlockingHandler := func(id string) (ToolHandler, *handlerState) { + state := &handlerState{done: make(chan struct{})} + h := ToolHandler(func(inv ToolInvocation) (ToolResult, error) { + state.ctx = inv.Context + <-inv.Context.Done() // block until cancelled + return ToolResult{TextResultForLLM: "cancelled"}, nil + }) + return h, state + } + + h1, s1 := makeBlockingHandler("tc-a") + h2, s2 := makeBlockingHandler("tc-b") + + // Start both handlers concurrently. go func() { - defer close(done) - session.executeToolAndRespond("req-2", "check_tool", "tc-2", map[string]any{"x": 1}, handler, "", "") + defer close(s1.done) + session.executeToolAndRespond("req-a", "tool_a", "tc-a", nil, h1, "", "") + }() + go func() { + defer close(s2.done) + session.executeToolAndRespond("req-b", "tool_b", "tc-b", nil, h2, "", "") }() - var inv ToolInvocation - select { - case inv = <-invCh: - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for handler invocation") + // Wait for both handlers to start (ctx will be set once they block). + deadline := time.After(2 * time.Second) + for s1.ctx == nil || s2.ctx == nil { + select { + case <-deadline: + t.Fatal("timed out waiting for handlers to start") + default: + time.Sleep(5 * time.Millisecond) + } } - if inv.Context == nil { - t.Fatal("expected ToolInvocation.Context to be set") + // Unknown ID returns false and leaves both handlers running. + if got := session.CancelToolCall("nonexistent"); got { + t.Fatal("expected CancelToolCall(unknown) to return false") } - if inv.TraceContext == nil { - t.Fatal("expected ToolInvocation.TraceContext to be set") + if s1.ctx.Err() != nil { + t.Fatal("handler 1 should still be running after unknown CancelToolCall") } - if inv.Context != inv.TraceContext { - t.Error("expected Context and TraceContext to be the same value") + if s2.ctx.Err() != nil { + t.Fatal("handler 2 should still be running after unknown CancelToolCall") } - if inv.SessionID != "session-ctx-test" { - t.Errorf("expected SessionID session-ctx-test, got %q", inv.SessionID) + + // Cancel only handler 1. + if got := session.CancelToolCall("tc-a"); !got { + t.Fatal("expected CancelToolCall(tc-a) to return true") } - if inv.ToolCallID != "tc-2" { - t.Errorf("expected ToolCallID tc-2, got %q", inv.ToolCallID) + + // Handler 1 should finish; handler 2 should remain live. + select { + case <-s1.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handler 1 to finish after CancelToolCall") + } + + if s1.ctx.Err() == nil { + t.Fatal("expected handler 1 context to be cancelled") + } + if s2.ctx.Err() != nil { + t.Fatal("handler 2 context should still be live") } + // CancelToolCall on the same ID again returns false (already removed). + if got := session.CancelToolCall("tc-a"); got { + t.Fatal("expected second CancelToolCall(tc-a) to return false") + } + + // Cancel handler 2 to let it finish and avoid goroutine leak. + session.CancelToolCall("tc-b") select { - case <-done: + case <-s2.done: case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for executeToolAndRespond to complete") + t.Fatal("timed out waiting for handler 2 to finish") } } From d11fed10e063422904007f4ae77530ba0f150b01 Mon Sep 17 00:00:00 2001 From: Alberto Gimeno Date: Wed, 17 Jun 2026 16:07:32 +0200 Subject: [PATCH 3/3] go: address Copilot PR feedback on ToolInvocation context fields - Set TraceContext to traceCtx (non-cancellable) to preserve backward compatibility; TraceContext is no longer affected by session.Abort() or CancelToolCall(). Only Context (the new field) is cancelled. - Fix newRPCDrainSession drain goroutine: replace bufio.Scanner + io.ReadFull on the same reader (which causes buffer-ahead corruption) with a single bufio.Reader used for both header line reads and body ReadFull calls. - Update TestToolInvocation_ContextPopulated assertion: Context and TraceContext are now different instances (Context is a cancellable child of TraceContext), so assert they differ rather than being equal. - Clarify TraceContext deprecation comment: document it is never cancelled. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- go/session.go | 2 +- go/session_test.go | 38 ++++++++++++++++++++++++++++---------- go/types.go | 5 ++++- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/go/session.go b/go/session.go index 3a5255c49..c35cc367e 100644 --- a/go/session.go +++ b/go/session.go @@ -1383,7 +1383,7 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, ToolName: toolName, Arguments: arguments, Context: ctx, - TraceContext: ctx, + TraceContext: traceCtx, } result, err := handler(invocation) diff --git a/go/session_test.go b/go/session_test.go index e6898d65a..e2356f3d3 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -1126,8 +1126,10 @@ func TestToolInvocation_ContextPopulated(t *testing.T) { if inv.TraceContext == nil { t.Fatal("expected ToolInvocation.TraceContext to be set") } - if inv.Context != inv.TraceContext { - t.Error("expected Context and TraceContext to be the same value") + // Context is a cancellable child of TraceContext; they are different instances. + // Both must be non-nil and Context must be cancellable independently. + if inv.Context == inv.TraceContext { + t.Error("expected Context and TraceContext to be different instances (Context is a cancellable child)") } if inv.SessionID != "session-ctx-test" { t.Errorf("expected SessionID session-ctx-test, got %q", inv.SessionID) @@ -1161,18 +1163,34 @@ func newRPCDrainSession(t *testing.T, sessionID string) (*Session, func()) { } // Drain goroutine: read every RPC request and send an empty success response. + // Uses a single bufio.Reader so that header parsing and body reads share the + // same read buffer — mixing bufio.Scanner with io.ReadFull on the same reader + // causes data corruption because Scanner may buffer-ahead bytes that + // io.ReadFull then misses. go func() { - scanner := bufio.NewScanner(stdinR) - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "Content-Length:") { + br := bufio.NewReader(stdinR) + for { + // Read headers until blank line. + var contentLen int + for { + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + break // end of headers + } + fmt.Sscanf(line, "Content-Length: %d", &contentLen) + } + if contentLen == 0 { continue } - var contentLen int - fmt.Sscanf(line, "Content-Length: %d", &contentLen) - scanner.Scan() // blank separator + body := make([]byte, contentLen) - io.ReadFull(stdinR, body) //nolint:errcheck + if _, err := io.ReadFull(br, body); err != nil { + return + } var req struct { ID json.RawMessage `json:"id"` diff --git a/go/types.go b/go/types.go index 9650de59f..749b44a7e 100644 --- a/go/types.go +++ b/go/types.go @@ -1167,9 +1167,12 @@ type ToolInvocation struct { // execute_tool span. Pass this to OpenTelemetry-aware code so that // child spans created inside the handler are parented to the CLI span. // When no trace context is available this will be context.Background(). + // Unlike Context, TraceContext is never cancelled — it remains valid for + // the lifetime of the RPC call regardless of session.Abort. // // Deprecated: Use Context, which carries the same trace information and - // is additionally cancelled when session.Abort is called. + // is additionally cancelled when session.Abort or session.CancelToolCall + // is called. TraceContext context.Context }