pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (45) hide show
  1. pydantic_ai/_agent_graph.py +220 -319
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +295 -331
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -156
  10. pydantic_ai/exceptions.py +12 -0
  11. pydantic_ai/ext/aci.py +12 -3
  12. pydantic_ai/ext/langchain.py +9 -1
  13. pydantic_ai/mcp.py +147 -84
  14. pydantic_ai/messages.py +13 -5
  15. pydantic_ai/models/__init__.py +30 -18
  16. pydantic_ai/models/anthropic.py +1 -1
  17. pydantic_ai/models/function.py +50 -24
  18. pydantic_ai/models/gemini.py +1 -9
  19. pydantic_ai/models/google.py +2 -11
  20. pydantic_ai/models/groq.py +1 -0
  21. pydantic_ai/models/mistral.py +1 -1
  22. pydantic_ai/models/openai.py +3 -3
  23. pydantic_ai/output.py +21 -7
  24. pydantic_ai/profiles/google.py +1 -1
  25. pydantic_ai/profiles/moonshotai.py +8 -0
  26. pydantic_ai/providers/grok.py +13 -1
  27. pydantic_ai/providers/groq.py +2 -0
  28. pydantic_ai/result.py +58 -45
  29. pydantic_ai/tools.py +26 -119
  30. pydantic_ai/toolsets/__init__.py +22 -0
  31. pydantic_ai/toolsets/abstract.py +155 -0
  32. pydantic_ai/toolsets/combined.py +88 -0
  33. pydantic_ai/toolsets/deferred.py +38 -0
  34. pydantic_ai/toolsets/filtered.py +24 -0
  35. pydantic_ai/toolsets/function.py +238 -0
  36. pydantic_ai/toolsets/prefixed.py +37 -0
  37. pydantic_ai/toolsets/prepared.py +36 -0
  38. pydantic_ai/toolsets/renamed.py +42 -0
  39. pydantic_ai/toolsets/wrapper.py +37 -0
  40. pydantic_ai/usage.py +14 -8
  41. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
  42. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
  43. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  44. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  45. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/tools.py CHANGED
@@ -1,20 +1,15 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import dataclasses
4
- import json
5
3
  from collections.abc import Awaitable, Sequence
6
4
  from dataclasses import dataclass, field
7
5
  from typing import Any, Callable, Generic, Literal, Union
8
6
 
9
- from opentelemetry.trace import Tracer
10
- from pydantic import ValidationError
11
7
  from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
12
8
  from pydantic_core import SchemaValidator, core_schema
13
9
  from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar
14
10
 
15
- from . import _function_schema, _utils, messages as _messages
11
+ from . import _function_schema, _utils
16
12
  from ._run_context import AgentDepsT, RunContext
17
- from .exceptions import ModelRetry, UnexpectedModelBehavior
18
13
 
19
14
  __all__ = (
20
15
  'AgentDepsT',
@@ -32,7 +27,6 @@ __all__ = (
32
27
  'ToolDefinition',
33
28
  )
34
29
 
35
- from .messages import ToolReturnPart
36
30
 
37
31
  ToolParams = ParamSpec('ToolParams', default=...)
38
32
  """Retrieval function param spec."""
@@ -173,12 +167,6 @@ class Tool(Generic[AgentDepsT]):
173
167
  This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
174
168
  """
175
169
 
176
- # TODO: Consider moving this current_retry state to live on something other than the tool.
177
- # We've worked around this for now by copying instances of the tool when creating new runs,
178
- # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things
179
- # up, though is also likely a larger effort to refactor.
180
- current_retry: int = field(default=0, init=False)
181
-
182
170
  def __init__(
183
171
  self,
184
172
  function: ToolFuncEither[AgentDepsT],
@@ -303,6 +291,15 @@ class Tool(Generic[AgentDepsT]):
303
291
  function_schema=function_schema,
304
292
  )
305
293
 
294
+ @property
295
+ def tool_def(self):
296
+ return ToolDefinition(
297
+ name=self.name,
298
+ description=self.description,
299
+ parameters_json_schema=self.function_schema.json_schema,
300
+ strict=self.strict,
301
+ )
302
+
306
303
  async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
307
304
  """Get the tool definition.
308
305
 
@@ -312,113 +309,11 @@ class Tool(Generic[AgentDepsT]):
312
309
  Returns:
313
310
  return a `ToolDefinition` or `None` if the tools should not be registered for this run.
314
311
  """
315
- tool_def = ToolDefinition(
316
- name=self.name,
317
- description=self.description,
318
- parameters_json_schema=self.function_schema.json_schema,
319
- strict=self.strict,
320
- )
312
+ base_tool_def = self.tool_def
321
313
  if self.prepare is not None:
322
- return await self.prepare(ctx, tool_def)
314
+ return await self.prepare(ctx, base_tool_def)
323
315
  else:
324
- return tool_def
325
-
326
- async def run(
327
- self,
328
- message: _messages.ToolCallPart,
329
- run_context: RunContext[AgentDepsT],
330
- tracer: Tracer,
331
- include_content: bool = False,
332
- ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
333
- """Run the tool function asynchronously.
334
-
335
- This method wraps `_run` in an OpenTelemetry span.
336
-
337
- See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>.
338
- """
339
- span_attributes = {
340
- 'gen_ai.tool.name': self.name,
341
- # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
342
- 'gen_ai.tool.call.id': message.tool_call_id,
343
- **({'tool_arguments': message.args_as_json_str()} if include_content else {}),
344
- 'logfire.msg': f'running tool: {self.name}',
345
- # add the JSON schema so these attributes are formatted nicely in Logfire
346
- 'logfire.json_schema': json.dumps(
347
- {
348
- 'type': 'object',
349
- 'properties': {
350
- **(
351
- {
352
- 'tool_arguments': {'type': 'object'},
353
- 'tool_response': {'type': 'object'},
354
- }
355
- if include_content
356
- else {}
357
- ),
358
- 'gen_ai.tool.name': {},
359
- 'gen_ai.tool.call.id': {},
360
- },
361
- }
362
- ),
363
- }
364
- with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
365
- response = await self._run(message, run_context)
366
- if include_content and span.is_recording():
367
- span.set_attribute(
368
- 'tool_response',
369
- response.model_response_str()
370
- if isinstance(response, ToolReturnPart)
371
- else response.model_response(),
372
- )
373
-
374
- return response
375
-
376
- async def _run(
377
- self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
378
- ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
379
- try:
380
- validator = self.function_schema.validator
381
- if isinstance(message.args, str):
382
- args_dict = validator.validate_json(message.args or '{}')
383
- else:
384
- args_dict = validator.validate_python(message.args or {})
385
- except ValidationError as e:
386
- return self._on_error(e, message)
387
-
388
- ctx = dataclasses.replace(
389
- run_context,
390
- retry=self.current_retry,
391
- tool_name=message.tool_name,
392
- tool_call_id=message.tool_call_id,
393
- )
394
- try:
395
- response_content = await self.function_schema.call(args_dict, ctx)
396
- except ModelRetry as e:
397
- return self._on_error(e, message)
398
-
399
- self.current_retry = 0
400
- return _messages.ToolReturnPart(
401
- tool_name=message.tool_name,
402
- content=response_content,
403
- tool_call_id=message.tool_call_id,
404
- )
405
-
406
- def _on_error(
407
- self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
408
- ) -> _messages.RetryPromptPart:
409
- self.current_retry += 1
410
- if self.max_retries is None or self.current_retry > self.max_retries:
411
- raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
412
- else:
413
- if isinstance(exc, ValidationError):
414
- content = exc.errors(include_url=False, include_context=False)
415
- else:
416
- content = exc.message
417
- return _messages.RetryPromptPart(
418
- tool_name=call_message.tool_name,
419
- content=content,
420
- tool_call_id=call_message.tool_call_id,
421
- )
316
+ return base_tool_def
422
317
 
423
318
 
424
319
  ObjectJsonSchema: TypeAlias = dict[str, Any]
@@ -429,6 +324,9 @@ This type is used to define tools parameters (aka arguments) in [ToolDefinition]
429
324
  With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
430
325
  """
431
326
 
327
+ ToolKind: TypeAlias = Literal['function', 'output', 'deferred']
328
+ """Kind of tool."""
329
+
432
330
 
433
331
  @dataclass(repr=False)
434
332
  class ToolDefinition:
@@ -440,7 +338,7 @@ class ToolDefinition:
440
338
  name: str
441
339
  """The name of the tool."""
442
340
 
443
- parameters_json_schema: ObjectJsonSchema
341
+ parameters_json_schema: ObjectJsonSchema = field(default_factory=lambda: {'type': 'object', 'properties': {}})
444
342
  """The JSON schema for the tool's parameters."""
445
343
 
446
344
  description: str | None = None
@@ -464,4 +362,13 @@ class ToolDefinition:
464
362
  Note: this is currently only supported by OpenAI models.
465
363
  """
466
364
 
365
+ kind: ToolKind = field(default='function')
366
+ """The kind of tool:
367
+
368
+ - `'function'`: a tool that will be executed by Pydantic AI during an agent run and has its result returned to the model
369
+ - `'output'`: a tool that passes through an output value that ends the run
370
+ - `'deferred'`: a tool whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running.
371
+ When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call.
372
+ """
373
+
467
374
  __repr__ = _utils.dataclasses_no_defaults_repr
@@ -0,0 +1,22 @@
1
+ from .abstract import AbstractToolset, ToolsetTool
2
+ from .combined import CombinedToolset
3
+ from .deferred import DeferredToolset
4
+ from .filtered import FilteredToolset
5
+ from .function import FunctionToolset
6
+ from .prefixed import PrefixedToolset
7
+ from .prepared import PreparedToolset
8
+ from .renamed import RenamedToolset
9
+ from .wrapper import WrapperToolset
10
+
11
+ __all__ = (
12
+ 'AbstractToolset',
13
+ 'ToolsetTool',
14
+ 'CombinedToolset',
15
+ 'DeferredToolset',
16
+ 'FilteredToolset',
17
+ 'FunctionToolset',
18
+ 'PrefixedToolset',
19
+ 'RenamedToolset',
20
+ 'PreparedToolset',
21
+ 'WrapperToolset',
22
+ )
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol
6
+
7
+ from pydantic_core import SchemaValidator
8
+ from typing_extensions import Self
9
+
10
+ from .._run_context import AgentDepsT, RunContext
11
+ from ..tools import ToolDefinition, ToolsPrepareFunc
12
+
13
+ if TYPE_CHECKING:
14
+ from .filtered import FilteredToolset
15
+ from .prefixed import PrefixedToolset
16
+ from .prepared import PreparedToolset
17
+ from .renamed import RenamedToolset
18
+
19
+
20
+ class SchemaValidatorProt(Protocol):
21
+ """Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible)."""
22
+
23
+ def validate_json(
24
+ self,
25
+ input: str | bytes | bytearray,
26
+ *,
27
+ allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
28
+ **kwargs: Any,
29
+ ) -> Any: ...
30
+
31
+ def validate_python(
32
+ self, input: Any, *, allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, **kwargs: Any
33
+ ) -> Any: ...
34
+
35
+
36
+ @dataclass
37
+ class ToolsetTool(Generic[AgentDepsT]):
38
+ """Definition of a tool available on a toolset.
39
+
40
+ This is a wrapper around a plain tool definition that includes information about:
41
+
42
+ - the toolset that provided it, for use in error messages
43
+ - the maximum number of retries to attempt if the tool call fails
44
+ - the validator for the tool's arguments
45
+ """
46
+
47
+ toolset: AbstractToolset[AgentDepsT]
48
+ """The toolset that provided this tool, for use in error messages."""
49
+ tool_def: ToolDefinition
50
+ """The tool definition for this tool, including the name, description, and parameters."""
51
+ max_retries: int
52
+ """The maximum number of retries to attempt if the tool call fails."""
53
+ args_validator: SchemaValidator | SchemaValidatorProt
54
+ """The Pydantic Core validator for the tool's arguments.
55
+
56
+ For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator).
57
+ """
58
+
59
+
60
+ class AbstractToolset(ABC, Generic[AgentDepsT]):
61
+ """A toolset is a collection of tools that can be used by an agent.
62
+
63
+ It is responsible for:
64
+
65
+ - Listing the tools it contains
66
+ - Validating the arguments of the tools
67
+ - Calling the tools
68
+
69
+ See [toolset docs](../toolsets.md) for more information.
70
+ """
71
+
72
+ @property
73
+ def name(self) -> str:
74
+ """The name of the toolset for use in error messages."""
75
+ return self.__class__.__name__.replace('Toolset', ' toolset')
76
+
77
+ @property
78
+ def tool_name_conflict_hint(self) -> str:
79
+ """A hint for how to avoid name conflicts with other toolsets for use in error messages."""
80
+ return 'Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.'
81
+
82
+ async def __aenter__(self) -> Self:
83
+ """Enter the toolset context.
84
+
85
+ This is where you can set up network connections in a concrete implementation.
86
+ """
87
+ return self
88
+
89
+ async def __aexit__(self, *args: Any) -> bool | None:
90
+ """Exit the toolset context.
91
+
92
+ This is where you can tear down network connections in a concrete implementation.
93
+ """
94
+ return None
95
+
96
+ @abstractmethod
97
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
98
+ """The tools that are available in this toolset."""
99
+ raise NotImplementedError()
100
+
101
+ @abstractmethod
102
+ async def call_tool(
103
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
104
+ ) -> Any:
105
+ """Call a tool with the given arguments.
106
+
107
+ Args:
108
+ name: The name of the tool to call.
109
+ tool_args: The arguments to pass to the tool.
110
+ ctx: The run context.
111
+ tool: The tool definition returned by [`get_tools`][pydantic_ai.toolsets.AbstractToolset.get_tools] that was called.
112
+ """
113
+ raise NotImplementedError()
114
+
115
+ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
116
+ """Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling)."""
117
+ visitor(self)
118
+
119
+ def filtered(
120
+ self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
121
+ ) -> FilteredToolset[AgentDepsT]:
122
+ """Returns a new toolset that filters this toolset's tools using a filter function that takes the agent context and the tool definition.
123
+
124
+ See [toolset docs](../toolsets.md#filtering-tools) for more information.
125
+ """
126
+ from .filtered import FilteredToolset
127
+
128
+ return FilteredToolset(self, filter_func)
129
+
130
+ def prefixed(self, prefix: str) -> PrefixedToolset[AgentDepsT]:
131
+ """Returns a new toolset that prefixes the names of this toolset's tools.
132
+
133
+ See [toolset docs](../toolsets.md#prefixing-tool-names) for more information.
134
+ """
135
+ from .prefixed import PrefixedToolset
136
+
137
+ return PrefixedToolset(self, prefix)
138
+
139
+ def prepared(self, prepare_func: ToolsPrepareFunc[AgentDepsT]) -> PreparedToolset[AgentDepsT]:
140
+ """Returns a new toolset that prepares this toolset's tools using a prepare function that takes the agent context and the original tool definitions.
141
+
142
+ See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information.
143
+ """
144
+ from .prepared import PreparedToolset
145
+
146
+ return PreparedToolset(self, prepare_func)
147
+
148
+ def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]:
149
+ """Returns a new toolset that renames this toolset's tools using a dictionary mapping new names to original names.
150
+
151
+ See [toolset docs](../toolsets.md#renaming-tools) for more information.
152
+ """
153
+ from .renamed import RenamedToolset
154
+
155
+ return RenamedToolset(self, name_map)
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import Sequence
5
+ from contextlib import AsyncExitStack
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Callable
8
+
9
+ from typing_extensions import Self
10
+
11
+ from .._run_context import AgentDepsT, RunContext
12
+ from .._utils import get_async_lock
13
+ from ..exceptions import UserError
14
+ from .abstract import AbstractToolset, ToolsetTool
15
+
16
+
17
+ @dataclass
18
+ class _CombinedToolsetTool(ToolsetTool[AgentDepsT]):
19
+ """A tool definition for a combined toolset tools that keeps track of the source toolset and tool."""
20
+
21
+ source_toolset: AbstractToolset[AgentDepsT]
22
+ source_tool: ToolsetTool[AgentDepsT]
23
+
24
+
25
+ @dataclass
26
+ class CombinedToolset(AbstractToolset[AgentDepsT]):
27
+ """A toolset that combines multiple toolsets.
28
+
29
+ See [toolset docs](../toolsets.md#combining-toolsets) for more information.
30
+ """
31
+
32
+ toolsets: Sequence[AbstractToolset[AgentDepsT]]
33
+
34
+ _enter_lock: asyncio.Lock = field(compare=False, init=False)
35
+ _entered_count: int = field(init=False)
36
+ _exit_stack: AsyncExitStack | None = field(init=False)
37
+
38
+ def __post_init__(self):
39
+ self._enter_lock = get_async_lock()
40
+ self._entered_count = 0
41
+ self._exit_stack = None
42
+
43
+ async def __aenter__(self) -> Self:
44
+ async with self._enter_lock:
45
+ if self._entered_count == 0:
46
+ self._exit_stack = AsyncExitStack()
47
+ for toolset in self.toolsets:
48
+ await self._exit_stack.enter_async_context(toolset)
49
+ self._entered_count += 1
50
+ return self
51
+
52
+ async def __aexit__(self, *args: Any) -> bool | None:
53
+ async with self._enter_lock:
54
+ self._entered_count -= 1
55
+ if self._entered_count == 0 and self._exit_stack is not None:
56
+ await self._exit_stack.aclose()
57
+ self._exit_stack = None
58
+
59
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
60
+ toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
61
+ all_tools: dict[str, ToolsetTool[AgentDepsT]] = {}
62
+
63
+ for toolset, tools in zip(self.toolsets, toolsets_tools):
64
+ for name, tool in tools.items():
65
+ if existing_tools := all_tools.get(name):
66
+ raise UserError(
67
+ f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}'
68
+ )
69
+
70
+ all_tools[name] = _CombinedToolsetTool(
71
+ toolset=tool.toolset,
72
+ tool_def=tool.tool_def,
73
+ max_retries=tool.max_retries,
74
+ args_validator=tool.args_validator,
75
+ source_toolset=toolset,
76
+ source_tool=tool,
77
+ )
78
+ return all_tools
79
+
80
+ async def call_tool(
81
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
82
+ ) -> Any:
83
+ assert isinstance(tool, _CombinedToolsetTool)
84
+ return await tool.source_toolset.call_tool(name, tool_args, ctx, tool.source_tool)
85
+
86
+ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
87
+ for toolset in self.toolsets:
88
+ toolset.apply(visitor)
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+ from typing import Any
5
+
6
+ from pydantic_core import SchemaValidator, core_schema
7
+
8
+ from .._run_context import AgentDepsT, RunContext
9
+ from ..tools import ToolDefinition
10
+ from .abstract import AbstractToolset, ToolsetTool
11
+
12
+ TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema())
13
+
14
+
15
+ @dataclass
16
+ class DeferredToolset(AbstractToolset[AgentDepsT]):
17
+ """A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called.
18
+
19
+ See [toolset docs](../toolsets.md#deferred-toolset), [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind], and [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] for more information.
20
+ """
21
+
22
+ tool_defs: list[ToolDefinition]
23
+
24
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
25
+ return {
26
+ tool_def.name: ToolsetTool(
27
+ toolset=self,
28
+ tool_def=replace(tool_def, kind='deferred'),
29
+ max_retries=0,
30
+ args_validator=TOOL_SCHEMA_VALIDATOR,
31
+ )
32
+ for tool_def in self.tool_defs
33
+ }
34
+
35
+ async def call_tool(
36
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
37
+ ) -> Any:
38
+ raise NotImplementedError('Deferred tools cannot be called')
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable
5
+
6
+ from .._run_context import AgentDepsT, RunContext
7
+ from ..tools import ToolDefinition
8
+ from .abstract import ToolsetTool
9
+ from .wrapper import WrapperToolset
10
+
11
+
12
+ @dataclass
13
+ class FilteredToolset(WrapperToolset[AgentDepsT]):
14
+ """A toolset that filters the tools it contains using a filter function that takes the agent context and the tool definition.
15
+
16
+ See [toolset docs](../toolsets.md#filtering-tools) for more information.
17
+ """
18
+
19
+ filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
20
+
21
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
22
+ return {
23
+ name: tool for name, tool in (await super().get_tools(ctx)).items() if self.filter_func(ctx, tool.tool_def)
24
+ }