pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (55) hide show
  1. pydantic_ai/_agent_graph.py +219 -315
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +296 -226
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -155
  10. pydantic_ai/common_tools/duckduckgo.py +5 -2
  11. pydantic_ai/exceptions.py +14 -2
  12. pydantic_ai/ext/aci.py +12 -3
  13. pydantic_ai/ext/langchain.py +9 -1
  14. pydantic_ai/mcp.py +147 -84
  15. pydantic_ai/messages.py +19 -9
  16. pydantic_ai/models/__init__.py +43 -19
  17. pydantic_ai/models/anthropic.py +2 -2
  18. pydantic_ai/models/bedrock.py +1 -1
  19. pydantic_ai/models/cohere.py +1 -1
  20. pydantic_ai/models/function.py +50 -24
  21. pydantic_ai/models/gemini.py +3 -11
  22. pydantic_ai/models/google.py +3 -12
  23. pydantic_ai/models/groq.py +2 -1
  24. pydantic_ai/models/huggingface.py +463 -0
  25. pydantic_ai/models/instrumented.py +1 -1
  26. pydantic_ai/models/mistral.py +3 -3
  27. pydantic_ai/models/openai.py +5 -5
  28. pydantic_ai/output.py +21 -7
  29. pydantic_ai/profiles/google.py +1 -1
  30. pydantic_ai/profiles/moonshotai.py +8 -0
  31. pydantic_ai/providers/__init__.py +4 -0
  32. pydantic_ai/providers/google.py +2 -2
  33. pydantic_ai/providers/google_vertex.py +10 -5
  34. pydantic_ai/providers/grok.py +13 -1
  35. pydantic_ai/providers/groq.py +2 -0
  36. pydantic_ai/providers/huggingface.py +88 -0
  37. pydantic_ai/result.py +57 -33
  38. pydantic_ai/tools.py +26 -119
  39. pydantic_ai/toolsets/__init__.py +22 -0
  40. pydantic_ai/toolsets/abstract.py +155 -0
  41. pydantic_ai/toolsets/combined.py +88 -0
  42. pydantic_ai/toolsets/deferred.py +38 -0
  43. pydantic_ai/toolsets/filtered.py +24 -0
  44. pydantic_ai/toolsets/function.py +238 -0
  45. pydantic_ai/toolsets/prefixed.py +37 -0
  46. pydantic_ai/toolsets/prepared.py +36 -0
  47. pydantic_ai/toolsets/renamed.py +42 -0
  48. pydantic_ai/toolsets/wrapper.py +37 -0
  49. pydantic_ai/usage.py +14 -8
  50. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
  51. pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
  52. pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
  53. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  54. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  55. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol
6
+
7
+ from pydantic_core import SchemaValidator
8
+ from typing_extensions import Self
9
+
10
+ from .._run_context import AgentDepsT, RunContext
11
+ from ..tools import ToolDefinition, ToolsPrepareFunc
12
+
13
+ if TYPE_CHECKING:
14
+ from .filtered import FilteredToolset
15
+ from .prefixed import PrefixedToolset
16
+ from .prepared import PreparedToolset
17
+ from .renamed import RenamedToolset
18
+
19
+
20
+ class SchemaValidatorProt(Protocol):
21
+ """Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible)."""
22
+
23
+ def validate_json(
24
+ self,
25
+ input: str | bytes | bytearray,
26
+ *,
27
+ allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
28
+ **kwargs: Any,
29
+ ) -> Any: ...
30
+
31
+ def validate_python(
32
+ self, input: Any, *, allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, **kwargs: Any
33
+ ) -> Any: ...
34
+
35
+
36
+ @dataclass
37
+ class ToolsetTool(Generic[AgentDepsT]):
38
+ """Definition of a tool available on a toolset.
39
+
40
+ This is a wrapper around a plain tool definition that includes information about:
41
+
42
+ - the toolset that provided it, for use in error messages
43
+ - the maximum number of retries to attempt if the tool call fails
44
+ - the validator for the tool's arguments
45
+ """
46
+
47
+ toolset: AbstractToolset[AgentDepsT]
48
+ """The toolset that provided this tool, for use in error messages."""
49
+ tool_def: ToolDefinition
50
+ """The tool definition for this tool, including the name, description, and parameters."""
51
+ max_retries: int
52
+ """The maximum number of retries to attempt if the tool call fails."""
53
+ args_validator: SchemaValidator | SchemaValidatorProt
54
+ """The Pydantic Core validator for the tool's arguments.
55
+
56
+ For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator).
57
+ """
58
+
59
+
60
+ class AbstractToolset(ABC, Generic[AgentDepsT]):
61
+ """A toolset is a collection of tools that can be used by an agent.
62
+
63
+ It is responsible for:
64
+
65
+ - Listing the tools it contains
66
+ - Validating the arguments of the tools
67
+ - Calling the tools
68
+
69
+ See [toolset docs](../toolsets.md) for more information.
70
+ """
71
+
72
+ @property
73
+ def name(self) -> str:
74
+ """The name of the toolset for use in error messages."""
75
+ return self.__class__.__name__.replace('Toolset', ' toolset')
76
+
77
+ @property
78
+ def tool_name_conflict_hint(self) -> str:
79
+ """A hint for how to avoid name conflicts with other toolsets for use in error messages."""
80
+ return 'Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.'
81
+
82
+ async def __aenter__(self) -> Self:
83
+ """Enter the toolset context.
84
+
85
+ This is where you can set up network connections in a concrete implementation.
86
+ """
87
+ return self
88
+
89
+ async def __aexit__(self, *args: Any) -> bool | None:
90
+ """Exit the toolset context.
91
+
92
+ This is where you can tear down network connections in a concrete implementation.
93
+ """
94
+ return None
95
+
96
+ @abstractmethod
97
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
98
+ """The tools that are available in this toolset."""
99
+ raise NotImplementedError()
100
+
101
+ @abstractmethod
102
+ async def call_tool(
103
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
104
+ ) -> Any:
105
+ """Call a tool with the given arguments.
106
+
107
+ Args:
108
+ name: The name of the tool to call.
109
+ tool_args: The arguments to pass to the tool.
110
+ ctx: The run context.
111
+ tool: The tool definition returned by [`get_tools`][pydantic_ai.toolsets.AbstractToolset.get_tools] that was called.
112
+ """
113
+ raise NotImplementedError()
114
+
115
+ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
116
+ """Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling)."""
117
+ visitor(self)
118
+
119
+ def filtered(
120
+ self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
121
+ ) -> FilteredToolset[AgentDepsT]:
122
+ """Returns a new toolset that filters this toolset's tools using a filter function that takes the agent context and the tool definition.
123
+
124
+ See [toolset docs](../toolsets.md#filtering-tools) for more information.
125
+ """
126
+ from .filtered import FilteredToolset
127
+
128
+ return FilteredToolset(self, filter_func)
129
+
130
+ def prefixed(self, prefix: str) -> PrefixedToolset[AgentDepsT]:
131
+ """Returns a new toolset that prefixes the names of this toolset's tools.
132
+
133
+ See [toolset docs](../toolsets.md#prefixing-tool-names) for more information.
134
+ """
135
+ from .prefixed import PrefixedToolset
136
+
137
+ return PrefixedToolset(self, prefix)
138
+
139
+ def prepared(self, prepare_func: ToolsPrepareFunc[AgentDepsT]) -> PreparedToolset[AgentDepsT]:
140
+ """Returns a new toolset that prepares this toolset's tools using a prepare function that takes the agent context and the original tool definitions.
141
+
142
+ See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information.
143
+ """
144
+ from .prepared import PreparedToolset
145
+
146
+ return PreparedToolset(self, prepare_func)
147
+
148
+ def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]:
149
+ """Returns a new toolset that renames this toolset's tools using a dictionary mapping new names to original names.
150
+
151
+ See [toolset docs](../toolsets.md#renaming-tools) for more information.
152
+ """
153
+ from .renamed import RenamedToolset
154
+
155
+ return RenamedToolset(self, name_map)
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import Sequence
5
+ from contextlib import AsyncExitStack
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Callable
8
+
9
+ from typing_extensions import Self
10
+
11
+ from .._run_context import AgentDepsT, RunContext
12
+ from .._utils import get_async_lock
13
+ from ..exceptions import UserError
14
+ from .abstract import AbstractToolset, ToolsetTool
15
+
16
+
17
+ @dataclass
18
+ class _CombinedToolsetTool(ToolsetTool[AgentDepsT]):
19
+ """A tool definition for a combined toolset tools that keeps track of the source toolset and tool."""
20
+
21
+ source_toolset: AbstractToolset[AgentDepsT]
22
+ source_tool: ToolsetTool[AgentDepsT]
23
+
24
+
25
+ @dataclass
26
+ class CombinedToolset(AbstractToolset[AgentDepsT]):
27
+ """A toolset that combines multiple toolsets.
28
+
29
+ See [toolset docs](../toolsets.md#combining-toolsets) for more information.
30
+ """
31
+
32
+ toolsets: Sequence[AbstractToolset[AgentDepsT]]
33
+
34
+ _enter_lock: asyncio.Lock = field(compare=False, init=False)
35
+ _entered_count: int = field(init=False)
36
+ _exit_stack: AsyncExitStack | None = field(init=False)
37
+
38
+ def __post_init__(self):
39
+ self._enter_lock = get_async_lock()
40
+ self._entered_count = 0
41
+ self._exit_stack = None
42
+
43
+ async def __aenter__(self) -> Self:
44
+ async with self._enter_lock:
45
+ if self._entered_count == 0:
46
+ self._exit_stack = AsyncExitStack()
47
+ for toolset in self.toolsets:
48
+ await self._exit_stack.enter_async_context(toolset)
49
+ self._entered_count += 1
50
+ return self
51
+
52
+ async def __aexit__(self, *args: Any) -> bool | None:
53
+ async with self._enter_lock:
54
+ self._entered_count -= 1
55
+ if self._entered_count == 0 and self._exit_stack is not None:
56
+ await self._exit_stack.aclose()
57
+ self._exit_stack = None
58
+
59
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
60
+ toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
61
+ all_tools: dict[str, ToolsetTool[AgentDepsT]] = {}
62
+
63
+ for toolset, tools in zip(self.toolsets, toolsets_tools):
64
+ for name, tool in tools.items():
65
+ if existing_tools := all_tools.get(name):
66
+ raise UserError(
67
+ f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}'
68
+ )
69
+
70
+ all_tools[name] = _CombinedToolsetTool(
71
+ toolset=tool.toolset,
72
+ tool_def=tool.tool_def,
73
+ max_retries=tool.max_retries,
74
+ args_validator=tool.args_validator,
75
+ source_toolset=toolset,
76
+ source_tool=tool,
77
+ )
78
+ return all_tools
79
+
80
+ async def call_tool(
81
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
82
+ ) -> Any:
83
+ assert isinstance(tool, _CombinedToolsetTool)
84
+ return await tool.source_toolset.call_tool(name, tool_args, ctx, tool.source_tool)
85
+
86
+ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
87
+ for toolset in self.toolsets:
88
+ toolset.apply(visitor)
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+ from typing import Any
5
+
6
+ from pydantic_core import SchemaValidator, core_schema
7
+
8
+ from .._run_context import AgentDepsT, RunContext
9
+ from ..tools import ToolDefinition
10
+ from .abstract import AbstractToolset, ToolsetTool
11
+
12
+ TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema())
13
+
14
+
15
+ @dataclass
16
+ class DeferredToolset(AbstractToolset[AgentDepsT]):
17
+ """A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called.
18
+
19
+ See [toolset docs](../toolsets.md#deferred-toolset), [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind], and [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] for more information.
20
+ """
21
+
22
+ tool_defs: list[ToolDefinition]
23
+
24
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
25
+ return {
26
+ tool_def.name: ToolsetTool(
27
+ toolset=self,
28
+ tool_def=replace(tool_def, kind='deferred'),
29
+ max_retries=0,
30
+ args_validator=TOOL_SCHEMA_VALIDATOR,
31
+ )
32
+ for tool_def in self.tool_defs
33
+ }
34
+
35
+ async def call_tool(
36
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
37
+ ) -> Any:
38
+ raise NotImplementedError('Deferred tools cannot be called')
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable
5
+
6
+ from .._run_context import AgentDepsT, RunContext
7
+ from ..tools import ToolDefinition
8
+ from .abstract import ToolsetTool
9
+ from .wrapper import WrapperToolset
10
+
11
+
12
+ @dataclass
13
+ class FilteredToolset(WrapperToolset[AgentDepsT]):
14
+ """A toolset that filters the tools it contains using a filter function that takes the agent context and the tool definition.
15
+
16
+ See [toolset docs](../toolsets.md#filtering-tools) for more information.
17
+ """
18
+
19
+ filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
20
+
21
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
22
+ return {
23
+ name: tool for name, tool in (await super().get_tools(ctx)).items() if self.filter_func(ctx, tool.tool_def)
24
+ }
@@ -0,0 +1,238 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Awaitable, Sequence
4
+ from dataclasses import dataclass, field, replace
5
+ from typing import Any, Callable, overload
6
+
7
+ from pydantic.json_schema import GenerateJsonSchema
8
+
9
+ from .._run_context import AgentDepsT, RunContext
10
+ from ..exceptions import UserError
11
+ from ..tools import (
12
+ DocstringFormat,
13
+ GenerateToolJsonSchema,
14
+ Tool,
15
+ ToolFuncEither,
16
+ ToolParams,
17
+ ToolPrepareFunc,
18
+ )
19
+ from .abstract import AbstractToolset, ToolsetTool
20
+
21
+
22
+ @dataclass
23
+ class _FunctionToolsetTool(ToolsetTool[AgentDepsT]):
24
+ """A tool definition for a function toolset tool that keeps track of the function to call."""
25
+
26
+ call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]]
27
+
28
+
29
+ @dataclass(init=False)
30
+ class FunctionToolset(AbstractToolset[AgentDepsT]):
31
+ """A toolset that lets Python functions be used as tools.
32
+
33
+ See [toolset docs](../toolsets.md#function-toolset) for more information.
34
+ """
35
+
36
+ max_retries: int = field(default=1)
37
+ tools: dict[str, Tool[Any]] = field(default_factory=dict)
38
+
39
+ def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1):
40
+ """Build a new function toolset.
41
+
42
+ Args:
43
+ tools: The tools to add to the toolset.
44
+ max_retries: The maximum number of retries for each tool during a run.
45
+ """
46
+ self.max_retries = max_retries
47
+ self.tools = {}
48
+ for tool in tools:
49
+ if isinstance(tool, Tool):
50
+ self.add_tool(tool)
51
+ else:
52
+ self.add_function(tool)
53
+
54
+ @overload
55
+ def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ...
56
+
57
+ @overload
58
+ def tool(
59
+ self,
60
+ /,
61
+ *,
62
+ name: str | None = None,
63
+ retries: int | None = None,
64
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
65
+ docstring_format: DocstringFormat = 'auto',
66
+ require_parameter_descriptions: bool = False,
67
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
68
+ strict: bool | None = None,
69
+ ) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ...
70
+
71
+ def tool(
72
+ self,
73
+ func: ToolFuncEither[AgentDepsT, ToolParams] | None = None,
74
+ /,
75
+ *,
76
+ name: str | None = None,
77
+ retries: int | None = None,
78
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
79
+ docstring_format: DocstringFormat = 'auto',
80
+ require_parameter_descriptions: bool = False,
81
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
82
+ strict: bool | None = None,
83
+ ) -> Any:
84
+ """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
85
+
86
+ Can decorate a sync or async functions.
87
+
88
+ The docstring is inspected to extract both the tool description and description of each parameter,
89
+ [learn more](../tools.md#function-tools-and-schema).
90
+
91
+ We can't add overloads for every possible signature of tool, since the return type is a recursive union
92
+ so the signature of functions decorated with `@toolset.tool` is obscured.
93
+
94
+ Example:
95
+ ```python
96
+ from pydantic_ai import Agent, RunContext
97
+ from pydantic_ai.toolsets.function import FunctionToolset
98
+
99
+ toolset = FunctionToolset()
100
+
101
+ @toolset.tool
102
+ def foobar(ctx: RunContext[int], x: int) -> int:
103
+ return ctx.deps + x
104
+
105
+ @toolset.tool(retries=2)
106
+ async def spam(ctx: RunContext[str], y: float) -> float:
107
+ return ctx.deps + y
108
+
109
+ agent = Agent('test', toolsets=[toolset], deps_type=int)
110
+ result = agent.run_sync('foobar', deps=1)
111
+ print(result.output)
112
+ #> {"foobar":1,"spam":1.0}
113
+ ```
114
+
115
+ Args:
116
+ func: The tool function to register.
117
+ name: The name of the tool, defaults to the function name.
118
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
119
+ which defaults to 1.
120
+ prepare: custom method to prepare the tool definition for each step, return `None` to omit this
121
+ tool from a given step. This is useful if you want to customise a tool at call time,
122
+ or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
123
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
124
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
125
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
126
+ schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
127
+ strict: Whether to enforce JSON schema compliance (only affects OpenAI).
128
+ See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
129
+ """
130
+
131
+ def tool_decorator(
132
+ func_: ToolFuncEither[AgentDepsT, ToolParams],
133
+ ) -> ToolFuncEither[AgentDepsT, ToolParams]:
134
+ # noinspection PyTypeChecker
135
+ self.add_function(
136
+ func_,
137
+ None,
138
+ name,
139
+ retries,
140
+ prepare,
141
+ docstring_format,
142
+ require_parameter_descriptions,
143
+ schema_generator,
144
+ strict,
145
+ )
146
+ return func_
147
+
148
+ return tool_decorator if func is None else tool_decorator(func)
149
+
150
+ def add_function(
151
+ self,
152
+ func: ToolFuncEither[AgentDepsT, ToolParams],
153
+ takes_ctx: bool | None = None,
154
+ name: str | None = None,
155
+ retries: int | None = None,
156
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
157
+ docstring_format: DocstringFormat = 'auto',
158
+ require_parameter_descriptions: bool = False,
159
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
160
+ strict: bool | None = None,
161
+ ) -> None:
162
+ """Add a function as a tool to the toolset.
163
+
164
+ Can take a sync or async function.
165
+
166
+ The docstring is inspected to extract both the tool description and description of each parameter,
167
+ [learn more](../tools.md#function-tools-and-schema).
168
+
169
+ Args:
170
+ func: The tool function to register.
171
+ takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. If `None`, this is inferred from the function signature.
172
+ name: The name of the tool, defaults to the function name.
173
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
174
+ which defaults to 1.
175
+ prepare: custom method to prepare the tool definition for each step, return `None` to omit this
176
+ tool from a given step. This is useful if you want to customise a tool at call time,
177
+ or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
178
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
179
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
180
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
181
+ schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
182
+ strict: Whether to enforce JSON schema compliance (only affects OpenAI).
183
+ See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
184
+ """
185
+ tool = Tool[AgentDepsT](
186
+ func,
187
+ takes_ctx=takes_ctx,
188
+ name=name,
189
+ max_retries=retries,
190
+ prepare=prepare,
191
+ docstring_format=docstring_format,
192
+ require_parameter_descriptions=require_parameter_descriptions,
193
+ schema_generator=schema_generator,
194
+ strict=strict,
195
+ )
196
+ self.add_tool(tool)
197
+
198
+ def add_tool(self, tool: Tool[AgentDepsT]) -> None:
199
+ """Add a tool to the toolset.
200
+
201
+ Args:
202
+ tool: The tool to add.
203
+ """
204
+ if tool.name in self.tools:
205
+ raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
206
+ if tool.max_retries is None:
207
+ tool.max_retries = self.max_retries
208
+ self.tools[tool.name] = tool
209
+
210
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
211
+ tools: dict[str, ToolsetTool[AgentDepsT]] = {}
212
+ for original_name, tool in self.tools.items():
213
+ run_context = replace(ctx, tool_name=original_name, retry=ctx.retries.get(original_name, 0))
214
+ tool_def = await tool.prepare_tool_def(run_context)
215
+ if not tool_def:
216
+ continue
217
+
218
+ new_name = tool_def.name
219
+ if new_name in tools:
220
+ if new_name != original_name:
221
+ raise UserError(f'Renaming tool {original_name!r} to {new_name!r} conflicts with existing tool.')
222
+ else:
223
+ raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.')
224
+
225
+ tools[new_name] = _FunctionToolsetTool(
226
+ toolset=self,
227
+ tool_def=tool_def,
228
+ max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries,
229
+ args_validator=tool.function_schema.validator,
230
+ call_func=tool.function_schema.call,
231
+ )
232
+ return tools
233
+
234
+ async def call_tool(
235
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
236
+ ) -> Any:
237
+ assert isinstance(tool, _FunctionToolsetTool)
238
+ return await tool.call_func(tool_args, ctx)
@@ -0,0 +1,37 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+ from typing import Any
5
+
6
+ from .._run_context import AgentDepsT, RunContext
7
+ from .abstract import ToolsetTool
8
+ from .wrapper import WrapperToolset
9
+
10
+
11
+ @dataclass
12
+ class PrefixedToolset(WrapperToolset[AgentDepsT]):
13
+ """A toolset that prefixes the names of the tools it contains.
14
+
15
+ See [toolset docs](../toolsets.md#prefixing-tool-names) for more information.
16
+ """
17
+
18
+ prefix: str
19
+
20
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
21
+ return {
22
+ new_name: replace(
23
+ tool,
24
+ toolset=self,
25
+ tool_def=replace(tool.tool_def, name=new_name),
26
+ )
27
+ for name, tool in (await super().get_tools(ctx)).items()
28
+ if (new_name := f'{self.prefix}_{name}')
29
+ }
30
+
31
+ async def call_tool(
32
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
33
+ ) -> Any:
34
+ original_name = name.removeprefix(self.prefix + '_')
35
+ ctx = replace(ctx, tool_name=original_name)
36
+ tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name))
37
+ return await super().call_tool(original_name, tool_args, ctx, tool)
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+
5
+ from .._run_context import AgentDepsT, RunContext
6
+ from ..exceptions import UserError
7
+ from ..tools import ToolsPrepareFunc
8
+ from .abstract import ToolsetTool
9
+ from .wrapper import WrapperToolset
10
+
11
+
12
+ @dataclass
13
+ class PreparedToolset(WrapperToolset[AgentDepsT]):
14
+ """A toolset that prepares the tools it contains using a prepare function that takes the agent context and the original tool definitions.
15
+
16
+ See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information.
17
+ """
18
+
19
+ prepare_func: ToolsPrepareFunc[AgentDepsT]
20
+
21
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
22
+ original_tools = await super().get_tools(ctx)
23
+ original_tool_defs = [tool.tool_def for tool in original_tools.values()]
24
+ prepared_tool_defs_by_name = {
25
+ tool_def.name: tool_def for tool_def in (await self.prepare_func(ctx, original_tool_defs) or [])
26
+ }
27
+
28
+ if len(prepared_tool_defs_by_name.keys() - original_tools.keys()) > 0:
29
+ raise UserError(
30
+ 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.'
31
+ )
32
+
33
+ return {
34
+ name: replace(original_tools[name], tool_def=tool_def)
35
+ for name, tool_def in prepared_tool_defs_by_name.items()
36
+ }
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+ from typing import Any
5
+
6
+ from .._run_context import AgentDepsT, RunContext
7
+ from .abstract import ToolsetTool
8
+ from .wrapper import WrapperToolset
9
+
10
+
11
+ @dataclass
12
+ class RenamedToolset(WrapperToolset[AgentDepsT]):
13
+ """A toolset that renames the tools it contains using a dictionary mapping new names to original names.
14
+
15
+ See [toolset docs](../toolsets.md#renaming-tools) for more information.
16
+ """
17
+
18
+ name_map: dict[str, str]
19
+
20
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
21
+ original_to_new_name_map = {v: k for k, v in self.name_map.items()}
22
+ original_tools = await super().get_tools(ctx)
23
+ tools: dict[str, ToolsetTool[AgentDepsT]] = {}
24
+ for original_name, tool in original_tools.items():
25
+ new_name = original_to_new_name_map.get(original_name, None)
26
+ if new_name:
27
+ tools[new_name] = replace(
28
+ tool,
29
+ toolset=self,
30
+ tool_def=replace(tool.tool_def, name=new_name),
31
+ )
32
+ else:
33
+ tools[original_name] = tool
34
+ return tools
35
+
36
+ async def call_tool(
37
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
38
+ ) -> Any:
39
+ original_name = self.name_map.get(name, name)
40
+ ctx = replace(ctx, tool_name=original_name)
41
+ tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name))
42
+ return await super().call_tool(original_name, tool_args, ctx, tool)