pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +70 -9
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +149 -42
  31. pydantic_ai/models/__init__.py +6 -4
  32. pydantic_ai/models/anthropic.py +9 -16
  33. pydantic_ai/models/bedrock.py +50 -56
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +12 -13
  38. pydantic_ai/models/google.py +18 -4
  39. pydantic_ai/models/groq.py +126 -38
  40. pydantic_ai/models/huggingface.py +4 -4
  41. pydantic_ai/models/instrumented.py +35 -16
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +6 -6
  44. pydantic_ai/models/openai.py +35 -40
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +144 -41
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,10 +5,10 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
5
5
 
6
6
  from __future__ import annotations as _annotations
7
7
 
8
- from collections.abc import Awaitable
8
+ from collections.abc import Awaitable, Callable
9
9
  from dataclasses import dataclass, field
10
10
  from inspect import Parameter, signature
11
- from typing import TYPE_CHECKING, Any, Callable, Union, cast
11
+ from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin
12
12
 
13
13
  from pydantic import ConfigDict
14
14
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -17,7 +17,7 @@ from pydantic.fields import FieldInfo
17
17
  from pydantic.json_schema import GenerateJsonSchema
18
18
  from pydantic.plugin._schema_validator import create_schema_validator
19
19
  from pydantic_core import SchemaValidator, core_schema
20
- from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar, get_origin
20
+ from typing_extensions import ParamSpec, TypeIs, TypeVar
21
21
 
22
22
  from ._griffe import doc_descriptions
23
23
  from ._run_context import RunContext
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
30
30
  __all__ = ('function_schema',)
31
31
 
32
32
 
33
- @dataclass
33
+ @dataclass(kw_only=True)
34
34
  class FunctionSchema:
35
35
  """Internal information about a function schema."""
36
36
 
@@ -231,7 +231,7 @@ R = TypeVar('R')
231
231
 
232
232
  WithCtx = Callable[Concatenate[RunContext[Any], P], R]
233
233
  WithoutCtx = Callable[P, R]
234
- TargetFunc = Union[WithCtx[P, R], WithoutCtx[P, R]]
234
+ TargetFunc = WithCtx[P, R] | WithoutCtx[P, R]
235
235
 
236
236
 
237
237
  def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
pydantic_ai/_griffe.py CHANGED
@@ -2,9 +2,10 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import logging
4
4
  import re
5
+ from collections.abc import Callable
5
6
  from contextlib import contextmanager
6
7
  from inspect import Signature
7
- from typing import TYPE_CHECKING, Any, Callable, Literal, cast
8
+ from typing import TYPE_CHECKING, Any, Literal, cast
8
9
 
9
10
  from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
10
11
 
@@ -5,10 +5,10 @@ Based on https://github.com/lmolkova/semantic-conventions/blob/eccd1f806e426a32c
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import Literal
8
+ from typing import Literal, TypeAlias
9
9
 
10
10
  from pydantic import JsonValue
11
- from typing_extensions import NotRequired, TypeAlias, TypedDict
11
+ from typing_extensions import NotRequired, TypedDict
12
12
 
13
13
 
14
14
  class TextPart(TypedDict):
pydantic_ai/_output.py CHANGED
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
3
3
  import inspect
4
4
  import json
5
5
  from abc import ABC, abstractmethod
6
- from collections.abc import Awaitable, Sequence
6
+ from collections.abc import Awaitable, Callable, Sequence
7
7
  from dataclasses import dataclass, field
8
- from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
9
9
 
10
10
  from pydantic import TypeAdapter, ValidationError
11
11
  from pydantic_core import SchemaValidator, to_json
@@ -15,7 +15,7 @@ from . import _function_schema, _utils, messages as _messages
15
15
  from ._run_context import AgentDepsT, RunContext
16
16
  from .exceptions import ModelRetry, ToolRetryError, UserError
17
17
  from .output import (
18
- DeferredToolCalls,
18
+ DeferredToolRequests,
19
19
  NativeOutput,
20
20
  OutputDataT,
21
21
  OutputMode,
@@ -49,12 +49,12 @@ At some point, it may make sense to change the input to OutputValidatorFunc to b
49
49
  resolve these potential variance issues.
50
50
  """
51
51
 
52
- OutputValidatorFunc = Union[
53
- Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv],
54
- Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]],
55
- Callable[[OutputDataT_inv], OutputDataT_inv],
56
- Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]],
57
- ]
52
+ OutputValidatorFunc = (
53
+ Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv]
54
+ | Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]]
55
+ | Callable[[OutputDataT_inv], OutputDataT_inv]
56
+ | Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]]
57
+ )
58
58
  """
59
59
  A function that always takes and returns the same type of data (which is the result type of an agent run), and:
60
60
 
@@ -196,7 +196,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
196
196
 
197
197
  @dataclass
198
198
  class BaseOutputSchema(ABC, Generic[OutputDataT]):
199
- allows_deferred_tool_calls: bool
199
+ allows_deferred_tools: bool
200
200
 
201
201
  @abstractmethod
202
202
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
@@ -249,10 +249,10 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
249
249
  """Build an OutputSchema dataclass from an output type."""
250
250
  raw_outputs = _flatten_output_spec(output_spec)
251
251
 
252
- outputs = [output for output in raw_outputs if output is not DeferredToolCalls]
253
- allows_deferred_tool_calls = len(outputs) < len(raw_outputs)
254
- if len(outputs) == 0 and allows_deferred_tool_calls:
255
- raise UserError('At least one output type must be provided other than `DeferredToolCalls`.')
252
+ outputs = [output for output in raw_outputs if output is not DeferredToolRequests]
253
+ allows_deferred_tools = len(outputs) < len(raw_outputs)
254
+ if len(outputs) == 0 and allows_deferred_tools:
255
+ raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
256
256
 
257
257
  if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
258
258
  if len(outputs) > 1:
@@ -265,7 +265,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
265
265
  description=output.description,
266
266
  strict=output.strict,
267
267
  ),
268
- allows_deferred_tool_calls=allows_deferred_tool_calls,
268
+ allows_deferred_tools=allows_deferred_tools,
269
269
  )
270
270
  elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
271
271
  if len(outputs) > 1:
@@ -278,7 +278,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
278
278
  description=output.description,
279
279
  ),
280
280
  template=output.template,
281
- allows_deferred_tool_calls=allows_deferred_tool_calls,
281
+ allows_deferred_tools=allows_deferred_tools,
282
282
  )
283
283
 
284
284
  text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
@@ -313,21 +313,21 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
313
313
 
314
314
  if toolset:
315
315
  return ToolOrTextOutputSchema(
316
- processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls
316
+ processor=text_output_schema,
317
+ toolset=toolset,
318
+ allows_deferred_tools=allows_deferred_tools,
317
319
  )
318
320
  else:
319
- return PlainTextOutputSchema(
320
- processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls
321
- )
321
+ return PlainTextOutputSchema(processor=text_output_schema, allows_deferred_tools=allows_deferred_tools)
322
322
 
323
323
  if len(tool_outputs) > 0:
324
- return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
324
+ return ToolOutputSchema(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
325
325
 
326
326
  if len(other_outputs) > 0:
327
327
  schema = OutputSchemaWithoutMode(
328
328
  processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
329
329
  toolset=toolset,
330
- allows_deferred_tool_calls=allows_deferred_tool_calls,
330
+ allows_deferred_tools=allows_deferred_tools,
331
331
  )
332
332
  if default_mode:
333
333
  schema = schema.with_default_mode(default_mode)
@@ -371,23 +371,19 @@ class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
371
371
  self,
372
372
  processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
373
373
  toolset: OutputToolset[Any] | None,
374
- allows_deferred_tool_calls: bool,
374
+ allows_deferred_tools: bool,
375
375
  ):
376
- super().__init__(allows_deferred_tool_calls)
376
+ super().__init__(allows_deferred_tools)
377
377
  self.processor = processor
378
378
  self._toolset = toolset
379
379
 
380
380
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
381
381
  if mode == 'native':
382
- return NativeOutputSchema(
383
- processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
384
- )
382
+ return NativeOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
385
383
  elif mode == 'prompted':
386
- return PromptedOutputSchema(
387
- processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
388
- )
384
+ return PromptedOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
389
385
  elif mode == 'tool':
390
- return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls)
386
+ return ToolOutputSchema(toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools)
391
387
  else:
392
388
  assert_never(mode)
393
389
 
@@ -550,8 +546,8 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
550
546
  class ToolOutputSchema(OutputSchema[OutputDataT]):
551
547
  _toolset: OutputToolset[Any] | None
552
548
 
553
- def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool):
554
- super().__init__(allows_deferred_tool_calls)
549
+ def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tools: bool):
550
+ super().__init__(allows_deferred_tools)
555
551
  self._toolset = toolset
556
552
 
557
553
  @property
@@ -575,9 +571,9 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem
575
571
  self,
576
572
  processor: PlainTextOutputProcessor[OutputDataT] | None,
577
573
  toolset: OutputToolset[Any] | None,
578
- allows_deferred_tool_calls: bool,
574
+ allows_deferred_tools: bool,
579
575
  ):
580
- super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
576
+ super().__init__(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
581
577
  self.processor = processor
582
578
 
583
579
  @property
@@ -15,7 +15,7 @@ from __future__ import annotations as _annotations
15
15
 
16
16
  from collections.abc import Hashable
17
17
  from dataclasses import dataclass, field, replace
18
- from typing import Any, Union
18
+ from typing import Any
19
19
 
20
20
  from pydantic_ai.exceptions import UnexpectedModelBehavior
21
21
  from pydantic_ai.messages import (
@@ -38,7 +38,7 @@ VendorId = Hashable
38
38
  Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
39
39
  """
40
40
 
41
- ManagedPart = Union[ModelResponsePart, ToolCallPartDelta]
41
+ ManagedPart = ModelResponsePart | ToolCallPartDelta
42
42
  """
43
43
  A union of types that are managed by the ModelResponsePartsManager.
44
44
  Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
@@ -154,6 +154,7 @@ class ModelResponsePartsManager:
154
154
  *,
155
155
  vendor_part_id: Hashable | None,
156
156
  content: str | None = None,
157
+ id: str | None = None,
157
158
  signature: str | None = None,
158
159
  ) -> ModelResponseStreamEvent:
159
160
  """Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate.
@@ -167,6 +168,7 @@ class ModelResponsePartsManager:
167
168
  of thinking. If None, a new part will be created unless the latest part is already
168
169
  a ThinkingPart.
169
170
  content: The thinking content to append to the appropriate ThinkingPart.
171
+ id: An optional id for the thinking part.
170
172
  signature: An optional signature for the thinking content.
171
173
 
172
174
  Returns:
@@ -197,7 +199,7 @@ class ModelResponsePartsManager:
197
199
  if content is not None:
198
200
  # There is no existing thinking part that should be updated, so create a new one
199
201
  new_part_index = len(self._parts)
200
- part = ThinkingPart(content=content, signature=signature)
202
+ part = ThinkingPart(content=content, id=id, signature=signature)
201
203
  if vendor_part_id is not None: # pragma: no branch
202
204
  self._vendor_id_to_part_index[vendor_part_id] = new_part_index
203
205
  self._parts.append(part)
@@ -262,14 +264,14 @@ class ModelResponsePartsManager:
262
264
  if tool_name is None and self._parts:
263
265
  part_index = len(self._parts) - 1
264
266
  latest_part = self._parts[part_index]
265
- if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)): # pragma: no branch
267
+ if isinstance(latest_part, ToolCallPart | ToolCallPartDelta): # pragma: no branch
266
268
  existing_matching_part_and_index = latest_part, part_index
267
269
  else:
268
270
  # vendor_part_id is provided, so look up the corresponding part or delta
269
271
  part_index = self._vendor_id_to_part_index.get(vendor_part_id)
270
272
  if part_index is not None:
271
273
  existing_part = self._parts[part_index]
272
- if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)):
274
+ if not isinstance(existing_part, ToolCallPartDelta | ToolCallPart):
273
275
  raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
274
276
  existing_matching_part_and_index = existing_part, part_index
275
277
 
@@ -18,7 +18,7 @@ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
18
18
  """Type variable for agent dependencies."""
19
19
 
20
20
 
21
- @dataclasses.dataclass(repr=False)
21
+ @dataclasses.dataclass(repr=False, kw_only=True)
22
22
  class RunContext(Generic[AgentDepsT]):
23
23
  """Information about the current call."""
24
24
 
@@ -46,5 +46,7 @@ class RunContext(Generic[AgentDepsT]):
46
46
  """Number of retries so far."""
47
47
  run_step: int = 0
48
48
  """The current step in the run."""
49
+ tool_call_approved: bool = False
50
+ """Whether a tool call that required approval has now been approved."""
49
51
 
50
52
  __repr__ = _utils.dataclasses_no_defaults_repr
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
- from collections.abc import Awaitable
4
+ from collections.abc import Awaitable, Callable
5
5
  from dataclasses import dataclass, field
6
- from typing import Any, Callable, Generic, cast
6
+ from typing import Any, Generic, cast
7
7
 
8
8
  from . import _utils
9
9
  from ._run_context import AgentDepsT, RunContext
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
- from collections.abc import Iterable
5
4
  from dataclasses import dataclass, field, replace
6
5
  from typing import Any, Generic
7
6
 
@@ -13,9 +12,9 @@ from . import messages as _messages
13
12
  from ._run_context import AgentDepsT, RunContext
14
13
  from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
15
14
  from .messages import ToolCallPart
16
- from .output import DeferredToolCalls
17
15
  from .tools import ToolDefinition
18
16
  from .toolsets.abstract import AbstractToolset, ToolsetTool
17
+ from .usage import UsageLimits
19
18
 
20
19
 
21
20
  @dataclass
@@ -68,7 +67,11 @@ class ToolManager(Generic[AgentDepsT]):
68
67
  return None
69
68
 
70
69
  async def handle_call(
71
- self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
70
+ self,
71
+ call: ToolCallPart,
72
+ allow_partial: bool = False,
73
+ wrap_validation_errors: bool = True,
74
+ usage_limits: UsageLimits | None = None,
72
75
  ) -> Any:
73
76
  """Handle a tool call by validating the arguments, calling the tool, and handling retries.
74
77
 
@@ -76,13 +79,14 @@ class ToolManager(Generic[AgentDepsT]):
76
79
  call: The tool call part to handle.
77
80
  allow_partial: Whether to allow partial validation of the tool arguments.
78
81
  wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
82
+ usage_limits: Optional usage limits to check before executing tools.
79
83
  """
80
84
  if self.tools is None or self.ctx is None:
81
85
  raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
82
86
 
83
87
  if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
84
- # Output tool calls are not traced
85
- return await self._call_tool(call, allow_partial, wrap_validation_errors)
88
+ # Output tool calls are not traced and not counted
89
+ return await self._call_tool(call, allow_partial, wrap_validation_errors, count_tool_usage=False)
86
90
  else:
87
91
  return await self._call_tool_traced(
88
92
  call,
@@ -90,9 +94,17 @@ class ToolManager(Generic[AgentDepsT]):
90
94
  wrap_validation_errors,
91
95
  self.ctx.tracer,
92
96
  self.ctx.trace_include_content,
97
+ usage_limits,
93
98
  )
94
99
 
95
- async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validation_errors: bool) -> Any:
100
+ async def _call_tool(
101
+ self,
102
+ call: ToolCallPart,
103
+ allow_partial: bool,
104
+ wrap_validation_errors: bool,
105
+ usage_limits: UsageLimits | None = None,
106
+ count_tool_usage: bool = True,
107
+ ) -> Any:
96
108
  if self.tools is None or self.ctx is None:
97
109
  raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
98
110
 
@@ -106,6 +118,9 @@ class ToolManager(Generic[AgentDepsT]):
106
118
  msg = 'No tools available.'
107
119
  raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
108
120
 
121
+ if tool.tool_def.defer:
122
+ raise RuntimeError('Deferred tools cannot be called')
123
+
109
124
  ctx = replace(
110
125
  self.ctx,
111
126
  tool_name=name,
@@ -120,7 +135,15 @@ class ToolManager(Generic[AgentDepsT]):
120
135
  else:
121
136
  args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
122
137
 
123
- return await self.toolset.call_tool(name, args_dict, ctx, tool)
138
+ if usage_limits is not None and count_tool_usage:
139
+ usage_limits.check_before_tool_call(self.ctx.usage)
140
+
141
+ result = await self.toolset.call_tool(name, args_dict, ctx, tool)
142
+
143
+ if count_tool_usage:
144
+ self.ctx.usage.tool_calls += 1
145
+
146
+ return result
124
147
  except (ValidationError, ModelRetry) as e:
125
148
  max_retries = tool.max_retries if tool is not None else 1
126
149
  current_retry = self.ctx.retries.get(name, 0)
@@ -159,6 +182,7 @@ class ToolManager(Generic[AgentDepsT]):
159
182
  wrap_validation_errors: bool,
160
183
  tracer: Tracer,
161
184
  include_content: bool = False,
185
+ usage_limits: UsageLimits | None = None,
162
186
  ) -> Any:
163
187
  """See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
164
188
  span_attributes = {
@@ -188,7 +212,7 @@ class ToolManager(Generic[AgentDepsT]):
188
212
  }
189
213
  with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
190
214
  try:
191
- tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
215
+ tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits)
192
216
  except ToolRetryError as e:
193
217
  part = e.tool_retry
194
218
  if include_content and span.is_recording():
@@ -204,23 +228,3 @@ class ToolManager(Generic[AgentDepsT]):
204
228
  )
205
229
 
206
230
  return tool_result
207
-
208
- def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None:
209
- """Get the deferred tool calls from the model response parts."""
210
- deferred_calls_and_defs = [
211
- (part, tool_def)
212
- for part in parts
213
- if isinstance(part, _messages.ToolCallPart)
214
- and (tool_def := self.get_tool_def(part.tool_name))
215
- and tool_def.kind == 'deferred'
216
- ]
217
- if not deferred_calls_and_defs:
218
- return None
219
-
220
- deferred_calls: list[_messages.ToolCallPart] = []
221
- deferred_tool_defs: dict[str, ToolDefinition] = {}
222
- for part, tool_def in deferred_calls_and_defs:
223
- deferred_calls.append(part)
224
- deferred_tool_defs[part.tool_name] = tool_def
225
-
226
- return DeferredToolCalls(deferred_calls, deferred_tool_defs)
pydantic_ai/_utils.py CHANGED
@@ -4,34 +4,28 @@ import asyncio
4
4
  import functools
5
5
  import inspect
6
6
  import re
7
- import sys
8
7
  import time
9
8
  import uuid
10
- import warnings
11
- from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
9
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator
12
10
  from contextlib import asynccontextmanager, suppress
13
11
  from dataclasses import dataclass, fields, is_dataclass
14
12
  from datetime import datetime, timezone
15
13
  from functools import partial
16
14
  from types import GenericAlias
17
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload
15
+ from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload
18
16
 
19
17
  from anyio.to_thread import run_sync
20
18
  from pydantic import BaseModel, TypeAdapter
21
19
  from pydantic.json_schema import JsonSchemaValue
22
20
  from typing_extensions import (
23
21
  ParamSpec,
24
- TypeAlias,
25
- TypeGuard,
26
22
  TypeIs,
27
- get_args,
28
- get_origin,
29
23
  is_typeddict,
30
24
  )
31
25
  from typing_inspection import typing_objects
32
26
  from typing_inspection.introspection import is_union_origin
33
27
 
34
- from pydantic_graph._utils import AbstractSpan, get_event_loop
28
+ from pydantic_graph._utils import AbstractSpan
35
29
 
36
30
  from . import exceptions
37
31
 
@@ -96,7 +90,7 @@ class Some(Generic[T]):
96
90
  value: T
97
91
 
98
92
 
99
- Option: TypeAlias = Union[Some[T], None]
93
+ Option: TypeAlias = Some[T] | None
100
94
  """Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""
101
95
 
102
96
 
@@ -459,28 +453,22 @@ def strip_markdown_fences(text: str) -> str:
459
453
  return text
460
454
 
461
455
 
456
+ def _unwrap_annotated(tp: Any) -> Any:
457
+ origin = get_origin(tp)
458
+ while typing_objects.is_annotated(origin):
459
+ tp = tp.__origin__
460
+ origin = get_origin(tp)
461
+ return tp
462
+
463
+
462
464
  def get_union_args(tp: Any) -> tuple[Any, ...]:
463
465
  """Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
464
466
  if typing_objects.is_typealiastype(tp):
465
467
  tp = tp.__value__
466
468
 
469
+ tp = _unwrap_annotated(tp)
467
470
  origin = get_origin(tp)
468
471
  if is_union_origin(origin):
469
- return get_args(tp)
472
+ return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
470
473
  else:
471
474
  return ()
472
-
473
-
474
- # The `asyncio.Lock` `loop` argument was deprecated in 3.8 and removed in 3.10,
475
- # but 3.9 still needs it to have the intended behavior.
476
-
477
- if sys.version_info < (3, 10):
478
-
479
- def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
480
- with warnings.catch_warnings():
481
- warnings.simplefilter('ignore', DeprecationWarning)
482
- return asyncio.Lock(loop=get_event_loop())
483
- else:
484
-
485
- def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
486
- return asyncio.Lock()