@ai-sdk/gateway 4.0.0-beta.39 → 4.0.0-beta.40

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -176,6 +176,34 @@ const { text } = await generateText({
176
176
 
177
177
  AI Gateway language models can also be used in the `streamText` function and support structured data generation with [`Output`](/docs/reference/ai-sdk-core/output) (see [AI SDK Core](/docs/ai-sdk-core)).
178
178
 
179
+ ## Reranking Models
180
+
181
+ You can create reranking models using the `rerankingModel` method on the provider instance:
182
+
183
+ ```ts
184
+ import { rerank } from 'ai';
185
+ import { gateway } from '@ai-sdk/gateway';
186
+
187
+ const { ranking } = await rerank({
188
+ model: gateway.rerankingModel('cohere/rerank-v3.5'),
189
+ query: 'What is the capital of France?',
190
+ documents: [
191
+ 'Paris is the capital of France.',
192
+ 'Berlin is the capital of Germany.',
193
+ 'Madrid is the capital of Spain.',
194
+ ],
195
+ topN: 2,
196
+ });
197
+
198
+ console.log(ranking);
199
+ // [
200
+ // { originalIndex: 0, score: 0.89, document: 'Paris is the capital of France.' },
201
+ // { originalIndex: 2, score: 0.15, document: 'Madrid is the capital of Spain.' },
202
+ // ]
203
+ ```
204
+
205
+ Reranking models are useful for improving search results in retrieval-augmented generation (RAG) pipelines by re-scoring candidate documents after an initial retrieval step.
206
+
179
207
  ## Available Models
180
208
 
181
209
  The AI Gateway supports models from OpenAI, Anthropic, Google, Meta, xAI, Mistral, DeepSeek, Amazon Bedrock, Cohere, Perplexity, Alibaba, and other providers.
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@ai-sdk/gateway",
3
3
  "private": false,
4
- "version": "4.0.0-beta.39",
4
+ "version": "4.0.0-beta.40",
5
5
  "license": "Apache-2.0",
6
6
  "sideEffects": false,
7
7
  "main": "./dist/index.js",
@@ -27,8 +27,10 @@ import { GatewayLanguageModel } from './gateway-language-model';
27
27
  import { GatewayEmbeddingModel } from './gateway-embedding-model';
28
28
  import { GatewayImageModel } from './gateway-image-model';
29
29
  import { GatewayVideoModel } from './gateway-video-model';
30
+ import { GatewayRerankingModel } from './gateway-reranking-model';
30
31
  import type { GatewayEmbeddingModelId } from './gateway-embedding-model-settings';
31
32
  import type { GatewayImageModelId } from './gateway-image-model-settings';
33
+ import type { GatewayRerankingModelId } from './gateway-reranking-model-settings';
32
34
  import type { GatewayVideoModelId } from './gateway-video-model-settings';
33
35
  import { gatewayTools } from './gateway-tools';
34
36
  import { getVercelOidcToken, getVercelRequestId } from './vercel-environment';
@@ -37,6 +39,7 @@ import type {
37
39
  LanguageModelV4,
38
40
  EmbeddingModelV4,
39
41
  ImageModelV4,
42
+ RerankingModelV4,
40
43
  Experimental_VideoModelV4,
41
44
  ProviderV4,
42
45
  } from '@ai-sdk/provider';
@@ -117,6 +120,16 @@ export interface GatewayProvider extends ProviderV4 {
117
120
  */
118
121
  videoModel(modelId: GatewayVideoModelId): Experimental_VideoModelV4;
119
122
 
123
+ /**
124
+ * Creates a model for reranking documents.
125
+ */
126
+ reranking(modelId: GatewayRerankingModelId): RerankingModelV4;
127
+
128
+ /**
129
+ * Creates a model for reranking documents.
130
+ */
131
+ rerankingModel(modelId: GatewayRerankingModelId): RerankingModelV4;
132
+
120
133
  /**
121
134
  * Gateway-specific tools executed server-side.
122
135
  */
@@ -354,6 +367,17 @@ export function createGatewayProvider(
354
367
  o11yHeaders: createO11yHeaders(),
355
368
  });
356
369
  };
370
+ const createRerankingModel = (modelId: GatewayRerankingModelId) => {
371
+ return new GatewayRerankingModel(modelId, {
372
+ provider: 'gateway',
373
+ baseURL,
374
+ headers: getHeaders,
375
+ fetch: options.fetch,
376
+ o11yHeaders: createO11yHeaders(),
377
+ });
378
+ };
379
+ provider.rerankingModel = createRerankingModel;
380
+ provider.reranking = createRerankingModel;
357
381
  provider.chat = provider.languageModel;
358
382
  provider.embedding = provider.embeddingModel;
359
383
  provider.image = provider.imageModel;
@@ -0,0 +1 @@
1
+ export type GatewayRerankingModelId = 'cohere/rerank-v3.5' | (string & {});
@@ -0,0 +1,114 @@
1
+ import type {
2
+ RerankingModelV4,
3
+ SharedV4ProviderMetadata,
4
+ } from '@ai-sdk/provider';
5
+ import {
6
+ combineHeaders,
7
+ createJsonErrorResponseHandler,
8
+ createJsonResponseHandler,
9
+ lazySchema,
10
+ postJsonToApi,
11
+ resolve,
12
+ zodSchema,
13
+ type Resolvable,
14
+ } from '@ai-sdk/provider-utils';
15
+ import { z } from 'zod/v4';
16
+ import { asGatewayError } from './errors';
17
+ import { parseAuthMethod } from './errors/parse-auth-method';
18
+ import type { GatewayConfig } from './gateway-config';
19
+
20
+ export class GatewayRerankingModel implements RerankingModelV4 {
21
+ readonly specificationVersion = 'v4';
22
+
23
+ constructor(
24
+ readonly modelId: string,
25
+ private readonly config: GatewayConfig & {
26
+ provider: string;
27
+ o11yHeaders: Resolvable<Record<string, string>>;
28
+ },
29
+ ) {}
30
+
31
+ get provider(): string {
32
+ return this.config.provider;
33
+ }
34
+
35
+ async doRerank({
36
+ documents,
37
+ query,
38
+ topN,
39
+ headers,
40
+ abortSignal,
41
+ providerOptions,
42
+ }: Parameters<RerankingModelV4['doRerank']>[0]): Promise<
43
+ Awaited<ReturnType<RerankingModelV4['doRerank']>>
44
+ > {
45
+ const resolvedHeaders = await resolve(this.config.headers());
46
+ try {
47
+ const {
48
+ responseHeaders,
49
+ value: responseBody,
50
+ rawValue,
51
+ } = await postJsonToApi({
52
+ url: this.getUrl(),
53
+ headers: combineHeaders(
54
+ resolvedHeaders,
55
+ headers ?? {},
56
+ this.getModelConfigHeaders(),
57
+ await resolve(this.config.o11yHeaders),
58
+ ),
59
+ body: {
60
+ documents,
61
+ query,
62
+ ...(topN != null ? { topN } : {}),
63
+ ...(providerOptions ? { providerOptions } : {}),
64
+ },
65
+ successfulResponseHandler: createJsonResponseHandler(
66
+ gatewayRerankingResponseSchema,
67
+ ),
68
+ failedResponseHandler: createJsonErrorResponseHandler({
69
+ errorSchema: z.any(),
70
+ errorToMessage: data => data,
71
+ }),
72
+ ...(abortSignal && { abortSignal }),
73
+ fetch: this.config.fetch,
74
+ });
75
+
76
+ return {
77
+ ranking: responseBody.ranking,
78
+ providerMetadata:
79
+ responseBody.providerMetadata as unknown as SharedV4ProviderMetadata,
80
+ response: { headers: responseHeaders, body: rawValue },
81
+ warnings: [],
82
+ };
83
+ } catch (error) {
84
+ throw await asGatewayError(error, await parseAuthMethod(resolvedHeaders));
85
+ }
86
+ }
87
+
88
+ private getUrl() {
89
+ return `${this.config.baseURL}/reranking-model`;
90
+ }
91
+
92
+ private getModelConfigHeaders() {
93
+ return {
94
+ 'ai-reranking-model-specification-version': '4',
95
+ 'ai-model-id': this.modelId,
96
+ };
97
+ }
98
+ }
99
+
100
+ const gatewayRerankingResponseSchema = lazySchema(() =>
101
+ zodSchema(
102
+ z.object({
103
+ ranking: z.array(
104
+ z.object({
105
+ index: z.number(),
106
+ relevanceScore: z.number(),
107
+ }),
108
+ ),
109
+ providerMetadata: z
110
+ .record(z.string(), z.record(z.string(), z.unknown()))
111
+ .optional(),
112
+ }),
113
+ ),
114
+ );
package/src/index.ts CHANGED
@@ -1,4 +1,5 @@
1
1
  export type { GatewayModelId } from './gateway-language-model-settings';
2
+ export type { GatewayRerankingModelId } from './gateway-reranking-model-settings';
2
3
  export type { GatewayVideoModelId } from './gateway-video-model-settings';
3
4
  export type {
4
5
  GatewayLanguageModelEntry,