pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -29,8 +29,8 @@ from ..messages import (
29
29
  from ..settings import ModelSettings
30
30
  from ..tools import ToolDefinition
31
31
  from . import (
32
- AgentModel,
33
32
  Model,
33
+ ModelRequestParameters,
34
34
  StreamedResponse,
35
35
  cached_async_http_client,
36
36
  check_allow_model_requests,
@@ -46,10 +46,16 @@ except ImportError as _import_error:
46
46
  "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
47
47
  ) from _import_error
48
48
 
49
- OpenAIModelName = Union[ChatModel, str]
49
+ OpenAIModelName = Union[str, ChatModel]
50
50
  """
51
+ Possible OpenAI model names.
52
+
53
+ Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but
54
+ allow any name in the type hints.
55
+ See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list.
56
+
51
57
  Using this more broad type for the model name instead of the ChatModel definition
52
- allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
58
+ allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
53
59
  """
54
60
 
55
61
  OpenAISystemPromptRole = Literal['system', 'developer', 'user']
@@ -58,7 +64,12 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
58
64
  class OpenAIModelSettings(ModelSettings):
59
65
  """Settings used for an OpenAI model request."""
60
66
 
61
- # This class is a placeholder for any future openai-specific settings
67
+ openai_reasoning_effort: chat.ChatCompletionReasoningEffort
68
+ """
69
+ Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
70
+ Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
71
+ result in faster responses and fewer tokens used on reasoning in a response.
72
+ """
62
73
 
63
74
 
64
75
  @dataclass(init=False)
@@ -70,10 +81,12 @@ class OpenAIModel(Model):
70
81
  Apart from `__init__`, all methods are private or match those of the base class.
71
82
  """
72
83
 
73
- model_name: OpenAIModelName
74
84
  client: AsyncOpenAI = field(repr=False)
75
85
  system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
76
86
 
87
+ _model_name: OpenAIModelName = field(repr=False)
88
+ _system: str | None = field(repr=False)
89
+
77
90
  def __init__(
78
91
  self,
79
92
  model_name: OpenAIModelName,
@@ -83,6 +96,7 @@ class OpenAIModel(Model):
83
96
  openai_client: AsyncOpenAI | None = None,
84
97
  http_client: AsyncHTTPClient | None = None,
85
98
  system_prompt_role: OpenAISystemPromptRole | None = None,
99
+ system: str | None = 'openai',
86
100
  ):
87
101
  """Initialize an OpenAI model.
88
102
 
@@ -100,13 +114,16 @@ class OpenAIModel(Model):
100
114
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
101
115
  system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
102
116
  In the future, this may be inferred from the model name.
117
+ system: The model provider used, defaults to `openai`. This is for observability purposes, you must
118
+ customize the `base_url` and `api_key` to use a different provider.
103
119
  """
104
- self.model_name: OpenAIModelName = model_name
120
+ self._model_name = model_name
105
121
  # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
106
122
  # openai compatible models do not always need an API key.
107
123
  if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
108
124
  api_key = ''
109
- elif openai_client is not None:
125
+
126
+ if openai_client is not None:
110
127
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
111
128
  assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
112
129
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
@@ -116,84 +133,80 @@ class OpenAIModel(Model):
116
133
  else:
117
134
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
118
135
  self.system_prompt_role = system_prompt_role
119
-
120
- async def agent_model(
121
- self,
122
- *,
123
- function_tools: list[ToolDefinition],
124
- allow_text_result: bool,
125
- result_tools: list[ToolDefinition],
126
- ) -> AgentModel:
127
- check_allow_model_requests()
128
- tools = [self._map_tool_definition(r) for r in function_tools]
129
- if result_tools:
130
- tools += [self._map_tool_definition(r) for r in result_tools]
131
- return OpenAIAgentModel(
132
- self.client,
133
- self.model_name,
134
- allow_text_result,
135
- tools,
136
- self.system_prompt_role,
137
- )
136
+ self._system = system
138
137
 
139
138
  def name(self) -> str:
140
- return f'openai:{self.model_name}'
141
-
142
- @staticmethod
143
- def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
144
- return {
145
- 'type': 'function',
146
- 'function': {
147
- 'name': f.name,
148
- 'description': f.description,
149
- 'parameters': f.parameters_json_schema,
150
- },
151
- }
152
-
153
-
154
- @dataclass
155
- class OpenAIAgentModel(AgentModel):
156
- """Implementation of `AgentModel` for OpenAI models."""
157
-
158
- client: AsyncOpenAI
159
- model_name: OpenAIModelName
160
- allow_text_result: bool
161
- tools: list[chat.ChatCompletionToolParam]
162
- system_prompt_role: OpenAISystemPromptRole | None
139
+ return f'openai:{self._model_name}'
163
140
 
164
141
  async def request(
165
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
142
+ self,
143
+ messages: list[ModelMessage],
144
+ model_settings: ModelSettings | None,
145
+ model_request_parameters: ModelRequestParameters,
166
146
  ) -> tuple[ModelResponse, usage.Usage]:
167
- response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
147
+ check_allow_model_requests()
148
+ response = await self._completions_create(
149
+ messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
150
+ )
168
151
  return self._process_response(response), _map_usage(response)
169
152
 
170
153
  @asynccontextmanager
171
154
  async def request_stream(
172
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
+ self,
156
+ messages: list[ModelMessage],
157
+ model_settings: ModelSettings | None,
158
+ model_request_parameters: ModelRequestParameters,
173
159
  ) -> AsyncIterator[StreamedResponse]:
174
- response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
160
+ check_allow_model_requests()
161
+ response = await self._completions_create(
162
+ messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
163
+ )
175
164
  async with response:
176
165
  yield await self._process_streamed_response(response)
177
166
 
167
+ @property
168
+ def model_name(self) -> OpenAIModelName:
169
+ """The model name."""
170
+ return self._model_name
171
+
172
+ @property
173
+ def system(self) -> str | None:
174
+ """The system / model provider."""
175
+ return self._system
176
+
178
177
  @overload
179
178
  async def _completions_create(
180
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
179
+ self,
180
+ messages: list[ModelMessage],
181
+ stream: Literal[True],
182
+ model_settings: OpenAIModelSettings,
183
+ model_request_parameters: ModelRequestParameters,
181
184
  ) -> AsyncStream[ChatCompletionChunk]:
182
185
  pass
183
186
 
184
187
  @overload
185
188
  async def _completions_create(
186
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
189
+ self,
190
+ messages: list[ModelMessage],
191
+ stream: Literal[False],
192
+ model_settings: OpenAIModelSettings,
193
+ model_request_parameters: ModelRequestParameters,
187
194
  ) -> chat.ChatCompletion:
188
195
  pass
189
196
 
190
197
  async def _completions_create(
191
- self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
198
+ self,
199
+ messages: list[ModelMessage],
200
+ stream: bool,
201
+ model_settings: OpenAIModelSettings,
202
+ model_request_parameters: ModelRequestParameters,
192
203
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
204
+ tools = self._get_tools(model_request_parameters)
205
+
193
206
  # standalone function to make it easier to override
194
- if not self.tools:
207
+ if not tools:
195
208
  tool_choice: Literal['none', 'required', 'auto'] | None = None
196
- elif not self.allow_text_result:
209
+ elif not model_request_parameters.allow_text_result:
197
210
  tool_choice = 'required'
198
211
  else:
199
212
  tool_choice = 'auto'
@@ -201,11 +214,11 @@ class OpenAIAgentModel(AgentModel):
201
214
  openai_messages = list(chain(*(self._map_message(m) for m in messages)))
202
215
 
203
216
  return await self.client.chat.completions.create(
204
- model=self.model_name,
217
+ model=self._model_name,
205
218
  messages=openai_messages,
206
219
  n=1,
207
220
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
208
- tools=self.tools or NOT_GIVEN,
221
+ tools=tools or NOT_GIVEN,
209
222
  tool_choice=tool_choice or NOT_GIVEN,
210
223
  stream=stream,
211
224
  stream_options={'include_usage': True} if stream else NOT_GIVEN,
@@ -217,6 +230,7 @@ class OpenAIAgentModel(AgentModel):
217
230
  presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
218
231
  frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
219
232
  logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
233
+ reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
220
234
  )
221
235
 
222
236
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
@@ -229,7 +243,7 @@ class OpenAIAgentModel(AgentModel):
229
243
  if choice.message.tool_calls is not None:
230
244
  for c in choice.message.tool_calls:
231
245
  items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
232
- return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
246
+ return ModelResponse(items, model_name=response.model, timestamp=timestamp)
233
247
 
234
248
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
235
249
  """Process a streamed response, and prepare a streaming response to return."""
@@ -239,11 +253,17 @@ class OpenAIAgentModel(AgentModel):
239
253
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
240
254
 
241
255
  return OpenAIStreamedResponse(
242
- _model_name=self.model_name,
256
+ _model_name=self._model_name,
243
257
  _response=peekable_response,
244
258
  _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
245
259
  )
246
260
 
261
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
262
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
263
+ if model_request_parameters.result_tools:
264
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
265
+ return tools
266
+
247
267
  def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
248
268
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
249
269
  if isinstance(message, ModelRequest):
@@ -255,7 +275,7 @@ class OpenAIAgentModel(AgentModel):
255
275
  if isinstance(item, TextPart):
256
276
  texts.append(item.content)
257
277
  elif isinstance(item, ToolCallPart):
258
- tool_calls.append(_map_tool_call(item))
278
+ tool_calls.append(self._map_tool_call(item))
259
279
  else:
260
280
  assert_never(item)
261
281
  message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
@@ -269,6 +289,25 @@ class OpenAIAgentModel(AgentModel):
269
289
  else:
270
290
  assert_never(message)
271
291
 
292
+ @staticmethod
293
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
294
+ return chat.ChatCompletionMessageToolCallParam(
295
+ id=_guard_tool_call_id(t=t, model_source='OpenAI'),
296
+ type='function',
297
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
298
+ )
299
+
300
+ @staticmethod
301
+ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
302
+ return {
303
+ 'type': 'function',
304
+ 'function': {
305
+ 'name': f.name,
306
+ 'description': f.description,
307
+ 'parameters': f.parameters_json_schema,
308
+ },
309
+ }
310
+
272
311
  def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
273
312
  for part in message.parts:
274
313
  if isinstance(part, SystemPromptPart):
@@ -303,6 +342,7 @@ class OpenAIAgentModel(AgentModel):
303
342
  class OpenAIStreamedResponse(StreamedResponse):
304
343
  """Implementation of `StreamedResponse` for OpenAI models."""
305
344
 
345
+ _model_name: OpenAIModelName
306
346
  _response: AsyncIterable[ChatCompletionChunk]
307
347
  _timestamp: datetime
308
348
 
@@ -330,18 +370,17 @@ class OpenAIStreamedResponse(StreamedResponse):
330
370
  if maybe_event is not None:
331
371
  yield maybe_event
332
372
 
373
+ @property
374
+ def model_name(self) -> OpenAIModelName:
375
+ """Get the model name of the response."""
376
+ return self._model_name
377
+
378
+ @property
333
379
  def timestamp(self) -> datetime:
380
+ """Get the timestamp of the response."""
334
381
  return self._timestamp
335
382
 
336
383
 
337
- def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
338
- return chat.ChatCompletionMessageToolCallParam(
339
- id=_guard_tool_call_id(t=t, model_source='OpenAI'),
340
- type='function',
341
- function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
342
- )
343
-
344
-
345
384
  def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
346
385
  response_usage = response.usage
347
386
  if response_usage is None:
@@ -26,8 +26,8 @@ from ..result import Usage
26
26
  from ..settings import ModelSettings
27
27
  from ..tools import ToolDefinition
28
28
  from . import (
29
- AgentModel,
30
29
  Model,
30
+ ModelRequestParameters,
31
31
  StreamedResponse,
32
32
  )
33
33
  from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
@@ -71,102 +71,102 @@ class TestModel(Model):
71
71
  """If set, these args will be passed to the result tool."""
72
72
  seed: int = 0
73
73
  """Seed for generating random data."""
74
- agent_model_function_tools: list[ToolDefinition] | None = field(default=None, init=False)
75
- """Definition of function tools passed to the model.
74
+ last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
75
+ """The last ModelRequestParameters passed to the model in a request.
76
76
 
77
- This is set when the model is called, so will reflect the function tools from the last step of the last run.
78
- """
79
- agent_model_allow_text_result: bool | None = field(default=None, init=False)
80
- """Whether plain text responses from the model are allowed.
77
+ The ModelRequestParameters contains information about the function and result tools available during request handling.
81
78
 
82
- This is set when the model is called, so will reflect the value from the last step of the last run.
79
+ This is set when a request is made, so will reflect the function tools from the last step of the last run.
83
80
  """
84
- agent_model_result_tools: list[ToolDefinition] | None = field(default=None, init=False)
85
- """Definition of result tools passed to the model.
81
+ _model_name: str = field(default='test', repr=False)
82
+ _system: str | None = field(default=None, repr=False)
86
83
 
87
- This is set when the model is called, so will reflect the result tools from the last step of the last run.
88
- """
84
+ async def request(
85
+ self,
86
+ messages: list[ModelMessage],
87
+ model_settings: ModelSettings | None,
88
+ model_request_parameters: ModelRequestParameters,
89
+ ) -> tuple[ModelResponse, Usage]:
90
+ self.last_model_request_parameters = model_request_parameters
89
91
 
90
- async def agent_model(
92
+ model_response = self._request(messages, model_settings, model_request_parameters)
93
+ usage = _estimate_usage([*messages, model_response])
94
+ return model_response, usage
95
+
96
+ @asynccontextmanager
97
+ async def request_stream(
91
98
  self,
92
- *,
93
- function_tools: list[ToolDefinition],
94
- allow_text_result: bool,
95
- result_tools: list[ToolDefinition],
96
- ) -> AgentModel:
97
- self.agent_model_function_tools = function_tools
98
- self.agent_model_allow_text_result = allow_text_result
99
- self.agent_model_result_tools = result_tools
99
+ messages: list[ModelMessage],
100
+ model_settings: ModelSettings | None,
101
+ model_request_parameters: ModelRequestParameters,
102
+ ) -> AsyncIterator[StreamedResponse]:
103
+ self.last_model_request_parameters = model_request_parameters
104
+
105
+ model_response = self._request(messages, model_settings, model_request_parameters)
106
+ yield TestStreamedResponse(
107
+ _model_name=self._model_name, _structured_response=model_response, _messages=messages
108
+ )
109
+
110
+ @property
111
+ def model_name(self) -> str:
112
+ """The model name."""
113
+ return self._model_name
114
+
115
+ @property
116
+ def system(self) -> str | None:
117
+ """The system / model provider."""
118
+ return self._system
119
+
120
+ def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
121
+ return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
100
122
 
123
+ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]:
101
124
  if self.call_tools == 'all':
102
- tool_calls = [(r.name, r) for r in function_tools]
125
+ return [(r.name, r) for r in model_request_parameters.function_tools]
103
126
  else:
104
- function_tools_lookup = {t.name: t for t in function_tools}
127
+ function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools}
105
128
  tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
106
- tool_calls = [(r.name, r) for r in tools_to_call]
129
+ return [(r.name, r) for r in tools_to_call]
107
130
 
131
+ def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
108
132
  if self.custom_result_text is not None:
109
- assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
133
+ assert (
134
+ model_request_parameters.allow_text_result
135
+ ), 'Plain response not allowed, but `custom_result_text` is set.'
110
136
  assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
111
- result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text)
137
+ return _TextResult(self.custom_result_text)
112
138
  elif self.custom_result_args is not None:
113
- assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
114
- result_tool = result_tools[0]
139
+ assert (
140
+ model_request_parameters.result_tools is not None
141
+ ), 'No result tools provided, but `custom_result_args` is set.'
142
+ result_tool = model_request_parameters.result_tools[0]
115
143
 
116
144
  if k := result_tool.outer_typed_dict_key:
117
- result = _FunctionToolResult({k: self.custom_result_args})
145
+ return _FunctionToolResult({k: self.custom_result_args})
118
146
  else:
119
- result = _FunctionToolResult(self.custom_result_args)
120
- elif allow_text_result:
121
- result = _TextResult(None)
122
- elif result_tools:
123
- result = _FunctionToolResult(None)
147
+ return _FunctionToolResult(self.custom_result_args)
148
+ elif model_request_parameters.allow_text_result:
149
+ return _TextResult(None)
150
+ elif model_request_parameters.result_tools:
151
+ return _FunctionToolResult(None)
124
152
  else:
125
- result = _TextResult(None)
126
-
127
- return TestAgentModel(tool_calls, result, result_tools, self.seed)
128
-
129
- def name(self) -> str:
130
- return 'test-model'
131
-
153
+ return _TextResult(None)
132
154
 
133
- @dataclass
134
- class TestAgentModel(AgentModel):
135
- """Implementation of `AgentModel` for testing purposes."""
136
-
137
- # NOTE: Avoid test discovery by pytest.
138
- __test__ = False
139
-
140
- tool_calls: list[tuple[str, ToolDefinition]]
141
- # left means the text is plain text; right means it's a function call
142
- result: _TextResult | _FunctionToolResult
143
- result_tools: list[ToolDefinition]
144
- seed: int
145
- model_name: str = 'test'
146
-
147
- async def request(
148
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
149
- ) -> tuple[ModelResponse, Usage]:
150
- model_response = self._request(messages, model_settings)
151
- usage = _estimate_usage([*messages, model_response])
152
- return model_response, usage
153
-
154
- @asynccontextmanager
155
- async def request_stream(
156
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
157
- ) -> AsyncIterator[StreamedResponse]:
158
- model_response = self._request(messages, model_settings)
159
- yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
160
-
161
- def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
162
- return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
155
+ def _request(
156
+ self,
157
+ messages: list[ModelMessage],
158
+ model_settings: ModelSettings | None,
159
+ model_request_parameters: ModelRequestParameters,
160
+ ) -> ModelResponse:
161
+ tool_calls = self._get_tool_calls(model_request_parameters)
162
+ result = self._get_result(model_request_parameters)
163
+ result_tools = model_request_parameters.result_tools
163
164
 
164
- def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
165
165
  # if there are tools, the first thing we want to do is call all of them
166
- if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
166
+ if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
167
167
  return ModelResponse(
168
- parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
169
- model_name=self.model_name,
168
+ parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
169
+ model_name=self._model_name,
170
170
  )
171
171
 
172
172
  if messages:
@@ -179,28 +179,26 @@ class TestAgentModel(AgentModel):
179
179
  # Handle retries for both function tools and result tools
180
180
  # Check function tools first
181
181
  retry_parts: list[ModelResponsePart] = [
182
- ToolCallPart(name, self.gen_tool_args(args))
183
- for name, args in self.tool_calls
184
- if name in new_retry_names
182
+ ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
185
183
  ]
186
184
  # Check result tools
187
- if self.result_tools:
185
+ if result_tools:
188
186
  retry_parts.extend(
189
187
  [
190
188
  ToolCallPart(
191
189
  tool.name,
192
- self.result.value
193
- if isinstance(self.result, _FunctionToolResult) and self.result.value is not None
190
+ result.value
191
+ if isinstance(result, _FunctionToolResult) and result.value is not None
194
192
  else self.gen_tool_args(tool),
195
193
  )
196
- for tool in self.result_tools
194
+ for tool in result_tools
197
195
  if tool.name in new_retry_names
198
196
  ]
199
197
  )
200
- return ModelResponse(parts=retry_parts, model_name=self.model_name)
198
+ return ModelResponse(parts=retry_parts, model_name=self._model_name)
201
199
 
202
- if isinstance(self.result, _TextResult):
203
- if (response_text := self.result.value) is None:
200
+ if isinstance(result, _TextResult):
201
+ if (response_text := result.value) is None:
204
202
  # build up details of tool responses
205
203
  output: dict[str, Any] = {}
206
204
  for message in messages:
@@ -210,32 +208,32 @@ class TestAgentModel(AgentModel):
210
208
  output[part.tool_name] = part.content
211
209
  if output:
212
210
  return ModelResponse(
213
- parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
211
+ parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name
214
212
  )
215
213
  else:
216
- return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
214
+ return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name)
217
215
  else:
218
- return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name)
216
+ return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
219
217
  else:
220
- assert self.result_tools, 'No result tools provided'
221
- custom_result_args = self.result.value
222
- result_tool = self.result_tools[self.seed % len(self.result_tools)]
218
+ assert result_tools, 'No result tools provided'
219
+ custom_result_args = result.value
220
+ result_tool = result_tools[self.seed % len(result_tools)]
223
221
  if custom_result_args is not None:
224
222
  return ModelResponse(
225
- parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name
223
+ parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name
226
224
  )
227
225
  else:
228
226
  response_args = self.gen_tool_args(result_tool)
229
- return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name)
227
+ return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name)
230
228
 
231
229
 
232
230
  @dataclass
233
231
  class TestStreamedResponse(StreamedResponse):
234
232
  """A structured response that streams test data."""
235
233
 
234
+ _model_name: str
236
235
  _structured_response: ModelResponse
237
236
  _messages: InitVar[Iterable[ModelMessage]]
238
-
239
237
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
240
238
 
241
239
  def __post_init__(self, _messages: Iterable[ModelMessage]):
@@ -261,7 +259,14 @@ class TestStreamedResponse(StreamedResponse):
261
259
  vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
262
260
  )
263
261
 
262
+ @property
263
+ def model_name(self) -> str:
264
+ """Get the model name of the response."""
265
+ return self._model_name
266
+
267
+ @property
264
268
  def timestamp(self) -> datetime:
269
+ """Get the timestamp of the response."""
265
270
  return self._timestamp
266
271
 
267
272