pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +70 -9
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +149 -42
  31. pydantic_ai/models/__init__.py +6 -4
  32. pydantic_ai/models/anthropic.py +9 -16
  33. pydantic_ai/models/bedrock.py +50 -56
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +12 -13
  38. pydantic_ai/models/google.py +18 -4
  39. pydantic_ai/models/groq.py +126 -38
  40. pydantic_ai/models/huggingface.py +4 -4
  41. pydantic_ai/models/instrumented.py +35 -16
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +6 -6
  44. pydantic_ai/models/openai.py +35 -40
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +144 -41
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py CHANGED
@@ -8,12 +8,11 @@ from __future__ import annotations
8
8
 
9
9
  import json
10
10
  import uuid
11
- from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
11
+ from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence
12
12
  from dataclasses import Field, dataclass, replace
13
13
  from http import HTTPStatus
14
14
  from typing import (
15
15
  Any,
16
- Callable,
17
16
  ClassVar,
18
17
  Final,
19
18
  Generic,
@@ -46,11 +45,11 @@ from .messages import (
46
45
  UserPromptPart,
47
46
  )
48
47
  from .models import KnownModelName, Model
49
- from .output import DeferredToolCalls, OutputDataT, OutputSpec
48
+ from .output import OutputDataT, OutputSpec
50
49
  from .settings import ModelSettings
51
- from .tools import AgentDepsT, ToolDefinition
50
+ from .tools import AgentDepsT, DeferredToolRequests, ToolDefinition
52
51
  from .toolsets import AbstractToolset
53
- from .toolsets.deferred import DeferredToolset
52
+ from .toolsets.external import ExternalToolset
54
53
  from .usage import RunUsage, UsageLimits
55
54
 
56
55
  try:
@@ -69,6 +68,9 @@ try:
69
68
  TextMessageContentEvent,
70
69
  TextMessageEndEvent,
71
70
  TextMessageStartEvent,
71
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
72
+ # ThinkingEndEvent,
73
+ # ThinkingStartEvent,
72
74
  ThinkingTextMessageContentEvent,
73
75
  ThinkingTextMessageEndEvent,
74
76
  ThinkingTextMessageStartEvent,
@@ -343,7 +345,7 @@ async def run_ag_ui(
343
345
 
344
346
  async with agent.iter(
345
347
  user_prompt=None,
346
- output_type=[output_type or agent.output_type, DeferredToolCalls],
348
+ output_type=[output_type or agent.output_type, DeferredToolRequests],
347
349
  message_history=messages,
348
350
  model=model,
349
351
  deps=deps,
@@ -393,6 +395,12 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
393
395
  if stream_ctx.part_end: # pragma: no branch
394
396
  yield stream_ctx.part_end
395
397
  stream_ctx.part_end = None
398
+ if stream_ctx.thinking:
399
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
400
+ # yield ThinkingEndEvent(
401
+ # type=EventType.THINKING_END,
402
+ # )
403
+ stream_ctx.thinking = False
396
404
  elif isinstance(node, CallToolsNode):
397
405
  async with node.stream(run.ctx) as handle_stream:
398
406
  async for event in handle_stream:
@@ -401,7 +409,7 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
401
409
  yield msg
402
410
 
403
411
 
404
- async def _handle_model_request_event(
412
+ async def _handle_model_request_event( # noqa: C901
405
413
  stream_ctx: _RequestStreamContext,
406
414
  agent_event: ModelResponseStreamEvent,
407
415
  ) -> AsyncIterator[BaseEvent]:
@@ -421,56 +429,70 @@ async def _handle_model_request_event(
421
429
  stream_ctx.part_end = None
422
430
 
423
431
  part = agent_event.part
424
- if isinstance(part, TextPart):
425
- message_id = stream_ctx.new_message_id()
426
- yield TextMessageStartEvent(
427
- message_id=message_id,
428
- )
429
- if part.content: # pragma: no branch
430
- yield TextMessageContentEvent(
431
- message_id=message_id,
432
+ if isinstance(part, ThinkingPart): # pragma: no branch
433
+ if not stream_ctx.thinking:
434
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
435
+ # yield ThinkingStartEvent(
436
+ # type=EventType.THINKING_START,
437
+ # )
438
+ stream_ctx.thinking = True
439
+
440
+ if part.content:
441
+ yield ThinkingTextMessageStartEvent(
442
+ type=EventType.THINKING_TEXT_MESSAGE_START,
443
+ )
444
+ yield ThinkingTextMessageContentEvent(
445
+ type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
432
446
  delta=part.content,
433
447
  )
434
- stream_ctx.part_end = TextMessageEndEvent(
435
- message_id=message_id,
436
- )
437
- elif isinstance(part, ToolCallPart): # pragma: no branch
438
- message_id = stream_ctx.message_id or stream_ctx.new_message_id()
439
- yield ToolCallStartEvent(
440
- tool_call_id=part.tool_call_id,
441
- tool_call_name=part.tool_name,
442
- parent_message_id=message_id,
443
- )
444
- if part.args:
445
- yield ToolCallArgsEvent(
448
+ stream_ctx.part_end = ThinkingTextMessageEndEvent(
449
+ type=EventType.THINKING_TEXT_MESSAGE_END,
450
+ )
451
+ else:
452
+ if stream_ctx.thinking:
453
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
454
+ # yield ThinkingEndEvent(
455
+ # type=EventType.THINKING_END,
456
+ # )
457
+ stream_ctx.thinking = False
458
+
459
+ if isinstance(part, TextPart):
460
+ message_id = stream_ctx.new_message_id()
461
+ yield TextMessageStartEvent(
462
+ message_id=message_id,
463
+ )
464
+ if part.content: # pragma: no branch
465
+ yield TextMessageContentEvent(
466
+ message_id=message_id,
467
+ delta=part.content,
468
+ )
469
+ stream_ctx.part_end = TextMessageEndEvent(
470
+ message_id=message_id,
471
+ )
472
+ elif isinstance(part, ToolCallPart): # pragma: no branch
473
+ message_id = stream_ctx.message_id or stream_ctx.new_message_id()
474
+ yield ToolCallStartEvent(
475
+ tool_call_id=part.tool_call_id,
476
+ tool_call_name=part.tool_name,
477
+ parent_message_id=message_id,
478
+ )
479
+ if part.args:
480
+ yield ToolCallArgsEvent(
481
+ tool_call_id=part.tool_call_id,
482
+ delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
483
+ )
484
+ stream_ctx.part_end = ToolCallEndEvent(
446
485
  tool_call_id=part.tool_call_id,
447
- delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
448
486
  )
449
- stream_ctx.part_end = ToolCallEndEvent(
450
- tool_call_id=part.tool_call_id,
451
- )
452
-
453
- elif isinstance(part, ThinkingPart): # pragma: no branch
454
- yield ThinkingTextMessageStartEvent(
455
- type=EventType.THINKING_TEXT_MESSAGE_START,
456
- )
457
- # Always send the content even if it's empty, as it may be
458
- # used to indicate the start of thinking.
459
- yield ThinkingTextMessageContentEvent(
460
- type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
461
- delta=part.content,
462
- )
463
- stream_ctx.part_end = ThinkingTextMessageEndEvent(
464
- type=EventType.THINKING_TEXT_MESSAGE_END,
465
- )
466
487
 
467
488
  elif isinstance(agent_event, PartDeltaEvent):
468
489
  delta = agent_event.delta
469
490
  if isinstance(delta, TextPartDelta):
470
- yield TextMessageContentEvent(
471
- message_id=stream_ctx.message_id,
472
- delta=delta.content_delta,
473
- )
491
+ if delta.content_delta: # pragma: no branch
492
+ yield TextMessageContentEvent(
493
+ message_id=stream_ctx.message_id,
494
+ delta=delta.content_delta,
495
+ )
474
496
  elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
475
497
  assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
476
498
  yield ToolCallArgsEvent(
@@ -479,6 +501,14 @@ async def _handle_model_request_event(
479
501
  )
480
502
  elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
481
503
  if delta.content_delta: # pragma: no branch
504
+ if not isinstance(stream_ctx.part_end, ThinkingTextMessageEndEvent):
505
+ yield ThinkingTextMessageStartEvent(
506
+ type=EventType.THINKING_TEXT_MESSAGE_START,
507
+ )
508
+ stream_ctx.part_end = ThinkingTextMessageEndEvent(
509
+ type=EventType.THINKING_TEXT_MESSAGE_END,
510
+ )
511
+
482
512
  yield ThinkingTextMessageContentEvent(
483
513
  type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
484
514
  delta=delta.content_delta,
@@ -515,7 +545,7 @@ async def _handle_tool_result_event(
515
545
  content = result.content
516
546
  if isinstance(content, BaseEvent):
517
547
  yield content
518
- elif isinstance(content, (str, bytes)): # pragma: no branch
548
+ elif isinstance(content, str | bytes): # pragma: no branch
519
549
  # Avoid iterable check for strings and bytes.
520
550
  pass
521
551
  elif isinstance(content, Iterable): # pragma: no branch
@@ -630,6 +660,7 @@ class _RequestStreamContext:
630
660
 
631
661
  message_id: str = ''
632
662
  part_end: BaseEvent | None = None
663
+ thinking: bool = False
633
664
 
634
665
  def new_message_id(self) -> str:
635
666
  """Generate a new message ID for the request stream.
@@ -681,7 +712,7 @@ class _ToolCallNotFoundError(_RunError, ValueError):
681
712
  )
682
713
 
683
714
 
684
- class _AGUIFrontendToolset(DeferredToolset[AgentDepsT]):
715
+ class _AGUIFrontendToolset(ExternalToolset[AgentDepsT]):
685
716
  def __init__(self, tools: list[AGUITool]):
686
717
  super().__init__(
687
718
  [
@@ -5,14 +5,14 @@ import inspect
5
5
  import json
6
6
  import warnings
7
7
  from asyncio import Lock
8
- from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
8
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
9
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
10
10
  from contextvars import ContextVar
11
- from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast, overload
11
+ from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
12
12
 
13
13
  from opentelemetry.trace import NoOpTracer, use_span
14
14
  from pydantic.json_schema import GenerateJsonSchema
15
- from typing_extensions import TypeVar, deprecated
15
+ from typing_extensions import Self, TypeVar, deprecated
16
16
 
17
17
  from pydantic_graph import Graph
18
18
 
@@ -45,10 +45,15 @@ from ..run import AgentRun, AgentRunResult
45
45
  from ..settings import ModelSettings, merge_model_settings
46
46
  from ..tools import (
47
47
  AgentDepsT,
48
+ DeferredToolCallResult,
49
+ DeferredToolResult,
50
+ DeferredToolResults,
48
51
  DocstringFormat,
49
52
  GenerateToolJsonSchema,
50
53
  RunContext,
51
54
  Tool,
55
+ ToolApproved,
56
+ ToolDenied,
52
57
  ToolFuncContext,
53
58
  ToolFuncEither,
54
59
  ToolFuncPlain,
@@ -321,7 +326,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
321
326
 
322
327
  self._instructions = ''
323
328
  self._instructions_functions = []
324
- if isinstance(instructions, (str, Callable)):
329
+ if isinstance(instructions, str | Callable):
325
330
  instructions = [instructions]
326
331
  for instruction in instructions or []:
327
332
  if isinstance(instruction, str):
@@ -346,7 +351,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
346
351
  if self._output_toolset:
347
352
  self._output_toolset.max_retries = self._max_result_retries
348
353
 
349
- self._function_toolset = _AgentFunctionToolset(tools, max_retries=self._max_tool_retries)
354
+ self._function_toolset = _AgentFunctionToolset(
355
+ tools, max_retries=self._max_tool_retries, output_schema=self._output_schema
356
+ )
350
357
  self._dynamic_toolsets = [
351
358
  DynamicToolset[AgentDepsT](toolset_func=toolset)
352
359
  for toolset in toolsets or []
@@ -367,7 +374,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
367
374
  _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
368
375
  ] = ContextVar('_override_tools', default=None)
369
376
 
370
- self._enter_lock = _utils.get_async_lock()
377
+ self._enter_lock = Lock()
371
378
  self._entered_count = 0
372
379
  self._exit_stack = None
373
380
 
@@ -427,6 +434,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
427
434
  *,
428
435
  output_type: None = None,
429
436
  message_history: list[_messages.ModelMessage] | None = None,
437
+ deferred_tool_results: DeferredToolResults | None = None,
430
438
  model: models.Model | models.KnownModelName | str | None = None,
431
439
  deps: AgentDepsT = None,
432
440
  model_settings: ModelSettings | None = None,
@@ -443,6 +451,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
443
451
  *,
444
452
  output_type: OutputSpec[RunOutputDataT],
445
453
  message_history: list[_messages.ModelMessage] | None = None,
454
+ deferred_tool_results: DeferredToolResults | None = None,
446
455
  model: models.Model | models.KnownModelName | str | None = None,
447
456
  deps: AgentDepsT = None,
448
457
  model_settings: ModelSettings | None = None,
@@ -453,12 +462,13 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
453
462
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
454
463
 
455
464
  @asynccontextmanager
456
- async def iter(
465
+ async def iter( # noqa: C901
457
466
  self,
458
467
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
459
468
  *,
460
469
  output_type: OutputSpec[RunOutputDataT] | None = None,
461
470
  message_history: list[_messages.ModelMessage] | None = None,
471
+ deferred_tool_results: DeferredToolResults | None = None,
462
472
  model: models.Model | models.KnownModelName | str | None = None,
463
473
  deps: AgentDepsT = None,
464
474
  model_settings: ModelSettings | None = None,
@@ -531,6 +541,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
531
541
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
532
542
  output validators since output validators would expect an argument that matches the agent's output type.
533
543
  message_history: History of the conversation so far.
544
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
534
545
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
535
546
  deps: Optional dependencies to use for this run.
536
547
  model_settings: Optional settings to use for this model's request.
@@ -609,6 +620,23 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
609
620
  instrumentation_settings = None
610
621
  tracer = NoOpTracer()
611
622
 
623
+ tool_call_results: dict[str, DeferredToolResult] | None = None
624
+ if deferred_tool_results is not None:
625
+ tool_call_results = {}
626
+ for tool_call_id, approval in deferred_tool_results.approvals.items():
627
+ if approval is True:
628
+ approval = ToolApproved()
629
+ elif approval is False:
630
+ approval = ToolDenied()
631
+ tool_call_results[tool_call_id] = approval
632
+
633
+ if calls := deferred_tool_results.calls:
634
+ call_result_types = _utils.get_union_args(DeferredToolCallResult)
635
+ for tool_call_id, result in calls.items():
636
+ if not isinstance(result, call_result_types):
637
+ result = _messages.ToolReturn(result)
638
+ tool_call_results[tool_call_id] = result
639
+
612
640
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
613
641
  user_deps=deps,
614
642
  prompt=user_prompt,
@@ -623,6 +651,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
623
651
  history_processors=self.history_processors,
624
652
  builtin_tools=list(self._builtin_tools),
625
653
  tool_manager=tool_manager,
654
+ tool_call_results=tool_call_results,
626
655
  tracer=tracer,
627
656
  get_instructions=get_instructions,
628
657
  instrumentation_settings=instrumentation_settings,
@@ -976,6 +1005,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
976
1005
  require_parameter_descriptions: bool = False,
977
1006
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
978
1007
  strict: bool | None = None,
1008
+ requires_approval: bool = False,
979
1009
  ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
980
1010
 
981
1011
  def tool(
@@ -990,6 +1020,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
990
1020
  require_parameter_descriptions: bool = False,
991
1021
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
992
1022
  strict: bool | None = None,
1023
+ requires_approval: bool = False,
993
1024
  ) -> Any:
994
1025
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
995
1026
 
@@ -1034,6 +1065,8 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1034
1065
  schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1035
1066
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1036
1067
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1068
+ requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
1069
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
1037
1070
  """
1038
1071
 
1039
1072
  def tool_decorator(
@@ -1050,6 +1083,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1050
1083
  require_parameter_descriptions,
1051
1084
  schema_generator,
1052
1085
  strict,
1086
+ requires_approval,
1053
1087
  )
1054
1088
  return func_
1055
1089
 
@@ -1070,6 +1104,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1070
1104
  require_parameter_descriptions: bool = False,
1071
1105
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1072
1106
  strict: bool | None = None,
1107
+ requires_approval: bool = False,
1073
1108
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
1074
1109
 
1075
1110
  def tool_plain(
@@ -1084,6 +1119,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1084
1119
  require_parameter_descriptions: bool = False,
1085
1120
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1086
1121
  strict: bool | None = None,
1122
+ requires_approval: bool = False,
1087
1123
  ) -> Any:
1088
1124
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
1089
1125
 
@@ -1128,6 +1164,8 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1128
1164
  schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1129
1165
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1130
1166
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1167
+ requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
1168
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
1131
1169
  """
1132
1170
 
1133
1171
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
@@ -1142,6 +1180,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1142
1180
  require_parameter_descriptions,
1143
1181
  schema_generator,
1144
1182
  strict,
1183
+ requires_approval,
1145
1184
  )
1146
1185
  return func_
1147
1186
 
@@ -1285,7 +1324,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1285
1324
  toolsets: list[AbstractToolset[AgentDepsT]] = []
1286
1325
 
1287
1326
  if some_tools := self._override_tools.get():
1288
- function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries)
1327
+ function_toolset = _AgentFunctionToolset(
1328
+ some_tools.value, max_retries=self._max_tool_retries, output_schema=self._output_schema
1329
+ )
1289
1330
  else:
1290
1331
  function_toolset = self._function_toolset
1291
1332
  toolsets.append(function_toolset)
@@ -1314,7 +1355,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1314
1355
 
1315
1356
  return schema # pyright: ignore[reportReturnType]
1316
1357
 
1317
- async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
1358
+ async def __aenter__(self) -> Self:
1318
1359
  """Enter the agent context.
1319
1360
 
1320
1361
  This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
@@ -1382,6 +1423,19 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1382
1423
 
1383
1424
  @dataclasses.dataclass(init=False)
1384
1425
  class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
1426
+ output_schema: _output.BaseOutputSchema[Any]
1427
+
1428
+ def __init__(
1429
+ self,
1430
+ tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
1431
+ *,
1432
+ max_retries: int = 1,
1433
+ id: str | None = None,
1434
+ output_schema: _output.BaseOutputSchema[Any],
1435
+ ):
1436
+ self.output_schema = output_schema
1437
+ super().__init__(tools, max_retries=max_retries, id=id)
1438
+
1385
1439
  @property
1386
1440
  def id(self) -> str:
1387
1441
  return '<agent>'
@@ -1389,3 +1443,10 @@ class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
1389
1443
  @property
1390
1444
  def label(self) -> str:
1391
1445
  return 'the agent'
1446
+
1447
+ def add_tool(self, tool: Tool[AgentDepsT]) -> None:
1448
+ if tool.requires_approval and not self.output_schema.allows_deferred_tools:
1449
+ raise exceptions.UserError(
1450
+ 'To use tools that require approval, add `DeferredToolRequests` to the list of output types for this agent.'
1451
+ )
1452
+ super().add_tool(tool)
@@ -2,12 +2,12 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  from abc import ABC, abstractmethod
5
- from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator, Mapping, Sequence
5
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator, Mapping, Sequence
6
6
  from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
7
7
  from types import FrameType
8
- from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
8
+ from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast, overload
9
9
 
10
- from typing_extensions import Self, TypeAlias, TypeIs, TypeVar
10
+ from typing_extensions import Self, TypeIs, TypeVar
11
11
 
12
12
  from pydantic_graph import End
13
13
  from pydantic_graph._utils import get_event_loop
@@ -27,6 +27,7 @@ from ..run import AgentRun, AgentRunResult
27
27
  from ..settings import ModelSettings
28
28
  from ..tools import (
29
29
  AgentDepsT,
30
+ DeferredToolResults,
30
31
  RunContext,
31
32
  Tool,
32
33
  ToolFuncEither,
@@ -116,6 +117,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
116
117
  *,
117
118
  output_type: None = None,
118
119
  message_history: list[_messages.ModelMessage] | None = None,
120
+ deferred_tool_results: DeferredToolResults | None = None,
119
121
  model: models.Model | models.KnownModelName | str | None = None,
120
122
  deps: AgentDepsT = None,
121
123
  model_settings: ModelSettings | None = None,
@@ -133,6 +135,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
133
135
  *,
134
136
  output_type: OutputSpec[RunOutputDataT],
135
137
  message_history: list[_messages.ModelMessage] | None = None,
138
+ deferred_tool_results: DeferredToolResults | None = None,
136
139
  model: models.Model | models.KnownModelName | str | None = None,
137
140
  deps: AgentDepsT = None,
138
141
  model_settings: ModelSettings | None = None,
@@ -149,6 +152,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
149
152
  *,
150
153
  output_type: OutputSpec[RunOutputDataT] | None = None,
151
154
  message_history: list[_messages.ModelMessage] | None = None,
155
+ deferred_tool_results: DeferredToolResults | None = None,
152
156
  model: models.Model | models.KnownModelName | str | None = None,
153
157
  deps: AgentDepsT = None,
154
158
  model_settings: ModelSettings | None = None,
@@ -180,6 +184,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
180
184
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
181
185
  output validators since output validators would expect an argument that matches the agent's output type.
182
186
  message_history: History of the conversation so far.
187
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
183
188
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
184
189
  deps: Optional dependencies to use for this run.
185
190
  model_settings: Optional settings to use for this model's request.
@@ -201,6 +206,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
201
206
  user_prompt=user_prompt,
202
207
  output_type=output_type,
203
208
  message_history=message_history,
209
+ deferred_tool_results=deferred_tool_results,
204
210
  model=model,
205
211
  deps=deps,
206
212
  model_settings=model_settings,
@@ -225,6 +231,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
225
231
  *,
226
232
  output_type: None = None,
227
233
  message_history: list[_messages.ModelMessage] | None = None,
234
+ deferred_tool_results: DeferredToolResults | None = None,
228
235
  model: models.Model | models.KnownModelName | str | None = None,
229
236
  deps: AgentDepsT = None,
230
237
  model_settings: ModelSettings | None = None,
@@ -242,6 +249,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
242
249
  *,
243
250
  output_type: OutputSpec[RunOutputDataT],
244
251
  message_history: list[_messages.ModelMessage] | None = None,
252
+ deferred_tool_results: DeferredToolResults | None = None,
245
253
  model: models.Model | models.KnownModelName | str | None = None,
246
254
  deps: AgentDepsT = None,
247
255
  model_settings: ModelSettings | None = None,
@@ -258,6 +266,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
258
266
  *,
259
267
  output_type: OutputSpec[RunOutputDataT] | None = None,
260
268
  message_history: list[_messages.ModelMessage] | None = None,
269
+ deferred_tool_results: DeferredToolResults | None = None,
261
270
  model: models.Model | models.KnownModelName | str | None = None,
262
271
  deps: AgentDepsT = None,
263
272
  model_settings: ModelSettings | None = None,
@@ -288,6 +297,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
288
297
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
289
298
  output validators since output validators would expect an argument that matches the agent's output type.
290
299
  message_history: History of the conversation so far.
300
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
291
301
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
292
302
  deps: Optional dependencies to use for this run.
293
303
  model_settings: Optional settings to use for this model's request.
@@ -308,6 +318,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
308
318
  user_prompt,
309
319
  output_type=output_type,
310
320
  message_history=message_history,
321
+ deferred_tool_results=deferred_tool_results,
311
322
  model=model,
312
323
  deps=deps,
313
324
  model_settings=model_settings,
@@ -326,6 +337,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
326
337
  *,
327
338
  output_type: None = None,
328
339
  message_history: list[_messages.ModelMessage] | None = None,
340
+ deferred_tool_results: DeferredToolResults | None = None,
329
341
  model: models.Model | models.KnownModelName | str | None = None,
330
342
  deps: AgentDepsT = None,
331
343
  model_settings: ModelSettings | None = None,
@@ -343,6 +355,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
343
355
  *,
344
356
  output_type: OutputSpec[RunOutputDataT],
345
357
  message_history: list[_messages.ModelMessage] | None = None,
358
+ deferred_tool_results: DeferredToolResults | None = None,
346
359
  model: models.Model | models.KnownModelName | str | None = None,
347
360
  deps: AgentDepsT = None,
348
361
  model_settings: ModelSettings | None = None,
@@ -360,6 +373,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
360
373
  *,
361
374
  output_type: OutputSpec[RunOutputDataT] | None = None,
362
375
  message_history: list[_messages.ModelMessage] | None = None,
376
+ deferred_tool_results: DeferredToolResults | None = None,
363
377
  model: models.Model | models.KnownModelName | str | None = None,
364
378
  deps: AgentDepsT = None,
365
379
  model_settings: ModelSettings | None = None,
@@ -398,6 +412,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
398
412
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
399
413
  output validators since output validators would expect an argument that matches the agent's output type.
400
414
  message_history: History of the conversation so far.
415
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
401
416
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
402
417
  deps: Optional dependencies to use for this run.
403
418
  model_settings: Optional settings to use for this model's request.
@@ -424,6 +439,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
424
439
  user_prompt,
425
440
  output_type=output_type,
426
441
  message_history=message_history,
442
+ deferred_tool_results=deferred_tool_results,
427
443
  model=model,
428
444
  deps=deps,
429
445
  model_settings=model_settings,
@@ -436,8 +452,8 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
436
452
  assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
437
453
  node = first_node
438
454
  while True:
455
+ graph_ctx = agent_run.ctx
439
456
  if self.is_model_request_node(node):
440
- graph_ctx = agent_run.ctx
441
457
  async with node.stream(graph_ctx) as stream:
442
458
  final_result_event = None
443
459
 
@@ -505,6 +521,17 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
505
521
  await event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream)
506
522
 
507
523
  next_node = await agent_run.next(node)
524
+ if isinstance(next_node, End) and agent_run.result is not None:
525
+ # A final output could have been produced by the CallToolsNode rather than the ModelRequestNode,
526
+ # if a tool function raised CallDeferred or ApprovalRequired.
527
+ # In this case there's no response to stream, but we still let the user access the output etc as normal.
528
+ yield StreamedRunResult(
529
+ graph_ctx.state.message_history,
530
+ graph_ctx.deps.new_message_index,
531
+ run_result=agent_run.result,
532
+ )
533
+ yielded = True
534
+ break
508
535
  if not isinstance(next_node, _agent_graph.AgentNode):
509
536
  raise exceptions.AgentRunError( # pragma: no cover
510
537
  'Should have produced a StreamedRunResult before getting here'
@@ -521,6 +548,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
521
548
  *,
522
549
  output_type: None = None,
523
550
  message_history: list[_messages.ModelMessage] | None = None,
551
+ deferred_tool_results: DeferredToolResults | None = None,
524
552
  model: models.Model | models.KnownModelName | str | None = None,
525
553
  deps: AgentDepsT = None,
526
554
  model_settings: ModelSettings | None = None,
@@ -537,6 +565,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
537
565
  *,
538
566
  output_type: OutputSpec[RunOutputDataT],
539
567
  message_history: list[_messages.ModelMessage] | None = None,
568
+ deferred_tool_results: DeferredToolResults | None = None,
540
569
  model: models.Model | models.KnownModelName | str | None = None,
541
570
  deps: AgentDepsT = None,
542
571
  model_settings: ModelSettings | None = None,
@@ -554,6 +583,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
554
583
  *,
555
584
  output_type: OutputSpec[RunOutputDataT] | None = None,
556
585
  message_history: list[_messages.ModelMessage] | None = None,
586
+ deferred_tool_results: DeferredToolResults | None = None,
557
587
  model: models.Model | models.KnownModelName | str | None = None,
558
588
  deps: AgentDepsT = None,
559
589
  model_settings: ModelSettings | None = None,
@@ -626,6 +656,7 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
626
656
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
627
657
  output validators since output validators would expect an argument that matches the agent's output type.
628
658
  message_history: History of the conversation so far.
659
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
629
660
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
630
661
  deps: Optional dependencies to use for this run.
631
662
  model_settings: Optional settings to use for this model's request.