@ai-sdk/cohere 3.0.8 → 3.0.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,243 @@
1
+ import { createTestServer } from '@ai-sdk/test-server/with-vitest';
2
+ import { createCohere } from '../cohere-provider';
3
+ import { describe, it, expect, beforeEach } from 'vitest';
4
+ import fs from 'node:fs';
5
+ import { CohereRerankingOptions } from './cohere-reranking-options';
6
+
7
+ const provider = createCohere({ apiKey: 'test-api-key' });
8
+ const model = provider.rerankingModel('rerank-english-v3.0');
9
+
10
+ describe('doRerank', () => {
11
+ const server = createTestServer({
12
+ 'https://api.cohere.com/v2/rerank': {},
13
+ });
14
+
15
+ function prepareJsonFixtureResponse(filename: string) {
16
+ server.urls['https://api.cohere.com/v2/rerank'].response = {
17
+ type: 'json-value',
18
+ body: JSON.parse(
19
+ fs.readFileSync(`src/reranking/__fixtures__/${filename}.json`, 'utf8'),
20
+ ),
21
+ };
22
+ return;
23
+ }
24
+
25
+ describe('json documents', () => {
26
+ let result: Awaited<ReturnType<typeof model.doRerank>>;
27
+
28
+ beforeEach(async () => {
29
+ prepareJsonFixtureResponse('cohere-reranking.1');
30
+
31
+ result = await model.doRerank({
32
+ documents: {
33
+ type: 'object',
34
+ values: [
35
+ { example: 'sunny day at the beach' },
36
+ { example: 'rainy day in the city' },
37
+ ],
38
+ },
39
+ query: 'rainy day',
40
+ topN: 2,
41
+ providerOptions: {
42
+ cohere: {
43
+ maxTokensPerDoc: 1000,
44
+ priority: 1,
45
+ } satisfies CohereRerankingOptions,
46
+ },
47
+ });
48
+ });
49
+
50
+ it('should send request with stringified json documents', async () => {
51
+ expect(await server.calls[0].requestBodyJson).toMatchInlineSnapshot(`
52
+ {
53
+ "documents": [
54
+ "{"example":"sunny day at the beach"}",
55
+ "{"example":"rainy day in the city"}",
56
+ ],
57
+ "max_tokens_per_doc": 1000,
58
+ "model": "rerank-english-v3.0",
59
+ "priority": 1,
60
+ "query": "rainy day",
61
+ "top_n": 2,
62
+ }
63
+ `);
64
+ });
65
+
66
+ it('should send request with the correct headers', async () => {
67
+ expect(server.calls[0].requestHeaders).toMatchInlineSnapshot(`
68
+ {
69
+ "authorization": "Bearer test-api-key",
70
+ "content-type": "application/json",
71
+ }
72
+ `);
73
+ });
74
+
75
+ it('should return result with warnings', async () => {
76
+ expect(result.warnings).toMatchInlineSnapshot(`
77
+ [
78
+ {
79
+ "details": "Object documents are converted to strings.",
80
+ "feature": "object documents",
81
+ "type": "compatibility",
82
+ },
83
+ ]
84
+ `);
85
+ });
86
+
87
+ it('should return result with the correct ranking', async () => {
88
+ expect(result.ranking).toMatchInlineSnapshot(`
89
+ [
90
+ {
91
+ "index": 1,
92
+ "relevanceScore": 0.10183054,
93
+ },
94
+ {
95
+ "index": 0,
96
+ "relevanceScore": 0.03762639,
97
+ },
98
+ ]
99
+ `);
100
+ });
101
+
102
+ it('should not return provider metadata (use response body instead)', async () => {
103
+ expect(result.providerMetadata).toMatchInlineSnapshot(`undefined`);
104
+ });
105
+
106
+ it('should return result with the correct response', async () => {
107
+ expect(result.response).toMatchInlineSnapshot(`
108
+ {
109
+ "body": {
110
+ "id": "b44fe75b-e3d3-489a-b61e-1a1aede3ef72",
111
+ "meta": {
112
+ "api_version": {
113
+ "version": "2",
114
+ },
115
+ "billed_units": {
116
+ "search_units": 1,
117
+ },
118
+ },
119
+ "results": [
120
+ {
121
+ "index": 1,
122
+ "relevance_score": 0.10183054,
123
+ },
124
+ {
125
+ "index": 0,
126
+ "relevance_score": 0.03762639,
127
+ },
128
+ ],
129
+ },
130
+ "headers": {
131
+ "content-length": "212",
132
+ "content-type": "application/json",
133
+ },
134
+ "id": "b44fe75b-e3d3-489a-b61e-1a1aede3ef72",
135
+ }
136
+ `);
137
+ });
138
+ });
139
+
140
+ describe('text documents', () => {
141
+ let result: Awaited<ReturnType<typeof model.doRerank>>;
142
+
143
+ beforeEach(async () => {
144
+ prepareJsonFixtureResponse('cohere-reranking.1');
145
+
146
+ result = await model.doRerank({
147
+ documents: {
148
+ type: 'text',
149
+ values: ['sunny day at the beach', 'rainy day in the city'],
150
+ },
151
+ query: 'rainy day',
152
+ topN: 2,
153
+ providerOptions: {
154
+ cohere: {
155
+ maxTokensPerDoc: 1000,
156
+ priority: 1,
157
+ } satisfies CohereRerankingOptions,
158
+ },
159
+ });
160
+ });
161
+
162
+ it('should send request with text documents', async () => {
163
+ expect(await server.calls[0].requestBodyJson).toMatchInlineSnapshot(`
164
+ {
165
+ "documents": [
166
+ "sunny day at the beach",
167
+ "rainy day in the city",
168
+ ],
169
+ "max_tokens_per_doc": 1000,
170
+ "model": "rerank-english-v3.0",
171
+ "priority": 1,
172
+ "query": "rainy day",
173
+ "top_n": 2,
174
+ }
175
+ `);
176
+ });
177
+
178
+ it('should send request with the correct headers', async () => {
179
+ expect(server.calls[0].requestHeaders).toMatchInlineSnapshot(`
180
+ {
181
+ "authorization": "Bearer test-api-key",
182
+ "content-type": "application/json",
183
+ }
184
+ `);
185
+ });
186
+
187
+ it('should return result without warnings', async () => {
188
+ expect(result.warnings).toMatchInlineSnapshot(`[]`);
189
+ });
190
+
191
+ it('should return result with the correct ranking', async () => {
192
+ expect(result.ranking).toMatchInlineSnapshot(`
193
+ [
194
+ {
195
+ "index": 1,
196
+ "relevanceScore": 0.10183054,
197
+ },
198
+ {
199
+ "index": 0,
200
+ "relevanceScore": 0.03762639,
201
+ },
202
+ ]
203
+ `);
204
+ });
205
+
206
+ it('should not return provider metadata (use response body instead)', async () => {
207
+ expect(result.providerMetadata).toMatchInlineSnapshot(`undefined`);
208
+ });
209
+
210
+ it('should return result with the correct response', async () => {
211
+ expect(result.response).toMatchInlineSnapshot(`
212
+ {
213
+ "body": {
214
+ "id": "b44fe75b-e3d3-489a-b61e-1a1aede3ef72",
215
+ "meta": {
216
+ "api_version": {
217
+ "version": "2",
218
+ },
219
+ "billed_units": {
220
+ "search_units": 1,
221
+ },
222
+ },
223
+ "results": [
224
+ {
225
+ "index": 1,
226
+ "relevance_score": 0.10183054,
227
+ },
228
+ {
229
+ "index": 0,
230
+ "relevance_score": 0.03762639,
231
+ },
232
+ ],
233
+ },
234
+ "headers": {
235
+ "content-length": "212",
236
+ "content-type": "application/json",
237
+ },
238
+ "id": "b44fe75b-e3d3-489a-b61e-1a1aede3ef72",
239
+ }
240
+ `);
241
+ });
242
+ });
243
+ });
@@ -0,0 +1,107 @@
1
+ import { RerankingModelV3, SharedV3Warning } from '@ai-sdk/provider';
2
+ import {
3
+ combineHeaders,
4
+ createJsonResponseHandler,
5
+ FetchFunction,
6
+ parseProviderOptions,
7
+ postJsonToApi,
8
+ } from '@ai-sdk/provider-utils';
9
+ import { cohereFailedResponseHandler } from '../cohere-error';
10
+ import {
11
+ CohereRerankingInput,
12
+ cohereRerankingResponseSchema,
13
+ } from './cohere-reranking-api';
14
+ import {
15
+ CohereRerankingModelId,
16
+ cohereRerankingOptionsSchema,
17
+ } from './cohere-reranking-options';
18
+
19
+ type CohereRerankingConfig = {
20
+ provider: string;
21
+ baseURL: string;
22
+ headers: () => Record<string, string | undefined>;
23
+ fetch?: FetchFunction;
24
+ };
25
+
26
+ export class CohereRerankingModel implements RerankingModelV3 {
27
+ readonly specificationVersion = 'v3';
28
+ readonly modelId: CohereRerankingModelId;
29
+
30
+ private readonly config: CohereRerankingConfig;
31
+
32
+ constructor(modelId: CohereRerankingModelId, config: CohereRerankingConfig) {
33
+ this.modelId = modelId;
34
+ this.config = config;
35
+ }
36
+
37
+ get provider(): string {
38
+ return this.config.provider;
39
+ }
40
+
41
+ // current implementation is based on v2 of the API: https://docs.cohere.com/v2/reference/rerank
42
+ async doRerank({
43
+ documents,
44
+ headers,
45
+ query,
46
+ topN,
47
+ abortSignal,
48
+ providerOptions,
49
+ }: Parameters<RerankingModelV3['doRerank']>[0]): Promise<
50
+ Awaited<ReturnType<RerankingModelV3['doRerank']>>
51
+ > {
52
+ const rerankingOptions = await parseProviderOptions({
53
+ provider: 'cohere',
54
+ providerOptions,
55
+ schema: cohereRerankingOptionsSchema,
56
+ });
57
+
58
+ const warnings: SharedV3Warning[] = [];
59
+
60
+ if (documents.type === 'object') {
61
+ warnings.push({
62
+ type: 'compatibility',
63
+ feature: 'object documents',
64
+ details: 'Object documents are converted to strings.',
65
+ });
66
+ }
67
+
68
+ const {
69
+ responseHeaders,
70
+ value: response,
71
+ rawValue,
72
+ } = await postJsonToApi({
73
+ url: `${this.config.baseURL}/rerank`,
74
+ headers: combineHeaders(this.config.headers(), headers),
75
+ body: {
76
+ model: this.modelId,
77
+ query,
78
+ documents:
79
+ documents.type === 'text'
80
+ ? documents.values
81
+ : documents.values.map(value => JSON.stringify(value)),
82
+ top_n: topN,
83
+ max_tokens_per_doc: rerankingOptions?.maxTokensPerDoc,
84
+ priority: rerankingOptions?.priority,
85
+ } satisfies CohereRerankingInput,
86
+ failedResponseHandler: cohereFailedResponseHandler,
87
+ successfulResponseHandler: createJsonResponseHandler(
88
+ cohereRerankingResponseSchema,
89
+ ),
90
+ abortSignal,
91
+ fetch: this.config.fetch,
92
+ });
93
+
94
+ return {
95
+ ranking: response.results.map(result => ({
96
+ index: result.index,
97
+ relevanceScore: result.relevance_score,
98
+ })),
99
+ warnings,
100
+ response: {
101
+ id: response.id ?? undefined,
102
+ headers: responseHeaders,
103
+ body: rawValue,
104
+ },
105
+ };
106
+ }
107
+ }
@@ -0,0 +1,35 @@
1
+ import { FlexibleSchema, lazySchema, zodSchema } from '@ai-sdk/provider-utils';
2
+ import { z } from 'zod/v4';
3
+
4
+ // https://docs.cohere.com/docs/rerank
5
+ export type CohereRerankingModelId =
6
+ | 'rerank-v3.5'
7
+ | 'rerank-english-v3.0'
8
+ | 'rerank-multilingual-v3.0'
9
+ | (string & {});
10
+
11
+ export type CohereRerankingOptions = {
12
+ /**
13
+ * Long documents will be automatically truncated to the specified number of tokens.
14
+ *
15
+ * @default 4096
16
+ */
17
+ maxTokensPerDoc?: number;
18
+
19
+ /**
20
+ * The priority of the request.
21
+ *
22
+ * @default 0
23
+ */
24
+ priority?: number;
25
+ };
26
+
27
+ export const cohereRerankingOptionsSchema: FlexibleSchema<CohereRerankingOptions> =
28
+ lazySchema(() =>
29
+ zodSchema(
30
+ z.object({
31
+ maxTokensPerDoc: z.number().optional(),
32
+ priority: z.number().optional(),
33
+ }),
34
+ ),
35
+ );
package/src/version.ts ADDED
@@ -0,0 +1,6 @@
1
+ // Version string of this package injected at build time.
2
+ declare const __PACKAGE_VERSION__: string | undefined;
3
+ export const VERSION: string =
4
+ typeof __PACKAGE_VERSION__ !== 'undefined'
5
+ ? __PACKAGE_VERSION__
6
+ : '0.0.0-test';