pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -156
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -9
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +3 -3
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
5
|
import hashlib
|
|
6
|
+
from collections import defaultdict, deque
|
|
6
7
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
8
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
9
|
from contextvars import ContextVar
|
|
@@ -13,17 +14,18 @@ from opentelemetry.trace import Tracer
|
|
|
13
14
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
14
15
|
|
|
15
16
|
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
|
|
17
|
+
from pydantic_ai._tool_manager import ToolManager
|
|
16
18
|
from pydantic_ai._utils import is_async_callable, run_in_executor
|
|
17
19
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
18
20
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
19
21
|
|
|
20
22
|
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
|
|
23
|
+
from .exceptions import ToolRetryError
|
|
21
24
|
from .output import OutputDataT, OutputSpec
|
|
22
25
|
from .settings import ModelSettings, merge_model_settings
|
|
23
|
-
from .tools import RunContext,
|
|
26
|
+
from .tools import RunContext, ToolDefinition, ToolKind
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
26
|
-
from .mcp import MCPServer
|
|
27
29
|
from .models.instrumented import InstrumentationSettings
|
|
28
30
|
|
|
29
31
|
__all__ = (
|
|
@@ -77,11 +79,13 @@ class GraphAgentState:
|
|
|
77
79
|
retries: int
|
|
78
80
|
run_step: int
|
|
79
81
|
|
|
80
|
-
def increment_retries(self, max_result_retries: int, error:
|
|
82
|
+
def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
|
|
81
83
|
self.retries += 1
|
|
82
84
|
if self.retries > max_result_retries:
|
|
83
|
-
message = f'Exceeded maximum retries ({max_result_retries}) for
|
|
85
|
+
message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
|
|
84
86
|
if error:
|
|
87
|
+
if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
|
|
88
|
+
error = error.__cause__
|
|
85
89
|
raise exceptions.UnexpectedModelBehavior(message) from error
|
|
86
90
|
else:
|
|
87
91
|
raise exceptions.UnexpectedModelBehavior(message)
|
|
@@ -108,15 +112,11 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
108
112
|
|
|
109
113
|
history_processors: Sequence[HistoryProcessor[DepsT]]
|
|
110
114
|
|
|
111
|
-
|
|
112
|
-
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
113
|
-
default_retries: int
|
|
115
|
+
tool_manager: ToolManager[DepsT]
|
|
114
116
|
|
|
115
117
|
tracer: Tracer
|
|
116
118
|
instrumentation_settings: InstrumentationSettings | None = None
|
|
117
119
|
|
|
118
|
-
prepare_tools: ToolsPrepareFunc[DepsT] | None = None
|
|
119
|
-
|
|
120
120
|
|
|
121
121
|
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
122
122
|
"""The base class for all agent nodes.
|
|
@@ -248,59 +248,27 @@ async def _prepare_request_parameters(
|
|
|
248
248
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
249
249
|
) -> models.ModelRequestParameters:
|
|
250
250
|
"""Build tools and create an agent model."""
|
|
251
|
-
function_tool_defs_map: dict[str, ToolDefinition] = {}
|
|
252
|
-
|
|
253
251
|
run_context = build_run_context(ctx)
|
|
254
|
-
|
|
255
|
-
async def add_tool(tool: Tool[DepsT]) -> None:
|
|
256
|
-
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
257
|
-
if tool_def := await tool.prepare_tool_def(ctx):
|
|
258
|
-
# prepare_tool_def may change tool_def.name
|
|
259
|
-
if tool_def.name in function_tool_defs_map:
|
|
260
|
-
if tool_def.name != tool.name:
|
|
261
|
-
# Prepare tool def may have renamed the tool
|
|
262
|
-
raise exceptions.UserError(
|
|
263
|
-
f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool."
|
|
264
|
-
)
|
|
265
|
-
else:
|
|
266
|
-
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.')
|
|
267
|
-
function_tool_defs_map[tool_def.name] = tool_def
|
|
268
|
-
|
|
269
|
-
async def add_mcp_server_tools(server: MCPServer) -> None:
|
|
270
|
-
if not server.is_running:
|
|
271
|
-
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
272
|
-
tool_defs = await server.list_tools()
|
|
273
|
-
for tool_def in tool_defs:
|
|
274
|
-
if tool_def.name in function_tool_defs_map:
|
|
275
|
-
raise exceptions.UserError(
|
|
276
|
-
f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts."
|
|
277
|
-
)
|
|
278
|
-
function_tool_defs_map[tool_def.name] = tool_def
|
|
279
|
-
|
|
280
|
-
await asyncio.gather(
|
|
281
|
-
*map(add_tool, ctx.deps.function_tools.values()),
|
|
282
|
-
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
|
|
283
|
-
)
|
|
284
|
-
function_tool_defs = list(function_tool_defs_map.values())
|
|
285
|
-
if ctx.deps.prepare_tools:
|
|
286
|
-
# Prepare the tools using the provided function
|
|
287
|
-
# This also acts over tool definitions pulled from MCP servers
|
|
288
|
-
function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
|
|
252
|
+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
289
253
|
|
|
290
254
|
output_schema = ctx.deps.output_schema
|
|
291
|
-
|
|
292
|
-
output_tools = []
|
|
293
255
|
output_object = None
|
|
294
|
-
if isinstance(output_schema, _output.
|
|
295
|
-
output_tools = output_schema.tool_defs()
|
|
296
|
-
elif isinstance(output_schema, _output.NativeOutputSchema):
|
|
256
|
+
if isinstance(output_schema, _output.NativeOutputSchema):
|
|
297
257
|
output_object = output_schema.object_def
|
|
298
258
|
|
|
299
259
|
# ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema
|
|
300
260
|
allow_text_output = isinstance(output_schema, _output.TextOutputSchema)
|
|
301
261
|
|
|
262
|
+
function_tools: list[ToolDefinition] = []
|
|
263
|
+
output_tools: list[ToolDefinition] = []
|
|
264
|
+
for tool_def in ctx.deps.tool_manager.tool_defs:
|
|
265
|
+
if tool_def.kind == 'output':
|
|
266
|
+
output_tools.append(tool_def)
|
|
267
|
+
else:
|
|
268
|
+
function_tools.append(tool_def)
|
|
269
|
+
|
|
302
270
|
return models.ModelRequestParameters(
|
|
303
|
-
function_tools=
|
|
271
|
+
function_tools=function_tools,
|
|
304
272
|
output_mode=output_schema.mode,
|
|
305
273
|
output_tools=output_tools,
|
|
306
274
|
output_object=output_object,
|
|
@@ -341,8 +309,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
341
309
|
ctx.deps.output_schema,
|
|
342
310
|
ctx.deps.output_validators,
|
|
343
311
|
build_run_context(ctx),
|
|
344
|
-
_output.build_trace_context(ctx),
|
|
345
312
|
ctx.deps.usage_limits,
|
|
313
|
+
ctx.deps.tool_manager,
|
|
346
314
|
)
|
|
347
315
|
yield agent_stream
|
|
348
316
|
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
@@ -438,7 +406,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
438
406
|
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
|
|
439
407
|
default=None, repr=False
|
|
440
408
|
)
|
|
441
|
-
_tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
|
|
442
409
|
|
|
443
410
|
async def run(
|
|
444
411
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
@@ -520,47 +487,30 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
520
487
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
521
488
|
tool_calls: list[_messages.ToolCallPart],
|
|
522
489
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
523
|
-
output_schema = ctx.deps.output_schema
|
|
524
490
|
run_context = build_run_context(ctx)
|
|
525
491
|
|
|
526
|
-
|
|
527
|
-
|
|
492
|
+
output_parts: list[_messages.ModelRequestPart] = []
|
|
493
|
+
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
|
|
528
494
|
|
|
529
|
-
# first, look for the output tool call
|
|
530
|
-
if isinstance(output_schema, _output.ToolOutputSchema):
|
|
531
|
-
for call, output_tool in output_schema.find_tool(tool_calls):
|
|
532
|
-
try:
|
|
533
|
-
trace_context = _output.build_trace_context(ctx)
|
|
534
|
-
result_data = await output_tool.process(call, run_context, trace_context)
|
|
535
|
-
result_data = await _validate_output(result_data, ctx, call)
|
|
536
|
-
except _output.ToolRetryError as e:
|
|
537
|
-
# TODO: Should only increment retry stuff once per node execution, not for each tool call
|
|
538
|
-
# Also, should increment the tool-specific retry count rather than the run retry count
|
|
539
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
540
|
-
parts.append(e.tool_retry)
|
|
541
|
-
else:
|
|
542
|
-
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
|
|
543
|
-
break
|
|
544
|
-
|
|
545
|
-
# Then build the other request parts based on end strategy
|
|
546
|
-
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
|
|
547
495
|
async for event in process_function_tools(
|
|
548
|
-
tool_calls,
|
|
549
|
-
final_result and final_result.tool_name,
|
|
550
|
-
final_result and final_result.tool_call_id,
|
|
551
|
-
ctx,
|
|
552
|
-
tool_responses,
|
|
496
|
+
ctx.deps.tool_manager, tool_calls, None, ctx, output_parts, output_final_result
|
|
553
497
|
):
|
|
554
498
|
yield event
|
|
555
499
|
|
|
556
|
-
if
|
|
557
|
-
|
|
500
|
+
if output_final_result:
|
|
501
|
+
final_result = output_final_result[0]
|
|
502
|
+
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
503
|
+
elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
|
|
504
|
+
if not ctx.deps.output_schema.allows_deferred_tool_calls:
|
|
505
|
+
raise exceptions.UserError(
|
|
506
|
+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
507
|
+
)
|
|
508
|
+
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
|
|
509
|
+
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
558
510
|
else:
|
|
559
|
-
if tool_responses:
|
|
560
|
-
parts.extend(tool_responses)
|
|
561
511
|
instructions = await ctx.deps.get_instructions(run_context)
|
|
562
512
|
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
563
|
-
_messages.ModelRequest(parts=
|
|
513
|
+
_messages.ModelRequest(parts=output_parts, instructions=instructions)
|
|
564
514
|
)
|
|
565
515
|
|
|
566
516
|
def _handle_final_result(
|
|
@@ -586,18 +536,18 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
586
536
|
|
|
587
537
|
text = '\n\n'.join(texts)
|
|
588
538
|
try:
|
|
539
|
+
run_context = build_run_context(ctx)
|
|
589
540
|
if isinstance(output_schema, _output.TextOutputSchema):
|
|
590
|
-
|
|
591
|
-
trace_context = _output.build_trace_context(ctx)
|
|
592
|
-
result_data = await output_schema.process(text, run_context, trace_context)
|
|
541
|
+
result_data = await output_schema.process(text, run_context)
|
|
593
542
|
else:
|
|
594
543
|
m = _messages.RetryPromptPart(
|
|
595
544
|
content='Plain text responses are not permitted, please include your response in a tool call',
|
|
596
545
|
)
|
|
597
|
-
raise
|
|
546
|
+
raise ToolRetryError(m)
|
|
598
547
|
|
|
599
|
-
|
|
600
|
-
|
|
548
|
+
for validator in ctx.deps.output_validators:
|
|
549
|
+
result_data = await validator.validate(result_data, run_context)
|
|
550
|
+
except ToolRetryError as e:
|
|
601
551
|
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
602
552
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
603
553
|
else:
|
|
@@ -612,6 +562,9 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
612
562
|
usage=ctx.state.usage,
|
|
613
563
|
prompt=ctx.deps.prompt,
|
|
614
564
|
messages=ctx.state.message_history,
|
|
565
|
+
tracer=ctx.deps.tracer,
|
|
566
|
+
trace_include_content=ctx.deps.instrumentation_settings is not None
|
|
567
|
+
and ctx.deps.instrumentation_settings.include_content,
|
|
615
568
|
run_step=ctx.state.run_step,
|
|
616
569
|
)
|
|
617
570
|
|
|
@@ -623,269 +576,210 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
|
|
|
623
576
|
return hashlib.sha1(identifier).hexdigest()[:6]
|
|
624
577
|
|
|
625
578
|
|
|
626
|
-
async def process_function_tools( # noqa C901
|
|
579
|
+
async def process_function_tools( # noqa: C901
|
|
580
|
+
tool_manager: ToolManager[DepsT],
|
|
627
581
|
tool_calls: list[_messages.ToolCallPart],
|
|
628
|
-
|
|
629
|
-
output_tool_call_id: str | None,
|
|
582
|
+
final_result: result.FinalResult[NodeRunEndT] | None,
|
|
630
583
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
631
584
|
output_parts: list[_messages.ModelRequestPart],
|
|
585
|
+
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1),
|
|
632
586
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
633
587
|
"""Process function (i.e., non-result) tool calls in parallel.
|
|
634
588
|
|
|
635
589
|
Also add stub return parts for any other tools that need it.
|
|
636
590
|
|
|
637
|
-
Because async iterators can't have return values, we use `output_parts` as
|
|
591
|
+
Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments.
|
|
638
592
|
"""
|
|
639
|
-
|
|
640
|
-
output_schema = ctx.deps.output_schema
|
|
641
|
-
|
|
642
|
-
# we rely on the fact that if we found a result, it's the first output tool in the last
|
|
643
|
-
found_used_output_tool = False
|
|
644
|
-
run_context = build_run_context(ctx)
|
|
645
|
-
|
|
646
|
-
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
593
|
+
tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
|
|
647
594
|
for call in tool_calls:
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
595
|
+
tool_def = tool_manager.get_tool_def(call.tool_name)
|
|
596
|
+
kind = tool_def.kind if tool_def else 'unknown'
|
|
597
|
+
tool_calls_by_kind[kind].append(call)
|
|
598
|
+
|
|
599
|
+
# First, we handle output tool calls
|
|
600
|
+
for call in tool_calls_by_kind['output']:
|
|
601
|
+
if final_result:
|
|
602
|
+
if final_result.tool_call_id == call.tool_call_id:
|
|
603
|
+
part = _messages.ToolReturnPart(
|
|
656
604
|
tool_name=call.tool_name,
|
|
657
605
|
content='Final result processed.',
|
|
658
606
|
tool_call_id=call.tool_call_id,
|
|
659
607
|
)
|
|
660
|
-
)
|
|
661
|
-
elif tool := ctx.deps.function_tools.get(call.tool_name):
|
|
662
|
-
if stub_function_tools:
|
|
663
|
-
output_parts.append(
|
|
664
|
-
_messages.ToolReturnPart(
|
|
665
|
-
tool_name=call.tool_name,
|
|
666
|
-
content='Tool not executed - a final result was already processed.',
|
|
667
|
-
tool_call_id=call.tool_call_id,
|
|
668
|
-
)
|
|
669
|
-
)
|
|
670
608
|
else:
|
|
671
|
-
event = _messages.FunctionToolCallEvent(call)
|
|
672
|
-
yield event
|
|
673
|
-
calls_to_run.append((tool, call))
|
|
674
|
-
elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
|
|
675
|
-
if stub_function_tools:
|
|
676
|
-
# TODO(Marcelo): We should add coverage for this part of the code.
|
|
677
|
-
output_parts.append( # pragma: no cover
|
|
678
|
-
_messages.ToolReturnPart(
|
|
679
|
-
tool_name=call.tool_name,
|
|
680
|
-
content='Tool not executed - a final result was already processed.',
|
|
681
|
-
tool_call_id=call.tool_call_id,
|
|
682
|
-
)
|
|
683
|
-
)
|
|
684
|
-
else:
|
|
685
|
-
event = _messages.FunctionToolCallEvent(call)
|
|
686
|
-
yield event
|
|
687
|
-
calls_to_run.append((mcp_tool, call))
|
|
688
|
-
elif call.tool_name in output_schema.tools:
|
|
689
|
-
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
|
|
690
|
-
# validation, we don't add another part here
|
|
691
|
-
if output_tool_name is not None:
|
|
692
609
|
yield _messages.FunctionToolCallEvent(call)
|
|
693
|
-
if found_used_output_tool:
|
|
694
|
-
content = 'Output tool not used - a final result was already processed.'
|
|
695
|
-
else:
|
|
696
|
-
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
|
|
697
|
-
content = 'Output tool not used - result failed validation.'
|
|
698
610
|
part = _messages.ToolReturnPart(
|
|
699
611
|
tool_name=call.tool_name,
|
|
700
|
-
content=
|
|
612
|
+
content='Output tool not used - a final result was already processed.',
|
|
701
613
|
tool_call_id=call.tool_call_id,
|
|
702
614
|
)
|
|
703
615
|
yield _messages.FunctionToolResultEvent(part)
|
|
704
|
-
output_parts.append(part)
|
|
705
|
-
else:
|
|
706
|
-
yield _messages.FunctionToolCallEvent(call)
|
|
707
616
|
|
|
708
|
-
part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
|
|
709
|
-
yield _messages.FunctionToolResultEvent(part)
|
|
710
617
|
output_parts.append(part)
|
|
618
|
+
else:
|
|
619
|
+
try:
|
|
620
|
+
result_data = await tool_manager.handle_call(call)
|
|
621
|
+
except exceptions.UnexpectedModelBehavior as e:
|
|
622
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
623
|
+
raise e # pragma: no cover
|
|
624
|
+
except ToolRetryError as e:
|
|
625
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
626
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
627
|
+
output_parts.append(e.tool_retry)
|
|
628
|
+
yield _messages.FunctionToolResultEvent(e.tool_retry)
|
|
629
|
+
else:
|
|
630
|
+
part = _messages.ToolReturnPart(
|
|
631
|
+
tool_name=call.tool_name,
|
|
632
|
+
content='Final result processed.',
|
|
633
|
+
tool_call_id=call.tool_call_id,
|
|
634
|
+
)
|
|
635
|
+
output_parts.append(part)
|
|
636
|
+
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
|
|
711
637
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
638
|
+
# Then, we handle function tool calls
|
|
639
|
+
calls_to_run: list[_messages.ToolCallPart] = []
|
|
640
|
+
if final_result and ctx.deps.end_strategy == 'early':
|
|
641
|
+
output_parts.extend(
|
|
642
|
+
[
|
|
643
|
+
_messages.ToolReturnPart(
|
|
644
|
+
tool_name=call.tool_name,
|
|
645
|
+
content='Tool not executed - a final result was already processed.',
|
|
646
|
+
tool_call_id=call.tool_call_id,
|
|
647
|
+
)
|
|
648
|
+
for call in tool_calls_by_kind['function']
|
|
649
|
+
]
|
|
650
|
+
)
|
|
651
|
+
else:
|
|
652
|
+
calls_to_run.extend(tool_calls_by_kind['function'])
|
|
716
653
|
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
654
|
+
# Then, we handle unknown tool calls
|
|
655
|
+
if tool_calls_by_kind['unknown']:
|
|
656
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
657
|
+
calls_to_run.extend(tool_calls_by_kind['unknown'])
|
|
720
658
|
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
with ctx.deps.tracer.start_as_current_span(
|
|
724
|
-
'running tools',
|
|
725
|
-
attributes={
|
|
726
|
-
'tools': [call.tool_name for _, call in calls_to_run],
|
|
727
|
-
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
|
|
728
|
-
},
|
|
729
|
-
):
|
|
730
|
-
tasks = [
|
|
731
|
-
asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer, include_content), name=call.tool_name)
|
|
732
|
-
for tool, call in calls_to_run
|
|
733
|
-
]
|
|
734
|
-
|
|
735
|
-
pending = tasks
|
|
736
|
-
while pending:
|
|
737
|
-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
738
|
-
for task in done:
|
|
739
|
-
index = tasks.index(task)
|
|
740
|
-
result = task.result()
|
|
741
|
-
yield _messages.FunctionToolResultEvent(result)
|
|
742
|
-
|
|
743
|
-
if isinstance(result, _messages.RetryPromptPart):
|
|
744
|
-
results_by_index[index] = result
|
|
745
|
-
elif isinstance(result, _messages.ToolReturnPart):
|
|
746
|
-
if isinstance(result.content, _messages.ToolReturn):
|
|
747
|
-
tool_return = result.content
|
|
748
|
-
if (
|
|
749
|
-
isinstance(tool_return.return_value, _messages.MultiModalContentTypes)
|
|
750
|
-
or isinstance(tool_return.return_value, list)
|
|
751
|
-
and any(
|
|
752
|
-
isinstance(content, _messages.MultiModalContentTypes)
|
|
753
|
-
for content in tool_return.return_value # type: ignore
|
|
754
|
-
)
|
|
755
|
-
):
|
|
756
|
-
raise exceptions.UserError(
|
|
757
|
-
f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. "
|
|
758
|
-
f'Please use `content` instead.'
|
|
759
|
-
)
|
|
760
|
-
result.content = tool_return.return_value # type: ignore
|
|
761
|
-
result.metadata = tool_return.metadata
|
|
762
|
-
if tool_return.content:
|
|
763
|
-
user_parts.append(
|
|
764
|
-
_messages.UserPromptPart(
|
|
765
|
-
content=list(tool_return.content),
|
|
766
|
-
timestamp=result.timestamp,
|
|
767
|
-
part_kind='user-prompt',
|
|
768
|
-
)
|
|
769
|
-
)
|
|
770
|
-
contents: list[Any]
|
|
771
|
-
single_content: bool
|
|
772
|
-
if isinstance(result.content, list):
|
|
773
|
-
contents = result.content # type: ignore
|
|
774
|
-
single_content = False
|
|
775
|
-
else:
|
|
776
|
-
contents = [result.content]
|
|
777
|
-
single_content = True
|
|
778
|
-
|
|
779
|
-
processed_contents: list[Any] = []
|
|
780
|
-
for content in contents:
|
|
781
|
-
if isinstance(content, _messages.ToolReturn):
|
|
782
|
-
raise exceptions.UserError(
|
|
783
|
-
f"{result.tool_name}'s return contains invalid nested ToolReturn objects. "
|
|
784
|
-
f'ToolReturn should be used directly.'
|
|
785
|
-
)
|
|
786
|
-
elif isinstance(content, _messages.MultiModalContentTypes):
|
|
787
|
-
# Handle direct multimodal content
|
|
788
|
-
if isinstance(content, _messages.BinaryContent):
|
|
789
|
-
identifier = multi_modal_content_identifier(content.data)
|
|
790
|
-
else:
|
|
791
|
-
identifier = multi_modal_content_identifier(content.url)
|
|
792
|
-
|
|
793
|
-
user_parts.append(
|
|
794
|
-
_messages.UserPromptPart(
|
|
795
|
-
content=[f'This is file {identifier}:', content],
|
|
796
|
-
timestamp=result.timestamp,
|
|
797
|
-
part_kind='user-prompt',
|
|
798
|
-
)
|
|
799
|
-
)
|
|
800
|
-
processed_contents.append(f'See file {identifier}')
|
|
801
|
-
else:
|
|
802
|
-
# Handle regular content
|
|
803
|
-
processed_contents.append(content)
|
|
804
|
-
|
|
805
|
-
if single_content:
|
|
806
|
-
result.content = processed_contents[0]
|
|
807
|
-
else:
|
|
808
|
-
result.content = processed_contents
|
|
659
|
+
for call in calls_to_run:
|
|
660
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
809
661
|
|
|
810
|
-
|
|
811
|
-
else:
|
|
812
|
-
assert_never(result)
|
|
662
|
+
user_parts: list[_messages.UserPromptPart] = []
|
|
813
663
|
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
664
|
+
if calls_to_run:
|
|
665
|
+
# Run all tool tasks in parallel
|
|
666
|
+
parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {}
|
|
667
|
+
with ctx.deps.tracer.start_as_current_span(
|
|
668
|
+
'running tools',
|
|
669
|
+
attributes={
|
|
670
|
+
'tools': [call.tool_name for call in calls_to_run],
|
|
671
|
+
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
|
|
672
|
+
},
|
|
673
|
+
):
|
|
674
|
+
tasks = [
|
|
675
|
+
asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
|
|
676
|
+
for call in calls_to_run
|
|
677
|
+
]
|
|
678
|
+
|
|
679
|
+
pending = tasks
|
|
680
|
+
while pending:
|
|
681
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
682
|
+
for task in done:
|
|
683
|
+
index = tasks.index(task)
|
|
684
|
+
tool_result_part, extra_parts = task.result()
|
|
685
|
+
yield _messages.FunctionToolResultEvent(tool_result_part)
|
|
686
|
+
|
|
687
|
+
parts_by_index[index] = [tool_result_part, *extra_parts]
|
|
688
|
+
|
|
689
|
+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
|
|
690
|
+
# This is mostly just to simplify testing
|
|
691
|
+
for k in sorted(parts_by_index):
|
|
692
|
+
output_parts.extend(parts_by_index[k])
|
|
693
|
+
|
|
694
|
+
# Finally, we handle deferred tool calls
|
|
695
|
+
for call in tool_calls_by_kind['deferred']:
|
|
696
|
+
if final_result:
|
|
697
|
+
output_parts.append(
|
|
698
|
+
_messages.ToolReturnPart(
|
|
699
|
+
tool_name=call.tool_name,
|
|
700
|
+
content='Tool not executed - a final result was already processed.',
|
|
701
|
+
tool_call_id=call.tool_call_id,
|
|
702
|
+
)
|
|
703
|
+
)
|
|
704
|
+
else:
|
|
705
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
818
706
|
|
|
819
707
|
output_parts.extend(user_parts)
|
|
820
708
|
|
|
709
|
+
if final_result:
|
|
710
|
+
output_final_result.append(final_result)
|
|
821
711
|
|
|
822
|
-
async def _tool_from_mcp_server(
|
|
823
|
-
tool_name: str,
|
|
824
|
-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
825
|
-
) -> Tool[DepsT] | None:
|
|
826
|
-
"""Call each MCP server to find the tool with the given name.
|
|
827
|
-
|
|
828
|
-
Args:
|
|
829
|
-
tool_name: The name of the tool to find.
|
|
830
|
-
ctx: The current run context.
|
|
831
|
-
|
|
832
|
-
Returns:
|
|
833
|
-
The tool with the given name, or `None` if no tool with the given name is found.
|
|
834
|
-
"""
|
|
835
|
-
|
|
836
|
-
async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
|
|
837
|
-
# There's no normal situation where the server will not be running at this point, we check just in case
|
|
838
|
-
# some weird edge case occurs.
|
|
839
|
-
if not server.is_running: # pragma: no cover
|
|
840
|
-
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
841
712
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
713
|
+
async def _call_function_tool(
|
|
714
|
+
tool_manager: ToolManager[DepsT],
|
|
715
|
+
tool_call: _messages.ToolCallPart,
|
|
716
|
+
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]:
|
|
717
|
+
try:
|
|
718
|
+
tool_result = await tool_manager.handle_call(tool_call)
|
|
719
|
+
except ToolRetryError as e:
|
|
720
|
+
return (e.tool_retry, [])
|
|
721
|
+
|
|
722
|
+
part = _messages.ToolReturnPart(
|
|
723
|
+
tool_name=tool_call.tool_name,
|
|
724
|
+
content=tool_result,
|
|
725
|
+
tool_call_id=tool_call.tool_call_id,
|
|
726
|
+
)
|
|
727
|
+
extra_parts: list[_messages.ModelRequestPart] = []
|
|
854
728
|
|
|
729
|
+
if isinstance(tool_result, _messages.ToolReturn):
|
|
730
|
+
if (
|
|
731
|
+
isinstance(tool_result.return_value, _messages.MultiModalContentTypes)
|
|
732
|
+
or isinstance(tool_result.return_value, list)
|
|
733
|
+
and any(
|
|
734
|
+
isinstance(content, _messages.MultiModalContentTypes)
|
|
735
|
+
for content in tool_result.return_value # type: ignore
|
|
736
|
+
)
|
|
737
|
+
):
|
|
738
|
+
raise exceptions.UserError(
|
|
739
|
+
f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. '
|
|
740
|
+
f'Please use `content` instead.'
|
|
741
|
+
)
|
|
855
742
|
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
743
|
+
part.content = tool_result.return_value # type: ignore
|
|
744
|
+
part.metadata = tool_result.metadata
|
|
745
|
+
if tool_result.content:
|
|
746
|
+
extra_parts.append(
|
|
747
|
+
_messages.UserPromptPart(
|
|
748
|
+
content=list(tool_result.content),
|
|
749
|
+
part_kind='user-prompt',
|
|
750
|
+
)
|
|
751
|
+
)
|
|
752
|
+
else:
|
|
863
753
|
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
754
|
+
def process_content(content: Any) -> Any:
|
|
755
|
+
if isinstance(content, _messages.ToolReturn):
|
|
756
|
+
raise exceptions.UserError(
|
|
757
|
+
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
|
|
758
|
+
f'`ToolReturn` should be used directly.'
|
|
759
|
+
)
|
|
760
|
+
elif isinstance(content, _messages.MultiModalContentTypes):
|
|
761
|
+
if isinstance(content, _messages.BinaryContent):
|
|
762
|
+
identifier = content.identifier or multi_modal_content_identifier(content.data)
|
|
763
|
+
else:
|
|
764
|
+
identifier = multi_modal_content_identifier(content.url)
|
|
867
765
|
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
766
|
+
extra_parts.append(
|
|
767
|
+
_messages.UserPromptPart(
|
|
768
|
+
content=[f'This is file {identifier}:', content],
|
|
769
|
+
part_kind='user-prompt',
|
|
770
|
+
)
|
|
771
|
+
)
|
|
772
|
+
return f'See file {identifier}'
|
|
872
773
|
|
|
873
|
-
|
|
874
|
-
tool_name=tool_name,
|
|
875
|
-
tool_call_id=tool_call_id,
|
|
876
|
-
content=f'Unknown tool name: {tool_name!r}. {msg}',
|
|
877
|
-
)
|
|
774
|
+
return content
|
|
878
775
|
|
|
776
|
+
if isinstance(tool_result, list):
|
|
777
|
+
contents = cast(list[Any], tool_result)
|
|
778
|
+
part.content = [process_content(content) for content in contents]
|
|
779
|
+
else:
|
|
780
|
+
part.content = process_content(tool_result)
|
|
879
781
|
|
|
880
|
-
|
|
881
|
-
result_data: T,
|
|
882
|
-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
883
|
-
tool_call: _messages.ToolCallPart | None,
|
|
884
|
-
) -> T:
|
|
885
|
-
for validator in ctx.deps.output_validators:
|
|
886
|
-
run_context = build_run_context(ctx)
|
|
887
|
-
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
888
|
-
return result_data
|
|
782
|
+
return (part, extra_parts)
|
|
889
783
|
|
|
890
784
|
|
|
891
785
|
@dataclasses.dataclass
|
|
@@ -921,14 +815,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
|
921
815
|
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
|
|
922
816
|
`messages` will represent the messages exchanged during the first call only.
|
|
923
817
|
"""
|
|
818
|
+
token = None
|
|
819
|
+
messages: list[_messages.ModelMessage] = []
|
|
820
|
+
|
|
821
|
+
# Try to reuse existing message context if available
|
|
924
822
|
try:
|
|
925
|
-
|
|
823
|
+
messages = _messages_ctx_var.get().messages
|
|
926
824
|
except LookupError:
|
|
927
|
-
|
|
825
|
+
# No existing context, create a new one
|
|
928
826
|
token = _messages_ctx_var.set(_RunMessages(messages))
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
827
|
+
|
|
828
|
+
try:
|
|
829
|
+
yield messages
|
|
830
|
+
finally:
|
|
831
|
+
# Clean up context if we created it
|
|
832
|
+
if token is not None:
|
|
932
833
|
_messages_ctx_var.reset(token)
|
|
933
834
|
|
|
934
835
|
|