pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/__init__.py
CHANGED
|
@@ -1,8 +1,18 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent
|
|
4
|
-
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
|
|
4
|
+
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
|
|
5
5
|
from .tools import RunContext, Tool
|
|
6
6
|
|
|
7
|
-
__all__ =
|
|
7
|
+
__all__ = (
|
|
8
|
+
'Agent',
|
|
9
|
+
'RunContext',
|
|
10
|
+
'Tool',
|
|
11
|
+
'AgentRunError',
|
|
12
|
+
'ModelRetry',
|
|
13
|
+
'UnexpectedModelBehavior',
|
|
14
|
+
'UsageLimitExceeded',
|
|
15
|
+
'UserError',
|
|
16
|
+
'__version__',
|
|
17
|
+
)
|
|
8
18
|
__version__ = version('pydantic_ai_slim')
|
pydantic_ai/_pydantic.py
CHANGED
|
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
9
|
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
|
|
10
10
|
|
|
11
|
-
from pydantic import ConfigDict
|
|
11
|
+
from pydantic import ConfigDict
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
13
13
|
from pydantic._internal._config import ConfigWrapper
|
|
14
14
|
from pydantic.fields import FieldInfo
|
|
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|
|
23
23
|
from .tools import ObjectJsonSchema
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
__all__ = 'function_schema',
|
|
26
|
+
__all__ = ('function_schema',)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class FunctionSchema(TypedDict):
|
|
@@ -138,14 +138,14 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
|
|
|
138
138
|
json_schema = GenerateJsonSchema().generate(schema)
|
|
139
139
|
|
|
140
140
|
# workaround for https://github.com/pydantic/pydantic/issues/10785
|
|
141
|
-
# if we build a custom TypeDict schema (matches when `single_arg_name
|
|
141
|
+
# if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
|
|
142
142
|
# `additionalProperties` in the JSON Schema
|
|
143
143
|
if single_arg_name is None:
|
|
144
144
|
json_schema['additionalProperties'] = bool(var_kwargs_schema)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
145
|
+
elif not description:
|
|
146
|
+
# if the tool description is not set, and we have a single parameter, take the description from that
|
|
147
|
+
# and set it on the tool
|
|
148
|
+
description = json_schema.pop('description', None)
|
|
149
149
|
|
|
150
150
|
return FunctionSchema(
|
|
151
151
|
description=description,
|
|
@@ -214,21 +214,3 @@ def _is_call_ctx(annotation: Any) -> bool:
|
|
|
214
214
|
return annotation is RunContext or (
|
|
215
215
|
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
|
216
216
|
)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
if TYPE_CHECKING:
|
|
220
|
-
LazyTypeAdapter = TypeAdapter
|
|
221
|
-
else:
|
|
222
|
-
|
|
223
|
-
class LazyTypeAdapter:
|
|
224
|
-
__slots__ = '_args', '_kwargs', '_type_adapter'
|
|
225
|
-
|
|
226
|
-
def __init__(self, *args, **kwargs):
|
|
227
|
-
self._args = args
|
|
228
|
-
self._kwargs = kwargs
|
|
229
|
-
self._type_adapter = None
|
|
230
|
-
|
|
231
|
-
def __getattr__(self, item):
|
|
232
|
-
if self._type_adapter is None:
|
|
233
|
-
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
|
|
234
|
-
return getattr(self._type_adapter, item)
|
pydantic_ai/_result.py
CHANGED
|
@@ -3,16 +3,15 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import sys
|
|
5
5
|
import types
|
|
6
|
-
from collections.abc import Awaitable
|
|
6
|
+
from collections.abc import Awaitable, Iterable
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
|
|
9
9
|
|
|
10
10
|
from pydantic import TypeAdapter, ValidationError
|
|
11
11
|
from typing_extensions import Self, TypeAliasType, TypedDict
|
|
12
12
|
|
|
13
|
-
from . import _utils, messages
|
|
13
|
+
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .messages import ModelStructuredResponse, ToolCall
|
|
16
15
|
from .result import ResultData
|
|
17
16
|
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
|
|
18
17
|
|
|
@@ -28,21 +27,24 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
28
27
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
29
28
|
|
|
30
29
|
async def validate(
|
|
31
|
-
self,
|
|
30
|
+
self,
|
|
31
|
+
result: ResultData,
|
|
32
|
+
tool_call: _messages.ToolCallPart | None,
|
|
33
|
+
run_context: RunContext[AgentDeps],
|
|
32
34
|
) -> ResultData:
|
|
33
35
|
"""Validate a result but calling the function.
|
|
34
36
|
|
|
35
37
|
Args:
|
|
36
38
|
result: The result data after Pydantic validation the message content.
|
|
37
|
-
deps: The agent dependencies.
|
|
38
|
-
retry: The current retry number.
|
|
39
39
|
tool_call: The original tool call message, `None` if there was no tool call.
|
|
40
|
+
run_context: The current run context.
|
|
40
41
|
|
|
41
42
|
Returns:
|
|
42
43
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
43
44
|
"""
|
|
44
45
|
if self._takes_ctx:
|
|
45
|
-
|
|
46
|
+
ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
|
|
47
|
+
args = ctx, result
|
|
46
48
|
else:
|
|
47
49
|
args = (result,)
|
|
48
50
|
|
|
@@ -54,10 +56,10 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
54
56
|
function = cast(Callable[[Any], ResultData], self.function)
|
|
55
57
|
result_data = await _utils.run_in_executor(function, *args)
|
|
56
58
|
except ModelRetry as r:
|
|
57
|
-
m =
|
|
59
|
+
m = _messages.RetryPromptPart(content=r.message)
|
|
58
60
|
if tool_call is not None:
|
|
59
61
|
m.tool_name = tool_call.tool_name
|
|
60
|
-
m.
|
|
62
|
+
m.tool_call_id = tool_call.tool_call_id
|
|
61
63
|
raise ToolRetryError(m) from r
|
|
62
64
|
else:
|
|
63
65
|
return result_data
|
|
@@ -66,7 +68,7 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
66
68
|
class ToolRetryError(Exception):
|
|
67
69
|
"""Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
68
70
|
|
|
69
|
-
def __init__(self, tool_retry:
|
|
71
|
+
def __init__(self, tool_retry: _messages.RetryPromptPart):
|
|
70
72
|
self.tool_retry = tool_retry
|
|
71
73
|
super().__init__()
|
|
72
74
|
|
|
@@ -108,11 +110,24 @@ class ResultSchema(Generic[ResultData]):
|
|
|
108
110
|
|
|
109
111
|
return cls(tools=tools, allow_text_result=allow_text_result)
|
|
110
112
|
|
|
111
|
-
def
|
|
113
|
+
def find_named_tool(
|
|
114
|
+
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
115
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
116
|
+
"""Find a tool that matches one of the calls, with a specific name."""
|
|
117
|
+
for part in parts:
|
|
118
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
119
|
+
if part.tool_name == tool_name:
|
|
120
|
+
return part, self.tools[tool_name]
|
|
121
|
+
|
|
122
|
+
def find_tool(
|
|
123
|
+
self,
|
|
124
|
+
parts: Iterable[_messages.ModelResponsePart],
|
|
125
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
112
126
|
"""Find a tool that matches one of the calls."""
|
|
113
|
-
for
|
|
114
|
-
if
|
|
115
|
-
|
|
127
|
+
for part in parts:
|
|
128
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
129
|
+
if result := self.tools.get(part.tool_name):
|
|
130
|
+
return part, result
|
|
116
131
|
|
|
117
132
|
def tool_names(self) -> list[str]:
|
|
118
133
|
"""Return the names of the tools."""
|
|
@@ -167,7 +182,7 @@ class ResultTool(Generic[ResultData]):
|
|
|
167
182
|
)
|
|
168
183
|
|
|
169
184
|
def validate(
|
|
170
|
-
self, tool_call:
|
|
185
|
+
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
|
|
171
186
|
) -> ResultData:
|
|
172
187
|
"""Validate a result message.
|
|
173
188
|
|
|
@@ -181,7 +196,7 @@ class ResultTool(Generic[ResultData]):
|
|
|
181
196
|
"""
|
|
182
197
|
try:
|
|
183
198
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
184
|
-
if isinstance(tool_call.args,
|
|
199
|
+
if isinstance(tool_call.args, _messages.ArgsJson):
|
|
185
200
|
result = self.type_adapter.validate_json(
|
|
186
201
|
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
|
|
187
202
|
)
|
|
@@ -191,10 +206,10 @@ class ResultTool(Generic[ResultData]):
|
|
|
191
206
|
)
|
|
192
207
|
except ValidationError as e:
|
|
193
208
|
if wrap_validation_errors:
|
|
194
|
-
m =
|
|
209
|
+
m = _messages.RetryPromptPart(
|
|
195
210
|
tool_name=tool_call.tool_name,
|
|
196
211
|
content=e.errors(include_url=False),
|
|
197
|
-
|
|
212
|
+
tool_call_id=tool_call.tool_call_id,
|
|
198
213
|
)
|
|
199
214
|
raise ToolRetryError(m) from e
|
|
200
215
|
else:
|
pydantic_ai/_system_prompt.py
CHANGED
|
@@ -19,9 +19,9 @@ class SystemPromptRunner(Generic[AgentDeps]):
|
|
|
19
19
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
|
|
20
20
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
21
21
|
|
|
22
|
-
async def run(self,
|
|
22
|
+
async def run(self, run_context: RunContext[AgentDeps]) -> str:
|
|
23
23
|
if self._takes_ctx:
|
|
24
|
-
args = (
|
|
24
|
+
args = (run_context,)
|
|
25
25
|
else:
|
|
26
26
|
args = ()
|
|
27
27
|
|
pydantic_ai/_utils.py
CHANGED
|
@@ -15,6 +15,7 @@ from pydantic.json_schema import JsonSchemaValue
|
|
|
15
15
|
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
+
from .messages import RetryPromptPart, ToolCallPart, ToolReturnPart
|
|
18
19
|
from .tools import ObjectJsonSchema
|
|
19
20
|
|
|
20
21
|
_P = ParamSpec('_P')
|
|
@@ -87,7 +88,7 @@ class Either(Generic[Left, Right]):
|
|
|
87
88
|
|
|
88
89
|
Usage:
|
|
89
90
|
|
|
90
|
-
```
|
|
91
|
+
```python
|
|
91
92
|
if left_thing := either.left:
|
|
92
93
|
use_left(left_thing.value)
|
|
93
94
|
else:
|
|
@@ -146,7 +147,7 @@ async def group_by_temporal(
|
|
|
146
147
|
|
|
147
148
|
Usage:
|
|
148
149
|
|
|
149
|
-
```
|
|
150
|
+
```python
|
|
150
151
|
async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
|
|
151
152
|
async for groups in groups_iter:
|
|
152
153
|
print(groups)
|
|
@@ -254,3 +255,9 @@ def sync_anext(iterator: Iterator[T]) -> T:
|
|
|
254
255
|
|
|
255
256
|
def now_utc() -> datetime:
|
|
256
257
|
return datetime.now(tz=timezone.utc)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model_source: str) -> str:
|
|
261
|
+
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
|
|
262
|
+
assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
|
|
263
|
+
return t.tool_call_id
|