pydantic-ai-slim 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +67 -55
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +503 -163
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +25 -5
- pydantic_ai/models/_json_schema.py +156 -0
- pydantic_ai/models/anthropic.py +14 -4
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +65 -75
- pydantic_ai/models/groq.py +34 -29
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +67 -58
- pydantic_ai/models/openai.py +113 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/models/wrapper.py +3 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/azure.py +2 -2
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +3 -3
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.0.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.54.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/mistral.py
CHANGED
|
@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
|
-
from itertools import chain
|
|
9
8
|
from typing import Any, Literal, Union, cast
|
|
10
9
|
|
|
11
10
|
import pydantic_core
|
|
@@ -29,16 +28,18 @@ from ..messages import (
|
|
|
29
28
|
ToolCallPart,
|
|
30
29
|
ToolReturnPart,
|
|
31
30
|
UserPromptPart,
|
|
31
|
+
VideoUrl,
|
|
32
32
|
)
|
|
33
33
|
from ..providers import Provider, infer_provider
|
|
34
|
-
from ..result import Usage
|
|
35
34
|
from ..settings import ModelSettings
|
|
36
35
|
from ..tools import ToolDefinition
|
|
36
|
+
from ..usage import Usage
|
|
37
37
|
from . import (
|
|
38
38
|
Model,
|
|
39
39
|
ModelRequestParameters,
|
|
40
40
|
StreamedResponse,
|
|
41
41
|
check_allow_model_requests,
|
|
42
|
+
get_user_agent,
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
try:
|
|
@@ -167,7 +168,7 @@ class MistralModel(Model):
|
|
|
167
168
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
168
169
|
)
|
|
169
170
|
async with response:
|
|
170
|
-
yield await self._process_streamed_response(model_request_parameters.
|
|
171
|
+
yield await self._process_streamed_response(model_request_parameters.output_tools, response)
|
|
171
172
|
|
|
172
173
|
@property
|
|
173
174
|
def model_name(self) -> MistralModelName:
|
|
@@ -189,9 +190,9 @@ class MistralModel(Model):
|
|
|
189
190
|
try:
|
|
190
191
|
response = await self.client.chat.complete_async(
|
|
191
192
|
model=str(self._model_name),
|
|
192
|
-
messages=
|
|
193
|
+
messages=self._map_messages(messages),
|
|
193
194
|
n=1,
|
|
194
|
-
tools=self.
|
|
195
|
+
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
|
|
195
196
|
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
196
197
|
stream=False,
|
|
197
198
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
@@ -200,6 +201,7 @@ class MistralModel(Model):
|
|
|
200
201
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
201
202
|
random_seed=model_settings.get('seed', UNSET),
|
|
202
203
|
stop=model_settings.get('stop_sequences', None),
|
|
204
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
203
205
|
)
|
|
204
206
|
except SDKError as e:
|
|
205
207
|
if (status_code := e.status_code) >= 400:
|
|
@@ -217,10 +219,10 @@ class MistralModel(Model):
|
|
|
217
219
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
218
220
|
"""Create a streaming completion request to the Mistral model."""
|
|
219
221
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
220
|
-
mistral_messages =
|
|
222
|
+
mistral_messages = self._map_messages(messages)
|
|
221
223
|
|
|
222
224
|
if (
|
|
223
|
-
model_request_parameters.
|
|
225
|
+
model_request_parameters.output_tools
|
|
224
226
|
and model_request_parameters.function_tools
|
|
225
227
|
or model_request_parameters.function_tools
|
|
226
228
|
):
|
|
@@ -229,7 +231,7 @@ class MistralModel(Model):
|
|
|
229
231
|
model=str(self._model_name),
|
|
230
232
|
messages=mistral_messages,
|
|
231
233
|
n=1,
|
|
232
|
-
tools=self.
|
|
234
|
+
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
|
|
233
235
|
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
234
236
|
temperature=model_settings.get('temperature', UNSET),
|
|
235
237
|
top_p=model_settings.get('top_p', 1),
|
|
@@ -238,11 +240,12 @@ class MistralModel(Model):
|
|
|
238
240
|
presence_penalty=model_settings.get('presence_penalty'),
|
|
239
241
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
240
242
|
stop=model_settings.get('stop_sequences', None),
|
|
243
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
241
244
|
)
|
|
242
245
|
|
|
243
|
-
elif model_request_parameters.
|
|
246
|
+
elif model_request_parameters.output_tools:
|
|
244
247
|
# Json Mode
|
|
245
|
-
parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.
|
|
248
|
+
parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools]
|
|
246
249
|
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
|
|
247
250
|
mistral_messages.append(user_output_format_message)
|
|
248
251
|
|
|
@@ -251,6 +254,7 @@ class MistralModel(Model):
|
|
|
251
254
|
messages=mistral_messages,
|
|
252
255
|
response_format={'type': 'json_object'},
|
|
253
256
|
stream=True,
|
|
257
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
254
258
|
)
|
|
255
259
|
|
|
256
260
|
else:
|
|
@@ -259,6 +263,7 @@ class MistralModel(Model):
|
|
|
259
263
|
model=str(self._model_name),
|
|
260
264
|
messages=mistral_messages,
|
|
261
265
|
stream=True,
|
|
266
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
262
267
|
)
|
|
263
268
|
assert response, 'A unexpected empty response from Mistral.'
|
|
264
269
|
return response
|
|
@@ -271,22 +276,22 @@ class MistralModel(Model):
|
|
|
271
276
|
- "none": Prevents tool use.
|
|
272
277
|
- "required": Forces tool use.
|
|
273
278
|
"""
|
|
274
|
-
if not model_request_parameters.function_tools and not model_request_parameters.
|
|
279
|
+
if not model_request_parameters.function_tools and not model_request_parameters.output_tools:
|
|
275
280
|
return None
|
|
276
|
-
elif not model_request_parameters.
|
|
281
|
+
elif not model_request_parameters.allow_text_output:
|
|
277
282
|
return 'required'
|
|
278
283
|
else:
|
|
279
284
|
return 'auto'
|
|
280
285
|
|
|
281
|
-
def
|
|
286
|
+
def _map_function_and_output_tools_definition(
|
|
282
287
|
self, model_request_parameters: ModelRequestParameters
|
|
283
288
|
) -> list[MistralTool] | None:
|
|
284
|
-
"""Map function and
|
|
289
|
+
"""Map function and output tools to MistralTool format.
|
|
285
290
|
|
|
286
|
-
Returns None if both function_tools and
|
|
291
|
+
Returns None if both function_tools and output_tools are empty.
|
|
287
292
|
"""
|
|
288
293
|
all_tools: list[ToolDefinition] = (
|
|
289
|
-
model_request_parameters.function_tools + model_request_parameters.
|
|
294
|
+
model_request_parameters.function_tools + model_request_parameters.output_tools
|
|
290
295
|
)
|
|
291
296
|
tools = [
|
|
292
297
|
MistralTool(
|
|
@@ -322,7 +327,7 @@ class MistralModel(Model):
|
|
|
322
327
|
|
|
323
328
|
async def _process_streamed_response(
|
|
324
329
|
self,
|
|
325
|
-
|
|
330
|
+
output_tools: list[ToolDefinition],
|
|
326
331
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
327
332
|
) -> StreamedResponse:
|
|
328
333
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -340,7 +345,7 @@ class MistralModel(Model):
|
|
|
340
345
|
_response=peekable_response,
|
|
341
346
|
_model_name=self._model_name,
|
|
342
347
|
_timestamp=timestamp,
|
|
343
|
-
|
|
348
|
+
_output_tools={c.name: c for c in output_tools},
|
|
344
349
|
)
|
|
345
350
|
|
|
346
351
|
@staticmethod
|
|
@@ -434,13 +439,12 @@ class MistralModel(Model):
|
|
|
434
439
|
return int(1000 * timeout)
|
|
435
440
|
raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
|
|
436
441
|
|
|
437
|
-
|
|
438
|
-
def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
|
|
442
|
+
def _map_user_message(self, message: ModelRequest) -> Iterable[MistralMessages]:
|
|
439
443
|
for part in message.parts:
|
|
440
444
|
if isinstance(part, SystemPromptPart):
|
|
441
445
|
yield MistralSystemMessage(content=part.content)
|
|
442
446
|
elif isinstance(part, UserPromptPart):
|
|
443
|
-
yield
|
|
447
|
+
yield self._map_user_prompt(part)
|
|
444
448
|
elif isinstance(part, ToolReturnPart):
|
|
445
449
|
yield MistralToolMessage(
|
|
446
450
|
tool_call_id=part.tool_call_id,
|
|
@@ -457,28 +461,31 @@ class MistralModel(Model):
|
|
|
457
461
|
else:
|
|
458
462
|
assert_never(part)
|
|
459
463
|
|
|
460
|
-
|
|
461
|
-
def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
|
|
464
|
+
def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]:
|
|
462
465
|
"""Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
466
|
+
mistral_messages: list[MistralMessages] = []
|
|
467
|
+
for message in messages:
|
|
468
|
+
if isinstance(message, ModelRequest):
|
|
469
|
+
mistral_messages.extend(self._map_user_message(message))
|
|
470
|
+
elif isinstance(message, ModelResponse):
|
|
471
|
+
content_chunks: list[MistralContentChunk] = []
|
|
472
|
+
tool_calls: list[MistralToolCall] = []
|
|
473
|
+
|
|
474
|
+
for part in message.parts:
|
|
475
|
+
if isinstance(part, TextPart):
|
|
476
|
+
content_chunks.append(MistralTextChunk(text=part.content))
|
|
477
|
+
elif isinstance(part, ToolCallPart):
|
|
478
|
+
tool_calls.append(self._map_tool_call(part))
|
|
479
|
+
else:
|
|
480
|
+
assert_never(part)
|
|
481
|
+
mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
|
|
482
|
+
else:
|
|
483
|
+
assert_never(message)
|
|
484
|
+
if instructions := self._get_instructions(messages):
|
|
485
|
+
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
|
|
486
|
+
return mistral_messages
|
|
479
487
|
|
|
480
|
-
|
|
481
|
-
def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
|
|
488
|
+
def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
|
|
482
489
|
content: str | list[MistralContentChunk]
|
|
483
490
|
if isinstance(part.content, str):
|
|
484
491
|
content = part.content
|
|
@@ -498,6 +505,8 @@ class MistralModel(Model):
|
|
|
498
505
|
raise RuntimeError('Only image binary content is supported for Mistral.')
|
|
499
506
|
elif isinstance(item, DocumentUrl):
|
|
500
507
|
raise RuntimeError('DocumentUrl is not supported in Mistral.')
|
|
508
|
+
elif isinstance(item, VideoUrl):
|
|
509
|
+
raise RuntimeError('VideoUrl is not supported in Mistral.')
|
|
501
510
|
else: # pragma: no cover
|
|
502
511
|
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
503
512
|
return MistralUserMessage(content=content)
|
|
@@ -513,7 +522,7 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
513
522
|
_model_name: MistralModelName
|
|
514
523
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
515
524
|
_timestamp: datetime
|
|
516
|
-
|
|
525
|
+
_output_tools: dict[str, ToolDefinition]
|
|
517
526
|
|
|
518
527
|
_delta_content: str = field(default='', init=False)
|
|
519
528
|
|
|
@@ -531,13 +540,13 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
531
540
|
content = choice.delta.content
|
|
532
541
|
text = _map_content(content)
|
|
533
542
|
if text:
|
|
534
|
-
# Attempt to produce
|
|
535
|
-
if self.
|
|
543
|
+
# Attempt to produce an output tool call from the received text
|
|
544
|
+
if self._output_tools:
|
|
536
545
|
self._delta_content += text
|
|
537
|
-
maybe_tool_call_part = self.
|
|
546
|
+
maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools)
|
|
538
547
|
if maybe_tool_call_part:
|
|
539
548
|
yield self._parts_manager.handle_tool_call_part(
|
|
540
|
-
vendor_part_id='
|
|
549
|
+
vendor_part_id='output',
|
|
541
550
|
tool_name=maybe_tool_call_part.tool_name,
|
|
542
551
|
args=maybe_tool_call_part.args_as_dict(),
|
|
543
552
|
tool_call_id=maybe_tool_call_part.tool_call_id,
|
|
@@ -563,20 +572,20 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
563
572
|
return self._timestamp
|
|
564
573
|
|
|
565
574
|
@staticmethod
|
|
566
|
-
def
|
|
575
|
+
def _try_get_output_tool_from_text(text: str, output_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
|
|
567
576
|
output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
|
|
568
577
|
if output_json:
|
|
569
|
-
for
|
|
570
|
-
# NOTE: Additional verification to prevent JSON validation to crash
|
|
578
|
+
for output_tool in output_tools.values():
|
|
579
|
+
# NOTE: Additional verification to prevent JSON validation to crash
|
|
571
580
|
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
572
581
|
# Example with BaseModel and required fields.
|
|
573
582
|
if not MistralStreamedResponse._validate_required_json_schema(
|
|
574
|
-
output_json,
|
|
583
|
+
output_json, output_tool.parameters_json_schema
|
|
575
584
|
):
|
|
576
585
|
continue
|
|
577
586
|
|
|
578
587
|
# The following part_id will be thrown away
|
|
579
|
-
return ToolCallPart(tool_name=
|
|
588
|
+
return ToolCallPart(tool_name=output_tool.name, args=output_json)
|
|
580
589
|
|
|
581
590
|
@staticmethod
|
|
582
591
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
@@ -644,21 +653,21 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
|
|
|
644
653
|
|
|
645
654
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
|
646
655
|
"""Maps the delta content from a Mistral Completion Chunk to a string or None."""
|
|
647
|
-
|
|
656
|
+
output: str | None = None
|
|
648
657
|
|
|
649
658
|
if isinstance(content, MistralUnset) or not content:
|
|
650
|
-
|
|
659
|
+
output = None
|
|
651
660
|
elif isinstance(content, list):
|
|
652
661
|
for chunk in content:
|
|
653
662
|
if isinstance(chunk, MistralTextChunk):
|
|
654
|
-
|
|
663
|
+
output = output or '' + chunk.text
|
|
655
664
|
else:
|
|
656
665
|
assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
|
|
657
666
|
elif isinstance(content, str):
|
|
658
|
-
|
|
667
|
+
output = content
|
|
659
668
|
|
|
660
669
|
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
|
|
661
|
-
if
|
|
662
|
-
|
|
670
|
+
if output and len(output) == 0: # pragma: no cover
|
|
671
|
+
output = None
|
|
663
672
|
|
|
664
|
-
return
|
|
673
|
+
return output
|