@core-ai/mistral 0.2.1 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.js +290 -66
- package/package.json +2 -2
package/dist/index.js
CHANGED
|
@@ -2,12 +2,18 @@
|
|
|
2
2
|
import { Mistral } from "@mistralai/mistralai";
|
|
3
3
|
|
|
4
4
|
// src/chat-model.ts
|
|
5
|
-
import {
|
|
5
|
+
import {
|
|
6
|
+
StructuredOutputNoObjectGeneratedError,
|
|
7
|
+
StructuredOutputParseError,
|
|
8
|
+
StructuredOutputValidationError,
|
|
9
|
+
createObjectStreamResult,
|
|
10
|
+
createStreamResult
|
|
11
|
+
} from "@core-ai/core-ai";
|
|
6
12
|
|
|
7
13
|
// src/chat-adapter.ts
|
|
8
|
-
import { MistralError } from "@mistralai/mistralai/models/errors";
|
|
9
14
|
import { zodToJsonSchema } from "zod-to-json-schema";
|
|
10
|
-
|
|
15
|
+
var DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME = "core_ai_generate_object";
|
|
16
|
+
var DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION = "Return a JSON object that matches the requested schema.";
|
|
11
17
|
function convertMessages(messages) {
|
|
12
18
|
return messages.map(convertMessage);
|
|
13
19
|
}
|
|
@@ -89,41 +95,69 @@ function convertToolChoice(choice) {
|
|
|
89
95
|
}
|
|
90
96
|
};
|
|
91
97
|
}
|
|
98
|
+
function getStructuredOutputToolName(options) {
|
|
99
|
+
const trimmedName = options.schemaName?.trim();
|
|
100
|
+
if (trimmedName && trimmedName.length > 0) {
|
|
101
|
+
return trimmedName;
|
|
102
|
+
}
|
|
103
|
+
return DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME;
|
|
104
|
+
}
|
|
105
|
+
function createStructuredOutputOptions(options) {
|
|
106
|
+
const toolName = getStructuredOutputToolName(options);
|
|
107
|
+
return {
|
|
108
|
+
messages: options.messages,
|
|
109
|
+
tools: {
|
|
110
|
+
structured_output: {
|
|
111
|
+
name: toolName,
|
|
112
|
+
description: options.schemaDescription ?? DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION,
|
|
113
|
+
parameters: options.schema
|
|
114
|
+
}
|
|
115
|
+
},
|
|
116
|
+
toolChoice: {
|
|
117
|
+
type: "tool",
|
|
118
|
+
toolName
|
|
119
|
+
},
|
|
120
|
+
config: options.config,
|
|
121
|
+
providerOptions: options.providerOptions,
|
|
122
|
+
signal: options.signal
|
|
123
|
+
};
|
|
124
|
+
}
|
|
92
125
|
function createGenerateRequest(modelId, options) {
|
|
93
126
|
const baseRequest = {
|
|
94
|
-
|
|
95
|
-
messages: convertMessages(options.messages),
|
|
96
|
-
...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
|
|
97
|
-
...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
|
|
98
|
-
...options.config?.temperature !== void 0 ? { temperature: options.config.temperature } : {},
|
|
99
|
-
...options.config?.maxTokens !== void 0 ? { maxTokens: options.config.maxTokens } : {},
|
|
100
|
-
...options.config?.topP !== void 0 ? { topP: options.config.topP } : {},
|
|
101
|
-
...options.config?.stopSequences ? { stop: options.config.stopSequences } : {},
|
|
102
|
-
...options.config?.frequencyPenalty !== void 0 ? { frequencyPenalty: options.config.frequencyPenalty } : {},
|
|
103
|
-
...options.config?.presencePenalty !== void 0 ? { presencePenalty: options.config.presencePenalty } : {}
|
|
127
|
+
...createRequestBase(modelId, options)
|
|
104
128
|
};
|
|
105
|
-
return options.providerOptions
|
|
106
|
-
...baseRequest,
|
|
107
|
-
...options.providerOptions
|
|
108
|
-
} : baseRequest;
|
|
129
|
+
return mergeProviderOptions(baseRequest, options.providerOptions);
|
|
109
130
|
}
|
|
110
131
|
function createStreamRequest(modelId, options) {
|
|
111
132
|
const baseRequest = {
|
|
133
|
+
...createRequestBase(modelId, options),
|
|
134
|
+
stream: true
|
|
135
|
+
};
|
|
136
|
+
return mergeProviderOptions(baseRequest, options.providerOptions);
|
|
137
|
+
}
|
|
138
|
+
function createRequestBase(modelId, options) {
|
|
139
|
+
return {
|
|
112
140
|
model: modelId,
|
|
113
141
|
messages: convertMessages(options.messages),
|
|
114
|
-
stream: true,
|
|
115
142
|
...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
|
|
116
143
|
...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
|
|
117
|
-
...options.config
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
...
|
|
144
|
+
...mapConfigToRequestFields(options.config)
|
|
145
|
+
};
|
|
146
|
+
}
|
|
147
|
+
function mapConfigToRequestFields(config) {
|
|
148
|
+
return {
|
|
149
|
+
...config?.temperature !== void 0 ? { temperature: config.temperature } : {},
|
|
150
|
+
...config?.maxTokens !== void 0 ? { maxTokens: config.maxTokens } : {},
|
|
151
|
+
...config?.topP !== void 0 ? { topP: config.topP } : {},
|
|
152
|
+
...config?.stopSequences ? { stop: config.stopSequences } : {},
|
|
153
|
+
...config?.frequencyPenalty !== void 0 ? { frequencyPenalty: config.frequencyPenalty } : {},
|
|
154
|
+
...config?.presencePenalty !== void 0 ? { presencePenalty: config.presencePenalty } : {}
|
|
123
155
|
};
|
|
124
|
-
|
|
156
|
+
}
|
|
157
|
+
function mergeProviderOptions(baseRequest, providerOptions) {
|
|
158
|
+
return providerOptions ? {
|
|
125
159
|
...baseRequest,
|
|
126
|
-
...
|
|
160
|
+
...providerOptions
|
|
127
161
|
} : baseRequest;
|
|
128
162
|
}
|
|
129
163
|
function mapGenerateResponse(response) {
|
|
@@ -152,8 +186,13 @@ async function* transformStream(stream) {
|
|
|
152
186
|
let usage = {
|
|
153
187
|
inputTokens: 0,
|
|
154
188
|
outputTokens: 0,
|
|
155
|
-
|
|
156
|
-
|
|
189
|
+
inputTokenDetails: {
|
|
190
|
+
cacheReadTokens: 0,
|
|
191
|
+
cacheWriteTokens: 0
|
|
192
|
+
},
|
|
193
|
+
outputTokenDetails: {
|
|
194
|
+
reasoningTokens: 0
|
|
195
|
+
}
|
|
157
196
|
};
|
|
158
197
|
for await (const event of stream) {
|
|
159
198
|
const chunk = event.data;
|
|
@@ -171,7 +210,10 @@ async function* transformStream(stream) {
|
|
|
171
210
|
};
|
|
172
211
|
}
|
|
173
212
|
if (choice.delta.toolCalls) {
|
|
174
|
-
for (const [
|
|
213
|
+
for (const [
|
|
214
|
+
position,
|
|
215
|
+
partialToolCall
|
|
216
|
+
] of choice.delta.toolCalls.entries()) {
|
|
175
217
|
const streamIndex = partialToolCall.index ?? position;
|
|
176
218
|
const current = bufferedToolCalls.get(streamIndex) ?? {
|
|
177
219
|
id: partialToolCall.id ?? `tool-${streamIndex}`,
|
|
@@ -270,8 +312,13 @@ function mapUsage(usage) {
|
|
|
270
312
|
return {
|
|
271
313
|
inputTokens: usage?.promptTokens ?? 0,
|
|
272
314
|
outputTokens: usage?.completionTokens ?? 0,
|
|
273
|
-
|
|
274
|
-
|
|
315
|
+
inputTokenDetails: {
|
|
316
|
+
cacheReadTokens: 0,
|
|
317
|
+
cacheWriteTokens: 0
|
|
318
|
+
},
|
|
319
|
+
outputTokenDetails: {
|
|
320
|
+
reasoningTokens: 0
|
|
321
|
+
}
|
|
275
322
|
};
|
|
276
323
|
}
|
|
277
324
|
function extractTextContent(content) {
|
|
@@ -291,7 +338,9 @@ function extractTextDeltas(content) {
|
|
|
291
338
|
if (!content || content.length === 0) {
|
|
292
339
|
return [];
|
|
293
340
|
}
|
|
294
|
-
return content.flatMap(
|
|
341
|
+
return content.flatMap(
|
|
342
|
+
(chunk) => chunk.type === "text" ? [chunk.text] : []
|
|
343
|
+
);
|
|
295
344
|
}
|
|
296
345
|
function serializeJsonObject(value) {
|
|
297
346
|
const objectValue = asObject(value);
|
|
@@ -317,9 +366,18 @@ function asObject(value) {
|
|
|
317
366
|
}
|
|
318
367
|
return {};
|
|
319
368
|
}
|
|
320
|
-
|
|
369
|
+
|
|
370
|
+
// src/mistral-error.ts
|
|
371
|
+
import { MistralError } from "@mistralai/mistralai/models/errors";
|
|
372
|
+
import { ProviderError } from "@core-ai/core-ai";
|
|
373
|
+
function wrapMistralError(error) {
|
|
321
374
|
if (error instanceof MistralError) {
|
|
322
|
-
return new ProviderError(
|
|
375
|
+
return new ProviderError(
|
|
376
|
+
error.message,
|
|
377
|
+
"mistral",
|
|
378
|
+
error.statusCode,
|
|
379
|
+
error
|
|
380
|
+
);
|
|
323
381
|
}
|
|
324
382
|
return new ProviderError(
|
|
325
383
|
error instanceof Error ? error.message : String(error),
|
|
@@ -331,35 +389,212 @@ function wrapError(error) {
|
|
|
331
389
|
|
|
332
390
|
// src/chat-model.ts
|
|
333
391
|
function createMistralChatModel(client, modelId) {
|
|
392
|
+
const provider = "mistral";
|
|
393
|
+
async function callMistralChatApi(call) {
|
|
394
|
+
try {
|
|
395
|
+
return await call();
|
|
396
|
+
} catch (error) {
|
|
397
|
+
throw wrapMistralError(error);
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
async function generateChat(options) {
|
|
401
|
+
const request = createGenerateRequest(modelId, options);
|
|
402
|
+
const response = await callMistralChatApi(
|
|
403
|
+
() => client.chat.complete(request)
|
|
404
|
+
);
|
|
405
|
+
return mapGenerateResponse(response);
|
|
406
|
+
}
|
|
407
|
+
async function streamChat(options) {
|
|
408
|
+
const request = createStreamRequest(modelId, options);
|
|
409
|
+
const stream = await callMistralChatApi(
|
|
410
|
+
() => client.chat.stream(request)
|
|
411
|
+
);
|
|
412
|
+
return createStreamResult(transformStream(stream));
|
|
413
|
+
}
|
|
334
414
|
return {
|
|
335
|
-
provider
|
|
415
|
+
provider,
|
|
336
416
|
modelId,
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
417
|
+
generate: generateChat,
|
|
418
|
+
stream: streamChat,
|
|
419
|
+
async generateObject(options) {
|
|
420
|
+
const structuredOptions = createStructuredOutputOptions(options);
|
|
421
|
+
const result = await generateChat(structuredOptions);
|
|
422
|
+
const toolName = getStructuredOutputToolName(options);
|
|
423
|
+
const object = extractStructuredObject(
|
|
424
|
+
result,
|
|
425
|
+
options.schema,
|
|
426
|
+
provider,
|
|
427
|
+
toolName
|
|
428
|
+
);
|
|
429
|
+
return {
|
|
430
|
+
object,
|
|
431
|
+
finishReason: result.finishReason,
|
|
432
|
+
usage: result.usage
|
|
433
|
+
};
|
|
345
434
|
},
|
|
346
|
-
async
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
435
|
+
async streamObject(options) {
|
|
436
|
+
const structuredOptions = createStructuredOutputOptions(options);
|
|
437
|
+
const stream = await streamChat(structuredOptions);
|
|
438
|
+
const toolName = getStructuredOutputToolName(options);
|
|
439
|
+
return createObjectStreamResult(
|
|
440
|
+
transformStructuredOutputStream(
|
|
441
|
+
stream,
|
|
442
|
+
options.schema,
|
|
443
|
+
provider,
|
|
444
|
+
toolName
|
|
445
|
+
)
|
|
446
|
+
);
|
|
447
|
+
}
|
|
448
|
+
};
|
|
449
|
+
}
|
|
450
|
+
function extractStructuredObject(result, schema, provider, toolName) {
|
|
451
|
+
const structuredToolCall = result.toolCalls.find(
|
|
452
|
+
(toolCall) => toolCall.name === toolName
|
|
453
|
+
);
|
|
454
|
+
if (structuredToolCall) {
|
|
455
|
+
return validateStructuredToolArguments(
|
|
456
|
+
schema,
|
|
457
|
+
structuredToolCall.arguments,
|
|
458
|
+
provider
|
|
459
|
+
);
|
|
460
|
+
}
|
|
461
|
+
const rawOutput = result.content?.trim();
|
|
462
|
+
if (rawOutput && rawOutput.length > 0) {
|
|
463
|
+
return parseAndValidateStructuredPayload(schema, rawOutput, provider);
|
|
464
|
+
}
|
|
465
|
+
throw new StructuredOutputNoObjectGeneratedError(
|
|
466
|
+
"model did not emit a structured object payload",
|
|
467
|
+
provider
|
|
468
|
+
);
|
|
469
|
+
}
|
|
470
|
+
async function* transformStructuredOutputStream(stream, schema, provider, toolName) {
|
|
471
|
+
let validatedObject;
|
|
472
|
+
let contentBuffer = "";
|
|
473
|
+
const toolArgumentDeltas = /* @__PURE__ */ new Map();
|
|
474
|
+
for await (const event of stream) {
|
|
475
|
+
if (event.type === "content-delta") {
|
|
476
|
+
contentBuffer += event.text;
|
|
477
|
+
yield {
|
|
478
|
+
type: "object-delta",
|
|
479
|
+
text: event.text
|
|
480
|
+
};
|
|
481
|
+
continue;
|
|
482
|
+
}
|
|
483
|
+
if (event.type === "tool-call-delta") {
|
|
484
|
+
const previous = toolArgumentDeltas.get(event.toolCallId) ?? "";
|
|
485
|
+
toolArgumentDeltas.set(
|
|
486
|
+
event.toolCallId,
|
|
487
|
+
`${previous}${event.argumentsDelta}`
|
|
488
|
+
);
|
|
489
|
+
yield {
|
|
490
|
+
type: "object-delta",
|
|
491
|
+
text: event.argumentsDelta
|
|
492
|
+
};
|
|
493
|
+
continue;
|
|
494
|
+
}
|
|
495
|
+
if (event.type === "tool-call-end" && event.toolCall.name === toolName) {
|
|
496
|
+
validatedObject = validateStructuredToolArguments(
|
|
497
|
+
schema,
|
|
498
|
+
event.toolCall.arguments,
|
|
499
|
+
provider
|
|
500
|
+
);
|
|
501
|
+
yield {
|
|
502
|
+
type: "object",
|
|
503
|
+
object: validatedObject
|
|
504
|
+
};
|
|
505
|
+
continue;
|
|
506
|
+
}
|
|
507
|
+
if (event.type === "finish") {
|
|
508
|
+
if (validatedObject === void 0) {
|
|
509
|
+
const fallbackPayload = getFallbackStructuredPayload(
|
|
510
|
+
contentBuffer,
|
|
511
|
+
toolArgumentDeltas
|
|
351
512
|
);
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
513
|
+
if (!fallbackPayload) {
|
|
514
|
+
throw new StructuredOutputNoObjectGeneratedError(
|
|
515
|
+
"structured output stream ended without an object payload",
|
|
516
|
+
provider
|
|
517
|
+
);
|
|
518
|
+
}
|
|
519
|
+
validatedObject = parseAndValidateStructuredPayload(
|
|
520
|
+
schema,
|
|
521
|
+
fallbackPayload,
|
|
522
|
+
provider
|
|
523
|
+
);
|
|
524
|
+
yield {
|
|
525
|
+
type: "object",
|
|
526
|
+
object: validatedObject
|
|
527
|
+
};
|
|
355
528
|
}
|
|
529
|
+
yield {
|
|
530
|
+
type: "finish",
|
|
531
|
+
finishReason: event.finishReason,
|
|
532
|
+
usage: event.usage
|
|
533
|
+
};
|
|
356
534
|
}
|
|
357
|
-
}
|
|
535
|
+
}
|
|
536
|
+
}
|
|
537
|
+
function getFallbackStructuredPayload(contentBuffer, toolArgumentDeltas) {
|
|
538
|
+
for (const delta of toolArgumentDeltas.values()) {
|
|
539
|
+
const trimmed = delta.trim();
|
|
540
|
+
if (trimmed.length > 0) {
|
|
541
|
+
return trimmed;
|
|
542
|
+
}
|
|
543
|
+
}
|
|
544
|
+
const trimmedContent = contentBuffer.trim();
|
|
545
|
+
if (trimmedContent.length > 0) {
|
|
546
|
+
return trimmedContent;
|
|
547
|
+
}
|
|
548
|
+
return void 0;
|
|
549
|
+
}
|
|
550
|
+
function validateStructuredToolArguments(schema, toolArguments, provider) {
|
|
551
|
+
return validateStructuredObject(
|
|
552
|
+
schema,
|
|
553
|
+
toolArguments,
|
|
554
|
+
provider,
|
|
555
|
+
JSON.stringify(toolArguments)
|
|
556
|
+
);
|
|
557
|
+
}
|
|
558
|
+
function parseAndValidateStructuredPayload(schema, rawPayload, provider) {
|
|
559
|
+
const parsedPayload = parseJson(rawPayload, provider);
|
|
560
|
+
return validateStructuredObject(schema, parsedPayload, provider, rawPayload);
|
|
561
|
+
}
|
|
562
|
+
function parseJson(rawOutput, provider) {
|
|
563
|
+
try {
|
|
564
|
+
return JSON.parse(rawOutput);
|
|
565
|
+
} catch (error) {
|
|
566
|
+
throw new StructuredOutputParseError(
|
|
567
|
+
"failed to parse structured output as JSON",
|
|
568
|
+
provider,
|
|
569
|
+
{
|
|
570
|
+
rawOutput,
|
|
571
|
+
cause: error
|
|
572
|
+
}
|
|
573
|
+
);
|
|
574
|
+
}
|
|
575
|
+
}
|
|
576
|
+
function validateStructuredObject(schema, value, provider, rawOutput) {
|
|
577
|
+
const parsed = schema.safeParse(value);
|
|
578
|
+
if (parsed.success) {
|
|
579
|
+
return parsed.data;
|
|
580
|
+
}
|
|
581
|
+
throw new StructuredOutputValidationError(
|
|
582
|
+
"structured output does not match schema",
|
|
583
|
+
provider,
|
|
584
|
+
formatZodIssues(parsed.error.issues),
|
|
585
|
+
{
|
|
586
|
+
rawOutput
|
|
587
|
+
}
|
|
588
|
+
);
|
|
589
|
+
}
|
|
590
|
+
function formatZodIssues(issues) {
|
|
591
|
+
return issues.map((issue) => {
|
|
592
|
+
const path = issue.path.length > 0 ? issue.path.map((segment) => String(segment)).join(".") : "<root>";
|
|
593
|
+
return `${path}: ${issue.message}`;
|
|
594
|
+
});
|
|
358
595
|
}
|
|
359
596
|
|
|
360
597
|
// src/embedding-model.ts
|
|
361
|
-
import { MistralError as MistralError2 } from "@mistralai/mistralai/models/errors";
|
|
362
|
-
import { ProviderError as ProviderError2 } from "@core-ai/core-ai";
|
|
363
598
|
function createMistralEmbeddingModel(client, modelId) {
|
|
364
599
|
return {
|
|
365
600
|
provider: "mistral",
|
|
@@ -383,22 +618,11 @@ function createMistralEmbeddingModel(client, modelId) {
|
|
|
383
618
|
}
|
|
384
619
|
};
|
|
385
620
|
} catch (error) {
|
|
386
|
-
throw
|
|
621
|
+
throw wrapMistralError(error);
|
|
387
622
|
}
|
|
388
623
|
}
|
|
389
624
|
};
|
|
390
625
|
}
|
|
391
|
-
function wrapError2(error) {
|
|
392
|
-
if (error instanceof MistralError2) {
|
|
393
|
-
return new ProviderError2(error.message, "mistral", error.statusCode, error);
|
|
394
|
-
}
|
|
395
|
-
return new ProviderError2(
|
|
396
|
-
error instanceof Error ? error.message : String(error),
|
|
397
|
-
"mistral",
|
|
398
|
-
void 0,
|
|
399
|
-
error
|
|
400
|
-
);
|
|
401
|
-
}
|
|
402
626
|
|
|
403
627
|
// src/provider.ts
|
|
404
628
|
function createMistral(options = {}) {
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@core-ai/mistral",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.4.0",
|
|
4
4
|
"description": "Mistral provider package for @core-ai/core-ai",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"author": "Omnifact (https://omnifact.ai)",
|
|
@@ -42,7 +42,7 @@
|
|
|
42
42
|
"test:watch": "vitest"
|
|
43
43
|
},
|
|
44
44
|
"dependencies": {
|
|
45
|
-
"@core-ai/core-ai": "^0.
|
|
45
|
+
"@core-ai/core-ai": "^0.4.0",
|
|
46
46
|
"@mistralai/mistralai": "^1.14.0",
|
|
47
47
|
"zod-to-json-schema": "^3.25.1"
|
|
48
48
|
},
|