pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +12 -8
- pydantic_ai/agent.py +5 -5
- pydantic_ai/models/__init__.py +52 -45
- pydantic_ai/models/anthropic.py +87 -66
- pydantic_ai/models/cohere.py +65 -67
- pydantic_ai/models/function.py +76 -60
- pydantic_ai/models/gemini.py +153 -99
- pydantic_ai/models/groq.py +97 -72
- pydantic_ai/models/mistral.py +90 -71
- pydantic_ai/models/openai.py +110 -71
- pydantic_ai/models/test.py +99 -94
- pydantic_ai/models/vertexai.py +48 -44
- pydantic_ai/result.py +2 -2
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.24.dist-info}/METADATA +3 -3
- pydantic_ai_slim-0.0.24.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.22.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.24.dist-info}/WHEEL +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from itertools import chain
|
|
8
|
-
from typing import Literal, cast, overload
|
|
8
|
+
from typing import Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -28,8 +28,8 @@ from ..messages import (
|
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
|
-
AgentModel,
|
|
32
31
|
Model,
|
|
32
|
+
ModelRequestParameters,
|
|
33
33
|
StreamedResponse,
|
|
34
34
|
cached_async_http_client,
|
|
35
35
|
check_allow_model_requests,
|
|
@@ -45,7 +45,7 @@ except ImportError as _import_error:
|
|
|
45
45
|
"you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
|
|
46
46
|
) from _import_error
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
LatestGroqModelNames = Literal[
|
|
49
49
|
'llama-3.3-70b-versatile',
|
|
50
50
|
'llama-3.3-70b-specdec',
|
|
51
51
|
'llama-3.1-8b-instant',
|
|
@@ -58,8 +58,14 @@ GroqModelName = Literal[
|
|
|
58
58
|
'mixtral-8x7b-32768',
|
|
59
59
|
'gemma2-9b-it',
|
|
60
60
|
]
|
|
61
|
-
"""
|
|
61
|
+
"""Latest Groq models."""
|
|
62
62
|
|
|
63
|
+
GroqModelName = Union[str, LatestGroqModelNames]
|
|
64
|
+
"""
|
|
65
|
+
Possible Groq model names.
|
|
66
|
+
|
|
67
|
+
Since Groq supports a variety of date-stamped models, we explicitly list the latest models but
|
|
68
|
+
allow any name in the type hints.
|
|
63
69
|
See [the Groq docs](https://console.groq.com/docs/models) for a full list.
|
|
64
70
|
"""
|
|
65
71
|
|
|
@@ -79,9 +85,11 @@ class GroqModel(Model):
|
|
|
79
85
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
80
86
|
"""
|
|
81
87
|
|
|
82
|
-
model_name: GroqModelName
|
|
83
88
|
client: AsyncGroq = field(repr=False)
|
|
84
89
|
|
|
90
|
+
_model_name: GroqModelName = field(repr=False)
|
|
91
|
+
_system: str | None = field(default='groq', repr=False)
|
|
92
|
+
|
|
85
93
|
def __init__(
|
|
86
94
|
self,
|
|
87
95
|
model_name: GroqModelName,
|
|
@@ -102,7 +110,7 @@ class GroqModel(Model):
|
|
|
102
110
|
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
103
111
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
104
112
|
"""
|
|
105
|
-
self.
|
|
113
|
+
self._model_name = model_name
|
|
106
114
|
if groq_client is not None:
|
|
107
115
|
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
|
|
108
116
|
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
|
|
@@ -112,81 +120,74 @@ class GroqModel(Model):
|
|
|
112
120
|
else:
|
|
113
121
|
self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
|
|
114
122
|
|
|
115
|
-
async def
|
|
123
|
+
async def request(
|
|
116
124
|
self,
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
) -> AgentModel:
|
|
125
|
+
messages: list[ModelMessage],
|
|
126
|
+
model_settings: ModelSettings | None,
|
|
127
|
+
model_request_parameters: ModelRequestParameters,
|
|
128
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
122
129
|
check_allow_model_requests()
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
126
|
-
return GroqAgentModel(
|
|
127
|
-
self.client,
|
|
128
|
-
self.model_name,
|
|
129
|
-
allow_text_result,
|
|
130
|
-
tools,
|
|
130
|
+
response = await self._completions_create(
|
|
131
|
+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
131
132
|
)
|
|
132
|
-
|
|
133
|
-
def name(self) -> str:
|
|
134
|
-
return f'groq:{self.model_name}'
|
|
135
|
-
|
|
136
|
-
@staticmethod
|
|
137
|
-
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
138
|
-
return {
|
|
139
|
-
'type': 'function',
|
|
140
|
-
'function': {
|
|
141
|
-
'name': f.name,
|
|
142
|
-
'description': f.description,
|
|
143
|
-
'parameters': f.parameters_json_schema,
|
|
144
|
-
},
|
|
145
|
-
}
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
@dataclass
|
|
149
|
-
class GroqAgentModel(AgentModel):
|
|
150
|
-
"""Implementation of `AgentModel` for Groq models."""
|
|
151
|
-
|
|
152
|
-
client: AsyncGroq
|
|
153
|
-
model_name: str
|
|
154
|
-
allow_text_result: bool
|
|
155
|
-
tools: list[chat.ChatCompletionToolParam]
|
|
156
|
-
|
|
157
|
-
async def request(
|
|
158
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
159
|
-
) -> tuple[ModelResponse, usage.Usage]:
|
|
160
|
-
response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
|
|
161
133
|
return self._process_response(response), _map_usage(response)
|
|
162
134
|
|
|
163
135
|
@asynccontextmanager
|
|
164
136
|
async def request_stream(
|
|
165
|
-
self,
|
|
137
|
+
self,
|
|
138
|
+
messages: list[ModelMessage],
|
|
139
|
+
model_settings: ModelSettings | None,
|
|
140
|
+
model_request_parameters: ModelRequestParameters,
|
|
166
141
|
) -> AsyncIterator[StreamedResponse]:
|
|
167
|
-
|
|
142
|
+
check_allow_model_requests()
|
|
143
|
+
response = await self._completions_create(
|
|
144
|
+
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
145
|
+
)
|
|
168
146
|
async with response:
|
|
169
147
|
yield await self._process_streamed_response(response)
|
|
170
148
|
|
|
149
|
+
@property
|
|
150
|
+
def model_name(self) -> GroqModelName:
|
|
151
|
+
"""The model name."""
|
|
152
|
+
return self._model_name
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def system(self) -> str | None:
|
|
156
|
+
"""The system / model provider."""
|
|
157
|
+
return self._system
|
|
158
|
+
|
|
171
159
|
@overload
|
|
172
160
|
async def _completions_create(
|
|
173
|
-
self,
|
|
161
|
+
self,
|
|
162
|
+
messages: list[ModelMessage],
|
|
163
|
+
stream: Literal[True],
|
|
164
|
+
model_settings: GroqModelSettings,
|
|
165
|
+
model_request_parameters: ModelRequestParameters,
|
|
174
166
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
175
167
|
pass
|
|
176
168
|
|
|
177
169
|
@overload
|
|
178
170
|
async def _completions_create(
|
|
179
|
-
self,
|
|
171
|
+
self,
|
|
172
|
+
messages: list[ModelMessage],
|
|
173
|
+
stream: Literal[False],
|
|
174
|
+
model_settings: GroqModelSettings,
|
|
175
|
+
model_request_parameters: ModelRequestParameters,
|
|
180
176
|
) -> chat.ChatCompletion:
|
|
181
177
|
pass
|
|
182
178
|
|
|
183
179
|
async def _completions_create(
|
|
184
|
-
self,
|
|
180
|
+
self,
|
|
181
|
+
messages: list[ModelMessage],
|
|
182
|
+
stream: bool,
|
|
183
|
+
model_settings: GroqModelSettings,
|
|
184
|
+
model_request_parameters: ModelRequestParameters,
|
|
185
185
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
186
|
+
tools = self._get_tools(model_request_parameters)
|
|
186
187
|
# standalone function to make it easier to override
|
|
187
|
-
if not
|
|
188
|
+
if not tools:
|
|
188
189
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
189
|
-
elif not
|
|
190
|
+
elif not model_request_parameters.allow_text_result:
|
|
190
191
|
tool_choice = 'required'
|
|
191
192
|
else:
|
|
192
193
|
tool_choice = 'auto'
|
|
@@ -194,11 +195,11 @@ class GroqAgentModel(AgentModel):
|
|
|
194
195
|
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
195
196
|
|
|
196
197
|
return await self.client.chat.completions.create(
|
|
197
|
-
model=str(self.
|
|
198
|
+
model=str(self._model_name),
|
|
198
199
|
messages=groq_messages,
|
|
199
200
|
n=1,
|
|
200
201
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
201
|
-
tools=
|
|
202
|
+
tools=tools or NOT_GIVEN,
|
|
202
203
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
203
204
|
stream=stream,
|
|
204
205
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
@@ -221,7 +222,7 @@ class GroqAgentModel(AgentModel):
|
|
|
221
222
|
if choice.message.tool_calls is not None:
|
|
222
223
|
for c in choice.message.tool_calls:
|
|
223
224
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
224
|
-
return ModelResponse(items, model_name=
|
|
225
|
+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
225
226
|
|
|
226
227
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
227
228
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -232,15 +233,20 @@ class GroqAgentModel(AgentModel):
|
|
|
232
233
|
|
|
233
234
|
return GroqStreamedResponse(
|
|
234
235
|
_response=peekable_response,
|
|
235
|
-
_model_name=self.
|
|
236
|
+
_model_name=self._model_name,
|
|
236
237
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
237
238
|
)
|
|
238
239
|
|
|
239
|
-
|
|
240
|
-
|
|
240
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
241
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
242
|
+
if model_request_parameters.result_tools:
|
|
243
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
244
|
+
return tools
|
|
245
|
+
|
|
246
|
+
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
241
247
|
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
|
|
242
248
|
if isinstance(message, ModelRequest):
|
|
243
|
-
yield from
|
|
249
|
+
yield from self._map_user_message(message)
|
|
244
250
|
elif isinstance(message, ModelResponse):
|
|
245
251
|
texts: list[str] = []
|
|
246
252
|
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
@@ -248,7 +254,7 @@ class GroqAgentModel(AgentModel):
|
|
|
248
254
|
if isinstance(item, TextPart):
|
|
249
255
|
texts.append(item.content)
|
|
250
256
|
elif isinstance(item, ToolCallPart):
|
|
251
|
-
tool_calls.append(_map_tool_call(item))
|
|
257
|
+
tool_calls.append(self._map_tool_call(item))
|
|
252
258
|
else:
|
|
253
259
|
assert_never(item)
|
|
254
260
|
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
@@ -262,6 +268,25 @@ class GroqAgentModel(AgentModel):
|
|
|
262
268
|
else:
|
|
263
269
|
assert_never(message)
|
|
264
270
|
|
|
271
|
+
@staticmethod
|
|
272
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
273
|
+
return chat.ChatCompletionMessageToolCallParam(
|
|
274
|
+
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
275
|
+
type='function',
|
|
276
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
@staticmethod
|
|
280
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
281
|
+
return {
|
|
282
|
+
'type': 'function',
|
|
283
|
+
'function': {
|
|
284
|
+
'name': f.name,
|
|
285
|
+
'description': f.description,
|
|
286
|
+
'parameters': f.parameters_json_schema,
|
|
287
|
+
},
|
|
288
|
+
}
|
|
289
|
+
|
|
265
290
|
@classmethod
|
|
266
291
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
267
292
|
for part in message.parts:
|
|
@@ -290,6 +315,7 @@ class GroqAgentModel(AgentModel):
|
|
|
290
315
|
class GroqStreamedResponse(StreamedResponse):
|
|
291
316
|
"""Implementation of `StreamedResponse` for Groq models."""
|
|
292
317
|
|
|
318
|
+
_model_name: GroqModelName
|
|
293
319
|
_response: AsyncIterable[ChatCompletionChunk]
|
|
294
320
|
_timestamp: datetime
|
|
295
321
|
|
|
@@ -318,18 +344,17 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
318
344
|
if maybe_event is not None:
|
|
319
345
|
yield maybe_event
|
|
320
346
|
|
|
347
|
+
@property
|
|
348
|
+
def model_name(self) -> GroqModelName:
|
|
349
|
+
"""Get the model name of the response."""
|
|
350
|
+
return self._model_name
|
|
351
|
+
|
|
352
|
+
@property
|
|
321
353
|
def timestamp(self) -> datetime:
|
|
354
|
+
"""Get the timestamp of the response."""
|
|
322
355
|
return self._timestamp
|
|
323
356
|
|
|
324
357
|
|
|
325
|
-
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
326
|
-
return chat.ChatCompletionMessageToolCallParam(
|
|
327
|
-
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
328
|
-
type='function',
|
|
329
|
-
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
|
|
333
358
|
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
334
359
|
response_usage = None
|
|
335
360
|
if isinstance(completion, ChatCompletion):
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -31,8 +31,8 @@ from ..result import Usage
|
|
|
31
31
|
from ..settings import ModelSettings
|
|
32
32
|
from ..tools import ToolDefinition
|
|
33
33
|
from . import (
|
|
34
|
-
AgentModel,
|
|
35
34
|
Model,
|
|
35
|
+
ModelRequestParameters,
|
|
36
36
|
StreamedResponse,
|
|
37
37
|
cached_async_http_client,
|
|
38
38
|
check_allow_model_requests,
|
|
@@ -70,12 +70,12 @@ except ImportError as e:
|
|
|
70
70
|
"you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
|
|
71
71
|
) from e
|
|
72
72
|
|
|
73
|
-
|
|
73
|
+
LatestMistralModelNames = Literal[
|
|
74
74
|
'mistral-large-latest', 'mistral-small-latest', 'codestral-latest', 'mistral-moderation-latest'
|
|
75
75
|
]
|
|
76
|
-
"""Latest
|
|
76
|
+
"""Latest Mistral models."""
|
|
77
77
|
|
|
78
|
-
MistralModelName = Union[
|
|
78
|
+
MistralModelName = Union[str, LatestMistralModelNames]
|
|
79
79
|
"""Possible Mistral model names.
|
|
80
80
|
|
|
81
81
|
Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
|
|
@@ -99,8 +99,11 @@ class MistralModel(Model):
|
|
|
99
99
|
[API Documentation](https://docs.mistral.ai/)
|
|
100
100
|
"""
|
|
101
101
|
|
|
102
|
-
model_name: MistralModelName
|
|
103
102
|
client: Mistral = field(repr=False)
|
|
103
|
+
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
104
|
+
|
|
105
|
+
_model_name: MistralModelName = field(repr=False)
|
|
106
|
+
_system: str | None = field(default='mistral', repr=False)
|
|
104
107
|
|
|
105
108
|
def __init__(
|
|
106
109
|
self,
|
|
@@ -109,6 +112,7 @@ class MistralModel(Model):
|
|
|
109
112
|
api_key: str | Callable[[], str | None] | None = None,
|
|
110
113
|
client: Mistral | None = None,
|
|
111
114
|
http_client: AsyncHTTPClient | None = None,
|
|
115
|
+
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
112
116
|
):
|
|
113
117
|
"""Initialize a Mistral model.
|
|
114
118
|
|
|
@@ -117,8 +121,10 @@ class MistralModel(Model):
|
|
|
117
121
|
api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
|
|
118
122
|
client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
119
123
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
124
|
+
json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
|
|
120
125
|
"""
|
|
121
|
-
self.
|
|
126
|
+
self._model_name = model_name
|
|
127
|
+
self.json_mode_schema_prompt = json_mode_schema_prompt
|
|
122
128
|
|
|
123
129
|
if client is not None:
|
|
124
130
|
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
|
|
@@ -128,64 +134,60 @@ class MistralModel(Model):
|
|
|
128
134
|
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
|
|
129
135
|
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
130
136
|
|
|
131
|
-
async def agent_model(
|
|
132
|
-
self,
|
|
133
|
-
*,
|
|
134
|
-
function_tools: list[ToolDefinition],
|
|
135
|
-
allow_text_result: bool,
|
|
136
|
-
result_tools: list[ToolDefinition],
|
|
137
|
-
) -> AgentModel:
|
|
138
|
-
"""Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
|
|
139
|
-
check_allow_model_requests()
|
|
140
|
-
return MistralAgentModel(
|
|
141
|
-
self.client,
|
|
142
|
-
self.model_name,
|
|
143
|
-
allow_text_result,
|
|
144
|
-
function_tools,
|
|
145
|
-
result_tools,
|
|
146
|
-
)
|
|
147
|
-
|
|
148
137
|
def name(self) -> str:
|
|
149
|
-
return f'mistral:{self.
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@dataclass
|
|
153
|
-
class MistralAgentModel(AgentModel):
|
|
154
|
-
"""Implementation of `AgentModel` for Mistral models."""
|
|
155
|
-
|
|
156
|
-
client: Mistral
|
|
157
|
-
model_name: MistralModelName
|
|
158
|
-
allow_text_result: bool
|
|
159
|
-
function_tools: list[ToolDefinition]
|
|
160
|
-
result_tools: list[ToolDefinition]
|
|
161
|
-
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
138
|
+
return f'mistral:{self._model_name}'
|
|
162
139
|
|
|
163
140
|
async def request(
|
|
164
|
-
self,
|
|
141
|
+
self,
|
|
142
|
+
messages: list[ModelMessage],
|
|
143
|
+
model_settings: ModelSettings | None,
|
|
144
|
+
model_request_parameters: ModelRequestParameters,
|
|
165
145
|
) -> tuple[ModelResponse, Usage]:
|
|
166
146
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
167
|
-
|
|
147
|
+
check_allow_model_requests()
|
|
148
|
+
response = await self._completions_create(
|
|
149
|
+
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
150
|
+
)
|
|
168
151
|
return self._process_response(response), _map_usage(response)
|
|
169
152
|
|
|
170
153
|
@asynccontextmanager
|
|
171
154
|
async def request_stream(
|
|
172
|
-
self,
|
|
155
|
+
self,
|
|
156
|
+
messages: list[ModelMessage],
|
|
157
|
+
model_settings: ModelSettings | None,
|
|
158
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
159
|
) -> AsyncIterator[StreamedResponse]:
|
|
174
160
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
175
|
-
|
|
161
|
+
check_allow_model_requests()
|
|
162
|
+
response = await self._stream_completions_create(
|
|
163
|
+
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
164
|
+
)
|
|
176
165
|
async with response:
|
|
177
|
-
yield await self._process_streamed_response(
|
|
166
|
+
yield await self._process_streamed_response(model_request_parameters.result_tools, response)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def model_name(self) -> MistralModelName:
|
|
170
|
+
"""The model name."""
|
|
171
|
+
return self._model_name
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def system(self) -> str | None:
|
|
175
|
+
"""The system / model provider."""
|
|
176
|
+
return self._system
|
|
178
177
|
|
|
179
178
|
async def _completions_create(
|
|
180
|
-
self,
|
|
179
|
+
self,
|
|
180
|
+
messages: list[ModelMessage],
|
|
181
|
+
model_settings: MistralModelSettings,
|
|
182
|
+
model_request_parameters: ModelRequestParameters,
|
|
181
183
|
) -> MistralChatCompletionResponse:
|
|
182
184
|
"""Make a non-streaming request to the model."""
|
|
183
185
|
response = await self.client.chat.complete_async(
|
|
184
|
-
model=str(self.
|
|
186
|
+
model=str(self._model_name),
|
|
185
187
|
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
186
188
|
n=1,
|
|
187
|
-
tools=self._map_function_and_result_tools_definition() or UNSET,
|
|
188
|
-
tool_choice=self._get_tool_choice(),
|
|
189
|
+
tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
|
|
190
|
+
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
189
191
|
stream=False,
|
|
190
192
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
191
193
|
temperature=model_settings.get('temperature', UNSET),
|
|
@@ -200,19 +202,24 @@ class MistralAgentModel(AgentModel):
|
|
|
200
202
|
self,
|
|
201
203
|
messages: list[ModelMessage],
|
|
202
204
|
model_settings: MistralModelSettings,
|
|
205
|
+
model_request_parameters: ModelRequestParameters,
|
|
203
206
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
204
207
|
"""Create a streaming completion request to the Mistral model."""
|
|
205
208
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
206
209
|
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
207
210
|
|
|
208
|
-
if
|
|
211
|
+
if (
|
|
212
|
+
model_request_parameters.result_tools
|
|
213
|
+
and model_request_parameters.function_tools
|
|
214
|
+
or model_request_parameters.function_tools
|
|
215
|
+
):
|
|
209
216
|
# Function Calling
|
|
210
217
|
response = await self.client.chat.stream_async(
|
|
211
|
-
model=str(self.
|
|
218
|
+
model=str(self._model_name),
|
|
212
219
|
messages=mistral_messages,
|
|
213
220
|
n=1,
|
|
214
|
-
tools=self._map_function_and_result_tools_definition() or UNSET,
|
|
215
|
-
tool_choice=self._get_tool_choice(),
|
|
221
|
+
tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
|
|
222
|
+
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
216
223
|
temperature=model_settings.get('temperature', UNSET),
|
|
217
224
|
top_p=model_settings.get('top_p', 1),
|
|
218
225
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
@@ -221,14 +228,14 @@ class MistralAgentModel(AgentModel):
|
|
|
221
228
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
222
229
|
)
|
|
223
230
|
|
|
224
|
-
elif
|
|
231
|
+
elif model_request_parameters.result_tools:
|
|
225
232
|
# Json Mode
|
|
226
|
-
parameters_json_schemas = [tool.parameters_json_schema for tool in
|
|
233
|
+
parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.result_tools]
|
|
227
234
|
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
|
|
228
235
|
mistral_messages.append(user_output_format_message)
|
|
229
236
|
|
|
230
237
|
response = await self.client.chat.stream_async(
|
|
231
|
-
model=str(self.
|
|
238
|
+
model=str(self._model_name),
|
|
232
239
|
messages=mistral_messages,
|
|
233
240
|
response_format={'type': 'json_object'},
|
|
234
241
|
stream=True,
|
|
@@ -237,14 +244,14 @@ class MistralAgentModel(AgentModel):
|
|
|
237
244
|
else:
|
|
238
245
|
# Stream Mode
|
|
239
246
|
response = await self.client.chat.stream_async(
|
|
240
|
-
model=str(self.
|
|
247
|
+
model=str(self._model_name),
|
|
241
248
|
messages=mistral_messages,
|
|
242
249
|
stream=True,
|
|
243
250
|
)
|
|
244
251
|
assert response, 'A unexpected empty response from Mistral.'
|
|
245
252
|
return response
|
|
246
253
|
|
|
247
|
-
def _get_tool_choice(self) -> MistralToolChoiceEnum | None:
|
|
254
|
+
def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> MistralToolChoiceEnum | None:
|
|
248
255
|
"""Get tool choice for the model.
|
|
249
256
|
|
|
250
257
|
- "auto": Default mode. Model decides if it uses the tool or not.
|
|
@@ -252,19 +259,23 @@ class MistralAgentModel(AgentModel):
|
|
|
252
259
|
- "none": Prevents tool use.
|
|
253
260
|
- "required": Forces tool use.
|
|
254
261
|
"""
|
|
255
|
-
if not
|
|
262
|
+
if not model_request_parameters.function_tools and not model_request_parameters.result_tools:
|
|
256
263
|
return None
|
|
257
|
-
elif not
|
|
264
|
+
elif not model_request_parameters.allow_text_result:
|
|
258
265
|
return 'required'
|
|
259
266
|
else:
|
|
260
267
|
return 'auto'
|
|
261
268
|
|
|
262
|
-
def _map_function_and_result_tools_definition(
|
|
269
|
+
def _map_function_and_result_tools_definition(
|
|
270
|
+
self, model_request_parameters: ModelRequestParameters
|
|
271
|
+
) -> list[MistralTool] | None:
|
|
263
272
|
"""Map function and result tools to MistralTool format.
|
|
264
273
|
|
|
265
274
|
Returns None if both function_tools and result_tools are empty.
|
|
266
275
|
"""
|
|
267
|
-
all_tools: list[ToolDefinition] =
|
|
276
|
+
all_tools: list[ToolDefinition] = (
|
|
277
|
+
model_request_parameters.function_tools + model_request_parameters.result_tools
|
|
278
|
+
)
|
|
268
279
|
tools = [
|
|
269
280
|
MistralTool(
|
|
270
281
|
function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description)
|
|
@@ -292,10 +303,10 @@ class MistralAgentModel(AgentModel):
|
|
|
292
303
|
|
|
293
304
|
if isinstance(tool_calls, list):
|
|
294
305
|
for tool_call in tool_calls:
|
|
295
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
306
|
+
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
296
307
|
parts.append(tool)
|
|
297
308
|
|
|
298
|
-
return ModelResponse(parts, model_name=
|
|
309
|
+
return ModelResponse(parts, model_name=response.model, timestamp=timestamp)
|
|
299
310
|
|
|
300
311
|
async def _process_streamed_response(
|
|
301
312
|
self,
|
|
@@ -315,13 +326,21 @@ class MistralAgentModel(AgentModel):
|
|
|
315
326
|
|
|
316
327
|
return MistralStreamedResponse(
|
|
317
328
|
_response=peekable_response,
|
|
318
|
-
_model_name=self.
|
|
329
|
+
_model_name=self._model_name,
|
|
319
330
|
_timestamp=timestamp,
|
|
320
331
|
_result_tools={c.name: c for c in result_tools},
|
|
321
332
|
)
|
|
322
333
|
|
|
323
334
|
@staticmethod
|
|
324
|
-
def
|
|
335
|
+
def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
|
|
336
|
+
"""Maps a MistralToolCall to a ToolCall."""
|
|
337
|
+
tool_call_id = tool_call.id or None
|
|
338
|
+
func_call = tool_call.function
|
|
339
|
+
|
|
340
|
+
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def _map_pydantic_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
325
344
|
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
326
345
|
return MistralToolCall(
|
|
327
346
|
id=t.tool_call_id,
|
|
@@ -437,7 +456,7 @@ class MistralAgentModel(AgentModel):
|
|
|
437
456
|
if isinstance(part, TextPart):
|
|
438
457
|
content_chunks.append(MistralTextChunk(text=part.content))
|
|
439
458
|
elif isinstance(part, ToolCallPart):
|
|
440
|
-
tool_calls.append(cls.
|
|
459
|
+
tool_calls.append(cls._map_pydantic_to_mistral_tool_call(part))
|
|
441
460
|
else:
|
|
442
461
|
assert_never(part)
|
|
443
462
|
yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
|
|
@@ -452,6 +471,7 @@ MistralToolCallId = Union[str, None]
|
|
|
452
471
|
class MistralStreamedResponse(StreamedResponse):
|
|
453
472
|
"""Implementation of `StreamedResponse` for Mistral models."""
|
|
454
473
|
|
|
474
|
+
_model_name: MistralModelName
|
|
455
475
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
456
476
|
_timestamp: datetime
|
|
457
477
|
_result_tools: dict[str, ToolDefinition]
|
|
@@ -493,7 +513,14 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
493
513
|
vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
|
|
494
514
|
)
|
|
495
515
|
|
|
516
|
+
@property
|
|
517
|
+
def model_name(self) -> MistralModelName:
|
|
518
|
+
"""Get the model name of the response."""
|
|
519
|
+
return self._model_name
|
|
520
|
+
|
|
521
|
+
@property
|
|
496
522
|
def timestamp(self) -> datetime:
|
|
523
|
+
"""Get the timestamp of the response."""
|
|
497
524
|
return self._timestamp
|
|
498
525
|
|
|
499
526
|
@staticmethod
|
|
@@ -563,14 +590,6 @@ SIMPLE_JSON_TYPE_MAPPING = {
|
|
|
563
590
|
}
|
|
564
591
|
|
|
565
592
|
|
|
566
|
-
def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
|
|
567
|
-
"""Maps a MistralToolCall to a ToolCall."""
|
|
568
|
-
tool_call_id = tool_call.id or None
|
|
569
|
-
func_call = tool_call.function
|
|
570
|
-
|
|
571
|
-
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
572
|
-
|
|
573
|
-
|
|
574
593
|
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
575
594
|
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
576
595
|
if response.usage:
|