pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +34 -16
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +333 -148
- pydantic_ai/messages.py +87 -48
- pydantic_ai/models/__init__.py +30 -6
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +59 -31
- pydantic_ai/models/gemini.py +150 -108
- pydantic_ai/models/groq.py +94 -74
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +102 -76
- pydantic_ai/models/test.py +62 -51
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +28 -18
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/_pydantic.py
CHANGED
|
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
9
|
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
|
|
10
10
|
|
|
11
|
-
from pydantic import ConfigDict
|
|
11
|
+
from pydantic import ConfigDict
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
13
13
|
from pydantic._internal._config import ConfigWrapper
|
|
14
14
|
from pydantic.fields import FieldInfo
|
|
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|
|
23
23
|
from .tools import ObjectJsonSchema
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
__all__ = 'function_schema',
|
|
26
|
+
__all__ = ('function_schema',)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class FunctionSchema(TypedDict):
|
|
@@ -138,14 +138,14 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
|
|
|
138
138
|
json_schema = GenerateJsonSchema().generate(schema)
|
|
139
139
|
|
|
140
140
|
# workaround for https://github.com/pydantic/pydantic/issues/10785
|
|
141
|
-
# if we build a custom TypeDict schema (matches when `single_arg_name
|
|
141
|
+
# if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
|
|
142
142
|
# `additionalProperties` in the JSON Schema
|
|
143
143
|
if single_arg_name is None:
|
|
144
144
|
json_schema['additionalProperties'] = bool(var_kwargs_schema)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
145
|
+
elif not description:
|
|
146
|
+
# if the tool description is not set, and we have a single parameter, take the description from that
|
|
147
|
+
# and set it on the tool
|
|
148
|
+
description = json_schema.pop('description', None)
|
|
149
149
|
|
|
150
150
|
return FunctionSchema(
|
|
151
151
|
description=description,
|
|
@@ -214,21 +214,3 @@ def _is_call_ctx(annotation: Any) -> bool:
|
|
|
214
214
|
return annotation is RunContext or (
|
|
215
215
|
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
|
216
216
|
)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
if TYPE_CHECKING:
|
|
220
|
-
LazyTypeAdapter = TypeAdapter
|
|
221
|
-
else:
|
|
222
|
-
|
|
223
|
-
class LazyTypeAdapter:
|
|
224
|
-
__slots__ = '_args', '_kwargs', '_type_adapter'
|
|
225
|
-
|
|
226
|
-
def __init__(self, *args, **kwargs):
|
|
227
|
-
self._args = args
|
|
228
|
-
self._kwargs = kwargs
|
|
229
|
-
self._type_adapter = None
|
|
230
|
-
|
|
231
|
-
def __getattr__(self, item):
|
|
232
|
-
if self._type_adapter is None:
|
|
233
|
-
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
|
|
234
|
-
return getattr(self._type_adapter, item)
|
pydantic_ai/_result.py
CHANGED
|
@@ -3,16 +3,15 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import sys
|
|
5
5
|
import types
|
|
6
|
-
from collections.abc import Awaitable
|
|
6
|
+
from collections.abc import Awaitable, Iterable
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
|
|
9
9
|
|
|
10
10
|
from pydantic import TypeAdapter, ValidationError
|
|
11
11
|
from typing_extensions import Self, TypeAliasType, TypedDict
|
|
12
12
|
|
|
13
|
-
from . import _utils, messages
|
|
13
|
+
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .messages import ModelStructuredResponse, ToolCall
|
|
16
15
|
from .result import ResultData
|
|
17
16
|
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
|
|
18
17
|
|
|
@@ -28,7 +27,12 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
28
27
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
29
28
|
|
|
30
29
|
async def validate(
|
|
31
|
-
self,
|
|
30
|
+
self,
|
|
31
|
+
result: ResultData,
|
|
32
|
+
deps: AgentDeps,
|
|
33
|
+
retry: int,
|
|
34
|
+
tool_call: _messages.ToolCallPart | None,
|
|
35
|
+
messages: list[_messages.ModelMessage],
|
|
32
36
|
) -> ResultData:
|
|
33
37
|
"""Validate a result but calling the function.
|
|
34
38
|
|
|
@@ -37,12 +41,13 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
37
41
|
deps: The agent dependencies.
|
|
38
42
|
retry: The current retry number.
|
|
39
43
|
tool_call: The original tool call message, `None` if there was no tool call.
|
|
44
|
+
messages: The messages exchanged so far in the conversation.
|
|
40
45
|
|
|
41
46
|
Returns:
|
|
42
47
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
43
48
|
"""
|
|
44
49
|
if self._takes_ctx:
|
|
45
|
-
args = RunContext(deps, retry, tool_call.tool_name if tool_call else None), result
|
|
50
|
+
args = RunContext(deps, retry, messages, tool_call.tool_name if tool_call else None), result
|
|
46
51
|
else:
|
|
47
52
|
args = (result,)
|
|
48
53
|
|
|
@@ -54,10 +59,10 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
54
59
|
function = cast(Callable[[Any], ResultData], self.function)
|
|
55
60
|
result_data = await _utils.run_in_executor(function, *args)
|
|
56
61
|
except ModelRetry as r:
|
|
57
|
-
m =
|
|
62
|
+
m = _messages.RetryPromptPart(content=r.message)
|
|
58
63
|
if tool_call is not None:
|
|
59
64
|
m.tool_name = tool_call.tool_name
|
|
60
|
-
m.
|
|
65
|
+
m.tool_call_id = tool_call.tool_call_id
|
|
61
66
|
raise ToolRetryError(m) from r
|
|
62
67
|
else:
|
|
63
68
|
return result_data
|
|
@@ -66,7 +71,7 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
66
71
|
class ToolRetryError(Exception):
|
|
67
72
|
"""Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
68
73
|
|
|
69
|
-
def __init__(self, tool_retry:
|
|
74
|
+
def __init__(self, tool_retry: _messages.RetryPromptPart):
|
|
70
75
|
self.tool_retry = tool_retry
|
|
71
76
|
super().__init__()
|
|
72
77
|
|
|
@@ -108,11 +113,24 @@ class ResultSchema(Generic[ResultData]):
|
|
|
108
113
|
|
|
109
114
|
return cls(tools=tools, allow_text_result=allow_text_result)
|
|
110
115
|
|
|
111
|
-
def
|
|
116
|
+
def find_named_tool(
|
|
117
|
+
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
118
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
119
|
+
"""Find a tool that matches one of the calls, with a specific name."""
|
|
120
|
+
for part in parts:
|
|
121
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
122
|
+
if part.tool_name == tool_name:
|
|
123
|
+
return part, self.tools[tool_name]
|
|
124
|
+
|
|
125
|
+
def find_tool(
|
|
126
|
+
self,
|
|
127
|
+
parts: Iterable[_messages.ModelResponsePart],
|
|
128
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
112
129
|
"""Find a tool that matches one of the calls."""
|
|
113
|
-
for
|
|
114
|
-
if
|
|
115
|
-
|
|
130
|
+
for part in parts:
|
|
131
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
132
|
+
if result := self.tools.get(part.tool_name):
|
|
133
|
+
return part, result
|
|
116
134
|
|
|
117
135
|
def tool_names(self) -> list[str]:
|
|
118
136
|
"""Return the names of the tools."""
|
|
@@ -167,7 +185,7 @@ class ResultTool(Generic[ResultData]):
|
|
|
167
185
|
)
|
|
168
186
|
|
|
169
187
|
def validate(
|
|
170
|
-
self, tool_call:
|
|
188
|
+
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
|
|
171
189
|
) -> ResultData:
|
|
172
190
|
"""Validate a result message.
|
|
173
191
|
|
|
@@ -181,7 +199,7 @@ class ResultTool(Generic[ResultData]):
|
|
|
181
199
|
"""
|
|
182
200
|
try:
|
|
183
201
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
184
|
-
if isinstance(tool_call.args,
|
|
202
|
+
if isinstance(tool_call.args, _messages.ArgsJson):
|
|
185
203
|
result = self.type_adapter.validate_json(
|
|
186
204
|
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
|
|
187
205
|
)
|
|
@@ -191,10 +209,10 @@ class ResultTool(Generic[ResultData]):
|
|
|
191
209
|
)
|
|
192
210
|
except ValidationError as e:
|
|
193
211
|
if wrap_validation_errors:
|
|
194
|
-
m =
|
|
212
|
+
m = _messages.RetryPromptPart(
|
|
195
213
|
tool_name=tool_call.tool_name,
|
|
196
214
|
content=e.errors(include_url=False),
|
|
197
|
-
|
|
215
|
+
tool_call_id=tool_call.tool_call_id,
|
|
198
216
|
)
|
|
199
217
|
raise ToolRetryError(m) from e
|
|
200
218
|
else:
|
pydantic_ai/_system_prompt.py
CHANGED
pydantic_ai/_utils.py
CHANGED
|
@@ -15,6 +15,7 @@ from pydantic.json_schema import JsonSchemaValue
|
|
|
15
15
|
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
+
from .messages import RetryPromptPart, ToolCallPart, ToolReturnPart
|
|
18
19
|
from .tools import ObjectJsonSchema
|
|
19
20
|
|
|
20
21
|
_P = ParamSpec('_P')
|
|
@@ -87,7 +88,7 @@ class Either(Generic[Left, Right]):
|
|
|
87
88
|
|
|
88
89
|
Usage:
|
|
89
90
|
|
|
90
|
-
```
|
|
91
|
+
```python
|
|
91
92
|
if left_thing := either.left:
|
|
92
93
|
use_left(left_thing.value)
|
|
93
94
|
else:
|
|
@@ -146,7 +147,7 @@ async def group_by_temporal(
|
|
|
146
147
|
|
|
147
148
|
Usage:
|
|
148
149
|
|
|
149
|
-
```
|
|
150
|
+
```python
|
|
150
151
|
async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
|
|
151
152
|
async for groups in groups_iter:
|
|
152
153
|
print(groups)
|
|
@@ -254,3 +255,9 @@ def sync_anext(iterator: Iterator[T]) -> T:
|
|
|
254
255
|
|
|
255
256
|
def now_utc() -> datetime:
|
|
256
257
|
return datetime.now(tz=timezone.utc)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model_source: str) -> str:
|
|
261
|
+
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
|
|
262
|
+
assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
|
|
263
|
+
return t.tool_call_id
|