Skip to content

Commit 0ba9325

Browse files
Add getBearerToken callback for BYOK providers (Managed Identity)
Lets BYOK provider configs supply a `getBearerToken` callback so the SDK consumer resolves bearer tokens (e.g. Azure Managed Identity via @azure/identity) on demand. The callback never crosses the wire: the SDK strips it from the provider config, sends a `hasBearerTokenProvider: true` flag, and answers the runtime's session-scoped `providerToken.acquire` RPC by routing to the matching per-provider callback. The returned token is applied as the Authorization header for outbound model requests; the consumer owns caching/refresh. - client.ts: strip the callback, emit the `hasBearerTokenProvider` wire flag, register per-provider callbacks on the session. - session.ts: handle `providerToken.acquire` by dispatching on provider name. - types.ts: public `getBearerToken` / `ProviderTokenArgs` / `ProviderBearerToken`. - generated/rpc.ts: regenerated contract (providerToken.acquire + hasBearerTokenProvider/bearerTokenScope fields). - e2e: callback token reaches model, refresh-on-expiry, per-provider dispatch. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent d15cfcb commit 0ba9325

8 files changed

Lines changed: 601 additions & 4 deletions

nodejs/src/client.ts

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,14 @@ import type {
4747
ExitPlanModeResult,
4848
ForegroundSessionInfo,
4949
GetAuthStatusResponse,
50+
GetBearerToken,
5051
GetStatusResponse,
5152
InternalRuntimeConnection,
5253
LargeToolOutputConfig,
5354
MCPServerConfig,
5455
ModelInfo,
56+
NamedProviderConfig,
57+
ProviderConfig,
5558
ResumeSessionConfig,
5659
SectionTransformFn,
5760
SessionConfig,
@@ -150,6 +153,62 @@ function toJsonSchema(parameters: Tool["parameters"]): Record<string, unknown> |
150153
return parameters;
151154
}
152155

156+
/** Implicit provider name for the singular, whole-session {@link ProviderConfig}. */
157+
const DEFAULT_PROVIDER_NAME = "default";
158+
159+
/** Wire-safe singular provider config carrying the `hasBearerTokenProvider` flag. */
160+
type WireProviderConfig = Omit<ProviderConfig, "getBearerToken"> & { hasBearerTokenProvider?: boolean };
161+
162+
/** Wire-safe named provider config carrying the `hasBearerTokenProvider` flag. */
163+
type WireNamedProviderConfig = Omit<NamedProviderConfig, "getBearerToken"> & {
164+
hasBearerTokenProvider?: boolean;
165+
};
166+
167+
/**
168+
* Strips the non-serializable {@link GetBearerToken} callbacks from the singular
169+
* and named provider configs before they cross the RPC boundary, replacing each
170+
* with a `hasBearerTokenProvider: true` wire flag. Any configured
171+
* {@link ProviderConfig.bearerTokenScope} is forwarded verbatim (the bearer-token
172+
* surface is provider-agnostic, so the SDK never substitutes a default scope).
173+
* Returns wire-safe provider configs alongside a map of provider name → callback
174+
* for session-side registration.
175+
*/
176+
function extractBearerTokenProviders(
177+
provider: ProviderConfig | undefined,
178+
providers: NamedProviderConfig[] | undefined
179+
): {
180+
wireProvider: WireProviderConfig | undefined;
181+
wireProviders: WireNamedProviderConfig[] | undefined;
182+
callbacks: Map<string, GetBearerToken>;
183+
} {
184+
const callbacks = new Map<string, GetBearerToken>();
185+
186+
let wireProvider: WireProviderConfig | undefined = provider;
187+
if (provider?.getBearerToken) {
188+
const { getBearerToken, ...rest } = provider;
189+
callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken);
190+
wireProvider = {
191+
...rest,
192+
hasBearerTokenProvider: true,
193+
};
194+
}
195+
196+
let wireProviders: WireNamedProviderConfig[] | undefined = providers;
197+
if (providers?.some((p) => p.getBearerToken)) {
198+
wireProviders = providers.map((p) => {
199+
if (!p.getBearerToken) return p;
200+
const { getBearerToken, ...rest } = p;
201+
callbacks.set(p.name, getBearerToken);
202+
return {
203+
...rest,
204+
hasBearerTokenProvider: true,
205+
};
206+
});
207+
}
208+
209+
return { wireProvider, wireProviders, callbacks };
210+
}
211+
153212
/**
154213
* Convert MCP server configs from public API format (workingDirectory) to
155214
* wire format (cwd) expected by the runtime.
@@ -1161,6 +1220,15 @@ export class CopilotClient {
11611220
const useServerGeneratedId = config.cloud != null && callerSessionId == null;
11621221
const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID());
11631222

1223+
// Strip non-serializable getBearerToken callbacks from provider configs,
1224+
// replacing them with a wire flag; keep the callbacks for session-side
1225+
// registration so the runtime can call back to acquire tokens.
1226+
const {
1227+
wireProvider: bearerWireProvider,
1228+
wireProviders: bearerWireProviders,
1229+
callbacks: bearerTokenCallbacks,
1230+
} = extractBearerTokenProviders(config.provider, config.providers);
1231+
11641232
// Extract transform callbacks from system message config before serialization.
11651233
const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks(
11661234
config.systemMessage
@@ -1178,6 +1246,9 @@ export class CopilotClient {
11781246
s.registerTools(config.tools);
11791247
s.registerCanvases(config.canvases);
11801248
s.registerCommands(config.commands);
1249+
if (bearerTokenCallbacks.size > 0) {
1250+
s.registerBearerTokenProviders(bearerTokenCallbacks);
1251+
}
11811252
s.registerPermissionHandler(config.onPermissionRequest);
11821253
if (config.onUserInputRequest) {
11831254
s.registerUserInputHandler(config.onUserInputRequest);
@@ -1249,8 +1320,8 @@ export class CopilotClient {
12491320
availableTools: toolFilterOptions.availableTools,
12501321
excludedTools: toolFilterOptions.excludedTools,
12511322
toolFilterPrecedence: toolFilterOptions.toolFilterPrecedence,
1252-
provider: config.provider,
1253-
providers: config.providers,
1323+
provider: bearerWireProvider,
1324+
providers: bearerWireProviders,
12541325
models: config.models,
12551326
enableSessionTelemetry: config.enableSessionTelemetry,
12561327
modelCapabilities: config.modelCapabilities,
@@ -1369,6 +1440,14 @@ export class CopilotClient {
13691440
session.registerTools(config.tools);
13701441
session.registerCanvases(config.canvases);
13711442
session.registerCommands(config.commands);
1443+
const {
1444+
wireProvider: bearerWireProvider,
1445+
wireProviders: bearerWireProviders,
1446+
callbacks: bearerTokenCallbacks,
1447+
} = extractBearerTokenProviders(config.provider, config.providers);
1448+
if (bearerTokenCallbacks.size > 0) {
1449+
session.registerBearerTokenProviders(bearerTokenCallbacks);
1450+
}
13721451
session.registerPermissionHandler(config.onPermissionRequest);
13731452
if (config.onUserInputRequest) {
13741453
session.registerUserInputHandler(config.onUserInputRequest);
@@ -1435,8 +1514,8 @@ export class CopilotClient {
14351514
name: cmd.name,
14361515
description: cmd.description,
14371516
})),
1438-
provider: config.provider,
1439-
providers: config.providers,
1517+
provider: bearerWireProvider,
1518+
providers: bearerWireProviders,
14401519
models: config.models,
14411520
modelCapabilities: config.modelCapabilities,
14421521
largeOutput: toWireLargeOutput(config.largeOutput),

nodejs/src/generated/rpc.ts

Lines changed: 73 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

nodejs/src/session.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import type {
2626
ExitPlanModeHandler,
2727
ExitPlanModeRequest,
2828
ExitPlanModeResult,
29+
GetBearerToken,
2930
UiInputOptions,
3031
MessageOptions,
3132
PermissionHandler,
@@ -122,6 +123,7 @@ export class CopilotSession {
122123
new Map();
123124
private toolHandlers: Map<string, ToolHandler> = new Map();
124125
private canvases: Map<string, Canvas> = new Map();
126+
private bearerTokenProviders: Map<string, GetBearerToken> = new Map();
125127
private commandHandlers: Map<string, CommandHandler> = new Map();
126128
private permissionHandler?: PermissionHandler;
127129
private userInputHandler?: UserInputHandler;
@@ -759,6 +761,52 @@ export class CopilotSession {
759761
};
760762
}
761763

764+
/**
765+
* Registers per-provider {@link GetBearerToken} callbacks for BYOK providers
766+
* configured with managed-identity / on-demand bearer-token auth.
767+
*
768+
* The runtime never receives the callback itself; the SDK strips it from the
769+
* provider config and instead sends `hasBearerTokenProvider: true`. When the
770+
* runtime needs a token it issues a session-scoped `providerToken.acquire`
771+
* request, which this handler routes to the matching per-provider callback.
772+
*
773+
* @param providers - Map of provider name → callback, or undefined/empty to clear.
774+
* @internal This method is called internally when creating/resuming a session.
775+
*/
776+
registerBearerTokenProviders(providers?: Map<string, GetBearerToken>): void {
777+
this.bearerTokenProviders.clear();
778+
if (!providers || providers.size === 0) {
779+
delete this.clientSessionApis.providerToken;
780+
return;
781+
}
782+
for (const [name, callback] of providers) {
783+
this.bearerTokenProviders.set(name, callback);
784+
}
785+
786+
const self = this;
787+
this.clientSessionApis.providerToken = {
788+
async acquire(params) {
789+
const callback = self.bearerTokenProviders.get(params.providerName);
790+
if (!callback) {
791+
throw new Error(
792+
`No bearer-token provider registered for provider "${params.providerName}"`
793+
);
794+
}
795+
const result = await callback({
796+
providerName: params.providerName,
797+
scope: params.scope,
798+
});
799+
if (typeof result === "string") {
800+
return { token: result };
801+
}
802+
return {
803+
token: result.token,
804+
expiresOnTimestamp: result.expiresOnTimestamp,
805+
};
806+
},
807+
};
808+
}
809+
762810
/**
763811
* Registers command handlers for this session.
764812
*

0 commit comments

Comments
 (0)