pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +156 -44
- pydantic_ai/models/__init__.py +20 -7
- pydantic_ai/models/anthropic.py +10 -17
- pydantic_ai/models/bedrock.py +55 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +13 -14
- pydantic_ai/models/google.py +19 -5
- pydantic_ai/models/groq.py +127 -39
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +49 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +37 -42
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -2,16 +2,15 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
|
-
import hashlib
|
|
6
5
|
from collections import defaultdict, deque
|
|
7
|
-
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
6
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
|
|
8
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
9
8
|
from contextvars import ContextVar
|
|
10
9
|
from dataclasses import field
|
|
11
|
-
from typing import TYPE_CHECKING, Any,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
|
|
12
11
|
|
|
13
12
|
from opentelemetry.trace import Tracer
|
|
14
|
-
from typing_extensions import
|
|
13
|
+
from typing_extensions import TypeVar, assert_never
|
|
15
14
|
|
|
16
15
|
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
|
|
17
16
|
from pydantic_ai._tool_manager import ToolManager
|
|
@@ -24,7 +23,14 @@ from . import _output, _system_prompt, exceptions, messages as _messages, models
|
|
|
24
23
|
from .exceptions import ToolRetryError
|
|
25
24
|
from .output import OutputDataT, OutputSpec
|
|
26
25
|
from .settings import ModelSettings
|
|
27
|
-
from .tools import
|
|
26
|
+
from .tools import (
|
|
27
|
+
DeferredToolResult,
|
|
28
|
+
RunContext,
|
|
29
|
+
ToolApproved,
|
|
30
|
+
ToolDefinition,
|
|
31
|
+
ToolDenied,
|
|
32
|
+
ToolKind,
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
if TYPE_CHECKING:
|
|
30
36
|
from .models.instrumented import InstrumentationSettings
|
|
@@ -59,19 +65,19 @@ _HistoryProcessorSyncWithCtx = Callable[[RunContext[DepsT], list[_messages.Model
|
|
|
59
65
|
_HistoryProcessorAsyncWithCtx = Callable[
|
|
60
66
|
[RunContext[DepsT], list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]
|
|
61
67
|
]
|
|
62
|
-
HistoryProcessor =
|
|
63
|
-
_HistoryProcessorSync
|
|
64
|
-
_HistoryProcessorAsync
|
|
65
|
-
_HistoryProcessorSyncWithCtx[DepsT]
|
|
66
|
-
_HistoryProcessorAsyncWithCtx[DepsT]
|
|
67
|
-
|
|
68
|
+
HistoryProcessor = (
|
|
69
|
+
_HistoryProcessorSync
|
|
70
|
+
| _HistoryProcessorAsync
|
|
71
|
+
| _HistoryProcessorSyncWithCtx[DepsT]
|
|
72
|
+
| _HistoryProcessorAsyncWithCtx[DepsT]
|
|
73
|
+
)
|
|
68
74
|
"""A function that processes a list of model messages and returns a list of model messages.
|
|
69
75
|
|
|
70
76
|
Can optionally accept a `RunContext` as a parameter.
|
|
71
77
|
"""
|
|
72
78
|
|
|
73
79
|
|
|
74
|
-
@dataclasses.dataclass
|
|
80
|
+
@dataclasses.dataclass(kw_only=True)
|
|
75
81
|
class GraphAgentState:
|
|
76
82
|
"""State kept across the execution of the agent graph."""
|
|
77
83
|
|
|
@@ -92,7 +98,7 @@ class GraphAgentState:
|
|
|
92
98
|
raise exceptions.UnexpectedModelBehavior(message)
|
|
93
99
|
|
|
94
100
|
|
|
95
|
-
@dataclasses.dataclass
|
|
101
|
+
@dataclasses.dataclass(kw_only=True)
|
|
96
102
|
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
97
103
|
"""Dependencies/config passed to the agent graph."""
|
|
98
104
|
|
|
@@ -115,9 +121,10 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
115
121
|
|
|
116
122
|
builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
|
|
117
123
|
tool_manager: ToolManager[DepsT]
|
|
124
|
+
tool_call_results: dict[str, DeferredToolResult] | None
|
|
118
125
|
|
|
119
126
|
tracer: Tracer
|
|
120
|
-
instrumentation_settings: InstrumentationSettings | None
|
|
127
|
+
instrumentation_settings: InstrumentationSettings | None
|
|
121
128
|
|
|
122
129
|
|
|
123
130
|
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
@@ -149,6 +156,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
149
156
|
|
|
150
157
|
user_prompt: str | Sequence[_messages.UserContent] | None
|
|
151
158
|
|
|
159
|
+
_: dataclasses.KW_ONLY
|
|
160
|
+
|
|
152
161
|
instructions: str | None
|
|
153
162
|
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
154
163
|
|
|
@@ -158,7 +167,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
158
167
|
|
|
159
168
|
async def run(
|
|
160
169
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
161
|
-
) ->
|
|
170
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]:
|
|
162
171
|
try:
|
|
163
172
|
ctx_messages = get_captured_run_messages()
|
|
164
173
|
except LookupError:
|
|
@@ -184,26 +193,29 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
184
193
|
else:
|
|
185
194
|
parts.extend(await self._sys_parts(run_context))
|
|
186
195
|
|
|
196
|
+
if (tool_call_results := ctx.deps.tool_call_results) is not None:
|
|
197
|
+
if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
|
|
198
|
+
# If tool call results were provided, that means the previous run ended on deferred tool calls.
|
|
199
|
+
# That run would typically have ended on a `ModelResponse`, but if it had a mix of deferred tool calls and ones that could already be executed,
|
|
200
|
+
# a `ModelRequest` would already have been added to the history with the preliminary results, even if it wouldn't have been sent to the model yet.
|
|
201
|
+
# So now that we have all of the deferred results, we roll back to the last `ModelResponse` and store the contents of the `ModelRequest` on `deferred_tool_results` to be handled by `CallToolsNode`.
|
|
202
|
+
ctx.deps.tool_call_results = self._update_tool_call_results_from_model_request(
|
|
203
|
+
tool_call_results, last_message
|
|
204
|
+
)
|
|
205
|
+
messages.pop()
|
|
206
|
+
|
|
207
|
+
if not messages:
|
|
208
|
+
raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
|
|
209
|
+
|
|
187
210
|
if messages and (last_message := messages[-1]):
|
|
188
211
|
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
|
|
189
212
|
# Drop last message from history and reuse its parts
|
|
190
213
|
messages.pop()
|
|
191
214
|
parts.extend(last_message.parts)
|
|
192
215
|
elif isinstance(last_message, _messages.ModelResponse):
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
197
|
-
|
|
198
|
-
# Skip ModelRequestNode and go directly to CallToolsNode
|
|
199
|
-
return CallToolsNode[DepsT, NodeRunEndT](model_response=last_message)
|
|
200
|
-
elif any(isinstance(part, _messages.ToolCallPart) for part in last_message.parts):
|
|
201
|
-
raise exceptions.UserError(
|
|
202
|
-
'Cannot provide a new user prompt when the message history ends with '
|
|
203
|
-
'a model response containing unprocessed tool calls. Either process the '
|
|
204
|
-
'tool calls first (by calling `iter` with `user_prompt=None`) or append a '
|
|
205
|
-
'`ModelRequest` with `ToolResultPart`s.'
|
|
206
|
-
)
|
|
216
|
+
call_tools_node = await self._handle_message_history_model_response(ctx, last_message)
|
|
217
|
+
if call_tools_node is not None:
|
|
218
|
+
return call_tools_node
|
|
207
219
|
|
|
208
220
|
if self.user_prompt is not None:
|
|
209
221
|
parts.append(_messages.UserPromptPart(self.user_prompt))
|
|
@@ -213,6 +225,74 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
213
225
|
|
|
214
226
|
return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
|
|
215
227
|
|
|
228
|
+
async def _handle_message_history_model_response(
|
|
229
|
+
self,
|
|
230
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
231
|
+
message: _messages.ModelResponse,
|
|
232
|
+
) -> CallToolsNode[DepsT, NodeRunEndT] | None:
|
|
233
|
+
unprocessed_tool_calls = any(isinstance(part, _messages.ToolCallPart) for part in message.parts)
|
|
234
|
+
if unprocessed_tool_calls:
|
|
235
|
+
if self.user_prompt is not None:
|
|
236
|
+
raise exceptions.UserError(
|
|
237
|
+
'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
if ctx.deps.tool_call_results is not None:
|
|
241
|
+
raise exceptions.UserError(
|
|
242
|
+
'Tool call results were provided, but the message history does not contain any unprocessed tool calls.'
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if unprocessed_tool_calls or self.user_prompt is None:
|
|
246
|
+
# `CallToolsNode` requires the tool manager to be prepared for the run step
|
|
247
|
+
# This will raise errors for any tool name conflicts
|
|
248
|
+
run_context = build_run_context(ctx)
|
|
249
|
+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
250
|
+
|
|
251
|
+
# Skip ModelRequestNode and go directly to CallToolsNode
|
|
252
|
+
return CallToolsNode[DepsT, NodeRunEndT](model_response=message)
|
|
253
|
+
|
|
254
|
+
def _update_tool_call_results_from_model_request(
|
|
255
|
+
self, tool_call_results: dict[str, DeferredToolResult], message: _messages.ModelRequest
|
|
256
|
+
) -> dict[str, DeferredToolResult]:
|
|
257
|
+
last_tool_return: _messages.ToolReturn | None = None
|
|
258
|
+
user_content: list[str | _messages.UserContent] = []
|
|
259
|
+
for part in message.parts:
|
|
260
|
+
if isinstance(part, _messages.ToolReturnPart):
|
|
261
|
+
if part.tool_call_id in tool_call_results:
|
|
262
|
+
raise exceptions.UserError(
|
|
263
|
+
f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
last_tool_return = _messages.ToolReturn(return_value=part.content, metadata=part.metadata)
|
|
267
|
+
tool_call_results[part.tool_call_id] = last_tool_return
|
|
268
|
+
elif isinstance(part, _messages.RetryPromptPart):
|
|
269
|
+
if part.tool_call_id in tool_call_results:
|
|
270
|
+
raise exceptions.UserError(
|
|
271
|
+
f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
tool_call_results[part.tool_call_id] = part
|
|
275
|
+
elif isinstance(part, _messages.UserPromptPart):
|
|
276
|
+
# Tools can return user parts via `ToolReturn.content` or by returning multi-modal content.
|
|
277
|
+
# These go together with a specific `ToolReturnPart`, but we don't have a way to know which,
|
|
278
|
+
# so (below) we just add them to the last one, matching the tool-results-before-user-parts order of the request.
|
|
279
|
+
if isinstance(part.content, str):
|
|
280
|
+
user_content.append(part.content)
|
|
281
|
+
else:
|
|
282
|
+
user_content.extend(part.content)
|
|
283
|
+
else:
|
|
284
|
+
raise exceptions.UserError(f'Unexpected message part type: {type(part)}') # pragma: no cover
|
|
285
|
+
|
|
286
|
+
if user_content:
|
|
287
|
+
if last_tool_return is None:
|
|
288
|
+
raise exceptions.UserError(
|
|
289
|
+
'Tool call results were provided, but the last message in the history was a `ModelRequest` with user parts not tied to preliminary tool results.'
|
|
290
|
+
)
|
|
291
|
+
assert last_tool_return is not None
|
|
292
|
+
last_tool_return.content = user_content
|
|
293
|
+
|
|
294
|
+
return tool_call_results
|
|
295
|
+
|
|
216
296
|
async def _reevaluate_dynamic_prompts(
|
|
217
297
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
218
298
|
) -> None:
|
|
@@ -221,16 +301,21 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
221
301
|
if self.system_prompt_dynamic_functions:
|
|
222
302
|
for msg in messages:
|
|
223
303
|
if isinstance(msg, _messages.ModelRequest):
|
|
224
|
-
|
|
304
|
+
reevaluated_message_parts: list[_messages.ModelRequestPart] = []
|
|
305
|
+
for part in msg.parts:
|
|
225
306
|
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
|
|
226
307
|
# Look up the runner by its ref
|
|
227
308
|
if runner := self.system_prompt_dynamic_functions.get( # pragma: lax no cover
|
|
228
309
|
part.dynamic_ref
|
|
229
310
|
):
|
|
230
311
|
updated_part_content = await runner.run(run_context)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
312
|
+
part = _messages.SystemPromptPart(updated_part_content, dynamic_ref=part.dynamic_ref)
|
|
313
|
+
|
|
314
|
+
reevaluated_message_parts.append(part)
|
|
315
|
+
|
|
316
|
+
# Replace message parts with reevaluated ones to prevent mutating parts list
|
|
317
|
+
if reevaluated_message_parts != msg.parts:
|
|
318
|
+
msg.parts = reevaluated_message_parts
|
|
234
319
|
|
|
235
320
|
async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
|
|
236
321
|
"""Build the initial messages for the conversation."""
|
|
@@ -280,8 +365,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
280
365
|
|
|
281
366
|
request: _messages.ModelRequest
|
|
282
367
|
|
|
283
|
-
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(
|
|
284
|
-
_did_stream: bool = field(
|
|
368
|
+
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
|
|
369
|
+
_did_stream: bool = field(repr=False, init=False, default=False)
|
|
285
370
|
|
|
286
371
|
async def run(
|
|
287
372
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
@@ -310,13 +395,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
310
395
|
self._did_stream = True
|
|
311
396
|
ctx.state.usage.requests += 1
|
|
312
397
|
agent_stream = result.AgentStream[DepsT, T](
|
|
313
|
-
streamed_response,
|
|
314
|
-
ctx.deps.output_schema,
|
|
315
|
-
model_request_parameters,
|
|
316
|
-
ctx.deps.output_validators,
|
|
317
|
-
build_run_context(ctx),
|
|
318
|
-
ctx.deps.usage_limits,
|
|
319
|
-
ctx.deps.tool_manager,
|
|
398
|
+
_raw_stream_response=streamed_response,
|
|
399
|
+
_output_schema=ctx.deps.output_schema,
|
|
400
|
+
_model_request_parameters=model_request_parameters,
|
|
401
|
+
_output_validators=ctx.deps.output_validators,
|
|
402
|
+
_run_ctx=build_run_context(ctx),
|
|
403
|
+
_usage_limits=ctx.deps.usage_limits,
|
|
404
|
+
_tool_manager=ctx.deps.tool_manager,
|
|
320
405
|
)
|
|
321
406
|
yield agent_stream
|
|
322
407
|
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
@@ -396,14 +481,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
396
481
|
|
|
397
482
|
model_response: _messages.ModelResponse
|
|
398
483
|
|
|
399
|
-
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
|
|
484
|
+
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
|
|
400
485
|
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
|
|
401
|
-
default=None, repr=False
|
|
486
|
+
default=None, init=False, repr=False
|
|
402
487
|
)
|
|
403
488
|
|
|
404
489
|
async def run(
|
|
405
490
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
406
|
-
) ->
|
|
491
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
407
492
|
async with self.stream(ctx):
|
|
408
493
|
pass
|
|
409
494
|
assert self._next_node is not None, 'the stream should set `self._next_node` before it ends'
|
|
@@ -506,13 +591,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
506
591
|
if output_final_result:
|
|
507
592
|
final_result = output_final_result[0]
|
|
508
593
|
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
509
|
-
elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
|
|
510
|
-
if not ctx.deps.output_schema.allows_deferred_tool_calls:
|
|
511
|
-
raise exceptions.UserError(
|
|
512
|
-
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
513
|
-
)
|
|
514
|
-
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
|
|
515
|
-
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
516
594
|
else:
|
|
517
595
|
instructions = await ctx.deps.get_instructions(run_context)
|
|
518
596
|
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
@@ -557,7 +635,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
557
635
|
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
558
636
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
559
637
|
else:
|
|
560
|
-
return self._handle_final_result(ctx, result.FinalResult(result_data
|
|
638
|
+
return self._handle_final_result(ctx, result.FinalResult(result_data), [])
|
|
561
639
|
|
|
562
640
|
|
|
563
641
|
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
@@ -572,16 +650,10 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
572
650
|
trace_include_content=ctx.deps.instrumentation_settings is not None
|
|
573
651
|
and ctx.deps.instrumentation_settings.include_content,
|
|
574
652
|
run_step=ctx.state.run_step,
|
|
653
|
+
tool_call_approved=ctx.state.run_step == 0 and ctx.deps.tool_call_results is not None,
|
|
575
654
|
)
|
|
576
655
|
|
|
577
656
|
|
|
578
|
-
def multi_modal_content_identifier(identifier: str | bytes) -> str:
|
|
579
|
-
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
|
|
580
|
-
if isinstance(identifier, str):
|
|
581
|
-
identifier = identifier.encode('utf-8')
|
|
582
|
-
return hashlib.sha1(identifier).hexdigest()[:6]
|
|
583
|
-
|
|
584
|
-
|
|
585
657
|
async def process_function_tools( # noqa: C901
|
|
586
658
|
tool_manager: ToolManager[DepsT],
|
|
587
659
|
tool_calls: list[_messages.ToolCallPart],
|
|
@@ -599,7 +671,10 @@ async def process_function_tools( # noqa: C901
|
|
|
599
671
|
tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
|
|
600
672
|
for call in tool_calls:
|
|
601
673
|
tool_def = tool_manager.get_tool_def(call.tool_name)
|
|
602
|
-
|
|
674
|
+
if tool_def:
|
|
675
|
+
kind = tool_def.kind
|
|
676
|
+
else:
|
|
677
|
+
kind = 'unknown'
|
|
603
678
|
tool_calls_by_kind[kind].append(call)
|
|
604
679
|
|
|
605
680
|
# First, we handle output tool calls
|
|
@@ -662,132 +737,224 @@ async def process_function_tools( # noqa: C901
|
|
|
662
737
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
663
738
|
calls_to_run.extend(tool_calls_by_kind['unknown'])
|
|
664
739
|
|
|
665
|
-
|
|
666
|
-
|
|
740
|
+
deferred_tool_results: dict[str, DeferredToolResult] = {}
|
|
741
|
+
if build_run_context(ctx).tool_call_approved and ctx.deps.tool_call_results is not None:
|
|
742
|
+
deferred_tool_results = ctx.deps.tool_call_results
|
|
667
743
|
|
|
668
|
-
|
|
744
|
+
# Deferred tool calls are "run" as well, by reading their value from the tool call results
|
|
745
|
+
calls_to_run.extend(tool_calls_by_kind['external'])
|
|
746
|
+
calls_to_run.extend(tool_calls_by_kind['unapproved'])
|
|
669
747
|
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
|
|
678
|
-
},
|
|
679
|
-
):
|
|
680
|
-
tasks = [
|
|
681
|
-
asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
|
|
682
|
-
for call in calls_to_run
|
|
683
|
-
]
|
|
684
|
-
|
|
685
|
-
pending = tasks
|
|
686
|
-
while pending:
|
|
687
|
-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
688
|
-
for task in done:
|
|
689
|
-
index = tasks.index(task)
|
|
690
|
-
tool_part, tool_user_parts = task.result()
|
|
691
|
-
yield _messages.FunctionToolResultEvent(tool_part)
|
|
748
|
+
result_tool_call_ids = set(deferred_tool_results.keys())
|
|
749
|
+
tool_call_ids_to_run = {call.tool_call_id for call in calls_to_run}
|
|
750
|
+
if tool_call_ids_to_run != result_tool_call_ids:
|
|
751
|
+
raise exceptions.UserError(
|
|
752
|
+
'Tool call results need to be provided for all deferred tool calls. '
|
|
753
|
+
f'Expected: {tool_call_ids_to_run}, got: {result_tool_call_ids}'
|
|
754
|
+
)
|
|
692
755
|
|
|
693
|
-
|
|
694
|
-
user_parts_by_index[index] = tool_user_parts
|
|
756
|
+
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
|
|
695
757
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
758
|
+
if calls_to_run:
|
|
759
|
+
async for event in _call_tools(
|
|
760
|
+
tool_manager,
|
|
761
|
+
calls_to_run,
|
|
762
|
+
deferred_tool_results,
|
|
763
|
+
ctx.deps.tracer,
|
|
764
|
+
ctx.deps.usage_limits,
|
|
765
|
+
output_parts,
|
|
766
|
+
deferred_calls,
|
|
767
|
+
):
|
|
768
|
+
yield event
|
|
700
769
|
|
|
701
|
-
# Finally, we handle deferred tool calls
|
|
702
|
-
|
|
770
|
+
# Finally, we handle deferred tool calls (unless they were already included in the run because results were provided)
|
|
771
|
+
if not deferred_tool_results:
|
|
703
772
|
if final_result:
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
773
|
+
for call in [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']]:
|
|
774
|
+
output_parts.append(
|
|
775
|
+
_messages.ToolReturnPart(
|
|
776
|
+
tool_name=call.tool_name,
|
|
777
|
+
content='Tool not executed - a final result was already processed.',
|
|
778
|
+
tool_call_id=call.tool_call_id,
|
|
779
|
+
)
|
|
709
780
|
)
|
|
710
|
-
)
|
|
711
781
|
else:
|
|
712
|
-
|
|
782
|
+
for call in tool_calls_by_kind['external']:
|
|
783
|
+
deferred_calls['external'].append(call)
|
|
784
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
713
785
|
|
|
714
|
-
|
|
715
|
-
|
|
786
|
+
for call in tool_calls_by_kind['unapproved']:
|
|
787
|
+
deferred_calls['unapproved'].append(call)
|
|
788
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
789
|
+
|
|
790
|
+
if not final_result and deferred_calls:
|
|
791
|
+
if not ctx.deps.output_schema.allows_deferred_tools:
|
|
792
|
+
raise exceptions.UserError(
|
|
793
|
+
'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
|
|
794
|
+
)
|
|
795
|
+
deferred_tool_requests = _output.DeferredToolRequests(
|
|
796
|
+
calls=deferred_calls['external'],
|
|
797
|
+
approvals=deferred_calls['unapproved'],
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
|
|
716
801
|
|
|
717
802
|
if final_result:
|
|
718
803
|
output_final_result.append(final_result)
|
|
719
804
|
|
|
720
805
|
|
|
721
|
-
async def
|
|
806
|
+
async def _call_tools(
|
|
722
807
|
tool_manager: ToolManager[DepsT],
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
808
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
809
|
+
deferred_tool_results: dict[str, DeferredToolResult],
|
|
810
|
+
tracer: Tracer,
|
|
811
|
+
usage_limits: _usage.UsageLimits | None,
|
|
812
|
+
output_parts: list[_messages.ModelRequestPart],
|
|
813
|
+
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
|
|
814
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
815
|
+
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
|
|
816
|
+
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
|
|
817
|
+
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
|
|
729
818
|
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
content=tool_result,
|
|
733
|
-
tool_call_id=tool_call.tool_call_id,
|
|
734
|
-
)
|
|
735
|
-
user_parts: list[_messages.UserPromptPart] = []
|
|
819
|
+
for call in tool_calls:
|
|
820
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
736
821
|
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
822
|
+
# Run all tool tasks in parallel
|
|
823
|
+
with tracer.start_as_current_span(
|
|
824
|
+
'running tools',
|
|
825
|
+
attributes={
|
|
826
|
+
'tools': [call.tool_name for call in tool_calls],
|
|
827
|
+
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
|
|
828
|
+
},
|
|
829
|
+
):
|
|
830
|
+
tasks = [
|
|
831
|
+
asyncio.create_task(
|
|
832
|
+
_call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits),
|
|
833
|
+
name=call.tool_name,
|
|
749
834
|
)
|
|
835
|
+
for call in tool_calls
|
|
836
|
+
]
|
|
837
|
+
|
|
838
|
+
pending = tasks
|
|
839
|
+
while pending:
|
|
840
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
841
|
+
for task in done:
|
|
842
|
+
index = tasks.index(task)
|
|
843
|
+
try:
|
|
844
|
+
tool_part, tool_user_part = task.result()
|
|
845
|
+
except exceptions.CallDeferred:
|
|
846
|
+
deferred_calls_by_index[index] = 'external'
|
|
847
|
+
except exceptions.ApprovalRequired:
|
|
848
|
+
deferred_calls_by_index[index] = 'unapproved'
|
|
849
|
+
else:
|
|
850
|
+
yield _messages.FunctionToolResultEvent(tool_part)
|
|
750
851
|
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
852
|
+
tool_parts_by_index[index] = tool_part
|
|
853
|
+
if tool_user_part:
|
|
854
|
+
user_parts_by_index[index] = tool_user_part
|
|
855
|
+
|
|
856
|
+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
|
|
857
|
+
# This is mostly just to simplify testing
|
|
858
|
+
for k in sorted(tool_parts_by_index):
|
|
859
|
+
output_parts.append(tool_parts_by_index[k])
|
|
860
|
+
|
|
861
|
+
for k in sorted(user_parts_by_index):
|
|
862
|
+
output_parts.append(user_parts_by_index[k])
|
|
863
|
+
|
|
864
|
+
for k in sorted(deferred_calls_by_index):
|
|
865
|
+
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
async def _call_tool(
|
|
869
|
+
tool_manager: ToolManager[DepsT],
|
|
870
|
+
tool_call: _messages.ToolCallPart,
|
|
871
|
+
tool_call_result: DeferredToolResult | None,
|
|
872
|
+
usage_limits: _usage.UsageLimits | None,
|
|
873
|
+
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
|
|
874
|
+
try:
|
|
875
|
+
if tool_call_result is None:
|
|
876
|
+
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
|
|
877
|
+
elif isinstance(tool_call_result, ToolApproved):
|
|
878
|
+
if tool_call_result.override_args is not None:
|
|
879
|
+
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
|
|
880
|
+
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
|
|
881
|
+
elif isinstance(tool_call_result, ToolDenied):
|
|
882
|
+
return _messages.ToolReturnPart(
|
|
883
|
+
tool_name=tool_call.tool_name,
|
|
884
|
+
content=tool_call_result.message,
|
|
885
|
+
tool_call_id=tool_call.tool_call_id,
|
|
886
|
+
), None
|
|
887
|
+
elif isinstance(tool_call_result, exceptions.ModelRetry):
|
|
888
|
+
m = _messages.RetryPromptPart(
|
|
889
|
+
content=tool_call_result.message,
|
|
890
|
+
tool_name=tool_call.tool_name,
|
|
891
|
+
tool_call_id=tool_call.tool_call_id,
|
|
759
892
|
)
|
|
893
|
+
raise ToolRetryError(m)
|
|
894
|
+
elif isinstance(tool_call_result, _messages.RetryPromptPart):
|
|
895
|
+
tool_call_result.tool_name = tool_call.tool_name
|
|
896
|
+
tool_call_result.tool_call_id = tool_call.tool_call_id
|
|
897
|
+
raise ToolRetryError(tool_call_result)
|
|
898
|
+
else:
|
|
899
|
+
tool_result = tool_call_result
|
|
900
|
+
except ToolRetryError as e:
|
|
901
|
+
return e.tool_retry, None
|
|
902
|
+
|
|
903
|
+
if isinstance(tool_result, _messages.ToolReturn):
|
|
904
|
+
tool_return = tool_result
|
|
760
905
|
else:
|
|
906
|
+
result_is_list = isinstance(tool_result, list)
|
|
907
|
+
contents = cast(list[Any], tool_result) if result_is_list else [tool_result]
|
|
761
908
|
|
|
762
|
-
|
|
909
|
+
return_values: list[Any] = []
|
|
910
|
+
user_contents: list[str | _messages.UserContent] = []
|
|
911
|
+
for content in contents:
|
|
763
912
|
if isinstance(content, _messages.ToolReturn):
|
|
764
913
|
raise exceptions.UserError(
|
|
765
914
|
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
|
|
766
915
|
f'`ToolReturn` should be used directly.'
|
|
767
916
|
)
|
|
768
|
-
elif isinstance(content, _messages.
|
|
769
|
-
|
|
770
|
-
identifier = content.identifier or multi_modal_content_identifier(content.data)
|
|
771
|
-
else:
|
|
772
|
-
identifier = multi_modal_content_identifier(content.url)
|
|
917
|
+
elif isinstance(content, _messages.MultiModalContent):
|
|
918
|
+
identifier = content.identifier
|
|
773
919
|
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
)
|
|
779
|
-
)
|
|
780
|
-
return f'See file {identifier}'
|
|
920
|
+
return_values.append(f'See file {identifier}')
|
|
921
|
+
user_contents.extend([f'This is file {identifier}:', content])
|
|
922
|
+
else:
|
|
923
|
+
return_values.append(content)
|
|
781
924
|
|
|
782
|
-
|
|
925
|
+
tool_return = _messages.ToolReturn(
|
|
926
|
+
return_value=return_values[0] if len(return_values) == 1 and not result_is_list else return_values,
|
|
927
|
+
content=user_contents,
|
|
928
|
+
)
|
|
783
929
|
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
930
|
+
if (
|
|
931
|
+
isinstance(tool_return.return_value, _messages.MultiModalContent)
|
|
932
|
+
or isinstance(tool_return.return_value, list)
|
|
933
|
+
and any(
|
|
934
|
+
isinstance(content, _messages.MultiModalContent)
|
|
935
|
+
for content in tool_return.return_value # type: ignore
|
|
936
|
+
)
|
|
937
|
+
):
|
|
938
|
+
raise exceptions.UserError(
|
|
939
|
+
f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContent` objects. '
|
|
940
|
+
f'Please use `content` instead.'
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
return_part = _messages.ToolReturnPart(
|
|
944
|
+
tool_name=tool_call.tool_name,
|
|
945
|
+
tool_call_id=tool_call.tool_call_id,
|
|
946
|
+
content=tool_return.return_value, # type: ignore
|
|
947
|
+
metadata=tool_return.metadata,
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
user_part: _messages.UserPromptPart | None = None
|
|
951
|
+
if tool_return.content:
|
|
952
|
+
user_part = _messages.UserPromptPart(
|
|
953
|
+
content=tool_return.content,
|
|
954
|
+
part_kind='user-prompt',
|
|
955
|
+
)
|
|
789
956
|
|
|
790
|
-
return
|
|
957
|
+
return return_part, user_part
|
|
791
958
|
|
|
792
959
|
|
|
793
960
|
@dataclasses.dataclass
|