pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (70) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_agent_graph.py +310 -140
  3. pydantic_ai/_function_schema.py +5 -5
  4. pydantic_ai/_griffe.py +2 -1
  5. pydantic_ai/_otel_messages.py +2 -2
  6. pydantic_ai/_output.py +31 -35
  7. pydantic_ai/_parts_manager.py +4 -4
  8. pydantic_ai/_run_context.py +3 -1
  9. pydantic_ai/_system_prompt.py +2 -2
  10. pydantic_ai/_tool_manager.py +3 -22
  11. pydantic_ai/_utils.py +14 -26
  12. pydantic_ai/ag_ui.py +7 -8
  13. pydantic_ai/agent/__init__.py +70 -9
  14. pydantic_ai/agent/abstract.py +35 -4
  15. pydantic_ai/agent/wrapper.py +6 -0
  16. pydantic_ai/builtin_tools.py +2 -2
  17. pydantic_ai/common_tools/duckduckgo.py +4 -2
  18. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  19. pydantic_ai/durable_exec/temporal/_agent.py +23 -2
  20. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  21. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  22. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  23. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  24. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  25. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  26. pydantic_ai/exceptions.py +45 -2
  27. pydantic_ai/format_prompt.py +2 -2
  28. pydantic_ai/mcp.py +2 -2
  29. pydantic_ai/messages.py +73 -25
  30. pydantic_ai/models/__init__.py +5 -4
  31. pydantic_ai/models/anthropic.py +5 -5
  32. pydantic_ai/models/bedrock.py +58 -56
  33. pydantic_ai/models/cohere.py +3 -3
  34. pydantic_ai/models/fallback.py +2 -2
  35. pydantic_ai/models/function.py +25 -23
  36. pydantic_ai/models/gemini.py +9 -12
  37. pydantic_ai/models/google.py +3 -3
  38. pydantic_ai/models/groq.py +4 -4
  39. pydantic_ai/models/huggingface.py +4 -4
  40. pydantic_ai/models/instrumented.py +30 -16
  41. pydantic_ai/models/mcp_sampling.py +3 -1
  42. pydantic_ai/models/mistral.py +6 -6
  43. pydantic_ai/models/openai.py +18 -27
  44. pydantic_ai/models/test.py +24 -4
  45. pydantic_ai/output.py +27 -32
  46. pydantic_ai/profiles/__init__.py +3 -3
  47. pydantic_ai/profiles/groq.py +1 -1
  48. pydantic_ai/profiles/openai.py +25 -4
  49. pydantic_ai/providers/anthropic.py +2 -3
  50. pydantic_ai/providers/bedrock.py +3 -2
  51. pydantic_ai/result.py +144 -41
  52. pydantic_ai/retries.py +10 -29
  53. pydantic_ai/run.py +12 -5
  54. pydantic_ai/tools.py +126 -22
  55. pydantic_ai/toolsets/__init__.py +4 -1
  56. pydantic_ai/toolsets/_dynamic.py +4 -4
  57. pydantic_ai/toolsets/abstract.py +18 -2
  58. pydantic_ai/toolsets/approval_required.py +32 -0
  59. pydantic_ai/toolsets/combined.py +7 -12
  60. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  61. pydantic_ai/toolsets/filtered.py +1 -1
  62. pydantic_ai/toolsets/function.py +13 -4
  63. pydantic_ai/toolsets/wrapper.py +2 -1
  64. pydantic_ai/usage.py +7 -5
  65. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +5 -6
  66. pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
  67. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  68. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
  69. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -4,14 +4,14 @@ import asyncio
4
4
  import dataclasses
5
5
  import hashlib
6
6
  from collections import defaultdict, deque
7
- from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
8
8
  from contextlib import asynccontextmanager, contextmanager
9
9
  from contextvars import ContextVar
10
10
  from dataclasses import field
11
- from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
11
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
12
12
 
13
13
  from opentelemetry.trace import Tracer
14
- from typing_extensions import TypeGuard, TypeVar, assert_never
14
+ from typing_extensions import TypeVar, assert_never
15
15
 
16
16
  from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
17
17
  from pydantic_ai._tool_manager import ToolManager
@@ -24,7 +24,14 @@ from . import _output, _system_prompt, exceptions, messages as _messages, models
24
24
  from .exceptions import ToolRetryError
25
25
  from .output import OutputDataT, OutputSpec
26
26
  from .settings import ModelSettings
27
- from .tools import RunContext, ToolDefinition, ToolKind
27
+ from .tools import (
28
+ DeferredToolResult,
29
+ RunContext,
30
+ ToolApproved,
31
+ ToolDefinition,
32
+ ToolDenied,
33
+ ToolKind,
34
+ )
28
35
 
29
36
  if TYPE_CHECKING:
30
37
  from .models.instrumented import InstrumentationSettings
@@ -59,19 +66,19 @@ _HistoryProcessorSyncWithCtx = Callable[[RunContext[DepsT], list[_messages.Model
59
66
  _HistoryProcessorAsyncWithCtx = Callable[
60
67
  [RunContext[DepsT], list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]
61
68
  ]
62
- HistoryProcessor = Union[
63
- _HistoryProcessorSync,
64
- _HistoryProcessorAsync,
65
- _HistoryProcessorSyncWithCtx[DepsT],
66
- _HistoryProcessorAsyncWithCtx[DepsT],
67
- ]
69
+ HistoryProcessor = (
70
+ _HistoryProcessorSync
71
+ | _HistoryProcessorAsync
72
+ | _HistoryProcessorSyncWithCtx[DepsT]
73
+ | _HistoryProcessorAsyncWithCtx[DepsT]
74
+ )
68
75
  """A function that processes a list of model messages and returns a list of model messages.
69
76
 
70
77
  Can optionally accept a `RunContext` as a parameter.
71
78
  """
72
79
 
73
80
 
74
- @dataclasses.dataclass
81
+ @dataclasses.dataclass(kw_only=True)
75
82
  class GraphAgentState:
76
83
  """State kept across the execution of the agent graph."""
77
84
 
@@ -92,7 +99,7 @@ class GraphAgentState:
92
99
  raise exceptions.UnexpectedModelBehavior(message)
93
100
 
94
101
 
95
- @dataclasses.dataclass
102
+ @dataclasses.dataclass(kw_only=True)
96
103
  class GraphAgentDeps(Generic[DepsT, OutputDataT]):
97
104
  """Dependencies/config passed to the agent graph."""
98
105
 
@@ -115,9 +122,10 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
115
122
 
116
123
  builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
117
124
  tool_manager: ToolManager[DepsT]
125
+ tool_call_results: dict[str, DeferredToolResult] | None
118
126
 
119
127
  tracer: Tracer
120
- instrumentation_settings: InstrumentationSettings | None = None
128
+ instrumentation_settings: InstrumentationSettings | None
121
129
 
122
130
 
123
131
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
@@ -149,6 +157,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
149
157
 
150
158
  user_prompt: str | Sequence[_messages.UserContent] | None
151
159
 
160
+ _: dataclasses.KW_ONLY
161
+
152
162
  instructions: str | None
153
163
  instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
154
164
 
@@ -158,7 +168,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
158
168
 
159
169
  async def run(
160
170
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
161
- ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], CallToolsNode[DepsT, NodeRunEndT]]: # noqa UP007
171
+ ) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]:
162
172
  try:
163
173
  ctx_messages = get_captured_run_messages()
164
174
  except LookupError:
@@ -184,26 +194,29 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
184
194
  else:
185
195
  parts.extend(await self._sys_parts(run_context))
186
196
 
197
+ if (tool_call_results := ctx.deps.tool_call_results) is not None:
198
+ if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
199
+ # If tool call results were provided, that means the previous run ended on deferred tool calls.
200
+ # That run would typically have ended on a `ModelResponse`, but if it had a mix of deferred tool calls and ones that could already be executed,
201
+ # a `ModelRequest` would already have been added to the history with the preliminary results, even if it wouldn't have been sent to the model yet.
202
+ # So now that we have all of the deferred results, we roll back to the last `ModelResponse` and store the contents of the `ModelRequest` on `deferred_tool_results` to be handled by `CallToolsNode`.
203
+ ctx.deps.tool_call_results = self._update_tool_call_results_from_model_request(
204
+ tool_call_results, last_message
205
+ )
206
+ messages.pop()
207
+
208
+ if not messages:
209
+ raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
210
+
187
211
  if messages and (last_message := messages[-1]):
188
212
  if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
189
213
  # Drop last message from history and reuse its parts
190
214
  messages.pop()
191
215
  parts.extend(last_message.parts)
192
216
  elif isinstance(last_message, _messages.ModelResponse):
193
- if self.user_prompt is None:
194
- # `CallToolsNode` requires the tool manager to be prepared for the run step
195
- # This will raise errors for any tool name conflicts
196
- ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
197
-
198
- # Skip ModelRequestNode and go directly to CallToolsNode
199
- return CallToolsNode[DepsT, NodeRunEndT](model_response=last_message)
200
- elif any(isinstance(part, _messages.ToolCallPart) for part in last_message.parts):
201
- raise exceptions.UserError(
202
- 'Cannot provide a new user prompt when the message history ends with '
203
- 'a model response containing unprocessed tool calls. Either process the '
204
- 'tool calls first (by calling `iter` with `user_prompt=None`) or append a '
205
- '`ModelRequest` with `ToolResultPart`s.'
206
- )
217
+ call_tools_node = await self._handle_message_history_model_response(ctx, last_message)
218
+ if call_tools_node is not None:
219
+ return call_tools_node
207
220
 
208
221
  if self.user_prompt is not None:
209
222
  parts.append(_messages.UserPromptPart(self.user_prompt))
@@ -213,6 +226,74 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
213
226
 
214
227
  return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
215
228
 
229
+ async def _handle_message_history_model_response(
230
+ self,
231
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
232
+ message: _messages.ModelResponse,
233
+ ) -> CallToolsNode[DepsT, NodeRunEndT] | None:
234
+ unprocessed_tool_calls = any(isinstance(part, _messages.ToolCallPart) for part in message.parts)
235
+ if unprocessed_tool_calls:
236
+ if self.user_prompt is not None:
237
+ raise exceptions.UserError(
238
+ 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
239
+ )
240
+ else:
241
+ if ctx.deps.tool_call_results is not None:
242
+ raise exceptions.UserError(
243
+ 'Tool call results were provided, but the message history does not contain any unprocessed tool calls.'
244
+ )
245
+
246
+ if unprocessed_tool_calls or self.user_prompt is None:
247
+ # `CallToolsNode` requires the tool manager to be prepared for the run step
248
+ # This will raise errors for any tool name conflicts
249
+ run_context = build_run_context(ctx)
250
+ ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
251
+
252
+ # Skip ModelRequestNode and go directly to CallToolsNode
253
+ return CallToolsNode[DepsT, NodeRunEndT](model_response=message)
254
+
255
+ def _update_tool_call_results_from_model_request(
256
+ self, tool_call_results: dict[str, DeferredToolResult], message: _messages.ModelRequest
257
+ ) -> dict[str, DeferredToolResult]:
258
+ last_tool_return: _messages.ToolReturn | None = None
259
+ user_content: list[str | _messages.UserContent] = []
260
+ for part in message.parts:
261
+ if isinstance(part, _messages.ToolReturnPart):
262
+ if part.tool_call_id in tool_call_results:
263
+ raise exceptions.UserError(
264
+ f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
265
+ )
266
+
267
+ last_tool_return = _messages.ToolReturn(return_value=part.content, metadata=part.metadata)
268
+ tool_call_results[part.tool_call_id] = last_tool_return
269
+ elif isinstance(part, _messages.RetryPromptPart):
270
+ if part.tool_call_id in tool_call_results:
271
+ raise exceptions.UserError(
272
+ f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
273
+ )
274
+
275
+ tool_call_results[part.tool_call_id] = part
276
+ elif isinstance(part, _messages.UserPromptPart):
277
+ # Tools can return user parts via `ToolReturn.content` or by returning multi-modal content.
278
+ # These go together with a specific `ToolReturnPart`, but we don't have a way to know which,
279
+ # so (below) we just add them to the last one, matching the tool-results-before-user-parts order of the request.
280
+ if isinstance(part.content, str):
281
+ user_content.append(part.content)
282
+ else:
283
+ user_content.extend(part.content)
284
+ else:
285
+ raise exceptions.UserError(f'Unexpected message part type: {type(part)}') # pragma: no cover
286
+
287
+ if user_content:
288
+ if last_tool_return is None:
289
+ raise exceptions.UserError(
290
+ 'Tool call results were provided, but the last message in the history was a `ModelRequest` with user parts not tied to preliminary tool results.'
291
+ )
292
+ assert last_tool_return is not None
293
+ last_tool_return.content = user_content
294
+
295
+ return tool_call_results
296
+
216
297
  async def _reevaluate_dynamic_prompts(
217
298
  self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
218
299
  ) -> None:
@@ -280,8 +361,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
280
361
 
281
362
  request: _messages.ModelRequest
282
363
 
283
- _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
284
- _did_stream: bool = field(default=False, repr=False)
364
+ _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
365
+ _did_stream: bool = field(repr=False, init=False, default=False)
285
366
 
286
367
  async def run(
287
368
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -310,13 +391,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
310
391
  self._did_stream = True
311
392
  ctx.state.usage.requests += 1
312
393
  agent_stream = result.AgentStream[DepsT, T](
313
- streamed_response,
314
- ctx.deps.output_schema,
315
- model_request_parameters,
316
- ctx.deps.output_validators,
317
- build_run_context(ctx),
318
- ctx.deps.usage_limits,
319
- ctx.deps.tool_manager,
394
+ _raw_stream_response=streamed_response,
395
+ _output_schema=ctx.deps.output_schema,
396
+ _model_request_parameters=model_request_parameters,
397
+ _output_validators=ctx.deps.output_validators,
398
+ _run_ctx=build_run_context(ctx),
399
+ _usage_limits=ctx.deps.usage_limits,
400
+ _tool_manager=ctx.deps.tool_manager,
320
401
  )
321
402
  yield agent_stream
322
403
  # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -396,14 +477,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
396
477
 
397
478
  model_response: _messages.ModelResponse
398
479
 
399
- _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
480
+ _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
400
481
  _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
401
- default=None, repr=False
482
+ default=None, init=False, repr=False
402
483
  )
403
484
 
404
485
  async def run(
405
486
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
406
- ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
487
+ ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
407
488
  async with self.stream(ctx):
408
489
  pass
409
490
  assert self._next_node is not None, 'the stream should set `self._next_node` before it ends'
@@ -506,13 +587,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
506
587
  if output_final_result:
507
588
  final_result = output_final_result[0]
508
589
  self._next_node = self._handle_final_result(ctx, final_result, output_parts)
509
- elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
510
- if not ctx.deps.output_schema.allows_deferred_tool_calls:
511
- raise exceptions.UserError(
512
- 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
513
- )
514
- final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
515
- self._next_node = self._handle_final_result(ctx, final_result, output_parts)
516
590
  else:
517
591
  instructions = await ctx.deps.get_instructions(run_context)
518
592
  self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
@@ -557,7 +631,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
557
631
  ctx.state.increment_retries(ctx.deps.max_result_retries, e)
558
632
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
559
633
  else:
560
- return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
634
+ return self._handle_final_result(ctx, result.FinalResult(result_data), [])
561
635
 
562
636
 
563
637
  def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
@@ -572,6 +646,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
572
646
  trace_include_content=ctx.deps.instrumentation_settings is not None
573
647
  and ctx.deps.instrumentation_settings.include_content,
574
648
  run_step=ctx.state.run_step,
649
+ tool_call_approved=ctx.state.run_step == 0 and ctx.deps.tool_call_results is not None,
575
650
  )
576
651
 
577
652
 
@@ -599,7 +674,10 @@ async def process_function_tools( # noqa: C901
599
674
  tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
600
675
  for call in tool_calls:
601
676
  tool_def = tool_manager.get_tool_def(call.tool_name)
602
- kind = tool_def.kind if tool_def else 'unknown'
677
+ if tool_def:
678
+ kind = tool_def.kind
679
+ else:
680
+ kind = 'unknown'
603
681
  tool_calls_by_kind[kind].append(call)
604
682
 
605
683
  # First, we handle output tool calls
@@ -662,132 +740,224 @@ async def process_function_tools( # noqa: C901
662
740
  ctx.state.increment_retries(ctx.deps.max_result_retries)
663
741
  calls_to_run.extend(tool_calls_by_kind['unknown'])
664
742
 
665
- for call in calls_to_run:
666
- yield _messages.FunctionToolCallEvent(call)
743
+ deferred_tool_results: dict[str, DeferredToolResult] = {}
744
+ if build_run_context(ctx).tool_call_approved and ctx.deps.tool_call_results is not None:
745
+ deferred_tool_results = ctx.deps.tool_call_results
667
746
 
668
- user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list)
747
+ # Deferred tool calls are "run" as well, by reading their value from the tool call results
748
+ calls_to_run.extend(tool_calls_by_kind['external'])
749
+ calls_to_run.extend(tool_calls_by_kind['unapproved'])
669
750
 
670
- if calls_to_run:
671
- # Run all tool tasks in parallel
672
- tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
673
- with ctx.deps.tracer.start_as_current_span(
674
- 'running tools',
675
- attributes={
676
- 'tools': [call.tool_name for call in calls_to_run],
677
- 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
678
- },
679
- ):
680
- tasks = [
681
- asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
682
- for call in calls_to_run
683
- ]
684
-
685
- pending = tasks
686
- while pending:
687
- done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
688
- for task in done:
689
- index = tasks.index(task)
690
- tool_part, tool_user_parts = task.result()
691
- yield _messages.FunctionToolResultEvent(tool_part)
751
+ result_tool_call_ids = set(deferred_tool_results.keys())
752
+ tool_call_ids_to_run = {call.tool_call_id for call in calls_to_run}
753
+ if tool_call_ids_to_run != result_tool_call_ids:
754
+ raise exceptions.UserError(
755
+ 'Tool call results need to be provided for all deferred tool calls. '
756
+ f'Expected: {tool_call_ids_to_run}, got: {result_tool_call_ids}'
757
+ )
692
758
 
693
- tool_parts_by_index[index] = tool_part
694
- user_parts_by_index[index] = tool_user_parts
759
+ deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
695
760
 
696
- # We append the results at the end, rather than as they are received, to retain a consistent ordering
697
- # This is mostly just to simplify testing
698
- for k in sorted(tool_parts_by_index):
699
- output_parts.append(tool_parts_by_index[k])
761
+ if calls_to_run:
762
+ async for event in _call_tools(
763
+ tool_manager,
764
+ calls_to_run,
765
+ deferred_tool_results,
766
+ ctx.deps.tracer,
767
+ output_parts,
768
+ deferred_calls,
769
+ ):
770
+ yield event
700
771
 
701
- # Finally, we handle deferred tool calls
702
- for call in tool_calls_by_kind['deferred']:
772
+ # Finally, we handle deferred tool calls (unless they were already included in the run because results were provided)
773
+ if not deferred_tool_results:
703
774
  if final_result:
704
- output_parts.append(
705
- _messages.ToolReturnPart(
706
- tool_name=call.tool_name,
707
- content='Tool not executed - a final result was already processed.',
708
- tool_call_id=call.tool_call_id,
775
+ for call in [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']]:
776
+ output_parts.append(
777
+ _messages.ToolReturnPart(
778
+ tool_name=call.tool_name,
779
+ content='Tool not executed - a final result was already processed.',
780
+ tool_call_id=call.tool_call_id,
781
+ )
709
782
  )
710
- )
711
783
  else:
712
- yield _messages.FunctionToolCallEvent(call)
784
+ for call in tool_calls_by_kind['external']:
785
+ deferred_calls['external'].append(call)
786
+ yield _messages.FunctionToolCallEvent(call)
713
787
 
714
- for k in sorted(user_parts_by_index):
715
- output_parts.extend(user_parts_by_index[k])
788
+ for call in tool_calls_by_kind['unapproved']:
789
+ deferred_calls['unapproved'].append(call)
790
+ yield _messages.FunctionToolCallEvent(call)
791
+
792
+ if not final_result and deferred_calls:
793
+ if not ctx.deps.output_schema.allows_deferred_tools:
794
+ raise exceptions.UserError(
795
+ 'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
796
+ )
797
+ deferred_tool_requests = _output.DeferredToolRequests(
798
+ calls=deferred_calls['external'],
799
+ approvals=deferred_calls['unapproved'],
800
+ )
801
+
802
+ final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
716
803
 
717
804
  if final_result:
718
805
  output_final_result.append(final_result)
719
806
 
720
807
 
721
- async def _call_function_tool(
808
+ async def _call_tools(
722
809
  tool_manager: ToolManager[DepsT],
723
- tool_call: _messages.ToolCallPart,
724
- ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]:
725
- try:
726
- tool_result = await tool_manager.handle_call(tool_call)
727
- except ToolRetryError as e:
728
- return (e.tool_retry, [])
810
+ tool_calls: list[_messages.ToolCallPart],
811
+ deferred_tool_results: dict[str, DeferredToolResult],
812
+ tracer: Tracer,
813
+ output_parts: list[_messages.ModelRequestPart],
814
+ output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
815
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
816
+ tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
817
+ user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
818
+ deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
729
819
 
730
- tool_part = _messages.ToolReturnPart(
731
- tool_name=tool_call.tool_name,
732
- content=tool_result,
733
- tool_call_id=tool_call.tool_call_id,
734
- )
735
- user_parts: list[_messages.UserPromptPart] = []
820
+ for call in tool_calls:
821
+ yield _messages.FunctionToolCallEvent(call)
736
822
 
737
- if isinstance(tool_result, _messages.ToolReturn):
738
- if (
739
- isinstance(tool_result.return_value, _messages.MultiModalContentTypes)
740
- or isinstance(tool_result.return_value, list)
741
- and any(
742
- isinstance(content, _messages.MultiModalContentTypes)
743
- for content in tool_result.return_value # type: ignore
744
- )
745
- ):
746
- raise exceptions.UserError(
747
- f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. '
748
- f'Please use `content` instead.'
823
+ # Run all tool tasks in parallel
824
+ with tracer.start_as_current_span(
825
+ 'running tools',
826
+ attributes={
827
+ 'tools': [call.tool_name for call in tool_calls],
828
+ 'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
829
+ },
830
+ ):
831
+ tasks = [
832
+ asyncio.create_task(
833
+ _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id)),
834
+ name=call.tool_name,
749
835
  )
836
+ for call in tool_calls
837
+ ]
838
+
839
+ pending = tasks
840
+ while pending:
841
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
842
+ for task in done:
843
+ index = tasks.index(task)
844
+ try:
845
+ tool_part, tool_user_part = task.result()
846
+ except exceptions.CallDeferred:
847
+ deferred_calls_by_index[index] = 'external'
848
+ except exceptions.ApprovalRequired:
849
+ deferred_calls_by_index[index] = 'unapproved'
850
+ else:
851
+ yield _messages.FunctionToolResultEvent(tool_part)
750
852
 
751
- tool_part.content = tool_result.return_value # type: ignore
752
- tool_part.metadata = tool_result.metadata
753
- if tool_result.content:
754
- user_parts.append(
755
- _messages.UserPromptPart(
756
- content=tool_result.content,
757
- part_kind='user-prompt',
758
- )
853
+ tool_parts_by_index[index] = tool_part
854
+ if tool_user_part:
855
+ user_parts_by_index[index] = tool_user_part
856
+
857
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
858
+ # This is mostly just to simplify testing
859
+ for k in sorted(tool_parts_by_index):
860
+ output_parts.append(tool_parts_by_index[k])
861
+
862
+ for k in sorted(user_parts_by_index):
863
+ output_parts.append(user_parts_by_index[k])
864
+
865
+ for k in sorted(deferred_calls_by_index):
866
+ output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
867
+
868
+
869
+ async def _call_tool(
870
+ tool_manager: ToolManager[DepsT],
871
+ tool_call: _messages.ToolCallPart,
872
+ tool_call_result: DeferredToolResult | None,
873
+ ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
874
+ try:
875
+ if tool_call_result is None:
876
+ tool_result = await tool_manager.handle_call(tool_call)
877
+ elif isinstance(tool_call_result, ToolApproved):
878
+ if tool_call_result.override_args is not None:
879
+ tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
880
+ tool_result = await tool_manager.handle_call(tool_call)
881
+ elif isinstance(tool_call_result, ToolDenied):
882
+ return _messages.ToolReturnPart(
883
+ tool_name=tool_call.tool_name,
884
+ content=tool_call_result.message,
885
+ tool_call_id=tool_call.tool_call_id,
886
+ ), None
887
+ elif isinstance(tool_call_result, exceptions.ModelRetry):
888
+ m = _messages.RetryPromptPart(
889
+ content=tool_call_result.message,
890
+ tool_name=tool_call.tool_name,
891
+ tool_call_id=tool_call.tool_call_id,
759
892
  )
893
+ raise ToolRetryError(m)
894
+ elif isinstance(tool_call_result, _messages.RetryPromptPart):
895
+ tool_call_result.tool_name = tool_call.tool_name
896
+ tool_call_result.tool_call_id = tool_call.tool_call_id
897
+ raise ToolRetryError(tool_call_result)
898
+ else:
899
+ tool_result = tool_call_result
900
+ except ToolRetryError as e:
901
+ return e.tool_retry, None
902
+
903
+ if isinstance(tool_result, _messages.ToolReturn):
904
+ tool_return = tool_result
760
905
  else:
906
+ result_is_list = isinstance(tool_result, list)
907
+ contents = cast(list[Any], tool_result) if result_is_list else [tool_result]
761
908
 
762
- def process_content(content: Any) -> Any:
909
+ return_values: list[Any] = []
910
+ user_contents: list[str | _messages.UserContent] = []
911
+ for content in contents:
763
912
  if isinstance(content, _messages.ToolReturn):
764
913
  raise exceptions.UserError(
765
914
  f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
766
915
  f'`ToolReturn` should be used directly.'
767
916
  )
768
- elif isinstance(content, _messages.MultiModalContentTypes):
917
+ elif isinstance(content, _messages.MultiModalContent):
769
918
  if isinstance(content, _messages.BinaryContent):
770
919
  identifier = content.identifier or multi_modal_content_identifier(content.data)
771
920
  else:
772
921
  identifier = multi_modal_content_identifier(content.url)
773
922
 
774
- user_parts.append(
775
- _messages.UserPromptPart(
776
- content=[f'This is file {identifier}:', content],
777
- part_kind='user-prompt',
778
- )
779
- )
780
- return f'See file {identifier}'
923
+ return_values.append(f'See file {identifier}')
924
+ user_contents.extend([f'This is file {identifier}:', content])
925
+ else:
926
+ return_values.append(content)
781
927
 
782
- return content
928
+ tool_return = _messages.ToolReturn(
929
+ return_value=return_values[0] if len(return_values) == 1 and not result_is_list else return_values,
930
+ content=user_contents,
931
+ )
783
932
 
784
- if isinstance(tool_result, list):
785
- contents = cast(list[Any], tool_result)
786
- tool_part.content = [process_content(content) for content in contents]
787
- else:
788
- tool_part.content = process_content(tool_result)
933
+ if (
934
+ isinstance(tool_return.return_value, _messages.MultiModalContent)
935
+ or isinstance(tool_return.return_value, list)
936
+ and any(
937
+ isinstance(content, _messages.MultiModalContent)
938
+ for content in tool_return.return_value # type: ignore
939
+ )
940
+ ):
941
+ raise exceptions.UserError(
942
+ f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContent` objects. '
943
+ f'Please use `content` instead.'
944
+ )
945
+
946
+ return_part = _messages.ToolReturnPart(
947
+ tool_name=tool_call.tool_name,
948
+ tool_call_id=tool_call.tool_call_id,
949
+ content=tool_return.return_value, # type: ignore
950
+ metadata=tool_return.metadata,
951
+ )
952
+
953
+ user_part: _messages.UserPromptPart | None = None
954
+ if tool_return.content:
955
+ user_part = _messages.UserPromptPart(
956
+ content=tool_return.content,
957
+ part_kind='user-prompt',
958
+ )
789
959
 
790
- return (tool_part, user_parts)
960
+ return return_part, user_part
791
961
 
792
962
 
793
963
  @dataclasses.dataclass
@@ -5,10 +5,10 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
5
5
 
6
6
  from __future__ import annotations as _annotations
7
7
 
8
- from collections.abc import Awaitable
8
+ from collections.abc import Awaitable, Callable
9
9
  from dataclasses import dataclass, field
10
10
  from inspect import Parameter, signature
11
- from typing import TYPE_CHECKING, Any, Callable, Union, cast
11
+ from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin
12
12
 
13
13
  from pydantic import ConfigDict
14
14
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -17,7 +17,7 @@ from pydantic.fields import FieldInfo
17
17
  from pydantic.json_schema import GenerateJsonSchema
18
18
  from pydantic.plugin._schema_validator import create_schema_validator
19
19
  from pydantic_core import SchemaValidator, core_schema
20
- from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar, get_origin
20
+ from typing_extensions import ParamSpec, TypeIs, TypeVar
21
21
 
22
22
  from ._griffe import doc_descriptions
23
23
  from ._run_context import RunContext
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
30
30
  __all__ = ('function_schema',)
31
31
 
32
32
 
33
- @dataclass
33
+ @dataclass(kw_only=True)
34
34
  class FunctionSchema:
35
35
  """Internal information about a function schema."""
36
36
 
@@ -231,7 +231,7 @@ R = TypeVar('R')
231
231
 
232
232
  WithCtx = Callable[Concatenate[RunContext[Any], P], R]
233
233
  WithoutCtx = Callable[P, R]
234
- TargetFunc = Union[WithCtx[P, R], WithoutCtx[P, R]]
234
+ TargetFunc = WithCtx[P, R] | WithoutCtx[P, R]
235
235
 
236
236
 
237
237
  def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]: