pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +12 -8
- pydantic_ai/agent.py +5 -5
- pydantic_ai/models/__init__.py +52 -45
- pydantic_ai/models/anthropic.py +87 -66
- pydantic_ai/models/cohere.py +65 -67
- pydantic_ai/models/function.py +76 -60
- pydantic_ai/models/gemini.py +153 -99
- pydantic_ai/models/groq.py +97 -72
- pydantic_ai/models/mistral.py +90 -71
- pydantic_ai/models/openai.py +110 -71
- pydantic_ai/models/test.py +99 -94
- pydantic_ai/models/vertexai.py +48 -44
- pydantic_ai/result.py +2 -2
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.24.dist-info}/METADATA +3 -3
- pydantic_ai_slim-0.0.24.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.22.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.24.dist-info}/WHEEL +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -29,8 +29,8 @@ from ..messages import (
|
|
|
29
29
|
from ..settings import ModelSettings
|
|
30
30
|
from ..tools import ToolDefinition
|
|
31
31
|
from . import (
|
|
32
|
-
AgentModel,
|
|
33
32
|
Model,
|
|
33
|
+
ModelRequestParameters,
|
|
34
34
|
StreamedResponse,
|
|
35
35
|
cached_async_http_client,
|
|
36
36
|
check_allow_model_requests,
|
|
@@ -46,10 +46,16 @@ except ImportError as _import_error:
|
|
|
46
46
|
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
47
47
|
) from _import_error
|
|
48
48
|
|
|
49
|
-
OpenAIModelName = Union[
|
|
49
|
+
OpenAIModelName = Union[str, ChatModel]
|
|
50
50
|
"""
|
|
51
|
+
Possible OpenAI model names.
|
|
52
|
+
|
|
53
|
+
Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but
|
|
54
|
+
allow any name in the type hints.
|
|
55
|
+
See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list.
|
|
56
|
+
|
|
51
57
|
Using this more broad type for the model name instead of the ChatModel definition
|
|
52
|
-
allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
|
|
58
|
+
allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
|
|
53
59
|
"""
|
|
54
60
|
|
|
55
61
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
@@ -58,7 +64,12 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
|
58
64
|
class OpenAIModelSettings(ModelSettings):
|
|
59
65
|
"""Settings used for an OpenAI model request."""
|
|
60
66
|
|
|
61
|
-
|
|
67
|
+
openai_reasoning_effort: chat.ChatCompletionReasoningEffort
|
|
68
|
+
"""
|
|
69
|
+
Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
|
|
70
|
+
Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
|
|
71
|
+
result in faster responses and fewer tokens used on reasoning in a response.
|
|
72
|
+
"""
|
|
62
73
|
|
|
63
74
|
|
|
64
75
|
@dataclass(init=False)
|
|
@@ -70,10 +81,12 @@ class OpenAIModel(Model):
|
|
|
70
81
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
71
82
|
"""
|
|
72
83
|
|
|
73
|
-
model_name: OpenAIModelName
|
|
74
84
|
client: AsyncOpenAI = field(repr=False)
|
|
75
85
|
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
76
86
|
|
|
87
|
+
_model_name: OpenAIModelName = field(repr=False)
|
|
88
|
+
_system: str | None = field(repr=False)
|
|
89
|
+
|
|
77
90
|
def __init__(
|
|
78
91
|
self,
|
|
79
92
|
model_name: OpenAIModelName,
|
|
@@ -83,6 +96,7 @@ class OpenAIModel(Model):
|
|
|
83
96
|
openai_client: AsyncOpenAI | None = None,
|
|
84
97
|
http_client: AsyncHTTPClient | None = None,
|
|
85
98
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
99
|
+
system: str | None = 'openai',
|
|
86
100
|
):
|
|
87
101
|
"""Initialize an OpenAI model.
|
|
88
102
|
|
|
@@ -100,13 +114,16 @@ class OpenAIModel(Model):
|
|
|
100
114
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
101
115
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
102
116
|
In the future, this may be inferred from the model name.
|
|
117
|
+
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
|
|
118
|
+
customize the `base_url` and `api_key` to use a different provider.
|
|
103
119
|
"""
|
|
104
|
-
self.
|
|
120
|
+
self._model_name = model_name
|
|
105
121
|
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
106
122
|
# openai compatible models do not always need an API key.
|
|
107
123
|
if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
|
|
108
124
|
api_key = ''
|
|
109
|
-
|
|
125
|
+
|
|
126
|
+
if openai_client is not None:
|
|
110
127
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
111
128
|
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
112
129
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
@@ -116,84 +133,80 @@ class OpenAIModel(Model):
|
|
|
116
133
|
else:
|
|
117
134
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
118
135
|
self.system_prompt_role = system_prompt_role
|
|
119
|
-
|
|
120
|
-
async def agent_model(
|
|
121
|
-
self,
|
|
122
|
-
*,
|
|
123
|
-
function_tools: list[ToolDefinition],
|
|
124
|
-
allow_text_result: bool,
|
|
125
|
-
result_tools: list[ToolDefinition],
|
|
126
|
-
) -> AgentModel:
|
|
127
|
-
check_allow_model_requests()
|
|
128
|
-
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
129
|
-
if result_tools:
|
|
130
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
131
|
-
return OpenAIAgentModel(
|
|
132
|
-
self.client,
|
|
133
|
-
self.model_name,
|
|
134
|
-
allow_text_result,
|
|
135
|
-
tools,
|
|
136
|
-
self.system_prompt_role,
|
|
137
|
-
)
|
|
136
|
+
self._system = system
|
|
138
137
|
|
|
139
138
|
def name(self) -> str:
|
|
140
|
-
return f'openai:{self.
|
|
141
|
-
|
|
142
|
-
@staticmethod
|
|
143
|
-
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
144
|
-
return {
|
|
145
|
-
'type': 'function',
|
|
146
|
-
'function': {
|
|
147
|
-
'name': f.name,
|
|
148
|
-
'description': f.description,
|
|
149
|
-
'parameters': f.parameters_json_schema,
|
|
150
|
-
},
|
|
151
|
-
}
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
@dataclass
|
|
155
|
-
class OpenAIAgentModel(AgentModel):
|
|
156
|
-
"""Implementation of `AgentModel` for OpenAI models."""
|
|
157
|
-
|
|
158
|
-
client: AsyncOpenAI
|
|
159
|
-
model_name: OpenAIModelName
|
|
160
|
-
allow_text_result: bool
|
|
161
|
-
tools: list[chat.ChatCompletionToolParam]
|
|
162
|
-
system_prompt_role: OpenAISystemPromptRole | None
|
|
139
|
+
return f'openai:{self._model_name}'
|
|
163
140
|
|
|
164
141
|
async def request(
|
|
165
|
-
self,
|
|
142
|
+
self,
|
|
143
|
+
messages: list[ModelMessage],
|
|
144
|
+
model_settings: ModelSettings | None,
|
|
145
|
+
model_request_parameters: ModelRequestParameters,
|
|
166
146
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
167
|
-
|
|
147
|
+
check_allow_model_requests()
|
|
148
|
+
response = await self._completions_create(
|
|
149
|
+
messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
150
|
+
)
|
|
168
151
|
return self._process_response(response), _map_usage(response)
|
|
169
152
|
|
|
170
153
|
@asynccontextmanager
|
|
171
154
|
async def request_stream(
|
|
172
|
-
self,
|
|
155
|
+
self,
|
|
156
|
+
messages: list[ModelMessage],
|
|
157
|
+
model_settings: ModelSettings | None,
|
|
158
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
159
|
) -> AsyncIterator[StreamedResponse]:
|
|
174
|
-
|
|
160
|
+
check_allow_model_requests()
|
|
161
|
+
response = await self._completions_create(
|
|
162
|
+
messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
163
|
+
)
|
|
175
164
|
async with response:
|
|
176
165
|
yield await self._process_streamed_response(response)
|
|
177
166
|
|
|
167
|
+
@property
|
|
168
|
+
def model_name(self) -> OpenAIModelName:
|
|
169
|
+
"""The model name."""
|
|
170
|
+
return self._model_name
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def system(self) -> str | None:
|
|
174
|
+
"""The system / model provider."""
|
|
175
|
+
return self._system
|
|
176
|
+
|
|
178
177
|
@overload
|
|
179
178
|
async def _completions_create(
|
|
180
|
-
self,
|
|
179
|
+
self,
|
|
180
|
+
messages: list[ModelMessage],
|
|
181
|
+
stream: Literal[True],
|
|
182
|
+
model_settings: OpenAIModelSettings,
|
|
183
|
+
model_request_parameters: ModelRequestParameters,
|
|
181
184
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
182
185
|
pass
|
|
183
186
|
|
|
184
187
|
@overload
|
|
185
188
|
async def _completions_create(
|
|
186
|
-
self,
|
|
189
|
+
self,
|
|
190
|
+
messages: list[ModelMessage],
|
|
191
|
+
stream: Literal[False],
|
|
192
|
+
model_settings: OpenAIModelSettings,
|
|
193
|
+
model_request_parameters: ModelRequestParameters,
|
|
187
194
|
) -> chat.ChatCompletion:
|
|
188
195
|
pass
|
|
189
196
|
|
|
190
197
|
async def _completions_create(
|
|
191
|
-
self,
|
|
198
|
+
self,
|
|
199
|
+
messages: list[ModelMessage],
|
|
200
|
+
stream: bool,
|
|
201
|
+
model_settings: OpenAIModelSettings,
|
|
202
|
+
model_request_parameters: ModelRequestParameters,
|
|
192
203
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
204
|
+
tools = self._get_tools(model_request_parameters)
|
|
205
|
+
|
|
193
206
|
# standalone function to make it easier to override
|
|
194
|
-
if not
|
|
207
|
+
if not tools:
|
|
195
208
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
196
|
-
elif not
|
|
209
|
+
elif not model_request_parameters.allow_text_result:
|
|
197
210
|
tool_choice = 'required'
|
|
198
211
|
else:
|
|
199
212
|
tool_choice = 'auto'
|
|
@@ -201,11 +214,11 @@ class OpenAIAgentModel(AgentModel):
|
|
|
201
214
|
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
202
215
|
|
|
203
216
|
return await self.client.chat.completions.create(
|
|
204
|
-
model=self.
|
|
217
|
+
model=self._model_name,
|
|
205
218
|
messages=openai_messages,
|
|
206
219
|
n=1,
|
|
207
220
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
208
|
-
tools=
|
|
221
|
+
tools=tools or NOT_GIVEN,
|
|
209
222
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
210
223
|
stream=stream,
|
|
211
224
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
@@ -217,6 +230,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
217
230
|
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
218
231
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
219
232
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
233
|
+
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
220
234
|
)
|
|
221
235
|
|
|
222
236
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
@@ -229,7 +243,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
229
243
|
if choice.message.tool_calls is not None:
|
|
230
244
|
for c in choice.message.tool_calls:
|
|
231
245
|
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
232
|
-
return ModelResponse(items, model_name=
|
|
246
|
+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
233
247
|
|
|
234
248
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
235
249
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -239,11 +253,17 @@ class OpenAIAgentModel(AgentModel):
|
|
|
239
253
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
240
254
|
|
|
241
255
|
return OpenAIStreamedResponse(
|
|
242
|
-
_model_name=self.
|
|
256
|
+
_model_name=self._model_name,
|
|
243
257
|
_response=peekable_response,
|
|
244
258
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
245
259
|
)
|
|
246
260
|
|
|
261
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
262
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
263
|
+
if model_request_parameters.result_tools:
|
|
264
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
265
|
+
return tools
|
|
266
|
+
|
|
247
267
|
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
248
268
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
249
269
|
if isinstance(message, ModelRequest):
|
|
@@ -255,7 +275,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
255
275
|
if isinstance(item, TextPart):
|
|
256
276
|
texts.append(item.content)
|
|
257
277
|
elif isinstance(item, ToolCallPart):
|
|
258
|
-
tool_calls.append(_map_tool_call(item))
|
|
278
|
+
tool_calls.append(self._map_tool_call(item))
|
|
259
279
|
else:
|
|
260
280
|
assert_never(item)
|
|
261
281
|
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
@@ -269,6 +289,25 @@ class OpenAIAgentModel(AgentModel):
|
|
|
269
289
|
else:
|
|
270
290
|
assert_never(message)
|
|
271
291
|
|
|
292
|
+
@staticmethod
|
|
293
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
294
|
+
return chat.ChatCompletionMessageToolCallParam(
|
|
295
|
+
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
296
|
+
type='function',
|
|
297
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
302
|
+
return {
|
|
303
|
+
'type': 'function',
|
|
304
|
+
'function': {
|
|
305
|
+
'name': f.name,
|
|
306
|
+
'description': f.description,
|
|
307
|
+
'parameters': f.parameters_json_schema,
|
|
308
|
+
},
|
|
309
|
+
}
|
|
310
|
+
|
|
272
311
|
def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
273
312
|
for part in message.parts:
|
|
274
313
|
if isinstance(part, SystemPromptPart):
|
|
@@ -303,6 +342,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
303
342
|
class OpenAIStreamedResponse(StreamedResponse):
|
|
304
343
|
"""Implementation of `StreamedResponse` for OpenAI models."""
|
|
305
344
|
|
|
345
|
+
_model_name: OpenAIModelName
|
|
306
346
|
_response: AsyncIterable[ChatCompletionChunk]
|
|
307
347
|
_timestamp: datetime
|
|
308
348
|
|
|
@@ -330,18 +370,17 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
330
370
|
if maybe_event is not None:
|
|
331
371
|
yield maybe_event
|
|
332
372
|
|
|
373
|
+
@property
|
|
374
|
+
def model_name(self) -> OpenAIModelName:
|
|
375
|
+
"""Get the model name of the response."""
|
|
376
|
+
return self._model_name
|
|
377
|
+
|
|
378
|
+
@property
|
|
333
379
|
def timestamp(self) -> datetime:
|
|
380
|
+
"""Get the timestamp of the response."""
|
|
334
381
|
return self._timestamp
|
|
335
382
|
|
|
336
383
|
|
|
337
|
-
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
338
|
-
return chat.ChatCompletionMessageToolCallParam(
|
|
339
|
-
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
340
|
-
type='function',
|
|
341
|
-
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
|
|
345
384
|
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
|
|
346
385
|
response_usage = response.usage
|
|
347
386
|
if response_usage is None:
|
pydantic_ai/models/test.py
CHANGED
|
@@ -26,8 +26,8 @@ from ..result import Usage
|
|
|
26
26
|
from ..settings import ModelSettings
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
29
|
-
AgentModel,
|
|
30
29
|
Model,
|
|
30
|
+
ModelRequestParameters,
|
|
31
31
|
StreamedResponse,
|
|
32
32
|
)
|
|
33
33
|
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
@@ -71,102 +71,102 @@ class TestModel(Model):
|
|
|
71
71
|
"""If set, these args will be passed to the result tool."""
|
|
72
72
|
seed: int = 0
|
|
73
73
|
"""Seed for generating random data."""
|
|
74
|
-
|
|
75
|
-
"""
|
|
74
|
+
last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
|
|
75
|
+
"""The last ModelRequestParameters passed to the model in a request.
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
"""
|
|
79
|
-
agent_model_allow_text_result: bool | None = field(default=None, init=False)
|
|
80
|
-
"""Whether plain text responses from the model are allowed.
|
|
77
|
+
The ModelRequestParameters contains information about the function and result tools available during request handling.
|
|
81
78
|
|
|
82
|
-
This is set when
|
|
79
|
+
This is set when a request is made, so will reflect the function tools from the last step of the last run.
|
|
83
80
|
"""
|
|
84
|
-
|
|
85
|
-
|
|
81
|
+
_model_name: str = field(default='test', repr=False)
|
|
82
|
+
_system: str | None = field(default=None, repr=False)
|
|
86
83
|
|
|
87
|
-
|
|
88
|
-
|
|
84
|
+
async def request(
|
|
85
|
+
self,
|
|
86
|
+
messages: list[ModelMessage],
|
|
87
|
+
model_settings: ModelSettings | None,
|
|
88
|
+
model_request_parameters: ModelRequestParameters,
|
|
89
|
+
) -> tuple[ModelResponse, Usage]:
|
|
90
|
+
self.last_model_request_parameters = model_request_parameters
|
|
89
91
|
|
|
90
|
-
|
|
92
|
+
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
93
|
+
usage = _estimate_usage([*messages, model_response])
|
|
94
|
+
return model_response, usage
|
|
95
|
+
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def request_stream(
|
|
91
98
|
self,
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
self.
|
|
99
|
-
|
|
99
|
+
messages: list[ModelMessage],
|
|
100
|
+
model_settings: ModelSettings | None,
|
|
101
|
+
model_request_parameters: ModelRequestParameters,
|
|
102
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
103
|
+
self.last_model_request_parameters = model_request_parameters
|
|
104
|
+
|
|
105
|
+
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
106
|
+
yield TestStreamedResponse(
|
|
107
|
+
_model_name=self._model_name, _structured_response=model_response, _messages=messages
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def model_name(self) -> str:
|
|
112
|
+
"""The model name."""
|
|
113
|
+
return self._model_name
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def system(self) -> str | None:
|
|
117
|
+
"""The system / model provider."""
|
|
118
|
+
return self._system
|
|
119
|
+
|
|
120
|
+
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
121
|
+
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
100
122
|
|
|
123
|
+
def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]:
|
|
101
124
|
if self.call_tools == 'all':
|
|
102
|
-
|
|
125
|
+
return [(r.name, r) for r in model_request_parameters.function_tools]
|
|
103
126
|
else:
|
|
104
|
-
function_tools_lookup = {t.name: t for t in function_tools}
|
|
127
|
+
function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools}
|
|
105
128
|
tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
|
|
106
|
-
|
|
129
|
+
return [(r.name, r) for r in tools_to_call]
|
|
107
130
|
|
|
131
|
+
def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
|
|
108
132
|
if self.custom_result_text is not None:
|
|
109
|
-
assert
|
|
133
|
+
assert (
|
|
134
|
+
model_request_parameters.allow_text_result
|
|
135
|
+
), 'Plain response not allowed, but `custom_result_text` is set.'
|
|
110
136
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
111
|
-
|
|
137
|
+
return _TextResult(self.custom_result_text)
|
|
112
138
|
elif self.custom_result_args is not None:
|
|
113
|
-
assert
|
|
114
|
-
|
|
139
|
+
assert (
|
|
140
|
+
model_request_parameters.result_tools is not None
|
|
141
|
+
), 'No result tools provided, but `custom_result_args` is set.'
|
|
142
|
+
result_tool = model_request_parameters.result_tools[0]
|
|
115
143
|
|
|
116
144
|
if k := result_tool.outer_typed_dict_key:
|
|
117
|
-
|
|
145
|
+
return _FunctionToolResult({k: self.custom_result_args})
|
|
118
146
|
else:
|
|
119
|
-
|
|
120
|
-
elif allow_text_result:
|
|
121
|
-
|
|
122
|
-
elif result_tools:
|
|
123
|
-
|
|
147
|
+
return _FunctionToolResult(self.custom_result_args)
|
|
148
|
+
elif model_request_parameters.allow_text_result:
|
|
149
|
+
return _TextResult(None)
|
|
150
|
+
elif model_request_parameters.result_tools:
|
|
151
|
+
return _FunctionToolResult(None)
|
|
124
152
|
else:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
return TestAgentModel(tool_calls, result, result_tools, self.seed)
|
|
128
|
-
|
|
129
|
-
def name(self) -> str:
|
|
130
|
-
return 'test-model'
|
|
131
|
-
|
|
153
|
+
return _TextResult(None)
|
|
132
154
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
result: _TextResult | _FunctionToolResult
|
|
143
|
-
result_tools: list[ToolDefinition]
|
|
144
|
-
seed: int
|
|
145
|
-
model_name: str = 'test'
|
|
146
|
-
|
|
147
|
-
async def request(
|
|
148
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
149
|
-
) -> tuple[ModelResponse, Usage]:
|
|
150
|
-
model_response = self._request(messages, model_settings)
|
|
151
|
-
usage = _estimate_usage([*messages, model_response])
|
|
152
|
-
return model_response, usage
|
|
153
|
-
|
|
154
|
-
@asynccontextmanager
|
|
155
|
-
async def request_stream(
|
|
156
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
|
-
) -> AsyncIterator[StreamedResponse]:
|
|
158
|
-
model_response = self._request(messages, model_settings)
|
|
159
|
-
yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
|
|
160
|
-
|
|
161
|
-
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
162
|
-
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
155
|
+
def _request(
|
|
156
|
+
self,
|
|
157
|
+
messages: list[ModelMessage],
|
|
158
|
+
model_settings: ModelSettings | None,
|
|
159
|
+
model_request_parameters: ModelRequestParameters,
|
|
160
|
+
) -> ModelResponse:
|
|
161
|
+
tool_calls = self._get_tool_calls(model_request_parameters)
|
|
162
|
+
result = self._get_result(model_request_parameters)
|
|
163
|
+
result_tools = model_request_parameters.result_tools
|
|
163
164
|
|
|
164
|
-
def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
|
|
165
165
|
# if there are tools, the first thing we want to do is call all of them
|
|
166
|
-
if
|
|
166
|
+
if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
167
167
|
return ModelResponse(
|
|
168
|
-
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in
|
|
169
|
-
model_name=self.
|
|
168
|
+
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
|
|
169
|
+
model_name=self._model_name,
|
|
170
170
|
)
|
|
171
171
|
|
|
172
172
|
if messages:
|
|
@@ -179,28 +179,26 @@ class TestAgentModel(AgentModel):
|
|
|
179
179
|
# Handle retries for both function tools and result tools
|
|
180
180
|
# Check function tools first
|
|
181
181
|
retry_parts: list[ModelResponsePart] = [
|
|
182
|
-
ToolCallPart(name, self.gen_tool_args(args))
|
|
183
|
-
for name, args in self.tool_calls
|
|
184
|
-
if name in new_retry_names
|
|
182
|
+
ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
|
|
185
183
|
]
|
|
186
184
|
# Check result tools
|
|
187
|
-
if
|
|
185
|
+
if result_tools:
|
|
188
186
|
retry_parts.extend(
|
|
189
187
|
[
|
|
190
188
|
ToolCallPart(
|
|
191
189
|
tool.name,
|
|
192
|
-
|
|
193
|
-
if isinstance(
|
|
190
|
+
result.value
|
|
191
|
+
if isinstance(result, _FunctionToolResult) and result.value is not None
|
|
194
192
|
else self.gen_tool_args(tool),
|
|
195
193
|
)
|
|
196
|
-
for tool in
|
|
194
|
+
for tool in result_tools
|
|
197
195
|
if tool.name in new_retry_names
|
|
198
196
|
]
|
|
199
197
|
)
|
|
200
|
-
return ModelResponse(parts=retry_parts, model_name=self.
|
|
198
|
+
return ModelResponse(parts=retry_parts, model_name=self._model_name)
|
|
201
199
|
|
|
202
|
-
if isinstance(
|
|
203
|
-
if (response_text :=
|
|
200
|
+
if isinstance(result, _TextResult):
|
|
201
|
+
if (response_text := result.value) is None:
|
|
204
202
|
# build up details of tool responses
|
|
205
203
|
output: dict[str, Any] = {}
|
|
206
204
|
for message in messages:
|
|
@@ -210,32 +208,32 @@ class TestAgentModel(AgentModel):
|
|
|
210
208
|
output[part.tool_name] = part.content
|
|
211
209
|
if output:
|
|
212
210
|
return ModelResponse(
|
|
213
|
-
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.
|
|
211
|
+
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name
|
|
214
212
|
)
|
|
215
213
|
else:
|
|
216
|
-
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.
|
|
214
|
+
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name)
|
|
217
215
|
else:
|
|
218
|
-
return ModelResponse(parts=[TextPart(response_text)], model_name=self.
|
|
216
|
+
return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
|
|
219
217
|
else:
|
|
220
|
-
assert
|
|
221
|
-
custom_result_args =
|
|
222
|
-
result_tool =
|
|
218
|
+
assert result_tools, 'No result tools provided'
|
|
219
|
+
custom_result_args = result.value
|
|
220
|
+
result_tool = result_tools[self.seed % len(result_tools)]
|
|
223
221
|
if custom_result_args is not None:
|
|
224
222
|
return ModelResponse(
|
|
225
|
-
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.
|
|
223
|
+
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name
|
|
226
224
|
)
|
|
227
225
|
else:
|
|
228
226
|
response_args = self.gen_tool_args(result_tool)
|
|
229
|
-
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.
|
|
227
|
+
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name)
|
|
230
228
|
|
|
231
229
|
|
|
232
230
|
@dataclass
|
|
233
231
|
class TestStreamedResponse(StreamedResponse):
|
|
234
232
|
"""A structured response that streams test data."""
|
|
235
233
|
|
|
234
|
+
_model_name: str
|
|
236
235
|
_structured_response: ModelResponse
|
|
237
236
|
_messages: InitVar[Iterable[ModelMessage]]
|
|
238
|
-
|
|
239
237
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
240
238
|
|
|
241
239
|
def __post_init__(self, _messages: Iterable[ModelMessage]):
|
|
@@ -261,7 +259,14 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
261
259
|
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
|
|
262
260
|
)
|
|
263
261
|
|
|
262
|
+
@property
|
|
263
|
+
def model_name(self) -> str:
|
|
264
|
+
"""Get the model name of the response."""
|
|
265
|
+
return self._model_name
|
|
266
|
+
|
|
267
|
+
@property
|
|
264
268
|
def timestamp(self) -> datetime:
|
|
269
|
+
"""Get the timestamp of the response."""
|
|
265
270
|
return self._timestamp
|
|
266
271
|
|
|
267
272
|
|