@ai-sdk/gateway 4.0.0-beta.6 → 4.0.0-beta.61
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +390 -4
- package/dist/index.d.ts +149 -24
- package/dist/index.js +735 -320
- package/dist/index.js.map +1 -1
- package/docs/00-ai-gateway.mdx +312 -45
- package/package.json +8 -10
- package/src/errors/create-gateway-error.ts +0 -1
- package/src/errors/gateway-authentication-error.ts +0 -1
- package/src/gateway-config.ts +1 -1
- package/src/gateway-embedding-model-settings.ts +1 -1
- package/src/gateway-embedding-model.ts +38 -14
- package/src/gateway-fetch-metadata.ts +51 -37
- package/src/gateway-generation-info.ts +149 -0
- package/src/gateway-image-model-settings.ts +9 -0
- package/src/gateway-image-model.ts +41 -21
- package/src/gateway-language-model-settings.ts +22 -10
- package/src/gateway-language-model.ts +49 -23
- package/src/gateway-model-entry.ts +13 -3
- package/src/gateway-provider-options.ts +35 -8
- package/src/gateway-provider.ts +100 -18
- package/src/gateway-reranking-model-settings.ts +7 -0
- package/src/gateway-reranking-model.ts +119 -0
- package/src/gateway-spend-report.ts +193 -0
- package/src/gateway-video-model-settings.ts +2 -0
- package/src/gateway-video-model.ts +22 -17
- package/src/index.ts +13 -3
- package/dist/index.d.mts +0 -602
- package/dist/index.mjs +0 -1539
- package/dist/index.mjs.map +0 -1
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import type {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
LanguageModelV3StreamResult,
|
|
2
|
+
LanguageModelV4,
|
|
3
|
+
LanguageModelV4CallOptions,
|
|
4
|
+
LanguageModelV4FilePart,
|
|
5
|
+
LanguageModelV4StreamPart,
|
|
6
|
+
LanguageModelV4GenerateResult,
|
|
7
|
+
LanguageModelV4StreamResult,
|
|
9
8
|
} from '@ai-sdk/provider';
|
|
10
9
|
import {
|
|
11
10
|
combineHeaders,
|
|
@@ -14,6 +13,9 @@ import {
|
|
|
14
13
|
createJsonResponseHandler,
|
|
15
14
|
postJsonToApi,
|
|
16
15
|
resolve,
|
|
16
|
+
serializeModelOptions,
|
|
17
|
+
WORKFLOW_SERIALIZE,
|
|
18
|
+
WORKFLOW_DESERIALIZE,
|
|
17
19
|
type ParseResult,
|
|
18
20
|
type Resolvable,
|
|
19
21
|
} from '@ai-sdk/provider-utils';
|
|
@@ -28,10 +30,24 @@ type GatewayChatConfig = GatewayConfig & {
|
|
|
28
30
|
o11yHeaders: Resolvable<Record<string, string>>;
|
|
29
31
|
};
|
|
30
32
|
|
|
31
|
-
export class GatewayLanguageModel implements
|
|
32
|
-
readonly specificationVersion = '
|
|
33
|
+
export class GatewayLanguageModel implements LanguageModelV4 {
|
|
34
|
+
readonly specificationVersion = 'v4';
|
|
33
35
|
readonly supportedUrls = { '*/*': [/.*/] };
|
|
34
36
|
|
|
37
|
+
static [WORKFLOW_SERIALIZE](model: GatewayLanguageModel) {
|
|
38
|
+
return serializeModelOptions({
|
|
39
|
+
modelId: model.modelId,
|
|
40
|
+
config: model.config,
|
|
41
|
+
});
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
static [WORKFLOW_DESERIALIZE](options: {
|
|
45
|
+
modelId: GatewayModelId;
|
|
46
|
+
config: GatewayChatConfig;
|
|
47
|
+
}) {
|
|
48
|
+
return new GatewayLanguageModel(options.modelId, options.config);
|
|
49
|
+
}
|
|
50
|
+
|
|
35
51
|
constructor(
|
|
36
52
|
readonly modelId: GatewayModelId,
|
|
37
53
|
private readonly config: GatewayChatConfig,
|
|
@@ -41,7 +57,7 @@ export class GatewayLanguageModel implements LanguageModelV3 {
|
|
|
41
57
|
return this.config.provider;
|
|
42
58
|
}
|
|
43
59
|
|
|
44
|
-
private async getArgs(options:
|
|
60
|
+
private async getArgs(options: LanguageModelV4CallOptions) {
|
|
45
61
|
const { abortSignal: _abortSignal, ...optionsWithoutSignal } = options;
|
|
46
62
|
|
|
47
63
|
return {
|
|
@@ -51,12 +67,14 @@ export class GatewayLanguageModel implements LanguageModelV3 {
|
|
|
51
67
|
}
|
|
52
68
|
|
|
53
69
|
async doGenerate(
|
|
54
|
-
options:
|
|
55
|
-
): Promise<
|
|
70
|
+
options: LanguageModelV4CallOptions,
|
|
71
|
+
): Promise<LanguageModelV4GenerateResult> {
|
|
56
72
|
const { args, warnings } = await this.getArgs(options);
|
|
57
73
|
const { abortSignal } = options;
|
|
58
74
|
|
|
59
|
-
const resolvedHeaders =
|
|
75
|
+
const resolvedHeaders = this.config.headers
|
|
76
|
+
? await resolve(this.config.headers)
|
|
77
|
+
: undefined;
|
|
60
78
|
|
|
61
79
|
try {
|
|
62
80
|
const {
|
|
@@ -88,17 +106,22 @@ export class GatewayLanguageModel implements LanguageModelV3 {
|
|
|
88
106
|
warnings,
|
|
89
107
|
};
|
|
90
108
|
} catch (error) {
|
|
91
|
-
throw await asGatewayError(
|
|
109
|
+
throw await asGatewayError(
|
|
110
|
+
error,
|
|
111
|
+
await parseAuthMethod(resolvedHeaders ?? {}),
|
|
112
|
+
);
|
|
92
113
|
}
|
|
93
114
|
}
|
|
94
115
|
|
|
95
116
|
async doStream(
|
|
96
|
-
options:
|
|
97
|
-
): Promise<
|
|
117
|
+
options: LanguageModelV4CallOptions,
|
|
118
|
+
): Promise<LanguageModelV4StreamResult> {
|
|
98
119
|
const { args, warnings } = await this.getArgs(options);
|
|
99
120
|
const { abortSignal } = options;
|
|
100
121
|
|
|
101
|
-
const resolvedHeaders =
|
|
122
|
+
const resolvedHeaders = this.config.headers
|
|
123
|
+
? await resolve(this.config.headers)
|
|
124
|
+
: undefined;
|
|
102
125
|
|
|
103
126
|
try {
|
|
104
127
|
const { value: response, responseHeaders } = await postJsonToApi({
|
|
@@ -122,8 +145,8 @@ export class GatewayLanguageModel implements LanguageModelV3 {
|
|
|
122
145
|
return {
|
|
123
146
|
stream: response.pipeThrough(
|
|
124
147
|
new TransformStream<
|
|
125
|
-
ParseResult<
|
|
126
|
-
|
|
148
|
+
ParseResult<LanguageModelV4StreamPart>,
|
|
149
|
+
LanguageModelV4StreamPart
|
|
127
150
|
>({
|
|
128
151
|
start(controller) {
|
|
129
152
|
if (warnings.length > 0) {
|
|
@@ -161,7 +184,10 @@ export class GatewayLanguageModel implements LanguageModelV3 {
|
|
|
161
184
|
response: { headers: responseHeaders },
|
|
162
185
|
};
|
|
163
186
|
} catch (error) {
|
|
164
|
-
throw await asGatewayError(
|
|
187
|
+
throw await asGatewayError(
|
|
188
|
+
error,
|
|
189
|
+
await parseAuthMethod(resolvedHeaders ?? {}),
|
|
190
|
+
);
|
|
165
191
|
}
|
|
166
192
|
}
|
|
167
193
|
|
|
@@ -177,11 +203,11 @@ export class GatewayLanguageModel implements LanguageModelV3 {
|
|
|
177
203
|
* @param options - The options to encode.
|
|
178
204
|
* @returns The options with the file parts encoded.
|
|
179
205
|
*/
|
|
180
|
-
private maybeEncodeFileParts(options:
|
|
206
|
+
private maybeEncodeFileParts(options: LanguageModelV4CallOptions) {
|
|
181
207
|
for (const message of options.prompt) {
|
|
182
208
|
for (const part of message.content) {
|
|
183
209
|
if (this.isFilePart(part)) {
|
|
184
|
-
const filePart = part as
|
|
210
|
+
const filePart = part as LanguageModelV4FilePart;
|
|
185
211
|
// If the file part is a URL it will get cleanly converted to a string.
|
|
186
212
|
// If it's a binary file attachment we convert it to a data url.
|
|
187
213
|
// In either case, server-side we should only ever see URLs as strings.
|
|
@@ -204,7 +230,7 @@ export class GatewayLanguageModel implements LanguageModelV3 {
|
|
|
204
230
|
|
|
205
231
|
private getModelConfigHeaders(modelId: string, streaming: boolean) {
|
|
206
232
|
return {
|
|
207
|
-
'ai-language-model-specification-version': '
|
|
233
|
+
'ai-language-model-specification-version': '4',
|
|
208
234
|
'ai-language-model-id': modelId,
|
|
209
235
|
'ai-language-model-streaming': String(streaming),
|
|
210
236
|
};
|
|
@@ -1,4 +1,14 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { LanguageModelV4 } from '@ai-sdk/provider';
|
|
2
|
+
|
|
3
|
+
export const KNOWN_MODEL_TYPES = [
|
|
4
|
+
'embedding',
|
|
5
|
+
'image',
|
|
6
|
+
'language',
|
|
7
|
+
'reranking',
|
|
8
|
+
'video',
|
|
9
|
+
] as const;
|
|
10
|
+
|
|
11
|
+
export type KnownModelType = (typeof KNOWN_MODEL_TYPES)[number];
|
|
2
12
|
|
|
3
13
|
export interface GatewayLanguageModelEntry {
|
|
4
14
|
/**
|
|
@@ -49,10 +59,10 @@ export interface GatewayLanguageModelEntry {
|
|
|
49
59
|
/**
|
|
50
60
|
* Optional field to differentiate between model types.
|
|
51
61
|
*/
|
|
52
|
-
modelType?:
|
|
62
|
+
modelType?: KnownModelType | null;
|
|
53
63
|
}
|
|
54
64
|
|
|
55
65
|
export type GatewayLanguageModelSpecification = Pick<
|
|
56
|
-
|
|
66
|
+
LanguageModelV4,
|
|
57
67
|
'specificationVersion' | 'provider' | 'modelId'
|
|
58
68
|
>;
|
|
@@ -2,7 +2,7 @@ import { InferSchema, lazySchema, zodSchema } from '@ai-sdk/provider-utils';
|
|
|
2
2
|
import { z } from 'zod/v4';
|
|
3
3
|
|
|
4
4
|
// https://vercel.com/docs/ai-gateway/provider-options
|
|
5
|
-
const
|
|
5
|
+
const gatewayProviderOptions = lazySchema(() =>
|
|
6
6
|
zodSchema(
|
|
7
7
|
z.object({
|
|
8
8
|
/**
|
|
@@ -17,6 +17,14 @@ const gatewayLanguageModelOptions = lazySchema(() =>
|
|
|
17
17
|
* Example: `['bedrock', 'anthropic']` will try Amazon Bedrock first, then Anthropic as fallback.
|
|
18
18
|
*/
|
|
19
19
|
order: z.array(z.string()).optional(),
|
|
20
|
+
/**
|
|
21
|
+
* Sort providers by a performance or cost metric before routing.
|
|
22
|
+
*
|
|
23
|
+
* - `'cost'`: lowest cost first
|
|
24
|
+
* - `'ttft'`: lowest time-to-first-token first
|
|
25
|
+
* - `'tps'`: highest tokens-per-second first
|
|
26
|
+
*/
|
|
27
|
+
sort: z.enum(['cost', 'ttft', 'tps']).optional(),
|
|
20
28
|
/**
|
|
21
29
|
* The unique identifier for the end user on behalf of whom the request was made.
|
|
22
30
|
*
|
|
@@ -53,12 +61,33 @@ const gatewayLanguageModelOptions = lazySchema(() =>
|
|
|
53
61
|
.record(z.string(), z.array(z.record(z.string(), z.unknown())))
|
|
54
62
|
.optional(),
|
|
55
63
|
/**
|
|
56
|
-
* Whether to filter by only providers that
|
|
57
|
-
*
|
|
58
|
-
*
|
|
59
|
-
*
|
|
64
|
+
* Whether to filter by only providers that have zero data retention
|
|
65
|
+
* agreements with Vercel for AI Gateway. When using BYOK credentials,
|
|
66
|
+
* this filter is not applied. If BYOK credentials fail and the request
|
|
67
|
+
* falls back to system credentials, only providers with zero data
|
|
68
|
+
* retention agreements will be used.
|
|
60
69
|
*/
|
|
61
70
|
zeroDataRetention: z.boolean().optional(),
|
|
71
|
+
/**
|
|
72
|
+
* Whether to filter by only providers that do not train on prompt data.
|
|
73
|
+
* When using BYOK credentials, this filter is not applied. If BYOK
|
|
74
|
+
* credentials fail and the request falls back to system credentials,
|
|
75
|
+
* only providers that have agreements with Vercel for AI Gateway to not
|
|
76
|
+
* use prompts for model training will be used.
|
|
77
|
+
*/
|
|
78
|
+
disallowPromptTraining: z.boolean().optional(),
|
|
79
|
+
/**
|
|
80
|
+
* Whether to filter by only providers that are HIPAA compliant with
|
|
81
|
+
* Vercel AI Gateway. When enabled, only providers that have agreements
|
|
82
|
+
* with Vercel AI Gateway for HIPAA compliance will be used.
|
|
83
|
+
*/
|
|
84
|
+
hipaaCompliant: z.boolean().optional(),
|
|
85
|
+
/**
|
|
86
|
+
* The unique identifier for the entity against which quota is tracked.
|
|
87
|
+
*
|
|
88
|
+
* Used for quota management and enforcement purposes.
|
|
89
|
+
*/
|
|
90
|
+
quotaEntityId: z.string().optional(),
|
|
62
91
|
/**
|
|
63
92
|
* Per-provider timeouts for BYOK credentials in milliseconds.
|
|
64
93
|
* Controls how long to wait for a provider to start responding
|
|
@@ -75,6 +104,4 @@ const gatewayLanguageModelOptions = lazySchema(() =>
|
|
|
75
104
|
),
|
|
76
105
|
);
|
|
77
106
|
|
|
78
|
-
export type
|
|
79
|
-
typeof gatewayLanguageModelOptions
|
|
80
|
-
>;
|
|
107
|
+
export type GatewayProviderOptions = InferSchema<typeof gatewayProviderOptions>;
|
package/src/gateway-provider.ts
CHANGED
|
@@ -13,38 +13,51 @@ import {
|
|
|
13
13
|
type GatewayFetchMetadataResponse,
|
|
14
14
|
type GatewayCreditsResponse,
|
|
15
15
|
} from './gateway-fetch-metadata';
|
|
16
|
+
import {
|
|
17
|
+
GatewaySpendReport,
|
|
18
|
+
type GatewaySpendReportParams,
|
|
19
|
+
type GatewaySpendReportResponse,
|
|
20
|
+
} from './gateway-spend-report';
|
|
21
|
+
import {
|
|
22
|
+
GatewayGenerationInfoFetcher,
|
|
23
|
+
type GatewayGenerationInfoParams,
|
|
24
|
+
type GatewayGenerationInfo,
|
|
25
|
+
} from './gateway-generation-info';
|
|
16
26
|
import { GatewayLanguageModel } from './gateway-language-model';
|
|
17
27
|
import { GatewayEmbeddingModel } from './gateway-embedding-model';
|
|
18
28
|
import { GatewayImageModel } from './gateway-image-model';
|
|
19
29
|
import { GatewayVideoModel } from './gateway-video-model';
|
|
30
|
+
import { GatewayRerankingModel } from './gateway-reranking-model';
|
|
20
31
|
import type { GatewayEmbeddingModelId } from './gateway-embedding-model-settings';
|
|
21
32
|
import type { GatewayImageModelId } from './gateway-image-model-settings';
|
|
33
|
+
import type { GatewayRerankingModelId } from './gateway-reranking-model-settings';
|
|
22
34
|
import type { GatewayVideoModelId } from './gateway-video-model-settings';
|
|
23
35
|
import { gatewayTools } from './gateway-tools';
|
|
24
36
|
import { getVercelOidcToken, getVercelRequestId } from './vercel-environment';
|
|
25
37
|
import type { GatewayModelId } from './gateway-language-model-settings';
|
|
26
38
|
import type {
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
39
|
+
LanguageModelV4,
|
|
40
|
+
EmbeddingModelV4,
|
|
41
|
+
ImageModelV4,
|
|
42
|
+
RerankingModelV4,
|
|
43
|
+
Experimental_VideoModelV4,
|
|
44
|
+
ProviderV4,
|
|
32
45
|
} from '@ai-sdk/provider';
|
|
33
46
|
import { withUserAgentSuffix } from '@ai-sdk/provider-utils';
|
|
34
47
|
import { VERSION } from './version';
|
|
35
48
|
|
|
36
|
-
export interface GatewayProvider extends
|
|
37
|
-
(modelId: GatewayModelId):
|
|
49
|
+
export interface GatewayProvider extends ProviderV4 {
|
|
50
|
+
(modelId: GatewayModelId): LanguageModelV4;
|
|
38
51
|
|
|
39
52
|
/**
|
|
40
53
|
* Creates a model for text generation.
|
|
41
54
|
*/
|
|
42
|
-
chat(modelId: GatewayModelId):
|
|
55
|
+
chat(modelId: GatewayModelId): LanguageModelV4;
|
|
43
56
|
|
|
44
57
|
/**
|
|
45
58
|
* Creates a model for text generation.
|
|
46
59
|
*/
|
|
47
|
-
languageModel(modelId: GatewayModelId):
|
|
60
|
+
languageModel(modelId: GatewayModelId): LanguageModelV4;
|
|
48
61
|
|
|
49
62
|
/**
|
|
50
63
|
* Returns available providers and models for use with the remote provider.
|
|
@@ -56,40 +69,66 @@ export interface GatewayProvider extends ProviderV3 {
|
|
|
56
69
|
*/
|
|
57
70
|
getCredits(): Promise<GatewayCreditsResponse>;
|
|
58
71
|
|
|
72
|
+
/**
|
|
73
|
+
* Returns a spend report with cost, token, and request count data,
|
|
74
|
+
* aggregated by the specified dimension.
|
|
75
|
+
*/
|
|
76
|
+
getSpendReport(
|
|
77
|
+
params: GatewaySpendReportParams,
|
|
78
|
+
): Promise<GatewaySpendReportResponse>;
|
|
79
|
+
|
|
80
|
+
/**
|
|
81
|
+
* Returns detailed information about a specific generation by its ID,
|
|
82
|
+
* including cost, token usage, latency, and provider details.
|
|
83
|
+
*/
|
|
84
|
+
getGenerationInfo(
|
|
85
|
+
params: GatewayGenerationInfoParams,
|
|
86
|
+
): Promise<GatewayGenerationInfo>;
|
|
87
|
+
|
|
59
88
|
/**
|
|
60
89
|
* Creates a model for generating text embeddings.
|
|
61
90
|
*/
|
|
62
|
-
embedding(modelId: GatewayEmbeddingModelId):
|
|
91
|
+
embedding(modelId: GatewayEmbeddingModelId): EmbeddingModelV4;
|
|
63
92
|
|
|
64
93
|
/**
|
|
65
94
|
* Creates a model for generating text embeddings.
|
|
66
95
|
*/
|
|
67
|
-
embeddingModel(modelId: GatewayEmbeddingModelId):
|
|
96
|
+
embeddingModel(modelId: GatewayEmbeddingModelId): EmbeddingModelV4;
|
|
68
97
|
|
|
69
98
|
/**
|
|
70
99
|
* @deprecated Use `embeddingModel` instead.
|
|
71
100
|
*/
|
|
72
|
-
textEmbeddingModel(modelId: GatewayEmbeddingModelId):
|
|
101
|
+
textEmbeddingModel(modelId: GatewayEmbeddingModelId): EmbeddingModelV4;
|
|
73
102
|
|
|
74
103
|
/**
|
|
75
104
|
* Creates a model for generating images.
|
|
76
105
|
*/
|
|
77
|
-
image(modelId: GatewayImageModelId):
|
|
106
|
+
image(modelId: GatewayImageModelId): ImageModelV4;
|
|
78
107
|
|
|
79
108
|
/**
|
|
80
109
|
* Creates a model for generating images.
|
|
81
110
|
*/
|
|
82
|
-
imageModel(modelId: GatewayImageModelId):
|
|
111
|
+
imageModel(modelId: GatewayImageModelId): ImageModelV4;
|
|
83
112
|
|
|
84
113
|
/**
|
|
85
114
|
* Creates a model for generating videos.
|
|
86
115
|
*/
|
|
87
|
-
video(modelId: GatewayVideoModelId):
|
|
116
|
+
video(modelId: GatewayVideoModelId): Experimental_VideoModelV4;
|
|
88
117
|
|
|
89
118
|
/**
|
|
90
119
|
* Creates a model for generating videos.
|
|
91
120
|
*/
|
|
92
|
-
videoModel(modelId: GatewayVideoModelId):
|
|
121
|
+
videoModel(modelId: GatewayVideoModelId): Experimental_VideoModelV4;
|
|
122
|
+
|
|
123
|
+
/**
|
|
124
|
+
* Creates a model for reranking documents.
|
|
125
|
+
*/
|
|
126
|
+
reranking(modelId: GatewayRerankingModelId): RerankingModelV4;
|
|
127
|
+
|
|
128
|
+
/**
|
|
129
|
+
* Creates a model for reranking documents.
|
|
130
|
+
*/
|
|
131
|
+
rerankingModel(modelId: GatewayRerankingModelId): RerankingModelV4;
|
|
93
132
|
|
|
94
133
|
/**
|
|
95
134
|
* Gateway-specific tools executed server-side.
|
|
@@ -148,7 +187,7 @@ export function createGatewayProvider(
|
|
|
148
187
|
|
|
149
188
|
const baseURL =
|
|
150
189
|
withoutTrailingSlash(options.baseURL) ??
|
|
151
|
-
'https://ai-gateway.vercel.sh/
|
|
190
|
+
'https://ai-gateway.vercel.sh/v4/ai';
|
|
152
191
|
|
|
153
192
|
const getHeaders = async () => {
|
|
154
193
|
try {
|
|
@@ -253,6 +292,36 @@ export function createGatewayProvider(
|
|
|
253
292
|
});
|
|
254
293
|
};
|
|
255
294
|
|
|
295
|
+
const getSpendReport = async (params: GatewaySpendReportParams) => {
|
|
296
|
+
return new GatewaySpendReport({
|
|
297
|
+
baseURL,
|
|
298
|
+
headers: getHeaders,
|
|
299
|
+
fetch: options.fetch,
|
|
300
|
+
})
|
|
301
|
+
.getSpendReport(params)
|
|
302
|
+
.catch(async (error: unknown) => {
|
|
303
|
+
throw await asGatewayError(
|
|
304
|
+
error,
|
|
305
|
+
await parseAuthMethod(await getHeaders()),
|
|
306
|
+
);
|
|
307
|
+
});
|
|
308
|
+
};
|
|
309
|
+
|
|
310
|
+
const getGenerationInfo = async (params: GatewayGenerationInfoParams) => {
|
|
311
|
+
return new GatewayGenerationInfoFetcher({
|
|
312
|
+
baseURL,
|
|
313
|
+
headers: getHeaders,
|
|
314
|
+
fetch: options.fetch,
|
|
315
|
+
})
|
|
316
|
+
.getGenerationInfo(params)
|
|
317
|
+
.catch(async (error: unknown) => {
|
|
318
|
+
throw await asGatewayError(
|
|
319
|
+
error,
|
|
320
|
+
await parseAuthMethod(await getHeaders()),
|
|
321
|
+
);
|
|
322
|
+
});
|
|
323
|
+
};
|
|
324
|
+
|
|
256
325
|
const provider = function (modelId: GatewayModelId) {
|
|
257
326
|
if (new.target) {
|
|
258
327
|
throw new Error(
|
|
@@ -263,9 +332,11 @@ export function createGatewayProvider(
|
|
|
263
332
|
return createLanguageModel(modelId);
|
|
264
333
|
};
|
|
265
334
|
|
|
266
|
-
provider.specificationVersion = '
|
|
335
|
+
provider.specificationVersion = 'v4' as const;
|
|
267
336
|
provider.getAvailableModels = getAvailableModels;
|
|
268
337
|
provider.getCredits = getCredits;
|
|
338
|
+
provider.getSpendReport = getSpendReport;
|
|
339
|
+
provider.getGenerationInfo = getGenerationInfo;
|
|
269
340
|
provider.imageModel = (modelId: GatewayImageModelId) => {
|
|
270
341
|
return new GatewayImageModel(modelId, {
|
|
271
342
|
provider: 'gateway',
|
|
@@ -296,6 +367,17 @@ export function createGatewayProvider(
|
|
|
296
367
|
o11yHeaders: createO11yHeaders(),
|
|
297
368
|
});
|
|
298
369
|
};
|
|
370
|
+
const createRerankingModel = (modelId: GatewayRerankingModelId) => {
|
|
371
|
+
return new GatewayRerankingModel(modelId, {
|
|
372
|
+
provider: 'gateway',
|
|
373
|
+
baseURL,
|
|
374
|
+
headers: getHeaders,
|
|
375
|
+
fetch: options.fetch,
|
|
376
|
+
o11yHeaders: createO11yHeaders(),
|
|
377
|
+
});
|
|
378
|
+
};
|
|
379
|
+
provider.rerankingModel = createRerankingModel;
|
|
380
|
+
provider.reranking = createRerankingModel;
|
|
299
381
|
provider.chat = provider.languageModel;
|
|
300
382
|
provider.embedding = provider.embeddingModel;
|
|
301
383
|
provider.image = provider.imageModel;
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import type {
|
|
2
|
+
RerankingModelV4,
|
|
3
|
+
SharedV4ProviderMetadata,
|
|
4
|
+
} from '@ai-sdk/provider';
|
|
5
|
+
import {
|
|
6
|
+
combineHeaders,
|
|
7
|
+
createJsonErrorResponseHandler,
|
|
8
|
+
createJsonResponseHandler,
|
|
9
|
+
lazySchema,
|
|
10
|
+
postJsonToApi,
|
|
11
|
+
resolve,
|
|
12
|
+
zodSchema,
|
|
13
|
+
type Resolvable,
|
|
14
|
+
} from '@ai-sdk/provider-utils';
|
|
15
|
+
import { z } from 'zod/v4';
|
|
16
|
+
import { asGatewayError } from './errors';
|
|
17
|
+
import { parseAuthMethod } from './errors/parse-auth-method';
|
|
18
|
+
import type { GatewayConfig } from './gateway-config';
|
|
19
|
+
|
|
20
|
+
export class GatewayRerankingModel implements RerankingModelV4 {
|
|
21
|
+
readonly specificationVersion = 'v4';
|
|
22
|
+
|
|
23
|
+
constructor(
|
|
24
|
+
readonly modelId: string,
|
|
25
|
+
private readonly config: GatewayConfig & {
|
|
26
|
+
provider: string;
|
|
27
|
+
o11yHeaders: Resolvable<Record<string, string>>;
|
|
28
|
+
},
|
|
29
|
+
) {}
|
|
30
|
+
|
|
31
|
+
get provider(): string {
|
|
32
|
+
return this.config.provider;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
async doRerank({
|
|
36
|
+
documents,
|
|
37
|
+
query,
|
|
38
|
+
topN,
|
|
39
|
+
headers,
|
|
40
|
+
abortSignal,
|
|
41
|
+
providerOptions,
|
|
42
|
+
}: Parameters<RerankingModelV4['doRerank']>[0]): Promise<
|
|
43
|
+
Awaited<ReturnType<RerankingModelV4['doRerank']>>
|
|
44
|
+
> {
|
|
45
|
+
const resolvedHeaders = this.config.headers
|
|
46
|
+
? await resolve(this.config.headers)
|
|
47
|
+
: undefined;
|
|
48
|
+
try {
|
|
49
|
+
const {
|
|
50
|
+
responseHeaders,
|
|
51
|
+
value: responseBody,
|
|
52
|
+
rawValue,
|
|
53
|
+
} = await postJsonToApi({
|
|
54
|
+
url: this.getUrl(),
|
|
55
|
+
headers: combineHeaders(
|
|
56
|
+
resolvedHeaders,
|
|
57
|
+
headers ?? {},
|
|
58
|
+
this.getModelConfigHeaders(),
|
|
59
|
+
await resolve(this.config.o11yHeaders),
|
|
60
|
+
),
|
|
61
|
+
body: {
|
|
62
|
+
documents,
|
|
63
|
+
query,
|
|
64
|
+
...(topN != null ? { topN } : {}),
|
|
65
|
+
...(providerOptions ? { providerOptions } : {}),
|
|
66
|
+
},
|
|
67
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
68
|
+
gatewayRerankingResponseSchema,
|
|
69
|
+
),
|
|
70
|
+
failedResponseHandler: createJsonErrorResponseHandler({
|
|
71
|
+
errorSchema: z.any(),
|
|
72
|
+
errorToMessage: data => data,
|
|
73
|
+
}),
|
|
74
|
+
...(abortSignal && { abortSignal }),
|
|
75
|
+
fetch: this.config.fetch,
|
|
76
|
+
});
|
|
77
|
+
|
|
78
|
+
return {
|
|
79
|
+
ranking: responseBody.ranking,
|
|
80
|
+
providerMetadata:
|
|
81
|
+
responseBody.providerMetadata as unknown as SharedV4ProviderMetadata,
|
|
82
|
+
response: { headers: responseHeaders, body: rawValue },
|
|
83
|
+
warnings: [],
|
|
84
|
+
};
|
|
85
|
+
} catch (error) {
|
|
86
|
+
throw await asGatewayError(
|
|
87
|
+
error,
|
|
88
|
+
await parseAuthMethod(resolvedHeaders ?? {}),
|
|
89
|
+
);
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
private getUrl() {
|
|
94
|
+
return `${this.config.baseURL}/reranking-model`;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
private getModelConfigHeaders() {
|
|
98
|
+
return {
|
|
99
|
+
'ai-reranking-model-specification-version': '4',
|
|
100
|
+
'ai-model-id': this.modelId,
|
|
101
|
+
};
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
const gatewayRerankingResponseSchema = lazySchema(() =>
|
|
106
|
+
zodSchema(
|
|
107
|
+
z.object({
|
|
108
|
+
ranking: z.array(
|
|
109
|
+
z.object({
|
|
110
|
+
index: z.number(),
|
|
111
|
+
relevanceScore: z.number(),
|
|
112
|
+
}),
|
|
113
|
+
),
|
|
114
|
+
providerMetadata: z
|
|
115
|
+
.record(z.string(), z.record(z.string(), z.unknown()))
|
|
116
|
+
.optional(),
|
|
117
|
+
}),
|
|
118
|
+
),
|
|
119
|
+
);
|