pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +14 -3
- pydantic_ai/_result.py +6 -9
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/agent.py +154 -90
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +29 -7
- pydantic_ai/models/__init__.py +10 -9
- pydantic_ai/models/anthropic.py +12 -12
- pydantic_ai/models/function.py +16 -22
- pydantic_ai/models/gemini.py +16 -18
- pydantic_ai/models/groq.py +21 -23
- pydantic_ai/models/mistral.py +34 -51
- pydantic_ai/models/openai.py +21 -23
- pydantic_ai/models/test.py +23 -17
- pydantic_ai/result.py +82 -35
- pydantic_ai/settings.py +69 -0
- pydantic_ai/tools.py +22 -28
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/METADATA +1 -2
- pydantic_ai_slim-0.0.15.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.13.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/WHEEL +0 -0
pydantic_ai/result.py
CHANGED
|
@@ -4,52 +4,72 @@ from abc import ABC, abstractmethod
|
|
|
4
4
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Generic,
|
|
7
|
+
from typing import Generic, Union, cast
|
|
8
8
|
|
|
9
9
|
import logfire_api
|
|
10
|
+
from typing_extensions import TypeVar
|
|
10
11
|
|
|
11
12
|
from . import _result, _utils, exceptions, messages as _messages, models
|
|
12
|
-
from .
|
|
13
|
+
from .settings import UsageLimits
|
|
14
|
+
from .tools import AgentDeps, RunContext
|
|
13
15
|
|
|
14
16
|
__all__ = (
|
|
15
17
|
'ResultData',
|
|
16
|
-
'
|
|
18
|
+
'ResultValidatorFunc',
|
|
19
|
+
'Usage',
|
|
17
20
|
'RunResult',
|
|
18
21
|
'StreamedRunResult',
|
|
19
22
|
)
|
|
20
23
|
|
|
21
24
|
|
|
22
|
-
ResultData = TypeVar('ResultData')
|
|
25
|
+
ResultData = TypeVar('ResultData', default=str)
|
|
23
26
|
"""Type variable for the result data of a run."""
|
|
24
27
|
|
|
28
|
+
ResultValidatorFunc = Union[
|
|
29
|
+
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
30
|
+
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
31
|
+
Callable[[ResultData], ResultData],
|
|
32
|
+
Callable[[ResultData], Awaitable[ResultData]],
|
|
33
|
+
]
|
|
34
|
+
"""
|
|
35
|
+
A function that always takes `ResultData` and returns `ResultData` and:
|
|
36
|
+
|
|
37
|
+
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
|
|
38
|
+
* may or may not be async
|
|
39
|
+
|
|
40
|
+
Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
|
|
41
|
+
"""
|
|
42
|
+
|
|
25
43
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
26
44
|
|
|
27
45
|
|
|
28
46
|
@dataclass
|
|
29
|
-
class
|
|
30
|
-
"""
|
|
47
|
+
class Usage:
|
|
48
|
+
"""LLM usage associated with a request or run.
|
|
31
49
|
|
|
32
|
-
Responsibility for calculating
|
|
50
|
+
Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
|
|
33
51
|
|
|
34
|
-
You'll need to look up the documentation of the model you're using to
|
|
52
|
+
You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
|
|
35
53
|
"""
|
|
36
54
|
|
|
55
|
+
requests: int = 0
|
|
56
|
+
"""Number of requests made to the LLM API."""
|
|
37
57
|
request_tokens: int | None = None
|
|
38
|
-
"""Tokens used in processing
|
|
58
|
+
"""Tokens used in processing requests."""
|
|
39
59
|
response_tokens: int | None = None
|
|
40
|
-
"""Tokens used in generating
|
|
60
|
+
"""Tokens used in generating responses."""
|
|
41
61
|
total_tokens: int | None = None
|
|
42
62
|
"""Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
|
|
43
63
|
details: dict[str, int] | None = None
|
|
44
64
|
"""Any extra details returned by the model."""
|
|
45
65
|
|
|
46
|
-
def __add__(self, other:
|
|
47
|
-
"""Add two
|
|
66
|
+
def __add__(self, other: Usage) -> Usage:
|
|
67
|
+
"""Add two Usages together.
|
|
48
68
|
|
|
49
|
-
This is provided so it's trivial to sum
|
|
69
|
+
This is provided so it's trivial to sum usage information from multiple requests and runs.
|
|
50
70
|
"""
|
|
51
71
|
counts: dict[str, int] = {}
|
|
52
|
-
for f in 'request_tokens', 'response_tokens', 'total_tokens':
|
|
72
|
+
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
|
|
53
73
|
self_value = getattr(self, f)
|
|
54
74
|
other_value = getattr(other, f)
|
|
55
75
|
if self_value is not None or other_value is not None:
|
|
@@ -61,7 +81,7 @@ class Cost:
|
|
|
61
81
|
for key, value in other.details.items():
|
|
62
82
|
details[key] = details.get(key, 0) + value
|
|
63
83
|
|
|
64
|
-
return
|
|
84
|
+
return Usage(**counts, details=details or None)
|
|
65
85
|
|
|
66
86
|
|
|
67
87
|
@dataclass
|
|
@@ -95,7 +115,7 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
95
115
|
return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
|
|
96
116
|
|
|
97
117
|
@abstractmethod
|
|
98
|
-
def
|
|
118
|
+
def usage(self) -> Usage:
|
|
99
119
|
raise NotImplementedError()
|
|
100
120
|
|
|
101
121
|
|
|
@@ -105,22 +125,23 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
105
125
|
|
|
106
126
|
data: ResultData
|
|
107
127
|
"""Data from the final response in the run."""
|
|
108
|
-
|
|
128
|
+
_usage: Usage
|
|
109
129
|
|
|
110
|
-
def
|
|
111
|
-
"""Return the
|
|
112
|
-
return self.
|
|
130
|
+
def usage(self) -> Usage:
|
|
131
|
+
"""Return the usage of the whole run."""
|
|
132
|
+
return self._usage
|
|
113
133
|
|
|
114
134
|
|
|
115
135
|
@dataclass
|
|
116
136
|
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
|
|
117
137
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
118
138
|
|
|
119
|
-
|
|
120
|
-
"""
|
|
139
|
+
usage_so_far: Usage
|
|
140
|
+
"""Usage of the run up until the last request."""
|
|
141
|
+
_usage_limits: UsageLimits | None
|
|
121
142
|
_stream_response: models.EitherStreamedResponse
|
|
122
143
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
123
|
-
|
|
144
|
+
_run_ctx: RunContext[AgentDeps]
|
|
124
145
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
|
|
125
146
|
_result_tool_name: str | None
|
|
126
147
|
_on_complete: Callable[[], Awaitable[None]]
|
|
@@ -173,11 +194,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
173
194
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
174
195
|
performing validation as each token is received.
|
|
175
196
|
"""
|
|
197
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
198
|
+
self._stream_response, self._usage_limits, self.usage
|
|
199
|
+
)
|
|
200
|
+
|
|
176
201
|
with _logfire.span('response stream text') as lf_span:
|
|
177
202
|
if isinstance(self._stream_response, models.StreamStructuredResponse):
|
|
178
203
|
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
179
204
|
if delta:
|
|
180
|
-
async with _utils.group_by_temporal(
|
|
205
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
181
206
|
async for _ in group_iter:
|
|
182
207
|
yield ''.join(self._stream_response.get())
|
|
183
208
|
final_delta = ''.join(self._stream_response.get(final=True))
|
|
@@ -188,7 +213,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
188
213
|
# yielding at each step
|
|
189
214
|
chunks: list[str] = []
|
|
190
215
|
combined = ''
|
|
191
|
-
async with _utils.group_by_temporal(
|
|
216
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
192
217
|
async for _ in group_iter:
|
|
193
218
|
new = False
|
|
194
219
|
for chunk in self._stream_response.get():
|
|
@@ -225,6 +250,10 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
225
250
|
Returns:
|
|
226
251
|
An async iterable of the structured response message and whether that is the last message.
|
|
227
252
|
"""
|
|
253
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
254
|
+
self._stream_response, self._usage_limits, self.usage
|
|
255
|
+
)
|
|
256
|
+
|
|
228
257
|
with _logfire.span('response stream structured') as lf_span:
|
|
229
258
|
if isinstance(self._stream_response, models.StreamTextResponse):
|
|
230
259
|
raise exceptions.UserError('stream_structured() can only be used with structured responses')
|
|
@@ -235,7 +264,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
235
264
|
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
236
265
|
yield msg, False
|
|
237
266
|
break
|
|
238
|
-
async with _utils.group_by_temporal(
|
|
267
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
239
268
|
async for _ in group_iter:
|
|
240
269
|
msg = self._stream_response.get()
|
|
241
270
|
for item in msg.parts:
|
|
@@ -249,8 +278,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
249
278
|
|
|
250
279
|
async def get_data(self) -> ResultData:
|
|
251
280
|
"""Stream the whole response, validate and return it."""
|
|
252
|
-
|
|
281
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
282
|
+
self._stream_response, self._usage_limits, self.usage
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
async for _ in usage_checking_stream:
|
|
253
286
|
pass
|
|
287
|
+
|
|
254
288
|
if isinstance(self._stream_response, models.StreamTextResponse):
|
|
255
289
|
text = ''.join(self._stream_response.get(final=True))
|
|
256
290
|
text = await self._validate_text_result(text)
|
|
@@ -266,13 +300,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
266
300
|
"""Return whether the stream response contains structured data (as opposed to text)."""
|
|
267
301
|
return isinstance(self._stream_response, models.StreamStructuredResponse)
|
|
268
302
|
|
|
269
|
-
def
|
|
270
|
-
"""Return the
|
|
303
|
+
def usage(self) -> Usage:
|
|
304
|
+
"""Return the usage of the whole run.
|
|
271
305
|
|
|
272
306
|
!!! note
|
|
273
|
-
This won't return the full
|
|
307
|
+
This won't return the full usage until the stream is finished.
|
|
274
308
|
"""
|
|
275
|
-
return self.
|
|
309
|
+
return self.usage_so_far + self._stream_response.usage()
|
|
276
310
|
|
|
277
311
|
def timestamp(self) -> datetime:
|
|
278
312
|
"""Get the timestamp of the response."""
|
|
@@ -294,17 +328,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
294
328
|
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
|
|
295
329
|
|
|
296
330
|
for validator in self._result_validators:
|
|
297
|
-
result_data = await validator.validate(result_data,
|
|
331
|
+
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
298
332
|
return result_data
|
|
299
333
|
|
|
300
334
|
async def _validate_text_result(self, text: str) -> str:
|
|
301
335
|
for validator in self._result_validators:
|
|
302
336
|
text = await validator.validate( # pyright: ignore[reportAssignmentType]
|
|
303
337
|
text, # pyright: ignore[reportArgumentType]
|
|
304
|
-
self._deps,
|
|
305
|
-
0,
|
|
306
338
|
None,
|
|
307
|
-
self.
|
|
339
|
+
self._run_ctx,
|
|
308
340
|
)
|
|
309
341
|
return text
|
|
310
342
|
|
|
@@ -312,3 +344,18 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
312
344
|
self.is_complete = True
|
|
313
345
|
self._all_messages.append(message)
|
|
314
346
|
await self._on_complete()
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _get_usage_checking_stream_response(
|
|
350
|
+
stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage]
|
|
351
|
+
) -> AsyncIterator[ResultData]:
|
|
352
|
+
if limits is not None and limits.has_token_limits():
|
|
353
|
+
|
|
354
|
+
async def _usage_checking_iterator():
|
|
355
|
+
async for item in stream_response:
|
|
356
|
+
limits.check_tokens(get_usage())
|
|
357
|
+
yield item
|
|
358
|
+
|
|
359
|
+
return _usage_checking_iterator()
|
|
360
|
+
else:
|
|
361
|
+
return stream_response
|
pydantic_ai/settings.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
3
6
|
from httpx import Timeout
|
|
4
7
|
from typing_extensions import TypedDict
|
|
5
8
|
|
|
9
|
+
from .exceptions import UsageLimitExceeded
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .result import Usage
|
|
13
|
+
|
|
6
14
|
|
|
7
15
|
class ModelSettings(TypedDict, total=False):
|
|
8
16
|
"""Settings to configure an LLM.
|
|
@@ -14,6 +22,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
14
22
|
"""The maximum number of tokens to generate before stopping.
|
|
15
23
|
|
|
16
24
|
Supported by:
|
|
25
|
+
|
|
17
26
|
* Gemini
|
|
18
27
|
* Anthropic
|
|
19
28
|
* OpenAI
|
|
@@ -29,6 +38,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
29
38
|
Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
|
|
30
39
|
|
|
31
40
|
Supported by:
|
|
41
|
+
|
|
32
42
|
* Gemini
|
|
33
43
|
* Anthropic
|
|
34
44
|
* OpenAI
|
|
@@ -43,6 +53,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
43
53
|
You should either alter `temperature` or `top_p`, but not both.
|
|
44
54
|
|
|
45
55
|
Supported by:
|
|
56
|
+
|
|
46
57
|
* Gemini
|
|
47
58
|
* Anthropic
|
|
48
59
|
* OpenAI
|
|
@@ -53,6 +64,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
53
64
|
"""Override the client-level default timeout for a request, in seconds.
|
|
54
65
|
|
|
55
66
|
Supported by:
|
|
67
|
+
|
|
56
68
|
* Gemini
|
|
57
69
|
* Anthropic
|
|
58
70
|
* OpenAI
|
|
@@ -70,3 +82,60 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings |
|
|
|
70
82
|
return base | overrides
|
|
71
83
|
else:
|
|
72
84
|
return base or overrides
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class UsageLimits:
|
|
89
|
+
"""Limits on model usage.
|
|
90
|
+
|
|
91
|
+
The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
|
|
92
|
+
Token counts are provided in responses from the model, and the token limits are checked after each response.
|
|
93
|
+
|
|
94
|
+
Each of the limits can be set to `None` to disable that limit.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
request_limit: int | None = 50
|
|
98
|
+
"""The maximum number of requests allowed to the model."""
|
|
99
|
+
request_tokens_limit: int | None = None
|
|
100
|
+
"""The maximum number of tokens allowed in requests to the model."""
|
|
101
|
+
response_tokens_limit: int | None = None
|
|
102
|
+
"""The maximum number of tokens allowed in responses from the model."""
|
|
103
|
+
total_tokens_limit: int | None = None
|
|
104
|
+
"""The maximum number of tokens allowed in requests and responses combined."""
|
|
105
|
+
|
|
106
|
+
def has_token_limits(self) -> bool:
|
|
107
|
+
"""Returns `True` if this instance places any limits on token counts.
|
|
108
|
+
|
|
109
|
+
If this returns `False`, the `check_tokens` method will never raise an error.
|
|
110
|
+
|
|
111
|
+
This is useful because if we have token limits, we need to check them after receiving each streamed message.
|
|
112
|
+
If there are no limits, we can skip that processing in the streaming response iterator.
|
|
113
|
+
"""
|
|
114
|
+
return any(
|
|
115
|
+
limit is not None
|
|
116
|
+
for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def check_before_request(self, usage: Usage) -> None:
|
|
120
|
+
"""Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
|
|
121
|
+
request_limit = self.request_limit
|
|
122
|
+
if request_limit is not None and usage.requests >= request_limit:
|
|
123
|
+
raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
|
|
124
|
+
|
|
125
|
+
def check_tokens(self, usage: Usage) -> None:
|
|
126
|
+
"""Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
|
|
127
|
+
request_tokens = usage.request_tokens or 0
|
|
128
|
+
if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
|
|
129
|
+
raise UsageLimitExceeded(
|
|
130
|
+
f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
response_tokens = usage.response_tokens or 0
|
|
134
|
+
if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
|
|
135
|
+
raise UsageLimitExceeded(
|
|
136
|
+
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
total_tokens = request_tokens + response_tokens
|
|
140
|
+
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
|
|
141
|
+
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
|
pydantic_ai/tools.py
CHANGED
|
@@ -1,27 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import dataclasses
|
|
3
4
|
import inspect
|
|
4
5
|
from collections.abc import Awaitable
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Any, Callable, Generic, TypeVar, Union, cast
|
|
7
8
|
|
|
8
9
|
from pydantic import ValidationError
|
|
9
10
|
from pydantic_core import SchemaValidator
|
|
10
11
|
from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
11
12
|
|
|
12
|
-
from . import _pydantic, _utils, messages as _messages
|
|
13
|
+
from . import _pydantic, _utils, messages as _messages, models
|
|
13
14
|
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
14
15
|
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
from .result import ResultData
|
|
17
|
-
else:
|
|
18
|
-
ResultData = Any
|
|
19
|
-
|
|
20
|
-
|
|
21
16
|
__all__ = (
|
|
22
17
|
'AgentDeps',
|
|
23
18
|
'RunContext',
|
|
24
|
-
'ResultValidatorFunc',
|
|
25
19
|
'SystemPromptFunc',
|
|
26
20
|
'ToolFuncContext',
|
|
27
21
|
'ToolFuncPlain',
|
|
@@ -37,7 +31,7 @@ AgentDeps = TypeVar('AgentDeps')
|
|
|
37
31
|
"""Type variable for agent dependencies."""
|
|
38
32
|
|
|
39
33
|
|
|
40
|
-
@dataclass
|
|
34
|
+
@dataclasses.dataclass
|
|
41
35
|
class RunContext(Generic[AgentDeps]):
|
|
42
36
|
"""Information about the current call."""
|
|
43
37
|
|
|
@@ -49,6 +43,19 @@ class RunContext(Generic[AgentDeps]):
|
|
|
49
43
|
"""Messages exchanged in the conversation so far."""
|
|
50
44
|
tool_name: str | None
|
|
51
45
|
"""Name of the tool being called."""
|
|
46
|
+
model: models.Model
|
|
47
|
+
"""The model used in this run."""
|
|
48
|
+
|
|
49
|
+
def replace_with(
|
|
50
|
+
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
|
|
51
|
+
) -> RunContext[AgentDeps]:
|
|
52
|
+
# Create a new `RunContext` a new `retry` value and `tool_name`.
|
|
53
|
+
kwargs = {}
|
|
54
|
+
if retry is not None:
|
|
55
|
+
kwargs['retry'] = retry
|
|
56
|
+
if tool_name is not _utils.UNSET:
|
|
57
|
+
kwargs['tool_name'] = tool_name
|
|
58
|
+
return dataclasses.replace(self, **kwargs)
|
|
52
59
|
|
|
53
60
|
|
|
54
61
|
ToolParams = ParamSpec('ToolParams')
|
|
@@ -65,19 +72,6 @@ SystemPromptFunc = Union[
|
|
|
65
72
|
Usage `SystemPromptFunc[AgentDeps]`.
|
|
66
73
|
"""
|
|
67
74
|
|
|
68
|
-
ResultValidatorFunc = Union[
|
|
69
|
-
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
70
|
-
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
71
|
-
Callable[[ResultData], ResultData],
|
|
72
|
-
Callable[[ResultData], Awaitable[ResultData]],
|
|
73
|
-
]
|
|
74
|
-
"""
|
|
75
|
-
A function that always takes `ResultData` and returns `ResultData`,
|
|
76
|
-
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
|
|
77
|
-
|
|
78
|
-
Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
79
|
-
"""
|
|
80
|
-
|
|
81
75
|
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
|
|
82
76
|
"""A tool function that takes `RunContext` as the first argument.
|
|
83
77
|
|
|
@@ -238,7 +232,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
238
232
|
return tool_def
|
|
239
233
|
|
|
240
234
|
async def run(
|
|
241
|
-
self,
|
|
235
|
+
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDeps]
|
|
242
236
|
) -> _messages.ModelRequestPart:
|
|
243
237
|
"""Run the tool function asynchronously."""
|
|
244
238
|
try:
|
|
@@ -249,7 +243,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
249
243
|
except ValidationError as e:
|
|
250
244
|
return self._on_error(e, message)
|
|
251
245
|
|
|
252
|
-
args, kwargs = self._call_args(
|
|
246
|
+
args, kwargs = self._call_args(args_dict, message, run_context)
|
|
253
247
|
try:
|
|
254
248
|
if self._is_async:
|
|
255
249
|
function = cast(Callable[[Any], Awaitable[str]], self.function)
|
|
@@ -269,15 +263,15 @@ class Tool(Generic[AgentDeps]):
|
|
|
269
263
|
|
|
270
264
|
def _call_args(
|
|
271
265
|
self,
|
|
272
|
-
deps: AgentDeps,
|
|
273
266
|
args_dict: dict[str, Any],
|
|
274
267
|
message: _messages.ToolCallPart,
|
|
275
|
-
|
|
268
|
+
run_context: RunContext[AgentDeps],
|
|
276
269
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
277
270
|
if self._single_arg_name:
|
|
278
271
|
args_dict = {self._single_arg_name: args_dict}
|
|
279
272
|
|
|
280
|
-
|
|
273
|
+
ctx = dataclasses.replace(run_context, retry=self.current_retry, tool_name=message.tool_name)
|
|
274
|
+
args = [ctx] if self.takes_ctx else []
|
|
281
275
|
for positional_field in self._positional_fields:
|
|
282
276
|
args.append(args_dict.pop(positional_field))
|
|
283
277
|
if self._var_positional_field:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.15
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -36,7 +36,6 @@ Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
|
36
36
|
Provides-Extra: logfire
|
|
37
37
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
38
38
|
Provides-Extra: mistral
|
|
39
|
-
Requires-Dist: json-repair>=0.30.3; extra == 'mistral'
|
|
40
39
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
41
40
|
Provides-Extra: openai
|
|
42
41
|
Requires-Dist: openai>=1.54.3; extra == 'openai'
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
pydantic_ai/__init__.py,sha256=FbYetEgT6OO25u2KF5ZnFxKpz5DtnSpfckRXP4mjl8E,489
|
|
2
|
+
pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
|
|
3
|
+
pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
|
|
4
|
+
pydantic_ai/_result.py,sha256=cUSugZQV0n5Z4fFHiMqua-2xs_0S6m-rr-yd6QS3nFE,10317
|
|
5
|
+
pydantic_ai/_system_prompt.py,sha256=MZJWksIoS5GM3Au5lznlcQnC-h7eqwtE7oI5WFgRcOg,1090
|
|
6
|
+
pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
|
|
7
|
+
pydantic_ai/agent.py,sha256=qa3Ox5pXEDzxcTJgwN0gebV37qQKizVc0PW-1q5MMn4,51662
|
|
8
|
+
pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
|
|
9
|
+
pydantic_ai/messages.py,sha256=ImbWY8Ft3mxInUQ08EmIWywf4nJBvTiJhmsECRYDkSQ,8968
|
|
10
|
+
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
pydantic_ai/result.py,sha256=n2cEFwm8WhFzHuT6KRhZ2itQVPShGUd7ECbOPmRIIoM,15335
|
|
12
|
+
pydantic_ai/settings.py,sha256=R71rBg2u2SgjxKWcpUdSzm7icV5_apF3b0BlBqa2lpA,4927
|
|
13
|
+
pydantic_ai/tools.py,sha256=wNYzfdp1XjIVw_8bqh5GP3x3k_12gHTC74HWNvHAAwI,11447
|
|
14
|
+
pydantic_ai/models/__init__.py,sha256=XHt02IDQAircb-lEkIbIcuabSAIh5_UKnz2V1xN0Glw,10926
|
|
15
|
+
pydantic_ai/models/anthropic.py,sha256=EUZgmvT0jhMDbooBp_jfW0z2cM5jTMuAhVws1XKgaNs,13451
|
|
16
|
+
pydantic_ai/models/function.py,sha256=i7qkS_31aHrTbYVh6OzQ7Cwucz44F5PjT2EJK3GMphw,10573
|
|
17
|
+
pydantic_ai/models/gemini.py,sha256=8vdcW4izL9NUGFj6lcD9yIPaakCtsmHauTvKwlTzD14,28207
|
|
18
|
+
pydantic_ai/models/groq.py,sha256=ZoPkuWJrf78JPnTRfZhi7v0ETgxJKNN5dH8BLWagGGk,15770
|
|
19
|
+
pydantic_ai/models/mistral.py,sha256=xGVI6-b8-9vnFickPPI2cRaHEWLc0jKKUM_vMjipf-U,25894
|
|
20
|
+
pydantic_ai/models/ollama.py,sha256=i3mMXkXu9xL6f4c52Eyx3j4aHKfYoloFondlGHPtkS4,3971
|
|
21
|
+
pydantic_ai/models/openai.py,sha256=qFFInL3NbgfGcsAWigxMP5mscp76hC-jJimHc9woU6Y,16518
|
|
22
|
+
pydantic_ai/models/test.py,sha256=pty5qaudHsSDvdE89HqMj-kmd4UMV9VJI2YGtdfOX1o,15960
|
|
23
|
+
pydantic_ai/models/vertexai.py,sha256=DBCBfpvpIhZaMG7cKvRl5rugCZqJqqEFm74uBc45weo,9259
|
|
24
|
+
pydantic_ai_slim-0.0.15.dist-info/METADATA,sha256=CM_cQ6RRb9PFJVXKVa0JIVsp_bCrdCKWGnEu0KBiD0c,2730
|
|
25
|
+
pydantic_ai_slim-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
26
|
+
pydantic_ai_slim-0.0.15.dist-info/RECORD,,
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
pydantic_ai/__init__.py,sha256=a29NqQz0JyW4BoCjcRh23fBBfwY17_n57moE4QrFWM4,324
|
|
2
|
-
pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
|
|
3
|
-
pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
|
|
4
|
-
pydantic_ai/_result.py,sha256=LycNJR_Whc9P7sz2uD-Ce5bs9kQBU6eYqQxCUDNiNxU,10453
|
|
5
|
-
pydantic_ai/_system_prompt.py,sha256=2Ui7fYAXxR_aZZLJdwMnOAecBOIbrKwx1yV4Qz523WQ,1089
|
|
6
|
-
pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
|
|
7
|
-
pydantic_ai/agent.py,sha256=sDQE20lQXyWO-SrodqSNlzzGwtaLNSkla6NgyJXPKTU,48568
|
|
8
|
-
pydantic_ai/exceptions.py,sha256=ko_47M0k6Rhg9mUC9P1cj7N4LCH6cC0pEsF65A2vL-U,1561
|
|
9
|
-
pydantic_ai/messages.py,sha256=Qa9H5kf9qeI1NqB-XgRjPJ-wwgVKvDZxqam7gnsLtrc,8106
|
|
10
|
-
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
pydantic_ai/result.py,sha256=ZhaYArCiVl9JlrTllaeFIl2vU2foiIgpQYGby55G4cQ,13664
|
|
12
|
-
pydantic_ai/settings.py,sha256=3sUaDMVMBX9Pku4Bh7lpv6VizX1utenHd5kVIRJvHyY,1908
|
|
13
|
-
pydantic_ai/tools.py,sha256=hhhe5_ELeyYpaRoETMLjml83FAbMY7cpo5us7qwbOWg,11532
|
|
14
|
-
pydantic_ai/models/__init__.py,sha256=y1tkHgWjIGLhZX95dIeOXcljSiAiRGma2T55hNi8EwA,10897
|
|
15
|
-
pydantic_ai/models/anthropic.py,sha256=YBqiYjvOVqsSHPNT2Vt2aMaAAa8nMK57IMPONrtBCyc,13430
|
|
16
|
-
pydantic_ai/models/function.py,sha256=i54ce97lqmy1p7Vqc162EiF_72oA3khc7z2uBGZbuWg,10731
|
|
17
|
-
pydantic_ai/models/gemini.py,sha256=3oy0FVHLOP8SEOvpvoWtlUhRCJpddRj4-J_IPXaEkLE,28277
|
|
18
|
-
pydantic_ai/models/groq.py,sha256=dorqZk8xbZ4ZDzZothJoWbUkoD8TWHb6lftdkNDlsu0,15821
|
|
19
|
-
pydantic_ai/models/mistral.py,sha256=S0K73J5AGKwJc0UU0ifCrPmcxR2QMvVS6L1Cq19xDrk,26793
|
|
20
|
-
pydantic_ai/models/ollama.py,sha256=i3mMXkXu9xL6f4c52Eyx3j4aHKfYoloFondlGHPtkS4,3971
|
|
21
|
-
pydantic_ai/models/openai.py,sha256=fo3ocIOylD8YTuJMTJR1eXcRAQDGFKWFIYYrSOQS1C0,16569
|
|
22
|
-
pydantic_ai/models/test.py,sha256=mBQ0vJYjEMHv01A3yyHR2porkxekpmqIUBkK-W8d-L8,15530
|
|
23
|
-
pydantic_ai/models/vertexai.py,sha256=DBCBfpvpIhZaMG7cKvRl5rugCZqJqqEFm74uBc45weo,9259
|
|
24
|
-
pydantic_ai_slim-0.0.13.dist-info/METADATA,sha256=57JLefiQcRnSOVvkDHcik6wQ24FZ3HOmPrfCaBUe4X8,2785
|
|
25
|
-
pydantic_ai_slim-0.0.13.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
26
|
-
pydantic_ai_slim-0.0.13.dist-info/RECORD,,
|
|
File without changes
|