pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (45) hide show
  1. pydantic_ai/_agent_graph.py +220 -319
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +295 -331
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -156
  10. pydantic_ai/exceptions.py +12 -0
  11. pydantic_ai/ext/aci.py +12 -3
  12. pydantic_ai/ext/langchain.py +9 -1
  13. pydantic_ai/mcp.py +147 -84
  14. pydantic_ai/messages.py +13 -5
  15. pydantic_ai/models/__init__.py +30 -18
  16. pydantic_ai/models/anthropic.py +1 -1
  17. pydantic_ai/models/function.py +50 -24
  18. pydantic_ai/models/gemini.py +1 -9
  19. pydantic_ai/models/google.py +2 -11
  20. pydantic_ai/models/groq.py +1 -0
  21. pydantic_ai/models/mistral.py +1 -1
  22. pydantic_ai/models/openai.py +3 -3
  23. pydantic_ai/output.py +21 -7
  24. pydantic_ai/profiles/google.py +1 -1
  25. pydantic_ai/profiles/moonshotai.py +8 -0
  26. pydantic_ai/providers/grok.py +13 -1
  27. pydantic_ai/providers/groq.py +2 -0
  28. pydantic_ai/result.py +58 -45
  29. pydantic_ai/tools.py +26 -119
  30. pydantic_ai/toolsets/__init__.py +22 -0
  31. pydantic_ai/toolsets/abstract.py +155 -0
  32. pydantic_ai/toolsets/combined.py +88 -0
  33. pydantic_ai/toolsets/deferred.py +38 -0
  34. pydantic_ai/toolsets/filtered.py +24 -0
  35. pydantic_ai/toolsets/function.py +238 -0
  36. pydantic_ai/toolsets/prefixed.py +37 -0
  37. pydantic_ai/toolsets/prepared.py +36 -0
  38. pydantic_ai/toolsets/renamed.py +42 -0
  39. pydantic_ai/toolsets/wrapper.py +37 -0
  40. pydantic_ai/usage.py +14 -8
  41. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
  42. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
  43. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  44. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  45. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import asyncio
4
4
  import dataclasses
5
5
  import hashlib
6
+ from collections import defaultdict, deque
6
7
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
8
  from contextlib import asynccontextmanager, contextmanager
8
9
  from contextvars import ContextVar
@@ -13,17 +14,18 @@ from opentelemetry.trace import Tracer
13
14
  from typing_extensions import TypeGuard, TypeVar, assert_never
14
15
 
15
16
  from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
17
+ from pydantic_ai._tool_manager import ToolManager
16
18
  from pydantic_ai._utils import is_async_callable, run_in_executor
17
19
  from pydantic_graph import BaseNode, Graph, GraphRunContext
18
20
  from pydantic_graph.nodes import End, NodeRunEndT
19
21
 
20
22
  from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
23
+ from .exceptions import ToolRetryError
21
24
  from .output import OutputDataT, OutputSpec
22
25
  from .settings import ModelSettings, merge_model_settings
23
- from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
26
+ from .tools import RunContext, ToolDefinition, ToolKind
24
27
 
25
28
  if TYPE_CHECKING:
26
- from .mcp import MCPServer
27
29
  from .models.instrumented import InstrumentationSettings
28
30
 
29
31
  __all__ = (
@@ -77,11 +79,13 @@ class GraphAgentState:
77
79
  retries: int
78
80
  run_step: int
79
81
 
80
- def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None:
82
+ def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
81
83
  self.retries += 1
82
84
  if self.retries > max_result_retries:
83
- message = f'Exceeded maximum retries ({max_result_retries}) for result validation'
85
+ message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
84
86
  if error:
87
+ if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
88
+ error = error.__cause__
85
89
  raise exceptions.UnexpectedModelBehavior(message) from error
86
90
  else:
87
91
  raise exceptions.UnexpectedModelBehavior(message)
@@ -108,15 +112,11 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
108
112
 
109
113
  history_processors: Sequence[HistoryProcessor[DepsT]]
110
114
 
111
- function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
112
- mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
113
- default_retries: int
115
+ tool_manager: ToolManager[DepsT]
114
116
 
115
117
  tracer: Tracer
116
118
  instrumentation_settings: InstrumentationSettings | None = None
117
119
 
118
- prepare_tools: ToolsPrepareFunc[DepsT] | None = None
119
-
120
120
 
121
121
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
122
122
  """The base class for all agent nodes.
@@ -248,59 +248,27 @@ async def _prepare_request_parameters(
248
248
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
249
249
  ) -> models.ModelRequestParameters:
250
250
  """Build tools and create an agent model."""
251
- function_tool_defs_map: dict[str, ToolDefinition] = {}
252
-
253
251
  run_context = build_run_context(ctx)
254
-
255
- async def add_tool(tool: Tool[DepsT]) -> None:
256
- ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
257
- if tool_def := await tool.prepare_tool_def(ctx):
258
- # prepare_tool_def may change tool_def.name
259
- if tool_def.name in function_tool_defs_map:
260
- if tool_def.name != tool.name:
261
- # Prepare tool def may have renamed the tool
262
- raise exceptions.UserError(
263
- f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool."
264
- )
265
- else:
266
- raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.')
267
- function_tool_defs_map[tool_def.name] = tool_def
268
-
269
- async def add_mcp_server_tools(server: MCPServer) -> None:
270
- if not server.is_running:
271
- raise exceptions.UserError(f'MCP server is not running: {server}')
272
- tool_defs = await server.list_tools()
273
- for tool_def in tool_defs:
274
- if tool_def.name in function_tool_defs_map:
275
- raise exceptions.UserError(
276
- f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts."
277
- )
278
- function_tool_defs_map[tool_def.name] = tool_def
279
-
280
- await asyncio.gather(
281
- *map(add_tool, ctx.deps.function_tools.values()),
282
- *map(add_mcp_server_tools, ctx.deps.mcp_servers),
283
- )
284
- function_tool_defs = list(function_tool_defs_map.values())
285
- if ctx.deps.prepare_tools:
286
- # Prepare the tools using the provided function
287
- # This also acts over tool definitions pulled from MCP servers
288
- function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
252
+ ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
289
253
 
290
254
  output_schema = ctx.deps.output_schema
291
-
292
- output_tools = []
293
255
  output_object = None
294
- if isinstance(output_schema, _output.ToolOutputSchema):
295
- output_tools = output_schema.tool_defs()
296
- elif isinstance(output_schema, _output.NativeOutputSchema):
256
+ if isinstance(output_schema, _output.NativeOutputSchema):
297
257
  output_object = output_schema.object_def
298
258
 
299
259
  # ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema
300
260
  allow_text_output = isinstance(output_schema, _output.TextOutputSchema)
301
261
 
262
+ function_tools: list[ToolDefinition] = []
263
+ output_tools: list[ToolDefinition] = []
264
+ for tool_def in ctx.deps.tool_manager.tool_defs:
265
+ if tool_def.kind == 'output':
266
+ output_tools.append(tool_def)
267
+ else:
268
+ function_tools.append(tool_def)
269
+
302
270
  return models.ModelRequestParameters(
303
- function_tools=function_tool_defs,
271
+ function_tools=function_tools,
304
272
  output_mode=output_schema.mode,
305
273
  output_tools=output_tools,
306
274
  output_object=output_object,
@@ -341,8 +309,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
341
309
  ctx.deps.output_schema,
342
310
  ctx.deps.output_validators,
343
311
  build_run_context(ctx),
344
- _output.build_trace_context(ctx),
345
312
  ctx.deps.usage_limits,
313
+ ctx.deps.tool_manager,
346
314
  )
347
315
  yield agent_stream
348
316
  # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -438,7 +406,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
438
406
  _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
439
407
  default=None, repr=False
440
408
  )
441
- _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
442
409
 
443
410
  async def run(
444
411
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -520,47 +487,30 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
520
487
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
521
488
  tool_calls: list[_messages.ToolCallPart],
522
489
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
523
- output_schema = ctx.deps.output_schema
524
490
  run_context = build_run_context(ctx)
525
491
 
526
- final_result: result.FinalResult[NodeRunEndT] | None = None
527
- parts: list[_messages.ModelRequestPart] = []
492
+ output_parts: list[_messages.ModelRequestPart] = []
493
+ output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
528
494
 
529
- # first, look for the output tool call
530
- if isinstance(output_schema, _output.ToolOutputSchema):
531
- for call, output_tool in output_schema.find_tool(tool_calls):
532
- try:
533
- trace_context = _output.build_trace_context(ctx)
534
- result_data = await output_tool.process(call, run_context, trace_context)
535
- result_data = await _validate_output(result_data, ctx, call)
536
- except _output.ToolRetryError as e:
537
- # TODO: Should only increment retry stuff once per node execution, not for each tool call
538
- # Also, should increment the tool-specific retry count rather than the run retry count
539
- ctx.state.increment_retries(ctx.deps.max_result_retries, e)
540
- parts.append(e.tool_retry)
541
- else:
542
- final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
543
- break
544
-
545
- # Then build the other request parts based on end strategy
546
- tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
547
495
  async for event in process_function_tools(
548
- tool_calls,
549
- final_result and final_result.tool_name,
550
- final_result and final_result.tool_call_id,
551
- ctx,
552
- tool_responses,
496
+ ctx.deps.tool_manager, tool_calls, None, ctx, output_parts, output_final_result
553
497
  ):
554
498
  yield event
555
499
 
556
- if final_result:
557
- self._next_node = self._handle_final_result(ctx, final_result, tool_responses)
500
+ if output_final_result:
501
+ final_result = output_final_result[0]
502
+ self._next_node = self._handle_final_result(ctx, final_result, output_parts)
503
+ elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
504
+ if not ctx.deps.output_schema.allows_deferred_tool_calls:
505
+ raise exceptions.UserError(
506
+ 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
507
+ )
508
+ final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
509
+ self._next_node = self._handle_final_result(ctx, final_result, output_parts)
558
510
  else:
559
- if tool_responses:
560
- parts.extend(tool_responses)
561
511
  instructions = await ctx.deps.get_instructions(run_context)
562
512
  self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
563
- _messages.ModelRequest(parts=parts, instructions=instructions)
513
+ _messages.ModelRequest(parts=output_parts, instructions=instructions)
564
514
  )
565
515
 
566
516
  def _handle_final_result(
@@ -586,18 +536,18 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
586
536
 
587
537
  text = '\n\n'.join(texts)
588
538
  try:
539
+ run_context = build_run_context(ctx)
589
540
  if isinstance(output_schema, _output.TextOutputSchema):
590
- run_context = build_run_context(ctx)
591
- trace_context = _output.build_trace_context(ctx)
592
- result_data = await output_schema.process(text, run_context, trace_context)
541
+ result_data = await output_schema.process(text, run_context)
593
542
  else:
594
543
  m = _messages.RetryPromptPart(
595
544
  content='Plain text responses are not permitted, please include your response in a tool call',
596
545
  )
597
- raise _output.ToolRetryError(m)
546
+ raise ToolRetryError(m)
598
547
 
599
- result_data = await _validate_output(result_data, ctx, None)
600
- except _output.ToolRetryError as e:
548
+ for validator in ctx.deps.output_validators:
549
+ result_data = await validator.validate(result_data, run_context)
550
+ except ToolRetryError as e:
601
551
  ctx.state.increment_retries(ctx.deps.max_result_retries, e)
602
552
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
603
553
  else:
@@ -612,6 +562,9 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
612
562
  usage=ctx.state.usage,
613
563
  prompt=ctx.deps.prompt,
614
564
  messages=ctx.state.message_history,
565
+ tracer=ctx.deps.tracer,
566
+ trace_include_content=ctx.deps.instrumentation_settings is not None
567
+ and ctx.deps.instrumentation_settings.include_content,
615
568
  run_step=ctx.state.run_step,
616
569
  )
617
570
 
@@ -623,269 +576,210 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
623
576
  return hashlib.sha1(identifier).hexdigest()[:6]
624
577
 
625
578
 
626
- async def process_function_tools( # noqa C901
579
+ async def process_function_tools( # noqa: C901
580
+ tool_manager: ToolManager[DepsT],
627
581
  tool_calls: list[_messages.ToolCallPart],
628
- output_tool_name: str | None,
629
- output_tool_call_id: str | None,
582
+ final_result: result.FinalResult[NodeRunEndT] | None,
630
583
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
631
584
  output_parts: list[_messages.ModelRequestPart],
585
+ output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1),
632
586
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
633
587
  """Process function (i.e., non-result) tool calls in parallel.
634
588
 
635
589
  Also add stub return parts for any other tools that need it.
636
590
 
637
- Because async iterators can't have return values, we use `output_parts` as an output argument.
591
+ Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments.
638
592
  """
639
- stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early'
640
- output_schema = ctx.deps.output_schema
641
-
642
- # we rely on the fact that if we found a result, it's the first output tool in the last
643
- found_used_output_tool = False
644
- run_context = build_run_context(ctx)
645
-
646
- calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
593
+ tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
647
594
  for call in tool_calls:
648
- if (
649
- call.tool_name == output_tool_name
650
- and call.tool_call_id == output_tool_call_id
651
- and not found_used_output_tool
652
- ):
653
- found_used_output_tool = True
654
- output_parts.append(
655
- _messages.ToolReturnPart(
595
+ tool_def = tool_manager.get_tool_def(call.tool_name)
596
+ kind = tool_def.kind if tool_def else 'unknown'
597
+ tool_calls_by_kind[kind].append(call)
598
+
599
+ # First, we handle output tool calls
600
+ for call in tool_calls_by_kind['output']:
601
+ if final_result:
602
+ if final_result.tool_call_id == call.tool_call_id:
603
+ part = _messages.ToolReturnPart(
656
604
  tool_name=call.tool_name,
657
605
  content='Final result processed.',
658
606
  tool_call_id=call.tool_call_id,
659
607
  )
660
- )
661
- elif tool := ctx.deps.function_tools.get(call.tool_name):
662
- if stub_function_tools:
663
- output_parts.append(
664
- _messages.ToolReturnPart(
665
- tool_name=call.tool_name,
666
- content='Tool not executed - a final result was already processed.',
667
- tool_call_id=call.tool_call_id,
668
- )
669
- )
670
608
  else:
671
- event = _messages.FunctionToolCallEvent(call)
672
- yield event
673
- calls_to_run.append((tool, call))
674
- elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
675
- if stub_function_tools:
676
- # TODO(Marcelo): We should add coverage for this part of the code.
677
- output_parts.append( # pragma: no cover
678
- _messages.ToolReturnPart(
679
- tool_name=call.tool_name,
680
- content='Tool not executed - a final result was already processed.',
681
- tool_call_id=call.tool_call_id,
682
- )
683
- )
684
- else:
685
- event = _messages.FunctionToolCallEvent(call)
686
- yield event
687
- calls_to_run.append((mcp_tool, call))
688
- elif call.tool_name in output_schema.tools:
689
- # if tool_name is in output_schema, it means we found a output tool but an error occurred in
690
- # validation, we don't add another part here
691
- if output_tool_name is not None:
692
609
  yield _messages.FunctionToolCallEvent(call)
693
- if found_used_output_tool:
694
- content = 'Output tool not used - a final result was already processed.'
695
- else:
696
- # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
697
- content = 'Output tool not used - result failed validation.'
698
610
  part = _messages.ToolReturnPart(
699
611
  tool_name=call.tool_name,
700
- content=content,
612
+ content='Output tool not used - a final result was already processed.',
701
613
  tool_call_id=call.tool_call_id,
702
614
  )
703
615
  yield _messages.FunctionToolResultEvent(part)
704
- output_parts.append(part)
705
- else:
706
- yield _messages.FunctionToolCallEvent(call)
707
616
 
708
- part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
709
- yield _messages.FunctionToolResultEvent(part)
710
617
  output_parts.append(part)
618
+ else:
619
+ try:
620
+ result_data = await tool_manager.handle_call(call)
621
+ except exceptions.UnexpectedModelBehavior as e:
622
+ ctx.state.increment_retries(ctx.deps.max_result_retries, e)
623
+ raise e # pragma: no cover
624
+ except ToolRetryError as e:
625
+ ctx.state.increment_retries(ctx.deps.max_result_retries, e)
626
+ yield _messages.FunctionToolCallEvent(call)
627
+ output_parts.append(e.tool_retry)
628
+ yield _messages.FunctionToolResultEvent(e.tool_retry)
629
+ else:
630
+ part = _messages.ToolReturnPart(
631
+ tool_name=call.tool_name,
632
+ content='Final result processed.',
633
+ tool_call_id=call.tool_call_id,
634
+ )
635
+ output_parts.append(part)
636
+ final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
711
637
 
712
- if not calls_to_run:
713
- return
714
-
715
- user_parts: list[_messages.UserPromptPart] = []
638
+ # Then, we handle function tool calls
639
+ calls_to_run: list[_messages.ToolCallPart] = []
640
+ if final_result and ctx.deps.end_strategy == 'early':
641
+ output_parts.extend(
642
+ [
643
+ _messages.ToolReturnPart(
644
+ tool_name=call.tool_name,
645
+ content='Tool not executed - a final result was already processed.',
646
+ tool_call_id=call.tool_call_id,
647
+ )
648
+ for call in tool_calls_by_kind['function']
649
+ ]
650
+ )
651
+ else:
652
+ calls_to_run.extend(tool_calls_by_kind['function'])
716
653
 
717
- include_content = (
718
- ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
719
- )
654
+ # Then, we handle unknown tool calls
655
+ if tool_calls_by_kind['unknown']:
656
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
657
+ calls_to_run.extend(tool_calls_by_kind['unknown'])
720
658
 
721
- # Run all tool tasks in parallel
722
- results_by_index: dict[int, _messages.ModelRequestPart] = {}
723
- with ctx.deps.tracer.start_as_current_span(
724
- 'running tools',
725
- attributes={
726
- 'tools': [call.tool_name for _, call in calls_to_run],
727
- 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
728
- },
729
- ):
730
- tasks = [
731
- asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer, include_content), name=call.tool_name)
732
- for tool, call in calls_to_run
733
- ]
734
-
735
- pending = tasks
736
- while pending:
737
- done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
738
- for task in done:
739
- index = tasks.index(task)
740
- result = task.result()
741
- yield _messages.FunctionToolResultEvent(result)
742
-
743
- if isinstance(result, _messages.RetryPromptPart):
744
- results_by_index[index] = result
745
- elif isinstance(result, _messages.ToolReturnPart):
746
- if isinstance(result.content, _messages.ToolReturn):
747
- tool_return = result.content
748
- if (
749
- isinstance(tool_return.return_value, _messages.MultiModalContentTypes)
750
- or isinstance(tool_return.return_value, list)
751
- and any(
752
- isinstance(content, _messages.MultiModalContentTypes)
753
- for content in tool_return.return_value # type: ignore
754
- )
755
- ):
756
- raise exceptions.UserError(
757
- f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. "
758
- f'Please use `content` instead.'
759
- )
760
- result.content = tool_return.return_value # type: ignore
761
- result.metadata = tool_return.metadata
762
- if tool_return.content:
763
- user_parts.append(
764
- _messages.UserPromptPart(
765
- content=list(tool_return.content),
766
- timestamp=result.timestamp,
767
- part_kind='user-prompt',
768
- )
769
- )
770
- contents: list[Any]
771
- single_content: bool
772
- if isinstance(result.content, list):
773
- contents = result.content # type: ignore
774
- single_content = False
775
- else:
776
- contents = [result.content]
777
- single_content = True
778
-
779
- processed_contents: list[Any] = []
780
- for content in contents:
781
- if isinstance(content, _messages.ToolReturn):
782
- raise exceptions.UserError(
783
- f"{result.tool_name}'s return contains invalid nested ToolReturn objects. "
784
- f'ToolReturn should be used directly.'
785
- )
786
- elif isinstance(content, _messages.MultiModalContentTypes):
787
- # Handle direct multimodal content
788
- if isinstance(content, _messages.BinaryContent):
789
- identifier = multi_modal_content_identifier(content.data)
790
- else:
791
- identifier = multi_modal_content_identifier(content.url)
792
-
793
- user_parts.append(
794
- _messages.UserPromptPart(
795
- content=[f'This is file {identifier}:', content],
796
- timestamp=result.timestamp,
797
- part_kind='user-prompt',
798
- )
799
- )
800
- processed_contents.append(f'See file {identifier}')
801
- else:
802
- # Handle regular content
803
- processed_contents.append(content)
804
-
805
- if single_content:
806
- result.content = processed_contents[0]
807
- else:
808
- result.content = processed_contents
659
+ for call in calls_to_run:
660
+ yield _messages.FunctionToolCallEvent(call)
809
661
 
810
- results_by_index[index] = result
811
- else:
812
- assert_never(result)
662
+ user_parts: list[_messages.UserPromptPart] = []
813
663
 
814
- # We append the results at the end, rather than as they are received, to retain a consistent ordering
815
- # This is mostly just to simplify testing
816
- for k in sorted(results_by_index):
817
- output_parts.append(results_by_index[k])
664
+ if calls_to_run:
665
+ # Run all tool tasks in parallel
666
+ parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {}
667
+ with ctx.deps.tracer.start_as_current_span(
668
+ 'running tools',
669
+ attributes={
670
+ 'tools': [call.tool_name for call in calls_to_run],
671
+ 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
672
+ },
673
+ ):
674
+ tasks = [
675
+ asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
676
+ for call in calls_to_run
677
+ ]
678
+
679
+ pending = tasks
680
+ while pending:
681
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
682
+ for task in done:
683
+ index = tasks.index(task)
684
+ tool_result_part, extra_parts = task.result()
685
+ yield _messages.FunctionToolResultEvent(tool_result_part)
686
+
687
+ parts_by_index[index] = [tool_result_part, *extra_parts]
688
+
689
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
690
+ # This is mostly just to simplify testing
691
+ for k in sorted(parts_by_index):
692
+ output_parts.extend(parts_by_index[k])
693
+
694
+ # Finally, we handle deferred tool calls
695
+ for call in tool_calls_by_kind['deferred']:
696
+ if final_result:
697
+ output_parts.append(
698
+ _messages.ToolReturnPart(
699
+ tool_name=call.tool_name,
700
+ content='Tool not executed - a final result was already processed.',
701
+ tool_call_id=call.tool_call_id,
702
+ )
703
+ )
704
+ else:
705
+ yield _messages.FunctionToolCallEvent(call)
818
706
 
819
707
  output_parts.extend(user_parts)
820
708
 
709
+ if final_result:
710
+ output_final_result.append(final_result)
821
711
 
822
- async def _tool_from_mcp_server(
823
- tool_name: str,
824
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
825
- ) -> Tool[DepsT] | None:
826
- """Call each MCP server to find the tool with the given name.
827
-
828
- Args:
829
- tool_name: The name of the tool to find.
830
- ctx: The current run context.
831
-
832
- Returns:
833
- The tool with the given name, or `None` if no tool with the given name is found.
834
- """
835
-
836
- async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
837
- # There's no normal situation where the server will not be running at this point, we check just in case
838
- # some weird edge case occurs.
839
- if not server.is_running: # pragma: no cover
840
- raise exceptions.UserError(f'MCP server is not running: {server}')
841
712
 
842
- if server.process_tool_call is not None:
843
- result = await server.process_tool_call(ctx, server.call_tool, tool_name, args)
844
- else:
845
- result = await server.call_tool(tool_name, args)
846
-
847
- return result
848
-
849
- for server in ctx.deps.mcp_servers:
850
- tools = await server.list_tools()
851
- if tool_name in {tool.name for tool in tools}: # pragma: no branch
852
- return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries)
853
- return None
713
+ async def _call_function_tool(
714
+ tool_manager: ToolManager[DepsT],
715
+ tool_call: _messages.ToolCallPart,
716
+ ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]:
717
+ try:
718
+ tool_result = await tool_manager.handle_call(tool_call)
719
+ except ToolRetryError as e:
720
+ return (e.tool_retry, [])
721
+
722
+ part = _messages.ToolReturnPart(
723
+ tool_name=tool_call.tool_name,
724
+ content=tool_result,
725
+ tool_call_id=tool_call.tool_call_id,
726
+ )
727
+ extra_parts: list[_messages.ModelRequestPart] = []
854
728
 
729
+ if isinstance(tool_result, _messages.ToolReturn):
730
+ if (
731
+ isinstance(tool_result.return_value, _messages.MultiModalContentTypes)
732
+ or isinstance(tool_result.return_value, list)
733
+ and any(
734
+ isinstance(content, _messages.MultiModalContentTypes)
735
+ for content in tool_result.return_value # type: ignore
736
+ )
737
+ ):
738
+ raise exceptions.UserError(
739
+ f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. '
740
+ f'Please use `content` instead.'
741
+ )
855
742
 
856
- def _unknown_tool(
857
- tool_name: str,
858
- tool_call_id: str,
859
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
860
- ) -> _messages.RetryPromptPart:
861
- ctx.state.increment_retries(ctx.deps.max_result_retries)
862
- tool_names = list(ctx.deps.function_tools.keys())
743
+ part.content = tool_result.return_value # type: ignore
744
+ part.metadata = tool_result.metadata
745
+ if tool_result.content:
746
+ extra_parts.append(
747
+ _messages.UserPromptPart(
748
+ content=list(tool_result.content),
749
+ part_kind='user-prompt',
750
+ )
751
+ )
752
+ else:
863
753
 
864
- output_schema = ctx.deps.output_schema
865
- if isinstance(output_schema, _output.ToolOutputSchema):
866
- tool_names.extend(output_schema.tool_names())
754
+ def process_content(content: Any) -> Any:
755
+ if isinstance(content, _messages.ToolReturn):
756
+ raise exceptions.UserError(
757
+ f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
758
+ f'`ToolReturn` should be used directly.'
759
+ )
760
+ elif isinstance(content, _messages.MultiModalContentTypes):
761
+ if isinstance(content, _messages.BinaryContent):
762
+ identifier = content.identifier or multi_modal_content_identifier(content.data)
763
+ else:
764
+ identifier = multi_modal_content_identifier(content.url)
867
765
 
868
- if tool_names:
869
- msg = f'Available tools: {", ".join(tool_names)}'
870
- else:
871
- msg = 'No tools available.'
766
+ extra_parts.append(
767
+ _messages.UserPromptPart(
768
+ content=[f'This is file {identifier}:', content],
769
+ part_kind='user-prompt',
770
+ )
771
+ )
772
+ return f'See file {identifier}'
872
773
 
873
- return _messages.RetryPromptPart(
874
- tool_name=tool_name,
875
- tool_call_id=tool_call_id,
876
- content=f'Unknown tool name: {tool_name!r}. {msg}',
877
- )
774
+ return content
878
775
 
776
+ if isinstance(tool_result, list):
777
+ contents = cast(list[Any], tool_result)
778
+ part.content = [process_content(content) for content in contents]
779
+ else:
780
+ part.content = process_content(tool_result)
879
781
 
880
- async def _validate_output(
881
- result_data: T,
882
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
883
- tool_call: _messages.ToolCallPart | None,
884
- ) -> T:
885
- for validator in ctx.deps.output_validators:
886
- run_context = build_run_context(ctx)
887
- result_data = await validator.validate(result_data, tool_call, run_context)
888
- return result_data
782
+ return (part, extra_parts)
889
783
 
890
784
 
891
785
  @dataclasses.dataclass
@@ -921,14 +815,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
921
815
  If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
922
816
  `messages` will represent the messages exchanged during the first call only.
923
817
  """
818
+ token = None
819
+ messages: list[_messages.ModelMessage] = []
820
+
821
+ # Try to reuse existing message context if available
924
822
  try:
925
- yield _messages_ctx_var.get().messages
823
+ messages = _messages_ctx_var.get().messages
926
824
  except LookupError:
927
- messages: list[_messages.ModelMessage] = []
825
+ # No existing context, create a new one
928
826
  token = _messages_ctx_var.set(_RunMessages(messages))
929
- try:
930
- yield messages
931
- finally:
827
+
828
+ try:
829
+ yield messages
830
+ finally:
831
+ # Clean up context if we created it
832
+ if token is not None:
932
833
  _messages_ctx_var.reset(token)
933
834
 
934
835