@ai-sdk/cohere 0.0.27 → 0.0.28
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +6 -0
- package/dist/index.d.mts +1 -2
- package/dist/index.d.ts +1 -2
- package/dist/index.js +233 -201
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +234 -203
- package/dist/index.mjs.map +1 -1
- package/package.json +1 -1
package/dist/index.mjs
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
// src/cohere-provider.ts
|
|
2
2
|
import {
|
|
3
|
-
generateId,
|
|
4
3
|
loadApiKey,
|
|
5
4
|
withoutTrailingSlash
|
|
6
5
|
} from "@ai-sdk/provider-utils";
|
|
@@ -11,8 +10,8 @@ import {
|
|
|
11
10
|
} from "@ai-sdk/provider";
|
|
12
11
|
import {
|
|
13
12
|
combineHeaders,
|
|
13
|
+
createEventSourceResponseHandler,
|
|
14
14
|
createJsonResponseHandler,
|
|
15
|
-
createJsonStreamResponseHandler,
|
|
16
15
|
postJsonToApi
|
|
17
16
|
} from "@ai-sdk/provider-utils";
|
|
18
17
|
import { z as z2 } from "zod";
|
|
@@ -37,13 +36,13 @@ function convertToCohereChatPrompt(prompt) {
|
|
|
37
36
|
for (const { role, content } of prompt) {
|
|
38
37
|
switch (role) {
|
|
39
38
|
case "system": {
|
|
40
|
-
messages.push({ role: "
|
|
39
|
+
messages.push({ role: "system", content });
|
|
41
40
|
break;
|
|
42
41
|
}
|
|
43
42
|
case "user": {
|
|
44
43
|
messages.push({
|
|
45
|
-
role: "
|
|
46
|
-
|
|
44
|
+
role: "user",
|
|
45
|
+
content: content.map((part) => {
|
|
47
46
|
switch (part.type) {
|
|
48
47
|
case "text": {
|
|
49
48
|
return part.text;
|
|
@@ -69,8 +68,12 @@ function convertToCohereChatPrompt(prompt) {
|
|
|
69
68
|
}
|
|
70
69
|
case "tool-call": {
|
|
71
70
|
toolCalls.push({
|
|
72
|
-
|
|
73
|
-
|
|
71
|
+
id: part.toolCallId,
|
|
72
|
+
type: "function",
|
|
73
|
+
function: {
|
|
74
|
+
name: part.toolName,
|
|
75
|
+
arguments: JSON.stringify(part.args)
|
|
76
|
+
}
|
|
74
77
|
});
|
|
75
78
|
break;
|
|
76
79
|
}
|
|
@@ -81,31 +84,23 @@ function convertToCohereChatPrompt(prompt) {
|
|
|
81
84
|
}
|
|
82
85
|
}
|
|
83
86
|
messages.push({
|
|
84
|
-
role: "
|
|
85
|
-
|
|
87
|
+
role: "assistant",
|
|
88
|
+
// note: this is a workaround for a Cohere API bug
|
|
89
|
+
// that requires content to be provided
|
|
90
|
+
// even if there are tool calls
|
|
91
|
+
content: text !== "" ? text : "call tool",
|
|
86
92
|
tool_calls: toolCalls.length > 0 ? toolCalls : void 0
|
|
87
93
|
});
|
|
88
94
|
break;
|
|
89
95
|
}
|
|
90
96
|
case "tool": {
|
|
91
|
-
messages.push(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
/*
|
|
97
|
-
Note: Currently the tool_results field requires we pass the parameters of the tool results again. It it is blank for two reasons:
|
|
98
|
-
|
|
99
|
-
1. The parameters are already present in chat_history as a tool message
|
|
100
|
-
2. The tool core message of the ai sdk does not include parameters
|
|
101
|
-
|
|
102
|
-
It is possible to traverse through the chat history and get the parameters by id but it's currently empty since there wasn't any degradation in the output when left blank.
|
|
103
|
-
*/
|
|
104
|
-
parameters: {}
|
|
105
|
-
},
|
|
106
|
-
outputs: [toolResult.result]
|
|
97
|
+
messages.push(
|
|
98
|
+
...content.map((toolResult) => ({
|
|
99
|
+
role: "tool",
|
|
100
|
+
content: JSON.stringify(toolResult.result),
|
|
101
|
+
tool_call_id: toolResult.toolCallId
|
|
107
102
|
}))
|
|
108
|
-
|
|
103
|
+
);
|
|
109
104
|
break;
|
|
110
105
|
}
|
|
111
106
|
default: {
|
|
@@ -146,77 +141,38 @@ function prepareTools(mode) {
|
|
|
146
141
|
const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
|
|
147
142
|
const toolWarnings = [];
|
|
148
143
|
if (tools == null) {
|
|
149
|
-
return { tools: void 0,
|
|
144
|
+
return { tools: void 0, tool_choice: void 0, toolWarnings };
|
|
150
145
|
}
|
|
151
146
|
const cohereTools = [];
|
|
152
147
|
for (const tool of tools) {
|
|
153
148
|
if (tool.type === "provider-defined") {
|
|
154
149
|
toolWarnings.push({ type: "unsupported-tool", tool });
|
|
155
150
|
} else {
|
|
156
|
-
const { properties, required } = tool.parameters;
|
|
157
|
-
const parameterDefinitions = {};
|
|
158
|
-
if (properties) {
|
|
159
|
-
for (const [key, value] of Object.entries(properties)) {
|
|
160
|
-
if (typeof value === "object" && value !== null) {
|
|
161
|
-
const { type: JSONType, description } = value;
|
|
162
|
-
let type2;
|
|
163
|
-
if (typeof JSONType === "string") {
|
|
164
|
-
switch (JSONType) {
|
|
165
|
-
case "string":
|
|
166
|
-
type2 = "str";
|
|
167
|
-
break;
|
|
168
|
-
case "number":
|
|
169
|
-
type2 = "float";
|
|
170
|
-
break;
|
|
171
|
-
case "integer":
|
|
172
|
-
type2 = "int";
|
|
173
|
-
break;
|
|
174
|
-
case "boolean":
|
|
175
|
-
type2 = "bool";
|
|
176
|
-
break;
|
|
177
|
-
default:
|
|
178
|
-
throw new UnsupportedFunctionalityError2({
|
|
179
|
-
functionality: `Unsupported tool parameter type: ${JSONType}`
|
|
180
|
-
});
|
|
181
|
-
}
|
|
182
|
-
} else {
|
|
183
|
-
throw new UnsupportedFunctionalityError2({
|
|
184
|
-
functionality: `Unsupported tool parameter type: ${JSONType}`
|
|
185
|
-
});
|
|
186
|
-
}
|
|
187
|
-
parameterDefinitions[key] = {
|
|
188
|
-
required: required ? required.includes(key) : false,
|
|
189
|
-
type: type2,
|
|
190
|
-
description
|
|
191
|
-
};
|
|
192
|
-
}
|
|
193
|
-
}
|
|
194
|
-
}
|
|
195
151
|
cohereTools.push({
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
152
|
+
type: "function",
|
|
153
|
+
function: {
|
|
154
|
+
name: tool.name,
|
|
155
|
+
description: tool.description,
|
|
156
|
+
parameters: tool.parameters
|
|
157
|
+
}
|
|
199
158
|
});
|
|
200
159
|
}
|
|
201
160
|
}
|
|
202
161
|
const toolChoice = mode.toolChoice;
|
|
203
162
|
if (toolChoice == null) {
|
|
204
|
-
return { tools: cohereTools,
|
|
163
|
+
return { tools: cohereTools, tool_choice: void 0, toolWarnings };
|
|
205
164
|
}
|
|
206
165
|
const type = toolChoice.type;
|
|
207
166
|
switch (type) {
|
|
208
167
|
case "auto":
|
|
209
|
-
return { tools: cohereTools,
|
|
210
|
-
case "required":
|
|
211
|
-
return { tools: cohereTools, force_single_step: true, toolWarnings };
|
|
168
|
+
return { tools: cohereTools, tool_choice: type, toolWarnings };
|
|
212
169
|
case "none":
|
|
213
|
-
return { tools: void 0,
|
|
170
|
+
return { tools: void 0, tool_choice: "any", toolWarnings };
|
|
171
|
+
case "required":
|
|
214
172
|
case "tool":
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
toolWarnings
|
|
219
|
-
};
|
|
173
|
+
throw new UnsupportedFunctionalityError2({
|
|
174
|
+
functionality: `Unsupported tool choice type: ${type}`
|
|
175
|
+
});
|
|
220
176
|
default: {
|
|
221
177
|
const _exhaustiveCheck = type;
|
|
222
178
|
throw new UnsupportedFunctionalityError2({
|
|
@@ -253,8 +209,6 @@ var CohereChatLanguageModel = class {
|
|
|
253
209
|
}) {
|
|
254
210
|
const type = mode.type;
|
|
255
211
|
const chatPrompt = convertToCohereChatPrompt(prompt);
|
|
256
|
-
const lastMessage = chatPrompt.at(-1);
|
|
257
|
-
const history = chatPrompt.slice(0, -1);
|
|
258
212
|
const baseArgs = {
|
|
259
213
|
// model id:
|
|
260
214
|
model: this.modelId,
|
|
@@ -272,17 +226,14 @@ var CohereChatLanguageModel = class {
|
|
|
272
226
|
// response format:
|
|
273
227
|
response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? { type: "json_object", schema: responseFormat.schema } : void 0,
|
|
274
228
|
// messages:
|
|
275
|
-
|
|
276
|
-
...(lastMessage == null ? void 0 : lastMessage.role) === "TOOL" ? { tool_results: lastMessage.tool_results } : {},
|
|
277
|
-
message: lastMessage ? lastMessage.role === "USER" ? lastMessage.message : void 0 : void 0
|
|
229
|
+
messages: chatPrompt
|
|
278
230
|
};
|
|
279
231
|
switch (type) {
|
|
280
232
|
case "regular": {
|
|
281
|
-
const { tools,
|
|
233
|
+
const { tools, tool_choice, toolWarnings } = prepareTools(mode);
|
|
282
234
|
return {
|
|
283
235
|
...baseArgs,
|
|
284
236
|
tools,
|
|
285
|
-
force_single_step,
|
|
286
237
|
warnings: toolWarnings
|
|
287
238
|
};
|
|
288
239
|
}
|
|
@@ -298,13 +249,41 @@ var CohereChatLanguageModel = class {
|
|
|
298
249
|
}
|
|
299
250
|
default: {
|
|
300
251
|
const _exhaustiveCheck = type;
|
|
301
|
-
throw new
|
|
252
|
+
throw new UnsupportedFunctionalityError3({
|
|
253
|
+
functionality: `Unsupported mode: ${_exhaustiveCheck}`
|
|
254
|
+
});
|
|
302
255
|
}
|
|
303
256
|
}
|
|
304
257
|
}
|
|
258
|
+
concatenateMessageText(messages) {
|
|
259
|
+
return messages.filter(
|
|
260
|
+
(message) => "content" in message
|
|
261
|
+
).map((message) => message.content).join("");
|
|
262
|
+
}
|
|
263
|
+
/*
|
|
264
|
+
Remove `additionalProperties` and `$schema` from the `parameters` object of each tool.
|
|
265
|
+
Though these are part of JSON schema, Cohere chokes if we include them in the request.
|
|
266
|
+
*/
|
|
267
|
+
// TODO(shaper): Look at defining a type to simplify the params here and a couple of other places.
|
|
268
|
+
removeJsonSchemaExtras(tools) {
|
|
269
|
+
return tools.map((tool) => {
|
|
270
|
+
if (tool.type === "function" && tool.function.parameters && typeof tool.function.parameters === "object") {
|
|
271
|
+
const { additionalProperties, $schema, ...restParameters } = tool.function.parameters;
|
|
272
|
+
return {
|
|
273
|
+
...tool,
|
|
274
|
+
function: {
|
|
275
|
+
...tool.function,
|
|
276
|
+
parameters: restParameters
|
|
277
|
+
}
|
|
278
|
+
};
|
|
279
|
+
}
|
|
280
|
+
return tool;
|
|
281
|
+
});
|
|
282
|
+
}
|
|
305
283
|
async doGenerate(options) {
|
|
306
|
-
var _a;
|
|
284
|
+
var _a, _b, _c, _d;
|
|
307
285
|
const { warnings, ...args } = this.getArgs(options);
|
|
286
|
+
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools);
|
|
308
287
|
const { responseHeaders, value: response } = await postJsonToApi({
|
|
309
288
|
url: `${this.config.baseURL}/chat`,
|
|
310
289
|
headers: combineHeaders(this.config.headers(), options.headers),
|
|
@@ -316,30 +295,28 @@ var CohereChatLanguageModel = class {
|
|
|
316
295
|
abortSignal: options.abortSignal,
|
|
317
296
|
fetch: this.config.fetch
|
|
318
297
|
});
|
|
319
|
-
const {
|
|
320
|
-
const generateId2 = this.config.generateId;
|
|
298
|
+
const { messages, ...rawSettings } = args;
|
|
321
299
|
return {
|
|
322
|
-
text: response.text,
|
|
323
|
-
toolCalls: response.tool_calls ? response.tool_calls.map((toolCall) => ({
|
|
324
|
-
toolCallId:
|
|
325
|
-
toolName: toolCall.name,
|
|
326
|
-
args:
|
|
300
|
+
text: (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "",
|
|
301
|
+
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
|
|
302
|
+
toolCallId: toolCall.id,
|
|
303
|
+
toolName: toolCall.function.name,
|
|
304
|
+
args: toolCall.function.arguments,
|
|
327
305
|
toolCallType: "function"
|
|
328
306
|
})) : [],
|
|
329
307
|
finishReason: mapCohereFinishReason(response.finish_reason),
|
|
330
308
|
usage: {
|
|
331
|
-
promptTokens: response.
|
|
332
|
-
completionTokens: response.
|
|
309
|
+
promptTokens: response.usage.tokens.input_tokens,
|
|
310
|
+
completionTokens: response.usage.tokens.output_tokens
|
|
333
311
|
},
|
|
334
312
|
rawCall: {
|
|
335
313
|
rawPrompt: {
|
|
336
|
-
|
|
337
|
-
message
|
|
314
|
+
messages
|
|
338
315
|
},
|
|
339
316
|
rawSettings
|
|
340
317
|
},
|
|
341
318
|
response: {
|
|
342
|
-
id: (
|
|
319
|
+
id: (_d = response.generation_id) != null ? _d : void 0
|
|
343
320
|
},
|
|
344
321
|
rawResponse: { headers: responseHeaders },
|
|
345
322
|
warnings,
|
|
@@ -348,26 +325,30 @@ var CohereChatLanguageModel = class {
|
|
|
348
325
|
}
|
|
349
326
|
async doStream(options) {
|
|
350
327
|
const { warnings, ...args } = this.getArgs(options);
|
|
328
|
+
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools);
|
|
351
329
|
const body = { ...args, stream: true };
|
|
352
330
|
const { responseHeaders, value: response } = await postJsonToApi({
|
|
353
331
|
url: `${this.config.baseURL}/chat`,
|
|
354
332
|
headers: combineHeaders(this.config.headers(), options.headers),
|
|
355
333
|
body,
|
|
356
334
|
failedResponseHandler: cohereFailedResponseHandler,
|
|
357
|
-
successfulResponseHandler:
|
|
335
|
+
successfulResponseHandler: createEventSourceResponseHandler(
|
|
358
336
|
cohereChatChunkSchema
|
|
359
337
|
),
|
|
360
338
|
abortSignal: options.abortSignal,
|
|
361
339
|
fetch: this.config.fetch
|
|
362
340
|
});
|
|
363
|
-
const {
|
|
341
|
+
const { messages, ...rawSettings } = args;
|
|
364
342
|
let finishReason = "unknown";
|
|
365
343
|
let usage = {
|
|
366
344
|
promptTokens: Number.NaN,
|
|
367
345
|
completionTokens: Number.NaN
|
|
368
346
|
};
|
|
369
|
-
|
|
370
|
-
|
|
347
|
+
let pendingToolCallDelta = {
|
|
348
|
+
toolCallId: "",
|
|
349
|
+
toolName: "",
|
|
350
|
+
argsTextDelta: ""
|
|
351
|
+
};
|
|
371
352
|
return {
|
|
372
353
|
stream: response.pipeThrough(
|
|
373
354
|
new TransformStream({
|
|
@@ -379,69 +360,68 @@ var CohereChatLanguageModel = class {
|
|
|
379
360
|
return;
|
|
380
361
|
}
|
|
381
362
|
const value = chunk.value;
|
|
382
|
-
const type = value.
|
|
363
|
+
const type = value.type;
|
|
383
364
|
switch (type) {
|
|
384
|
-
case "
|
|
365
|
+
case "content-delta": {
|
|
385
366
|
controller.enqueue({
|
|
386
367
|
type: "text-delta",
|
|
387
|
-
textDelta: value.text
|
|
368
|
+
textDelta: value.delta.message.content.text
|
|
388
369
|
});
|
|
389
370
|
return;
|
|
390
371
|
}
|
|
391
|
-
case "tool-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
type: "tool-call-delta",
|
|
405
|
-
toolCallType: "function",
|
|
406
|
-
toolCallId: toolCalls[index].toolCallId,
|
|
407
|
-
toolName: toolCalls[index].toolName,
|
|
408
|
-
argsTextDelta: ""
|
|
409
|
-
});
|
|
410
|
-
} else if (value.tool_call_delta.parameters) {
|
|
411
|
-
controller.enqueue({
|
|
412
|
-
type: "tool-call-delta",
|
|
413
|
-
toolCallType: "function",
|
|
414
|
-
toolCallId: toolCalls[index].toolCallId,
|
|
415
|
-
toolName: toolCalls[index].toolName,
|
|
416
|
-
argsTextDelta: value.tool_call_delta.parameters
|
|
417
|
-
});
|
|
418
|
-
}
|
|
419
|
-
}
|
|
372
|
+
case "tool-call-start": {
|
|
373
|
+
pendingToolCallDelta = {
|
|
374
|
+
toolCallId: value.delta.message.tool_calls.id,
|
|
375
|
+
toolName: value.delta.message.tool_calls.function.name,
|
|
376
|
+
argsTextDelta: value.delta.message.tool_calls.function.arguments
|
|
377
|
+
};
|
|
378
|
+
controller.enqueue({
|
|
379
|
+
type: "tool-call-delta",
|
|
380
|
+
toolCallId: pendingToolCallDelta.toolCallId,
|
|
381
|
+
toolName: pendingToolCallDelta.toolName,
|
|
382
|
+
toolCallType: "function",
|
|
383
|
+
argsTextDelta: pendingToolCallDelta.argsTextDelta
|
|
384
|
+
});
|
|
420
385
|
return;
|
|
421
386
|
}
|
|
422
|
-
case "tool-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
});
|
|
432
|
-
}
|
|
387
|
+
case "tool-call-delta": {
|
|
388
|
+
pendingToolCallDelta.argsTextDelta += value.delta.message.tool_calls.function.arguments;
|
|
389
|
+
controller.enqueue({
|
|
390
|
+
type: "tool-call-delta",
|
|
391
|
+
toolCallId: pendingToolCallDelta.toolCallId,
|
|
392
|
+
toolName: pendingToolCallDelta.toolName,
|
|
393
|
+
toolCallType: "function",
|
|
394
|
+
argsTextDelta: value.delta.message.tool_calls.function.arguments
|
|
395
|
+
});
|
|
433
396
|
return;
|
|
434
397
|
}
|
|
435
|
-
case "
|
|
398
|
+
case "tool-call-end": {
|
|
399
|
+
controller.enqueue({
|
|
400
|
+
type: "tool-call",
|
|
401
|
+
toolCallId: pendingToolCallDelta.toolCallId,
|
|
402
|
+
toolName: pendingToolCallDelta.toolName,
|
|
403
|
+
toolCallType: "function",
|
|
404
|
+
args: JSON.stringify(
|
|
405
|
+
JSON.parse(pendingToolCallDelta.argsTextDelta)
|
|
406
|
+
)
|
|
407
|
+
});
|
|
408
|
+
pendingToolCallDelta = {
|
|
409
|
+
toolCallId: "",
|
|
410
|
+
toolName: "",
|
|
411
|
+
argsTextDelta: ""
|
|
412
|
+
};
|
|
413
|
+
return;
|
|
414
|
+
}
|
|
415
|
+
case "message-start": {
|
|
436
416
|
controller.enqueue({
|
|
437
417
|
type: "response-metadata",
|
|
438
|
-
id: (_a = value.
|
|
418
|
+
id: (_a = value.id) != null ? _a : void 0
|
|
439
419
|
});
|
|
440
420
|
return;
|
|
441
421
|
}
|
|
442
|
-
case "
|
|
443
|
-
finishReason = mapCohereFinishReason(value.finish_reason);
|
|
444
|
-
const tokens = value.
|
|
422
|
+
case "message-end": {
|
|
423
|
+
finishReason = mapCohereFinishReason(value.delta.finish_reason);
|
|
424
|
+
const tokens = value.delta.usage.tokens;
|
|
445
425
|
usage = {
|
|
446
426
|
promptTokens: tokens.input_tokens,
|
|
447
427
|
completionTokens: tokens.output_tokens
|
|
@@ -463,8 +443,7 @@ var CohereChatLanguageModel = class {
|
|
|
463
443
|
),
|
|
464
444
|
rawCall: {
|
|
465
445
|
rawPrompt: {
|
|
466
|
-
|
|
467
|
-
message
|
|
446
|
+
messages
|
|
468
447
|
},
|
|
469
448
|
rawSettings
|
|
470
449
|
},
|
|
@@ -476,68 +455,117 @@ var CohereChatLanguageModel = class {
|
|
|
476
455
|
};
|
|
477
456
|
var cohereChatResponseSchema = z2.object({
|
|
478
457
|
generation_id: z2.string().nullish(),
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
z2.
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
458
|
+
message: z2.object({
|
|
459
|
+
role: z2.string(),
|
|
460
|
+
content: z2.array(
|
|
461
|
+
z2.object({
|
|
462
|
+
type: z2.string(),
|
|
463
|
+
text: z2.string()
|
|
464
|
+
})
|
|
465
|
+
).nullish(),
|
|
466
|
+
tool_calls: z2.array(
|
|
467
|
+
z2.object({
|
|
468
|
+
id: z2.string(),
|
|
469
|
+
type: z2.literal("function"),
|
|
470
|
+
function: z2.object({
|
|
471
|
+
name: z2.string(),
|
|
472
|
+
arguments: z2.string()
|
|
473
|
+
})
|
|
474
|
+
})
|
|
475
|
+
).nullish()
|
|
476
|
+
}),
|
|
486
477
|
finish_reason: z2.string(),
|
|
487
|
-
|
|
478
|
+
usage: z2.object({
|
|
479
|
+
billed_units: z2.object({
|
|
480
|
+
input_tokens: z2.number(),
|
|
481
|
+
output_tokens: z2.number()
|
|
482
|
+
}),
|
|
488
483
|
tokens: z2.object({
|
|
489
484
|
input_tokens: z2.number(),
|
|
490
485
|
output_tokens: z2.number()
|
|
491
486
|
})
|
|
492
487
|
})
|
|
493
488
|
});
|
|
494
|
-
var cohereChatChunkSchema = z2.discriminatedUnion("
|
|
495
|
-
z2.object({
|
|
496
|
-
event_type: z2.literal("stream-start"),
|
|
497
|
-
generation_id: z2.string().nullish()
|
|
498
|
-
}),
|
|
489
|
+
var cohereChatChunkSchema = z2.discriminatedUnion("type", [
|
|
499
490
|
z2.object({
|
|
500
|
-
|
|
491
|
+
type: z2.literal("citation-start")
|
|
501
492
|
}),
|
|
502
493
|
z2.object({
|
|
503
|
-
|
|
494
|
+
type: z2.literal("citation-end")
|
|
504
495
|
}),
|
|
505
496
|
z2.object({
|
|
506
|
-
|
|
507
|
-
text: z2.string()
|
|
497
|
+
type: z2.literal("content-start")
|
|
508
498
|
}),
|
|
509
499
|
z2.object({
|
|
510
|
-
|
|
500
|
+
type: z2.literal("content-delta"),
|
|
501
|
+
delta: z2.object({
|
|
502
|
+
message: z2.object({
|
|
503
|
+
content: z2.object({
|
|
504
|
+
text: z2.string()
|
|
505
|
+
})
|
|
506
|
+
})
|
|
507
|
+
})
|
|
511
508
|
}),
|
|
512
509
|
z2.object({
|
|
513
|
-
|
|
514
|
-
tool_calls: z2.array(
|
|
515
|
-
z2.object({
|
|
516
|
-
name: z2.string(),
|
|
517
|
-
parameters: z2.unknown({})
|
|
518
|
-
})
|
|
519
|
-
)
|
|
510
|
+
type: z2.literal("content-end")
|
|
520
511
|
}),
|
|
521
512
|
z2.object({
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
tool_call_delta: z2.object({
|
|
525
|
-
index: z2.number(),
|
|
526
|
-
name: z2.string().optional(),
|
|
527
|
-
parameters: z2.string().optional()
|
|
528
|
-
}).optional()
|
|
513
|
+
type: z2.literal("message-start"),
|
|
514
|
+
id: z2.string().nullish()
|
|
529
515
|
}),
|
|
530
516
|
z2.object({
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
517
|
+
type: z2.literal("message-end"),
|
|
518
|
+
delta: z2.object({
|
|
519
|
+
finish_reason: z2.string(),
|
|
520
|
+
usage: z2.object({
|
|
535
521
|
tokens: z2.object({
|
|
536
522
|
input_tokens: z2.number(),
|
|
537
523
|
output_tokens: z2.number()
|
|
538
524
|
})
|
|
539
525
|
})
|
|
540
526
|
})
|
|
527
|
+
}),
|
|
528
|
+
// https://docs.cohere.com/v2/docs/streaming#tool-use-stream-events-for-tool-calling
|
|
529
|
+
z2.object({
|
|
530
|
+
type: z2.literal("tool-plan-delta"),
|
|
531
|
+
delta: z2.object({
|
|
532
|
+
message: z2.object({
|
|
533
|
+
tool_plan: z2.string()
|
|
534
|
+
})
|
|
535
|
+
})
|
|
536
|
+
}),
|
|
537
|
+
z2.object({
|
|
538
|
+
type: z2.literal("tool-call-start"),
|
|
539
|
+
delta: z2.object({
|
|
540
|
+
message: z2.object({
|
|
541
|
+
tool_calls: z2.object({
|
|
542
|
+
id: z2.string(),
|
|
543
|
+
type: z2.literal("function"),
|
|
544
|
+
function: z2.object({
|
|
545
|
+
name: z2.string(),
|
|
546
|
+
arguments: z2.string()
|
|
547
|
+
})
|
|
548
|
+
})
|
|
549
|
+
})
|
|
550
|
+
})
|
|
551
|
+
}),
|
|
552
|
+
// A single tool call's `arguments` stream in chunks and must be accumulated
|
|
553
|
+
// in a string and so the full tool object info can only be parsed once we see
|
|
554
|
+
// `tool-call-end`.
|
|
555
|
+
z2.object({
|
|
556
|
+
type: z2.literal("tool-call-delta"),
|
|
557
|
+
delta: z2.object({
|
|
558
|
+
message: z2.object({
|
|
559
|
+
tool_calls: z2.object({
|
|
560
|
+
function: z2.object({
|
|
561
|
+
arguments: z2.string()
|
|
562
|
+
})
|
|
563
|
+
})
|
|
564
|
+
})
|
|
565
|
+
})
|
|
566
|
+
}),
|
|
567
|
+
z2.object({
|
|
568
|
+
type: z2.literal("tool-call-end")
|
|
541
569
|
})
|
|
542
570
|
]);
|
|
543
571
|
|
|
@@ -582,6 +610,11 @@ var CohereEmbeddingModel = class {
|
|
|
582
610
|
headers: combineHeaders2(this.config.headers(), headers),
|
|
583
611
|
body: {
|
|
584
612
|
model: this.modelId,
|
|
613
|
+
// TODO(shaper): There are other embedding types. Do we need to support them?
|
|
614
|
+
// For now we only support 'float' embeddings which are also the only ones
|
|
615
|
+
// the Cohere API docs state are supported for all models.
|
|
616
|
+
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
|
|
617
|
+
embedding_types: ["float"],
|
|
585
618
|
texts: values,
|
|
586
619
|
input_type: (_a = this.settings.inputType) != null ? _a : "search_query",
|
|
587
620
|
truncate: this.settings.truncate
|
|
@@ -594,14 +627,16 @@ var CohereEmbeddingModel = class {
|
|
|
594
627
|
fetch: this.config.fetch
|
|
595
628
|
});
|
|
596
629
|
return {
|
|
597
|
-
embeddings: response.embeddings,
|
|
630
|
+
embeddings: response.embeddings.float,
|
|
598
631
|
usage: { tokens: response.meta.billed_units.input_tokens },
|
|
599
632
|
rawResponse: { headers: responseHeaders }
|
|
600
633
|
};
|
|
601
634
|
}
|
|
602
635
|
};
|
|
603
636
|
var cohereTextEmbeddingResponseSchema = z3.object({
|
|
604
|
-
embeddings: z3.
|
|
637
|
+
embeddings: z3.object({
|
|
638
|
+
float: z3.array(z3.array(z3.number()))
|
|
639
|
+
}),
|
|
605
640
|
meta: z3.object({
|
|
606
641
|
billed_units: z3.object({
|
|
607
642
|
input_tokens: z3.number()
|
|
@@ -612,7 +647,7 @@ var cohereTextEmbeddingResponseSchema = z3.object({
|
|
|
612
647
|
// src/cohere-provider.ts
|
|
613
648
|
function createCohere(options = {}) {
|
|
614
649
|
var _a;
|
|
615
|
-
const baseURL = (_a = withoutTrailingSlash(options.baseURL)) != null ? _a : "https://api.cohere.com/
|
|
650
|
+
const baseURL = (_a = withoutTrailingSlash(options.baseURL)) != null ? _a : "https://api.cohere.com/v2";
|
|
616
651
|
const getHeaders = () => ({
|
|
617
652
|
Authorization: `Bearer ${loadApiKey({
|
|
618
653
|
apiKey: options.apiKey,
|
|
@@ -621,16 +656,12 @@ function createCohere(options = {}) {
|
|
|
621
656
|
})}`,
|
|
622
657
|
...options.headers
|
|
623
658
|
});
|
|
624
|
-
const createChatModel = (modelId, settings = {}) => {
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
generateId: (_a2 = options.generateId) != null ? _a2 : generateId,
|
|
631
|
-
fetch: options.fetch
|
|
632
|
-
});
|
|
633
|
-
};
|
|
659
|
+
const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, {
|
|
660
|
+
provider: "cohere.chat",
|
|
661
|
+
baseURL,
|
|
662
|
+
headers: getHeaders,
|
|
663
|
+
fetch: options.fetch
|
|
664
|
+
});
|
|
634
665
|
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, {
|
|
635
666
|
provider: "cohere.textEmbedding",
|
|
636
667
|
baseURL,
|