pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +3 -3
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +378 -164
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/format_prompt.py +3 -6
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -18
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/instrumented.py +6 -1
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +16 -4
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/RECORD +48 -35
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/tools.py
CHANGED
|
@@ -1,20 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import dataclasses
|
|
4
|
-
import json
|
|
5
3
|
from collections.abc import Awaitable, Sequence
|
|
6
4
|
from dataclasses import dataclass, field
|
|
7
5
|
from typing import Any, Callable, Generic, Literal, Union
|
|
8
6
|
|
|
9
|
-
from opentelemetry.trace import Tracer
|
|
10
|
-
from pydantic import ValidationError
|
|
11
7
|
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
|
|
12
8
|
from pydantic_core import SchemaValidator, core_schema
|
|
13
9
|
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar
|
|
14
10
|
|
|
15
|
-
from . import _function_schema, _utils
|
|
11
|
+
from . import _function_schema, _utils
|
|
16
12
|
from ._run_context import AgentDepsT, RunContext
|
|
17
|
-
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
18
13
|
|
|
19
14
|
__all__ = (
|
|
20
15
|
'AgentDepsT',
|
|
@@ -32,7 +27,6 @@ __all__ = (
|
|
|
32
27
|
'ToolDefinition',
|
|
33
28
|
)
|
|
34
29
|
|
|
35
|
-
from .messages import ToolReturnPart
|
|
36
30
|
|
|
37
31
|
ToolParams = ParamSpec('ToolParams', default=...)
|
|
38
32
|
"""Retrieval function param spec."""
|
|
@@ -173,12 +167,6 @@ class Tool(Generic[AgentDepsT]):
|
|
|
173
167
|
This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
|
|
174
168
|
"""
|
|
175
169
|
|
|
176
|
-
# TODO: Consider moving this current_retry state to live on something other than the tool.
|
|
177
|
-
# We've worked around this for now by copying instances of the tool when creating new runs,
|
|
178
|
-
# but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things
|
|
179
|
-
# up, though is also likely a larger effort to refactor.
|
|
180
|
-
current_retry: int = field(default=0, init=False)
|
|
181
|
-
|
|
182
170
|
def __init__(
|
|
183
171
|
self,
|
|
184
172
|
function: ToolFuncEither[AgentDepsT],
|
|
@@ -303,6 +291,15 @@ class Tool(Generic[AgentDepsT]):
|
|
|
303
291
|
function_schema=function_schema,
|
|
304
292
|
)
|
|
305
293
|
|
|
294
|
+
@property
|
|
295
|
+
def tool_def(self):
|
|
296
|
+
return ToolDefinition(
|
|
297
|
+
name=self.name,
|
|
298
|
+
description=self.description,
|
|
299
|
+
parameters_json_schema=self.function_schema.json_schema,
|
|
300
|
+
strict=self.strict,
|
|
301
|
+
)
|
|
302
|
+
|
|
306
303
|
async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
|
|
307
304
|
"""Get the tool definition.
|
|
308
305
|
|
|
@@ -312,113 +309,11 @@ class Tool(Generic[AgentDepsT]):
|
|
|
312
309
|
Returns:
|
|
313
310
|
return a `ToolDefinition` or `None` if the tools should not be registered for this run.
|
|
314
311
|
"""
|
|
315
|
-
|
|
316
|
-
name=self.name,
|
|
317
|
-
description=self.description,
|
|
318
|
-
parameters_json_schema=self.function_schema.json_schema,
|
|
319
|
-
strict=self.strict,
|
|
320
|
-
)
|
|
312
|
+
base_tool_def = self.tool_def
|
|
321
313
|
if self.prepare is not None:
|
|
322
|
-
return await self.prepare(ctx,
|
|
314
|
+
return await self.prepare(ctx, base_tool_def)
|
|
323
315
|
else:
|
|
324
|
-
return
|
|
325
|
-
|
|
326
|
-
async def run(
|
|
327
|
-
self,
|
|
328
|
-
message: _messages.ToolCallPart,
|
|
329
|
-
run_context: RunContext[AgentDepsT],
|
|
330
|
-
tracer: Tracer,
|
|
331
|
-
include_content: bool = False,
|
|
332
|
-
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
|
|
333
|
-
"""Run the tool function asynchronously.
|
|
334
|
-
|
|
335
|
-
This method wraps `_run` in an OpenTelemetry span.
|
|
336
|
-
|
|
337
|
-
See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>.
|
|
338
|
-
"""
|
|
339
|
-
span_attributes = {
|
|
340
|
-
'gen_ai.tool.name': self.name,
|
|
341
|
-
# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
|
|
342
|
-
'gen_ai.tool.call.id': message.tool_call_id,
|
|
343
|
-
**({'tool_arguments': message.args_as_json_str()} if include_content else {}),
|
|
344
|
-
'logfire.msg': f'running tool: {self.name}',
|
|
345
|
-
# add the JSON schema so these attributes are formatted nicely in Logfire
|
|
346
|
-
'logfire.json_schema': json.dumps(
|
|
347
|
-
{
|
|
348
|
-
'type': 'object',
|
|
349
|
-
'properties': {
|
|
350
|
-
**(
|
|
351
|
-
{
|
|
352
|
-
'tool_arguments': {'type': 'object'},
|
|
353
|
-
'tool_response': {'type': 'object'},
|
|
354
|
-
}
|
|
355
|
-
if include_content
|
|
356
|
-
else {}
|
|
357
|
-
),
|
|
358
|
-
'gen_ai.tool.name': {},
|
|
359
|
-
'gen_ai.tool.call.id': {},
|
|
360
|
-
},
|
|
361
|
-
}
|
|
362
|
-
),
|
|
363
|
-
}
|
|
364
|
-
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
|
|
365
|
-
response = await self._run(message, run_context)
|
|
366
|
-
if include_content and span.is_recording():
|
|
367
|
-
span.set_attribute(
|
|
368
|
-
'tool_response',
|
|
369
|
-
response.model_response_str()
|
|
370
|
-
if isinstance(response, ToolReturnPart)
|
|
371
|
-
else response.model_response(),
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
return response
|
|
375
|
-
|
|
376
|
-
async def _run(
|
|
377
|
-
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
|
|
378
|
-
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
|
|
379
|
-
try:
|
|
380
|
-
validator = self.function_schema.validator
|
|
381
|
-
if isinstance(message.args, str):
|
|
382
|
-
args_dict = validator.validate_json(message.args or '{}')
|
|
383
|
-
else:
|
|
384
|
-
args_dict = validator.validate_python(message.args or {})
|
|
385
|
-
except ValidationError as e:
|
|
386
|
-
return self._on_error(e, message)
|
|
387
|
-
|
|
388
|
-
ctx = dataclasses.replace(
|
|
389
|
-
run_context,
|
|
390
|
-
retry=self.current_retry,
|
|
391
|
-
tool_name=message.tool_name,
|
|
392
|
-
tool_call_id=message.tool_call_id,
|
|
393
|
-
)
|
|
394
|
-
try:
|
|
395
|
-
response_content = await self.function_schema.call(args_dict, ctx)
|
|
396
|
-
except ModelRetry as e:
|
|
397
|
-
return self._on_error(e, message)
|
|
398
|
-
|
|
399
|
-
self.current_retry = 0
|
|
400
|
-
return _messages.ToolReturnPart(
|
|
401
|
-
tool_name=message.tool_name,
|
|
402
|
-
content=response_content,
|
|
403
|
-
tool_call_id=message.tool_call_id,
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
def _on_error(
|
|
407
|
-
self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
|
|
408
|
-
) -> _messages.RetryPromptPart:
|
|
409
|
-
self.current_retry += 1
|
|
410
|
-
if self.max_retries is None or self.current_retry > self.max_retries:
|
|
411
|
-
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
|
|
412
|
-
else:
|
|
413
|
-
if isinstance(exc, ValidationError):
|
|
414
|
-
content = exc.errors(include_url=False, include_context=False)
|
|
415
|
-
else:
|
|
416
|
-
content = exc.message
|
|
417
|
-
return _messages.RetryPromptPart(
|
|
418
|
-
tool_name=call_message.tool_name,
|
|
419
|
-
content=content,
|
|
420
|
-
tool_call_id=call_message.tool_call_id,
|
|
421
|
-
)
|
|
316
|
+
return base_tool_def
|
|
422
317
|
|
|
423
318
|
|
|
424
319
|
ObjectJsonSchema: TypeAlias = dict[str, Any]
|
|
@@ -429,6 +324,9 @@ This type is used to define tools parameters (aka arguments) in [ToolDefinition]
|
|
|
429
324
|
With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
|
|
430
325
|
"""
|
|
431
326
|
|
|
327
|
+
ToolKind: TypeAlias = Literal['function', 'output', 'deferred']
|
|
328
|
+
"""Kind of tool."""
|
|
329
|
+
|
|
432
330
|
|
|
433
331
|
@dataclass(repr=False)
|
|
434
332
|
class ToolDefinition:
|
|
@@ -440,7 +338,7 @@ class ToolDefinition:
|
|
|
440
338
|
name: str
|
|
441
339
|
"""The name of the tool."""
|
|
442
340
|
|
|
443
|
-
parameters_json_schema: ObjectJsonSchema
|
|
341
|
+
parameters_json_schema: ObjectJsonSchema = field(default_factory=lambda: {'type': 'object', 'properties': {}})
|
|
444
342
|
"""The JSON schema for the tool's parameters."""
|
|
445
343
|
|
|
446
344
|
description: str | None = None
|
|
@@ -464,4 +362,13 @@ class ToolDefinition:
|
|
|
464
362
|
Note: this is currently only supported by OpenAI models.
|
|
465
363
|
"""
|
|
466
364
|
|
|
365
|
+
kind: ToolKind = field(default='function')
|
|
366
|
+
"""The kind of tool:
|
|
367
|
+
|
|
368
|
+
- `'function'`: a tool that will be executed by Pydantic AI during an agent run and has its result returned to the model
|
|
369
|
+
- `'output'`: a tool that passes through an output value that ends the run
|
|
370
|
+
- `'deferred'`: a tool whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running.
|
|
371
|
+
When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call.
|
|
372
|
+
"""
|
|
373
|
+
|
|
467
374
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
2
|
+
from .combined import CombinedToolset
|
|
3
|
+
from .deferred import DeferredToolset
|
|
4
|
+
from .filtered import FilteredToolset
|
|
5
|
+
from .function import FunctionToolset
|
|
6
|
+
from .prefixed import PrefixedToolset
|
|
7
|
+
from .prepared import PreparedToolset
|
|
8
|
+
from .renamed import RenamedToolset
|
|
9
|
+
from .wrapper import WrapperToolset
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
'AbstractToolset',
|
|
13
|
+
'ToolsetTool',
|
|
14
|
+
'CombinedToolset',
|
|
15
|
+
'DeferredToolset',
|
|
16
|
+
'FilteredToolset',
|
|
17
|
+
'FunctionToolset',
|
|
18
|
+
'PrefixedToolset',
|
|
19
|
+
'RenamedToolset',
|
|
20
|
+
'PreparedToolset',
|
|
21
|
+
'WrapperToolset',
|
|
22
|
+
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol
|
|
6
|
+
|
|
7
|
+
from pydantic_core import SchemaValidator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from .._run_context import AgentDepsT, RunContext
|
|
11
|
+
from ..tools import ToolDefinition, ToolsPrepareFunc
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .filtered import FilteredToolset
|
|
15
|
+
from .prefixed import PrefixedToolset
|
|
16
|
+
from .prepared import PreparedToolset
|
|
17
|
+
from .renamed import RenamedToolset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SchemaValidatorProt(Protocol):
|
|
21
|
+
"""Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible)."""
|
|
22
|
+
|
|
23
|
+
def validate_json(
|
|
24
|
+
self,
|
|
25
|
+
input: str | bytes | bytearray,
|
|
26
|
+
*,
|
|
27
|
+
allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
|
28
|
+
**kwargs: Any,
|
|
29
|
+
) -> Any: ...
|
|
30
|
+
|
|
31
|
+
def validate_python(
|
|
32
|
+
self, input: Any, *, allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, **kwargs: Any
|
|
33
|
+
) -> Any: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ToolsetTool(Generic[AgentDepsT]):
|
|
38
|
+
"""Definition of a tool available on a toolset.
|
|
39
|
+
|
|
40
|
+
This is a wrapper around a plain tool definition that includes information about:
|
|
41
|
+
|
|
42
|
+
- the toolset that provided it, for use in error messages
|
|
43
|
+
- the maximum number of retries to attempt if the tool call fails
|
|
44
|
+
- the validator for the tool's arguments
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
toolset: AbstractToolset[AgentDepsT]
|
|
48
|
+
"""The toolset that provided this tool, for use in error messages."""
|
|
49
|
+
tool_def: ToolDefinition
|
|
50
|
+
"""The tool definition for this tool, including the name, description, and parameters."""
|
|
51
|
+
max_retries: int
|
|
52
|
+
"""The maximum number of retries to attempt if the tool call fails."""
|
|
53
|
+
args_validator: SchemaValidator | SchemaValidatorProt
|
|
54
|
+
"""The Pydantic Core validator for the tool's arguments.
|
|
55
|
+
|
|
56
|
+
For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AbstractToolset(ABC, Generic[AgentDepsT]):
|
|
61
|
+
"""A toolset is a collection of tools that can be used by an agent.
|
|
62
|
+
|
|
63
|
+
It is responsible for:
|
|
64
|
+
|
|
65
|
+
- Listing the tools it contains
|
|
66
|
+
- Validating the arguments of the tools
|
|
67
|
+
- Calling the tools
|
|
68
|
+
|
|
69
|
+
See [toolset docs](../toolsets.md) for more information.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def name(self) -> str:
|
|
74
|
+
"""The name of the toolset for use in error messages."""
|
|
75
|
+
return self.__class__.__name__.replace('Toolset', ' toolset')
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def tool_name_conflict_hint(self) -> str:
|
|
79
|
+
"""A hint for how to avoid name conflicts with other toolsets for use in error messages."""
|
|
80
|
+
return 'Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.'
|
|
81
|
+
|
|
82
|
+
async def __aenter__(self) -> Self:
|
|
83
|
+
"""Enter the toolset context.
|
|
84
|
+
|
|
85
|
+
This is where you can set up network connections in a concrete implementation.
|
|
86
|
+
"""
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
90
|
+
"""Exit the toolset context.
|
|
91
|
+
|
|
92
|
+
This is where you can tear down network connections in a concrete implementation.
|
|
93
|
+
"""
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
98
|
+
"""The tools that are available in this toolset."""
|
|
99
|
+
raise NotImplementedError()
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def call_tool(
|
|
103
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
104
|
+
) -> Any:
|
|
105
|
+
"""Call a tool with the given arguments.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
name: The name of the tool to call.
|
|
109
|
+
tool_args: The arguments to pass to the tool.
|
|
110
|
+
ctx: The run context.
|
|
111
|
+
tool: The tool definition returned by [`get_tools`][pydantic_ai.toolsets.AbstractToolset.get_tools] that was called.
|
|
112
|
+
"""
|
|
113
|
+
raise NotImplementedError()
|
|
114
|
+
|
|
115
|
+
def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
|
|
116
|
+
"""Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling)."""
|
|
117
|
+
visitor(self)
|
|
118
|
+
|
|
119
|
+
def filtered(
|
|
120
|
+
self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
|
|
121
|
+
) -> FilteredToolset[AgentDepsT]:
|
|
122
|
+
"""Returns a new toolset that filters this toolset's tools using a filter function that takes the agent context and the tool definition.
|
|
123
|
+
|
|
124
|
+
See [toolset docs](../toolsets.md#filtering-tools) for more information.
|
|
125
|
+
"""
|
|
126
|
+
from .filtered import FilteredToolset
|
|
127
|
+
|
|
128
|
+
return FilteredToolset(self, filter_func)
|
|
129
|
+
|
|
130
|
+
def prefixed(self, prefix: str) -> PrefixedToolset[AgentDepsT]:
|
|
131
|
+
"""Returns a new toolset that prefixes the names of this toolset's tools.
|
|
132
|
+
|
|
133
|
+
See [toolset docs](../toolsets.md#prefixing-tool-names) for more information.
|
|
134
|
+
"""
|
|
135
|
+
from .prefixed import PrefixedToolset
|
|
136
|
+
|
|
137
|
+
return PrefixedToolset(self, prefix)
|
|
138
|
+
|
|
139
|
+
def prepared(self, prepare_func: ToolsPrepareFunc[AgentDepsT]) -> PreparedToolset[AgentDepsT]:
|
|
140
|
+
"""Returns a new toolset that prepares this toolset's tools using a prepare function that takes the agent context and the original tool definitions.
|
|
141
|
+
|
|
142
|
+
See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information.
|
|
143
|
+
"""
|
|
144
|
+
from .prepared import PreparedToolset
|
|
145
|
+
|
|
146
|
+
return PreparedToolset(self, prepare_func)
|
|
147
|
+
|
|
148
|
+
def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]:
|
|
149
|
+
"""Returns a new toolset that renames this toolset's tools using a dictionary mapping new names to original names.
|
|
150
|
+
|
|
151
|
+
See [toolset docs](../toolsets.md#renaming-tools) for more information.
|
|
152
|
+
"""
|
|
153
|
+
from .renamed import RenamedToolset
|
|
154
|
+
|
|
155
|
+
return RenamedToolset(self, name_map)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from contextlib import AsyncExitStack
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from .._run_context import AgentDepsT, RunContext
|
|
12
|
+
from .._utils import get_async_lock
|
|
13
|
+
from ..exceptions import UserError
|
|
14
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class _CombinedToolsetTool(ToolsetTool[AgentDepsT]):
|
|
19
|
+
"""A tool definition for a combined toolset tools that keeps track of the source toolset and tool."""
|
|
20
|
+
|
|
21
|
+
source_toolset: AbstractToolset[AgentDepsT]
|
|
22
|
+
source_tool: ToolsetTool[AgentDepsT]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class CombinedToolset(AbstractToolset[AgentDepsT]):
|
|
27
|
+
"""A toolset that combines multiple toolsets.
|
|
28
|
+
|
|
29
|
+
See [toolset docs](../toolsets.md#combining-toolsets) for more information.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]]
|
|
33
|
+
|
|
34
|
+
_enter_lock: asyncio.Lock = field(compare=False, init=False)
|
|
35
|
+
_entered_count: int = field(init=False)
|
|
36
|
+
_exit_stack: AsyncExitStack | None = field(init=False)
|
|
37
|
+
|
|
38
|
+
def __post_init__(self):
|
|
39
|
+
self._enter_lock = get_async_lock()
|
|
40
|
+
self._entered_count = 0
|
|
41
|
+
self._exit_stack = None
|
|
42
|
+
|
|
43
|
+
async def __aenter__(self) -> Self:
|
|
44
|
+
async with self._enter_lock:
|
|
45
|
+
if self._entered_count == 0:
|
|
46
|
+
self._exit_stack = AsyncExitStack()
|
|
47
|
+
for toolset in self.toolsets:
|
|
48
|
+
await self._exit_stack.enter_async_context(toolset)
|
|
49
|
+
self._entered_count += 1
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
53
|
+
async with self._enter_lock:
|
|
54
|
+
self._entered_count -= 1
|
|
55
|
+
if self._entered_count == 0 and self._exit_stack is not None:
|
|
56
|
+
await self._exit_stack.aclose()
|
|
57
|
+
self._exit_stack = None
|
|
58
|
+
|
|
59
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
60
|
+
toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
|
|
61
|
+
all_tools: dict[str, ToolsetTool[AgentDepsT]] = {}
|
|
62
|
+
|
|
63
|
+
for toolset, tools in zip(self.toolsets, toolsets_tools):
|
|
64
|
+
for name, tool in tools.items():
|
|
65
|
+
if existing_tools := all_tools.get(name):
|
|
66
|
+
raise UserError(
|
|
67
|
+
f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}'
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
all_tools[name] = _CombinedToolsetTool(
|
|
71
|
+
toolset=tool.toolset,
|
|
72
|
+
tool_def=tool.tool_def,
|
|
73
|
+
max_retries=tool.max_retries,
|
|
74
|
+
args_validator=tool.args_validator,
|
|
75
|
+
source_toolset=toolset,
|
|
76
|
+
source_tool=tool,
|
|
77
|
+
)
|
|
78
|
+
return all_tools
|
|
79
|
+
|
|
80
|
+
async def call_tool(
|
|
81
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
82
|
+
) -> Any:
|
|
83
|
+
assert isinstance(tool, _CombinedToolsetTool)
|
|
84
|
+
return await tool.source_toolset.call_tool(name, tool_args, ctx, tool.source_tool)
|
|
85
|
+
|
|
86
|
+
def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
|
|
87
|
+
for toolset in self.toolsets:
|
|
88
|
+
toolset.apply(visitor)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic_core import SchemaValidator, core_schema
|
|
7
|
+
|
|
8
|
+
from .._run_context import AgentDepsT, RunContext
|
|
9
|
+
from ..tools import ToolDefinition
|
|
10
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
11
|
+
|
|
12
|
+
TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema())
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class DeferredToolset(AbstractToolset[AgentDepsT]):
|
|
17
|
+
"""A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called.
|
|
18
|
+
|
|
19
|
+
See [toolset docs](../toolsets.md#deferred-toolset), [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind], and [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] for more information.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
tool_defs: list[ToolDefinition]
|
|
23
|
+
|
|
24
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
25
|
+
return {
|
|
26
|
+
tool_def.name: ToolsetTool(
|
|
27
|
+
toolset=self,
|
|
28
|
+
tool_def=replace(tool_def, kind='deferred'),
|
|
29
|
+
max_retries=0,
|
|
30
|
+
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
31
|
+
)
|
|
32
|
+
for tool_def in self.tool_defs
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
async def call_tool(
|
|
36
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
37
|
+
) -> Any:
|
|
38
|
+
raise NotImplementedError('Deferred tools cannot be called')
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from .._run_context import AgentDepsT, RunContext
|
|
7
|
+
from ..tools import ToolDefinition
|
|
8
|
+
from .abstract import ToolsetTool
|
|
9
|
+
from .wrapper import WrapperToolset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class FilteredToolset(WrapperToolset[AgentDepsT]):
|
|
14
|
+
"""A toolset that filters the tools it contains using a filter function that takes the agent context and the tool definition.
|
|
15
|
+
|
|
16
|
+
See [toolset docs](../toolsets.md#filtering-tools) for more information.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
|
|
20
|
+
|
|
21
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
22
|
+
return {
|
|
23
|
+
name: tool for name, tool in (await super().get_tools(ctx)).items() if self.filter_func(ctx, tool.tool_def)
|
|
24
|
+
}
|