pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_pydantic.py +1 -0
- pydantic_ai/_result.py +29 -28
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +137 -113
- pydantic_ai/messages.py +24 -56
- pydantic_ai/models/__init__.py +122 -51
- pydantic_ai/models/anthropic.py +109 -38
- pydantic_ai/models/cohere.py +290 -0
- pydantic_ai/models/function.py +12 -8
- pydantic_ai/models/gemini.py +29 -15
- pydantic_ai/models/groq.py +27 -23
- pydantic_ai/models/mistral.py +34 -29
- pydantic_ai/models/openai.py +45 -23
- pydantic_ai/models/test.py +47 -24
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +45 -26
- pydantic_ai/settings.py +58 -1
- pydantic_ai/tools.py +29 -26
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/METADATA +6 -4
- pydantic_ai_slim-0.0.21.dist-info/RECORD +29 -0
- pydantic_ai/models/ollama.py +0 -120
- pydantic_ai_slim-0.0.19.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -22,10 +22,10 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
usage as _usage,
|
|
24
24
|
)
|
|
25
|
-
from .result import
|
|
25
|
+
from .result import ResultDataT
|
|
26
26
|
from .settings import ModelSettings, merge_model_settings
|
|
27
27
|
from .tools import (
|
|
28
|
-
|
|
28
|
+
AgentDepsT,
|
|
29
29
|
DocstringFormat,
|
|
30
30
|
RunContext,
|
|
31
31
|
Tool,
|
|
@@ -51,6 +51,8 @@ else:
|
|
|
51
51
|
|
|
52
52
|
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
53
53
|
|
|
54
|
+
T = TypeVar('T')
|
|
55
|
+
"""An invariant TypeVar."""
|
|
54
56
|
NoneType = type(None)
|
|
55
57
|
EndStrategy = Literal['early', 'exhaustive']
|
|
56
58
|
"""The strategy for handling multiple tool calls when a final result is found.
|
|
@@ -58,17 +60,17 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
58
60
|
- `'early'`: Stop processing other tool calls once a final result is found
|
|
59
61
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
60
62
|
"""
|
|
61
|
-
|
|
63
|
+
RunResultDataT = TypeVar('RunResultDataT')
|
|
62
64
|
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
63
65
|
|
|
64
66
|
|
|
65
67
|
@final
|
|
66
68
|
@dataclasses.dataclass(init=False)
|
|
67
|
-
class Agent(Generic[
|
|
69
|
+
class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
68
70
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
69
71
|
|
|
70
|
-
Agents are generic in the dependency type they take [`
|
|
71
|
-
and the result data type they return, [`
|
|
72
|
+
Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
|
|
73
|
+
and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT].
|
|
72
74
|
|
|
73
75
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
74
76
|
|
|
@@ -104,34 +106,34 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
104
106
|
"""
|
|
105
107
|
_result_tool_name: str = dataclasses.field(repr=False)
|
|
106
108
|
_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
107
|
-
_result_schema: _result.ResultSchema[
|
|
108
|
-
_result_validators: list[_result.ResultValidator[
|
|
109
|
+
_result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
|
|
110
|
+
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
|
|
109
111
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
110
|
-
_function_tools: dict[str, Tool[
|
|
112
|
+
_function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
|
|
111
113
|
_default_retries: int = dataclasses.field(repr=False)
|
|
112
|
-
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[
|
|
113
|
-
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[
|
|
114
|
+
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
115
|
+
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
|
|
114
116
|
repr=False
|
|
115
117
|
)
|
|
116
|
-
_deps_type: type[
|
|
118
|
+
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
117
119
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
118
|
-
_override_deps: _utils.Option[
|
|
120
|
+
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
|
|
119
121
|
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
|
|
120
122
|
|
|
121
123
|
def __init__(
|
|
122
124
|
self,
|
|
123
125
|
model: models.Model | models.KnownModelName | None = None,
|
|
124
126
|
*,
|
|
125
|
-
result_type: type[
|
|
127
|
+
result_type: type[ResultDataT] = str,
|
|
126
128
|
system_prompt: str | Sequence[str] = (),
|
|
127
|
-
deps_type: type[
|
|
129
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
128
130
|
name: str | None = None,
|
|
129
131
|
model_settings: ModelSettings | None = None,
|
|
130
132
|
retries: int = 1,
|
|
131
133
|
result_tool_name: str = 'final_result',
|
|
132
134
|
result_tool_description: str | None = None,
|
|
133
135
|
result_retries: int | None = None,
|
|
134
|
-
tools: Sequence[Tool[
|
|
136
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
135
137
|
defer_model_check: bool = False,
|
|
136
138
|
end_strategy: EndStrategy = 'early',
|
|
137
139
|
):
|
|
@@ -200,27 +202,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
200
202
|
result_type: None = None,
|
|
201
203
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
202
204
|
model: models.Model | models.KnownModelName | None = None,
|
|
203
|
-
deps:
|
|
205
|
+
deps: AgentDepsT = None,
|
|
204
206
|
model_settings: ModelSettings | None = None,
|
|
205
207
|
usage_limits: _usage.UsageLimits | None = None,
|
|
206
208
|
usage: _usage.Usage | None = None,
|
|
207
209
|
infer_name: bool = True,
|
|
208
|
-
) -> result.RunResult[
|
|
210
|
+
) -> result.RunResult[ResultDataT]: ...
|
|
209
211
|
|
|
210
212
|
@overload
|
|
211
213
|
async def run(
|
|
212
214
|
self,
|
|
213
215
|
user_prompt: str,
|
|
214
216
|
*,
|
|
215
|
-
result_type: type[
|
|
217
|
+
result_type: type[RunResultDataT],
|
|
216
218
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
217
219
|
model: models.Model | models.KnownModelName | None = None,
|
|
218
|
-
deps:
|
|
220
|
+
deps: AgentDepsT = None,
|
|
219
221
|
model_settings: ModelSettings | None = None,
|
|
220
222
|
usage_limits: _usage.UsageLimits | None = None,
|
|
221
223
|
usage: _usage.Usage | None = None,
|
|
222
224
|
infer_name: bool = True,
|
|
223
|
-
) -> result.RunResult[
|
|
225
|
+
) -> result.RunResult[RunResultDataT]: ...
|
|
224
226
|
|
|
225
227
|
async def run(
|
|
226
228
|
self,
|
|
@@ -228,11 +230,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
228
230
|
*,
|
|
229
231
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
230
232
|
model: models.Model | models.KnownModelName | None = None,
|
|
231
|
-
deps:
|
|
233
|
+
deps: AgentDepsT = None,
|
|
232
234
|
model_settings: ModelSettings | None = None,
|
|
233
235
|
usage_limits: _usage.UsageLimits | None = None,
|
|
234
236
|
usage: _usage.Usage | None = None,
|
|
235
|
-
result_type: type[
|
|
237
|
+
result_type: type[RunResultDataT] | None = None,
|
|
236
238
|
infer_name: bool = True,
|
|
237
239
|
) -> result.RunResult[Any]:
|
|
238
240
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -338,36 +340,36 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
338
340
|
*,
|
|
339
341
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
340
342
|
model: models.Model | models.KnownModelName | None = None,
|
|
341
|
-
deps:
|
|
343
|
+
deps: AgentDepsT = None,
|
|
342
344
|
model_settings: ModelSettings | None = None,
|
|
343
345
|
usage_limits: _usage.UsageLimits | None = None,
|
|
344
346
|
usage: _usage.Usage | None = None,
|
|
345
347
|
infer_name: bool = True,
|
|
346
|
-
) -> result.RunResult[
|
|
348
|
+
) -> result.RunResult[ResultDataT]: ...
|
|
347
349
|
|
|
348
350
|
@overload
|
|
349
351
|
def run_sync(
|
|
350
352
|
self,
|
|
351
353
|
user_prompt: str,
|
|
352
354
|
*,
|
|
353
|
-
result_type: type[
|
|
355
|
+
result_type: type[RunResultDataT] | None,
|
|
354
356
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
355
357
|
model: models.Model | models.KnownModelName | None = None,
|
|
356
|
-
deps:
|
|
358
|
+
deps: AgentDepsT = None,
|
|
357
359
|
model_settings: ModelSettings | None = None,
|
|
358
360
|
usage_limits: _usage.UsageLimits | None = None,
|
|
359
361
|
usage: _usage.Usage | None = None,
|
|
360
362
|
infer_name: bool = True,
|
|
361
|
-
) -> result.RunResult[
|
|
363
|
+
) -> result.RunResult[RunResultDataT]: ...
|
|
362
364
|
|
|
363
365
|
def run_sync(
|
|
364
366
|
self,
|
|
365
367
|
user_prompt: str,
|
|
366
368
|
*,
|
|
367
|
-
result_type: type[
|
|
369
|
+
result_type: type[RunResultDataT] | None = None,
|
|
368
370
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
369
371
|
model: models.Model | models.KnownModelName | None = None,
|
|
370
|
-
deps:
|
|
372
|
+
deps: AgentDepsT = None,
|
|
371
373
|
model_settings: ModelSettings | None = None,
|
|
372
374
|
usage_limits: _usage.UsageLimits | None = None,
|
|
373
375
|
usage: _usage.Usage | None = None,
|
|
@@ -428,42 +430,42 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
428
430
|
result_type: None = None,
|
|
429
431
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
430
432
|
model: models.Model | models.KnownModelName | None = None,
|
|
431
|
-
deps:
|
|
433
|
+
deps: AgentDepsT = None,
|
|
432
434
|
model_settings: ModelSettings | None = None,
|
|
433
435
|
usage_limits: _usage.UsageLimits | None = None,
|
|
434
436
|
usage: _usage.Usage | None = None,
|
|
435
437
|
infer_name: bool = True,
|
|
436
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[
|
|
438
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ...
|
|
437
439
|
|
|
438
440
|
@overload
|
|
439
441
|
def run_stream(
|
|
440
442
|
self,
|
|
441
443
|
user_prompt: str,
|
|
442
444
|
*,
|
|
443
|
-
result_type: type[
|
|
445
|
+
result_type: type[RunResultDataT],
|
|
444
446
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
445
447
|
model: models.Model | models.KnownModelName | None = None,
|
|
446
|
-
deps:
|
|
448
|
+
deps: AgentDepsT = None,
|
|
447
449
|
model_settings: ModelSettings | None = None,
|
|
448
450
|
usage_limits: _usage.UsageLimits | None = None,
|
|
449
451
|
usage: _usage.Usage | None = None,
|
|
450
452
|
infer_name: bool = True,
|
|
451
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[
|
|
453
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
|
|
452
454
|
|
|
453
455
|
@asynccontextmanager
|
|
454
456
|
async def run_stream(
|
|
455
457
|
self,
|
|
456
458
|
user_prompt: str,
|
|
457
459
|
*,
|
|
458
|
-
result_type: type[
|
|
460
|
+
result_type: type[RunResultDataT] | None = None,
|
|
459
461
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
460
462
|
model: models.Model | models.KnownModelName | None = None,
|
|
461
|
-
deps:
|
|
463
|
+
deps: AgentDepsT = None,
|
|
462
464
|
model_settings: ModelSettings | None = None,
|
|
463
465
|
usage_limits: _usage.UsageLimits | None = None,
|
|
464
466
|
usage: _usage.Usage | None = None,
|
|
465
467
|
infer_name: bool = True,
|
|
466
|
-
) -> AsyncIterator[result.StreamedRunResult[
|
|
468
|
+
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
|
|
467
469
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
468
470
|
|
|
469
471
|
Example:
|
|
@@ -560,10 +562,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
560
562
|
parts = await self._process_function_tools(
|
|
561
563
|
tool_calls, result_tool_name, run_context, result_schema
|
|
562
564
|
)
|
|
565
|
+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
566
|
+
self._incr_result_retry(run_context)
|
|
563
567
|
if parts:
|
|
564
568
|
messages.append(_messages.ModelRequest(parts))
|
|
565
569
|
run_span.set_attribute('all_messages', messages)
|
|
566
570
|
|
|
571
|
+
# The following is not guaranteed to be true, but we consider it a user error if
|
|
572
|
+
# there are result validators that might convert the result data from an overridden
|
|
573
|
+
# `result_type` to a type that is not valid as such.
|
|
574
|
+
result_validators = cast(
|
|
575
|
+
list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators
|
|
576
|
+
)
|
|
577
|
+
|
|
567
578
|
yield result.StreamedRunResult(
|
|
568
579
|
messages,
|
|
569
580
|
new_message_index,
|
|
@@ -571,7 +582,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
571
582
|
result_stream,
|
|
572
583
|
result_schema,
|
|
573
584
|
run_context,
|
|
574
|
-
|
|
585
|
+
result_validators,
|
|
575
586
|
result_tool_name,
|
|
576
587
|
on_complete,
|
|
577
588
|
)
|
|
@@ -597,7 +608,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
597
608
|
def override(
|
|
598
609
|
self,
|
|
599
610
|
*,
|
|
600
|
-
deps:
|
|
611
|
+
deps: AgentDepsT | _utils.Unset = _utils.UNSET,
|
|
601
612
|
model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
|
|
602
613
|
) -> Iterator[None]:
|
|
603
614
|
"""Context manager to temporarily override agent dependencies and model.
|
|
@@ -633,13 +644,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
633
644
|
|
|
634
645
|
@overload
|
|
635
646
|
def system_prompt(
|
|
636
|
-
self, func: Callable[[RunContext[
|
|
637
|
-
) -> Callable[[RunContext[
|
|
647
|
+
self, func: Callable[[RunContext[AgentDepsT]], str], /
|
|
648
|
+
) -> Callable[[RunContext[AgentDepsT]], str]: ...
|
|
638
649
|
|
|
639
650
|
@overload
|
|
640
651
|
def system_prompt(
|
|
641
|
-
self, func: Callable[[RunContext[
|
|
642
|
-
) -> Callable[[RunContext[
|
|
652
|
+
self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
|
|
653
|
+
) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
|
|
643
654
|
|
|
644
655
|
@overload
|
|
645
656
|
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
|
|
@@ -650,17 +661,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
650
661
|
@overload
|
|
651
662
|
def system_prompt(
|
|
652
663
|
self, /, *, dynamic: bool = False
|
|
653
|
-
) -> Callable[[_system_prompt.SystemPromptFunc[
|
|
664
|
+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
|
|
654
665
|
|
|
655
666
|
def system_prompt(
|
|
656
667
|
self,
|
|
657
|
-
func: _system_prompt.SystemPromptFunc[
|
|
668
|
+
func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
658
669
|
/,
|
|
659
670
|
*,
|
|
660
671
|
dynamic: bool = False,
|
|
661
672
|
) -> (
|
|
662
|
-
Callable[[_system_prompt.SystemPromptFunc[
|
|
663
|
-
| _system_prompt.SystemPromptFunc[
|
|
673
|
+
Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
674
|
+
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
664
675
|
):
|
|
665
676
|
"""Decorator to register a system prompt function.
|
|
666
677
|
|
|
@@ -696,9 +707,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
696
707
|
if func is None:
|
|
697
708
|
|
|
698
709
|
def decorator(
|
|
699
|
-
func_: _system_prompt.SystemPromptFunc[
|
|
700
|
-
) -> _system_prompt.SystemPromptFunc[
|
|
701
|
-
runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
|
|
710
|
+
func_: _system_prompt.SystemPromptFunc[AgentDepsT],
|
|
711
|
+
) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
|
|
712
|
+
runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
|
|
702
713
|
self._system_prompt_functions.append(runner)
|
|
703
714
|
if dynamic:
|
|
704
715
|
self._system_prompt_dynamic_functions[func_.__qualname__] = runner
|
|
@@ -712,25 +723,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
712
723
|
|
|
713
724
|
@overload
|
|
714
725
|
def result_validator(
|
|
715
|
-
self, func: Callable[[RunContext[
|
|
716
|
-
) -> Callable[[RunContext[
|
|
726
|
+
self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], /
|
|
727
|
+
) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ...
|
|
717
728
|
|
|
718
729
|
@overload
|
|
719
730
|
def result_validator(
|
|
720
|
-
self, func: Callable[[RunContext[
|
|
721
|
-
) -> Callable[[RunContext[
|
|
731
|
+
self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], /
|
|
732
|
+
) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ...
|
|
722
733
|
|
|
723
734
|
@overload
|
|
724
|
-
def result_validator(
|
|
735
|
+
def result_validator(
|
|
736
|
+
self, func: Callable[[ResultDataT], ResultDataT], /
|
|
737
|
+
) -> Callable[[ResultDataT], ResultDataT]: ...
|
|
725
738
|
|
|
726
739
|
@overload
|
|
727
740
|
def result_validator(
|
|
728
|
-
self, func: Callable[[
|
|
729
|
-
) -> Callable[[
|
|
741
|
+
self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], /
|
|
742
|
+
) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ...
|
|
730
743
|
|
|
731
744
|
def result_validator(
|
|
732
|
-
self, func: _result.ResultValidatorFunc[
|
|
733
|
-
) -> _result.ResultValidatorFunc[
|
|
745
|
+
self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], /
|
|
746
|
+
) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]:
|
|
734
747
|
"""Decorator to register a result validator function.
|
|
735
748
|
|
|
736
749
|
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
@@ -762,11 +775,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
762
775
|
#> success (no tool calls)
|
|
763
776
|
```
|
|
764
777
|
"""
|
|
765
|
-
self._result_validators.append(_result.ResultValidator[
|
|
778
|
+
self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func))
|
|
766
779
|
return func
|
|
767
780
|
|
|
768
781
|
@overload
|
|
769
|
-
def tool(self, func: ToolFuncContext[
|
|
782
|
+
def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
|
|
770
783
|
|
|
771
784
|
@overload
|
|
772
785
|
def tool(
|
|
@@ -774,18 +787,18 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
774
787
|
/,
|
|
775
788
|
*,
|
|
776
789
|
retries: int | None = None,
|
|
777
|
-
prepare: ToolPrepareFunc[
|
|
790
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
778
791
|
docstring_format: DocstringFormat = 'auto',
|
|
779
792
|
require_parameter_descriptions: bool = False,
|
|
780
|
-
) -> Callable[[ToolFuncContext[
|
|
793
|
+
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
|
|
781
794
|
|
|
782
795
|
def tool(
|
|
783
796
|
self,
|
|
784
|
-
func: ToolFuncContext[
|
|
797
|
+
func: ToolFuncContext[AgentDepsT, ToolParams] | None = None,
|
|
785
798
|
/,
|
|
786
799
|
*,
|
|
787
800
|
retries: int | None = None,
|
|
788
|
-
prepare: ToolPrepareFunc[
|
|
801
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
789
802
|
docstring_format: DocstringFormat = 'auto',
|
|
790
803
|
require_parameter_descriptions: bool = False,
|
|
791
804
|
) -> Any:
|
|
@@ -832,8 +845,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
832
845
|
if func is None:
|
|
833
846
|
|
|
834
847
|
def tool_decorator(
|
|
835
|
-
func_: ToolFuncContext[
|
|
836
|
-
) -> ToolFuncContext[
|
|
848
|
+
func_: ToolFuncContext[AgentDepsT, ToolParams],
|
|
849
|
+
) -> ToolFuncContext[AgentDepsT, ToolParams]:
|
|
837
850
|
# noinspection PyTypeChecker
|
|
838
851
|
self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
839
852
|
return func_
|
|
@@ -853,7 +866,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
853
866
|
/,
|
|
854
867
|
*,
|
|
855
868
|
retries: int | None = None,
|
|
856
|
-
prepare: ToolPrepareFunc[
|
|
869
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
857
870
|
docstring_format: DocstringFormat = 'auto',
|
|
858
871
|
require_parameter_descriptions: bool = False,
|
|
859
872
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
@@ -864,7 +877,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
864
877
|
/,
|
|
865
878
|
*,
|
|
866
879
|
retries: int | None = None,
|
|
867
|
-
prepare: ToolPrepareFunc[
|
|
880
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
868
881
|
docstring_format: DocstringFormat = 'auto',
|
|
869
882
|
require_parameter_descriptions: bool = False,
|
|
870
883
|
) -> Any:
|
|
@@ -924,16 +937,16 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
924
937
|
|
|
925
938
|
def _register_function(
|
|
926
939
|
self,
|
|
927
|
-
func: ToolFuncEither[
|
|
940
|
+
func: ToolFuncEither[AgentDepsT, ToolParams],
|
|
928
941
|
takes_ctx: bool,
|
|
929
942
|
retries: int | None,
|
|
930
|
-
prepare: ToolPrepareFunc[
|
|
943
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None,
|
|
931
944
|
docstring_format: DocstringFormat,
|
|
932
945
|
require_parameter_descriptions: bool,
|
|
933
946
|
) -> None:
|
|
934
947
|
"""Private utility to register a function as a tool."""
|
|
935
948
|
retries_ = retries if retries is not None else self._default_retries
|
|
936
|
-
tool = Tool(
|
|
949
|
+
tool = Tool[AgentDepsT](
|
|
937
950
|
func,
|
|
938
951
|
takes_ctx=takes_ctx,
|
|
939
952
|
max_retries=retries_,
|
|
@@ -943,7 +956,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
943
956
|
)
|
|
944
957
|
self._register_tool(tool)
|
|
945
958
|
|
|
946
|
-
def _register_tool(self, tool: Tool[
|
|
959
|
+
def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
|
|
947
960
|
"""Private utility to register a tool instance."""
|
|
948
961
|
if tool.max_retries is None:
|
|
949
962
|
# noinspection PyTypeChecker
|
|
@@ -986,12 +999,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
986
999
|
return model_
|
|
987
1000
|
|
|
988
1001
|
async def _prepare_model(
|
|
989
|
-
self, run_context: RunContext[
|
|
1002
|
+
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
|
|
990
1003
|
) -> models.AgentModel:
|
|
991
1004
|
"""Build tools and create an agent model."""
|
|
992
1005
|
function_tools: list[ToolDefinition] = []
|
|
993
1006
|
|
|
994
|
-
async def add_tool(tool: Tool[
|
|
1007
|
+
async def add_tool(tool: Tool[AgentDepsT]) -> None:
|
|
995
1008
|
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
996
1009
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
997
1010
|
function_tools.append(tool_def)
|
|
@@ -1005,7 +1018,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1005
1018
|
)
|
|
1006
1019
|
|
|
1007
1020
|
async def _reevaluate_dynamic_prompts(
|
|
1008
|
-
self, messages: list[_messages.ModelMessage], run_context: RunContext[
|
|
1021
|
+
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDepsT]
|
|
1009
1022
|
) -> None:
|
|
1010
1023
|
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
|
|
1011
1024
|
# Only proceed if there's at least one dynamic runner.
|
|
@@ -1022,8 +1035,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1022
1035
|
)
|
|
1023
1036
|
|
|
1024
1037
|
def _prepare_result_schema(
|
|
1025
|
-
self, result_type: type[
|
|
1026
|
-
) -> _result.ResultSchema[
|
|
1038
|
+
self, result_type: type[RunResultDataT] | None
|
|
1039
|
+
) -> _result.ResultSchema[RunResultDataT] | None:
|
|
1027
1040
|
if result_type is not None:
|
|
1028
1041
|
if self._result_validators:
|
|
1029
1042
|
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
@@ -1034,10 +1047,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1034
1047
|
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1035
1048
|
|
|
1036
1049
|
async def _prepare_messages(
|
|
1037
|
-
self,
|
|
1050
|
+
self,
|
|
1051
|
+
user_prompt: str,
|
|
1052
|
+
message_history: list[_messages.ModelMessage] | None,
|
|
1053
|
+
run_context: RunContext[AgentDepsT],
|
|
1038
1054
|
) -> list[_messages.ModelMessage]:
|
|
1039
1055
|
try:
|
|
1040
|
-
ctx_messages =
|
|
1056
|
+
ctx_messages = get_captured_run_messages()
|
|
1041
1057
|
except LookupError:
|
|
1042
1058
|
messages: list[_messages.ModelMessage] = []
|
|
1043
1059
|
else:
|
|
@@ -1063,9 +1079,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1063
1079
|
async def _handle_model_response(
|
|
1064
1080
|
self,
|
|
1065
1081
|
model_response: _messages.ModelResponse,
|
|
1066
|
-
run_context: RunContext[
|
|
1067
|
-
result_schema: _result.ResultSchema[
|
|
1068
|
-
) -> tuple[_MarkFinalResult[
|
|
1082
|
+
run_context: RunContext[AgentDepsT],
|
|
1083
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1084
|
+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1069
1085
|
"""Process a non-streamed response from the model.
|
|
1070
1086
|
|
|
1071
1087
|
Returns:
|
|
@@ -1094,11 +1110,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1094
1110
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1095
1111
|
|
|
1096
1112
|
async def _handle_text_response(
|
|
1097
|
-
self, text: str, run_context: RunContext[
|
|
1098
|
-
) -> tuple[_MarkFinalResult[
|
|
1113
|
+
self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
|
|
1114
|
+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1099
1115
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
1100
1116
|
if self._allow_text_result(result_schema):
|
|
1101
|
-
result_data_input = cast(
|
|
1117
|
+
result_data_input = cast(RunResultDataT, text)
|
|
1102
1118
|
try:
|
|
1103
1119
|
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
1104
1120
|
except _result.ToolRetryError as e:
|
|
@@ -1116,14 +1132,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1116
1132
|
async def _handle_structured_response(
|
|
1117
1133
|
self,
|
|
1118
1134
|
tool_calls: list[_messages.ToolCallPart],
|
|
1119
|
-
run_context: RunContext[
|
|
1120
|
-
result_schema: _result.ResultSchema[
|
|
1121
|
-
) -> tuple[_MarkFinalResult[
|
|
1135
|
+
run_context: RunContext[AgentDepsT],
|
|
1136
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1137
|
+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1122
1138
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
1123
1139
|
assert tool_calls, 'Expected at least one tool call'
|
|
1124
1140
|
|
|
1125
1141
|
# first look for the result tool call
|
|
1126
|
-
final_result: _MarkFinalResult[
|
|
1142
|
+
final_result: _MarkFinalResult[RunResultDataT] | None = None
|
|
1127
1143
|
|
|
1128
1144
|
parts: list[_messages.ModelRequestPart] = []
|
|
1129
1145
|
if result_schema is not None:
|
|
@@ -1133,7 +1149,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1133
1149
|
result_data = result_tool.validate(call)
|
|
1134
1150
|
result_data = await self._validate_result(result_data, run_context, call)
|
|
1135
1151
|
except _result.ToolRetryError as e:
|
|
1136
|
-
self._incr_result_retry(run_context)
|
|
1137
1152
|
parts.append(e.tool_retry)
|
|
1138
1153
|
else:
|
|
1139
1154
|
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
@@ -1143,14 +1158,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1143
1158
|
tool_calls, final_result and final_result.tool_name, run_context, result_schema
|
|
1144
1159
|
)
|
|
1145
1160
|
|
|
1161
|
+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1162
|
+
self._incr_result_retry(run_context)
|
|
1163
|
+
|
|
1146
1164
|
return final_result, parts
|
|
1147
1165
|
|
|
1148
1166
|
async def _process_function_tools(
|
|
1149
1167
|
self,
|
|
1150
1168
|
tool_calls: list[_messages.ToolCallPart],
|
|
1151
1169
|
result_tool_name: str | None,
|
|
1152
|
-
run_context: RunContext[
|
|
1153
|
-
result_schema: _result.ResultSchema[
|
|
1170
|
+
run_context: RunContext[AgentDepsT],
|
|
1171
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1154
1172
|
) -> list[_messages.ModelRequestPart]:
|
|
1155
1173
|
"""Process function (non-result) tool calls in parallel.
|
|
1156
1174
|
|
|
@@ -1196,7 +1214,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1196
1214
|
)
|
|
1197
1215
|
)
|
|
1198
1216
|
else:
|
|
1199
|
-
parts.append(self._unknown_tool(call.tool_name,
|
|
1217
|
+
parts.append(self._unknown_tool(call.tool_name, result_schema))
|
|
1200
1218
|
|
|
1201
1219
|
# Run all tool tasks in parallel
|
|
1202
1220
|
if tasks:
|
|
@@ -1208,8 +1226,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1208
1226
|
async def _handle_streamed_response(
|
|
1209
1227
|
self,
|
|
1210
1228
|
streamed_response: models.StreamedResponse,
|
|
1211
|
-
run_context: RunContext[
|
|
1212
|
-
result_schema: _result.ResultSchema[
|
|
1229
|
+
run_context: RunContext[AgentDepsT],
|
|
1230
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1213
1231
|
) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
|
|
1214
1232
|
"""Process a streamed response from the model.
|
|
1215
1233
|
|
|
@@ -1243,7 +1261,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1243
1261
|
if tool := self._function_tools.get(p.tool_name):
|
|
1244
1262
|
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
1245
1263
|
else:
|
|
1246
|
-
parts.append(self._unknown_tool(p.tool_name,
|
|
1264
|
+
parts.append(self._unknown_tool(p.tool_name, result_schema))
|
|
1247
1265
|
|
|
1248
1266
|
if received_text and not tasks and not parts:
|
|
1249
1267
|
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
@@ -1256,30 +1274,34 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1256
1274
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1257
1275
|
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1258
1276
|
parts.extend(task_results)
|
|
1277
|
+
|
|
1278
|
+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1279
|
+
self._incr_result_retry(run_context)
|
|
1280
|
+
|
|
1259
1281
|
return model_response, parts
|
|
1260
1282
|
|
|
1261
1283
|
async def _validate_result(
|
|
1262
1284
|
self,
|
|
1263
|
-
result_data:
|
|
1264
|
-
run_context: RunContext[
|
|
1285
|
+
result_data: RunResultDataT,
|
|
1286
|
+
run_context: RunContext[AgentDepsT],
|
|
1265
1287
|
tool_call: _messages.ToolCallPart | None,
|
|
1266
|
-
) ->
|
|
1288
|
+
) -> RunResultDataT:
|
|
1267
1289
|
if self._result_validators:
|
|
1268
|
-
agent_result_data = cast(
|
|
1290
|
+
agent_result_data = cast(ResultDataT, result_data)
|
|
1269
1291
|
for validator in self._result_validators:
|
|
1270
1292
|
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
|
|
1271
|
-
return cast(
|
|
1293
|
+
return cast(RunResultDataT, agent_result_data)
|
|
1272
1294
|
else:
|
|
1273
1295
|
return result_data
|
|
1274
1296
|
|
|
1275
|
-
def _incr_result_retry(self, run_context: RunContext[
|
|
1297
|
+
def _incr_result_retry(self, run_context: RunContext[AgentDepsT]) -> None:
|
|
1276
1298
|
run_context.retry += 1
|
|
1277
1299
|
if run_context.retry > self._max_result_retries:
|
|
1278
1300
|
raise exceptions.UnexpectedModelBehavior(
|
|
1279
1301
|
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
1280
1302
|
)
|
|
1281
1303
|
|
|
1282
|
-
async def _sys_parts(self, run_context: RunContext[
|
|
1304
|
+
async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_messages.ModelRequestPart]:
|
|
1283
1305
|
"""Build the initial messages for the conversation."""
|
|
1284
1306
|
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1285
1307
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
@@ -1293,10 +1315,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1293
1315
|
def _unknown_tool(
|
|
1294
1316
|
self,
|
|
1295
1317
|
tool_name: str,
|
|
1296
|
-
|
|
1297
|
-
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1318
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1298
1319
|
) -> _messages.RetryPromptPart:
|
|
1299
|
-
self._incr_result_retry(run_context)
|
|
1300
1320
|
names = list(self._function_tools.keys())
|
|
1301
1321
|
if result_schema:
|
|
1302
1322
|
names.extend(result_schema.tool_names())
|
|
@@ -1306,7 +1326,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1306
1326
|
msg = 'No tools available.'
|
|
1307
1327
|
return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
1308
1328
|
|
|
1309
|
-
def _get_deps(self, deps:
|
|
1329
|
+
def _get_deps(self: Agent[T, Any], deps: T) -> T:
|
|
1310
1330
|
"""Get deps for a run.
|
|
1311
1331
|
|
|
1312
1332
|
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
|
|
@@ -1338,7 +1358,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1338
1358
|
return
|
|
1339
1359
|
|
|
1340
1360
|
@staticmethod
|
|
1341
|
-
def _allow_text_result(result_schema: _result.ResultSchema[
|
|
1361
|
+
def _allow_text_result(result_schema: _result.ResultSchema[RunResultDataT] | None) -> bool:
|
|
1342
1362
|
return result_schema is None or result_schema.allow_text_result
|
|
1343
1363
|
|
|
1344
1364
|
@property
|
|
@@ -1393,16 +1413,20 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
|
1393
1413
|
_messages_ctx_var.reset(token)
|
|
1394
1414
|
|
|
1395
1415
|
|
|
1416
|
+
def get_captured_run_messages() -> _RunMessages:
|
|
1417
|
+
return _messages_ctx_var.get()
|
|
1418
|
+
|
|
1419
|
+
|
|
1396
1420
|
@dataclasses.dataclass
|
|
1397
|
-
class _MarkFinalResult(Generic[
|
|
1421
|
+
class _MarkFinalResult(Generic[ResultDataT]):
|
|
1398
1422
|
"""Marker class to indicate that the result is the final result.
|
|
1399
1423
|
|
|
1400
|
-
This allows us to use `isinstance`, which wouldn't be possible if we were returning `
|
|
1424
|
+
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
|
|
1401
1425
|
|
|
1402
1426
|
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
1403
1427
|
"""
|
|
1404
1428
|
|
|
1405
|
-
data:
|
|
1429
|
+
data: ResultDataT
|
|
1406
1430
|
"""The final result data."""
|
|
1407
1431
|
tool_name: str | None
|
|
1408
1432
|
"""Name of the final result tool, None if the result is a string."""
|