pydantic-ai-slim 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from itertools import chain
9
8
  from typing import Any, Literal, Union, cast
10
9
 
11
10
  import pydantic_core
@@ -29,16 +28,18 @@ from ..messages import (
29
28
  ToolCallPart,
30
29
  ToolReturnPart,
31
30
  UserPromptPart,
31
+ VideoUrl,
32
32
  )
33
33
  from ..providers import Provider, infer_provider
34
- from ..result import Usage
35
34
  from ..settings import ModelSettings
36
35
  from ..tools import ToolDefinition
36
+ from ..usage import Usage
37
37
  from . import (
38
38
  Model,
39
39
  ModelRequestParameters,
40
40
  StreamedResponse,
41
41
  check_allow_model_requests,
42
+ get_user_agent,
42
43
  )
43
44
 
44
45
  try:
@@ -167,7 +168,7 @@ class MistralModel(Model):
167
168
  messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
168
169
  )
169
170
  async with response:
170
- yield await self._process_streamed_response(model_request_parameters.result_tools, response)
171
+ yield await self._process_streamed_response(model_request_parameters.output_tools, response)
171
172
 
172
173
  @property
173
174
  def model_name(self) -> MistralModelName:
@@ -189,9 +190,9 @@ class MistralModel(Model):
189
190
  try:
190
191
  response = await self.client.chat.complete_async(
191
192
  model=str(self._model_name),
192
- messages=list(chain(*(self._map_message(m) for m in messages))),
193
+ messages=self._map_messages(messages),
193
194
  n=1,
194
- tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
195
+ tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
195
196
  tool_choice=self._get_tool_choice(model_request_parameters),
196
197
  stream=False,
197
198
  max_tokens=model_settings.get('max_tokens', UNSET),
@@ -200,6 +201,7 @@ class MistralModel(Model):
200
201
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
201
202
  random_seed=model_settings.get('seed', UNSET),
202
203
  stop=model_settings.get('stop_sequences', None),
204
+ http_headers={'User-Agent': get_user_agent()},
203
205
  )
204
206
  except SDKError as e:
205
207
  if (status_code := e.status_code) >= 400:
@@ -217,10 +219,10 @@ class MistralModel(Model):
217
219
  ) -> MistralEventStreamAsync[MistralCompletionEvent]:
218
220
  """Create a streaming completion request to the Mistral model."""
219
221
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
220
- mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
222
+ mistral_messages = self._map_messages(messages)
221
223
 
222
224
  if (
223
- model_request_parameters.result_tools
225
+ model_request_parameters.output_tools
224
226
  and model_request_parameters.function_tools
225
227
  or model_request_parameters.function_tools
226
228
  ):
@@ -229,7 +231,7 @@ class MistralModel(Model):
229
231
  model=str(self._model_name),
230
232
  messages=mistral_messages,
231
233
  n=1,
232
- tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
234
+ tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
233
235
  tool_choice=self._get_tool_choice(model_request_parameters),
234
236
  temperature=model_settings.get('temperature', UNSET),
235
237
  top_p=model_settings.get('top_p', 1),
@@ -238,11 +240,12 @@ class MistralModel(Model):
238
240
  presence_penalty=model_settings.get('presence_penalty'),
239
241
  frequency_penalty=model_settings.get('frequency_penalty'),
240
242
  stop=model_settings.get('stop_sequences', None),
243
+ http_headers={'User-Agent': get_user_agent()},
241
244
  )
242
245
 
243
- elif model_request_parameters.result_tools:
246
+ elif model_request_parameters.output_tools:
244
247
  # Json Mode
245
- parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.result_tools]
248
+ parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools]
246
249
  user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
247
250
  mistral_messages.append(user_output_format_message)
248
251
 
@@ -251,6 +254,7 @@ class MistralModel(Model):
251
254
  messages=mistral_messages,
252
255
  response_format={'type': 'json_object'},
253
256
  stream=True,
257
+ http_headers={'User-Agent': get_user_agent()},
254
258
  )
255
259
 
256
260
  else:
@@ -259,6 +263,7 @@ class MistralModel(Model):
259
263
  model=str(self._model_name),
260
264
  messages=mistral_messages,
261
265
  stream=True,
266
+ http_headers={'User-Agent': get_user_agent()},
262
267
  )
263
268
  assert response, 'A unexpected empty response from Mistral.'
264
269
  return response
@@ -271,22 +276,22 @@ class MistralModel(Model):
271
276
  - "none": Prevents tool use.
272
277
  - "required": Forces tool use.
273
278
  """
274
- if not model_request_parameters.function_tools and not model_request_parameters.result_tools:
279
+ if not model_request_parameters.function_tools and not model_request_parameters.output_tools:
275
280
  return None
276
- elif not model_request_parameters.allow_text_result:
281
+ elif not model_request_parameters.allow_text_output:
277
282
  return 'required'
278
283
  else:
279
284
  return 'auto'
280
285
 
281
- def _map_function_and_result_tools_definition(
286
+ def _map_function_and_output_tools_definition(
282
287
  self, model_request_parameters: ModelRequestParameters
283
288
  ) -> list[MistralTool] | None:
284
- """Map function and result tools to MistralTool format.
289
+ """Map function and output tools to MistralTool format.
285
290
 
286
- Returns None if both function_tools and result_tools are empty.
291
+ Returns None if both function_tools and output_tools are empty.
287
292
  """
288
293
  all_tools: list[ToolDefinition] = (
289
- model_request_parameters.function_tools + model_request_parameters.result_tools
294
+ model_request_parameters.function_tools + model_request_parameters.output_tools
290
295
  )
291
296
  tools = [
292
297
  MistralTool(
@@ -322,7 +327,7 @@ class MistralModel(Model):
322
327
 
323
328
  async def _process_streamed_response(
324
329
  self,
325
- result_tools: list[ToolDefinition],
330
+ output_tools: list[ToolDefinition],
326
331
  response: MistralEventStreamAsync[MistralCompletionEvent],
327
332
  ) -> StreamedResponse:
328
333
  """Process a streamed response, and prepare a streaming response to return."""
@@ -340,7 +345,7 @@ class MistralModel(Model):
340
345
  _response=peekable_response,
341
346
  _model_name=self._model_name,
342
347
  _timestamp=timestamp,
343
- _result_tools={c.name: c for c in result_tools},
348
+ _output_tools={c.name: c for c in output_tools},
344
349
  )
345
350
 
346
351
  @staticmethod
@@ -434,13 +439,12 @@ class MistralModel(Model):
434
439
  return int(1000 * timeout)
435
440
  raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
436
441
 
437
- @classmethod
438
- def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
442
+ def _map_user_message(self, message: ModelRequest) -> Iterable[MistralMessages]:
439
443
  for part in message.parts:
440
444
  if isinstance(part, SystemPromptPart):
441
445
  yield MistralSystemMessage(content=part.content)
442
446
  elif isinstance(part, UserPromptPart):
443
- yield cls._map_user_prompt(part)
447
+ yield self._map_user_prompt(part)
444
448
  elif isinstance(part, ToolReturnPart):
445
449
  yield MistralToolMessage(
446
450
  tool_call_id=part.tool_call_id,
@@ -457,28 +461,31 @@ class MistralModel(Model):
457
461
  else:
458
462
  assert_never(part)
459
463
 
460
- @classmethod
461
- def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
464
+ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]:
462
465
  """Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
463
- if isinstance(message, ModelRequest):
464
- yield from cls._map_user_message(message)
465
- elif isinstance(message, ModelResponse):
466
- content_chunks: list[MistralContentChunk] = []
467
- tool_calls: list[MistralToolCall] = []
468
-
469
- for part in message.parts:
470
- if isinstance(part, TextPart):
471
- content_chunks.append(MistralTextChunk(text=part.content))
472
- elif isinstance(part, ToolCallPart):
473
- tool_calls.append(cls._map_tool_call(part))
474
- else:
475
- assert_never(part)
476
- yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
477
- else:
478
- assert_never(message)
466
+ mistral_messages: list[MistralMessages] = []
467
+ for message in messages:
468
+ if isinstance(message, ModelRequest):
469
+ mistral_messages.extend(self._map_user_message(message))
470
+ elif isinstance(message, ModelResponse):
471
+ content_chunks: list[MistralContentChunk] = []
472
+ tool_calls: list[MistralToolCall] = []
473
+
474
+ for part in message.parts:
475
+ if isinstance(part, TextPart):
476
+ content_chunks.append(MistralTextChunk(text=part.content))
477
+ elif isinstance(part, ToolCallPart):
478
+ tool_calls.append(self._map_tool_call(part))
479
+ else:
480
+ assert_never(part)
481
+ mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
482
+ else:
483
+ assert_never(message)
484
+ if instructions := self._get_instructions(messages):
485
+ mistral_messages.insert(0, MistralSystemMessage(content=instructions))
486
+ return mistral_messages
479
487
 
480
- @staticmethod
481
- def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
488
+ def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
482
489
  content: str | list[MistralContentChunk]
483
490
  if isinstance(part.content, str):
484
491
  content = part.content
@@ -498,6 +505,8 @@ class MistralModel(Model):
498
505
  raise RuntimeError('Only image binary content is supported for Mistral.')
499
506
  elif isinstance(item, DocumentUrl):
500
507
  raise RuntimeError('DocumentUrl is not supported in Mistral.')
508
+ elif isinstance(item, VideoUrl):
509
+ raise RuntimeError('VideoUrl is not supported in Mistral.')
501
510
  else: # pragma: no cover
502
511
  raise RuntimeError(f'Unsupported content type: {type(item)}')
503
512
  return MistralUserMessage(content=content)
@@ -513,7 +522,7 @@ class MistralStreamedResponse(StreamedResponse):
513
522
  _model_name: MistralModelName
514
523
  _response: AsyncIterable[MistralCompletionEvent]
515
524
  _timestamp: datetime
516
- _result_tools: dict[str, ToolDefinition]
525
+ _output_tools: dict[str, ToolDefinition]
517
526
 
518
527
  _delta_content: str = field(default='', init=False)
519
528
 
@@ -531,13 +540,13 @@ class MistralStreamedResponse(StreamedResponse):
531
540
  content = choice.delta.content
532
541
  text = _map_content(content)
533
542
  if text:
534
- # Attempt to produce a result tool call from the received text
535
- if self._result_tools:
543
+ # Attempt to produce an output tool call from the received text
544
+ if self._output_tools:
536
545
  self._delta_content += text
537
- maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools)
546
+ maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools)
538
547
  if maybe_tool_call_part:
539
548
  yield self._parts_manager.handle_tool_call_part(
540
- vendor_part_id='result',
549
+ vendor_part_id='output',
541
550
  tool_name=maybe_tool_call_part.tool_name,
542
551
  args=maybe_tool_call_part.args_as_dict(),
543
552
  tool_call_id=maybe_tool_call_part.tool_call_id,
@@ -563,20 +572,20 @@ class MistralStreamedResponse(StreamedResponse):
563
572
  return self._timestamp
564
573
 
565
574
  @staticmethod
566
- def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
575
+ def _try_get_output_tool_from_text(text: str, output_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
567
576
  output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
568
577
  if output_json:
569
- for result_tool in result_tools.values():
570
- # NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
578
+ for output_tool in output_tools.values():
579
+ # NOTE: Additional verification to prevent JSON validation to crash
571
580
  # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
572
581
  # Example with BaseModel and required fields.
573
582
  if not MistralStreamedResponse._validate_required_json_schema(
574
- output_json, result_tool.parameters_json_schema
583
+ output_json, output_tool.parameters_json_schema
575
584
  ):
576
585
  continue
577
586
 
578
587
  # The following part_id will be thrown away
579
- return ToolCallPart(tool_name=result_tool.name, args=output_json)
588
+ return ToolCallPart(tool_name=output_tool.name, args=output_json)
580
589
 
581
590
  @staticmethod
582
591
  def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
@@ -644,21 +653,21 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
644
653
 
645
654
  def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
646
655
  """Maps the delta content from a Mistral Completion Chunk to a string or None."""
647
- result: str | None = None
656
+ output: str | None = None
648
657
 
649
658
  if isinstance(content, MistralUnset) or not content:
650
- result = None
659
+ output = None
651
660
  elif isinstance(content, list):
652
661
  for chunk in content:
653
662
  if isinstance(chunk, MistralTextChunk):
654
- result = result or '' + chunk.text
663
+ output = output or '' + chunk.text
655
664
  else:
656
665
  assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
657
666
  elif isinstance(content, str):
658
- result = content
667
+ output = content
659
668
 
660
669
  # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
661
- if result and len(result) == 0:
662
- result = None
670
+ if output and len(output) == 0: # pragma: no cover
671
+ output = None
663
672
 
664
- return result
673
+ return output