-
Notifications
You must be signed in to change notification settings - Fork 155
fix: add runtime reranking model support #509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| --- | ||
| "@openrouter/ai-sdk-provider": patch | ||
| --- | ||
|
|
||
| Add runtime support for `rerankingModel()` on the OpenRouter provider. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import { describe, expect, it } from 'vitest'; | ||
| import { createOpenRouter } from '../provider'; | ||
| import { OpenRouterRerankingModel } from './index'; | ||
|
|
||
| describe('OpenRouterRerankingModel', () => { | ||
| describe('provider methods', () => { | ||
| it('should expose rerankingModel method', () => { | ||
| const provider = createOpenRouter({ apiKey: 'test-api-key' }); | ||
| expect(provider.rerankingModel).toBeDefined(); | ||
| expect(typeof provider.rerankingModel).toBe('function'); | ||
| }); | ||
|
|
||
| it('should create a reranking model instance', () => { | ||
| const provider = createOpenRouter({ apiKey: 'test-api-key' }); | ||
| const model = provider.rerankingModel('cohere/rerank-v3.5'); | ||
| expect(model).toBeInstanceOf(OpenRouterRerankingModel); | ||
| expect(model.modelId).toBe('cohere/rerank-v3.5'); | ||
| expect(model.provider).toBe('openrouter'); | ||
| expect(model.specificationVersion).toBe('v3'); | ||
| }); | ||
| }); | ||
|
|
||
| describe('doRerank', () => { | ||
| it('should rerank text documents', async () => { | ||
| let capturedUrl: string | undefined; | ||
| let capturedRequest: Record<string, unknown> | undefined; | ||
|
|
||
| const mockFetch = async ( | ||
| url: URL | RequestInfo, | ||
| init?: RequestInit, | ||
| ): Promise<Response> => { | ||
| capturedUrl = url.toString(); | ||
| capturedRequest = JSON.parse(init?.body as string); | ||
| return new Response( | ||
| JSON.stringify({ | ||
| id: 'rerank-test-id', | ||
| model: 'cohere/rerank-v3.5', | ||
| results: [ | ||
| { index: 1, relevance_score: 0.98 }, | ||
| { index: 0, relevance_score: 0.12 }, | ||
| ], | ||
| usage: { | ||
| prompt_tokens: 12, | ||
| total_tokens: 12, | ||
| cost: 0.00002, | ||
| }, | ||
| }), | ||
| { | ||
| status: 200, | ||
| headers: { | ||
| 'content-type': 'application/json', | ||
| }, | ||
| }, | ||
| ); | ||
| }; | ||
|
|
||
| const provider = createOpenRouter({ | ||
| apiKey: 'test-api-key', | ||
| fetch: mockFetch, | ||
| }); | ||
| const model = provider.rerankingModel('cohere/rerank-v3.5'); | ||
|
|
||
| const result = await model.doRerank({ | ||
| query: 'capital of France', | ||
| documents: { | ||
| type: 'text', | ||
| values: ['Berlin is in Germany', 'Paris is in France'], | ||
| }, | ||
| topN: 2, | ||
| }); | ||
|
|
||
| expect(capturedUrl).toBe('https://openrouter.ai/api/v1/rerank'); | ||
| expect(capturedRequest).toMatchObject({ | ||
| model: 'cohere/rerank-v3.5', | ||
| query: 'capital of France', | ||
| documents: ['Berlin is in Germany', 'Paris is in France'], | ||
| top_n: 2, | ||
| }); | ||
| expect(result.ranking).toEqual([ | ||
| { index: 1, relevanceScore: 0.98 }, | ||
| { index: 0, relevanceScore: 0.12 }, | ||
| ]); | ||
| expect(result.response?.id).toBe('rerank-test-id'); | ||
| expect(result.response?.modelId).toBe('cohere/rerank-v3.5'); | ||
| expect( | ||
| (result.providerMetadata?.openrouter as { usage?: { cost?: number } }) | ||
| ?.usage?.cost, | ||
| ).toBe(0.00002); | ||
| }); | ||
| }); | ||
| }); |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,110 @@ | ||||||
| import type { | ||||||
| JSONObject, | ||||||
| RerankingModelV3, | ||||||
| SharedV3Headers, | ||||||
| SharedV3ProviderMetadata, | ||||||
| } from '@ai-sdk/provider'; | ||||||
| import type { | ||||||
| OpenRouterRerankingModelId, | ||||||
| OpenRouterRerankingSettings, | ||||||
| } from '../types/openrouter-reranking-settings'; | ||||||
|
|
||||||
| import { | ||||||
| combineHeaders, | ||||||
| createJsonResponseHandler, | ||||||
| postJsonToApi, | ||||||
| } from '@ai-sdk/provider-utils'; | ||||||
| import { openrouterFailedResponseHandler } from '../schemas/error-response'; | ||||||
| import { OpenRouterProviderMetadataSchema } from '../schemas/provider-metadata'; | ||||||
| import { OpenRouterRerankingResponseSchema } from './schemas'; | ||||||
|
|
||||||
| type OpenRouterRerankingConfig = { | ||||||
| provider: string; | ||||||
| headers: () => Record<string, string | undefined>; | ||||||
| url: (options: { modelId: string; path: string }) => string; | ||||||
| fetch?: typeof fetch; | ||||||
| extraBody?: Record<string, unknown>; | ||||||
| }; | ||||||
|
|
||||||
| export class OpenRouterRerankingModel implements RerankingModelV3 { | ||||||
| readonly specificationVersion = 'v3' as const; | ||||||
| readonly provider = 'openrouter'; | ||||||
| readonly modelId: OpenRouterRerankingModelId; | ||||||
| readonly settings: OpenRouterRerankingSettings; | ||||||
|
|
||||||
| private readonly config: OpenRouterRerankingConfig; | ||||||
|
|
||||||
| constructor( | ||||||
| modelId: OpenRouterRerankingModelId, | ||||||
| settings: OpenRouterRerankingSettings, | ||||||
| config: OpenRouterRerankingConfig, | ||||||
| ) { | ||||||
| this.modelId = modelId; | ||||||
| this.settings = settings; | ||||||
| this.config = config; | ||||||
| } | ||||||
|
|
||||||
| async doRerank({ | ||||||
| documents, | ||||||
| query, | ||||||
| topN, | ||||||
| abortSignal, | ||||||
| headers, | ||||||
| }: Parameters<RerankingModelV3['doRerank']>[0]): Promise< | ||||||
| Awaited<ReturnType<RerankingModelV3['doRerank']>> | ||||||
| > { | ||||||
| const documentValues: string[] | JSONObject[] = documents.values; | ||||||
| const args = { | ||||||
| model: this.modelId, | ||||||
| query, | ||||||
| documents: documentValues, | ||||||
| top_n: topN, | ||||||
| user: this.settings.user, | ||||||
| provider: this.settings.provider, | ||||||
| ...this.config.extraBody, | ||||||
| ...this.settings.extraBody, | ||||||
| }; | ||||||
|
|
||||||
| const { value: responseValue, responseHeaders } = await postJsonToApi({ | ||||||
| url: this.config.url({ | ||||||
| path: '/rerank', | ||||||
| modelId: this.modelId, | ||||||
| }), | ||||||
| headers: combineHeaders(this.config.headers(), headers), | ||||||
| body: args, | ||||||
| failedResponseHandler: openrouterFailedResponseHandler, | ||||||
| successfulResponseHandler: createJsonResponseHandler( | ||||||
| OpenRouterRerankingResponseSchema, | ||||||
| ), | ||||||
| abortSignal, | ||||||
| fetch: this.config.fetch, | ||||||
| }); | ||||||
|
|
||||||
| return { | ||||||
| ranking: responseValue.results.map((result) => ({ | ||||||
| index: result.index, | ||||||
| relevanceScore: result.relevance_score, | ||||||
| })), | ||||||
| providerMetadata: { | ||||||
| openrouter: OpenRouterProviderMetadataSchema.parse({ | ||||||
| provider: '', | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [suggestion]
Suggested change
DetailsWhy: Every other model in this provider surfaces the upstream provider slug in The rerank response schema ( Reviewed at |
||||||
| usage: { | ||||||
| promptTokens: responseValue.usage?.prompt_tokens ?? 0, | ||||||
| completionTokens: 0, | ||||||
| totalTokens: responseValue.usage?.total_tokens ?? 0, | ||||||
| ...(responseValue.usage?.cost != null | ||||||
| ? { cost: responseValue.usage.cost } | ||||||
| : {}), | ||||||
| }, | ||||||
| }), | ||||||
| } satisfies SharedV3ProviderMetadata, | ||||||
| response: { | ||||||
| id: responseValue.id, | ||||||
| modelId: responseValue.model, | ||||||
| headers: responseHeaders as SharedV3Headers, | ||||||
| body: responseValue, | ||||||
| }, | ||||||
| warnings: [], | ||||||
| }; | ||||||
| } | ||||||
| } | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| import { z } from 'zod/v4'; | ||
|
|
||
| export const OpenRouterRerankingResponseSchema = z | ||
| .object({ | ||
| id: z.string().optional(), | ||
| model: z.string().optional(), | ||
| results: z.array( | ||
| z | ||
| .object({ | ||
| index: z.number(), | ||
| relevance_score: z.number(), | ||
| }) | ||
| .passthrough(), | ||
| ), | ||
| usage: z | ||
| .object({ | ||
| prompt_tokens: z.number().optional(), | ||
| total_tokens: z.number().optional(), | ||
| cost: z.number().optional(), | ||
| }) | ||
| .passthrough() | ||
| .optional(), | ||
| }) | ||
| .passthrough(); | ||
|
|
||
| export type OpenRouterRerankingResponse = z.infer< | ||
| typeof OpenRouterRerankingResponseSchema | ||
| >; |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| import type { OpenRouterSharedSettings } from '..'; | ||
|
|
||
| // https://openrouter.ai/models?fmt=cards&supported_parameters=rerank | ||
| export type OpenRouterRerankingModelId = string; | ||
|
|
||
| export type OpenRouterRerankingSettings = { | ||
| /** | ||
| * Provider routing preferences to control request routing behavior. | ||
| */ | ||
| provider?: { | ||
| /** | ||
| * List of provider slugs to try in order. | ||
| */ | ||
| order?: string[]; | ||
| /** | ||
| * Whether to allow backup providers when primary is unavailable. | ||
| */ | ||
| allow_fallbacks?: boolean; | ||
| /** | ||
| * Only use providers that support all parameters in your request. | ||
| */ | ||
| require_parameters?: boolean; | ||
| /** | ||
| * Control whether to use providers that may store data. | ||
| */ | ||
| data_collection?: 'allow' | 'deny'; | ||
| /** | ||
| * List of provider slugs to allow for this request. | ||
| */ | ||
| only?: string[]; | ||
| /** | ||
| * List of provider slugs to skip for this request. | ||
| */ | ||
| ignore?: string[]; | ||
| /** | ||
| * Sort providers by price, throughput, or latency. | ||
| */ | ||
| sort?: 'price' | 'throughput' | 'latency'; | ||
| }; | ||
| } & Pick<OpenRouterSharedSettings, 'extraBody' | 'user'>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[suggestion] Changeset is marked
patch— a new public API method is conventionally aminorbump.Details
Why: This PR adds a brand-new public surface —
provider.rerankingModel()and the exportedOpenRouterRerankingModel/OpenRouterRerankingSettingstypes. Per semver, new backwards-compatible API additions areminor, notpatch. The most recent comparable change in this repo — #479, which addedprovider.videoModel()— shipped as a Minor changeset (seeCHANGELOG.md). Marking thispatchmeans the new method lands in a patch release, which under-signals the feature to consumers pinning on~ranges.Non-blocking: if the maintainers intentionally batch feature additions into patch releases for this package, disregard.
Reviewed at
bd948de