@ai-sdk/amazon-bedrock 4.0.24 → 4.0.26
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +16 -0
- package/dist/anthropic/index.js +1 -1
- package/dist/anthropic/index.mjs +1 -1
- package/dist/index.js +1 -1
- package/dist/index.mjs +1 -1
- package/docs/08-amazon-bedrock.mdx +1453 -0
- package/package.json +11 -6
- package/src/__fixtures__/bedrock-json-only-text-first.1.chunks.txt +7 -0
- package/src/__fixtures__/bedrock-json-other-tool.1.chunks.txt +6 -0
- package/src/__fixtures__/bedrock-json-other-tool.1.json +24 -0
- package/src/__fixtures__/bedrock-json-tool-text-then-weather-then-json.1.chunks.txt +12 -0
- package/src/__fixtures__/bedrock-json-tool-with-answer.1.json +29 -0
- package/src/__fixtures__/bedrock-json-tool.1.chunks.txt +4 -0
- package/src/__fixtures__/bedrock-json-tool.1.json +35 -0
- package/src/__fixtures__/bedrock-json-tool.2.chunks.txt +6 -0
- package/src/__fixtures__/bedrock-json-tool.2.json +28 -0
- package/src/__fixtures__/bedrock-json-tool.3.chunks.txt +7 -0
- package/src/__fixtures__/bedrock-json-tool.3.json +36 -0
- package/src/__fixtures__/bedrock-json-with-tool.1.chunks.txt +9 -0
- package/src/__fixtures__/bedrock-json-with-tool.1.json +41 -0
- package/src/__fixtures__/bedrock-json-with-tools.1.chunks.txt +12 -0
- package/src/__fixtures__/bedrock-json-with-tools.1.json +50 -0
- package/src/__fixtures__/bedrock-tool-call.1.chunks.txt +6 -0
- package/src/__fixtures__/bedrock-tool-call.1.json +24 -0
- package/src/__fixtures__/bedrock-tool-no-args.chunks.txt +8 -0
- package/src/__fixtures__/bedrock-tool-no-args.json +25 -0
- package/src/anthropic/bedrock-anthropic-fetch.test.ts +344 -0
- package/src/anthropic/bedrock-anthropic-fetch.ts +62 -0
- package/src/anthropic/bedrock-anthropic-options.ts +28 -0
- package/src/anthropic/bedrock-anthropic-provider.test.ts +456 -0
- package/src/anthropic/bedrock-anthropic-provider.ts +357 -0
- package/src/anthropic/index.ts +9 -0
- package/src/bedrock-api-types.ts +195 -0
- package/src/bedrock-chat-language-model.test.ts +4569 -0
- package/src/bedrock-chat-language-model.ts +1019 -0
- package/src/bedrock-chat-options.ts +114 -0
- package/src/bedrock-embedding-model.test.ts +148 -0
- package/src/bedrock-embedding-model.ts +104 -0
- package/src/bedrock-embedding-options.ts +24 -0
- package/src/bedrock-error.ts +6 -0
- package/src/bedrock-event-stream-decoder.ts +59 -0
- package/src/bedrock-event-stream-response-handler.test.ts +233 -0
- package/src/bedrock-event-stream-response-handler.ts +57 -0
- package/src/bedrock-image-model.test.ts +866 -0
- package/src/bedrock-image-model.ts +297 -0
- package/src/bedrock-image-settings.ts +6 -0
- package/src/bedrock-prepare-tools.ts +190 -0
- package/src/bedrock-provider.test.ts +457 -0
- package/src/bedrock-provider.ts +351 -0
- package/src/bedrock-sigv4-fetch.test.ts +675 -0
- package/src/bedrock-sigv4-fetch.ts +138 -0
- package/src/convert-bedrock-usage.test.ts +207 -0
- package/src/convert-bedrock-usage.ts +50 -0
- package/src/convert-to-bedrock-chat-messages.test.ts +1175 -0
- package/src/convert-to-bedrock-chat-messages.ts +452 -0
- package/src/index.ts +10 -0
- package/src/inject-fetch-headers.test.ts +135 -0
- package/src/inject-fetch-headers.ts +32 -0
- package/src/map-bedrock-finish-reason.ts +22 -0
- package/src/normalize-tool-call-id.test.ts +72 -0
- package/src/normalize-tool-call-id.ts +36 -0
- package/src/reranking/__fixtures__/bedrock-reranking.1.json +12 -0
- package/src/reranking/bedrock-reranking-api.ts +44 -0
- package/src/reranking/bedrock-reranking-model.test.ts +299 -0
- package/src/reranking/bedrock-reranking-model.ts +115 -0
- package/src/reranking/bedrock-reranking-options.ts +36 -0
- package/src/version.ts +6 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import { describe, it, expect } from 'vitest';
|
|
2
|
+
import { isMistralModel, normalizeToolCallId } from './normalize-tool-call-id';
|
|
3
|
+
|
|
4
|
+
describe('isMistralModel', () => {
|
|
5
|
+
it('should return true for mistral models', () => {
|
|
6
|
+
expect(isMistralModel('mistral.mistral-7b-instruct-v0:2')).toBe(true);
|
|
7
|
+
expect(isMistralModel('mistral.mixtral-8x7b-instruct-v0:1')).toBe(true);
|
|
8
|
+
expect(isMistralModel('mistral.mistral-large-2402-v1:0')).toBe(true);
|
|
9
|
+
expect(isMistralModel('mistral.mistral-small-2402-v1:0')).toBe(true);
|
|
10
|
+
expect(isMistralModel('mistral.mistral-large-2407-v1:0')).toBe(true);
|
|
11
|
+
expect(isMistralModel('mistral.ministral-3-14b-instruct')).toBe(true);
|
|
12
|
+
expect(isMistralModel('mistral.ministral-3-8b-instruct')).toBe(true);
|
|
13
|
+
});
|
|
14
|
+
|
|
15
|
+
it('should return true for region-prefixed mistral models', () => {
|
|
16
|
+
expect(isMistralModel('us.mistral.pixtral-large-2502-v1:0')).toBe(true);
|
|
17
|
+
expect(isMistralModel('eu.mistral.mistral-large-2407-v1:0')).toBe(true);
|
|
18
|
+
});
|
|
19
|
+
|
|
20
|
+
it('should return false for non-mistral models', () => {
|
|
21
|
+
expect(isMistralModel('anthropic.claude-3-5-sonnet-20241022-v2:0')).toBe(
|
|
22
|
+
false,
|
|
23
|
+
);
|
|
24
|
+
expect(isMistralModel('amazon.nova-pro-v1:0')).toBe(false);
|
|
25
|
+
expect(isMistralModel('openai.gpt-4o')).toBe(false);
|
|
26
|
+
expect(isMistralModel('meta.llama3-70b-instruct-v1:0')).toBe(false);
|
|
27
|
+
});
|
|
28
|
+
});
|
|
29
|
+
|
|
30
|
+
describe('normalizeToolCallId', () => {
|
|
31
|
+
it('should return the original ID when not a Mistral model', () => {
|
|
32
|
+
const originalId = 'tooluse_bpe71yCfRu2b5i-nKGDr5g';
|
|
33
|
+
expect(normalizeToolCallId(originalId, false)).toBe(originalId);
|
|
34
|
+
});
|
|
35
|
+
|
|
36
|
+
it('should extract first 9 alphanumeric characters for Mistral models', () => {
|
|
37
|
+
// Bedrock format: tooluse_bpe71yCfRu2b5i-nKGDr5g
|
|
38
|
+
// After removing non-alphanumeric: toolusebpe71yCfRu2b5inKGDr5g
|
|
39
|
+
// First 9 chars: toolusebp
|
|
40
|
+
expect(normalizeToolCallId('tooluse_bpe71yCfRu2b5i-nKGDr5g', true)).toBe(
|
|
41
|
+
'toolusebp',
|
|
42
|
+
);
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
it('should handle IDs with various special characters', () => {
|
|
46
|
+
expect(normalizeToolCallId('tool-use_123ABC456', true)).toBe('tooluse12');
|
|
47
|
+
expect(normalizeToolCallId('___abc123DEF___', true)).toBe('abc123DEF');
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
it('should handle IDs that are already alphanumeric', () => {
|
|
51
|
+
expect(normalizeToolCallId('abcdefghi', true)).toBe('abcdefghi');
|
|
52
|
+
expect(normalizeToolCallId('abc123XYZ', true)).toBe('abc123XYZ');
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
it('should handle short IDs', () => {
|
|
56
|
+
expect(normalizeToolCallId('abc', true)).toBe('abc');
|
|
57
|
+
expect(normalizeToolCallId('12345', true)).toBe('12345');
|
|
58
|
+
});
|
|
59
|
+
|
|
60
|
+
it('should handle IDs with only special characters', () => {
|
|
61
|
+
expect(normalizeToolCallId('___---___', true)).toBe('');
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it('should produce valid Mistral tool call IDs (9 alphanumeric chars)', () => {
|
|
65
|
+
const normalizedId = normalizeToolCallId(
|
|
66
|
+
'tooluse_bpe71yCfRu2b5i-nKGDr5g',
|
|
67
|
+
true,
|
|
68
|
+
);
|
|
69
|
+
// Verify the ID matches Mistral's requirements: ^[a-zA-Z0-9]{1,9}$
|
|
70
|
+
expect(normalizedId).toMatch(/^[a-zA-Z0-9]{1,9}$/);
|
|
71
|
+
});
|
|
72
|
+
});
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Checks if the given model ID is a Mistral model.
|
|
3
|
+
* Mistral models on Bedrock are prefixed with 'mistral.' or region-prefixed like 'us.mistral.'.
|
|
4
|
+
*/
|
|
5
|
+
export function isMistralModel(modelId: string): boolean {
|
|
6
|
+
return modelId.includes('mistral.');
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* Normalizes a tool call ID for Mistral models.
|
|
11
|
+
*
|
|
12
|
+
* Mistral models require tool call IDs to match the regex `^[a-zA-Z0-9]{9}$`:
|
|
13
|
+
* - Exactly 9 characters
|
|
14
|
+
* - Alphanumeric only (no underscores, hyphens, or other characters)
|
|
15
|
+
*
|
|
16
|
+
* Bedrock generates tool call IDs in formats like `tooluse_bpe71yCfRu2b5i-nKGDr5g`,
|
|
17
|
+
* which are incompatible with Mistral's requirements.
|
|
18
|
+
*
|
|
19
|
+
* This function extracts the first 9 alphanumeric characters from the ID.
|
|
20
|
+
*
|
|
21
|
+
* @param toolCallId - The original tool call ID from Bedrock
|
|
22
|
+
* @param isMistral - Whether the model is a Mistral model
|
|
23
|
+
* @returns The normalized tool call ID (9 alphanumeric chars) if Mistral, otherwise the original ID
|
|
24
|
+
*/
|
|
25
|
+
export function normalizeToolCallId(
|
|
26
|
+
toolCallId: string,
|
|
27
|
+
isMistral: boolean,
|
|
28
|
+
): string {
|
|
29
|
+
if (!isMistral) {
|
|
30
|
+
return toolCallId;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
// Extract only alphanumeric characters and take first 9
|
|
34
|
+
const alphanumericChars = toolCallId.replace(/[^a-zA-Z0-9]/g, '');
|
|
35
|
+
return alphanumericChars.slice(0, 9);
|
|
36
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import { lazySchema, zodSchema } from '@ai-sdk/provider-utils';
|
|
2
|
+
import { z } from 'zod/v4';
|
|
3
|
+
|
|
4
|
+
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_Rerank.html
|
|
5
|
+
export type BedrockRerankingInput = {
|
|
6
|
+
nextToken?: string;
|
|
7
|
+
queries: [{ type: 'TEXT'; textQuery: { text: string } }];
|
|
8
|
+
rerankingConfiguration: {
|
|
9
|
+
type: 'BEDROCK_RERANKING_MODEL';
|
|
10
|
+
bedrockRerankingConfiguration: {
|
|
11
|
+
modelConfiguration: {
|
|
12
|
+
modelArn: string;
|
|
13
|
+
additionalModelRequestFields?: Record<string, unknown>;
|
|
14
|
+
};
|
|
15
|
+
numberOfResults?: number;
|
|
16
|
+
};
|
|
17
|
+
};
|
|
18
|
+
sources: {
|
|
19
|
+
type: 'INLINE';
|
|
20
|
+
inlineDocumentSource:
|
|
21
|
+
| {
|
|
22
|
+
type: 'TEXT';
|
|
23
|
+
textDocument: { text: string };
|
|
24
|
+
}
|
|
25
|
+
| {
|
|
26
|
+
type: 'JSON';
|
|
27
|
+
jsonDocument: unknown;
|
|
28
|
+
};
|
|
29
|
+
}[];
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
export const bedrockRerankingResponseSchema = lazySchema(() =>
|
|
33
|
+
zodSchema(
|
|
34
|
+
z.object({
|
|
35
|
+
results: z.array(
|
|
36
|
+
z.object({
|
|
37
|
+
index: z.number(),
|
|
38
|
+
relevanceScore: z.number(),
|
|
39
|
+
}),
|
|
40
|
+
),
|
|
41
|
+
nextToken: z.string().optional(),
|
|
42
|
+
}),
|
|
43
|
+
),
|
|
44
|
+
);
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
2
|
+
import fs from 'node:fs';
|
|
3
|
+
import { beforeEach, describe, expect, it } from 'vitest';
|
|
4
|
+
import { injectFetchHeaders } from '../inject-fetch-headers';
|
|
5
|
+
import { BedrockRerankingModel } from './bedrock-reranking-model';
|
|
6
|
+
import { BedrockRerankingOptions } from './bedrock-reranking-options';
|
|
7
|
+
|
|
8
|
+
const fakeFetchWithAuth = injectFetchHeaders({ 'x-amz-auth': 'test-auth' });
|
|
9
|
+
|
|
10
|
+
const model = new BedrockRerankingModel('cohere.rerank-v3-5:0', {
|
|
11
|
+
baseUrl: () => 'https://bedrock-agent-runtime.us-east-1.amazonaws.com',
|
|
12
|
+
region: 'us-west-2',
|
|
13
|
+
headers: {
|
|
14
|
+
'config-header': 'config-value',
|
|
15
|
+
'shared-header': 'config-shared',
|
|
16
|
+
},
|
|
17
|
+
fetch: fakeFetchWithAuth,
|
|
18
|
+
});
|
|
19
|
+
|
|
20
|
+
describe('doRerank', () => {
|
|
21
|
+
const server = createTestServer({
|
|
22
|
+
'https://bedrock-agent-runtime.us-east-1.amazonaws.com/rerank': {},
|
|
23
|
+
});
|
|
24
|
+
|
|
25
|
+
function prepareJsonFixtureResponse(filename: string) {
|
|
26
|
+
server.urls[
|
|
27
|
+
'https://bedrock-agent-runtime.us-east-1.amazonaws.com/rerank'
|
|
28
|
+
].response = {
|
|
29
|
+
type: 'binary',
|
|
30
|
+
headers: { 'content-type': 'application/json' },
|
|
31
|
+
body: Buffer.from(
|
|
32
|
+
fs.readFileSync(`src/reranking/__fixtures__/${filename}.json`, 'utf8'),
|
|
33
|
+
),
|
|
34
|
+
};
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
describe('json documents', () => {
|
|
38
|
+
let result: Awaited<ReturnType<typeof model.doRerank>>;
|
|
39
|
+
|
|
40
|
+
beforeEach(async () => {
|
|
41
|
+
prepareJsonFixtureResponse('bedrock-reranking.1');
|
|
42
|
+
|
|
43
|
+
result = await model.doRerank({
|
|
44
|
+
documents: {
|
|
45
|
+
type: 'object',
|
|
46
|
+
values: [
|
|
47
|
+
{ example: 'sunny day at the beach' },
|
|
48
|
+
{ example: 'rainy day in the city' },
|
|
49
|
+
],
|
|
50
|
+
},
|
|
51
|
+
query: 'rainy day',
|
|
52
|
+
topN: 2,
|
|
53
|
+
providerOptions: {
|
|
54
|
+
bedrock: {
|
|
55
|
+
nextToken: 'test-token',
|
|
56
|
+
additionalModelRequestFields: {
|
|
57
|
+
test: 'test-value',
|
|
58
|
+
},
|
|
59
|
+
} satisfies BedrockRerankingOptions,
|
|
60
|
+
},
|
|
61
|
+
});
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it('should send request with stringified json documents', async () => {
|
|
65
|
+
expect(await server.calls[0].requestBodyJson).toMatchInlineSnapshot(`
|
|
66
|
+
{
|
|
67
|
+
"nextToken": "test-token",
|
|
68
|
+
"queries": [
|
|
69
|
+
{
|
|
70
|
+
"textQuery": {
|
|
71
|
+
"text": "rainy day",
|
|
72
|
+
},
|
|
73
|
+
"type": "TEXT",
|
|
74
|
+
},
|
|
75
|
+
],
|
|
76
|
+
"rerankingConfiguration": {
|
|
77
|
+
"bedrockRerankingConfiguration": {
|
|
78
|
+
"modelConfiguration": {
|
|
79
|
+
"additionalModelRequestFields": {
|
|
80
|
+
"test": "test-value",
|
|
81
|
+
},
|
|
82
|
+
"modelArn": "arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0",
|
|
83
|
+
},
|
|
84
|
+
"numberOfResults": 2,
|
|
85
|
+
},
|
|
86
|
+
"type": "BEDROCK_RERANKING_MODEL",
|
|
87
|
+
},
|
|
88
|
+
"sources": [
|
|
89
|
+
{
|
|
90
|
+
"inlineDocumentSource": {
|
|
91
|
+
"jsonDocument": {
|
|
92
|
+
"example": "sunny day at the beach",
|
|
93
|
+
},
|
|
94
|
+
"type": "JSON",
|
|
95
|
+
},
|
|
96
|
+
"type": "INLINE",
|
|
97
|
+
},
|
|
98
|
+
{
|
|
99
|
+
"inlineDocumentSource": {
|
|
100
|
+
"jsonDocument": {
|
|
101
|
+
"example": "rainy day in the city",
|
|
102
|
+
},
|
|
103
|
+
"type": "JSON",
|
|
104
|
+
},
|
|
105
|
+
"type": "INLINE",
|
|
106
|
+
},
|
|
107
|
+
],
|
|
108
|
+
}
|
|
109
|
+
`);
|
|
110
|
+
});
|
|
111
|
+
|
|
112
|
+
it('should send request with the correct headers', async () => {
|
|
113
|
+
expect(server.calls[0].requestHeaders).toMatchInlineSnapshot(`
|
|
114
|
+
{
|
|
115
|
+
"config-header": "config-value",
|
|
116
|
+
"content-type": "application/json",
|
|
117
|
+
"shared-header": "config-shared",
|
|
118
|
+
"x-amz-auth": "test-auth",
|
|
119
|
+
}
|
|
120
|
+
`);
|
|
121
|
+
});
|
|
122
|
+
|
|
123
|
+
it('should return result with warnings', async () => {
|
|
124
|
+
expect(result.warnings).toMatchInlineSnapshot(`undefined`);
|
|
125
|
+
});
|
|
126
|
+
|
|
127
|
+
it('should return result with the correct ranking', async () => {
|
|
128
|
+
expect(result.ranking).toMatchInlineSnapshot(`
|
|
129
|
+
[
|
|
130
|
+
{
|
|
131
|
+
"index": 0,
|
|
132
|
+
"relevanceScore": 0.5110583305358887,
|
|
133
|
+
},
|
|
134
|
+
{
|
|
135
|
+
"index": 5,
|
|
136
|
+
"relevanceScore": 0.30241215229034424,
|
|
137
|
+
},
|
|
138
|
+
]
|
|
139
|
+
`);
|
|
140
|
+
});
|
|
141
|
+
|
|
142
|
+
it('should not return provider metadata (use response body instead)', async () => {
|
|
143
|
+
expect(result.providerMetadata).toMatchInlineSnapshot(`undefined`);
|
|
144
|
+
});
|
|
145
|
+
|
|
146
|
+
it('should return result with the correct response', async () => {
|
|
147
|
+
expect(result.response).toMatchInlineSnapshot(`
|
|
148
|
+
{
|
|
149
|
+
"body": {
|
|
150
|
+
"results": [
|
|
151
|
+
{
|
|
152
|
+
"index": 0,
|
|
153
|
+
"relevanceScore": 0.5110583305358887,
|
|
154
|
+
},
|
|
155
|
+
{
|
|
156
|
+
"index": 5,
|
|
157
|
+
"relevanceScore": 0.30241215229034424,
|
|
158
|
+
},
|
|
159
|
+
],
|
|
160
|
+
},
|
|
161
|
+
"headers": {
|
|
162
|
+
"content-length": "171",
|
|
163
|
+
"content-type": "application/json",
|
|
164
|
+
},
|
|
165
|
+
}
|
|
166
|
+
`);
|
|
167
|
+
});
|
|
168
|
+
});
|
|
169
|
+
|
|
170
|
+
describe('text documents', () => {
|
|
171
|
+
let result: Awaited<ReturnType<typeof model.doRerank>>;
|
|
172
|
+
|
|
173
|
+
beforeEach(async () => {
|
|
174
|
+
prepareJsonFixtureResponse('bedrock-reranking.1');
|
|
175
|
+
|
|
176
|
+
result = await model.doRerank({
|
|
177
|
+
documents: {
|
|
178
|
+
type: 'text',
|
|
179
|
+
values: ['sunny day at the beach', 'rainy day in the city'],
|
|
180
|
+
},
|
|
181
|
+
query: 'rainy day',
|
|
182
|
+
topN: 2,
|
|
183
|
+
providerOptions: {
|
|
184
|
+
bedrock: {
|
|
185
|
+
nextToken: 'test-token',
|
|
186
|
+
additionalModelRequestFields: {
|
|
187
|
+
test: 'test-value',
|
|
188
|
+
},
|
|
189
|
+
} satisfies BedrockRerankingOptions,
|
|
190
|
+
},
|
|
191
|
+
});
|
|
192
|
+
});
|
|
193
|
+
|
|
194
|
+
it('should send request with text documents', async () => {
|
|
195
|
+
expect(await server.calls[0].requestBodyJson).toMatchInlineSnapshot(`
|
|
196
|
+
{
|
|
197
|
+
"nextToken": "test-token",
|
|
198
|
+
"queries": [
|
|
199
|
+
{
|
|
200
|
+
"textQuery": {
|
|
201
|
+
"text": "rainy day",
|
|
202
|
+
},
|
|
203
|
+
"type": "TEXT",
|
|
204
|
+
},
|
|
205
|
+
],
|
|
206
|
+
"rerankingConfiguration": {
|
|
207
|
+
"bedrockRerankingConfiguration": {
|
|
208
|
+
"modelConfiguration": {
|
|
209
|
+
"additionalModelRequestFields": {
|
|
210
|
+
"test": "test-value",
|
|
211
|
+
},
|
|
212
|
+
"modelArn": "arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0",
|
|
213
|
+
},
|
|
214
|
+
"numberOfResults": 2,
|
|
215
|
+
},
|
|
216
|
+
"type": "BEDROCK_RERANKING_MODEL",
|
|
217
|
+
},
|
|
218
|
+
"sources": [
|
|
219
|
+
{
|
|
220
|
+
"inlineDocumentSource": {
|
|
221
|
+
"textDocument": {
|
|
222
|
+
"text": "sunny day at the beach",
|
|
223
|
+
},
|
|
224
|
+
"type": "TEXT",
|
|
225
|
+
},
|
|
226
|
+
"type": "INLINE",
|
|
227
|
+
},
|
|
228
|
+
{
|
|
229
|
+
"inlineDocumentSource": {
|
|
230
|
+
"textDocument": {
|
|
231
|
+
"text": "rainy day in the city",
|
|
232
|
+
},
|
|
233
|
+
"type": "TEXT",
|
|
234
|
+
},
|
|
235
|
+
"type": "INLINE",
|
|
236
|
+
},
|
|
237
|
+
],
|
|
238
|
+
}
|
|
239
|
+
`);
|
|
240
|
+
});
|
|
241
|
+
|
|
242
|
+
it('should send request with the correct headers', async () => {
|
|
243
|
+
expect(server.calls[0].requestHeaders).toMatchInlineSnapshot(`
|
|
244
|
+
{
|
|
245
|
+
"config-header": "config-value",
|
|
246
|
+
"content-type": "application/json",
|
|
247
|
+
"shared-header": "config-shared",
|
|
248
|
+
"x-amz-auth": "test-auth",
|
|
249
|
+
}
|
|
250
|
+
`);
|
|
251
|
+
});
|
|
252
|
+
|
|
253
|
+
it('should return result without warnings', async () => {
|
|
254
|
+
expect(result.warnings).toMatchInlineSnapshot(`undefined`);
|
|
255
|
+
});
|
|
256
|
+
|
|
257
|
+
it('should return result with the correct ranking', async () => {
|
|
258
|
+
expect(result.ranking).toMatchInlineSnapshot(`
|
|
259
|
+
[
|
|
260
|
+
{
|
|
261
|
+
"index": 0,
|
|
262
|
+
"relevanceScore": 0.5110583305358887,
|
|
263
|
+
},
|
|
264
|
+
{
|
|
265
|
+
"index": 5,
|
|
266
|
+
"relevanceScore": 0.30241215229034424,
|
|
267
|
+
},
|
|
268
|
+
]
|
|
269
|
+
`);
|
|
270
|
+
});
|
|
271
|
+
|
|
272
|
+
it('should not return provider metadata (use response body instead)', async () => {
|
|
273
|
+
expect(result.providerMetadata).toMatchInlineSnapshot(`undefined`);
|
|
274
|
+
});
|
|
275
|
+
|
|
276
|
+
it('should return result with the correct response', async () => {
|
|
277
|
+
expect(result.response).toMatchInlineSnapshot(`
|
|
278
|
+
{
|
|
279
|
+
"body": {
|
|
280
|
+
"results": [
|
|
281
|
+
{
|
|
282
|
+
"index": 0,
|
|
283
|
+
"relevanceScore": 0.5110583305358887,
|
|
284
|
+
},
|
|
285
|
+
{
|
|
286
|
+
"index": 5,
|
|
287
|
+
"relevanceScore": 0.30241215229034424,
|
|
288
|
+
},
|
|
289
|
+
],
|
|
290
|
+
},
|
|
291
|
+
"headers": {
|
|
292
|
+
"content-length": "171",
|
|
293
|
+
"content-type": "application/json",
|
|
294
|
+
},
|
|
295
|
+
}
|
|
296
|
+
`);
|
|
297
|
+
});
|
|
298
|
+
});
|
|
299
|
+
});
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import { RerankingModelV3 } from '@ai-sdk/provider';
|
|
2
|
+
import {
|
|
3
|
+
FetchFunction,
|
|
4
|
+
Resolvable,
|
|
5
|
+
combineHeaders,
|
|
6
|
+
createJsonErrorResponseHandler,
|
|
7
|
+
createJsonResponseHandler,
|
|
8
|
+
parseProviderOptions,
|
|
9
|
+
postJsonToApi,
|
|
10
|
+
resolve,
|
|
11
|
+
} from '@ai-sdk/provider-utils';
|
|
12
|
+
import { BedrockErrorSchema } from '../bedrock-error';
|
|
13
|
+
import {
|
|
14
|
+
BedrockRerankingInput,
|
|
15
|
+
bedrockRerankingResponseSchema,
|
|
16
|
+
} from './bedrock-reranking-api';
|
|
17
|
+
import {
|
|
18
|
+
BedrockRerankingModelId,
|
|
19
|
+
bedrockRerankingOptionsSchema,
|
|
20
|
+
} from './bedrock-reranking-options';
|
|
21
|
+
|
|
22
|
+
type BedrockRerankingConfig = {
|
|
23
|
+
baseUrl: () => string;
|
|
24
|
+
region: string;
|
|
25
|
+
headers: Resolvable<Record<string, string | undefined>>;
|
|
26
|
+
fetch?: FetchFunction;
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
type DoRerankResponse = Awaited<ReturnType<RerankingModelV3['doRerank']>>;
|
|
30
|
+
|
|
31
|
+
export class BedrockRerankingModel implements RerankingModelV3 {
|
|
32
|
+
readonly specificationVersion = 'v3';
|
|
33
|
+
readonly provider = 'amazon-bedrock';
|
|
34
|
+
|
|
35
|
+
constructor(
|
|
36
|
+
readonly modelId: BedrockRerankingModelId,
|
|
37
|
+
private readonly config: BedrockRerankingConfig,
|
|
38
|
+
) {}
|
|
39
|
+
|
|
40
|
+
async doRerank({
|
|
41
|
+
documents,
|
|
42
|
+
headers,
|
|
43
|
+
query,
|
|
44
|
+
topN,
|
|
45
|
+
abortSignal,
|
|
46
|
+
providerOptions,
|
|
47
|
+
}: Parameters<RerankingModelV3['doRerank']>[0]): Promise<DoRerankResponse> {
|
|
48
|
+
const bedrockOptions = await parseProviderOptions({
|
|
49
|
+
provider: 'bedrock',
|
|
50
|
+
providerOptions,
|
|
51
|
+
schema: bedrockRerankingOptionsSchema,
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
const {
|
|
55
|
+
value: response,
|
|
56
|
+
responseHeaders,
|
|
57
|
+
rawValue,
|
|
58
|
+
} = await postJsonToApi({
|
|
59
|
+
url: `${this.config.baseUrl()}/rerank`,
|
|
60
|
+
headers: await resolve(
|
|
61
|
+
combineHeaders(await resolve(this.config.headers), headers),
|
|
62
|
+
),
|
|
63
|
+
body: {
|
|
64
|
+
nextToken: bedrockOptions?.nextToken,
|
|
65
|
+
queries: [
|
|
66
|
+
{
|
|
67
|
+
textQuery: { text: query },
|
|
68
|
+
type: 'TEXT',
|
|
69
|
+
},
|
|
70
|
+
],
|
|
71
|
+
rerankingConfiguration: {
|
|
72
|
+
bedrockRerankingConfiguration: {
|
|
73
|
+
modelConfiguration: {
|
|
74
|
+
modelArn: `arn:aws:bedrock:${this.config.region}::foundation-model/${this.modelId}`,
|
|
75
|
+
additionalModelRequestFields:
|
|
76
|
+
bedrockOptions?.additionalModelRequestFields,
|
|
77
|
+
},
|
|
78
|
+
numberOfResults: topN,
|
|
79
|
+
},
|
|
80
|
+
type: 'BEDROCK_RERANKING_MODEL',
|
|
81
|
+
},
|
|
82
|
+
sources: documents.values.map(value => ({
|
|
83
|
+
type: 'INLINE' as const,
|
|
84
|
+
inlineDocumentSource:
|
|
85
|
+
documents.type === 'text'
|
|
86
|
+
? {
|
|
87
|
+
type: 'TEXT' as const,
|
|
88
|
+
textDocument: { text: value as string },
|
|
89
|
+
}
|
|
90
|
+
: {
|
|
91
|
+
type: 'JSON' as const,
|
|
92
|
+
jsonDocument: value,
|
|
93
|
+
},
|
|
94
|
+
})),
|
|
95
|
+
} satisfies BedrockRerankingInput,
|
|
96
|
+
failedResponseHandler: createJsonErrorResponseHandler({
|
|
97
|
+
errorSchema: BedrockErrorSchema,
|
|
98
|
+
errorToMessage: error => `${error.type}: ${error.message}`,
|
|
99
|
+
}),
|
|
100
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
101
|
+
bedrockRerankingResponseSchema,
|
|
102
|
+
),
|
|
103
|
+
fetch: this.config.fetch,
|
|
104
|
+
abortSignal,
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
ranking: response.results,
|
|
109
|
+
response: {
|
|
110
|
+
headers: responseHeaders,
|
|
111
|
+
body: rawValue,
|
|
112
|
+
},
|
|
113
|
+
};
|
|
114
|
+
}
|
|
115
|
+
}
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import { lazySchema, zodSchema } from '@ai-sdk/provider-utils';
|
|
2
|
+
import { z } from 'zod/v4';
|
|
3
|
+
|
|
4
|
+
// https://docs.aws.amazon.com/bedrock/latest/userguide/rerank-supported.html
|
|
5
|
+
export type BedrockRerankingModelId =
|
|
6
|
+
| 'amazon.rerank-v1:0'
|
|
7
|
+
| 'cohere.rerank-v3-5:0'
|
|
8
|
+
| (string & {});
|
|
9
|
+
|
|
10
|
+
export type BedrockRerankingOptions = {
|
|
11
|
+
/**
|
|
12
|
+
* If the total number of results was greater than could fit in a response, a token is returned in the nextToken field. You can enter that token in this field to return the next batch of results.
|
|
13
|
+
*/
|
|
14
|
+
nextToken?: string;
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Additional model request fields to pass to the model.
|
|
18
|
+
*/
|
|
19
|
+
additionalModelRequestFields?: Record<string, unknown>;
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
export const bedrockRerankingOptionsSchema = lazySchema(() =>
|
|
23
|
+
zodSchema(
|
|
24
|
+
z.object({
|
|
25
|
+
/**
|
|
26
|
+
* If the total number of results was greater than could fit in a response, a token is returned in the nextToken field. You can enter that token in this field to return the next batch of results.
|
|
27
|
+
*/
|
|
28
|
+
nextToken: z.string().optional(),
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
* Additional model request fields to pass to the model.
|
|
32
|
+
*/
|
|
33
|
+
additionalModelRequestFields: z.record(z.string(), z.any()).optional(),
|
|
34
|
+
}),
|
|
35
|
+
),
|
|
36
|
+
);
|