pydantic-ai-slim 0.0.6a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,114 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import inspect
4
+ from collections.abc import Awaitable
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Callable, Generic, cast
7
+
8
+ from pydantic import ValidationError
9
+ from pydantic_core import SchemaValidator
10
+
11
+ from . import _pydantic, _utils, messages
12
+ from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
13
+ from .exceptions import ModelRetry, UnexpectedModelBehavior
14
+
15
+ # Usage `RetrieverEitherFunc[AgentDependencies, P]`
16
+ RetrieverEitherFunc = _utils.Either[
17
+ RetrieverContextFunc[AgentDeps, RetrieverParams], RetrieverPlainFunc[RetrieverParams]
18
+ ]
19
+
20
+
21
+ @dataclass(init=False)
22
+ class Retriever(Generic[AgentDeps, RetrieverParams]):
23
+ """A retriever function for an agent."""
24
+
25
+ name: str
26
+ description: str
27
+ function: RetrieverEitherFunc[AgentDeps, RetrieverParams] = field(repr=False)
28
+ is_async: bool
29
+ single_arg_name: str | None
30
+ positional_fields: list[str]
31
+ var_positional_field: str | None
32
+ validator: SchemaValidator = field(repr=False)
33
+ json_schema: _utils.ObjectJsonSchema
34
+ max_retries: int
35
+ _current_retry: int = 0
36
+ outer_typed_dict_key: str | None = None
37
+
38
+ def __init__(self, function: RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int):
39
+ """Build a Retriever dataclass from a function."""
40
+ self.function = function
41
+ # noinspection PyTypeChecker
42
+ f = _pydantic.function_schema(function)
43
+ raw_function = function.whichever()
44
+ self.name = raw_function.__name__
45
+ self.description = f['description']
46
+ self.is_async = inspect.iscoroutinefunction(raw_function)
47
+ self.single_arg_name = f['single_arg_name']
48
+ self.positional_fields = f['positional_fields']
49
+ self.var_positional_field = f['var_positional_field']
50
+ self.validator = f['validator']
51
+ self.json_schema = f['json_schema']
52
+ self.max_retries = retries
53
+
54
+ def reset(self) -> None:
55
+ """Reset the current retry count."""
56
+ self._current_retry = 0
57
+
58
+ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
59
+ """Run the retriever function asynchronously."""
60
+ try:
61
+ if isinstance(message.args, messages.ArgsJson):
62
+ args_dict = self.validator.validate_json(message.args.args_json)
63
+ else:
64
+ args_dict = self.validator.validate_python(message.args.args_object)
65
+ except ValidationError as e:
66
+ return self._on_error(e, message)
67
+
68
+ args, kwargs = self._call_args(deps, args_dict, message)
69
+ try:
70
+ if self.is_async:
71
+ function = cast(Callable[[Any], Awaitable[str]], self.function.whichever())
72
+ response_content = await function(*args, **kwargs)
73
+ else:
74
+ function = cast(Callable[[Any], str], self.function.whichever())
75
+ response_content = await _utils.run_in_executor(function, *args, **kwargs)
76
+ except ModelRetry as e:
77
+ return self._on_error(e, message)
78
+
79
+ self._current_retry = 0
80
+ return messages.ToolReturn(
81
+ tool_name=message.tool_name,
82
+ content=response_content,
83
+ tool_id=message.tool_id,
84
+ )
85
+
86
+ def _call_args(
87
+ self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
88
+ ) -> tuple[list[Any], dict[str, Any]]:
89
+ if self.single_arg_name:
90
+ args_dict = {self.single_arg_name: args_dict}
91
+
92
+ args = [CallContext(deps, self._current_retry, message.tool_name)] if self.function.is_left() else []
93
+ for positional_field in self.positional_fields:
94
+ args.append(args_dict.pop(positional_field))
95
+ if self.var_positional_field:
96
+ args.extend(args_dict.pop(self.var_positional_field))
97
+
98
+ return args, args_dict
99
+
100
+ def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
101
+ self._current_retry += 1
102
+ if self._current_retry > self.max_retries:
103
+ # TODO custom error with details of the retriever
104
+ raise UnexpectedModelBehavior(f'Retriever exceeded max retries count of {self.max_retries}') from exc
105
+ else:
106
+ if isinstance(exc, ValidationError):
107
+ content = exc.errors(include_url=False)
108
+ else:
109
+ content = exc.message
110
+ return messages.RetryPrompt(
111
+ tool_name=call_message.tool_name,
112
+ content=content,
113
+ tool_id=call_message.tool_id,
114
+ )
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import inspect
4
+ from collections.abc import Awaitable
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Callable, Generic, cast
7
+
8
+ from . import _utils
9
+ from .dependencies import AgentDeps, CallContext, SystemPromptFunc
10
+
11
+
12
+ @dataclass
13
+ class SystemPromptRunner(Generic[AgentDeps]):
14
+ function: SystemPromptFunc[AgentDeps]
15
+ _takes_ctx: bool = field(init=False)
16
+ _is_async: bool = field(init=False)
17
+
18
+ def __post_init__(self):
19
+ self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
20
+ self._is_async = inspect.iscoroutinefunction(self.function)
21
+
22
+ async def run(self, deps: AgentDeps) -> str:
23
+ if self._takes_ctx:
24
+ args = (CallContext(deps, 0, None),)
25
+ else:
26
+ args = ()
27
+
28
+ if self._is_async:
29
+ function = cast(Callable[[Any], Awaitable[str]], self.function)
30
+ return await function(*args)
31
+ else:
32
+ function = cast(Callable[[Any], str], self.function)
33
+ return await _utils.run_in_executor(function, *args)
pydantic_ai/_utils.py ADDED
@@ -0,0 +1,247 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import asyncio
4
+ import time
5
+ from collections.abc import AsyncIterable, AsyncIterator, Iterator
6
+ from contextlib import asynccontextmanager, suppress
7
+ from dataclasses import dataclass, is_dataclass
8
+ from datetime import datetime, timezone
9
+ from functools import partial
10
+ from types import GenericAlias
11
+ from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
12
+
13
+ from pydantic import BaseModel
14
+ from pydantic.json_schema import JsonSchemaValue
15
+ from typing_extensions import ParamSpec, TypeAlias, is_typeddict
16
+
17
+ _P = ParamSpec('_P')
18
+ _R = TypeVar('_R')
19
+
20
+
21
+ async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R:
22
+ if kwargs:
23
+ # noinspection PyTypeChecker
24
+ return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs))
25
+ else:
26
+ return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore
27
+
28
+
29
+ def is_model_like(type_: Any) -> bool:
30
+ """Check if something is a pydantic model, dataclass or typedict.
31
+
32
+ These should all generate a JSON Schema with `{"type": "object"}` and therefore be usable directly as
33
+ function parameters.
34
+ """
35
+ return (
36
+ isinstance(type_, type)
37
+ and not isinstance(type_, GenericAlias)
38
+ and (issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_))
39
+ )
40
+
41
+
42
+ # With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
43
+ ObjectJsonSchema: TypeAlias = dict[str, Any]
44
+
45
+
46
+ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
47
+ from .exceptions import UserError
48
+
49
+ if schema.get('type') == 'object':
50
+ return schema
51
+ else:
52
+ raise UserError('Schema must be an object')
53
+
54
+
55
+ T = TypeVar('T')
56
+
57
+
58
+ @dataclass
59
+ class Some(Generic[T]):
60
+ """Analogous to Rust's `Option::Some` type."""
61
+
62
+ value: T
63
+
64
+
65
+ Option: TypeAlias = Union[Some[T], None]
66
+ """Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""
67
+
68
+
69
+ Left = TypeVar('Left')
70
+ Right = TypeVar('Right')
71
+
72
+
73
+ class Unset:
74
+ """A singleton to represent an unset value."""
75
+
76
+ pass
77
+
78
+
79
+ UNSET = Unset()
80
+
81
+
82
+ class Either(Generic[Left, Right]):
83
+ """Two member Union that records which member was set, this is analogous to Rust enums with two variants.
84
+
85
+ Usage:
86
+
87
+ ```py
88
+ if left_thing := either.left:
89
+ use_left(left_thing.value)
90
+ else:
91
+ use_right(either.right)
92
+ ```
93
+ """
94
+
95
+ __slots__ = '_left', '_right'
96
+
97
+ @overload
98
+ def __init__(self, *, left: Left) -> None: ...
99
+
100
+ @overload
101
+ def __init__(self, *, right: Right) -> None: ...
102
+
103
+ def __init__(self, left: Left | Unset = UNSET, right: Right | Unset = UNSET) -> None:
104
+ if left is not UNSET:
105
+ assert right is UNSET, '`Either` must receive exactly one argument - `left` or `right`'
106
+ self._left: Option[Left] = Some(cast(Left, left))
107
+ else:
108
+ assert right is not UNSET, '`Either` must receive exactly one argument - `left` or `right`'
109
+ self._left = None
110
+ self._right = cast(Right, right)
111
+
112
+ @property
113
+ def left(self) -> Option[Left]:
114
+ return self._left
115
+
116
+ @property
117
+ def right(self) -> Right:
118
+ return self._right
119
+
120
+ def is_left(self) -> bool:
121
+ return self._left is not None
122
+
123
+ def whichever(self) -> Left | Right:
124
+ return self._left.value if self._left is not None else self.right
125
+
126
+
127
+ @asynccontextmanager
128
+ async def group_by_temporal(
129
+ aiter: AsyncIterator[T], soft_max_interval: float | None
130
+ ) -> AsyncIterator[AsyncIterable[list[T]]]:
131
+ """Group items from an async iterable into lists based on time interval between them.
132
+
133
+ Effectively debouncing the iterator.
134
+
135
+ This returns a context manager usable as an iterator so any pending tasks can be cancelled if an error occurs
136
+ during iteration.
137
+
138
+ Usage:
139
+
140
+ ```py
141
+ async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
142
+ async for groups in groups_iter:
143
+ print(groups)
144
+ ```
145
+
146
+ Args:
147
+ aiter: The async iterable to group.
148
+ soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
149
+ a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
150
+ as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed
151
+
152
+ Returns:
153
+ A context manager usable as an iterator async iterable of lists of items from the input async iterable.
154
+ """
155
+ if soft_max_interval is None:
156
+
157
+ async def async_iter_groups_noop() -> AsyncIterator[list[T]]:
158
+ async for item in aiter:
159
+ yield [item]
160
+
161
+ yield async_iter_groups_noop()
162
+ return
163
+
164
+ # we might wait for the next item more than once, so we store the task to await next time
165
+ task: asyncio.Task[T] | None = None
166
+
167
+ async def async_iter_groups() -> AsyncIterator[list[T]]:
168
+ nonlocal task
169
+
170
+ assert soft_max_interval is not None and soft_max_interval >= 0, 'soft_max_interval must be a positive number'
171
+ buffer: list[T] = []
172
+ group_start_time = time.monotonic()
173
+
174
+ while True:
175
+ if group_start_time is None:
176
+ # group hasn't started, we just wait for the maximum interval
177
+ wait_time = soft_max_interval
178
+ else:
179
+ # wait for the time remaining in the group
180
+ wait_time = soft_max_interval - (time.monotonic() - group_start_time)
181
+
182
+ # if there's no current task, we get the next one
183
+ if task is None:
184
+ # aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
185
+ # so far, this doesn't seem to be a problem
186
+ task = asyncio.create_task(aiter.__anext__()) # pyright: ignore[reportArgumentType]
187
+
188
+ # we use asyncio.wait to avoid cancelling the coroutine if it's not done
189
+ done, _ = await asyncio.wait((task,), timeout=wait_time)
190
+
191
+ if done:
192
+ # the one task we waited for completed
193
+ try:
194
+ item = done.pop().result()
195
+ except StopAsyncIteration:
196
+ # if the task raised StopAsyncIteration, we're done iterating
197
+ if buffer:
198
+ yield buffer
199
+ task = None
200
+ break
201
+ else:
202
+ # we got an item, add it to the buffer and set task to None to get the next item
203
+ buffer.append(item)
204
+ task = None
205
+ # if this is the first item in the group, set the group start time
206
+ if group_start_time is None:
207
+ group_start_time = time.monotonic()
208
+ elif buffer:
209
+ # otherwise if the task timeout expired and we have items in the buffer, yield the buffer
210
+ yield buffer
211
+ # clear the buffer and reset the group start time ready for the next group
212
+ buffer = []
213
+ group_start_time = None
214
+
215
+ try:
216
+ yield async_iter_groups()
217
+ finally:
218
+ # after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
219
+ if task:
220
+ task.cancel('Cancelling due to error in iterator')
221
+ with suppress(asyncio.CancelledError):
222
+ await task
223
+
224
+
225
+ def add_optional(a: str | None, b: str | None) -> str | None:
226
+ """Add two optional strings."""
227
+ if a is None:
228
+ return b
229
+ elif b is None:
230
+ return a
231
+ else:
232
+ return a + b
233
+
234
+
235
+ def sync_anext(iterator: Iterator[T]) -> T:
236
+ """Get the next item from a sync iterator, raising `StopAsyncIteration` if it's exhausted.
237
+
238
+ Useful when iterating over a sync iterator in an async context.
239
+ """
240
+ try:
241
+ return next(iterator)
242
+ except StopIteration as e:
243
+ raise StopAsyncIteration() from e
244
+
245
+
246
+ def now_utc() -> datetime:
247
+ return datetime.now(tz=timezone.utc)