@core-ai/mistral 0.2.1 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.js +276 -62
- package/package.json +2 -2
package/dist/index.js
CHANGED
|
@@ -2,12 +2,18 @@
|
|
|
2
2
|
import { Mistral } from "@mistralai/mistralai";
|
|
3
3
|
|
|
4
4
|
// src/chat-model.ts
|
|
5
|
-
import {
|
|
5
|
+
import {
|
|
6
|
+
StructuredOutputNoObjectGeneratedError,
|
|
7
|
+
StructuredOutputParseError,
|
|
8
|
+
StructuredOutputValidationError,
|
|
9
|
+
createObjectStreamResult,
|
|
10
|
+
createStreamResult
|
|
11
|
+
} from "@core-ai/core-ai";
|
|
6
12
|
|
|
7
13
|
// src/chat-adapter.ts
|
|
8
|
-
import { MistralError } from "@mistralai/mistralai/models/errors";
|
|
9
14
|
import { zodToJsonSchema } from "zod-to-json-schema";
|
|
10
|
-
|
|
15
|
+
var DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME = "core_ai_generate_object";
|
|
16
|
+
var DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION = "Return a JSON object that matches the requested schema.";
|
|
11
17
|
function convertMessages(messages) {
|
|
12
18
|
return messages.map(convertMessage);
|
|
13
19
|
}
|
|
@@ -89,41 +95,69 @@ function convertToolChoice(choice) {
|
|
|
89
95
|
}
|
|
90
96
|
};
|
|
91
97
|
}
|
|
98
|
+
function getStructuredOutputToolName(options) {
|
|
99
|
+
const trimmedName = options.schemaName?.trim();
|
|
100
|
+
if (trimmedName && trimmedName.length > 0) {
|
|
101
|
+
return trimmedName;
|
|
102
|
+
}
|
|
103
|
+
return DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME;
|
|
104
|
+
}
|
|
105
|
+
function createStructuredOutputOptions(options) {
|
|
106
|
+
const toolName = getStructuredOutputToolName(options);
|
|
107
|
+
return {
|
|
108
|
+
messages: options.messages,
|
|
109
|
+
tools: {
|
|
110
|
+
structured_output: {
|
|
111
|
+
name: toolName,
|
|
112
|
+
description: options.schemaDescription ?? DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION,
|
|
113
|
+
parameters: options.schema
|
|
114
|
+
}
|
|
115
|
+
},
|
|
116
|
+
toolChoice: {
|
|
117
|
+
type: "tool",
|
|
118
|
+
toolName
|
|
119
|
+
},
|
|
120
|
+
config: options.config,
|
|
121
|
+
providerOptions: options.providerOptions,
|
|
122
|
+
signal: options.signal
|
|
123
|
+
};
|
|
124
|
+
}
|
|
92
125
|
function createGenerateRequest(modelId, options) {
|
|
93
126
|
const baseRequest = {
|
|
94
|
-
|
|
95
|
-
messages: convertMessages(options.messages),
|
|
96
|
-
...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
|
|
97
|
-
...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
|
|
98
|
-
...options.config?.temperature !== void 0 ? { temperature: options.config.temperature } : {},
|
|
99
|
-
...options.config?.maxTokens !== void 0 ? { maxTokens: options.config.maxTokens } : {},
|
|
100
|
-
...options.config?.topP !== void 0 ? { topP: options.config.topP } : {},
|
|
101
|
-
...options.config?.stopSequences ? { stop: options.config.stopSequences } : {},
|
|
102
|
-
...options.config?.frequencyPenalty !== void 0 ? { frequencyPenalty: options.config.frequencyPenalty } : {},
|
|
103
|
-
...options.config?.presencePenalty !== void 0 ? { presencePenalty: options.config.presencePenalty } : {}
|
|
127
|
+
...createRequestBase(modelId, options)
|
|
104
128
|
};
|
|
105
|
-
return options.providerOptions
|
|
106
|
-
...baseRequest,
|
|
107
|
-
...options.providerOptions
|
|
108
|
-
} : baseRequest;
|
|
129
|
+
return mergeProviderOptions(baseRequest, options.providerOptions);
|
|
109
130
|
}
|
|
110
131
|
function createStreamRequest(modelId, options) {
|
|
111
132
|
const baseRequest = {
|
|
133
|
+
...createRequestBase(modelId, options),
|
|
134
|
+
stream: true
|
|
135
|
+
};
|
|
136
|
+
return mergeProviderOptions(baseRequest, options.providerOptions);
|
|
137
|
+
}
|
|
138
|
+
function createRequestBase(modelId, options) {
|
|
139
|
+
return {
|
|
112
140
|
model: modelId,
|
|
113
141
|
messages: convertMessages(options.messages),
|
|
114
|
-
stream: true,
|
|
115
142
|
...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
|
|
116
143
|
...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
|
|
117
|
-
...options.config
|
|
118
|
-
...options.config?.maxTokens !== void 0 ? { maxTokens: options.config.maxTokens } : {},
|
|
119
|
-
...options.config?.topP !== void 0 ? { topP: options.config.topP } : {},
|
|
120
|
-
...options.config?.stopSequences ? { stop: options.config.stopSequences } : {},
|
|
121
|
-
...options.config?.frequencyPenalty !== void 0 ? { frequencyPenalty: options.config.frequencyPenalty } : {},
|
|
122
|
-
...options.config?.presencePenalty !== void 0 ? { presencePenalty: options.config.presencePenalty } : {}
|
|
144
|
+
...mapConfigToRequestFields(options.config)
|
|
123
145
|
};
|
|
124
|
-
|
|
146
|
+
}
|
|
147
|
+
function mapConfigToRequestFields(config) {
|
|
148
|
+
return {
|
|
149
|
+
...config?.temperature !== void 0 ? { temperature: config.temperature } : {},
|
|
150
|
+
...config?.maxTokens !== void 0 ? { maxTokens: config.maxTokens } : {},
|
|
151
|
+
...config?.topP !== void 0 ? { topP: config.topP } : {},
|
|
152
|
+
...config?.stopSequences ? { stop: config.stopSequences } : {},
|
|
153
|
+
...config?.frequencyPenalty !== void 0 ? { frequencyPenalty: config.frequencyPenalty } : {},
|
|
154
|
+
...config?.presencePenalty !== void 0 ? { presencePenalty: config.presencePenalty } : {}
|
|
155
|
+
};
|
|
156
|
+
}
|
|
157
|
+
function mergeProviderOptions(baseRequest, providerOptions) {
|
|
158
|
+
return providerOptions ? {
|
|
125
159
|
...baseRequest,
|
|
126
|
-
...
|
|
160
|
+
...providerOptions
|
|
127
161
|
} : baseRequest;
|
|
128
162
|
}
|
|
129
163
|
function mapGenerateResponse(response) {
|
|
@@ -171,7 +205,10 @@ async function* transformStream(stream) {
|
|
|
171
205
|
};
|
|
172
206
|
}
|
|
173
207
|
if (choice.delta.toolCalls) {
|
|
174
|
-
for (const [
|
|
208
|
+
for (const [
|
|
209
|
+
position,
|
|
210
|
+
partialToolCall
|
|
211
|
+
] of choice.delta.toolCalls.entries()) {
|
|
175
212
|
const streamIndex = partialToolCall.index ?? position;
|
|
176
213
|
const current = bufferedToolCalls.get(streamIndex) ?? {
|
|
177
214
|
id: partialToolCall.id ?? `tool-${streamIndex}`,
|
|
@@ -291,7 +328,9 @@ function extractTextDeltas(content) {
|
|
|
291
328
|
if (!content || content.length === 0) {
|
|
292
329
|
return [];
|
|
293
330
|
}
|
|
294
|
-
return content.flatMap(
|
|
331
|
+
return content.flatMap(
|
|
332
|
+
(chunk) => chunk.type === "text" ? [chunk.text] : []
|
|
333
|
+
);
|
|
295
334
|
}
|
|
296
335
|
function serializeJsonObject(value) {
|
|
297
336
|
const objectValue = asObject(value);
|
|
@@ -317,9 +356,18 @@ function asObject(value) {
|
|
|
317
356
|
}
|
|
318
357
|
return {};
|
|
319
358
|
}
|
|
320
|
-
|
|
359
|
+
|
|
360
|
+
// src/mistral-error.ts
|
|
361
|
+
import { MistralError } from "@mistralai/mistralai/models/errors";
|
|
362
|
+
import { ProviderError } from "@core-ai/core-ai";
|
|
363
|
+
function wrapMistralError(error) {
|
|
321
364
|
if (error instanceof MistralError) {
|
|
322
|
-
return new ProviderError(
|
|
365
|
+
return new ProviderError(
|
|
366
|
+
error.message,
|
|
367
|
+
"mistral",
|
|
368
|
+
error.statusCode,
|
|
369
|
+
error
|
|
370
|
+
);
|
|
323
371
|
}
|
|
324
372
|
return new ProviderError(
|
|
325
373
|
error instanceof Error ? error.message : String(error),
|
|
@@ -331,35 +379,212 @@ function wrapError(error) {
|
|
|
331
379
|
|
|
332
380
|
// src/chat-model.ts
|
|
333
381
|
function createMistralChatModel(client, modelId) {
|
|
382
|
+
const provider = "mistral";
|
|
383
|
+
async function callMistralChatApi(call) {
|
|
384
|
+
try {
|
|
385
|
+
return await call();
|
|
386
|
+
} catch (error) {
|
|
387
|
+
throw wrapMistralError(error);
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
async function generateChat(options) {
|
|
391
|
+
const request = createGenerateRequest(modelId, options);
|
|
392
|
+
const response = await callMistralChatApi(
|
|
393
|
+
() => client.chat.complete(request)
|
|
394
|
+
);
|
|
395
|
+
return mapGenerateResponse(response);
|
|
396
|
+
}
|
|
397
|
+
async function streamChat(options) {
|
|
398
|
+
const request = createStreamRequest(modelId, options);
|
|
399
|
+
const stream = await callMistralChatApi(
|
|
400
|
+
() => client.chat.stream(request)
|
|
401
|
+
);
|
|
402
|
+
return createStreamResult(transformStream(stream));
|
|
403
|
+
}
|
|
334
404
|
return {
|
|
335
|
-
provider
|
|
405
|
+
provider,
|
|
336
406
|
modelId,
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
407
|
+
generate: generateChat,
|
|
408
|
+
stream: streamChat,
|
|
409
|
+
async generateObject(options) {
|
|
410
|
+
const structuredOptions = createStructuredOutputOptions(options);
|
|
411
|
+
const result = await generateChat(structuredOptions);
|
|
412
|
+
const toolName = getStructuredOutputToolName(options);
|
|
413
|
+
const object = extractStructuredObject(
|
|
414
|
+
result,
|
|
415
|
+
options.schema,
|
|
416
|
+
provider,
|
|
417
|
+
toolName
|
|
418
|
+
);
|
|
419
|
+
return {
|
|
420
|
+
object,
|
|
421
|
+
finishReason: result.finishReason,
|
|
422
|
+
usage: result.usage
|
|
423
|
+
};
|
|
345
424
|
},
|
|
346
|
-
async
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
425
|
+
async streamObject(options) {
|
|
426
|
+
const structuredOptions = createStructuredOutputOptions(options);
|
|
427
|
+
const stream = await streamChat(structuredOptions);
|
|
428
|
+
const toolName = getStructuredOutputToolName(options);
|
|
429
|
+
return createObjectStreamResult(
|
|
430
|
+
transformStructuredOutputStream(
|
|
431
|
+
stream,
|
|
432
|
+
options.schema,
|
|
433
|
+
provider,
|
|
434
|
+
toolName
|
|
435
|
+
)
|
|
436
|
+
);
|
|
437
|
+
}
|
|
438
|
+
};
|
|
439
|
+
}
|
|
440
|
+
function extractStructuredObject(result, schema, provider, toolName) {
|
|
441
|
+
const structuredToolCall = result.toolCalls.find(
|
|
442
|
+
(toolCall) => toolCall.name === toolName
|
|
443
|
+
);
|
|
444
|
+
if (structuredToolCall) {
|
|
445
|
+
return validateStructuredToolArguments(
|
|
446
|
+
schema,
|
|
447
|
+
structuredToolCall.arguments,
|
|
448
|
+
provider
|
|
449
|
+
);
|
|
450
|
+
}
|
|
451
|
+
const rawOutput = result.content?.trim();
|
|
452
|
+
if (rawOutput && rawOutput.length > 0) {
|
|
453
|
+
return parseAndValidateStructuredPayload(schema, rawOutput, provider);
|
|
454
|
+
}
|
|
455
|
+
throw new StructuredOutputNoObjectGeneratedError(
|
|
456
|
+
"model did not emit a structured object payload",
|
|
457
|
+
provider
|
|
458
|
+
);
|
|
459
|
+
}
|
|
460
|
+
async function* transformStructuredOutputStream(stream, schema, provider, toolName) {
|
|
461
|
+
let validatedObject;
|
|
462
|
+
let contentBuffer = "";
|
|
463
|
+
const toolArgumentDeltas = /* @__PURE__ */ new Map();
|
|
464
|
+
for await (const event of stream) {
|
|
465
|
+
if (event.type === "content-delta") {
|
|
466
|
+
contentBuffer += event.text;
|
|
467
|
+
yield {
|
|
468
|
+
type: "object-delta",
|
|
469
|
+
text: event.text
|
|
470
|
+
};
|
|
471
|
+
continue;
|
|
472
|
+
}
|
|
473
|
+
if (event.type === "tool-call-delta") {
|
|
474
|
+
const previous = toolArgumentDeltas.get(event.toolCallId) ?? "";
|
|
475
|
+
toolArgumentDeltas.set(
|
|
476
|
+
event.toolCallId,
|
|
477
|
+
`${previous}${event.argumentsDelta}`
|
|
478
|
+
);
|
|
479
|
+
yield {
|
|
480
|
+
type: "object-delta",
|
|
481
|
+
text: event.argumentsDelta
|
|
482
|
+
};
|
|
483
|
+
continue;
|
|
484
|
+
}
|
|
485
|
+
if (event.type === "tool-call-end" && event.toolCall.name === toolName) {
|
|
486
|
+
validatedObject = validateStructuredToolArguments(
|
|
487
|
+
schema,
|
|
488
|
+
event.toolCall.arguments,
|
|
489
|
+
provider
|
|
490
|
+
);
|
|
491
|
+
yield {
|
|
492
|
+
type: "object",
|
|
493
|
+
object: validatedObject
|
|
494
|
+
};
|
|
495
|
+
continue;
|
|
496
|
+
}
|
|
497
|
+
if (event.type === "finish") {
|
|
498
|
+
if (validatedObject === void 0) {
|
|
499
|
+
const fallbackPayload = getFallbackStructuredPayload(
|
|
500
|
+
contentBuffer,
|
|
501
|
+
toolArgumentDeltas
|
|
351
502
|
);
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
503
|
+
if (!fallbackPayload) {
|
|
504
|
+
throw new StructuredOutputNoObjectGeneratedError(
|
|
505
|
+
"structured output stream ended without an object payload",
|
|
506
|
+
provider
|
|
507
|
+
);
|
|
508
|
+
}
|
|
509
|
+
validatedObject = parseAndValidateStructuredPayload(
|
|
510
|
+
schema,
|
|
511
|
+
fallbackPayload,
|
|
512
|
+
provider
|
|
513
|
+
);
|
|
514
|
+
yield {
|
|
515
|
+
type: "object",
|
|
516
|
+
object: validatedObject
|
|
517
|
+
};
|
|
355
518
|
}
|
|
519
|
+
yield {
|
|
520
|
+
type: "finish",
|
|
521
|
+
finishReason: event.finishReason,
|
|
522
|
+
usage: event.usage
|
|
523
|
+
};
|
|
356
524
|
}
|
|
357
|
-
}
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
function getFallbackStructuredPayload(contentBuffer, toolArgumentDeltas) {
|
|
528
|
+
for (const delta of toolArgumentDeltas.values()) {
|
|
529
|
+
const trimmed = delta.trim();
|
|
530
|
+
if (trimmed.length > 0) {
|
|
531
|
+
return trimmed;
|
|
532
|
+
}
|
|
533
|
+
}
|
|
534
|
+
const trimmedContent = contentBuffer.trim();
|
|
535
|
+
if (trimmedContent.length > 0) {
|
|
536
|
+
return trimmedContent;
|
|
537
|
+
}
|
|
538
|
+
return void 0;
|
|
539
|
+
}
|
|
540
|
+
function validateStructuredToolArguments(schema, toolArguments, provider) {
|
|
541
|
+
return validateStructuredObject(
|
|
542
|
+
schema,
|
|
543
|
+
toolArguments,
|
|
544
|
+
provider,
|
|
545
|
+
JSON.stringify(toolArguments)
|
|
546
|
+
);
|
|
547
|
+
}
|
|
548
|
+
function parseAndValidateStructuredPayload(schema, rawPayload, provider) {
|
|
549
|
+
const parsedPayload = parseJson(rawPayload, provider);
|
|
550
|
+
return validateStructuredObject(schema, parsedPayload, provider, rawPayload);
|
|
551
|
+
}
|
|
552
|
+
function parseJson(rawOutput, provider) {
|
|
553
|
+
try {
|
|
554
|
+
return JSON.parse(rawOutput);
|
|
555
|
+
} catch (error) {
|
|
556
|
+
throw new StructuredOutputParseError(
|
|
557
|
+
"failed to parse structured output as JSON",
|
|
558
|
+
provider,
|
|
559
|
+
{
|
|
560
|
+
rawOutput,
|
|
561
|
+
cause: error
|
|
562
|
+
}
|
|
563
|
+
);
|
|
564
|
+
}
|
|
565
|
+
}
|
|
566
|
+
function validateStructuredObject(schema, value, provider, rawOutput) {
|
|
567
|
+
const parsed = schema.safeParse(value);
|
|
568
|
+
if (parsed.success) {
|
|
569
|
+
return parsed.data;
|
|
570
|
+
}
|
|
571
|
+
throw new StructuredOutputValidationError(
|
|
572
|
+
"structured output does not match schema",
|
|
573
|
+
provider,
|
|
574
|
+
formatZodIssues(parsed.error.issues),
|
|
575
|
+
{
|
|
576
|
+
rawOutput
|
|
577
|
+
}
|
|
578
|
+
);
|
|
579
|
+
}
|
|
580
|
+
function formatZodIssues(issues) {
|
|
581
|
+
return issues.map((issue) => {
|
|
582
|
+
const path = issue.path.length > 0 ? issue.path.map((segment) => String(segment)).join(".") : "<root>";
|
|
583
|
+
return `${path}: ${issue.message}`;
|
|
584
|
+
});
|
|
358
585
|
}
|
|
359
586
|
|
|
360
587
|
// src/embedding-model.ts
|
|
361
|
-
import { MistralError as MistralError2 } from "@mistralai/mistralai/models/errors";
|
|
362
|
-
import { ProviderError as ProviderError2 } from "@core-ai/core-ai";
|
|
363
588
|
function createMistralEmbeddingModel(client, modelId) {
|
|
364
589
|
return {
|
|
365
590
|
provider: "mistral",
|
|
@@ -383,22 +608,11 @@ function createMistralEmbeddingModel(client, modelId) {
|
|
|
383
608
|
}
|
|
384
609
|
};
|
|
385
610
|
} catch (error) {
|
|
386
|
-
throw
|
|
611
|
+
throw wrapMistralError(error);
|
|
387
612
|
}
|
|
388
613
|
}
|
|
389
614
|
};
|
|
390
615
|
}
|
|
391
|
-
function wrapError2(error) {
|
|
392
|
-
if (error instanceof MistralError2) {
|
|
393
|
-
return new ProviderError2(error.message, "mistral", error.statusCode, error);
|
|
394
|
-
}
|
|
395
|
-
return new ProviderError2(
|
|
396
|
-
error instanceof Error ? error.message : String(error),
|
|
397
|
-
"mistral",
|
|
398
|
-
void 0,
|
|
399
|
-
error
|
|
400
|
-
);
|
|
401
|
-
}
|
|
402
616
|
|
|
403
617
|
// src/provider.ts
|
|
404
618
|
function createMistral(options = {}) {
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@core-ai/mistral",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.3.0",
|
|
4
4
|
"description": "Mistral provider package for @core-ai/core-ai",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"author": "Omnifact (https://omnifact.ai)",
|
|
@@ -42,7 +42,7 @@
|
|
|
42
42
|
"test:watch": "vitest"
|
|
43
43
|
},
|
|
44
44
|
"dependencies": {
|
|
45
|
-
"@core-ai/core-ai": "^0.
|
|
45
|
+
"@core-ai/core-ai": "^0.3.0",
|
|
46
46
|
"@mistralai/mistralai": "^1.14.0",
|
|
47
47
|
"zod-to-json-schema": "^3.25.1"
|
|
48
48
|
},
|