pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import os
3
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
5
  from contextlib import asynccontextmanager
5
6
  from dataclasses import dataclass, field
@@ -28,8 +29,8 @@ from ..messages import (
28
29
  from ..settings import ModelSettings
29
30
  from ..tools import ToolDefinition
30
31
  from . import (
31
- AgentModel,
32
32
  Model,
33
+ ModelRequestParameters,
33
34
  StreamedResponse,
34
35
  cached_async_http_client,
35
36
  check_allow_model_requests,
@@ -45,10 +46,16 @@ except ImportError as _import_error:
45
46
  "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
46
47
  ) from _import_error
47
48
 
48
- OpenAIModelName = Union[ChatModel, str]
49
+ OpenAIModelName = Union[str, ChatModel]
49
50
  """
51
+ Possible OpenAI model names.
52
+
53
+ Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but
54
+ allow any name in the type hints.
55
+ See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list.
56
+
50
57
  Using this more broad type for the model name instead of the ChatModel definition
51
- allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
58
+ allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
52
59
  """
53
60
 
54
61
  OpenAISystemPromptRole = Literal['system', 'developer', 'user']
@@ -57,7 +64,12 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
57
64
  class OpenAIModelSettings(ModelSettings):
58
65
  """Settings used for an OpenAI model request."""
59
66
 
60
- # This class is a placeholder for any future openai-specific settings
67
+ openai_reasoning_effort: chat.ChatCompletionReasoningEffort
68
+ """
69
+ Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
70
+ Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
71
+ result in faster responses and fewer tokens used on reasoning in a response.
72
+ """
61
73
 
62
74
 
63
75
  @dataclass(init=False)
@@ -69,10 +81,12 @@ class OpenAIModel(Model):
69
81
  Apart from `__init__`, all methods are private or match those of the base class.
70
82
  """
71
83
 
72
- model_name: OpenAIModelName
73
84
  client: AsyncOpenAI = field(repr=False)
74
85
  system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
75
86
 
87
+ _model_name: OpenAIModelName = field(repr=False)
88
+ _system: str | None = field(repr=False)
89
+
76
90
  def __init__(
77
91
  self,
78
92
  model_name: OpenAIModelName,
@@ -82,6 +96,7 @@ class OpenAIModel(Model):
82
96
  openai_client: AsyncOpenAI | None = None,
83
97
  http_client: AsyncHTTPClient | None = None,
84
98
  system_prompt_role: OpenAISystemPromptRole | None = None,
99
+ system: str | None = 'openai',
85
100
  ):
86
101
  """Initialize an OpenAI model.
87
102
 
@@ -99,9 +114,15 @@ class OpenAIModel(Model):
99
114
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
100
115
  system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
101
116
  In the future, this may be inferred from the model name.
117
+ system: The model provider used, defaults to `openai`. This is for observability purposes, you must
118
+ customize the `base_url` and `api_key` to use a different provider.
102
119
  """
103
- self.model_name: OpenAIModelName = model_name
104
- if openai_client is not None:
120
+ self._model_name = model_name
121
+ # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
122
+ # openai compatible models do not always need an API key.
123
+ if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
124
+ api_key = ''
125
+ elif openai_client is not None:
105
126
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
106
127
  assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
107
128
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
@@ -111,84 +132,70 @@ class OpenAIModel(Model):
111
132
  else:
112
133
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
113
134
  self.system_prompt_role = system_prompt_role
114
-
115
- async def agent_model(
116
- self,
117
- *,
118
- function_tools: list[ToolDefinition],
119
- allow_text_result: bool,
120
- result_tools: list[ToolDefinition],
121
- ) -> AgentModel:
122
- check_allow_model_requests()
123
- tools = [self._map_tool_definition(r) for r in function_tools]
124
- if result_tools:
125
- tools += [self._map_tool_definition(r) for r in result_tools]
126
- return OpenAIAgentModel(
127
- self.client,
128
- self.model_name,
129
- allow_text_result,
130
- tools,
131
- self.system_prompt_role,
132
- )
135
+ self._system = system
133
136
 
134
137
  def name(self) -> str:
135
- return f'openai:{self.model_name}'
136
-
137
- @staticmethod
138
- def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
139
- return {
140
- 'type': 'function',
141
- 'function': {
142
- 'name': f.name,
143
- 'description': f.description,
144
- 'parameters': f.parameters_json_schema,
145
- },
146
- }
147
-
148
-
149
- @dataclass
150
- class OpenAIAgentModel(AgentModel):
151
- """Implementation of `AgentModel` for OpenAI models."""
152
-
153
- client: AsyncOpenAI
154
- model_name: OpenAIModelName
155
- allow_text_result: bool
156
- tools: list[chat.ChatCompletionToolParam]
157
- system_prompt_role: OpenAISystemPromptRole | None
138
+ return f'openai:{self._model_name}'
158
139
 
159
140
  async def request(
160
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
141
+ self,
142
+ messages: list[ModelMessage],
143
+ model_settings: ModelSettings | None,
144
+ model_request_parameters: ModelRequestParameters,
161
145
  ) -> tuple[ModelResponse, usage.Usage]:
162
- response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
146
+ check_allow_model_requests()
147
+ response = await self._completions_create(
148
+ messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
149
+ )
163
150
  return self._process_response(response), _map_usage(response)
164
151
 
165
152
  @asynccontextmanager
166
153
  async def request_stream(
167
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
154
+ self,
155
+ messages: list[ModelMessage],
156
+ model_settings: ModelSettings | None,
157
+ model_request_parameters: ModelRequestParameters,
168
158
  ) -> AsyncIterator[StreamedResponse]:
169
- response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
159
+ check_allow_model_requests()
160
+ response = await self._completions_create(
161
+ messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
162
+ )
170
163
  async with response:
171
164
  yield await self._process_streamed_response(response)
172
165
 
173
166
  @overload
174
167
  async def _completions_create(
175
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
168
+ self,
169
+ messages: list[ModelMessage],
170
+ stream: Literal[True],
171
+ model_settings: OpenAIModelSettings,
172
+ model_request_parameters: ModelRequestParameters,
176
173
  ) -> AsyncStream[ChatCompletionChunk]:
177
174
  pass
178
175
 
179
176
  @overload
180
177
  async def _completions_create(
181
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
178
+ self,
179
+ messages: list[ModelMessage],
180
+ stream: Literal[False],
181
+ model_settings: OpenAIModelSettings,
182
+ model_request_parameters: ModelRequestParameters,
182
183
  ) -> chat.ChatCompletion:
183
184
  pass
184
185
 
185
186
  async def _completions_create(
186
- self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
187
+ self,
188
+ messages: list[ModelMessage],
189
+ stream: bool,
190
+ model_settings: OpenAIModelSettings,
191
+ model_request_parameters: ModelRequestParameters,
187
192
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
193
+ tools = self._get_tools(model_request_parameters)
194
+
188
195
  # standalone function to make it easier to override
189
- if not self.tools:
196
+ if not tools:
190
197
  tool_choice: Literal['none', 'required', 'auto'] | None = None
191
- elif not self.allow_text_result:
198
+ elif not model_request_parameters.allow_text_result:
192
199
  tool_choice = 'required'
193
200
  else:
194
201
  tool_choice = 'auto'
@@ -196,11 +203,11 @@ class OpenAIAgentModel(AgentModel):
196
203
  openai_messages = list(chain(*(self._map_message(m) for m in messages)))
197
204
 
198
205
  return await self.client.chat.completions.create(
199
- model=self.model_name,
206
+ model=self._model_name,
200
207
  messages=openai_messages,
201
208
  n=1,
202
209
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
203
- tools=self.tools or NOT_GIVEN,
210
+ tools=tools or NOT_GIVEN,
204
211
  tool_choice=tool_choice or NOT_GIVEN,
205
212
  stream=stream,
206
213
  stream_options={'include_usage': True} if stream else NOT_GIVEN,
@@ -212,6 +219,7 @@ class OpenAIAgentModel(AgentModel):
212
219
  presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
213
220
  frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
214
221
  logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
222
+ reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
215
223
  )
216
224
 
217
225
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
@@ -224,7 +232,7 @@ class OpenAIAgentModel(AgentModel):
224
232
  if choice.message.tool_calls is not None:
225
233
  for c in choice.message.tool_calls:
226
234
  items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
227
- return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
235
+ return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
228
236
 
229
237
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
230
238
  """Process a streamed response, and prepare a streaming response to return."""
@@ -234,11 +242,17 @@ class OpenAIAgentModel(AgentModel):
234
242
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
235
243
 
236
244
  return OpenAIStreamedResponse(
237
- _model_name=self.model_name,
245
+ _model_name=self._model_name,
238
246
  _response=peekable_response,
239
247
  _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
240
248
  )
241
249
 
250
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
251
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
252
+ if model_request_parameters.result_tools:
253
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
254
+ return tools
255
+
242
256
  def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
243
257
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
244
258
  if isinstance(message, ModelRequest):
@@ -250,7 +264,7 @@ class OpenAIAgentModel(AgentModel):
250
264
  if isinstance(item, TextPart):
251
265
  texts.append(item.content)
252
266
  elif isinstance(item, ToolCallPart):
253
- tool_calls.append(_map_tool_call(item))
267
+ tool_calls.append(self._map_tool_call(item))
254
268
  else:
255
269
  assert_never(item)
256
270
  message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
@@ -264,6 +278,25 @@ class OpenAIAgentModel(AgentModel):
264
278
  else:
265
279
  assert_never(message)
266
280
 
281
+ @staticmethod
282
+ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
283
+ return chat.ChatCompletionMessageToolCallParam(
284
+ id=_guard_tool_call_id(t=t, model_source='OpenAI'),
285
+ type='function',
286
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
287
+ )
288
+
289
+ @staticmethod
290
+ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
291
+ return {
292
+ 'type': 'function',
293
+ 'function': {
294
+ 'name': f.name,
295
+ 'description': f.description,
296
+ 'parameters': f.parameters_json_schema,
297
+ },
298
+ }
299
+
267
300
  def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
268
301
  for part in message.parts:
269
302
  if isinstance(part, SystemPromptPart):
@@ -329,14 +362,6 @@ class OpenAIStreamedResponse(StreamedResponse):
329
362
  return self._timestamp
330
363
 
331
364
 
332
- def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
333
- return chat.ChatCompletionMessageToolCallParam(
334
- id=_guard_tool_call_id(t=t, model_source='OpenAI'),
335
- type='function',
336
- function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
337
- )
338
-
339
-
340
365
  def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
341
366
  response_usage = response.usage
342
367
  if response_usage is None:
@@ -26,8 +26,8 @@ from ..result import Usage
26
26
  from ..settings import ModelSettings
27
27
  from ..tools import ToolDefinition
28
28
  from . import (
29
- AgentModel,
30
29
  Model,
30
+ ModelRequestParameters,
31
31
  StreamedResponse,
32
32
  )
33
33
  from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
@@ -71,102 +71,92 @@ class TestModel(Model):
71
71
  """If set, these args will be passed to the result tool."""
72
72
  seed: int = 0
73
73
  """Seed for generating random data."""
74
- agent_model_function_tools: list[ToolDefinition] | None = field(default=None, init=False)
75
- """Definition of function tools passed to the model.
74
+ last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
75
+ """The last ModelRequestParameters passed to the model in a request.
76
76
 
77
- This is set when the model is called, so will reflect the function tools from the last step of the last run.
78
- """
79
- agent_model_allow_text_result: bool | None = field(default=None, init=False)
80
- """Whether plain text responses from the model are allowed.
77
+ The ModelRequestParameters contains information about the function and result tools available during request handling.
81
78
 
82
- This is set when the model is called, so will reflect the value from the last step of the last run.
79
+ This is set when a request is made, so will reflect the function tools from the last step of the last run.
83
80
  """
84
- agent_model_result_tools: list[ToolDefinition] | None = field(default=None, init=False)
85
- """Definition of result tools passed to the model.
81
+ _model_name: str = field(default='test', repr=False)
82
+ _system: str | None = field(default=None, repr=False)
86
83
 
87
- This is set when the model is called, so will reflect the result tools from the last step of the last run.
88
- """
84
+ async def request(
85
+ self,
86
+ messages: list[ModelMessage],
87
+ model_settings: ModelSettings | None,
88
+ model_request_parameters: ModelRequestParameters,
89
+ ) -> tuple[ModelResponse, Usage]:
90
+ self.last_model_request_parameters = model_request_parameters
89
91
 
90
- async def agent_model(
92
+ model_response = self._request(messages, model_settings, model_request_parameters)
93
+ usage = _estimate_usage([*messages, model_response])
94
+ return model_response, usage
95
+
96
+ @asynccontextmanager
97
+ async def request_stream(
91
98
  self,
92
- *,
93
- function_tools: list[ToolDefinition],
94
- allow_text_result: bool,
95
- result_tools: list[ToolDefinition],
96
- ) -> AgentModel:
97
- self.agent_model_function_tools = function_tools
98
- self.agent_model_allow_text_result = allow_text_result
99
- self.agent_model_result_tools = result_tools
99
+ messages: list[ModelMessage],
100
+ model_settings: ModelSettings | None,
101
+ model_request_parameters: ModelRequestParameters,
102
+ ) -> AsyncIterator[StreamedResponse]:
103
+ self.last_model_request_parameters = model_request_parameters
104
+
105
+ model_response = self._request(messages, model_settings, model_request_parameters)
106
+ yield TestStreamedResponse(
107
+ _model_name=self._model_name, _structured_response=model_response, _messages=messages
108
+ )
100
109
 
110
+ def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
111
+ return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
112
+
113
+ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]:
101
114
  if self.call_tools == 'all':
102
- tool_calls = [(r.name, r) for r in function_tools]
115
+ return [(r.name, r) for r in model_request_parameters.function_tools]
103
116
  else:
104
- function_tools_lookup = {t.name: t for t in function_tools}
117
+ function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools}
105
118
  tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
106
- tool_calls = [(r.name, r) for r in tools_to_call]
119
+ return [(r.name, r) for r in tools_to_call]
107
120
 
121
+ def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
108
122
  if self.custom_result_text is not None:
109
- assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
123
+ assert (
124
+ model_request_parameters.allow_text_result
125
+ ), 'Plain response not allowed, but `custom_result_text` is set.'
110
126
  assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
111
- result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text)
127
+ return _TextResult(self.custom_result_text)
112
128
  elif self.custom_result_args is not None:
113
- assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
114
- result_tool = result_tools[0]
129
+ assert (
130
+ model_request_parameters.result_tools is not None
131
+ ), 'No result tools provided, but `custom_result_args` is set.'
132
+ result_tool = model_request_parameters.result_tools[0]
115
133
 
116
134
  if k := result_tool.outer_typed_dict_key:
117
- result = _FunctionToolResult({k: self.custom_result_args})
135
+ return _FunctionToolResult({k: self.custom_result_args})
118
136
  else:
119
- result = _FunctionToolResult(self.custom_result_args)
120
- elif allow_text_result:
121
- result = _TextResult(None)
122
- elif result_tools:
123
- result = _FunctionToolResult(None)
137
+ return _FunctionToolResult(self.custom_result_args)
138
+ elif model_request_parameters.allow_text_result:
139
+ return _TextResult(None)
140
+ elif model_request_parameters.result_tools:
141
+ return _FunctionToolResult(None)
124
142
  else:
125
- result = _TextResult(None)
126
-
127
- return TestAgentModel(tool_calls, result, result_tools, self.seed)
128
-
129
- def name(self) -> str:
130
- return 'test-model'
131
-
143
+ return _TextResult(None)
132
144
 
133
- @dataclass
134
- class TestAgentModel(AgentModel):
135
- """Implementation of `AgentModel` for testing purposes."""
136
-
137
- # NOTE: Avoid test discovery by pytest.
138
- __test__ = False
139
-
140
- tool_calls: list[tuple[str, ToolDefinition]]
141
- # left means the text is plain text; right means it's a function call
142
- result: _TextResult | _FunctionToolResult
143
- result_tools: list[ToolDefinition]
144
- seed: int
145
- model_name: str = 'test'
146
-
147
- async def request(
148
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
149
- ) -> tuple[ModelResponse, Usage]:
150
- model_response = self._request(messages, model_settings)
151
- usage = _estimate_usage([*messages, model_response])
152
- return model_response, usage
153
-
154
- @asynccontextmanager
155
- async def request_stream(
156
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
157
- ) -> AsyncIterator[StreamedResponse]:
158
- model_response = self._request(messages, model_settings)
159
- yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
160
-
161
- def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
162
- return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
145
+ def _request(
146
+ self,
147
+ messages: list[ModelMessage],
148
+ model_settings: ModelSettings | None,
149
+ model_request_parameters: ModelRequestParameters,
150
+ ) -> ModelResponse:
151
+ tool_calls = self._get_tool_calls(model_request_parameters)
152
+ result = self._get_result(model_request_parameters)
153
+ result_tools = model_request_parameters.result_tools
163
154
 
164
- def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
165
155
  # if there are tools, the first thing we want to do is call all of them
166
- if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
156
+ if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
167
157
  return ModelResponse(
168
- parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
169
- model_name=self.model_name,
158
+ parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
159
+ model_name=self._model_name,
170
160
  )
171
161
 
172
162
  if messages:
@@ -179,28 +169,26 @@ class TestAgentModel(AgentModel):
179
169
  # Handle retries for both function tools and result tools
180
170
  # Check function tools first
181
171
  retry_parts: list[ModelResponsePart] = [
182
- ToolCallPart(name, self.gen_tool_args(args))
183
- for name, args in self.tool_calls
184
- if name in new_retry_names
172
+ ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
185
173
  ]
186
174
  # Check result tools
187
- if self.result_tools:
175
+ if result_tools:
188
176
  retry_parts.extend(
189
177
  [
190
178
  ToolCallPart(
191
179
  tool.name,
192
- self.result.value
193
- if isinstance(self.result, _FunctionToolResult) and self.result.value is not None
180
+ result.value
181
+ if isinstance(result, _FunctionToolResult) and result.value is not None
194
182
  else self.gen_tool_args(tool),
195
183
  )
196
- for tool in self.result_tools
184
+ for tool in result_tools
197
185
  if tool.name in new_retry_names
198
186
  ]
199
187
  )
200
- return ModelResponse(parts=retry_parts, model_name=self.model_name)
188
+ return ModelResponse(parts=retry_parts, model_name=self._model_name)
201
189
 
202
- if isinstance(self.result, _TextResult):
203
- if (response_text := self.result.value) is None:
190
+ if isinstance(result, _TextResult):
191
+ if (response_text := result.value) is None:
204
192
  # build up details of tool responses
205
193
  output: dict[str, Any] = {}
206
194
  for message in messages:
@@ -210,23 +198,23 @@ class TestAgentModel(AgentModel):
210
198
  output[part.tool_name] = part.content
211
199
  if output:
212
200
  return ModelResponse(
213
- parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
201
+ parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name
214
202
  )
215
203
  else:
216
- return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
204
+ return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name)
217
205
  else:
218
- return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name)
206
+ return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
219
207
  else:
220
- assert self.result_tools, 'No result tools provided'
221
- custom_result_args = self.result.value
222
- result_tool = self.result_tools[self.seed % len(self.result_tools)]
208
+ assert result_tools, 'No result tools provided'
209
+ custom_result_args = result.value
210
+ result_tool = result_tools[self.seed % len(result_tools)]
223
211
  if custom_result_args is not None:
224
212
  return ModelResponse(
225
- parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name
213
+ parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name
226
214
  )
227
215
  else:
228
216
  response_args = self.gen_tool_args(result_tool)
229
- return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name)
217
+ return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name)
230
218
 
231
219
 
232
220
  @dataclass