@ai-sdk/cohere 1.2.10 → 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +254 -20
- package/dist/index.d.mts +6 -28
- package/dist/index.d.ts +6 -28
- package/dist/index.js +143 -144
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +144 -144
- package/dist/index.mjs.map +1 -1
- package/package.json +13 -12
package/dist/index.mjs
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
// src/cohere-provider.ts
|
|
2
|
+
import {
|
|
3
|
+
NoSuchModelError
|
|
4
|
+
} from "@ai-sdk/provider";
|
|
2
5
|
import {
|
|
3
6
|
loadApiKey,
|
|
4
7
|
withoutTrailingSlash
|
|
5
8
|
} from "@ai-sdk/provider-utils";
|
|
6
9
|
|
|
7
10
|
// src/cohere-chat-language-model.ts
|
|
8
|
-
import {
|
|
9
|
-
UnsupportedFunctionalityError as UnsupportedFunctionalityError3
|
|
10
|
-
} from "@ai-sdk/provider";
|
|
11
11
|
import {
|
|
12
12
|
combineHeaders,
|
|
13
13
|
createEventSourceResponseHandler,
|
|
@@ -47,9 +47,9 @@ function convertToCohereChatPrompt(prompt) {
|
|
|
47
47
|
case "text": {
|
|
48
48
|
return part.text;
|
|
49
49
|
}
|
|
50
|
-
case "
|
|
50
|
+
case "file": {
|
|
51
51
|
throw new UnsupportedFunctionalityError({
|
|
52
|
-
functionality: "
|
|
52
|
+
functionality: "File URL data"
|
|
53
53
|
});
|
|
54
54
|
}
|
|
55
55
|
}
|
|
@@ -127,9 +127,11 @@ function mapCohereFinishReason(finishReason) {
|
|
|
127
127
|
import {
|
|
128
128
|
UnsupportedFunctionalityError as UnsupportedFunctionalityError2
|
|
129
129
|
} from "@ai-sdk/provider";
|
|
130
|
-
function prepareTools(
|
|
131
|
-
|
|
132
|
-
|
|
130
|
+
function prepareTools({
|
|
131
|
+
tools,
|
|
132
|
+
toolChoice
|
|
133
|
+
}) {
|
|
134
|
+
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
|
|
133
135
|
const toolWarnings = [];
|
|
134
136
|
if (tools == null) {
|
|
135
137
|
return { tools: void 0, toolChoice: void 0, toolWarnings };
|
|
@@ -149,7 +151,6 @@ function prepareTools(mode) {
|
|
|
149
151
|
});
|
|
150
152
|
}
|
|
151
153
|
}
|
|
152
|
-
const toolChoice = mode.toolChoice;
|
|
153
154
|
if (toolChoice == null) {
|
|
154
155
|
return { tools: cohereTools, toolChoice: void 0, toolWarnings };
|
|
155
156
|
}
|
|
@@ -172,7 +173,7 @@ function prepareTools(mode) {
|
|
|
172
173
|
default: {
|
|
173
174
|
const _exhaustiveCheck = type;
|
|
174
175
|
throw new UnsupportedFunctionalityError2({
|
|
175
|
-
functionality: `
|
|
176
|
+
functionality: `tool choice type: ${_exhaustiveCheck}`
|
|
176
177
|
});
|
|
177
178
|
}
|
|
178
179
|
}
|
|
@@ -180,20 +181,20 @@ function prepareTools(mode) {
|
|
|
180
181
|
|
|
181
182
|
// src/cohere-chat-language-model.ts
|
|
182
183
|
var CohereChatLanguageModel = class {
|
|
183
|
-
constructor(modelId,
|
|
184
|
-
this.specificationVersion = "
|
|
185
|
-
this.
|
|
184
|
+
constructor(modelId, config) {
|
|
185
|
+
this.specificationVersion = "v2";
|
|
186
|
+
this.supportedUrls = {
|
|
187
|
+
// No URLs are supported.
|
|
188
|
+
};
|
|
186
189
|
this.modelId = modelId;
|
|
187
|
-
this.settings = settings;
|
|
188
190
|
this.config = config;
|
|
189
191
|
}
|
|
190
192
|
get provider() {
|
|
191
193
|
return this.config.provider;
|
|
192
194
|
}
|
|
193
195
|
getArgs({
|
|
194
|
-
mode,
|
|
195
196
|
prompt,
|
|
196
|
-
|
|
197
|
+
maxOutputTokens,
|
|
197
198
|
temperature,
|
|
198
199
|
topP,
|
|
199
200
|
topK,
|
|
@@ -201,78 +202,42 @@ var CohereChatLanguageModel = class {
|
|
|
201
202
|
presencePenalty,
|
|
202
203
|
stopSequences,
|
|
203
204
|
responseFormat,
|
|
204
|
-
seed
|
|
205
|
+
seed,
|
|
206
|
+
tools,
|
|
207
|
+
toolChoice
|
|
205
208
|
}) {
|
|
206
|
-
var _a;
|
|
207
|
-
const type = mode.type;
|
|
208
209
|
const chatPrompt = convertToCohereChatPrompt(prompt);
|
|
209
|
-
const
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
210
|
+
const {
|
|
211
|
+
tools: cohereTools,
|
|
212
|
+
toolChoice: cohereToolChoice,
|
|
213
|
+
toolWarnings
|
|
214
|
+
} = prepareTools({ tools, toolChoice });
|
|
215
|
+
return {
|
|
216
|
+
args: {
|
|
217
|
+
// model id:
|
|
218
|
+
model: this.modelId,
|
|
219
|
+
// standardized settings:
|
|
220
|
+
frequency_penalty: frequencyPenalty,
|
|
221
|
+
presence_penalty: presencePenalty,
|
|
222
|
+
max_tokens: maxOutputTokens,
|
|
223
|
+
temperature,
|
|
224
|
+
p: topP,
|
|
225
|
+
k: topK,
|
|
226
|
+
seed,
|
|
227
|
+
stop_sequences: stopSequences,
|
|
228
|
+
// response format:
|
|
229
|
+
response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0,
|
|
230
|
+
// messages:
|
|
231
|
+
messages: chatPrompt,
|
|
232
|
+
// tools:
|
|
233
|
+
tools: cohereTools,
|
|
234
|
+
tool_choice: cohereToolChoice
|
|
235
|
+
},
|
|
236
|
+
warnings: toolWarnings
|
|
225
237
|
};
|
|
226
|
-
switch (type) {
|
|
227
|
-
case "regular": {
|
|
228
|
-
const { tools, toolChoice, toolWarnings } = prepareTools(mode);
|
|
229
|
-
return {
|
|
230
|
-
args: {
|
|
231
|
-
...baseArgs,
|
|
232
|
-
tools,
|
|
233
|
-
tool_choice: toolChoice
|
|
234
|
-
},
|
|
235
|
-
warnings: toolWarnings
|
|
236
|
-
};
|
|
237
|
-
}
|
|
238
|
-
case "object-json": {
|
|
239
|
-
return {
|
|
240
|
-
args: {
|
|
241
|
-
...baseArgs,
|
|
242
|
-
response_format: mode.schema == null ? { type: "json_object" } : { type: "json_object", json_schema: mode.schema }
|
|
243
|
-
},
|
|
244
|
-
warnings: []
|
|
245
|
-
};
|
|
246
|
-
}
|
|
247
|
-
case "object-tool": {
|
|
248
|
-
return {
|
|
249
|
-
args: {
|
|
250
|
-
...baseArgs,
|
|
251
|
-
tools: [
|
|
252
|
-
{
|
|
253
|
-
type: "function",
|
|
254
|
-
function: {
|
|
255
|
-
name: mode.tool.name,
|
|
256
|
-
description: (_a = mode.tool.description) != null ? _a : "",
|
|
257
|
-
parameters: mode.tool.parameters
|
|
258
|
-
}
|
|
259
|
-
}
|
|
260
|
-
],
|
|
261
|
-
tool_choice: "REQUIRED"
|
|
262
|
-
},
|
|
263
|
-
warnings: []
|
|
264
|
-
};
|
|
265
|
-
}
|
|
266
|
-
default: {
|
|
267
|
-
const _exhaustiveCheck = type;
|
|
268
|
-
throw new UnsupportedFunctionalityError3({
|
|
269
|
-
functionality: `Unsupported mode: ${_exhaustiveCheck}`
|
|
270
|
-
});
|
|
271
|
-
}
|
|
272
|
-
}
|
|
273
238
|
}
|
|
274
239
|
async doGenerate(options) {
|
|
275
|
-
var _a, _b, _c
|
|
240
|
+
var _a, _b, _c;
|
|
276
241
|
const { args, warnings } = this.getArgs(options);
|
|
277
242
|
const {
|
|
278
243
|
responseHeaders,
|
|
@@ -289,38 +254,40 @@ var CohereChatLanguageModel = class {
|
|
|
289
254
|
abortSignal: options.abortSignal,
|
|
290
255
|
fetch: this.config.fetch
|
|
291
256
|
});
|
|
292
|
-
const
|
|
293
|
-
const text = (
|
|
257
|
+
const content = [];
|
|
258
|
+
const text = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text;
|
|
259
|
+
if (text != null && text.length > 0) {
|
|
260
|
+
content.push({ type: "text", text });
|
|
261
|
+
}
|
|
262
|
+
if (response.message.tool_calls != null) {
|
|
263
|
+
for (const toolCall of response.message.tool_calls) {
|
|
264
|
+
content.push({
|
|
265
|
+
type: "tool-call",
|
|
266
|
+
toolCallId: toolCall.id,
|
|
267
|
+
toolName: toolCall.function.name,
|
|
268
|
+
// Cohere sometimes returns `null` for tool call arguments for tools
|
|
269
|
+
// defined as having no arguments.
|
|
270
|
+
args: toolCall.function.arguments.replace(/^null$/, "{}"),
|
|
271
|
+
toolCallType: "function"
|
|
272
|
+
});
|
|
273
|
+
}
|
|
274
|
+
}
|
|
294
275
|
return {
|
|
295
|
-
|
|
296
|
-
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
|
|
297
|
-
toolCallId: toolCall.id,
|
|
298
|
-
toolName: toolCall.function.name,
|
|
299
|
-
// Cohere sometimes returns `null` for tool call arguments for tools
|
|
300
|
-
// defined as having no arguments.
|
|
301
|
-
args: toolCall.function.arguments.replace(/^null$/, "{}"),
|
|
302
|
-
toolCallType: "function"
|
|
303
|
-
})) : [],
|
|
276
|
+
content,
|
|
304
277
|
finishReason: mapCohereFinishReason(response.finish_reason),
|
|
305
278
|
usage: {
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
rawCall: {
|
|
310
|
-
rawPrompt: {
|
|
311
|
-
messages
|
|
312
|
-
},
|
|
313
|
-
rawSettings
|
|
279
|
+
inputTokens: response.usage.tokens.input_tokens,
|
|
280
|
+
outputTokens: response.usage.tokens.output_tokens,
|
|
281
|
+
totalTokens: response.usage.tokens.input_tokens + response.usage.tokens.output_tokens
|
|
314
282
|
},
|
|
283
|
+
request: { body: args },
|
|
315
284
|
response: {
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
rawResponse: {
|
|
285
|
+
// TODO timestamp, model id
|
|
286
|
+
id: (_c = response.generation_id) != null ? _c : void 0,
|
|
319
287
|
headers: responseHeaders,
|
|
320
288
|
body: rawResponse
|
|
321
289
|
},
|
|
322
|
-
warnings
|
|
323
|
-
request: { body: JSON.stringify(args) }
|
|
290
|
+
warnings
|
|
324
291
|
};
|
|
325
292
|
}
|
|
326
293
|
async doStream(options) {
|
|
@@ -336,11 +303,11 @@ var CohereChatLanguageModel = class {
|
|
|
336
303
|
abortSignal: options.abortSignal,
|
|
337
304
|
fetch: this.config.fetch
|
|
338
305
|
});
|
|
339
|
-
const { messages, ...rawSettings } = args;
|
|
340
306
|
let finishReason = "unknown";
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
307
|
+
const usage = {
|
|
308
|
+
inputTokens: void 0,
|
|
309
|
+
outputTokens: void 0,
|
|
310
|
+
totalTokens: void 0
|
|
344
311
|
};
|
|
345
312
|
let pendingToolCallDelta = {
|
|
346
313
|
toolCallId: "",
|
|
@@ -350,6 +317,9 @@ var CohereChatLanguageModel = class {
|
|
|
350
317
|
return {
|
|
351
318
|
stream: response.pipeThrough(
|
|
352
319
|
new TransformStream({
|
|
320
|
+
start(controller) {
|
|
321
|
+
controller.enqueue({ type: "stream-start", warnings });
|
|
322
|
+
},
|
|
353
323
|
transform(chunk, controller) {
|
|
354
324
|
var _a, _b;
|
|
355
325
|
if (!chunk.success) {
|
|
@@ -362,8 +332,8 @@ var CohereChatLanguageModel = class {
|
|
|
362
332
|
switch (type) {
|
|
363
333
|
case "content-delta": {
|
|
364
334
|
controller.enqueue({
|
|
365
|
-
type: "text
|
|
366
|
-
|
|
335
|
+
type: "text",
|
|
336
|
+
text: value.delta.message.content.text
|
|
367
337
|
});
|
|
368
338
|
return;
|
|
369
339
|
}
|
|
@@ -422,10 +392,9 @@ var CohereChatLanguageModel = class {
|
|
|
422
392
|
case "message-end": {
|
|
423
393
|
finishReason = mapCohereFinishReason(value.delta.finish_reason);
|
|
424
394
|
const tokens = value.delta.usage.tokens;
|
|
425
|
-
usage =
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
};
|
|
395
|
+
usage.inputTokens = tokens.input_tokens;
|
|
396
|
+
usage.outputTokens = tokens.output_tokens;
|
|
397
|
+
usage.totalTokens = tokens.input_tokens + tokens.output_tokens;
|
|
429
398
|
}
|
|
430
399
|
default: {
|
|
431
400
|
return;
|
|
@@ -441,15 +410,8 @@ var CohereChatLanguageModel = class {
|
|
|
441
410
|
}
|
|
442
411
|
})
|
|
443
412
|
),
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
messages
|
|
447
|
-
},
|
|
448
|
-
rawSettings
|
|
449
|
-
},
|
|
450
|
-
rawResponse: { headers: responseHeaders },
|
|
451
|
-
warnings,
|
|
452
|
-
request: { body: JSON.stringify({ ...args, stream: true }) }
|
|
413
|
+
request: { body: { ...args, stream: true } },
|
|
414
|
+
response: { headers: responseHeaders }
|
|
453
415
|
};
|
|
454
416
|
}
|
|
455
417
|
};
|
|
@@ -577,16 +539,41 @@ import {
|
|
|
577
539
|
import {
|
|
578
540
|
combineHeaders as combineHeaders2,
|
|
579
541
|
createJsonResponseHandler as createJsonResponseHandler2,
|
|
542
|
+
parseProviderOptions,
|
|
580
543
|
postJsonToApi as postJsonToApi2
|
|
581
544
|
} from "@ai-sdk/provider-utils";
|
|
545
|
+
import { z as z4 } from "zod";
|
|
546
|
+
|
|
547
|
+
// src/cohere-embedding-options.ts
|
|
582
548
|
import { z as z3 } from "zod";
|
|
549
|
+
var cohereEmbeddingOptions = z3.object({
|
|
550
|
+
/**
|
|
551
|
+
* Specifies the type of input passed to the model. Default is `search_query`.
|
|
552
|
+
*
|
|
553
|
+
* - "search_document": Used for embeddings stored in a vector database for search use-cases.
|
|
554
|
+
* - "search_query": Used for embeddings of search queries run against a vector DB to find relevant documents.
|
|
555
|
+
* - "classification": Used for embeddings passed through a text classifier.
|
|
556
|
+
* - "clustering": Used for embeddings run through a clustering algorithm.
|
|
557
|
+
*/
|
|
558
|
+
inputType: z3.enum(["search_document", "search_query", "classification", "clustering"]).optional(),
|
|
559
|
+
/**
|
|
560
|
+
* Specifies how the API will handle inputs longer than the maximum token length.
|
|
561
|
+
* Default is `END`.
|
|
562
|
+
*
|
|
563
|
+
* - "NONE": If selected, when the input exceeds the maximum input token length will return an error.
|
|
564
|
+
* - "START": Will discard the start of the input until the remaining input is exactly the maximum input token length for the model.
|
|
565
|
+
* - "END": Will discard the end of the input until the remaining input is exactly the maximum input token length for the model.
|
|
566
|
+
*/
|
|
567
|
+
truncate: z3.enum(["NONE", "START", "END"]).optional()
|
|
568
|
+
});
|
|
569
|
+
|
|
570
|
+
// src/cohere-embedding-model.ts
|
|
583
571
|
var CohereEmbeddingModel = class {
|
|
584
|
-
constructor(modelId,
|
|
585
|
-
this.specificationVersion = "
|
|
572
|
+
constructor(modelId, config) {
|
|
573
|
+
this.specificationVersion = "v2";
|
|
586
574
|
this.maxEmbeddingsPerCall = 96;
|
|
587
575
|
this.supportsParallelCalls = true;
|
|
588
576
|
this.modelId = modelId;
|
|
589
|
-
this.settings = settings;
|
|
590
577
|
this.config = config;
|
|
591
578
|
}
|
|
592
579
|
get provider() {
|
|
@@ -595,9 +582,15 @@ var CohereEmbeddingModel = class {
|
|
|
595
582
|
async doEmbed({
|
|
596
583
|
values,
|
|
597
584
|
headers,
|
|
598
|
-
abortSignal
|
|
585
|
+
abortSignal,
|
|
586
|
+
providerOptions
|
|
599
587
|
}) {
|
|
600
588
|
var _a;
|
|
589
|
+
const embeddingOptions = await parseProviderOptions({
|
|
590
|
+
provider: "cohere",
|
|
591
|
+
providerOptions,
|
|
592
|
+
schema: cohereEmbeddingOptions
|
|
593
|
+
});
|
|
601
594
|
if (values.length > this.maxEmbeddingsPerCall) {
|
|
602
595
|
throw new TooManyEmbeddingValuesForCallError({
|
|
603
596
|
provider: this.provider,
|
|
@@ -606,7 +599,11 @@ var CohereEmbeddingModel = class {
|
|
|
606
599
|
values
|
|
607
600
|
});
|
|
608
601
|
}
|
|
609
|
-
const {
|
|
602
|
+
const {
|
|
603
|
+
responseHeaders,
|
|
604
|
+
value: response,
|
|
605
|
+
rawValue
|
|
606
|
+
} = await postJsonToApi2({
|
|
610
607
|
url: `${this.config.baseURL}/embed`,
|
|
611
608
|
headers: combineHeaders2(this.config.headers(), headers),
|
|
612
609
|
body: {
|
|
@@ -616,8 +613,8 @@ var CohereEmbeddingModel = class {
|
|
|
616
613
|
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
|
|
617
614
|
embedding_types: ["float"],
|
|
618
615
|
texts: values,
|
|
619
|
-
input_type: (_a =
|
|
620
|
-
truncate:
|
|
616
|
+
input_type: (_a = embeddingOptions == null ? void 0 : embeddingOptions.inputType) != null ? _a : "search_query",
|
|
617
|
+
truncate: embeddingOptions == null ? void 0 : embeddingOptions.truncate
|
|
621
618
|
},
|
|
622
619
|
failedResponseHandler: cohereFailedResponseHandler,
|
|
623
620
|
successfulResponseHandler: createJsonResponseHandler2(
|
|
@@ -629,17 +626,17 @@ var CohereEmbeddingModel = class {
|
|
|
629
626
|
return {
|
|
630
627
|
embeddings: response.embeddings.float,
|
|
631
628
|
usage: { tokens: response.meta.billed_units.input_tokens },
|
|
632
|
-
|
|
629
|
+
response: { headers: responseHeaders, body: rawValue }
|
|
633
630
|
};
|
|
634
631
|
}
|
|
635
632
|
};
|
|
636
|
-
var cohereTextEmbeddingResponseSchema =
|
|
637
|
-
embeddings:
|
|
638
|
-
float:
|
|
633
|
+
var cohereTextEmbeddingResponseSchema = z4.object({
|
|
634
|
+
embeddings: z4.object({
|
|
635
|
+
float: z4.array(z4.array(z4.number()))
|
|
639
636
|
}),
|
|
640
|
-
meta:
|
|
641
|
-
billed_units:
|
|
642
|
-
input_tokens:
|
|
637
|
+
meta: z4.object({
|
|
638
|
+
billed_units: z4.object({
|
|
639
|
+
input_tokens: z4.number()
|
|
643
640
|
})
|
|
644
641
|
})
|
|
645
642
|
});
|
|
@@ -656,29 +653,32 @@ function createCohere(options = {}) {
|
|
|
656
653
|
})}`,
|
|
657
654
|
...options.headers
|
|
658
655
|
});
|
|
659
|
-
const createChatModel = (modelId
|
|
656
|
+
const createChatModel = (modelId) => new CohereChatLanguageModel(modelId, {
|
|
660
657
|
provider: "cohere.chat",
|
|
661
658
|
baseURL,
|
|
662
659
|
headers: getHeaders,
|
|
663
660
|
fetch: options.fetch
|
|
664
661
|
});
|
|
665
|
-
const createTextEmbeddingModel = (modelId
|
|
662
|
+
const createTextEmbeddingModel = (modelId) => new CohereEmbeddingModel(modelId, {
|
|
666
663
|
provider: "cohere.textEmbedding",
|
|
667
664
|
baseURL,
|
|
668
665
|
headers: getHeaders,
|
|
669
666
|
fetch: options.fetch
|
|
670
667
|
});
|
|
671
|
-
const provider = function(modelId
|
|
668
|
+
const provider = function(modelId) {
|
|
672
669
|
if (new.target) {
|
|
673
670
|
throw new Error(
|
|
674
671
|
"The Cohere model function cannot be called with the new keyword."
|
|
675
672
|
);
|
|
676
673
|
}
|
|
677
|
-
return createChatModel(modelId
|
|
674
|
+
return createChatModel(modelId);
|
|
678
675
|
};
|
|
679
676
|
provider.languageModel = createChatModel;
|
|
680
677
|
provider.embedding = createTextEmbeddingModel;
|
|
681
678
|
provider.textEmbeddingModel = createTextEmbeddingModel;
|
|
679
|
+
provider.imageModel = (modelId) => {
|
|
680
|
+
throw new NoSuchModelError({ modelId, modelType: "imageModel" });
|
|
681
|
+
};
|
|
682
682
|
return provider;
|
|
683
683
|
}
|
|
684
684
|
var cohere = createCohere();
|