pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +219 -315
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +296 -226
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -155
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +14 -2
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +19 -9
- pydantic_ai/models/__init__.py +43 -19
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +3 -11
- pydantic_ai/models/google.py +3 -12
- pydantic_ai/models/groq.py +2 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +3 -3
- pydantic_ai/models/openai.py +5 -5
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +57 -33
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
- pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
- pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol
|
|
6
|
+
|
|
7
|
+
from pydantic_core import SchemaValidator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from .._run_context import AgentDepsT, RunContext
|
|
11
|
+
from ..tools import ToolDefinition, ToolsPrepareFunc
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .filtered import FilteredToolset
|
|
15
|
+
from .prefixed import PrefixedToolset
|
|
16
|
+
from .prepared import PreparedToolset
|
|
17
|
+
from .renamed import RenamedToolset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SchemaValidatorProt(Protocol):
|
|
21
|
+
"""Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible)."""
|
|
22
|
+
|
|
23
|
+
def validate_json(
|
|
24
|
+
self,
|
|
25
|
+
input: str | bytes | bytearray,
|
|
26
|
+
*,
|
|
27
|
+
allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
|
28
|
+
**kwargs: Any,
|
|
29
|
+
) -> Any: ...
|
|
30
|
+
|
|
31
|
+
def validate_python(
|
|
32
|
+
self, input: Any, *, allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, **kwargs: Any
|
|
33
|
+
) -> Any: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ToolsetTool(Generic[AgentDepsT]):
|
|
38
|
+
"""Definition of a tool available on a toolset.
|
|
39
|
+
|
|
40
|
+
This is a wrapper around a plain tool definition that includes information about:
|
|
41
|
+
|
|
42
|
+
- the toolset that provided it, for use in error messages
|
|
43
|
+
- the maximum number of retries to attempt if the tool call fails
|
|
44
|
+
- the validator for the tool's arguments
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
toolset: AbstractToolset[AgentDepsT]
|
|
48
|
+
"""The toolset that provided this tool, for use in error messages."""
|
|
49
|
+
tool_def: ToolDefinition
|
|
50
|
+
"""The tool definition for this tool, including the name, description, and parameters."""
|
|
51
|
+
max_retries: int
|
|
52
|
+
"""The maximum number of retries to attempt if the tool call fails."""
|
|
53
|
+
args_validator: SchemaValidator | SchemaValidatorProt
|
|
54
|
+
"""The Pydantic Core validator for the tool's arguments.
|
|
55
|
+
|
|
56
|
+
For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AbstractToolset(ABC, Generic[AgentDepsT]):
|
|
61
|
+
"""A toolset is a collection of tools that can be used by an agent.
|
|
62
|
+
|
|
63
|
+
It is responsible for:
|
|
64
|
+
|
|
65
|
+
- Listing the tools it contains
|
|
66
|
+
- Validating the arguments of the tools
|
|
67
|
+
- Calling the tools
|
|
68
|
+
|
|
69
|
+
See [toolset docs](../toolsets.md) for more information.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def name(self) -> str:
|
|
74
|
+
"""The name of the toolset for use in error messages."""
|
|
75
|
+
return self.__class__.__name__.replace('Toolset', ' toolset')
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def tool_name_conflict_hint(self) -> str:
|
|
79
|
+
"""A hint for how to avoid name conflicts with other toolsets for use in error messages."""
|
|
80
|
+
return 'Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.'
|
|
81
|
+
|
|
82
|
+
async def __aenter__(self) -> Self:
|
|
83
|
+
"""Enter the toolset context.
|
|
84
|
+
|
|
85
|
+
This is where you can set up network connections in a concrete implementation.
|
|
86
|
+
"""
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
90
|
+
"""Exit the toolset context.
|
|
91
|
+
|
|
92
|
+
This is where you can tear down network connections in a concrete implementation.
|
|
93
|
+
"""
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
98
|
+
"""The tools that are available in this toolset."""
|
|
99
|
+
raise NotImplementedError()
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def call_tool(
|
|
103
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
104
|
+
) -> Any:
|
|
105
|
+
"""Call a tool with the given arguments.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
name: The name of the tool to call.
|
|
109
|
+
tool_args: The arguments to pass to the tool.
|
|
110
|
+
ctx: The run context.
|
|
111
|
+
tool: The tool definition returned by [`get_tools`][pydantic_ai.toolsets.AbstractToolset.get_tools] that was called.
|
|
112
|
+
"""
|
|
113
|
+
raise NotImplementedError()
|
|
114
|
+
|
|
115
|
+
def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
|
|
116
|
+
"""Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling)."""
|
|
117
|
+
visitor(self)
|
|
118
|
+
|
|
119
|
+
def filtered(
|
|
120
|
+
self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
|
|
121
|
+
) -> FilteredToolset[AgentDepsT]:
|
|
122
|
+
"""Returns a new toolset that filters this toolset's tools using a filter function that takes the agent context and the tool definition.
|
|
123
|
+
|
|
124
|
+
See [toolset docs](../toolsets.md#filtering-tools) for more information.
|
|
125
|
+
"""
|
|
126
|
+
from .filtered import FilteredToolset
|
|
127
|
+
|
|
128
|
+
return FilteredToolset(self, filter_func)
|
|
129
|
+
|
|
130
|
+
def prefixed(self, prefix: str) -> PrefixedToolset[AgentDepsT]:
|
|
131
|
+
"""Returns a new toolset that prefixes the names of this toolset's tools.
|
|
132
|
+
|
|
133
|
+
See [toolset docs](../toolsets.md#prefixing-tool-names) for more information.
|
|
134
|
+
"""
|
|
135
|
+
from .prefixed import PrefixedToolset
|
|
136
|
+
|
|
137
|
+
return PrefixedToolset(self, prefix)
|
|
138
|
+
|
|
139
|
+
def prepared(self, prepare_func: ToolsPrepareFunc[AgentDepsT]) -> PreparedToolset[AgentDepsT]:
|
|
140
|
+
"""Returns a new toolset that prepares this toolset's tools using a prepare function that takes the agent context and the original tool definitions.
|
|
141
|
+
|
|
142
|
+
See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information.
|
|
143
|
+
"""
|
|
144
|
+
from .prepared import PreparedToolset
|
|
145
|
+
|
|
146
|
+
return PreparedToolset(self, prepare_func)
|
|
147
|
+
|
|
148
|
+
def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]:
|
|
149
|
+
"""Returns a new toolset that renames this toolset's tools using a dictionary mapping new names to original names.
|
|
150
|
+
|
|
151
|
+
See [toolset docs](../toolsets.md#renaming-tools) for more information.
|
|
152
|
+
"""
|
|
153
|
+
from .renamed import RenamedToolset
|
|
154
|
+
|
|
155
|
+
return RenamedToolset(self, name_map)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from contextlib import AsyncExitStack
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from .._run_context import AgentDepsT, RunContext
|
|
12
|
+
from .._utils import get_async_lock
|
|
13
|
+
from ..exceptions import UserError
|
|
14
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class _CombinedToolsetTool(ToolsetTool[AgentDepsT]):
|
|
19
|
+
"""A tool definition for a combined toolset tools that keeps track of the source toolset and tool."""
|
|
20
|
+
|
|
21
|
+
source_toolset: AbstractToolset[AgentDepsT]
|
|
22
|
+
source_tool: ToolsetTool[AgentDepsT]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class CombinedToolset(AbstractToolset[AgentDepsT]):
|
|
27
|
+
"""A toolset that combines multiple toolsets.
|
|
28
|
+
|
|
29
|
+
See [toolset docs](../toolsets.md#combining-toolsets) for more information.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]]
|
|
33
|
+
|
|
34
|
+
_enter_lock: asyncio.Lock = field(compare=False, init=False)
|
|
35
|
+
_entered_count: int = field(init=False)
|
|
36
|
+
_exit_stack: AsyncExitStack | None = field(init=False)
|
|
37
|
+
|
|
38
|
+
def __post_init__(self):
|
|
39
|
+
self._enter_lock = get_async_lock()
|
|
40
|
+
self._entered_count = 0
|
|
41
|
+
self._exit_stack = None
|
|
42
|
+
|
|
43
|
+
async def __aenter__(self) -> Self:
|
|
44
|
+
async with self._enter_lock:
|
|
45
|
+
if self._entered_count == 0:
|
|
46
|
+
self._exit_stack = AsyncExitStack()
|
|
47
|
+
for toolset in self.toolsets:
|
|
48
|
+
await self._exit_stack.enter_async_context(toolset)
|
|
49
|
+
self._entered_count += 1
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
53
|
+
async with self._enter_lock:
|
|
54
|
+
self._entered_count -= 1
|
|
55
|
+
if self._entered_count == 0 and self._exit_stack is not None:
|
|
56
|
+
await self._exit_stack.aclose()
|
|
57
|
+
self._exit_stack = None
|
|
58
|
+
|
|
59
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
60
|
+
toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
|
|
61
|
+
all_tools: dict[str, ToolsetTool[AgentDepsT]] = {}
|
|
62
|
+
|
|
63
|
+
for toolset, tools in zip(self.toolsets, toolsets_tools):
|
|
64
|
+
for name, tool in tools.items():
|
|
65
|
+
if existing_tools := all_tools.get(name):
|
|
66
|
+
raise UserError(
|
|
67
|
+
f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}'
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
all_tools[name] = _CombinedToolsetTool(
|
|
71
|
+
toolset=tool.toolset,
|
|
72
|
+
tool_def=tool.tool_def,
|
|
73
|
+
max_retries=tool.max_retries,
|
|
74
|
+
args_validator=tool.args_validator,
|
|
75
|
+
source_toolset=toolset,
|
|
76
|
+
source_tool=tool,
|
|
77
|
+
)
|
|
78
|
+
return all_tools
|
|
79
|
+
|
|
80
|
+
async def call_tool(
|
|
81
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
82
|
+
) -> Any:
|
|
83
|
+
assert isinstance(tool, _CombinedToolsetTool)
|
|
84
|
+
return await tool.source_toolset.call_tool(name, tool_args, ctx, tool.source_tool)
|
|
85
|
+
|
|
86
|
+
def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
|
|
87
|
+
for toolset in self.toolsets:
|
|
88
|
+
toolset.apply(visitor)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic_core import SchemaValidator, core_schema
|
|
7
|
+
|
|
8
|
+
from .._run_context import AgentDepsT, RunContext
|
|
9
|
+
from ..tools import ToolDefinition
|
|
10
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
11
|
+
|
|
12
|
+
TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema())
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class DeferredToolset(AbstractToolset[AgentDepsT]):
|
|
17
|
+
"""A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called.
|
|
18
|
+
|
|
19
|
+
See [toolset docs](../toolsets.md#deferred-toolset), [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind], and [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] for more information.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
tool_defs: list[ToolDefinition]
|
|
23
|
+
|
|
24
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
25
|
+
return {
|
|
26
|
+
tool_def.name: ToolsetTool(
|
|
27
|
+
toolset=self,
|
|
28
|
+
tool_def=replace(tool_def, kind='deferred'),
|
|
29
|
+
max_retries=0,
|
|
30
|
+
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
31
|
+
)
|
|
32
|
+
for tool_def in self.tool_defs
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
async def call_tool(
|
|
36
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
37
|
+
) -> Any:
|
|
38
|
+
raise NotImplementedError('Deferred tools cannot be called')
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from .._run_context import AgentDepsT, RunContext
|
|
7
|
+
from ..tools import ToolDefinition
|
|
8
|
+
from .abstract import ToolsetTool
|
|
9
|
+
from .wrapper import WrapperToolset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class FilteredToolset(WrapperToolset[AgentDepsT]):
|
|
14
|
+
"""A toolset that filters the tools it contains using a filter function that takes the agent context and the tool definition.
|
|
15
|
+
|
|
16
|
+
See [toolset docs](../toolsets.md#filtering-tools) for more information.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
|
|
20
|
+
|
|
21
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
22
|
+
return {
|
|
23
|
+
name: tool for name, tool in (await super().get_tools(ctx)).items() if self.filter_func(ctx, tool.tool_def)
|
|
24
|
+
}
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Awaitable, Sequence
|
|
4
|
+
from dataclasses import dataclass, field, replace
|
|
5
|
+
from typing import Any, Callable, overload
|
|
6
|
+
|
|
7
|
+
from pydantic.json_schema import GenerateJsonSchema
|
|
8
|
+
|
|
9
|
+
from .._run_context import AgentDepsT, RunContext
|
|
10
|
+
from ..exceptions import UserError
|
|
11
|
+
from ..tools import (
|
|
12
|
+
DocstringFormat,
|
|
13
|
+
GenerateToolJsonSchema,
|
|
14
|
+
Tool,
|
|
15
|
+
ToolFuncEither,
|
|
16
|
+
ToolParams,
|
|
17
|
+
ToolPrepareFunc,
|
|
18
|
+
)
|
|
19
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class _FunctionToolsetTool(ToolsetTool[AgentDepsT]):
|
|
24
|
+
"""A tool definition for a function toolset tool that keeps track of the function to call."""
|
|
25
|
+
|
|
26
|
+
call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(init=False)
|
|
30
|
+
class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
31
|
+
"""A toolset that lets Python functions be used as tools.
|
|
32
|
+
|
|
33
|
+
See [toolset docs](../toolsets.md#function-toolset) for more information.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
max_retries: int = field(default=1)
|
|
37
|
+
tools: dict[str, Tool[Any]] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1):
|
|
40
|
+
"""Build a new function toolset.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
tools: The tools to add to the toolset.
|
|
44
|
+
max_retries: The maximum number of retries for each tool during a run.
|
|
45
|
+
"""
|
|
46
|
+
self.max_retries = max_retries
|
|
47
|
+
self.tools = {}
|
|
48
|
+
for tool in tools:
|
|
49
|
+
if isinstance(tool, Tool):
|
|
50
|
+
self.add_tool(tool)
|
|
51
|
+
else:
|
|
52
|
+
self.add_function(tool)
|
|
53
|
+
|
|
54
|
+
@overload
|
|
55
|
+
def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ...
|
|
56
|
+
|
|
57
|
+
@overload
|
|
58
|
+
def tool(
|
|
59
|
+
self,
|
|
60
|
+
/,
|
|
61
|
+
*,
|
|
62
|
+
name: str | None = None,
|
|
63
|
+
retries: int | None = None,
|
|
64
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
65
|
+
docstring_format: DocstringFormat = 'auto',
|
|
66
|
+
require_parameter_descriptions: bool = False,
|
|
67
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
68
|
+
strict: bool | None = None,
|
|
69
|
+
) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ...
|
|
70
|
+
|
|
71
|
+
def tool(
|
|
72
|
+
self,
|
|
73
|
+
func: ToolFuncEither[AgentDepsT, ToolParams] | None = None,
|
|
74
|
+
/,
|
|
75
|
+
*,
|
|
76
|
+
name: str | None = None,
|
|
77
|
+
retries: int | None = None,
|
|
78
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
79
|
+
docstring_format: DocstringFormat = 'auto',
|
|
80
|
+
require_parameter_descriptions: bool = False,
|
|
81
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
82
|
+
strict: bool | None = None,
|
|
83
|
+
) -> Any:
|
|
84
|
+
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
85
|
+
|
|
86
|
+
Can decorate a sync or async functions.
|
|
87
|
+
|
|
88
|
+
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
89
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
90
|
+
|
|
91
|
+
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
92
|
+
so the signature of functions decorated with `@toolset.tool` is obscured.
|
|
93
|
+
|
|
94
|
+
Example:
|
|
95
|
+
```python
|
|
96
|
+
from pydantic_ai import Agent, RunContext
|
|
97
|
+
from pydantic_ai.toolsets.function import FunctionToolset
|
|
98
|
+
|
|
99
|
+
toolset = FunctionToolset()
|
|
100
|
+
|
|
101
|
+
@toolset.tool
|
|
102
|
+
def foobar(ctx: RunContext[int], x: int) -> int:
|
|
103
|
+
return ctx.deps + x
|
|
104
|
+
|
|
105
|
+
@toolset.tool(retries=2)
|
|
106
|
+
async def spam(ctx: RunContext[str], y: float) -> float:
|
|
107
|
+
return ctx.deps + y
|
|
108
|
+
|
|
109
|
+
agent = Agent('test', toolsets=[toolset], deps_type=int)
|
|
110
|
+
result = agent.run_sync('foobar', deps=1)
|
|
111
|
+
print(result.output)
|
|
112
|
+
#> {"foobar":1,"spam":1.0}
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
func: The tool function to register.
|
|
117
|
+
name: The name of the tool, defaults to the function name.
|
|
118
|
+
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
119
|
+
which defaults to 1.
|
|
120
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
121
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
122
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
123
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
124
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
125
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
126
|
+
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
127
|
+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
128
|
+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def tool_decorator(
|
|
132
|
+
func_: ToolFuncEither[AgentDepsT, ToolParams],
|
|
133
|
+
) -> ToolFuncEither[AgentDepsT, ToolParams]:
|
|
134
|
+
# noinspection PyTypeChecker
|
|
135
|
+
self.add_function(
|
|
136
|
+
func_,
|
|
137
|
+
None,
|
|
138
|
+
name,
|
|
139
|
+
retries,
|
|
140
|
+
prepare,
|
|
141
|
+
docstring_format,
|
|
142
|
+
require_parameter_descriptions,
|
|
143
|
+
schema_generator,
|
|
144
|
+
strict,
|
|
145
|
+
)
|
|
146
|
+
return func_
|
|
147
|
+
|
|
148
|
+
return tool_decorator if func is None else tool_decorator(func)
|
|
149
|
+
|
|
150
|
+
def add_function(
|
|
151
|
+
self,
|
|
152
|
+
func: ToolFuncEither[AgentDepsT, ToolParams],
|
|
153
|
+
takes_ctx: bool | None = None,
|
|
154
|
+
name: str | None = None,
|
|
155
|
+
retries: int | None = None,
|
|
156
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
157
|
+
docstring_format: DocstringFormat = 'auto',
|
|
158
|
+
require_parameter_descriptions: bool = False,
|
|
159
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
160
|
+
strict: bool | None = None,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Add a function as a tool to the toolset.
|
|
163
|
+
|
|
164
|
+
Can take a sync or async function.
|
|
165
|
+
|
|
166
|
+
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
167
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
func: The tool function to register.
|
|
171
|
+
takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. If `None`, this is inferred from the function signature.
|
|
172
|
+
name: The name of the tool, defaults to the function name.
|
|
173
|
+
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
174
|
+
which defaults to 1.
|
|
175
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
176
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
177
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
178
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
179
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
180
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
181
|
+
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
182
|
+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
183
|
+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
184
|
+
"""
|
|
185
|
+
tool = Tool[AgentDepsT](
|
|
186
|
+
func,
|
|
187
|
+
takes_ctx=takes_ctx,
|
|
188
|
+
name=name,
|
|
189
|
+
max_retries=retries,
|
|
190
|
+
prepare=prepare,
|
|
191
|
+
docstring_format=docstring_format,
|
|
192
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
193
|
+
schema_generator=schema_generator,
|
|
194
|
+
strict=strict,
|
|
195
|
+
)
|
|
196
|
+
self.add_tool(tool)
|
|
197
|
+
|
|
198
|
+
def add_tool(self, tool: Tool[AgentDepsT]) -> None:
|
|
199
|
+
"""Add a tool to the toolset.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
tool: The tool to add.
|
|
203
|
+
"""
|
|
204
|
+
if tool.name in self.tools:
|
|
205
|
+
raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
206
|
+
if tool.max_retries is None:
|
|
207
|
+
tool.max_retries = self.max_retries
|
|
208
|
+
self.tools[tool.name] = tool
|
|
209
|
+
|
|
210
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
211
|
+
tools: dict[str, ToolsetTool[AgentDepsT]] = {}
|
|
212
|
+
for original_name, tool in self.tools.items():
|
|
213
|
+
run_context = replace(ctx, tool_name=original_name, retry=ctx.retries.get(original_name, 0))
|
|
214
|
+
tool_def = await tool.prepare_tool_def(run_context)
|
|
215
|
+
if not tool_def:
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
new_name = tool_def.name
|
|
219
|
+
if new_name in tools:
|
|
220
|
+
if new_name != original_name:
|
|
221
|
+
raise UserError(f'Renaming tool {original_name!r} to {new_name!r} conflicts with existing tool.')
|
|
222
|
+
else:
|
|
223
|
+
raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.')
|
|
224
|
+
|
|
225
|
+
tools[new_name] = _FunctionToolsetTool(
|
|
226
|
+
toolset=self,
|
|
227
|
+
tool_def=tool_def,
|
|
228
|
+
max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries,
|
|
229
|
+
args_validator=tool.function_schema.validator,
|
|
230
|
+
call_func=tool.function_schema.call,
|
|
231
|
+
)
|
|
232
|
+
return tools
|
|
233
|
+
|
|
234
|
+
async def call_tool(
|
|
235
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
236
|
+
) -> Any:
|
|
237
|
+
assert isinstance(tool, _FunctionToolsetTool)
|
|
238
|
+
return await tool.call_func(tool_args, ctx)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .._run_context import AgentDepsT, RunContext
|
|
7
|
+
from .abstract import ToolsetTool
|
|
8
|
+
from .wrapper import WrapperToolset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class PrefixedToolset(WrapperToolset[AgentDepsT]):
|
|
13
|
+
"""A toolset that prefixes the names of the tools it contains.
|
|
14
|
+
|
|
15
|
+
See [toolset docs](../toolsets.md#prefixing-tool-names) for more information.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
prefix: str
|
|
19
|
+
|
|
20
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
21
|
+
return {
|
|
22
|
+
new_name: replace(
|
|
23
|
+
tool,
|
|
24
|
+
toolset=self,
|
|
25
|
+
tool_def=replace(tool.tool_def, name=new_name),
|
|
26
|
+
)
|
|
27
|
+
for name, tool in (await super().get_tools(ctx)).items()
|
|
28
|
+
if (new_name := f'{self.prefix}_{name}')
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
async def call_tool(
|
|
32
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
33
|
+
) -> Any:
|
|
34
|
+
original_name = name.removeprefix(self.prefix + '_')
|
|
35
|
+
ctx = replace(ctx, tool_name=original_name)
|
|
36
|
+
tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name))
|
|
37
|
+
return await super().call_tool(original_name, tool_args, ctx, tool)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
|
|
5
|
+
from .._run_context import AgentDepsT, RunContext
|
|
6
|
+
from ..exceptions import UserError
|
|
7
|
+
from ..tools import ToolsPrepareFunc
|
|
8
|
+
from .abstract import ToolsetTool
|
|
9
|
+
from .wrapper import WrapperToolset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class PreparedToolset(WrapperToolset[AgentDepsT]):
|
|
14
|
+
"""A toolset that prepares the tools it contains using a prepare function that takes the agent context and the original tool definitions.
|
|
15
|
+
|
|
16
|
+
See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
prepare_func: ToolsPrepareFunc[AgentDepsT]
|
|
20
|
+
|
|
21
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
22
|
+
original_tools = await super().get_tools(ctx)
|
|
23
|
+
original_tool_defs = [tool.tool_def for tool in original_tools.values()]
|
|
24
|
+
prepared_tool_defs_by_name = {
|
|
25
|
+
tool_def.name: tool_def for tool_def in (await self.prepare_func(ctx, original_tool_defs) or [])
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
if len(prepared_tool_defs_by_name.keys() - original_tools.keys()) > 0:
|
|
29
|
+
raise UserError(
|
|
30
|
+
'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.'
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
return {
|
|
34
|
+
name: replace(original_tools[name], tool_def=tool_def)
|
|
35
|
+
for name, tool_def in prepared_tool_defs_by_name.items()
|
|
36
|
+
}
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .._run_context import AgentDepsT, RunContext
|
|
7
|
+
from .abstract import ToolsetTool
|
|
8
|
+
from .wrapper import WrapperToolset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class RenamedToolset(WrapperToolset[AgentDepsT]):
|
|
13
|
+
"""A toolset that renames the tools it contains using a dictionary mapping new names to original names.
|
|
14
|
+
|
|
15
|
+
See [toolset docs](../toolsets.md#renaming-tools) for more information.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
name_map: dict[str, str]
|
|
19
|
+
|
|
20
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
21
|
+
original_to_new_name_map = {v: k for k, v in self.name_map.items()}
|
|
22
|
+
original_tools = await super().get_tools(ctx)
|
|
23
|
+
tools: dict[str, ToolsetTool[AgentDepsT]] = {}
|
|
24
|
+
for original_name, tool in original_tools.items():
|
|
25
|
+
new_name = original_to_new_name_map.get(original_name, None)
|
|
26
|
+
if new_name:
|
|
27
|
+
tools[new_name] = replace(
|
|
28
|
+
tool,
|
|
29
|
+
toolset=self,
|
|
30
|
+
tool_def=replace(tool.tool_def, name=new_name),
|
|
31
|
+
)
|
|
32
|
+
else:
|
|
33
|
+
tools[original_name] = tool
|
|
34
|
+
return tools
|
|
35
|
+
|
|
36
|
+
async def call_tool(
|
|
37
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
38
|
+
) -> Any:
|
|
39
|
+
original_name = self.name_map.get(name, name)
|
|
40
|
+
ctx = replace(ctx, tool_name=original_name)
|
|
41
|
+
tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name))
|
|
42
|
+
return await super().call_tool(original_name, tool_args, ctx, tool)
|