pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +25 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +65 -9
- pydantic_ai/models/anthropic.py +16 -7
- pydantic_ai/models/bedrock.py +8 -5
- pydantic_ai/models/cohere.py +1 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +9 -4
- pydantic_ai/models/gemini.py +15 -9
- pydantic_ai/models/google.py +18 -14
- pydantic_ai/models/groq.py +17 -14
- pydantic_ai/models/huggingface.py +18 -12
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +12 -18
- pydantic_ai/models/openai.py +29 -26
- pydantic_ai/models/test.py +3 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/run.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from typing import Any, Generic, overload
|
|
7
|
+
|
|
8
|
+
from typing_extensions import Literal
|
|
9
|
+
|
|
10
|
+
from pydantic_graph import End, GraphRun, GraphRunContext
|
|
11
|
+
|
|
12
|
+
from . import (
|
|
13
|
+
_agent_graph,
|
|
14
|
+
exceptions,
|
|
15
|
+
messages as _messages,
|
|
16
|
+
usage as _usage,
|
|
17
|
+
)
|
|
18
|
+
from .output import OutputDataT
|
|
19
|
+
from .result import FinalResult
|
|
20
|
+
from .tools import AgentDepsT
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclasses.dataclass(repr=False)
|
|
24
|
+
class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
25
|
+
"""A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
|
|
26
|
+
|
|
27
|
+
You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
|
|
28
|
+
|
|
29
|
+
Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
|
|
30
|
+
[`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
|
|
31
|
+
becomes available.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
```python
|
|
35
|
+
from pydantic_ai import Agent
|
|
36
|
+
|
|
37
|
+
agent = Agent('openai:gpt-4o')
|
|
38
|
+
|
|
39
|
+
async def main():
|
|
40
|
+
nodes = []
|
|
41
|
+
# Iterate through the run, recording each node along the way:
|
|
42
|
+
async with agent.iter('What is the capital of France?') as agent_run:
|
|
43
|
+
async for node in agent_run:
|
|
44
|
+
nodes.append(node)
|
|
45
|
+
print(nodes)
|
|
46
|
+
'''
|
|
47
|
+
[
|
|
48
|
+
UserPromptNode(
|
|
49
|
+
user_prompt='What is the capital of France?',
|
|
50
|
+
instructions=None,
|
|
51
|
+
instructions_functions=[],
|
|
52
|
+
system_prompts=(),
|
|
53
|
+
system_prompt_functions=[],
|
|
54
|
+
system_prompt_dynamic_functions={},
|
|
55
|
+
),
|
|
56
|
+
ModelRequestNode(
|
|
57
|
+
request=ModelRequest(
|
|
58
|
+
parts=[
|
|
59
|
+
UserPromptPart(
|
|
60
|
+
content='What is the capital of France?',
|
|
61
|
+
timestamp=datetime.datetime(...),
|
|
62
|
+
)
|
|
63
|
+
]
|
|
64
|
+
)
|
|
65
|
+
),
|
|
66
|
+
CallToolsNode(
|
|
67
|
+
model_response=ModelResponse(
|
|
68
|
+
parts=[TextPart(content='The capital of France is Paris.')],
|
|
69
|
+
usage=Usage(
|
|
70
|
+
requests=1, request_tokens=56, response_tokens=7, total_tokens=63
|
|
71
|
+
),
|
|
72
|
+
model_name='gpt-4o',
|
|
73
|
+
timestamp=datetime.datetime(...),
|
|
74
|
+
)
|
|
75
|
+
),
|
|
76
|
+
End(data=FinalResult(output='The capital of France is Paris.')),
|
|
77
|
+
]
|
|
78
|
+
'''
|
|
79
|
+
print(agent_run.result.output)
|
|
80
|
+
#> The capital of France is Paris.
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
You can also manually drive the iteration using the [`next`][pydantic_ai.agent.AgentRun.next] method for
|
|
84
|
+
more granular control.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
_graph_run: GraphRun[
|
|
88
|
+
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[OutputDataT]
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
@overload
|
|
92
|
+
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
|
|
93
|
+
@overload
|
|
94
|
+
def _traceparent(self) -> str: ...
|
|
95
|
+
def _traceparent(self, *, required: bool = True) -> str | None:
|
|
96
|
+
traceparent = self._graph_run._traceparent(required=False) # type: ignore[reportPrivateUsage]
|
|
97
|
+
if traceparent is None and required: # pragma: no cover
|
|
98
|
+
raise AttributeError('No span was created for this agent run')
|
|
99
|
+
return traceparent
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
|
|
103
|
+
"""The current context of the agent run."""
|
|
104
|
+
return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]](
|
|
105
|
+
self._graph_run.state, self._graph_run.deps
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def next_node(
|
|
110
|
+
self,
|
|
111
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
112
|
+
"""The next node that will be run in the agent graph.
|
|
113
|
+
|
|
114
|
+
This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
|
|
115
|
+
"""
|
|
116
|
+
next_node = self._graph_run.next_node
|
|
117
|
+
if isinstance(next_node, End):
|
|
118
|
+
return next_node
|
|
119
|
+
if _agent_graph.is_agent_node(next_node):
|
|
120
|
+
return next_node
|
|
121
|
+
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def result(self) -> AgentRunResult[OutputDataT] | None:
|
|
125
|
+
"""The final result of the run if it has ended, otherwise `None`.
|
|
126
|
+
|
|
127
|
+
Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
|
|
128
|
+
with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult].
|
|
129
|
+
"""
|
|
130
|
+
graph_run_result = self._graph_run.result
|
|
131
|
+
if graph_run_result is None:
|
|
132
|
+
return None
|
|
133
|
+
return AgentRunResult(
|
|
134
|
+
graph_run_result.output.output,
|
|
135
|
+
graph_run_result.output.tool_name,
|
|
136
|
+
graph_run_result.state,
|
|
137
|
+
self._graph_run.deps.new_message_index,
|
|
138
|
+
self._traceparent(required=False),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def __aiter__(
|
|
142
|
+
self,
|
|
143
|
+
) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]]:
|
|
144
|
+
"""Provide async-iteration over the nodes in the agent run."""
|
|
145
|
+
return self
|
|
146
|
+
|
|
147
|
+
async def __anext__(
|
|
148
|
+
self,
|
|
149
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
150
|
+
"""Advance to the next node automatically based on the last returned node."""
|
|
151
|
+
next_node = await self._graph_run.__anext__()
|
|
152
|
+
if _agent_graph.is_agent_node(node=next_node):
|
|
153
|
+
return next_node
|
|
154
|
+
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
|
|
155
|
+
return next_node
|
|
156
|
+
|
|
157
|
+
async def next(
|
|
158
|
+
self,
|
|
159
|
+
node: _agent_graph.AgentNode[AgentDepsT, OutputDataT],
|
|
160
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
161
|
+
"""Manually drive the agent run by passing in the node you want to run next.
|
|
162
|
+
|
|
163
|
+
This lets you inspect or mutate the node before continuing execution, or skip certain nodes
|
|
164
|
+
under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End]
|
|
165
|
+
node.
|
|
166
|
+
|
|
167
|
+
Example:
|
|
168
|
+
```python
|
|
169
|
+
from pydantic_ai import Agent
|
|
170
|
+
from pydantic_graph import End
|
|
171
|
+
|
|
172
|
+
agent = Agent('openai:gpt-4o')
|
|
173
|
+
|
|
174
|
+
async def main():
|
|
175
|
+
async with agent.iter('What is the capital of France?') as agent_run:
|
|
176
|
+
next_node = agent_run.next_node # start with the first node
|
|
177
|
+
nodes = [next_node]
|
|
178
|
+
while not isinstance(next_node, End):
|
|
179
|
+
next_node = await agent_run.next(next_node)
|
|
180
|
+
nodes.append(next_node)
|
|
181
|
+
# Once `next_node` is an End, we've finished:
|
|
182
|
+
print(nodes)
|
|
183
|
+
'''
|
|
184
|
+
[
|
|
185
|
+
UserPromptNode(
|
|
186
|
+
user_prompt='What is the capital of France?',
|
|
187
|
+
instructions=None,
|
|
188
|
+
instructions_functions=[],
|
|
189
|
+
system_prompts=(),
|
|
190
|
+
system_prompt_functions=[],
|
|
191
|
+
system_prompt_dynamic_functions={},
|
|
192
|
+
),
|
|
193
|
+
ModelRequestNode(
|
|
194
|
+
request=ModelRequest(
|
|
195
|
+
parts=[
|
|
196
|
+
UserPromptPart(
|
|
197
|
+
content='What is the capital of France?',
|
|
198
|
+
timestamp=datetime.datetime(...),
|
|
199
|
+
)
|
|
200
|
+
]
|
|
201
|
+
)
|
|
202
|
+
),
|
|
203
|
+
CallToolsNode(
|
|
204
|
+
model_response=ModelResponse(
|
|
205
|
+
parts=[TextPart(content='The capital of France is Paris.')],
|
|
206
|
+
usage=Usage(
|
|
207
|
+
requests=1,
|
|
208
|
+
request_tokens=56,
|
|
209
|
+
response_tokens=7,
|
|
210
|
+
total_tokens=63,
|
|
211
|
+
),
|
|
212
|
+
model_name='gpt-4o',
|
|
213
|
+
timestamp=datetime.datetime(...),
|
|
214
|
+
)
|
|
215
|
+
),
|
|
216
|
+
End(data=FinalResult(output='The capital of France is Paris.')),
|
|
217
|
+
]
|
|
218
|
+
'''
|
|
219
|
+
print('Final result:', agent_run.result.output)
|
|
220
|
+
#> Final result: The capital of France is Paris.
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
node: The node to run next in the graph.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if
|
|
228
|
+
the run has completed.
|
|
229
|
+
"""
|
|
230
|
+
# Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
|
|
231
|
+
# on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
|
|
232
|
+
next_node = await self._graph_run.next(node)
|
|
233
|
+
if _agent_graph.is_agent_node(next_node):
|
|
234
|
+
return next_node
|
|
235
|
+
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
|
|
236
|
+
return next_node
|
|
237
|
+
|
|
238
|
+
def usage(self) -> _usage.Usage:
|
|
239
|
+
"""Get usage statistics for the run so far, including token usage, model requests, and so on."""
|
|
240
|
+
return self._graph_run.state.usage
|
|
241
|
+
|
|
242
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
243
|
+
result = self._graph_run.result
|
|
244
|
+
result_repr = '<run not finished>' if result is None else repr(result.output)
|
|
245
|
+
return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>'
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@dataclasses.dataclass
|
|
249
|
+
class AgentRunResult(Generic[OutputDataT]):
|
|
250
|
+
"""The final result of an agent run."""
|
|
251
|
+
|
|
252
|
+
output: OutputDataT
|
|
253
|
+
"""The output data from the agent run."""
|
|
254
|
+
|
|
255
|
+
_output_tool_name: str | None = dataclasses.field(repr=False)
|
|
256
|
+
_state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
|
|
257
|
+
_new_message_index: int = dataclasses.field(repr=False)
|
|
258
|
+
_traceparent_value: str | None = dataclasses.field(repr=False)
|
|
259
|
+
|
|
260
|
+
@overload
|
|
261
|
+
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
|
|
262
|
+
@overload
|
|
263
|
+
def _traceparent(self) -> str: ...
|
|
264
|
+
def _traceparent(self, *, required: bool = True) -> str | None:
|
|
265
|
+
if self._traceparent_value is None and required: # pragma: no cover
|
|
266
|
+
raise AttributeError('No span was created for this agent run')
|
|
267
|
+
return self._traceparent_value
|
|
268
|
+
|
|
269
|
+
def _set_output_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
|
|
270
|
+
"""Set return content for the output tool.
|
|
271
|
+
|
|
272
|
+
Useful if you want to continue the conversation and want to set the response to the output tool call.
|
|
273
|
+
"""
|
|
274
|
+
if not self._output_tool_name:
|
|
275
|
+
raise ValueError('Cannot set output tool return content when the return type is `str`.')
|
|
276
|
+
|
|
277
|
+
messages = self._state.message_history
|
|
278
|
+
last_message = messages[-1]
|
|
279
|
+
for idx, part in enumerate(last_message.parts):
|
|
280
|
+
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name:
|
|
281
|
+
# Only do deepcopy when we have to modify
|
|
282
|
+
copied_messages = list(messages)
|
|
283
|
+
copied_last = deepcopy(last_message)
|
|
284
|
+
copied_last.parts[idx].content = return_content # type: ignore[misc]
|
|
285
|
+
copied_messages[-1] = copied_last
|
|
286
|
+
return copied_messages
|
|
287
|
+
|
|
288
|
+
raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.')
|
|
289
|
+
|
|
290
|
+
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
291
|
+
"""Return the history of _messages.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
295
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
296
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
297
|
+
not be modified.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
List of messages.
|
|
301
|
+
"""
|
|
302
|
+
if output_tool_return_content is not None:
|
|
303
|
+
return self._set_output_tool_return(output_tool_return_content)
|
|
304
|
+
else:
|
|
305
|
+
return self._state.message_history
|
|
306
|
+
|
|
307
|
+
def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes:
|
|
308
|
+
"""Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
312
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
313
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
314
|
+
not be modified.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
JSON bytes representing the messages.
|
|
318
|
+
"""
|
|
319
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(
|
|
320
|
+
self.all_messages(output_tool_return_content=output_tool_return_content)
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
324
|
+
"""Return new messages associated with this run.
|
|
325
|
+
|
|
326
|
+
Messages from older runs are excluded.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
330
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
331
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
332
|
+
not be modified.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
List of new messages.
|
|
336
|
+
"""
|
|
337
|
+
return self.all_messages(output_tool_return_content=output_tool_return_content)[self._new_message_index :]
|
|
338
|
+
|
|
339
|
+
def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes:
|
|
340
|
+
"""Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
344
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
345
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
346
|
+
not be modified.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
JSON bytes representing the new messages.
|
|
350
|
+
"""
|
|
351
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(
|
|
352
|
+
self.new_messages(output_tool_return_content=output_tool_return_content)
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def usage(self) -> _usage.Usage:
|
|
356
|
+
"""Return the usage of the whole run."""
|
|
357
|
+
return self._state.usage
|
pydantic_ai/tools.py
CHANGED
pydantic_ai/toolsets/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from ._dynamic import ToolsetFunc
|
|
1
2
|
from .abstract import AbstractToolset, ToolsetTool
|
|
2
3
|
from .combined import CombinedToolset
|
|
3
4
|
from .deferred import DeferredToolset
|
|
@@ -10,6 +11,7 @@ from .wrapper import WrapperToolset
|
|
|
10
11
|
|
|
11
12
|
__all__ = (
|
|
12
13
|
'AbstractToolset',
|
|
14
|
+
'ToolsetFunc',
|
|
13
15
|
'ToolsetTool',
|
|
14
16
|
'CombinedToolset',
|
|
15
17
|
'DeferredToolset',
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
from typing import Any, Callable, Union
|
|
7
|
+
|
|
8
|
+
from typing_extensions import Self, TypeAlias
|
|
9
|
+
|
|
10
|
+
from .._run_context import AgentDepsT, RunContext
|
|
11
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
12
|
+
|
|
13
|
+
ToolsetFunc: TypeAlias = Callable[
|
|
14
|
+
[RunContext[AgentDepsT]],
|
|
15
|
+
Union[AbstractToolset[AgentDepsT], None, Awaitable[Union[AbstractToolset[AgentDepsT], None]]],
|
|
16
|
+
]
|
|
17
|
+
"""A sync/async function which takes a run context and returns a toolset."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class DynamicToolset(AbstractToolset[AgentDepsT]):
|
|
22
|
+
"""A toolset that dynamically builds a toolset using a function that takes the run context.
|
|
23
|
+
|
|
24
|
+
It should only be used during a single agent run as it stores the generated toolset.
|
|
25
|
+
To use it multiple times, copy it using `dataclasses.replace`.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
toolset_func: ToolsetFunc[AgentDepsT]
|
|
29
|
+
per_run_step: bool = True
|
|
30
|
+
|
|
31
|
+
_toolset: AbstractToolset[AgentDepsT] | None = None
|
|
32
|
+
_run_step: int | None = None
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def id(self) -> str | None:
|
|
36
|
+
return None # pragma: no cover
|
|
37
|
+
|
|
38
|
+
async def __aenter__(self) -> Self:
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
42
|
+
try:
|
|
43
|
+
if self._toolset is not None:
|
|
44
|
+
return await self._toolset.__aexit__(*args)
|
|
45
|
+
finally:
|
|
46
|
+
self._toolset = None
|
|
47
|
+
self._run_step = None
|
|
48
|
+
|
|
49
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
50
|
+
if self._toolset is None or (self.per_run_step and ctx.run_step != self._run_step):
|
|
51
|
+
if self._toolset is not None:
|
|
52
|
+
await self._toolset.__aexit__()
|
|
53
|
+
|
|
54
|
+
toolset = self.toolset_func(ctx)
|
|
55
|
+
if inspect.isawaitable(toolset):
|
|
56
|
+
toolset = await toolset
|
|
57
|
+
|
|
58
|
+
if toolset is not None:
|
|
59
|
+
await toolset.__aenter__()
|
|
60
|
+
|
|
61
|
+
self._toolset = toolset
|
|
62
|
+
self._run_step = ctx.run_step
|
|
63
|
+
|
|
64
|
+
if self._toolset is None:
|
|
65
|
+
return {}
|
|
66
|
+
|
|
67
|
+
return await self._toolset.get_tools(ctx)
|
|
68
|
+
|
|
69
|
+
async def call_tool(
|
|
70
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
71
|
+
) -> Any:
|
|
72
|
+
assert self._toolset is not None
|
|
73
|
+
return await self._toolset.call_tool(name, tool_args, ctx, tool)
|
|
74
|
+
|
|
75
|
+
def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
|
|
76
|
+
if self._toolset is None:
|
|
77
|
+
super().apply(visitor)
|
|
78
|
+
else:
|
|
79
|
+
self._toolset.apply(visitor)
|
|
80
|
+
|
|
81
|
+
def visit_and_replace(
|
|
82
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
83
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
84
|
+
if self._toolset is None:
|
|
85
|
+
return super().visit_and_replace(visitor)
|
|
86
|
+
else:
|
|
87
|
+
return replace(self, _toolset=self._toolset.visit_and_replace(visitor))
|
pydantic_ai/toolsets/abstract.py
CHANGED
|
@@ -70,9 +70,23 @@ class AbstractToolset(ABC, Generic[AgentDepsT]):
|
|
|
70
70
|
"""
|
|
71
71
|
|
|
72
72
|
@property
|
|
73
|
-
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def id(self) -> str | None:
|
|
75
|
+
"""An ID for the toolset that is unique among all toolsets registered with the same agent.
|
|
76
|
+
|
|
77
|
+
If you're implementing a concrete implementation that users can instantiate more than once, you should let them optionally pass a custom ID to the constructor and return that here.
|
|
78
|
+
|
|
79
|
+
A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow.
|
|
80
|
+
"""
|
|
81
|
+
raise NotImplementedError()
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def label(self) -> str:
|
|
74
85
|
"""The name of the toolset for use in error messages."""
|
|
75
|
-
|
|
86
|
+
label = self.__class__.__name__
|
|
87
|
+
if self.id: # pragma: no branch
|
|
88
|
+
label += f' {self.id!r}'
|
|
89
|
+
return label
|
|
76
90
|
|
|
77
91
|
@property
|
|
78
92
|
def tool_name_conflict_hint(self) -> str:
|
|
@@ -113,9 +127,15 @@ class AbstractToolset(ABC, Generic[AgentDepsT]):
|
|
|
113
127
|
raise NotImplementedError()
|
|
114
128
|
|
|
115
129
|
def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
|
|
116
|
-
"""Run a visitor function on all
|
|
130
|
+
"""Run a visitor function on all "leaf" toolsets (i.e. those that implement their own tool listing and calling)."""
|
|
117
131
|
visitor(self)
|
|
118
132
|
|
|
133
|
+
def visit_and_replace(
|
|
134
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
135
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
136
|
+
"""Run a visitor function on all "leaf" toolsets (i.e. those that implement their own tool listing and calling) and replace them in the hierarchy with the result of the function."""
|
|
137
|
+
return visitor(self)
|
|
138
|
+
|
|
119
139
|
def filtered(
|
|
120
140
|
self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool]
|
|
121
141
|
) -> FilteredToolset[AgentDepsT]:
|
pydantic_ai/toolsets/combined.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from collections.abc import Sequence
|
|
5
5
|
from contextlib import AsyncExitStack
|
|
6
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field, replace
|
|
7
7
|
from typing import Any, Callable
|
|
8
8
|
|
|
9
9
|
from typing_extensions import Self
|
|
@@ -40,6 +40,14 @@ class CombinedToolset(AbstractToolset[AgentDepsT]):
|
|
|
40
40
|
self._entered_count = 0
|
|
41
41
|
self._exit_stack = None
|
|
42
42
|
|
|
43
|
+
@property
|
|
44
|
+
def id(self) -> str | None:
|
|
45
|
+
return None # pragma: no cover
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def label(self) -> str:
|
|
49
|
+
return f'{self.__class__.__name__}({", ".join(toolset.label for toolset in self.toolsets)})' # pragma: no cover
|
|
50
|
+
|
|
43
51
|
async def __aenter__(self) -> Self:
|
|
44
52
|
async with self._enter_lock:
|
|
45
53
|
if self._entered_count == 0:
|
|
@@ -63,13 +71,15 @@ class CombinedToolset(AbstractToolset[AgentDepsT]):
|
|
|
63
71
|
|
|
64
72
|
for toolset, tools in zip(self.toolsets, toolsets_tools):
|
|
65
73
|
for name, tool in tools.items():
|
|
66
|
-
|
|
74
|
+
tool_toolset = tool.toolset
|
|
75
|
+
if existing_tool := all_tools.get(name):
|
|
76
|
+
capitalized_toolset_label = tool_toolset.label[0].upper() + tool_toolset.label[1:]
|
|
67
77
|
raise UserError(
|
|
68
|
-
f'{
|
|
78
|
+
f'{capitalized_toolset_label} defines a tool whose name conflicts with existing tool from {existing_tool.toolset.label}: {name!r}. {toolset.tool_name_conflict_hint}'
|
|
69
79
|
)
|
|
70
80
|
|
|
71
81
|
all_tools[name] = _CombinedToolsetTool(
|
|
72
|
-
toolset=
|
|
82
|
+
toolset=tool_toolset,
|
|
73
83
|
tool_def=tool.tool_def,
|
|
74
84
|
max_retries=tool.max_retries,
|
|
75
85
|
args_validator=tool.args_validator,
|
|
@@ -87,3 +97,8 @@ class CombinedToolset(AbstractToolset[AgentDepsT]):
|
|
|
87
97
|
def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
|
|
88
98
|
for toolset in self.toolsets:
|
|
89
99
|
toolset.apply(visitor)
|
|
100
|
+
|
|
101
|
+
def visit_and_replace(
|
|
102
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
103
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
104
|
+
return replace(self, toolsets=[toolset.visit_and_replace(visitor) for toolset in self.toolsets])
|
pydantic_ai/toolsets/deferred.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import
|
|
3
|
+
from dataclasses import replace
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
from pydantic_core import SchemaValidator, core_schema
|
|
@@ -12,7 +12,6 @@ from .abstract import AbstractToolset, ToolsetTool
|
|
|
12
12
|
TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema())
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
@dataclass
|
|
16
15
|
class DeferredToolset(AbstractToolset[AgentDepsT]):
|
|
17
16
|
"""A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called.
|
|
18
17
|
|
|
@@ -20,6 +19,15 @@ class DeferredToolset(AbstractToolset[AgentDepsT]):
|
|
|
20
19
|
"""
|
|
21
20
|
|
|
22
21
|
tool_defs: list[ToolDefinition]
|
|
22
|
+
_id: str | None
|
|
23
|
+
|
|
24
|
+
def __init__(self, tool_defs: list[ToolDefinition], *, id: str | None = None):
|
|
25
|
+
self.tool_defs = tool_defs
|
|
26
|
+
self._id = id
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def id(self) -> str | None:
|
|
30
|
+
return self._id
|
|
23
31
|
|
|
24
32
|
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
25
33
|
return {
|
pydantic_ai/toolsets/function.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import Awaitable, Sequence
|
|
4
|
-
from dataclasses import dataclass,
|
|
4
|
+
from dataclasses import dataclass, replace
|
|
5
5
|
from typing import Any, Callable, overload
|
|
6
6
|
|
|
7
7
|
from pydantic.json_schema import GenerateJsonSchema
|
|
@@ -20,30 +20,40 @@ from .abstract import AbstractToolset, ToolsetTool
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@dataclass
|
|
23
|
-
class
|
|
23
|
+
class FunctionToolsetTool(ToolsetTool[AgentDepsT]):
|
|
24
24
|
"""A tool definition for a function toolset tool that keeps track of the function to call."""
|
|
25
25
|
|
|
26
26
|
call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]]
|
|
27
|
+
is_async: bool
|
|
27
28
|
|
|
28
29
|
|
|
29
|
-
@dataclass(init=False)
|
|
30
30
|
class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
31
31
|
"""A toolset that lets Python functions be used as tools.
|
|
32
32
|
|
|
33
33
|
See [toolset docs](../toolsets.md#function-toolset) for more information.
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
|
-
max_retries: int
|
|
37
|
-
tools: dict[str, Tool[Any]]
|
|
36
|
+
max_retries: int
|
|
37
|
+
tools: dict[str, Tool[Any]]
|
|
38
|
+
_id: str | None
|
|
38
39
|
|
|
39
|
-
def __init__(
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
|
|
43
|
+
max_retries: int = 1,
|
|
44
|
+
*,
|
|
45
|
+
id: str | None = None,
|
|
46
|
+
):
|
|
40
47
|
"""Build a new function toolset.
|
|
41
48
|
|
|
42
49
|
Args:
|
|
43
50
|
tools: The tools to add to the toolset.
|
|
44
51
|
max_retries: The maximum number of retries for each tool during a run.
|
|
52
|
+
id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow.
|
|
45
53
|
"""
|
|
46
54
|
self.max_retries = max_retries
|
|
55
|
+
self._id = id
|
|
56
|
+
|
|
47
57
|
self.tools = {}
|
|
48
58
|
for tool in tools:
|
|
49
59
|
if isinstance(tool, Tool):
|
|
@@ -51,6 +61,10 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
51
61
|
else:
|
|
52
62
|
self.add_function(tool)
|
|
53
63
|
|
|
64
|
+
@property
|
|
65
|
+
def id(self) -> str | None:
|
|
66
|
+
return self._id
|
|
67
|
+
|
|
54
68
|
@overload
|
|
55
69
|
def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ...
|
|
56
70
|
|
|
@@ -222,17 +236,18 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
|
|
|
222
236
|
else:
|
|
223
237
|
raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.')
|
|
224
238
|
|
|
225
|
-
tools[new_name] =
|
|
239
|
+
tools[new_name] = FunctionToolsetTool(
|
|
226
240
|
toolset=self,
|
|
227
241
|
tool_def=tool_def,
|
|
228
242
|
max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries,
|
|
229
243
|
args_validator=tool.function_schema.validator,
|
|
230
244
|
call_func=tool.function_schema.call,
|
|
245
|
+
is_async=tool.function_schema.is_async,
|
|
231
246
|
)
|
|
232
247
|
return tools
|
|
233
248
|
|
|
234
249
|
async def call_tool(
|
|
235
250
|
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
236
251
|
) -> Any:
|
|
237
|
-
assert isinstance(tool,
|
|
252
|
+
assert isinstance(tool, FunctionToolsetTool)
|
|
238
253
|
return await tool.call_func(tool_args, ctx)
|