@ai-sdk/cohere 1.2.10 → 2.0.0-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +189 -20
- package/dist/index.d.mts +6 -28
- package/dist/index.d.ts +6 -28
- package/dist/index.js +143 -144
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +144 -144
- package/dist/index.mjs.map +1 -1
- package/package.json +13 -12
package/dist/index.js
CHANGED
|
@@ -26,10 +26,10 @@ __export(src_exports, {
|
|
|
26
26
|
module.exports = __toCommonJS(src_exports);
|
|
27
27
|
|
|
28
28
|
// src/cohere-provider.ts
|
|
29
|
+
var import_provider4 = require("@ai-sdk/provider");
|
|
29
30
|
var import_provider_utils4 = require("@ai-sdk/provider-utils");
|
|
30
31
|
|
|
31
32
|
// src/cohere-chat-language-model.ts
|
|
32
|
-
var import_provider3 = require("@ai-sdk/provider");
|
|
33
33
|
var import_provider_utils2 = require("@ai-sdk/provider-utils");
|
|
34
34
|
var import_zod2 = require("zod");
|
|
35
35
|
|
|
@@ -62,9 +62,9 @@ function convertToCohereChatPrompt(prompt) {
|
|
|
62
62
|
case "text": {
|
|
63
63
|
return part.text;
|
|
64
64
|
}
|
|
65
|
-
case "
|
|
65
|
+
case "file": {
|
|
66
66
|
throw new import_provider.UnsupportedFunctionalityError({
|
|
67
|
-
functionality: "
|
|
67
|
+
functionality: "File URL data"
|
|
68
68
|
});
|
|
69
69
|
}
|
|
70
70
|
}
|
|
@@ -140,9 +140,11 @@ function mapCohereFinishReason(finishReason) {
|
|
|
140
140
|
|
|
141
141
|
// src/cohere-prepare-tools.ts
|
|
142
142
|
var import_provider2 = require("@ai-sdk/provider");
|
|
143
|
-
function prepareTools(
|
|
144
|
-
|
|
145
|
-
|
|
143
|
+
function prepareTools({
|
|
144
|
+
tools,
|
|
145
|
+
toolChoice
|
|
146
|
+
}) {
|
|
147
|
+
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
|
|
146
148
|
const toolWarnings = [];
|
|
147
149
|
if (tools == null) {
|
|
148
150
|
return { tools: void 0, toolChoice: void 0, toolWarnings };
|
|
@@ -162,7 +164,6 @@ function prepareTools(mode) {
|
|
|
162
164
|
});
|
|
163
165
|
}
|
|
164
166
|
}
|
|
165
|
-
const toolChoice = mode.toolChoice;
|
|
166
167
|
if (toolChoice == null) {
|
|
167
168
|
return { tools: cohereTools, toolChoice: void 0, toolWarnings };
|
|
168
169
|
}
|
|
@@ -185,7 +186,7 @@ function prepareTools(mode) {
|
|
|
185
186
|
default: {
|
|
186
187
|
const _exhaustiveCheck = type;
|
|
187
188
|
throw new import_provider2.UnsupportedFunctionalityError({
|
|
188
|
-
functionality: `
|
|
189
|
+
functionality: `tool choice type: ${_exhaustiveCheck}`
|
|
189
190
|
});
|
|
190
191
|
}
|
|
191
192
|
}
|
|
@@ -193,20 +194,20 @@ function prepareTools(mode) {
|
|
|
193
194
|
|
|
194
195
|
// src/cohere-chat-language-model.ts
|
|
195
196
|
var CohereChatLanguageModel = class {
|
|
196
|
-
constructor(modelId,
|
|
197
|
-
this.specificationVersion = "
|
|
198
|
-
this.
|
|
197
|
+
constructor(modelId, config) {
|
|
198
|
+
this.specificationVersion = "v2";
|
|
199
|
+
this.supportedUrls = {
|
|
200
|
+
// No URLs are supported.
|
|
201
|
+
};
|
|
199
202
|
this.modelId = modelId;
|
|
200
|
-
this.settings = settings;
|
|
201
203
|
this.config = config;
|
|
202
204
|
}
|
|
203
205
|
get provider() {
|
|
204
206
|
return this.config.provider;
|
|
205
207
|
}
|
|
206
208
|
getArgs({
|
|
207
|
-
mode,
|
|
208
209
|
prompt,
|
|
209
|
-
|
|
210
|
+
maxOutputTokens,
|
|
210
211
|
temperature,
|
|
211
212
|
topP,
|
|
212
213
|
topK,
|
|
@@ -214,78 +215,42 @@ var CohereChatLanguageModel = class {
|
|
|
214
215
|
presencePenalty,
|
|
215
216
|
stopSequences,
|
|
216
217
|
responseFormat,
|
|
217
|
-
seed
|
|
218
|
+
seed,
|
|
219
|
+
tools,
|
|
220
|
+
toolChoice
|
|
218
221
|
}) {
|
|
219
|
-
var _a;
|
|
220
|
-
const type = mode.type;
|
|
221
222
|
const chatPrompt = convertToCohereChatPrompt(prompt);
|
|
222
|
-
const
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
223
|
+
const {
|
|
224
|
+
tools: cohereTools,
|
|
225
|
+
toolChoice: cohereToolChoice,
|
|
226
|
+
toolWarnings
|
|
227
|
+
} = prepareTools({ tools, toolChoice });
|
|
228
|
+
return {
|
|
229
|
+
args: {
|
|
230
|
+
// model id:
|
|
231
|
+
model: this.modelId,
|
|
232
|
+
// standardized settings:
|
|
233
|
+
frequency_penalty: frequencyPenalty,
|
|
234
|
+
presence_penalty: presencePenalty,
|
|
235
|
+
max_tokens: maxOutputTokens,
|
|
236
|
+
temperature,
|
|
237
|
+
p: topP,
|
|
238
|
+
k: topK,
|
|
239
|
+
seed,
|
|
240
|
+
stop_sequences: stopSequences,
|
|
241
|
+
// response format:
|
|
242
|
+
response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0,
|
|
243
|
+
// messages:
|
|
244
|
+
messages: chatPrompt,
|
|
245
|
+
// tools:
|
|
246
|
+
tools: cohereTools,
|
|
247
|
+
tool_choice: cohereToolChoice
|
|
248
|
+
},
|
|
249
|
+
warnings: toolWarnings
|
|
238
250
|
};
|
|
239
|
-
switch (type) {
|
|
240
|
-
case "regular": {
|
|
241
|
-
const { tools, toolChoice, toolWarnings } = prepareTools(mode);
|
|
242
|
-
return {
|
|
243
|
-
args: {
|
|
244
|
-
...baseArgs,
|
|
245
|
-
tools,
|
|
246
|
-
tool_choice: toolChoice
|
|
247
|
-
},
|
|
248
|
-
warnings: toolWarnings
|
|
249
|
-
};
|
|
250
|
-
}
|
|
251
|
-
case "object-json": {
|
|
252
|
-
return {
|
|
253
|
-
args: {
|
|
254
|
-
...baseArgs,
|
|
255
|
-
response_format: mode.schema == null ? { type: "json_object" } : { type: "json_object", json_schema: mode.schema }
|
|
256
|
-
},
|
|
257
|
-
warnings: []
|
|
258
|
-
};
|
|
259
|
-
}
|
|
260
|
-
case "object-tool": {
|
|
261
|
-
return {
|
|
262
|
-
args: {
|
|
263
|
-
...baseArgs,
|
|
264
|
-
tools: [
|
|
265
|
-
{
|
|
266
|
-
type: "function",
|
|
267
|
-
function: {
|
|
268
|
-
name: mode.tool.name,
|
|
269
|
-
description: (_a = mode.tool.description) != null ? _a : "",
|
|
270
|
-
parameters: mode.tool.parameters
|
|
271
|
-
}
|
|
272
|
-
}
|
|
273
|
-
],
|
|
274
|
-
tool_choice: "REQUIRED"
|
|
275
|
-
},
|
|
276
|
-
warnings: []
|
|
277
|
-
};
|
|
278
|
-
}
|
|
279
|
-
default: {
|
|
280
|
-
const _exhaustiveCheck = type;
|
|
281
|
-
throw new import_provider3.UnsupportedFunctionalityError({
|
|
282
|
-
functionality: `Unsupported mode: ${_exhaustiveCheck}`
|
|
283
|
-
});
|
|
284
|
-
}
|
|
285
|
-
}
|
|
286
251
|
}
|
|
287
252
|
async doGenerate(options) {
|
|
288
|
-
var _a, _b, _c
|
|
253
|
+
var _a, _b, _c;
|
|
289
254
|
const { args, warnings } = this.getArgs(options);
|
|
290
255
|
const {
|
|
291
256
|
responseHeaders,
|
|
@@ -302,38 +267,40 @@ var CohereChatLanguageModel = class {
|
|
|
302
267
|
abortSignal: options.abortSignal,
|
|
303
268
|
fetch: this.config.fetch
|
|
304
269
|
});
|
|
305
|
-
const
|
|
306
|
-
const text = (
|
|
270
|
+
const content = [];
|
|
271
|
+
const text = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text;
|
|
272
|
+
if (text != null && text.length > 0) {
|
|
273
|
+
content.push({ type: "text", text });
|
|
274
|
+
}
|
|
275
|
+
if (response.message.tool_calls != null) {
|
|
276
|
+
for (const toolCall of response.message.tool_calls) {
|
|
277
|
+
content.push({
|
|
278
|
+
type: "tool-call",
|
|
279
|
+
toolCallId: toolCall.id,
|
|
280
|
+
toolName: toolCall.function.name,
|
|
281
|
+
// Cohere sometimes returns `null` for tool call arguments for tools
|
|
282
|
+
// defined as having no arguments.
|
|
283
|
+
args: toolCall.function.arguments.replace(/^null$/, "{}"),
|
|
284
|
+
toolCallType: "function"
|
|
285
|
+
});
|
|
286
|
+
}
|
|
287
|
+
}
|
|
307
288
|
return {
|
|
308
|
-
|
|
309
|
-
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
|
|
310
|
-
toolCallId: toolCall.id,
|
|
311
|
-
toolName: toolCall.function.name,
|
|
312
|
-
// Cohere sometimes returns `null` for tool call arguments for tools
|
|
313
|
-
// defined as having no arguments.
|
|
314
|
-
args: toolCall.function.arguments.replace(/^null$/, "{}"),
|
|
315
|
-
toolCallType: "function"
|
|
316
|
-
})) : [],
|
|
289
|
+
content,
|
|
317
290
|
finishReason: mapCohereFinishReason(response.finish_reason),
|
|
318
291
|
usage: {
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
rawCall: {
|
|
323
|
-
rawPrompt: {
|
|
324
|
-
messages
|
|
325
|
-
},
|
|
326
|
-
rawSettings
|
|
292
|
+
inputTokens: response.usage.tokens.input_tokens,
|
|
293
|
+
outputTokens: response.usage.tokens.output_tokens,
|
|
294
|
+
totalTokens: response.usage.tokens.input_tokens + response.usage.tokens.output_tokens
|
|
327
295
|
},
|
|
296
|
+
request: { body: args },
|
|
328
297
|
response: {
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
rawResponse: {
|
|
298
|
+
// TODO timestamp, model id
|
|
299
|
+
id: (_c = response.generation_id) != null ? _c : void 0,
|
|
332
300
|
headers: responseHeaders,
|
|
333
301
|
body: rawResponse
|
|
334
302
|
},
|
|
335
|
-
warnings
|
|
336
|
-
request: { body: JSON.stringify(args) }
|
|
303
|
+
warnings
|
|
337
304
|
};
|
|
338
305
|
}
|
|
339
306
|
async doStream(options) {
|
|
@@ -349,11 +316,11 @@ var CohereChatLanguageModel = class {
|
|
|
349
316
|
abortSignal: options.abortSignal,
|
|
350
317
|
fetch: this.config.fetch
|
|
351
318
|
});
|
|
352
|
-
const { messages, ...rawSettings } = args;
|
|
353
319
|
let finishReason = "unknown";
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
320
|
+
const usage = {
|
|
321
|
+
inputTokens: void 0,
|
|
322
|
+
outputTokens: void 0,
|
|
323
|
+
totalTokens: void 0
|
|
357
324
|
};
|
|
358
325
|
let pendingToolCallDelta = {
|
|
359
326
|
toolCallId: "",
|
|
@@ -363,6 +330,9 @@ var CohereChatLanguageModel = class {
|
|
|
363
330
|
return {
|
|
364
331
|
stream: response.pipeThrough(
|
|
365
332
|
new TransformStream({
|
|
333
|
+
start(controller) {
|
|
334
|
+
controller.enqueue({ type: "stream-start", warnings });
|
|
335
|
+
},
|
|
366
336
|
transform(chunk, controller) {
|
|
367
337
|
var _a, _b;
|
|
368
338
|
if (!chunk.success) {
|
|
@@ -375,8 +345,8 @@ var CohereChatLanguageModel = class {
|
|
|
375
345
|
switch (type) {
|
|
376
346
|
case "content-delta": {
|
|
377
347
|
controller.enqueue({
|
|
378
|
-
type: "text
|
|
379
|
-
|
|
348
|
+
type: "text",
|
|
349
|
+
text: value.delta.message.content.text
|
|
380
350
|
});
|
|
381
351
|
return;
|
|
382
352
|
}
|
|
@@ -435,10 +405,9 @@ var CohereChatLanguageModel = class {
|
|
|
435
405
|
case "message-end": {
|
|
436
406
|
finishReason = mapCohereFinishReason(value.delta.finish_reason);
|
|
437
407
|
const tokens = value.delta.usage.tokens;
|
|
438
|
-
usage =
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
};
|
|
408
|
+
usage.inputTokens = tokens.input_tokens;
|
|
409
|
+
usage.outputTokens = tokens.output_tokens;
|
|
410
|
+
usage.totalTokens = tokens.input_tokens + tokens.output_tokens;
|
|
442
411
|
}
|
|
443
412
|
default: {
|
|
444
413
|
return;
|
|
@@ -454,15 +423,8 @@ var CohereChatLanguageModel = class {
|
|
|
454
423
|
}
|
|
455
424
|
})
|
|
456
425
|
),
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
messages
|
|
460
|
-
},
|
|
461
|
-
rawSettings
|
|
462
|
-
},
|
|
463
|
-
rawResponse: { headers: responseHeaders },
|
|
464
|
-
warnings,
|
|
465
|
-
request: { body: JSON.stringify({ ...args, stream: true }) }
|
|
426
|
+
request: { body: { ...args, stream: true } },
|
|
427
|
+
response: { headers: responseHeaders }
|
|
466
428
|
};
|
|
467
429
|
}
|
|
468
430
|
};
|
|
@@ -584,16 +546,40 @@ var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [
|
|
|
584
546
|
]);
|
|
585
547
|
|
|
586
548
|
// src/cohere-embedding-model.ts
|
|
587
|
-
var
|
|
549
|
+
var import_provider3 = require("@ai-sdk/provider");
|
|
588
550
|
var import_provider_utils3 = require("@ai-sdk/provider-utils");
|
|
551
|
+
var import_zod4 = require("zod");
|
|
552
|
+
|
|
553
|
+
// src/cohere-embedding-options.ts
|
|
589
554
|
var import_zod3 = require("zod");
|
|
555
|
+
var cohereEmbeddingOptions = import_zod3.z.object({
|
|
556
|
+
/**
|
|
557
|
+
* Specifies the type of input passed to the model. Default is `search_query`.
|
|
558
|
+
*
|
|
559
|
+
* - "search_document": Used for embeddings stored in a vector database for search use-cases.
|
|
560
|
+
* - "search_query": Used for embeddings of search queries run against a vector DB to find relevant documents.
|
|
561
|
+
* - "classification": Used for embeddings passed through a text classifier.
|
|
562
|
+
* - "clustering": Used for embeddings run through a clustering algorithm.
|
|
563
|
+
*/
|
|
564
|
+
inputType: import_zod3.z.enum(["search_document", "search_query", "classification", "clustering"]).optional(),
|
|
565
|
+
/**
|
|
566
|
+
* Specifies how the API will handle inputs longer than the maximum token length.
|
|
567
|
+
* Default is `END`.
|
|
568
|
+
*
|
|
569
|
+
* - "NONE": If selected, when the input exceeds the maximum input token length will return an error.
|
|
570
|
+
* - "START": Will discard the start of the input until the remaining input is exactly the maximum input token length for the model.
|
|
571
|
+
* - "END": Will discard the end of the input until the remaining input is exactly the maximum input token length for the model.
|
|
572
|
+
*/
|
|
573
|
+
truncate: import_zod3.z.enum(["NONE", "START", "END"]).optional()
|
|
574
|
+
});
|
|
575
|
+
|
|
576
|
+
// src/cohere-embedding-model.ts
|
|
590
577
|
var CohereEmbeddingModel = class {
|
|
591
|
-
constructor(modelId,
|
|
592
|
-
this.specificationVersion = "
|
|
578
|
+
constructor(modelId, config) {
|
|
579
|
+
this.specificationVersion = "v2";
|
|
593
580
|
this.maxEmbeddingsPerCall = 96;
|
|
594
581
|
this.supportsParallelCalls = true;
|
|
595
582
|
this.modelId = modelId;
|
|
596
|
-
this.settings = settings;
|
|
597
583
|
this.config = config;
|
|
598
584
|
}
|
|
599
585
|
get provider() {
|
|
@@ -602,18 +588,28 @@ var CohereEmbeddingModel = class {
|
|
|
602
588
|
async doEmbed({
|
|
603
589
|
values,
|
|
604
590
|
headers,
|
|
605
|
-
abortSignal
|
|
591
|
+
abortSignal,
|
|
592
|
+
providerOptions
|
|
606
593
|
}) {
|
|
607
594
|
var _a;
|
|
595
|
+
const embeddingOptions = await (0, import_provider_utils3.parseProviderOptions)({
|
|
596
|
+
provider: "cohere",
|
|
597
|
+
providerOptions,
|
|
598
|
+
schema: cohereEmbeddingOptions
|
|
599
|
+
});
|
|
608
600
|
if (values.length > this.maxEmbeddingsPerCall) {
|
|
609
|
-
throw new
|
|
601
|
+
throw new import_provider3.TooManyEmbeddingValuesForCallError({
|
|
610
602
|
provider: this.provider,
|
|
611
603
|
modelId: this.modelId,
|
|
612
604
|
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
|
|
613
605
|
values
|
|
614
606
|
});
|
|
615
607
|
}
|
|
616
|
-
const {
|
|
608
|
+
const {
|
|
609
|
+
responseHeaders,
|
|
610
|
+
value: response,
|
|
611
|
+
rawValue
|
|
612
|
+
} = await (0, import_provider_utils3.postJsonToApi)({
|
|
617
613
|
url: `${this.config.baseURL}/embed`,
|
|
618
614
|
headers: (0, import_provider_utils3.combineHeaders)(this.config.headers(), headers),
|
|
619
615
|
body: {
|
|
@@ -623,8 +619,8 @@ var CohereEmbeddingModel = class {
|
|
|
623
619
|
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
|
|
624
620
|
embedding_types: ["float"],
|
|
625
621
|
texts: values,
|
|
626
|
-
input_type: (_a =
|
|
627
|
-
truncate:
|
|
622
|
+
input_type: (_a = embeddingOptions == null ? void 0 : embeddingOptions.inputType) != null ? _a : "search_query",
|
|
623
|
+
truncate: embeddingOptions == null ? void 0 : embeddingOptions.truncate
|
|
628
624
|
},
|
|
629
625
|
failedResponseHandler: cohereFailedResponseHandler,
|
|
630
626
|
successfulResponseHandler: (0, import_provider_utils3.createJsonResponseHandler)(
|
|
@@ -636,17 +632,17 @@ var CohereEmbeddingModel = class {
|
|
|
636
632
|
return {
|
|
637
633
|
embeddings: response.embeddings.float,
|
|
638
634
|
usage: { tokens: response.meta.billed_units.input_tokens },
|
|
639
|
-
|
|
635
|
+
response: { headers: responseHeaders, body: rawValue }
|
|
640
636
|
};
|
|
641
637
|
}
|
|
642
638
|
};
|
|
643
|
-
var cohereTextEmbeddingResponseSchema =
|
|
644
|
-
embeddings:
|
|
645
|
-
float:
|
|
639
|
+
var cohereTextEmbeddingResponseSchema = import_zod4.z.object({
|
|
640
|
+
embeddings: import_zod4.z.object({
|
|
641
|
+
float: import_zod4.z.array(import_zod4.z.array(import_zod4.z.number()))
|
|
646
642
|
}),
|
|
647
|
-
meta:
|
|
648
|
-
billed_units:
|
|
649
|
-
input_tokens:
|
|
643
|
+
meta: import_zod4.z.object({
|
|
644
|
+
billed_units: import_zod4.z.object({
|
|
645
|
+
input_tokens: import_zod4.z.number()
|
|
650
646
|
})
|
|
651
647
|
})
|
|
652
648
|
});
|
|
@@ -663,29 +659,32 @@ function createCohere(options = {}) {
|
|
|
663
659
|
})}`,
|
|
664
660
|
...options.headers
|
|
665
661
|
});
|
|
666
|
-
const createChatModel = (modelId
|
|
662
|
+
const createChatModel = (modelId) => new CohereChatLanguageModel(modelId, {
|
|
667
663
|
provider: "cohere.chat",
|
|
668
664
|
baseURL,
|
|
669
665
|
headers: getHeaders,
|
|
670
666
|
fetch: options.fetch
|
|
671
667
|
});
|
|
672
|
-
const createTextEmbeddingModel = (modelId
|
|
668
|
+
const createTextEmbeddingModel = (modelId) => new CohereEmbeddingModel(modelId, {
|
|
673
669
|
provider: "cohere.textEmbedding",
|
|
674
670
|
baseURL,
|
|
675
671
|
headers: getHeaders,
|
|
676
672
|
fetch: options.fetch
|
|
677
673
|
});
|
|
678
|
-
const provider = function(modelId
|
|
674
|
+
const provider = function(modelId) {
|
|
679
675
|
if (new.target) {
|
|
680
676
|
throw new Error(
|
|
681
677
|
"The Cohere model function cannot be called with the new keyword."
|
|
682
678
|
);
|
|
683
679
|
}
|
|
684
|
-
return createChatModel(modelId
|
|
680
|
+
return createChatModel(modelId);
|
|
685
681
|
};
|
|
686
682
|
provider.languageModel = createChatModel;
|
|
687
683
|
provider.embedding = createTextEmbeddingModel;
|
|
688
684
|
provider.textEmbeddingModel = createTextEmbeddingModel;
|
|
685
|
+
provider.imageModel = (modelId) => {
|
|
686
|
+
throw new import_provider4.NoSuchModelError({ modelId, modelType: "imageModel" });
|
|
687
|
+
};
|
|
689
688
|
return provider;
|
|
690
689
|
}
|
|
691
690
|
var cohere = createCohere();
|