pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -6,7 +6,7 @@ from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
8
  from itertools import chain
9
- from typing import Any, Callable, Literal, Union
9
+ from typing import Any, Callable, Literal, Union, cast
10
10
 
11
11
  import pydantic_core
12
12
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
@@ -15,7 +15,6 @@ from typing_extensions import assert_never
15
15
  from .. import UnexpectedModelBehavior, _utils
16
16
  from .._utils import now_utc as _now_utc
17
17
  from ..messages import (
18
- ArgsJson,
19
18
  ModelMessage,
20
19
  ModelRequest,
21
20
  ModelResponse,
@@ -36,6 +35,7 @@ from . import (
36
35
  Model,
37
36
  StreamedResponse,
38
37
  cached_async_http_client,
38
+ check_allow_model_requests,
39
39
  )
40
40
 
41
41
  try:
@@ -84,6 +84,12 @@ Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_o
84
84
  """
85
85
 
86
86
 
87
+ class MistralModelSettings(ModelSettings):
88
+ """Settings used for a Mistral model request."""
89
+
90
+ # This class is a placeholder for any future mistral-specific settings
91
+
92
+
87
93
  @dataclass(init=False)
88
94
  class MistralModel(Model):
89
95
  """A model that uses Mistral.
@@ -130,6 +136,7 @@ class MistralModel(Model):
130
136
  result_tools: list[ToolDefinition],
131
137
  ) -> AgentModel:
132
138
  """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
139
+ check_allow_model_requests()
133
140
  return MistralAgentModel(
134
141
  self.client,
135
142
  self.model_name,
@@ -147,7 +154,7 @@ class MistralAgentModel(AgentModel):
147
154
  """Implementation of `AgentModel` for Mistral models."""
148
155
 
149
156
  client: Mistral
150
- model_name: str
157
+ model_name: MistralModelName
151
158
  allow_text_result: bool
152
159
  function_tools: list[ToolDefinition]
153
160
  result_tools: list[ToolDefinition]
@@ -157,7 +164,7 @@ class MistralAgentModel(AgentModel):
157
164
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
158
165
  ) -> tuple[ModelResponse, Usage]:
159
166
  """Make a non-streaming request to the model from Pydantic AI call."""
160
- response = await self._completions_create(messages, model_settings)
167
+ response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
161
168
  return self._process_response(response), _map_usage(response)
162
169
 
163
170
  @asynccontextmanager
@@ -165,15 +172,14 @@ class MistralAgentModel(AgentModel):
165
172
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
166
173
  ) -> AsyncIterator[StreamedResponse]:
167
174
  """Make a streaming request to the model from Pydantic AI call."""
168
- response = await self._stream_completions_create(messages, model_settings)
175
+ response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
169
176
  async with response:
170
177
  yield await self._process_streamed_response(self.result_tools, response)
171
178
 
172
179
  async def _completions_create(
173
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
180
+ self, messages: list[ModelMessage], model_settings: MistralModelSettings
174
181
  ) -> MistralChatCompletionResponse:
175
182
  """Make a non-streaming request to the model."""
176
- model_settings = model_settings or {}
177
183
  response = await self.client.chat.complete_async(
178
184
  model=str(self.model_name),
179
185
  messages=list(chain(*(self._map_message(m) for m in messages))),
@@ -185,6 +191,7 @@ class MistralAgentModel(AgentModel):
185
191
  temperature=model_settings.get('temperature', UNSET),
186
192
  top_p=model_settings.get('top_p', 1),
187
193
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
194
+ random_seed=model_settings.get('seed', UNSET),
188
195
  )
189
196
  assert response, 'A unexpected empty response from Mistral.'
190
197
  return response
@@ -192,12 +199,11 @@ class MistralAgentModel(AgentModel):
192
199
  async def _stream_completions_create(
193
200
  self,
194
201
  messages: list[ModelMessage],
195
- model_settings: ModelSettings | None,
202
+ model_settings: MistralModelSettings,
196
203
  ) -> MistralEventStreamAsync[MistralCompletionEvent]:
197
204
  """Create a streaming completion request to the Mistral model."""
198
205
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
199
206
  mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
200
- model_settings = model_settings or {}
201
207
 
202
208
  if self.result_tools and self.function_tools or self.function_tools:
203
209
  # Function Calling
@@ -211,6 +217,8 @@ class MistralAgentModel(AgentModel):
211
217
  top_p=model_settings.get('top_p', 1),
212
218
  max_tokens=model_settings.get('max_tokens', UNSET),
213
219
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
220
+ presence_penalty=model_settings.get('presence_penalty'),
221
+ frequency_penalty=model_settings.get('frequency_penalty'),
214
222
  )
215
223
 
216
224
  elif self.result_tools:
@@ -265,8 +273,7 @@ class MistralAgentModel(AgentModel):
265
273
  ]
266
274
  return tools if tools else None
267
275
 
268
- @staticmethod
269
- def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
276
+ def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
270
277
  """Process a non-streamed response, and prepare a message to return."""
271
278
  assert response.choices, 'Unexpected empty response choice.'
272
279
 
@@ -288,10 +295,10 @@ class MistralAgentModel(AgentModel):
288
295
  tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
289
296
  parts.append(tool)
290
297
 
291
- return ModelResponse(parts, timestamp=timestamp)
298
+ return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
292
299
 
293
- @staticmethod
294
300
  async def _process_streamed_response(
301
+ self,
295
302
  result_tools: list[ToolDefinition],
296
303
  response: MistralEventStreamAsync[MistralCompletionEvent],
297
304
  ) -> StreamedResponse:
@@ -306,23 +313,21 @@ class MistralAgentModel(AgentModel):
306
313
  else:
307
314
  timestamp = datetime.now(tz=timezone.utc)
308
315
 
309
- return MistralStreamedResponse(peekable_response, timestamp, {c.name: c for c in result_tools})
316
+ return MistralStreamedResponse(
317
+ _response=peekable_response,
318
+ _model_name=self.model_name,
319
+ _timestamp=timestamp,
320
+ _result_tools={c.name: c for c in result_tools},
321
+ )
310
322
 
311
323
  @staticmethod
312
324
  def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
313
325
  """Maps a pydantic-ai ToolCall to a MistralToolCall."""
314
- if isinstance(t.args, ArgsJson):
315
- return MistralToolCall(
316
- id=t.tool_call_id,
317
- type='function',
318
- function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_json),
319
- )
320
- else:
321
- return MistralToolCall(
322
- id=t.tool_call_id,
323
- type='function',
324
- function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
325
- )
326
+ return MistralToolCall(
327
+ id=t.tool_call_id,
328
+ type='function',
329
+ function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
330
+ )
326
331
 
327
332
  def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
328
333
  """Get a message with an example of the expected output format."""
@@ -505,7 +510,7 @@ class MistralStreamedResponse(StreamedResponse):
505
510
  continue
506
511
 
507
512
  # The following part_id will be thrown away
508
- return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json)
513
+ return ToolCallPart(tool_name=result_tool.name, args=output_json)
509
514
 
510
515
  @staticmethod
511
516
  def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
@@ -563,7 +568,7 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
563
568
  tool_call_id = tool_call.id or None
564
569
  func_call = tool_call.function
565
570
 
566
- return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
571
+ return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
567
572
 
568
573
 
569
574
  def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
@@ -594,7 +599,7 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
594
599
  elif isinstance(content, str):
595
600
  result = content
596
601
 
597
- # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and reponses`)
602
+ # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
598
603
  if result and len(result) == 0:
599
604
  result = None
600
605
 
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
7
  from itertools import chain
8
- from typing import Literal, Union, overload
8
+ from typing import Literal, Union, cast, overload
9
9
 
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
@@ -48,9 +48,17 @@ except ImportError as _import_error:
48
48
  OpenAIModelName = Union[ChatModel, str]
49
49
  """
50
50
  Using this more broad type for the model name instead of the ChatModel definition
51
- allows this model to be used more easily with other model types (ie, Ollama)
51
+ allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
52
52
  """
53
53
 
54
+ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
55
+
56
+
57
+ class OpenAIModelSettings(ModelSettings):
58
+ """Settings used for an OpenAI model request."""
59
+
60
+ # This class is a placeholder for any future openai-specific settings
61
+
54
62
 
55
63
  @dataclass(init=False)
56
64
  class OpenAIModel(Model):
@@ -63,6 +71,7 @@ class OpenAIModel(Model):
63
71
 
64
72
  model_name: OpenAIModelName
65
73
  client: AsyncOpenAI = field(repr=False)
74
+ system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
66
75
 
67
76
  def __init__(
68
77
  self,
@@ -72,6 +81,7 @@ class OpenAIModel(Model):
72
81
  api_key: str | None = None,
73
82
  openai_client: AsyncOpenAI | None = None,
74
83
  http_client: AsyncHTTPClient | None = None,
84
+ system_prompt_role: OpenAISystemPromptRole | None = None,
75
85
  ):
76
86
  """Initialize an OpenAI model.
77
87
 
@@ -87,6 +97,8 @@ class OpenAIModel(Model):
87
97
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
88
98
  client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
89
99
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
100
+ system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
101
+ In the future, this may be inferred from the model name.
90
102
  """
91
103
  self.model_name: OpenAIModelName = model_name
92
104
  if openai_client is not None:
@@ -98,6 +110,7 @@ class OpenAIModel(Model):
98
110
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
99
111
  else:
100
112
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
113
+ self.system_prompt_role = system_prompt_role
101
114
 
102
115
  async def agent_model(
103
116
  self,
@@ -115,6 +128,7 @@ class OpenAIModel(Model):
115
128
  self.model_name,
116
129
  allow_text_result,
117
130
  tools,
131
+ self.system_prompt_role,
118
132
  )
119
133
 
120
134
  def name(self) -> str:
@@ -140,35 +154,36 @@ class OpenAIAgentModel(AgentModel):
140
154
  model_name: OpenAIModelName
141
155
  allow_text_result: bool
142
156
  tools: list[chat.ChatCompletionToolParam]
157
+ system_prompt_role: OpenAISystemPromptRole | None
143
158
 
144
159
  async def request(
145
160
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
146
161
  ) -> tuple[ModelResponse, usage.Usage]:
147
- response = await self._completions_create(messages, False, model_settings)
162
+ response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
148
163
  return self._process_response(response), _map_usage(response)
149
164
 
150
165
  @asynccontextmanager
151
166
  async def request_stream(
152
167
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
153
168
  ) -> AsyncIterator[StreamedResponse]:
154
- response = await self._completions_create(messages, True, model_settings)
169
+ response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
155
170
  async with response:
156
171
  yield await self._process_streamed_response(response)
157
172
 
158
173
  @overload
159
174
  async def _completions_create(
160
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
175
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
161
176
  ) -> AsyncStream[ChatCompletionChunk]:
162
177
  pass
163
178
 
164
179
  @overload
165
180
  async def _completions_create(
166
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
181
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
167
182
  ) -> chat.ChatCompletion:
168
183
  pass
169
184
 
170
185
  async def _completions_create(
171
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
186
+ self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
172
187
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
173
188
  # standalone function to make it easier to override
174
189
  if not self.tools:
@@ -180,13 +195,11 @@ class OpenAIAgentModel(AgentModel):
180
195
 
181
196
  openai_messages = list(chain(*(self._map_message(m) for m in messages)))
182
197
 
183
- model_settings = model_settings or {}
184
-
185
198
  return await self.client.chat.completions.create(
186
199
  model=self.model_name,
187
200
  messages=openai_messages,
188
201
  n=1,
189
- parallel_tool_calls=True if self.tools else NOT_GIVEN,
202
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
190
203
  tools=self.tools or NOT_GIVEN,
191
204
  tool_choice=tool_choice or NOT_GIVEN,
192
205
  stream=stream,
@@ -195,10 +208,13 @@ class OpenAIAgentModel(AgentModel):
195
208
  temperature=model_settings.get('temperature', NOT_GIVEN),
196
209
  top_p=model_settings.get('top_p', NOT_GIVEN),
197
210
  timeout=model_settings.get('timeout', NOT_GIVEN),
211
+ seed=model_settings.get('seed', NOT_GIVEN),
212
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
213
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
214
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
198
215
  )
199
216
 
200
- @staticmethod
201
- def _process_response(response: chat.ChatCompletion) -> ModelResponse:
217
+ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
202
218
  """Process a non-streamed response, and prepare a message to return."""
203
219
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
204
220
  choice = response.choices[0]
@@ -207,24 +223,26 @@ class OpenAIAgentModel(AgentModel):
207
223
  items.append(TextPart(choice.message.content))
208
224
  if choice.message.tool_calls is not None:
209
225
  for c in choice.message.tool_calls:
210
- items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
211
- return ModelResponse(items, timestamp=timestamp)
226
+ items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
227
+ return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
212
228
 
213
- @staticmethod
214
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
229
+ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
215
230
  """Process a streamed response, and prepare a streaming response to return."""
216
231
  peekable_response = _utils.PeekableAsyncStream(response)
217
232
  first_chunk = await peekable_response.peek()
218
233
  if isinstance(first_chunk, _utils.Unset):
219
234
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
220
235
 
221
- return OpenAIStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
236
+ return OpenAIStreamedResponse(
237
+ _model_name=self.model_name,
238
+ _response=peekable_response,
239
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
240
+ )
222
241
 
223
- @classmethod
224
- def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
242
+ def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
225
243
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
226
244
  if isinstance(message, ModelRequest):
227
- yield from cls._map_user_message(message)
245
+ yield from self._map_user_message(message)
228
246
  elif isinstance(message, ModelResponse):
229
247
  texts: list[str] = []
230
248
  tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -246,11 +264,15 @@ class OpenAIAgentModel(AgentModel):
246
264
  else:
247
265
  assert_never(message)
248
266
 
249
- @classmethod
250
- def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
267
+ def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
251
268
  for part in message.parts:
252
269
  if isinstance(part, SystemPromptPart):
253
- yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
270
+ if self.system_prompt_role == 'developer':
271
+ yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
272
+ elif self.system_prompt_role == 'user':
273
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
274
+ else:
275
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
254
276
  elif isinstance(part, UserPromptPart):
255
277
  yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
256
278
  elif isinstance(part, ToolReturnPart):
@@ -12,7 +12,6 @@ import pydantic_core
12
12
 
13
13
  from .. import _utils
14
14
  from ..messages import (
15
- ArgsJson,
16
15
  ModelMessage,
17
16
  ModelRequest,
18
17
  ModelResponse,
@@ -34,6 +33,20 @@ from . import (
34
33
  from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
35
34
 
36
35
 
36
+ @dataclass
37
+ class _TextResult:
38
+ """A private wrapper class to tag a result that came from the custom_result_text field."""
39
+
40
+ value: str | None
41
+
42
+
43
+ @dataclass
44
+ class _FunctionToolResult:
45
+ """A wrapper class to tag a result that came from the custom_result_args field."""
46
+
47
+ value: Any | None
48
+
49
+
37
50
  @dataclass
38
51
  class TestModel(Model):
39
52
  """A model specifically for testing purposes.
@@ -53,7 +66,7 @@ class TestModel(Model):
53
66
  call_tools: list[str] | Literal['all'] = 'all'
54
67
  """List of tools to call. If `'all'`, all tools will be called."""
55
68
  custom_result_text: str | None = None
56
- """If set, this text is return as the final result."""
69
+ """If set, this text is returned as the final result."""
57
70
  custom_result_args: Any | None = None
58
71
  """If set, these args will be passed to the result tool."""
59
72
  seed: int = 0
@@ -95,21 +108,21 @@ class TestModel(Model):
95
108
  if self.custom_result_text is not None:
96
109
  assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
97
110
  assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
98
- result: _utils.Either[str | None, Any | None] = _utils.Either(left=self.custom_result_text)
111
+ result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text)
99
112
  elif self.custom_result_args is not None:
100
113
  assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
101
114
  result_tool = result_tools[0]
102
115
 
103
116
  if k := result_tool.outer_typed_dict_key:
104
- result = _utils.Either(right={k: self.custom_result_args})
117
+ result = _FunctionToolResult({k: self.custom_result_args})
105
118
  else:
106
- result = _utils.Either(right=self.custom_result_args)
119
+ result = _FunctionToolResult(self.custom_result_args)
107
120
  elif allow_text_result:
108
- result = _utils.Either(left=None)
121
+ result = _TextResult(None)
109
122
  elif result_tools:
110
- result = _utils.Either(right=None)
123
+ result = _FunctionToolResult(None)
111
124
  else:
112
- result = _utils.Either(left=None)
125
+ result = _TextResult(None)
113
126
 
114
127
  return TestAgentModel(tool_calls, result, result_tools, self.seed)
115
128
 
@@ -126,9 +139,10 @@ class TestAgentModel(AgentModel):
126
139
 
127
140
  tool_calls: list[tuple[str, ToolDefinition]]
128
141
  # left means the text is plain text; right means it's a function call
129
- result: _utils.Either[str | None, Any | None]
142
+ result: _TextResult | _FunctionToolResult
130
143
  result_tools: list[ToolDefinition]
131
144
  seed: int
145
+ model_name: str = 'test'
132
146
 
133
147
  async def request(
134
148
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
@@ -142,7 +156,7 @@ class TestAgentModel(AgentModel):
142
156
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
143
157
  ) -> AsyncIterator[StreamedResponse]:
144
158
  model_response = self._request(messages, model_settings)
145
- yield TestStreamedResponse(model_response, messages)
159
+ yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
146
160
 
147
161
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
148
162
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
@@ -151,7 +165,8 @@ class TestAgentModel(AgentModel):
151
165
  # if there are tools, the first thing we want to do is call all of them
152
166
  if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
153
167
  return ModelResponse(
154
- parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
168
+ parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
169
+ model_name=self.model_name,
155
170
  )
156
171
 
157
172
  if messages:
@@ -164,7 +179,7 @@ class TestAgentModel(AgentModel):
164
179
  # Handle retries for both function tools and result tools
165
180
  # Check function tools first
166
181
  retry_parts: list[ModelResponsePart] = [
167
- ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
182
+ ToolCallPart(name, self.gen_tool_args(args))
168
183
  for name, args in self.tool_calls
169
184
  if name in new_retry_names
170
185
  ]
@@ -172,15 +187,20 @@ class TestAgentModel(AgentModel):
172
187
  if self.result_tools:
173
188
  retry_parts.extend(
174
189
  [
175
- ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
190
+ ToolCallPart(
191
+ tool.name,
192
+ self.result.value
193
+ if isinstance(self.result, _FunctionToolResult) and self.result.value is not None
194
+ else self.gen_tool_args(tool),
195
+ )
176
196
  for tool in self.result_tools
177
197
  if tool.name in new_retry_names
178
198
  ]
179
199
  )
180
- return ModelResponse(parts=retry_parts)
200
+ return ModelResponse(parts=retry_parts, model_name=self.model_name)
181
201
 
182
- if response_text := self.result.left:
183
- if response_text.value is None:
202
+ if isinstance(self.result, _TextResult):
203
+ if (response_text := self.result.value) is None:
184
204
  # build up details of tool responses
185
205
  output: dict[str, Any] = {}
186
206
  for message in messages:
@@ -189,20 +209,24 @@ class TestAgentModel(AgentModel):
189
209
  if isinstance(part, ToolReturnPart):
190
210
  output[part.tool_name] = part.content
191
211
  if output:
192
- return ModelResponse.from_text(pydantic_core.to_json(output).decode())
212
+ return ModelResponse(
213
+ parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
214
+ )
193
215
  else:
194
- return ModelResponse.from_text('success (no tool calls)')
216
+ return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
195
217
  else:
196
- return ModelResponse.from_text(response_text.value)
218
+ return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name)
197
219
  else:
198
220
  assert self.result_tools, 'No result tools provided'
199
- custom_result_args = self.result.right
221
+ custom_result_args = self.result.value
200
222
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
201
223
  if custom_result_args is not None:
202
- return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
224
+ return ModelResponse(
225
+ parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name
226
+ )
203
227
  else:
204
228
  response_args = self.gen_tool_args(result_tool)
205
- return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
229
+ return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name)
206
230
 
207
231
 
208
232
  @dataclass
@@ -233,9 +257,8 @@ class TestStreamedResponse(StreamedResponse):
233
257
  self._usage += _get_string_usage(word)
234
258
  yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
235
259
  else:
236
- args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
237
260
  yield self._parts_manager.handle_tool_call_part(
238
- vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
261
+ vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
239
262
  )
240
263
 
241
264
  def timestamp(self) -> datetime:
@@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient
10
10
  from .._utils import run_in_executor
11
11
  from ..exceptions import UserError
12
12
  from ..tools import ToolDefinition
13
- from . import Model, cached_async_http_client
13
+ from . import Model, cached_async_http_client, check_allow_model_requests
14
14
  from .gemini import GeminiAgentModel, GeminiModelName
15
15
 
16
16
  try:
@@ -114,6 +114,7 @@ class VertexAIModel(Model):
114
114
  allow_text_result: bool,
115
115
  result_tools: list[ToolDefinition],
116
116
  ) -> GeminiAgentModel:
117
+ check_allow_model_requests()
117
118
  url, auth = await self.ainit()
118
119
  return GeminiAgentModel(
119
120
  http_client=self.http_client,