pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +13 -29
- pydantic_ai/_result.py +52 -38
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +20 -8
- pydantic_ai/agent.py +431 -167
- pydantic_ai/messages.py +90 -48
- pydantic_ai/models/__init__.py +59 -42
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +66 -44
- pydantic_ai/models/gemini.py +160 -117
- pydantic_ai/models/groq.py +125 -108
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +145 -114
- pydantic_ai/models/test.py +109 -77
- pydantic_ai/models/vertexai.py +14 -9
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +140 -45
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
pydantic_ai/_pydantic.py
CHANGED
|
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
9
|
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
|
|
10
10
|
|
|
11
|
-
from pydantic import ConfigDict
|
|
11
|
+
from pydantic import ConfigDict
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
13
13
|
from pydantic._internal._config import ConfigWrapper
|
|
14
14
|
from pydantic.fields import FieldInfo
|
|
@@ -17,13 +17,13 @@ from pydantic.plugin._schema_validator import create_schema_validator
|
|
|
17
17
|
from pydantic_core import SchemaValidator, core_schema
|
|
18
18
|
|
|
19
19
|
from ._griffe import doc_descriptions
|
|
20
|
-
from ._utils import
|
|
20
|
+
from ._utils import check_object_json_schema, is_model_like
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
|
|
23
|
+
from .tools import ObjectJsonSchema
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
__all__ = 'function_schema',
|
|
26
|
+
__all__ = ('function_schema',)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class FunctionSchema(TypedDict):
|
|
@@ -138,14 +138,14 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
|
|
|
138
138
|
json_schema = GenerateJsonSchema().generate(schema)
|
|
139
139
|
|
|
140
140
|
# workaround for https://github.com/pydantic/pydantic/issues/10785
|
|
141
|
-
# if we build a custom TypeDict schema (matches when `single_arg_name
|
|
141
|
+
# if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
|
|
142
142
|
# `additionalProperties` in the JSON Schema
|
|
143
143
|
if single_arg_name is None:
|
|
144
144
|
json_schema['additionalProperties'] = bool(var_kwargs_schema)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
145
|
+
elif not description:
|
|
146
|
+
# if the tool description is not set, and we have a single parameter, take the description from that
|
|
147
|
+
# and set it on the tool
|
|
148
|
+
description = json_schema.pop('description', None)
|
|
149
149
|
|
|
150
150
|
return FunctionSchema(
|
|
151
151
|
description=description,
|
|
@@ -168,11 +168,13 @@ def takes_ctx(function: Callable[..., Any]) -> bool:
|
|
|
168
168
|
"""
|
|
169
169
|
sig = signature(function)
|
|
170
170
|
try:
|
|
171
|
-
|
|
171
|
+
first_param_name = next(iter(sig.parameters.keys()))
|
|
172
172
|
except StopIteration:
|
|
173
173
|
return False
|
|
174
174
|
else:
|
|
175
|
-
|
|
175
|
+
type_hints = _typing_extra.get_function_type_hints(function)
|
|
176
|
+
annotation = type_hints[first_param_name]
|
|
177
|
+
return annotation is not sig.empty and _is_call_ctx(annotation)
|
|
176
178
|
|
|
177
179
|
|
|
178
180
|
def _build_schema(
|
|
@@ -212,21 +214,3 @@ def _is_call_ctx(annotation: Any) -> bool:
|
|
|
212
214
|
return annotation is RunContext or (
|
|
213
215
|
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
|
214
216
|
)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
if TYPE_CHECKING:
|
|
218
|
-
LazyTypeAdapter = TypeAdapter
|
|
219
|
-
else:
|
|
220
|
-
|
|
221
|
-
class LazyTypeAdapter:
|
|
222
|
-
__slots__ = '_args', '_kwargs', '_type_adapter'
|
|
223
|
-
|
|
224
|
-
def __init__(self, *args, **kwargs):
|
|
225
|
-
self._args = args
|
|
226
|
-
self._kwargs = kwargs
|
|
227
|
-
self._type_adapter = None
|
|
228
|
-
|
|
229
|
-
def __getattr__(self, item):
|
|
230
|
-
if self._type_adapter is None:
|
|
231
|
-
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
|
|
232
|
-
return getattr(self._type_adapter, item)
|
pydantic_ai/_result.py
CHANGED
|
@@ -3,18 +3,17 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import sys
|
|
5
5
|
import types
|
|
6
|
-
from collections.abc import Awaitable
|
|
6
|
+
from collections.abc import Awaitable, Iterable
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
|
|
9
9
|
|
|
10
10
|
from pydantic import TypeAdapter, ValidationError
|
|
11
11
|
from typing_extensions import Self, TypeAliasType, TypedDict
|
|
12
12
|
|
|
13
|
-
from . import _utils, messages
|
|
13
|
+
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .messages import ModelStructuredResponse, ToolCall
|
|
16
15
|
from .result import ResultData
|
|
17
|
-
from .tools import AgentDeps, ResultValidatorFunc, RunContext
|
|
16
|
+
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
@dataclass
|
|
@@ -28,7 +27,12 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
28
27
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
29
28
|
|
|
30
29
|
async def validate(
|
|
31
|
-
self,
|
|
30
|
+
self,
|
|
31
|
+
result: ResultData,
|
|
32
|
+
deps: AgentDeps,
|
|
33
|
+
retry: int,
|
|
34
|
+
tool_call: _messages.ToolCallPart | None,
|
|
35
|
+
messages: list[_messages.ModelMessage],
|
|
32
36
|
) -> ResultData:
|
|
33
37
|
"""Validate a result but calling the function.
|
|
34
38
|
|
|
@@ -37,12 +41,13 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
37
41
|
deps: The agent dependencies.
|
|
38
42
|
retry: The current retry number.
|
|
39
43
|
tool_call: The original tool call message, `None` if there was no tool call.
|
|
44
|
+
messages: The messages exchanged so far in the conversation.
|
|
40
45
|
|
|
41
46
|
Returns:
|
|
42
47
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
43
48
|
"""
|
|
44
49
|
if self._takes_ctx:
|
|
45
|
-
args = RunContext(deps, retry, tool_call.tool_name if tool_call else None), result
|
|
50
|
+
args = RunContext(deps, retry, messages, tool_call.tool_name if tool_call else None), result
|
|
46
51
|
else:
|
|
47
52
|
args = (result,)
|
|
48
53
|
|
|
@@ -54,10 +59,10 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
54
59
|
function = cast(Callable[[Any], ResultData], self.function)
|
|
55
60
|
result_data = await _utils.run_in_executor(function, *args)
|
|
56
61
|
except ModelRetry as r:
|
|
57
|
-
m =
|
|
62
|
+
m = _messages.RetryPromptPart(content=r.message)
|
|
58
63
|
if tool_call is not None:
|
|
59
64
|
m.tool_name = tool_call.tool_name
|
|
60
|
-
m.
|
|
65
|
+
m.tool_call_id = tool_call.tool_call_id
|
|
61
66
|
raise ToolRetryError(m) from r
|
|
62
67
|
else:
|
|
63
68
|
return result_data
|
|
@@ -66,7 +71,7 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
66
71
|
class ToolRetryError(Exception):
|
|
67
72
|
"""Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
68
73
|
|
|
69
|
-
def __init__(self, tool_retry:
|
|
74
|
+
def __init__(self, tool_retry: _messages.RetryPromptPart):
|
|
70
75
|
self.tool_retry = tool_retry
|
|
71
76
|
super().__init__()
|
|
72
77
|
|
|
@@ -94,10 +99,7 @@ class ResultSchema(Generic[ResultData]):
|
|
|
94
99
|
allow_text_result = False
|
|
95
100
|
|
|
96
101
|
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
|
|
97
|
-
return cast(
|
|
98
|
-
ResultTool[ResultData],
|
|
99
|
-
ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
|
|
100
|
-
)
|
|
102
|
+
return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
|
|
101
103
|
|
|
102
104
|
tools: dict[str, ResultTool[ResultData]] = {}
|
|
103
105
|
if args := get_union_args(response_type):
|
|
@@ -111,48 +113,61 @@ class ResultSchema(Generic[ResultData]):
|
|
|
111
113
|
|
|
112
114
|
return cls(tools=tools, allow_text_result=allow_text_result)
|
|
113
115
|
|
|
114
|
-
def
|
|
116
|
+
def find_named_tool(
|
|
117
|
+
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
118
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
119
|
+
"""Find a tool that matches one of the calls, with a specific name."""
|
|
120
|
+
for part in parts:
|
|
121
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
122
|
+
if part.tool_name == tool_name:
|
|
123
|
+
return part, self.tools[tool_name]
|
|
124
|
+
|
|
125
|
+
def find_tool(
|
|
126
|
+
self,
|
|
127
|
+
parts: Iterable[_messages.ModelResponsePart],
|
|
128
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
115
129
|
"""Find a tool that matches one of the calls."""
|
|
116
|
-
for
|
|
117
|
-
if
|
|
118
|
-
|
|
130
|
+
for part in parts:
|
|
131
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
132
|
+
if result := self.tools.get(part.tool_name):
|
|
133
|
+
return part, result
|
|
119
134
|
|
|
120
135
|
def tool_names(self) -> list[str]:
|
|
121
136
|
"""Return the names of the tools."""
|
|
122
137
|
return list(self.tools.keys())
|
|
123
138
|
|
|
139
|
+
def tool_defs(self) -> list[ToolDefinition]:
|
|
140
|
+
"""Get tool definitions to register with the model."""
|
|
141
|
+
return [t.tool_def for t in self.tools.values()]
|
|
142
|
+
|
|
124
143
|
|
|
125
144
|
DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
|
|
126
145
|
|
|
127
146
|
|
|
128
|
-
@dataclass
|
|
147
|
+
@dataclass(init=False)
|
|
129
148
|
class ResultTool(Generic[ResultData]):
|
|
130
|
-
|
|
131
|
-
description: str
|
|
149
|
+
tool_def: ToolDefinition
|
|
132
150
|
type_adapter: TypeAdapter[Any]
|
|
133
|
-
json_schema: _utils.ObjectJsonSchema
|
|
134
|
-
outer_typed_dict_key: str | None
|
|
135
151
|
|
|
136
|
-
|
|
137
|
-
def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
|
|
152
|
+
def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
|
|
138
153
|
"""Build a ResultTool dataclass from a response type."""
|
|
139
154
|
assert response_type is not str, 'ResultTool does not support str as a response type'
|
|
140
155
|
|
|
141
156
|
if _utils.is_model_like(response_type):
|
|
142
|
-
type_adapter = TypeAdapter(response_type)
|
|
157
|
+
self.type_adapter = TypeAdapter(response_type)
|
|
143
158
|
outer_typed_dict_key: str | None = None
|
|
144
159
|
# noinspection PyArgumentList
|
|
145
|
-
|
|
160
|
+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
|
|
146
161
|
else:
|
|
147
162
|
response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
|
|
148
|
-
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
163
|
+
self.type_adapter = TypeAdapter(response_data_typed_dict)
|
|
149
164
|
outer_typed_dict_key = 'response'
|
|
150
165
|
# noinspection PyArgumentList
|
|
151
|
-
|
|
166
|
+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
|
|
152
167
|
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
|
|
153
|
-
|
|
168
|
+
parameters_json_schema.pop('title')
|
|
154
169
|
|
|
155
|
-
if json_schema_description :=
|
|
170
|
+
if json_schema_description := parameters_json_schema.pop('description', None):
|
|
156
171
|
if description is None:
|
|
157
172
|
tool_description = json_schema_description
|
|
158
173
|
else:
|
|
@@ -162,16 +177,15 @@ class ResultTool(Generic[ResultData]):
|
|
|
162
177
|
if multiple:
|
|
163
178
|
tool_description = f'{union_arg_name(response_type)}: {tool_description}'
|
|
164
179
|
|
|
165
|
-
|
|
180
|
+
self.tool_def = ToolDefinition(
|
|
166
181
|
name=name,
|
|
167
182
|
description=tool_description,
|
|
168
|
-
|
|
169
|
-
json_schema=json_schema,
|
|
183
|
+
parameters_json_schema=parameters_json_schema,
|
|
170
184
|
outer_typed_dict_key=outer_typed_dict_key,
|
|
171
185
|
)
|
|
172
186
|
|
|
173
187
|
def validate(
|
|
174
|
-
self, tool_call:
|
|
188
|
+
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
|
|
175
189
|
) -> ResultData:
|
|
176
190
|
"""Validate a result message.
|
|
177
191
|
|
|
@@ -185,7 +199,7 @@ class ResultTool(Generic[ResultData]):
|
|
|
185
199
|
"""
|
|
186
200
|
try:
|
|
187
201
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
188
|
-
if isinstance(tool_call.args,
|
|
202
|
+
if isinstance(tool_call.args, _messages.ArgsJson):
|
|
189
203
|
result = self.type_adapter.validate_json(
|
|
190
204
|
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
|
|
191
205
|
)
|
|
@@ -195,16 +209,16 @@ class ResultTool(Generic[ResultData]):
|
|
|
195
209
|
)
|
|
196
210
|
except ValidationError as e:
|
|
197
211
|
if wrap_validation_errors:
|
|
198
|
-
m =
|
|
212
|
+
m = _messages.RetryPromptPart(
|
|
199
213
|
tool_name=tool_call.tool_name,
|
|
200
214
|
content=e.errors(include_url=False),
|
|
201
|
-
|
|
215
|
+
tool_call_id=tool_call.tool_call_id,
|
|
202
216
|
)
|
|
203
217
|
raise ToolRetryError(m) from e
|
|
204
218
|
else:
|
|
205
219
|
raise
|
|
206
220
|
else:
|
|
207
|
-
if k := self.outer_typed_dict_key:
|
|
221
|
+
if k := self.tool_def.outer_typed_dict_key:
|
|
208
222
|
result = result[k]
|
|
209
223
|
return result
|
|
210
224
|
|
pydantic_ai/_system_prompt.py
CHANGED
pydantic_ai/_utils.py
CHANGED
|
@@ -8,12 +8,16 @@ from dataclasses import dataclass, is_dataclass
|
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from functools import partial
|
|
10
10
|
from types import GenericAlias
|
|
11
|
-
from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel
|
|
14
14
|
from pydantic.json_schema import JsonSchemaValue
|
|
15
15
|
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
16
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .messages import RetryPromptPart, ToolCallPart, ToolReturnPart
|
|
19
|
+
from .tools import ObjectJsonSchema
|
|
20
|
+
|
|
17
21
|
_P = ParamSpec('_P')
|
|
18
22
|
_R = TypeVar('_R')
|
|
19
23
|
|
|
@@ -39,10 +43,6 @@ def is_model_like(type_: Any) -> bool:
|
|
|
39
43
|
)
|
|
40
44
|
|
|
41
45
|
|
|
42
|
-
# With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
|
|
43
|
-
ObjectJsonSchema: TypeAlias = dict[str, Any]
|
|
44
|
-
|
|
45
|
-
|
|
46
46
|
def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
|
47
47
|
from .exceptions import UserError
|
|
48
48
|
|
|
@@ -88,7 +88,7 @@ class Either(Generic[Left, Right]):
|
|
|
88
88
|
|
|
89
89
|
Usage:
|
|
90
90
|
|
|
91
|
-
```
|
|
91
|
+
```python
|
|
92
92
|
if left_thing := either.left:
|
|
93
93
|
use_left(left_thing.value)
|
|
94
94
|
else:
|
|
@@ -127,6 +127,12 @@ class Either(Generic[Left, Right]):
|
|
|
127
127
|
def whichever(self) -> Left | Right:
|
|
128
128
|
return self._left.value if self._left is not None else self.right
|
|
129
129
|
|
|
130
|
+
def __repr__(self):
|
|
131
|
+
if left := self._left:
|
|
132
|
+
return f'Either(left={left.value!r})'
|
|
133
|
+
else:
|
|
134
|
+
return f'Either(right={self.right!r})'
|
|
135
|
+
|
|
130
136
|
|
|
131
137
|
@asynccontextmanager
|
|
132
138
|
async def group_by_temporal(
|
|
@@ -141,7 +147,7 @@ async def group_by_temporal(
|
|
|
141
147
|
|
|
142
148
|
Usage:
|
|
143
149
|
|
|
144
|
-
```
|
|
150
|
+
```python
|
|
145
151
|
async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
|
|
146
152
|
async for groups in groups_iter:
|
|
147
153
|
print(groups)
|
|
@@ -218,7 +224,7 @@ async def group_by_temporal(
|
|
|
218
224
|
|
|
219
225
|
try:
|
|
220
226
|
yield async_iter_groups()
|
|
221
|
-
finally:
|
|
227
|
+
finally: # pragma: no cover
|
|
222
228
|
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
|
|
223
229
|
if task:
|
|
224
230
|
task.cancel('Cancelling due to error in iterator')
|
|
@@ -249,3 +255,9 @@ def sync_anext(iterator: Iterator[T]) -> T:
|
|
|
249
255
|
|
|
250
256
|
def now_utc() -> datetime:
|
|
251
257
|
return datetime.now(tz=timezone.utc)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model_source: str) -> str:
|
|
261
|
+
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
|
|
262
|
+
assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
|
|
263
|
+
return t.tool_call_id
|