pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -22,10 +22,10 @@ from . import (
22
22
  result,
23
23
  usage as _usage,
24
24
  )
25
- from .result import ResultData
25
+ from .result import ResultDataT
26
26
  from .settings import ModelSettings, merge_model_settings
27
27
  from .tools import (
28
- AgentDeps,
28
+ AgentDepsT,
29
29
  DocstringFormat,
30
30
  RunContext,
31
31
  Tool,
@@ -51,6 +51,8 @@ else:
51
51
 
52
52
  logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
53
53
 
54
+ T = TypeVar('T')
55
+ """An invariant TypeVar."""
54
56
  NoneType = type(None)
55
57
  EndStrategy = Literal['early', 'exhaustive']
56
58
  """The strategy for handling multiple tool calls when a final result is found.
@@ -58,17 +60,17 @@ EndStrategy = Literal['early', 'exhaustive']
58
60
  - `'early'`: Stop processing other tool calls once a final result is found
59
61
  - `'exhaustive'`: Process all tool calls even after finding a final result
60
62
  """
61
- RunResultData = TypeVar('RunResultData')
63
+ RunResultDataT = TypeVar('RunResultDataT')
62
64
  """Type variable for the result data of a run where `result_type` was customized on the run call."""
63
65
 
64
66
 
65
67
  @final
66
68
  @dataclasses.dataclass(init=False)
67
- class Agent(Generic[AgentDeps, ResultData]):
69
+ class Agent(Generic[AgentDepsT, ResultDataT]):
68
70
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
69
71
 
70
- Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.tools.AgentDeps]
71
- and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
72
+ Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
73
+ and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT].
72
74
 
73
75
  By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
74
76
 
@@ -104,34 +106,34 @@ class Agent(Generic[AgentDeps, ResultData]):
104
106
  """
105
107
  _result_tool_name: str = dataclasses.field(repr=False)
106
108
  _result_tool_description: str | None = dataclasses.field(repr=False)
107
- _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
108
- _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
109
+ _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
110
+ _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
109
111
  _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
110
- _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
112
+ _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
111
113
  _default_retries: int = dataclasses.field(repr=False)
112
- _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
113
- _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
114
+ _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
115
+ _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
114
116
  repr=False
115
117
  )
116
- _deps_type: type[AgentDeps] = dataclasses.field(repr=False)
118
+ _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
117
119
  _max_result_retries: int = dataclasses.field(repr=False)
118
- _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
120
+ _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
119
121
  _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
120
122
 
121
123
  def __init__(
122
124
  self,
123
125
  model: models.Model | models.KnownModelName | None = None,
124
126
  *,
125
- result_type: type[ResultData] = str,
127
+ result_type: type[ResultDataT] = str,
126
128
  system_prompt: str | Sequence[str] = (),
127
- deps_type: type[AgentDeps] = NoneType,
129
+ deps_type: type[AgentDepsT] = NoneType,
128
130
  name: str | None = None,
129
131
  model_settings: ModelSettings | None = None,
130
132
  retries: int = 1,
131
133
  result_tool_name: str = 'final_result',
132
134
  result_tool_description: str | None = None,
133
135
  result_retries: int | None = None,
134
- tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
136
+ tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
135
137
  defer_model_check: bool = False,
136
138
  end_strategy: EndStrategy = 'early',
137
139
  ):
@@ -200,27 +202,27 @@ class Agent(Generic[AgentDeps, ResultData]):
200
202
  result_type: None = None,
201
203
  message_history: list[_messages.ModelMessage] | None = None,
202
204
  model: models.Model | models.KnownModelName | None = None,
203
- deps: AgentDeps = None,
205
+ deps: AgentDepsT = None,
204
206
  model_settings: ModelSettings | None = None,
205
207
  usage_limits: _usage.UsageLimits | None = None,
206
208
  usage: _usage.Usage | None = None,
207
209
  infer_name: bool = True,
208
- ) -> result.RunResult[ResultData]: ...
210
+ ) -> result.RunResult[ResultDataT]: ...
209
211
 
210
212
  @overload
211
213
  async def run(
212
214
  self,
213
215
  user_prompt: str,
214
216
  *,
215
- result_type: type[RunResultData],
217
+ result_type: type[RunResultDataT],
216
218
  message_history: list[_messages.ModelMessage] | None = None,
217
219
  model: models.Model | models.KnownModelName | None = None,
218
- deps: AgentDeps = None,
220
+ deps: AgentDepsT = None,
219
221
  model_settings: ModelSettings | None = None,
220
222
  usage_limits: _usage.UsageLimits | None = None,
221
223
  usage: _usage.Usage | None = None,
222
224
  infer_name: bool = True,
223
- ) -> result.RunResult[RunResultData]: ...
225
+ ) -> result.RunResult[RunResultDataT]: ...
224
226
 
225
227
  async def run(
226
228
  self,
@@ -228,11 +230,11 @@ class Agent(Generic[AgentDeps, ResultData]):
228
230
  *,
229
231
  message_history: list[_messages.ModelMessage] | None = None,
230
232
  model: models.Model | models.KnownModelName | None = None,
231
- deps: AgentDeps = None,
233
+ deps: AgentDepsT = None,
232
234
  model_settings: ModelSettings | None = None,
233
235
  usage_limits: _usage.UsageLimits | None = None,
234
236
  usage: _usage.Usage | None = None,
235
- result_type: type[RunResultData] | None = None,
237
+ result_type: type[RunResultDataT] | None = None,
236
238
  infer_name: bool = True,
237
239
  ) -> result.RunResult[Any]:
238
240
  """Run the agent with a user prompt in async mode.
@@ -338,36 +340,36 @@ class Agent(Generic[AgentDeps, ResultData]):
338
340
  *,
339
341
  message_history: list[_messages.ModelMessage] | None = None,
340
342
  model: models.Model | models.KnownModelName | None = None,
341
- deps: AgentDeps = None,
343
+ deps: AgentDepsT = None,
342
344
  model_settings: ModelSettings | None = None,
343
345
  usage_limits: _usage.UsageLimits | None = None,
344
346
  usage: _usage.Usage | None = None,
345
347
  infer_name: bool = True,
346
- ) -> result.RunResult[ResultData]: ...
348
+ ) -> result.RunResult[ResultDataT]: ...
347
349
 
348
350
  @overload
349
351
  def run_sync(
350
352
  self,
351
353
  user_prompt: str,
352
354
  *,
353
- result_type: type[RunResultData] | None,
355
+ result_type: type[RunResultDataT] | None,
354
356
  message_history: list[_messages.ModelMessage] | None = None,
355
357
  model: models.Model | models.KnownModelName | None = None,
356
- deps: AgentDeps = None,
358
+ deps: AgentDepsT = None,
357
359
  model_settings: ModelSettings | None = None,
358
360
  usage_limits: _usage.UsageLimits | None = None,
359
361
  usage: _usage.Usage | None = None,
360
362
  infer_name: bool = True,
361
- ) -> result.RunResult[RunResultData]: ...
363
+ ) -> result.RunResult[RunResultDataT]: ...
362
364
 
363
365
  def run_sync(
364
366
  self,
365
367
  user_prompt: str,
366
368
  *,
367
- result_type: type[RunResultData] | None = None,
369
+ result_type: type[RunResultDataT] | None = None,
368
370
  message_history: list[_messages.ModelMessage] | None = None,
369
371
  model: models.Model | models.KnownModelName | None = None,
370
- deps: AgentDeps = None,
372
+ deps: AgentDepsT = None,
371
373
  model_settings: ModelSettings | None = None,
372
374
  usage_limits: _usage.UsageLimits | None = None,
373
375
  usage: _usage.Usage | None = None,
@@ -428,42 +430,42 @@ class Agent(Generic[AgentDeps, ResultData]):
428
430
  result_type: None = None,
429
431
  message_history: list[_messages.ModelMessage] | None = None,
430
432
  model: models.Model | models.KnownModelName | None = None,
431
- deps: AgentDeps = None,
433
+ deps: AgentDepsT = None,
432
434
  model_settings: ModelSettings | None = None,
433
435
  usage_limits: _usage.UsageLimits | None = None,
434
436
  usage: _usage.Usage | None = None,
435
437
  infer_name: bool = True,
436
- ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, ResultData]]: ...
438
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ...
437
439
 
438
440
  @overload
439
441
  def run_stream(
440
442
  self,
441
443
  user_prompt: str,
442
444
  *,
443
- result_type: type[RunResultData],
445
+ result_type: type[RunResultDataT],
444
446
  message_history: list[_messages.ModelMessage] | None = None,
445
447
  model: models.Model | models.KnownModelName | None = None,
446
- deps: AgentDeps = None,
448
+ deps: AgentDepsT = None,
447
449
  model_settings: ModelSettings | None = None,
448
450
  usage_limits: _usage.UsageLimits | None = None,
449
451
  usage: _usage.Usage | None = None,
450
452
  infer_name: bool = True,
451
- ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, RunResultData]]: ...
453
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
452
454
 
453
455
  @asynccontextmanager
454
456
  async def run_stream(
455
457
  self,
456
458
  user_prompt: str,
457
459
  *,
458
- result_type: type[RunResultData] | None = None,
460
+ result_type: type[RunResultDataT] | None = None,
459
461
  message_history: list[_messages.ModelMessage] | None = None,
460
462
  model: models.Model | models.KnownModelName | None = None,
461
- deps: AgentDeps = None,
463
+ deps: AgentDepsT = None,
462
464
  model_settings: ModelSettings | None = None,
463
465
  usage_limits: _usage.UsageLimits | None = None,
464
466
  usage: _usage.Usage | None = None,
465
467
  infer_name: bool = True,
466
- ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, Any]]:
468
+ ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
467
469
  """Run the agent with a user prompt in async mode, returning a streamed response.
468
470
 
469
471
  Example:
@@ -560,10 +562,19 @@ class Agent(Generic[AgentDeps, ResultData]):
560
562
  parts = await self._process_function_tools(
561
563
  tool_calls, result_tool_name, run_context, result_schema
562
564
  )
565
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
566
+ self._incr_result_retry(run_context)
563
567
  if parts:
564
568
  messages.append(_messages.ModelRequest(parts))
565
569
  run_span.set_attribute('all_messages', messages)
566
570
 
571
+ # The following is not guaranteed to be true, but we consider it a user error if
572
+ # there are result validators that might convert the result data from an overridden
573
+ # `result_type` to a type that is not valid as such.
574
+ result_validators = cast(
575
+ list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators
576
+ )
577
+
567
578
  yield result.StreamedRunResult(
568
579
  messages,
569
580
  new_message_index,
@@ -571,7 +582,7 @@ class Agent(Generic[AgentDeps, ResultData]):
571
582
  result_stream,
572
583
  result_schema,
573
584
  run_context,
574
- self._result_validators,
585
+ result_validators,
575
586
  result_tool_name,
576
587
  on_complete,
577
588
  )
@@ -597,7 +608,7 @@ class Agent(Generic[AgentDeps, ResultData]):
597
608
  def override(
598
609
  self,
599
610
  *,
600
- deps: AgentDeps | _utils.Unset = _utils.UNSET,
611
+ deps: AgentDepsT | _utils.Unset = _utils.UNSET,
601
612
  model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
602
613
  ) -> Iterator[None]:
603
614
  """Context manager to temporarily override agent dependencies and model.
@@ -633,13 +644,13 @@ class Agent(Generic[AgentDeps, ResultData]):
633
644
 
634
645
  @overload
635
646
  def system_prompt(
636
- self, func: Callable[[RunContext[AgentDeps]], str], /
637
- ) -> Callable[[RunContext[AgentDeps]], str]: ...
647
+ self, func: Callable[[RunContext[AgentDepsT]], str], /
648
+ ) -> Callable[[RunContext[AgentDepsT]], str]: ...
638
649
 
639
650
  @overload
640
651
  def system_prompt(
641
- self, func: Callable[[RunContext[AgentDeps]], Awaitable[str]], /
642
- ) -> Callable[[RunContext[AgentDeps]], Awaitable[str]]: ...
652
+ self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
653
+ ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
643
654
 
644
655
  @overload
645
656
  def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
@@ -650,17 +661,17 @@ class Agent(Generic[AgentDeps, ResultData]):
650
661
  @overload
651
662
  def system_prompt(
652
663
  self, /, *, dynamic: bool = False
653
- ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
664
+ ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
654
665
 
655
666
  def system_prompt(
656
667
  self,
657
- func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
668
+ func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
658
669
  /,
659
670
  *,
660
671
  dynamic: bool = False,
661
672
  ) -> (
662
- Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
663
- | _system_prompt.SystemPromptFunc[AgentDeps]
673
+ Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
674
+ | _system_prompt.SystemPromptFunc[AgentDepsT]
664
675
  ):
665
676
  """Decorator to register a system prompt function.
666
677
 
@@ -696,9 +707,9 @@ class Agent(Generic[AgentDeps, ResultData]):
696
707
  if func is None:
697
708
 
698
709
  def decorator(
699
- func_: _system_prompt.SystemPromptFunc[AgentDeps],
700
- ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
701
- runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
710
+ func_: _system_prompt.SystemPromptFunc[AgentDepsT],
711
+ ) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
712
+ runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
702
713
  self._system_prompt_functions.append(runner)
703
714
  if dynamic:
704
715
  self._system_prompt_dynamic_functions[func_.__qualname__] = runner
@@ -712,25 +723,27 @@ class Agent(Generic[AgentDeps, ResultData]):
712
723
 
713
724
  @overload
714
725
  def result_validator(
715
- self, func: Callable[[RunContext[AgentDeps], ResultData], ResultData], /
716
- ) -> Callable[[RunContext[AgentDeps], ResultData], ResultData]: ...
726
+ self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], /
727
+ ) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ...
717
728
 
718
729
  @overload
719
730
  def result_validator(
720
- self, func: Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], /
721
- ) -> Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
731
+ self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], /
732
+ ) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ...
722
733
 
723
734
  @overload
724
- def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
735
+ def result_validator(
736
+ self, func: Callable[[ResultDataT], ResultDataT], /
737
+ ) -> Callable[[ResultDataT], ResultDataT]: ...
725
738
 
726
739
  @overload
727
740
  def result_validator(
728
- self, func: Callable[[ResultData], Awaitable[ResultData]], /
729
- ) -> Callable[[ResultData], Awaitable[ResultData]]: ...
741
+ self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], /
742
+ ) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ...
730
743
 
731
744
  def result_validator(
732
- self, func: _result.ResultValidatorFunc[AgentDeps, ResultData], /
733
- ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
745
+ self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], /
746
+ ) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]:
734
747
  """Decorator to register a result validator function.
735
748
 
736
749
  Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
@@ -762,11 +775,11 @@ class Agent(Generic[AgentDeps, ResultData]):
762
775
  #> success (no tool calls)
763
776
  ```
764
777
  """
765
- self._result_validators.append(_result.ResultValidator[AgentDeps, Any](func))
778
+ self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func))
766
779
  return func
767
780
 
768
781
  @overload
769
- def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ...
782
+ def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
770
783
 
771
784
  @overload
772
785
  def tool(
@@ -774,18 +787,18 @@ class Agent(Generic[AgentDeps, ResultData]):
774
787
  /,
775
788
  *,
776
789
  retries: int | None = None,
777
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
790
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
778
791
  docstring_format: DocstringFormat = 'auto',
779
792
  require_parameter_descriptions: bool = False,
780
- ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
793
+ ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
781
794
 
782
795
  def tool(
783
796
  self,
784
- func: ToolFuncContext[AgentDeps, ToolParams] | None = None,
797
+ func: ToolFuncContext[AgentDepsT, ToolParams] | None = None,
785
798
  /,
786
799
  *,
787
800
  retries: int | None = None,
788
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
801
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
789
802
  docstring_format: DocstringFormat = 'auto',
790
803
  require_parameter_descriptions: bool = False,
791
804
  ) -> Any:
@@ -832,8 +845,8 @@ class Agent(Generic[AgentDeps, ResultData]):
832
845
  if func is None:
833
846
 
834
847
  def tool_decorator(
835
- func_: ToolFuncContext[AgentDeps, ToolParams],
836
- ) -> ToolFuncContext[AgentDeps, ToolParams]:
848
+ func_: ToolFuncContext[AgentDepsT, ToolParams],
849
+ ) -> ToolFuncContext[AgentDepsT, ToolParams]:
837
850
  # noinspection PyTypeChecker
838
851
  self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
839
852
  return func_
@@ -853,7 +866,7 @@ class Agent(Generic[AgentDeps, ResultData]):
853
866
  /,
854
867
  *,
855
868
  retries: int | None = None,
856
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
869
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
857
870
  docstring_format: DocstringFormat = 'auto',
858
871
  require_parameter_descriptions: bool = False,
859
872
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
@@ -864,7 +877,7 @@ class Agent(Generic[AgentDeps, ResultData]):
864
877
  /,
865
878
  *,
866
879
  retries: int | None = None,
867
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
880
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
868
881
  docstring_format: DocstringFormat = 'auto',
869
882
  require_parameter_descriptions: bool = False,
870
883
  ) -> Any:
@@ -924,16 +937,16 @@ class Agent(Generic[AgentDeps, ResultData]):
924
937
 
925
938
  def _register_function(
926
939
  self,
927
- func: ToolFuncEither[AgentDeps, ToolParams],
940
+ func: ToolFuncEither[AgentDepsT, ToolParams],
928
941
  takes_ctx: bool,
929
942
  retries: int | None,
930
- prepare: ToolPrepareFunc[AgentDeps] | None,
943
+ prepare: ToolPrepareFunc[AgentDepsT] | None,
931
944
  docstring_format: DocstringFormat,
932
945
  require_parameter_descriptions: bool,
933
946
  ) -> None:
934
947
  """Private utility to register a function as a tool."""
935
948
  retries_ = retries if retries is not None else self._default_retries
936
- tool = Tool(
949
+ tool = Tool[AgentDepsT](
937
950
  func,
938
951
  takes_ctx=takes_ctx,
939
952
  max_retries=retries_,
@@ -943,7 +956,7 @@ class Agent(Generic[AgentDeps, ResultData]):
943
956
  )
944
957
  self._register_tool(tool)
945
958
 
946
- def _register_tool(self, tool: Tool[AgentDeps]) -> None:
959
+ def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
947
960
  """Private utility to register a tool instance."""
948
961
  if tool.max_retries is None:
949
962
  # noinspection PyTypeChecker
@@ -986,12 +999,12 @@ class Agent(Generic[AgentDeps, ResultData]):
986
999
  return model_
987
1000
 
988
1001
  async def _prepare_model(
989
- self, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
1002
+ self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
990
1003
  ) -> models.AgentModel:
991
1004
  """Build tools and create an agent model."""
992
1005
  function_tools: list[ToolDefinition] = []
993
1006
 
994
- async def add_tool(tool: Tool[AgentDeps]) -> None:
1007
+ async def add_tool(tool: Tool[AgentDepsT]) -> None:
995
1008
  ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
996
1009
  if tool_def := await tool.prepare_tool_def(ctx):
997
1010
  function_tools.append(tool_def)
@@ -1005,7 +1018,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1005
1018
  )
1006
1019
 
1007
1020
  async def _reevaluate_dynamic_prompts(
1008
- self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
1021
+ self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDepsT]
1009
1022
  ) -> None:
1010
1023
  """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
1011
1024
  # Only proceed if there's at least one dynamic runner.
@@ -1022,8 +1035,8 @@ class Agent(Generic[AgentDeps, ResultData]):
1022
1035
  )
1023
1036
 
1024
1037
  def _prepare_result_schema(
1025
- self, result_type: type[RunResultData] | None
1026
- ) -> _result.ResultSchema[RunResultData] | None:
1038
+ self, result_type: type[RunResultDataT] | None
1039
+ ) -> _result.ResultSchema[RunResultDataT] | None:
1027
1040
  if result_type is not None:
1028
1041
  if self._result_validators:
1029
1042
  raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
@@ -1034,10 +1047,13 @@ class Agent(Generic[AgentDeps, ResultData]):
1034
1047
  return self._result_schema # pyright: ignore[reportReturnType]
1035
1048
 
1036
1049
  async def _prepare_messages(
1037
- self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
1050
+ self,
1051
+ user_prompt: str,
1052
+ message_history: list[_messages.ModelMessage] | None,
1053
+ run_context: RunContext[AgentDepsT],
1038
1054
  ) -> list[_messages.ModelMessage]:
1039
1055
  try:
1040
- ctx_messages = _messages_ctx_var.get()
1056
+ ctx_messages = get_captured_run_messages()
1041
1057
  except LookupError:
1042
1058
  messages: list[_messages.ModelMessage] = []
1043
1059
  else:
@@ -1063,9 +1079,9 @@ class Agent(Generic[AgentDeps, ResultData]):
1063
1079
  async def _handle_model_response(
1064
1080
  self,
1065
1081
  model_response: _messages.ModelResponse,
1066
- run_context: RunContext[AgentDeps],
1067
- result_schema: _result.ResultSchema[RunResultData] | None,
1068
- ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1082
+ run_context: RunContext[AgentDepsT],
1083
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1084
+ ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1069
1085
  """Process a non-streamed response from the model.
1070
1086
 
1071
1087
  Returns:
@@ -1094,11 +1110,11 @@ class Agent(Generic[AgentDeps, ResultData]):
1094
1110
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
1095
1111
 
1096
1112
  async def _handle_text_response(
1097
- self, text: str, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
1098
- ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1113
+ self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
1114
+ ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1099
1115
  """Handle a plain text response from the model for non-streaming responses."""
1100
1116
  if self._allow_text_result(result_schema):
1101
- result_data_input = cast(RunResultData, text)
1117
+ result_data_input = cast(RunResultDataT, text)
1102
1118
  try:
1103
1119
  result_data = await self._validate_result(result_data_input, run_context, None)
1104
1120
  except _result.ToolRetryError as e:
@@ -1116,14 +1132,14 @@ class Agent(Generic[AgentDeps, ResultData]):
1116
1132
  async def _handle_structured_response(
1117
1133
  self,
1118
1134
  tool_calls: list[_messages.ToolCallPart],
1119
- run_context: RunContext[AgentDeps],
1120
- result_schema: _result.ResultSchema[RunResultData] | None,
1121
- ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1135
+ run_context: RunContext[AgentDepsT],
1136
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1137
+ ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1122
1138
  """Handle a structured response containing tool calls from the model for non-streaming responses."""
1123
1139
  assert tool_calls, 'Expected at least one tool call'
1124
1140
 
1125
1141
  # first look for the result tool call
1126
- final_result: _MarkFinalResult[RunResultData] | None = None
1142
+ final_result: _MarkFinalResult[RunResultDataT] | None = None
1127
1143
 
1128
1144
  parts: list[_messages.ModelRequestPart] = []
1129
1145
  if result_schema is not None:
@@ -1133,7 +1149,6 @@ class Agent(Generic[AgentDeps, ResultData]):
1133
1149
  result_data = result_tool.validate(call)
1134
1150
  result_data = await self._validate_result(result_data, run_context, call)
1135
1151
  except _result.ToolRetryError as e:
1136
- self._incr_result_retry(run_context)
1137
1152
  parts.append(e.tool_retry)
1138
1153
  else:
1139
1154
  final_result = _MarkFinalResult(result_data, call.tool_name)
@@ -1143,14 +1158,17 @@ class Agent(Generic[AgentDeps, ResultData]):
1143
1158
  tool_calls, final_result and final_result.tool_name, run_context, result_schema
1144
1159
  )
1145
1160
 
1161
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1162
+ self._incr_result_retry(run_context)
1163
+
1146
1164
  return final_result, parts
1147
1165
 
1148
1166
  async def _process_function_tools(
1149
1167
  self,
1150
1168
  tool_calls: list[_messages.ToolCallPart],
1151
1169
  result_tool_name: str | None,
1152
- run_context: RunContext[AgentDeps],
1153
- result_schema: _result.ResultSchema[RunResultData] | None,
1170
+ run_context: RunContext[AgentDepsT],
1171
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1154
1172
  ) -> list[_messages.ModelRequestPart]:
1155
1173
  """Process function (non-result) tool calls in parallel.
1156
1174
 
@@ -1196,7 +1214,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1196
1214
  )
1197
1215
  )
1198
1216
  else:
1199
- parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
1217
+ parts.append(self._unknown_tool(call.tool_name, result_schema))
1200
1218
 
1201
1219
  # Run all tool tasks in parallel
1202
1220
  if tasks:
@@ -1208,8 +1226,8 @@ class Agent(Generic[AgentDeps, ResultData]):
1208
1226
  async def _handle_streamed_response(
1209
1227
  self,
1210
1228
  streamed_response: models.StreamedResponse,
1211
- run_context: RunContext[AgentDeps],
1212
- result_schema: _result.ResultSchema[RunResultData] | None,
1229
+ run_context: RunContext[AgentDepsT],
1230
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1213
1231
  ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
1214
1232
  """Process a streamed response from the model.
1215
1233
 
@@ -1243,7 +1261,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1243
1261
  if tool := self._function_tools.get(p.tool_name):
1244
1262
  tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
1245
1263
  else:
1246
- parts.append(self._unknown_tool(p.tool_name, run_context, result_schema))
1264
+ parts.append(self._unknown_tool(p.tool_name, result_schema))
1247
1265
 
1248
1266
  if received_text and not tasks and not parts:
1249
1267
  # Can only get here if self._allow_text_result returns `False` for the provided result_schema
@@ -1256,30 +1274,34 @@ class Agent(Generic[AgentDeps, ResultData]):
1256
1274
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1257
1275
  task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1258
1276
  parts.extend(task_results)
1277
+
1278
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1279
+ self._incr_result_retry(run_context)
1280
+
1259
1281
  return model_response, parts
1260
1282
 
1261
1283
  async def _validate_result(
1262
1284
  self,
1263
- result_data: RunResultData,
1264
- run_context: RunContext[AgentDeps],
1285
+ result_data: RunResultDataT,
1286
+ run_context: RunContext[AgentDepsT],
1265
1287
  tool_call: _messages.ToolCallPart | None,
1266
- ) -> RunResultData:
1288
+ ) -> RunResultDataT:
1267
1289
  if self._result_validators:
1268
- agent_result_data = cast(ResultData, result_data)
1290
+ agent_result_data = cast(ResultDataT, result_data)
1269
1291
  for validator in self._result_validators:
1270
1292
  agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1271
- return cast(RunResultData, agent_result_data)
1293
+ return cast(RunResultDataT, agent_result_data)
1272
1294
  else:
1273
1295
  return result_data
1274
1296
 
1275
- def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1297
+ def _incr_result_retry(self, run_context: RunContext[AgentDepsT]) -> None:
1276
1298
  run_context.retry += 1
1277
1299
  if run_context.retry > self._max_result_retries:
1278
1300
  raise exceptions.UnexpectedModelBehavior(
1279
1301
  f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
1280
1302
  )
1281
1303
 
1282
- async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
1304
+ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_messages.ModelRequestPart]:
1283
1305
  """Build the initial messages for the conversation."""
1284
1306
  messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1285
1307
  for sys_prompt_runner in self._system_prompt_functions:
@@ -1293,10 +1315,8 @@ class Agent(Generic[AgentDeps, ResultData]):
1293
1315
  def _unknown_tool(
1294
1316
  self,
1295
1317
  tool_name: str,
1296
- run_context: RunContext[AgentDeps],
1297
- result_schema: _result.ResultSchema[RunResultData] | None,
1318
+ result_schema: _result.ResultSchema[RunResultDataT] | None,
1298
1319
  ) -> _messages.RetryPromptPart:
1299
- self._incr_result_retry(run_context)
1300
1320
  names = list(self._function_tools.keys())
1301
1321
  if result_schema:
1302
1322
  names.extend(result_schema.tool_names())
@@ -1306,7 +1326,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1306
1326
  msg = 'No tools available.'
1307
1327
  return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
1308
1328
 
1309
- def _get_deps(self, deps: AgentDeps) -> AgentDeps:
1329
+ def _get_deps(self: Agent[T, Any], deps: T) -> T:
1310
1330
  """Get deps for a run.
1311
1331
 
1312
1332
  If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
@@ -1338,7 +1358,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1338
1358
  return
1339
1359
 
1340
1360
  @staticmethod
1341
- def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
1361
+ def _allow_text_result(result_schema: _result.ResultSchema[RunResultDataT] | None) -> bool:
1342
1362
  return result_schema is None or result_schema.allow_text_result
1343
1363
 
1344
1364
  @property
@@ -1393,16 +1413,20 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1393
1413
  _messages_ctx_var.reset(token)
1394
1414
 
1395
1415
 
1416
+ def get_captured_run_messages() -> _RunMessages:
1417
+ return _messages_ctx_var.get()
1418
+
1419
+
1396
1420
  @dataclasses.dataclass
1397
- class _MarkFinalResult(Generic[ResultData]):
1421
+ class _MarkFinalResult(Generic[ResultDataT]):
1398
1422
  """Marker class to indicate that the result is the final result.
1399
1423
 
1400
- This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
1424
+ This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
1401
1425
 
1402
1426
  It also avoids problems in the case where the result type is itself `None`, but is set.
1403
1427
  """
1404
1428
 
1405
- data: ResultData
1429
+ data: ResultDataT
1406
1430
  """The final result data."""
1407
1431
  tool_name: str | None
1408
1432
  """Name of the final result tool, None if the result is a string."""