pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -22,10 +22,11 @@ from . import (
22
22
  result,
23
23
  usage as _usage,
24
24
  )
25
- from .result import ResultData
25
+ from .result import ResultDataT
26
26
  from .settings import ModelSettings, merge_model_settings
27
27
  from .tools import (
28
- AgentDeps,
28
+ AgentDepsT,
29
+ DocstringFormat,
29
30
  RunContext,
30
31
  Tool,
31
32
  ToolDefinition,
@@ -50,6 +51,8 @@ else:
50
51
 
51
52
  logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
52
53
 
54
+ T = TypeVar('T')
55
+ """An invariant TypeVar."""
53
56
  NoneType = type(None)
54
57
  EndStrategy = Literal['early', 'exhaustive']
55
58
  """The strategy for handling multiple tool calls when a final result is found.
@@ -63,11 +66,11 @@ RunResultData = TypeVar('RunResultData')
63
66
 
64
67
  @final
65
68
  @dataclasses.dataclass(init=False)
66
- class Agent(Generic[AgentDeps, ResultData]):
69
+ class Agent(Generic[AgentDepsT, ResultDataT]):
67
70
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
68
71
 
69
- Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.tools.AgentDeps]
70
- and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
72
+ Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
73
+ and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT].
71
74
 
72
75
  By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
73
76
 
@@ -103,34 +106,34 @@ class Agent(Generic[AgentDeps, ResultData]):
103
106
  """
104
107
  _result_tool_name: str = dataclasses.field(repr=False)
105
108
  _result_tool_description: str | None = dataclasses.field(repr=False)
106
- _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
107
- _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
109
+ _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
110
+ _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
108
111
  _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
109
- _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
112
+ _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
110
113
  _default_retries: int = dataclasses.field(repr=False)
111
- _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
112
- _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
114
+ _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
115
+ _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
113
116
  repr=False
114
117
  )
115
- _deps_type: type[AgentDeps] = dataclasses.field(repr=False)
118
+ _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
116
119
  _max_result_retries: int = dataclasses.field(repr=False)
117
- _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
120
+ _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
118
121
  _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
119
122
 
120
123
  def __init__(
121
124
  self,
122
125
  model: models.Model | models.KnownModelName | None = None,
123
126
  *,
124
- result_type: type[ResultData] = str,
127
+ result_type: type[ResultDataT] = str,
125
128
  system_prompt: str | Sequence[str] = (),
126
- deps_type: type[AgentDeps] = NoneType,
129
+ deps_type: type[AgentDepsT] = NoneType,
127
130
  name: str | None = None,
128
131
  model_settings: ModelSettings | None = None,
129
132
  retries: int = 1,
130
133
  result_tool_name: str = 'final_result',
131
134
  result_tool_description: str | None = None,
132
135
  result_retries: int | None = None,
133
- tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
136
+ tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
134
137
  defer_model_check: bool = False,
135
138
  end_strategy: EndStrategy = 'early',
136
139
  ):
@@ -199,12 +202,12 @@ class Agent(Generic[AgentDeps, ResultData]):
199
202
  result_type: None = None,
200
203
  message_history: list[_messages.ModelMessage] | None = None,
201
204
  model: models.Model | models.KnownModelName | None = None,
202
- deps: AgentDeps = None,
205
+ deps: AgentDepsT = None,
203
206
  model_settings: ModelSettings | None = None,
204
207
  usage_limits: _usage.UsageLimits | None = None,
205
208
  usage: _usage.Usage | None = None,
206
209
  infer_name: bool = True,
207
- ) -> result.RunResult[ResultData]: ...
210
+ ) -> result.RunResult[ResultDataT]: ...
208
211
 
209
212
  @overload
210
213
  async def run(
@@ -214,7 +217,7 @@ class Agent(Generic[AgentDeps, ResultData]):
214
217
  result_type: type[RunResultData],
215
218
  message_history: list[_messages.ModelMessage] | None = None,
216
219
  model: models.Model | models.KnownModelName | None = None,
217
- deps: AgentDeps = None,
220
+ deps: AgentDepsT = None,
218
221
  model_settings: ModelSettings | None = None,
219
222
  usage_limits: _usage.UsageLimits | None = None,
220
223
  usage: _usage.Usage | None = None,
@@ -227,7 +230,7 @@ class Agent(Generic[AgentDeps, ResultData]):
227
230
  *,
228
231
  message_history: list[_messages.ModelMessage] | None = None,
229
232
  model: models.Model | models.KnownModelName | None = None,
230
- deps: AgentDeps = None,
233
+ deps: AgentDepsT = None,
231
234
  model_settings: ModelSettings | None = None,
232
235
  usage_limits: _usage.UsageLimits | None = None,
233
236
  usage: _usage.Usage | None = None,
@@ -242,9 +245,10 @@ class Agent(Generic[AgentDeps, ResultData]):
242
245
 
243
246
  agent = Agent('openai:gpt-4o')
244
247
 
245
- result_sync = agent.run_sync('What is the capital of Italy?')
246
- print(result_sync.data)
247
- #> Rome
248
+ async def main():
249
+ result = await agent.run('What is the capital of France?')
250
+ print(result.data)
251
+ #> Paris
248
252
  ```
249
253
 
250
254
  Args:
@@ -336,12 +340,12 @@ class Agent(Generic[AgentDeps, ResultData]):
336
340
  *,
337
341
  message_history: list[_messages.ModelMessage] | None = None,
338
342
  model: models.Model | models.KnownModelName | None = None,
339
- deps: AgentDeps = None,
343
+ deps: AgentDepsT = None,
340
344
  model_settings: ModelSettings | None = None,
341
345
  usage_limits: _usage.UsageLimits | None = None,
342
346
  usage: _usage.Usage | None = None,
343
347
  infer_name: bool = True,
344
- ) -> result.RunResult[ResultData]: ...
348
+ ) -> result.RunResult[ResultDataT]: ...
345
349
 
346
350
  @overload
347
351
  def run_sync(
@@ -351,7 +355,7 @@ class Agent(Generic[AgentDeps, ResultData]):
351
355
  result_type: type[RunResultData] | None,
352
356
  message_history: list[_messages.ModelMessage] | None = None,
353
357
  model: models.Model | models.KnownModelName | None = None,
354
- deps: AgentDeps = None,
358
+ deps: AgentDepsT = None,
355
359
  model_settings: ModelSettings | None = None,
356
360
  usage_limits: _usage.UsageLimits | None = None,
357
361
  usage: _usage.Usage | None = None,
@@ -365,7 +369,7 @@ class Agent(Generic[AgentDeps, ResultData]):
365
369
  result_type: type[RunResultData] | None = None,
366
370
  message_history: list[_messages.ModelMessage] | None = None,
367
371
  model: models.Model | models.KnownModelName | None = None,
368
- deps: AgentDeps = None,
372
+ deps: AgentDepsT = None,
369
373
  model_settings: ModelSettings | None = None,
370
374
  usage_limits: _usage.UsageLimits | None = None,
371
375
  usage: _usage.Usage | None = None,
@@ -382,10 +386,9 @@ class Agent(Generic[AgentDeps, ResultData]):
382
386
 
383
387
  agent = Agent('openai:gpt-4o')
384
388
 
385
- async def main():
386
- result = await agent.run('What is the capital of France?')
387
- print(result.data)
388
- #> Paris
389
+ result_sync = agent.run_sync('What is the capital of Italy?')
390
+ print(result_sync.data)
391
+ #> Rome
389
392
  ```
390
393
 
391
394
  Args:
@@ -427,12 +430,12 @@ class Agent(Generic[AgentDeps, ResultData]):
427
430
  result_type: None = None,
428
431
  message_history: list[_messages.ModelMessage] | None = None,
429
432
  model: models.Model | models.KnownModelName | None = None,
430
- deps: AgentDeps = None,
433
+ deps: AgentDepsT = None,
431
434
  model_settings: ModelSettings | None = None,
432
435
  usage_limits: _usage.UsageLimits | None = None,
433
436
  usage: _usage.Usage | None = None,
434
437
  infer_name: bool = True,
435
- ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, ResultData]]: ...
438
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ...
436
439
 
437
440
  @overload
438
441
  def run_stream(
@@ -442,12 +445,12 @@ class Agent(Generic[AgentDeps, ResultData]):
442
445
  result_type: type[RunResultData],
443
446
  message_history: list[_messages.ModelMessage] | None = None,
444
447
  model: models.Model | models.KnownModelName | None = None,
445
- deps: AgentDeps = None,
448
+ deps: AgentDepsT = None,
446
449
  model_settings: ModelSettings | None = None,
447
450
  usage_limits: _usage.UsageLimits | None = None,
448
451
  usage: _usage.Usage | None = None,
449
452
  infer_name: bool = True,
450
- ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, RunResultData]]: ...
453
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultData]]: ...
451
454
 
452
455
  @asynccontextmanager
453
456
  async def run_stream(
@@ -457,12 +460,12 @@ class Agent(Generic[AgentDeps, ResultData]):
457
460
  result_type: type[RunResultData] | None = None,
458
461
  message_history: list[_messages.ModelMessage] | None = None,
459
462
  model: models.Model | models.KnownModelName | None = None,
460
- deps: AgentDeps = None,
463
+ deps: AgentDepsT = None,
461
464
  model_settings: ModelSettings | None = None,
462
465
  usage_limits: _usage.UsageLimits | None = None,
463
466
  usage: _usage.Usage | None = None,
464
467
  infer_name: bool = True,
465
- ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, Any]]:
468
+ ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
466
469
  """Run the agent with a user prompt in async mode, returning a streamed response.
467
470
 
468
471
  Example:
@@ -535,7 +538,7 @@ class Agent(Generic[AgentDeps, ResultData]):
535
538
  model_req_span.__exit__(None, None, None)
536
539
 
537
540
  with _logfire.span('handle model response') as handle_span:
538
- maybe_final_result = await self._handle_streamed_model_response(
541
+ maybe_final_result = await self._handle_streamed_response(
539
542
  model_response, run_context, result_schema
540
543
  )
541
544
 
@@ -559,10 +562,19 @@ class Agent(Generic[AgentDeps, ResultData]):
559
562
  parts = await self._process_function_tools(
560
563
  tool_calls, result_tool_name, run_context, result_schema
561
564
  )
565
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
566
+ self._incr_result_retry(run_context)
562
567
  if parts:
563
568
  messages.append(_messages.ModelRequest(parts))
564
569
  run_span.set_attribute('all_messages', messages)
565
570
 
571
+ # The following is not guaranteed to be true, but we consider it a user error if
572
+ # there are result validators that might convert the result data from an overridden
573
+ # `result_type` to a type that is not valid as such.
574
+ result_validators = cast(
575
+ list[_result.ResultValidator[AgentDepsT, RunResultData]], self._result_validators
576
+ )
577
+
566
578
  yield result.StreamedRunResult(
567
579
  messages,
568
580
  new_message_index,
@@ -570,7 +582,7 @@ class Agent(Generic[AgentDeps, ResultData]):
570
582
  result_stream,
571
583
  result_schema,
572
584
  run_context,
573
- self._result_validators,
585
+ result_validators,
574
586
  result_tool_name,
575
587
  on_complete,
576
588
  )
@@ -596,7 +608,7 @@ class Agent(Generic[AgentDeps, ResultData]):
596
608
  def override(
597
609
  self,
598
610
  *,
599
- deps: AgentDeps | _utils.Unset = _utils.UNSET,
611
+ deps: AgentDepsT | _utils.Unset = _utils.UNSET,
600
612
  model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
601
613
  ) -> Iterator[None]:
602
614
  """Context manager to temporarily override agent dependencies and model.
@@ -632,13 +644,13 @@ class Agent(Generic[AgentDeps, ResultData]):
632
644
 
633
645
  @overload
634
646
  def system_prompt(
635
- self, func: Callable[[RunContext[AgentDeps]], str], /
636
- ) -> Callable[[RunContext[AgentDeps]], str]: ...
647
+ self, func: Callable[[RunContext[AgentDepsT]], str], /
648
+ ) -> Callable[[RunContext[AgentDepsT]], str]: ...
637
649
 
638
650
  @overload
639
651
  def system_prompt(
640
- self, func: Callable[[RunContext[AgentDeps]], Awaitable[str]], /
641
- ) -> Callable[[RunContext[AgentDeps]], Awaitable[str]]: ...
652
+ self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
653
+ ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
642
654
 
643
655
  @overload
644
656
  def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
@@ -649,17 +661,17 @@ class Agent(Generic[AgentDeps, ResultData]):
649
661
  @overload
650
662
  def system_prompt(
651
663
  self, /, *, dynamic: bool = False
652
- ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
664
+ ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
653
665
 
654
666
  def system_prompt(
655
667
  self,
656
- func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
668
+ func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
657
669
  /,
658
670
  *,
659
671
  dynamic: bool = False,
660
672
  ) -> (
661
- Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
662
- | _system_prompt.SystemPromptFunc[AgentDeps]
673
+ Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
674
+ | _system_prompt.SystemPromptFunc[AgentDepsT]
663
675
  ):
664
676
  """Decorator to register a system prompt function.
665
677
 
@@ -695,9 +707,9 @@ class Agent(Generic[AgentDeps, ResultData]):
695
707
  if func is None:
696
708
 
697
709
  def decorator(
698
- func_: _system_prompt.SystemPromptFunc[AgentDeps],
699
- ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
700
- runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
710
+ func_: _system_prompt.SystemPromptFunc[AgentDepsT],
711
+ ) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
712
+ runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
701
713
  self._system_prompt_functions.append(runner)
702
714
  if dynamic:
703
715
  self._system_prompt_dynamic_functions[func_.__qualname__] = runner
@@ -711,25 +723,27 @@ class Agent(Generic[AgentDeps, ResultData]):
711
723
 
712
724
  @overload
713
725
  def result_validator(
714
- self, func: Callable[[RunContext[AgentDeps], ResultData], ResultData], /
715
- ) -> Callable[[RunContext[AgentDeps], ResultData], ResultData]: ...
726
+ self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], /
727
+ ) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ...
716
728
 
717
729
  @overload
718
730
  def result_validator(
719
- self, func: Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], /
720
- ) -> Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
731
+ self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], /
732
+ ) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ...
721
733
 
722
734
  @overload
723
- def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
735
+ def result_validator(
736
+ self, func: Callable[[ResultDataT], ResultDataT], /
737
+ ) -> Callable[[ResultDataT], ResultDataT]: ...
724
738
 
725
739
  @overload
726
740
  def result_validator(
727
- self, func: Callable[[ResultData], Awaitable[ResultData]], /
728
- ) -> Callable[[ResultData], Awaitable[ResultData]]: ...
741
+ self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], /
742
+ ) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ...
729
743
 
730
744
  def result_validator(
731
- self, func: _result.ResultValidatorFunc[AgentDeps, ResultData], /
732
- ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
745
+ self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], /
746
+ ) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]:
733
747
  """Decorator to register a result validator function.
734
748
 
735
749
  Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
@@ -761,11 +775,11 @@ class Agent(Generic[AgentDeps, ResultData]):
761
775
  #> success (no tool calls)
762
776
  ```
763
777
  """
764
- self._result_validators.append(_result.ResultValidator[AgentDeps, Any](func))
778
+ self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func))
765
779
  return func
766
780
 
767
781
  @overload
768
- def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ...
782
+ def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
769
783
 
770
784
  @overload
771
785
  def tool(
@@ -773,16 +787,20 @@ class Agent(Generic[AgentDeps, ResultData]):
773
787
  /,
774
788
  *,
775
789
  retries: int | None = None,
776
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
777
- ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
790
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
791
+ docstring_format: DocstringFormat = 'auto',
792
+ require_parameter_descriptions: bool = False,
793
+ ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
778
794
 
779
795
  def tool(
780
796
  self,
781
- func: ToolFuncContext[AgentDeps, ToolParams] | None = None,
797
+ func: ToolFuncContext[AgentDepsT, ToolParams] | None = None,
782
798
  /,
783
799
  *,
784
800
  retries: int | None = None,
785
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
801
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
802
+ docstring_format: DocstringFormat = 'auto',
803
+ require_parameter_descriptions: bool = False,
786
804
  ) -> Any:
787
805
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
788
806
 
@@ -820,20 +838,23 @@ class Agent(Generic[AgentDeps, ResultData]):
820
838
  prepare: custom method to prepare the tool definition for each step, return `None` to omit this
821
839
  tool from a given step. This is useful if you want to customise a tool at call time,
822
840
  or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
841
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
842
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
843
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
823
844
  """
824
845
  if func is None:
825
846
 
826
847
  def tool_decorator(
827
- func_: ToolFuncContext[AgentDeps, ToolParams],
828
- ) -> ToolFuncContext[AgentDeps, ToolParams]:
848
+ func_: ToolFuncContext[AgentDepsT, ToolParams],
849
+ ) -> ToolFuncContext[AgentDepsT, ToolParams]:
829
850
  # noinspection PyTypeChecker
830
- self._register_function(func_, True, retries, prepare)
851
+ self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
831
852
  return func_
832
853
 
833
854
  return tool_decorator
834
855
  else:
835
856
  # noinspection PyTypeChecker
836
- self._register_function(func, True, retries, prepare)
857
+ self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
837
858
  return func
838
859
 
839
860
  @overload
@@ -845,7 +866,9 @@ class Agent(Generic[AgentDeps, ResultData]):
845
866
  /,
846
867
  *,
847
868
  retries: int | None = None,
848
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
869
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
870
+ docstring_format: DocstringFormat = 'auto',
871
+ require_parameter_descriptions: bool = False,
849
872
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
850
873
 
851
874
  def tool_plain(
@@ -854,7 +877,9 @@ class Agent(Generic[AgentDeps, ResultData]):
854
877
  /,
855
878
  *,
856
879
  retries: int | None = None,
857
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
880
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
881
+ docstring_format: DocstringFormat = 'auto',
882
+ require_parameter_descriptions: bool = False,
858
883
  ) -> Any:
859
884
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
860
885
 
@@ -892,32 +917,46 @@ class Agent(Generic[AgentDeps, ResultData]):
892
917
  prepare: custom method to prepare the tool definition for each step, return `None` to omit this
893
918
  tool from a given step. This is useful if you want to customise a tool at call time,
894
919
  or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
920
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
921
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
922
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
895
923
  """
896
924
  if func is None:
897
925
 
898
926
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
899
927
  # noinspection PyTypeChecker
900
- self._register_function(func_, False, retries, prepare)
928
+ self._register_function(
929
+ func_, False, retries, prepare, docstring_format, require_parameter_descriptions
930
+ )
901
931
  return func_
902
932
 
903
933
  return tool_decorator
904
934
  else:
905
- self._register_function(func, False, retries, prepare)
935
+ self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
906
936
  return func
907
937
 
908
938
  def _register_function(
909
939
  self,
910
- func: ToolFuncEither[AgentDeps, ToolParams],
940
+ func: ToolFuncEither[AgentDepsT, ToolParams],
911
941
  takes_ctx: bool,
912
942
  retries: int | None,
913
- prepare: ToolPrepareFunc[AgentDeps] | None,
943
+ prepare: ToolPrepareFunc[AgentDepsT] | None,
944
+ docstring_format: DocstringFormat,
945
+ require_parameter_descriptions: bool,
914
946
  ) -> None:
915
947
  """Private utility to register a function as a tool."""
916
948
  retries_ = retries if retries is not None else self._default_retries
917
- tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
949
+ tool = Tool[AgentDepsT](
950
+ func,
951
+ takes_ctx=takes_ctx,
952
+ max_retries=retries_,
953
+ prepare=prepare,
954
+ docstring_format=docstring_format,
955
+ require_parameter_descriptions=require_parameter_descriptions,
956
+ )
918
957
  self._register_tool(tool)
919
958
 
920
- def _register_tool(self, tool: Tool[AgentDeps]) -> None:
959
+ def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
921
960
  """Private utility to register a tool instance."""
922
961
  if tool.max_retries is None:
923
962
  # noinspection PyTypeChecker
@@ -960,12 +999,12 @@ class Agent(Generic[AgentDeps, ResultData]):
960
999
  return model_
961
1000
 
962
1001
  async def _prepare_model(
963
- self, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
1002
+ self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
964
1003
  ) -> models.AgentModel:
965
1004
  """Build tools and create an agent model."""
966
1005
  function_tools: list[ToolDefinition] = []
967
1006
 
968
- async def add_tool(tool: Tool[AgentDeps]) -> None:
1007
+ async def add_tool(tool: Tool[AgentDepsT]) -> None:
969
1008
  ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
970
1009
  if tool_def := await tool.prepare_tool_def(ctx):
971
1010
  function_tools.append(tool_def)
@@ -979,7 +1018,7 @@ class Agent(Generic[AgentDeps, ResultData]):
979
1018
  )
980
1019
 
981
1020
  async def _reevaluate_dynamic_prompts(
982
- self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
1021
+ self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDepsT]
983
1022
  ) -> None:
984
1023
  """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
985
1024
  # Only proceed if there's at least one dynamic runner.
@@ -1008,7 +1047,10 @@ class Agent(Generic[AgentDeps, ResultData]):
1008
1047
  return self._result_schema # pyright: ignore[reportReturnType]
1009
1048
 
1010
1049
  async def _prepare_messages(
1011
- self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
1050
+ self,
1051
+ user_prompt: str,
1052
+ message_history: list[_messages.ModelMessage] | None,
1053
+ run_context: RunContext[AgentDepsT],
1012
1054
  ) -> list[_messages.ModelMessage]:
1013
1055
  try:
1014
1056
  ctx_messages = _messages_ctx_var.get()
@@ -1037,7 +1079,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1037
1079
  async def _handle_model_response(
1038
1080
  self,
1039
1081
  model_response: _messages.ModelResponse,
1040
- run_context: RunContext[AgentDeps],
1082
+ run_context: RunContext[AgentDepsT],
1041
1083
  result_schema: _result.ResultSchema[RunResultData] | None,
1042
1084
  ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1043
1085
  """Process a non-streamed response from the model.
@@ -1068,7 +1110,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1068
1110
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
1069
1111
 
1070
1112
  async def _handle_text_response(
1071
- self, text: str, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
1113
+ self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
1072
1114
  ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1073
1115
  """Handle a plain text response from the model for non-streaming responses."""
1074
1116
  if self._allow_text_result(result_schema):
@@ -1090,7 +1132,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1090
1132
  async def _handle_structured_response(
1091
1133
  self,
1092
1134
  tool_calls: list[_messages.ToolCallPart],
1093
- run_context: RunContext[AgentDeps],
1135
+ run_context: RunContext[AgentDepsT],
1094
1136
  result_schema: _result.ResultSchema[RunResultData] | None,
1095
1137
  ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1096
1138
  """Handle a structured response containing tool calls from the model for non-streaming responses."""
@@ -1100,14 +1142,13 @@ class Agent(Generic[AgentDeps, ResultData]):
1100
1142
  final_result: _MarkFinalResult[RunResultData] | None = None
1101
1143
 
1102
1144
  parts: list[_messages.ModelRequestPart] = []
1103
- if result_schema := result_schema:
1145
+ if result_schema is not None:
1104
1146
  if match := result_schema.find_tool(tool_calls):
1105
1147
  call, result_tool = match
1106
1148
  try:
1107
1149
  result_data = result_tool.validate(call)
1108
1150
  result_data = await self._validate_result(result_data, run_context, call)
1109
1151
  except _result.ToolRetryError as e:
1110
- self._incr_result_retry(run_context)
1111
1152
  parts.append(e.tool_retry)
1112
1153
  else:
1113
1154
  final_result = _MarkFinalResult(result_data, call.tool_name)
@@ -1117,13 +1158,16 @@ class Agent(Generic[AgentDeps, ResultData]):
1117
1158
  tool_calls, final_result and final_result.tool_name, run_context, result_schema
1118
1159
  )
1119
1160
 
1161
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1162
+ self._incr_result_retry(run_context)
1163
+
1120
1164
  return final_result, parts
1121
1165
 
1122
1166
  async def _process_function_tools(
1123
1167
  self,
1124
1168
  tool_calls: list[_messages.ToolCallPart],
1125
1169
  result_tool_name: str | None,
1126
- run_context: RunContext[AgentDeps],
1170
+ run_context: RunContext[AgentDepsT],
1127
1171
  result_schema: _result.ResultSchema[RunResultData] | None,
1128
1172
  ) -> list[_messages.ModelRequestPart]:
1129
1173
  """Process function (non-result) tool calls in parallel.
@@ -1170,7 +1214,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1170
1214
  )
1171
1215
  )
1172
1216
  else:
1173
- parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
1217
+ parts.append(self._unknown_tool(call.tool_name, result_schema))
1174
1218
 
1175
1219
  # Run all tool tasks in parallel
1176
1220
  if tasks:
@@ -1179,99 +1223,85 @@ class Agent(Generic[AgentDeps, ResultData]):
1179
1223
  parts.extend(task_results)
1180
1224
  return parts
1181
1225
 
1182
- async def _handle_streamed_model_response(
1226
+ async def _handle_streamed_response(
1183
1227
  self,
1184
- model_response: models.EitherStreamedResponse,
1185
- run_context: RunContext[AgentDeps],
1228
+ streamed_response: models.StreamedResponse,
1229
+ run_context: RunContext[AgentDepsT],
1186
1230
  result_schema: _result.ResultSchema[RunResultData] | None,
1187
- ) -> (
1188
- _MarkFinalResult[models.EitherStreamedResponse]
1189
- | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
1190
- ):
1231
+ ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
1191
1232
  """Process a streamed response from the model.
1192
1233
 
1193
1234
  Returns:
1194
1235
  Either a final result or a tuple of the model response and the tool responses for the next request.
1195
1236
  If a final result is returned, the conversation should end.
1196
1237
  """
1197
- if isinstance(model_response, models.StreamTextResponse):
1198
- # plain string response
1199
- if self._allow_text_result(result_schema):
1200
- return _MarkFinalResult(model_response, None)
1201
- else:
1202
- self._incr_result_retry(run_context)
1203
- response = _messages.RetryPromptPart(
1204
- content='Plain text responses are not permitted, please call one of the functions instead.',
1205
- )
1206
- # stream the response, so usage is correct
1207
- async for _ in model_response:
1208
- pass
1209
-
1210
- text = ''.join(model_response.get(final=True))
1211
- return _messages.ModelResponse([_messages.TextPart(text)]), [response]
1212
- elif isinstance(model_response, models.StreamStructuredResponse):
1213
- if result_schema is not None:
1214
- # if there's a result schema, iterate over the stream until we find at least one tool
1215
- # NOTE: this means we ignore any other tools called here
1216
- structured_msg = model_response.get()
1217
- while not structured_msg.parts:
1218
- try:
1219
- await model_response.__anext__()
1220
- except StopAsyncIteration:
1221
- break
1222
- structured_msg = model_response.get()
1223
-
1224
- if match := result_schema.find_tool(structured_msg.parts):
1225
- call, _ = match
1226
- return _MarkFinalResult(model_response, call.tool_name)
1227
-
1228
- # the model is calling a tool function, consume the response to get the next message
1229
- async for _ in model_response:
1230
- pass
1231
- model_response_msg = model_response.get()
1232
- if not model_response_msg.parts:
1233
- raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
1234
-
1235
- # we now run all tool functions in parallel
1236
- tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1237
- parts: list[_messages.ModelRequestPart] = []
1238
- for item in model_response_msg.parts:
1239
- if isinstance(item, _messages.ToolCallPart):
1240
- call = item
1241
- if tool := self._function_tools.get(call.tool_name):
1242
- tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1243
- else:
1244
- parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
1238
+ received_text = False
1239
+
1240
+ async for maybe_part_event in streamed_response:
1241
+ if isinstance(maybe_part_event, _messages.PartStartEvent):
1242
+ new_part = maybe_part_event.part
1243
+ if isinstance(new_part, _messages.TextPart):
1244
+ received_text = True
1245
+ if self._allow_text_result(result_schema):
1246
+ return _MarkFinalResult(streamed_response, None)
1247
+ elif isinstance(new_part, _messages.ToolCallPart):
1248
+ if result_schema is not None and (match := result_schema.find_tool([new_part])):
1249
+ call, _ = match
1250
+ return _MarkFinalResult(streamed_response, call.tool_name)
1251
+ else:
1252
+ assert_never(new_part)
1245
1253
 
1246
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1247
- task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1248
- parts.extend(task_results)
1249
- return model_response_msg, parts
1250
- else:
1251
- assert_never(model_response)
1254
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1255
+ parts: list[_messages.ModelRequestPart] = []
1256
+ model_response = streamed_response.get()
1257
+ if not model_response.parts:
1258
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
1259
+ for p in model_response.parts:
1260
+ if isinstance(p, _messages.ToolCallPart):
1261
+ if tool := self._function_tools.get(p.tool_name):
1262
+ tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
1263
+ else:
1264
+ parts.append(self._unknown_tool(p.tool_name, result_schema))
1265
+
1266
+ if received_text and not tasks and not parts:
1267
+ # Can only get here if self._allow_text_result returns `False` for the provided result_schema
1268
+ self._incr_result_retry(run_context)
1269
+ model_response = _messages.RetryPromptPart(
1270
+ content='Plain text responses are not permitted, please call one of the functions instead.',
1271
+ )
1272
+ return streamed_response.get(), [model_response]
1273
+
1274
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1275
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1276
+ parts.extend(task_results)
1277
+
1278
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1279
+ self._incr_result_retry(run_context)
1280
+
1281
+ return model_response, parts
1252
1282
 
1253
1283
  async def _validate_result(
1254
1284
  self,
1255
1285
  result_data: RunResultData,
1256
- run_context: RunContext[AgentDeps],
1286
+ run_context: RunContext[AgentDepsT],
1257
1287
  tool_call: _messages.ToolCallPart | None,
1258
1288
  ) -> RunResultData:
1259
1289
  if self._result_validators:
1260
- agent_result_data = cast(ResultData, result_data)
1290
+ agent_result_data = cast(ResultDataT, result_data)
1261
1291
  for validator in self._result_validators:
1262
1292
  agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1263
1293
  return cast(RunResultData, agent_result_data)
1264
1294
  else:
1265
1295
  return result_data
1266
1296
 
1267
- def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1297
+ def _incr_result_retry(self, run_context: RunContext[AgentDepsT]) -> None:
1268
1298
  run_context.retry += 1
1269
1299
  if run_context.retry > self._max_result_retries:
1270
1300
  raise exceptions.UnexpectedModelBehavior(
1271
1301
  f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
1272
1302
  )
1273
1303
 
1274
- async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
1304
+ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_messages.ModelRequestPart]:
1275
1305
  """Build the initial messages for the conversation."""
1276
1306
  messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1277
1307
  for sys_prompt_runner in self._system_prompt_functions:
@@ -1285,10 +1315,8 @@ class Agent(Generic[AgentDeps, ResultData]):
1285
1315
  def _unknown_tool(
1286
1316
  self,
1287
1317
  tool_name: str,
1288
- run_context: RunContext[AgentDeps],
1289
1318
  result_schema: _result.ResultSchema[RunResultData] | None,
1290
1319
  ) -> _messages.RetryPromptPart:
1291
- self._incr_result_retry(run_context)
1292
1320
  names = list(self._function_tools.keys())
1293
1321
  if result_schema:
1294
1322
  names.extend(result_schema.tool_names())
@@ -1298,7 +1326,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1298
1326
  msg = 'No tools available.'
1299
1327
  return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
1300
1328
 
1301
- def _get_deps(self, deps: AgentDeps) -> AgentDeps:
1329
+ def _get_deps(self: Agent[T, Any], deps: T) -> T:
1302
1330
  """Get deps for a run.
1303
1331
 
1304
1332
  If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
@@ -1386,15 +1414,15 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1386
1414
 
1387
1415
 
1388
1416
  @dataclasses.dataclass
1389
- class _MarkFinalResult(Generic[ResultData]):
1417
+ class _MarkFinalResult(Generic[ResultDataT]):
1390
1418
  """Marker class to indicate that the result is the final result.
1391
1419
 
1392
- This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
1420
+ This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
1393
1421
 
1394
1422
  It also avoids problems in the case where the result type is itself `None`, but is set.
1395
1423
  """
1396
1424
 
1397
- data: ResultData
1425
+ data: ResultDataT
1398
1426
  """The final result data."""
1399
1427
  tool_name: str | None
1400
1428
  """Name of the final result tool, None if the result is a string."""