pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +774 -0
- pydantic_ai/agent.py +183 -555
- pydantic_ai/models/__init__.py +43 -37
- pydantic_ai/models/anthropic.py +69 -66
- pydantic_ai/models/cohere.py +56 -68
- pydantic_ai/models/function.py +58 -60
- pydantic_ai/models/gemini.py +139 -100
- pydantic_ai/models/groq.py +79 -72
- pydantic_ai/models/mistral.py +72 -71
- pydantic_ai/models/openai.py +96 -71
- pydantic_ai/models/test.py +81 -93
- pydantic_ai/models/vertexai.py +38 -44
- pydantic_ai/settings.py +5 -0
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/METADATA +3 -4
- pydantic_ai_slim-0.0.23.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.21.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/WHEEL +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -28,8 +29,8 @@ from ..messages import (
|
|
|
28
29
|
from ..settings import ModelSettings
|
|
29
30
|
from ..tools import ToolDefinition
|
|
30
31
|
from . import (
|
|
31
|
-
AgentModel,
|
|
32
32
|
Model,
|
|
33
|
+
ModelRequestParameters,
|
|
33
34
|
StreamedResponse,
|
|
34
35
|
cached_async_http_client,
|
|
35
36
|
check_allow_model_requests,
|
|
@@ -45,10 +46,16 @@ except ImportError as _import_error:
|
|
|
45
46
|
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
46
47
|
) from _import_error
|
|
47
48
|
|
|
48
|
-
OpenAIModelName = Union[
|
|
49
|
+
OpenAIModelName = Union[str, ChatModel]
|
|
49
50
|
"""
|
|
51
|
+
Possible OpenAI model names.
|
|
52
|
+
|
|
53
|
+
Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but
|
|
54
|
+
allow any name in the type hints.
|
|
55
|
+
See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list.
|
|
56
|
+
|
|
50
57
|
Using this more broad type for the model name instead of the ChatModel definition
|
|
51
|
-
allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
|
|
58
|
+
allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
|
|
52
59
|
"""
|
|
53
60
|
|
|
54
61
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
@@ -57,7 +64,12 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
|
57
64
|
class OpenAIModelSettings(ModelSettings):
|
|
58
65
|
"""Settings used for an OpenAI model request."""
|
|
59
66
|
|
|
60
|
-
|
|
67
|
+
openai_reasoning_effort: chat.ChatCompletionReasoningEffort
|
|
68
|
+
"""
|
|
69
|
+
Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
|
|
70
|
+
Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
|
|
71
|
+
result in faster responses and fewer tokens used on reasoning in a response.
|
|
72
|
+
"""
|
|
61
73
|
|
|
62
74
|
|
|
63
75
|
@dataclass(init=False)
|
|
@@ -69,10 +81,12 @@ class OpenAIModel(Model):
|
|
|
69
81
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
70
82
|
"""
|
|
71
83
|
|
|
72
|
-
model_name: OpenAIModelName
|
|
73
84
|
client: AsyncOpenAI = field(repr=False)
|
|
74
85
|
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
75
86
|
|
|
87
|
+
_model_name: OpenAIModelName = field(repr=False)
|
|
88
|
+
_system: str | None = field(repr=False)
|
|
89
|
+
|
|
76
90
|
def __init__(
|
|
77
91
|
self,
|
|
78
92
|
model_name: OpenAIModelName,
|
|
@@ -82,6 +96,7 @@ class OpenAIModel(Model):
|
|
|
82
96
|
openai_client: AsyncOpenAI | None = None,
|
|
83
97
|
http_client: AsyncHTTPClient | None = None,
|
|
84
98
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
99
|
+
system: str | None = 'openai',
|
|
85
100
|
):
|
|
86
101
|
"""Initialize an OpenAI model.
|
|
87
102
|
|
|
@@ -99,9 +114,15 @@ class OpenAIModel(Model):
|
|
|
99
114
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
100
115
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
101
116
|
In the future, this may be inferred from the model name.
|
|
117
|
+
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
|
|
118
|
+
customize the `base_url` and `api_key` to use a different provider.
|
|
102
119
|
"""
|
|
103
|
-
self.
|
|
104
|
-
|
|
120
|
+
self._model_name = model_name
|
|
121
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
122
|
+
# openai compatible models do not always need an API key.
|
|
123
|
+
if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
|
|
124
|
+
api_key = ''
|
|
125
|
+
elif openai_client is not None:
|
|
105
126
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
106
127
|
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
107
128
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
@@ -111,84 +132,70 @@ class OpenAIModel(Model):
|
|
|
111
132
|
else:
|
|
112
133
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
113
134
|
self.system_prompt_role = system_prompt_role
|
|
114
|
-
|
|
115
|
-
async def agent_model(
|
|
116
|
-
self,
|
|
117
|
-
*,
|
|
118
|
-
function_tools: list[ToolDefinition],
|
|
119
|
-
allow_text_result: bool,
|
|
120
|
-
result_tools: list[ToolDefinition],
|
|
121
|
-
) -> AgentModel:
|
|
122
|
-
check_allow_model_requests()
|
|
123
|
-
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
124
|
-
if result_tools:
|
|
125
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
126
|
-
return OpenAIAgentModel(
|
|
127
|
-
self.client,
|
|
128
|
-
self.model_name,
|
|
129
|
-
allow_text_result,
|
|
130
|
-
tools,
|
|
131
|
-
self.system_prompt_role,
|
|
132
|
-
)
|
|
135
|
+
self._system = system
|
|
133
136
|
|
|
134
137
|
def name(self) -> str:
|
|
135
|
-
return f'openai:{self.
|
|
136
|
-
|
|
137
|
-
@staticmethod
|
|
138
|
-
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
139
|
-
return {
|
|
140
|
-
'type': 'function',
|
|
141
|
-
'function': {
|
|
142
|
-
'name': f.name,
|
|
143
|
-
'description': f.description,
|
|
144
|
-
'parameters': f.parameters_json_schema,
|
|
145
|
-
},
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
@dataclass
|
|
150
|
-
class OpenAIAgentModel(AgentModel):
|
|
151
|
-
"""Implementation of `AgentModel` for OpenAI models."""
|
|
152
|
-
|
|
153
|
-
client: AsyncOpenAI
|
|
154
|
-
model_name: OpenAIModelName
|
|
155
|
-
allow_text_result: bool
|
|
156
|
-
tools: list[chat.ChatCompletionToolParam]
|
|
157
|
-
system_prompt_role: OpenAISystemPromptRole | None
|
|
138
|
+
return f'openai:{self._model_name}'
|
|
158
139
|
|
|
159
140
|
async def request(
|
|
160
|
-
self,
|
|
141
|
+
self,
|
|
142
|
+
messages: list[ModelMessage],
|
|
143
|
+
model_settings: ModelSettings | None,
|
|
144
|
+
model_request_parameters: ModelRequestParameters,
|
|
161
145
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
162
|
-
|
|
146
|
+
check_allow_model_requests()
|
|
147
|
+
response = await self._completions_create(
|
|
148
|
+
messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
149
|
+
)
|
|
163
150
|
return self._process_response(response), _map_usage(response)
|
|
164
151
|
|
|
165
152
|
@asynccontextmanager
|
|
166
153
|
async def request_stream(
|
|
167
|
-
self,
|
|
154
|
+
self,
|
|
155
|
+
messages: list[ModelMessage],
|
|
156
|
+
model_settings: ModelSettings | None,
|
|
157
|
+
model_request_parameters: ModelRequestParameters,
|
|
168
158
|
) -> AsyncIterator[StreamedResponse]:
|
|
169
|
-
|
|
159
|
+
check_allow_model_requests()
|
|
160
|
+
response = await self._completions_create(
|
|
161
|
+
messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
162
|
+
)
|
|
170
163
|
async with response:
|
|
171
164
|
yield await self._process_streamed_response(response)
|
|
172
165
|
|
|
173
166
|
@overload
|
|
174
167
|
async def _completions_create(
|
|
175
|
-
self,
|
|
168
|
+
self,
|
|
169
|
+
messages: list[ModelMessage],
|
|
170
|
+
stream: Literal[True],
|
|
171
|
+
model_settings: OpenAIModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
176
173
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
177
174
|
pass
|
|
178
175
|
|
|
179
176
|
@overload
|
|
180
177
|
async def _completions_create(
|
|
181
|
-
self,
|
|
178
|
+
self,
|
|
179
|
+
messages: list[ModelMessage],
|
|
180
|
+
stream: Literal[False],
|
|
181
|
+
model_settings: OpenAIModelSettings,
|
|
182
|
+
model_request_parameters: ModelRequestParameters,
|
|
182
183
|
) -> chat.ChatCompletion:
|
|
183
184
|
pass
|
|
184
185
|
|
|
185
186
|
async def _completions_create(
|
|
186
|
-
self,
|
|
187
|
+
self,
|
|
188
|
+
messages: list[ModelMessage],
|
|
189
|
+
stream: bool,
|
|
190
|
+
model_settings: OpenAIModelSettings,
|
|
191
|
+
model_request_parameters: ModelRequestParameters,
|
|
187
192
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
193
|
+
tools = self._get_tools(model_request_parameters)
|
|
194
|
+
|
|
188
195
|
# standalone function to make it easier to override
|
|
189
|
-
if not
|
|
196
|
+
if not tools:
|
|
190
197
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
191
|
-
elif not
|
|
198
|
+
elif not model_request_parameters.allow_text_result:
|
|
192
199
|
tool_choice = 'required'
|
|
193
200
|
else:
|
|
194
201
|
tool_choice = 'auto'
|
|
@@ -196,11 +203,11 @@ class OpenAIAgentModel(AgentModel):
|
|
|
196
203
|
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
197
204
|
|
|
198
205
|
return await self.client.chat.completions.create(
|
|
199
|
-
model=self.
|
|
206
|
+
model=self._model_name,
|
|
200
207
|
messages=openai_messages,
|
|
201
208
|
n=1,
|
|
202
209
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
203
|
-
tools=
|
|
210
|
+
tools=tools or NOT_GIVEN,
|
|
204
211
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
205
212
|
stream=stream,
|
|
206
213
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
@@ -212,6 +219,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
212
219
|
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
213
220
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
214
221
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
222
|
+
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
215
223
|
)
|
|
216
224
|
|
|
217
225
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
@@ -224,7 +232,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
224
232
|
if choice.message.tool_calls is not None:
|
|
225
233
|
for c in choice.message.tool_calls:
|
|
226
234
|
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
227
|
-
return ModelResponse(items, model_name=self.
|
|
235
|
+
return ModelResponse(items, model_name=self._model_name, timestamp=timestamp)
|
|
228
236
|
|
|
229
237
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
230
238
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -234,11 +242,17 @@ class OpenAIAgentModel(AgentModel):
|
|
|
234
242
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
235
243
|
|
|
236
244
|
return OpenAIStreamedResponse(
|
|
237
|
-
_model_name=self.
|
|
245
|
+
_model_name=self._model_name,
|
|
238
246
|
_response=peekable_response,
|
|
239
247
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
240
248
|
)
|
|
241
249
|
|
|
250
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
251
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
252
|
+
if model_request_parameters.result_tools:
|
|
253
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
254
|
+
return tools
|
|
255
|
+
|
|
242
256
|
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
243
257
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
244
258
|
if isinstance(message, ModelRequest):
|
|
@@ -250,7 +264,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
250
264
|
if isinstance(item, TextPart):
|
|
251
265
|
texts.append(item.content)
|
|
252
266
|
elif isinstance(item, ToolCallPart):
|
|
253
|
-
tool_calls.append(_map_tool_call(item))
|
|
267
|
+
tool_calls.append(self._map_tool_call(item))
|
|
254
268
|
else:
|
|
255
269
|
assert_never(item)
|
|
256
270
|
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
@@ -264,6 +278,25 @@ class OpenAIAgentModel(AgentModel):
|
|
|
264
278
|
else:
|
|
265
279
|
assert_never(message)
|
|
266
280
|
|
|
281
|
+
@staticmethod
|
|
282
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
283
|
+
return chat.ChatCompletionMessageToolCallParam(
|
|
284
|
+
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
285
|
+
type='function',
|
|
286
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
291
|
+
return {
|
|
292
|
+
'type': 'function',
|
|
293
|
+
'function': {
|
|
294
|
+
'name': f.name,
|
|
295
|
+
'description': f.description,
|
|
296
|
+
'parameters': f.parameters_json_schema,
|
|
297
|
+
},
|
|
298
|
+
}
|
|
299
|
+
|
|
267
300
|
def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
268
301
|
for part in message.parts:
|
|
269
302
|
if isinstance(part, SystemPromptPart):
|
|
@@ -329,14 +362,6 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
329
362
|
return self._timestamp
|
|
330
363
|
|
|
331
364
|
|
|
332
|
-
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
333
|
-
return chat.ChatCompletionMessageToolCallParam(
|
|
334
|
-
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
335
|
-
type='function',
|
|
336
|
-
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
|
|
340
365
|
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
|
|
341
366
|
response_usage = response.usage
|
|
342
367
|
if response_usage is None:
|
pydantic_ai/models/test.py
CHANGED
|
@@ -26,8 +26,8 @@ from ..result import Usage
|
|
|
26
26
|
from ..settings import ModelSettings
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
29
|
-
AgentModel,
|
|
30
29
|
Model,
|
|
30
|
+
ModelRequestParameters,
|
|
31
31
|
StreamedResponse,
|
|
32
32
|
)
|
|
33
33
|
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
@@ -71,102 +71,92 @@ class TestModel(Model):
|
|
|
71
71
|
"""If set, these args will be passed to the result tool."""
|
|
72
72
|
seed: int = 0
|
|
73
73
|
"""Seed for generating random data."""
|
|
74
|
-
|
|
75
|
-
"""
|
|
74
|
+
last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
|
|
75
|
+
"""The last ModelRequestParameters passed to the model in a request.
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
"""
|
|
79
|
-
agent_model_allow_text_result: bool | None = field(default=None, init=False)
|
|
80
|
-
"""Whether plain text responses from the model are allowed.
|
|
77
|
+
The ModelRequestParameters contains information about the function and result tools available during request handling.
|
|
81
78
|
|
|
82
|
-
This is set when
|
|
79
|
+
This is set when a request is made, so will reflect the function tools from the last step of the last run.
|
|
83
80
|
"""
|
|
84
|
-
|
|
85
|
-
|
|
81
|
+
_model_name: str = field(default='test', repr=False)
|
|
82
|
+
_system: str | None = field(default=None, repr=False)
|
|
86
83
|
|
|
87
|
-
|
|
88
|
-
|
|
84
|
+
async def request(
|
|
85
|
+
self,
|
|
86
|
+
messages: list[ModelMessage],
|
|
87
|
+
model_settings: ModelSettings | None,
|
|
88
|
+
model_request_parameters: ModelRequestParameters,
|
|
89
|
+
) -> tuple[ModelResponse, Usage]:
|
|
90
|
+
self.last_model_request_parameters = model_request_parameters
|
|
89
91
|
|
|
90
|
-
|
|
92
|
+
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
93
|
+
usage = _estimate_usage([*messages, model_response])
|
|
94
|
+
return model_response, usage
|
|
95
|
+
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def request_stream(
|
|
91
98
|
self,
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
self.
|
|
99
|
-
|
|
99
|
+
messages: list[ModelMessage],
|
|
100
|
+
model_settings: ModelSettings | None,
|
|
101
|
+
model_request_parameters: ModelRequestParameters,
|
|
102
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
103
|
+
self.last_model_request_parameters = model_request_parameters
|
|
104
|
+
|
|
105
|
+
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
106
|
+
yield TestStreamedResponse(
|
|
107
|
+
_model_name=self._model_name, _structured_response=model_response, _messages=messages
|
|
108
|
+
)
|
|
100
109
|
|
|
110
|
+
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
111
|
+
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
112
|
+
|
|
113
|
+
def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]:
|
|
101
114
|
if self.call_tools == 'all':
|
|
102
|
-
|
|
115
|
+
return [(r.name, r) for r in model_request_parameters.function_tools]
|
|
103
116
|
else:
|
|
104
|
-
function_tools_lookup = {t.name: t for t in function_tools}
|
|
117
|
+
function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools}
|
|
105
118
|
tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
|
|
106
|
-
|
|
119
|
+
return [(r.name, r) for r in tools_to_call]
|
|
107
120
|
|
|
121
|
+
def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
|
|
108
122
|
if self.custom_result_text is not None:
|
|
109
|
-
assert
|
|
123
|
+
assert (
|
|
124
|
+
model_request_parameters.allow_text_result
|
|
125
|
+
), 'Plain response not allowed, but `custom_result_text` is set.'
|
|
110
126
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
111
|
-
|
|
127
|
+
return _TextResult(self.custom_result_text)
|
|
112
128
|
elif self.custom_result_args is not None:
|
|
113
|
-
assert
|
|
114
|
-
|
|
129
|
+
assert (
|
|
130
|
+
model_request_parameters.result_tools is not None
|
|
131
|
+
), 'No result tools provided, but `custom_result_args` is set.'
|
|
132
|
+
result_tool = model_request_parameters.result_tools[0]
|
|
115
133
|
|
|
116
134
|
if k := result_tool.outer_typed_dict_key:
|
|
117
|
-
|
|
135
|
+
return _FunctionToolResult({k: self.custom_result_args})
|
|
118
136
|
else:
|
|
119
|
-
|
|
120
|
-
elif allow_text_result:
|
|
121
|
-
|
|
122
|
-
elif result_tools:
|
|
123
|
-
|
|
137
|
+
return _FunctionToolResult(self.custom_result_args)
|
|
138
|
+
elif model_request_parameters.allow_text_result:
|
|
139
|
+
return _TextResult(None)
|
|
140
|
+
elif model_request_parameters.result_tools:
|
|
141
|
+
return _FunctionToolResult(None)
|
|
124
142
|
else:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
return TestAgentModel(tool_calls, result, result_tools, self.seed)
|
|
128
|
-
|
|
129
|
-
def name(self) -> str:
|
|
130
|
-
return 'test-model'
|
|
131
|
-
|
|
143
|
+
return _TextResult(None)
|
|
132
144
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
result: _TextResult | _FunctionToolResult
|
|
143
|
-
result_tools: list[ToolDefinition]
|
|
144
|
-
seed: int
|
|
145
|
-
model_name: str = 'test'
|
|
146
|
-
|
|
147
|
-
async def request(
|
|
148
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
149
|
-
) -> tuple[ModelResponse, Usage]:
|
|
150
|
-
model_response = self._request(messages, model_settings)
|
|
151
|
-
usage = _estimate_usage([*messages, model_response])
|
|
152
|
-
return model_response, usage
|
|
153
|
-
|
|
154
|
-
@asynccontextmanager
|
|
155
|
-
async def request_stream(
|
|
156
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
|
-
) -> AsyncIterator[StreamedResponse]:
|
|
158
|
-
model_response = self._request(messages, model_settings)
|
|
159
|
-
yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
|
|
160
|
-
|
|
161
|
-
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
162
|
-
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
145
|
+
def _request(
|
|
146
|
+
self,
|
|
147
|
+
messages: list[ModelMessage],
|
|
148
|
+
model_settings: ModelSettings | None,
|
|
149
|
+
model_request_parameters: ModelRequestParameters,
|
|
150
|
+
) -> ModelResponse:
|
|
151
|
+
tool_calls = self._get_tool_calls(model_request_parameters)
|
|
152
|
+
result = self._get_result(model_request_parameters)
|
|
153
|
+
result_tools = model_request_parameters.result_tools
|
|
163
154
|
|
|
164
|
-
def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
|
|
165
155
|
# if there are tools, the first thing we want to do is call all of them
|
|
166
|
-
if
|
|
156
|
+
if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
167
157
|
return ModelResponse(
|
|
168
|
-
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in
|
|
169
|
-
model_name=self.
|
|
158
|
+
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
|
|
159
|
+
model_name=self._model_name,
|
|
170
160
|
)
|
|
171
161
|
|
|
172
162
|
if messages:
|
|
@@ -179,28 +169,26 @@ class TestAgentModel(AgentModel):
|
|
|
179
169
|
# Handle retries for both function tools and result tools
|
|
180
170
|
# Check function tools first
|
|
181
171
|
retry_parts: list[ModelResponsePart] = [
|
|
182
|
-
ToolCallPart(name, self.gen_tool_args(args))
|
|
183
|
-
for name, args in self.tool_calls
|
|
184
|
-
if name in new_retry_names
|
|
172
|
+
ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
|
|
185
173
|
]
|
|
186
174
|
# Check result tools
|
|
187
|
-
if
|
|
175
|
+
if result_tools:
|
|
188
176
|
retry_parts.extend(
|
|
189
177
|
[
|
|
190
178
|
ToolCallPart(
|
|
191
179
|
tool.name,
|
|
192
|
-
|
|
193
|
-
if isinstance(
|
|
180
|
+
result.value
|
|
181
|
+
if isinstance(result, _FunctionToolResult) and result.value is not None
|
|
194
182
|
else self.gen_tool_args(tool),
|
|
195
183
|
)
|
|
196
|
-
for tool in
|
|
184
|
+
for tool in result_tools
|
|
197
185
|
if tool.name in new_retry_names
|
|
198
186
|
]
|
|
199
187
|
)
|
|
200
|
-
return ModelResponse(parts=retry_parts, model_name=self.
|
|
188
|
+
return ModelResponse(parts=retry_parts, model_name=self._model_name)
|
|
201
189
|
|
|
202
|
-
if isinstance(
|
|
203
|
-
if (response_text :=
|
|
190
|
+
if isinstance(result, _TextResult):
|
|
191
|
+
if (response_text := result.value) is None:
|
|
204
192
|
# build up details of tool responses
|
|
205
193
|
output: dict[str, Any] = {}
|
|
206
194
|
for message in messages:
|
|
@@ -210,23 +198,23 @@ class TestAgentModel(AgentModel):
|
|
|
210
198
|
output[part.tool_name] = part.content
|
|
211
199
|
if output:
|
|
212
200
|
return ModelResponse(
|
|
213
|
-
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.
|
|
201
|
+
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name
|
|
214
202
|
)
|
|
215
203
|
else:
|
|
216
|
-
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.
|
|
204
|
+
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name)
|
|
217
205
|
else:
|
|
218
|
-
return ModelResponse(parts=[TextPart(response_text)], model_name=self.
|
|
206
|
+
return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
|
|
219
207
|
else:
|
|
220
|
-
assert
|
|
221
|
-
custom_result_args =
|
|
222
|
-
result_tool =
|
|
208
|
+
assert result_tools, 'No result tools provided'
|
|
209
|
+
custom_result_args = result.value
|
|
210
|
+
result_tool = result_tools[self.seed % len(result_tools)]
|
|
223
211
|
if custom_result_args is not None:
|
|
224
212
|
return ModelResponse(
|
|
225
|
-
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.
|
|
213
|
+
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name
|
|
226
214
|
)
|
|
227
215
|
else:
|
|
228
216
|
response_args = self.gen_tool_args(result_tool)
|
|
229
|
-
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.
|
|
217
|
+
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name)
|
|
230
218
|
|
|
231
219
|
|
|
232
220
|
@dataclass
|