pydantic-ai-slim 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -2
- pydantic_ai/_pydantic.py +27 -11
- pydantic_ai/_result.py +1 -1
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/agent.py +86 -34
- pydantic_ai/messages.py +7 -16
- pydantic_ai/models/__init__.py +21 -11
- pydantic_ai/result.py +1 -1
- pydantic_ai/tools.py +240 -0
- {pydantic_ai_slim-0.0.8.dist-info → pydantic_ai_slim-0.0.10.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.10.dist-info/RECORD +22 -0
- pydantic_ai/_tool.py +0 -112
- pydantic_ai/dependencies.py +0 -83
- pydantic_ai_slim-0.0.8.dist-info/RECORD +0 -23
- {pydantic_ai_slim-0.0.8.dist-info → pydantic_ai_slim-0.0.10.dist-info}/WHEEL +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent
|
|
4
|
-
from .dependencies import RunContext
|
|
5
4
|
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
|
|
5
|
+
from .tools import RunContext, Tool
|
|
6
6
|
|
|
7
|
-
__all__ = 'Agent', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
|
|
7
|
+
__all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
|
|
8
8
|
__version__ = version('pydantic_ai_slim')
|
pydantic_ai/_pydantic.py
CHANGED
|
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
|
-
from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
|
|
10
10
|
|
|
11
11
|
from pydantic import ConfigDict, TypeAdapter
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -20,8 +20,7 @@ from ._griffe import doc_descriptions
|
|
|
20
20
|
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
|
|
24
|
-
from .dependencies import AgentDeps, ToolParams
|
|
23
|
+
pass
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
__all__ = 'function_schema', 'LazyTypeAdapter'
|
|
@@ -39,17 +38,16 @@ class FunctionSchema(TypedDict):
|
|
|
39
38
|
var_positional_field: str | None
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
def function_schema(
|
|
41
|
+
def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSchema: # noqa: C901
|
|
43
42
|
"""Build a Pydantic validator and JSON schema from a tool function.
|
|
44
43
|
|
|
45
44
|
Args:
|
|
46
|
-
|
|
45
|
+
function: The function to build a validator and JSON schema for.
|
|
46
|
+
takes_ctx: Whether the function takes a `RunContext` first argument.
|
|
47
47
|
|
|
48
48
|
Returns:
|
|
49
49
|
A `FunctionSchema` instance.
|
|
50
50
|
"""
|
|
51
|
-
function = either_function.whichever()
|
|
52
|
-
takes_ctx = either_function.is_left()
|
|
53
51
|
config = ConfigDict(title=function.__name__)
|
|
54
52
|
config_wrapper = ConfigWrapper(config)
|
|
55
53
|
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
|
|
@@ -78,13 +76,13 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
|
|
|
78
76
|
|
|
79
77
|
if index == 0 and takes_ctx:
|
|
80
78
|
if not _is_call_ctx(annotation):
|
|
81
|
-
errors.append('First
|
|
79
|
+
errors.append('First parameter of tools that take context must be annotated with RunContext[...]')
|
|
82
80
|
continue
|
|
83
81
|
elif not takes_ctx and _is_call_ctx(annotation):
|
|
84
|
-
errors.append('RunContext
|
|
82
|
+
errors.append('RunContext annotations can only be used with tools that take context')
|
|
85
83
|
continue
|
|
86
84
|
elif index != 0 and _is_call_ctx(annotation):
|
|
87
|
-
errors.append('RunContext
|
|
85
|
+
errors.append('RunContext annotations can only be used as the first argument')
|
|
88
86
|
continue
|
|
89
87
|
|
|
90
88
|
field_name = p.name
|
|
@@ -159,6 +157,24 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
|
|
|
159
157
|
)
|
|
160
158
|
|
|
161
159
|
|
|
160
|
+
def takes_ctx(function: Callable[..., Any]) -> bool:
|
|
161
|
+
"""Check if a function takes a `RunContext` first argument.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
function: The function to check.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
`True` if the function takes a `RunContext` as first argument, `False` otherwise.
|
|
168
|
+
"""
|
|
169
|
+
sig = signature(function)
|
|
170
|
+
try:
|
|
171
|
+
_, first_param = next(iter(sig.parameters.items()))
|
|
172
|
+
except StopIteration:
|
|
173
|
+
return False
|
|
174
|
+
else:
|
|
175
|
+
return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
|
|
176
|
+
|
|
177
|
+
|
|
162
178
|
def _build_schema(
|
|
163
179
|
fields: dict[str, core_schema.TypedDictField],
|
|
164
180
|
var_kwargs_schema: core_schema.CoreSchema | None,
|
|
@@ -191,7 +207,7 @@ def _build_schema(
|
|
|
191
207
|
|
|
192
208
|
|
|
193
209
|
def _is_call_ctx(annotation: Any) -> bool:
|
|
194
|
-
from .
|
|
210
|
+
from .tools import RunContext
|
|
195
211
|
|
|
196
212
|
return annotation is RunContext or (
|
|
197
213
|
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
pydantic_ai/_result.py
CHANGED
|
@@ -11,10 +11,10 @@ from pydantic import TypeAdapter, ValidationError
|
|
|
11
11
|
from typing_extensions import Self, TypeAliasType, TypedDict
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages
|
|
14
|
-
from .dependencies import AgentDeps, ResultValidatorFunc, RunContext
|
|
15
14
|
from .exceptions import ModelRetry
|
|
16
15
|
from .messages import ModelStructuredResponse, ToolCall
|
|
17
16
|
from .result import ResultData
|
|
17
|
+
from .tools import AgentDeps, ResultValidatorFunc, RunContext
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
pydantic_ai/_system_prompt.py
CHANGED
pydantic_ai/agent.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import dataclasses
|
|
5
|
+
import inspect
|
|
4
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
5
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
6
8
|
from dataclasses import dataclass, field
|
|
9
|
+
from types import FrameType
|
|
7
10
|
from typing import Any, Callable, Generic, cast, final, overload
|
|
8
11
|
|
|
9
12
|
import logfire_api
|
|
@@ -12,15 +15,14 @@ from typing_extensions import assert_never
|
|
|
12
15
|
from . import (
|
|
13
16
|
_result,
|
|
14
17
|
_system_prompt,
|
|
15
|
-
_tool as _r,
|
|
16
18
|
_utils,
|
|
17
19
|
exceptions,
|
|
18
20
|
messages as _messages,
|
|
19
21
|
models,
|
|
20
22
|
result,
|
|
21
23
|
)
|
|
22
|
-
from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
|
|
23
24
|
from .result import ResultData
|
|
25
|
+
from .tools import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams
|
|
24
26
|
|
|
25
27
|
__all__ = ('Agent',)
|
|
26
28
|
|
|
@@ -34,7 +36,7 @@ NoneType = type(None)
|
|
|
34
36
|
class Agent(Generic[AgentDeps, ResultData]):
|
|
35
37
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
36
38
|
|
|
37
|
-
Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.
|
|
39
|
+
Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.tools.AgentDeps]
|
|
38
40
|
and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
|
|
39
41
|
|
|
40
42
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
@@ -54,11 +56,16 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
54
56
|
# dataclass fields mostly for my sanity — knowing what attributes are available
|
|
55
57
|
model: models.Model | models.KnownModelName | None
|
|
56
58
|
"""The default model configured for this agent."""
|
|
59
|
+
name: str | None
|
|
60
|
+
"""The name of the agent, used for logging.
|
|
61
|
+
|
|
62
|
+
If `None`, we try to infer the agent name from the call frame when the agent is first run.
|
|
63
|
+
"""
|
|
57
64
|
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
|
58
65
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
59
66
|
_allow_text_result: bool = field(repr=False)
|
|
60
67
|
_system_prompts: tuple[str, ...] = field(repr=False)
|
|
61
|
-
_function_tools: dict[str,
|
|
68
|
+
_function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
|
|
62
69
|
_default_retries: int = field(repr=False)
|
|
63
70
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
64
71
|
_deps_type: type[AgentDeps] = field(repr=False)
|
|
@@ -79,10 +86,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
79
86
|
result_type: type[ResultData] = str,
|
|
80
87
|
system_prompt: str | Sequence[str] = (),
|
|
81
88
|
deps_type: type[AgentDeps] = NoneType,
|
|
89
|
+
name: str | None = None,
|
|
82
90
|
retries: int = 1,
|
|
83
91
|
result_tool_name: str = 'final_result',
|
|
84
92
|
result_tool_description: str | None = None,
|
|
85
93
|
result_retries: int | None = None,
|
|
94
|
+
tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
|
|
86
95
|
defer_model_check: bool = False,
|
|
87
96
|
):
|
|
88
97
|
"""Create an agent.
|
|
@@ -97,10 +106,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
97
106
|
parameterize the agent, and therefore get the best out of static type checking.
|
|
98
107
|
If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
|
|
99
108
|
or add a type hint `: Agent[None, <return type>]`.
|
|
109
|
+
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
|
|
110
|
+
when the agent is first run.
|
|
100
111
|
retries: The default number of retries to allow before raising an error.
|
|
101
112
|
result_tool_name: The name of the tool to use for the final result.
|
|
102
113
|
result_tool_description: The description of the final result tool.
|
|
103
114
|
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
115
|
+
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
116
|
+
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
104
117
|
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
|
|
105
118
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
106
119
|
which checks for the necessary environment variables. Set this to `false`
|
|
@@ -112,6 +125,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
112
125
|
else:
|
|
113
126
|
self.model = models.infer_model(model)
|
|
114
127
|
|
|
128
|
+
self.name = name
|
|
115
129
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
116
130
|
result_type, result_tool_name, result_tool_description
|
|
117
131
|
)
|
|
@@ -119,9 +133,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
119
133
|
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
120
134
|
|
|
121
135
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
122
|
-
self._function_tools
|
|
123
|
-
self._deps_type = deps_type
|
|
136
|
+
self._function_tools = {}
|
|
124
137
|
self._default_retries = retries
|
|
138
|
+
for tool in tools:
|
|
139
|
+
self._register_tool(Tool.infer(tool))
|
|
140
|
+
self._deps_type = deps_type
|
|
125
141
|
self._system_prompt_functions = []
|
|
126
142
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
127
143
|
self._current_result_retry = 0
|
|
@@ -134,6 +150,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
134
150
|
message_history: list[_messages.Message] | None = None,
|
|
135
151
|
model: models.Model | models.KnownModelName | None = None,
|
|
136
152
|
deps: AgentDeps = None,
|
|
153
|
+
infer_name: bool = True,
|
|
137
154
|
) -> result.RunResult[ResultData]:
|
|
138
155
|
"""Run the agent with a user prompt in async mode.
|
|
139
156
|
|
|
@@ -142,16 +159,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
142
159
|
message_history: History of the conversation so far.
|
|
143
160
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
144
161
|
deps: Optional dependencies to use for this run.
|
|
162
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
145
163
|
|
|
146
164
|
Returns:
|
|
147
165
|
The result of the run.
|
|
148
166
|
"""
|
|
167
|
+
if infer_name and self.name is None:
|
|
168
|
+
self._infer_name(inspect.currentframe())
|
|
149
169
|
model_used, custom_model, agent_model = await self._get_agent_model(model)
|
|
150
170
|
|
|
151
171
|
deps = self._get_deps(deps)
|
|
152
172
|
|
|
153
173
|
with _logfire.span(
|
|
154
|
-
'agent run {prompt=}',
|
|
174
|
+
'{agent.name} run {prompt=}',
|
|
155
175
|
prompt=user_prompt,
|
|
156
176
|
agent=self,
|
|
157
177
|
custom_model=custom_model,
|
|
@@ -203,21 +223,28 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
203
223
|
message_history: list[_messages.Message] | None = None,
|
|
204
224
|
model: models.Model | models.KnownModelName | None = None,
|
|
205
225
|
deps: AgentDeps = None,
|
|
226
|
+
infer_name: bool = True,
|
|
206
227
|
) -> result.RunResult[ResultData]:
|
|
207
228
|
"""Run the agent with a user prompt synchronously.
|
|
208
229
|
|
|
209
|
-
This is a convenience method that wraps `self.run` with `
|
|
230
|
+
This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
|
|
210
231
|
|
|
211
232
|
Args:
|
|
212
233
|
user_prompt: User input to start/continue the conversation.
|
|
213
234
|
message_history: History of the conversation so far.
|
|
214
235
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
215
236
|
deps: Optional dependencies to use for this run.
|
|
237
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
216
238
|
|
|
217
239
|
Returns:
|
|
218
240
|
The result of the run.
|
|
219
241
|
"""
|
|
220
|
-
|
|
242
|
+
if infer_name and self.name is None:
|
|
243
|
+
self._infer_name(inspect.currentframe())
|
|
244
|
+
loop = asyncio.get_event_loop()
|
|
245
|
+
return loop.run_until_complete(
|
|
246
|
+
self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False)
|
|
247
|
+
)
|
|
221
248
|
|
|
222
249
|
@asynccontextmanager
|
|
223
250
|
async def run_stream(
|
|
@@ -227,6 +254,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
227
254
|
message_history: list[_messages.Message] | None = None,
|
|
228
255
|
model: models.Model | models.KnownModelName | None = None,
|
|
229
256
|
deps: AgentDeps = None,
|
|
257
|
+
infer_name: bool = True,
|
|
230
258
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
231
259
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
232
260
|
|
|
@@ -235,16 +263,21 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
235
263
|
message_history: History of the conversation so far.
|
|
236
264
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
237
265
|
deps: Optional dependencies to use for this run.
|
|
266
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
238
267
|
|
|
239
268
|
Returns:
|
|
240
269
|
The result of the run.
|
|
241
270
|
"""
|
|
271
|
+
if infer_name and self.name is None:
|
|
272
|
+
# f_back because `asynccontextmanager` adds one frame
|
|
273
|
+
if frame := inspect.currentframe(): # pragma: no branch
|
|
274
|
+
self._infer_name(frame.f_back)
|
|
242
275
|
model_used, custom_model, agent_model = await self._get_agent_model(model)
|
|
243
276
|
|
|
244
277
|
deps = self._get_deps(deps)
|
|
245
278
|
|
|
246
279
|
with _logfire.span(
|
|
247
|
-
'agent run stream {prompt=}',
|
|
280
|
+
'{agent.name} run stream {prompt=}',
|
|
248
281
|
prompt=user_prompt,
|
|
249
282
|
agent=self,
|
|
250
283
|
custom_model=custom_model,
|
|
@@ -354,7 +387,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
354
387
|
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
355
388
|
"""Decorator to register a system prompt function.
|
|
356
389
|
|
|
357
|
-
Optionally takes [`RunContext`][pydantic_ai.
|
|
390
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's only argument.
|
|
358
391
|
Can decorate a sync or async functions.
|
|
359
392
|
|
|
360
393
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
@@ -405,7 +438,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
405
438
|
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
406
439
|
"""Decorator to register a result validator function.
|
|
407
440
|
|
|
408
|
-
Optionally takes [`RunContext`][pydantic_ai.
|
|
441
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's first argument.
|
|
409
442
|
Can decorate a sync or async functions.
|
|
410
443
|
|
|
411
444
|
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
@@ -438,22 +471,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
438
471
|
return func
|
|
439
472
|
|
|
440
473
|
@overload
|
|
441
|
-
def tool(self, func:
|
|
474
|
+
def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ...
|
|
442
475
|
|
|
443
476
|
@overload
|
|
444
477
|
def tool(
|
|
445
478
|
self, /, *, retries: int | None = None
|
|
446
|
-
) -> Callable[[
|
|
479
|
+
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
|
|
447
480
|
|
|
448
481
|
def tool(
|
|
449
482
|
self,
|
|
450
|
-
func:
|
|
483
|
+
func: ToolFuncContext[AgentDeps, ToolParams] | None = None,
|
|
451
484
|
/,
|
|
452
485
|
*,
|
|
453
486
|
retries: int | None = None,
|
|
454
487
|
) -> Any:
|
|
455
488
|
"""Decorator to register a tool function which takes
|
|
456
|
-
[`RunContext`][pydantic_ai.
|
|
489
|
+
[`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
457
490
|
|
|
458
491
|
Can decorate a sync or async functions.
|
|
459
492
|
|
|
@@ -490,27 +523,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
490
523
|
if func is None:
|
|
491
524
|
|
|
492
525
|
def tool_decorator(
|
|
493
|
-
func_:
|
|
494
|
-
) ->
|
|
526
|
+
func_: ToolFuncContext[AgentDeps, ToolParams],
|
|
527
|
+
) -> ToolFuncContext[AgentDeps, ToolParams]:
|
|
495
528
|
# noinspection PyTypeChecker
|
|
496
|
-
self.
|
|
529
|
+
self._register_function(func_, True, retries)
|
|
497
530
|
return func_
|
|
498
531
|
|
|
499
532
|
return tool_decorator
|
|
500
533
|
else:
|
|
501
534
|
# noinspection PyTypeChecker
|
|
502
|
-
self.
|
|
535
|
+
self._register_function(func, True, retries)
|
|
503
536
|
return func
|
|
504
537
|
|
|
505
538
|
@overload
|
|
506
|
-
def tool_plain(self, func:
|
|
539
|
+
def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
|
|
507
540
|
|
|
508
541
|
@overload
|
|
509
542
|
def tool_plain(
|
|
510
543
|
self, /, *, retries: int | None = None
|
|
511
|
-
) -> Callable[[
|
|
544
|
+
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
512
545
|
|
|
513
|
-
def tool_plain(self, func:
|
|
546
|
+
def tool_plain(self, func: ToolFuncPlain[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
|
|
514
547
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
515
548
|
|
|
516
549
|
Can decorate a sync or async functions.
|
|
@@ -547,28 +580,34 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
547
580
|
"""
|
|
548
581
|
if func is None:
|
|
549
582
|
|
|
550
|
-
def tool_decorator(
|
|
551
|
-
func_: ToolPlainFunc[ToolParams],
|
|
552
|
-
) -> ToolPlainFunc[ToolParams]:
|
|
583
|
+
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
553
584
|
# noinspection PyTypeChecker
|
|
554
|
-
self.
|
|
585
|
+
self._register_function(func_, False, retries)
|
|
555
586
|
return func_
|
|
556
587
|
|
|
557
588
|
return tool_decorator
|
|
558
589
|
else:
|
|
559
|
-
self.
|
|
590
|
+
self._register_function(func, False, retries)
|
|
560
591
|
return func
|
|
561
592
|
|
|
562
|
-
def
|
|
563
|
-
|
|
593
|
+
def _register_function(
|
|
594
|
+
self, func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None
|
|
595
|
+
) -> None:
|
|
596
|
+
"""Private utility to register a function as a tool."""
|
|
564
597
|
retries_ = retries if retries is not None else self._default_retries
|
|
565
|
-
tool =
|
|
598
|
+
tool = Tool(func, takes_ctx, max_retries=retries_)
|
|
599
|
+
self._register_tool(tool)
|
|
566
600
|
|
|
567
|
-
|
|
568
|
-
|
|
601
|
+
def _register_tool(self, tool: Tool[AgentDeps]) -> None:
|
|
602
|
+
"""Private utility to register a tool instance."""
|
|
603
|
+
if tool.max_retries is None:
|
|
604
|
+
tool = dataclasses.replace(tool, max_retries=self._default_retries)
|
|
569
605
|
|
|
570
606
|
if tool.name in self._function_tools:
|
|
571
|
-
raise
|
|
607
|
+
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
608
|
+
|
|
609
|
+
if self._result_schema and tool.name in self._result_schema.tools:
|
|
610
|
+
raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
|
|
572
611
|
|
|
573
612
|
self._function_tools[tool.name] = tool
|
|
574
613
|
|
|
@@ -786,6 +825,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
786
825
|
else:
|
|
787
826
|
return deps
|
|
788
827
|
|
|
828
|
+
def _infer_name(self, function_frame: FrameType | None) -> None:
|
|
829
|
+
"""Infer the agent name from the call frame.
|
|
830
|
+
|
|
831
|
+
Usage should be `self._infer_name(inspect.currentframe())`.
|
|
832
|
+
"""
|
|
833
|
+
assert self.name is None, 'Name already set'
|
|
834
|
+
if function_frame is not None: # pragma: no branch
|
|
835
|
+
if parent_frame := function_frame.f_back: # pragma: no branch
|
|
836
|
+
for name, item in parent_frame.f_locals.items():
|
|
837
|
+
if item is self:
|
|
838
|
+
self.name = name
|
|
839
|
+
return
|
|
840
|
+
|
|
789
841
|
|
|
790
842
|
@dataclass
|
|
791
843
|
class _MarkFinalResult(Generic[ResultData]):
|
pydantic_ai/messages.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from collections.abc import Mapping, Sequence
|
|
5
4
|
from dataclasses import dataclass, field
|
|
6
5
|
from datetime import datetime
|
|
7
|
-
from typing import
|
|
6
|
+
from typing import Annotated, Any, Literal, Union
|
|
8
7
|
|
|
9
8
|
import pydantic
|
|
10
9
|
import pydantic_core
|
|
11
10
|
from pydantic import TypeAdapter
|
|
12
|
-
from typing_extensions import TypeAlias, TypeAliasType
|
|
13
11
|
|
|
14
12
|
from . import _pydantic
|
|
15
13
|
from ._utils import now_utc as _now_utc
|
|
@@ -44,13 +42,7 @@ class UserPrompt:
|
|
|
44
42
|
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
45
43
|
|
|
46
44
|
|
|
47
|
-
|
|
48
|
-
if not TYPE_CHECKING:
|
|
49
|
-
# work around for https://github.com/pydantic/pydantic/issues/10873
|
|
50
|
-
# this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file
|
|
51
|
-
JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]')
|
|
52
|
-
|
|
53
|
-
json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData)
|
|
45
|
+
tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
|
|
54
46
|
|
|
55
47
|
|
|
56
48
|
@dataclass
|
|
@@ -59,7 +51,7 @@ class ToolReturn:
|
|
|
59
51
|
|
|
60
52
|
tool_name: str
|
|
61
53
|
"""The name of the "tool" was called."""
|
|
62
|
-
content:
|
|
54
|
+
content: Any
|
|
63
55
|
"""The return value."""
|
|
64
56
|
tool_id: str | None = None
|
|
65
57
|
"""Optional tool identifier, this is used by some models including OpenAI."""
|
|
@@ -72,15 +64,14 @@ class ToolReturn:
|
|
|
72
64
|
if isinstance(self.content, str):
|
|
73
65
|
return self.content
|
|
74
66
|
else:
|
|
75
|
-
|
|
76
|
-
return json_ta.dump_json(content).decode()
|
|
67
|
+
return tool_return_ta.dump_json(self.content).decode()
|
|
77
68
|
|
|
78
|
-
def model_response_object(self) -> dict[str,
|
|
69
|
+
def model_response_object(self) -> dict[str, Any]:
|
|
79
70
|
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
|
|
80
71
|
if isinstance(self.content, dict):
|
|
81
|
-
return
|
|
72
|
+
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
|
|
82
73
|
else:
|
|
83
|
-
return {'return_value':
|
|
74
|
+
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
84
75
|
|
|
85
76
|
|
|
86
77
|
@dataclass
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -259,20 +259,30 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
259
259
|
class AbstractToolDefinition(Protocol):
|
|
260
260
|
"""Abstract definition of a function/tool.
|
|
261
261
|
|
|
262
|
-
This is used for both tools and result tools.
|
|
262
|
+
This is used for both function tools and result tools.
|
|
263
263
|
"""
|
|
264
264
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
json_schema: ObjectJsonSchema
|
|
270
|
-
"""The JSON schema for the tool's arguments."""
|
|
271
|
-
outer_typed_dict_key: str | None
|
|
272
|
-
"""The key in the outer [TypedDict] that wraps a result tool.
|
|
265
|
+
@property
|
|
266
|
+
def name(self) -> str:
|
|
267
|
+
"""The name of the tool."""
|
|
268
|
+
...
|
|
273
269
|
|
|
274
|
-
|
|
275
|
-
|
|
270
|
+
@property
|
|
271
|
+
def description(self) -> str:
|
|
272
|
+
"""The description of the tool."""
|
|
273
|
+
...
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def json_schema(self) -> ObjectJsonSchema:
|
|
277
|
+
"""The JSON schema for the tool's arguments."""
|
|
278
|
+
...
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def outer_typed_dict_key(self) -> str | None:
|
|
282
|
+
"""The key in the outer [TypedDict] that wraps a result tool.
|
|
283
|
+
|
|
284
|
+
This will only be set for result tools which don't have an `object` JSON schema.
|
|
285
|
+
"""
|
|
276
286
|
|
|
277
287
|
|
|
278
288
|
@cache
|
pydantic_ai/result.py
CHANGED
pydantic_ai/tools.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
|
|
7
|
+
|
|
8
|
+
from pydantic import ValidationError
|
|
9
|
+
from pydantic_core import SchemaValidator
|
|
10
|
+
from typing_extensions import Concatenate, ParamSpec, final
|
|
11
|
+
|
|
12
|
+
from . import _pydantic, _utils, messages
|
|
13
|
+
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .result import ResultData
|
|
17
|
+
else:
|
|
18
|
+
ResultData = Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__all__ = (
|
|
22
|
+
'AgentDeps',
|
|
23
|
+
'RunContext',
|
|
24
|
+
'ResultValidatorFunc',
|
|
25
|
+
'SystemPromptFunc',
|
|
26
|
+
'ToolFuncContext',
|
|
27
|
+
'ToolFuncPlain',
|
|
28
|
+
'ToolFuncEither',
|
|
29
|
+
'ToolParams',
|
|
30
|
+
'Tool',
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
AgentDeps = TypeVar('AgentDeps')
|
|
34
|
+
"""Type variable for agent dependencies."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class RunContext(Generic[AgentDeps]):
|
|
39
|
+
"""Information about the current call."""
|
|
40
|
+
|
|
41
|
+
deps: AgentDeps
|
|
42
|
+
"""Dependencies for the agent."""
|
|
43
|
+
retry: int
|
|
44
|
+
"""Number of retries so far."""
|
|
45
|
+
tool_name: str | None
|
|
46
|
+
"""Name of the tool being called."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
ToolParams = ParamSpec('ToolParams')
|
|
50
|
+
"""Retrieval function param spec."""
|
|
51
|
+
|
|
52
|
+
SystemPromptFunc = Union[
|
|
53
|
+
Callable[[RunContext[AgentDeps]], str],
|
|
54
|
+
Callable[[RunContext[AgentDeps]], Awaitable[str]],
|
|
55
|
+
Callable[[], str],
|
|
56
|
+
Callable[[], Awaitable[str]],
|
|
57
|
+
]
|
|
58
|
+
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
|
|
59
|
+
|
|
60
|
+
Usage `SystemPromptFunc[AgentDeps]`.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
ResultValidatorFunc = Union[
|
|
64
|
+
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
65
|
+
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
66
|
+
Callable[[ResultData], ResultData],
|
|
67
|
+
Callable[[ResultData], Awaitable[ResultData]],
|
|
68
|
+
]
|
|
69
|
+
"""
|
|
70
|
+
A function that always takes `ResultData` and returns `ResultData`,
|
|
71
|
+
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
|
|
72
|
+
|
|
73
|
+
Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
|
|
77
|
+
"""A tool function that takes `RunContext` as the first argument.
|
|
78
|
+
|
|
79
|
+
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
|
|
80
|
+
"""
|
|
81
|
+
ToolFuncPlain = Callable[ToolParams, Any]
|
|
82
|
+
"""A tool function that does not take `RunContext` as the first argument.
|
|
83
|
+
|
|
84
|
+
Usage `ToolPlainFunc[ToolParams]`.
|
|
85
|
+
"""
|
|
86
|
+
ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
|
|
87
|
+
"""Either kind of tool function.
|
|
88
|
+
|
|
89
|
+
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
|
|
90
|
+
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
|
|
91
|
+
|
|
92
|
+
Usage `ToolFuncEither[AgentDeps, ToolParams]`.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
A = TypeVar('A')
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@final
|
|
99
|
+
@dataclass(init=False)
|
|
100
|
+
class Tool(Generic[AgentDeps]):
|
|
101
|
+
"""A tool function for an agent."""
|
|
102
|
+
|
|
103
|
+
function: ToolFuncEither[AgentDeps, ...]
|
|
104
|
+
takes_ctx: bool
|
|
105
|
+
max_retries: int | None
|
|
106
|
+
name: str
|
|
107
|
+
description: str
|
|
108
|
+
_is_async: bool = field(init=False)
|
|
109
|
+
_single_arg_name: str | None = field(init=False)
|
|
110
|
+
_positional_fields: list[str] = field(init=False)
|
|
111
|
+
_var_positional_field: str | None = field(init=False)
|
|
112
|
+
_validator: SchemaValidator = field(init=False, repr=False)
|
|
113
|
+
_json_schema: _utils.ObjectJsonSchema = field(init=False)
|
|
114
|
+
_current_retry: int = field(default=0, init=False)
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
function: ToolFuncEither[AgentDeps, ...],
|
|
119
|
+
takes_ctx: bool,
|
|
120
|
+
*,
|
|
121
|
+
max_retries: int | None = None,
|
|
122
|
+
name: str | None = None,
|
|
123
|
+
description: str | None = None,
|
|
124
|
+
):
|
|
125
|
+
"""Create a new tool instance.
|
|
126
|
+
|
|
127
|
+
Example usage:
|
|
128
|
+
|
|
129
|
+
```py
|
|
130
|
+
from pydantic_ai import Agent, RunContext, Tool
|
|
131
|
+
|
|
132
|
+
async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
|
|
133
|
+
return f'{ctx.deps} {x} {y}'
|
|
134
|
+
|
|
135
|
+
agent = Agent('test', tools=[Tool(my_tool, True)])
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
function: The Python function to call as the tool.
|
|
140
|
+
takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
|
|
141
|
+
max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
|
|
142
|
+
name: Name of the tool, inferred from the function if `None`.
|
|
143
|
+
description: Description of the tool, inferred from the function if `None`.
|
|
144
|
+
"""
|
|
145
|
+
f = _pydantic.function_schema(function, takes_ctx)
|
|
146
|
+
self.function = function
|
|
147
|
+
self.takes_ctx = takes_ctx
|
|
148
|
+
self.max_retries = max_retries
|
|
149
|
+
self.name = name or function.__name__
|
|
150
|
+
self.description = description or f['description']
|
|
151
|
+
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
152
|
+
self._single_arg_name = f['single_arg_name']
|
|
153
|
+
self._positional_fields = f['positional_fields']
|
|
154
|
+
self._var_positional_field = f['var_positional_field']
|
|
155
|
+
self._validator = f['validator']
|
|
156
|
+
self._json_schema = f['json_schema']
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def infer(function: ToolFuncEither[A, ...] | Tool[A]) -> Tool[A]:
|
|
160
|
+
"""Create a tool from a pure function, inferring whether it takes `RunContext` as its first argument.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
function: The tool function to wrap; or for convenience, a `Tool` instance.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
A new `Tool` instance.
|
|
167
|
+
"""
|
|
168
|
+
if isinstance(function, Tool):
|
|
169
|
+
return function
|
|
170
|
+
else:
|
|
171
|
+
return Tool(function, takes_ctx=_pydantic.takes_ctx(function))
|
|
172
|
+
|
|
173
|
+
def reset(self) -> None:
|
|
174
|
+
"""Reset the current retry count."""
|
|
175
|
+
self._current_retry = 0
|
|
176
|
+
|
|
177
|
+
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
|
|
178
|
+
"""Run the tool function asynchronously."""
|
|
179
|
+
try:
|
|
180
|
+
if isinstance(message.args, messages.ArgsJson):
|
|
181
|
+
args_dict = self._validator.validate_json(message.args.args_json)
|
|
182
|
+
else:
|
|
183
|
+
args_dict = self._validator.validate_python(message.args.args_dict)
|
|
184
|
+
except ValidationError as e:
|
|
185
|
+
return self._on_error(e, message)
|
|
186
|
+
|
|
187
|
+
args, kwargs = self._call_args(deps, args_dict, message)
|
|
188
|
+
try:
|
|
189
|
+
if self._is_async:
|
|
190
|
+
function = cast(Callable[[Any], Awaitable[str]], self.function)
|
|
191
|
+
response_content = await function(*args, **kwargs)
|
|
192
|
+
else:
|
|
193
|
+
function = cast(Callable[[Any], str], self.function)
|
|
194
|
+
response_content = await _utils.run_in_executor(function, *args, **kwargs)
|
|
195
|
+
except ModelRetry as e:
|
|
196
|
+
return self._on_error(e, message)
|
|
197
|
+
|
|
198
|
+
self._current_retry = 0
|
|
199
|
+
return messages.ToolReturn(
|
|
200
|
+
tool_name=message.tool_name,
|
|
201
|
+
content=response_content,
|
|
202
|
+
tool_id=message.tool_id,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def json_schema(self) -> _utils.ObjectJsonSchema:
|
|
207
|
+
return self._json_schema
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def outer_typed_dict_key(self) -> str | None:
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def _call_args(
|
|
214
|
+
self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
|
|
215
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
216
|
+
if self._single_arg_name:
|
|
217
|
+
args_dict = {self._single_arg_name: args_dict}
|
|
218
|
+
|
|
219
|
+
args = [RunContext(deps, self._current_retry, message.tool_name)] if self.takes_ctx else []
|
|
220
|
+
for positional_field in self._positional_fields:
|
|
221
|
+
args.append(args_dict.pop(positional_field))
|
|
222
|
+
if self._var_positional_field:
|
|
223
|
+
args.extend(args_dict.pop(self._var_positional_field))
|
|
224
|
+
|
|
225
|
+
return args, args_dict
|
|
226
|
+
|
|
227
|
+
def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
|
|
228
|
+
self._current_retry += 1
|
|
229
|
+
if self.max_retries is None or self._current_retry > self.max_retries:
|
|
230
|
+
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
|
|
231
|
+
else:
|
|
232
|
+
if isinstance(exc, ValidationError):
|
|
233
|
+
content = exc.errors(include_url=False)
|
|
234
|
+
else:
|
|
235
|
+
content = exc.message
|
|
236
|
+
return messages.RetryPrompt(
|
|
237
|
+
tool_name=call_message.tool_name,
|
|
238
|
+
content=content,
|
|
239
|
+
tool_id=call_message.tool_id,
|
|
240
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
pydantic_ai/__init__.py,sha256=a29NqQz0JyW4BoCjcRh23fBBfwY17_n57moE4QrFWM4,324
|
|
2
|
+
pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
|
|
3
|
+
pydantic_ai/_pydantic.py,sha256=oFfcHDv_wuL1NQ7mCzVHvP1HBaVzyvb7xS-_Iiri_tA,8491
|
|
4
|
+
pydantic_ai/_result.py,sha256=wzcfwDpr_sro1Vkn3DkyIhCXMHTReDxL_ZYm50JzdRI,9667
|
|
5
|
+
pydantic_ai/_system_prompt.py,sha256=vFT0y9Wykl5veGMgLLkGRYiHQrgdW2BZ1rwMn4izjjo,1085
|
|
6
|
+
pydantic_ai/_utils.py,sha256=eNb7f3-ZQC8WDEa87iUcXGQ-lyuutFQG-5yBCMD4Vvs,8227
|
|
7
|
+
pydantic_ai/agent.py,sha256=dB2_JshYBjK04fmzJP79wmoKJwMuEBaLmjIRVdwrISM,36854
|
|
8
|
+
pydantic_ai/exceptions.py,sha256=ko_47M0k6Rhg9mUC9P1cj7N4LCH6cC0pEsF65A2vL-U,1561
|
|
9
|
+
pydantic_ai/messages.py,sha256=I0_CPXDIGGSy-PXHuKq540oAXYOO9uyylpsfSsE4vLs,7032
|
|
10
|
+
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
pydantic_ai/result.py,sha256=UB3vFOcDAceeNLXh_f_3PKy2J0A6FIHdgDFJxPH6OJk,13651
|
|
12
|
+
pydantic_ai/tools.py,sha256=rzchmsEPYtUzBBFoeeJKC1ct36iqOxFOY6kfKqItCCs,8210
|
|
13
|
+
pydantic_ai/models/__init__.py,sha256=_Mz_32WGlAf4NlxXfdQ-EAaY_bDOk10gIc5HmTAO_ts,10318
|
|
14
|
+
pydantic_ai/models/function.py,sha256=Mzc-zXnb2RayWAA8N9NS7KGF49do1S-VW3U9fkc661o,10045
|
|
15
|
+
pydantic_ai/models/gemini.py,sha256=ruO4tnnpDDuHThg7jUOphs8I_KXBJH7gfDMluliED8E,26606
|
|
16
|
+
pydantic_ai/models/groq.py,sha256=Tx2yU3ysmPLBmWGsjzES-XcumzrsoBtB7spCnJBlLiM,14947
|
|
17
|
+
pydantic_ai/models/openai.py,sha256=5ihH25CrS0tnZNW-BZw4GyPe8V-IxIHWw3B9ulPVjQE,14931
|
|
18
|
+
pydantic_ai/models/test.py,sha256=q1wch_E7TSb4qx9PCcP1YyBGZx567MGlAQhlAlON0S8,14463
|
|
19
|
+
pydantic_ai/models/vertexai.py,sha256=5wI8y2YjeRgSE51uKy5OtevQkks65uEbxIUAs5EGBaI,9161
|
|
20
|
+
pydantic_ai_slim-0.0.10.dist-info/METADATA,sha256=2-sEVelPFDeLZwYC4o8ZgGvyXvJsVXZIkj3Mx3wXs6g,2562
|
|
21
|
+
pydantic_ai_slim-0.0.10.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
|
22
|
+
pydantic_ai_slim-0.0.10.dist-info/RECORD,,
|
pydantic_ai/_tool.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations as _annotations
|
|
2
|
-
|
|
3
|
-
import inspect
|
|
4
|
-
from collections.abc import Awaitable
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any, Callable, Generic, cast
|
|
7
|
-
|
|
8
|
-
from pydantic import ValidationError
|
|
9
|
-
from pydantic_core import SchemaValidator
|
|
10
|
-
|
|
11
|
-
from . import _pydantic, _utils, messages
|
|
12
|
-
from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
|
|
13
|
-
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
14
|
-
|
|
15
|
-
# Usage `ToolEitherFunc[AgentDependencies, P]`
|
|
16
|
-
ToolEitherFunc = _utils.Either[ToolContextFunc[AgentDeps, ToolParams], ToolPlainFunc[ToolParams]]
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@dataclass(init=False)
|
|
20
|
-
class Tool(Generic[AgentDeps, ToolParams]):
|
|
21
|
-
"""A tool function for an agent."""
|
|
22
|
-
|
|
23
|
-
name: str
|
|
24
|
-
description: str
|
|
25
|
-
function: ToolEitherFunc[AgentDeps, ToolParams] = field(repr=False)
|
|
26
|
-
is_async: bool
|
|
27
|
-
single_arg_name: str | None
|
|
28
|
-
positional_fields: list[str]
|
|
29
|
-
var_positional_field: str | None
|
|
30
|
-
validator: SchemaValidator = field(repr=False)
|
|
31
|
-
json_schema: _utils.ObjectJsonSchema
|
|
32
|
-
max_retries: int
|
|
33
|
-
_current_retry: int = 0
|
|
34
|
-
outer_typed_dict_key: str | None = None
|
|
35
|
-
|
|
36
|
-
def __init__(self, function: ToolEitherFunc[AgentDeps, ToolParams], retries: int):
|
|
37
|
-
"""Build a Tool dataclass from a function."""
|
|
38
|
-
self.function = function
|
|
39
|
-
# noinspection PyTypeChecker
|
|
40
|
-
f = _pydantic.function_schema(function)
|
|
41
|
-
raw_function = function.whichever()
|
|
42
|
-
self.name = raw_function.__name__
|
|
43
|
-
self.description = f['description']
|
|
44
|
-
self.is_async = inspect.iscoroutinefunction(raw_function)
|
|
45
|
-
self.single_arg_name = f['single_arg_name']
|
|
46
|
-
self.positional_fields = f['positional_fields']
|
|
47
|
-
self.var_positional_field = f['var_positional_field']
|
|
48
|
-
self.validator = f['validator']
|
|
49
|
-
self.json_schema = f['json_schema']
|
|
50
|
-
self.max_retries = retries
|
|
51
|
-
|
|
52
|
-
def reset(self) -> None:
|
|
53
|
-
"""Reset the current retry count."""
|
|
54
|
-
self._current_retry = 0
|
|
55
|
-
|
|
56
|
-
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
|
|
57
|
-
"""Run the tool function asynchronously."""
|
|
58
|
-
try:
|
|
59
|
-
if isinstance(message.args, messages.ArgsJson):
|
|
60
|
-
args_dict = self.validator.validate_json(message.args.args_json)
|
|
61
|
-
else:
|
|
62
|
-
args_dict = self.validator.validate_python(message.args.args_dict)
|
|
63
|
-
except ValidationError as e:
|
|
64
|
-
return self._on_error(e, message)
|
|
65
|
-
|
|
66
|
-
args, kwargs = self._call_args(deps, args_dict, message)
|
|
67
|
-
try:
|
|
68
|
-
if self.is_async:
|
|
69
|
-
function = cast(Callable[[Any], Awaitable[str]], self.function.whichever())
|
|
70
|
-
response_content = await function(*args, **kwargs)
|
|
71
|
-
else:
|
|
72
|
-
function = cast(Callable[[Any], str], self.function.whichever())
|
|
73
|
-
response_content = await _utils.run_in_executor(function, *args, **kwargs)
|
|
74
|
-
except ModelRetry as e:
|
|
75
|
-
return self._on_error(e, message)
|
|
76
|
-
|
|
77
|
-
self._current_retry = 0
|
|
78
|
-
return messages.ToolReturn(
|
|
79
|
-
tool_name=message.tool_name,
|
|
80
|
-
content=response_content,
|
|
81
|
-
tool_id=message.tool_id,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
def _call_args(
|
|
85
|
-
self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
|
|
86
|
-
) -> tuple[list[Any], dict[str, Any]]:
|
|
87
|
-
if self.single_arg_name:
|
|
88
|
-
args_dict = {self.single_arg_name: args_dict}
|
|
89
|
-
|
|
90
|
-
args = [RunContext(deps, self._current_retry, message.tool_name)] if self.function.is_left() else []
|
|
91
|
-
for positional_field in self.positional_fields:
|
|
92
|
-
args.append(args_dict.pop(positional_field))
|
|
93
|
-
if self.var_positional_field:
|
|
94
|
-
args.extend(args_dict.pop(self.var_positional_field))
|
|
95
|
-
|
|
96
|
-
return args, args_dict
|
|
97
|
-
|
|
98
|
-
def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
|
|
99
|
-
self._current_retry += 1
|
|
100
|
-
if self._current_retry > self.max_retries:
|
|
101
|
-
# TODO custom error with details of the tool
|
|
102
|
-
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
|
|
103
|
-
else:
|
|
104
|
-
if isinstance(exc, ValidationError):
|
|
105
|
-
content = exc.errors(include_url=False)
|
|
106
|
-
else:
|
|
107
|
-
content = exc.message
|
|
108
|
-
return messages.RetryPrompt(
|
|
109
|
-
tool_name=call_message.tool_name,
|
|
110
|
-
content=content,
|
|
111
|
-
tool_id=call_message.tool_id,
|
|
112
|
-
)
|
pydantic_ai/dependencies.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations as _annotations
|
|
2
|
-
|
|
3
|
-
from collections.abc import Awaitable, Mapping, Sequence
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
|
6
|
-
|
|
7
|
-
from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
8
|
-
|
|
9
|
-
if TYPE_CHECKING:
|
|
10
|
-
from .result import ResultData
|
|
11
|
-
else:
|
|
12
|
-
ResultData = Any
|
|
13
|
-
|
|
14
|
-
__all__ = (
|
|
15
|
-
'AgentDeps',
|
|
16
|
-
'RunContext',
|
|
17
|
-
'ResultValidatorFunc',
|
|
18
|
-
'SystemPromptFunc',
|
|
19
|
-
'ToolReturnValue',
|
|
20
|
-
'ToolContextFunc',
|
|
21
|
-
'ToolPlainFunc',
|
|
22
|
-
'ToolParams',
|
|
23
|
-
'JsonData',
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
AgentDeps = TypeVar('AgentDeps')
|
|
27
|
-
"""Type variable for agent dependencies."""
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@dataclass
|
|
31
|
-
class RunContext(Generic[AgentDeps]):
|
|
32
|
-
"""Information about the current call."""
|
|
33
|
-
|
|
34
|
-
deps: AgentDeps
|
|
35
|
-
"""Dependencies for the agent."""
|
|
36
|
-
retry: int
|
|
37
|
-
"""Number of retries so far."""
|
|
38
|
-
tool_name: str | None
|
|
39
|
-
"""Name of the tool being called."""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
ToolParams = ParamSpec('ToolParams')
|
|
43
|
-
"""Retrieval function param spec."""
|
|
44
|
-
|
|
45
|
-
SystemPromptFunc = Union[
|
|
46
|
-
Callable[[RunContext[AgentDeps]], str],
|
|
47
|
-
Callable[[RunContext[AgentDeps]], Awaitable[str]],
|
|
48
|
-
Callable[[], str],
|
|
49
|
-
Callable[[], Awaitable[str]],
|
|
50
|
-
]
|
|
51
|
-
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
|
|
52
|
-
|
|
53
|
-
Usage `SystemPromptFunc[AgentDeps]`.
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
ResultValidatorFunc = Union[
|
|
57
|
-
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
58
|
-
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
59
|
-
Callable[[ResultData], ResultData],
|
|
60
|
-
Callable[[ResultData], Awaitable[ResultData]],
|
|
61
|
-
]
|
|
62
|
-
"""
|
|
63
|
-
A function that always takes `ResultData` and returns `ResultData`,
|
|
64
|
-
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
|
|
65
|
-
|
|
66
|
-
Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
67
|
-
"""
|
|
68
|
-
|
|
69
|
-
JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
|
|
70
|
-
"""Type representing any JSON data."""
|
|
71
|
-
|
|
72
|
-
ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
|
|
73
|
-
"""Return value of a tool function."""
|
|
74
|
-
ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
|
|
75
|
-
"""A tool function that takes `RunContext` as the first argument.
|
|
76
|
-
|
|
77
|
-
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
|
|
78
|
-
"""
|
|
79
|
-
ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
|
|
80
|
-
"""A tool function that does not take `RunContext` as the first argument.
|
|
81
|
-
|
|
82
|
-
Usage `ToolPlainFunc[ToolParams]`.
|
|
83
|
-
"""
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
pydantic_ai/__init__.py,sha256=KaTzG8uBSEpfxk_y2a_O_R2Xa53GXYfiigjWtD4PCeI,317
|
|
2
|
-
pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
|
|
3
|
-
pydantic_ai/_pydantic.py,sha256=j1kObPIUDwn1VOHbJBwMFbWLxHM_OYRH7GFAda68ZC0,8010
|
|
4
|
-
pydantic_ai/_result.py,sha256=cAqfPipK39cz-p-ftlJ83Q5_LI1TRb3-HH-iivb5rEM,9674
|
|
5
|
-
pydantic_ai/_system_prompt.py,sha256=63egOej8zHsDVOInPayn0EEEDXKd0HVAbbrqXUTV96s,1092
|
|
6
|
-
pydantic_ai/_tool.py,sha256=5Q9XaGOEXbyOLS644osB1AA5EMoJkr4eYK60MVZo0Z8,4528
|
|
7
|
-
pydantic_ai/_utils.py,sha256=eNb7f3-ZQC8WDEa87iUcXGQ-lyuutFQG-5yBCMD4Vvs,8227
|
|
8
|
-
pydantic_ai/agent.py,sha256=r5DI4ZBqYE67GOMEEu-LXrTa5ty2AchW4szotwm5Qis,34338
|
|
9
|
-
pydantic_ai/dependencies.py,sha256=EHvD68AFkItxMnfHzJLG7T_AD1RGI2MZOfzm1v89hGQ,2399
|
|
10
|
-
pydantic_ai/exceptions.py,sha256=ko_47M0k6Rhg9mUC9P1cj7N4LCH6cC0pEsF65A2vL-U,1561
|
|
11
|
-
pydantic_ai/messages.py,sha256=FFTQ9Bo2Ct4bLuyJF-M9xkeraw05I--NC_ieR6oGtTM,7587
|
|
12
|
-
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
-
pydantic_ai/result.py,sha256=qsanb7v4qJ4pJdkdsqpy68kuZ1WNCpQB5jcXeTwGpe0,13658
|
|
14
|
-
pydantic_ai/models/__init__.py,sha256=Cx8PjEsi5gkNOVQic32sf4CmM-A3pRu1LcjpM6poiBI,10138
|
|
15
|
-
pydantic_ai/models/function.py,sha256=Mzc-zXnb2RayWAA8N9NS7KGF49do1S-VW3U9fkc661o,10045
|
|
16
|
-
pydantic_ai/models/gemini.py,sha256=ruO4tnnpDDuHThg7jUOphs8I_KXBJH7gfDMluliED8E,26606
|
|
17
|
-
pydantic_ai/models/groq.py,sha256=Tx2yU3ysmPLBmWGsjzES-XcumzrsoBtB7spCnJBlLiM,14947
|
|
18
|
-
pydantic_ai/models/openai.py,sha256=5ihH25CrS0tnZNW-BZw4GyPe8V-IxIHWw3B9ulPVjQE,14931
|
|
19
|
-
pydantic_ai/models/test.py,sha256=q1wch_E7TSb4qx9PCcP1YyBGZx567MGlAQhlAlON0S8,14463
|
|
20
|
-
pydantic_ai/models/vertexai.py,sha256=5wI8y2YjeRgSE51uKy5OtevQkks65uEbxIUAs5EGBaI,9161
|
|
21
|
-
pydantic_ai_slim-0.0.8.dist-info/METADATA,sha256=CmpvlEAUyaWaPbdRCbFuypVLJ8yNC3TwZ0jgvlR9yps,2561
|
|
22
|
-
pydantic_ai_slim-0.0.8.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
|
23
|
-
pydantic_ai_slim-0.0.8.dist-info/RECORD,,
|
|
File without changes
|