pydantic-ai-slim 0.0.6a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +8 -0
- pydantic_ai/_griffe.py +128 -0
- pydantic_ai/_pydantic.py +216 -0
- pydantic_ai/_result.py +258 -0
- pydantic_ai/_retriever.py +114 -0
- pydantic_ai/_system_prompt.py +33 -0
- pydantic_ai/_utils.py +247 -0
- pydantic_ai/agent.py +795 -0
- pydantic_ai/dependencies.py +83 -0
- pydantic_ai/exceptions.py +56 -0
- pydantic_ai/messages.py +205 -0
- pydantic_ai/models/__init__.py +300 -0
- pydantic_ai/models/function.py +268 -0
- pydantic_ai/models/gemini.py +720 -0
- pydantic_ai/models/groq.py +400 -0
- pydantic_ai/models/openai.py +379 -0
- pydantic_ai/models/test.py +389 -0
- pydantic_ai/models/vertexai.py +306 -0
- pydantic_ai/py.typed +0 -0
- pydantic_ai/result.py +314 -0
- pydantic_ai_slim-0.0.6a1.dist-info/METADATA +49 -0
- pydantic_ai_slim-0.0.6a1.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6a1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Callable, Generic, cast
|
|
7
|
+
|
|
8
|
+
from pydantic import ValidationError
|
|
9
|
+
from pydantic_core import SchemaValidator
|
|
10
|
+
|
|
11
|
+
from . import _pydantic, _utils, messages
|
|
12
|
+
from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
|
|
13
|
+
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
14
|
+
|
|
15
|
+
# Usage `RetrieverEitherFunc[AgentDependencies, P]`
|
|
16
|
+
RetrieverEitherFunc = _utils.Either[
|
|
17
|
+
RetrieverContextFunc[AgentDeps, RetrieverParams], RetrieverPlainFunc[RetrieverParams]
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(init=False)
|
|
22
|
+
class Retriever(Generic[AgentDeps, RetrieverParams]):
|
|
23
|
+
"""A retriever function for an agent."""
|
|
24
|
+
|
|
25
|
+
name: str
|
|
26
|
+
description: str
|
|
27
|
+
function: RetrieverEitherFunc[AgentDeps, RetrieverParams] = field(repr=False)
|
|
28
|
+
is_async: bool
|
|
29
|
+
single_arg_name: str | None
|
|
30
|
+
positional_fields: list[str]
|
|
31
|
+
var_positional_field: str | None
|
|
32
|
+
validator: SchemaValidator = field(repr=False)
|
|
33
|
+
json_schema: _utils.ObjectJsonSchema
|
|
34
|
+
max_retries: int
|
|
35
|
+
_current_retry: int = 0
|
|
36
|
+
outer_typed_dict_key: str | None = None
|
|
37
|
+
|
|
38
|
+
def __init__(self, function: RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int):
|
|
39
|
+
"""Build a Retriever dataclass from a function."""
|
|
40
|
+
self.function = function
|
|
41
|
+
# noinspection PyTypeChecker
|
|
42
|
+
f = _pydantic.function_schema(function)
|
|
43
|
+
raw_function = function.whichever()
|
|
44
|
+
self.name = raw_function.__name__
|
|
45
|
+
self.description = f['description']
|
|
46
|
+
self.is_async = inspect.iscoroutinefunction(raw_function)
|
|
47
|
+
self.single_arg_name = f['single_arg_name']
|
|
48
|
+
self.positional_fields = f['positional_fields']
|
|
49
|
+
self.var_positional_field = f['var_positional_field']
|
|
50
|
+
self.validator = f['validator']
|
|
51
|
+
self.json_schema = f['json_schema']
|
|
52
|
+
self.max_retries = retries
|
|
53
|
+
|
|
54
|
+
def reset(self) -> None:
|
|
55
|
+
"""Reset the current retry count."""
|
|
56
|
+
self._current_retry = 0
|
|
57
|
+
|
|
58
|
+
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
|
|
59
|
+
"""Run the retriever function asynchronously."""
|
|
60
|
+
try:
|
|
61
|
+
if isinstance(message.args, messages.ArgsJson):
|
|
62
|
+
args_dict = self.validator.validate_json(message.args.args_json)
|
|
63
|
+
else:
|
|
64
|
+
args_dict = self.validator.validate_python(message.args.args_object)
|
|
65
|
+
except ValidationError as e:
|
|
66
|
+
return self._on_error(e, message)
|
|
67
|
+
|
|
68
|
+
args, kwargs = self._call_args(deps, args_dict, message)
|
|
69
|
+
try:
|
|
70
|
+
if self.is_async:
|
|
71
|
+
function = cast(Callable[[Any], Awaitable[str]], self.function.whichever())
|
|
72
|
+
response_content = await function(*args, **kwargs)
|
|
73
|
+
else:
|
|
74
|
+
function = cast(Callable[[Any], str], self.function.whichever())
|
|
75
|
+
response_content = await _utils.run_in_executor(function, *args, **kwargs)
|
|
76
|
+
except ModelRetry as e:
|
|
77
|
+
return self._on_error(e, message)
|
|
78
|
+
|
|
79
|
+
self._current_retry = 0
|
|
80
|
+
return messages.ToolReturn(
|
|
81
|
+
tool_name=message.tool_name,
|
|
82
|
+
content=response_content,
|
|
83
|
+
tool_id=message.tool_id,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def _call_args(
|
|
87
|
+
self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
|
|
88
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
89
|
+
if self.single_arg_name:
|
|
90
|
+
args_dict = {self.single_arg_name: args_dict}
|
|
91
|
+
|
|
92
|
+
args = [CallContext(deps, self._current_retry, message.tool_name)] if self.function.is_left() else []
|
|
93
|
+
for positional_field in self.positional_fields:
|
|
94
|
+
args.append(args_dict.pop(positional_field))
|
|
95
|
+
if self.var_positional_field:
|
|
96
|
+
args.extend(args_dict.pop(self.var_positional_field))
|
|
97
|
+
|
|
98
|
+
return args, args_dict
|
|
99
|
+
|
|
100
|
+
def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
|
|
101
|
+
self._current_retry += 1
|
|
102
|
+
if self._current_retry > self.max_retries:
|
|
103
|
+
# TODO custom error with details of the retriever
|
|
104
|
+
raise UnexpectedModelBehavior(f'Retriever exceeded max retries count of {self.max_retries}') from exc
|
|
105
|
+
else:
|
|
106
|
+
if isinstance(exc, ValidationError):
|
|
107
|
+
content = exc.errors(include_url=False)
|
|
108
|
+
else:
|
|
109
|
+
content = exc.message
|
|
110
|
+
return messages.RetryPrompt(
|
|
111
|
+
tool_name=call_message.tool_name,
|
|
112
|
+
content=content,
|
|
113
|
+
tool_id=call_message.tool_id,
|
|
114
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Callable, Generic, cast
|
|
7
|
+
|
|
8
|
+
from . import _utils
|
|
9
|
+
from .dependencies import AgentDeps, CallContext, SystemPromptFunc
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SystemPromptRunner(Generic[AgentDeps]):
|
|
14
|
+
function: SystemPromptFunc[AgentDeps]
|
|
15
|
+
_takes_ctx: bool = field(init=False)
|
|
16
|
+
_is_async: bool = field(init=False)
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
|
|
20
|
+
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
21
|
+
|
|
22
|
+
async def run(self, deps: AgentDeps) -> str:
|
|
23
|
+
if self._takes_ctx:
|
|
24
|
+
args = (CallContext(deps, 0, None),)
|
|
25
|
+
else:
|
|
26
|
+
args = ()
|
|
27
|
+
|
|
28
|
+
if self._is_async:
|
|
29
|
+
function = cast(Callable[[Any], Awaitable[str]], self.function)
|
|
30
|
+
return await function(*args)
|
|
31
|
+
else:
|
|
32
|
+
function = cast(Callable[[Any], str], self.function)
|
|
33
|
+
return await _utils.run_in_executor(function, *args)
|
pydantic_ai/_utils.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterator
|
|
6
|
+
from contextlib import asynccontextmanager, suppress
|
|
7
|
+
from dataclasses import dataclass, is_dataclass
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from functools import partial
|
|
10
|
+
from types import GenericAlias
|
|
11
|
+
from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
from pydantic.json_schema import JsonSchemaValue
|
|
15
|
+
from typing_extensions import ParamSpec, TypeAlias, is_typeddict
|
|
16
|
+
|
|
17
|
+
_P = ParamSpec('_P')
|
|
18
|
+
_R = TypeVar('_R')
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
22
|
+
if kwargs:
|
|
23
|
+
# noinspection PyTypeChecker
|
|
24
|
+
return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs))
|
|
25
|
+
else:
|
|
26
|
+
return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_model_like(type_: Any) -> bool:
|
|
30
|
+
"""Check if something is a pydantic model, dataclass or typedict.
|
|
31
|
+
|
|
32
|
+
These should all generate a JSON Schema with `{"type": "object"}` and therefore be usable directly as
|
|
33
|
+
function parameters.
|
|
34
|
+
"""
|
|
35
|
+
return (
|
|
36
|
+
isinstance(type_, type)
|
|
37
|
+
and not isinstance(type_, GenericAlias)
|
|
38
|
+
and (issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_))
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
|
|
43
|
+
ObjectJsonSchema: TypeAlias = dict[str, Any]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
|
47
|
+
from .exceptions import UserError
|
|
48
|
+
|
|
49
|
+
if schema.get('type') == 'object':
|
|
50
|
+
return schema
|
|
51
|
+
else:
|
|
52
|
+
raise UserError('Schema must be an object')
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
T = TypeVar('T')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class Some(Generic[T]):
|
|
60
|
+
"""Analogous to Rust's `Option::Some` type."""
|
|
61
|
+
|
|
62
|
+
value: T
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
Option: TypeAlias = Union[Some[T], None]
|
|
66
|
+
"""Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
Left = TypeVar('Left')
|
|
70
|
+
Right = TypeVar('Right')
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Unset:
|
|
74
|
+
"""A singleton to represent an unset value."""
|
|
75
|
+
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
UNSET = Unset()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Either(Generic[Left, Right]):
|
|
83
|
+
"""Two member Union that records which member was set, this is analogous to Rust enums with two variants.
|
|
84
|
+
|
|
85
|
+
Usage:
|
|
86
|
+
|
|
87
|
+
```py
|
|
88
|
+
if left_thing := either.left:
|
|
89
|
+
use_left(left_thing.value)
|
|
90
|
+
else:
|
|
91
|
+
use_right(either.right)
|
|
92
|
+
```
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
__slots__ = '_left', '_right'
|
|
96
|
+
|
|
97
|
+
@overload
|
|
98
|
+
def __init__(self, *, left: Left) -> None: ...
|
|
99
|
+
|
|
100
|
+
@overload
|
|
101
|
+
def __init__(self, *, right: Right) -> None: ...
|
|
102
|
+
|
|
103
|
+
def __init__(self, left: Left | Unset = UNSET, right: Right | Unset = UNSET) -> None:
|
|
104
|
+
if left is not UNSET:
|
|
105
|
+
assert right is UNSET, '`Either` must receive exactly one argument - `left` or `right`'
|
|
106
|
+
self._left: Option[Left] = Some(cast(Left, left))
|
|
107
|
+
else:
|
|
108
|
+
assert right is not UNSET, '`Either` must receive exactly one argument - `left` or `right`'
|
|
109
|
+
self._left = None
|
|
110
|
+
self._right = cast(Right, right)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def left(self) -> Option[Left]:
|
|
114
|
+
return self._left
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def right(self) -> Right:
|
|
118
|
+
return self._right
|
|
119
|
+
|
|
120
|
+
def is_left(self) -> bool:
|
|
121
|
+
return self._left is not None
|
|
122
|
+
|
|
123
|
+
def whichever(self) -> Left | Right:
|
|
124
|
+
return self._left.value if self._left is not None else self.right
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@asynccontextmanager
|
|
128
|
+
async def group_by_temporal(
|
|
129
|
+
aiter: AsyncIterator[T], soft_max_interval: float | None
|
|
130
|
+
) -> AsyncIterator[AsyncIterable[list[T]]]:
|
|
131
|
+
"""Group items from an async iterable into lists based on time interval between them.
|
|
132
|
+
|
|
133
|
+
Effectively debouncing the iterator.
|
|
134
|
+
|
|
135
|
+
This returns a context manager usable as an iterator so any pending tasks can be cancelled if an error occurs
|
|
136
|
+
during iteration.
|
|
137
|
+
|
|
138
|
+
Usage:
|
|
139
|
+
|
|
140
|
+
```py
|
|
141
|
+
async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
|
|
142
|
+
async for groups in groups_iter:
|
|
143
|
+
print(groups)
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
aiter: The async iterable to group.
|
|
148
|
+
soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
|
|
149
|
+
a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
|
|
150
|
+
as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
A context manager usable as an iterator async iterable of lists of items from the input async iterable.
|
|
154
|
+
"""
|
|
155
|
+
if soft_max_interval is None:
|
|
156
|
+
|
|
157
|
+
async def async_iter_groups_noop() -> AsyncIterator[list[T]]:
|
|
158
|
+
async for item in aiter:
|
|
159
|
+
yield [item]
|
|
160
|
+
|
|
161
|
+
yield async_iter_groups_noop()
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
# we might wait for the next item more than once, so we store the task to await next time
|
|
165
|
+
task: asyncio.Task[T] | None = None
|
|
166
|
+
|
|
167
|
+
async def async_iter_groups() -> AsyncIterator[list[T]]:
|
|
168
|
+
nonlocal task
|
|
169
|
+
|
|
170
|
+
assert soft_max_interval is not None and soft_max_interval >= 0, 'soft_max_interval must be a positive number'
|
|
171
|
+
buffer: list[T] = []
|
|
172
|
+
group_start_time = time.monotonic()
|
|
173
|
+
|
|
174
|
+
while True:
|
|
175
|
+
if group_start_time is None:
|
|
176
|
+
# group hasn't started, we just wait for the maximum interval
|
|
177
|
+
wait_time = soft_max_interval
|
|
178
|
+
else:
|
|
179
|
+
# wait for the time remaining in the group
|
|
180
|
+
wait_time = soft_max_interval - (time.monotonic() - group_start_time)
|
|
181
|
+
|
|
182
|
+
# if there's no current task, we get the next one
|
|
183
|
+
if task is None:
|
|
184
|
+
# aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
|
|
185
|
+
# so far, this doesn't seem to be a problem
|
|
186
|
+
task = asyncio.create_task(aiter.__anext__()) # pyright: ignore[reportArgumentType]
|
|
187
|
+
|
|
188
|
+
# we use asyncio.wait to avoid cancelling the coroutine if it's not done
|
|
189
|
+
done, _ = await asyncio.wait((task,), timeout=wait_time)
|
|
190
|
+
|
|
191
|
+
if done:
|
|
192
|
+
# the one task we waited for completed
|
|
193
|
+
try:
|
|
194
|
+
item = done.pop().result()
|
|
195
|
+
except StopAsyncIteration:
|
|
196
|
+
# if the task raised StopAsyncIteration, we're done iterating
|
|
197
|
+
if buffer:
|
|
198
|
+
yield buffer
|
|
199
|
+
task = None
|
|
200
|
+
break
|
|
201
|
+
else:
|
|
202
|
+
# we got an item, add it to the buffer and set task to None to get the next item
|
|
203
|
+
buffer.append(item)
|
|
204
|
+
task = None
|
|
205
|
+
# if this is the first item in the group, set the group start time
|
|
206
|
+
if group_start_time is None:
|
|
207
|
+
group_start_time = time.monotonic()
|
|
208
|
+
elif buffer:
|
|
209
|
+
# otherwise if the task timeout expired and we have items in the buffer, yield the buffer
|
|
210
|
+
yield buffer
|
|
211
|
+
# clear the buffer and reset the group start time ready for the next group
|
|
212
|
+
buffer = []
|
|
213
|
+
group_start_time = None
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
yield async_iter_groups()
|
|
217
|
+
finally:
|
|
218
|
+
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
|
|
219
|
+
if task:
|
|
220
|
+
task.cancel('Cancelling due to error in iterator')
|
|
221
|
+
with suppress(asyncio.CancelledError):
|
|
222
|
+
await task
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def add_optional(a: str | None, b: str | None) -> str | None:
|
|
226
|
+
"""Add two optional strings."""
|
|
227
|
+
if a is None:
|
|
228
|
+
return b
|
|
229
|
+
elif b is None:
|
|
230
|
+
return a
|
|
231
|
+
else:
|
|
232
|
+
return a + b
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def sync_anext(iterator: Iterator[T]) -> T:
|
|
236
|
+
"""Get the next item from a sync iterator, raising `StopAsyncIteration` if it's exhausted.
|
|
237
|
+
|
|
238
|
+
Useful when iterating over a sync iterator in an async context.
|
|
239
|
+
"""
|
|
240
|
+
try:
|
|
241
|
+
return next(iterator)
|
|
242
|
+
except StopIteration as e:
|
|
243
|
+
raise StopAsyncIteration() from e
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def now_utc() -> datetime:
|
|
247
|
+
return datetime.now(tz=timezone.utc)
|