pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
7
  from itertools import chain
8
- from typing import Literal, cast, overload
8
+ from typing import Literal, Union, cast, overload
9
9
 
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
@@ -28,8 +28,8 @@ from ..messages import (
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
- AgentModel,
32
31
  Model,
32
+ ModelRequestParameters,
33
33
  StreamedResponse,
34
34
  cached_async_http_client,
35
35
  check_allow_model_requests,
@@ -45,7 +45,7 @@ except ImportError as _import_error:
45
45
  "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
46
46
  ) from _import_error
47
47
 
48
- GroqModelName = Literal[
48
+ LatestGroqModelNames = Literal[
49
49
  'llama-3.3-70b-versatile',
50
50
  'llama-3.3-70b-specdec',
51
51
  'llama-3.1-8b-instant',
@@ -58,8 +58,14 @@ GroqModelName = Literal[
58
58
  'mixtral-8x7b-32768',
59
59
  'gemma2-9b-it',
60
60
  ]
61
- """Named Groq models.
61
+ """Latest Groq models."""
62
62
 
63
+ GroqModelName = Union[str, LatestGroqModelNames]
64
+ """
65
+ Possible Groq model names.
66
+
67
+ Since Groq supports a variety of date-stamped models, we explicitly list the latest models but
68
+ allow any name in the type hints.
63
69
  See [the Groq docs](https://console.groq.com/docs/models) for a full list.
64
70
  """
65
71
 
@@ -79,9 +85,11 @@ class GroqModel(Model):
79
85
  Apart from `__init__`, all methods are private or match those of the base class.
80
86
  """
81
87
 
82
- model_name: GroqModelName
83
88
  client: AsyncGroq = field(repr=False)
84
89
 
90
+ _model_name: GroqModelName = field(repr=False)
91
+ _system: str | None = field(default='groq', repr=False)
92
+
85
93
  def __init__(
86
94
  self,
87
95
  model_name: GroqModelName,
@@ -102,7 +110,7 @@ class GroqModel(Model):
102
110
  client to use, if provided, `api_key` and `http_client` must be `None`.
103
111
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
104
112
  """
105
- self.model_name = model_name
113
+ self._model_name = model_name
106
114
  if groq_client is not None:
107
115
  assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
108
116
  assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
@@ -112,81 +120,64 @@ class GroqModel(Model):
112
120
  else:
113
121
  self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
114
122
 
115
- async def agent_model(
123
+ async def request(
116
124
  self,
117
- *,
118
- function_tools: list[ToolDefinition],
119
- allow_text_result: bool,
120
- result_tools: list[ToolDefinition],
121
- ) -> AgentModel:
125
+ messages: list[ModelMessage],
126
+ model_settings: ModelSettings | None,
127
+ model_request_parameters: ModelRequestParameters,
128
+ ) -> tuple[ModelResponse, usage.Usage]:
122
129
  check_allow_model_requests()
123
- tools = [self._map_tool_definition(r) for r in function_tools]
124
- if result_tools:
125
- tools += [self._map_tool_definition(r) for r in result_tools]
126
- return GroqAgentModel(
127
- self.client,
128
- self.model_name,
129
- allow_text_result,
130
- tools,
130
+ response = await self._completions_create(
131
+ messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
131
132
  )
132
-
133
- def name(self) -> str:
134
- return f'groq:{self.model_name}'
135
-
136
- @staticmethod
137
- def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
138
- return {
139
- 'type': 'function',
140
- 'function': {
141
- 'name': f.name,
142
- 'description': f.description,
143
- 'parameters': f.parameters_json_schema,
144
- },
145
- }
146
-
147
-
148
- @dataclass
149
- class GroqAgentModel(AgentModel):
150
- """Implementation of `AgentModel` for Groq models."""
151
-
152
- client: AsyncGroq
153
- model_name: str
154
- allow_text_result: bool
155
- tools: list[chat.ChatCompletionToolParam]
156
-
157
- async def request(
158
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
159
- ) -> tuple[ModelResponse, usage.Usage]:
160
- response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
161
133
  return self._process_response(response), _map_usage(response)
162
134
 
163
135
  @asynccontextmanager
164
136
  async def request_stream(
165
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
137
+ self,
138
+ messages: list[ModelMessage],
139
+ model_settings: ModelSettings | None,
140
+ model_request_parameters: ModelRequestParameters,
166
141
  ) -> AsyncIterator[StreamedResponse]:
167
- response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
142
+ check_allow_model_requests()
143
+ response = await self._completions_create(
144
+ messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
145
+ )
168
146
  async with response:
169
147
  yield await self._process_streamed_response(response)
170
148
 
171
149
  @overload
172
150
  async def _completions_create(
173
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
151
+ self,
152
+ messages: list[ModelMessage],
153
+ stream: Literal[True],
154
+ model_settings: GroqModelSettings,
155
+ model_request_parameters: ModelRequestParameters,
174
156
  ) -> AsyncStream[ChatCompletionChunk]:
175
157
  pass
176
158
 
177
159
  @overload
178
160
  async def _completions_create(
179
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
161
+ self,
162
+ messages: list[ModelMessage],
163
+ stream: Literal[False],
164
+ model_settings: GroqModelSettings,
165
+ model_request_parameters: ModelRequestParameters,
180
166
  ) -> chat.ChatCompletion:
181
167
  pass
182
168
 
183
169
  async def _completions_create(
184
- self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
170
+ self,
171
+ messages: list[ModelMessage],
172
+ stream: bool,
173
+ model_settings: GroqModelSettings,
174
+ model_request_parameters: ModelRequestParameters,
185
175
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
176
+ tools = self._get_tools(model_request_parameters)
186
177
  # standalone function to make it easier to override
187
- if not self.tools:
178
+ if not tools:
188
179
  tool_choice: Literal['none', 'required', 'auto'] | None = None
189
- elif not self.allow_text_result:
180
+ elif not model_request_parameters.allow_text_result:
190
181
  tool_choice = 'required'
191
182
  else:
192
183
  tool_choice = 'auto'
@@ -194,11 +185,11 @@ class GroqAgentModel(AgentModel):
194
185
  groq_messages = list(chain(*(self._map_message(m) for m in messages)))
195
186
 
196
187
  return await self.client.chat.completions.create(
197
- model=str(self.model_name),
188
+ model=str(self._model_name),
198
189
  messages=groq_messages,
199
190
  n=1,
200
191
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
201
- tools=self.tools or NOT_GIVEN,
192
+ tools=tools or NOT_GIVEN,
202
193
  tool_choice=tool_choice or NOT_GIVEN,
203
194
  stream=stream,
204
195
  max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
@@ -221,7 +212,7 @@ class GroqAgentModel(AgentModel):
221
212
  if choice.message.tool_calls is not None:
222
213
  for c in choice.message.tool_calls:
223
214
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
224
- return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
215
+ return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
225
216
 
226
217
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
227
218
  """Process a streamed response, and prepare a streaming response to return."""
@@ -232,15 +223,20 @@ class GroqAgentModel(AgentModel):
232
223
 
233
224
  return GroqStreamedResponse(
234
225
  _response=peekable_response,
235
- _model_name=self.model_name,
226
+ _model_name=self._model_name,
236
227
  _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
237
228
  )
238
229
 
239
- @classmethod
240
- def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
230
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
231
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
232
+ if model_request_parameters.result_tools:
233
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
234
+ return tools
235
+
236
+ def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
241
237
  """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
242
238
  if isinstance(message, ModelRequest):
243
- yield from cls._map_user_message(message)
239
+ yield from self._map_user_message(message)
244
240
  elif isinstance(message, ModelResponse):
245
241
  texts: list[str] = []
246
242
  tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -248,7 +244,7 @@ class GroqAgentModel(AgentModel):
248
244
  if isinstance(item, TextPart):
249
245
  texts.append(item.content)
250
246
  elif isinstance(item, ToolCallPart):
251
- tool_calls.append(_map_tool_call(item))
247
+ tool_calls.append(self._map_tool_call(item))
252
248
  else:
253
249
  assert_never(item)
254
250
  message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
@@ -262,6 +258,25 @@ class GroqAgentModel(AgentModel):
262
258
  else:
263
259
  assert_never(message)
264
260
 
261
+ @staticmethod
262
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
263
+ return chat.ChatCompletionMessageToolCallParam(
264
+ id=_guard_tool_call_id(t=t, model_source='Groq'),
265
+ type='function',
266
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
267
+ )
268
+
269
+ @staticmethod
270
+ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
271
+ return {
272
+ 'type': 'function',
273
+ 'function': {
274
+ 'name': f.name,
275
+ 'description': f.description,
276
+ 'parameters': f.parameters_json_schema,
277
+ },
278
+ }
279
+
265
280
  @classmethod
266
281
  def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
267
282
  for part in message.parts:
@@ -322,14 +337,6 @@ class GroqStreamedResponse(StreamedResponse):
322
337
  return self._timestamp
323
338
 
324
339
 
325
- def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
326
- return chat.ChatCompletionMessageToolCallParam(
327
- id=_guard_tool_call_id(t=t, model_source='Groq'),
328
- type='function',
329
- function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
330
- )
331
-
332
-
333
340
  def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
334
341
  response_usage = None
335
342
  if isinstance(completion, ChatCompletion):
@@ -31,8 +31,8 @@ from ..result import Usage
31
31
  from ..settings import ModelSettings
32
32
  from ..tools import ToolDefinition
33
33
  from . import (
34
- AgentModel,
35
34
  Model,
35
+ ModelRequestParameters,
36
36
  StreamedResponse,
37
37
  cached_async_http_client,
38
38
  check_allow_model_requests,
@@ -70,12 +70,12 @@ except ImportError as e:
70
70
  "you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
71
71
  ) from e
72
72
 
73
- NamedMistralModels = Literal[
73
+ LatestMistralModelNames = Literal[
74
74
  'mistral-large-latest', 'mistral-small-latest', 'codestral-latest', 'mistral-moderation-latest'
75
75
  ]
76
- """Latest / most popular named Mistral models."""
76
+ """Latest Mistral models."""
77
77
 
78
- MistralModelName = Union[NamedMistralModels, str]
78
+ MistralModelName = Union[str, LatestMistralModelNames]
79
79
  """Possible Mistral model names.
80
80
 
81
81
  Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
@@ -99,8 +99,11 @@ class MistralModel(Model):
99
99
  [API Documentation](https://docs.mistral.ai/)
100
100
  """
101
101
 
102
- model_name: MistralModelName
103
102
  client: Mistral = field(repr=False)
103
+ json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
104
+
105
+ _model_name: MistralModelName = field(repr=False)
106
+ _system: str | None = field(default='mistral', repr=False)
104
107
 
105
108
  def __init__(
106
109
  self,
@@ -109,6 +112,7 @@ class MistralModel(Model):
109
112
  api_key: str | Callable[[], str | None] | None = None,
110
113
  client: Mistral | None = None,
111
114
  http_client: AsyncHTTPClient | None = None,
115
+ json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
112
116
  ):
113
117
  """Initialize a Mistral model.
114
118
 
@@ -117,8 +121,10 @@ class MistralModel(Model):
117
121
  api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
118
122
  client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
119
123
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
124
+ json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
120
125
  """
121
- self.model_name = model_name
126
+ self._model_name = model_name
127
+ self.json_mode_schema_prompt = json_mode_schema_prompt
122
128
 
123
129
  if client is not None:
124
130
  assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
@@ -128,64 +134,50 @@ class MistralModel(Model):
128
134
  api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
129
135
  self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
130
136
 
131
- async def agent_model(
132
- self,
133
- *,
134
- function_tools: list[ToolDefinition],
135
- allow_text_result: bool,
136
- result_tools: list[ToolDefinition],
137
- ) -> AgentModel:
138
- """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
139
- check_allow_model_requests()
140
- return MistralAgentModel(
141
- self.client,
142
- self.model_name,
143
- allow_text_result,
144
- function_tools,
145
- result_tools,
146
- )
147
-
148
137
  def name(self) -> str:
149
- return f'mistral:{self.model_name}'
150
-
151
-
152
- @dataclass
153
- class MistralAgentModel(AgentModel):
154
- """Implementation of `AgentModel` for Mistral models."""
155
-
156
- client: Mistral
157
- model_name: MistralModelName
158
- allow_text_result: bool
159
- function_tools: list[ToolDefinition]
160
- result_tools: list[ToolDefinition]
161
- json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
138
+ return f'mistral:{self._model_name}'
162
139
 
163
140
  async def request(
164
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
141
+ self,
142
+ messages: list[ModelMessage],
143
+ model_settings: ModelSettings | None,
144
+ model_request_parameters: ModelRequestParameters,
165
145
  ) -> tuple[ModelResponse, Usage]:
166
146
  """Make a non-streaming request to the model from Pydantic AI call."""
167
- response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
147
+ check_allow_model_requests()
148
+ response = await self._completions_create(
149
+ messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
150
+ )
168
151
  return self._process_response(response), _map_usage(response)
169
152
 
170
153
  @asynccontextmanager
171
154
  async def request_stream(
172
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
+ self,
156
+ messages: list[ModelMessage],
157
+ model_settings: ModelSettings | None,
158
+ model_request_parameters: ModelRequestParameters,
173
159
  ) -> AsyncIterator[StreamedResponse]:
174
160
  """Make a streaming request to the model from Pydantic AI call."""
175
- response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
161
+ check_allow_model_requests()
162
+ response = await self._stream_completions_create(
163
+ messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
164
+ )
176
165
  async with response:
177
- yield await self._process_streamed_response(self.result_tools, response)
166
+ yield await self._process_streamed_response(model_request_parameters.result_tools, response)
178
167
 
179
168
  async def _completions_create(
180
- self, messages: list[ModelMessage], model_settings: MistralModelSettings
169
+ self,
170
+ messages: list[ModelMessage],
171
+ model_settings: MistralModelSettings,
172
+ model_request_parameters: ModelRequestParameters,
181
173
  ) -> MistralChatCompletionResponse:
182
174
  """Make a non-streaming request to the model."""
183
175
  response = await self.client.chat.complete_async(
184
- model=str(self.model_name),
176
+ model=str(self._model_name),
185
177
  messages=list(chain(*(self._map_message(m) for m in messages))),
186
178
  n=1,
187
- tools=self._map_function_and_result_tools_definition() or UNSET,
188
- tool_choice=self._get_tool_choice(),
179
+ tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
180
+ tool_choice=self._get_tool_choice(model_request_parameters),
189
181
  stream=False,
190
182
  max_tokens=model_settings.get('max_tokens', UNSET),
191
183
  temperature=model_settings.get('temperature', UNSET),
@@ -200,19 +192,24 @@ class MistralAgentModel(AgentModel):
200
192
  self,
201
193
  messages: list[ModelMessage],
202
194
  model_settings: MistralModelSettings,
195
+ model_request_parameters: ModelRequestParameters,
203
196
  ) -> MistralEventStreamAsync[MistralCompletionEvent]:
204
197
  """Create a streaming completion request to the Mistral model."""
205
198
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
206
199
  mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
207
200
 
208
- if self.result_tools and self.function_tools or self.function_tools:
201
+ if (
202
+ model_request_parameters.result_tools
203
+ and model_request_parameters.function_tools
204
+ or model_request_parameters.function_tools
205
+ ):
209
206
  # Function Calling
210
207
  response = await self.client.chat.stream_async(
211
- model=str(self.model_name),
208
+ model=str(self._model_name),
212
209
  messages=mistral_messages,
213
210
  n=1,
214
- tools=self._map_function_and_result_tools_definition() or UNSET,
215
- tool_choice=self._get_tool_choice(),
211
+ tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
212
+ tool_choice=self._get_tool_choice(model_request_parameters),
216
213
  temperature=model_settings.get('temperature', UNSET),
217
214
  top_p=model_settings.get('top_p', 1),
218
215
  max_tokens=model_settings.get('max_tokens', UNSET),
@@ -221,14 +218,14 @@ class MistralAgentModel(AgentModel):
221
218
  frequency_penalty=model_settings.get('frequency_penalty'),
222
219
  )
223
220
 
224
- elif self.result_tools:
221
+ elif model_request_parameters.result_tools:
225
222
  # Json Mode
226
- parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
223
+ parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.result_tools]
227
224
  user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
228
225
  mistral_messages.append(user_output_format_message)
229
226
 
230
227
  response = await self.client.chat.stream_async(
231
- model=str(self.model_name),
228
+ model=str(self._model_name),
232
229
  messages=mistral_messages,
233
230
  response_format={'type': 'json_object'},
234
231
  stream=True,
@@ -237,14 +234,14 @@ class MistralAgentModel(AgentModel):
237
234
  else:
238
235
  # Stream Mode
239
236
  response = await self.client.chat.stream_async(
240
- model=str(self.model_name),
237
+ model=str(self._model_name),
241
238
  messages=mistral_messages,
242
239
  stream=True,
243
240
  )
244
241
  assert response, 'A unexpected empty response from Mistral.'
245
242
  return response
246
243
 
247
- def _get_tool_choice(self) -> MistralToolChoiceEnum | None:
244
+ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> MistralToolChoiceEnum | None:
248
245
  """Get tool choice for the model.
249
246
 
250
247
  - "auto": Default mode. Model decides if it uses the tool or not.
@@ -252,19 +249,23 @@ class MistralAgentModel(AgentModel):
252
249
  - "none": Prevents tool use.
253
250
  - "required": Forces tool use.
254
251
  """
255
- if not self.function_tools and not self.result_tools:
252
+ if not model_request_parameters.function_tools and not model_request_parameters.result_tools:
256
253
  return None
257
- elif not self.allow_text_result:
254
+ elif not model_request_parameters.allow_text_result:
258
255
  return 'required'
259
256
  else:
260
257
  return 'auto'
261
258
 
262
- def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None:
259
+ def _map_function_and_result_tools_definition(
260
+ self, model_request_parameters: ModelRequestParameters
261
+ ) -> list[MistralTool] | None:
263
262
  """Map function and result tools to MistralTool format.
264
263
 
265
264
  Returns None if both function_tools and result_tools are empty.
266
265
  """
267
- all_tools: list[ToolDefinition] = self.function_tools + self.result_tools
266
+ all_tools: list[ToolDefinition] = (
267
+ model_request_parameters.function_tools + model_request_parameters.result_tools
268
+ )
268
269
  tools = [
269
270
  MistralTool(
270
271
  function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description)
@@ -292,10 +293,10 @@ class MistralAgentModel(AgentModel):
292
293
 
293
294
  if isinstance(tool_calls, list):
294
295
  for tool_call in tool_calls:
295
- tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
296
+ tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
296
297
  parts.append(tool)
297
298
 
298
- return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
299
+ return ModelResponse(parts, model_name=self._model_name, timestamp=timestamp)
299
300
 
300
301
  async def _process_streamed_response(
301
302
  self,
@@ -315,13 +316,21 @@ class MistralAgentModel(AgentModel):
315
316
 
316
317
  return MistralStreamedResponse(
317
318
  _response=peekable_response,
318
- _model_name=self.model_name,
319
+ _model_name=self._model_name,
319
320
  _timestamp=timestamp,
320
321
  _result_tools={c.name: c for c in result_tools},
321
322
  )
322
323
 
323
324
  @staticmethod
324
- def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
325
+ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
326
+ """Maps a MistralToolCall to a ToolCall."""
327
+ tool_call_id = tool_call.id or None
328
+ func_call = tool_call.function
329
+
330
+ return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
331
+
332
+ @staticmethod
333
+ def _map_pydantic_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
325
334
  """Maps a pydantic-ai ToolCall to a MistralToolCall."""
326
335
  return MistralToolCall(
327
336
  id=t.tool_call_id,
@@ -437,7 +446,7 @@ class MistralAgentModel(AgentModel):
437
446
  if isinstance(part, TextPart):
438
447
  content_chunks.append(MistralTextChunk(text=part.content))
439
448
  elif isinstance(part, ToolCallPart):
440
- tool_calls.append(cls._map_to_mistral_tool_call(part))
449
+ tool_calls.append(cls._map_pydantic_to_mistral_tool_call(part))
441
450
  else:
442
451
  assert_never(part)
443
452
  yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
@@ -563,14 +572,6 @@ SIMPLE_JSON_TYPE_MAPPING = {
563
572
  }
564
573
 
565
574
 
566
- def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
567
- """Maps a MistralToolCall to a ToolCall."""
568
- tool_call_id = tool_call.id or None
569
- func_call = tool_call.function
570
-
571
- return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
572
-
573
-
574
575
  def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
575
576
  """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
576
577
  if response.usage: