pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +219 -315
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +296 -226
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -155
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +14 -2
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +19 -9
- pydantic_ai/models/__init__.py +43 -19
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +3 -11
- pydantic_ai/models/google.py +3 -12
- pydantic_ai/models/groq.py +2 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +3 -3
- pydantic_ai/models/openai.py +5 -5
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +57 -33
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
- pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
- pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_parts_manager.py
CHANGED
|
@@ -2,10 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
The manager tracks which parts (in particular, text and tool calls) correspond to which
|
|
4
4
|
vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model),
|
|
5
|
-
and produces
|
|
5
|
+
and produces Pydantic AI-format events as appropriate for consumers of the streaming APIs.
|
|
6
6
|
|
|
7
7
|
The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor,
|
|
8
|
-
and are tightly coupled to the specific model being used, and the
|
|
8
|
+
and are tightly coupled to the specific model being used, and the Pydantic AI Model subclass implementation.
|
|
9
9
|
|
|
10
10
|
This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate
|
|
11
11
|
event-emitting logic.
|
pydantic_ai/_run_context.py
CHANGED
|
@@ -5,6 +5,7 @@ from collections.abc import Sequence
|
|
|
5
5
|
from dataclasses import field
|
|
6
6
|
from typing import TYPE_CHECKING, Generic
|
|
7
7
|
|
|
8
|
+
from opentelemetry.trace import NoOpTracer, Tracer
|
|
8
9
|
from typing_extensions import TypeVar
|
|
9
10
|
|
|
10
11
|
from . import _utils, messages as _messages
|
|
@@ -27,10 +28,16 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
27
28
|
"""The model used in this run."""
|
|
28
29
|
usage: Usage
|
|
29
30
|
"""LLM usage associated with the run."""
|
|
30
|
-
prompt: str | Sequence[_messages.UserContent] | None
|
|
31
|
+
prompt: str | Sequence[_messages.UserContent] | None = None
|
|
31
32
|
"""The original user prompt passed to the run."""
|
|
32
33
|
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
33
34
|
"""Messages exchanged in the conversation so far."""
|
|
35
|
+
tracer: Tracer = field(default_factory=NoOpTracer)
|
|
36
|
+
"""The tracer to use for tracing the run."""
|
|
37
|
+
trace_include_content: bool = False
|
|
38
|
+
"""Whether to include the content of the messages in the trace."""
|
|
39
|
+
retries: dict[str, int] = field(default_factory=dict)
|
|
40
|
+
"""Number of retries for each tool so far."""
|
|
34
41
|
tool_call_id: str | None = None
|
|
35
42
|
"""The ID of the tool call."""
|
|
36
43
|
tool_name: str | None = None
|
|
@@ -40,17 +47,4 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
40
47
|
run_step: int = 0
|
|
41
48
|
"""The current step in the run."""
|
|
42
49
|
|
|
43
|
-
def replace_with(
|
|
44
|
-
self,
|
|
45
|
-
retry: int | None = None,
|
|
46
|
-
tool_name: str | None | _utils.Unset = _utils.UNSET,
|
|
47
|
-
) -> RunContext[AgentDepsT]:
|
|
48
|
-
# Create a new `RunContext` a new `retry` value and `tool_name`.
|
|
49
|
-
kwargs = {}
|
|
50
|
-
if retry is not None:
|
|
51
|
-
kwargs['retry'] = retry
|
|
52
|
-
if tool_name is not _utils.UNSET: # pragma: no branch
|
|
53
|
-
kwargs['tool_name'] = tool_name
|
|
54
|
-
return dataclasses.replace(self, **kwargs)
|
|
55
|
-
|
|
56
50
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
from typing import Any, Generic
|
|
7
|
+
|
|
8
|
+
from pydantic import ValidationError
|
|
9
|
+
from typing_extensions import assert_never
|
|
10
|
+
|
|
11
|
+
from pydantic_ai.output import DeferredToolCalls
|
|
12
|
+
|
|
13
|
+
from . import messages as _messages
|
|
14
|
+
from ._run_context import AgentDepsT, RunContext
|
|
15
|
+
from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
|
|
16
|
+
from .messages import ToolCallPart
|
|
17
|
+
from .tools import ToolDefinition
|
|
18
|
+
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ToolManager(Generic[AgentDepsT]):
|
|
23
|
+
"""Manages tools for an agent run step. It caches the agent run's toolset's tool definitions and handles calling tools and retries."""
|
|
24
|
+
|
|
25
|
+
ctx: RunContext[AgentDepsT]
|
|
26
|
+
"""The agent run context for a specific run step."""
|
|
27
|
+
toolset: AbstractToolset[AgentDepsT]
|
|
28
|
+
"""The toolset that provides the tools for this run step."""
|
|
29
|
+
tools: dict[str, ToolsetTool[AgentDepsT]]
|
|
30
|
+
"""The cached tools for this run step."""
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
|
|
34
|
+
"""Build a new tool manager for a specific run step."""
|
|
35
|
+
return cls(
|
|
36
|
+
ctx=ctx,
|
|
37
|
+
toolset=toolset,
|
|
38
|
+
tools=await toolset.get_tools(ctx),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
|
|
42
|
+
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
|
|
43
|
+
return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries))
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def tool_defs(self) -> list[ToolDefinition]:
|
|
47
|
+
"""The tool definitions for the tools in this tool manager."""
|
|
48
|
+
return [tool.tool_def for tool in self.tools.values()]
|
|
49
|
+
|
|
50
|
+
def get_tool_def(self, name: str) -> ToolDefinition | None:
|
|
51
|
+
"""Get the tool definition for a given tool name, or `None` if the tool is unknown."""
|
|
52
|
+
try:
|
|
53
|
+
return self.tools[name].tool_def
|
|
54
|
+
except KeyError:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
|
|
58
|
+
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
call: The tool call part to handle.
|
|
62
|
+
allow_partial: Whether to allow partial validation of the tool arguments.
|
|
63
|
+
"""
|
|
64
|
+
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
|
|
65
|
+
# Output tool calls are not traced
|
|
66
|
+
return await self._call_tool(call, allow_partial)
|
|
67
|
+
else:
|
|
68
|
+
return await self._call_tool_traced(call, allow_partial)
|
|
69
|
+
|
|
70
|
+
async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
|
|
71
|
+
name = call.tool_name
|
|
72
|
+
tool = self.tools.get(name)
|
|
73
|
+
try:
|
|
74
|
+
if tool is None:
|
|
75
|
+
if self.tools:
|
|
76
|
+
msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tools.keys())}'
|
|
77
|
+
else:
|
|
78
|
+
msg = 'No tools available.'
|
|
79
|
+
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
|
|
80
|
+
|
|
81
|
+
ctx = replace(
|
|
82
|
+
self.ctx,
|
|
83
|
+
tool_name=name,
|
|
84
|
+
tool_call_id=call.tool_call_id,
|
|
85
|
+
retry=self.ctx.retries.get(name, 0),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
|
|
89
|
+
validator = tool.args_validator
|
|
90
|
+
if isinstance(call.args, str):
|
|
91
|
+
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
|
|
92
|
+
else:
|
|
93
|
+
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
|
|
94
|
+
|
|
95
|
+
output = await self.toolset.call_tool(name, args_dict, ctx, tool)
|
|
96
|
+
except (ValidationError, ModelRetry) as e:
|
|
97
|
+
max_retries = tool.max_retries if tool is not None else 1
|
|
98
|
+
current_retry = self.ctx.retries.get(name, 0)
|
|
99
|
+
|
|
100
|
+
if current_retry == max_retries:
|
|
101
|
+
raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
|
|
102
|
+
else:
|
|
103
|
+
if isinstance(e, ValidationError):
|
|
104
|
+
m = _messages.RetryPromptPart(
|
|
105
|
+
tool_name=name,
|
|
106
|
+
content=e.errors(include_url=False, include_context=False),
|
|
107
|
+
tool_call_id=call.tool_call_id,
|
|
108
|
+
)
|
|
109
|
+
e = ToolRetryError(m)
|
|
110
|
+
elif isinstance(e, ModelRetry):
|
|
111
|
+
m = _messages.RetryPromptPart(
|
|
112
|
+
tool_name=name,
|
|
113
|
+
content=e.message,
|
|
114
|
+
tool_call_id=call.tool_call_id,
|
|
115
|
+
)
|
|
116
|
+
e = ToolRetryError(m)
|
|
117
|
+
else:
|
|
118
|
+
assert_never(e)
|
|
119
|
+
|
|
120
|
+
self.ctx.retries[name] = current_retry + 1
|
|
121
|
+
raise e
|
|
122
|
+
else:
|
|
123
|
+
self.ctx.retries.pop(name, None)
|
|
124
|
+
return output
|
|
125
|
+
|
|
126
|
+
async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
|
|
127
|
+
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
|
|
128
|
+
span_attributes = {
|
|
129
|
+
'gen_ai.tool.name': call.tool_name,
|
|
130
|
+
# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
|
|
131
|
+
'gen_ai.tool.call.id': call.tool_call_id,
|
|
132
|
+
**({'tool_arguments': call.args_as_json_str()} if self.ctx.trace_include_content else {}),
|
|
133
|
+
'logfire.msg': f'running tool: {call.tool_name}',
|
|
134
|
+
# add the JSON schema so these attributes are formatted nicely in Logfire
|
|
135
|
+
'logfire.json_schema': json.dumps(
|
|
136
|
+
{
|
|
137
|
+
'type': 'object',
|
|
138
|
+
'properties': {
|
|
139
|
+
**(
|
|
140
|
+
{
|
|
141
|
+
'tool_arguments': {'type': 'object'},
|
|
142
|
+
'tool_response': {'type': 'object'},
|
|
143
|
+
}
|
|
144
|
+
if self.ctx.trace_include_content
|
|
145
|
+
else {}
|
|
146
|
+
),
|
|
147
|
+
'gen_ai.tool.name': {},
|
|
148
|
+
'gen_ai.tool.call.id': {},
|
|
149
|
+
},
|
|
150
|
+
}
|
|
151
|
+
),
|
|
152
|
+
}
|
|
153
|
+
with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
|
|
154
|
+
try:
|
|
155
|
+
tool_result = await self._call_tool(call, allow_partial)
|
|
156
|
+
except ToolRetryError as e:
|
|
157
|
+
part = e.tool_retry
|
|
158
|
+
if self.ctx.trace_include_content and span.is_recording():
|
|
159
|
+
span.set_attribute('tool_response', part.model_response())
|
|
160
|
+
raise e
|
|
161
|
+
|
|
162
|
+
if self.ctx.trace_include_content and span.is_recording():
|
|
163
|
+
span.set_attribute(
|
|
164
|
+
'tool_response',
|
|
165
|
+
tool_result
|
|
166
|
+
if isinstance(tool_result, str)
|
|
167
|
+
else _messages.tool_return_ta.dump_json(tool_result).decode(),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return tool_result
|
|
171
|
+
|
|
172
|
+
def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None:
|
|
173
|
+
"""Get the deferred tool calls from the model response parts."""
|
|
174
|
+
deferred_calls_and_defs = [
|
|
175
|
+
(part, tool_def)
|
|
176
|
+
for part in parts
|
|
177
|
+
if isinstance(part, _messages.ToolCallPart)
|
|
178
|
+
and (tool_def := self.get_tool_def(part.tool_name))
|
|
179
|
+
and tool_def.kind == 'deferred'
|
|
180
|
+
]
|
|
181
|
+
if not deferred_calls_and_defs:
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
deferred_calls: list[_messages.ToolCallPart] = []
|
|
185
|
+
deferred_tool_defs: dict[str, ToolDefinition] = {}
|
|
186
|
+
for part, tool_def in deferred_calls_and_defs:
|
|
187
|
+
deferred_calls.append(part)
|
|
188
|
+
deferred_tool_defs[part.tool_name] = tool_def
|
|
189
|
+
|
|
190
|
+
return DeferredToolCalls(deferred_calls, deferred_tool_defs)
|
pydantic_ai/_utils.py
CHANGED
|
@@ -4,8 +4,10 @@ import asyncio
|
|
|
4
4
|
import functools
|
|
5
5
|
import inspect
|
|
6
6
|
import re
|
|
7
|
+
import sys
|
|
7
8
|
import time
|
|
8
9
|
import uuid
|
|
10
|
+
import warnings
|
|
9
11
|
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
|
|
10
12
|
from contextlib import asynccontextmanager, suppress
|
|
11
13
|
from dataclasses import dataclass, fields, is_dataclass
|
|
@@ -29,7 +31,7 @@ from typing_extensions import (
|
|
|
29
31
|
from typing_inspection import typing_objects
|
|
30
32
|
from typing_inspection.introspection import is_union_origin
|
|
31
33
|
|
|
32
|
-
from pydantic_graph._utils import AbstractSpan
|
|
34
|
+
from pydantic_graph._utils import AbstractSpan, get_event_loop
|
|
33
35
|
|
|
34
36
|
from . import exceptions
|
|
35
37
|
|
|
@@ -461,3 +463,18 @@ def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
|
461
463
|
return get_args(tp)
|
|
462
464
|
else:
|
|
463
465
|
return ()
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
# The `asyncio.Lock` `loop` argument was deprecated in 3.8 and removed in 3.10,
|
|
469
|
+
# but 3.9 still needs it to have the intended behavior.
|
|
470
|
+
|
|
471
|
+
if sys.version_info < (3, 10):
|
|
472
|
+
|
|
473
|
+
def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
|
|
474
|
+
with warnings.catch_warnings():
|
|
475
|
+
warnings.simplefilter('ignore', DeprecationWarning)
|
|
476
|
+
return asyncio.Lock(loop=get_event_loop())
|
|
477
|
+
else:
|
|
478
|
+
|
|
479
|
+
def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
|
|
480
|
+
return asyncio.Lock()
|