@core-ai/mistral 0.2.1 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/dist/index.js +276 -62
  2. package/package.json +2 -2
package/dist/index.js CHANGED
@@ -2,12 +2,18 @@
2
2
  import { Mistral } from "@mistralai/mistralai";
3
3
 
4
4
  // src/chat-model.ts
5
- import { createStreamResult } from "@core-ai/core-ai";
5
+ import {
6
+ StructuredOutputNoObjectGeneratedError,
7
+ StructuredOutputParseError,
8
+ StructuredOutputValidationError,
9
+ createObjectStreamResult,
10
+ createStreamResult
11
+ } from "@core-ai/core-ai";
6
12
 
7
13
  // src/chat-adapter.ts
8
- import { MistralError } from "@mistralai/mistralai/models/errors";
9
14
  import { zodToJsonSchema } from "zod-to-json-schema";
10
- import { ProviderError } from "@core-ai/core-ai";
15
+ var DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME = "core_ai_generate_object";
16
+ var DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION = "Return a JSON object that matches the requested schema.";
11
17
  function convertMessages(messages) {
12
18
  return messages.map(convertMessage);
13
19
  }
@@ -89,41 +95,69 @@ function convertToolChoice(choice) {
89
95
  }
90
96
  };
91
97
  }
98
+ function getStructuredOutputToolName(options) {
99
+ const trimmedName = options.schemaName?.trim();
100
+ if (trimmedName && trimmedName.length > 0) {
101
+ return trimmedName;
102
+ }
103
+ return DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME;
104
+ }
105
+ function createStructuredOutputOptions(options) {
106
+ const toolName = getStructuredOutputToolName(options);
107
+ return {
108
+ messages: options.messages,
109
+ tools: {
110
+ structured_output: {
111
+ name: toolName,
112
+ description: options.schemaDescription ?? DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION,
113
+ parameters: options.schema
114
+ }
115
+ },
116
+ toolChoice: {
117
+ type: "tool",
118
+ toolName
119
+ },
120
+ config: options.config,
121
+ providerOptions: options.providerOptions,
122
+ signal: options.signal
123
+ };
124
+ }
92
125
  function createGenerateRequest(modelId, options) {
93
126
  const baseRequest = {
94
- model: modelId,
95
- messages: convertMessages(options.messages),
96
- ...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
97
- ...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
98
- ...options.config?.temperature !== void 0 ? { temperature: options.config.temperature } : {},
99
- ...options.config?.maxTokens !== void 0 ? { maxTokens: options.config.maxTokens } : {},
100
- ...options.config?.topP !== void 0 ? { topP: options.config.topP } : {},
101
- ...options.config?.stopSequences ? { stop: options.config.stopSequences } : {},
102
- ...options.config?.frequencyPenalty !== void 0 ? { frequencyPenalty: options.config.frequencyPenalty } : {},
103
- ...options.config?.presencePenalty !== void 0 ? { presencePenalty: options.config.presencePenalty } : {}
127
+ ...createRequestBase(modelId, options)
104
128
  };
105
- return options.providerOptions ? {
106
- ...baseRequest,
107
- ...options.providerOptions
108
- } : baseRequest;
129
+ return mergeProviderOptions(baseRequest, options.providerOptions);
109
130
  }
110
131
  function createStreamRequest(modelId, options) {
111
132
  const baseRequest = {
133
+ ...createRequestBase(modelId, options),
134
+ stream: true
135
+ };
136
+ return mergeProviderOptions(baseRequest, options.providerOptions);
137
+ }
138
+ function createRequestBase(modelId, options) {
139
+ return {
112
140
  model: modelId,
113
141
  messages: convertMessages(options.messages),
114
- stream: true,
115
142
  ...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
116
143
  ...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
117
- ...options.config?.temperature !== void 0 ? { temperature: options.config.temperature } : {},
118
- ...options.config?.maxTokens !== void 0 ? { maxTokens: options.config.maxTokens } : {},
119
- ...options.config?.topP !== void 0 ? { topP: options.config.topP } : {},
120
- ...options.config?.stopSequences ? { stop: options.config.stopSequences } : {},
121
- ...options.config?.frequencyPenalty !== void 0 ? { frequencyPenalty: options.config.frequencyPenalty } : {},
122
- ...options.config?.presencePenalty !== void 0 ? { presencePenalty: options.config.presencePenalty } : {}
144
+ ...mapConfigToRequestFields(options.config)
123
145
  };
124
- return options.providerOptions ? {
146
+ }
147
+ function mapConfigToRequestFields(config) {
148
+ return {
149
+ ...config?.temperature !== void 0 ? { temperature: config.temperature } : {},
150
+ ...config?.maxTokens !== void 0 ? { maxTokens: config.maxTokens } : {},
151
+ ...config?.topP !== void 0 ? { topP: config.topP } : {},
152
+ ...config?.stopSequences ? { stop: config.stopSequences } : {},
153
+ ...config?.frequencyPenalty !== void 0 ? { frequencyPenalty: config.frequencyPenalty } : {},
154
+ ...config?.presencePenalty !== void 0 ? { presencePenalty: config.presencePenalty } : {}
155
+ };
156
+ }
157
+ function mergeProviderOptions(baseRequest, providerOptions) {
158
+ return providerOptions ? {
125
159
  ...baseRequest,
126
- ...options.providerOptions
160
+ ...providerOptions
127
161
  } : baseRequest;
128
162
  }
129
163
  function mapGenerateResponse(response) {
@@ -171,7 +205,10 @@ async function* transformStream(stream) {
171
205
  };
172
206
  }
173
207
  if (choice.delta.toolCalls) {
174
- for (const [position, partialToolCall] of choice.delta.toolCalls.entries()) {
208
+ for (const [
209
+ position,
210
+ partialToolCall
211
+ ] of choice.delta.toolCalls.entries()) {
175
212
  const streamIndex = partialToolCall.index ?? position;
176
213
  const current = bufferedToolCalls.get(streamIndex) ?? {
177
214
  id: partialToolCall.id ?? `tool-${streamIndex}`,
@@ -291,7 +328,9 @@ function extractTextDeltas(content) {
291
328
  if (!content || content.length === 0) {
292
329
  return [];
293
330
  }
294
- return content.flatMap((chunk) => chunk.type === "text" ? [chunk.text] : []);
331
+ return content.flatMap(
332
+ (chunk) => chunk.type === "text" ? [chunk.text] : []
333
+ );
295
334
  }
296
335
  function serializeJsonObject(value) {
297
336
  const objectValue = asObject(value);
@@ -317,9 +356,18 @@ function asObject(value) {
317
356
  }
318
357
  return {};
319
358
  }
320
- function wrapError(error) {
359
+
360
+ // src/mistral-error.ts
361
+ import { MistralError } from "@mistralai/mistralai/models/errors";
362
+ import { ProviderError } from "@core-ai/core-ai";
363
+ function wrapMistralError(error) {
321
364
  if (error instanceof MistralError) {
322
- return new ProviderError(error.message, "mistral", error.statusCode, error);
365
+ return new ProviderError(
366
+ error.message,
367
+ "mistral",
368
+ error.statusCode,
369
+ error
370
+ );
323
371
  }
324
372
  return new ProviderError(
325
373
  error instanceof Error ? error.message : String(error),
@@ -331,35 +379,212 @@ function wrapError(error) {
331
379
 
332
380
  // src/chat-model.ts
333
381
  function createMistralChatModel(client, modelId) {
382
+ const provider = "mistral";
383
+ async function callMistralChatApi(call) {
384
+ try {
385
+ return await call();
386
+ } catch (error) {
387
+ throw wrapMistralError(error);
388
+ }
389
+ }
390
+ async function generateChat(options) {
391
+ const request = createGenerateRequest(modelId, options);
392
+ const response = await callMistralChatApi(
393
+ () => client.chat.complete(request)
394
+ );
395
+ return mapGenerateResponse(response);
396
+ }
397
+ async function streamChat(options) {
398
+ const request = createStreamRequest(modelId, options);
399
+ const stream = await callMistralChatApi(
400
+ () => client.chat.stream(request)
401
+ );
402
+ return createStreamResult(transformStream(stream));
403
+ }
334
404
  return {
335
- provider: "mistral",
405
+ provider,
336
406
  modelId,
337
- async generate(options) {
338
- try {
339
- const request = createGenerateRequest(modelId, options);
340
- const response = await client.chat.complete(request);
341
- return mapGenerateResponse(response);
342
- } catch (error) {
343
- throw wrapError(error);
344
- }
407
+ generate: generateChat,
408
+ stream: streamChat,
409
+ async generateObject(options) {
410
+ const structuredOptions = createStructuredOutputOptions(options);
411
+ const result = await generateChat(structuredOptions);
412
+ const toolName = getStructuredOutputToolName(options);
413
+ const object = extractStructuredObject(
414
+ result,
415
+ options.schema,
416
+ provider,
417
+ toolName
418
+ );
419
+ return {
420
+ object,
421
+ finishReason: result.finishReason,
422
+ usage: result.usage
423
+ };
345
424
  },
346
- async stream(options) {
347
- try {
348
- const request = createStreamRequest(modelId, options);
349
- const stream = await client.chat.stream(
350
- request
425
+ async streamObject(options) {
426
+ const structuredOptions = createStructuredOutputOptions(options);
427
+ const stream = await streamChat(structuredOptions);
428
+ const toolName = getStructuredOutputToolName(options);
429
+ return createObjectStreamResult(
430
+ transformStructuredOutputStream(
431
+ stream,
432
+ options.schema,
433
+ provider,
434
+ toolName
435
+ )
436
+ );
437
+ }
438
+ };
439
+ }
440
+ function extractStructuredObject(result, schema, provider, toolName) {
441
+ const structuredToolCall = result.toolCalls.find(
442
+ (toolCall) => toolCall.name === toolName
443
+ );
444
+ if (structuredToolCall) {
445
+ return validateStructuredToolArguments(
446
+ schema,
447
+ structuredToolCall.arguments,
448
+ provider
449
+ );
450
+ }
451
+ const rawOutput = result.content?.trim();
452
+ if (rawOutput && rawOutput.length > 0) {
453
+ return parseAndValidateStructuredPayload(schema, rawOutput, provider);
454
+ }
455
+ throw new StructuredOutputNoObjectGeneratedError(
456
+ "model did not emit a structured object payload",
457
+ provider
458
+ );
459
+ }
460
+ async function* transformStructuredOutputStream(stream, schema, provider, toolName) {
461
+ let validatedObject;
462
+ let contentBuffer = "";
463
+ const toolArgumentDeltas = /* @__PURE__ */ new Map();
464
+ for await (const event of stream) {
465
+ if (event.type === "content-delta") {
466
+ contentBuffer += event.text;
467
+ yield {
468
+ type: "object-delta",
469
+ text: event.text
470
+ };
471
+ continue;
472
+ }
473
+ if (event.type === "tool-call-delta") {
474
+ const previous = toolArgumentDeltas.get(event.toolCallId) ?? "";
475
+ toolArgumentDeltas.set(
476
+ event.toolCallId,
477
+ `${previous}${event.argumentsDelta}`
478
+ );
479
+ yield {
480
+ type: "object-delta",
481
+ text: event.argumentsDelta
482
+ };
483
+ continue;
484
+ }
485
+ if (event.type === "tool-call-end" && event.toolCall.name === toolName) {
486
+ validatedObject = validateStructuredToolArguments(
487
+ schema,
488
+ event.toolCall.arguments,
489
+ provider
490
+ );
491
+ yield {
492
+ type: "object",
493
+ object: validatedObject
494
+ };
495
+ continue;
496
+ }
497
+ if (event.type === "finish") {
498
+ if (validatedObject === void 0) {
499
+ const fallbackPayload = getFallbackStructuredPayload(
500
+ contentBuffer,
501
+ toolArgumentDeltas
351
502
  );
352
- return createStreamResult(transformStream(stream));
353
- } catch (error) {
354
- throw wrapError(error);
503
+ if (!fallbackPayload) {
504
+ throw new StructuredOutputNoObjectGeneratedError(
505
+ "structured output stream ended without an object payload",
506
+ provider
507
+ );
508
+ }
509
+ validatedObject = parseAndValidateStructuredPayload(
510
+ schema,
511
+ fallbackPayload,
512
+ provider
513
+ );
514
+ yield {
515
+ type: "object",
516
+ object: validatedObject
517
+ };
355
518
  }
519
+ yield {
520
+ type: "finish",
521
+ finishReason: event.finishReason,
522
+ usage: event.usage
523
+ };
356
524
  }
357
- };
525
+ }
526
+ }
527
+ function getFallbackStructuredPayload(contentBuffer, toolArgumentDeltas) {
528
+ for (const delta of toolArgumentDeltas.values()) {
529
+ const trimmed = delta.trim();
530
+ if (trimmed.length > 0) {
531
+ return trimmed;
532
+ }
533
+ }
534
+ const trimmedContent = contentBuffer.trim();
535
+ if (trimmedContent.length > 0) {
536
+ return trimmedContent;
537
+ }
538
+ return void 0;
539
+ }
540
+ function validateStructuredToolArguments(schema, toolArguments, provider) {
541
+ return validateStructuredObject(
542
+ schema,
543
+ toolArguments,
544
+ provider,
545
+ JSON.stringify(toolArguments)
546
+ );
547
+ }
548
+ function parseAndValidateStructuredPayload(schema, rawPayload, provider) {
549
+ const parsedPayload = parseJson(rawPayload, provider);
550
+ return validateStructuredObject(schema, parsedPayload, provider, rawPayload);
551
+ }
552
+ function parseJson(rawOutput, provider) {
553
+ try {
554
+ return JSON.parse(rawOutput);
555
+ } catch (error) {
556
+ throw new StructuredOutputParseError(
557
+ "failed to parse structured output as JSON",
558
+ provider,
559
+ {
560
+ rawOutput,
561
+ cause: error
562
+ }
563
+ );
564
+ }
565
+ }
566
+ function validateStructuredObject(schema, value, provider, rawOutput) {
567
+ const parsed = schema.safeParse(value);
568
+ if (parsed.success) {
569
+ return parsed.data;
570
+ }
571
+ throw new StructuredOutputValidationError(
572
+ "structured output does not match schema",
573
+ provider,
574
+ formatZodIssues(parsed.error.issues),
575
+ {
576
+ rawOutput
577
+ }
578
+ );
579
+ }
580
+ function formatZodIssues(issues) {
581
+ return issues.map((issue) => {
582
+ const path = issue.path.length > 0 ? issue.path.map((segment) => String(segment)).join(".") : "<root>";
583
+ return `${path}: ${issue.message}`;
584
+ });
358
585
  }
359
586
 
360
587
  // src/embedding-model.ts
361
- import { MistralError as MistralError2 } from "@mistralai/mistralai/models/errors";
362
- import { ProviderError as ProviderError2 } from "@core-ai/core-ai";
363
588
  function createMistralEmbeddingModel(client, modelId) {
364
589
  return {
365
590
  provider: "mistral",
@@ -383,22 +608,11 @@ function createMistralEmbeddingModel(client, modelId) {
383
608
  }
384
609
  };
385
610
  } catch (error) {
386
- throw wrapError2(error);
611
+ throw wrapMistralError(error);
387
612
  }
388
613
  }
389
614
  };
390
615
  }
391
- function wrapError2(error) {
392
- if (error instanceof MistralError2) {
393
- return new ProviderError2(error.message, "mistral", error.statusCode, error);
394
- }
395
- return new ProviderError2(
396
- error instanceof Error ? error.message : String(error),
397
- "mistral",
398
- void 0,
399
- error
400
- );
401
- }
402
616
 
403
617
  // src/provider.ts
404
618
  function createMistral(options = {}) {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@core-ai/mistral",
3
- "version": "0.2.1",
3
+ "version": "0.3.0",
4
4
  "description": "Mistral provider package for @core-ai/core-ai",
5
5
  "license": "MIT",
6
6
  "author": "Omnifact (https://omnifact.ai)",
@@ -42,7 +42,7 @@
42
42
  "test:watch": "vitest"
43
43
  },
44
44
  "dependencies": {
45
- "@core-ai/core-ai": "^0.2.1",
45
+ "@core-ai/core-ai": "^0.3.0",
46
46
  "@mistralai/mistralai": "^1.14.0",
47
47
  "zod-to-json-schema": "^3.25.1"
48
48
  },