pydantic-ai-slim 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -2
- pydantic_ai/_pydantic.py +27 -11
- pydantic_ai/_result.py +1 -1
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/agent.py +45 -32
- pydantic_ai/messages.py +11 -16
- pydantic_ai/models/__init__.py +21 -11
- pydantic_ai/models/gemini.py +4 -0
- pydantic_ai/models/test.py +4 -8
- pydantic_ai/models/vertexai.py +2 -0
- pydantic_ai/result.py +10 -8
- pydantic_ai/tools.py +240 -0
- {pydantic_ai_slim-0.0.7.dist-info → pydantic_ai_slim-0.0.9.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.9.dist-info/RECORD +22 -0
- pydantic_ai/_tool.py +0 -112
- pydantic_ai/dependencies.py +0 -83
- pydantic_ai_slim-0.0.7.dist-info/RECORD +0 -23
- {pydantic_ai_slim-0.0.7.dist-info → pydantic_ai_slim-0.0.9.dist-info}/WHEEL +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent
|
|
4
|
-
from .dependencies import RunContext
|
|
5
4
|
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
|
|
5
|
+
from .tools import RunContext, Tool
|
|
6
6
|
|
|
7
|
-
__all__ = 'Agent', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
|
|
7
|
+
__all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
|
|
8
8
|
__version__ = version('pydantic_ai_slim')
|
pydantic_ai/_pydantic.py
CHANGED
|
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
|
-
from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
|
|
10
10
|
|
|
11
11
|
from pydantic import ConfigDict, TypeAdapter
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -20,8 +20,7 @@ from ._griffe import doc_descriptions
|
|
|
20
20
|
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
|
|
24
|
-
from .dependencies import AgentDeps, ToolParams
|
|
23
|
+
pass
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
__all__ = 'function_schema', 'LazyTypeAdapter'
|
|
@@ -39,17 +38,16 @@ class FunctionSchema(TypedDict):
|
|
|
39
38
|
var_positional_field: str | None
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
def function_schema(
|
|
41
|
+
def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSchema: # noqa: C901
|
|
43
42
|
"""Build a Pydantic validator and JSON schema from a tool function.
|
|
44
43
|
|
|
45
44
|
Args:
|
|
46
|
-
|
|
45
|
+
function: The function to build a validator and JSON schema for.
|
|
46
|
+
takes_ctx: Whether the function takes a `RunContext` first argument.
|
|
47
47
|
|
|
48
48
|
Returns:
|
|
49
49
|
A `FunctionSchema` instance.
|
|
50
50
|
"""
|
|
51
|
-
function = either_function.whichever()
|
|
52
|
-
takes_ctx = either_function.is_left()
|
|
53
51
|
config = ConfigDict(title=function.__name__)
|
|
54
52
|
config_wrapper = ConfigWrapper(config)
|
|
55
53
|
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
|
|
@@ -78,13 +76,13 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
|
|
|
78
76
|
|
|
79
77
|
if index == 0 and takes_ctx:
|
|
80
78
|
if not _is_call_ctx(annotation):
|
|
81
|
-
errors.append('First
|
|
79
|
+
errors.append('First parameter of tools that take context must be annotated with RunContext[...]')
|
|
82
80
|
continue
|
|
83
81
|
elif not takes_ctx and _is_call_ctx(annotation):
|
|
84
|
-
errors.append('RunContext
|
|
82
|
+
errors.append('RunContext annotations can only be used with tools that take context')
|
|
85
83
|
continue
|
|
86
84
|
elif index != 0 and _is_call_ctx(annotation):
|
|
87
|
-
errors.append('RunContext
|
|
85
|
+
errors.append('RunContext annotations can only be used as the first argument')
|
|
88
86
|
continue
|
|
89
87
|
|
|
90
88
|
field_name = p.name
|
|
@@ -159,6 +157,24 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
|
|
|
159
157
|
)
|
|
160
158
|
|
|
161
159
|
|
|
160
|
+
def takes_ctx(function: Callable[..., Any]) -> bool:
|
|
161
|
+
"""Check if a function takes a `RunContext` first argument.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
function: The function to check.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
`True` if the function takes a `RunContext` as first argument, `False` otherwise.
|
|
168
|
+
"""
|
|
169
|
+
sig = signature(function)
|
|
170
|
+
try:
|
|
171
|
+
_, first_param = next(iter(sig.parameters.items()))
|
|
172
|
+
except StopIteration:
|
|
173
|
+
return False
|
|
174
|
+
else:
|
|
175
|
+
return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
|
|
176
|
+
|
|
177
|
+
|
|
162
178
|
def _build_schema(
|
|
163
179
|
fields: dict[str, core_schema.TypedDictField],
|
|
164
180
|
var_kwargs_schema: core_schema.CoreSchema | None,
|
|
@@ -191,7 +207,7 @@ def _build_schema(
|
|
|
191
207
|
|
|
192
208
|
|
|
193
209
|
def _is_call_ctx(annotation: Any) -> bool:
|
|
194
|
-
from .
|
|
210
|
+
from .tools import RunContext
|
|
195
211
|
|
|
196
212
|
return annotation is RunContext or (
|
|
197
213
|
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
pydantic_ai/_result.py
CHANGED
|
@@ -11,10 +11,10 @@ from pydantic import TypeAdapter, ValidationError
|
|
|
11
11
|
from typing_extensions import Self, TypeAliasType, TypedDict
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages
|
|
14
|
-
from .dependencies import AgentDeps, ResultValidatorFunc, RunContext
|
|
15
14
|
from .exceptions import ModelRetry
|
|
16
15
|
from .messages import ModelStructuredResponse, ToolCall
|
|
17
16
|
from .result import ResultData
|
|
17
|
+
from .tools import AgentDeps, ResultValidatorFunc, RunContext
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
pydantic_ai/_system_prompt.py
CHANGED
pydantic_ai/agent.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import dataclasses
|
|
4
5
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
5
6
|
from contextlib import asynccontextmanager, contextmanager
|
|
6
7
|
from dataclasses import dataclass, field
|
|
@@ -12,15 +13,14 @@ from typing_extensions import assert_never
|
|
|
12
13
|
from . import (
|
|
13
14
|
_result,
|
|
14
15
|
_system_prompt,
|
|
15
|
-
_tool as _r,
|
|
16
16
|
_utils,
|
|
17
17
|
exceptions,
|
|
18
18
|
messages as _messages,
|
|
19
19
|
models,
|
|
20
20
|
result,
|
|
21
21
|
)
|
|
22
|
-
from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
|
|
23
22
|
from .result import ResultData
|
|
23
|
+
from .tools import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams
|
|
24
24
|
|
|
25
25
|
__all__ = ('Agent',)
|
|
26
26
|
|
|
@@ -34,7 +34,7 @@ NoneType = type(None)
|
|
|
34
34
|
class Agent(Generic[AgentDeps, ResultData]):
|
|
35
35
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
36
36
|
|
|
37
|
-
Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.
|
|
37
|
+
Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.tools.AgentDeps]
|
|
38
38
|
and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
|
|
39
39
|
|
|
40
40
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
58
58
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
59
59
|
_allow_text_result: bool = field(repr=False)
|
|
60
60
|
_system_prompts: tuple[str, ...] = field(repr=False)
|
|
61
|
-
_function_tools: dict[str,
|
|
61
|
+
_function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
|
|
62
62
|
_default_retries: int = field(repr=False)
|
|
63
63
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
64
64
|
_deps_type: type[AgentDeps] = field(repr=False)
|
|
@@ -83,6 +83,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
83
83
|
result_tool_name: str = 'final_result',
|
|
84
84
|
result_tool_description: str | None = None,
|
|
85
85
|
result_retries: int | None = None,
|
|
86
|
+
tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
|
|
86
87
|
defer_model_check: bool = False,
|
|
87
88
|
):
|
|
88
89
|
"""Create an agent.
|
|
@@ -101,6 +102,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
101
102
|
result_tool_name: The name of the tool to use for the final result.
|
|
102
103
|
result_tool_description: The description of the final result tool.
|
|
103
104
|
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
105
|
+
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
106
|
+
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
104
107
|
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
|
|
105
108
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
106
109
|
which checks for the necessary environment variables. Set this to `false`
|
|
@@ -119,9 +122,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
119
122
|
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
120
123
|
|
|
121
124
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
122
|
-
self._function_tools
|
|
123
|
-
self._deps_type = deps_type
|
|
125
|
+
self._function_tools = {}
|
|
124
126
|
self._default_retries = retries
|
|
127
|
+
for tool in tools:
|
|
128
|
+
self._register_tool(Tool.infer(tool))
|
|
129
|
+
self._deps_type = deps_type
|
|
125
130
|
self._system_prompt_functions = []
|
|
126
131
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
127
132
|
self._current_result_retry = 0
|
|
@@ -206,7 +211,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
206
211
|
) -> result.RunResult[ResultData]:
|
|
207
212
|
"""Run the agent with a user prompt synchronously.
|
|
208
213
|
|
|
209
|
-
This is a convenience method that wraps `self.run` with `
|
|
214
|
+
This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
|
|
210
215
|
|
|
211
216
|
Args:
|
|
212
217
|
user_prompt: User input to start/continue the conversation.
|
|
@@ -217,7 +222,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
217
222
|
Returns:
|
|
218
223
|
The result of the run.
|
|
219
224
|
"""
|
|
220
|
-
|
|
225
|
+
loop = asyncio.get_event_loop()
|
|
226
|
+
return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
|
|
221
227
|
|
|
222
228
|
@asynccontextmanager
|
|
223
229
|
async def run_stream(
|
|
@@ -284,6 +290,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
284
290
|
self._result_schema,
|
|
285
291
|
deps,
|
|
286
292
|
self._result_validators,
|
|
293
|
+
lambda m: run_span.set_attribute('all_messages', messages),
|
|
287
294
|
)
|
|
288
295
|
return
|
|
289
296
|
else:
|
|
@@ -353,7 +360,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
353
360
|
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
354
361
|
"""Decorator to register a system prompt function.
|
|
355
362
|
|
|
356
|
-
Optionally takes [`RunContext`][pydantic_ai.
|
|
363
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's only argument.
|
|
357
364
|
Can decorate a sync or async functions.
|
|
358
365
|
|
|
359
366
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
@@ -404,7 +411,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
404
411
|
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
405
412
|
"""Decorator to register a result validator function.
|
|
406
413
|
|
|
407
|
-
Optionally takes [`RunContext`][pydantic_ai.
|
|
414
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's first argument.
|
|
408
415
|
Can decorate a sync or async functions.
|
|
409
416
|
|
|
410
417
|
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
@@ -437,22 +444,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
437
444
|
return func
|
|
438
445
|
|
|
439
446
|
@overload
|
|
440
|
-
def tool(self, func:
|
|
447
|
+
def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ...
|
|
441
448
|
|
|
442
449
|
@overload
|
|
443
450
|
def tool(
|
|
444
451
|
self, /, *, retries: int | None = None
|
|
445
|
-
) -> Callable[[
|
|
452
|
+
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
|
|
446
453
|
|
|
447
454
|
def tool(
|
|
448
455
|
self,
|
|
449
|
-
func:
|
|
456
|
+
func: ToolFuncContext[AgentDeps, ToolParams] | None = None,
|
|
450
457
|
/,
|
|
451
458
|
*,
|
|
452
459
|
retries: int | None = None,
|
|
453
460
|
) -> Any:
|
|
454
461
|
"""Decorator to register a tool function which takes
|
|
455
|
-
[`RunContext`][pydantic_ai.
|
|
462
|
+
[`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
456
463
|
|
|
457
464
|
Can decorate a sync or async functions.
|
|
458
465
|
|
|
@@ -489,27 +496,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
489
496
|
if func is None:
|
|
490
497
|
|
|
491
498
|
def tool_decorator(
|
|
492
|
-
func_:
|
|
493
|
-
) ->
|
|
499
|
+
func_: ToolFuncContext[AgentDeps, ToolParams],
|
|
500
|
+
) -> ToolFuncContext[AgentDeps, ToolParams]:
|
|
494
501
|
# noinspection PyTypeChecker
|
|
495
|
-
self.
|
|
502
|
+
self._register_function(func_, True, retries)
|
|
496
503
|
return func_
|
|
497
504
|
|
|
498
505
|
return tool_decorator
|
|
499
506
|
else:
|
|
500
507
|
# noinspection PyTypeChecker
|
|
501
|
-
self.
|
|
508
|
+
self._register_function(func, True, retries)
|
|
502
509
|
return func
|
|
503
510
|
|
|
504
511
|
@overload
|
|
505
|
-
def tool_plain(self, func:
|
|
512
|
+
def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
|
|
506
513
|
|
|
507
514
|
@overload
|
|
508
515
|
def tool_plain(
|
|
509
516
|
self, /, *, retries: int | None = None
|
|
510
|
-
) -> Callable[[
|
|
517
|
+
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
511
518
|
|
|
512
|
-
def tool_plain(self, func:
|
|
519
|
+
def tool_plain(self, func: ToolFuncPlain[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
|
|
513
520
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
514
521
|
|
|
515
522
|
Can decorate a sync or async functions.
|
|
@@ -546,28 +553,34 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
546
553
|
"""
|
|
547
554
|
if func is None:
|
|
548
555
|
|
|
549
|
-
def tool_decorator(
|
|
550
|
-
func_: ToolPlainFunc[ToolParams],
|
|
551
|
-
) -> ToolPlainFunc[ToolParams]:
|
|
556
|
+
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
552
557
|
# noinspection PyTypeChecker
|
|
553
|
-
self.
|
|
558
|
+
self._register_function(func_, False, retries)
|
|
554
559
|
return func_
|
|
555
560
|
|
|
556
561
|
return tool_decorator
|
|
557
562
|
else:
|
|
558
|
-
self.
|
|
563
|
+
self._register_function(func, False, retries)
|
|
559
564
|
return func
|
|
560
565
|
|
|
561
|
-
def
|
|
562
|
-
|
|
566
|
+
def _register_function(
|
|
567
|
+
self, func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None
|
|
568
|
+
) -> None:
|
|
569
|
+
"""Private utility to register a function as a tool."""
|
|
563
570
|
retries_ = retries if retries is not None else self._default_retries
|
|
564
|
-
tool =
|
|
571
|
+
tool = Tool(func, takes_ctx, max_retries=retries_)
|
|
572
|
+
self._register_tool(tool)
|
|
565
573
|
|
|
566
|
-
|
|
567
|
-
|
|
574
|
+
def _register_tool(self, tool: Tool[AgentDeps]) -> None:
|
|
575
|
+
"""Private utility to register a tool instance."""
|
|
576
|
+
if tool.max_retries is None:
|
|
577
|
+
tool = dataclasses.replace(tool, max_retries=self._default_retries)
|
|
568
578
|
|
|
569
579
|
if tool.name in self._function_tools:
|
|
570
|
-
raise
|
|
580
|
+
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
581
|
+
|
|
582
|
+
if self._result_schema and tool.name in self._result_schema.tools:
|
|
583
|
+
raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
|
|
571
584
|
|
|
572
585
|
self._function_tools[tool.name] = tool
|
|
573
586
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from collections.abc import Mapping, Sequence
|
|
5
4
|
from dataclasses import dataclass, field
|
|
6
5
|
from datetime import datetime
|
|
7
|
-
from typing import
|
|
6
|
+
from typing import Annotated, Any, Literal, Union
|
|
8
7
|
|
|
9
8
|
import pydantic
|
|
10
9
|
import pydantic_core
|
|
11
10
|
from pydantic import TypeAdapter
|
|
12
|
-
from typing_extensions import TypeAlias, TypeAliasType
|
|
13
11
|
|
|
14
12
|
from . import _pydantic
|
|
15
13
|
from ._utils import now_utc as _now_utc
|
|
@@ -44,13 +42,7 @@ class UserPrompt:
|
|
|
44
42
|
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
45
43
|
|
|
46
44
|
|
|
47
|
-
|
|
48
|
-
if not TYPE_CHECKING:
|
|
49
|
-
# work around for https://github.com/pydantic/pydantic/issues/10873
|
|
50
|
-
# this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file
|
|
51
|
-
JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]')
|
|
52
|
-
|
|
53
|
-
json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData)
|
|
45
|
+
tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
|
|
54
46
|
|
|
55
47
|
|
|
56
48
|
@dataclass
|
|
@@ -59,7 +51,7 @@ class ToolReturn:
|
|
|
59
51
|
|
|
60
52
|
tool_name: str
|
|
61
53
|
"""The name of the "tool" was called."""
|
|
62
|
-
content:
|
|
54
|
+
content: Any
|
|
63
55
|
"""The return value."""
|
|
64
56
|
tool_id: str | None = None
|
|
65
57
|
"""Optional tool identifier, this is used by some models including OpenAI."""
|
|
@@ -72,15 +64,14 @@ class ToolReturn:
|
|
|
72
64
|
if isinstance(self.content, str):
|
|
73
65
|
return self.content
|
|
74
66
|
else:
|
|
75
|
-
|
|
76
|
-
return json_ta.dump_json(content).decode()
|
|
67
|
+
return tool_return_ta.dump_json(self.content).decode()
|
|
77
68
|
|
|
78
|
-
def model_response_object(self) -> dict[str,
|
|
69
|
+
def model_response_object(self) -> dict[str, Any]:
|
|
79
70
|
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
|
|
80
71
|
if isinstance(self.content, dict):
|
|
81
|
-
return
|
|
72
|
+
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
|
|
82
73
|
else:
|
|
83
|
-
return {'return_value':
|
|
74
|
+
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
84
75
|
|
|
85
76
|
|
|
86
77
|
@dataclass
|
|
@@ -139,12 +130,16 @@ class ModelTextResponse:
|
|
|
139
130
|
|
|
140
131
|
@dataclass
|
|
141
132
|
class ArgsJson:
|
|
133
|
+
"""Tool arguments as a JSON string."""
|
|
134
|
+
|
|
142
135
|
args_json: str
|
|
143
136
|
"""A JSON string of arguments."""
|
|
144
137
|
|
|
145
138
|
|
|
146
139
|
@dataclass
|
|
147
140
|
class ArgsDict:
|
|
141
|
+
"""Tool arguments as a Python dictionary."""
|
|
142
|
+
|
|
148
143
|
args_dict: dict[str, Any]
|
|
149
144
|
"""A python dictionary of arguments."""
|
|
150
145
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -259,20 +259,30 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
259
259
|
class AbstractToolDefinition(Protocol):
|
|
260
260
|
"""Abstract definition of a function/tool.
|
|
261
261
|
|
|
262
|
-
This is used for both tools and result tools.
|
|
262
|
+
This is used for both function tools and result tools.
|
|
263
263
|
"""
|
|
264
264
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
json_schema: ObjectJsonSchema
|
|
270
|
-
"""The JSON schema for the tool's arguments."""
|
|
271
|
-
outer_typed_dict_key: str | None
|
|
272
|
-
"""The key in the outer [TypedDict] that wraps a result tool.
|
|
265
|
+
@property
|
|
266
|
+
def name(self) -> str:
|
|
267
|
+
"""The name of the tool."""
|
|
268
|
+
...
|
|
273
269
|
|
|
274
|
-
|
|
275
|
-
|
|
270
|
+
@property
|
|
271
|
+
def description(self) -> str:
|
|
272
|
+
"""The description of the tool."""
|
|
273
|
+
...
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def json_schema(self) -> ObjectJsonSchema:
|
|
277
|
+
"""The JSON schema for the tool's arguments."""
|
|
278
|
+
...
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def outer_typed_dict_key(self) -> str | None:
|
|
282
|
+
"""The key in the outer [TypedDict] that wraps a result tool.
|
|
283
|
+
|
|
284
|
+
This will only be set for result tools which don't have an `object` JSON schema.
|
|
285
|
+
"""
|
|
276
286
|
|
|
277
287
|
|
|
278
288
|
@cache
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -109,11 +109,15 @@ class GeminiModel(Model):
|
|
|
109
109
|
|
|
110
110
|
|
|
111
111
|
class AuthProtocol(Protocol):
|
|
112
|
+
"""Abstract definition for Gemini authentication."""
|
|
113
|
+
|
|
112
114
|
async def headers(self) -> dict[str, str]: ...
|
|
113
115
|
|
|
114
116
|
|
|
115
117
|
@dataclass
|
|
116
118
|
class ApiKeyAuth:
|
|
119
|
+
"""Authentication using an API key for the `X-Goog-Api-Key` header."""
|
|
120
|
+
|
|
117
121
|
api_key: str
|
|
118
122
|
|
|
119
123
|
async def headers(self) -> dict[str, str]:
|
pydantic_ai/models/test.py
CHANGED
|
@@ -31,14 +31,6 @@ from . import (
|
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
class UnSetType:
|
|
35
|
-
def __repr__(self):
|
|
36
|
-
return 'UnSet'
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
UnSet = UnSetType()
|
|
40
|
-
|
|
41
|
-
|
|
42
34
|
@dataclass
|
|
43
35
|
class TestModel(Model):
|
|
44
36
|
"""A model specifically for testing purposes.
|
|
@@ -186,6 +178,8 @@ class TestAgentModel(AgentModel):
|
|
|
186
178
|
|
|
187
179
|
@dataclass
|
|
188
180
|
class TestStreamTextResponse(StreamTextResponse):
|
|
181
|
+
"""A text response that streams test data."""
|
|
182
|
+
|
|
189
183
|
_text: str
|
|
190
184
|
_cost: Cost
|
|
191
185
|
_iter: Iterator[str] = field(init=False)
|
|
@@ -217,6 +211,8 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
217
211
|
|
|
218
212
|
@dataclass
|
|
219
213
|
class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
214
|
+
"""A structured response that streams test data."""
|
|
215
|
+
|
|
220
216
|
_structured_response: ModelStructuredResponse
|
|
221
217
|
_cost: Cost
|
|
222
218
|
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -182,6 +182,8 @@ MAX_TOKEN_AGE = timedelta(seconds=3000)
|
|
|
182
182
|
|
|
183
183
|
@dataclass
|
|
184
184
|
class BearerTokenAuth:
|
|
185
|
+
"""Authentication using a bearer token generated by google-auth."""
|
|
186
|
+
|
|
185
187
|
credentials: BaseCredentials | ServiceAccountCredentials
|
|
186
188
|
token_created: datetime | None = field(default=None, init=False)
|
|
187
189
|
|
pydantic_ai/result.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import AsyncIterator
|
|
5
|
-
from dataclasses import dataclass
|
|
4
|
+
from collections.abc import AsyncIterator, Callable
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Generic, TypeVar, cast
|
|
8
8
|
|
|
9
9
|
import logfire_api
|
|
10
10
|
|
|
11
11
|
from . import _result, _utils, exceptions, messages, models
|
|
12
|
-
from .
|
|
12
|
+
from .tools import AgentDeps
|
|
13
13
|
|
|
14
14
|
__all__ = (
|
|
15
15
|
'ResultData',
|
|
@@ -49,11 +49,11 @@ class Cost:
|
|
|
49
49
|
This is provided so it's trivial to sum costs from multiple requests and runs.
|
|
50
50
|
"""
|
|
51
51
|
counts: dict[str, int] = {}
|
|
52
|
-
for
|
|
53
|
-
self_value = getattr(self,
|
|
54
|
-
other_value = getattr(other,
|
|
52
|
+
for f in 'request_tokens', 'response_tokens', 'total_tokens':
|
|
53
|
+
self_value = getattr(self, f)
|
|
54
|
+
other_value = getattr(other, f)
|
|
55
55
|
if self_value is not None or other_value is not None:
|
|
56
|
-
counts[
|
|
56
|
+
counts[f] = (self_value or 0) + (other_value or 0)
|
|
57
57
|
|
|
58
58
|
details = self.details.copy() if self.details is not None else None
|
|
59
59
|
if other.details is not None:
|
|
@@ -122,7 +122,8 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
122
122
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
123
123
|
_deps: AgentDeps
|
|
124
124
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
|
|
125
|
-
|
|
125
|
+
_on_complete: Callable[[list[messages.Message]], None]
|
|
126
|
+
is_complete: bool = field(default=False, init=False)
|
|
126
127
|
"""Whether the stream has all been received.
|
|
127
128
|
|
|
128
129
|
This is set to `True` when one of
|
|
@@ -312,3 +313,4 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
312
313
|
else:
|
|
313
314
|
assert structured_message is not None, 'Either text or structured_message should provided, not both'
|
|
314
315
|
self._all_messages.append(structured_message)
|
|
316
|
+
self._on_complete(self._all_messages)
|
pydantic_ai/tools.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
|
|
7
|
+
|
|
8
|
+
from pydantic import ValidationError
|
|
9
|
+
from pydantic_core import SchemaValidator
|
|
10
|
+
from typing_extensions import Concatenate, ParamSpec, final
|
|
11
|
+
|
|
12
|
+
from . import _pydantic, _utils, messages
|
|
13
|
+
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .result import ResultData
|
|
17
|
+
else:
|
|
18
|
+
ResultData = Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__all__ = (
|
|
22
|
+
'AgentDeps',
|
|
23
|
+
'RunContext',
|
|
24
|
+
'ResultValidatorFunc',
|
|
25
|
+
'SystemPromptFunc',
|
|
26
|
+
'ToolFuncContext',
|
|
27
|
+
'ToolFuncPlain',
|
|
28
|
+
'ToolFuncEither',
|
|
29
|
+
'ToolParams',
|
|
30
|
+
'Tool',
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
AgentDeps = TypeVar('AgentDeps')
|
|
34
|
+
"""Type variable for agent dependencies."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class RunContext(Generic[AgentDeps]):
|
|
39
|
+
"""Information about the current call."""
|
|
40
|
+
|
|
41
|
+
deps: AgentDeps
|
|
42
|
+
"""Dependencies for the agent."""
|
|
43
|
+
retry: int
|
|
44
|
+
"""Number of retries so far."""
|
|
45
|
+
tool_name: str | None
|
|
46
|
+
"""Name of the tool being called."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
ToolParams = ParamSpec('ToolParams')
|
|
50
|
+
"""Retrieval function param spec."""
|
|
51
|
+
|
|
52
|
+
SystemPromptFunc = Union[
|
|
53
|
+
Callable[[RunContext[AgentDeps]], str],
|
|
54
|
+
Callable[[RunContext[AgentDeps]], Awaitable[str]],
|
|
55
|
+
Callable[[], str],
|
|
56
|
+
Callable[[], Awaitable[str]],
|
|
57
|
+
]
|
|
58
|
+
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
|
|
59
|
+
|
|
60
|
+
Usage `SystemPromptFunc[AgentDeps]`.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
ResultValidatorFunc = Union[
|
|
64
|
+
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
65
|
+
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
66
|
+
Callable[[ResultData], ResultData],
|
|
67
|
+
Callable[[ResultData], Awaitable[ResultData]],
|
|
68
|
+
]
|
|
69
|
+
"""
|
|
70
|
+
A function that always takes `ResultData` and returns `ResultData`,
|
|
71
|
+
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
|
|
72
|
+
|
|
73
|
+
Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
|
|
77
|
+
"""A tool function that takes `RunContext` as the first argument.
|
|
78
|
+
|
|
79
|
+
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
|
|
80
|
+
"""
|
|
81
|
+
ToolFuncPlain = Callable[ToolParams, Any]
|
|
82
|
+
"""A tool function that does not take `RunContext` as the first argument.
|
|
83
|
+
|
|
84
|
+
Usage `ToolPlainFunc[ToolParams]`.
|
|
85
|
+
"""
|
|
86
|
+
ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
|
|
87
|
+
"""Either kind of tool function.
|
|
88
|
+
|
|
89
|
+
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
|
|
90
|
+
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
|
|
91
|
+
|
|
92
|
+
Usage `ToolFuncEither[AgentDeps, ToolParams]`.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
A = TypeVar('A')
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@final
|
|
99
|
+
@dataclass(init=False)
|
|
100
|
+
class Tool(Generic[AgentDeps]):
|
|
101
|
+
"""A tool function for an agent."""
|
|
102
|
+
|
|
103
|
+
function: ToolFuncEither[AgentDeps, ...]
|
|
104
|
+
takes_ctx: bool
|
|
105
|
+
max_retries: int | None
|
|
106
|
+
name: str
|
|
107
|
+
description: str
|
|
108
|
+
_is_async: bool = field(init=False)
|
|
109
|
+
_single_arg_name: str | None = field(init=False)
|
|
110
|
+
_positional_fields: list[str] = field(init=False)
|
|
111
|
+
_var_positional_field: str | None = field(init=False)
|
|
112
|
+
_validator: SchemaValidator = field(init=False, repr=False)
|
|
113
|
+
_json_schema: _utils.ObjectJsonSchema = field(init=False)
|
|
114
|
+
_current_retry: int = field(default=0, init=False)
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
function: ToolFuncEither[AgentDeps, ...],
|
|
119
|
+
takes_ctx: bool,
|
|
120
|
+
*,
|
|
121
|
+
max_retries: int | None = None,
|
|
122
|
+
name: str | None = None,
|
|
123
|
+
description: str | None = None,
|
|
124
|
+
):
|
|
125
|
+
"""Create a new tool instance.
|
|
126
|
+
|
|
127
|
+
Example usage:
|
|
128
|
+
|
|
129
|
+
```py
|
|
130
|
+
from pydantic_ai import Agent, RunContext, Tool
|
|
131
|
+
|
|
132
|
+
async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
|
|
133
|
+
return f'{ctx.deps} {x} {y}'
|
|
134
|
+
|
|
135
|
+
agent = Agent('test', tools=[Tool(my_tool, True)])
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
function: The Python function to call as the tool.
|
|
140
|
+
takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
|
|
141
|
+
max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
|
|
142
|
+
name: Name of the tool, inferred from the function if `None`.
|
|
143
|
+
description: Description of the tool, inferred from the function if `None`.
|
|
144
|
+
"""
|
|
145
|
+
f = _pydantic.function_schema(function, takes_ctx)
|
|
146
|
+
self.function = function
|
|
147
|
+
self.takes_ctx = takes_ctx
|
|
148
|
+
self.max_retries = max_retries
|
|
149
|
+
self.name = name or function.__name__
|
|
150
|
+
self.description = description or f['description']
|
|
151
|
+
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
152
|
+
self._single_arg_name = f['single_arg_name']
|
|
153
|
+
self._positional_fields = f['positional_fields']
|
|
154
|
+
self._var_positional_field = f['var_positional_field']
|
|
155
|
+
self._validator = f['validator']
|
|
156
|
+
self._json_schema = f['json_schema']
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def infer(function: ToolFuncEither[A, ...] | Tool[A]) -> Tool[A]:
|
|
160
|
+
"""Create a tool from a pure function, inferring whether it takes `RunContext` as its first argument.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
function: The tool function to wrap; or for convenience, a `Tool` instance.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
A new `Tool` instance.
|
|
167
|
+
"""
|
|
168
|
+
if isinstance(function, Tool):
|
|
169
|
+
return function
|
|
170
|
+
else:
|
|
171
|
+
return Tool(function, takes_ctx=_pydantic.takes_ctx(function))
|
|
172
|
+
|
|
173
|
+
def reset(self) -> None:
|
|
174
|
+
"""Reset the current retry count."""
|
|
175
|
+
self._current_retry = 0
|
|
176
|
+
|
|
177
|
+
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
|
|
178
|
+
"""Run the tool function asynchronously."""
|
|
179
|
+
try:
|
|
180
|
+
if isinstance(message.args, messages.ArgsJson):
|
|
181
|
+
args_dict = self._validator.validate_json(message.args.args_json)
|
|
182
|
+
else:
|
|
183
|
+
args_dict = self._validator.validate_python(message.args.args_dict)
|
|
184
|
+
except ValidationError as e:
|
|
185
|
+
return self._on_error(e, message)
|
|
186
|
+
|
|
187
|
+
args, kwargs = self._call_args(deps, args_dict, message)
|
|
188
|
+
try:
|
|
189
|
+
if self._is_async:
|
|
190
|
+
function = cast(Callable[[Any], Awaitable[str]], self.function)
|
|
191
|
+
response_content = await function(*args, **kwargs)
|
|
192
|
+
else:
|
|
193
|
+
function = cast(Callable[[Any], str], self.function)
|
|
194
|
+
response_content = await _utils.run_in_executor(function, *args, **kwargs)
|
|
195
|
+
except ModelRetry as e:
|
|
196
|
+
return self._on_error(e, message)
|
|
197
|
+
|
|
198
|
+
self._current_retry = 0
|
|
199
|
+
return messages.ToolReturn(
|
|
200
|
+
tool_name=message.tool_name,
|
|
201
|
+
content=response_content,
|
|
202
|
+
tool_id=message.tool_id,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def json_schema(self) -> _utils.ObjectJsonSchema:
|
|
207
|
+
return self._json_schema
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def outer_typed_dict_key(self) -> str | None:
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def _call_args(
|
|
214
|
+
self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
|
|
215
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
216
|
+
if self._single_arg_name:
|
|
217
|
+
args_dict = {self._single_arg_name: args_dict}
|
|
218
|
+
|
|
219
|
+
args = [RunContext(deps, self._current_retry, message.tool_name)] if self.takes_ctx else []
|
|
220
|
+
for positional_field in self._positional_fields:
|
|
221
|
+
args.append(args_dict.pop(positional_field))
|
|
222
|
+
if self._var_positional_field:
|
|
223
|
+
args.extend(args_dict.pop(self._var_positional_field))
|
|
224
|
+
|
|
225
|
+
return args, args_dict
|
|
226
|
+
|
|
227
|
+
def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
|
|
228
|
+
self._current_retry += 1
|
|
229
|
+
if self.max_retries is None or self._current_retry > self.max_retries:
|
|
230
|
+
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
|
|
231
|
+
else:
|
|
232
|
+
if isinstance(exc, ValidationError):
|
|
233
|
+
content = exc.errors(include_url=False)
|
|
234
|
+
else:
|
|
235
|
+
content = exc.message
|
|
236
|
+
return messages.RetryPrompt(
|
|
237
|
+
tool_name=call_message.tool_name,
|
|
238
|
+
content=content,
|
|
239
|
+
tool_id=call_message.tool_id,
|
|
240
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
pydantic_ai/__init__.py,sha256=a29NqQz0JyW4BoCjcRh23fBBfwY17_n57moE4QrFWM4,324
|
|
2
|
+
pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
|
|
3
|
+
pydantic_ai/_pydantic.py,sha256=oFfcHDv_wuL1NQ7mCzVHvP1HBaVzyvb7xS-_Iiri_tA,8491
|
|
4
|
+
pydantic_ai/_result.py,sha256=wzcfwDpr_sro1Vkn3DkyIhCXMHTReDxL_ZYm50JzdRI,9667
|
|
5
|
+
pydantic_ai/_system_prompt.py,sha256=vFT0y9Wykl5veGMgLLkGRYiHQrgdW2BZ1rwMn4izjjo,1085
|
|
6
|
+
pydantic_ai/_utils.py,sha256=eNb7f3-ZQC8WDEa87iUcXGQ-lyuutFQG-5yBCMD4Vvs,8227
|
|
7
|
+
pydantic_ai/agent.py,sha256=bHjw1rskPjkAQc8BdoB6AbWjQVTQK0cDDFu7v3EQXqE,34972
|
|
8
|
+
pydantic_ai/exceptions.py,sha256=ko_47M0k6Rhg9mUC9P1cj7N4LCH6cC0pEsF65A2vL-U,1561
|
|
9
|
+
pydantic_ai/messages.py,sha256=I0_CPXDIGGSy-PXHuKq540oAXYOO9uyylpsfSsE4vLs,7032
|
|
10
|
+
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
pydantic_ai/result.py,sha256=UB3vFOcDAceeNLXh_f_3PKy2J0A6FIHdgDFJxPH6OJk,13651
|
|
12
|
+
pydantic_ai/tools.py,sha256=rzchmsEPYtUzBBFoeeJKC1ct36iqOxFOY6kfKqItCCs,8210
|
|
13
|
+
pydantic_ai/models/__init__.py,sha256=_Mz_32WGlAf4NlxXfdQ-EAaY_bDOk10gIc5HmTAO_ts,10318
|
|
14
|
+
pydantic_ai/models/function.py,sha256=Mzc-zXnb2RayWAA8N9NS7KGF49do1S-VW3U9fkc661o,10045
|
|
15
|
+
pydantic_ai/models/gemini.py,sha256=ruO4tnnpDDuHThg7jUOphs8I_KXBJH7gfDMluliED8E,26606
|
|
16
|
+
pydantic_ai/models/groq.py,sha256=Tx2yU3ysmPLBmWGsjzES-XcumzrsoBtB7spCnJBlLiM,14947
|
|
17
|
+
pydantic_ai/models/openai.py,sha256=5ihH25CrS0tnZNW-BZw4GyPe8V-IxIHWw3B9ulPVjQE,14931
|
|
18
|
+
pydantic_ai/models/test.py,sha256=q1wch_E7TSb4qx9PCcP1YyBGZx567MGlAQhlAlON0S8,14463
|
|
19
|
+
pydantic_ai/models/vertexai.py,sha256=5wI8y2YjeRgSE51uKy5OtevQkks65uEbxIUAs5EGBaI,9161
|
|
20
|
+
pydantic_ai_slim-0.0.9.dist-info/METADATA,sha256=NI561tFX5Xwjdak9lEYgMNnXOuKvGY8VZxmTNrt7k5k,2561
|
|
21
|
+
pydantic_ai_slim-0.0.9.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
|
22
|
+
pydantic_ai_slim-0.0.9.dist-info/RECORD,,
|
pydantic_ai/_tool.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations as _annotations
|
|
2
|
-
|
|
3
|
-
import inspect
|
|
4
|
-
from collections.abc import Awaitable
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any, Callable, Generic, cast
|
|
7
|
-
|
|
8
|
-
from pydantic import ValidationError
|
|
9
|
-
from pydantic_core import SchemaValidator
|
|
10
|
-
|
|
11
|
-
from . import _pydantic, _utils, messages
|
|
12
|
-
from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
|
|
13
|
-
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
14
|
-
|
|
15
|
-
# Usage `ToolEitherFunc[AgentDependencies, P]`
|
|
16
|
-
ToolEitherFunc = _utils.Either[ToolContextFunc[AgentDeps, ToolParams], ToolPlainFunc[ToolParams]]
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@dataclass(init=False)
|
|
20
|
-
class Tool(Generic[AgentDeps, ToolParams]):
|
|
21
|
-
"""A tool function for an agent."""
|
|
22
|
-
|
|
23
|
-
name: str
|
|
24
|
-
description: str
|
|
25
|
-
function: ToolEitherFunc[AgentDeps, ToolParams] = field(repr=False)
|
|
26
|
-
is_async: bool
|
|
27
|
-
single_arg_name: str | None
|
|
28
|
-
positional_fields: list[str]
|
|
29
|
-
var_positional_field: str | None
|
|
30
|
-
validator: SchemaValidator = field(repr=False)
|
|
31
|
-
json_schema: _utils.ObjectJsonSchema
|
|
32
|
-
max_retries: int
|
|
33
|
-
_current_retry: int = 0
|
|
34
|
-
outer_typed_dict_key: str | None = None
|
|
35
|
-
|
|
36
|
-
def __init__(self, function: ToolEitherFunc[AgentDeps, ToolParams], retries: int):
|
|
37
|
-
"""Build a Tool dataclass from a function."""
|
|
38
|
-
self.function = function
|
|
39
|
-
# noinspection PyTypeChecker
|
|
40
|
-
f = _pydantic.function_schema(function)
|
|
41
|
-
raw_function = function.whichever()
|
|
42
|
-
self.name = raw_function.__name__
|
|
43
|
-
self.description = f['description']
|
|
44
|
-
self.is_async = inspect.iscoroutinefunction(raw_function)
|
|
45
|
-
self.single_arg_name = f['single_arg_name']
|
|
46
|
-
self.positional_fields = f['positional_fields']
|
|
47
|
-
self.var_positional_field = f['var_positional_field']
|
|
48
|
-
self.validator = f['validator']
|
|
49
|
-
self.json_schema = f['json_schema']
|
|
50
|
-
self.max_retries = retries
|
|
51
|
-
|
|
52
|
-
def reset(self) -> None:
|
|
53
|
-
"""Reset the current retry count."""
|
|
54
|
-
self._current_retry = 0
|
|
55
|
-
|
|
56
|
-
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
|
|
57
|
-
"""Run the tool function asynchronously."""
|
|
58
|
-
try:
|
|
59
|
-
if isinstance(message.args, messages.ArgsJson):
|
|
60
|
-
args_dict = self.validator.validate_json(message.args.args_json)
|
|
61
|
-
else:
|
|
62
|
-
args_dict = self.validator.validate_python(message.args.args_dict)
|
|
63
|
-
except ValidationError as e:
|
|
64
|
-
return self._on_error(e, message)
|
|
65
|
-
|
|
66
|
-
args, kwargs = self._call_args(deps, args_dict, message)
|
|
67
|
-
try:
|
|
68
|
-
if self.is_async:
|
|
69
|
-
function = cast(Callable[[Any], Awaitable[str]], self.function.whichever())
|
|
70
|
-
response_content = await function(*args, **kwargs)
|
|
71
|
-
else:
|
|
72
|
-
function = cast(Callable[[Any], str], self.function.whichever())
|
|
73
|
-
response_content = await _utils.run_in_executor(function, *args, **kwargs)
|
|
74
|
-
except ModelRetry as e:
|
|
75
|
-
return self._on_error(e, message)
|
|
76
|
-
|
|
77
|
-
self._current_retry = 0
|
|
78
|
-
return messages.ToolReturn(
|
|
79
|
-
tool_name=message.tool_name,
|
|
80
|
-
content=response_content,
|
|
81
|
-
tool_id=message.tool_id,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
def _call_args(
|
|
85
|
-
self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
|
|
86
|
-
) -> tuple[list[Any], dict[str, Any]]:
|
|
87
|
-
if self.single_arg_name:
|
|
88
|
-
args_dict = {self.single_arg_name: args_dict}
|
|
89
|
-
|
|
90
|
-
args = [RunContext(deps, self._current_retry, message.tool_name)] if self.function.is_left() else []
|
|
91
|
-
for positional_field in self.positional_fields:
|
|
92
|
-
args.append(args_dict.pop(positional_field))
|
|
93
|
-
if self.var_positional_field:
|
|
94
|
-
args.extend(args_dict.pop(self.var_positional_field))
|
|
95
|
-
|
|
96
|
-
return args, args_dict
|
|
97
|
-
|
|
98
|
-
def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
|
|
99
|
-
self._current_retry += 1
|
|
100
|
-
if self._current_retry > self.max_retries:
|
|
101
|
-
# TODO custom error with details of the tool
|
|
102
|
-
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
|
|
103
|
-
else:
|
|
104
|
-
if isinstance(exc, ValidationError):
|
|
105
|
-
content = exc.errors(include_url=False)
|
|
106
|
-
else:
|
|
107
|
-
content = exc.message
|
|
108
|
-
return messages.RetryPrompt(
|
|
109
|
-
tool_name=call_message.tool_name,
|
|
110
|
-
content=content,
|
|
111
|
-
tool_id=call_message.tool_id,
|
|
112
|
-
)
|
pydantic_ai/dependencies.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations as _annotations
|
|
2
|
-
|
|
3
|
-
from collections.abc import Awaitable, Mapping, Sequence
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
|
6
|
-
|
|
7
|
-
from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
8
|
-
|
|
9
|
-
if TYPE_CHECKING:
|
|
10
|
-
from .result import ResultData
|
|
11
|
-
else:
|
|
12
|
-
ResultData = Any
|
|
13
|
-
|
|
14
|
-
__all__ = (
|
|
15
|
-
'AgentDeps',
|
|
16
|
-
'RunContext',
|
|
17
|
-
'ResultValidatorFunc',
|
|
18
|
-
'SystemPromptFunc',
|
|
19
|
-
'ToolReturnValue',
|
|
20
|
-
'ToolContextFunc',
|
|
21
|
-
'ToolPlainFunc',
|
|
22
|
-
'ToolParams',
|
|
23
|
-
'JsonData',
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
AgentDeps = TypeVar('AgentDeps')
|
|
27
|
-
"""Type variable for agent dependencies."""
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@dataclass
|
|
31
|
-
class RunContext(Generic[AgentDeps]):
|
|
32
|
-
"""Information about the current call."""
|
|
33
|
-
|
|
34
|
-
deps: AgentDeps
|
|
35
|
-
"""Dependencies for the agent."""
|
|
36
|
-
retry: int
|
|
37
|
-
"""Number of retries so far."""
|
|
38
|
-
tool_name: str | None
|
|
39
|
-
"""Name of the tool being called."""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
ToolParams = ParamSpec('ToolParams')
|
|
43
|
-
"""Retrieval function param spec."""
|
|
44
|
-
|
|
45
|
-
SystemPromptFunc = Union[
|
|
46
|
-
Callable[[RunContext[AgentDeps]], str],
|
|
47
|
-
Callable[[RunContext[AgentDeps]], Awaitable[str]],
|
|
48
|
-
Callable[[], str],
|
|
49
|
-
Callable[[], Awaitable[str]],
|
|
50
|
-
]
|
|
51
|
-
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
|
|
52
|
-
|
|
53
|
-
Usage `SystemPromptFunc[AgentDeps]`.
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
ResultValidatorFunc = Union[
|
|
57
|
-
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
58
|
-
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
59
|
-
Callable[[ResultData], ResultData],
|
|
60
|
-
Callable[[ResultData], Awaitable[ResultData]],
|
|
61
|
-
]
|
|
62
|
-
"""
|
|
63
|
-
A function that always takes `ResultData` and returns `ResultData`,
|
|
64
|
-
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
|
|
65
|
-
|
|
66
|
-
Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
67
|
-
"""
|
|
68
|
-
|
|
69
|
-
JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
|
|
70
|
-
"""Type representing any JSON data."""
|
|
71
|
-
|
|
72
|
-
ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
|
|
73
|
-
"""Return value of a tool function."""
|
|
74
|
-
ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
|
|
75
|
-
"""A tool function that takes `RunContext` as the first argument.
|
|
76
|
-
|
|
77
|
-
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
|
|
78
|
-
"""
|
|
79
|
-
ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
|
|
80
|
-
"""A tool function that does not take `RunContext` as the first argument.
|
|
81
|
-
|
|
82
|
-
Usage `ToolPlainFunc[ToolParams]`.
|
|
83
|
-
"""
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
pydantic_ai/__init__.py,sha256=KaTzG8uBSEpfxk_y2a_O_R2Xa53GXYfiigjWtD4PCeI,317
|
|
2
|
-
pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
|
|
3
|
-
pydantic_ai/_pydantic.py,sha256=j1kObPIUDwn1VOHbJBwMFbWLxHM_OYRH7GFAda68ZC0,8010
|
|
4
|
-
pydantic_ai/_result.py,sha256=cAqfPipK39cz-p-ftlJ83Q5_LI1TRb3-HH-iivb5rEM,9674
|
|
5
|
-
pydantic_ai/_system_prompt.py,sha256=63egOej8zHsDVOInPayn0EEEDXKd0HVAbbrqXUTV96s,1092
|
|
6
|
-
pydantic_ai/_tool.py,sha256=5Q9XaGOEXbyOLS644osB1AA5EMoJkr4eYK60MVZo0Z8,4528
|
|
7
|
-
pydantic_ai/_utils.py,sha256=eNb7f3-ZQC8WDEa87iUcXGQ-lyuutFQG-5yBCMD4Vvs,8227
|
|
8
|
-
pydantic_ai/agent.py,sha256=RhLl6tf6kogtauqQr3U1ufg712k-g002YPcio9-BPB4,34242
|
|
9
|
-
pydantic_ai/dependencies.py,sha256=EHvD68AFkItxMnfHzJLG7T_AD1RGI2MZOfzm1v89hGQ,2399
|
|
10
|
-
pydantic_ai/exceptions.py,sha256=ko_47M0k6Rhg9mUC9P1cj7N4LCH6cC0pEsF65A2vL-U,1561
|
|
11
|
-
pydantic_ai/messages.py,sha256=tyBs2ucqIrKjjMIX2bvjMaZ6BMxim4VBkmCzTbeH-AI,7493
|
|
12
|
-
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
-
pydantic_ai/result.py,sha256=Gs2ZjuFJgONGJm8M5R2mGph5-lUFLBg7FxVrs2CVDPs,13525
|
|
14
|
-
pydantic_ai/models/__init__.py,sha256=Cx8PjEsi5gkNOVQic32sf4CmM-A3pRu1LcjpM6poiBI,10138
|
|
15
|
-
pydantic_ai/models/function.py,sha256=Mzc-zXnb2RayWAA8N9NS7KGF49do1S-VW3U9fkc661o,10045
|
|
16
|
-
pydantic_ai/models/gemini.py,sha256=2mD5MJT7qaiMAWrySLjeUFSVoMYhYOS4V7ueTOxzkdA,26472
|
|
17
|
-
pydantic_ai/models/groq.py,sha256=Tx2yU3ysmPLBmWGsjzES-XcumzrsoBtB7spCnJBlLiM,14947
|
|
18
|
-
pydantic_ai/models/openai.py,sha256=5ihH25CrS0tnZNW-BZw4GyPe8V-IxIHWw3B9ulPVjQE,14931
|
|
19
|
-
pydantic_ai/models/test.py,sha256=HOS3_u0n3nYWJmBFPsAFpZ9gfHiZLauz24o825q3e9M,14443
|
|
20
|
-
pydantic_ai/models/vertexai.py,sha256=xHatvwRn7_vyqmp3aDtHqryYAq8NoeXgXVATkj7yHuw,9088
|
|
21
|
-
pydantic_ai_slim-0.0.7.dist-info/METADATA,sha256=UeP3i7vUZv1MCKG-VGLZUwEO8kbcCzzLpD6RsAN_Vbs,2561
|
|
22
|
-
pydantic_ai_slim-0.0.7.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
|
23
|
-
pydantic_ai_slim-0.0.7.dist-info/RECORD,,
|
|
File without changes
|