pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
|
@@ -0,0 +1,668 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from itertools import chain
|
|
9
|
+
from typing import Any, Callable, Literal, Union
|
|
10
|
+
|
|
11
|
+
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
12
|
+
from typing_extensions import assert_never
|
|
13
|
+
|
|
14
|
+
from .. import UnexpectedModelBehavior
|
|
15
|
+
from .._utils import now_utc as _now_utc
|
|
16
|
+
from ..messages import (
|
|
17
|
+
ArgsJson,
|
|
18
|
+
ModelMessage,
|
|
19
|
+
ModelRequest,
|
|
20
|
+
ModelResponse,
|
|
21
|
+
ModelResponsePart,
|
|
22
|
+
RetryPromptPart,
|
|
23
|
+
SystemPromptPart,
|
|
24
|
+
TextPart,
|
|
25
|
+
ToolCallPart,
|
|
26
|
+
ToolReturnPart,
|
|
27
|
+
UserPromptPart,
|
|
28
|
+
)
|
|
29
|
+
from ..result import Usage
|
|
30
|
+
from ..settings import ModelSettings
|
|
31
|
+
from ..tools import ToolDefinition
|
|
32
|
+
from . import (
|
|
33
|
+
AgentModel,
|
|
34
|
+
EitherStreamedResponse,
|
|
35
|
+
Model,
|
|
36
|
+
StreamStructuredResponse,
|
|
37
|
+
StreamTextResponse,
|
|
38
|
+
cached_async_http_client,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from json_repair import repair_json
|
|
43
|
+
from mistralai import (
|
|
44
|
+
UNSET,
|
|
45
|
+
CompletionChunk as MistralCompletionChunk,
|
|
46
|
+
Content as MistralContent,
|
|
47
|
+
ContentChunk as MistralContentChunk,
|
|
48
|
+
FunctionCall as MistralFunctionCall,
|
|
49
|
+
Mistral,
|
|
50
|
+
OptionalNullable as MistralOptionalNullable,
|
|
51
|
+
TextChunk as MistralTextChunk,
|
|
52
|
+
ToolChoiceEnum as MistralToolChoiceEnum,
|
|
53
|
+
)
|
|
54
|
+
from mistralai.models import (
|
|
55
|
+
ChatCompletionResponse as MistralChatCompletionResponse,
|
|
56
|
+
CompletionEvent as MistralCompletionEvent,
|
|
57
|
+
Messages as MistralMessages,
|
|
58
|
+
Tool as MistralTool,
|
|
59
|
+
ToolCall as MistralToolCall,
|
|
60
|
+
)
|
|
61
|
+
from mistralai.models.assistantmessage import AssistantMessage as MistralAssistantMessage
|
|
62
|
+
from mistralai.models.function import Function as MistralFunction
|
|
63
|
+
from mistralai.models.systemmessage import SystemMessage as MistralSystemMessage
|
|
64
|
+
from mistralai.models.toolmessage import ToolMessage as MistralToolMessage
|
|
65
|
+
from mistralai.models.usermessage import UserMessage as MistralUserMessage
|
|
66
|
+
from mistralai.types.basemodel import Unset as MistralUnset
|
|
67
|
+
from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
|
|
68
|
+
except ImportError as e:
|
|
69
|
+
raise ImportError(
|
|
70
|
+
'Please install `mistral` to use the Mistral model, '
|
|
71
|
+
"you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
|
|
72
|
+
) from e
|
|
73
|
+
|
|
74
|
+
NamedMistralModels = Literal[
|
|
75
|
+
'mistral-large-latest', 'mistral-small-latest', 'codestral-latest', 'mistral-moderation-latest'
|
|
76
|
+
]
|
|
77
|
+
"""Latest / most popular named Mistral models."""
|
|
78
|
+
|
|
79
|
+
MistralModelName = Union[NamedMistralModels, str]
|
|
80
|
+
"""Possible Mistral model names.
|
|
81
|
+
|
|
82
|
+
Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
|
|
83
|
+
allow any name in the type hints.
|
|
84
|
+
Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_overview/) for a full list.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(init=False)
|
|
89
|
+
class MistralModel(Model):
|
|
90
|
+
"""A model that uses Mistral.
|
|
91
|
+
|
|
92
|
+
Internally, this uses the [Mistral Python client](https://github.com/mistralai/client-python) to interact with the API.
|
|
93
|
+
|
|
94
|
+
[API Documentation](https://docs.mistral.ai/)
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
model_name: MistralModelName
|
|
98
|
+
client: Mistral = field(repr=False)
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
model_name: MistralModelName,
|
|
103
|
+
*,
|
|
104
|
+
api_key: str | Callable[[], str | None] | None = None,
|
|
105
|
+
client: Mistral | None = None,
|
|
106
|
+
http_client: AsyncHTTPClient | None = None,
|
|
107
|
+
):
|
|
108
|
+
"""Initialize a Mistral model.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
model_name: The name of the model to use.
|
|
112
|
+
api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
|
|
113
|
+
client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
114
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
115
|
+
"""
|
|
116
|
+
self.model_name = model_name
|
|
117
|
+
|
|
118
|
+
if client is not None:
|
|
119
|
+
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
|
|
120
|
+
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
|
|
121
|
+
self.client = client
|
|
122
|
+
else:
|
|
123
|
+
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
|
|
124
|
+
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
125
|
+
|
|
126
|
+
async def agent_model(
|
|
127
|
+
self,
|
|
128
|
+
*,
|
|
129
|
+
function_tools: list[ToolDefinition],
|
|
130
|
+
allow_text_result: bool,
|
|
131
|
+
result_tools: list[ToolDefinition],
|
|
132
|
+
) -> AgentModel:
|
|
133
|
+
"""Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
|
|
134
|
+
return MistralAgentModel(
|
|
135
|
+
self.client,
|
|
136
|
+
self.model_name,
|
|
137
|
+
allow_text_result,
|
|
138
|
+
function_tools,
|
|
139
|
+
result_tools,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def name(self) -> str:
|
|
143
|
+
return f'mistral:{self.model_name}'
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass
|
|
147
|
+
class MistralAgentModel(AgentModel):
|
|
148
|
+
"""Implementation of `AgentModel` for Mistral models."""
|
|
149
|
+
|
|
150
|
+
client: Mistral
|
|
151
|
+
model_name: str
|
|
152
|
+
allow_text_result: bool
|
|
153
|
+
function_tools: list[ToolDefinition]
|
|
154
|
+
result_tools: list[ToolDefinition]
|
|
155
|
+
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
156
|
+
|
|
157
|
+
async def request(
|
|
158
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
159
|
+
) -> tuple[ModelResponse, Usage]:
|
|
160
|
+
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
161
|
+
response = await self._completions_create(messages, model_settings)
|
|
162
|
+
return self._process_response(response), _map_usage(response)
|
|
163
|
+
|
|
164
|
+
@asynccontextmanager
|
|
165
|
+
async def request_stream(
|
|
166
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
168
|
+
"""Make a streaming request to the model from Pydantic AI call."""
|
|
169
|
+
response = await self._stream_completions_create(messages, model_settings)
|
|
170
|
+
async with response:
|
|
171
|
+
yield await self._process_streamed_response(self.result_tools, response)
|
|
172
|
+
|
|
173
|
+
async def _completions_create(
|
|
174
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
175
|
+
) -> MistralChatCompletionResponse:
|
|
176
|
+
"""Make a non-streaming request to the model."""
|
|
177
|
+
model_settings = model_settings or {}
|
|
178
|
+
response = await self.client.chat.complete_async(
|
|
179
|
+
model=str(self.model_name),
|
|
180
|
+
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
181
|
+
n=1,
|
|
182
|
+
tools=self._map_function_and_result_tools_definition() or UNSET,
|
|
183
|
+
tool_choice=self._get_tool_choice(),
|
|
184
|
+
stream=False,
|
|
185
|
+
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
186
|
+
temperature=model_settings.get('temperature', UNSET),
|
|
187
|
+
top_p=model_settings.get('top_p', 1),
|
|
188
|
+
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
189
|
+
)
|
|
190
|
+
assert response, 'A unexpected empty response from Mistral.'
|
|
191
|
+
return response
|
|
192
|
+
|
|
193
|
+
async def _stream_completions_create(
|
|
194
|
+
self,
|
|
195
|
+
messages: list[ModelMessage],
|
|
196
|
+
model_settings: ModelSettings | None,
|
|
197
|
+
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
198
|
+
"""Create a streaming completion request to the Mistral model."""
|
|
199
|
+
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
200
|
+
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
201
|
+
|
|
202
|
+
model_settings = model_settings or {}
|
|
203
|
+
|
|
204
|
+
if self.result_tools and self.function_tools or self.function_tools:
|
|
205
|
+
# Function Calling Mode
|
|
206
|
+
response = await self.client.chat.stream_async(
|
|
207
|
+
model=str(self.model_name),
|
|
208
|
+
messages=mistral_messages,
|
|
209
|
+
n=1,
|
|
210
|
+
tools=self._map_function_and_result_tools_definition() or UNSET,
|
|
211
|
+
tool_choice=self._get_tool_choice(),
|
|
212
|
+
temperature=model_settings.get('temperature', UNSET),
|
|
213
|
+
top_p=model_settings.get('top_p', 1),
|
|
214
|
+
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
215
|
+
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
elif self.result_tools:
|
|
219
|
+
# Json Mode
|
|
220
|
+
parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
|
|
221
|
+
|
|
222
|
+
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
|
|
223
|
+
mistral_messages.append(user_output_format_message)
|
|
224
|
+
response = await self.client.chat.stream_async(
|
|
225
|
+
model=str(self.model_name),
|
|
226
|
+
messages=mistral_messages,
|
|
227
|
+
response_format={'type': 'json_object'},
|
|
228
|
+
stream=True,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
else:
|
|
232
|
+
# Stream Mode
|
|
233
|
+
response = await self.client.chat.stream_async(
|
|
234
|
+
model=str(self.model_name),
|
|
235
|
+
messages=mistral_messages,
|
|
236
|
+
stream=True,
|
|
237
|
+
)
|
|
238
|
+
assert response, 'A unexpected empty response from Mistral.'
|
|
239
|
+
return response
|
|
240
|
+
|
|
241
|
+
def _get_tool_choice(self) -> MistralToolChoiceEnum | None:
|
|
242
|
+
"""Get tool choice for the model.
|
|
243
|
+
|
|
244
|
+
- "auto": Default mode. Model decides if it uses the tool or not.
|
|
245
|
+
- "any": Select any tool.
|
|
246
|
+
- "none": Prevents tool use.
|
|
247
|
+
- "required": Forces tool use.
|
|
248
|
+
"""
|
|
249
|
+
if not self.function_tools and not self.result_tools:
|
|
250
|
+
return None
|
|
251
|
+
elif not self.allow_text_result:
|
|
252
|
+
return 'required'
|
|
253
|
+
else:
|
|
254
|
+
return 'auto'
|
|
255
|
+
|
|
256
|
+
def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None:
|
|
257
|
+
"""Map function and result tools to MistralTool format.
|
|
258
|
+
|
|
259
|
+
Returns None if both function_tools and result_tools are empty.
|
|
260
|
+
"""
|
|
261
|
+
all_tools: list[ToolDefinition] = self.function_tools + self.result_tools
|
|
262
|
+
tools = [
|
|
263
|
+
MistralTool(
|
|
264
|
+
function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description)
|
|
265
|
+
)
|
|
266
|
+
for r in all_tools
|
|
267
|
+
]
|
|
268
|
+
return tools if tools else None
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
|
|
272
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
273
|
+
if response.created:
|
|
274
|
+
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
275
|
+
else:
|
|
276
|
+
timestamp = _now_utc()
|
|
277
|
+
|
|
278
|
+
assert response.choices, 'Unexpected empty response choice.'
|
|
279
|
+
choice = response.choices[0]
|
|
280
|
+
content = choice.message.content
|
|
281
|
+
tool_calls = choice.message.tool_calls
|
|
282
|
+
|
|
283
|
+
parts: list[ModelResponsePart] = []
|
|
284
|
+
if text := _map_content(content):
|
|
285
|
+
parts.append(TextPart(text))
|
|
286
|
+
|
|
287
|
+
if isinstance(tool_calls, list):
|
|
288
|
+
for tool_call in tool_calls:
|
|
289
|
+
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
290
|
+
parts.append(tool)
|
|
291
|
+
|
|
292
|
+
return ModelResponse(parts, timestamp=timestamp)
|
|
293
|
+
|
|
294
|
+
@staticmethod
|
|
295
|
+
async def _process_streamed_response(
|
|
296
|
+
result_tools: list[ToolDefinition],
|
|
297
|
+
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
298
|
+
) -> EitherStreamedResponse:
|
|
299
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
300
|
+
start_usage = Usage()
|
|
301
|
+
|
|
302
|
+
# Iterate until we get either `tool_calls` or `content` from the first chunk.
|
|
303
|
+
while True:
|
|
304
|
+
try:
|
|
305
|
+
event = await response.__anext__()
|
|
306
|
+
chunk = event.data
|
|
307
|
+
except StopAsyncIteration as e:
|
|
308
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
309
|
+
|
|
310
|
+
start_usage += _map_usage(chunk)
|
|
311
|
+
|
|
312
|
+
if chunk.created:
|
|
313
|
+
timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
314
|
+
else:
|
|
315
|
+
timestamp = _now_utc()
|
|
316
|
+
|
|
317
|
+
if chunk.choices:
|
|
318
|
+
delta = chunk.choices[0].delta
|
|
319
|
+
content = _map_content(delta.content)
|
|
320
|
+
|
|
321
|
+
tool_calls: list[MistralToolCall] | None = None
|
|
322
|
+
if delta.tool_calls:
|
|
323
|
+
tool_calls = delta.tool_calls
|
|
324
|
+
|
|
325
|
+
if tool_calls or content and result_tools:
|
|
326
|
+
return MistralStreamStructuredResponse(
|
|
327
|
+
{c.id if c.id else 'null': c for c in tool_calls or []},
|
|
328
|
+
{c.name: c for c in result_tools},
|
|
329
|
+
response,
|
|
330
|
+
content,
|
|
331
|
+
timestamp,
|
|
332
|
+
start_usage,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
elif content:
|
|
336
|
+
return MistralStreamTextResponse(content, response, timestamp, start_usage)
|
|
337
|
+
|
|
338
|
+
@staticmethod
|
|
339
|
+
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
340
|
+
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
341
|
+
if isinstance(t.args, ArgsJson):
|
|
342
|
+
return MistralToolCall(
|
|
343
|
+
id=t.tool_call_id,
|
|
344
|
+
type='function',
|
|
345
|
+
function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_json),
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
return MistralToolCall(
|
|
349
|
+
id=t.tool_call_id,
|
|
350
|
+
type='function',
|
|
351
|
+
function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
|
|
355
|
+
"""Get a message with an example of the expected output format."""
|
|
356
|
+
examples: list[dict[str, Any]] = []
|
|
357
|
+
for schema in schemas:
|
|
358
|
+
typed_dict_definition: dict[str, Any] = {}
|
|
359
|
+
for key, value in schema.get('properties', {}).items():
|
|
360
|
+
typed_dict_definition[key] = self._get_python_type(value)
|
|
361
|
+
examples.append(typed_dict_definition)
|
|
362
|
+
|
|
363
|
+
example_schema = examples[0] if len(examples) == 1 else examples
|
|
364
|
+
return MistralUserMessage(content=self.json_mode_schema_prompt.format(schema=example_schema))
|
|
365
|
+
|
|
366
|
+
@classmethod
|
|
367
|
+
def _get_python_type(cls, value: dict[str, Any]) -> str:
|
|
368
|
+
"""Return a string representation of the Python type for a single JSON schema property.
|
|
369
|
+
|
|
370
|
+
This function handles recursion for nested arrays/objects and `anyOf`.
|
|
371
|
+
"""
|
|
372
|
+
# 1) Handle anyOf first, because it's a different schema structure
|
|
373
|
+
if any_of := value.get('anyOf'):
|
|
374
|
+
# Simplistic approach: pick the first option in anyOf
|
|
375
|
+
# (In reality, you'd possibly want to merge or union types)
|
|
376
|
+
return f'Optional[{cls._get_python_type(any_of[0])}]'
|
|
377
|
+
|
|
378
|
+
# 2) If we have a top-level "type" field
|
|
379
|
+
value_type = value.get('type')
|
|
380
|
+
if not value_type:
|
|
381
|
+
# No explicit type; fallback
|
|
382
|
+
return 'Any'
|
|
383
|
+
|
|
384
|
+
# 3) Direct simple type mapping (string, integer, float, bool, None)
|
|
385
|
+
if value_type in SIMPLE_JSON_TYPE_MAPPING and value_type != 'array' and value_type != 'object':
|
|
386
|
+
return SIMPLE_JSON_TYPE_MAPPING[value_type]
|
|
387
|
+
|
|
388
|
+
# 4) Array: Recursively get the item type
|
|
389
|
+
if value_type == 'array':
|
|
390
|
+
items = value.get('items', {})
|
|
391
|
+
return f'list[{cls._get_python_type(items)}]'
|
|
392
|
+
|
|
393
|
+
# 5) Object: Check for additionalProperties
|
|
394
|
+
if value_type == 'object':
|
|
395
|
+
additional_properties = value.get('additionalProperties', {})
|
|
396
|
+
additional_properties_type = additional_properties.get('type')
|
|
397
|
+
if (
|
|
398
|
+
additional_properties_type in SIMPLE_JSON_TYPE_MAPPING
|
|
399
|
+
and additional_properties_type != 'array'
|
|
400
|
+
and additional_properties_type != 'object'
|
|
401
|
+
):
|
|
402
|
+
# dict[str, bool/int/float/etc...]
|
|
403
|
+
return f'dict[str, {SIMPLE_JSON_TYPE_MAPPING[additional_properties_type]}]'
|
|
404
|
+
elif additional_properties_type == 'array':
|
|
405
|
+
array_items = additional_properties.get('items', {})
|
|
406
|
+
return f'dict[str, list[{cls._get_python_type(array_items)}]]'
|
|
407
|
+
elif additional_properties_type == 'object':
|
|
408
|
+
# nested dictionary of unknown shape
|
|
409
|
+
return 'dict[str, dict[str, Any]]'
|
|
410
|
+
else:
|
|
411
|
+
# If no additionalProperties type or something else, default to a generic dict
|
|
412
|
+
return 'dict[str, Any]'
|
|
413
|
+
|
|
414
|
+
# 6) Fallback
|
|
415
|
+
return 'Any'
|
|
416
|
+
|
|
417
|
+
@staticmethod
|
|
418
|
+
def _get_timeout_ms(timeout: Timeout | float | None) -> int | None:
|
|
419
|
+
"""Convert a timeout to milliseconds."""
|
|
420
|
+
if timeout is None:
|
|
421
|
+
return None
|
|
422
|
+
if isinstance(timeout, float):
|
|
423
|
+
return int(1000 * timeout)
|
|
424
|
+
raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
|
|
425
|
+
|
|
426
|
+
@classmethod
|
|
427
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
|
|
428
|
+
for part in message.parts:
|
|
429
|
+
if isinstance(part, SystemPromptPart):
|
|
430
|
+
yield MistralSystemMessage(content=part.content)
|
|
431
|
+
elif isinstance(part, UserPromptPart):
|
|
432
|
+
yield MistralUserMessage(content=part.content)
|
|
433
|
+
elif isinstance(part, ToolReturnPart):
|
|
434
|
+
yield MistralToolMessage(
|
|
435
|
+
tool_call_id=part.tool_call_id,
|
|
436
|
+
content=part.model_response_str(),
|
|
437
|
+
)
|
|
438
|
+
elif isinstance(part, RetryPromptPart):
|
|
439
|
+
if part.tool_name is None:
|
|
440
|
+
yield MistralUserMessage(content=part.model_response())
|
|
441
|
+
else:
|
|
442
|
+
yield MistralToolMessage(
|
|
443
|
+
tool_call_id=part.tool_call_id,
|
|
444
|
+
content=part.model_response(),
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
assert_never(part)
|
|
448
|
+
|
|
449
|
+
@classmethod
|
|
450
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
|
|
451
|
+
"""Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
|
|
452
|
+
if isinstance(message, ModelRequest):
|
|
453
|
+
yield from cls._map_user_message(message)
|
|
454
|
+
elif isinstance(message, ModelResponse):
|
|
455
|
+
content_chunks: list[MistralContentChunk] = []
|
|
456
|
+
tool_calls: list[MistralToolCall] = []
|
|
457
|
+
|
|
458
|
+
for part in message.parts:
|
|
459
|
+
if isinstance(part, TextPart):
|
|
460
|
+
content_chunks.append(MistralTextChunk(text=part.content))
|
|
461
|
+
elif isinstance(part, ToolCallPart):
|
|
462
|
+
tool_calls.append(cls._map_to_mistral_tool_call(part))
|
|
463
|
+
else:
|
|
464
|
+
assert_never(part)
|
|
465
|
+
yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
|
|
466
|
+
else:
|
|
467
|
+
assert_never(message)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
@dataclass
|
|
471
|
+
class MistralStreamTextResponse(StreamTextResponse):
|
|
472
|
+
"""Implementation of `StreamTextResponse` for Mistral models."""
|
|
473
|
+
|
|
474
|
+
_first: str | None
|
|
475
|
+
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
476
|
+
_timestamp: datetime
|
|
477
|
+
_usage: Usage
|
|
478
|
+
_buffer: list[str] = field(default_factory=list, init=False)
|
|
479
|
+
|
|
480
|
+
async def __anext__(self) -> None:
|
|
481
|
+
if self._first is not None and len(self._first) > 0:
|
|
482
|
+
self._buffer.append(self._first)
|
|
483
|
+
self._first = None
|
|
484
|
+
return None
|
|
485
|
+
|
|
486
|
+
chunk = await self._response.__anext__()
|
|
487
|
+
self._usage += _map_usage(chunk.data)
|
|
488
|
+
|
|
489
|
+
try:
|
|
490
|
+
choice = chunk.data.choices[0]
|
|
491
|
+
except IndexError:
|
|
492
|
+
raise StopAsyncIteration()
|
|
493
|
+
|
|
494
|
+
content = choice.delta.content
|
|
495
|
+
if choice.finish_reason is None:
|
|
496
|
+
assert content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
|
|
497
|
+
|
|
498
|
+
if text := _map_content(content):
|
|
499
|
+
self._buffer.append(text)
|
|
500
|
+
|
|
501
|
+
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
502
|
+
yield from self._buffer
|
|
503
|
+
self._buffer.clear()
|
|
504
|
+
|
|
505
|
+
def usage(self) -> Usage:
|
|
506
|
+
return self._usage
|
|
507
|
+
|
|
508
|
+
def timestamp(self) -> datetime:
|
|
509
|
+
return self._timestamp
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
@dataclass
|
|
513
|
+
class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
514
|
+
"""Implementation of `StreamStructuredResponse` for Mistral models."""
|
|
515
|
+
|
|
516
|
+
_function_tools: dict[str, MistralToolCall]
|
|
517
|
+
_result_tools: dict[str, ToolDefinition]
|
|
518
|
+
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
519
|
+
_delta_content: str | None
|
|
520
|
+
_timestamp: datetime
|
|
521
|
+
_usage: Usage
|
|
522
|
+
|
|
523
|
+
async def __anext__(self) -> None:
|
|
524
|
+
chunk = await self._response.__anext__()
|
|
525
|
+
self._usage += _map_usage(chunk.data)
|
|
526
|
+
|
|
527
|
+
try:
|
|
528
|
+
choice = chunk.data.choices[0]
|
|
529
|
+
|
|
530
|
+
except IndexError:
|
|
531
|
+
raise StopAsyncIteration()
|
|
532
|
+
|
|
533
|
+
if choice.finish_reason is not None:
|
|
534
|
+
raise StopAsyncIteration()
|
|
535
|
+
|
|
536
|
+
content = choice.delta.content
|
|
537
|
+
if self._result_tools:
|
|
538
|
+
if text := _map_content(content):
|
|
539
|
+
self._delta_content = (self._delta_content or '') + text
|
|
540
|
+
|
|
541
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
542
|
+
calls: list[ModelResponsePart] = []
|
|
543
|
+
if self._function_tools and self._result_tools or self._function_tools:
|
|
544
|
+
for tool_call in self._function_tools.values():
|
|
545
|
+
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
546
|
+
calls.append(tool)
|
|
547
|
+
|
|
548
|
+
elif self._delta_content and self._result_tools:
|
|
549
|
+
# NOTE: Params set for the most efficient and fastest way.
|
|
550
|
+
output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True)
|
|
551
|
+
assert isinstance(
|
|
552
|
+
output_json, dict
|
|
553
|
+
), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
|
|
554
|
+
|
|
555
|
+
if output_json:
|
|
556
|
+
for result_tool in self._result_tools.values():
|
|
557
|
+
# NOTE: Additional verification to prevent JSON validation to crash in `result.py`
|
|
558
|
+
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
559
|
+
# For example, `return_type=list[str]` expects a 'response' key with value type array of str.
|
|
560
|
+
# when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
|
|
561
|
+
# when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
|
|
562
|
+
# This ensures it's corrected to `{"response": {}}` and other required parameters and type.
|
|
563
|
+
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
|
|
564
|
+
continue
|
|
565
|
+
|
|
566
|
+
tool = ToolCallPart.from_raw_args(result_tool.name, output_json)
|
|
567
|
+
calls.append(tool)
|
|
568
|
+
|
|
569
|
+
return ModelResponse(calls, timestamp=self._timestamp)
|
|
570
|
+
|
|
571
|
+
def usage(self) -> Usage:
|
|
572
|
+
return self._usage
|
|
573
|
+
|
|
574
|
+
def timestamp(self) -> datetime:
|
|
575
|
+
return self._timestamp
|
|
576
|
+
|
|
577
|
+
@staticmethod
|
|
578
|
+
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
579
|
+
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
|
|
580
|
+
required_params = json_schema.get('required', [])
|
|
581
|
+
properties = json_schema.get('properties', {})
|
|
582
|
+
|
|
583
|
+
for param in required_params:
|
|
584
|
+
if param not in json_dict:
|
|
585
|
+
return False
|
|
586
|
+
|
|
587
|
+
param_schema = properties.get(param, {})
|
|
588
|
+
param_type = param_schema.get('type')
|
|
589
|
+
param_items_type = param_schema.get('items', {}).get('type')
|
|
590
|
+
|
|
591
|
+
if param_type == 'array' and param_items_type:
|
|
592
|
+
if not isinstance(json_dict[param], list):
|
|
593
|
+
return False
|
|
594
|
+
for item in json_dict[param]:
|
|
595
|
+
if not isinstance(item, VALIDE_JSON_TYPE_MAPPING[param_items_type]):
|
|
596
|
+
return False
|
|
597
|
+
elif param_type and not isinstance(json_dict[param], VALIDE_JSON_TYPE_MAPPING[param_type]):
|
|
598
|
+
return False
|
|
599
|
+
|
|
600
|
+
if isinstance(json_dict[param], dict) and 'properties' in param_schema:
|
|
601
|
+
nested_schema = param_schema
|
|
602
|
+
if not MistralStreamStructuredResponse._validate_required_json_schema(json_dict[param], nested_schema):
|
|
603
|
+
return False
|
|
604
|
+
|
|
605
|
+
return True
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
VALIDE_JSON_TYPE_MAPPING: dict[str, Any] = {
|
|
609
|
+
'string': str,
|
|
610
|
+
'integer': int,
|
|
611
|
+
'number': float,
|
|
612
|
+
'boolean': bool,
|
|
613
|
+
'array': list,
|
|
614
|
+
'object': dict,
|
|
615
|
+
'null': type(None),
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
SIMPLE_JSON_TYPE_MAPPING = {
|
|
619
|
+
'string': 'str',
|
|
620
|
+
'integer': 'int',
|
|
621
|
+
'number': 'float',
|
|
622
|
+
'boolean': 'bool',
|
|
623
|
+
'array': 'list',
|
|
624
|
+
'null': 'None',
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
|
|
629
|
+
"""Maps a MistralToolCall to a ToolCall."""
|
|
630
|
+
tool_call_id = tool_call.id or None
|
|
631
|
+
func_call = tool_call.function
|
|
632
|
+
|
|
633
|
+
return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
637
|
+
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
638
|
+
if response.usage:
|
|
639
|
+
return Usage(
|
|
640
|
+
request_tokens=response.usage.prompt_tokens,
|
|
641
|
+
response_tokens=response.usage.completion_tokens,
|
|
642
|
+
total_tokens=response.usage.total_tokens,
|
|
643
|
+
details=None,
|
|
644
|
+
)
|
|
645
|
+
else:
|
|
646
|
+
return Usage()
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
|
650
|
+
"""Maps the delta content from a Mistral Completion Chunk to a string or None."""
|
|
651
|
+
result: str | None = None
|
|
652
|
+
|
|
653
|
+
if isinstance(content, MistralUnset) or not content:
|
|
654
|
+
result = None
|
|
655
|
+
elif isinstance(content, list):
|
|
656
|
+
for chunk in content:
|
|
657
|
+
if isinstance(chunk, MistralTextChunk):
|
|
658
|
+
result = result or '' + chunk.text
|
|
659
|
+
else:
|
|
660
|
+
assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
|
|
661
|
+
elif isinstance(content, str):
|
|
662
|
+
result = content
|
|
663
|
+
|
|
664
|
+
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and reponses`)
|
|
665
|
+
if result and len(result) == 0:
|
|
666
|
+
result = None
|
|
667
|
+
|
|
668
|
+
return result
|
pydantic_ai/models/ollama.py
CHANGED
|
@@ -17,7 +17,7 @@ try:
|
|
|
17
17
|
except ImportError as e:
|
|
18
18
|
raise ImportError(
|
|
19
19
|
'Please install `openai` to use the OpenAI model, '
|
|
20
|
-
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
20
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
21
21
|
) from e
|
|
22
22
|
|
|
23
23
|
|