pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +84 -17
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +70 -17
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +156 -44
  31. pydantic_ai/models/__init__.py +20 -7
  32. pydantic_ai/models/anthropic.py +10 -17
  33. pydantic_ai/models/bedrock.py +55 -57
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +13 -14
  38. pydantic_ai/models/google.py +19 -5
  39. pydantic_ai/models/groq.py +127 -39
  40. pydantic_ai/models/huggingface.py +5 -5
  41. pydantic_ai/models/instrumented.py +49 -21
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +8 -8
  44. pydantic_ai/models/openai.py +37 -42
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +173 -52
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,16 +2,15 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
- import hashlib
6
5
  from collections import defaultdict, deque
7
- from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
6
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
8
7
  from contextlib import asynccontextmanager, contextmanager
9
8
  from contextvars import ContextVar
10
9
  from dataclasses import field
11
- from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
10
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
12
11
 
13
12
  from opentelemetry.trace import Tracer
14
- from typing_extensions import TypeGuard, TypeVar, assert_never
13
+ from typing_extensions import TypeVar, assert_never
15
14
 
16
15
  from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
17
16
  from pydantic_ai._tool_manager import ToolManager
@@ -24,7 +23,14 @@ from . import _output, _system_prompt, exceptions, messages as _messages, models
24
23
  from .exceptions import ToolRetryError
25
24
  from .output import OutputDataT, OutputSpec
26
25
  from .settings import ModelSettings
27
- from .tools import RunContext, ToolDefinition, ToolKind
26
+ from .tools import (
27
+ DeferredToolResult,
28
+ RunContext,
29
+ ToolApproved,
30
+ ToolDefinition,
31
+ ToolDenied,
32
+ ToolKind,
33
+ )
28
34
 
29
35
  if TYPE_CHECKING:
30
36
  from .models.instrumented import InstrumentationSettings
@@ -59,19 +65,19 @@ _HistoryProcessorSyncWithCtx = Callable[[RunContext[DepsT], list[_messages.Model
59
65
  _HistoryProcessorAsyncWithCtx = Callable[
60
66
  [RunContext[DepsT], list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]
61
67
  ]
62
- HistoryProcessor = Union[
63
- _HistoryProcessorSync,
64
- _HistoryProcessorAsync,
65
- _HistoryProcessorSyncWithCtx[DepsT],
66
- _HistoryProcessorAsyncWithCtx[DepsT],
67
- ]
68
+ HistoryProcessor = (
69
+ _HistoryProcessorSync
70
+ | _HistoryProcessorAsync
71
+ | _HistoryProcessorSyncWithCtx[DepsT]
72
+ | _HistoryProcessorAsyncWithCtx[DepsT]
73
+ )
68
74
  """A function that processes a list of model messages and returns a list of model messages.
69
75
 
70
76
  Can optionally accept a `RunContext` as a parameter.
71
77
  """
72
78
 
73
79
 
74
- @dataclasses.dataclass
80
+ @dataclasses.dataclass(kw_only=True)
75
81
  class GraphAgentState:
76
82
  """State kept across the execution of the agent graph."""
77
83
 
@@ -92,7 +98,7 @@ class GraphAgentState:
92
98
  raise exceptions.UnexpectedModelBehavior(message)
93
99
 
94
100
 
95
- @dataclasses.dataclass
101
+ @dataclasses.dataclass(kw_only=True)
96
102
  class GraphAgentDeps(Generic[DepsT, OutputDataT]):
97
103
  """Dependencies/config passed to the agent graph."""
98
104
 
@@ -115,9 +121,10 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
115
121
 
116
122
  builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
117
123
  tool_manager: ToolManager[DepsT]
124
+ tool_call_results: dict[str, DeferredToolResult] | None
118
125
 
119
126
  tracer: Tracer
120
- instrumentation_settings: InstrumentationSettings | None = None
127
+ instrumentation_settings: InstrumentationSettings | None
121
128
 
122
129
 
123
130
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
@@ -149,6 +156,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
149
156
 
150
157
  user_prompt: str | Sequence[_messages.UserContent] | None
151
158
 
159
+ _: dataclasses.KW_ONLY
160
+
152
161
  instructions: str | None
153
162
  instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
154
163
 
@@ -158,7 +167,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
158
167
 
159
168
  async def run(
160
169
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
161
- ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], CallToolsNode[DepsT, NodeRunEndT]]: # noqa UP007
170
+ ) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]:
162
171
  try:
163
172
  ctx_messages = get_captured_run_messages()
164
173
  except LookupError:
@@ -184,26 +193,29 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
184
193
  else:
185
194
  parts.extend(await self._sys_parts(run_context))
186
195
 
196
+ if (tool_call_results := ctx.deps.tool_call_results) is not None:
197
+ if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
198
+ # If tool call results were provided, that means the previous run ended on deferred tool calls.
199
+ # That run would typically have ended on a `ModelResponse`, but if it had a mix of deferred tool calls and ones that could already be executed,
200
+ # a `ModelRequest` would already have been added to the history with the preliminary results, even if it wouldn't have been sent to the model yet.
201
+ # So now that we have all of the deferred results, we roll back to the last `ModelResponse` and store the contents of the `ModelRequest` on `deferred_tool_results` to be handled by `CallToolsNode`.
202
+ ctx.deps.tool_call_results = self._update_tool_call_results_from_model_request(
203
+ tool_call_results, last_message
204
+ )
205
+ messages.pop()
206
+
207
+ if not messages:
208
+ raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
209
+
187
210
  if messages and (last_message := messages[-1]):
188
211
  if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
189
212
  # Drop last message from history and reuse its parts
190
213
  messages.pop()
191
214
  parts.extend(last_message.parts)
192
215
  elif isinstance(last_message, _messages.ModelResponse):
193
- if self.user_prompt is None:
194
- # `CallToolsNode` requires the tool manager to be prepared for the run step
195
- # This will raise errors for any tool name conflicts
196
- ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
197
-
198
- # Skip ModelRequestNode and go directly to CallToolsNode
199
- return CallToolsNode[DepsT, NodeRunEndT](model_response=last_message)
200
- elif any(isinstance(part, _messages.ToolCallPart) for part in last_message.parts):
201
- raise exceptions.UserError(
202
- 'Cannot provide a new user prompt when the message history ends with '
203
- 'a model response containing unprocessed tool calls. Either process the '
204
- 'tool calls first (by calling `iter` with `user_prompt=None`) or append a '
205
- '`ModelRequest` with `ToolResultPart`s.'
206
- )
216
+ call_tools_node = await self._handle_message_history_model_response(ctx, last_message)
217
+ if call_tools_node is not None:
218
+ return call_tools_node
207
219
 
208
220
  if self.user_prompt is not None:
209
221
  parts.append(_messages.UserPromptPart(self.user_prompt))
@@ -213,6 +225,74 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
213
225
 
214
226
  return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
215
227
 
228
+ async def _handle_message_history_model_response(
229
+ self,
230
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
231
+ message: _messages.ModelResponse,
232
+ ) -> CallToolsNode[DepsT, NodeRunEndT] | None:
233
+ unprocessed_tool_calls = any(isinstance(part, _messages.ToolCallPart) for part in message.parts)
234
+ if unprocessed_tool_calls:
235
+ if self.user_prompt is not None:
236
+ raise exceptions.UserError(
237
+ 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
238
+ )
239
+ else:
240
+ if ctx.deps.tool_call_results is not None:
241
+ raise exceptions.UserError(
242
+ 'Tool call results were provided, but the message history does not contain any unprocessed tool calls.'
243
+ )
244
+
245
+ if unprocessed_tool_calls or self.user_prompt is None:
246
+ # `CallToolsNode` requires the tool manager to be prepared for the run step
247
+ # This will raise errors for any tool name conflicts
248
+ run_context = build_run_context(ctx)
249
+ ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
250
+
251
+ # Skip ModelRequestNode and go directly to CallToolsNode
252
+ return CallToolsNode[DepsT, NodeRunEndT](model_response=message)
253
+
254
+ def _update_tool_call_results_from_model_request(
255
+ self, tool_call_results: dict[str, DeferredToolResult], message: _messages.ModelRequest
256
+ ) -> dict[str, DeferredToolResult]:
257
+ last_tool_return: _messages.ToolReturn | None = None
258
+ user_content: list[str | _messages.UserContent] = []
259
+ for part in message.parts:
260
+ if isinstance(part, _messages.ToolReturnPart):
261
+ if part.tool_call_id in tool_call_results:
262
+ raise exceptions.UserError(
263
+ f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
264
+ )
265
+
266
+ last_tool_return = _messages.ToolReturn(return_value=part.content, metadata=part.metadata)
267
+ tool_call_results[part.tool_call_id] = last_tool_return
268
+ elif isinstance(part, _messages.RetryPromptPart):
269
+ if part.tool_call_id in tool_call_results:
270
+ raise exceptions.UserError(
271
+ f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
272
+ )
273
+
274
+ tool_call_results[part.tool_call_id] = part
275
+ elif isinstance(part, _messages.UserPromptPart):
276
+ # Tools can return user parts via `ToolReturn.content` or by returning multi-modal content.
277
+ # These go together with a specific `ToolReturnPart`, but we don't have a way to know which,
278
+ # so (below) we just add them to the last one, matching the tool-results-before-user-parts order of the request.
279
+ if isinstance(part.content, str):
280
+ user_content.append(part.content)
281
+ else:
282
+ user_content.extend(part.content)
283
+ else:
284
+ raise exceptions.UserError(f'Unexpected message part type: {type(part)}') # pragma: no cover
285
+
286
+ if user_content:
287
+ if last_tool_return is None:
288
+ raise exceptions.UserError(
289
+ 'Tool call results were provided, but the last message in the history was a `ModelRequest` with user parts not tied to preliminary tool results.'
290
+ )
291
+ assert last_tool_return is not None
292
+ last_tool_return.content = user_content
293
+
294
+ return tool_call_results
295
+
216
296
  async def _reevaluate_dynamic_prompts(
217
297
  self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
218
298
  ) -> None:
@@ -221,16 +301,21 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
221
301
  if self.system_prompt_dynamic_functions:
222
302
  for msg in messages:
223
303
  if isinstance(msg, _messages.ModelRequest):
224
- for i, part in enumerate(msg.parts):
304
+ reevaluated_message_parts: list[_messages.ModelRequestPart] = []
305
+ for part in msg.parts:
225
306
  if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
226
307
  # Look up the runner by its ref
227
308
  if runner := self.system_prompt_dynamic_functions.get( # pragma: lax no cover
228
309
  part.dynamic_ref
229
310
  ):
230
311
  updated_part_content = await runner.run(run_context)
231
- msg.parts[i] = _messages.SystemPromptPart(
232
- updated_part_content, dynamic_ref=part.dynamic_ref
233
- )
312
+ part = _messages.SystemPromptPart(updated_part_content, dynamic_ref=part.dynamic_ref)
313
+
314
+ reevaluated_message_parts.append(part)
315
+
316
+ # Replace message parts with reevaluated ones to prevent mutating parts list
317
+ if reevaluated_message_parts != msg.parts:
318
+ msg.parts = reevaluated_message_parts
234
319
 
235
320
  async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
236
321
  """Build the initial messages for the conversation."""
@@ -280,8 +365,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
280
365
 
281
366
  request: _messages.ModelRequest
282
367
 
283
- _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
284
- _did_stream: bool = field(default=False, repr=False)
368
+ _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
369
+ _did_stream: bool = field(repr=False, init=False, default=False)
285
370
 
286
371
  async def run(
287
372
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -310,13 +395,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
310
395
  self._did_stream = True
311
396
  ctx.state.usage.requests += 1
312
397
  agent_stream = result.AgentStream[DepsT, T](
313
- streamed_response,
314
- ctx.deps.output_schema,
315
- model_request_parameters,
316
- ctx.deps.output_validators,
317
- build_run_context(ctx),
318
- ctx.deps.usage_limits,
319
- ctx.deps.tool_manager,
398
+ _raw_stream_response=streamed_response,
399
+ _output_schema=ctx.deps.output_schema,
400
+ _model_request_parameters=model_request_parameters,
401
+ _output_validators=ctx.deps.output_validators,
402
+ _run_ctx=build_run_context(ctx),
403
+ _usage_limits=ctx.deps.usage_limits,
404
+ _tool_manager=ctx.deps.tool_manager,
320
405
  )
321
406
  yield agent_stream
322
407
  # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -396,14 +481,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
396
481
 
397
482
  model_response: _messages.ModelResponse
398
483
 
399
- _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
484
+ _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
400
485
  _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
401
- default=None, repr=False
486
+ default=None, init=False, repr=False
402
487
  )
403
488
 
404
489
  async def run(
405
490
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
406
- ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
491
+ ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
407
492
  async with self.stream(ctx):
408
493
  pass
409
494
  assert self._next_node is not None, 'the stream should set `self._next_node` before it ends'
@@ -506,13 +591,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
506
591
  if output_final_result:
507
592
  final_result = output_final_result[0]
508
593
  self._next_node = self._handle_final_result(ctx, final_result, output_parts)
509
- elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
510
- if not ctx.deps.output_schema.allows_deferred_tool_calls:
511
- raise exceptions.UserError(
512
- 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
513
- )
514
- final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
515
- self._next_node = self._handle_final_result(ctx, final_result, output_parts)
516
594
  else:
517
595
  instructions = await ctx.deps.get_instructions(run_context)
518
596
  self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
@@ -557,7 +635,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
557
635
  ctx.state.increment_retries(ctx.deps.max_result_retries, e)
558
636
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
559
637
  else:
560
- return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
638
+ return self._handle_final_result(ctx, result.FinalResult(result_data), [])
561
639
 
562
640
 
563
641
  def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
@@ -572,16 +650,10 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
572
650
  trace_include_content=ctx.deps.instrumentation_settings is not None
573
651
  and ctx.deps.instrumentation_settings.include_content,
574
652
  run_step=ctx.state.run_step,
653
+ tool_call_approved=ctx.state.run_step == 0 and ctx.deps.tool_call_results is not None,
575
654
  )
576
655
 
577
656
 
578
- def multi_modal_content_identifier(identifier: str | bytes) -> str:
579
- """Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
580
- if isinstance(identifier, str):
581
- identifier = identifier.encode('utf-8')
582
- return hashlib.sha1(identifier).hexdigest()[:6]
583
-
584
-
585
657
  async def process_function_tools( # noqa: C901
586
658
  tool_manager: ToolManager[DepsT],
587
659
  tool_calls: list[_messages.ToolCallPart],
@@ -599,7 +671,10 @@ async def process_function_tools( # noqa: C901
599
671
  tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
600
672
  for call in tool_calls:
601
673
  tool_def = tool_manager.get_tool_def(call.tool_name)
602
- kind = tool_def.kind if tool_def else 'unknown'
674
+ if tool_def:
675
+ kind = tool_def.kind
676
+ else:
677
+ kind = 'unknown'
603
678
  tool_calls_by_kind[kind].append(call)
604
679
 
605
680
  # First, we handle output tool calls
@@ -662,132 +737,224 @@ async def process_function_tools( # noqa: C901
662
737
  ctx.state.increment_retries(ctx.deps.max_result_retries)
663
738
  calls_to_run.extend(tool_calls_by_kind['unknown'])
664
739
 
665
- for call in calls_to_run:
666
- yield _messages.FunctionToolCallEvent(call)
740
+ deferred_tool_results: dict[str, DeferredToolResult] = {}
741
+ if build_run_context(ctx).tool_call_approved and ctx.deps.tool_call_results is not None:
742
+ deferred_tool_results = ctx.deps.tool_call_results
667
743
 
668
- user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list)
744
+ # Deferred tool calls are "run" as well, by reading their value from the tool call results
745
+ calls_to_run.extend(tool_calls_by_kind['external'])
746
+ calls_to_run.extend(tool_calls_by_kind['unapproved'])
669
747
 
670
- if calls_to_run:
671
- # Run all tool tasks in parallel
672
- tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
673
- with ctx.deps.tracer.start_as_current_span(
674
- 'running tools',
675
- attributes={
676
- 'tools': [call.tool_name for call in calls_to_run],
677
- 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
678
- },
679
- ):
680
- tasks = [
681
- asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
682
- for call in calls_to_run
683
- ]
684
-
685
- pending = tasks
686
- while pending:
687
- done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
688
- for task in done:
689
- index = tasks.index(task)
690
- tool_part, tool_user_parts = task.result()
691
- yield _messages.FunctionToolResultEvent(tool_part)
748
+ result_tool_call_ids = set(deferred_tool_results.keys())
749
+ tool_call_ids_to_run = {call.tool_call_id for call in calls_to_run}
750
+ if tool_call_ids_to_run != result_tool_call_ids:
751
+ raise exceptions.UserError(
752
+ 'Tool call results need to be provided for all deferred tool calls. '
753
+ f'Expected: {tool_call_ids_to_run}, got: {result_tool_call_ids}'
754
+ )
692
755
 
693
- tool_parts_by_index[index] = tool_part
694
- user_parts_by_index[index] = tool_user_parts
756
+ deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
695
757
 
696
- # We append the results at the end, rather than as they are received, to retain a consistent ordering
697
- # This is mostly just to simplify testing
698
- for k in sorted(tool_parts_by_index):
699
- output_parts.append(tool_parts_by_index[k])
758
+ if calls_to_run:
759
+ async for event in _call_tools(
760
+ tool_manager,
761
+ calls_to_run,
762
+ deferred_tool_results,
763
+ ctx.deps.tracer,
764
+ ctx.deps.usage_limits,
765
+ output_parts,
766
+ deferred_calls,
767
+ ):
768
+ yield event
700
769
 
701
- # Finally, we handle deferred tool calls
702
- for call in tool_calls_by_kind['deferred']:
770
+ # Finally, we handle deferred tool calls (unless they were already included in the run because results were provided)
771
+ if not deferred_tool_results:
703
772
  if final_result:
704
- output_parts.append(
705
- _messages.ToolReturnPart(
706
- tool_name=call.tool_name,
707
- content='Tool not executed - a final result was already processed.',
708
- tool_call_id=call.tool_call_id,
773
+ for call in [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']]:
774
+ output_parts.append(
775
+ _messages.ToolReturnPart(
776
+ tool_name=call.tool_name,
777
+ content='Tool not executed - a final result was already processed.',
778
+ tool_call_id=call.tool_call_id,
779
+ )
709
780
  )
710
- )
711
781
  else:
712
- yield _messages.FunctionToolCallEvent(call)
782
+ for call in tool_calls_by_kind['external']:
783
+ deferred_calls['external'].append(call)
784
+ yield _messages.FunctionToolCallEvent(call)
713
785
 
714
- for k in sorted(user_parts_by_index):
715
- output_parts.extend(user_parts_by_index[k])
786
+ for call in tool_calls_by_kind['unapproved']:
787
+ deferred_calls['unapproved'].append(call)
788
+ yield _messages.FunctionToolCallEvent(call)
789
+
790
+ if not final_result and deferred_calls:
791
+ if not ctx.deps.output_schema.allows_deferred_tools:
792
+ raise exceptions.UserError(
793
+ 'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
794
+ )
795
+ deferred_tool_requests = _output.DeferredToolRequests(
796
+ calls=deferred_calls['external'],
797
+ approvals=deferred_calls['unapproved'],
798
+ )
799
+
800
+ final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
716
801
 
717
802
  if final_result:
718
803
  output_final_result.append(final_result)
719
804
 
720
805
 
721
- async def _call_function_tool(
806
+ async def _call_tools(
722
807
  tool_manager: ToolManager[DepsT],
723
- tool_call: _messages.ToolCallPart,
724
- ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]:
725
- try:
726
- tool_result = await tool_manager.handle_call(tool_call)
727
- except ToolRetryError as e:
728
- return (e.tool_retry, [])
808
+ tool_calls: list[_messages.ToolCallPart],
809
+ deferred_tool_results: dict[str, DeferredToolResult],
810
+ tracer: Tracer,
811
+ usage_limits: _usage.UsageLimits | None,
812
+ output_parts: list[_messages.ModelRequestPart],
813
+ output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
814
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
815
+ tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
816
+ user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
817
+ deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
729
818
 
730
- tool_part = _messages.ToolReturnPart(
731
- tool_name=tool_call.tool_name,
732
- content=tool_result,
733
- tool_call_id=tool_call.tool_call_id,
734
- )
735
- user_parts: list[_messages.UserPromptPart] = []
819
+ for call in tool_calls:
820
+ yield _messages.FunctionToolCallEvent(call)
736
821
 
737
- if isinstance(tool_result, _messages.ToolReturn):
738
- if (
739
- isinstance(tool_result.return_value, _messages.MultiModalContentTypes)
740
- or isinstance(tool_result.return_value, list)
741
- and any(
742
- isinstance(content, _messages.MultiModalContentTypes)
743
- for content in tool_result.return_value # type: ignore
744
- )
745
- ):
746
- raise exceptions.UserError(
747
- f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. '
748
- f'Please use `content` instead.'
822
+ # Run all tool tasks in parallel
823
+ with tracer.start_as_current_span(
824
+ 'running tools',
825
+ attributes={
826
+ 'tools': [call.tool_name for call in tool_calls],
827
+ 'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
828
+ },
829
+ ):
830
+ tasks = [
831
+ asyncio.create_task(
832
+ _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits),
833
+ name=call.tool_name,
749
834
  )
835
+ for call in tool_calls
836
+ ]
837
+
838
+ pending = tasks
839
+ while pending:
840
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
841
+ for task in done:
842
+ index = tasks.index(task)
843
+ try:
844
+ tool_part, tool_user_part = task.result()
845
+ except exceptions.CallDeferred:
846
+ deferred_calls_by_index[index] = 'external'
847
+ except exceptions.ApprovalRequired:
848
+ deferred_calls_by_index[index] = 'unapproved'
849
+ else:
850
+ yield _messages.FunctionToolResultEvent(tool_part)
750
851
 
751
- tool_part.content = tool_result.return_value # type: ignore
752
- tool_part.metadata = tool_result.metadata
753
- if tool_result.content:
754
- user_parts.append(
755
- _messages.UserPromptPart(
756
- content=tool_result.content,
757
- part_kind='user-prompt',
758
- )
852
+ tool_parts_by_index[index] = tool_part
853
+ if tool_user_part:
854
+ user_parts_by_index[index] = tool_user_part
855
+
856
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
857
+ # This is mostly just to simplify testing
858
+ for k in sorted(tool_parts_by_index):
859
+ output_parts.append(tool_parts_by_index[k])
860
+
861
+ for k in sorted(user_parts_by_index):
862
+ output_parts.append(user_parts_by_index[k])
863
+
864
+ for k in sorted(deferred_calls_by_index):
865
+ output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
866
+
867
+
868
+ async def _call_tool(
869
+ tool_manager: ToolManager[DepsT],
870
+ tool_call: _messages.ToolCallPart,
871
+ tool_call_result: DeferredToolResult | None,
872
+ usage_limits: _usage.UsageLimits | None,
873
+ ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
874
+ try:
875
+ if tool_call_result is None:
876
+ tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
877
+ elif isinstance(tool_call_result, ToolApproved):
878
+ if tool_call_result.override_args is not None:
879
+ tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
880
+ tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
881
+ elif isinstance(tool_call_result, ToolDenied):
882
+ return _messages.ToolReturnPart(
883
+ tool_name=tool_call.tool_name,
884
+ content=tool_call_result.message,
885
+ tool_call_id=tool_call.tool_call_id,
886
+ ), None
887
+ elif isinstance(tool_call_result, exceptions.ModelRetry):
888
+ m = _messages.RetryPromptPart(
889
+ content=tool_call_result.message,
890
+ tool_name=tool_call.tool_name,
891
+ tool_call_id=tool_call.tool_call_id,
759
892
  )
893
+ raise ToolRetryError(m)
894
+ elif isinstance(tool_call_result, _messages.RetryPromptPart):
895
+ tool_call_result.tool_name = tool_call.tool_name
896
+ tool_call_result.tool_call_id = tool_call.tool_call_id
897
+ raise ToolRetryError(tool_call_result)
898
+ else:
899
+ tool_result = tool_call_result
900
+ except ToolRetryError as e:
901
+ return e.tool_retry, None
902
+
903
+ if isinstance(tool_result, _messages.ToolReturn):
904
+ tool_return = tool_result
760
905
  else:
906
+ result_is_list = isinstance(tool_result, list)
907
+ contents = cast(list[Any], tool_result) if result_is_list else [tool_result]
761
908
 
762
- def process_content(content: Any) -> Any:
909
+ return_values: list[Any] = []
910
+ user_contents: list[str | _messages.UserContent] = []
911
+ for content in contents:
763
912
  if isinstance(content, _messages.ToolReturn):
764
913
  raise exceptions.UserError(
765
914
  f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
766
915
  f'`ToolReturn` should be used directly.'
767
916
  )
768
- elif isinstance(content, _messages.MultiModalContentTypes):
769
- if isinstance(content, _messages.BinaryContent):
770
- identifier = content.identifier or multi_modal_content_identifier(content.data)
771
- else:
772
- identifier = multi_modal_content_identifier(content.url)
917
+ elif isinstance(content, _messages.MultiModalContent):
918
+ identifier = content.identifier
773
919
 
774
- user_parts.append(
775
- _messages.UserPromptPart(
776
- content=[f'This is file {identifier}:', content],
777
- part_kind='user-prompt',
778
- )
779
- )
780
- return f'See file {identifier}'
920
+ return_values.append(f'See file {identifier}')
921
+ user_contents.extend([f'This is file {identifier}:', content])
922
+ else:
923
+ return_values.append(content)
781
924
 
782
- return content
925
+ tool_return = _messages.ToolReturn(
926
+ return_value=return_values[0] if len(return_values) == 1 and not result_is_list else return_values,
927
+ content=user_contents,
928
+ )
783
929
 
784
- if isinstance(tool_result, list):
785
- contents = cast(list[Any], tool_result)
786
- tool_part.content = [process_content(content) for content in contents]
787
- else:
788
- tool_part.content = process_content(tool_result)
930
+ if (
931
+ isinstance(tool_return.return_value, _messages.MultiModalContent)
932
+ or isinstance(tool_return.return_value, list)
933
+ and any(
934
+ isinstance(content, _messages.MultiModalContent)
935
+ for content in tool_return.return_value # type: ignore
936
+ )
937
+ ):
938
+ raise exceptions.UserError(
939
+ f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContent` objects. '
940
+ f'Please use `content` instead.'
941
+ )
942
+
943
+ return_part = _messages.ToolReturnPart(
944
+ tool_name=tool_call.tool_name,
945
+ tool_call_id=tool_call.tool_call_id,
946
+ content=tool_return.return_value, # type: ignore
947
+ metadata=tool_return.metadata,
948
+ )
949
+
950
+ user_part: _messages.UserPromptPart | None = None
951
+ if tool_return.content:
952
+ user_part = _messages.UserPromptPart(
953
+ content=tool_return.content,
954
+ part_kind='user-prompt',
955
+ )
789
956
 
790
- return (tool_part, user_parts)
957
+ return return_part, user_part
791
958
 
792
959
 
793
960
  @dataclasses.dataclass