pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +70 -9
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +4 -2
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +149 -42
- pydantic_ai/models/__init__.py +6 -4
- pydantic_ai/models/anthropic.py +9 -16
- pydantic_ai/models/bedrock.py +50 -56
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +12 -13
- pydantic_ai/models/google.py +18 -4
- pydantic_ai/models/groq.py +126 -38
- pydantic_ai/models/huggingface.py +4 -4
- pydantic_ai/models/instrumented.py +35 -16
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +6 -6
- pydantic_ai/models/openai.py +35 -40
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +144 -41
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_function_schema.py
CHANGED
|
@@ -5,10 +5,10 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
|
-
from collections.abc import Awaitable
|
|
8
|
+
from collections.abc import Awaitable, Callable
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from inspect import Parameter, signature
|
|
11
|
-
from typing import TYPE_CHECKING, Any,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin
|
|
12
12
|
|
|
13
13
|
from pydantic import ConfigDict
|
|
14
14
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -17,7 +17,7 @@ from pydantic.fields import FieldInfo
|
|
|
17
17
|
from pydantic.json_schema import GenerateJsonSchema
|
|
18
18
|
from pydantic.plugin._schema_validator import create_schema_validator
|
|
19
19
|
from pydantic_core import SchemaValidator, core_schema
|
|
20
|
-
from typing_extensions import
|
|
20
|
+
from typing_extensions import ParamSpec, TypeIs, TypeVar
|
|
21
21
|
|
|
22
22
|
from ._griffe import doc_descriptions
|
|
23
23
|
from ._run_context import RunContext
|
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
|
|
30
30
|
__all__ = ('function_schema',)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
@dataclass
|
|
33
|
+
@dataclass(kw_only=True)
|
|
34
34
|
class FunctionSchema:
|
|
35
35
|
"""Internal information about a function schema."""
|
|
36
36
|
|
|
@@ -231,7 +231,7 @@ R = TypeVar('R')
|
|
|
231
231
|
|
|
232
232
|
WithCtx = Callable[Concatenate[RunContext[Any], P], R]
|
|
233
233
|
WithoutCtx = Callable[P, R]
|
|
234
|
-
TargetFunc =
|
|
234
|
+
TargetFunc = WithCtx[P, R] | WithoutCtx[P, R]
|
|
235
235
|
|
|
236
236
|
|
|
237
237
|
def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
|
pydantic_ai/_griffe.py
CHANGED
|
@@ -2,9 +2,10 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import re
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from contextlib import contextmanager
|
|
6
7
|
from inspect import Signature
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
8
9
|
|
|
9
10
|
from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
|
|
10
11
|
|
pydantic_ai/_otel_messages.py
CHANGED
|
@@ -5,10 +5,10 @@ Based on https://github.com/lmolkova/semantic-conventions/blob/eccd1f806e426a32c
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from typing import Literal
|
|
8
|
+
from typing import Literal, TypeAlias
|
|
9
9
|
|
|
10
10
|
from pydantic import JsonValue
|
|
11
|
-
from typing_extensions import NotRequired,
|
|
11
|
+
from typing_extensions import NotRequired, TypedDict
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class TextPart(TypedDict):
|
pydantic_ai/_output.py
CHANGED
|
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import json
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
-
from collections.abc import Awaitable, Sequence
|
|
6
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import TYPE_CHECKING, Any,
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from pydantic import TypeAdapter, ValidationError
|
|
11
11
|
from pydantic_core import SchemaValidator, to_json
|
|
@@ -15,7 +15,7 @@ from . import _function_schema, _utils, messages as _messages
|
|
|
15
15
|
from ._run_context import AgentDepsT, RunContext
|
|
16
16
|
from .exceptions import ModelRetry, ToolRetryError, UserError
|
|
17
17
|
from .output import (
|
|
18
|
-
|
|
18
|
+
DeferredToolRequests,
|
|
19
19
|
NativeOutput,
|
|
20
20
|
OutputDataT,
|
|
21
21
|
OutputMode,
|
|
@@ -49,12 +49,12 @@ At some point, it may make sense to change the input to OutputValidatorFunc to b
|
|
|
49
49
|
resolve these potential variance issues.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
|
-
OutputValidatorFunc =
|
|
53
|
-
Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv]
|
|
54
|
-
Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
55
|
-
Callable[[OutputDataT_inv], OutputDataT_inv]
|
|
56
|
-
Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
57
|
-
|
|
52
|
+
OutputValidatorFunc = (
|
|
53
|
+
Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv]
|
|
54
|
+
| Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
55
|
+
| Callable[[OutputDataT_inv], OutputDataT_inv]
|
|
56
|
+
| Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
57
|
+
)
|
|
58
58
|
"""
|
|
59
59
|
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
|
|
60
60
|
|
|
@@ -196,7 +196,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
196
196
|
|
|
197
197
|
@dataclass
|
|
198
198
|
class BaseOutputSchema(ABC, Generic[OutputDataT]):
|
|
199
|
-
|
|
199
|
+
allows_deferred_tools: bool
|
|
200
200
|
|
|
201
201
|
@abstractmethod
|
|
202
202
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
@@ -249,10 +249,10 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
249
249
|
"""Build an OutputSchema dataclass from an output type."""
|
|
250
250
|
raw_outputs = _flatten_output_spec(output_spec)
|
|
251
251
|
|
|
252
|
-
outputs = [output for output in raw_outputs if output is not
|
|
253
|
-
|
|
254
|
-
if len(outputs) == 0 and
|
|
255
|
-
raise UserError('At least one output type must be provided other than `
|
|
252
|
+
outputs = [output for output in raw_outputs if output is not DeferredToolRequests]
|
|
253
|
+
allows_deferred_tools = len(outputs) < len(raw_outputs)
|
|
254
|
+
if len(outputs) == 0 and allows_deferred_tools:
|
|
255
|
+
raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
|
|
256
256
|
|
|
257
257
|
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
|
|
258
258
|
if len(outputs) > 1:
|
|
@@ -265,7 +265,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
265
265
|
description=output.description,
|
|
266
266
|
strict=output.strict,
|
|
267
267
|
),
|
|
268
|
-
|
|
268
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
269
269
|
)
|
|
270
270
|
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
|
|
271
271
|
if len(outputs) > 1:
|
|
@@ -278,7 +278,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
278
278
|
description=output.description,
|
|
279
279
|
),
|
|
280
280
|
template=output.template,
|
|
281
|
-
|
|
281
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
282
282
|
)
|
|
283
283
|
|
|
284
284
|
text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
|
|
@@ -313,21 +313,21 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
313
313
|
|
|
314
314
|
if toolset:
|
|
315
315
|
return ToolOrTextOutputSchema(
|
|
316
|
-
processor=text_output_schema,
|
|
316
|
+
processor=text_output_schema,
|
|
317
|
+
toolset=toolset,
|
|
318
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
317
319
|
)
|
|
318
320
|
else:
|
|
319
|
-
return PlainTextOutputSchema(
|
|
320
|
-
processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls
|
|
321
|
-
)
|
|
321
|
+
return PlainTextOutputSchema(processor=text_output_schema, allows_deferred_tools=allows_deferred_tools)
|
|
322
322
|
|
|
323
323
|
if len(tool_outputs) > 0:
|
|
324
|
-
return ToolOutputSchema(toolset=toolset,
|
|
324
|
+
return ToolOutputSchema(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
|
|
325
325
|
|
|
326
326
|
if len(other_outputs) > 0:
|
|
327
327
|
schema = OutputSchemaWithoutMode(
|
|
328
328
|
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
|
|
329
329
|
toolset=toolset,
|
|
330
|
-
|
|
330
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
331
331
|
)
|
|
332
332
|
if default_mode:
|
|
333
333
|
schema = schema.with_default_mode(default_mode)
|
|
@@ -371,23 +371,19 @@ class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
|
|
|
371
371
|
self,
|
|
372
372
|
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
|
|
373
373
|
toolset: OutputToolset[Any] | None,
|
|
374
|
-
|
|
374
|
+
allows_deferred_tools: bool,
|
|
375
375
|
):
|
|
376
|
-
super().__init__(
|
|
376
|
+
super().__init__(allows_deferred_tools)
|
|
377
377
|
self.processor = processor
|
|
378
378
|
self._toolset = toolset
|
|
379
379
|
|
|
380
380
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
381
381
|
if mode == 'native':
|
|
382
|
-
return NativeOutputSchema(
|
|
383
|
-
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
384
|
-
)
|
|
382
|
+
return NativeOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
|
|
385
383
|
elif mode == 'prompted':
|
|
386
|
-
return PromptedOutputSchema(
|
|
387
|
-
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
388
|
-
)
|
|
384
|
+
return PromptedOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
|
|
389
385
|
elif mode == 'tool':
|
|
390
|
-
return ToolOutputSchema(toolset=self.toolset,
|
|
386
|
+
return ToolOutputSchema(toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools)
|
|
391
387
|
else:
|
|
392
388
|
assert_never(mode)
|
|
393
389
|
|
|
@@ -550,8 +546,8 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
550
546
|
class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
551
547
|
_toolset: OutputToolset[Any] | None
|
|
552
548
|
|
|
553
|
-
def __init__(self, toolset: OutputToolset[Any] | None,
|
|
554
|
-
super().__init__(
|
|
549
|
+
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tools: bool):
|
|
550
|
+
super().__init__(allows_deferred_tools)
|
|
555
551
|
self._toolset = toolset
|
|
556
552
|
|
|
557
553
|
@property
|
|
@@ -575,9 +571,9 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem
|
|
|
575
571
|
self,
|
|
576
572
|
processor: PlainTextOutputProcessor[OutputDataT] | None,
|
|
577
573
|
toolset: OutputToolset[Any] | None,
|
|
578
|
-
|
|
574
|
+
allows_deferred_tools: bool,
|
|
579
575
|
):
|
|
580
|
-
super().__init__(toolset=toolset,
|
|
576
|
+
super().__init__(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
|
|
581
577
|
self.processor = processor
|
|
582
578
|
|
|
583
579
|
@property
|
pydantic_ai/_parts_manager.py
CHANGED
|
@@ -15,7 +15,7 @@ from __future__ import annotations as _annotations
|
|
|
15
15
|
|
|
16
16
|
from collections.abc import Hashable
|
|
17
17
|
from dataclasses import dataclass, field, replace
|
|
18
|
-
from typing import Any
|
|
18
|
+
from typing import Any
|
|
19
19
|
|
|
20
20
|
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
21
21
|
from pydantic_ai.messages import (
|
|
@@ -38,7 +38,7 @@ VendorId = Hashable
|
|
|
38
38
|
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
|
|
39
39
|
"""
|
|
40
40
|
|
|
41
|
-
ManagedPart =
|
|
41
|
+
ManagedPart = ModelResponsePart | ToolCallPartDelta
|
|
42
42
|
"""
|
|
43
43
|
A union of types that are managed by the ModelResponsePartsManager.
|
|
44
44
|
Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
|
|
@@ -154,6 +154,7 @@ class ModelResponsePartsManager:
|
|
|
154
154
|
*,
|
|
155
155
|
vendor_part_id: Hashable | None,
|
|
156
156
|
content: str | None = None,
|
|
157
|
+
id: str | None = None,
|
|
157
158
|
signature: str | None = None,
|
|
158
159
|
) -> ModelResponseStreamEvent:
|
|
159
160
|
"""Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate.
|
|
@@ -167,6 +168,7 @@ class ModelResponsePartsManager:
|
|
|
167
168
|
of thinking. If None, a new part will be created unless the latest part is already
|
|
168
169
|
a ThinkingPart.
|
|
169
170
|
content: The thinking content to append to the appropriate ThinkingPart.
|
|
171
|
+
id: An optional id for the thinking part.
|
|
170
172
|
signature: An optional signature for the thinking content.
|
|
171
173
|
|
|
172
174
|
Returns:
|
|
@@ -197,7 +199,7 @@ class ModelResponsePartsManager:
|
|
|
197
199
|
if content is not None:
|
|
198
200
|
# There is no existing thinking part that should be updated, so create a new one
|
|
199
201
|
new_part_index = len(self._parts)
|
|
200
|
-
part = ThinkingPart(content=content, signature=signature)
|
|
202
|
+
part = ThinkingPart(content=content, id=id, signature=signature)
|
|
201
203
|
if vendor_part_id is not None: # pragma: no branch
|
|
202
204
|
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
|
|
203
205
|
self._parts.append(part)
|
|
@@ -262,14 +264,14 @@ class ModelResponsePartsManager:
|
|
|
262
264
|
if tool_name is None and self._parts:
|
|
263
265
|
part_index = len(self._parts) - 1
|
|
264
266
|
latest_part = self._parts[part_index]
|
|
265
|
-
if isinstance(latest_part,
|
|
267
|
+
if isinstance(latest_part, ToolCallPart | ToolCallPartDelta): # pragma: no branch
|
|
266
268
|
existing_matching_part_and_index = latest_part, part_index
|
|
267
269
|
else:
|
|
268
270
|
# vendor_part_id is provided, so look up the corresponding part or delta
|
|
269
271
|
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
270
272
|
if part_index is not None:
|
|
271
273
|
existing_part = self._parts[part_index]
|
|
272
|
-
if not isinstance(existing_part,
|
|
274
|
+
if not isinstance(existing_part, ToolCallPartDelta | ToolCallPart):
|
|
273
275
|
raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
|
|
274
276
|
existing_matching_part_and_index = existing_part, part_index
|
|
275
277
|
|
pydantic_ai/_run_context.py
CHANGED
|
@@ -18,7 +18,7 @@ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
|
|
|
18
18
|
"""Type variable for agent dependencies."""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
@dataclasses.dataclass(repr=False)
|
|
21
|
+
@dataclasses.dataclass(repr=False, kw_only=True)
|
|
22
22
|
class RunContext(Generic[AgentDepsT]):
|
|
23
23
|
"""Information about the current call."""
|
|
24
24
|
|
|
@@ -46,5 +46,7 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
46
46
|
"""Number of retries so far."""
|
|
47
47
|
run_step: int = 0
|
|
48
48
|
"""The current step in the run."""
|
|
49
|
+
tool_call_approved: bool = False
|
|
50
|
+
"""Whether a tool call that required approval has now been approved."""
|
|
49
51
|
|
|
50
52
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
pydantic_ai/_system_prompt.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from collections.abc import Awaitable
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Generic, cast
|
|
7
7
|
|
|
8
8
|
from . import _utils
|
|
9
9
|
from ._run_context import AgentDepsT, RunContext
|
pydantic_ai/_tool_manager.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from collections.abc import Iterable
|
|
5
4
|
from dataclasses import dataclass, field, replace
|
|
6
5
|
from typing import Any, Generic
|
|
7
6
|
|
|
@@ -13,9 +12,9 @@ from . import messages as _messages
|
|
|
13
12
|
from ._run_context import AgentDepsT, RunContext
|
|
14
13
|
from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
|
|
15
14
|
from .messages import ToolCallPart
|
|
16
|
-
from .output import DeferredToolCalls
|
|
17
15
|
from .tools import ToolDefinition
|
|
18
16
|
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
17
|
+
from .usage import UsageLimits
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
@dataclass
|
|
@@ -68,7 +67,11 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
68
67
|
return None
|
|
69
68
|
|
|
70
69
|
async def handle_call(
|
|
71
|
-
self,
|
|
70
|
+
self,
|
|
71
|
+
call: ToolCallPart,
|
|
72
|
+
allow_partial: bool = False,
|
|
73
|
+
wrap_validation_errors: bool = True,
|
|
74
|
+
usage_limits: UsageLimits | None = None,
|
|
72
75
|
) -> Any:
|
|
73
76
|
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
|
|
74
77
|
|
|
@@ -76,13 +79,14 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
76
79
|
call: The tool call part to handle.
|
|
77
80
|
allow_partial: Whether to allow partial validation of the tool arguments.
|
|
78
81
|
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
|
|
82
|
+
usage_limits: Optional usage limits to check before executing tools.
|
|
79
83
|
"""
|
|
80
84
|
if self.tools is None or self.ctx is None:
|
|
81
85
|
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
|
|
82
86
|
|
|
83
87
|
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
|
|
84
|
-
# Output tool calls are not traced
|
|
85
|
-
return await self._call_tool(call, allow_partial, wrap_validation_errors)
|
|
88
|
+
# Output tool calls are not traced and not counted
|
|
89
|
+
return await self._call_tool(call, allow_partial, wrap_validation_errors, count_tool_usage=False)
|
|
86
90
|
else:
|
|
87
91
|
return await self._call_tool_traced(
|
|
88
92
|
call,
|
|
@@ -90,9 +94,17 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
90
94
|
wrap_validation_errors,
|
|
91
95
|
self.ctx.tracer,
|
|
92
96
|
self.ctx.trace_include_content,
|
|
97
|
+
usage_limits,
|
|
93
98
|
)
|
|
94
99
|
|
|
95
|
-
async def _call_tool(
|
|
100
|
+
async def _call_tool(
|
|
101
|
+
self,
|
|
102
|
+
call: ToolCallPart,
|
|
103
|
+
allow_partial: bool,
|
|
104
|
+
wrap_validation_errors: bool,
|
|
105
|
+
usage_limits: UsageLimits | None = None,
|
|
106
|
+
count_tool_usage: bool = True,
|
|
107
|
+
) -> Any:
|
|
96
108
|
if self.tools is None or self.ctx is None:
|
|
97
109
|
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
|
|
98
110
|
|
|
@@ -106,6 +118,9 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
106
118
|
msg = 'No tools available.'
|
|
107
119
|
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
|
|
108
120
|
|
|
121
|
+
if tool.tool_def.defer:
|
|
122
|
+
raise RuntimeError('Deferred tools cannot be called')
|
|
123
|
+
|
|
109
124
|
ctx = replace(
|
|
110
125
|
self.ctx,
|
|
111
126
|
tool_name=name,
|
|
@@ -120,7 +135,15 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
120
135
|
else:
|
|
121
136
|
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
|
|
122
137
|
|
|
123
|
-
|
|
138
|
+
if usage_limits is not None and count_tool_usage:
|
|
139
|
+
usage_limits.check_before_tool_call(self.ctx.usage)
|
|
140
|
+
|
|
141
|
+
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
|
|
142
|
+
|
|
143
|
+
if count_tool_usage:
|
|
144
|
+
self.ctx.usage.tool_calls += 1
|
|
145
|
+
|
|
146
|
+
return result
|
|
124
147
|
except (ValidationError, ModelRetry) as e:
|
|
125
148
|
max_retries = tool.max_retries if tool is not None else 1
|
|
126
149
|
current_retry = self.ctx.retries.get(name, 0)
|
|
@@ -159,6 +182,7 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
159
182
|
wrap_validation_errors: bool,
|
|
160
183
|
tracer: Tracer,
|
|
161
184
|
include_content: bool = False,
|
|
185
|
+
usage_limits: UsageLimits | None = None,
|
|
162
186
|
) -> Any:
|
|
163
187
|
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
|
|
164
188
|
span_attributes = {
|
|
@@ -188,7 +212,7 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
188
212
|
}
|
|
189
213
|
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
|
|
190
214
|
try:
|
|
191
|
-
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
|
|
215
|
+
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits)
|
|
192
216
|
except ToolRetryError as e:
|
|
193
217
|
part = e.tool_retry
|
|
194
218
|
if include_content and span.is_recording():
|
|
@@ -204,23 +228,3 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
204
228
|
)
|
|
205
229
|
|
|
206
230
|
return tool_result
|
|
207
|
-
|
|
208
|
-
def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None:
|
|
209
|
-
"""Get the deferred tool calls from the model response parts."""
|
|
210
|
-
deferred_calls_and_defs = [
|
|
211
|
-
(part, tool_def)
|
|
212
|
-
for part in parts
|
|
213
|
-
if isinstance(part, _messages.ToolCallPart)
|
|
214
|
-
and (tool_def := self.get_tool_def(part.tool_name))
|
|
215
|
-
and tool_def.kind == 'deferred'
|
|
216
|
-
]
|
|
217
|
-
if not deferred_calls_and_defs:
|
|
218
|
-
return None
|
|
219
|
-
|
|
220
|
-
deferred_calls: list[_messages.ToolCallPart] = []
|
|
221
|
-
deferred_tool_defs: dict[str, ToolDefinition] = {}
|
|
222
|
-
for part, tool_def in deferred_calls_and_defs:
|
|
223
|
-
deferred_calls.append(part)
|
|
224
|
-
deferred_tool_defs[part.tool_name] = tool_def
|
|
225
|
-
|
|
226
|
-
return DeferredToolCalls(deferred_calls, deferred_tool_defs)
|
pydantic_ai/_utils.py
CHANGED
|
@@ -4,34 +4,28 @@ import asyncio
|
|
|
4
4
|
import functools
|
|
5
5
|
import inspect
|
|
6
6
|
import re
|
|
7
|
-
import sys
|
|
8
7
|
import time
|
|
9
8
|
import uuid
|
|
10
|
-
import
|
|
11
|
-
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
|
|
9
|
+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator
|
|
12
10
|
from contextlib import asynccontextmanager, suppress
|
|
13
11
|
from dataclasses import dataclass, fields, is_dataclass
|
|
14
12
|
from datetime import datetime, timezone
|
|
15
13
|
from functools import partial
|
|
16
14
|
from types import GenericAlias
|
|
17
|
-
from typing import TYPE_CHECKING, Any,
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload
|
|
18
16
|
|
|
19
17
|
from anyio.to_thread import run_sync
|
|
20
18
|
from pydantic import BaseModel, TypeAdapter
|
|
21
19
|
from pydantic.json_schema import JsonSchemaValue
|
|
22
20
|
from typing_extensions import (
|
|
23
21
|
ParamSpec,
|
|
24
|
-
TypeAlias,
|
|
25
|
-
TypeGuard,
|
|
26
22
|
TypeIs,
|
|
27
|
-
get_args,
|
|
28
|
-
get_origin,
|
|
29
23
|
is_typeddict,
|
|
30
24
|
)
|
|
31
25
|
from typing_inspection import typing_objects
|
|
32
26
|
from typing_inspection.introspection import is_union_origin
|
|
33
27
|
|
|
34
|
-
from pydantic_graph._utils import AbstractSpan
|
|
28
|
+
from pydantic_graph._utils import AbstractSpan
|
|
35
29
|
|
|
36
30
|
from . import exceptions
|
|
37
31
|
|
|
@@ -96,7 +90,7 @@ class Some(Generic[T]):
|
|
|
96
90
|
value: T
|
|
97
91
|
|
|
98
92
|
|
|
99
|
-
Option: TypeAlias =
|
|
93
|
+
Option: TypeAlias = Some[T] | None
|
|
100
94
|
"""Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""
|
|
101
95
|
|
|
102
96
|
|
|
@@ -459,28 +453,22 @@ def strip_markdown_fences(text: str) -> str:
|
|
|
459
453
|
return text
|
|
460
454
|
|
|
461
455
|
|
|
456
|
+
def _unwrap_annotated(tp: Any) -> Any:
|
|
457
|
+
origin = get_origin(tp)
|
|
458
|
+
while typing_objects.is_annotated(origin):
|
|
459
|
+
tp = tp.__origin__
|
|
460
|
+
origin = get_origin(tp)
|
|
461
|
+
return tp
|
|
462
|
+
|
|
463
|
+
|
|
462
464
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
463
465
|
"""Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
|
|
464
466
|
if typing_objects.is_typealiastype(tp):
|
|
465
467
|
tp = tp.__value__
|
|
466
468
|
|
|
469
|
+
tp = _unwrap_annotated(tp)
|
|
467
470
|
origin = get_origin(tp)
|
|
468
471
|
if is_union_origin(origin):
|
|
469
|
-
return get_args(tp)
|
|
472
|
+
return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
|
|
470
473
|
else:
|
|
471
474
|
return ()
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
# The `asyncio.Lock` `loop` argument was deprecated in 3.8 and removed in 3.10,
|
|
475
|
-
# but 3.9 still needs it to have the intended behavior.
|
|
476
|
-
|
|
477
|
-
if sys.version_info < (3, 10):
|
|
478
|
-
|
|
479
|
-
def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
|
|
480
|
-
with warnings.catch_warnings():
|
|
481
|
-
warnings.simplefilter('ignore', DeprecationWarning)
|
|
482
|
-
return asyncio.Lock(loop=get_event_loop())
|
|
483
|
-
else:
|
|
484
|
-
|
|
485
|
-
def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
|
|
486
|
-
return asyncio.Lock()
|