pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +156 -44
- pydantic_ai/models/__init__.py +20 -7
- pydantic_ai/models/anthropic.py +10 -17
- pydantic_ai/models/bedrock.py +55 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +13 -14
- pydantic_ai/models/google.py +19 -5
- pydantic_ai/models/groq.py +127 -39
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +49 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +37 -42
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py
CHANGED
|
@@ -8,12 +8,11 @@ from __future__ import annotations
|
|
|
8
8
|
|
|
9
9
|
import json
|
|
10
10
|
import uuid
|
|
11
|
-
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
|
|
11
|
+
from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence
|
|
12
12
|
from dataclasses import Field, dataclass, replace
|
|
13
13
|
from http import HTTPStatus
|
|
14
14
|
from typing import (
|
|
15
15
|
Any,
|
|
16
|
-
Callable,
|
|
17
16
|
ClassVar,
|
|
18
17
|
Final,
|
|
19
18
|
Generic,
|
|
@@ -46,11 +45,11 @@ from .messages import (
|
|
|
46
45
|
UserPromptPart,
|
|
47
46
|
)
|
|
48
47
|
from .models import KnownModelName, Model
|
|
49
|
-
from .output import
|
|
48
|
+
from .output import OutputDataT, OutputSpec
|
|
50
49
|
from .settings import ModelSettings
|
|
51
|
-
from .tools import AgentDepsT, ToolDefinition
|
|
50
|
+
from .tools import AgentDepsT, DeferredToolRequests, ToolDefinition
|
|
52
51
|
from .toolsets import AbstractToolset
|
|
53
|
-
from .toolsets.
|
|
52
|
+
from .toolsets.external import ExternalToolset
|
|
54
53
|
from .usage import RunUsage, UsageLimits
|
|
55
54
|
|
|
56
55
|
try:
|
|
@@ -69,6 +68,9 @@ try:
|
|
|
69
68
|
TextMessageContentEvent,
|
|
70
69
|
TextMessageEndEvent,
|
|
71
70
|
TextMessageStartEvent,
|
|
71
|
+
# TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
|
|
72
|
+
# ThinkingEndEvent,
|
|
73
|
+
# ThinkingStartEvent,
|
|
72
74
|
ThinkingTextMessageContentEvent,
|
|
73
75
|
ThinkingTextMessageEndEvent,
|
|
74
76
|
ThinkingTextMessageStartEvent,
|
|
@@ -343,7 +345,7 @@ async def run_ag_ui(
|
|
|
343
345
|
|
|
344
346
|
async with agent.iter(
|
|
345
347
|
user_prompt=None,
|
|
346
|
-
output_type=[output_type or agent.output_type,
|
|
348
|
+
output_type=[output_type or agent.output_type, DeferredToolRequests],
|
|
347
349
|
message_history=messages,
|
|
348
350
|
model=model,
|
|
349
351
|
deps=deps,
|
|
@@ -393,6 +395,12 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
|
|
|
393
395
|
if stream_ctx.part_end: # pragma: no branch
|
|
394
396
|
yield stream_ctx.part_end
|
|
395
397
|
stream_ctx.part_end = None
|
|
398
|
+
if stream_ctx.thinking:
|
|
399
|
+
# TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
|
|
400
|
+
# yield ThinkingEndEvent(
|
|
401
|
+
# type=EventType.THINKING_END,
|
|
402
|
+
# )
|
|
403
|
+
stream_ctx.thinking = False
|
|
396
404
|
elif isinstance(node, CallToolsNode):
|
|
397
405
|
async with node.stream(run.ctx) as handle_stream:
|
|
398
406
|
async for event in handle_stream:
|
|
@@ -401,7 +409,7 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
|
|
|
401
409
|
yield msg
|
|
402
410
|
|
|
403
411
|
|
|
404
|
-
async def _handle_model_request_event(
|
|
412
|
+
async def _handle_model_request_event( # noqa: C901
|
|
405
413
|
stream_ctx: _RequestStreamContext,
|
|
406
414
|
agent_event: ModelResponseStreamEvent,
|
|
407
415
|
) -> AsyncIterator[BaseEvent]:
|
|
@@ -421,56 +429,70 @@ async def _handle_model_request_event(
|
|
|
421
429
|
stream_ctx.part_end = None
|
|
422
430
|
|
|
423
431
|
part = agent_event.part
|
|
424
|
-
if isinstance(part,
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
+
if isinstance(part, ThinkingPart): # pragma: no branch
|
|
433
|
+
if not stream_ctx.thinking:
|
|
434
|
+
# TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
|
|
435
|
+
# yield ThinkingStartEvent(
|
|
436
|
+
# type=EventType.THINKING_START,
|
|
437
|
+
# )
|
|
438
|
+
stream_ctx.thinking = True
|
|
439
|
+
|
|
440
|
+
if part.content:
|
|
441
|
+
yield ThinkingTextMessageStartEvent(
|
|
442
|
+
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
443
|
+
)
|
|
444
|
+
yield ThinkingTextMessageContentEvent(
|
|
445
|
+
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
432
446
|
delta=part.content,
|
|
433
447
|
)
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
448
|
+
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
449
|
+
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
450
|
+
)
|
|
451
|
+
else:
|
|
452
|
+
if stream_ctx.thinking:
|
|
453
|
+
# TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
|
|
454
|
+
# yield ThinkingEndEvent(
|
|
455
|
+
# type=EventType.THINKING_END,
|
|
456
|
+
# )
|
|
457
|
+
stream_ctx.thinking = False
|
|
458
|
+
|
|
459
|
+
if isinstance(part, TextPart):
|
|
460
|
+
message_id = stream_ctx.new_message_id()
|
|
461
|
+
yield TextMessageStartEvent(
|
|
462
|
+
message_id=message_id,
|
|
463
|
+
)
|
|
464
|
+
if part.content: # pragma: no branch
|
|
465
|
+
yield TextMessageContentEvent(
|
|
466
|
+
message_id=message_id,
|
|
467
|
+
delta=part.content,
|
|
468
|
+
)
|
|
469
|
+
stream_ctx.part_end = TextMessageEndEvent(
|
|
470
|
+
message_id=message_id,
|
|
471
|
+
)
|
|
472
|
+
elif isinstance(part, ToolCallPart): # pragma: no branch
|
|
473
|
+
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
|
|
474
|
+
yield ToolCallStartEvent(
|
|
475
|
+
tool_call_id=part.tool_call_id,
|
|
476
|
+
tool_call_name=part.tool_name,
|
|
477
|
+
parent_message_id=message_id,
|
|
478
|
+
)
|
|
479
|
+
if part.args:
|
|
480
|
+
yield ToolCallArgsEvent(
|
|
481
|
+
tool_call_id=part.tool_call_id,
|
|
482
|
+
delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
|
|
483
|
+
)
|
|
484
|
+
stream_ctx.part_end = ToolCallEndEvent(
|
|
446
485
|
tool_call_id=part.tool_call_id,
|
|
447
|
-
delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
|
|
448
486
|
)
|
|
449
|
-
stream_ctx.part_end = ToolCallEndEvent(
|
|
450
|
-
tool_call_id=part.tool_call_id,
|
|
451
|
-
)
|
|
452
|
-
|
|
453
|
-
elif isinstance(part, ThinkingPart): # pragma: no branch
|
|
454
|
-
yield ThinkingTextMessageStartEvent(
|
|
455
|
-
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
456
|
-
)
|
|
457
|
-
# Always send the content even if it's empty, as it may be
|
|
458
|
-
# used to indicate the start of thinking.
|
|
459
|
-
yield ThinkingTextMessageContentEvent(
|
|
460
|
-
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
461
|
-
delta=part.content,
|
|
462
|
-
)
|
|
463
|
-
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
464
|
-
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
465
|
-
)
|
|
466
487
|
|
|
467
488
|
elif isinstance(agent_event, PartDeltaEvent):
|
|
468
489
|
delta = agent_event.delta
|
|
469
490
|
if isinstance(delta, TextPartDelta):
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
491
|
+
if delta.content_delta: # pragma: no branch
|
|
492
|
+
yield TextMessageContentEvent(
|
|
493
|
+
message_id=stream_ctx.message_id,
|
|
494
|
+
delta=delta.content_delta,
|
|
495
|
+
)
|
|
474
496
|
elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
|
|
475
497
|
assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
|
|
476
498
|
yield ToolCallArgsEvent(
|
|
@@ -479,6 +501,14 @@ async def _handle_model_request_event(
|
|
|
479
501
|
)
|
|
480
502
|
elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
|
|
481
503
|
if delta.content_delta: # pragma: no branch
|
|
504
|
+
if not isinstance(stream_ctx.part_end, ThinkingTextMessageEndEvent):
|
|
505
|
+
yield ThinkingTextMessageStartEvent(
|
|
506
|
+
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
507
|
+
)
|
|
508
|
+
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
509
|
+
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
510
|
+
)
|
|
511
|
+
|
|
482
512
|
yield ThinkingTextMessageContentEvent(
|
|
483
513
|
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
484
514
|
delta=delta.content_delta,
|
|
@@ -515,7 +545,7 @@ async def _handle_tool_result_event(
|
|
|
515
545
|
content = result.content
|
|
516
546
|
if isinstance(content, BaseEvent):
|
|
517
547
|
yield content
|
|
518
|
-
elif isinstance(content,
|
|
548
|
+
elif isinstance(content, str | bytes): # pragma: no branch
|
|
519
549
|
# Avoid iterable check for strings and bytes.
|
|
520
550
|
pass
|
|
521
551
|
elif isinstance(content, Iterable): # pragma: no branch
|
|
@@ -630,6 +660,7 @@ class _RequestStreamContext:
|
|
|
630
660
|
|
|
631
661
|
message_id: str = ''
|
|
632
662
|
part_end: BaseEvent | None = None
|
|
663
|
+
thinking: bool = False
|
|
633
664
|
|
|
634
665
|
def new_message_id(self) -> str:
|
|
635
666
|
"""Generate a new message ID for the request stream.
|
|
@@ -681,7 +712,7 @@ class _ToolCallNotFoundError(_RunError, ValueError):
|
|
|
681
712
|
)
|
|
682
713
|
|
|
683
714
|
|
|
684
|
-
class _AGUIFrontendToolset(
|
|
715
|
+
class _AGUIFrontendToolset(ExternalToolset[AgentDepsT]):
|
|
685
716
|
def __init__(self, tools: list[AGUITool]):
|
|
686
717
|
super().__init__(
|
|
687
718
|
[
|
pydantic_ai/agent/__init__.py
CHANGED
|
@@ -5,14 +5,14 @@ import inspect
|
|
|
5
5
|
import json
|
|
6
6
|
import warnings
|
|
7
7
|
from asyncio import Lock
|
|
8
|
-
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
8
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
|
|
9
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
|
|
10
10
|
from contextvars import ContextVar
|
|
11
|
-
from typing import TYPE_CHECKING, Any,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
|
|
12
12
|
|
|
13
13
|
from opentelemetry.trace import NoOpTracer, use_span
|
|
14
14
|
from pydantic.json_schema import GenerateJsonSchema
|
|
15
|
-
from typing_extensions import TypeVar, deprecated
|
|
15
|
+
from typing_extensions import Self, TypeVar, deprecated
|
|
16
16
|
|
|
17
17
|
from pydantic_graph import Graph
|
|
18
18
|
|
|
@@ -45,10 +45,15 @@ from ..run import AgentRun, AgentRunResult
|
|
|
45
45
|
from ..settings import ModelSettings, merge_model_settings
|
|
46
46
|
from ..tools import (
|
|
47
47
|
AgentDepsT,
|
|
48
|
+
DeferredToolCallResult,
|
|
49
|
+
DeferredToolResult,
|
|
50
|
+
DeferredToolResults,
|
|
48
51
|
DocstringFormat,
|
|
49
52
|
GenerateToolJsonSchema,
|
|
50
53
|
RunContext,
|
|
51
54
|
Tool,
|
|
55
|
+
ToolApproved,
|
|
56
|
+
ToolDenied,
|
|
52
57
|
ToolFuncContext,
|
|
53
58
|
ToolFuncEither,
|
|
54
59
|
ToolFuncPlain,
|
|
@@ -321,7 +326,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
321
326
|
|
|
322
327
|
self._instructions = ''
|
|
323
328
|
self._instructions_functions = []
|
|
324
|
-
if isinstance(instructions,
|
|
329
|
+
if isinstance(instructions, str | Callable):
|
|
325
330
|
instructions = [instructions]
|
|
326
331
|
for instruction in instructions or []:
|
|
327
332
|
if isinstance(instruction, str):
|
|
@@ -346,7 +351,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
346
351
|
if self._output_toolset:
|
|
347
352
|
self._output_toolset.max_retries = self._max_result_retries
|
|
348
353
|
|
|
349
|
-
self._function_toolset = _AgentFunctionToolset(
|
|
354
|
+
self._function_toolset = _AgentFunctionToolset(
|
|
355
|
+
tools, max_retries=self._max_tool_retries, output_schema=self._output_schema
|
|
356
|
+
)
|
|
350
357
|
self._dynamic_toolsets = [
|
|
351
358
|
DynamicToolset[AgentDepsT](toolset_func=toolset)
|
|
352
359
|
for toolset in toolsets or []
|
|
@@ -367,7 +374,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
367
374
|
_utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
|
|
368
375
|
] = ContextVar('_override_tools', default=None)
|
|
369
376
|
|
|
370
|
-
self._enter_lock =
|
|
377
|
+
self._enter_lock = Lock()
|
|
371
378
|
self._entered_count = 0
|
|
372
379
|
self._exit_stack = None
|
|
373
380
|
|
|
@@ -427,6 +434,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
427
434
|
*,
|
|
428
435
|
output_type: None = None,
|
|
429
436
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
437
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
430
438
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
431
439
|
deps: AgentDepsT = None,
|
|
432
440
|
model_settings: ModelSettings | None = None,
|
|
@@ -443,6 +451,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
443
451
|
*,
|
|
444
452
|
output_type: OutputSpec[RunOutputDataT],
|
|
445
453
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
454
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
446
455
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
447
456
|
deps: AgentDepsT = None,
|
|
448
457
|
model_settings: ModelSettings | None = None,
|
|
@@ -453,12 +462,13 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
453
462
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
|
|
454
463
|
|
|
455
464
|
@asynccontextmanager
|
|
456
|
-
async def iter(
|
|
465
|
+
async def iter( # noqa: C901
|
|
457
466
|
self,
|
|
458
467
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
459
468
|
*,
|
|
460
469
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
461
470
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
471
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
462
472
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
463
473
|
deps: AgentDepsT = None,
|
|
464
474
|
model_settings: ModelSettings | None = None,
|
|
@@ -531,6 +541,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
531
541
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
532
542
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
533
543
|
message_history: History of the conversation so far.
|
|
544
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
534
545
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
535
546
|
deps: Optional dependencies to use for this run.
|
|
536
547
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -609,6 +620,23 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
609
620
|
instrumentation_settings = None
|
|
610
621
|
tracer = NoOpTracer()
|
|
611
622
|
|
|
623
|
+
tool_call_results: dict[str, DeferredToolResult] | None = None
|
|
624
|
+
if deferred_tool_results is not None:
|
|
625
|
+
tool_call_results = {}
|
|
626
|
+
for tool_call_id, approval in deferred_tool_results.approvals.items():
|
|
627
|
+
if approval is True:
|
|
628
|
+
approval = ToolApproved()
|
|
629
|
+
elif approval is False:
|
|
630
|
+
approval = ToolDenied()
|
|
631
|
+
tool_call_results[tool_call_id] = approval
|
|
632
|
+
|
|
633
|
+
if calls := deferred_tool_results.calls:
|
|
634
|
+
call_result_types = _utils.get_union_args(DeferredToolCallResult)
|
|
635
|
+
for tool_call_id, result in calls.items():
|
|
636
|
+
if not isinstance(result, call_result_types):
|
|
637
|
+
result = _messages.ToolReturn(result)
|
|
638
|
+
tool_call_results[tool_call_id] = result
|
|
639
|
+
|
|
612
640
|
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
613
641
|
user_deps=deps,
|
|
614
642
|
prompt=user_prompt,
|
|
@@ -623,6 +651,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
623
651
|
history_processors=self.history_processors,
|
|
624
652
|
builtin_tools=list(self._builtin_tools),
|
|
625
653
|
tool_manager=tool_manager,
|
|
654
|
+
tool_call_results=tool_call_results,
|
|
626
655
|
tracer=tracer,
|
|
627
656
|
get_instructions=get_instructions,
|
|
628
657
|
instrumentation_settings=instrumentation_settings,
|
|
@@ -678,22 +707,28 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
678
707
|
self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings
|
|
679
708
|
):
|
|
680
709
|
if settings.version == 1:
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
710
|
+
attrs = {
|
|
711
|
+
'all_messages_events': json.dumps(
|
|
712
|
+
[
|
|
713
|
+
InstrumentedModel.event_to_dict(e)
|
|
714
|
+
for e in settings.messages_to_otel_events(state.message_history)
|
|
715
|
+
]
|
|
716
|
+
)
|
|
717
|
+
}
|
|
685
718
|
else:
|
|
686
|
-
|
|
687
|
-
|
|
719
|
+
attrs = {
|
|
720
|
+
'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)),
|
|
721
|
+
**settings.system_instructions_attributes(self._instructions),
|
|
722
|
+
}
|
|
688
723
|
|
|
689
724
|
return {
|
|
690
725
|
**usage.opentelemetry_attributes(),
|
|
691
|
-
|
|
726
|
+
**attrs,
|
|
692
727
|
'logfire.json_schema': json.dumps(
|
|
693
728
|
{
|
|
694
729
|
'type': 'object',
|
|
695
730
|
'properties': {
|
|
696
|
-
|
|
731
|
+
**{attr: {'type': 'array'} for attr in attrs.keys()},
|
|
697
732
|
'final_result': {'type': 'object'},
|
|
698
733
|
},
|
|
699
734
|
}
|
|
@@ -970,6 +1005,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
970
1005
|
require_parameter_descriptions: bool = False,
|
|
971
1006
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
972
1007
|
strict: bool | None = None,
|
|
1008
|
+
requires_approval: bool = False,
|
|
973
1009
|
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
|
|
974
1010
|
|
|
975
1011
|
def tool(
|
|
@@ -984,6 +1020,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
984
1020
|
require_parameter_descriptions: bool = False,
|
|
985
1021
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
986
1022
|
strict: bool | None = None,
|
|
1023
|
+
requires_approval: bool = False,
|
|
987
1024
|
) -> Any:
|
|
988
1025
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
989
1026
|
|
|
@@ -1028,6 +1065,8 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1028
1065
|
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
1029
1066
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1030
1067
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1068
|
+
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
|
|
1069
|
+
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
1031
1070
|
"""
|
|
1032
1071
|
|
|
1033
1072
|
def tool_decorator(
|
|
@@ -1044,6 +1083,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1044
1083
|
require_parameter_descriptions,
|
|
1045
1084
|
schema_generator,
|
|
1046
1085
|
strict,
|
|
1086
|
+
requires_approval,
|
|
1047
1087
|
)
|
|
1048
1088
|
return func_
|
|
1049
1089
|
|
|
@@ -1064,6 +1104,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1064
1104
|
require_parameter_descriptions: bool = False,
|
|
1065
1105
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1066
1106
|
strict: bool | None = None,
|
|
1107
|
+
requires_approval: bool = False,
|
|
1067
1108
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
1068
1109
|
|
|
1069
1110
|
def tool_plain(
|
|
@@ -1078,6 +1119,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1078
1119
|
require_parameter_descriptions: bool = False,
|
|
1079
1120
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1080
1121
|
strict: bool | None = None,
|
|
1122
|
+
requires_approval: bool = False,
|
|
1081
1123
|
) -> Any:
|
|
1082
1124
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
1083
1125
|
|
|
@@ -1122,6 +1164,8 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1122
1164
|
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
1123
1165
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1124
1166
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1167
|
+
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
|
|
1168
|
+
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
1125
1169
|
"""
|
|
1126
1170
|
|
|
1127
1171
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
@@ -1136,6 +1180,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1136
1180
|
require_parameter_descriptions,
|
|
1137
1181
|
schema_generator,
|
|
1138
1182
|
strict,
|
|
1183
|
+
requires_approval,
|
|
1139
1184
|
)
|
|
1140
1185
|
return func_
|
|
1141
1186
|
|
|
@@ -1279,7 +1324,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1279
1324
|
toolsets: list[AbstractToolset[AgentDepsT]] = []
|
|
1280
1325
|
|
|
1281
1326
|
if some_tools := self._override_tools.get():
|
|
1282
|
-
function_toolset = _AgentFunctionToolset(
|
|
1327
|
+
function_toolset = _AgentFunctionToolset(
|
|
1328
|
+
some_tools.value, max_retries=self._max_tool_retries, output_schema=self._output_schema
|
|
1329
|
+
)
|
|
1283
1330
|
else:
|
|
1284
1331
|
function_toolset = self._function_toolset
|
|
1285
1332
|
toolsets.append(function_toolset)
|
|
@@ -1308,7 +1355,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1308
1355
|
|
|
1309
1356
|
return schema # pyright: ignore[reportReturnType]
|
|
1310
1357
|
|
|
1311
|
-
async def __aenter__(self) ->
|
|
1358
|
+
async def __aenter__(self) -> Self:
|
|
1312
1359
|
"""Enter the agent context.
|
|
1313
1360
|
|
|
1314
1361
|
This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
|
|
@@ -1376,6 +1423,19 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1376
1423
|
|
|
1377
1424
|
@dataclasses.dataclass(init=False)
|
|
1378
1425
|
class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
|
|
1426
|
+
output_schema: _output.BaseOutputSchema[Any]
|
|
1427
|
+
|
|
1428
|
+
def __init__(
|
|
1429
|
+
self,
|
|
1430
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
|
|
1431
|
+
*,
|
|
1432
|
+
max_retries: int = 1,
|
|
1433
|
+
id: str | None = None,
|
|
1434
|
+
output_schema: _output.BaseOutputSchema[Any],
|
|
1435
|
+
):
|
|
1436
|
+
self.output_schema = output_schema
|
|
1437
|
+
super().__init__(tools, max_retries=max_retries, id=id)
|
|
1438
|
+
|
|
1379
1439
|
@property
|
|
1380
1440
|
def id(self) -> str:
|
|
1381
1441
|
return '<agent>'
|
|
@@ -1383,3 +1443,10 @@ class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
|
|
|
1383
1443
|
@property
|
|
1384
1444
|
def label(self) -> str:
|
|
1385
1445
|
return 'the agent'
|
|
1446
|
+
|
|
1447
|
+
def add_tool(self, tool: Tool[AgentDepsT]) -> None:
|
|
1448
|
+
if tool.requires_approval and not self.output_schema.allows_deferred_tools:
|
|
1449
|
+
raise exceptions.UserError(
|
|
1450
|
+
'To use tools that require approval, add `DeferredToolRequests` to the list of output types for this agent.'
|
|
1451
|
+
)
|
|
1452
|
+
super().add_tool(tool)
|