pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_agent_graph.py +310 -140
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +4 -4
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +3 -22
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +7 -8
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +23 -2
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +2 -2
- pydantic_ai/messages.py +81 -28
- pydantic_ai/models/__init__.py +19 -7
- pydantic_ai/models/anthropic.py +6 -6
- pydantic_ai/models/bedrock.py +63 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +10 -13
- pydantic_ai/models/google.py +4 -4
- pydantic_ai/models/groq.py +5 -5
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +44 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +20 -29
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +10 -29
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +126 -22
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +13 -4
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +7 -5
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +6 -7
- pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -4,14 +4,14 @@ import asyncio
|
|
|
4
4
|
import dataclasses
|
|
5
5
|
import hashlib
|
|
6
6
|
from collections import defaultdict, deque
|
|
7
|
-
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
|
|
8
8
|
from contextlib import asynccontextmanager, contextmanager
|
|
9
9
|
from contextvars import ContextVar
|
|
10
10
|
from dataclasses import field
|
|
11
|
-
from typing import TYPE_CHECKING, Any,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
|
|
12
12
|
|
|
13
13
|
from opentelemetry.trace import Tracer
|
|
14
|
-
from typing_extensions import
|
|
14
|
+
from typing_extensions import TypeVar, assert_never
|
|
15
15
|
|
|
16
16
|
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
|
|
17
17
|
from pydantic_ai._tool_manager import ToolManager
|
|
@@ -24,7 +24,14 @@ from . import _output, _system_prompt, exceptions, messages as _messages, models
|
|
|
24
24
|
from .exceptions import ToolRetryError
|
|
25
25
|
from .output import OutputDataT, OutputSpec
|
|
26
26
|
from .settings import ModelSettings
|
|
27
|
-
from .tools import
|
|
27
|
+
from .tools import (
|
|
28
|
+
DeferredToolResult,
|
|
29
|
+
RunContext,
|
|
30
|
+
ToolApproved,
|
|
31
|
+
ToolDefinition,
|
|
32
|
+
ToolDenied,
|
|
33
|
+
ToolKind,
|
|
34
|
+
)
|
|
28
35
|
|
|
29
36
|
if TYPE_CHECKING:
|
|
30
37
|
from .models.instrumented import InstrumentationSettings
|
|
@@ -59,19 +66,19 @@ _HistoryProcessorSyncWithCtx = Callable[[RunContext[DepsT], list[_messages.Model
|
|
|
59
66
|
_HistoryProcessorAsyncWithCtx = Callable[
|
|
60
67
|
[RunContext[DepsT], list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]
|
|
61
68
|
]
|
|
62
|
-
HistoryProcessor =
|
|
63
|
-
_HistoryProcessorSync
|
|
64
|
-
_HistoryProcessorAsync
|
|
65
|
-
_HistoryProcessorSyncWithCtx[DepsT]
|
|
66
|
-
_HistoryProcessorAsyncWithCtx[DepsT]
|
|
67
|
-
|
|
69
|
+
HistoryProcessor = (
|
|
70
|
+
_HistoryProcessorSync
|
|
71
|
+
| _HistoryProcessorAsync
|
|
72
|
+
| _HistoryProcessorSyncWithCtx[DepsT]
|
|
73
|
+
| _HistoryProcessorAsyncWithCtx[DepsT]
|
|
74
|
+
)
|
|
68
75
|
"""A function that processes a list of model messages and returns a list of model messages.
|
|
69
76
|
|
|
70
77
|
Can optionally accept a `RunContext` as a parameter.
|
|
71
78
|
"""
|
|
72
79
|
|
|
73
80
|
|
|
74
|
-
@dataclasses.dataclass
|
|
81
|
+
@dataclasses.dataclass(kw_only=True)
|
|
75
82
|
class GraphAgentState:
|
|
76
83
|
"""State kept across the execution of the agent graph."""
|
|
77
84
|
|
|
@@ -92,7 +99,7 @@ class GraphAgentState:
|
|
|
92
99
|
raise exceptions.UnexpectedModelBehavior(message)
|
|
93
100
|
|
|
94
101
|
|
|
95
|
-
@dataclasses.dataclass
|
|
102
|
+
@dataclasses.dataclass(kw_only=True)
|
|
96
103
|
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
97
104
|
"""Dependencies/config passed to the agent graph."""
|
|
98
105
|
|
|
@@ -115,9 +122,10 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
115
122
|
|
|
116
123
|
builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
|
|
117
124
|
tool_manager: ToolManager[DepsT]
|
|
125
|
+
tool_call_results: dict[str, DeferredToolResult] | None
|
|
118
126
|
|
|
119
127
|
tracer: Tracer
|
|
120
|
-
instrumentation_settings: InstrumentationSettings | None
|
|
128
|
+
instrumentation_settings: InstrumentationSettings | None
|
|
121
129
|
|
|
122
130
|
|
|
123
131
|
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
@@ -149,6 +157,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
149
157
|
|
|
150
158
|
user_prompt: str | Sequence[_messages.UserContent] | None
|
|
151
159
|
|
|
160
|
+
_: dataclasses.KW_ONLY
|
|
161
|
+
|
|
152
162
|
instructions: str | None
|
|
153
163
|
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
154
164
|
|
|
@@ -158,7 +168,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
158
168
|
|
|
159
169
|
async def run(
|
|
160
170
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
161
|
-
) ->
|
|
171
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]:
|
|
162
172
|
try:
|
|
163
173
|
ctx_messages = get_captured_run_messages()
|
|
164
174
|
except LookupError:
|
|
@@ -184,26 +194,29 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
184
194
|
else:
|
|
185
195
|
parts.extend(await self._sys_parts(run_context))
|
|
186
196
|
|
|
197
|
+
if (tool_call_results := ctx.deps.tool_call_results) is not None:
|
|
198
|
+
if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
|
|
199
|
+
# If tool call results were provided, that means the previous run ended on deferred tool calls.
|
|
200
|
+
# That run would typically have ended on a `ModelResponse`, but if it had a mix of deferred tool calls and ones that could already be executed,
|
|
201
|
+
# a `ModelRequest` would already have been added to the history with the preliminary results, even if it wouldn't have been sent to the model yet.
|
|
202
|
+
# So now that we have all of the deferred results, we roll back to the last `ModelResponse` and store the contents of the `ModelRequest` on `deferred_tool_results` to be handled by `CallToolsNode`.
|
|
203
|
+
ctx.deps.tool_call_results = self._update_tool_call_results_from_model_request(
|
|
204
|
+
tool_call_results, last_message
|
|
205
|
+
)
|
|
206
|
+
messages.pop()
|
|
207
|
+
|
|
208
|
+
if not messages:
|
|
209
|
+
raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
|
|
210
|
+
|
|
187
211
|
if messages and (last_message := messages[-1]):
|
|
188
212
|
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
|
|
189
213
|
# Drop last message from history and reuse its parts
|
|
190
214
|
messages.pop()
|
|
191
215
|
parts.extend(last_message.parts)
|
|
192
216
|
elif isinstance(last_message, _messages.ModelResponse):
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
197
|
-
|
|
198
|
-
# Skip ModelRequestNode and go directly to CallToolsNode
|
|
199
|
-
return CallToolsNode[DepsT, NodeRunEndT](model_response=last_message)
|
|
200
|
-
elif any(isinstance(part, _messages.ToolCallPart) for part in last_message.parts):
|
|
201
|
-
raise exceptions.UserError(
|
|
202
|
-
'Cannot provide a new user prompt when the message history ends with '
|
|
203
|
-
'a model response containing unprocessed tool calls. Either process the '
|
|
204
|
-
'tool calls first (by calling `iter` with `user_prompt=None`) or append a '
|
|
205
|
-
'`ModelRequest` with `ToolResultPart`s.'
|
|
206
|
-
)
|
|
217
|
+
call_tools_node = await self._handle_message_history_model_response(ctx, last_message)
|
|
218
|
+
if call_tools_node is not None:
|
|
219
|
+
return call_tools_node
|
|
207
220
|
|
|
208
221
|
if self.user_prompt is not None:
|
|
209
222
|
parts.append(_messages.UserPromptPart(self.user_prompt))
|
|
@@ -213,6 +226,74 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
213
226
|
|
|
214
227
|
return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
|
|
215
228
|
|
|
229
|
+
async def _handle_message_history_model_response(
|
|
230
|
+
self,
|
|
231
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
232
|
+
message: _messages.ModelResponse,
|
|
233
|
+
) -> CallToolsNode[DepsT, NodeRunEndT] | None:
|
|
234
|
+
unprocessed_tool_calls = any(isinstance(part, _messages.ToolCallPart) for part in message.parts)
|
|
235
|
+
if unprocessed_tool_calls:
|
|
236
|
+
if self.user_prompt is not None:
|
|
237
|
+
raise exceptions.UserError(
|
|
238
|
+
'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
if ctx.deps.tool_call_results is not None:
|
|
242
|
+
raise exceptions.UserError(
|
|
243
|
+
'Tool call results were provided, but the message history does not contain any unprocessed tool calls.'
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if unprocessed_tool_calls or self.user_prompt is None:
|
|
247
|
+
# `CallToolsNode` requires the tool manager to be prepared for the run step
|
|
248
|
+
# This will raise errors for any tool name conflicts
|
|
249
|
+
run_context = build_run_context(ctx)
|
|
250
|
+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
251
|
+
|
|
252
|
+
# Skip ModelRequestNode and go directly to CallToolsNode
|
|
253
|
+
return CallToolsNode[DepsT, NodeRunEndT](model_response=message)
|
|
254
|
+
|
|
255
|
+
def _update_tool_call_results_from_model_request(
|
|
256
|
+
self, tool_call_results: dict[str, DeferredToolResult], message: _messages.ModelRequest
|
|
257
|
+
) -> dict[str, DeferredToolResult]:
|
|
258
|
+
last_tool_return: _messages.ToolReturn | None = None
|
|
259
|
+
user_content: list[str | _messages.UserContent] = []
|
|
260
|
+
for part in message.parts:
|
|
261
|
+
if isinstance(part, _messages.ToolReturnPart):
|
|
262
|
+
if part.tool_call_id in tool_call_results:
|
|
263
|
+
raise exceptions.UserError(
|
|
264
|
+
f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
last_tool_return = _messages.ToolReturn(return_value=part.content, metadata=part.metadata)
|
|
268
|
+
tool_call_results[part.tool_call_id] = last_tool_return
|
|
269
|
+
elif isinstance(part, _messages.RetryPromptPart):
|
|
270
|
+
if part.tool_call_id in tool_call_results:
|
|
271
|
+
raise exceptions.UserError(
|
|
272
|
+
f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
tool_call_results[part.tool_call_id] = part
|
|
276
|
+
elif isinstance(part, _messages.UserPromptPart):
|
|
277
|
+
# Tools can return user parts via `ToolReturn.content` or by returning multi-modal content.
|
|
278
|
+
# These go together with a specific `ToolReturnPart`, but we don't have a way to know which,
|
|
279
|
+
# so (below) we just add them to the last one, matching the tool-results-before-user-parts order of the request.
|
|
280
|
+
if isinstance(part.content, str):
|
|
281
|
+
user_content.append(part.content)
|
|
282
|
+
else:
|
|
283
|
+
user_content.extend(part.content)
|
|
284
|
+
else:
|
|
285
|
+
raise exceptions.UserError(f'Unexpected message part type: {type(part)}') # pragma: no cover
|
|
286
|
+
|
|
287
|
+
if user_content:
|
|
288
|
+
if last_tool_return is None:
|
|
289
|
+
raise exceptions.UserError(
|
|
290
|
+
'Tool call results were provided, but the last message in the history was a `ModelRequest` with user parts not tied to preliminary tool results.'
|
|
291
|
+
)
|
|
292
|
+
assert last_tool_return is not None
|
|
293
|
+
last_tool_return.content = user_content
|
|
294
|
+
|
|
295
|
+
return tool_call_results
|
|
296
|
+
|
|
216
297
|
async def _reevaluate_dynamic_prompts(
|
|
217
298
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
218
299
|
) -> None:
|
|
@@ -280,8 +361,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
280
361
|
|
|
281
362
|
request: _messages.ModelRequest
|
|
282
363
|
|
|
283
|
-
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(
|
|
284
|
-
_did_stream: bool = field(
|
|
364
|
+
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
|
|
365
|
+
_did_stream: bool = field(repr=False, init=False, default=False)
|
|
285
366
|
|
|
286
367
|
async def run(
|
|
287
368
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
@@ -310,13 +391,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
310
391
|
self._did_stream = True
|
|
311
392
|
ctx.state.usage.requests += 1
|
|
312
393
|
agent_stream = result.AgentStream[DepsT, T](
|
|
313
|
-
streamed_response,
|
|
314
|
-
ctx.deps.output_schema,
|
|
315
|
-
model_request_parameters,
|
|
316
|
-
ctx.deps.output_validators,
|
|
317
|
-
build_run_context(ctx),
|
|
318
|
-
ctx.deps.usage_limits,
|
|
319
|
-
ctx.deps.tool_manager,
|
|
394
|
+
_raw_stream_response=streamed_response,
|
|
395
|
+
_output_schema=ctx.deps.output_schema,
|
|
396
|
+
_model_request_parameters=model_request_parameters,
|
|
397
|
+
_output_validators=ctx.deps.output_validators,
|
|
398
|
+
_run_ctx=build_run_context(ctx),
|
|
399
|
+
_usage_limits=ctx.deps.usage_limits,
|
|
400
|
+
_tool_manager=ctx.deps.tool_manager,
|
|
320
401
|
)
|
|
321
402
|
yield agent_stream
|
|
322
403
|
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
@@ -396,14 +477,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
396
477
|
|
|
397
478
|
model_response: _messages.ModelResponse
|
|
398
479
|
|
|
399
|
-
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
|
|
480
|
+
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
|
|
400
481
|
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
|
|
401
|
-
default=None, repr=False
|
|
482
|
+
default=None, init=False, repr=False
|
|
402
483
|
)
|
|
403
484
|
|
|
404
485
|
async def run(
|
|
405
486
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
406
|
-
) ->
|
|
487
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
407
488
|
async with self.stream(ctx):
|
|
408
489
|
pass
|
|
409
490
|
assert self._next_node is not None, 'the stream should set `self._next_node` before it ends'
|
|
@@ -506,13 +587,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
506
587
|
if output_final_result:
|
|
507
588
|
final_result = output_final_result[0]
|
|
508
589
|
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
509
|
-
elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
|
|
510
|
-
if not ctx.deps.output_schema.allows_deferred_tool_calls:
|
|
511
|
-
raise exceptions.UserError(
|
|
512
|
-
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
513
|
-
)
|
|
514
|
-
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
|
|
515
|
-
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
516
590
|
else:
|
|
517
591
|
instructions = await ctx.deps.get_instructions(run_context)
|
|
518
592
|
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
@@ -557,7 +631,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
557
631
|
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
558
632
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
559
633
|
else:
|
|
560
|
-
return self._handle_final_result(ctx, result.FinalResult(result_data
|
|
634
|
+
return self._handle_final_result(ctx, result.FinalResult(result_data), [])
|
|
561
635
|
|
|
562
636
|
|
|
563
637
|
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
@@ -572,6 +646,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
572
646
|
trace_include_content=ctx.deps.instrumentation_settings is not None
|
|
573
647
|
and ctx.deps.instrumentation_settings.include_content,
|
|
574
648
|
run_step=ctx.state.run_step,
|
|
649
|
+
tool_call_approved=ctx.state.run_step == 0 and ctx.deps.tool_call_results is not None,
|
|
575
650
|
)
|
|
576
651
|
|
|
577
652
|
|
|
@@ -599,7 +674,10 @@ async def process_function_tools( # noqa: C901
|
|
|
599
674
|
tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
|
|
600
675
|
for call in tool_calls:
|
|
601
676
|
tool_def = tool_manager.get_tool_def(call.tool_name)
|
|
602
|
-
|
|
677
|
+
if tool_def:
|
|
678
|
+
kind = tool_def.kind
|
|
679
|
+
else:
|
|
680
|
+
kind = 'unknown'
|
|
603
681
|
tool_calls_by_kind[kind].append(call)
|
|
604
682
|
|
|
605
683
|
# First, we handle output tool calls
|
|
@@ -662,132 +740,224 @@ async def process_function_tools( # noqa: C901
|
|
|
662
740
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
663
741
|
calls_to_run.extend(tool_calls_by_kind['unknown'])
|
|
664
742
|
|
|
665
|
-
|
|
666
|
-
|
|
743
|
+
deferred_tool_results: dict[str, DeferredToolResult] = {}
|
|
744
|
+
if build_run_context(ctx).tool_call_approved and ctx.deps.tool_call_results is not None:
|
|
745
|
+
deferred_tool_results = ctx.deps.tool_call_results
|
|
667
746
|
|
|
668
|
-
|
|
747
|
+
# Deferred tool calls are "run" as well, by reading their value from the tool call results
|
|
748
|
+
calls_to_run.extend(tool_calls_by_kind['external'])
|
|
749
|
+
calls_to_run.extend(tool_calls_by_kind['unapproved'])
|
|
669
750
|
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
|
|
678
|
-
},
|
|
679
|
-
):
|
|
680
|
-
tasks = [
|
|
681
|
-
asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
|
|
682
|
-
for call in calls_to_run
|
|
683
|
-
]
|
|
684
|
-
|
|
685
|
-
pending = tasks
|
|
686
|
-
while pending:
|
|
687
|
-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
688
|
-
for task in done:
|
|
689
|
-
index = tasks.index(task)
|
|
690
|
-
tool_part, tool_user_parts = task.result()
|
|
691
|
-
yield _messages.FunctionToolResultEvent(tool_part)
|
|
751
|
+
result_tool_call_ids = set(deferred_tool_results.keys())
|
|
752
|
+
tool_call_ids_to_run = {call.tool_call_id for call in calls_to_run}
|
|
753
|
+
if tool_call_ids_to_run != result_tool_call_ids:
|
|
754
|
+
raise exceptions.UserError(
|
|
755
|
+
'Tool call results need to be provided for all deferred tool calls. '
|
|
756
|
+
f'Expected: {tool_call_ids_to_run}, got: {result_tool_call_ids}'
|
|
757
|
+
)
|
|
692
758
|
|
|
693
|
-
|
|
694
|
-
user_parts_by_index[index] = tool_user_parts
|
|
759
|
+
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
|
|
695
760
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
761
|
+
if calls_to_run:
|
|
762
|
+
async for event in _call_tools(
|
|
763
|
+
tool_manager,
|
|
764
|
+
calls_to_run,
|
|
765
|
+
deferred_tool_results,
|
|
766
|
+
ctx.deps.tracer,
|
|
767
|
+
output_parts,
|
|
768
|
+
deferred_calls,
|
|
769
|
+
):
|
|
770
|
+
yield event
|
|
700
771
|
|
|
701
|
-
# Finally, we handle deferred tool calls
|
|
702
|
-
|
|
772
|
+
# Finally, we handle deferred tool calls (unless they were already included in the run because results were provided)
|
|
773
|
+
if not deferred_tool_results:
|
|
703
774
|
if final_result:
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
775
|
+
for call in [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']]:
|
|
776
|
+
output_parts.append(
|
|
777
|
+
_messages.ToolReturnPart(
|
|
778
|
+
tool_name=call.tool_name,
|
|
779
|
+
content='Tool not executed - a final result was already processed.',
|
|
780
|
+
tool_call_id=call.tool_call_id,
|
|
781
|
+
)
|
|
709
782
|
)
|
|
710
|
-
)
|
|
711
783
|
else:
|
|
712
|
-
|
|
784
|
+
for call in tool_calls_by_kind['external']:
|
|
785
|
+
deferred_calls['external'].append(call)
|
|
786
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
713
787
|
|
|
714
|
-
|
|
715
|
-
|
|
788
|
+
for call in tool_calls_by_kind['unapproved']:
|
|
789
|
+
deferred_calls['unapproved'].append(call)
|
|
790
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
791
|
+
|
|
792
|
+
if not final_result and deferred_calls:
|
|
793
|
+
if not ctx.deps.output_schema.allows_deferred_tools:
|
|
794
|
+
raise exceptions.UserError(
|
|
795
|
+
'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
|
|
796
|
+
)
|
|
797
|
+
deferred_tool_requests = _output.DeferredToolRequests(
|
|
798
|
+
calls=deferred_calls['external'],
|
|
799
|
+
approvals=deferred_calls['unapproved'],
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
|
|
716
803
|
|
|
717
804
|
if final_result:
|
|
718
805
|
output_final_result.append(final_result)
|
|
719
806
|
|
|
720
807
|
|
|
721
|
-
async def
|
|
808
|
+
async def _call_tools(
|
|
722
809
|
tool_manager: ToolManager[DepsT],
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
810
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
811
|
+
deferred_tool_results: dict[str, DeferredToolResult],
|
|
812
|
+
tracer: Tracer,
|
|
813
|
+
output_parts: list[_messages.ModelRequestPart],
|
|
814
|
+
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
|
|
815
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
816
|
+
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
|
|
817
|
+
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
|
|
818
|
+
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
|
|
729
819
|
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
content=tool_result,
|
|
733
|
-
tool_call_id=tool_call.tool_call_id,
|
|
734
|
-
)
|
|
735
|
-
user_parts: list[_messages.UserPromptPart] = []
|
|
820
|
+
for call in tool_calls:
|
|
821
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
736
822
|
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
823
|
+
# Run all tool tasks in parallel
|
|
824
|
+
with tracer.start_as_current_span(
|
|
825
|
+
'running tools',
|
|
826
|
+
attributes={
|
|
827
|
+
'tools': [call.tool_name for call in tool_calls],
|
|
828
|
+
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
|
|
829
|
+
},
|
|
830
|
+
):
|
|
831
|
+
tasks = [
|
|
832
|
+
asyncio.create_task(
|
|
833
|
+
_call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id)),
|
|
834
|
+
name=call.tool_name,
|
|
749
835
|
)
|
|
836
|
+
for call in tool_calls
|
|
837
|
+
]
|
|
838
|
+
|
|
839
|
+
pending = tasks
|
|
840
|
+
while pending:
|
|
841
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
842
|
+
for task in done:
|
|
843
|
+
index = tasks.index(task)
|
|
844
|
+
try:
|
|
845
|
+
tool_part, tool_user_part = task.result()
|
|
846
|
+
except exceptions.CallDeferred:
|
|
847
|
+
deferred_calls_by_index[index] = 'external'
|
|
848
|
+
except exceptions.ApprovalRequired:
|
|
849
|
+
deferred_calls_by_index[index] = 'unapproved'
|
|
850
|
+
else:
|
|
851
|
+
yield _messages.FunctionToolResultEvent(tool_part)
|
|
750
852
|
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
853
|
+
tool_parts_by_index[index] = tool_part
|
|
854
|
+
if tool_user_part:
|
|
855
|
+
user_parts_by_index[index] = tool_user_part
|
|
856
|
+
|
|
857
|
+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
|
|
858
|
+
# This is mostly just to simplify testing
|
|
859
|
+
for k in sorted(tool_parts_by_index):
|
|
860
|
+
output_parts.append(tool_parts_by_index[k])
|
|
861
|
+
|
|
862
|
+
for k in sorted(user_parts_by_index):
|
|
863
|
+
output_parts.append(user_parts_by_index[k])
|
|
864
|
+
|
|
865
|
+
for k in sorted(deferred_calls_by_index):
|
|
866
|
+
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
async def _call_tool(
|
|
870
|
+
tool_manager: ToolManager[DepsT],
|
|
871
|
+
tool_call: _messages.ToolCallPart,
|
|
872
|
+
tool_call_result: DeferredToolResult | None,
|
|
873
|
+
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
|
|
874
|
+
try:
|
|
875
|
+
if tool_call_result is None:
|
|
876
|
+
tool_result = await tool_manager.handle_call(tool_call)
|
|
877
|
+
elif isinstance(tool_call_result, ToolApproved):
|
|
878
|
+
if tool_call_result.override_args is not None:
|
|
879
|
+
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
|
|
880
|
+
tool_result = await tool_manager.handle_call(tool_call)
|
|
881
|
+
elif isinstance(tool_call_result, ToolDenied):
|
|
882
|
+
return _messages.ToolReturnPart(
|
|
883
|
+
tool_name=tool_call.tool_name,
|
|
884
|
+
content=tool_call_result.message,
|
|
885
|
+
tool_call_id=tool_call.tool_call_id,
|
|
886
|
+
), None
|
|
887
|
+
elif isinstance(tool_call_result, exceptions.ModelRetry):
|
|
888
|
+
m = _messages.RetryPromptPart(
|
|
889
|
+
content=tool_call_result.message,
|
|
890
|
+
tool_name=tool_call.tool_name,
|
|
891
|
+
tool_call_id=tool_call.tool_call_id,
|
|
759
892
|
)
|
|
893
|
+
raise ToolRetryError(m)
|
|
894
|
+
elif isinstance(tool_call_result, _messages.RetryPromptPart):
|
|
895
|
+
tool_call_result.tool_name = tool_call.tool_name
|
|
896
|
+
tool_call_result.tool_call_id = tool_call.tool_call_id
|
|
897
|
+
raise ToolRetryError(tool_call_result)
|
|
898
|
+
else:
|
|
899
|
+
tool_result = tool_call_result
|
|
900
|
+
except ToolRetryError as e:
|
|
901
|
+
return e.tool_retry, None
|
|
902
|
+
|
|
903
|
+
if isinstance(tool_result, _messages.ToolReturn):
|
|
904
|
+
tool_return = tool_result
|
|
760
905
|
else:
|
|
906
|
+
result_is_list = isinstance(tool_result, list)
|
|
907
|
+
contents = cast(list[Any], tool_result) if result_is_list else [tool_result]
|
|
761
908
|
|
|
762
|
-
|
|
909
|
+
return_values: list[Any] = []
|
|
910
|
+
user_contents: list[str | _messages.UserContent] = []
|
|
911
|
+
for content in contents:
|
|
763
912
|
if isinstance(content, _messages.ToolReturn):
|
|
764
913
|
raise exceptions.UserError(
|
|
765
914
|
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
|
|
766
915
|
f'`ToolReturn` should be used directly.'
|
|
767
916
|
)
|
|
768
|
-
elif isinstance(content, _messages.
|
|
917
|
+
elif isinstance(content, _messages.MultiModalContent):
|
|
769
918
|
if isinstance(content, _messages.BinaryContent):
|
|
770
919
|
identifier = content.identifier or multi_modal_content_identifier(content.data)
|
|
771
920
|
else:
|
|
772
921
|
identifier = multi_modal_content_identifier(content.url)
|
|
773
922
|
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
)
|
|
779
|
-
)
|
|
780
|
-
return f'See file {identifier}'
|
|
923
|
+
return_values.append(f'See file {identifier}')
|
|
924
|
+
user_contents.extend([f'This is file {identifier}:', content])
|
|
925
|
+
else:
|
|
926
|
+
return_values.append(content)
|
|
781
927
|
|
|
782
|
-
|
|
928
|
+
tool_return = _messages.ToolReturn(
|
|
929
|
+
return_value=return_values[0] if len(return_values) == 1 and not result_is_list else return_values,
|
|
930
|
+
content=user_contents,
|
|
931
|
+
)
|
|
783
932
|
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
933
|
+
if (
|
|
934
|
+
isinstance(tool_return.return_value, _messages.MultiModalContent)
|
|
935
|
+
or isinstance(tool_return.return_value, list)
|
|
936
|
+
and any(
|
|
937
|
+
isinstance(content, _messages.MultiModalContent)
|
|
938
|
+
for content in tool_return.return_value # type: ignore
|
|
939
|
+
)
|
|
940
|
+
):
|
|
941
|
+
raise exceptions.UserError(
|
|
942
|
+
f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContent` objects. '
|
|
943
|
+
f'Please use `content` instead.'
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
return_part = _messages.ToolReturnPart(
|
|
947
|
+
tool_name=tool_call.tool_name,
|
|
948
|
+
tool_call_id=tool_call.tool_call_id,
|
|
949
|
+
content=tool_return.return_value, # type: ignore
|
|
950
|
+
metadata=tool_return.metadata,
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
user_part: _messages.UserPromptPart | None = None
|
|
954
|
+
if tool_return.content:
|
|
955
|
+
user_part = _messages.UserPromptPart(
|
|
956
|
+
content=tool_return.content,
|
|
957
|
+
part_kind='user-prompt',
|
|
958
|
+
)
|
|
789
959
|
|
|
790
|
-
return
|
|
960
|
+
return return_part, user_part
|
|
791
961
|
|
|
792
962
|
|
|
793
963
|
@dataclasses.dataclass
|
pydantic_ai/_function_schema.py
CHANGED
|
@@ -5,10 +5,10 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
|
-
from collections.abc import Awaitable
|
|
8
|
+
from collections.abc import Awaitable, Callable
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from inspect import Parameter, signature
|
|
11
|
-
from typing import TYPE_CHECKING, Any,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin
|
|
12
12
|
|
|
13
13
|
from pydantic import ConfigDict
|
|
14
14
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -17,7 +17,7 @@ from pydantic.fields import FieldInfo
|
|
|
17
17
|
from pydantic.json_schema import GenerateJsonSchema
|
|
18
18
|
from pydantic.plugin._schema_validator import create_schema_validator
|
|
19
19
|
from pydantic_core import SchemaValidator, core_schema
|
|
20
|
-
from typing_extensions import
|
|
20
|
+
from typing_extensions import ParamSpec, TypeIs, TypeVar
|
|
21
21
|
|
|
22
22
|
from ._griffe import doc_descriptions
|
|
23
23
|
from ._run_context import RunContext
|
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
|
|
30
30
|
__all__ = ('function_schema',)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
@dataclass
|
|
33
|
+
@dataclass(kw_only=True)
|
|
34
34
|
class FunctionSchema:
|
|
35
35
|
"""Internal information about a function schema."""
|
|
36
36
|
|
|
@@ -231,7 +231,7 @@ R = TypeVar('R')
|
|
|
231
231
|
|
|
232
232
|
WithCtx = Callable[Concatenate[RunContext[Any], P], R]
|
|
233
233
|
WithoutCtx = Callable[P, R]
|
|
234
|
-
TargetFunc =
|
|
234
|
+
TargetFunc = WithCtx[P, R] | WithoutCtx[P, R]
|
|
235
235
|
|
|
236
236
|
|
|
237
237
|
def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
|