pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +17 -3
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +187 -159
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +217 -15
- pydantic_ai/models/__init__.py +58 -71
- pydantic_ai/models/anthropic.py +112 -48
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +57 -85
- pydantic_ai/models/gemini.py +83 -129
- pydantic_ai/models/groq.py +60 -130
- pydantic_ai/models/mistral.py +86 -142
- pydantic_ai/models/ollama.py +4 -0
- pydantic_ai/models/openai.py +75 -136
- pydantic_ai/models/test.py +55 -80
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +132 -114
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +42 -23
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -22,10 +22,11 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
usage as _usage,
|
|
24
24
|
)
|
|
25
|
-
from .result import
|
|
25
|
+
from .result import ResultDataT
|
|
26
26
|
from .settings import ModelSettings, merge_model_settings
|
|
27
27
|
from .tools import (
|
|
28
|
-
|
|
28
|
+
AgentDepsT,
|
|
29
|
+
DocstringFormat,
|
|
29
30
|
RunContext,
|
|
30
31
|
Tool,
|
|
31
32
|
ToolDefinition,
|
|
@@ -50,6 +51,8 @@ else:
|
|
|
50
51
|
|
|
51
52
|
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
52
53
|
|
|
54
|
+
T = TypeVar('T')
|
|
55
|
+
"""An invariant TypeVar."""
|
|
53
56
|
NoneType = type(None)
|
|
54
57
|
EndStrategy = Literal['early', 'exhaustive']
|
|
55
58
|
"""The strategy for handling multiple tool calls when a final result is found.
|
|
@@ -63,11 +66,11 @@ RunResultData = TypeVar('RunResultData')
|
|
|
63
66
|
|
|
64
67
|
@final
|
|
65
68
|
@dataclasses.dataclass(init=False)
|
|
66
|
-
class Agent(Generic[
|
|
69
|
+
class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
67
70
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
68
71
|
|
|
69
|
-
Agents are generic in the dependency type they take [`
|
|
70
|
-
and the result data type they return, [`
|
|
72
|
+
Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
|
|
73
|
+
and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT].
|
|
71
74
|
|
|
72
75
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
73
76
|
|
|
@@ -103,34 +106,34 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
103
106
|
"""
|
|
104
107
|
_result_tool_name: str = dataclasses.field(repr=False)
|
|
105
108
|
_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
106
|
-
_result_schema: _result.ResultSchema[
|
|
107
|
-
_result_validators: list[_result.ResultValidator[
|
|
109
|
+
_result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
|
|
110
|
+
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
|
|
108
111
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
109
|
-
_function_tools: dict[str, Tool[
|
|
112
|
+
_function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
|
|
110
113
|
_default_retries: int = dataclasses.field(repr=False)
|
|
111
|
-
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[
|
|
112
|
-
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[
|
|
114
|
+
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
115
|
+
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
|
|
113
116
|
repr=False
|
|
114
117
|
)
|
|
115
|
-
_deps_type: type[
|
|
118
|
+
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
116
119
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
117
|
-
_override_deps: _utils.Option[
|
|
120
|
+
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
|
|
118
121
|
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
|
|
119
122
|
|
|
120
123
|
def __init__(
|
|
121
124
|
self,
|
|
122
125
|
model: models.Model | models.KnownModelName | None = None,
|
|
123
126
|
*,
|
|
124
|
-
result_type: type[
|
|
127
|
+
result_type: type[ResultDataT] = str,
|
|
125
128
|
system_prompt: str | Sequence[str] = (),
|
|
126
|
-
deps_type: type[
|
|
129
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
127
130
|
name: str | None = None,
|
|
128
131
|
model_settings: ModelSettings | None = None,
|
|
129
132
|
retries: int = 1,
|
|
130
133
|
result_tool_name: str = 'final_result',
|
|
131
134
|
result_tool_description: str | None = None,
|
|
132
135
|
result_retries: int | None = None,
|
|
133
|
-
tools: Sequence[Tool[
|
|
136
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
134
137
|
defer_model_check: bool = False,
|
|
135
138
|
end_strategy: EndStrategy = 'early',
|
|
136
139
|
):
|
|
@@ -199,12 +202,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
199
202
|
result_type: None = None,
|
|
200
203
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
201
204
|
model: models.Model | models.KnownModelName | None = None,
|
|
202
|
-
deps:
|
|
205
|
+
deps: AgentDepsT = None,
|
|
203
206
|
model_settings: ModelSettings | None = None,
|
|
204
207
|
usage_limits: _usage.UsageLimits | None = None,
|
|
205
208
|
usage: _usage.Usage | None = None,
|
|
206
209
|
infer_name: bool = True,
|
|
207
|
-
) -> result.RunResult[
|
|
210
|
+
) -> result.RunResult[ResultDataT]: ...
|
|
208
211
|
|
|
209
212
|
@overload
|
|
210
213
|
async def run(
|
|
@@ -214,7 +217,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
214
217
|
result_type: type[RunResultData],
|
|
215
218
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
216
219
|
model: models.Model | models.KnownModelName | None = None,
|
|
217
|
-
deps:
|
|
220
|
+
deps: AgentDepsT = None,
|
|
218
221
|
model_settings: ModelSettings | None = None,
|
|
219
222
|
usage_limits: _usage.UsageLimits | None = None,
|
|
220
223
|
usage: _usage.Usage | None = None,
|
|
@@ -227,7 +230,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
227
230
|
*,
|
|
228
231
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
229
232
|
model: models.Model | models.KnownModelName | None = None,
|
|
230
|
-
deps:
|
|
233
|
+
deps: AgentDepsT = None,
|
|
231
234
|
model_settings: ModelSettings | None = None,
|
|
232
235
|
usage_limits: _usage.UsageLimits | None = None,
|
|
233
236
|
usage: _usage.Usage | None = None,
|
|
@@ -242,9 +245,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
242
245
|
|
|
243
246
|
agent = Agent('openai:gpt-4o')
|
|
244
247
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
+
async def main():
|
|
249
|
+
result = await agent.run('What is the capital of France?')
|
|
250
|
+
print(result.data)
|
|
251
|
+
#> Paris
|
|
248
252
|
```
|
|
249
253
|
|
|
250
254
|
Args:
|
|
@@ -336,12 +340,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
336
340
|
*,
|
|
337
341
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
338
342
|
model: models.Model | models.KnownModelName | None = None,
|
|
339
|
-
deps:
|
|
343
|
+
deps: AgentDepsT = None,
|
|
340
344
|
model_settings: ModelSettings | None = None,
|
|
341
345
|
usage_limits: _usage.UsageLimits | None = None,
|
|
342
346
|
usage: _usage.Usage | None = None,
|
|
343
347
|
infer_name: bool = True,
|
|
344
|
-
) -> result.RunResult[
|
|
348
|
+
) -> result.RunResult[ResultDataT]: ...
|
|
345
349
|
|
|
346
350
|
@overload
|
|
347
351
|
def run_sync(
|
|
@@ -351,7 +355,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
351
355
|
result_type: type[RunResultData] | None,
|
|
352
356
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
353
357
|
model: models.Model | models.KnownModelName | None = None,
|
|
354
|
-
deps:
|
|
358
|
+
deps: AgentDepsT = None,
|
|
355
359
|
model_settings: ModelSettings | None = None,
|
|
356
360
|
usage_limits: _usage.UsageLimits | None = None,
|
|
357
361
|
usage: _usage.Usage | None = None,
|
|
@@ -365,7 +369,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
365
369
|
result_type: type[RunResultData] | None = None,
|
|
366
370
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
367
371
|
model: models.Model | models.KnownModelName | None = None,
|
|
368
|
-
deps:
|
|
372
|
+
deps: AgentDepsT = None,
|
|
369
373
|
model_settings: ModelSettings | None = None,
|
|
370
374
|
usage_limits: _usage.UsageLimits | None = None,
|
|
371
375
|
usage: _usage.Usage | None = None,
|
|
@@ -382,10 +386,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
382
386
|
|
|
383
387
|
agent = Agent('openai:gpt-4o')
|
|
384
388
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
#> Paris
|
|
389
|
+
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
390
|
+
print(result_sync.data)
|
|
391
|
+
#> Rome
|
|
389
392
|
```
|
|
390
393
|
|
|
391
394
|
Args:
|
|
@@ -427,12 +430,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
427
430
|
result_type: None = None,
|
|
428
431
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
429
432
|
model: models.Model | models.KnownModelName | None = None,
|
|
430
|
-
deps:
|
|
433
|
+
deps: AgentDepsT = None,
|
|
431
434
|
model_settings: ModelSettings | None = None,
|
|
432
435
|
usage_limits: _usage.UsageLimits | None = None,
|
|
433
436
|
usage: _usage.Usage | None = None,
|
|
434
437
|
infer_name: bool = True,
|
|
435
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[
|
|
438
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ...
|
|
436
439
|
|
|
437
440
|
@overload
|
|
438
441
|
def run_stream(
|
|
@@ -442,12 +445,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
442
445
|
result_type: type[RunResultData],
|
|
443
446
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
444
447
|
model: models.Model | models.KnownModelName | None = None,
|
|
445
|
-
deps:
|
|
448
|
+
deps: AgentDepsT = None,
|
|
446
449
|
model_settings: ModelSettings | None = None,
|
|
447
450
|
usage_limits: _usage.UsageLimits | None = None,
|
|
448
451
|
usage: _usage.Usage | None = None,
|
|
449
452
|
infer_name: bool = True,
|
|
450
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[
|
|
453
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultData]]: ...
|
|
451
454
|
|
|
452
455
|
@asynccontextmanager
|
|
453
456
|
async def run_stream(
|
|
@@ -457,12 +460,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
457
460
|
result_type: type[RunResultData] | None = None,
|
|
458
461
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
459
462
|
model: models.Model | models.KnownModelName | None = None,
|
|
460
|
-
deps:
|
|
463
|
+
deps: AgentDepsT = None,
|
|
461
464
|
model_settings: ModelSettings | None = None,
|
|
462
465
|
usage_limits: _usage.UsageLimits | None = None,
|
|
463
466
|
usage: _usage.Usage | None = None,
|
|
464
467
|
infer_name: bool = True,
|
|
465
|
-
) -> AsyncIterator[result.StreamedRunResult[
|
|
468
|
+
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
|
|
466
469
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
467
470
|
|
|
468
471
|
Example:
|
|
@@ -535,7 +538,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
535
538
|
model_req_span.__exit__(None, None, None)
|
|
536
539
|
|
|
537
540
|
with _logfire.span('handle model response') as handle_span:
|
|
538
|
-
maybe_final_result = await self.
|
|
541
|
+
maybe_final_result = await self._handle_streamed_response(
|
|
539
542
|
model_response, run_context, result_schema
|
|
540
543
|
)
|
|
541
544
|
|
|
@@ -559,10 +562,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
559
562
|
parts = await self._process_function_tools(
|
|
560
563
|
tool_calls, result_tool_name, run_context, result_schema
|
|
561
564
|
)
|
|
565
|
+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
566
|
+
self._incr_result_retry(run_context)
|
|
562
567
|
if parts:
|
|
563
568
|
messages.append(_messages.ModelRequest(parts))
|
|
564
569
|
run_span.set_attribute('all_messages', messages)
|
|
565
570
|
|
|
571
|
+
# The following is not guaranteed to be true, but we consider it a user error if
|
|
572
|
+
# there are result validators that might convert the result data from an overridden
|
|
573
|
+
# `result_type` to a type that is not valid as such.
|
|
574
|
+
result_validators = cast(
|
|
575
|
+
list[_result.ResultValidator[AgentDepsT, RunResultData]], self._result_validators
|
|
576
|
+
)
|
|
577
|
+
|
|
566
578
|
yield result.StreamedRunResult(
|
|
567
579
|
messages,
|
|
568
580
|
new_message_index,
|
|
@@ -570,7 +582,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
570
582
|
result_stream,
|
|
571
583
|
result_schema,
|
|
572
584
|
run_context,
|
|
573
|
-
|
|
585
|
+
result_validators,
|
|
574
586
|
result_tool_name,
|
|
575
587
|
on_complete,
|
|
576
588
|
)
|
|
@@ -596,7 +608,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
596
608
|
def override(
|
|
597
609
|
self,
|
|
598
610
|
*,
|
|
599
|
-
deps:
|
|
611
|
+
deps: AgentDepsT | _utils.Unset = _utils.UNSET,
|
|
600
612
|
model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
|
|
601
613
|
) -> Iterator[None]:
|
|
602
614
|
"""Context manager to temporarily override agent dependencies and model.
|
|
@@ -632,13 +644,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
632
644
|
|
|
633
645
|
@overload
|
|
634
646
|
def system_prompt(
|
|
635
|
-
self, func: Callable[[RunContext[
|
|
636
|
-
) -> Callable[[RunContext[
|
|
647
|
+
self, func: Callable[[RunContext[AgentDepsT]], str], /
|
|
648
|
+
) -> Callable[[RunContext[AgentDepsT]], str]: ...
|
|
637
649
|
|
|
638
650
|
@overload
|
|
639
651
|
def system_prompt(
|
|
640
|
-
self, func: Callable[[RunContext[
|
|
641
|
-
) -> Callable[[RunContext[
|
|
652
|
+
self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
|
|
653
|
+
) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
|
|
642
654
|
|
|
643
655
|
@overload
|
|
644
656
|
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
|
|
@@ -649,17 +661,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
649
661
|
@overload
|
|
650
662
|
def system_prompt(
|
|
651
663
|
self, /, *, dynamic: bool = False
|
|
652
|
-
) -> Callable[[_system_prompt.SystemPromptFunc[
|
|
664
|
+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
|
|
653
665
|
|
|
654
666
|
def system_prompt(
|
|
655
667
|
self,
|
|
656
|
-
func: _system_prompt.SystemPromptFunc[
|
|
668
|
+
func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
657
669
|
/,
|
|
658
670
|
*,
|
|
659
671
|
dynamic: bool = False,
|
|
660
672
|
) -> (
|
|
661
|
-
Callable[[_system_prompt.SystemPromptFunc[
|
|
662
|
-
| _system_prompt.SystemPromptFunc[
|
|
673
|
+
Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
674
|
+
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
663
675
|
):
|
|
664
676
|
"""Decorator to register a system prompt function.
|
|
665
677
|
|
|
@@ -695,9 +707,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
695
707
|
if func is None:
|
|
696
708
|
|
|
697
709
|
def decorator(
|
|
698
|
-
func_: _system_prompt.SystemPromptFunc[
|
|
699
|
-
) -> _system_prompt.SystemPromptFunc[
|
|
700
|
-
runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
|
|
710
|
+
func_: _system_prompt.SystemPromptFunc[AgentDepsT],
|
|
711
|
+
) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
|
|
712
|
+
runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
|
|
701
713
|
self._system_prompt_functions.append(runner)
|
|
702
714
|
if dynamic:
|
|
703
715
|
self._system_prompt_dynamic_functions[func_.__qualname__] = runner
|
|
@@ -711,25 +723,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
711
723
|
|
|
712
724
|
@overload
|
|
713
725
|
def result_validator(
|
|
714
|
-
self, func: Callable[[RunContext[
|
|
715
|
-
) -> Callable[[RunContext[
|
|
726
|
+
self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], /
|
|
727
|
+
) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ...
|
|
716
728
|
|
|
717
729
|
@overload
|
|
718
730
|
def result_validator(
|
|
719
|
-
self, func: Callable[[RunContext[
|
|
720
|
-
) -> Callable[[RunContext[
|
|
731
|
+
self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], /
|
|
732
|
+
) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ...
|
|
721
733
|
|
|
722
734
|
@overload
|
|
723
|
-
def result_validator(
|
|
735
|
+
def result_validator(
|
|
736
|
+
self, func: Callable[[ResultDataT], ResultDataT], /
|
|
737
|
+
) -> Callable[[ResultDataT], ResultDataT]: ...
|
|
724
738
|
|
|
725
739
|
@overload
|
|
726
740
|
def result_validator(
|
|
727
|
-
self, func: Callable[[
|
|
728
|
-
) -> Callable[[
|
|
741
|
+
self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], /
|
|
742
|
+
) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ...
|
|
729
743
|
|
|
730
744
|
def result_validator(
|
|
731
|
-
self, func: _result.ResultValidatorFunc[
|
|
732
|
-
) -> _result.ResultValidatorFunc[
|
|
745
|
+
self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], /
|
|
746
|
+
) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]:
|
|
733
747
|
"""Decorator to register a result validator function.
|
|
734
748
|
|
|
735
749
|
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
@@ -761,11 +775,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
761
775
|
#> success (no tool calls)
|
|
762
776
|
```
|
|
763
777
|
"""
|
|
764
|
-
self._result_validators.append(_result.ResultValidator[
|
|
778
|
+
self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func))
|
|
765
779
|
return func
|
|
766
780
|
|
|
767
781
|
@overload
|
|
768
|
-
def tool(self, func: ToolFuncContext[
|
|
782
|
+
def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
|
|
769
783
|
|
|
770
784
|
@overload
|
|
771
785
|
def tool(
|
|
@@ -773,16 +787,20 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
773
787
|
/,
|
|
774
788
|
*,
|
|
775
789
|
retries: int | None = None,
|
|
776
|
-
prepare: ToolPrepareFunc[
|
|
777
|
-
|
|
790
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
791
|
+
docstring_format: DocstringFormat = 'auto',
|
|
792
|
+
require_parameter_descriptions: bool = False,
|
|
793
|
+
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
|
|
778
794
|
|
|
779
795
|
def tool(
|
|
780
796
|
self,
|
|
781
|
-
func: ToolFuncContext[
|
|
797
|
+
func: ToolFuncContext[AgentDepsT, ToolParams] | None = None,
|
|
782
798
|
/,
|
|
783
799
|
*,
|
|
784
800
|
retries: int | None = None,
|
|
785
|
-
prepare: ToolPrepareFunc[
|
|
801
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
802
|
+
docstring_format: DocstringFormat = 'auto',
|
|
803
|
+
require_parameter_descriptions: bool = False,
|
|
786
804
|
) -> Any:
|
|
787
805
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
788
806
|
|
|
@@ -820,20 +838,23 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
820
838
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
821
839
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
822
840
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
841
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
842
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
843
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
823
844
|
"""
|
|
824
845
|
if func is None:
|
|
825
846
|
|
|
826
847
|
def tool_decorator(
|
|
827
|
-
func_: ToolFuncContext[
|
|
828
|
-
) -> ToolFuncContext[
|
|
848
|
+
func_: ToolFuncContext[AgentDepsT, ToolParams],
|
|
849
|
+
) -> ToolFuncContext[AgentDepsT, ToolParams]:
|
|
829
850
|
# noinspection PyTypeChecker
|
|
830
|
-
self._register_function(func_, True, retries, prepare)
|
|
851
|
+
self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
831
852
|
return func_
|
|
832
853
|
|
|
833
854
|
return tool_decorator
|
|
834
855
|
else:
|
|
835
856
|
# noinspection PyTypeChecker
|
|
836
|
-
self._register_function(func, True, retries, prepare)
|
|
857
|
+
self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
837
858
|
return func
|
|
838
859
|
|
|
839
860
|
@overload
|
|
@@ -845,7 +866,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
845
866
|
/,
|
|
846
867
|
*,
|
|
847
868
|
retries: int | None = None,
|
|
848
|
-
prepare: ToolPrepareFunc[
|
|
869
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
870
|
+
docstring_format: DocstringFormat = 'auto',
|
|
871
|
+
require_parameter_descriptions: bool = False,
|
|
849
872
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
850
873
|
|
|
851
874
|
def tool_plain(
|
|
@@ -854,7 +877,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
854
877
|
/,
|
|
855
878
|
*,
|
|
856
879
|
retries: int | None = None,
|
|
857
|
-
prepare: ToolPrepareFunc[
|
|
880
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
881
|
+
docstring_format: DocstringFormat = 'auto',
|
|
882
|
+
require_parameter_descriptions: bool = False,
|
|
858
883
|
) -> Any:
|
|
859
884
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
860
885
|
|
|
@@ -892,32 +917,46 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
892
917
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
893
918
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
894
919
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
920
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
921
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
922
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
895
923
|
"""
|
|
896
924
|
if func is None:
|
|
897
925
|
|
|
898
926
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
899
927
|
# noinspection PyTypeChecker
|
|
900
|
-
self._register_function(
|
|
928
|
+
self._register_function(
|
|
929
|
+
func_, False, retries, prepare, docstring_format, require_parameter_descriptions
|
|
930
|
+
)
|
|
901
931
|
return func_
|
|
902
932
|
|
|
903
933
|
return tool_decorator
|
|
904
934
|
else:
|
|
905
|
-
self._register_function(func, False, retries, prepare)
|
|
935
|
+
self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
906
936
|
return func
|
|
907
937
|
|
|
908
938
|
def _register_function(
|
|
909
939
|
self,
|
|
910
|
-
func: ToolFuncEither[
|
|
940
|
+
func: ToolFuncEither[AgentDepsT, ToolParams],
|
|
911
941
|
takes_ctx: bool,
|
|
912
942
|
retries: int | None,
|
|
913
|
-
prepare: ToolPrepareFunc[
|
|
943
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None,
|
|
944
|
+
docstring_format: DocstringFormat,
|
|
945
|
+
require_parameter_descriptions: bool,
|
|
914
946
|
) -> None:
|
|
915
947
|
"""Private utility to register a function as a tool."""
|
|
916
948
|
retries_ = retries if retries is not None else self._default_retries
|
|
917
|
-
tool = Tool(
|
|
949
|
+
tool = Tool[AgentDepsT](
|
|
950
|
+
func,
|
|
951
|
+
takes_ctx=takes_ctx,
|
|
952
|
+
max_retries=retries_,
|
|
953
|
+
prepare=prepare,
|
|
954
|
+
docstring_format=docstring_format,
|
|
955
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
956
|
+
)
|
|
918
957
|
self._register_tool(tool)
|
|
919
958
|
|
|
920
|
-
def _register_tool(self, tool: Tool[
|
|
959
|
+
def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
|
|
921
960
|
"""Private utility to register a tool instance."""
|
|
922
961
|
if tool.max_retries is None:
|
|
923
962
|
# noinspection PyTypeChecker
|
|
@@ -960,12 +999,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
960
999
|
return model_
|
|
961
1000
|
|
|
962
1001
|
async def _prepare_model(
|
|
963
|
-
self, run_context: RunContext[
|
|
1002
|
+
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
|
|
964
1003
|
) -> models.AgentModel:
|
|
965
1004
|
"""Build tools and create an agent model."""
|
|
966
1005
|
function_tools: list[ToolDefinition] = []
|
|
967
1006
|
|
|
968
|
-
async def add_tool(tool: Tool[
|
|
1007
|
+
async def add_tool(tool: Tool[AgentDepsT]) -> None:
|
|
969
1008
|
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
970
1009
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
971
1010
|
function_tools.append(tool_def)
|
|
@@ -979,7 +1018,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
979
1018
|
)
|
|
980
1019
|
|
|
981
1020
|
async def _reevaluate_dynamic_prompts(
|
|
982
|
-
self, messages: list[_messages.ModelMessage], run_context: RunContext[
|
|
1021
|
+
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDepsT]
|
|
983
1022
|
) -> None:
|
|
984
1023
|
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
|
|
985
1024
|
# Only proceed if there's at least one dynamic runner.
|
|
@@ -1008,7 +1047,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1008
1047
|
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1009
1048
|
|
|
1010
1049
|
async def _prepare_messages(
|
|
1011
|
-
self,
|
|
1050
|
+
self,
|
|
1051
|
+
user_prompt: str,
|
|
1052
|
+
message_history: list[_messages.ModelMessage] | None,
|
|
1053
|
+
run_context: RunContext[AgentDepsT],
|
|
1012
1054
|
) -> list[_messages.ModelMessage]:
|
|
1013
1055
|
try:
|
|
1014
1056
|
ctx_messages = _messages_ctx_var.get()
|
|
@@ -1037,7 +1079,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1037
1079
|
async def _handle_model_response(
|
|
1038
1080
|
self,
|
|
1039
1081
|
model_response: _messages.ModelResponse,
|
|
1040
|
-
run_context: RunContext[
|
|
1082
|
+
run_context: RunContext[AgentDepsT],
|
|
1041
1083
|
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1042
1084
|
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
1043
1085
|
"""Process a non-streamed response from the model.
|
|
@@ -1068,7 +1110,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1068
1110
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1069
1111
|
|
|
1070
1112
|
async def _handle_text_response(
|
|
1071
|
-
self, text: str, run_context: RunContext[
|
|
1113
|
+
self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
|
|
1072
1114
|
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
1073
1115
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
1074
1116
|
if self._allow_text_result(result_schema):
|
|
@@ -1090,7 +1132,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1090
1132
|
async def _handle_structured_response(
|
|
1091
1133
|
self,
|
|
1092
1134
|
tool_calls: list[_messages.ToolCallPart],
|
|
1093
|
-
run_context: RunContext[
|
|
1135
|
+
run_context: RunContext[AgentDepsT],
|
|
1094
1136
|
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1095
1137
|
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
1096
1138
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
@@ -1100,14 +1142,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1100
1142
|
final_result: _MarkFinalResult[RunResultData] | None = None
|
|
1101
1143
|
|
|
1102
1144
|
parts: list[_messages.ModelRequestPart] = []
|
|
1103
|
-
if result_schema
|
|
1145
|
+
if result_schema is not None:
|
|
1104
1146
|
if match := result_schema.find_tool(tool_calls):
|
|
1105
1147
|
call, result_tool = match
|
|
1106
1148
|
try:
|
|
1107
1149
|
result_data = result_tool.validate(call)
|
|
1108
1150
|
result_data = await self._validate_result(result_data, run_context, call)
|
|
1109
1151
|
except _result.ToolRetryError as e:
|
|
1110
|
-
self._incr_result_retry(run_context)
|
|
1111
1152
|
parts.append(e.tool_retry)
|
|
1112
1153
|
else:
|
|
1113
1154
|
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
@@ -1117,13 +1158,16 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1117
1158
|
tool_calls, final_result and final_result.tool_name, run_context, result_schema
|
|
1118
1159
|
)
|
|
1119
1160
|
|
|
1161
|
+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1162
|
+
self._incr_result_retry(run_context)
|
|
1163
|
+
|
|
1120
1164
|
return final_result, parts
|
|
1121
1165
|
|
|
1122
1166
|
async def _process_function_tools(
|
|
1123
1167
|
self,
|
|
1124
1168
|
tool_calls: list[_messages.ToolCallPart],
|
|
1125
1169
|
result_tool_name: str | None,
|
|
1126
|
-
run_context: RunContext[
|
|
1170
|
+
run_context: RunContext[AgentDepsT],
|
|
1127
1171
|
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1128
1172
|
) -> list[_messages.ModelRequestPart]:
|
|
1129
1173
|
"""Process function (non-result) tool calls in parallel.
|
|
@@ -1170,7 +1214,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1170
1214
|
)
|
|
1171
1215
|
)
|
|
1172
1216
|
else:
|
|
1173
|
-
parts.append(self._unknown_tool(call.tool_name,
|
|
1217
|
+
parts.append(self._unknown_tool(call.tool_name, result_schema))
|
|
1174
1218
|
|
|
1175
1219
|
# Run all tool tasks in parallel
|
|
1176
1220
|
if tasks:
|
|
@@ -1179,99 +1223,85 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1179
1223
|
parts.extend(task_results)
|
|
1180
1224
|
return parts
|
|
1181
1225
|
|
|
1182
|
-
async def
|
|
1226
|
+
async def _handle_streamed_response(
|
|
1183
1227
|
self,
|
|
1184
|
-
|
|
1185
|
-
run_context: RunContext[
|
|
1228
|
+
streamed_response: models.StreamedResponse,
|
|
1229
|
+
run_context: RunContext[AgentDepsT],
|
|
1186
1230
|
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1187
|
-
) ->
|
|
1188
|
-
_MarkFinalResult[models.EitherStreamedResponse]
|
|
1189
|
-
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
1190
|
-
):
|
|
1231
|
+
) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
|
|
1191
1232
|
"""Process a streamed response from the model.
|
|
1192
1233
|
|
|
1193
1234
|
Returns:
|
|
1194
1235
|
Either a final result or a tuple of the model response and the tool responses for the next request.
|
|
1195
1236
|
If a final result is returned, the conversation should end.
|
|
1196
1237
|
"""
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
1213
|
-
if result_schema is not None:
|
|
1214
|
-
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
1215
|
-
# NOTE: this means we ignore any other tools called here
|
|
1216
|
-
structured_msg = model_response.get()
|
|
1217
|
-
while not structured_msg.parts:
|
|
1218
|
-
try:
|
|
1219
|
-
await model_response.__anext__()
|
|
1220
|
-
except StopAsyncIteration:
|
|
1221
|
-
break
|
|
1222
|
-
structured_msg = model_response.get()
|
|
1223
|
-
|
|
1224
|
-
if match := result_schema.find_tool(structured_msg.parts):
|
|
1225
|
-
call, _ = match
|
|
1226
|
-
return _MarkFinalResult(model_response, call.tool_name)
|
|
1227
|
-
|
|
1228
|
-
# the model is calling a tool function, consume the response to get the next message
|
|
1229
|
-
async for _ in model_response:
|
|
1230
|
-
pass
|
|
1231
|
-
model_response_msg = model_response.get()
|
|
1232
|
-
if not model_response_msg.parts:
|
|
1233
|
-
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
1234
|
-
|
|
1235
|
-
# we now run all tool functions in parallel
|
|
1236
|
-
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1237
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
1238
|
-
for item in model_response_msg.parts:
|
|
1239
|
-
if isinstance(item, _messages.ToolCallPart):
|
|
1240
|
-
call = item
|
|
1241
|
-
if tool := self._function_tools.get(call.tool_name):
|
|
1242
|
-
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1243
|
-
else:
|
|
1244
|
-
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
|
|
1238
|
+
received_text = False
|
|
1239
|
+
|
|
1240
|
+
async for maybe_part_event in streamed_response:
|
|
1241
|
+
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
1242
|
+
new_part = maybe_part_event.part
|
|
1243
|
+
if isinstance(new_part, _messages.TextPart):
|
|
1244
|
+
received_text = True
|
|
1245
|
+
if self._allow_text_result(result_schema):
|
|
1246
|
+
return _MarkFinalResult(streamed_response, None)
|
|
1247
|
+
elif isinstance(new_part, _messages.ToolCallPart):
|
|
1248
|
+
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
1249
|
+
call, _ = match
|
|
1250
|
+
return _MarkFinalResult(streamed_response, call.tool_name)
|
|
1251
|
+
else:
|
|
1252
|
+
assert_never(new_part)
|
|
1245
1253
|
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1254
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1255
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
1256
|
+
model_response = streamed_response.get()
|
|
1257
|
+
if not model_response.parts:
|
|
1258
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1259
|
+
for p in model_response.parts:
|
|
1260
|
+
if isinstance(p, _messages.ToolCallPart):
|
|
1261
|
+
if tool := self._function_tools.get(p.tool_name):
|
|
1262
|
+
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
1263
|
+
else:
|
|
1264
|
+
parts.append(self._unknown_tool(p.tool_name, result_schema))
|
|
1265
|
+
|
|
1266
|
+
if received_text and not tasks and not parts:
|
|
1267
|
+
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
1268
|
+
self._incr_result_retry(run_context)
|
|
1269
|
+
model_response = _messages.RetryPromptPart(
|
|
1270
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
1271
|
+
)
|
|
1272
|
+
return streamed_response.get(), [model_response]
|
|
1273
|
+
|
|
1274
|
+
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1275
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1276
|
+
parts.extend(task_results)
|
|
1277
|
+
|
|
1278
|
+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1279
|
+
self._incr_result_retry(run_context)
|
|
1280
|
+
|
|
1281
|
+
return model_response, parts
|
|
1252
1282
|
|
|
1253
1283
|
async def _validate_result(
|
|
1254
1284
|
self,
|
|
1255
1285
|
result_data: RunResultData,
|
|
1256
|
-
run_context: RunContext[
|
|
1286
|
+
run_context: RunContext[AgentDepsT],
|
|
1257
1287
|
tool_call: _messages.ToolCallPart | None,
|
|
1258
1288
|
) -> RunResultData:
|
|
1259
1289
|
if self._result_validators:
|
|
1260
|
-
agent_result_data = cast(
|
|
1290
|
+
agent_result_data = cast(ResultDataT, result_data)
|
|
1261
1291
|
for validator in self._result_validators:
|
|
1262
1292
|
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
|
|
1263
1293
|
return cast(RunResultData, agent_result_data)
|
|
1264
1294
|
else:
|
|
1265
1295
|
return result_data
|
|
1266
1296
|
|
|
1267
|
-
def _incr_result_retry(self, run_context: RunContext[
|
|
1297
|
+
def _incr_result_retry(self, run_context: RunContext[AgentDepsT]) -> None:
|
|
1268
1298
|
run_context.retry += 1
|
|
1269
1299
|
if run_context.retry > self._max_result_retries:
|
|
1270
1300
|
raise exceptions.UnexpectedModelBehavior(
|
|
1271
1301
|
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
1272
1302
|
)
|
|
1273
1303
|
|
|
1274
|
-
async def _sys_parts(self, run_context: RunContext[
|
|
1304
|
+
async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_messages.ModelRequestPart]:
|
|
1275
1305
|
"""Build the initial messages for the conversation."""
|
|
1276
1306
|
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1277
1307
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
@@ -1285,10 +1315,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1285
1315
|
def _unknown_tool(
|
|
1286
1316
|
self,
|
|
1287
1317
|
tool_name: str,
|
|
1288
|
-
run_context: RunContext[AgentDeps],
|
|
1289
1318
|
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1290
1319
|
) -> _messages.RetryPromptPart:
|
|
1291
|
-
self._incr_result_retry(run_context)
|
|
1292
1320
|
names = list(self._function_tools.keys())
|
|
1293
1321
|
if result_schema:
|
|
1294
1322
|
names.extend(result_schema.tool_names())
|
|
@@ -1298,7 +1326,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1298
1326
|
msg = 'No tools available.'
|
|
1299
1327
|
return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
1300
1328
|
|
|
1301
|
-
def _get_deps(self, deps:
|
|
1329
|
+
def _get_deps(self: Agent[T, Any], deps: T) -> T:
|
|
1302
1330
|
"""Get deps for a run.
|
|
1303
1331
|
|
|
1304
1332
|
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
|
|
@@ -1386,15 +1414,15 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
|
1386
1414
|
|
|
1387
1415
|
|
|
1388
1416
|
@dataclasses.dataclass
|
|
1389
|
-
class _MarkFinalResult(Generic[
|
|
1417
|
+
class _MarkFinalResult(Generic[ResultDataT]):
|
|
1390
1418
|
"""Marker class to indicate that the result is the final result.
|
|
1391
1419
|
|
|
1392
|
-
This allows us to use `isinstance`, which wouldn't be possible if we were returning `
|
|
1420
|
+
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
|
|
1393
1421
|
|
|
1394
1422
|
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
1395
1423
|
"""
|
|
1396
1424
|
|
|
1397
|
-
data:
|
|
1425
|
+
data: ResultDataT
|
|
1398
1426
|
"""The final result data."""
|
|
1399
1427
|
tool_name: str | None
|
|
1400
1428
|
"""Name of the final result tool, None if the result is a string."""
|