diff --git a/.changeset/whole-birds-build.md b/.changeset/whole-birds-build.md new file mode 100644 index 00000000..1b08f99d --- /dev/null +++ b/.changeset/whole-birds-build.md @@ -0,0 +1,5 @@ +--- +"@openrouter/ai-sdk-provider": minor +--- + +Add reranking support via `rerankingModel()` and `reranking()` provider methods, implementing the `RerankingModelV3` interface for use with the `rerank()` function from the Vercel AI SDK. diff --git a/README.md b/README.md index d25ea649..5db3363c 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,57 @@ OpenRouter supports various embedding models including: - `openai/text-embedding-ada-002` - And more available on [OpenRouter](https://openrouter.ai/models?output_modalities=embeddings) +## Reranking + +OpenRouter supports reranking models for relevance scoring, search result ordering, and RAG pipeline optimization. + +### Basic Usage + +```ts +import { rerank } from 'ai'; +import { openrouter } from '@openrouter/ai-sdk-provider'; + +const { rerankedDocuments } = await rerank({ + model: openrouter.rerankingModel('cohere/rerank-v3.5'), + query: 'What is the capital of France?', + documents: [ + 'Berlin is the capital of Germany.', + 'Paris is the capital of France.', + 'Madrid is the capital of Spain.', + ], +}); + +console.log(rerankedDocuments); // ['Paris is the capital of France.', ...] +``` + +### Limiting Results with topN + +```ts +import { rerank } from 'ai'; +import { openrouter } from '@openrouter/ai-sdk-provider'; + +const { rerankedDocuments } = await rerank({ + model: openrouter.rerankingModel('cohere/rerank-v3.5'), + query: 'Famous landmarks in Paris', + documents: [ + 'The Eiffel Tower is in Paris.', + 'The Colosseum is in Rome.', + 'The Sagrada Familia is in Barcelona.', + 'Big Ben is in London.', + ], + topN: 2, +}); + +console.log(rerankedDocuments); // Top 2 most relevant documents +``` + +### Supported Reranking Models + +OpenRouter supports various reranking models including: + +- `cohere/rerank-v3.5` +- And more available on [OpenRouter](https://openrouter.ai/models?output_modalities=rerank) + ## Passing Extra Body to OpenRouter There are 3 ways to pass extra body to OpenRouter: diff --git a/e2e/rerank/index.test.ts b/e2e/rerank/index.test.ts new file mode 100644 index 00000000..50e5f14e --- /dev/null +++ b/e2e/rerank/index.test.ts @@ -0,0 +1,66 @@ +import { rerank } from 'ai'; +import { describe, expect, it, vi } from 'vitest'; +import { createOpenRouter } from '../../src/index'; + +vi.setConfig({ testTimeout: 30_000 }); + +describe('OpenRouter Reranking E2E', () => { + const openrouter = createOpenRouter({ + apiKey: process.env.OPENROUTER_API_KEY, + baseUrl: `${process.env.OPENROUTER_API_BASE}/api/v1`, + }); + + it('reranks documents and returns the most relevant first', async () => { + const documents = [ + 'Berlin is the capital of Germany.', + 'Paris is the capital of France.', + 'Madrid is the capital of Spain.', + 'Rome is the capital of Italy.', + ]; + const query = 'What is the capital of France?'; + + const result = await rerank({ + model: openrouter.rerankingModel('cohere/rerank-v3.5'), + documents, + query, + }); + + expect(result.rerankedDocuments).toHaveLength(documents.length); + expect(result.rerankedDocuments[0]).toBe('Paris is the capital of France.'); + + expect(result.ranking).toHaveLength(documents.length); + expect(result.ranking![0]!.score).toBeGreaterThan(0.5); + }); + + it('respects topN and returns only the top N results', async () => { + const documents = [ + 'The Eiffel Tower is in Paris.', + 'The Colosseum is in Rome.', + 'The Sagrada Familia is in Barcelona.', + 'Big Ben is in London.', + 'The Acropolis is in Athens.', + ]; + + const result = await rerank({ + model: openrouter.rerankingModel('cohere/rerank-v3.5'), + documents, + query: 'Famous landmark in Paris', + topN: 2, + }); + + expect(result.rerankedDocuments).toHaveLength(2); + expect(result.rerankedDocuments[0]).toBe('The Eiffel Tower is in Paris.'); + }); + + it('reranking() alias works identically to rerankingModel()', async () => { + const result = await rerank({ + model: openrouter.reranking('cohere/rerank-v3.5'), + documents: ['cat', 'dog', 'fish'], + query: 'aquatic animal', + topN: 1, + }); + + expect(result.rerankedDocuments).toHaveLength(1); + expect(result.rerankedDocuments[0]).toBe('fish'); + }); +}); diff --git a/src/facade.ts b/src/facade.ts index 6e9b83b9..93df2ec2 100644 --- a/src/facade.ts +++ b/src/facade.ts @@ -11,11 +11,16 @@ import type { OpenRouterEmbeddingModelId, OpenRouterEmbeddingSettings, } from './types/openrouter-embedding-settings'; +import type { + OpenRouterRerankModelId, + OpenRouterRerankSettings, +} from './types/openrouter-rerank-settings'; import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils'; import { OpenRouterChatLanguageModel } from './chat'; import { OpenRouterCompletionLanguageModel } from './completion'; import { OpenRouterEmbeddingModel } from './embedding'; +import { OpenRouterRerankingModel } from './rerank'; /** @deprecated Use `createOpenRouter` instead. @@ -128,4 +133,24 @@ Custom headers to include in the requests. ) { return this.textEmbeddingModel(modelId, settings); } + + rerankingModel( + modelId: OpenRouterRerankModelId, + settings: OpenRouterRerankSettings = {}, + ) { + return new OpenRouterRerankingModel(modelId, settings, { + ...this.baseConfig, + url: ({ path }: { path: string }) => `${this.baseURL}${path}`, + }); + } + + /** + * @deprecated Use rerankingModel instead + */ + reranking( + modelId: OpenRouterRerankModelId, + settings: OpenRouterRerankSettings = {}, + ) { + return this.rerankingModel(modelId, settings); + } } diff --git a/src/internal/index.ts b/src/internal/index.ts index 77567c3f..41a212b8 100644 --- a/src/internal/index.ts +++ b/src/internal/index.ts @@ -2,9 +2,11 @@ export * from '../chat'; export * from '../completion'; export * from '../embedding'; export * from '../image'; +export * from '../rerank'; export * from '../types'; export * from '../types/openrouter-chat-settings'; export * from '../types/openrouter-completion-settings'; export * from '../types/openrouter-image-settings'; +export * from '../types/openrouter-rerank-settings'; export * from '../types/openrouter-video-settings'; export * from '../video'; diff --git a/src/provider.ts b/src/provider.ts index b6386000..0f83686b 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -17,6 +17,10 @@ import type { OpenRouterImageModelId, OpenRouterImageSettings, } from './types/openrouter-image-settings'; +import type { + OpenRouterRerankModelId, + OpenRouterRerankSettings, +} from './types/openrouter-rerank-settings'; import type { OpenRouterVideoModelId, OpenRouterVideoSettings, @@ -27,6 +31,7 @@ import { OpenRouterChatLanguageModel } from './chat'; import { OpenRouterCompletionLanguageModel } from './completion'; import { OpenRouterEmbeddingModel } from './embedding'; import { OpenRouterImageModel } from './image'; +import { OpenRouterRerankingModel } from './rerank'; import { webSearch } from './tool/web-search'; import { withUserAgentSuffix } from './utils/with-user-agent-suffix'; import { VERSION } from './version'; @@ -115,6 +120,22 @@ Creates an OpenRouter video model for video generation. settings?: OpenRouterVideoSettings, ): OpenRouterVideoModel; + /** +Creates an OpenRouter reranking model. + */ + rerankingModel( + modelId: OpenRouterRerankModelId, + settings?: OpenRouterRerankSettings, + ): OpenRouterRerankingModel; + + /** +Creates an OpenRouter reranking model. Alias for rerankingModel. + */ + reranking( + modelId: OpenRouterRerankModelId, + settings?: OpenRouterRerankSettings, + ): OpenRouterRerankingModel; + /** * Provider-defined tools for OpenRouter server tools. */ @@ -280,6 +301,17 @@ export function createOpenRouter( extraBody: options.extraBody, }); + const createRerankingModel = ( + modelId: OpenRouterRerankModelId, + settings: OpenRouterRerankSettings = {}, + ) => + new OpenRouterRerankingModel(modelId, settings, { + url: ({ path }) => `${baseURL}${path}`, + headers: getHeaders, + fetch: options.fetch, + extraBody: options.extraBody, + }); + const createLanguageModel = ( modelId: OpenRouterChatModelId | OpenRouterCompletionModelId, settings?: OpenRouterChatSettings | OpenRouterCompletionSettings, @@ -312,6 +344,8 @@ export function createOpenRouter( provider.embedding = createEmbeddingModel; // deprecated alias for v4 compatibility provider.imageModel = createImageModel; provider.videoModel = createVideoModel; + provider.rerankingModel = createRerankingModel; + provider.reranking = createRerankingModel; provider.tools = { webSearch: webSearch, }; diff --git a/src/rerank/index.test.ts b/src/rerank/index.test.ts new file mode 100644 index 00000000..6917f3f2 --- /dev/null +++ b/src/rerank/index.test.ts @@ -0,0 +1,308 @@ +import { describe, expect, it } from 'vitest'; +import { createOpenRouter } from '../provider'; +import { OpenRouterRerankingModel } from './index'; + +const MOCK_RERANK_RESPONSE = { + id: 'rerank-abc123', + model: 'cohere/rerank-v3.5', + provider: 'Cohere', + results: [ + { + document: { text: 'Paris is the capital of France.' }, + index: 1, + relevance_score: 0.98, + }, + { + document: { text: 'Berlin is the capital of Germany.' }, + index: 0, + relevance_score: 0.12, + }, + ], + usage: { + total_tokens: 42, + search_units: 1, + cost: 0.0001, + }, +}; + +function mockFetch(response: unknown, status = 200) { + return async () => + new Response(JSON.stringify(response), { + status, + headers: { 'content-type': 'application/json' }, + }); +} + +type RequestCapture = { url: string; body: Record }; + +function mockFetchWithCapture(response: unknown): { + fetch: typeof fetch; + captured: () => RequestCapture; +} { + let capture: RequestCapture; + const fetchFn = async (url: string | URL | Request, init?: RequestInit) => { + capture = { + url: url.toString(), + body: JSON.parse(init?.body as string), + }; + return new Response(JSON.stringify(response), { + status: 200, + headers: { 'content-type': 'application/json' }, + }); + }; + return { + fetch: fetchFn as typeof fetch, + captured: () => capture, + }; +} + +describe('OpenRouterRerankingModel', () => { + describe('provider methods', () => { + it('rerankingModel() returns an OpenRouterRerankingModel', () => { + const openrouter = createOpenRouter({ + apiKey: 'test-key', + fetch: mockFetch(MOCK_RERANK_RESPONSE), + }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + expect(model).toBeInstanceOf(OpenRouterRerankingModel); + expect(model.modelId).toBe('cohere/rerank-v3.5'); + expect(model.provider).toBe('openrouter.reranking'); + }); + + it('reranking() alias also returns an OpenRouterRerankingModel', () => { + const openrouter = createOpenRouter({ + apiKey: 'test-key', + fetch: mockFetch(MOCK_RERANK_RESPONSE), + }); + const model = openrouter.reranking('cohere/rerank-v3.5'); + expect(model).toBeInstanceOf(OpenRouterRerankingModel); + }); + }); + + describe('doRerank', () => { + it('sends correct request with text documents', async () => { + const { fetch, captured } = mockFetchWithCapture(MOCK_RERANK_RESPONSE); + const openrouter = createOpenRouter({ apiKey: 'test-key', fetch }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + await model.doRerank({ + documents: { + type: 'text', + values: [ + 'Berlin is the capital of Germany.', + 'Paris is the capital of France.', + ], + }, + query: 'What is the capital of France?', + topN: 2, + }); + + expect(captured().url).toContain('/rerank'); + expect(captured().body).toMatchObject({ + model: 'cohere/rerank-v3.5', + query: 'What is the capital of France?', + documents: [ + 'Berlin is the capital of Germany.', + 'Paris is the capital of France.', + ], + top_n: 2, + }); + }); + + it('maps results to ranking with relevanceScore', async () => { + const openrouter = createOpenRouter({ + apiKey: 'test-key', + fetch: mockFetch(MOCK_RERANK_RESPONSE), + }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + const result = await model.doRerank({ + documents: { + type: 'text', + values: [ + 'Berlin is the capital of Germany.', + 'Paris is the capital of France.', + ], + }, + query: 'What is the capital of France?', + }); + + expect(result.ranking).toHaveLength(2); + expect(result.ranking[0]).toEqual({ index: 1, relevanceScore: 0.98 }); + expect(result.ranking[1]).toEqual({ index: 0, relevanceScore: 0.12 }); + }); + + it('omits top_n when topN is not provided', async () => { + const { fetch, captured } = mockFetchWithCapture(MOCK_RERANK_RESPONSE); + const openrouter = createOpenRouter({ apiKey: 'test-key', fetch }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + await model.doRerank({ + documents: { type: 'text', values: ['doc1', 'doc2'] }, + query: 'test query', + }); + + expect(captured().body).not.toHaveProperty('top_n'); + }); + + it('converts object documents to JSON strings and emits a warning', async () => { + const { fetch, captured } = mockFetchWithCapture({ + ...MOCK_RERANK_RESPONSE, + results: [ + { + document: { text: '{"title":"France"}' }, + index: 0, + relevance_score: 0.9, + }, + ], + }); + const openrouter = createOpenRouter({ apiKey: 'test-key', fetch }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + const result = await model.doRerank({ + documents: { + type: 'object', + values: [{ title: 'France' }, { title: 'Germany' }], + }, + query: 'France', + }); + + expect(captured().body.documents).toEqual([ + '{"title":"France"}', + '{"title":"Germany"}', + ]); + expect(result.warnings).toHaveLength(1); + expect(result.warnings![0]!.type).toBe('compatibility'); + }); + + it('surfaces provider metadata with usage and provider name', async () => { + const openrouter = createOpenRouter({ + apiKey: 'test-key', + fetch: mockFetch(MOCK_RERANK_RESPONSE), + }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + const result = await model.doRerank({ + documents: { type: 'text', values: ['doc1'] }, + query: 'test', + }); + + expect(result.providerMetadata?.openrouter).toMatchObject({ + provider: 'Cohere', + usage: { total_tokens: 42, search_units: 1, cost: 0.0001 }, + }); + }); + + it('surfaces response id and modelId', async () => { + const openrouter = createOpenRouter({ + apiKey: 'test-key', + fetch: mockFetch(MOCK_RERANK_RESPONSE), + }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + const result = await model.doRerank({ + documents: { type: 'text', values: ['doc1'] }, + query: 'test', + }); + + expect(result.response?.id).toBe('rerank-abc123'); + expect(result.response?.modelId).toBe('cohere/rerank-v3.5'); + }); + + it('passes extraBody through to the request', async () => { + const { fetch, captured } = mockFetchWithCapture(MOCK_RERANK_RESPONSE); + const openrouter = createOpenRouter({ + apiKey: 'test-key', + fetch, + extraBody: { custom_field: 'custom_value' }, + }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + await model.doRerank({ + documents: { type: 'text', values: ['doc1'] }, + query: 'test', + }); + + expect(captured().body).toHaveProperty('custom_field', 'custom_value'); + }); + + it('includes Authorization header with API key', async () => { + let capturedHeaders: Record = {}; + const fetchFn = async ( + _url: string | URL | Request, + init?: RequestInit, + ) => { + capturedHeaders = Object.fromEntries( + new Headers(init?.headers).entries(), + ); + return new Response(JSON.stringify(MOCK_RERANK_RESPONSE), { + status: 200, + headers: { 'content-type': 'application/json' }, + }); + }; + + const openrouter = createOpenRouter({ + apiKey: 'test-api-key-123', + fetch: fetchFn as typeof fetch, + }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + await model.doRerank({ + documents: { type: 'text', values: ['doc1'] }, + query: 'test', + }); + + expect(capturedHeaders['authorization']).toBe('Bearer test-api-key-123'); + }); + + it('merges providerOptions.openrouter into the request body', async () => { + const { fetch, captured } = mockFetchWithCapture(MOCK_RERANK_RESPONSE); + const openrouter = createOpenRouter({ apiKey: 'test-key', fetch }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + await model.doRerank({ + documents: { type: 'text', values: ['doc1'] }, + query: 'test', + providerOptions: { + openrouter: { provider: { order: ['Cohere'] } }, + }, + }); + + expect(captured().body).toHaveProperty('provider'); + expect(captured().body['provider']).toMatchObject({ order: ['Cohere'] }); + }); + + it('merges settings.extraBody into the request body', async () => { + const { fetch, captured } = mockFetchWithCapture(MOCK_RERANK_RESPONSE); + const openrouter = createOpenRouter({ apiKey: 'test-key', fetch }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5', { + extraBody: { settings_field: 'settings_value' }, + }); + + await model.doRerank({ + documents: { type: 'text', values: ['doc1'] }, + query: 'test', + }); + + expect(captured().body).toHaveProperty( + 'settings_field', + 'settings_value', + ); + }); + + it('emits no warnings for text documents', async () => { + const openrouter = createOpenRouter({ + apiKey: 'test-key', + fetch: mockFetch(MOCK_RERANK_RESPONSE), + }); + const model = openrouter.rerankingModel('cohere/rerank-v3.5'); + + const result = await model.doRerank({ + documents: { type: 'text', values: ['doc1', 'doc2'] }, + query: 'test', + }); + + expect(result.warnings).toHaveLength(0); + }); + }); +}); diff --git a/src/rerank/index.ts b/src/rerank/index.ts new file mode 100644 index 00000000..3d4b354e --- /dev/null +++ b/src/rerank/index.ts @@ -0,0 +1,113 @@ +import type { + RerankingModelV3, + RerankingModelV3CallOptions, + SharedV3Warning, +} from '@ai-sdk/provider'; +import type { + OpenRouterRerankModelId, + OpenRouterRerankSettings, +} from '../types/openrouter-rerank-settings'; + +import { + combineHeaders, + createJsonResponseHandler, + postJsonToApi, +} from '@ai-sdk/provider-utils'; +import { openrouterFailedResponseHandler } from '../schemas/error-response'; +import { OpenRouterRerankResponseSchema } from './schemas'; + +type OpenRouterRerankingConfig = { + headers: () => Record; + url: (options: { modelId: string; path: string }) => string; + fetch?: typeof fetch; + extraBody?: Record; +}; + +export class OpenRouterRerankingModel implements RerankingModelV3 { + readonly specificationVersion = 'v3' as const; + readonly provider = 'openrouter.reranking'; + readonly modelId: OpenRouterRerankModelId; + + private readonly settings: OpenRouterRerankSettings; + private readonly config: OpenRouterRerankingConfig; + + constructor( + modelId: OpenRouterRerankModelId, + settings: OpenRouterRerankSettings, + config: OpenRouterRerankingConfig, + ) { + this.modelId = modelId; + this.settings = settings; + this.config = config; + } + + async doRerank({ + documents, + query, + topN, + headers, + abortSignal, + providerOptions, + }: RerankingModelV3CallOptions) { + const openrouterOptions = + (providerOptions?.openrouter as Record) || {}; + + const warnings: SharedV3Warning[] = []; + + if (documents.type === 'object') { + warnings.push({ + type: 'compatibility' as const, + feature: 'object documents', + details: + 'Object documents are not natively supported. They are converted to JSON strings.', + }); + } + + const { + responseHeaders, + value: response, + rawValue, + } = await postJsonToApi({ + url: this.config.url({ modelId: this.modelId, path: '/rerank' }), + headers: combineHeaders(this.config.headers(), headers), + body: { + model: this.modelId, + query, + documents: + documents.type === 'text' + ? documents.values + : documents.values.map((v) => JSON.stringify(v)), + ...(topN !== undefined && { top_n: topN }), + ...this.config.extraBody, + ...this.settings.extraBody, + ...openrouterOptions, + }, + failedResponseHandler: openrouterFailedResponseHandler, + successfulResponseHandler: createJsonResponseHandler( + OpenRouterRerankResponseSchema, + ), + abortSignal, + fetch: this.config.fetch, + }); + + return { + ranking: response.results.map((r) => ({ + index: r.index, + relevanceScore: r.relevance_score, + })), + providerMetadata: { + openrouter: { + ...(response.provider && { provider: response.provider }), + ...(response.usage && { usage: response.usage }), + }, + }, + warnings, + response: { + id: response.id, + modelId: response.model, + headers: responseHeaders, + body: rawValue, + }, + }; + } +} diff --git a/src/rerank/schemas.ts b/src/rerank/schemas.ts new file mode 100644 index 00000000..83745a02 --- /dev/null +++ b/src/rerank/schemas.ts @@ -0,0 +1,23 @@ +import { z } from 'zod/v4'; + +export const OpenRouterRerankResponseSchema = z + .object({ + id: z.string().optional(), + model: z.string(), + provider: z.string().optional(), + results: z.array( + z.object({ + document: z.object({ text: z.string() }), + index: z.number(), + relevance_score: z.number(), + }), + ), + usage: z + .object({ + cost: z.number().optional(), + search_units: z.number().optional(), + total_tokens: z.number().optional(), + }) + .optional(), + }) + .passthrough(); diff --git a/src/types/index.ts b/src/types/index.ts index c400b40f..5cd66a79 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -4,6 +4,7 @@ export type { LanguageModelV3, LanguageModelV3Prompt }; export * from './openrouter-embedding-settings'; export * from './openrouter-image-settings'; +export * from './openrouter-rerank-settings'; export * from './openrouter-video-settings'; export type OpenRouterProviderOptions = { diff --git a/src/types/openrouter-rerank-settings.ts b/src/types/openrouter-rerank-settings.ts new file mode 100644 index 00000000..b8c42982 --- /dev/null +++ b/src/types/openrouter-rerank-settings.ts @@ -0,0 +1,24 @@ +import type { OpenRouterSharedSettings } from '.'; + +export type OpenRouterRerankModelId = + | 'cohere/rerank-v3.5' + | 'cohere/rerank-4-fast' + | 'cohere/rerank-4-pro' + | (string & {}); + +/** + * Provider-specific options for OpenRouter reranking models. + * Pass these via `providerOptions.openrouter` in the `rerank()` call. + */ +export type OpenRouterRerankProviderOptions = { + /** + * Provider routing preferences for the reranking request. + * @see https://openrouter.ai/docs/features/provider-routing + */ + provider?: Record; +}; + +/** + * Settings for OpenRouter reranking models, passed at model creation time. + */ +export type OpenRouterRerankSettings = OpenRouterSharedSettings;