pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,668 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from collections.abc import AsyncIterator, Iterable
5
+ from contextlib import asynccontextmanager
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime, timezone
8
+ from itertools import chain
9
+ from typing import Any, Callable, Literal, Union
10
+
11
+ from httpx import AsyncClient as AsyncHTTPClient, Timeout
12
+ from typing_extensions import assert_never
13
+
14
+ from .. import UnexpectedModelBehavior
15
+ from .._utils import now_utc as _now_utc
16
+ from ..messages import (
17
+ ArgsJson,
18
+ ModelMessage,
19
+ ModelRequest,
20
+ ModelResponse,
21
+ ModelResponsePart,
22
+ RetryPromptPart,
23
+ SystemPromptPart,
24
+ TextPart,
25
+ ToolCallPart,
26
+ ToolReturnPart,
27
+ UserPromptPart,
28
+ )
29
+ from ..result import Usage
30
+ from ..settings import ModelSettings
31
+ from ..tools import ToolDefinition
32
+ from . import (
33
+ AgentModel,
34
+ EitherStreamedResponse,
35
+ Model,
36
+ StreamStructuredResponse,
37
+ StreamTextResponse,
38
+ cached_async_http_client,
39
+ )
40
+
41
+ try:
42
+ from json_repair import repair_json
43
+ from mistralai import (
44
+ UNSET,
45
+ CompletionChunk as MistralCompletionChunk,
46
+ Content as MistralContent,
47
+ ContentChunk as MistralContentChunk,
48
+ FunctionCall as MistralFunctionCall,
49
+ Mistral,
50
+ OptionalNullable as MistralOptionalNullable,
51
+ TextChunk as MistralTextChunk,
52
+ ToolChoiceEnum as MistralToolChoiceEnum,
53
+ )
54
+ from mistralai.models import (
55
+ ChatCompletionResponse as MistralChatCompletionResponse,
56
+ CompletionEvent as MistralCompletionEvent,
57
+ Messages as MistralMessages,
58
+ Tool as MistralTool,
59
+ ToolCall as MistralToolCall,
60
+ )
61
+ from mistralai.models.assistantmessage import AssistantMessage as MistralAssistantMessage
62
+ from mistralai.models.function import Function as MistralFunction
63
+ from mistralai.models.systemmessage import SystemMessage as MistralSystemMessage
64
+ from mistralai.models.toolmessage import ToolMessage as MistralToolMessage
65
+ from mistralai.models.usermessage import UserMessage as MistralUserMessage
66
+ from mistralai.types.basemodel import Unset as MistralUnset
67
+ from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
68
+ except ImportError as e:
69
+ raise ImportError(
70
+ 'Please install `mistral` to use the Mistral model, '
71
+ "you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
72
+ ) from e
73
+
74
+ NamedMistralModels = Literal[
75
+ 'mistral-large-latest', 'mistral-small-latest', 'codestral-latest', 'mistral-moderation-latest'
76
+ ]
77
+ """Latest / most popular named Mistral models."""
78
+
79
+ MistralModelName = Union[NamedMistralModels, str]
80
+ """Possible Mistral model names.
81
+
82
+ Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
83
+ allow any name in the type hints.
84
+ Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_overview/) for a full list.
85
+ """
86
+
87
+
88
+ @dataclass(init=False)
89
+ class MistralModel(Model):
90
+ """A model that uses Mistral.
91
+
92
+ Internally, this uses the [Mistral Python client](https://github.com/mistralai/client-python) to interact with the API.
93
+
94
+ [API Documentation](https://docs.mistral.ai/)
95
+ """
96
+
97
+ model_name: MistralModelName
98
+ client: Mistral = field(repr=False)
99
+
100
+ def __init__(
101
+ self,
102
+ model_name: MistralModelName,
103
+ *,
104
+ api_key: str | Callable[[], str | None] | None = None,
105
+ client: Mistral | None = None,
106
+ http_client: AsyncHTTPClient | None = None,
107
+ ):
108
+ """Initialize a Mistral model.
109
+
110
+ Args:
111
+ model_name: The name of the model to use.
112
+ api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
113
+ client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
114
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
115
+ """
116
+ self.model_name = model_name
117
+
118
+ if client is not None:
119
+ assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
120
+ assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
121
+ self.client = client
122
+ else:
123
+ api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
124
+ self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
125
+
126
+ async def agent_model(
127
+ self,
128
+ *,
129
+ function_tools: list[ToolDefinition],
130
+ allow_text_result: bool,
131
+ result_tools: list[ToolDefinition],
132
+ ) -> AgentModel:
133
+ """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
134
+ return MistralAgentModel(
135
+ self.client,
136
+ self.model_name,
137
+ allow_text_result,
138
+ function_tools,
139
+ result_tools,
140
+ )
141
+
142
+ def name(self) -> str:
143
+ return f'mistral:{self.model_name}'
144
+
145
+
146
+ @dataclass
147
+ class MistralAgentModel(AgentModel):
148
+ """Implementation of `AgentModel` for Mistral models."""
149
+
150
+ client: Mistral
151
+ model_name: str
152
+ allow_text_result: bool
153
+ function_tools: list[ToolDefinition]
154
+ result_tools: list[ToolDefinition]
155
+ json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
156
+
157
+ async def request(
158
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
159
+ ) -> tuple[ModelResponse, Usage]:
160
+ """Make a non-streaming request to the model from Pydantic AI call."""
161
+ response = await self._completions_create(messages, model_settings)
162
+ return self._process_response(response), _map_usage(response)
163
+
164
+ @asynccontextmanager
165
+ async def request_stream(
166
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
167
+ ) -> AsyncIterator[EitherStreamedResponse]:
168
+ """Make a streaming request to the model from Pydantic AI call."""
169
+ response = await self._stream_completions_create(messages, model_settings)
170
+ async with response:
171
+ yield await self._process_streamed_response(self.result_tools, response)
172
+
173
+ async def _completions_create(
174
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
175
+ ) -> MistralChatCompletionResponse:
176
+ """Make a non-streaming request to the model."""
177
+ model_settings = model_settings or {}
178
+ response = await self.client.chat.complete_async(
179
+ model=str(self.model_name),
180
+ messages=list(chain(*(self._map_message(m) for m in messages))),
181
+ n=1,
182
+ tools=self._map_function_and_result_tools_definition() or UNSET,
183
+ tool_choice=self._get_tool_choice(),
184
+ stream=False,
185
+ max_tokens=model_settings.get('max_tokens', UNSET),
186
+ temperature=model_settings.get('temperature', UNSET),
187
+ top_p=model_settings.get('top_p', 1),
188
+ timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
189
+ )
190
+ assert response, 'A unexpected empty response from Mistral.'
191
+ return response
192
+
193
+ async def _stream_completions_create(
194
+ self,
195
+ messages: list[ModelMessage],
196
+ model_settings: ModelSettings | None,
197
+ ) -> MistralEventStreamAsync[MistralCompletionEvent]:
198
+ """Create a streaming completion request to the Mistral model."""
199
+ response: MistralEventStreamAsync[MistralCompletionEvent] | None
200
+ mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
201
+
202
+ model_settings = model_settings or {}
203
+
204
+ if self.result_tools and self.function_tools or self.function_tools:
205
+ # Function Calling Mode
206
+ response = await self.client.chat.stream_async(
207
+ model=str(self.model_name),
208
+ messages=mistral_messages,
209
+ n=1,
210
+ tools=self._map_function_and_result_tools_definition() or UNSET,
211
+ tool_choice=self._get_tool_choice(),
212
+ temperature=model_settings.get('temperature', UNSET),
213
+ top_p=model_settings.get('top_p', 1),
214
+ max_tokens=model_settings.get('max_tokens', UNSET),
215
+ timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
216
+ )
217
+
218
+ elif self.result_tools:
219
+ # Json Mode
220
+ parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
221
+
222
+ user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
223
+ mistral_messages.append(user_output_format_message)
224
+ response = await self.client.chat.stream_async(
225
+ model=str(self.model_name),
226
+ messages=mistral_messages,
227
+ response_format={'type': 'json_object'},
228
+ stream=True,
229
+ )
230
+
231
+ else:
232
+ # Stream Mode
233
+ response = await self.client.chat.stream_async(
234
+ model=str(self.model_name),
235
+ messages=mistral_messages,
236
+ stream=True,
237
+ )
238
+ assert response, 'A unexpected empty response from Mistral.'
239
+ return response
240
+
241
+ def _get_tool_choice(self) -> MistralToolChoiceEnum | None:
242
+ """Get tool choice for the model.
243
+
244
+ - "auto": Default mode. Model decides if it uses the tool or not.
245
+ - "any": Select any tool.
246
+ - "none": Prevents tool use.
247
+ - "required": Forces tool use.
248
+ """
249
+ if not self.function_tools and not self.result_tools:
250
+ return None
251
+ elif not self.allow_text_result:
252
+ return 'required'
253
+ else:
254
+ return 'auto'
255
+
256
+ def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None:
257
+ """Map function and result tools to MistralTool format.
258
+
259
+ Returns None if both function_tools and result_tools are empty.
260
+ """
261
+ all_tools: list[ToolDefinition] = self.function_tools + self.result_tools
262
+ tools = [
263
+ MistralTool(
264
+ function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description)
265
+ )
266
+ for r in all_tools
267
+ ]
268
+ return tools if tools else None
269
+
270
+ @staticmethod
271
+ def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
272
+ """Process a non-streamed response, and prepare a message to return."""
273
+ if response.created:
274
+ timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
275
+ else:
276
+ timestamp = _now_utc()
277
+
278
+ assert response.choices, 'Unexpected empty response choice.'
279
+ choice = response.choices[0]
280
+ content = choice.message.content
281
+ tool_calls = choice.message.tool_calls
282
+
283
+ parts: list[ModelResponsePart] = []
284
+ if text := _map_content(content):
285
+ parts.append(TextPart(text))
286
+
287
+ if isinstance(tool_calls, list):
288
+ for tool_call in tool_calls:
289
+ tool = _map_mistral_to_pydantic_tool_call(tool_call)
290
+ parts.append(tool)
291
+
292
+ return ModelResponse(parts, timestamp=timestamp)
293
+
294
+ @staticmethod
295
+ async def _process_streamed_response(
296
+ result_tools: list[ToolDefinition],
297
+ response: MistralEventStreamAsync[MistralCompletionEvent],
298
+ ) -> EitherStreamedResponse:
299
+ """Process a streamed response, and prepare a streaming response to return."""
300
+ start_usage = Usage()
301
+
302
+ # Iterate until we get either `tool_calls` or `content` from the first chunk.
303
+ while True:
304
+ try:
305
+ event = await response.__anext__()
306
+ chunk = event.data
307
+ except StopAsyncIteration as e:
308
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
309
+
310
+ start_usage += _map_usage(chunk)
311
+
312
+ if chunk.created:
313
+ timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
314
+ else:
315
+ timestamp = _now_utc()
316
+
317
+ if chunk.choices:
318
+ delta = chunk.choices[0].delta
319
+ content = _map_content(delta.content)
320
+
321
+ tool_calls: list[MistralToolCall] | None = None
322
+ if delta.tool_calls:
323
+ tool_calls = delta.tool_calls
324
+
325
+ if tool_calls or content and result_tools:
326
+ return MistralStreamStructuredResponse(
327
+ {c.id if c.id else 'null': c for c in tool_calls or []},
328
+ {c.name: c for c in result_tools},
329
+ response,
330
+ content,
331
+ timestamp,
332
+ start_usage,
333
+ )
334
+
335
+ elif content:
336
+ return MistralStreamTextResponse(content, response, timestamp, start_usage)
337
+
338
+ @staticmethod
339
+ def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
340
+ """Maps a pydantic-ai ToolCall to a MistralToolCall."""
341
+ if isinstance(t.args, ArgsJson):
342
+ return MistralToolCall(
343
+ id=t.tool_call_id,
344
+ type='function',
345
+ function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_json),
346
+ )
347
+ else:
348
+ return MistralToolCall(
349
+ id=t.tool_call_id,
350
+ type='function',
351
+ function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
352
+ )
353
+
354
+ def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
355
+ """Get a message with an example of the expected output format."""
356
+ examples: list[dict[str, Any]] = []
357
+ for schema in schemas:
358
+ typed_dict_definition: dict[str, Any] = {}
359
+ for key, value in schema.get('properties', {}).items():
360
+ typed_dict_definition[key] = self._get_python_type(value)
361
+ examples.append(typed_dict_definition)
362
+
363
+ example_schema = examples[0] if len(examples) == 1 else examples
364
+ return MistralUserMessage(content=self.json_mode_schema_prompt.format(schema=example_schema))
365
+
366
+ @classmethod
367
+ def _get_python_type(cls, value: dict[str, Any]) -> str:
368
+ """Return a string representation of the Python type for a single JSON schema property.
369
+
370
+ This function handles recursion for nested arrays/objects and `anyOf`.
371
+ """
372
+ # 1) Handle anyOf first, because it's a different schema structure
373
+ if any_of := value.get('anyOf'):
374
+ # Simplistic approach: pick the first option in anyOf
375
+ # (In reality, you'd possibly want to merge or union types)
376
+ return f'Optional[{cls._get_python_type(any_of[0])}]'
377
+
378
+ # 2) If we have a top-level "type" field
379
+ value_type = value.get('type')
380
+ if not value_type:
381
+ # No explicit type; fallback
382
+ return 'Any'
383
+
384
+ # 3) Direct simple type mapping (string, integer, float, bool, None)
385
+ if value_type in SIMPLE_JSON_TYPE_MAPPING and value_type != 'array' and value_type != 'object':
386
+ return SIMPLE_JSON_TYPE_MAPPING[value_type]
387
+
388
+ # 4) Array: Recursively get the item type
389
+ if value_type == 'array':
390
+ items = value.get('items', {})
391
+ return f'list[{cls._get_python_type(items)}]'
392
+
393
+ # 5) Object: Check for additionalProperties
394
+ if value_type == 'object':
395
+ additional_properties = value.get('additionalProperties', {})
396
+ additional_properties_type = additional_properties.get('type')
397
+ if (
398
+ additional_properties_type in SIMPLE_JSON_TYPE_MAPPING
399
+ and additional_properties_type != 'array'
400
+ and additional_properties_type != 'object'
401
+ ):
402
+ # dict[str, bool/int/float/etc...]
403
+ return f'dict[str, {SIMPLE_JSON_TYPE_MAPPING[additional_properties_type]}]'
404
+ elif additional_properties_type == 'array':
405
+ array_items = additional_properties.get('items', {})
406
+ return f'dict[str, list[{cls._get_python_type(array_items)}]]'
407
+ elif additional_properties_type == 'object':
408
+ # nested dictionary of unknown shape
409
+ return 'dict[str, dict[str, Any]]'
410
+ else:
411
+ # If no additionalProperties type or something else, default to a generic dict
412
+ return 'dict[str, Any]'
413
+
414
+ # 6) Fallback
415
+ return 'Any'
416
+
417
+ @staticmethod
418
+ def _get_timeout_ms(timeout: Timeout | float | None) -> int | None:
419
+ """Convert a timeout to milliseconds."""
420
+ if timeout is None:
421
+ return None
422
+ if isinstance(timeout, float):
423
+ return int(1000 * timeout)
424
+ raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
425
+
426
+ @classmethod
427
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
428
+ for part in message.parts:
429
+ if isinstance(part, SystemPromptPart):
430
+ yield MistralSystemMessage(content=part.content)
431
+ elif isinstance(part, UserPromptPart):
432
+ yield MistralUserMessage(content=part.content)
433
+ elif isinstance(part, ToolReturnPart):
434
+ yield MistralToolMessage(
435
+ tool_call_id=part.tool_call_id,
436
+ content=part.model_response_str(),
437
+ )
438
+ elif isinstance(part, RetryPromptPart):
439
+ if part.tool_name is None:
440
+ yield MistralUserMessage(content=part.model_response())
441
+ else:
442
+ yield MistralToolMessage(
443
+ tool_call_id=part.tool_call_id,
444
+ content=part.model_response(),
445
+ )
446
+ else:
447
+ assert_never(part)
448
+
449
+ @classmethod
450
+ def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
451
+ """Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
452
+ if isinstance(message, ModelRequest):
453
+ yield from cls._map_user_message(message)
454
+ elif isinstance(message, ModelResponse):
455
+ content_chunks: list[MistralContentChunk] = []
456
+ tool_calls: list[MistralToolCall] = []
457
+
458
+ for part in message.parts:
459
+ if isinstance(part, TextPart):
460
+ content_chunks.append(MistralTextChunk(text=part.content))
461
+ elif isinstance(part, ToolCallPart):
462
+ tool_calls.append(cls._map_to_mistral_tool_call(part))
463
+ else:
464
+ assert_never(part)
465
+ yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
466
+ else:
467
+ assert_never(message)
468
+
469
+
470
+ @dataclass
471
+ class MistralStreamTextResponse(StreamTextResponse):
472
+ """Implementation of `StreamTextResponse` for Mistral models."""
473
+
474
+ _first: str | None
475
+ _response: MistralEventStreamAsync[MistralCompletionEvent]
476
+ _timestamp: datetime
477
+ _usage: Usage
478
+ _buffer: list[str] = field(default_factory=list, init=False)
479
+
480
+ async def __anext__(self) -> None:
481
+ if self._first is not None and len(self._first) > 0:
482
+ self._buffer.append(self._first)
483
+ self._first = None
484
+ return None
485
+
486
+ chunk = await self._response.__anext__()
487
+ self._usage += _map_usage(chunk.data)
488
+
489
+ try:
490
+ choice = chunk.data.choices[0]
491
+ except IndexError:
492
+ raise StopAsyncIteration()
493
+
494
+ content = choice.delta.content
495
+ if choice.finish_reason is None:
496
+ assert content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
497
+
498
+ if text := _map_content(content):
499
+ self._buffer.append(text)
500
+
501
+ def get(self, *, final: bool = False) -> Iterable[str]:
502
+ yield from self._buffer
503
+ self._buffer.clear()
504
+
505
+ def usage(self) -> Usage:
506
+ return self._usage
507
+
508
+ def timestamp(self) -> datetime:
509
+ return self._timestamp
510
+
511
+
512
+ @dataclass
513
+ class MistralStreamStructuredResponse(StreamStructuredResponse):
514
+ """Implementation of `StreamStructuredResponse` for Mistral models."""
515
+
516
+ _function_tools: dict[str, MistralToolCall]
517
+ _result_tools: dict[str, ToolDefinition]
518
+ _response: MistralEventStreamAsync[MistralCompletionEvent]
519
+ _delta_content: str | None
520
+ _timestamp: datetime
521
+ _usage: Usage
522
+
523
+ async def __anext__(self) -> None:
524
+ chunk = await self._response.__anext__()
525
+ self._usage += _map_usage(chunk.data)
526
+
527
+ try:
528
+ choice = chunk.data.choices[0]
529
+
530
+ except IndexError:
531
+ raise StopAsyncIteration()
532
+
533
+ if choice.finish_reason is not None:
534
+ raise StopAsyncIteration()
535
+
536
+ content = choice.delta.content
537
+ if self._result_tools:
538
+ if text := _map_content(content):
539
+ self._delta_content = (self._delta_content or '') + text
540
+
541
+ def get(self, *, final: bool = False) -> ModelResponse:
542
+ calls: list[ModelResponsePart] = []
543
+ if self._function_tools and self._result_tools or self._function_tools:
544
+ for tool_call in self._function_tools.values():
545
+ tool = _map_mistral_to_pydantic_tool_call(tool_call)
546
+ calls.append(tool)
547
+
548
+ elif self._delta_content and self._result_tools:
549
+ # NOTE: Params set for the most efficient and fastest way.
550
+ output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True)
551
+ assert isinstance(
552
+ output_json, dict
553
+ ), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
554
+
555
+ if output_json:
556
+ for result_tool in self._result_tools.values():
557
+ # NOTE: Additional verification to prevent JSON validation to crash in `result.py`
558
+ # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
559
+ # For example, `return_type=list[str]` expects a 'response' key with value type array of str.
560
+ # when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
561
+ # when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
562
+ # This ensures it's corrected to `{"response": {}}` and other required parameters and type.
563
+ if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
564
+ continue
565
+
566
+ tool = ToolCallPart.from_raw_args(result_tool.name, output_json)
567
+ calls.append(tool)
568
+
569
+ return ModelResponse(calls, timestamp=self._timestamp)
570
+
571
+ def usage(self) -> Usage:
572
+ return self._usage
573
+
574
+ def timestamp(self) -> datetime:
575
+ return self._timestamp
576
+
577
+ @staticmethod
578
+ def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
579
+ """Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
580
+ required_params = json_schema.get('required', [])
581
+ properties = json_schema.get('properties', {})
582
+
583
+ for param in required_params:
584
+ if param not in json_dict:
585
+ return False
586
+
587
+ param_schema = properties.get(param, {})
588
+ param_type = param_schema.get('type')
589
+ param_items_type = param_schema.get('items', {}).get('type')
590
+
591
+ if param_type == 'array' and param_items_type:
592
+ if not isinstance(json_dict[param], list):
593
+ return False
594
+ for item in json_dict[param]:
595
+ if not isinstance(item, VALIDE_JSON_TYPE_MAPPING[param_items_type]):
596
+ return False
597
+ elif param_type and not isinstance(json_dict[param], VALIDE_JSON_TYPE_MAPPING[param_type]):
598
+ return False
599
+
600
+ if isinstance(json_dict[param], dict) and 'properties' in param_schema:
601
+ nested_schema = param_schema
602
+ if not MistralStreamStructuredResponse._validate_required_json_schema(json_dict[param], nested_schema):
603
+ return False
604
+
605
+ return True
606
+
607
+
608
+ VALIDE_JSON_TYPE_MAPPING: dict[str, Any] = {
609
+ 'string': str,
610
+ 'integer': int,
611
+ 'number': float,
612
+ 'boolean': bool,
613
+ 'array': list,
614
+ 'object': dict,
615
+ 'null': type(None),
616
+ }
617
+
618
+ SIMPLE_JSON_TYPE_MAPPING = {
619
+ 'string': 'str',
620
+ 'integer': 'int',
621
+ 'number': 'float',
622
+ 'boolean': 'bool',
623
+ 'array': 'list',
624
+ 'null': 'None',
625
+ }
626
+
627
+
628
+ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
629
+ """Maps a MistralToolCall to a ToolCall."""
630
+ tool_call_id = tool_call.id or None
631
+ func_call = tool_call.function
632
+
633
+ return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
634
+
635
+
636
+ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
637
+ """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
638
+ if response.usage:
639
+ return Usage(
640
+ request_tokens=response.usage.prompt_tokens,
641
+ response_tokens=response.usage.completion_tokens,
642
+ total_tokens=response.usage.total_tokens,
643
+ details=None,
644
+ )
645
+ else:
646
+ return Usage()
647
+
648
+
649
+ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
650
+ """Maps the delta content from a Mistral Completion Chunk to a string or None."""
651
+ result: str | None = None
652
+
653
+ if isinstance(content, MistralUnset) or not content:
654
+ result = None
655
+ elif isinstance(content, list):
656
+ for chunk in content:
657
+ if isinstance(chunk, MistralTextChunk):
658
+ result = result or '' + chunk.text
659
+ else:
660
+ assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
661
+ elif isinstance(content, str):
662
+ result = content
663
+
664
+ # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and reponses`)
665
+ if result and len(result) == 0:
666
+ result = None
667
+
668
+ return result
@@ -17,7 +17,7 @@ try:
17
17
  except ImportError as e:
18
18
  raise ImportError(
19
19
  'Please install `openai` to use the OpenAI model, '
20
- "you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
20
+ "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
21
21
  ) from e
22
22
 
23
23