pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +84 -17
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +70 -17
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +156 -44
  31. pydantic_ai/models/__init__.py +20 -7
  32. pydantic_ai/models/anthropic.py +10 -17
  33. pydantic_ai/models/bedrock.py +55 -57
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +13 -14
  38. pydantic_ai/models/google.py +19 -5
  39. pydantic_ai/models/groq.py +127 -39
  40. pydantic_ai/models/huggingface.py +5 -5
  41. pydantic_ai/models/instrumented.py +49 -21
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +8 -8
  44. pydantic_ai/models/openai.py +37 -42
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +173 -52
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py CHANGED
@@ -8,12 +8,11 @@ from __future__ import annotations
8
8
 
9
9
  import json
10
10
  import uuid
11
- from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
11
+ from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence
12
12
  from dataclasses import Field, dataclass, replace
13
13
  from http import HTTPStatus
14
14
  from typing import (
15
15
  Any,
16
- Callable,
17
16
  ClassVar,
18
17
  Final,
19
18
  Generic,
@@ -46,11 +45,11 @@ from .messages import (
46
45
  UserPromptPart,
47
46
  )
48
47
  from .models import KnownModelName, Model
49
- from .output import DeferredToolCalls, OutputDataT, OutputSpec
48
+ from .output import OutputDataT, OutputSpec
50
49
  from .settings import ModelSettings
51
- from .tools import AgentDepsT, ToolDefinition
50
+ from .tools import AgentDepsT, DeferredToolRequests, ToolDefinition
52
51
  from .toolsets import AbstractToolset
53
- from .toolsets.deferred import DeferredToolset
52
+ from .toolsets.external import ExternalToolset
54
53
  from .usage import RunUsage, UsageLimits
55
54
 
56
55
  try:
@@ -69,6 +68,9 @@ try:
69
68
  TextMessageContentEvent,
70
69
  TextMessageEndEvent,
71
70
  TextMessageStartEvent,
71
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
72
+ # ThinkingEndEvent,
73
+ # ThinkingStartEvent,
72
74
  ThinkingTextMessageContentEvent,
73
75
  ThinkingTextMessageEndEvent,
74
76
  ThinkingTextMessageStartEvent,
@@ -343,7 +345,7 @@ async def run_ag_ui(
343
345
 
344
346
  async with agent.iter(
345
347
  user_prompt=None,
346
- output_type=[output_type or agent.output_type, DeferredToolCalls],
348
+ output_type=[output_type or agent.output_type, DeferredToolRequests],
347
349
  message_history=messages,
348
350
  model=model,
349
351
  deps=deps,
@@ -393,6 +395,12 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
393
395
  if stream_ctx.part_end: # pragma: no branch
394
396
  yield stream_ctx.part_end
395
397
  stream_ctx.part_end = None
398
+ if stream_ctx.thinking:
399
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
400
+ # yield ThinkingEndEvent(
401
+ # type=EventType.THINKING_END,
402
+ # )
403
+ stream_ctx.thinking = False
396
404
  elif isinstance(node, CallToolsNode):
397
405
  async with node.stream(run.ctx) as handle_stream:
398
406
  async for event in handle_stream:
@@ -401,7 +409,7 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
401
409
  yield msg
402
410
 
403
411
 
404
- async def _handle_model_request_event(
412
+ async def _handle_model_request_event( # noqa: C901
405
413
  stream_ctx: _RequestStreamContext,
406
414
  agent_event: ModelResponseStreamEvent,
407
415
  ) -> AsyncIterator[BaseEvent]:
@@ -421,56 +429,70 @@ async def _handle_model_request_event(
421
429
  stream_ctx.part_end = None
422
430
 
423
431
  part = agent_event.part
424
- if isinstance(part, TextPart):
425
- message_id = stream_ctx.new_message_id()
426
- yield TextMessageStartEvent(
427
- message_id=message_id,
428
- )
429
- if part.content: # pragma: no branch
430
- yield TextMessageContentEvent(
431
- message_id=message_id,
432
+ if isinstance(part, ThinkingPart): # pragma: no branch
433
+ if not stream_ctx.thinking:
434
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
435
+ # yield ThinkingStartEvent(
436
+ # type=EventType.THINKING_START,
437
+ # )
438
+ stream_ctx.thinking = True
439
+
440
+ if part.content:
441
+ yield ThinkingTextMessageStartEvent(
442
+ type=EventType.THINKING_TEXT_MESSAGE_START,
443
+ )
444
+ yield ThinkingTextMessageContentEvent(
445
+ type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
432
446
  delta=part.content,
433
447
  )
434
- stream_ctx.part_end = TextMessageEndEvent(
435
- message_id=message_id,
436
- )
437
- elif isinstance(part, ToolCallPart): # pragma: no branch
438
- message_id = stream_ctx.message_id or stream_ctx.new_message_id()
439
- yield ToolCallStartEvent(
440
- tool_call_id=part.tool_call_id,
441
- tool_call_name=part.tool_name,
442
- parent_message_id=message_id,
443
- )
444
- if part.args:
445
- yield ToolCallArgsEvent(
448
+ stream_ctx.part_end = ThinkingTextMessageEndEvent(
449
+ type=EventType.THINKING_TEXT_MESSAGE_END,
450
+ )
451
+ else:
452
+ if stream_ctx.thinking:
453
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
454
+ # yield ThinkingEndEvent(
455
+ # type=EventType.THINKING_END,
456
+ # )
457
+ stream_ctx.thinking = False
458
+
459
+ if isinstance(part, TextPart):
460
+ message_id = stream_ctx.new_message_id()
461
+ yield TextMessageStartEvent(
462
+ message_id=message_id,
463
+ )
464
+ if part.content: # pragma: no branch
465
+ yield TextMessageContentEvent(
466
+ message_id=message_id,
467
+ delta=part.content,
468
+ )
469
+ stream_ctx.part_end = TextMessageEndEvent(
470
+ message_id=message_id,
471
+ )
472
+ elif isinstance(part, ToolCallPart): # pragma: no branch
473
+ message_id = stream_ctx.message_id or stream_ctx.new_message_id()
474
+ yield ToolCallStartEvent(
475
+ tool_call_id=part.tool_call_id,
476
+ tool_call_name=part.tool_name,
477
+ parent_message_id=message_id,
478
+ )
479
+ if part.args:
480
+ yield ToolCallArgsEvent(
481
+ tool_call_id=part.tool_call_id,
482
+ delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
483
+ )
484
+ stream_ctx.part_end = ToolCallEndEvent(
446
485
  tool_call_id=part.tool_call_id,
447
- delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
448
486
  )
449
- stream_ctx.part_end = ToolCallEndEvent(
450
- tool_call_id=part.tool_call_id,
451
- )
452
-
453
- elif isinstance(part, ThinkingPart): # pragma: no branch
454
- yield ThinkingTextMessageStartEvent(
455
- type=EventType.THINKING_TEXT_MESSAGE_START,
456
- )
457
- # Always send the content even if it's empty, as it may be
458
- # used to indicate the start of thinking.
459
- yield ThinkingTextMessageContentEvent(
460
- type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
461
- delta=part.content,
462
- )
463
- stream_ctx.part_end = ThinkingTextMessageEndEvent(
464
- type=EventType.THINKING_TEXT_MESSAGE_END,
465
- )
466
487
 
467
488
  elif isinstance(agent_event, PartDeltaEvent):
468
489
  delta = agent_event.delta
469
490
  if isinstance(delta, TextPartDelta):
470
- yield TextMessageContentEvent(
471
- message_id=stream_ctx.message_id,
472
- delta=delta.content_delta,
473
- )
491
+ if delta.content_delta: # pragma: no branch
492
+ yield TextMessageContentEvent(
493
+ message_id=stream_ctx.message_id,
494
+ delta=delta.content_delta,
495
+ )
474
496
  elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
475
497
  assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
476
498
  yield ToolCallArgsEvent(
@@ -479,6 +501,14 @@ async def _handle_model_request_event(
479
501
  )
480
502
  elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
481
503
  if delta.content_delta: # pragma: no branch
504
+ if not isinstance(stream_ctx.part_end, ThinkingTextMessageEndEvent):
505
+ yield ThinkingTextMessageStartEvent(
506
+ type=EventType.THINKING_TEXT_MESSAGE_START,
507
+ )
508
+ stream_ctx.part_end = ThinkingTextMessageEndEvent(
509
+ type=EventType.THINKING_TEXT_MESSAGE_END,
510
+ )
511
+
482
512
  yield ThinkingTextMessageContentEvent(
483
513
  type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
484
514
  delta=delta.content_delta,
@@ -515,7 +545,7 @@ async def _handle_tool_result_event(
515
545
  content = result.content
516
546
  if isinstance(content, BaseEvent):
517
547
  yield content
518
- elif isinstance(content, (str, bytes)): # pragma: no branch
548
+ elif isinstance(content, str | bytes): # pragma: no branch
519
549
  # Avoid iterable check for strings and bytes.
520
550
  pass
521
551
  elif isinstance(content, Iterable): # pragma: no branch
@@ -630,6 +660,7 @@ class _RequestStreamContext:
630
660
 
631
661
  message_id: str = ''
632
662
  part_end: BaseEvent | None = None
663
+ thinking: bool = False
633
664
 
634
665
  def new_message_id(self) -> str:
635
666
  """Generate a new message ID for the request stream.
@@ -681,7 +712,7 @@ class _ToolCallNotFoundError(_RunError, ValueError):
681
712
  )
682
713
 
683
714
 
684
- class _AGUIFrontendToolset(DeferredToolset[AgentDepsT]):
715
+ class _AGUIFrontendToolset(ExternalToolset[AgentDepsT]):
685
716
  def __init__(self, tools: list[AGUITool]):
686
717
  super().__init__(
687
718
  [
@@ -5,14 +5,14 @@ import inspect
5
5
  import json
6
6
  import warnings
7
7
  from asyncio import Lock
8
- from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
8
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
9
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
10
10
  from contextvars import ContextVar
11
- from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast, overload
11
+ from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
12
12
 
13
13
  from opentelemetry.trace import NoOpTracer, use_span
14
14
  from pydantic.json_schema import GenerateJsonSchema
15
- from typing_extensions import TypeVar, deprecated
15
+ from typing_extensions import Self, TypeVar, deprecated
16
16
 
17
17
  from pydantic_graph import Graph
18
18
 
@@ -45,10 +45,15 @@ from ..run import AgentRun, AgentRunResult
45
45
  from ..settings import ModelSettings, merge_model_settings
46
46
  from ..tools import (
47
47
  AgentDepsT,
48
+ DeferredToolCallResult,
49
+ DeferredToolResult,
50
+ DeferredToolResults,
48
51
  DocstringFormat,
49
52
  GenerateToolJsonSchema,
50
53
  RunContext,
51
54
  Tool,
55
+ ToolApproved,
56
+ ToolDenied,
52
57
  ToolFuncContext,
53
58
  ToolFuncEither,
54
59
  ToolFuncPlain,
@@ -321,7 +326,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
321
326
 
322
327
  self._instructions = ''
323
328
  self._instructions_functions = []
324
- if isinstance(instructions, (str, Callable)):
329
+ if isinstance(instructions, str | Callable):
325
330
  instructions = [instructions]
326
331
  for instruction in instructions or []:
327
332
  if isinstance(instruction, str):
@@ -346,7 +351,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
346
351
  if self._output_toolset:
347
352
  self._output_toolset.max_retries = self._max_result_retries
348
353
 
349
- self._function_toolset = _AgentFunctionToolset(tools, max_retries=self._max_tool_retries)
354
+ self._function_toolset = _AgentFunctionToolset(
355
+ tools, max_retries=self._max_tool_retries, output_schema=self._output_schema
356
+ )
350
357
  self._dynamic_toolsets = [
351
358
  DynamicToolset[AgentDepsT](toolset_func=toolset)
352
359
  for toolset in toolsets or []
@@ -367,7 +374,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
367
374
  _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
368
375
  ] = ContextVar('_override_tools', default=None)
369
376
 
370
- self._enter_lock = _utils.get_async_lock()
377
+ self._enter_lock = Lock()
371
378
  self._entered_count = 0
372
379
  self._exit_stack = None
373
380
 
@@ -427,6 +434,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
427
434
  *,
428
435
  output_type: None = None,
429
436
  message_history: list[_messages.ModelMessage] | None = None,
437
+ deferred_tool_results: DeferredToolResults | None = None,
430
438
  model: models.Model | models.KnownModelName | str | None = None,
431
439
  deps: AgentDepsT = None,
432
440
  model_settings: ModelSettings | None = None,
@@ -443,6 +451,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
443
451
  *,
444
452
  output_type: OutputSpec[RunOutputDataT],
445
453
  message_history: list[_messages.ModelMessage] | None = None,
454
+ deferred_tool_results: DeferredToolResults | None = None,
446
455
  model: models.Model | models.KnownModelName | str | None = None,
447
456
  deps: AgentDepsT = None,
448
457
  model_settings: ModelSettings | None = None,
@@ -453,12 +462,13 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
453
462
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
454
463
 
455
464
  @asynccontextmanager
456
- async def iter(
465
+ async def iter( # noqa: C901
457
466
  self,
458
467
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
459
468
  *,
460
469
  output_type: OutputSpec[RunOutputDataT] | None = None,
461
470
  message_history: list[_messages.ModelMessage] | None = None,
471
+ deferred_tool_results: DeferredToolResults | None = None,
462
472
  model: models.Model | models.KnownModelName | str | None = None,
463
473
  deps: AgentDepsT = None,
464
474
  model_settings: ModelSettings | None = None,
@@ -531,6 +541,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
531
541
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
532
542
  output validators since output validators would expect an argument that matches the agent's output type.
533
543
  message_history: History of the conversation so far.
544
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
534
545
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
535
546
  deps: Optional dependencies to use for this run.
536
547
  model_settings: Optional settings to use for this model's request.
@@ -609,6 +620,23 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
609
620
  instrumentation_settings = None
610
621
  tracer = NoOpTracer()
611
622
 
623
+ tool_call_results: dict[str, DeferredToolResult] | None = None
624
+ if deferred_tool_results is not None:
625
+ tool_call_results = {}
626
+ for tool_call_id, approval in deferred_tool_results.approvals.items():
627
+ if approval is True:
628
+ approval = ToolApproved()
629
+ elif approval is False:
630
+ approval = ToolDenied()
631
+ tool_call_results[tool_call_id] = approval
632
+
633
+ if calls := deferred_tool_results.calls:
634
+ call_result_types = _utils.get_union_args(DeferredToolCallResult)
635
+ for tool_call_id, result in calls.items():
636
+ if not isinstance(result, call_result_types):
637
+ result = _messages.ToolReturn(result)
638
+ tool_call_results[tool_call_id] = result
639
+
612
640
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
613
641
  user_deps=deps,
614
642
  prompt=user_prompt,
@@ -623,6 +651,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
623
651
  history_processors=self.history_processors,
624
652
  builtin_tools=list(self._builtin_tools),
625
653
  tool_manager=tool_manager,
654
+ tool_call_results=tool_call_results,
626
655
  tracer=tracer,
627
656
  get_instructions=get_instructions,
628
657
  instrumentation_settings=instrumentation_settings,
@@ -678,22 +707,28 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
678
707
  self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings
679
708
  ):
680
709
  if settings.version == 1:
681
- attr_name = 'all_messages_events'
682
- value = [
683
- InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)
684
- ]
710
+ attrs = {
711
+ 'all_messages_events': json.dumps(
712
+ [
713
+ InstrumentedModel.event_to_dict(e)
714
+ for e in settings.messages_to_otel_events(state.message_history)
715
+ ]
716
+ )
717
+ }
685
718
  else:
686
- attr_name = 'pydantic_ai.all_messages'
687
- value = settings.messages_to_otel_messages(state.message_history)
719
+ attrs = {
720
+ 'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)),
721
+ **settings.system_instructions_attributes(self._instructions),
722
+ }
688
723
 
689
724
  return {
690
725
  **usage.opentelemetry_attributes(),
691
- attr_name: json.dumps(value),
726
+ **attrs,
692
727
  'logfire.json_schema': json.dumps(
693
728
  {
694
729
  'type': 'object',
695
730
  'properties': {
696
- attr_name: {'type': 'array'},
731
+ **{attr: {'type': 'array'} for attr in attrs.keys()},
697
732
  'final_result': {'type': 'object'},
698
733
  },
699
734
  }
@@ -970,6 +1005,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
970
1005
  require_parameter_descriptions: bool = False,
971
1006
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
972
1007
  strict: bool | None = None,
1008
+ requires_approval: bool = False,
973
1009
  ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
974
1010
 
975
1011
  def tool(
@@ -984,6 +1020,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
984
1020
  require_parameter_descriptions: bool = False,
985
1021
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
986
1022
  strict: bool | None = None,
1023
+ requires_approval: bool = False,
987
1024
  ) -> Any:
988
1025
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
989
1026
 
@@ -1028,6 +1065,8 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1028
1065
  schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1029
1066
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1030
1067
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1068
+ requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
1069
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
1031
1070
  """
1032
1071
 
1033
1072
  def tool_decorator(
@@ -1044,6 +1083,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1044
1083
  require_parameter_descriptions,
1045
1084
  schema_generator,
1046
1085
  strict,
1086
+ requires_approval,
1047
1087
  )
1048
1088
  return func_
1049
1089
 
@@ -1064,6 +1104,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1064
1104
  require_parameter_descriptions: bool = False,
1065
1105
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1066
1106
  strict: bool | None = None,
1107
+ requires_approval: bool = False,
1067
1108
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
1068
1109
 
1069
1110
  def tool_plain(
@@ -1078,6 +1119,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1078
1119
  require_parameter_descriptions: bool = False,
1079
1120
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1080
1121
  strict: bool | None = None,
1122
+ requires_approval: bool = False,
1081
1123
  ) -> Any:
1082
1124
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
1083
1125
 
@@ -1122,6 +1164,8 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1122
1164
  schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1123
1165
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1124
1166
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1167
+ requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
1168
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
1125
1169
  """
1126
1170
 
1127
1171
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
@@ -1136,6 +1180,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1136
1180
  require_parameter_descriptions,
1137
1181
  schema_generator,
1138
1182
  strict,
1183
+ requires_approval,
1139
1184
  )
1140
1185
  return func_
1141
1186
 
@@ -1279,7 +1324,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1279
1324
  toolsets: list[AbstractToolset[AgentDepsT]] = []
1280
1325
 
1281
1326
  if some_tools := self._override_tools.get():
1282
- function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries)
1327
+ function_toolset = _AgentFunctionToolset(
1328
+ some_tools.value, max_retries=self._max_tool_retries, output_schema=self._output_schema
1329
+ )
1283
1330
  else:
1284
1331
  function_toolset = self._function_toolset
1285
1332
  toolsets.append(function_toolset)
@@ -1308,7 +1355,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1308
1355
 
1309
1356
  return schema # pyright: ignore[reportReturnType]
1310
1357
 
1311
- async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
1358
+ async def __aenter__(self) -> Self:
1312
1359
  """Enter the agent context.
1313
1360
 
1314
1361
  This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
@@ -1376,6 +1423,19 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1376
1423
 
1377
1424
  @dataclasses.dataclass(init=False)
1378
1425
  class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
1426
+ output_schema: _output.BaseOutputSchema[Any]
1427
+
1428
+ def __init__(
1429
+ self,
1430
+ tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
1431
+ *,
1432
+ max_retries: int = 1,
1433
+ id: str | None = None,
1434
+ output_schema: _output.BaseOutputSchema[Any],
1435
+ ):
1436
+ self.output_schema = output_schema
1437
+ super().__init__(tools, max_retries=max_retries, id=id)
1438
+
1379
1439
  @property
1380
1440
  def id(self) -> str:
1381
1441
  return '<agent>'
@@ -1383,3 +1443,10 @@ class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
1383
1443
  @property
1384
1444
  def label(self) -> str:
1385
1445
  return 'the agent'
1446
+
1447
+ def add_tool(self, tool: Tool[AgentDepsT]) -> None:
1448
+ if tool.requires_approval and not self.output_schema.allows_deferred_tools:
1449
+ raise exceptions.UserError(
1450
+ 'To use tools that require approval, add `DeferredToolRequests` to the list of output types for this agent.'
1451
+ )
1452
+ super().add_tool(tool)