pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +6 -4
- pydantic_ai/_result.py +18 -22
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +11 -6
- pydantic_ai/agent.py +146 -67
- pydantic_ai/messages.py +5 -2
- pydantic_ai/models/__init__.py +30 -37
- pydantic_ai/models/function.py +8 -14
- pydantic_ai/models/gemini.py +11 -10
- pydantic_ai/models/groq.py +31 -34
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +43 -38
- pydantic_ai/models/test.py +70 -49
- pydantic_ai/models/vertexai.py +7 -6
- pydantic_ai/tools.py +119 -34
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.12.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.12.dist-info}/WHEEL +0 -0
pydantic_ai/_pydantic.py
CHANGED
|
@@ -17,10 +17,10 @@ from pydantic.plugin._schema_validator import create_schema_validator
|
|
|
17
17
|
from pydantic_core import SchemaValidator, core_schema
|
|
18
18
|
|
|
19
19
|
from ._griffe import doc_descriptions
|
|
20
|
-
from ._utils import
|
|
20
|
+
from ._utils import check_object_json_schema, is_model_like
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
|
|
23
|
+
from .tools import ObjectJsonSchema
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
__all__ = 'function_schema', 'LazyTypeAdapter'
|
|
@@ -168,11 +168,13 @@ def takes_ctx(function: Callable[..., Any]) -> bool:
|
|
|
168
168
|
"""
|
|
169
169
|
sig = signature(function)
|
|
170
170
|
try:
|
|
171
|
-
|
|
171
|
+
first_param_name = next(iter(sig.parameters.keys()))
|
|
172
172
|
except StopIteration:
|
|
173
173
|
return False
|
|
174
174
|
else:
|
|
175
|
-
|
|
175
|
+
type_hints = _typing_extra.get_function_type_hints(function)
|
|
176
|
+
annotation = type_hints[first_param_name]
|
|
177
|
+
return annotation is not sig.empty and _is_call_ctx(annotation)
|
|
176
178
|
|
|
177
179
|
|
|
178
180
|
def _build_schema(
|
pydantic_ai/_result.py
CHANGED
|
@@ -14,7 +14,7 @@ from . import _utils, messages
|
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
15
|
from .messages import ModelStructuredResponse, ToolCall
|
|
16
16
|
from .result import ResultData
|
|
17
|
-
from .tools import AgentDeps, ResultValidatorFunc, RunContext
|
|
17
|
+
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
|
@@ -94,10 +94,7 @@ class ResultSchema(Generic[ResultData]):
|
|
|
94
94
|
allow_text_result = False
|
|
95
95
|
|
|
96
96
|
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
|
|
97
|
-
return cast(
|
|
98
|
-
ResultTool[ResultData],
|
|
99
|
-
ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
|
|
100
|
-
)
|
|
97
|
+
return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
|
|
101
98
|
|
|
102
99
|
tools: dict[str, ResultTool[ResultData]] = {}
|
|
103
100
|
if args := get_union_args(response_type):
|
|
@@ -121,38 +118,38 @@ class ResultSchema(Generic[ResultData]):
|
|
|
121
118
|
"""Return the names of the tools."""
|
|
122
119
|
return list(self.tools.keys())
|
|
123
120
|
|
|
121
|
+
def tool_defs(self) -> list[ToolDefinition]:
|
|
122
|
+
"""Get tool definitions to register with the model."""
|
|
123
|
+
return [t.tool_def for t in self.tools.values()]
|
|
124
|
+
|
|
124
125
|
|
|
125
126
|
DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
|
|
126
127
|
|
|
127
128
|
|
|
128
|
-
@dataclass
|
|
129
|
+
@dataclass(init=False)
|
|
129
130
|
class ResultTool(Generic[ResultData]):
|
|
130
|
-
|
|
131
|
-
description: str
|
|
131
|
+
tool_def: ToolDefinition
|
|
132
132
|
type_adapter: TypeAdapter[Any]
|
|
133
|
-
json_schema: _utils.ObjectJsonSchema
|
|
134
|
-
outer_typed_dict_key: str | None
|
|
135
133
|
|
|
136
|
-
|
|
137
|
-
def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
|
|
134
|
+
def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
|
|
138
135
|
"""Build a ResultTool dataclass from a response type."""
|
|
139
136
|
assert response_type is not str, 'ResultTool does not support str as a response type'
|
|
140
137
|
|
|
141
138
|
if _utils.is_model_like(response_type):
|
|
142
|
-
type_adapter = TypeAdapter(response_type)
|
|
139
|
+
self.type_adapter = TypeAdapter(response_type)
|
|
143
140
|
outer_typed_dict_key: str | None = None
|
|
144
141
|
# noinspection PyArgumentList
|
|
145
|
-
|
|
142
|
+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
|
|
146
143
|
else:
|
|
147
144
|
response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
|
|
148
|
-
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
145
|
+
self.type_adapter = TypeAdapter(response_data_typed_dict)
|
|
149
146
|
outer_typed_dict_key = 'response'
|
|
150
147
|
# noinspection PyArgumentList
|
|
151
|
-
|
|
148
|
+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
|
|
152
149
|
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
|
|
153
|
-
|
|
150
|
+
parameters_json_schema.pop('title')
|
|
154
151
|
|
|
155
|
-
if json_schema_description :=
|
|
152
|
+
if json_schema_description := parameters_json_schema.pop('description', None):
|
|
156
153
|
if description is None:
|
|
157
154
|
tool_description = json_schema_description
|
|
158
155
|
else:
|
|
@@ -162,11 +159,10 @@ class ResultTool(Generic[ResultData]):
|
|
|
162
159
|
if multiple:
|
|
163
160
|
tool_description = f'{union_arg_name(response_type)}: {tool_description}'
|
|
164
161
|
|
|
165
|
-
|
|
162
|
+
self.tool_def = ToolDefinition(
|
|
166
163
|
name=name,
|
|
167
164
|
description=tool_description,
|
|
168
|
-
|
|
169
|
-
json_schema=json_schema,
|
|
165
|
+
parameters_json_schema=parameters_json_schema,
|
|
170
166
|
outer_typed_dict_key=outer_typed_dict_key,
|
|
171
167
|
)
|
|
172
168
|
|
|
@@ -204,7 +200,7 @@ class ResultTool(Generic[ResultData]):
|
|
|
204
200
|
else:
|
|
205
201
|
raise
|
|
206
202
|
else:
|
|
207
|
-
if k := self.outer_typed_dict_key:
|
|
203
|
+
if k := self.tool_def.outer_typed_dict_key:
|
|
208
204
|
result = result[k]
|
|
209
205
|
return result
|
|
210
206
|
|
pydantic_ai/_system_prompt.py
CHANGED
pydantic_ai/_utils.py
CHANGED
|
@@ -8,12 +8,15 @@ from dataclasses import dataclass, is_dataclass
|
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from functools import partial
|
|
10
10
|
from types import GenericAlias
|
|
11
|
-
from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel
|
|
14
14
|
from pydantic.json_schema import JsonSchemaValue
|
|
15
15
|
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
16
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .tools import ObjectJsonSchema
|
|
19
|
+
|
|
17
20
|
_P = ParamSpec('_P')
|
|
18
21
|
_R = TypeVar('_R')
|
|
19
22
|
|
|
@@ -39,10 +42,6 @@ def is_model_like(type_: Any) -> bool:
|
|
|
39
42
|
)
|
|
40
43
|
|
|
41
44
|
|
|
42
|
-
# With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
|
|
43
|
-
ObjectJsonSchema: TypeAlias = dict[str, Any]
|
|
44
|
-
|
|
45
|
-
|
|
46
45
|
def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
|
47
46
|
from .exceptions import UserError
|
|
48
47
|
|
|
@@ -127,6 +126,12 @@ class Either(Generic[Left, Right]):
|
|
|
127
126
|
def whichever(self) -> Left | Right:
|
|
128
127
|
return self._left.value if self._left is not None else self.right
|
|
129
128
|
|
|
129
|
+
def __repr__(self):
|
|
130
|
+
if left := self._left:
|
|
131
|
+
return f'Either(left={left.value!r})'
|
|
132
|
+
else:
|
|
133
|
+
return f'Either(right={self.right!r})'
|
|
134
|
+
|
|
130
135
|
|
|
131
136
|
@asynccontextmanager
|
|
132
137
|
async def group_by_temporal(
|
|
@@ -218,7 +223,7 @@ async def group_by_temporal(
|
|
|
218
223
|
|
|
219
224
|
try:
|
|
220
225
|
yield async_iter_groups()
|
|
221
|
-
finally:
|
|
226
|
+
finally: # pragma: no cover
|
|
222
227
|
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
|
|
223
228
|
if task:
|
|
224
229
|
task.cancel('Cancelling due to error in iterator')
|
pydantic_ai/agent.py
CHANGED
|
@@ -22,7 +22,17 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
)
|
|
24
24
|
from .result import ResultData
|
|
25
|
-
from .tools import
|
|
25
|
+
from .tools import (
|
|
26
|
+
AgentDeps,
|
|
27
|
+
RunContext,
|
|
28
|
+
Tool,
|
|
29
|
+
ToolDefinition,
|
|
30
|
+
ToolFuncContext,
|
|
31
|
+
ToolFuncEither,
|
|
32
|
+
ToolFuncPlain,
|
|
33
|
+
ToolParams,
|
|
34
|
+
ToolPrepareFunc,
|
|
35
|
+
)
|
|
26
36
|
|
|
27
37
|
__all__ = ('Agent',)
|
|
28
38
|
|
|
@@ -136,7 +146,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
136
146
|
self._function_tools = {}
|
|
137
147
|
self._default_retries = retries
|
|
138
148
|
for tool in tools:
|
|
139
|
-
|
|
149
|
+
if isinstance(tool, Tool):
|
|
150
|
+
self._register_tool(tool)
|
|
151
|
+
else:
|
|
152
|
+
self._register_tool(Tool(tool))
|
|
140
153
|
self._deps_type = deps_type
|
|
141
154
|
self._system_prompt_functions = []
|
|
142
155
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
@@ -166,7 +179,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
166
179
|
"""
|
|
167
180
|
if infer_name and self.name is None:
|
|
168
181
|
self._infer_name(inspect.currentframe())
|
|
169
|
-
model_used,
|
|
182
|
+
model_used, mode_selection = await self._get_model(model)
|
|
170
183
|
|
|
171
184
|
deps = self._get_deps(deps)
|
|
172
185
|
|
|
@@ -174,7 +187,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
174
187
|
'{agent_name} run {prompt=}',
|
|
175
188
|
prompt=user_prompt,
|
|
176
189
|
agent=self,
|
|
177
|
-
|
|
190
|
+
mode_selection=mode_selection,
|
|
178
191
|
model_name=model_used.name(),
|
|
179
192
|
agent_name=self.name or 'agent',
|
|
180
193
|
) as run_span:
|
|
@@ -182,14 +195,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
182
195
|
self.last_run_messages = messages
|
|
183
196
|
|
|
184
197
|
for tool in self._function_tools.values():
|
|
185
|
-
tool.
|
|
198
|
+
tool.current_retry = 0
|
|
186
199
|
|
|
187
200
|
cost = result.Cost()
|
|
188
201
|
|
|
189
202
|
run_step = 0
|
|
190
203
|
while True:
|
|
191
204
|
run_step += 1
|
|
192
|
-
with _logfire.span('model
|
|
205
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
206
|
+
agent_model = await self._prepare_model(model_used, deps)
|
|
207
|
+
|
|
208
|
+
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
193
209
|
model_response, request_cost = await agent_model.request(messages)
|
|
194
210
|
model_req_span.set_attribute('response', model_response)
|
|
195
211
|
model_req_span.set_attribute('cost', request_cost)
|
|
@@ -198,12 +214,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
198
214
|
messages.append(model_response)
|
|
199
215
|
cost += request_cost
|
|
200
216
|
|
|
201
|
-
with _logfire.span('handle model response') as handle_span:
|
|
202
|
-
|
|
217
|
+
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
218
|
+
final_result, response_messages = await self._handle_model_response(model_response, deps)
|
|
203
219
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
220
|
+
# Add all messages to the conversation
|
|
221
|
+
messages.extend(response_messages)
|
|
222
|
+
|
|
223
|
+
# Check if we got a final result
|
|
224
|
+
if final_result is not None:
|
|
225
|
+
result_data = final_result.data
|
|
207
226
|
run_span.set_attribute('all_messages', messages)
|
|
208
227
|
run_span.set_attribute('cost', cost)
|
|
209
228
|
handle_span.set_attribute('result', result_data)
|
|
@@ -211,11 +230,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
211
230
|
return result.RunResult(messages, new_message_index, result_data, cost)
|
|
212
231
|
else:
|
|
213
232
|
# continue the conversation
|
|
214
|
-
tool_responses
|
|
215
|
-
|
|
216
|
-
response_msgs = ' '.join(m.role for m in tool_responses)
|
|
233
|
+
handle_span.set_attribute('tool_responses', response_messages)
|
|
234
|
+
response_msgs = ' '.join(r.role for r in response_messages)
|
|
217
235
|
handle_span.message = f'handle model response -> {response_msgs}'
|
|
218
|
-
messages.extend(tool_responses)
|
|
219
236
|
|
|
220
237
|
def run_sync(
|
|
221
238
|
self,
|
|
@@ -273,7 +290,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
273
290
|
# f_back because `asynccontextmanager` adds one frame
|
|
274
291
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
275
292
|
self._infer_name(frame.f_back)
|
|
276
|
-
model_used,
|
|
293
|
+
model_used, mode_selection = await self._get_model(model)
|
|
277
294
|
|
|
278
295
|
deps = self._get_deps(deps)
|
|
279
296
|
|
|
@@ -281,7 +298,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
281
298
|
'{agent_name} run stream {prompt=}',
|
|
282
299
|
prompt=user_prompt,
|
|
283
300
|
agent=self,
|
|
284
|
-
|
|
301
|
+
mode_selection=mode_selection,
|
|
285
302
|
model_name=model_used.name(),
|
|
286
303
|
agent_name=self.name or 'agent',
|
|
287
304
|
) as run_span:
|
|
@@ -289,13 +306,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
289
306
|
self.last_run_messages = messages
|
|
290
307
|
|
|
291
308
|
for tool in self._function_tools.values():
|
|
292
|
-
tool.
|
|
309
|
+
tool.current_retry = 0
|
|
293
310
|
|
|
294
311
|
cost = result.Cost()
|
|
295
312
|
|
|
296
313
|
run_step = 0
|
|
297
314
|
while True:
|
|
298
315
|
run_step += 1
|
|
316
|
+
|
|
317
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
318
|
+
agent_model = await self._prepare_model(model_used, deps)
|
|
319
|
+
|
|
299
320
|
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
300
321
|
async with agent_model.request_stream(messages) as model_response:
|
|
301
322
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
@@ -304,10 +325,16 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
304
325
|
model_req_span.__exit__(None, None, None)
|
|
305
326
|
|
|
306
327
|
with _logfire.span('handle model response') as handle_span:
|
|
307
|
-
|
|
328
|
+
final_result, response_messages = await self._handle_streamed_model_response(
|
|
329
|
+
model_response, deps
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Add all messages to the conversation
|
|
333
|
+
messages.extend(response_messages)
|
|
308
334
|
|
|
309
|
-
if
|
|
310
|
-
|
|
335
|
+
# Check if we got a final result
|
|
336
|
+
if final_result is not None:
|
|
337
|
+
result_stream = final_result.data
|
|
311
338
|
run_span.set_attribute('all_messages', messages)
|
|
312
339
|
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
|
|
313
340
|
handle_span.message = 'handle model response -> final result'
|
|
@@ -323,11 +350,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
323
350
|
)
|
|
324
351
|
return
|
|
325
352
|
else:
|
|
326
|
-
|
|
327
|
-
handle_span.set_attribute('tool_responses',
|
|
328
|
-
response_msgs = ' '.join(
|
|
353
|
+
# continue the conversation
|
|
354
|
+
handle_span.set_attribute('tool_responses', response_messages)
|
|
355
|
+
response_msgs = ' '.join(r.role for r in response_messages)
|
|
329
356
|
handle_span.message = f'handle model response -> {response_msgs}'
|
|
330
|
-
messages.extend(tool_responses)
|
|
331
357
|
# the model_response should have been fully streamed by now, we can add it's cost
|
|
332
358
|
cost += model_response.cost()
|
|
333
359
|
|
|
@@ -477,7 +503,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
477
503
|
|
|
478
504
|
@overload
|
|
479
505
|
def tool(
|
|
480
|
-
self,
|
|
506
|
+
self,
|
|
507
|
+
/,
|
|
508
|
+
*,
|
|
509
|
+
retries: int | None = None,
|
|
510
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
481
511
|
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
|
|
482
512
|
|
|
483
513
|
def tool(
|
|
@@ -486,9 +516,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
486
516
|
/,
|
|
487
517
|
*,
|
|
488
518
|
retries: int | None = None,
|
|
519
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
489
520
|
) -> Any:
|
|
490
|
-
"""Decorator to register a tool function which takes
|
|
491
|
-
[`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
521
|
+
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
492
522
|
|
|
493
523
|
Can decorate a sync or async functions.
|
|
494
524
|
|
|
@@ -521,20 +551,23 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
521
551
|
func: The tool function to register.
|
|
522
552
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
523
553
|
which defaults to 1.
|
|
524
|
-
|
|
554
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
555
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
556
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
557
|
+
"""
|
|
525
558
|
if func is None:
|
|
526
559
|
|
|
527
560
|
def tool_decorator(
|
|
528
561
|
func_: ToolFuncContext[AgentDeps, ToolParams],
|
|
529
562
|
) -> ToolFuncContext[AgentDeps, ToolParams]:
|
|
530
563
|
# noinspection PyTypeChecker
|
|
531
|
-
self._register_function(func_, True, retries)
|
|
564
|
+
self._register_function(func_, True, retries, prepare)
|
|
532
565
|
return func_
|
|
533
566
|
|
|
534
567
|
return tool_decorator
|
|
535
568
|
else:
|
|
536
569
|
# noinspection PyTypeChecker
|
|
537
|
-
self._register_function(func, True, retries)
|
|
570
|
+
self._register_function(func, True, retries, prepare)
|
|
538
571
|
return func
|
|
539
572
|
|
|
540
573
|
@overload
|
|
@@ -542,10 +575,21 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
542
575
|
|
|
543
576
|
@overload
|
|
544
577
|
def tool_plain(
|
|
545
|
-
self,
|
|
578
|
+
self,
|
|
579
|
+
/,
|
|
580
|
+
*,
|
|
581
|
+
retries: int | None = None,
|
|
582
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
546
583
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
547
584
|
|
|
548
|
-
def tool_plain(
|
|
585
|
+
def tool_plain(
|
|
586
|
+
self,
|
|
587
|
+
func: ToolFuncPlain[ToolParams] | None = None,
|
|
588
|
+
/,
|
|
589
|
+
*,
|
|
590
|
+
retries: int | None = None,
|
|
591
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
592
|
+
) -> Any:
|
|
549
593
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
550
594
|
|
|
551
595
|
Can decorate a sync or async functions.
|
|
@@ -579,30 +623,38 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
579
623
|
func: The tool function to register.
|
|
580
624
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
581
625
|
which defaults to 1.
|
|
626
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
627
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
628
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
582
629
|
"""
|
|
583
630
|
if func is None:
|
|
584
631
|
|
|
585
632
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
586
633
|
# noinspection PyTypeChecker
|
|
587
|
-
self._register_function(func_, False, retries)
|
|
634
|
+
self._register_function(func_, False, retries, prepare)
|
|
588
635
|
return func_
|
|
589
636
|
|
|
590
637
|
return tool_decorator
|
|
591
638
|
else:
|
|
592
|
-
self._register_function(func, False, retries)
|
|
639
|
+
self._register_function(func, False, retries, prepare)
|
|
593
640
|
return func
|
|
594
641
|
|
|
595
642
|
def _register_function(
|
|
596
|
-
self,
|
|
643
|
+
self,
|
|
644
|
+
func: ToolFuncEither[AgentDeps, ToolParams],
|
|
645
|
+
takes_ctx: bool,
|
|
646
|
+
retries: int | None,
|
|
647
|
+
prepare: ToolPrepareFunc[AgentDeps] | None,
|
|
597
648
|
) -> None:
|
|
598
649
|
"""Private utility to register a function as a tool."""
|
|
599
650
|
retries_ = retries if retries is not None else self._default_retries
|
|
600
|
-
tool = Tool(func, takes_ctx, max_retries=retries_)
|
|
651
|
+
tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
|
|
601
652
|
self._register_tool(tool)
|
|
602
653
|
|
|
603
654
|
def _register_tool(self, tool: Tool[AgentDeps]) -> None:
|
|
604
655
|
"""Private utility to register a tool instance."""
|
|
605
656
|
if tool.max_retries is None:
|
|
657
|
+
# noinspection PyTypeChecker
|
|
606
658
|
tool = dataclasses.replace(tool, max_retries=self._default_retries)
|
|
607
659
|
|
|
608
660
|
if tool.name in self._function_tools:
|
|
@@ -613,16 +665,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
613
665
|
|
|
614
666
|
self._function_tools[tool.name] = tool
|
|
615
667
|
|
|
616
|
-
async def
|
|
617
|
-
self, model: models.Model | models.KnownModelName | None
|
|
618
|
-
) -> tuple[models.Model, models.Model | None, models.AgentModel]:
|
|
668
|
+
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
|
|
619
669
|
"""Create a model configured for this agent.
|
|
620
670
|
|
|
621
671
|
Args:
|
|
622
672
|
model: model to use for this run, required if `model` was not set when creating the agent.
|
|
623
673
|
|
|
624
674
|
Returns:
|
|
625
|
-
a tuple of `(model used,
|
|
675
|
+
a tuple of `(model used, how the model was selected)`
|
|
626
676
|
"""
|
|
627
677
|
model_: models.Model
|
|
628
678
|
if some_model := self._override_model:
|
|
@@ -633,19 +683,35 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
633
683
|
'(Even when `override(model=...)` is customizing the model that will actually be called)'
|
|
634
684
|
)
|
|
635
685
|
model_ = some_model.value
|
|
636
|
-
|
|
686
|
+
mode_selection = 'override-model'
|
|
637
687
|
elif model is not None:
|
|
638
|
-
|
|
688
|
+
model_ = models.infer_model(model)
|
|
689
|
+
mode_selection = 'custom'
|
|
639
690
|
elif self.model is not None:
|
|
640
691
|
# noinspection PyTypeChecker
|
|
641
692
|
model_ = self.model = models.infer_model(self.model)
|
|
642
|
-
|
|
693
|
+
mode_selection = 'from-agent'
|
|
643
694
|
else:
|
|
644
695
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
645
696
|
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
697
|
+
return model_, mode_selection
|
|
698
|
+
|
|
699
|
+
async def _prepare_model(self, model: models.Model, deps: AgentDeps) -> models.AgentModel:
|
|
700
|
+
"""Create building tools and create an agent model."""
|
|
701
|
+
function_tools: list[ToolDefinition] = []
|
|
702
|
+
|
|
703
|
+
async def add_tool(tool: Tool[AgentDeps]) -> None:
|
|
704
|
+
ctx = RunContext(deps, tool.current_retry, tool.name)
|
|
705
|
+
if tool_def := await tool.prepare_tool_def(ctx):
|
|
706
|
+
function_tools.append(tool_def)
|
|
707
|
+
|
|
708
|
+
await asyncio.gather(*map(add_tool, self._function_tools.values()))
|
|
709
|
+
|
|
710
|
+
return await model.agent_model(
|
|
711
|
+
function_tools=function_tools,
|
|
712
|
+
allow_text_result=self._allow_text_result,
|
|
713
|
+
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
|
|
714
|
+
)
|
|
649
715
|
|
|
650
716
|
async def _prepare_messages(
|
|
651
717
|
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
|
|
@@ -665,11 +731,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
665
731
|
|
|
666
732
|
async def _handle_model_response(
|
|
667
733
|
self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
|
|
668
|
-
) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
|
|
734
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
|
|
669
735
|
"""Process a non-streamed response from the model.
|
|
670
736
|
|
|
671
737
|
Returns:
|
|
672
|
-
|
|
738
|
+
A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
|
|
673
739
|
"""
|
|
674
740
|
if model_response.role == 'model-text-response':
|
|
675
741
|
# plain string response
|
|
@@ -679,15 +745,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
679
745
|
result_data = await self._validate_result(result_data_input, deps, None)
|
|
680
746
|
except _result.ToolRetryError as e:
|
|
681
747
|
self._incr_result_retry()
|
|
682
|
-
return [e.tool_retry]
|
|
748
|
+
return None, [e.tool_retry]
|
|
683
749
|
else:
|
|
684
|
-
return _MarkFinalResult(result_data)
|
|
750
|
+
return _MarkFinalResult(result_data), []
|
|
685
751
|
else:
|
|
686
752
|
self._incr_result_retry()
|
|
687
753
|
response = _messages.RetryPrompt(
|
|
688
754
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
689
755
|
)
|
|
690
|
-
return [response]
|
|
756
|
+
return None, [response]
|
|
691
757
|
elif model_response.role == 'model-structured-response':
|
|
692
758
|
if self._result_schema is not None:
|
|
693
759
|
# if there's a result schema, and any of the calls match one of its tools, return the result
|
|
@@ -699,9 +765,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
699
765
|
result_data = await self._validate_result(result_data, deps, call)
|
|
700
766
|
except _result.ToolRetryError as e:
|
|
701
767
|
self._incr_result_retry()
|
|
702
|
-
return [e.tool_retry]
|
|
768
|
+
return None, [e.tool_retry]
|
|
703
769
|
else:
|
|
704
|
-
|
|
770
|
+
# Add a ToolReturn message for the schema tool call
|
|
771
|
+
tool_return = _messages.ToolReturn(
|
|
772
|
+
tool_name=call.tool_name,
|
|
773
|
+
content='Final result processed.',
|
|
774
|
+
tool_id=call.tool_id,
|
|
775
|
+
)
|
|
776
|
+
return _MarkFinalResult(result_data), [tool_return]
|
|
705
777
|
|
|
706
778
|
if not model_response.calls:
|
|
707
779
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
@@ -716,26 +788,24 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
716
788
|
messages.append(self._unknown_tool(call.tool_name))
|
|
717
789
|
|
|
718
790
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
719
|
-
|
|
720
|
-
|
|
791
|
+
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
|
|
792
|
+
messages.extend(task_results)
|
|
793
|
+
return None, messages
|
|
721
794
|
else:
|
|
722
795
|
assert_never(model_response)
|
|
723
796
|
|
|
724
797
|
async def _handle_streamed_model_response(
|
|
725
798
|
self, model_response: models.EitherStreamedResponse, deps: AgentDeps
|
|
726
|
-
) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
|
|
799
|
+
) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
|
|
727
800
|
"""Process a streamed response from the model.
|
|
728
801
|
|
|
729
|
-
TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
|
|
730
|
-
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
|
|
731
|
-
|
|
732
802
|
Returns:
|
|
733
|
-
|
|
803
|
+
A tuple of (final_result, messages). If final_result is not None, the conversation should end.
|
|
734
804
|
"""
|
|
735
805
|
if isinstance(model_response, models.StreamTextResponse):
|
|
736
806
|
# plain string response
|
|
737
807
|
if self._allow_text_result:
|
|
738
|
-
return _MarkFinalResult(model_response)
|
|
808
|
+
return _MarkFinalResult(model_response), []
|
|
739
809
|
else:
|
|
740
810
|
self._incr_result_retry()
|
|
741
811
|
response = _messages.RetryPrompt(
|
|
@@ -745,7 +815,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
745
815
|
async for _ in model_response:
|
|
746
816
|
pass
|
|
747
817
|
|
|
748
|
-
return [response]
|
|
818
|
+
return None, [response]
|
|
749
819
|
else:
|
|
750
820
|
assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
|
|
751
821
|
if self._result_schema is not None:
|
|
@@ -759,8 +829,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
759
829
|
break
|
|
760
830
|
structured_msg = model_response.get()
|
|
761
831
|
|
|
762
|
-
if self._result_schema.find_tool(structured_msg):
|
|
763
|
-
|
|
832
|
+
if match := self._result_schema.find_tool(structured_msg):
|
|
833
|
+
call, _ = match
|
|
834
|
+
tool_return = _messages.ToolReturn(
|
|
835
|
+
tool_name=call.tool_name,
|
|
836
|
+
content='Final result processed.',
|
|
837
|
+
tool_id=call.tool_id,
|
|
838
|
+
)
|
|
839
|
+
return _MarkFinalResult(model_response), [tool_return]
|
|
764
840
|
|
|
765
841
|
# the model is calling a tool function, consume the response to get the next message
|
|
766
842
|
async for _ in model_response:
|
|
@@ -779,8 +855,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
779
855
|
messages.append(self._unknown_tool(call.tool_name))
|
|
780
856
|
|
|
781
857
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
782
|
-
|
|
783
|
-
|
|
858
|
+
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
|
|
859
|
+
messages.extend(task_results)
|
|
860
|
+
return None, messages
|
|
784
861
|
|
|
785
862
|
async def _validate_result(
|
|
786
863
|
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
|
|
@@ -852,6 +929,8 @@ class _MarkFinalResult(Generic[ResultData]):
|
|
|
852
929
|
"""Marker class to indicate that the result is the final result.
|
|
853
930
|
|
|
854
931
|
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
|
|
932
|
+
|
|
933
|
+
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
855
934
|
"""
|
|
856
935
|
|
|
857
936
|
data: ResultData
|