pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +12 -8
- pydantic_ai/agent.py +2 -2
- pydantic_ai/models/__init__.py +39 -37
- pydantic_ai/models/anthropic.py +69 -66
- pydantic_ai/models/cohere.py +56 -68
- pydantic_ai/models/function.py +58 -60
- pydantic_ai/models/gemini.py +132 -99
- pydantic_ai/models/groq.py +79 -72
- pydantic_ai/models/mistral.py +72 -71
- pydantic_ai/models/openai.py +90 -70
- pydantic_ai/models/test.py +81 -93
- pydantic_ai/models/vertexai.py +38 -44
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.23.dist-info}/METADATA +3 -3
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.23.dist-info}/RECORD +15 -15
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.23.dist-info}/WHEEL +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -29,8 +29,8 @@ from ..messages import (
|
|
|
29
29
|
from ..settings import ModelSettings
|
|
30
30
|
from ..tools import ToolDefinition
|
|
31
31
|
from . import (
|
|
32
|
-
AgentModel,
|
|
33
32
|
Model,
|
|
33
|
+
ModelRequestParameters,
|
|
34
34
|
StreamedResponse,
|
|
35
35
|
cached_async_http_client,
|
|
36
36
|
check_allow_model_requests,
|
|
@@ -46,10 +46,16 @@ except ImportError as _import_error:
|
|
|
46
46
|
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
47
47
|
) from _import_error
|
|
48
48
|
|
|
49
|
-
OpenAIModelName = Union[
|
|
49
|
+
OpenAIModelName = Union[str, ChatModel]
|
|
50
50
|
"""
|
|
51
|
+
Possible OpenAI model names.
|
|
52
|
+
|
|
53
|
+
Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but
|
|
54
|
+
allow any name in the type hints.
|
|
55
|
+
See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list.
|
|
56
|
+
|
|
51
57
|
Using this more broad type for the model name instead of the ChatModel definition
|
|
52
|
-
allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
|
|
58
|
+
allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
|
|
53
59
|
"""
|
|
54
60
|
|
|
55
61
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
@@ -58,7 +64,12 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
|
58
64
|
class OpenAIModelSettings(ModelSettings):
|
|
59
65
|
"""Settings used for an OpenAI model request."""
|
|
60
66
|
|
|
61
|
-
|
|
67
|
+
openai_reasoning_effort: chat.ChatCompletionReasoningEffort
|
|
68
|
+
"""
|
|
69
|
+
Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
|
|
70
|
+
Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
|
|
71
|
+
result in faster responses and fewer tokens used on reasoning in a response.
|
|
72
|
+
"""
|
|
62
73
|
|
|
63
74
|
|
|
64
75
|
@dataclass(init=False)
|
|
@@ -70,10 +81,12 @@ class OpenAIModel(Model):
|
|
|
70
81
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
71
82
|
"""
|
|
72
83
|
|
|
73
|
-
model_name: OpenAIModelName
|
|
74
84
|
client: AsyncOpenAI = field(repr=False)
|
|
75
85
|
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
76
86
|
|
|
87
|
+
_model_name: OpenAIModelName = field(repr=False)
|
|
88
|
+
_system: str | None = field(repr=False)
|
|
89
|
+
|
|
77
90
|
def __init__(
|
|
78
91
|
self,
|
|
79
92
|
model_name: OpenAIModelName,
|
|
@@ -83,6 +96,7 @@ class OpenAIModel(Model):
|
|
|
83
96
|
openai_client: AsyncOpenAI | None = None,
|
|
84
97
|
http_client: AsyncHTTPClient | None = None,
|
|
85
98
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
99
|
+
system: str | None = 'openai',
|
|
86
100
|
):
|
|
87
101
|
"""Initialize an OpenAI model.
|
|
88
102
|
|
|
@@ -100,8 +114,10 @@ class OpenAIModel(Model):
|
|
|
100
114
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
101
115
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
102
116
|
In the future, this may be inferred from the model name.
|
|
117
|
+
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
|
|
118
|
+
customize the `base_url` and `api_key` to use a different provider.
|
|
103
119
|
"""
|
|
104
|
-
self.
|
|
120
|
+
self._model_name = model_name
|
|
105
121
|
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
106
122
|
# openai compatible models do not always need an API key.
|
|
107
123
|
if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
|
|
@@ -116,84 +132,70 @@ class OpenAIModel(Model):
|
|
|
116
132
|
else:
|
|
117
133
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
118
134
|
self.system_prompt_role = system_prompt_role
|
|
119
|
-
|
|
120
|
-
async def agent_model(
|
|
121
|
-
self,
|
|
122
|
-
*,
|
|
123
|
-
function_tools: list[ToolDefinition],
|
|
124
|
-
allow_text_result: bool,
|
|
125
|
-
result_tools: list[ToolDefinition],
|
|
126
|
-
) -> AgentModel:
|
|
127
|
-
check_allow_model_requests()
|
|
128
|
-
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
129
|
-
if result_tools:
|
|
130
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
131
|
-
return OpenAIAgentModel(
|
|
132
|
-
self.client,
|
|
133
|
-
self.model_name,
|
|
134
|
-
allow_text_result,
|
|
135
|
-
tools,
|
|
136
|
-
self.system_prompt_role,
|
|
137
|
-
)
|
|
135
|
+
self._system = system
|
|
138
136
|
|
|
139
137
|
def name(self) -> str:
|
|
140
|
-
return f'openai:{self.
|
|
141
|
-
|
|
142
|
-
@staticmethod
|
|
143
|
-
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
144
|
-
return {
|
|
145
|
-
'type': 'function',
|
|
146
|
-
'function': {
|
|
147
|
-
'name': f.name,
|
|
148
|
-
'description': f.description,
|
|
149
|
-
'parameters': f.parameters_json_schema,
|
|
150
|
-
},
|
|
151
|
-
}
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
@dataclass
|
|
155
|
-
class OpenAIAgentModel(AgentModel):
|
|
156
|
-
"""Implementation of `AgentModel` for OpenAI models."""
|
|
157
|
-
|
|
158
|
-
client: AsyncOpenAI
|
|
159
|
-
model_name: OpenAIModelName
|
|
160
|
-
allow_text_result: bool
|
|
161
|
-
tools: list[chat.ChatCompletionToolParam]
|
|
162
|
-
system_prompt_role: OpenAISystemPromptRole | None
|
|
138
|
+
return f'openai:{self._model_name}'
|
|
163
139
|
|
|
164
140
|
async def request(
|
|
165
|
-
self,
|
|
141
|
+
self,
|
|
142
|
+
messages: list[ModelMessage],
|
|
143
|
+
model_settings: ModelSettings | None,
|
|
144
|
+
model_request_parameters: ModelRequestParameters,
|
|
166
145
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
167
|
-
|
|
146
|
+
check_allow_model_requests()
|
|
147
|
+
response = await self._completions_create(
|
|
148
|
+
messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
149
|
+
)
|
|
168
150
|
return self._process_response(response), _map_usage(response)
|
|
169
151
|
|
|
170
152
|
@asynccontextmanager
|
|
171
153
|
async def request_stream(
|
|
172
|
-
self,
|
|
154
|
+
self,
|
|
155
|
+
messages: list[ModelMessage],
|
|
156
|
+
model_settings: ModelSettings | None,
|
|
157
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
158
|
) -> AsyncIterator[StreamedResponse]:
|
|
174
|
-
|
|
159
|
+
check_allow_model_requests()
|
|
160
|
+
response = await self._completions_create(
|
|
161
|
+
messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
162
|
+
)
|
|
175
163
|
async with response:
|
|
176
164
|
yield await self._process_streamed_response(response)
|
|
177
165
|
|
|
178
166
|
@overload
|
|
179
167
|
async def _completions_create(
|
|
180
|
-
self,
|
|
168
|
+
self,
|
|
169
|
+
messages: list[ModelMessage],
|
|
170
|
+
stream: Literal[True],
|
|
171
|
+
model_settings: OpenAIModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
181
173
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
182
174
|
pass
|
|
183
175
|
|
|
184
176
|
@overload
|
|
185
177
|
async def _completions_create(
|
|
186
|
-
self,
|
|
178
|
+
self,
|
|
179
|
+
messages: list[ModelMessage],
|
|
180
|
+
stream: Literal[False],
|
|
181
|
+
model_settings: OpenAIModelSettings,
|
|
182
|
+
model_request_parameters: ModelRequestParameters,
|
|
187
183
|
) -> chat.ChatCompletion:
|
|
188
184
|
pass
|
|
189
185
|
|
|
190
186
|
async def _completions_create(
|
|
191
|
-
self,
|
|
187
|
+
self,
|
|
188
|
+
messages: list[ModelMessage],
|
|
189
|
+
stream: bool,
|
|
190
|
+
model_settings: OpenAIModelSettings,
|
|
191
|
+
model_request_parameters: ModelRequestParameters,
|
|
192
192
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
193
|
+
tools = self._get_tools(model_request_parameters)
|
|
194
|
+
|
|
193
195
|
# standalone function to make it easier to override
|
|
194
|
-
if not
|
|
196
|
+
if not tools:
|
|
195
197
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
196
|
-
elif not
|
|
198
|
+
elif not model_request_parameters.allow_text_result:
|
|
197
199
|
tool_choice = 'required'
|
|
198
200
|
else:
|
|
199
201
|
tool_choice = 'auto'
|
|
@@ -201,11 +203,11 @@ class OpenAIAgentModel(AgentModel):
|
|
|
201
203
|
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
202
204
|
|
|
203
205
|
return await self.client.chat.completions.create(
|
|
204
|
-
model=self.
|
|
206
|
+
model=self._model_name,
|
|
205
207
|
messages=openai_messages,
|
|
206
208
|
n=1,
|
|
207
209
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
208
|
-
tools=
|
|
210
|
+
tools=tools or NOT_GIVEN,
|
|
209
211
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
210
212
|
stream=stream,
|
|
211
213
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
@@ -217,6 +219,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
217
219
|
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
218
220
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
219
221
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
222
|
+
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
220
223
|
)
|
|
221
224
|
|
|
222
225
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
@@ -229,7 +232,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
229
232
|
if choice.message.tool_calls is not None:
|
|
230
233
|
for c in choice.message.tool_calls:
|
|
231
234
|
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
232
|
-
return ModelResponse(items, model_name=self.
|
|
235
|
+
return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
|
|
233
236
|
|
|
234
237
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
235
238
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -239,11 +242,17 @@ class OpenAIAgentModel(AgentModel):
|
|
|
239
242
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
240
243
|
|
|
241
244
|
return OpenAIStreamedResponse(
|
|
242
|
-
_model_name=self.
|
|
245
|
+
_model_name=self._model_name,
|
|
243
246
|
_response=peekable_response,
|
|
244
247
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
245
248
|
)
|
|
246
249
|
|
|
250
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
251
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
252
|
+
if model_request_parameters.result_tools:
|
|
253
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
254
|
+
return tools
|
|
255
|
+
|
|
247
256
|
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
248
257
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
249
258
|
if isinstance(message, ModelRequest):
|
|
@@ -255,7 +264,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
255
264
|
if isinstance(item, TextPart):
|
|
256
265
|
texts.append(item.content)
|
|
257
266
|
elif isinstance(item, ToolCallPart):
|
|
258
|
-
tool_calls.append(_map_tool_call(item))
|
|
267
|
+
tool_calls.append(self._map_tool_call(item))
|
|
259
268
|
else:
|
|
260
269
|
assert_never(item)
|
|
261
270
|
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
@@ -269,6 +278,25 @@ class OpenAIAgentModel(AgentModel):
|
|
|
269
278
|
else:
|
|
270
279
|
assert_never(message)
|
|
271
280
|
|
|
281
|
+
@staticmethod
|
|
282
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
283
|
+
return chat.ChatCompletionMessageToolCallParam(
|
|
284
|
+
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
285
|
+
type='function',
|
|
286
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
291
|
+
return {
|
|
292
|
+
'type': 'function',
|
|
293
|
+
'function': {
|
|
294
|
+
'name': f.name,
|
|
295
|
+
'description': f.description,
|
|
296
|
+
'parameters': f.parameters_json_schema,
|
|
297
|
+
},
|
|
298
|
+
}
|
|
299
|
+
|
|
272
300
|
def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
273
301
|
for part in message.parts:
|
|
274
302
|
if isinstance(part, SystemPromptPart):
|
|
@@ -334,14 +362,6 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
334
362
|
return self._timestamp
|
|
335
363
|
|
|
336
364
|
|
|
337
|
-
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
338
|
-
return chat.ChatCompletionMessageToolCallParam(
|
|
339
|
-
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
340
|
-
type='function',
|
|
341
|
-
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
|
|
345
365
|
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
|
|
346
366
|
response_usage = response.usage
|
|
347
367
|
if response_usage is None:
|
pydantic_ai/models/test.py
CHANGED
|
@@ -26,8 +26,8 @@ from ..result import Usage
|
|
|
26
26
|
from ..settings import ModelSettings
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
29
|
-
AgentModel,
|
|
30
29
|
Model,
|
|
30
|
+
ModelRequestParameters,
|
|
31
31
|
StreamedResponse,
|
|
32
32
|
)
|
|
33
33
|
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
@@ -71,102 +71,92 @@ class TestModel(Model):
|
|
|
71
71
|
"""If set, these args will be passed to the result tool."""
|
|
72
72
|
seed: int = 0
|
|
73
73
|
"""Seed for generating random data."""
|
|
74
|
-
|
|
75
|
-
"""
|
|
74
|
+
last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
|
|
75
|
+
"""The last ModelRequestParameters passed to the model in a request.
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
"""
|
|
79
|
-
agent_model_allow_text_result: bool | None = field(default=None, init=False)
|
|
80
|
-
"""Whether plain text responses from the model are allowed.
|
|
77
|
+
The ModelRequestParameters contains information about the function and result tools available during request handling.
|
|
81
78
|
|
|
82
|
-
This is set when
|
|
79
|
+
This is set when a request is made, so will reflect the function tools from the last step of the last run.
|
|
83
80
|
"""
|
|
84
|
-
|
|
85
|
-
|
|
81
|
+
_model_name: str = field(default='test', repr=False)
|
|
82
|
+
_system: str | None = field(default=None, repr=False)
|
|
86
83
|
|
|
87
|
-
|
|
88
|
-
|
|
84
|
+
async def request(
|
|
85
|
+
self,
|
|
86
|
+
messages: list[ModelMessage],
|
|
87
|
+
model_settings: ModelSettings | None,
|
|
88
|
+
model_request_parameters: ModelRequestParameters,
|
|
89
|
+
) -> tuple[ModelResponse, Usage]:
|
|
90
|
+
self.last_model_request_parameters = model_request_parameters
|
|
89
91
|
|
|
90
|
-
|
|
92
|
+
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
93
|
+
usage = _estimate_usage([*messages, model_response])
|
|
94
|
+
return model_response, usage
|
|
95
|
+
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def request_stream(
|
|
91
98
|
self,
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
self.
|
|
99
|
-
|
|
99
|
+
messages: list[ModelMessage],
|
|
100
|
+
model_settings: ModelSettings | None,
|
|
101
|
+
model_request_parameters: ModelRequestParameters,
|
|
102
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
103
|
+
self.last_model_request_parameters = model_request_parameters
|
|
104
|
+
|
|
105
|
+
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
106
|
+
yield TestStreamedResponse(
|
|
107
|
+
_model_name=self._model_name, _structured_response=model_response, _messages=messages
|
|
108
|
+
)
|
|
100
109
|
|
|
110
|
+
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
111
|
+
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
112
|
+
|
|
113
|
+
def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]:
|
|
101
114
|
if self.call_tools == 'all':
|
|
102
|
-
|
|
115
|
+
return [(r.name, r) for r in model_request_parameters.function_tools]
|
|
103
116
|
else:
|
|
104
|
-
function_tools_lookup = {t.name: t for t in function_tools}
|
|
117
|
+
function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools}
|
|
105
118
|
tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
|
|
106
|
-
|
|
119
|
+
return [(r.name, r) for r in tools_to_call]
|
|
107
120
|
|
|
121
|
+
def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
|
|
108
122
|
if self.custom_result_text is not None:
|
|
109
|
-
assert
|
|
123
|
+
assert (
|
|
124
|
+
model_request_parameters.allow_text_result
|
|
125
|
+
), 'Plain response not allowed, but `custom_result_text` is set.'
|
|
110
126
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
111
|
-
|
|
127
|
+
return _TextResult(self.custom_result_text)
|
|
112
128
|
elif self.custom_result_args is not None:
|
|
113
|
-
assert
|
|
114
|
-
|
|
129
|
+
assert (
|
|
130
|
+
model_request_parameters.result_tools is not None
|
|
131
|
+
), 'No result tools provided, but `custom_result_args` is set.'
|
|
132
|
+
result_tool = model_request_parameters.result_tools[0]
|
|
115
133
|
|
|
116
134
|
if k := result_tool.outer_typed_dict_key:
|
|
117
|
-
|
|
135
|
+
return _FunctionToolResult({k: self.custom_result_args})
|
|
118
136
|
else:
|
|
119
|
-
|
|
120
|
-
elif allow_text_result:
|
|
121
|
-
|
|
122
|
-
elif result_tools:
|
|
123
|
-
|
|
137
|
+
return _FunctionToolResult(self.custom_result_args)
|
|
138
|
+
elif model_request_parameters.allow_text_result:
|
|
139
|
+
return _TextResult(None)
|
|
140
|
+
elif model_request_parameters.result_tools:
|
|
141
|
+
return _FunctionToolResult(None)
|
|
124
142
|
else:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
return TestAgentModel(tool_calls, result, result_tools, self.seed)
|
|
128
|
-
|
|
129
|
-
def name(self) -> str:
|
|
130
|
-
return 'test-model'
|
|
131
|
-
|
|
143
|
+
return _TextResult(None)
|
|
132
144
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
result: _TextResult | _FunctionToolResult
|
|
143
|
-
result_tools: list[ToolDefinition]
|
|
144
|
-
seed: int
|
|
145
|
-
model_name: str = 'test'
|
|
146
|
-
|
|
147
|
-
async def request(
|
|
148
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
149
|
-
) -> tuple[ModelResponse, Usage]:
|
|
150
|
-
model_response = self._request(messages, model_settings)
|
|
151
|
-
usage = _estimate_usage([*messages, model_response])
|
|
152
|
-
return model_response, usage
|
|
153
|
-
|
|
154
|
-
@asynccontextmanager
|
|
155
|
-
async def request_stream(
|
|
156
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
|
-
) -> AsyncIterator[StreamedResponse]:
|
|
158
|
-
model_response = self._request(messages, model_settings)
|
|
159
|
-
yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
|
|
160
|
-
|
|
161
|
-
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
162
|
-
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
145
|
+
def _request(
|
|
146
|
+
self,
|
|
147
|
+
messages: list[ModelMessage],
|
|
148
|
+
model_settings: ModelSettings | None,
|
|
149
|
+
model_request_parameters: ModelRequestParameters,
|
|
150
|
+
) -> ModelResponse:
|
|
151
|
+
tool_calls = self._get_tool_calls(model_request_parameters)
|
|
152
|
+
result = self._get_result(model_request_parameters)
|
|
153
|
+
result_tools = model_request_parameters.result_tools
|
|
163
154
|
|
|
164
|
-
def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
|
|
165
155
|
# if there are tools, the first thing we want to do is call all of them
|
|
166
|
-
if
|
|
156
|
+
if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
167
157
|
return ModelResponse(
|
|
168
|
-
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in
|
|
169
|
-
model_name=self.
|
|
158
|
+
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
|
|
159
|
+
model_name=self._model_name,
|
|
170
160
|
)
|
|
171
161
|
|
|
172
162
|
if messages:
|
|
@@ -179,28 +169,26 @@ class TestAgentModel(AgentModel):
|
|
|
179
169
|
# Handle retries for both function tools and result tools
|
|
180
170
|
# Check function tools first
|
|
181
171
|
retry_parts: list[ModelResponsePart] = [
|
|
182
|
-
ToolCallPart(name, self.gen_tool_args(args))
|
|
183
|
-
for name, args in self.tool_calls
|
|
184
|
-
if name in new_retry_names
|
|
172
|
+
ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
|
|
185
173
|
]
|
|
186
174
|
# Check result tools
|
|
187
|
-
if
|
|
175
|
+
if result_tools:
|
|
188
176
|
retry_parts.extend(
|
|
189
177
|
[
|
|
190
178
|
ToolCallPart(
|
|
191
179
|
tool.name,
|
|
192
|
-
|
|
193
|
-
if isinstance(
|
|
180
|
+
result.value
|
|
181
|
+
if isinstance(result, _FunctionToolResult) and result.value is not None
|
|
194
182
|
else self.gen_tool_args(tool),
|
|
195
183
|
)
|
|
196
|
-
for tool in
|
|
184
|
+
for tool in result_tools
|
|
197
185
|
if tool.name in new_retry_names
|
|
198
186
|
]
|
|
199
187
|
)
|
|
200
|
-
return ModelResponse(parts=retry_parts, model_name=self.
|
|
188
|
+
return ModelResponse(parts=retry_parts, model_name=self._model_name)
|
|
201
189
|
|
|
202
|
-
if isinstance(
|
|
203
|
-
if (response_text :=
|
|
190
|
+
if isinstance(result, _TextResult):
|
|
191
|
+
if (response_text := result.value) is None:
|
|
204
192
|
# build up details of tool responses
|
|
205
193
|
output: dict[str, Any] = {}
|
|
206
194
|
for message in messages:
|
|
@@ -210,23 +198,23 @@ class TestAgentModel(AgentModel):
|
|
|
210
198
|
output[part.tool_name] = part.content
|
|
211
199
|
if output:
|
|
212
200
|
return ModelResponse(
|
|
213
|
-
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.
|
|
201
|
+
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name
|
|
214
202
|
)
|
|
215
203
|
else:
|
|
216
|
-
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.
|
|
204
|
+
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name)
|
|
217
205
|
else:
|
|
218
|
-
return ModelResponse(parts=[TextPart(response_text)], model_name=self.
|
|
206
|
+
return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
|
|
219
207
|
else:
|
|
220
|
-
assert
|
|
221
|
-
custom_result_args =
|
|
222
|
-
result_tool =
|
|
208
|
+
assert result_tools, 'No result tools provided'
|
|
209
|
+
custom_result_args = result.value
|
|
210
|
+
result_tool = result_tools[self.seed % len(result_tools)]
|
|
223
211
|
if custom_result_args is not None:
|
|
224
212
|
return ModelResponse(
|
|
225
|
-
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.
|
|
213
|
+
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name
|
|
226
214
|
)
|
|
227
215
|
else:
|
|
228
216
|
response_args = self.gen_tool_args(result_tool)
|
|
229
|
-
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.
|
|
217
|
+
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name)
|
|
230
218
|
|
|
231
219
|
|
|
232
220
|
@dataclass
|