pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (55) hide show
  1. pydantic_ai/_agent_graph.py +219 -315
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +296 -226
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -155
  10. pydantic_ai/common_tools/duckduckgo.py +5 -2
  11. pydantic_ai/exceptions.py +14 -2
  12. pydantic_ai/ext/aci.py +12 -3
  13. pydantic_ai/ext/langchain.py +9 -1
  14. pydantic_ai/mcp.py +147 -84
  15. pydantic_ai/messages.py +19 -9
  16. pydantic_ai/models/__init__.py +43 -19
  17. pydantic_ai/models/anthropic.py +2 -2
  18. pydantic_ai/models/bedrock.py +1 -1
  19. pydantic_ai/models/cohere.py +1 -1
  20. pydantic_ai/models/function.py +50 -24
  21. pydantic_ai/models/gemini.py +3 -11
  22. pydantic_ai/models/google.py +3 -12
  23. pydantic_ai/models/groq.py +2 -1
  24. pydantic_ai/models/huggingface.py +463 -0
  25. pydantic_ai/models/instrumented.py +1 -1
  26. pydantic_ai/models/mistral.py +3 -3
  27. pydantic_ai/models/openai.py +5 -5
  28. pydantic_ai/output.py +21 -7
  29. pydantic_ai/profiles/google.py +1 -1
  30. pydantic_ai/profiles/moonshotai.py +8 -0
  31. pydantic_ai/providers/__init__.py +4 -0
  32. pydantic_ai/providers/google.py +2 -2
  33. pydantic_ai/providers/google_vertex.py +10 -5
  34. pydantic_ai/providers/grok.py +13 -1
  35. pydantic_ai/providers/groq.py +2 -0
  36. pydantic_ai/providers/huggingface.py +88 -0
  37. pydantic_ai/result.py +57 -33
  38. pydantic_ai/tools.py +26 -119
  39. pydantic_ai/toolsets/__init__.py +22 -0
  40. pydantic_ai/toolsets/abstract.py +155 -0
  41. pydantic_ai/toolsets/combined.py +88 -0
  42. pydantic_ai/toolsets/deferred.py +38 -0
  43. pydantic_ai/toolsets/filtered.py +24 -0
  44. pydantic_ai/toolsets/function.py +238 -0
  45. pydantic_ai/toolsets/prefixed.py +37 -0
  46. pydantic_ai/toolsets/prepared.py +36 -0
  47. pydantic_ai/toolsets/renamed.py +42 -0
  48. pydantic_ai/toolsets/wrapper.py +37 -0
  49. pydantic_ai/usage.py +14 -8
  50. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
  51. pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
  52. pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
  53. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  54. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  55. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import asyncio
4
4
  import dataclasses
5
5
  import hashlib
6
+ from collections import defaultdict, deque
6
7
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
8
  from contextlib import asynccontextmanager, contextmanager
8
9
  from contextvars import ContextVar
@@ -13,17 +14,18 @@ from opentelemetry.trace import Tracer
13
14
  from typing_extensions import TypeGuard, TypeVar, assert_never
14
15
 
15
16
  from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
17
+ from pydantic_ai._tool_manager import ToolManager
16
18
  from pydantic_ai._utils import is_async_callable, run_in_executor
17
19
  from pydantic_graph import BaseNode, Graph, GraphRunContext
18
20
  from pydantic_graph.nodes import End, NodeRunEndT
19
21
 
20
22
  from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
23
+ from .exceptions import ToolRetryError
21
24
  from .output import OutputDataT, OutputSpec
22
25
  from .settings import ModelSettings, merge_model_settings
23
- from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
26
+ from .tools import RunContext, ToolDefinition, ToolKind
24
27
 
25
28
  if TYPE_CHECKING:
26
- from .mcp import MCPServer
27
29
  from .models.instrumented import InstrumentationSettings
28
30
 
29
31
  __all__ = (
@@ -77,11 +79,13 @@ class GraphAgentState:
77
79
  retries: int
78
80
  run_step: int
79
81
 
80
- def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None:
82
+ def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
81
83
  self.retries += 1
82
84
  if self.retries > max_result_retries:
83
- message = f'Exceeded maximum retries ({max_result_retries}) for result validation'
85
+ message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
84
86
  if error:
87
+ if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
88
+ error = error.__cause__
85
89
  raise exceptions.UnexpectedModelBehavior(message) from error
86
90
  else:
87
91
  raise exceptions.UnexpectedModelBehavior(message)
@@ -108,15 +112,11 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
108
112
 
109
113
  history_processors: Sequence[HistoryProcessor[DepsT]]
110
114
 
111
- function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
112
- mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
113
- default_retries: int
115
+ tool_manager: ToolManager[DepsT]
114
116
 
115
117
  tracer: Tracer
116
118
  instrumentation_settings: InstrumentationSettings | None = None
117
119
 
118
- prepare_tools: ToolsPrepareFunc[DepsT] | None = None
119
-
120
120
 
121
121
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
122
122
  """The base class for all agent nodes.
@@ -248,59 +248,27 @@ async def _prepare_request_parameters(
248
248
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
249
249
  ) -> models.ModelRequestParameters:
250
250
  """Build tools and create an agent model."""
251
- function_tool_defs_map: dict[str, ToolDefinition] = {}
252
-
253
251
  run_context = build_run_context(ctx)
254
-
255
- async def add_tool(tool: Tool[DepsT]) -> None:
256
- ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
257
- if tool_def := await tool.prepare_tool_def(ctx):
258
- # prepare_tool_def may change tool_def.name
259
- if tool_def.name in function_tool_defs_map:
260
- if tool_def.name != tool.name:
261
- # Prepare tool def may have renamed the tool
262
- raise exceptions.UserError(
263
- f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool."
264
- )
265
- else:
266
- raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.')
267
- function_tool_defs_map[tool_def.name] = tool_def
268
-
269
- async def add_mcp_server_tools(server: MCPServer) -> None:
270
- if not server.is_running:
271
- raise exceptions.UserError(f'MCP server is not running: {server}')
272
- tool_defs = await server.list_tools()
273
- for tool_def in tool_defs:
274
- if tool_def.name in function_tool_defs_map:
275
- raise exceptions.UserError(
276
- f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts."
277
- )
278
- function_tool_defs_map[tool_def.name] = tool_def
279
-
280
- await asyncio.gather(
281
- *map(add_tool, ctx.deps.function_tools.values()),
282
- *map(add_mcp_server_tools, ctx.deps.mcp_servers),
283
- )
284
- function_tool_defs = list(function_tool_defs_map.values())
285
- if ctx.deps.prepare_tools:
286
- # Prepare the tools using the provided function
287
- # This also acts over tool definitions pulled from MCP servers
288
- function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
252
+ ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
289
253
 
290
254
  output_schema = ctx.deps.output_schema
291
-
292
- output_tools = []
293
255
  output_object = None
294
- if isinstance(output_schema, _output.ToolOutputSchema):
295
- output_tools = output_schema.tool_defs()
296
- elif isinstance(output_schema, _output.NativeOutputSchema):
256
+ if isinstance(output_schema, _output.NativeOutputSchema):
297
257
  output_object = output_schema.object_def
298
258
 
299
259
  # ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema
300
260
  allow_text_output = isinstance(output_schema, _output.TextOutputSchema)
301
261
 
262
+ function_tools: list[ToolDefinition] = []
263
+ output_tools: list[ToolDefinition] = []
264
+ for tool_def in ctx.deps.tool_manager.tool_defs:
265
+ if tool_def.kind == 'output':
266
+ output_tools.append(tool_def)
267
+ else:
268
+ function_tools.append(tool_def)
269
+
302
270
  return models.ModelRequestParameters(
303
- function_tools=function_tool_defs,
271
+ function_tools=function_tools,
304
272
  output_mode=output_schema.mode,
305
273
  output_tools=output_tools,
306
274
  output_object=output_object,
@@ -342,6 +310,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
342
310
  ctx.deps.output_validators,
343
311
  build_run_context(ctx),
344
312
  ctx.deps.usage_limits,
313
+ ctx.deps.tool_manager,
345
314
  )
346
315
  yield agent_stream
347
316
  # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -437,7 +406,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
437
406
  _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
438
407
  default=None, repr=False
439
408
  )
440
- _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
441
409
 
442
410
  async def run(
443
411
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -519,46 +487,30 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
519
487
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
520
488
  tool_calls: list[_messages.ToolCallPart],
521
489
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
522
- output_schema = ctx.deps.output_schema
523
490
  run_context = build_run_context(ctx)
524
491
 
525
- final_result: result.FinalResult[NodeRunEndT] | None = None
526
- parts: list[_messages.ModelRequestPart] = []
492
+ output_parts: list[_messages.ModelRequestPart] = []
493
+ output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
527
494
 
528
- # first, look for the output tool call
529
- if isinstance(output_schema, _output.ToolOutputSchema):
530
- for call, output_tool in output_schema.find_tool(tool_calls):
531
- try:
532
- result_data = await output_tool.process(call, run_context)
533
- result_data = await _validate_output(result_data, ctx, call)
534
- except _output.ToolRetryError as e:
535
- # TODO: Should only increment retry stuff once per node execution, not for each tool call
536
- # Also, should increment the tool-specific retry count rather than the run retry count
537
- ctx.state.increment_retries(ctx.deps.max_result_retries, e)
538
- parts.append(e.tool_retry)
539
- else:
540
- final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
541
- break
542
-
543
- # Then build the other request parts based on end strategy
544
- tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
545
495
  async for event in process_function_tools(
546
- tool_calls,
547
- final_result and final_result.tool_name,
548
- final_result and final_result.tool_call_id,
549
- ctx,
550
- tool_responses,
496
+ ctx.deps.tool_manager, tool_calls, None, ctx, output_parts, output_final_result
551
497
  ):
552
498
  yield event
553
499
 
554
- if final_result:
555
- self._next_node = self._handle_final_result(ctx, final_result, tool_responses)
500
+ if output_final_result:
501
+ final_result = output_final_result[0]
502
+ self._next_node = self._handle_final_result(ctx, final_result, output_parts)
503
+ elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
504
+ if not ctx.deps.output_schema.allows_deferred_tool_calls:
505
+ raise exceptions.UserError(
506
+ 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
507
+ )
508
+ final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
509
+ self._next_node = self._handle_final_result(ctx, final_result, output_parts)
556
510
  else:
557
- if tool_responses:
558
- parts.extend(tool_responses)
559
511
  instructions = await ctx.deps.get_instructions(run_context)
560
512
  self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
561
- _messages.ModelRequest(parts=parts, instructions=instructions)
513
+ _messages.ModelRequest(parts=output_parts, instructions=instructions)
562
514
  )
563
515
 
564
516
  def _handle_final_result(
@@ -584,17 +536,18 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
584
536
 
585
537
  text = '\n\n'.join(texts)
586
538
  try:
539
+ run_context = build_run_context(ctx)
587
540
  if isinstance(output_schema, _output.TextOutputSchema):
588
- run_context = build_run_context(ctx)
589
541
  result_data = await output_schema.process(text, run_context)
590
542
  else:
591
543
  m = _messages.RetryPromptPart(
592
544
  content='Plain text responses are not permitted, please include your response in a tool call',
593
545
  )
594
- raise _output.ToolRetryError(m)
546
+ raise ToolRetryError(m)
595
547
 
596
- result_data = await _validate_output(result_data, ctx, None)
597
- except _output.ToolRetryError as e:
548
+ for validator in ctx.deps.output_validators:
549
+ result_data = await validator.validate(result_data, run_context)
550
+ except ToolRetryError as e:
598
551
  ctx.state.increment_retries(ctx.deps.max_result_retries, e)
599
552
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
600
553
  else:
@@ -609,6 +562,9 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
609
562
  usage=ctx.state.usage,
610
563
  prompt=ctx.deps.prompt,
611
564
  messages=ctx.state.message_history,
565
+ tracer=ctx.deps.tracer,
566
+ trace_include_content=ctx.deps.instrumentation_settings is not None
567
+ and ctx.deps.instrumentation_settings.include_content,
612
568
  run_step=ctx.state.run_step,
613
569
  )
614
570
 
@@ -620,269 +576,210 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
620
576
  return hashlib.sha1(identifier).hexdigest()[:6]
621
577
 
622
578
 
623
- async def process_function_tools( # noqa C901
579
+ async def process_function_tools( # noqa: C901
580
+ tool_manager: ToolManager[DepsT],
624
581
  tool_calls: list[_messages.ToolCallPart],
625
- output_tool_name: str | None,
626
- output_tool_call_id: str | None,
582
+ final_result: result.FinalResult[NodeRunEndT] | None,
627
583
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
628
584
  output_parts: list[_messages.ModelRequestPart],
585
+ output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1),
629
586
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
630
587
  """Process function (i.e., non-result) tool calls in parallel.
631
588
 
632
589
  Also add stub return parts for any other tools that need it.
633
590
 
634
- Because async iterators can't have return values, we use `output_parts` as an output argument.
591
+ Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments.
635
592
  """
636
- stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early'
637
- output_schema = ctx.deps.output_schema
638
-
639
- # we rely on the fact that if we found a result, it's the first output tool in the last
640
- found_used_output_tool = False
641
- run_context = build_run_context(ctx)
642
-
643
- calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
593
+ tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
644
594
  for call in tool_calls:
645
- if (
646
- call.tool_name == output_tool_name
647
- and call.tool_call_id == output_tool_call_id
648
- and not found_used_output_tool
649
- ):
650
- found_used_output_tool = True
651
- output_parts.append(
652
- _messages.ToolReturnPart(
595
+ tool_def = tool_manager.get_tool_def(call.tool_name)
596
+ kind = tool_def.kind if tool_def else 'unknown'
597
+ tool_calls_by_kind[kind].append(call)
598
+
599
+ # First, we handle output tool calls
600
+ for call in tool_calls_by_kind['output']:
601
+ if final_result:
602
+ if final_result.tool_call_id == call.tool_call_id:
603
+ part = _messages.ToolReturnPart(
653
604
  tool_name=call.tool_name,
654
605
  content='Final result processed.',
655
606
  tool_call_id=call.tool_call_id,
656
607
  )
657
- )
658
- elif tool := ctx.deps.function_tools.get(call.tool_name):
659
- if stub_function_tools:
660
- output_parts.append(
661
- _messages.ToolReturnPart(
662
- tool_name=call.tool_name,
663
- content='Tool not executed - a final result was already processed.',
664
- tool_call_id=call.tool_call_id,
665
- )
666
- )
667
608
  else:
668
- event = _messages.FunctionToolCallEvent(call)
669
- yield event
670
- calls_to_run.append((tool, call))
671
- elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
672
- if stub_function_tools:
673
- # TODO(Marcelo): We should add coverage for this part of the code.
674
- output_parts.append( # pragma: no cover
675
- _messages.ToolReturnPart(
676
- tool_name=call.tool_name,
677
- content='Tool not executed - a final result was already processed.',
678
- tool_call_id=call.tool_call_id,
679
- )
680
- )
681
- else:
682
- event = _messages.FunctionToolCallEvent(call)
683
- yield event
684
- calls_to_run.append((mcp_tool, call))
685
- elif call.tool_name in output_schema.tools:
686
- # if tool_name is in output_schema, it means we found a output tool but an error occurred in
687
- # validation, we don't add another part here
688
- if output_tool_name is not None:
689
609
  yield _messages.FunctionToolCallEvent(call)
690
- if found_used_output_tool:
691
- content = 'Output tool not used - a final result was already processed.'
692
- else:
693
- # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
694
- content = 'Output tool not used - result failed validation.'
695
610
  part = _messages.ToolReturnPart(
696
611
  tool_name=call.tool_name,
697
- content=content,
612
+ content='Output tool not used - a final result was already processed.',
698
613
  tool_call_id=call.tool_call_id,
699
614
  )
700
615
  yield _messages.FunctionToolResultEvent(part)
701
- output_parts.append(part)
702
- else:
703
- yield _messages.FunctionToolCallEvent(call)
704
616
 
705
- part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
706
- yield _messages.FunctionToolResultEvent(part)
707
617
  output_parts.append(part)
618
+ else:
619
+ try:
620
+ result_data = await tool_manager.handle_call(call)
621
+ except exceptions.UnexpectedModelBehavior as e:
622
+ ctx.state.increment_retries(ctx.deps.max_result_retries, e)
623
+ raise e # pragma: no cover
624
+ except ToolRetryError as e:
625
+ ctx.state.increment_retries(ctx.deps.max_result_retries, e)
626
+ yield _messages.FunctionToolCallEvent(call)
627
+ output_parts.append(e.tool_retry)
628
+ yield _messages.FunctionToolResultEvent(e.tool_retry)
629
+ else:
630
+ part = _messages.ToolReturnPart(
631
+ tool_name=call.tool_name,
632
+ content='Final result processed.',
633
+ tool_call_id=call.tool_call_id,
634
+ )
635
+ output_parts.append(part)
636
+ final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
708
637
 
709
- if not calls_to_run:
710
- return
711
-
712
- user_parts: list[_messages.UserPromptPart] = []
638
+ # Then, we handle function tool calls
639
+ calls_to_run: list[_messages.ToolCallPart] = []
640
+ if final_result and ctx.deps.end_strategy == 'early':
641
+ output_parts.extend(
642
+ [
643
+ _messages.ToolReturnPart(
644
+ tool_name=call.tool_name,
645
+ content='Tool not executed - a final result was already processed.',
646
+ tool_call_id=call.tool_call_id,
647
+ )
648
+ for call in tool_calls_by_kind['function']
649
+ ]
650
+ )
651
+ else:
652
+ calls_to_run.extend(tool_calls_by_kind['function'])
713
653
 
714
- include_content = (
715
- ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
716
- )
654
+ # Then, we handle unknown tool calls
655
+ if tool_calls_by_kind['unknown']:
656
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
657
+ calls_to_run.extend(tool_calls_by_kind['unknown'])
717
658
 
718
- # Run all tool tasks in parallel
719
- results_by_index: dict[int, _messages.ModelRequestPart] = {}
720
- with ctx.deps.tracer.start_as_current_span(
721
- 'running tools',
722
- attributes={
723
- 'tools': [call.tool_name for _, call in calls_to_run],
724
- 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
725
- },
726
- ):
727
- tasks = [
728
- asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer, include_content), name=call.tool_name)
729
- for tool, call in calls_to_run
730
- ]
731
-
732
- pending = tasks
733
- while pending:
734
- done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
735
- for task in done:
736
- index = tasks.index(task)
737
- result = task.result()
738
- yield _messages.FunctionToolResultEvent(result)
739
-
740
- if isinstance(result, _messages.RetryPromptPart):
741
- results_by_index[index] = result
742
- elif isinstance(result, _messages.ToolReturnPart):
743
- if isinstance(result.content, _messages.ToolReturn):
744
- tool_return = result.content
745
- if (
746
- isinstance(tool_return.return_value, _messages.MultiModalContentTypes)
747
- or isinstance(tool_return.return_value, list)
748
- and any(
749
- isinstance(content, _messages.MultiModalContentTypes)
750
- for content in tool_return.return_value # type: ignore
751
- )
752
- ):
753
- raise exceptions.UserError(
754
- f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. "
755
- f'Please use `content` instead.'
756
- )
757
- result.content = tool_return.return_value # type: ignore
758
- result.metadata = tool_return.metadata
759
- if tool_return.content:
760
- user_parts.append(
761
- _messages.UserPromptPart(
762
- content=list(tool_return.content),
763
- timestamp=result.timestamp,
764
- part_kind='user-prompt',
765
- )
766
- )
767
- contents: list[Any]
768
- single_content: bool
769
- if isinstance(result.content, list):
770
- contents = result.content # type: ignore
771
- single_content = False
772
- else:
773
- contents = [result.content]
774
- single_content = True
775
-
776
- processed_contents: list[Any] = []
777
- for content in contents:
778
- if isinstance(content, _messages.ToolReturn):
779
- raise exceptions.UserError(
780
- f"{result.tool_name}'s return contains invalid nested ToolReturn objects. "
781
- f'ToolReturn should be used directly.'
782
- )
783
- elif isinstance(content, _messages.MultiModalContentTypes):
784
- # Handle direct multimodal content
785
- if isinstance(content, _messages.BinaryContent):
786
- identifier = multi_modal_content_identifier(content.data)
787
- else:
788
- identifier = multi_modal_content_identifier(content.url)
789
-
790
- user_parts.append(
791
- _messages.UserPromptPart(
792
- content=[f'This is file {identifier}:', content],
793
- timestamp=result.timestamp,
794
- part_kind='user-prompt',
795
- )
796
- )
797
- processed_contents.append(f'See file {identifier}')
798
- else:
799
- # Handle regular content
800
- processed_contents.append(content)
801
-
802
- if single_content:
803
- result.content = processed_contents[0]
804
- else:
805
- result.content = processed_contents
659
+ for call in calls_to_run:
660
+ yield _messages.FunctionToolCallEvent(call)
806
661
 
807
- results_by_index[index] = result
808
- else:
809
- assert_never(result)
662
+ user_parts: list[_messages.UserPromptPart] = []
810
663
 
811
- # We append the results at the end, rather than as they are received, to retain a consistent ordering
812
- # This is mostly just to simplify testing
813
- for k in sorted(results_by_index):
814
- output_parts.append(results_by_index[k])
664
+ if calls_to_run:
665
+ # Run all tool tasks in parallel
666
+ parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {}
667
+ with ctx.deps.tracer.start_as_current_span(
668
+ 'running tools',
669
+ attributes={
670
+ 'tools': [call.tool_name for call in calls_to_run],
671
+ 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
672
+ },
673
+ ):
674
+ tasks = [
675
+ asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
676
+ for call in calls_to_run
677
+ ]
678
+
679
+ pending = tasks
680
+ while pending:
681
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
682
+ for task in done:
683
+ index = tasks.index(task)
684
+ tool_result_part, extra_parts = task.result()
685
+ yield _messages.FunctionToolResultEvent(tool_result_part)
686
+
687
+ parts_by_index[index] = [tool_result_part, *extra_parts]
688
+
689
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
690
+ # This is mostly just to simplify testing
691
+ for k in sorted(parts_by_index):
692
+ output_parts.extend(parts_by_index[k])
693
+
694
+ # Finally, we handle deferred tool calls
695
+ for call in tool_calls_by_kind['deferred']:
696
+ if final_result:
697
+ output_parts.append(
698
+ _messages.ToolReturnPart(
699
+ tool_name=call.tool_name,
700
+ content='Tool not executed - a final result was already processed.',
701
+ tool_call_id=call.tool_call_id,
702
+ )
703
+ )
704
+ else:
705
+ yield _messages.FunctionToolCallEvent(call)
815
706
 
816
707
  output_parts.extend(user_parts)
817
708
 
709
+ if final_result:
710
+ output_final_result.append(final_result)
818
711
 
819
- async def _tool_from_mcp_server(
820
- tool_name: str,
821
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
822
- ) -> Tool[DepsT] | None:
823
- """Call each MCP server to find the tool with the given name.
824
-
825
- Args:
826
- tool_name: The name of the tool to find.
827
- ctx: The current run context.
828
-
829
- Returns:
830
- The tool with the given name, or `None` if no tool with the given name is found.
831
- """
832
-
833
- async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
834
- # There's no normal situation where the server will not be running at this point, we check just in case
835
- # some weird edge case occurs.
836
- if not server.is_running: # pragma: no cover
837
- raise exceptions.UserError(f'MCP server is not running: {server}')
838
712
 
839
- if server.process_tool_call is not None:
840
- result = await server.process_tool_call(ctx, server.call_tool, tool_name, args)
841
- else:
842
- result = await server.call_tool(tool_name, args)
843
-
844
- return result
845
-
846
- for server in ctx.deps.mcp_servers:
847
- tools = await server.list_tools()
848
- if tool_name in {tool.name for tool in tools}: # pragma: no branch
849
- return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries)
850
- return None
713
+ async def _call_function_tool(
714
+ tool_manager: ToolManager[DepsT],
715
+ tool_call: _messages.ToolCallPart,
716
+ ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]:
717
+ try:
718
+ tool_result = await tool_manager.handle_call(tool_call)
719
+ except ToolRetryError as e:
720
+ return (e.tool_retry, [])
721
+
722
+ part = _messages.ToolReturnPart(
723
+ tool_name=tool_call.tool_name,
724
+ content=tool_result,
725
+ tool_call_id=tool_call.tool_call_id,
726
+ )
727
+ extra_parts: list[_messages.ModelRequestPart] = []
851
728
 
729
+ if isinstance(tool_result, _messages.ToolReturn):
730
+ if (
731
+ isinstance(tool_result.return_value, _messages.MultiModalContentTypes)
732
+ or isinstance(tool_result.return_value, list)
733
+ and any(
734
+ isinstance(content, _messages.MultiModalContentTypes)
735
+ for content in tool_result.return_value # type: ignore
736
+ )
737
+ ):
738
+ raise exceptions.UserError(
739
+ f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. '
740
+ f'Please use `content` instead.'
741
+ )
852
742
 
853
- def _unknown_tool(
854
- tool_name: str,
855
- tool_call_id: str,
856
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
857
- ) -> _messages.RetryPromptPart:
858
- ctx.state.increment_retries(ctx.deps.max_result_retries)
859
- tool_names = list(ctx.deps.function_tools.keys())
743
+ part.content = tool_result.return_value # type: ignore
744
+ part.metadata = tool_result.metadata
745
+ if tool_result.content:
746
+ extra_parts.append(
747
+ _messages.UserPromptPart(
748
+ content=list(tool_result.content),
749
+ part_kind='user-prompt',
750
+ )
751
+ )
752
+ else:
860
753
 
861
- output_schema = ctx.deps.output_schema
862
- if isinstance(output_schema, _output.ToolOutputSchema):
863
- tool_names.extend(output_schema.tool_names())
754
+ def process_content(content: Any) -> Any:
755
+ if isinstance(content, _messages.ToolReturn):
756
+ raise exceptions.UserError(
757
+ f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
758
+ f'`ToolReturn` should be used directly.'
759
+ )
760
+ elif isinstance(content, _messages.MultiModalContentTypes):
761
+ if isinstance(content, _messages.BinaryContent):
762
+ identifier = content.identifier or multi_modal_content_identifier(content.data)
763
+ else:
764
+ identifier = multi_modal_content_identifier(content.url)
864
765
 
865
- if tool_names:
866
- msg = f'Available tools: {", ".join(tool_names)}'
867
- else:
868
- msg = 'No tools available.'
766
+ extra_parts.append(
767
+ _messages.UserPromptPart(
768
+ content=[f'This is file {identifier}:', content],
769
+ part_kind='user-prompt',
770
+ )
771
+ )
772
+ return f'See file {identifier}'
869
773
 
870
- return _messages.RetryPromptPart(
871
- tool_name=tool_name,
872
- tool_call_id=tool_call_id,
873
- content=f'Unknown tool name: {tool_name!r}. {msg}',
874
- )
774
+ return content
875
775
 
776
+ if isinstance(tool_result, list):
777
+ contents = cast(list[Any], tool_result)
778
+ part.content = [process_content(content) for content in contents]
779
+ else:
780
+ part.content = process_content(tool_result)
876
781
 
877
- async def _validate_output(
878
- result_data: T,
879
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
880
- tool_call: _messages.ToolCallPart | None,
881
- ) -> T:
882
- for validator in ctx.deps.output_validators:
883
- run_context = build_run_context(ctx)
884
- result_data = await validator.validate(result_data, tool_call, run_context)
885
- return result_data
782
+ return (part, extra_parts)
886
783
 
887
784
 
888
785
  @dataclasses.dataclass
@@ -918,14 +815,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
918
815
  If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
919
816
  `messages` will represent the messages exchanged during the first call only.
920
817
  """
818
+ token = None
819
+ messages: list[_messages.ModelMessage] = []
820
+
821
+ # Try to reuse existing message context if available
921
822
  try:
922
- yield _messages_ctx_var.get().messages
823
+ messages = _messages_ctx_var.get().messages
923
824
  except LookupError:
924
- messages: list[_messages.ModelMessage] = []
825
+ # No existing context, create a new one
925
826
  token = _messages_ctx_var.set(_RunMessages(messages))
926
- try:
927
- yield messages
928
- finally:
827
+
828
+ try:
829
+ yield messages
830
+ finally:
831
+ # Clean up context if we created it
832
+ if token is not None:
929
833
  _messages_ctx_var.reset(token)
930
834
 
931
835