pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +774 -0
- pydantic_ai/agent.py +183 -555
- pydantic_ai/models/__init__.py +43 -37
- pydantic_ai/models/anthropic.py +69 -66
- pydantic_ai/models/cohere.py +56 -68
- pydantic_ai/models/function.py +58 -60
- pydantic_ai/models/gemini.py +139 -100
- pydantic_ai/models/groq.py +79 -72
- pydantic_ai/models/mistral.py +72 -71
- pydantic_ai/models/openai.py +96 -71
- pydantic_ai/models/test.py +81 -93
- pydantic_ai/models/vertexai.py +38 -44
- pydantic_ai/settings.py +5 -0
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/METADATA +3 -4
- pydantic_ai_slim-0.0.23.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.21.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/WHEEL +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from itertools import chain
|
|
8
|
-
from typing import Literal, cast, overload
|
|
8
|
+
from typing import Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -28,8 +28,8 @@ from ..messages import (
|
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
|
-
AgentModel,
|
|
32
31
|
Model,
|
|
32
|
+
ModelRequestParameters,
|
|
33
33
|
StreamedResponse,
|
|
34
34
|
cached_async_http_client,
|
|
35
35
|
check_allow_model_requests,
|
|
@@ -45,7 +45,7 @@ except ImportError as _import_error:
|
|
|
45
45
|
"you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
|
|
46
46
|
) from _import_error
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
LatestGroqModelNames = Literal[
|
|
49
49
|
'llama-3.3-70b-versatile',
|
|
50
50
|
'llama-3.3-70b-specdec',
|
|
51
51
|
'llama-3.1-8b-instant',
|
|
@@ -58,8 +58,14 @@ GroqModelName = Literal[
|
|
|
58
58
|
'mixtral-8x7b-32768',
|
|
59
59
|
'gemma2-9b-it',
|
|
60
60
|
]
|
|
61
|
-
"""
|
|
61
|
+
"""Latest Groq models."""
|
|
62
62
|
|
|
63
|
+
GroqModelName = Union[str, LatestGroqModelNames]
|
|
64
|
+
"""
|
|
65
|
+
Possible Groq model names.
|
|
66
|
+
|
|
67
|
+
Since Groq supports a variety of date-stamped models, we explicitly list the latest models but
|
|
68
|
+
allow any name in the type hints.
|
|
63
69
|
See [the Groq docs](https://console.groq.com/docs/models) for a full list.
|
|
64
70
|
"""
|
|
65
71
|
|
|
@@ -79,9 +85,11 @@ class GroqModel(Model):
|
|
|
79
85
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
80
86
|
"""
|
|
81
87
|
|
|
82
|
-
model_name: GroqModelName
|
|
83
88
|
client: AsyncGroq = field(repr=False)
|
|
84
89
|
|
|
90
|
+
_model_name: GroqModelName = field(repr=False)
|
|
91
|
+
_system: str | None = field(default='groq', repr=False)
|
|
92
|
+
|
|
85
93
|
def __init__(
|
|
86
94
|
self,
|
|
87
95
|
model_name: GroqModelName,
|
|
@@ -102,7 +110,7 @@ class GroqModel(Model):
|
|
|
102
110
|
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
103
111
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
104
112
|
"""
|
|
105
|
-
self.
|
|
113
|
+
self._model_name = model_name
|
|
106
114
|
if groq_client is not None:
|
|
107
115
|
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
|
|
108
116
|
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
|
|
@@ -112,81 +120,64 @@ class GroqModel(Model):
|
|
|
112
120
|
else:
|
|
113
121
|
self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
|
|
114
122
|
|
|
115
|
-
async def
|
|
123
|
+
async def request(
|
|
116
124
|
self,
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
) -> AgentModel:
|
|
125
|
+
messages: list[ModelMessage],
|
|
126
|
+
model_settings: ModelSettings | None,
|
|
127
|
+
model_request_parameters: ModelRequestParameters,
|
|
128
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
122
129
|
check_allow_model_requests()
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
126
|
-
return GroqAgentModel(
|
|
127
|
-
self.client,
|
|
128
|
-
self.model_name,
|
|
129
|
-
allow_text_result,
|
|
130
|
-
tools,
|
|
130
|
+
response = await self._completions_create(
|
|
131
|
+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
131
132
|
)
|
|
132
|
-
|
|
133
|
-
def name(self) -> str:
|
|
134
|
-
return f'groq:{self.model_name}'
|
|
135
|
-
|
|
136
|
-
@staticmethod
|
|
137
|
-
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
138
|
-
return {
|
|
139
|
-
'type': 'function',
|
|
140
|
-
'function': {
|
|
141
|
-
'name': f.name,
|
|
142
|
-
'description': f.description,
|
|
143
|
-
'parameters': f.parameters_json_schema,
|
|
144
|
-
},
|
|
145
|
-
}
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
@dataclass
|
|
149
|
-
class GroqAgentModel(AgentModel):
|
|
150
|
-
"""Implementation of `AgentModel` for Groq models."""
|
|
151
|
-
|
|
152
|
-
client: AsyncGroq
|
|
153
|
-
model_name: str
|
|
154
|
-
allow_text_result: bool
|
|
155
|
-
tools: list[chat.ChatCompletionToolParam]
|
|
156
|
-
|
|
157
|
-
async def request(
|
|
158
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
159
|
-
) -> tuple[ModelResponse, usage.Usage]:
|
|
160
|
-
response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
|
|
161
133
|
return self._process_response(response), _map_usage(response)
|
|
162
134
|
|
|
163
135
|
@asynccontextmanager
|
|
164
136
|
async def request_stream(
|
|
165
|
-
self,
|
|
137
|
+
self,
|
|
138
|
+
messages: list[ModelMessage],
|
|
139
|
+
model_settings: ModelSettings | None,
|
|
140
|
+
model_request_parameters: ModelRequestParameters,
|
|
166
141
|
) -> AsyncIterator[StreamedResponse]:
|
|
167
|
-
|
|
142
|
+
check_allow_model_requests()
|
|
143
|
+
response = await self._completions_create(
|
|
144
|
+
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
145
|
+
)
|
|
168
146
|
async with response:
|
|
169
147
|
yield await self._process_streamed_response(response)
|
|
170
148
|
|
|
171
149
|
@overload
|
|
172
150
|
async def _completions_create(
|
|
173
|
-
self,
|
|
151
|
+
self,
|
|
152
|
+
messages: list[ModelMessage],
|
|
153
|
+
stream: Literal[True],
|
|
154
|
+
model_settings: GroqModelSettings,
|
|
155
|
+
model_request_parameters: ModelRequestParameters,
|
|
174
156
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
175
157
|
pass
|
|
176
158
|
|
|
177
159
|
@overload
|
|
178
160
|
async def _completions_create(
|
|
179
|
-
self,
|
|
161
|
+
self,
|
|
162
|
+
messages: list[ModelMessage],
|
|
163
|
+
stream: Literal[False],
|
|
164
|
+
model_settings: GroqModelSettings,
|
|
165
|
+
model_request_parameters: ModelRequestParameters,
|
|
180
166
|
) -> chat.ChatCompletion:
|
|
181
167
|
pass
|
|
182
168
|
|
|
183
169
|
async def _completions_create(
|
|
184
|
-
self,
|
|
170
|
+
self,
|
|
171
|
+
messages: list[ModelMessage],
|
|
172
|
+
stream: bool,
|
|
173
|
+
model_settings: GroqModelSettings,
|
|
174
|
+
model_request_parameters: ModelRequestParameters,
|
|
185
175
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
176
|
+
tools = self._get_tools(model_request_parameters)
|
|
186
177
|
# standalone function to make it easier to override
|
|
187
|
-
if not
|
|
178
|
+
if not tools:
|
|
188
179
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
189
|
-
elif not
|
|
180
|
+
elif not model_request_parameters.allow_text_result:
|
|
190
181
|
tool_choice = 'required'
|
|
191
182
|
else:
|
|
192
183
|
tool_choice = 'auto'
|
|
@@ -194,11 +185,11 @@ class GroqAgentModel(AgentModel):
|
|
|
194
185
|
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
195
186
|
|
|
196
187
|
return await self.client.chat.completions.create(
|
|
197
|
-
model=str(self.
|
|
188
|
+
model=str(self._model_name),
|
|
198
189
|
messages=groq_messages,
|
|
199
190
|
n=1,
|
|
200
191
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
201
|
-
tools=
|
|
192
|
+
tools=tools or NOT_GIVEN,
|
|
202
193
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
203
194
|
stream=stream,
|
|
204
195
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
@@ -221,7 +212,7 @@ class GroqAgentModel(AgentModel):
|
|
|
221
212
|
if choice.message.tool_calls is not None:
|
|
222
213
|
for c in choice.message.tool_calls:
|
|
223
214
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
224
|
-
return ModelResponse(items, model_name=self.
|
|
215
|
+
return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
|
|
225
216
|
|
|
226
217
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
227
218
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -232,15 +223,20 @@ class GroqAgentModel(AgentModel):
|
|
|
232
223
|
|
|
233
224
|
return GroqStreamedResponse(
|
|
234
225
|
_response=peekable_response,
|
|
235
|
-
_model_name=self.
|
|
226
|
+
_model_name=self._model_name,
|
|
236
227
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
237
228
|
)
|
|
238
229
|
|
|
239
|
-
|
|
240
|
-
|
|
230
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
231
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
232
|
+
if model_request_parameters.result_tools:
|
|
233
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
234
|
+
return tools
|
|
235
|
+
|
|
236
|
+
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
241
237
|
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
|
|
242
238
|
if isinstance(message, ModelRequest):
|
|
243
|
-
yield from
|
|
239
|
+
yield from self._map_user_message(message)
|
|
244
240
|
elif isinstance(message, ModelResponse):
|
|
245
241
|
texts: list[str] = []
|
|
246
242
|
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
@@ -248,7 +244,7 @@ class GroqAgentModel(AgentModel):
|
|
|
248
244
|
if isinstance(item, TextPart):
|
|
249
245
|
texts.append(item.content)
|
|
250
246
|
elif isinstance(item, ToolCallPart):
|
|
251
|
-
tool_calls.append(_map_tool_call(item))
|
|
247
|
+
tool_calls.append(self._map_tool_call(item))
|
|
252
248
|
else:
|
|
253
249
|
assert_never(item)
|
|
254
250
|
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
@@ -262,6 +258,25 @@ class GroqAgentModel(AgentModel):
|
|
|
262
258
|
else:
|
|
263
259
|
assert_never(message)
|
|
264
260
|
|
|
261
|
+
@staticmethod
|
|
262
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
263
|
+
return chat.ChatCompletionMessageToolCallParam(
|
|
264
|
+
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
265
|
+
type='function',
|
|
266
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
@staticmethod
|
|
270
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
271
|
+
return {
|
|
272
|
+
'type': 'function',
|
|
273
|
+
'function': {
|
|
274
|
+
'name': f.name,
|
|
275
|
+
'description': f.description,
|
|
276
|
+
'parameters': f.parameters_json_schema,
|
|
277
|
+
},
|
|
278
|
+
}
|
|
279
|
+
|
|
265
280
|
@classmethod
|
|
266
281
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
267
282
|
for part in message.parts:
|
|
@@ -322,14 +337,6 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
322
337
|
return self._timestamp
|
|
323
338
|
|
|
324
339
|
|
|
325
|
-
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
326
|
-
return chat.ChatCompletionMessageToolCallParam(
|
|
327
|
-
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
328
|
-
type='function',
|
|
329
|
-
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
|
|
333
340
|
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
334
341
|
response_usage = None
|
|
335
342
|
if isinstance(completion, ChatCompletion):
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -31,8 +31,8 @@ from ..result import Usage
|
|
|
31
31
|
from ..settings import ModelSettings
|
|
32
32
|
from ..tools import ToolDefinition
|
|
33
33
|
from . import (
|
|
34
|
-
AgentModel,
|
|
35
34
|
Model,
|
|
35
|
+
ModelRequestParameters,
|
|
36
36
|
StreamedResponse,
|
|
37
37
|
cached_async_http_client,
|
|
38
38
|
check_allow_model_requests,
|
|
@@ -70,12 +70,12 @@ except ImportError as e:
|
|
|
70
70
|
"you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
|
|
71
71
|
) from e
|
|
72
72
|
|
|
73
|
-
|
|
73
|
+
LatestMistralModelNames = Literal[
|
|
74
74
|
'mistral-large-latest', 'mistral-small-latest', 'codestral-latest', 'mistral-moderation-latest'
|
|
75
75
|
]
|
|
76
|
-
"""Latest
|
|
76
|
+
"""Latest Mistral models."""
|
|
77
77
|
|
|
78
|
-
MistralModelName = Union[
|
|
78
|
+
MistralModelName = Union[str, LatestMistralModelNames]
|
|
79
79
|
"""Possible Mistral model names.
|
|
80
80
|
|
|
81
81
|
Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
|
|
@@ -99,8 +99,11 @@ class MistralModel(Model):
|
|
|
99
99
|
[API Documentation](https://docs.mistral.ai/)
|
|
100
100
|
"""
|
|
101
101
|
|
|
102
|
-
model_name: MistralModelName
|
|
103
102
|
client: Mistral = field(repr=False)
|
|
103
|
+
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
104
|
+
|
|
105
|
+
_model_name: MistralModelName = field(repr=False)
|
|
106
|
+
_system: str | None = field(default='mistral', repr=False)
|
|
104
107
|
|
|
105
108
|
def __init__(
|
|
106
109
|
self,
|
|
@@ -109,6 +112,7 @@ class MistralModel(Model):
|
|
|
109
112
|
api_key: str | Callable[[], str | None] | None = None,
|
|
110
113
|
client: Mistral | None = None,
|
|
111
114
|
http_client: AsyncHTTPClient | None = None,
|
|
115
|
+
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
112
116
|
):
|
|
113
117
|
"""Initialize a Mistral model.
|
|
114
118
|
|
|
@@ -117,8 +121,10 @@ class MistralModel(Model):
|
|
|
117
121
|
api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
|
|
118
122
|
client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
119
123
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
124
|
+
json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
|
|
120
125
|
"""
|
|
121
|
-
self.
|
|
126
|
+
self._model_name = model_name
|
|
127
|
+
self.json_mode_schema_prompt = json_mode_schema_prompt
|
|
122
128
|
|
|
123
129
|
if client is not None:
|
|
124
130
|
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
|
|
@@ -128,64 +134,50 @@ class MistralModel(Model):
|
|
|
128
134
|
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
|
|
129
135
|
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
130
136
|
|
|
131
|
-
async def agent_model(
|
|
132
|
-
self,
|
|
133
|
-
*,
|
|
134
|
-
function_tools: list[ToolDefinition],
|
|
135
|
-
allow_text_result: bool,
|
|
136
|
-
result_tools: list[ToolDefinition],
|
|
137
|
-
) -> AgentModel:
|
|
138
|
-
"""Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
|
|
139
|
-
check_allow_model_requests()
|
|
140
|
-
return MistralAgentModel(
|
|
141
|
-
self.client,
|
|
142
|
-
self.model_name,
|
|
143
|
-
allow_text_result,
|
|
144
|
-
function_tools,
|
|
145
|
-
result_tools,
|
|
146
|
-
)
|
|
147
|
-
|
|
148
137
|
def name(self) -> str:
|
|
149
|
-
return f'mistral:{self.
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@dataclass
|
|
153
|
-
class MistralAgentModel(AgentModel):
|
|
154
|
-
"""Implementation of `AgentModel` for Mistral models."""
|
|
155
|
-
|
|
156
|
-
client: Mistral
|
|
157
|
-
model_name: MistralModelName
|
|
158
|
-
allow_text_result: bool
|
|
159
|
-
function_tools: list[ToolDefinition]
|
|
160
|
-
result_tools: list[ToolDefinition]
|
|
161
|
-
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
138
|
+
return f'mistral:{self._model_name}'
|
|
162
139
|
|
|
163
140
|
async def request(
|
|
164
|
-
self,
|
|
141
|
+
self,
|
|
142
|
+
messages: list[ModelMessage],
|
|
143
|
+
model_settings: ModelSettings | None,
|
|
144
|
+
model_request_parameters: ModelRequestParameters,
|
|
165
145
|
) -> tuple[ModelResponse, Usage]:
|
|
166
146
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
167
|
-
|
|
147
|
+
check_allow_model_requests()
|
|
148
|
+
response = await self._completions_create(
|
|
149
|
+
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
150
|
+
)
|
|
168
151
|
return self._process_response(response), _map_usage(response)
|
|
169
152
|
|
|
170
153
|
@asynccontextmanager
|
|
171
154
|
async def request_stream(
|
|
172
|
-
self,
|
|
155
|
+
self,
|
|
156
|
+
messages: list[ModelMessage],
|
|
157
|
+
model_settings: ModelSettings | None,
|
|
158
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
159
|
) -> AsyncIterator[StreamedResponse]:
|
|
174
160
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
175
|
-
|
|
161
|
+
check_allow_model_requests()
|
|
162
|
+
response = await self._stream_completions_create(
|
|
163
|
+
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
164
|
+
)
|
|
176
165
|
async with response:
|
|
177
|
-
yield await self._process_streamed_response(
|
|
166
|
+
yield await self._process_streamed_response(model_request_parameters.result_tools, response)
|
|
178
167
|
|
|
179
168
|
async def _completions_create(
|
|
180
|
-
self,
|
|
169
|
+
self,
|
|
170
|
+
messages: list[ModelMessage],
|
|
171
|
+
model_settings: MistralModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
181
173
|
) -> MistralChatCompletionResponse:
|
|
182
174
|
"""Make a non-streaming request to the model."""
|
|
183
175
|
response = await self.client.chat.complete_async(
|
|
184
|
-
model=str(self.
|
|
176
|
+
model=str(self._model_name),
|
|
185
177
|
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
186
178
|
n=1,
|
|
187
|
-
tools=self._map_function_and_result_tools_definition() or UNSET,
|
|
188
|
-
tool_choice=self._get_tool_choice(),
|
|
179
|
+
tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
|
|
180
|
+
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
189
181
|
stream=False,
|
|
190
182
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
191
183
|
temperature=model_settings.get('temperature', UNSET),
|
|
@@ -200,19 +192,24 @@ class MistralAgentModel(AgentModel):
|
|
|
200
192
|
self,
|
|
201
193
|
messages: list[ModelMessage],
|
|
202
194
|
model_settings: MistralModelSettings,
|
|
195
|
+
model_request_parameters: ModelRequestParameters,
|
|
203
196
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
204
197
|
"""Create a streaming completion request to the Mistral model."""
|
|
205
198
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
206
199
|
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
207
200
|
|
|
208
|
-
if
|
|
201
|
+
if (
|
|
202
|
+
model_request_parameters.result_tools
|
|
203
|
+
and model_request_parameters.function_tools
|
|
204
|
+
or model_request_parameters.function_tools
|
|
205
|
+
):
|
|
209
206
|
# Function Calling
|
|
210
207
|
response = await self.client.chat.stream_async(
|
|
211
|
-
model=str(self.
|
|
208
|
+
model=str(self._model_name),
|
|
212
209
|
messages=mistral_messages,
|
|
213
210
|
n=1,
|
|
214
|
-
tools=self._map_function_and_result_tools_definition() or UNSET,
|
|
215
|
-
tool_choice=self._get_tool_choice(),
|
|
211
|
+
tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
|
|
212
|
+
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
216
213
|
temperature=model_settings.get('temperature', UNSET),
|
|
217
214
|
top_p=model_settings.get('top_p', 1),
|
|
218
215
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
@@ -221,14 +218,14 @@ class MistralAgentModel(AgentModel):
|
|
|
221
218
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
222
219
|
)
|
|
223
220
|
|
|
224
|
-
elif
|
|
221
|
+
elif model_request_parameters.result_tools:
|
|
225
222
|
# Json Mode
|
|
226
|
-
parameters_json_schemas = [tool.parameters_json_schema for tool in
|
|
223
|
+
parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.result_tools]
|
|
227
224
|
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
|
|
228
225
|
mistral_messages.append(user_output_format_message)
|
|
229
226
|
|
|
230
227
|
response = await self.client.chat.stream_async(
|
|
231
|
-
model=str(self.
|
|
228
|
+
model=str(self._model_name),
|
|
232
229
|
messages=mistral_messages,
|
|
233
230
|
response_format={'type': 'json_object'},
|
|
234
231
|
stream=True,
|
|
@@ -237,14 +234,14 @@ class MistralAgentModel(AgentModel):
|
|
|
237
234
|
else:
|
|
238
235
|
# Stream Mode
|
|
239
236
|
response = await self.client.chat.stream_async(
|
|
240
|
-
model=str(self.
|
|
237
|
+
model=str(self._model_name),
|
|
241
238
|
messages=mistral_messages,
|
|
242
239
|
stream=True,
|
|
243
240
|
)
|
|
244
241
|
assert response, 'A unexpected empty response from Mistral.'
|
|
245
242
|
return response
|
|
246
243
|
|
|
247
|
-
def _get_tool_choice(self) -> MistralToolChoiceEnum | None:
|
|
244
|
+
def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> MistralToolChoiceEnum | None:
|
|
248
245
|
"""Get tool choice for the model.
|
|
249
246
|
|
|
250
247
|
- "auto": Default mode. Model decides if it uses the tool or not.
|
|
@@ -252,19 +249,23 @@ class MistralAgentModel(AgentModel):
|
|
|
252
249
|
- "none": Prevents tool use.
|
|
253
250
|
- "required": Forces tool use.
|
|
254
251
|
"""
|
|
255
|
-
if not
|
|
252
|
+
if not model_request_parameters.function_tools and not model_request_parameters.result_tools:
|
|
256
253
|
return None
|
|
257
|
-
elif not
|
|
254
|
+
elif not model_request_parameters.allow_text_result:
|
|
258
255
|
return 'required'
|
|
259
256
|
else:
|
|
260
257
|
return 'auto'
|
|
261
258
|
|
|
262
|
-
def _map_function_and_result_tools_definition(
|
|
259
|
+
def _map_function_and_result_tools_definition(
|
|
260
|
+
self, model_request_parameters: ModelRequestParameters
|
|
261
|
+
) -> list[MistralTool] | None:
|
|
263
262
|
"""Map function and result tools to MistralTool format.
|
|
264
263
|
|
|
265
264
|
Returns None if both function_tools and result_tools are empty.
|
|
266
265
|
"""
|
|
267
|
-
all_tools: list[ToolDefinition] =
|
|
266
|
+
all_tools: list[ToolDefinition] = (
|
|
267
|
+
model_request_parameters.function_tools + model_request_parameters.result_tools
|
|
268
|
+
)
|
|
268
269
|
tools = [
|
|
269
270
|
MistralTool(
|
|
270
271
|
function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description)
|
|
@@ -292,10 +293,10 @@ class MistralAgentModel(AgentModel):
|
|
|
292
293
|
|
|
293
294
|
if isinstance(tool_calls, list):
|
|
294
295
|
for tool_call in tool_calls:
|
|
295
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
296
|
+
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
296
297
|
parts.append(tool)
|
|
297
298
|
|
|
298
|
-
return ModelResponse(parts, model_name=self.
|
|
299
|
+
return ModelResponse(parts, model_name=self._model_name, timestamp=timestamp)
|
|
299
300
|
|
|
300
301
|
async def _process_streamed_response(
|
|
301
302
|
self,
|
|
@@ -315,13 +316,21 @@ class MistralAgentModel(AgentModel):
|
|
|
315
316
|
|
|
316
317
|
return MistralStreamedResponse(
|
|
317
318
|
_response=peekable_response,
|
|
318
|
-
_model_name=self.
|
|
319
|
+
_model_name=self._model_name,
|
|
319
320
|
_timestamp=timestamp,
|
|
320
321
|
_result_tools={c.name: c for c in result_tools},
|
|
321
322
|
)
|
|
322
323
|
|
|
323
324
|
@staticmethod
|
|
324
|
-
def
|
|
325
|
+
def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
|
|
326
|
+
"""Maps a MistralToolCall to a ToolCall."""
|
|
327
|
+
tool_call_id = tool_call.id or None
|
|
328
|
+
func_call = tool_call.function
|
|
329
|
+
|
|
330
|
+
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
331
|
+
|
|
332
|
+
@staticmethod
|
|
333
|
+
def _map_pydantic_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
325
334
|
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
326
335
|
return MistralToolCall(
|
|
327
336
|
id=t.tool_call_id,
|
|
@@ -437,7 +446,7 @@ class MistralAgentModel(AgentModel):
|
|
|
437
446
|
if isinstance(part, TextPart):
|
|
438
447
|
content_chunks.append(MistralTextChunk(text=part.content))
|
|
439
448
|
elif isinstance(part, ToolCallPart):
|
|
440
|
-
tool_calls.append(cls.
|
|
449
|
+
tool_calls.append(cls._map_pydantic_to_mistral_tool_call(part))
|
|
441
450
|
else:
|
|
442
451
|
assert_never(part)
|
|
443
452
|
yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
|
|
@@ -563,14 +572,6 @@ SIMPLE_JSON_TYPE_MAPPING = {
|
|
|
563
572
|
}
|
|
564
573
|
|
|
565
574
|
|
|
566
|
-
def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
|
|
567
|
-
"""Maps a MistralToolCall to a ToolCall."""
|
|
568
|
-
tool_call_id = tool_call.id or None
|
|
569
|
-
func_call = tool_call.function
|
|
570
|
-
|
|
571
|
-
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
572
|
-
|
|
573
|
-
|
|
574
575
|
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
575
576
|
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
576
577
|
if response.usage:
|