@ai-sdk/cohere 1.2.10 → 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -1,13 +1,13 @@
1
1
  // src/cohere-provider.ts
2
+ import {
3
+ NoSuchModelError
4
+ } from "@ai-sdk/provider";
2
5
  import {
3
6
  loadApiKey,
4
7
  withoutTrailingSlash
5
8
  } from "@ai-sdk/provider-utils";
6
9
 
7
10
  // src/cohere-chat-language-model.ts
8
- import {
9
- UnsupportedFunctionalityError as UnsupportedFunctionalityError3
10
- } from "@ai-sdk/provider";
11
11
  import {
12
12
  combineHeaders,
13
13
  createEventSourceResponseHandler,
@@ -47,9 +47,9 @@ function convertToCohereChatPrompt(prompt) {
47
47
  case "text": {
48
48
  return part.text;
49
49
  }
50
- case "image": {
50
+ case "file": {
51
51
  throw new UnsupportedFunctionalityError({
52
- functionality: "image-part"
52
+ functionality: "File URL data"
53
53
  });
54
54
  }
55
55
  }
@@ -127,9 +127,11 @@ function mapCohereFinishReason(finishReason) {
127
127
  import {
128
128
  UnsupportedFunctionalityError as UnsupportedFunctionalityError2
129
129
  } from "@ai-sdk/provider";
130
- function prepareTools(mode) {
131
- var _a;
132
- const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
130
+ function prepareTools({
131
+ tools,
132
+ toolChoice
133
+ }) {
134
+ tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
133
135
  const toolWarnings = [];
134
136
  if (tools == null) {
135
137
  return { tools: void 0, toolChoice: void 0, toolWarnings };
@@ -149,7 +151,6 @@ function prepareTools(mode) {
149
151
  });
150
152
  }
151
153
  }
152
- const toolChoice = mode.toolChoice;
153
154
  if (toolChoice == null) {
154
155
  return { tools: cohereTools, toolChoice: void 0, toolWarnings };
155
156
  }
@@ -172,7 +173,7 @@ function prepareTools(mode) {
172
173
  default: {
173
174
  const _exhaustiveCheck = type;
174
175
  throw new UnsupportedFunctionalityError2({
175
- functionality: `Unsupported tool choice type: ${_exhaustiveCheck}`
176
+ functionality: `tool choice type: ${_exhaustiveCheck}`
176
177
  });
177
178
  }
178
179
  }
@@ -180,20 +181,20 @@ function prepareTools(mode) {
180
181
 
181
182
  // src/cohere-chat-language-model.ts
182
183
  var CohereChatLanguageModel = class {
183
- constructor(modelId, settings, config) {
184
- this.specificationVersion = "v1";
185
- this.defaultObjectGenerationMode = "json";
184
+ constructor(modelId, config) {
185
+ this.specificationVersion = "v2";
186
+ this.supportedUrls = {
187
+ // No URLs are supported.
188
+ };
186
189
  this.modelId = modelId;
187
- this.settings = settings;
188
190
  this.config = config;
189
191
  }
190
192
  get provider() {
191
193
  return this.config.provider;
192
194
  }
193
195
  getArgs({
194
- mode,
195
196
  prompt,
196
- maxTokens,
197
+ maxOutputTokens,
197
198
  temperature,
198
199
  topP,
199
200
  topK,
@@ -201,78 +202,42 @@ var CohereChatLanguageModel = class {
201
202
  presencePenalty,
202
203
  stopSequences,
203
204
  responseFormat,
204
- seed
205
+ seed,
206
+ tools,
207
+ toolChoice
205
208
  }) {
206
- var _a;
207
- const type = mode.type;
208
209
  const chatPrompt = convertToCohereChatPrompt(prompt);
209
- const baseArgs = {
210
- // model id:
211
- model: this.modelId,
212
- // standardized settings:
213
- frequency_penalty: frequencyPenalty,
214
- presence_penalty: presencePenalty,
215
- max_tokens: maxTokens,
216
- temperature,
217
- p: topP,
218
- k: topK,
219
- seed,
220
- stop_sequences: stopSequences,
221
- // response format:
222
- response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0,
223
- // messages:
224
- messages: chatPrompt
210
+ const {
211
+ tools: cohereTools,
212
+ toolChoice: cohereToolChoice,
213
+ toolWarnings
214
+ } = prepareTools({ tools, toolChoice });
215
+ return {
216
+ args: {
217
+ // model id:
218
+ model: this.modelId,
219
+ // standardized settings:
220
+ frequency_penalty: frequencyPenalty,
221
+ presence_penalty: presencePenalty,
222
+ max_tokens: maxOutputTokens,
223
+ temperature,
224
+ p: topP,
225
+ k: topK,
226
+ seed,
227
+ stop_sequences: stopSequences,
228
+ // response format:
229
+ response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", json_schema: responseFormat.schema } : void 0,
230
+ // messages:
231
+ messages: chatPrompt,
232
+ // tools:
233
+ tools: cohereTools,
234
+ tool_choice: cohereToolChoice
235
+ },
236
+ warnings: toolWarnings
225
237
  };
226
- switch (type) {
227
- case "regular": {
228
- const { tools, toolChoice, toolWarnings } = prepareTools(mode);
229
- return {
230
- args: {
231
- ...baseArgs,
232
- tools,
233
- tool_choice: toolChoice
234
- },
235
- warnings: toolWarnings
236
- };
237
- }
238
- case "object-json": {
239
- return {
240
- args: {
241
- ...baseArgs,
242
- response_format: mode.schema == null ? { type: "json_object" } : { type: "json_object", json_schema: mode.schema }
243
- },
244
- warnings: []
245
- };
246
- }
247
- case "object-tool": {
248
- return {
249
- args: {
250
- ...baseArgs,
251
- tools: [
252
- {
253
- type: "function",
254
- function: {
255
- name: mode.tool.name,
256
- description: (_a = mode.tool.description) != null ? _a : "",
257
- parameters: mode.tool.parameters
258
- }
259
- }
260
- ],
261
- tool_choice: "REQUIRED"
262
- },
263
- warnings: []
264
- };
265
- }
266
- default: {
267
- const _exhaustiveCheck = type;
268
- throw new UnsupportedFunctionalityError3({
269
- functionality: `Unsupported mode: ${_exhaustiveCheck}`
270
- });
271
- }
272
- }
273
238
  }
274
239
  async doGenerate(options) {
275
- var _a, _b, _c, _d;
240
+ var _a, _b, _c;
276
241
  const { args, warnings } = this.getArgs(options);
277
242
  const {
278
243
  responseHeaders,
@@ -289,38 +254,40 @@ var CohereChatLanguageModel = class {
289
254
  abortSignal: options.abortSignal,
290
255
  fetch: this.config.fetch
291
256
  });
292
- const { messages, ...rawSettings } = args;
293
- const text = (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "";
257
+ const content = [];
258
+ const text = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text;
259
+ if (text != null && text.length > 0) {
260
+ content.push({ type: "text", text });
261
+ }
262
+ if (response.message.tool_calls != null) {
263
+ for (const toolCall of response.message.tool_calls) {
264
+ content.push({
265
+ type: "tool-call",
266
+ toolCallId: toolCall.id,
267
+ toolName: toolCall.function.name,
268
+ // Cohere sometimes returns `null` for tool call arguments for tools
269
+ // defined as having no arguments.
270
+ args: toolCall.function.arguments.replace(/^null$/, "{}"),
271
+ toolCallType: "function"
272
+ });
273
+ }
274
+ }
294
275
  return {
295
- text,
296
- toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
297
- toolCallId: toolCall.id,
298
- toolName: toolCall.function.name,
299
- // Cohere sometimes returns `null` for tool call arguments for tools
300
- // defined as having no arguments.
301
- args: toolCall.function.arguments.replace(/^null$/, "{}"),
302
- toolCallType: "function"
303
- })) : [],
276
+ content,
304
277
  finishReason: mapCohereFinishReason(response.finish_reason),
305
278
  usage: {
306
- promptTokens: response.usage.tokens.input_tokens,
307
- completionTokens: response.usage.tokens.output_tokens
308
- },
309
- rawCall: {
310
- rawPrompt: {
311
- messages
312
- },
313
- rawSettings
279
+ inputTokens: response.usage.tokens.input_tokens,
280
+ outputTokens: response.usage.tokens.output_tokens,
281
+ totalTokens: response.usage.tokens.input_tokens + response.usage.tokens.output_tokens
314
282
  },
283
+ request: { body: args },
315
284
  response: {
316
- id: (_d = response.generation_id) != null ? _d : void 0
317
- },
318
- rawResponse: {
285
+ // TODO timestamp, model id
286
+ id: (_c = response.generation_id) != null ? _c : void 0,
319
287
  headers: responseHeaders,
320
288
  body: rawResponse
321
289
  },
322
- warnings,
323
- request: { body: JSON.stringify(args) }
290
+ warnings
324
291
  };
325
292
  }
326
293
  async doStream(options) {
@@ -336,11 +303,11 @@ var CohereChatLanguageModel = class {
336
303
  abortSignal: options.abortSignal,
337
304
  fetch: this.config.fetch
338
305
  });
339
- const { messages, ...rawSettings } = args;
340
306
  let finishReason = "unknown";
341
- let usage = {
342
- promptTokens: Number.NaN,
343
- completionTokens: Number.NaN
307
+ const usage = {
308
+ inputTokens: void 0,
309
+ outputTokens: void 0,
310
+ totalTokens: void 0
344
311
  };
345
312
  let pendingToolCallDelta = {
346
313
  toolCallId: "",
@@ -350,6 +317,9 @@ var CohereChatLanguageModel = class {
350
317
  return {
351
318
  stream: response.pipeThrough(
352
319
  new TransformStream({
320
+ start(controller) {
321
+ controller.enqueue({ type: "stream-start", warnings });
322
+ },
353
323
  transform(chunk, controller) {
354
324
  var _a, _b;
355
325
  if (!chunk.success) {
@@ -362,8 +332,8 @@ var CohereChatLanguageModel = class {
362
332
  switch (type) {
363
333
  case "content-delta": {
364
334
  controller.enqueue({
365
- type: "text-delta",
366
- textDelta: value.delta.message.content.text
335
+ type: "text",
336
+ text: value.delta.message.content.text
367
337
  });
368
338
  return;
369
339
  }
@@ -422,10 +392,9 @@ var CohereChatLanguageModel = class {
422
392
  case "message-end": {
423
393
  finishReason = mapCohereFinishReason(value.delta.finish_reason);
424
394
  const tokens = value.delta.usage.tokens;
425
- usage = {
426
- promptTokens: tokens.input_tokens,
427
- completionTokens: tokens.output_tokens
428
- };
395
+ usage.inputTokens = tokens.input_tokens;
396
+ usage.outputTokens = tokens.output_tokens;
397
+ usage.totalTokens = tokens.input_tokens + tokens.output_tokens;
429
398
  }
430
399
  default: {
431
400
  return;
@@ -441,15 +410,8 @@ var CohereChatLanguageModel = class {
441
410
  }
442
411
  })
443
412
  ),
444
- rawCall: {
445
- rawPrompt: {
446
- messages
447
- },
448
- rawSettings
449
- },
450
- rawResponse: { headers: responseHeaders },
451
- warnings,
452
- request: { body: JSON.stringify({ ...args, stream: true }) }
413
+ request: { body: { ...args, stream: true } },
414
+ response: { headers: responseHeaders }
453
415
  };
454
416
  }
455
417
  };
@@ -577,16 +539,41 @@ import {
577
539
  import {
578
540
  combineHeaders as combineHeaders2,
579
541
  createJsonResponseHandler as createJsonResponseHandler2,
542
+ parseProviderOptions,
580
543
  postJsonToApi as postJsonToApi2
581
544
  } from "@ai-sdk/provider-utils";
545
+ import { z as z4 } from "zod";
546
+
547
+ // src/cohere-embedding-options.ts
582
548
  import { z as z3 } from "zod";
549
+ var cohereEmbeddingOptions = z3.object({
550
+ /**
551
+ * Specifies the type of input passed to the model. Default is `search_query`.
552
+ *
553
+ * - "search_document": Used for embeddings stored in a vector database for search use-cases.
554
+ * - "search_query": Used for embeddings of search queries run against a vector DB to find relevant documents.
555
+ * - "classification": Used for embeddings passed through a text classifier.
556
+ * - "clustering": Used for embeddings run through a clustering algorithm.
557
+ */
558
+ inputType: z3.enum(["search_document", "search_query", "classification", "clustering"]).optional(),
559
+ /**
560
+ * Specifies how the API will handle inputs longer than the maximum token length.
561
+ * Default is `END`.
562
+ *
563
+ * - "NONE": If selected, when the input exceeds the maximum input token length will return an error.
564
+ * - "START": Will discard the start of the input until the remaining input is exactly the maximum input token length for the model.
565
+ * - "END": Will discard the end of the input until the remaining input is exactly the maximum input token length for the model.
566
+ */
567
+ truncate: z3.enum(["NONE", "START", "END"]).optional()
568
+ });
569
+
570
+ // src/cohere-embedding-model.ts
583
571
  var CohereEmbeddingModel = class {
584
- constructor(modelId, settings, config) {
585
- this.specificationVersion = "v1";
572
+ constructor(modelId, config) {
573
+ this.specificationVersion = "v2";
586
574
  this.maxEmbeddingsPerCall = 96;
587
575
  this.supportsParallelCalls = true;
588
576
  this.modelId = modelId;
589
- this.settings = settings;
590
577
  this.config = config;
591
578
  }
592
579
  get provider() {
@@ -595,9 +582,15 @@ var CohereEmbeddingModel = class {
595
582
  async doEmbed({
596
583
  values,
597
584
  headers,
598
- abortSignal
585
+ abortSignal,
586
+ providerOptions
599
587
  }) {
600
588
  var _a;
589
+ const embeddingOptions = await parseProviderOptions({
590
+ provider: "cohere",
591
+ providerOptions,
592
+ schema: cohereEmbeddingOptions
593
+ });
601
594
  if (values.length > this.maxEmbeddingsPerCall) {
602
595
  throw new TooManyEmbeddingValuesForCallError({
603
596
  provider: this.provider,
@@ -606,7 +599,11 @@ var CohereEmbeddingModel = class {
606
599
  values
607
600
  });
608
601
  }
609
- const { responseHeaders, value: response } = await postJsonToApi2({
602
+ const {
603
+ responseHeaders,
604
+ value: response,
605
+ rawValue
606
+ } = await postJsonToApi2({
610
607
  url: `${this.config.baseURL}/embed`,
611
608
  headers: combineHeaders2(this.config.headers(), headers),
612
609
  body: {
@@ -616,8 +613,8 @@ var CohereEmbeddingModel = class {
616
613
  // https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
617
614
  embedding_types: ["float"],
618
615
  texts: values,
619
- input_type: (_a = this.settings.inputType) != null ? _a : "search_query",
620
- truncate: this.settings.truncate
616
+ input_type: (_a = embeddingOptions == null ? void 0 : embeddingOptions.inputType) != null ? _a : "search_query",
617
+ truncate: embeddingOptions == null ? void 0 : embeddingOptions.truncate
621
618
  },
622
619
  failedResponseHandler: cohereFailedResponseHandler,
623
620
  successfulResponseHandler: createJsonResponseHandler2(
@@ -629,17 +626,17 @@ var CohereEmbeddingModel = class {
629
626
  return {
630
627
  embeddings: response.embeddings.float,
631
628
  usage: { tokens: response.meta.billed_units.input_tokens },
632
- rawResponse: { headers: responseHeaders }
629
+ response: { headers: responseHeaders, body: rawValue }
633
630
  };
634
631
  }
635
632
  };
636
- var cohereTextEmbeddingResponseSchema = z3.object({
637
- embeddings: z3.object({
638
- float: z3.array(z3.array(z3.number()))
633
+ var cohereTextEmbeddingResponseSchema = z4.object({
634
+ embeddings: z4.object({
635
+ float: z4.array(z4.array(z4.number()))
639
636
  }),
640
- meta: z3.object({
641
- billed_units: z3.object({
642
- input_tokens: z3.number()
637
+ meta: z4.object({
638
+ billed_units: z4.object({
639
+ input_tokens: z4.number()
643
640
  })
644
641
  })
645
642
  });
@@ -656,29 +653,32 @@ function createCohere(options = {}) {
656
653
  })}`,
657
654
  ...options.headers
658
655
  });
659
- const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, {
656
+ const createChatModel = (modelId) => new CohereChatLanguageModel(modelId, {
660
657
  provider: "cohere.chat",
661
658
  baseURL,
662
659
  headers: getHeaders,
663
660
  fetch: options.fetch
664
661
  });
665
- const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, {
662
+ const createTextEmbeddingModel = (modelId) => new CohereEmbeddingModel(modelId, {
666
663
  provider: "cohere.textEmbedding",
667
664
  baseURL,
668
665
  headers: getHeaders,
669
666
  fetch: options.fetch
670
667
  });
671
- const provider = function(modelId, settings) {
668
+ const provider = function(modelId) {
672
669
  if (new.target) {
673
670
  throw new Error(
674
671
  "The Cohere model function cannot be called with the new keyword."
675
672
  );
676
673
  }
677
- return createChatModel(modelId, settings);
674
+ return createChatModel(modelId);
678
675
  };
679
676
  provider.languageModel = createChatModel;
680
677
  provider.embedding = createTextEmbeddingModel;
681
678
  provider.textEmbeddingModel = createTextEmbeddingModel;
679
+ provider.imageModel = (modelId) => {
680
+ throw new NoSuchModelError({ modelId, modelType: "imageModel" });
681
+ };
682
682
  return provider;
683
683
  }
684
684
  var cohere = createCohere();