@core-ai/mistral 0.2.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/dist/index.js +290 -66
  2. package/package.json +2 -2
package/dist/index.js CHANGED
@@ -2,12 +2,18 @@
2
2
  import { Mistral } from "@mistralai/mistralai";
3
3
 
4
4
  // src/chat-model.ts
5
- import { createStreamResult } from "@core-ai/core-ai";
5
+ import {
6
+ StructuredOutputNoObjectGeneratedError,
7
+ StructuredOutputParseError,
8
+ StructuredOutputValidationError,
9
+ createObjectStreamResult,
10
+ createStreamResult
11
+ } from "@core-ai/core-ai";
6
12
 
7
13
  // src/chat-adapter.ts
8
- import { MistralError } from "@mistralai/mistralai/models/errors";
9
14
  import { zodToJsonSchema } from "zod-to-json-schema";
10
- import { ProviderError } from "@core-ai/core-ai";
15
+ var DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME = "core_ai_generate_object";
16
+ var DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION = "Return a JSON object that matches the requested schema.";
11
17
  function convertMessages(messages) {
12
18
  return messages.map(convertMessage);
13
19
  }
@@ -89,41 +95,69 @@ function convertToolChoice(choice) {
89
95
  }
90
96
  };
91
97
  }
98
+ function getStructuredOutputToolName(options) {
99
+ const trimmedName = options.schemaName?.trim();
100
+ if (trimmedName && trimmedName.length > 0) {
101
+ return trimmedName;
102
+ }
103
+ return DEFAULT_STRUCTURED_OUTPUT_TOOL_NAME;
104
+ }
105
+ function createStructuredOutputOptions(options) {
106
+ const toolName = getStructuredOutputToolName(options);
107
+ return {
108
+ messages: options.messages,
109
+ tools: {
110
+ structured_output: {
111
+ name: toolName,
112
+ description: options.schemaDescription ?? DEFAULT_STRUCTURED_OUTPUT_TOOL_DESCRIPTION,
113
+ parameters: options.schema
114
+ }
115
+ },
116
+ toolChoice: {
117
+ type: "tool",
118
+ toolName
119
+ },
120
+ config: options.config,
121
+ providerOptions: options.providerOptions,
122
+ signal: options.signal
123
+ };
124
+ }
92
125
  function createGenerateRequest(modelId, options) {
93
126
  const baseRequest = {
94
- model: modelId,
95
- messages: convertMessages(options.messages),
96
- ...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
97
- ...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
98
- ...options.config?.temperature !== void 0 ? { temperature: options.config.temperature } : {},
99
- ...options.config?.maxTokens !== void 0 ? { maxTokens: options.config.maxTokens } : {},
100
- ...options.config?.topP !== void 0 ? { topP: options.config.topP } : {},
101
- ...options.config?.stopSequences ? { stop: options.config.stopSequences } : {},
102
- ...options.config?.frequencyPenalty !== void 0 ? { frequencyPenalty: options.config.frequencyPenalty } : {},
103
- ...options.config?.presencePenalty !== void 0 ? { presencePenalty: options.config.presencePenalty } : {}
127
+ ...createRequestBase(modelId, options)
104
128
  };
105
- return options.providerOptions ? {
106
- ...baseRequest,
107
- ...options.providerOptions
108
- } : baseRequest;
129
+ return mergeProviderOptions(baseRequest, options.providerOptions);
109
130
  }
110
131
  function createStreamRequest(modelId, options) {
111
132
  const baseRequest = {
133
+ ...createRequestBase(modelId, options),
134
+ stream: true
135
+ };
136
+ return mergeProviderOptions(baseRequest, options.providerOptions);
137
+ }
138
+ function createRequestBase(modelId, options) {
139
+ return {
112
140
  model: modelId,
113
141
  messages: convertMessages(options.messages),
114
- stream: true,
115
142
  ...options.tools && Object.keys(options.tools).length > 0 ? { tools: convertTools(options.tools) } : {},
116
143
  ...options.toolChoice ? { toolChoice: convertToolChoice(options.toolChoice) } : {},
117
- ...options.config?.temperature !== void 0 ? { temperature: options.config.temperature } : {},
118
- ...options.config?.maxTokens !== void 0 ? { maxTokens: options.config.maxTokens } : {},
119
- ...options.config?.topP !== void 0 ? { topP: options.config.topP } : {},
120
- ...options.config?.stopSequences ? { stop: options.config.stopSequences } : {},
121
- ...options.config?.frequencyPenalty !== void 0 ? { frequencyPenalty: options.config.frequencyPenalty } : {},
122
- ...options.config?.presencePenalty !== void 0 ? { presencePenalty: options.config.presencePenalty } : {}
144
+ ...mapConfigToRequestFields(options.config)
145
+ };
146
+ }
147
+ function mapConfigToRequestFields(config) {
148
+ return {
149
+ ...config?.temperature !== void 0 ? { temperature: config.temperature } : {},
150
+ ...config?.maxTokens !== void 0 ? { maxTokens: config.maxTokens } : {},
151
+ ...config?.topP !== void 0 ? { topP: config.topP } : {},
152
+ ...config?.stopSequences ? { stop: config.stopSequences } : {},
153
+ ...config?.frequencyPenalty !== void 0 ? { frequencyPenalty: config.frequencyPenalty } : {},
154
+ ...config?.presencePenalty !== void 0 ? { presencePenalty: config.presencePenalty } : {}
123
155
  };
124
- return options.providerOptions ? {
156
+ }
157
+ function mergeProviderOptions(baseRequest, providerOptions) {
158
+ return providerOptions ? {
125
159
  ...baseRequest,
126
- ...options.providerOptions
160
+ ...providerOptions
127
161
  } : baseRequest;
128
162
  }
129
163
  function mapGenerateResponse(response) {
@@ -152,8 +186,13 @@ async function* transformStream(stream) {
152
186
  let usage = {
153
187
  inputTokens: 0,
154
188
  outputTokens: 0,
155
- reasoningTokens: 0,
156
- totalTokens: 0
189
+ inputTokenDetails: {
190
+ cacheReadTokens: 0,
191
+ cacheWriteTokens: 0
192
+ },
193
+ outputTokenDetails: {
194
+ reasoningTokens: 0
195
+ }
157
196
  };
158
197
  for await (const event of stream) {
159
198
  const chunk = event.data;
@@ -171,7 +210,10 @@ async function* transformStream(stream) {
171
210
  };
172
211
  }
173
212
  if (choice.delta.toolCalls) {
174
- for (const [position, partialToolCall] of choice.delta.toolCalls.entries()) {
213
+ for (const [
214
+ position,
215
+ partialToolCall
216
+ ] of choice.delta.toolCalls.entries()) {
175
217
  const streamIndex = partialToolCall.index ?? position;
176
218
  const current = bufferedToolCalls.get(streamIndex) ?? {
177
219
  id: partialToolCall.id ?? `tool-${streamIndex}`,
@@ -270,8 +312,13 @@ function mapUsage(usage) {
270
312
  return {
271
313
  inputTokens: usage?.promptTokens ?? 0,
272
314
  outputTokens: usage?.completionTokens ?? 0,
273
- reasoningTokens: 0,
274
- totalTokens: usage?.totalTokens ?? 0
315
+ inputTokenDetails: {
316
+ cacheReadTokens: 0,
317
+ cacheWriteTokens: 0
318
+ },
319
+ outputTokenDetails: {
320
+ reasoningTokens: 0
321
+ }
275
322
  };
276
323
  }
277
324
  function extractTextContent(content) {
@@ -291,7 +338,9 @@ function extractTextDeltas(content) {
291
338
  if (!content || content.length === 0) {
292
339
  return [];
293
340
  }
294
- return content.flatMap((chunk) => chunk.type === "text" ? [chunk.text] : []);
341
+ return content.flatMap(
342
+ (chunk) => chunk.type === "text" ? [chunk.text] : []
343
+ );
295
344
  }
296
345
  function serializeJsonObject(value) {
297
346
  const objectValue = asObject(value);
@@ -317,9 +366,18 @@ function asObject(value) {
317
366
  }
318
367
  return {};
319
368
  }
320
- function wrapError(error) {
369
+
370
+ // src/mistral-error.ts
371
+ import { MistralError } from "@mistralai/mistralai/models/errors";
372
+ import { ProviderError } from "@core-ai/core-ai";
373
+ function wrapMistralError(error) {
321
374
  if (error instanceof MistralError) {
322
- return new ProviderError(error.message, "mistral", error.statusCode, error);
375
+ return new ProviderError(
376
+ error.message,
377
+ "mistral",
378
+ error.statusCode,
379
+ error
380
+ );
323
381
  }
324
382
  return new ProviderError(
325
383
  error instanceof Error ? error.message : String(error),
@@ -331,35 +389,212 @@ function wrapError(error) {
331
389
 
332
390
  // src/chat-model.ts
333
391
  function createMistralChatModel(client, modelId) {
392
+ const provider = "mistral";
393
+ async function callMistralChatApi(call) {
394
+ try {
395
+ return await call();
396
+ } catch (error) {
397
+ throw wrapMistralError(error);
398
+ }
399
+ }
400
+ async function generateChat(options) {
401
+ const request = createGenerateRequest(modelId, options);
402
+ const response = await callMistralChatApi(
403
+ () => client.chat.complete(request)
404
+ );
405
+ return mapGenerateResponse(response);
406
+ }
407
+ async function streamChat(options) {
408
+ const request = createStreamRequest(modelId, options);
409
+ const stream = await callMistralChatApi(
410
+ () => client.chat.stream(request)
411
+ );
412
+ return createStreamResult(transformStream(stream));
413
+ }
334
414
  return {
335
- provider: "mistral",
415
+ provider,
336
416
  modelId,
337
- async generate(options) {
338
- try {
339
- const request = createGenerateRequest(modelId, options);
340
- const response = await client.chat.complete(request);
341
- return mapGenerateResponse(response);
342
- } catch (error) {
343
- throw wrapError(error);
344
- }
417
+ generate: generateChat,
418
+ stream: streamChat,
419
+ async generateObject(options) {
420
+ const structuredOptions = createStructuredOutputOptions(options);
421
+ const result = await generateChat(structuredOptions);
422
+ const toolName = getStructuredOutputToolName(options);
423
+ const object = extractStructuredObject(
424
+ result,
425
+ options.schema,
426
+ provider,
427
+ toolName
428
+ );
429
+ return {
430
+ object,
431
+ finishReason: result.finishReason,
432
+ usage: result.usage
433
+ };
345
434
  },
346
- async stream(options) {
347
- try {
348
- const request = createStreamRequest(modelId, options);
349
- const stream = await client.chat.stream(
350
- request
435
+ async streamObject(options) {
436
+ const structuredOptions = createStructuredOutputOptions(options);
437
+ const stream = await streamChat(structuredOptions);
438
+ const toolName = getStructuredOutputToolName(options);
439
+ return createObjectStreamResult(
440
+ transformStructuredOutputStream(
441
+ stream,
442
+ options.schema,
443
+ provider,
444
+ toolName
445
+ )
446
+ );
447
+ }
448
+ };
449
+ }
450
+ function extractStructuredObject(result, schema, provider, toolName) {
451
+ const structuredToolCall = result.toolCalls.find(
452
+ (toolCall) => toolCall.name === toolName
453
+ );
454
+ if (structuredToolCall) {
455
+ return validateStructuredToolArguments(
456
+ schema,
457
+ structuredToolCall.arguments,
458
+ provider
459
+ );
460
+ }
461
+ const rawOutput = result.content?.trim();
462
+ if (rawOutput && rawOutput.length > 0) {
463
+ return parseAndValidateStructuredPayload(schema, rawOutput, provider);
464
+ }
465
+ throw new StructuredOutputNoObjectGeneratedError(
466
+ "model did not emit a structured object payload",
467
+ provider
468
+ );
469
+ }
470
+ async function* transformStructuredOutputStream(stream, schema, provider, toolName) {
471
+ let validatedObject;
472
+ let contentBuffer = "";
473
+ const toolArgumentDeltas = /* @__PURE__ */ new Map();
474
+ for await (const event of stream) {
475
+ if (event.type === "content-delta") {
476
+ contentBuffer += event.text;
477
+ yield {
478
+ type: "object-delta",
479
+ text: event.text
480
+ };
481
+ continue;
482
+ }
483
+ if (event.type === "tool-call-delta") {
484
+ const previous = toolArgumentDeltas.get(event.toolCallId) ?? "";
485
+ toolArgumentDeltas.set(
486
+ event.toolCallId,
487
+ `${previous}${event.argumentsDelta}`
488
+ );
489
+ yield {
490
+ type: "object-delta",
491
+ text: event.argumentsDelta
492
+ };
493
+ continue;
494
+ }
495
+ if (event.type === "tool-call-end" && event.toolCall.name === toolName) {
496
+ validatedObject = validateStructuredToolArguments(
497
+ schema,
498
+ event.toolCall.arguments,
499
+ provider
500
+ );
501
+ yield {
502
+ type: "object",
503
+ object: validatedObject
504
+ };
505
+ continue;
506
+ }
507
+ if (event.type === "finish") {
508
+ if (validatedObject === void 0) {
509
+ const fallbackPayload = getFallbackStructuredPayload(
510
+ contentBuffer,
511
+ toolArgumentDeltas
351
512
  );
352
- return createStreamResult(transformStream(stream));
353
- } catch (error) {
354
- throw wrapError(error);
513
+ if (!fallbackPayload) {
514
+ throw new StructuredOutputNoObjectGeneratedError(
515
+ "structured output stream ended without an object payload",
516
+ provider
517
+ );
518
+ }
519
+ validatedObject = parseAndValidateStructuredPayload(
520
+ schema,
521
+ fallbackPayload,
522
+ provider
523
+ );
524
+ yield {
525
+ type: "object",
526
+ object: validatedObject
527
+ };
355
528
  }
529
+ yield {
530
+ type: "finish",
531
+ finishReason: event.finishReason,
532
+ usage: event.usage
533
+ };
356
534
  }
357
- };
535
+ }
536
+ }
537
+ function getFallbackStructuredPayload(contentBuffer, toolArgumentDeltas) {
538
+ for (const delta of toolArgumentDeltas.values()) {
539
+ const trimmed = delta.trim();
540
+ if (trimmed.length > 0) {
541
+ return trimmed;
542
+ }
543
+ }
544
+ const trimmedContent = contentBuffer.trim();
545
+ if (trimmedContent.length > 0) {
546
+ return trimmedContent;
547
+ }
548
+ return void 0;
549
+ }
550
+ function validateStructuredToolArguments(schema, toolArguments, provider) {
551
+ return validateStructuredObject(
552
+ schema,
553
+ toolArguments,
554
+ provider,
555
+ JSON.stringify(toolArguments)
556
+ );
557
+ }
558
+ function parseAndValidateStructuredPayload(schema, rawPayload, provider) {
559
+ const parsedPayload = parseJson(rawPayload, provider);
560
+ return validateStructuredObject(schema, parsedPayload, provider, rawPayload);
561
+ }
562
+ function parseJson(rawOutput, provider) {
563
+ try {
564
+ return JSON.parse(rawOutput);
565
+ } catch (error) {
566
+ throw new StructuredOutputParseError(
567
+ "failed to parse structured output as JSON",
568
+ provider,
569
+ {
570
+ rawOutput,
571
+ cause: error
572
+ }
573
+ );
574
+ }
575
+ }
576
+ function validateStructuredObject(schema, value, provider, rawOutput) {
577
+ const parsed = schema.safeParse(value);
578
+ if (parsed.success) {
579
+ return parsed.data;
580
+ }
581
+ throw new StructuredOutputValidationError(
582
+ "structured output does not match schema",
583
+ provider,
584
+ formatZodIssues(parsed.error.issues),
585
+ {
586
+ rawOutput
587
+ }
588
+ );
589
+ }
590
+ function formatZodIssues(issues) {
591
+ return issues.map((issue) => {
592
+ const path = issue.path.length > 0 ? issue.path.map((segment) => String(segment)).join(".") : "<root>";
593
+ return `${path}: ${issue.message}`;
594
+ });
358
595
  }
359
596
 
360
597
  // src/embedding-model.ts
361
- import { MistralError as MistralError2 } from "@mistralai/mistralai/models/errors";
362
- import { ProviderError as ProviderError2 } from "@core-ai/core-ai";
363
598
  function createMistralEmbeddingModel(client, modelId) {
364
599
  return {
365
600
  provider: "mistral",
@@ -383,22 +618,11 @@ function createMistralEmbeddingModel(client, modelId) {
383
618
  }
384
619
  };
385
620
  } catch (error) {
386
- throw wrapError2(error);
621
+ throw wrapMistralError(error);
387
622
  }
388
623
  }
389
624
  };
390
625
  }
391
- function wrapError2(error) {
392
- if (error instanceof MistralError2) {
393
- return new ProviderError2(error.message, "mistral", error.statusCode, error);
394
- }
395
- return new ProviderError2(
396
- error instanceof Error ? error.message : String(error),
397
- "mistral",
398
- void 0,
399
- error
400
- );
401
- }
402
626
 
403
627
  // src/provider.ts
404
628
  function createMistral(options = {}) {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@core-ai/mistral",
3
- "version": "0.2.1",
3
+ "version": "0.4.0",
4
4
  "description": "Mistral provider package for @core-ai/core-ai",
5
5
  "license": "MIT",
6
6
  "author": "Omnifact (https://omnifact.ai)",
@@ -42,7 +42,7 @@
42
42
  "test:watch": "vitest"
43
43
  },
44
44
  "dependencies": {
45
- "@core-ai/core-ai": "^0.2.1",
45
+ "@core-ai/core-ai": "^0.4.0",
46
46
  "@mistralai/mistralai": "^1.14.0",
47
47
  "zod-to-json-schema": "^3.25.1"
48
48
  },