pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +219 -315
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +296 -226
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -155
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +14 -2
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +19 -9
- pydantic_ai/models/__init__.py +43 -19
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +3 -11
- pydantic_ai/models/google.py +3 -12
- pydantic_ai/models/groq.py +2 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +3 -3
- pydantic_ai/models/openai.py +5 -5
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +57 -33
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
- pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
- pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
5
|
import hashlib
|
|
6
|
+
from collections import defaultdict, deque
|
|
6
7
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
8
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
9
|
from contextvars import ContextVar
|
|
@@ -13,17 +14,18 @@ from opentelemetry.trace import Tracer
|
|
|
13
14
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
14
15
|
|
|
15
16
|
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
|
|
17
|
+
from pydantic_ai._tool_manager import ToolManager
|
|
16
18
|
from pydantic_ai._utils import is_async_callable, run_in_executor
|
|
17
19
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
18
20
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
19
21
|
|
|
20
22
|
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
|
|
23
|
+
from .exceptions import ToolRetryError
|
|
21
24
|
from .output import OutputDataT, OutputSpec
|
|
22
25
|
from .settings import ModelSettings, merge_model_settings
|
|
23
|
-
from .tools import RunContext,
|
|
26
|
+
from .tools import RunContext, ToolDefinition, ToolKind
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
26
|
-
from .mcp import MCPServer
|
|
27
29
|
from .models.instrumented import InstrumentationSettings
|
|
28
30
|
|
|
29
31
|
__all__ = (
|
|
@@ -77,11 +79,13 @@ class GraphAgentState:
|
|
|
77
79
|
retries: int
|
|
78
80
|
run_step: int
|
|
79
81
|
|
|
80
|
-
def increment_retries(self, max_result_retries: int, error:
|
|
82
|
+
def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
|
|
81
83
|
self.retries += 1
|
|
82
84
|
if self.retries > max_result_retries:
|
|
83
|
-
message = f'Exceeded maximum retries ({max_result_retries}) for
|
|
85
|
+
message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
|
|
84
86
|
if error:
|
|
87
|
+
if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
|
|
88
|
+
error = error.__cause__
|
|
85
89
|
raise exceptions.UnexpectedModelBehavior(message) from error
|
|
86
90
|
else:
|
|
87
91
|
raise exceptions.UnexpectedModelBehavior(message)
|
|
@@ -108,15 +112,11 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
108
112
|
|
|
109
113
|
history_processors: Sequence[HistoryProcessor[DepsT]]
|
|
110
114
|
|
|
111
|
-
|
|
112
|
-
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
113
|
-
default_retries: int
|
|
115
|
+
tool_manager: ToolManager[DepsT]
|
|
114
116
|
|
|
115
117
|
tracer: Tracer
|
|
116
118
|
instrumentation_settings: InstrumentationSettings | None = None
|
|
117
119
|
|
|
118
|
-
prepare_tools: ToolsPrepareFunc[DepsT] | None = None
|
|
119
|
-
|
|
120
120
|
|
|
121
121
|
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
122
122
|
"""The base class for all agent nodes.
|
|
@@ -248,59 +248,27 @@ async def _prepare_request_parameters(
|
|
|
248
248
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
249
249
|
) -> models.ModelRequestParameters:
|
|
250
250
|
"""Build tools and create an agent model."""
|
|
251
|
-
function_tool_defs_map: dict[str, ToolDefinition] = {}
|
|
252
|
-
|
|
253
251
|
run_context = build_run_context(ctx)
|
|
254
|
-
|
|
255
|
-
async def add_tool(tool: Tool[DepsT]) -> None:
|
|
256
|
-
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
257
|
-
if tool_def := await tool.prepare_tool_def(ctx):
|
|
258
|
-
# prepare_tool_def may change tool_def.name
|
|
259
|
-
if tool_def.name in function_tool_defs_map:
|
|
260
|
-
if tool_def.name != tool.name:
|
|
261
|
-
# Prepare tool def may have renamed the tool
|
|
262
|
-
raise exceptions.UserError(
|
|
263
|
-
f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool."
|
|
264
|
-
)
|
|
265
|
-
else:
|
|
266
|
-
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.')
|
|
267
|
-
function_tool_defs_map[tool_def.name] = tool_def
|
|
268
|
-
|
|
269
|
-
async def add_mcp_server_tools(server: MCPServer) -> None:
|
|
270
|
-
if not server.is_running:
|
|
271
|
-
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
272
|
-
tool_defs = await server.list_tools()
|
|
273
|
-
for tool_def in tool_defs:
|
|
274
|
-
if tool_def.name in function_tool_defs_map:
|
|
275
|
-
raise exceptions.UserError(
|
|
276
|
-
f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts."
|
|
277
|
-
)
|
|
278
|
-
function_tool_defs_map[tool_def.name] = tool_def
|
|
279
|
-
|
|
280
|
-
await asyncio.gather(
|
|
281
|
-
*map(add_tool, ctx.deps.function_tools.values()),
|
|
282
|
-
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
|
|
283
|
-
)
|
|
284
|
-
function_tool_defs = list(function_tool_defs_map.values())
|
|
285
|
-
if ctx.deps.prepare_tools:
|
|
286
|
-
# Prepare the tools using the provided function
|
|
287
|
-
# This also acts over tool definitions pulled from MCP servers
|
|
288
|
-
function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
|
|
252
|
+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
289
253
|
|
|
290
254
|
output_schema = ctx.deps.output_schema
|
|
291
|
-
|
|
292
|
-
output_tools = []
|
|
293
255
|
output_object = None
|
|
294
|
-
if isinstance(output_schema, _output.
|
|
295
|
-
output_tools = output_schema.tool_defs()
|
|
296
|
-
elif isinstance(output_schema, _output.NativeOutputSchema):
|
|
256
|
+
if isinstance(output_schema, _output.NativeOutputSchema):
|
|
297
257
|
output_object = output_schema.object_def
|
|
298
258
|
|
|
299
259
|
# ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema
|
|
300
260
|
allow_text_output = isinstance(output_schema, _output.TextOutputSchema)
|
|
301
261
|
|
|
262
|
+
function_tools: list[ToolDefinition] = []
|
|
263
|
+
output_tools: list[ToolDefinition] = []
|
|
264
|
+
for tool_def in ctx.deps.tool_manager.tool_defs:
|
|
265
|
+
if tool_def.kind == 'output':
|
|
266
|
+
output_tools.append(tool_def)
|
|
267
|
+
else:
|
|
268
|
+
function_tools.append(tool_def)
|
|
269
|
+
|
|
302
270
|
return models.ModelRequestParameters(
|
|
303
|
-
function_tools=
|
|
271
|
+
function_tools=function_tools,
|
|
304
272
|
output_mode=output_schema.mode,
|
|
305
273
|
output_tools=output_tools,
|
|
306
274
|
output_object=output_object,
|
|
@@ -342,6 +310,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
342
310
|
ctx.deps.output_validators,
|
|
343
311
|
build_run_context(ctx),
|
|
344
312
|
ctx.deps.usage_limits,
|
|
313
|
+
ctx.deps.tool_manager,
|
|
345
314
|
)
|
|
346
315
|
yield agent_stream
|
|
347
316
|
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
@@ -437,7 +406,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
437
406
|
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
|
|
438
407
|
default=None, repr=False
|
|
439
408
|
)
|
|
440
|
-
_tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
|
|
441
409
|
|
|
442
410
|
async def run(
|
|
443
411
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
@@ -519,46 +487,30 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
519
487
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
520
488
|
tool_calls: list[_messages.ToolCallPart],
|
|
521
489
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
522
|
-
output_schema = ctx.deps.output_schema
|
|
523
490
|
run_context = build_run_context(ctx)
|
|
524
491
|
|
|
525
|
-
|
|
526
|
-
|
|
492
|
+
output_parts: list[_messages.ModelRequestPart] = []
|
|
493
|
+
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
|
|
527
494
|
|
|
528
|
-
# first, look for the output tool call
|
|
529
|
-
if isinstance(output_schema, _output.ToolOutputSchema):
|
|
530
|
-
for call, output_tool in output_schema.find_tool(tool_calls):
|
|
531
|
-
try:
|
|
532
|
-
result_data = await output_tool.process(call, run_context)
|
|
533
|
-
result_data = await _validate_output(result_data, ctx, call)
|
|
534
|
-
except _output.ToolRetryError as e:
|
|
535
|
-
# TODO: Should only increment retry stuff once per node execution, not for each tool call
|
|
536
|
-
# Also, should increment the tool-specific retry count rather than the run retry count
|
|
537
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
538
|
-
parts.append(e.tool_retry)
|
|
539
|
-
else:
|
|
540
|
-
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
|
|
541
|
-
break
|
|
542
|
-
|
|
543
|
-
# Then build the other request parts based on end strategy
|
|
544
|
-
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
|
|
545
495
|
async for event in process_function_tools(
|
|
546
|
-
tool_calls,
|
|
547
|
-
final_result and final_result.tool_name,
|
|
548
|
-
final_result and final_result.tool_call_id,
|
|
549
|
-
ctx,
|
|
550
|
-
tool_responses,
|
|
496
|
+
ctx.deps.tool_manager, tool_calls, None, ctx, output_parts, output_final_result
|
|
551
497
|
):
|
|
552
498
|
yield event
|
|
553
499
|
|
|
554
|
-
if
|
|
555
|
-
|
|
500
|
+
if output_final_result:
|
|
501
|
+
final_result = output_final_result[0]
|
|
502
|
+
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
503
|
+
elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls):
|
|
504
|
+
if not ctx.deps.output_schema.allows_deferred_tool_calls:
|
|
505
|
+
raise exceptions.UserError(
|
|
506
|
+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
507
|
+
)
|
|
508
|
+
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
|
|
509
|
+
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
|
|
556
510
|
else:
|
|
557
|
-
if tool_responses:
|
|
558
|
-
parts.extend(tool_responses)
|
|
559
511
|
instructions = await ctx.deps.get_instructions(run_context)
|
|
560
512
|
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
561
|
-
_messages.ModelRequest(parts=
|
|
513
|
+
_messages.ModelRequest(parts=output_parts, instructions=instructions)
|
|
562
514
|
)
|
|
563
515
|
|
|
564
516
|
def _handle_final_result(
|
|
@@ -584,17 +536,18 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
584
536
|
|
|
585
537
|
text = '\n\n'.join(texts)
|
|
586
538
|
try:
|
|
539
|
+
run_context = build_run_context(ctx)
|
|
587
540
|
if isinstance(output_schema, _output.TextOutputSchema):
|
|
588
|
-
run_context = build_run_context(ctx)
|
|
589
541
|
result_data = await output_schema.process(text, run_context)
|
|
590
542
|
else:
|
|
591
543
|
m = _messages.RetryPromptPart(
|
|
592
544
|
content='Plain text responses are not permitted, please include your response in a tool call',
|
|
593
545
|
)
|
|
594
|
-
raise
|
|
546
|
+
raise ToolRetryError(m)
|
|
595
547
|
|
|
596
|
-
|
|
597
|
-
|
|
548
|
+
for validator in ctx.deps.output_validators:
|
|
549
|
+
result_data = await validator.validate(result_data, run_context)
|
|
550
|
+
except ToolRetryError as e:
|
|
598
551
|
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
599
552
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
600
553
|
else:
|
|
@@ -609,6 +562,9 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
609
562
|
usage=ctx.state.usage,
|
|
610
563
|
prompt=ctx.deps.prompt,
|
|
611
564
|
messages=ctx.state.message_history,
|
|
565
|
+
tracer=ctx.deps.tracer,
|
|
566
|
+
trace_include_content=ctx.deps.instrumentation_settings is not None
|
|
567
|
+
and ctx.deps.instrumentation_settings.include_content,
|
|
612
568
|
run_step=ctx.state.run_step,
|
|
613
569
|
)
|
|
614
570
|
|
|
@@ -620,269 +576,210 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
|
|
|
620
576
|
return hashlib.sha1(identifier).hexdigest()[:6]
|
|
621
577
|
|
|
622
578
|
|
|
623
|
-
async def process_function_tools( # noqa C901
|
|
579
|
+
async def process_function_tools( # noqa: C901
|
|
580
|
+
tool_manager: ToolManager[DepsT],
|
|
624
581
|
tool_calls: list[_messages.ToolCallPart],
|
|
625
|
-
|
|
626
|
-
output_tool_call_id: str | None,
|
|
582
|
+
final_result: result.FinalResult[NodeRunEndT] | None,
|
|
627
583
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
628
584
|
output_parts: list[_messages.ModelRequestPart],
|
|
585
|
+
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1),
|
|
629
586
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
630
587
|
"""Process function (i.e., non-result) tool calls in parallel.
|
|
631
588
|
|
|
632
589
|
Also add stub return parts for any other tools that need it.
|
|
633
590
|
|
|
634
|
-
Because async iterators can't have return values, we use `output_parts` as
|
|
591
|
+
Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments.
|
|
635
592
|
"""
|
|
636
|
-
|
|
637
|
-
output_schema = ctx.deps.output_schema
|
|
638
|
-
|
|
639
|
-
# we rely on the fact that if we found a result, it's the first output tool in the last
|
|
640
|
-
found_used_output_tool = False
|
|
641
|
-
run_context = build_run_context(ctx)
|
|
642
|
-
|
|
643
|
-
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
593
|
+
tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
|
|
644
594
|
for call in tool_calls:
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
595
|
+
tool_def = tool_manager.get_tool_def(call.tool_name)
|
|
596
|
+
kind = tool_def.kind if tool_def else 'unknown'
|
|
597
|
+
tool_calls_by_kind[kind].append(call)
|
|
598
|
+
|
|
599
|
+
# First, we handle output tool calls
|
|
600
|
+
for call in tool_calls_by_kind['output']:
|
|
601
|
+
if final_result:
|
|
602
|
+
if final_result.tool_call_id == call.tool_call_id:
|
|
603
|
+
part = _messages.ToolReturnPart(
|
|
653
604
|
tool_name=call.tool_name,
|
|
654
605
|
content='Final result processed.',
|
|
655
606
|
tool_call_id=call.tool_call_id,
|
|
656
607
|
)
|
|
657
|
-
)
|
|
658
|
-
elif tool := ctx.deps.function_tools.get(call.tool_name):
|
|
659
|
-
if stub_function_tools:
|
|
660
|
-
output_parts.append(
|
|
661
|
-
_messages.ToolReturnPart(
|
|
662
|
-
tool_name=call.tool_name,
|
|
663
|
-
content='Tool not executed - a final result was already processed.',
|
|
664
|
-
tool_call_id=call.tool_call_id,
|
|
665
|
-
)
|
|
666
|
-
)
|
|
667
608
|
else:
|
|
668
|
-
event = _messages.FunctionToolCallEvent(call)
|
|
669
|
-
yield event
|
|
670
|
-
calls_to_run.append((tool, call))
|
|
671
|
-
elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
|
|
672
|
-
if stub_function_tools:
|
|
673
|
-
# TODO(Marcelo): We should add coverage for this part of the code.
|
|
674
|
-
output_parts.append( # pragma: no cover
|
|
675
|
-
_messages.ToolReturnPart(
|
|
676
|
-
tool_name=call.tool_name,
|
|
677
|
-
content='Tool not executed - a final result was already processed.',
|
|
678
|
-
tool_call_id=call.tool_call_id,
|
|
679
|
-
)
|
|
680
|
-
)
|
|
681
|
-
else:
|
|
682
|
-
event = _messages.FunctionToolCallEvent(call)
|
|
683
|
-
yield event
|
|
684
|
-
calls_to_run.append((mcp_tool, call))
|
|
685
|
-
elif call.tool_name in output_schema.tools:
|
|
686
|
-
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
|
|
687
|
-
# validation, we don't add another part here
|
|
688
|
-
if output_tool_name is not None:
|
|
689
609
|
yield _messages.FunctionToolCallEvent(call)
|
|
690
|
-
if found_used_output_tool:
|
|
691
|
-
content = 'Output tool not used - a final result was already processed.'
|
|
692
|
-
else:
|
|
693
|
-
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
|
|
694
|
-
content = 'Output tool not used - result failed validation.'
|
|
695
610
|
part = _messages.ToolReturnPart(
|
|
696
611
|
tool_name=call.tool_name,
|
|
697
|
-
content=
|
|
612
|
+
content='Output tool not used - a final result was already processed.',
|
|
698
613
|
tool_call_id=call.tool_call_id,
|
|
699
614
|
)
|
|
700
615
|
yield _messages.FunctionToolResultEvent(part)
|
|
701
|
-
output_parts.append(part)
|
|
702
|
-
else:
|
|
703
|
-
yield _messages.FunctionToolCallEvent(call)
|
|
704
616
|
|
|
705
|
-
part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
|
|
706
|
-
yield _messages.FunctionToolResultEvent(part)
|
|
707
617
|
output_parts.append(part)
|
|
618
|
+
else:
|
|
619
|
+
try:
|
|
620
|
+
result_data = await tool_manager.handle_call(call)
|
|
621
|
+
except exceptions.UnexpectedModelBehavior as e:
|
|
622
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
623
|
+
raise e # pragma: no cover
|
|
624
|
+
except ToolRetryError as e:
|
|
625
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
|
|
626
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
627
|
+
output_parts.append(e.tool_retry)
|
|
628
|
+
yield _messages.FunctionToolResultEvent(e.tool_retry)
|
|
629
|
+
else:
|
|
630
|
+
part = _messages.ToolReturnPart(
|
|
631
|
+
tool_name=call.tool_name,
|
|
632
|
+
content='Final result processed.',
|
|
633
|
+
tool_call_id=call.tool_call_id,
|
|
634
|
+
)
|
|
635
|
+
output_parts.append(part)
|
|
636
|
+
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
|
|
708
637
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
638
|
+
# Then, we handle function tool calls
|
|
639
|
+
calls_to_run: list[_messages.ToolCallPart] = []
|
|
640
|
+
if final_result and ctx.deps.end_strategy == 'early':
|
|
641
|
+
output_parts.extend(
|
|
642
|
+
[
|
|
643
|
+
_messages.ToolReturnPart(
|
|
644
|
+
tool_name=call.tool_name,
|
|
645
|
+
content='Tool not executed - a final result was already processed.',
|
|
646
|
+
tool_call_id=call.tool_call_id,
|
|
647
|
+
)
|
|
648
|
+
for call in tool_calls_by_kind['function']
|
|
649
|
+
]
|
|
650
|
+
)
|
|
651
|
+
else:
|
|
652
|
+
calls_to_run.extend(tool_calls_by_kind['function'])
|
|
713
653
|
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
654
|
+
# Then, we handle unknown tool calls
|
|
655
|
+
if tool_calls_by_kind['unknown']:
|
|
656
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
657
|
+
calls_to_run.extend(tool_calls_by_kind['unknown'])
|
|
717
658
|
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
with ctx.deps.tracer.start_as_current_span(
|
|
721
|
-
'running tools',
|
|
722
|
-
attributes={
|
|
723
|
-
'tools': [call.tool_name for _, call in calls_to_run],
|
|
724
|
-
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
|
|
725
|
-
},
|
|
726
|
-
):
|
|
727
|
-
tasks = [
|
|
728
|
-
asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer, include_content), name=call.tool_name)
|
|
729
|
-
for tool, call in calls_to_run
|
|
730
|
-
]
|
|
731
|
-
|
|
732
|
-
pending = tasks
|
|
733
|
-
while pending:
|
|
734
|
-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
735
|
-
for task in done:
|
|
736
|
-
index = tasks.index(task)
|
|
737
|
-
result = task.result()
|
|
738
|
-
yield _messages.FunctionToolResultEvent(result)
|
|
739
|
-
|
|
740
|
-
if isinstance(result, _messages.RetryPromptPart):
|
|
741
|
-
results_by_index[index] = result
|
|
742
|
-
elif isinstance(result, _messages.ToolReturnPart):
|
|
743
|
-
if isinstance(result.content, _messages.ToolReturn):
|
|
744
|
-
tool_return = result.content
|
|
745
|
-
if (
|
|
746
|
-
isinstance(tool_return.return_value, _messages.MultiModalContentTypes)
|
|
747
|
-
or isinstance(tool_return.return_value, list)
|
|
748
|
-
and any(
|
|
749
|
-
isinstance(content, _messages.MultiModalContentTypes)
|
|
750
|
-
for content in tool_return.return_value # type: ignore
|
|
751
|
-
)
|
|
752
|
-
):
|
|
753
|
-
raise exceptions.UserError(
|
|
754
|
-
f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. "
|
|
755
|
-
f'Please use `content` instead.'
|
|
756
|
-
)
|
|
757
|
-
result.content = tool_return.return_value # type: ignore
|
|
758
|
-
result.metadata = tool_return.metadata
|
|
759
|
-
if tool_return.content:
|
|
760
|
-
user_parts.append(
|
|
761
|
-
_messages.UserPromptPart(
|
|
762
|
-
content=list(tool_return.content),
|
|
763
|
-
timestamp=result.timestamp,
|
|
764
|
-
part_kind='user-prompt',
|
|
765
|
-
)
|
|
766
|
-
)
|
|
767
|
-
contents: list[Any]
|
|
768
|
-
single_content: bool
|
|
769
|
-
if isinstance(result.content, list):
|
|
770
|
-
contents = result.content # type: ignore
|
|
771
|
-
single_content = False
|
|
772
|
-
else:
|
|
773
|
-
contents = [result.content]
|
|
774
|
-
single_content = True
|
|
775
|
-
|
|
776
|
-
processed_contents: list[Any] = []
|
|
777
|
-
for content in contents:
|
|
778
|
-
if isinstance(content, _messages.ToolReturn):
|
|
779
|
-
raise exceptions.UserError(
|
|
780
|
-
f"{result.tool_name}'s return contains invalid nested ToolReturn objects. "
|
|
781
|
-
f'ToolReturn should be used directly.'
|
|
782
|
-
)
|
|
783
|
-
elif isinstance(content, _messages.MultiModalContentTypes):
|
|
784
|
-
# Handle direct multimodal content
|
|
785
|
-
if isinstance(content, _messages.BinaryContent):
|
|
786
|
-
identifier = multi_modal_content_identifier(content.data)
|
|
787
|
-
else:
|
|
788
|
-
identifier = multi_modal_content_identifier(content.url)
|
|
789
|
-
|
|
790
|
-
user_parts.append(
|
|
791
|
-
_messages.UserPromptPart(
|
|
792
|
-
content=[f'This is file {identifier}:', content],
|
|
793
|
-
timestamp=result.timestamp,
|
|
794
|
-
part_kind='user-prompt',
|
|
795
|
-
)
|
|
796
|
-
)
|
|
797
|
-
processed_contents.append(f'See file {identifier}')
|
|
798
|
-
else:
|
|
799
|
-
# Handle regular content
|
|
800
|
-
processed_contents.append(content)
|
|
801
|
-
|
|
802
|
-
if single_content:
|
|
803
|
-
result.content = processed_contents[0]
|
|
804
|
-
else:
|
|
805
|
-
result.content = processed_contents
|
|
659
|
+
for call in calls_to_run:
|
|
660
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
806
661
|
|
|
807
|
-
|
|
808
|
-
else:
|
|
809
|
-
assert_never(result)
|
|
662
|
+
user_parts: list[_messages.UserPromptPart] = []
|
|
810
663
|
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
664
|
+
if calls_to_run:
|
|
665
|
+
# Run all tool tasks in parallel
|
|
666
|
+
parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {}
|
|
667
|
+
with ctx.deps.tracer.start_as_current_span(
|
|
668
|
+
'running tools',
|
|
669
|
+
attributes={
|
|
670
|
+
'tools': [call.tool_name for call in calls_to_run],
|
|
671
|
+
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
|
|
672
|
+
},
|
|
673
|
+
):
|
|
674
|
+
tasks = [
|
|
675
|
+
asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name)
|
|
676
|
+
for call in calls_to_run
|
|
677
|
+
]
|
|
678
|
+
|
|
679
|
+
pending = tasks
|
|
680
|
+
while pending:
|
|
681
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
682
|
+
for task in done:
|
|
683
|
+
index = tasks.index(task)
|
|
684
|
+
tool_result_part, extra_parts = task.result()
|
|
685
|
+
yield _messages.FunctionToolResultEvent(tool_result_part)
|
|
686
|
+
|
|
687
|
+
parts_by_index[index] = [tool_result_part, *extra_parts]
|
|
688
|
+
|
|
689
|
+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
|
|
690
|
+
# This is mostly just to simplify testing
|
|
691
|
+
for k in sorted(parts_by_index):
|
|
692
|
+
output_parts.extend(parts_by_index[k])
|
|
693
|
+
|
|
694
|
+
# Finally, we handle deferred tool calls
|
|
695
|
+
for call in tool_calls_by_kind['deferred']:
|
|
696
|
+
if final_result:
|
|
697
|
+
output_parts.append(
|
|
698
|
+
_messages.ToolReturnPart(
|
|
699
|
+
tool_name=call.tool_name,
|
|
700
|
+
content='Tool not executed - a final result was already processed.',
|
|
701
|
+
tool_call_id=call.tool_call_id,
|
|
702
|
+
)
|
|
703
|
+
)
|
|
704
|
+
else:
|
|
705
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
815
706
|
|
|
816
707
|
output_parts.extend(user_parts)
|
|
817
708
|
|
|
709
|
+
if final_result:
|
|
710
|
+
output_final_result.append(final_result)
|
|
818
711
|
|
|
819
|
-
async def _tool_from_mcp_server(
|
|
820
|
-
tool_name: str,
|
|
821
|
-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
822
|
-
) -> Tool[DepsT] | None:
|
|
823
|
-
"""Call each MCP server to find the tool with the given name.
|
|
824
|
-
|
|
825
|
-
Args:
|
|
826
|
-
tool_name: The name of the tool to find.
|
|
827
|
-
ctx: The current run context.
|
|
828
|
-
|
|
829
|
-
Returns:
|
|
830
|
-
The tool with the given name, or `None` if no tool with the given name is found.
|
|
831
|
-
"""
|
|
832
|
-
|
|
833
|
-
async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
|
|
834
|
-
# There's no normal situation where the server will not be running at this point, we check just in case
|
|
835
|
-
# some weird edge case occurs.
|
|
836
|
-
if not server.is_running: # pragma: no cover
|
|
837
|
-
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
838
712
|
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
713
|
+
async def _call_function_tool(
|
|
714
|
+
tool_manager: ToolManager[DepsT],
|
|
715
|
+
tool_call: _messages.ToolCallPart,
|
|
716
|
+
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]:
|
|
717
|
+
try:
|
|
718
|
+
tool_result = await tool_manager.handle_call(tool_call)
|
|
719
|
+
except ToolRetryError as e:
|
|
720
|
+
return (e.tool_retry, [])
|
|
721
|
+
|
|
722
|
+
part = _messages.ToolReturnPart(
|
|
723
|
+
tool_name=tool_call.tool_name,
|
|
724
|
+
content=tool_result,
|
|
725
|
+
tool_call_id=tool_call.tool_call_id,
|
|
726
|
+
)
|
|
727
|
+
extra_parts: list[_messages.ModelRequestPart] = []
|
|
851
728
|
|
|
729
|
+
if isinstance(tool_result, _messages.ToolReturn):
|
|
730
|
+
if (
|
|
731
|
+
isinstance(tool_result.return_value, _messages.MultiModalContentTypes)
|
|
732
|
+
or isinstance(tool_result.return_value, list)
|
|
733
|
+
and any(
|
|
734
|
+
isinstance(content, _messages.MultiModalContentTypes)
|
|
735
|
+
for content in tool_result.return_value # type: ignore
|
|
736
|
+
)
|
|
737
|
+
):
|
|
738
|
+
raise exceptions.UserError(
|
|
739
|
+
f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. '
|
|
740
|
+
f'Please use `content` instead.'
|
|
741
|
+
)
|
|
852
742
|
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
743
|
+
part.content = tool_result.return_value # type: ignore
|
|
744
|
+
part.metadata = tool_result.metadata
|
|
745
|
+
if tool_result.content:
|
|
746
|
+
extra_parts.append(
|
|
747
|
+
_messages.UserPromptPart(
|
|
748
|
+
content=list(tool_result.content),
|
|
749
|
+
part_kind='user-prompt',
|
|
750
|
+
)
|
|
751
|
+
)
|
|
752
|
+
else:
|
|
860
753
|
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
754
|
+
def process_content(content: Any) -> Any:
|
|
755
|
+
if isinstance(content, _messages.ToolReturn):
|
|
756
|
+
raise exceptions.UserError(
|
|
757
|
+
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
|
|
758
|
+
f'`ToolReturn` should be used directly.'
|
|
759
|
+
)
|
|
760
|
+
elif isinstance(content, _messages.MultiModalContentTypes):
|
|
761
|
+
if isinstance(content, _messages.BinaryContent):
|
|
762
|
+
identifier = content.identifier or multi_modal_content_identifier(content.data)
|
|
763
|
+
else:
|
|
764
|
+
identifier = multi_modal_content_identifier(content.url)
|
|
864
765
|
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
766
|
+
extra_parts.append(
|
|
767
|
+
_messages.UserPromptPart(
|
|
768
|
+
content=[f'This is file {identifier}:', content],
|
|
769
|
+
part_kind='user-prompt',
|
|
770
|
+
)
|
|
771
|
+
)
|
|
772
|
+
return f'See file {identifier}'
|
|
869
773
|
|
|
870
|
-
|
|
871
|
-
tool_name=tool_name,
|
|
872
|
-
tool_call_id=tool_call_id,
|
|
873
|
-
content=f'Unknown tool name: {tool_name!r}. {msg}',
|
|
874
|
-
)
|
|
774
|
+
return content
|
|
875
775
|
|
|
776
|
+
if isinstance(tool_result, list):
|
|
777
|
+
contents = cast(list[Any], tool_result)
|
|
778
|
+
part.content = [process_content(content) for content in contents]
|
|
779
|
+
else:
|
|
780
|
+
part.content = process_content(tool_result)
|
|
876
781
|
|
|
877
|
-
|
|
878
|
-
result_data: T,
|
|
879
|
-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
880
|
-
tool_call: _messages.ToolCallPart | None,
|
|
881
|
-
) -> T:
|
|
882
|
-
for validator in ctx.deps.output_validators:
|
|
883
|
-
run_context = build_run_context(ctx)
|
|
884
|
-
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
885
|
-
return result_data
|
|
782
|
+
return (part, extra_parts)
|
|
886
783
|
|
|
887
784
|
|
|
888
785
|
@dataclasses.dataclass
|
|
@@ -918,14 +815,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
|
918
815
|
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
|
|
919
816
|
`messages` will represent the messages exchanged during the first call only.
|
|
920
817
|
"""
|
|
818
|
+
token = None
|
|
819
|
+
messages: list[_messages.ModelMessage] = []
|
|
820
|
+
|
|
821
|
+
# Try to reuse existing message context if available
|
|
921
822
|
try:
|
|
922
|
-
|
|
823
|
+
messages = _messages_ctx_var.get().messages
|
|
923
824
|
except LookupError:
|
|
924
|
-
|
|
825
|
+
# No existing context, create a new one
|
|
925
826
|
token = _messages_ctx_var.set(_RunMessages(messages))
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
827
|
+
|
|
828
|
+
try:
|
|
829
|
+
yield messages
|
|
830
|
+
finally:
|
|
831
|
+
# Clean up context if we created it
|
|
832
|
+
if token is not None:
|
|
929
833
|
_messages_ctx_var.reset(token)
|
|
930
834
|
|
|
931
835
|
|