pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
7
  from itertools import chain
8
- from typing import Literal, cast, overload
8
+ from typing import Literal, Union, cast, overload
9
9
 
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
@@ -28,8 +28,8 @@ from ..messages import (
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
- AgentModel,
32
31
  Model,
32
+ ModelRequestParameters,
33
33
  StreamedResponse,
34
34
  cached_async_http_client,
35
35
  check_allow_model_requests,
@@ -45,7 +45,7 @@ except ImportError as _import_error:
45
45
  "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
46
46
  ) from _import_error
47
47
 
48
- GroqModelName = Literal[
48
+ LatestGroqModelNames = Literal[
49
49
  'llama-3.3-70b-versatile',
50
50
  'llama-3.3-70b-specdec',
51
51
  'llama-3.1-8b-instant',
@@ -58,8 +58,14 @@ GroqModelName = Literal[
58
58
  'mixtral-8x7b-32768',
59
59
  'gemma2-9b-it',
60
60
  ]
61
- """Named Groq models.
61
+ """Latest Groq models."""
62
62
 
63
+ GroqModelName = Union[str, LatestGroqModelNames]
64
+ """
65
+ Possible Groq model names.
66
+
67
+ Since Groq supports a variety of date-stamped models, we explicitly list the latest models but
68
+ allow any name in the type hints.
63
69
  See [the Groq docs](https://console.groq.com/docs/models) for a full list.
64
70
  """
65
71
 
@@ -79,9 +85,11 @@ class GroqModel(Model):
79
85
  Apart from `__init__`, all methods are private or match those of the base class.
80
86
  """
81
87
 
82
- model_name: GroqModelName
83
88
  client: AsyncGroq = field(repr=False)
84
89
 
90
+ _model_name: GroqModelName = field(repr=False)
91
+ _system: str | None = field(default='groq', repr=False)
92
+
85
93
  def __init__(
86
94
  self,
87
95
  model_name: GroqModelName,
@@ -102,7 +110,7 @@ class GroqModel(Model):
102
110
  client to use, if provided, `api_key` and `http_client` must be `None`.
103
111
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
104
112
  """
105
- self.model_name = model_name
113
+ self._model_name = model_name
106
114
  if groq_client is not None:
107
115
  assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
108
116
  assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
@@ -112,81 +120,74 @@ class GroqModel(Model):
112
120
  else:
113
121
  self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
114
122
 
115
- async def agent_model(
123
+ async def request(
116
124
  self,
117
- *,
118
- function_tools: list[ToolDefinition],
119
- allow_text_result: bool,
120
- result_tools: list[ToolDefinition],
121
- ) -> AgentModel:
125
+ messages: list[ModelMessage],
126
+ model_settings: ModelSettings | None,
127
+ model_request_parameters: ModelRequestParameters,
128
+ ) -> tuple[ModelResponse, usage.Usage]:
122
129
  check_allow_model_requests()
123
- tools = [self._map_tool_definition(r) for r in function_tools]
124
- if result_tools:
125
- tools += [self._map_tool_definition(r) for r in result_tools]
126
- return GroqAgentModel(
127
- self.client,
128
- self.model_name,
129
- allow_text_result,
130
- tools,
130
+ response = await self._completions_create(
131
+ messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
131
132
  )
132
-
133
- def name(self) -> str:
134
- return f'groq:{self.model_name}'
135
-
136
- @staticmethod
137
- def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
138
- return {
139
- 'type': 'function',
140
- 'function': {
141
- 'name': f.name,
142
- 'description': f.description,
143
- 'parameters': f.parameters_json_schema,
144
- },
145
- }
146
-
147
-
148
- @dataclass
149
- class GroqAgentModel(AgentModel):
150
- """Implementation of `AgentModel` for Groq models."""
151
-
152
- client: AsyncGroq
153
- model_name: str
154
- allow_text_result: bool
155
- tools: list[chat.ChatCompletionToolParam]
156
-
157
- async def request(
158
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
159
- ) -> tuple[ModelResponse, usage.Usage]:
160
- response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
161
133
  return self._process_response(response), _map_usage(response)
162
134
 
163
135
  @asynccontextmanager
164
136
  async def request_stream(
165
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
137
+ self,
138
+ messages: list[ModelMessage],
139
+ model_settings: ModelSettings | None,
140
+ model_request_parameters: ModelRequestParameters,
166
141
  ) -> AsyncIterator[StreamedResponse]:
167
- response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
142
+ check_allow_model_requests()
143
+ response = await self._completions_create(
144
+ messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
145
+ )
168
146
  async with response:
169
147
  yield await self._process_streamed_response(response)
170
148
 
149
+ @property
150
+ def model_name(self) -> GroqModelName:
151
+ """The model name."""
152
+ return self._model_name
153
+
154
+ @property
155
+ def system(self) -> str | None:
156
+ """The system / model provider."""
157
+ return self._system
158
+
171
159
  @overload
172
160
  async def _completions_create(
173
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
161
+ self,
162
+ messages: list[ModelMessage],
163
+ stream: Literal[True],
164
+ model_settings: GroqModelSettings,
165
+ model_request_parameters: ModelRequestParameters,
174
166
  ) -> AsyncStream[ChatCompletionChunk]:
175
167
  pass
176
168
 
177
169
  @overload
178
170
  async def _completions_create(
179
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
171
+ self,
172
+ messages: list[ModelMessage],
173
+ stream: Literal[False],
174
+ model_settings: GroqModelSettings,
175
+ model_request_parameters: ModelRequestParameters,
180
176
  ) -> chat.ChatCompletion:
181
177
  pass
182
178
 
183
179
  async def _completions_create(
184
- self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
180
+ self,
181
+ messages: list[ModelMessage],
182
+ stream: bool,
183
+ model_settings: GroqModelSettings,
184
+ model_request_parameters: ModelRequestParameters,
185
185
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
186
+ tools = self._get_tools(model_request_parameters)
186
187
  # standalone function to make it easier to override
187
- if not self.tools:
188
+ if not tools:
188
189
  tool_choice: Literal['none', 'required', 'auto'] | None = None
189
- elif not self.allow_text_result:
190
+ elif not model_request_parameters.allow_text_result:
190
191
  tool_choice = 'required'
191
192
  else:
192
193
  tool_choice = 'auto'
@@ -194,11 +195,11 @@ class GroqAgentModel(AgentModel):
194
195
  groq_messages = list(chain(*(self._map_message(m) for m in messages)))
195
196
 
196
197
  return await self.client.chat.completions.create(
197
- model=str(self.model_name),
198
+ model=str(self._model_name),
198
199
  messages=groq_messages,
199
200
  n=1,
200
201
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
201
- tools=self.tools or NOT_GIVEN,
202
+ tools=tools or NOT_GIVEN,
202
203
  tool_choice=tool_choice or NOT_GIVEN,
203
204
  stream=stream,
204
205
  max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
@@ -221,7 +222,7 @@ class GroqAgentModel(AgentModel):
221
222
  if choice.message.tool_calls is not None:
222
223
  for c in choice.message.tool_calls:
223
224
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
224
- return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
225
+ return ModelResponse(items, model_name=response.model, timestamp=timestamp)
225
226
 
226
227
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
227
228
  """Process a streamed response, and prepare a streaming response to return."""
@@ -232,15 +233,20 @@ class GroqAgentModel(AgentModel):
232
233
 
233
234
  return GroqStreamedResponse(
234
235
  _response=peekable_response,
235
- _model_name=self.model_name,
236
+ _model_name=self._model_name,
236
237
  _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
237
238
  )
238
239
 
239
- @classmethod
240
- def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
240
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
241
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
242
+ if model_request_parameters.result_tools:
243
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
244
+ return tools
245
+
246
+ def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
241
247
  """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
242
248
  if isinstance(message, ModelRequest):
243
- yield from cls._map_user_message(message)
249
+ yield from self._map_user_message(message)
244
250
  elif isinstance(message, ModelResponse):
245
251
  texts: list[str] = []
246
252
  tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -248,7 +254,7 @@ class GroqAgentModel(AgentModel):
248
254
  if isinstance(item, TextPart):
249
255
  texts.append(item.content)
250
256
  elif isinstance(item, ToolCallPart):
251
- tool_calls.append(_map_tool_call(item))
257
+ tool_calls.append(self._map_tool_call(item))
252
258
  else:
253
259
  assert_never(item)
254
260
  message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
@@ -262,6 +268,25 @@ class GroqAgentModel(AgentModel):
262
268
  else:
263
269
  assert_never(message)
264
270
 
271
+ @staticmethod
272
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
273
+ return chat.ChatCompletionMessageToolCallParam(
274
+ id=_guard_tool_call_id(t=t, model_source='Groq'),
275
+ type='function',
276
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
277
+ )
278
+
279
+ @staticmethod
280
+ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
281
+ return {
282
+ 'type': 'function',
283
+ 'function': {
284
+ 'name': f.name,
285
+ 'description': f.description,
286
+ 'parameters': f.parameters_json_schema,
287
+ },
288
+ }
289
+
265
290
  @classmethod
266
291
  def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
267
292
  for part in message.parts:
@@ -290,6 +315,7 @@ class GroqAgentModel(AgentModel):
290
315
  class GroqStreamedResponse(StreamedResponse):
291
316
  """Implementation of `StreamedResponse` for Groq models."""
292
317
 
318
+ _model_name: GroqModelName
293
319
  _response: AsyncIterable[ChatCompletionChunk]
294
320
  _timestamp: datetime
295
321
 
@@ -318,18 +344,17 @@ class GroqStreamedResponse(StreamedResponse):
318
344
  if maybe_event is not None:
319
345
  yield maybe_event
320
346
 
347
+ @property
348
+ def model_name(self) -> GroqModelName:
349
+ """Get the model name of the response."""
350
+ return self._model_name
351
+
352
+ @property
321
353
  def timestamp(self) -> datetime:
354
+ """Get the timestamp of the response."""
322
355
  return self._timestamp
323
356
 
324
357
 
325
- def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
326
- return chat.ChatCompletionMessageToolCallParam(
327
- id=_guard_tool_call_id(t=t, model_source='Groq'),
328
- type='function',
329
- function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
330
- )
331
-
332
-
333
358
  def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
334
359
  response_usage = None
335
360
  if isinstance(completion, ChatCompletion):
@@ -31,8 +31,8 @@ from ..result import Usage
31
31
  from ..settings import ModelSettings
32
32
  from ..tools import ToolDefinition
33
33
  from . import (
34
- AgentModel,
35
34
  Model,
35
+ ModelRequestParameters,
36
36
  StreamedResponse,
37
37
  cached_async_http_client,
38
38
  check_allow_model_requests,
@@ -70,12 +70,12 @@ except ImportError as e:
70
70
  "you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
71
71
  ) from e
72
72
 
73
- NamedMistralModels = Literal[
73
+ LatestMistralModelNames = Literal[
74
74
  'mistral-large-latest', 'mistral-small-latest', 'codestral-latest', 'mistral-moderation-latest'
75
75
  ]
76
- """Latest / most popular named Mistral models."""
76
+ """Latest Mistral models."""
77
77
 
78
- MistralModelName = Union[NamedMistralModels, str]
78
+ MistralModelName = Union[str, LatestMistralModelNames]
79
79
  """Possible Mistral model names.
80
80
 
81
81
  Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
@@ -99,8 +99,11 @@ class MistralModel(Model):
99
99
  [API Documentation](https://docs.mistral.ai/)
100
100
  """
101
101
 
102
- model_name: MistralModelName
103
102
  client: Mistral = field(repr=False)
103
+ json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
104
+
105
+ _model_name: MistralModelName = field(repr=False)
106
+ _system: str | None = field(default='mistral', repr=False)
104
107
 
105
108
  def __init__(
106
109
  self,
@@ -109,6 +112,7 @@ class MistralModel(Model):
109
112
  api_key: str | Callable[[], str | None] | None = None,
110
113
  client: Mistral | None = None,
111
114
  http_client: AsyncHTTPClient | None = None,
115
+ json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
112
116
  ):
113
117
  """Initialize a Mistral model.
114
118
 
@@ -117,8 +121,10 @@ class MistralModel(Model):
117
121
  api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
118
122
  client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
119
123
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
124
+ json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
120
125
  """
121
- self.model_name = model_name
126
+ self._model_name = model_name
127
+ self.json_mode_schema_prompt = json_mode_schema_prompt
122
128
 
123
129
  if client is not None:
124
130
  assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
@@ -128,64 +134,60 @@ class MistralModel(Model):
128
134
  api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
129
135
  self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
130
136
 
131
- async def agent_model(
132
- self,
133
- *,
134
- function_tools: list[ToolDefinition],
135
- allow_text_result: bool,
136
- result_tools: list[ToolDefinition],
137
- ) -> AgentModel:
138
- """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
139
- check_allow_model_requests()
140
- return MistralAgentModel(
141
- self.client,
142
- self.model_name,
143
- allow_text_result,
144
- function_tools,
145
- result_tools,
146
- )
147
-
148
137
  def name(self) -> str:
149
- return f'mistral:{self.model_name}'
150
-
151
-
152
- @dataclass
153
- class MistralAgentModel(AgentModel):
154
- """Implementation of `AgentModel` for Mistral models."""
155
-
156
- client: Mistral
157
- model_name: MistralModelName
158
- allow_text_result: bool
159
- function_tools: list[ToolDefinition]
160
- result_tools: list[ToolDefinition]
161
- json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
138
+ return f'mistral:{self._model_name}'
162
139
 
163
140
  async def request(
164
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
141
+ self,
142
+ messages: list[ModelMessage],
143
+ model_settings: ModelSettings | None,
144
+ model_request_parameters: ModelRequestParameters,
165
145
  ) -> tuple[ModelResponse, Usage]:
166
146
  """Make a non-streaming request to the model from Pydantic AI call."""
167
- response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
147
+ check_allow_model_requests()
148
+ response = await self._completions_create(
149
+ messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
150
+ )
168
151
  return self._process_response(response), _map_usage(response)
169
152
 
170
153
  @asynccontextmanager
171
154
  async def request_stream(
172
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
+ self,
156
+ messages: list[ModelMessage],
157
+ model_settings: ModelSettings | None,
158
+ model_request_parameters: ModelRequestParameters,
173
159
  ) -> AsyncIterator[StreamedResponse]:
174
160
  """Make a streaming request to the model from Pydantic AI call."""
175
- response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
161
+ check_allow_model_requests()
162
+ response = await self._stream_completions_create(
163
+ messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
164
+ )
176
165
  async with response:
177
- yield await self._process_streamed_response(self.result_tools, response)
166
+ yield await self._process_streamed_response(model_request_parameters.result_tools, response)
167
+
168
+ @property
169
+ def model_name(self) -> MistralModelName:
170
+ """The model name."""
171
+ return self._model_name
172
+
173
+ @property
174
+ def system(self) -> str | None:
175
+ """The system / model provider."""
176
+ return self._system
178
177
 
179
178
  async def _completions_create(
180
- self, messages: list[ModelMessage], model_settings: MistralModelSettings
179
+ self,
180
+ messages: list[ModelMessage],
181
+ model_settings: MistralModelSettings,
182
+ model_request_parameters: ModelRequestParameters,
181
183
  ) -> MistralChatCompletionResponse:
182
184
  """Make a non-streaming request to the model."""
183
185
  response = await self.client.chat.complete_async(
184
- model=str(self.model_name),
186
+ model=str(self._model_name),
185
187
  messages=list(chain(*(self._map_message(m) for m in messages))),
186
188
  n=1,
187
- tools=self._map_function_and_result_tools_definition() or UNSET,
188
- tool_choice=self._get_tool_choice(),
189
+ tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
190
+ tool_choice=self._get_tool_choice(model_request_parameters),
189
191
  stream=False,
190
192
  max_tokens=model_settings.get('max_tokens', UNSET),
191
193
  temperature=model_settings.get('temperature', UNSET),
@@ -200,19 +202,24 @@ class MistralAgentModel(AgentModel):
200
202
  self,
201
203
  messages: list[ModelMessage],
202
204
  model_settings: MistralModelSettings,
205
+ model_request_parameters: ModelRequestParameters,
203
206
  ) -> MistralEventStreamAsync[MistralCompletionEvent]:
204
207
  """Create a streaming completion request to the Mistral model."""
205
208
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
206
209
  mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
207
210
 
208
- if self.result_tools and self.function_tools or self.function_tools:
211
+ if (
212
+ model_request_parameters.result_tools
213
+ and model_request_parameters.function_tools
214
+ or model_request_parameters.function_tools
215
+ ):
209
216
  # Function Calling
210
217
  response = await self.client.chat.stream_async(
211
- model=str(self.model_name),
218
+ model=str(self._model_name),
212
219
  messages=mistral_messages,
213
220
  n=1,
214
- tools=self._map_function_and_result_tools_definition() or UNSET,
215
- tool_choice=self._get_tool_choice(),
221
+ tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
222
+ tool_choice=self._get_tool_choice(model_request_parameters),
216
223
  temperature=model_settings.get('temperature', UNSET),
217
224
  top_p=model_settings.get('top_p', 1),
218
225
  max_tokens=model_settings.get('max_tokens', UNSET),
@@ -221,14 +228,14 @@ class MistralAgentModel(AgentModel):
221
228
  frequency_penalty=model_settings.get('frequency_penalty'),
222
229
  )
223
230
 
224
- elif self.result_tools:
231
+ elif model_request_parameters.result_tools:
225
232
  # Json Mode
226
- parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
233
+ parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.result_tools]
227
234
  user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
228
235
  mistral_messages.append(user_output_format_message)
229
236
 
230
237
  response = await self.client.chat.stream_async(
231
- model=str(self.model_name),
238
+ model=str(self._model_name),
232
239
  messages=mistral_messages,
233
240
  response_format={'type': 'json_object'},
234
241
  stream=True,
@@ -237,14 +244,14 @@ class MistralAgentModel(AgentModel):
237
244
  else:
238
245
  # Stream Mode
239
246
  response = await self.client.chat.stream_async(
240
- model=str(self.model_name),
247
+ model=str(self._model_name),
241
248
  messages=mistral_messages,
242
249
  stream=True,
243
250
  )
244
251
  assert response, 'A unexpected empty response from Mistral.'
245
252
  return response
246
253
 
247
- def _get_tool_choice(self) -> MistralToolChoiceEnum | None:
254
+ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> MistralToolChoiceEnum | None:
248
255
  """Get tool choice for the model.
249
256
 
250
257
  - "auto": Default mode. Model decides if it uses the tool or not.
@@ -252,19 +259,23 @@ class MistralAgentModel(AgentModel):
252
259
  - "none": Prevents tool use.
253
260
  - "required": Forces tool use.
254
261
  """
255
- if not self.function_tools and not self.result_tools:
262
+ if not model_request_parameters.function_tools and not model_request_parameters.result_tools:
256
263
  return None
257
- elif not self.allow_text_result:
264
+ elif not model_request_parameters.allow_text_result:
258
265
  return 'required'
259
266
  else:
260
267
  return 'auto'
261
268
 
262
- def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None:
269
+ def _map_function_and_result_tools_definition(
270
+ self, model_request_parameters: ModelRequestParameters
271
+ ) -> list[MistralTool] | None:
263
272
  """Map function and result tools to MistralTool format.
264
273
 
265
274
  Returns None if both function_tools and result_tools are empty.
266
275
  """
267
- all_tools: list[ToolDefinition] = self.function_tools + self.result_tools
276
+ all_tools: list[ToolDefinition] = (
277
+ model_request_parameters.function_tools + model_request_parameters.result_tools
278
+ )
268
279
  tools = [
269
280
  MistralTool(
270
281
  function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description)
@@ -292,10 +303,10 @@ class MistralAgentModel(AgentModel):
292
303
 
293
304
  if isinstance(tool_calls, list):
294
305
  for tool_call in tool_calls:
295
- tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
306
+ tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
296
307
  parts.append(tool)
297
308
 
298
- return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
309
+ return ModelResponse(parts, model_name=response.model, timestamp=timestamp)
299
310
 
300
311
  async def _process_streamed_response(
301
312
  self,
@@ -315,13 +326,21 @@ class MistralAgentModel(AgentModel):
315
326
 
316
327
  return MistralStreamedResponse(
317
328
  _response=peekable_response,
318
- _model_name=self.model_name,
329
+ _model_name=self._model_name,
319
330
  _timestamp=timestamp,
320
331
  _result_tools={c.name: c for c in result_tools},
321
332
  )
322
333
 
323
334
  @staticmethod
324
- def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
335
+ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
336
+ """Maps a MistralToolCall to a ToolCall."""
337
+ tool_call_id = tool_call.id or None
338
+ func_call = tool_call.function
339
+
340
+ return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
341
+
342
+ @staticmethod
343
+ def _map_pydantic_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
325
344
  """Maps a pydantic-ai ToolCall to a MistralToolCall."""
326
345
  return MistralToolCall(
327
346
  id=t.tool_call_id,
@@ -437,7 +456,7 @@ class MistralAgentModel(AgentModel):
437
456
  if isinstance(part, TextPart):
438
457
  content_chunks.append(MistralTextChunk(text=part.content))
439
458
  elif isinstance(part, ToolCallPart):
440
- tool_calls.append(cls._map_to_mistral_tool_call(part))
459
+ tool_calls.append(cls._map_pydantic_to_mistral_tool_call(part))
441
460
  else:
442
461
  assert_never(part)
443
462
  yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
@@ -452,6 +471,7 @@ MistralToolCallId = Union[str, None]
452
471
  class MistralStreamedResponse(StreamedResponse):
453
472
  """Implementation of `StreamedResponse` for Mistral models."""
454
473
 
474
+ _model_name: MistralModelName
455
475
  _response: AsyncIterable[MistralCompletionEvent]
456
476
  _timestamp: datetime
457
477
  _result_tools: dict[str, ToolDefinition]
@@ -493,7 +513,14 @@ class MistralStreamedResponse(StreamedResponse):
493
513
  vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
494
514
  )
495
515
 
516
+ @property
517
+ def model_name(self) -> MistralModelName:
518
+ """Get the model name of the response."""
519
+ return self._model_name
520
+
521
+ @property
496
522
  def timestamp(self) -> datetime:
523
+ """Get the timestamp of the response."""
497
524
  return self._timestamp
498
525
 
499
526
  @staticmethod
@@ -563,14 +590,6 @@ SIMPLE_JSON_TYPE_MAPPING = {
563
590
  }
564
591
 
565
592
 
566
- def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
567
- """Maps a MistralToolCall to a ToolCall."""
568
- tool_call_id = tool_call.id or None
569
- func_call = tool_call.function
570
-
571
- return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
572
-
573
-
574
593
  def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
575
594
  """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
576
595
  if response.usage: