@ai-sdk/cohere 1.2.10 → 2.0.0-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -26,10 +26,10 @@ __export(src_exports, {
26
26
  module.exports = __toCommonJS(src_exports);
27
27
 
28
28
  // src/cohere-provider.ts
29
+ var import_provider4 = require("@ai-sdk/provider");
29
30
  var import_provider_utils4 = require("@ai-sdk/provider-utils");
30
31
 
31
32
  // src/cohere-chat-language-model.ts
32
- var import_provider3 = require("@ai-sdk/provider");
33
33
  var import_provider_utils2 = require("@ai-sdk/provider-utils");
34
34
  var import_zod2 = require("zod");
35
35
 
@@ -62,9 +62,9 @@ function convertToCohereChatPrompt(prompt) {
62
62
  case "text": {
63
63
  return part.text;
64
64
  }
65
- case "image": {
65
+ case "file": {
66
66
  throw new import_provider.UnsupportedFunctionalityError({
67
- functionality: "image-part"
67
+ functionality: "File URL data"
68
68
  });
69
69
  }
70
70
  }
@@ -140,9 +140,11 @@ function mapCohereFinishReason(finishReason) {
140
140
 
141
141
  // src/cohere-prepare-tools.ts
142
142
  var import_provider2 = require("@ai-sdk/provider");
143
- function prepareTools(mode) {
144
- var _a;
145
- const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
143
+ function prepareTools({
144
+ tools,
145
+ toolChoice
146
+ }) {
147
+ tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
146
148
  const toolWarnings = [];
147
149
  if (tools == null) {
148
150
  return { tools: void 0, toolChoice: void 0, toolWarnings };
@@ -162,7 +164,6 @@ function prepareTools(mode) {
162
164
  });
163
165
  }
164
166
  }
165
- const toolChoice = mode.toolChoice;
166
167
  if (toolChoice == null) {
167
168
  return { tools: cohereTools, toolChoice: void 0, toolWarnings };
168
169
  }
@@ -185,7 +186,7 @@ function prepareTools(mode) {
185
186
  default: {
186
187
  const _exhaustiveCheck = type;
187
188
  throw new import_provider2.UnsupportedFunctionalityError({
188
- functionality: `Unsupported tool choice type: ${_exhaustiveCheck}`
189
+ functionality: `tool choice type: ${_exhaustiveCheck}`
189
190
  });
190
191
  }
191
192
  }
@@ -193,20 +194,20 @@ function prepareTools(mode) {
193
194
 
194
195
  // src/cohere-chat-language-model.ts
195
196
  var CohereChatLanguageModel = class {
196
- constructor(modelId, settings, config) {
197
- this.specificationVersion = "v1";
198
- this.defaultObjectGenerationMode = "json";
197
+ constructor(modelId, config) {
198
+ this.specificationVersion = "v2";
199
+ this.supportedUrls = {
200
+ // No URLs are supported.
201
+ };
199
202
  this.modelId = modelId;
200
- this.settings = settings;
201
203
  this.config = config;
202
204
  }
203
205
  get provider() {
204
206
  return this.config.provider;
205
207
  }
206
208
  getArgs({
207
- mode,
208
209
  prompt,
209
- maxTokens,
210
+ maxOutputTokens,
210
211
  temperature,
211
212
  topP,
212
213
  topK,
@@ -214,78 +215,42 @@ var CohereChatLanguageModel = class {
214
215
  presencePenalty,
215
216
  stopSequences,
216
217
  responseFormat,
217
- seed
218
+ seed,
219
+ tools,
220
+ toolChoice
218
221
  }) {
219
- var _a;
220
- const type = mode.type;
221
222
  const chatPrompt = convertToCohereChatPrompt(prompt);
222
- const baseArgs = {
223
- // model id:
224
- model: this.modelId,
225
- // standardized settings:
226
- frequency_penalty: frequencyPenalty,
227
- presence_penalty: presencePenalty,
228
- max_tokens: maxTokens,
229
- temperature,
230
- p: topP,
231
- k: topK,
232
- seed,
233
- stop_sequences: stopSequences,
234
- // response format:
235
- response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0,
236
- // messages:
237
- messages: chatPrompt
223
+ const {
224
+ tools: cohereTools,
225
+ toolChoice: cohereToolChoice,
226
+ toolWarnings
227
+ } = prepareTools({ tools, toolChoice });
228
+ return {
229
+ args: {
230
+ // model id:
231
+ model: this.modelId,
232
+ // standardized settings:
233
+ frequency_penalty: frequencyPenalty,
234
+ presence_penalty: presencePenalty,
235
+ max_tokens: maxOutputTokens,
236
+ temperature,
237
+ p: topP,
238
+ k: topK,
239
+ seed,
240
+ stop_sequences: stopSequences,
241
+ // response format:
242
+ response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0,
243
+ // messages:
244
+ messages: chatPrompt,
245
+ // tools:
246
+ tools: cohereTools,
247
+ tool_choice: cohereToolChoice
248
+ },
249
+ warnings: toolWarnings
238
250
  };
239
- switch (type) {
240
- case "regular": {
241
- const { tools, toolChoice, toolWarnings } = prepareTools(mode);
242
- return {
243
- args: {
244
- ...baseArgs,
245
- tools,
246
- tool_choice: toolChoice
247
- },
248
- warnings: toolWarnings
249
- };
250
- }
251
- case "object-json": {
252
- return {
253
- args: {
254
- ...baseArgs,
255
- response_format: mode.schema == null ? { type: "json_object" } : { type: "json_object", json_schema: mode.schema }
256
- },
257
- warnings: []
258
- };
259
- }
260
- case "object-tool": {
261
- return {
262
- args: {
263
- ...baseArgs,
264
- tools: [
265
- {
266
- type: "function",
267
- function: {
268
- name: mode.tool.name,
269
- description: (_a = mode.tool.description) != null ? _a : "",
270
- parameters: mode.tool.parameters
271
- }
272
- }
273
- ],
274
- tool_choice: "REQUIRED"
275
- },
276
- warnings: []
277
- };
278
- }
279
- default: {
280
- const _exhaustiveCheck = type;
281
- throw new import_provider3.UnsupportedFunctionalityError({
282
- functionality: `Unsupported mode: ${_exhaustiveCheck}`
283
- });
284
- }
285
- }
286
251
  }
287
252
  async doGenerate(options) {
288
- var _a, _b, _c, _d;
253
+ var _a, _b, _c;
289
254
  const { args, warnings } = this.getArgs(options);
290
255
  const {
291
256
  responseHeaders,
@@ -302,38 +267,40 @@ var CohereChatLanguageModel = class {
302
267
  abortSignal: options.abortSignal,
303
268
  fetch: this.config.fetch
304
269
  });
305
- const { messages, ...rawSettings } = args;
306
- const text = (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "";
270
+ const content = [];
271
+ const text = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text;
272
+ if (text != null && text.length > 0) {
273
+ content.push({ type: "text", text });
274
+ }
275
+ if (response.message.tool_calls != null) {
276
+ for (const toolCall of response.message.tool_calls) {
277
+ content.push({
278
+ type: "tool-call",
279
+ toolCallId: toolCall.id,
280
+ toolName: toolCall.function.name,
281
+ // Cohere sometimes returns `null` for tool call arguments for tools
282
+ // defined as having no arguments.
283
+ args: toolCall.function.arguments.replace(/^null$/, "{}"),
284
+ toolCallType: "function"
285
+ });
286
+ }
287
+ }
307
288
  return {
308
- text,
309
- toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
310
- toolCallId: toolCall.id,
311
- toolName: toolCall.function.name,
312
- // Cohere sometimes returns `null` for tool call arguments for tools
313
- // defined as having no arguments.
314
- args: toolCall.function.arguments.replace(/^null$/, "{}"),
315
- toolCallType: "function"
316
- })) : [],
289
+ content,
317
290
  finishReason: mapCohereFinishReason(response.finish_reason),
318
291
  usage: {
319
- promptTokens: response.usage.tokens.input_tokens,
320
- completionTokens: response.usage.tokens.output_tokens
321
- },
322
- rawCall: {
323
- rawPrompt: {
324
- messages
325
- },
326
- rawSettings
292
+ inputTokens: response.usage.tokens.input_tokens,
293
+ outputTokens: response.usage.tokens.output_tokens,
294
+ totalTokens: response.usage.tokens.input_tokens + response.usage.tokens.output_tokens
327
295
  },
296
+ request: { body: args },
328
297
  response: {
329
- id: (_d = response.generation_id) != null ? _d : void 0
330
- },
331
- rawResponse: {
298
+ // TODO timestamp, model id
299
+ id: (_c = response.generation_id) != null ? _c : void 0,
332
300
  headers: responseHeaders,
333
301
  body: rawResponse
334
302
  },
335
- warnings,
336
- request: { body: JSON.stringify(args) }
303
+ warnings
337
304
  };
338
305
  }
339
306
  async doStream(options) {
@@ -349,11 +316,11 @@ var CohereChatLanguageModel = class {
349
316
  abortSignal: options.abortSignal,
350
317
  fetch: this.config.fetch
351
318
  });
352
- const { messages, ...rawSettings } = args;
353
319
  let finishReason = "unknown";
354
- let usage = {
355
- promptTokens: Number.NaN,
356
- completionTokens: Number.NaN
320
+ const usage = {
321
+ inputTokens: void 0,
322
+ outputTokens: void 0,
323
+ totalTokens: void 0
357
324
  };
358
325
  let pendingToolCallDelta = {
359
326
  toolCallId: "",
@@ -363,6 +330,9 @@ var CohereChatLanguageModel = class {
363
330
  return {
364
331
  stream: response.pipeThrough(
365
332
  new TransformStream({
333
+ start(controller) {
334
+ controller.enqueue({ type: "stream-start", warnings });
335
+ },
366
336
  transform(chunk, controller) {
367
337
  var _a, _b;
368
338
  if (!chunk.success) {
@@ -375,8 +345,8 @@ var CohereChatLanguageModel = class {
375
345
  switch (type) {
376
346
  case "content-delta": {
377
347
  controller.enqueue({
378
- type: "text-delta",
379
- textDelta: value.delta.message.content.text
348
+ type: "text",
349
+ text: value.delta.message.content.text
380
350
  });
381
351
  return;
382
352
  }
@@ -435,10 +405,9 @@ var CohereChatLanguageModel = class {
435
405
  case "message-end": {
436
406
  finishReason = mapCohereFinishReason(value.delta.finish_reason);
437
407
  const tokens = value.delta.usage.tokens;
438
- usage = {
439
- promptTokens: tokens.input_tokens,
440
- completionTokens: tokens.output_tokens
441
- };
408
+ usage.inputTokens = tokens.input_tokens;
409
+ usage.outputTokens = tokens.output_tokens;
410
+ usage.totalTokens = tokens.input_tokens + tokens.output_tokens;
442
411
  }
443
412
  default: {
444
413
  return;
@@ -454,15 +423,8 @@ var CohereChatLanguageModel = class {
454
423
  }
455
424
  })
456
425
  ),
457
- rawCall: {
458
- rawPrompt: {
459
- messages
460
- },
461
- rawSettings
462
- },
463
- rawResponse: { headers: responseHeaders },
464
- warnings,
465
- request: { body: JSON.stringify({ ...args, stream: true }) }
426
+ request: { body: { ...args, stream: true } },
427
+ response: { headers: responseHeaders }
466
428
  };
467
429
  }
468
430
  };
@@ -584,16 +546,40 @@ var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [
584
546
  ]);
585
547
 
586
548
  // src/cohere-embedding-model.ts
587
- var import_provider4 = require("@ai-sdk/provider");
549
+ var import_provider3 = require("@ai-sdk/provider");
588
550
  var import_provider_utils3 = require("@ai-sdk/provider-utils");
551
+ var import_zod4 = require("zod");
552
+
553
+ // src/cohere-embedding-options.ts
589
554
  var import_zod3 = require("zod");
555
+ var cohereEmbeddingOptions = import_zod3.z.object({
556
+ /**
557
+ * Specifies the type of input passed to the model. Default is `search_query`.
558
+ *
559
+ * - "search_document": Used for embeddings stored in a vector database for search use-cases.
560
+ * - "search_query": Used for embeddings of search queries run against a vector DB to find relevant documents.
561
+ * - "classification": Used for embeddings passed through a text classifier.
562
+ * - "clustering": Used for embeddings run through a clustering algorithm.
563
+ */
564
+ inputType: import_zod3.z.enum(["search_document", "search_query", "classification", "clustering"]).optional(),
565
+ /**
566
+ * Specifies how the API will handle inputs longer than the maximum token length.
567
+ * Default is `END`.
568
+ *
569
+ * - "NONE": If selected, when the input exceeds the maximum input token length will return an error.
570
+ * - "START": Will discard the start of the input until the remaining input is exactly the maximum input token length for the model.
571
+ * - "END": Will discard the end of the input until the remaining input is exactly the maximum input token length for the model.
572
+ */
573
+ truncate: import_zod3.z.enum(["NONE", "START", "END"]).optional()
574
+ });
575
+
576
+ // src/cohere-embedding-model.ts
590
577
  var CohereEmbeddingModel = class {
591
- constructor(modelId, settings, config) {
592
- this.specificationVersion = "v1";
578
+ constructor(modelId, config) {
579
+ this.specificationVersion = "v2";
593
580
  this.maxEmbeddingsPerCall = 96;
594
581
  this.supportsParallelCalls = true;
595
582
  this.modelId = modelId;
596
- this.settings = settings;
597
583
  this.config = config;
598
584
  }
599
585
  get provider() {
@@ -602,18 +588,28 @@ var CohereEmbeddingModel = class {
602
588
  async doEmbed({
603
589
  values,
604
590
  headers,
605
- abortSignal
591
+ abortSignal,
592
+ providerOptions
606
593
  }) {
607
594
  var _a;
595
+ const embeddingOptions = await (0, import_provider_utils3.parseProviderOptions)({
596
+ provider: "cohere",
597
+ providerOptions,
598
+ schema: cohereEmbeddingOptions
599
+ });
608
600
  if (values.length > this.maxEmbeddingsPerCall) {
609
- throw new import_provider4.TooManyEmbeddingValuesForCallError({
601
+ throw new import_provider3.TooManyEmbeddingValuesForCallError({
610
602
  provider: this.provider,
611
603
  modelId: this.modelId,
612
604
  maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
613
605
  values
614
606
  });
615
607
  }
616
- const { responseHeaders, value: response } = await (0, import_provider_utils3.postJsonToApi)({
608
+ const {
609
+ responseHeaders,
610
+ value: response,
611
+ rawValue
612
+ } = await (0, import_provider_utils3.postJsonToApi)({
617
613
  url: `${this.config.baseURL}/embed`,
618
614
  headers: (0, import_provider_utils3.combineHeaders)(this.config.headers(), headers),
619
615
  body: {
@@ -623,8 +619,8 @@ var CohereEmbeddingModel = class {
623
619
  // https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
624
620
  embedding_types: ["float"],
625
621
  texts: values,
626
- input_type: (_a = this.settings.inputType) != null ? _a : "search_query",
627
- truncate: this.settings.truncate
622
+ input_type: (_a = embeddingOptions == null ? void 0 : embeddingOptions.inputType) != null ? _a : "search_query",
623
+ truncate: embeddingOptions == null ? void 0 : embeddingOptions.truncate
628
624
  },
629
625
  failedResponseHandler: cohereFailedResponseHandler,
630
626
  successfulResponseHandler: (0, import_provider_utils3.createJsonResponseHandler)(
@@ -636,17 +632,17 @@ var CohereEmbeddingModel = class {
636
632
  return {
637
633
  embeddings: response.embeddings.float,
638
634
  usage: { tokens: response.meta.billed_units.input_tokens },
639
- rawResponse: { headers: responseHeaders }
635
+ response: { headers: responseHeaders, body: rawValue }
640
636
  };
641
637
  }
642
638
  };
643
- var cohereTextEmbeddingResponseSchema = import_zod3.z.object({
644
- embeddings: import_zod3.z.object({
645
- float: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number()))
639
+ var cohereTextEmbeddingResponseSchema = import_zod4.z.object({
640
+ embeddings: import_zod4.z.object({
641
+ float: import_zod4.z.array(import_zod4.z.array(import_zod4.z.number()))
646
642
  }),
647
- meta: import_zod3.z.object({
648
- billed_units: import_zod3.z.object({
649
- input_tokens: import_zod3.z.number()
643
+ meta: import_zod4.z.object({
644
+ billed_units: import_zod4.z.object({
645
+ input_tokens: import_zod4.z.number()
650
646
  })
651
647
  })
652
648
  });
@@ -663,29 +659,32 @@ function createCohere(options = {}) {
663
659
  })}`,
664
660
  ...options.headers
665
661
  });
666
- const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, {
662
+ const createChatModel = (modelId) => new CohereChatLanguageModel(modelId, {
667
663
  provider: "cohere.chat",
668
664
  baseURL,
669
665
  headers: getHeaders,
670
666
  fetch: options.fetch
671
667
  });
672
- const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, {
668
+ const createTextEmbeddingModel = (modelId) => new CohereEmbeddingModel(modelId, {
673
669
  provider: "cohere.textEmbedding",
674
670
  baseURL,
675
671
  headers: getHeaders,
676
672
  fetch: options.fetch
677
673
  });
678
- const provider = function(modelId, settings) {
674
+ const provider = function(modelId) {
679
675
  if (new.target) {
680
676
  throw new Error(
681
677
  "The Cohere model function cannot be called with the new keyword."
682
678
  );
683
679
  }
684
- return createChatModel(modelId, settings);
680
+ return createChatModel(modelId);
685
681
  };
686
682
  provider.languageModel = createChatModel;
687
683
  provider.embedding = createTextEmbeddingModel;
688
684
  provider.textEmbeddingModel = createTextEmbeddingModel;
685
+ provider.imageModel = (modelId) => {
686
+ throw new import_provider4.NoSuchModelError({ modelId, modelType: "imageModel" });
687
+ };
689
688
  return provider;
690
689
  }
691
690
  var cohere = createCohere();