pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/result.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import AsyncIterator, Callable
|
|
4
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Generic, TypeVar, cast
|
|
8
8
|
|
|
9
9
|
import logfire_api
|
|
10
10
|
|
|
11
|
-
from . import _result, _utils, exceptions, messages, models
|
|
12
|
-
from .
|
|
11
|
+
from . import _result, _utils, exceptions, messages as _messages, models
|
|
12
|
+
from .settings import UsageLimits
|
|
13
|
+
from .tools import AgentDeps, RunContext
|
|
13
14
|
|
|
14
15
|
__all__ = (
|
|
15
16
|
'ResultData',
|
|
16
|
-
'
|
|
17
|
+
'Usage',
|
|
17
18
|
'RunResult',
|
|
18
19
|
'StreamedRunResult',
|
|
19
20
|
)
|
|
@@ -26,30 +27,32 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
@dataclass
|
|
29
|
-
class
|
|
30
|
-
"""
|
|
30
|
+
class Usage:
|
|
31
|
+
"""LLM usage associated to a request or run.
|
|
31
32
|
|
|
32
|
-
Responsibility for calculating
|
|
33
|
+
Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
|
|
33
34
|
|
|
34
|
-
You'll need to look up the documentation of the model you're using to
|
|
35
|
+
You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
|
|
35
36
|
"""
|
|
36
37
|
|
|
38
|
+
requests: int = 0
|
|
39
|
+
"""Number of requests made."""
|
|
37
40
|
request_tokens: int | None = None
|
|
38
|
-
"""Tokens used in processing
|
|
41
|
+
"""Tokens used in processing requests."""
|
|
39
42
|
response_tokens: int | None = None
|
|
40
|
-
"""Tokens used in generating
|
|
43
|
+
"""Tokens used in generating responses."""
|
|
41
44
|
total_tokens: int | None = None
|
|
42
45
|
"""Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
|
|
43
46
|
details: dict[str, int] | None = None
|
|
44
47
|
"""Any extra details returned by the model."""
|
|
45
48
|
|
|
46
|
-
def __add__(self, other:
|
|
47
|
-
"""Add two
|
|
49
|
+
def __add__(self, other: Usage) -> Usage:
|
|
50
|
+
"""Add two Usages together.
|
|
48
51
|
|
|
49
|
-
This is provided so it's trivial to sum
|
|
52
|
+
This is provided so it's trivial to sum usage information from multiple requests and runs.
|
|
50
53
|
"""
|
|
51
54
|
counts: dict[str, int] = {}
|
|
52
|
-
for f in 'request_tokens', 'response_tokens', 'total_tokens':
|
|
55
|
+
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
|
|
53
56
|
self_value = getattr(self, f)
|
|
54
57
|
other_value = getattr(other, f)
|
|
55
58
|
if self_value is not None or other_value is not None:
|
|
@@ -61,7 +64,7 @@ class Cost:
|
|
|
61
64
|
for key, value in other.details.items():
|
|
62
65
|
details[key] = details.get(key, 0) + value
|
|
63
66
|
|
|
64
|
-
return
|
|
67
|
+
return Usage(**counts, details=details or None)
|
|
65
68
|
|
|
66
69
|
|
|
67
70
|
@dataclass
|
|
@@ -71,19 +74,19 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
71
74
|
You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
|
|
72
75
|
"""
|
|
73
76
|
|
|
74
|
-
_all_messages: list[
|
|
77
|
+
_all_messages: list[_messages.ModelMessage]
|
|
75
78
|
_new_message_index: int
|
|
76
79
|
|
|
77
|
-
def all_messages(self) -> list[
|
|
78
|
-
"""Return the history of
|
|
80
|
+
def all_messages(self) -> list[_messages.ModelMessage]:
|
|
81
|
+
"""Return the history of _messages."""
|
|
79
82
|
# this is a method to be consistent with the other methods
|
|
80
83
|
return self._all_messages
|
|
81
84
|
|
|
82
85
|
def all_messages_json(self) -> bytes:
|
|
83
|
-
"""Return all messages from [`all_messages`][
|
|
84
|
-
return
|
|
86
|
+
"""Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
|
|
87
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
|
|
85
88
|
|
|
86
|
-
def new_messages(self) -> list[
|
|
89
|
+
def new_messages(self) -> list[_messages.ModelMessage]:
|
|
87
90
|
"""Return new messages associated with this run.
|
|
88
91
|
|
|
89
92
|
System prompts and any messages from older runs are excluded.
|
|
@@ -91,11 +94,11 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
91
94
|
return self.all_messages()[self._new_message_index :]
|
|
92
95
|
|
|
93
96
|
def new_messages_json(self) -> bytes:
|
|
94
|
-
"""Return new messages from [`new_messages`][
|
|
95
|
-
return
|
|
97
|
+
"""Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
|
|
98
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
|
|
96
99
|
|
|
97
100
|
@abstractmethod
|
|
98
|
-
def
|
|
101
|
+
def usage(self) -> Usage:
|
|
99
102
|
raise NotImplementedError()
|
|
100
103
|
|
|
101
104
|
|
|
@@ -105,24 +108,26 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
105
108
|
|
|
106
109
|
data: ResultData
|
|
107
110
|
"""Data from the final response in the run."""
|
|
108
|
-
|
|
111
|
+
_usage: Usage
|
|
109
112
|
|
|
110
|
-
def
|
|
111
|
-
"""Return the
|
|
112
|
-
return self.
|
|
113
|
+
def usage(self) -> Usage:
|
|
114
|
+
"""Return the usage of the whole run."""
|
|
115
|
+
return self._usage
|
|
113
116
|
|
|
114
117
|
|
|
115
118
|
@dataclass
|
|
116
119
|
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
|
|
117
120
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
118
121
|
|
|
119
|
-
|
|
120
|
-
"""
|
|
122
|
+
usage_so_far: Usage
|
|
123
|
+
"""Usage of the run up until the last request."""
|
|
124
|
+
_usage_limits: UsageLimits | None
|
|
121
125
|
_stream_response: models.EitherStreamedResponse
|
|
122
126
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
123
|
-
|
|
127
|
+
_run_ctx: RunContext[AgentDeps]
|
|
124
128
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
|
|
125
|
-
|
|
129
|
+
_result_tool_name: str | None
|
|
130
|
+
_on_complete: Callable[[], Awaitable[None]]
|
|
126
131
|
is_complete: bool = field(default=False, init=False)
|
|
127
132
|
"""Whether the stream has all been received.
|
|
128
133
|
|
|
@@ -172,11 +177,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
172
177
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
173
178
|
performing validation as each token is received.
|
|
174
179
|
"""
|
|
180
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
181
|
+
self._stream_response, self._usage_limits, self.usage
|
|
182
|
+
)
|
|
183
|
+
|
|
175
184
|
with _logfire.span('response stream text') as lf_span:
|
|
176
185
|
if isinstance(self._stream_response, models.StreamStructuredResponse):
|
|
177
186
|
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
178
187
|
if delta:
|
|
179
|
-
async with _utils.group_by_temporal(
|
|
188
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
180
189
|
async for _ in group_iter:
|
|
181
190
|
yield ''.join(self._stream_response.get())
|
|
182
191
|
final_delta = ''.join(self._stream_response.get(final=True))
|
|
@@ -187,7 +196,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
187
196
|
# yielding at each step
|
|
188
197
|
chunks: list[str] = []
|
|
189
198
|
combined = ''
|
|
190
|
-
async with _utils.group_by_temporal(
|
|
199
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
191
200
|
async for _ in group_iter:
|
|
192
201
|
new = False
|
|
193
202
|
for chunk in self._stream_response.get():
|
|
@@ -205,11 +214,11 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
205
214
|
combined = await self._validate_text_result(''.join(chunks))
|
|
206
215
|
yield combined
|
|
207
216
|
lf_span.set_attribute('combined_text', combined)
|
|
208
|
-
self._marked_completed(
|
|
217
|
+
await self._marked_completed(_messages.ModelResponse.from_text(combined))
|
|
209
218
|
|
|
210
219
|
async def stream_structured(
|
|
211
220
|
self, *, debounce_by: float | None = 0.1
|
|
212
|
-
) -> AsyncIterator[tuple[
|
|
221
|
+
) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
|
|
213
222
|
"""Stream the response as an async iterable of Structured LLM Messages.
|
|
214
223
|
|
|
215
224
|
!!! note
|
|
@@ -224,61 +233,75 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
224
233
|
Returns:
|
|
225
234
|
An async iterable of the structured response message and whether that is the last message.
|
|
226
235
|
"""
|
|
236
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
237
|
+
self._stream_response, self._usage_limits, self.usage
|
|
238
|
+
)
|
|
239
|
+
|
|
227
240
|
with _logfire.span('response stream structured') as lf_span:
|
|
228
241
|
if isinstance(self._stream_response, models.StreamTextResponse):
|
|
229
242
|
raise exceptions.UserError('stream_structured() can only be used with structured responses')
|
|
230
243
|
else:
|
|
231
244
|
# we should already have a message at this point, yield that first if it has any content
|
|
232
245
|
msg = self._stream_response.get()
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
246
|
+
for item in msg.parts:
|
|
247
|
+
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
248
|
+
yield msg, False
|
|
249
|
+
break
|
|
250
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
236
251
|
async for _ in group_iter:
|
|
237
252
|
msg = self._stream_response.get()
|
|
238
|
-
|
|
239
|
-
|
|
253
|
+
for item in msg.parts:
|
|
254
|
+
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
255
|
+
yield msg, False
|
|
256
|
+
break
|
|
240
257
|
msg = self._stream_response.get(final=True)
|
|
241
258
|
yield msg, True
|
|
242
259
|
lf_span.set_attribute('structured_response', msg)
|
|
243
|
-
self._marked_completed(
|
|
260
|
+
await self._marked_completed(msg)
|
|
244
261
|
|
|
245
262
|
async def get_data(self) -> ResultData:
|
|
246
263
|
"""Stream the whole response, validate and return it."""
|
|
247
|
-
|
|
264
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
265
|
+
self._stream_response, self._usage_limits, self.usage
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
async for _ in usage_checking_stream:
|
|
248
269
|
pass
|
|
270
|
+
|
|
249
271
|
if isinstance(self._stream_response, models.StreamTextResponse):
|
|
250
272
|
text = ''.join(self._stream_response.get(final=True))
|
|
251
273
|
text = await self._validate_text_result(text)
|
|
252
|
-
self._marked_completed(text
|
|
274
|
+
await self._marked_completed(_messages.ModelResponse.from_text(text))
|
|
253
275
|
return cast(ResultData, text)
|
|
254
276
|
else:
|
|
255
|
-
|
|
256
|
-
self._marked_completed(
|
|
257
|
-
return await self.validate_structured_result(
|
|
277
|
+
message = self._stream_response.get(final=True)
|
|
278
|
+
await self._marked_completed(message)
|
|
279
|
+
return await self.validate_structured_result(message)
|
|
258
280
|
|
|
259
281
|
@property
|
|
260
282
|
def is_structured(self) -> bool:
|
|
261
283
|
"""Return whether the stream response contains structured data (as opposed to text)."""
|
|
262
284
|
return isinstance(self._stream_response, models.StreamStructuredResponse)
|
|
263
285
|
|
|
264
|
-
def
|
|
265
|
-
"""Return the
|
|
286
|
+
def usage(self) -> Usage:
|
|
287
|
+
"""Return the usage of the whole run.
|
|
266
288
|
|
|
267
289
|
!!! note
|
|
268
|
-
This won't return the full
|
|
290
|
+
This won't return the full usage until the stream is finished.
|
|
269
291
|
"""
|
|
270
|
-
return self.
|
|
292
|
+
return self.usage_so_far + self._stream_response.usage()
|
|
271
293
|
|
|
272
294
|
def timestamp(self) -> datetime:
|
|
273
295
|
"""Get the timestamp of the response."""
|
|
274
296
|
return self._stream_response.timestamp()
|
|
275
297
|
|
|
276
298
|
async def validate_structured_result(
|
|
277
|
-
self, message:
|
|
299
|
+
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
278
300
|
) -> ResultData:
|
|
279
301
|
"""Validate a structured result message."""
|
|
280
302
|
assert self._result_schema is not None, 'Expected _result_schema to not be None'
|
|
281
|
-
|
|
303
|
+
assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None'
|
|
304
|
+
match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
|
|
282
305
|
if match is None:
|
|
283
306
|
raise exceptions.UnexpectedModelBehavior(
|
|
284
307
|
f'Invalid message, unable to find tool: {self._result_schema.tool_names()}'
|
|
@@ -288,29 +311,34 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
288
311
|
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
|
|
289
312
|
|
|
290
313
|
for validator in self._result_validators:
|
|
291
|
-
result_data = await validator.validate(result_data, self.
|
|
314
|
+
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
292
315
|
return result_data
|
|
293
316
|
|
|
294
317
|
async def _validate_text_result(self, text: str) -> str:
|
|
295
318
|
for validator in self._result_validators:
|
|
296
319
|
text = await validator.validate( # pyright: ignore[reportAssignmentType]
|
|
297
320
|
text, # pyright: ignore[reportArgumentType]
|
|
298
|
-
self._deps,
|
|
299
|
-
0,
|
|
300
321
|
None,
|
|
322
|
+
self._run_ctx,
|
|
301
323
|
)
|
|
302
324
|
return text
|
|
303
325
|
|
|
304
|
-
def _marked_completed(
|
|
305
|
-
self, *, text: str | None = None, structured_message: messages.ModelStructuredResponse | None = None
|
|
306
|
-
) -> None:
|
|
326
|
+
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
|
|
307
327
|
self.is_complete = True
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
328
|
+
self._all_messages.append(message)
|
|
329
|
+
await self._on_complete()
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _get_usage_checking_stream_response(
|
|
333
|
+
stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage]
|
|
334
|
+
) -> AsyncIterator[ResultData]:
|
|
335
|
+
if limits is not None and limits.has_token_limits():
|
|
336
|
+
|
|
337
|
+
async def _usage_checking_iterator():
|
|
338
|
+
async for item in stream_response:
|
|
339
|
+
limits.check_tokens(get_usage())
|
|
340
|
+
yield item
|
|
341
|
+
|
|
342
|
+
return _usage_checking_iterator()
|
|
343
|
+
else:
|
|
344
|
+
return stream_response
|
pydantic_ai/settings.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from httpx import Timeout
|
|
7
|
+
from typing_extensions import TypedDict
|
|
8
|
+
|
|
9
|
+
from .exceptions import UsageLimitExceeded
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .result import Usage
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelSettings(TypedDict, total=False):
|
|
16
|
+
"""Settings to configure an LLM.
|
|
17
|
+
|
|
18
|
+
Here we include only settings which apply to multiple models / model providers.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
max_tokens: int
|
|
22
|
+
"""The maximum number of tokens to generate before stopping.
|
|
23
|
+
|
|
24
|
+
Supported by:
|
|
25
|
+
* Gemini
|
|
26
|
+
* Anthropic
|
|
27
|
+
* OpenAI
|
|
28
|
+
* Groq
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
temperature: float
|
|
32
|
+
"""Amount of randomness injected into the response.
|
|
33
|
+
|
|
34
|
+
Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's
|
|
35
|
+
maximum `temperature` for creative and generative tasks.
|
|
36
|
+
|
|
37
|
+
Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
|
|
38
|
+
|
|
39
|
+
Supported by:
|
|
40
|
+
* Gemini
|
|
41
|
+
* Anthropic
|
|
42
|
+
* OpenAI
|
|
43
|
+
* Groq
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
top_p: float
|
|
47
|
+
"""An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
|
|
48
|
+
|
|
49
|
+
So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
|
50
|
+
|
|
51
|
+
You should either alter `temperature` or `top_p`, but not both.
|
|
52
|
+
|
|
53
|
+
Supported by:
|
|
54
|
+
* Gemini
|
|
55
|
+
* Anthropic
|
|
56
|
+
* OpenAI
|
|
57
|
+
* Groq
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
timeout: float | Timeout
|
|
61
|
+
"""Override the client-level default timeout for a request, in seconds.
|
|
62
|
+
|
|
63
|
+
Supported by:
|
|
64
|
+
* Gemini
|
|
65
|
+
* Anthropic
|
|
66
|
+
* OpenAI
|
|
67
|
+
* Groq
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
|
|
72
|
+
"""Merge two sets of model settings, preferring the overrides.
|
|
73
|
+
|
|
74
|
+
A common use case is: merge_model_settings(<agent settings>, <run settings>)
|
|
75
|
+
"""
|
|
76
|
+
# Note: we may want merge recursively if/when we add non-primitive values
|
|
77
|
+
if base and overrides:
|
|
78
|
+
return base | overrides
|
|
79
|
+
else:
|
|
80
|
+
return base or overrides
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class UsageLimits:
|
|
85
|
+
"""Limits on model usage.
|
|
86
|
+
|
|
87
|
+
The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
|
|
88
|
+
Token counts are provided in responses from the model, and the token limits are checked after each response.
|
|
89
|
+
|
|
90
|
+
Each of the limits can be set to `None` to disable that limit.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
request_limit: int | None = 50
|
|
94
|
+
"""The maximum number of requests allowed to the model."""
|
|
95
|
+
request_tokens_limit: int | None = None
|
|
96
|
+
"""The maximum number of tokens allowed in requests to the model."""
|
|
97
|
+
response_tokens_limit: int | None = None
|
|
98
|
+
"""The maximum number of tokens allowed in responses from the model."""
|
|
99
|
+
total_tokens_limit: int | None = None
|
|
100
|
+
"""The maximum number of tokens allowed in requests and responses combined."""
|
|
101
|
+
|
|
102
|
+
def has_token_limits(self) -> bool:
|
|
103
|
+
"""Returns `True` if this instance places any limits on token counts.
|
|
104
|
+
|
|
105
|
+
If this returns `False`, the `check_tokens` method will never raise an error.
|
|
106
|
+
|
|
107
|
+
This is useful because if we have token limits, we need to check them after receiving each streamed message.
|
|
108
|
+
If there are no limits, we can skip that processing in the streaming response iterator.
|
|
109
|
+
"""
|
|
110
|
+
return any(
|
|
111
|
+
limit is not None
|
|
112
|
+
for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def check_before_request(self, usage: Usage) -> None:
|
|
116
|
+
"""Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
|
|
117
|
+
request_limit = self.request_limit
|
|
118
|
+
if request_limit is not None and usage.requests >= request_limit:
|
|
119
|
+
raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
|
|
120
|
+
|
|
121
|
+
def check_tokens(self, usage: Usage) -> None:
|
|
122
|
+
"""Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
|
|
123
|
+
request_tokens = usage.request_tokens or 0
|
|
124
|
+
if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
|
|
125
|
+
raise UsageLimitExceeded(
|
|
126
|
+
f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
response_tokens = usage.response_tokens or 0
|
|
130
|
+
if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
|
|
131
|
+
raise UsageLimitExceeded(
|
|
132
|
+
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
total_tokens = request_tokens + response_tokens
|
|
136
|
+
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
|
|
137
|
+
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
|
pydantic_ai/tools.py
CHANGED
|
@@ -1,23 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import dataclasses
|
|
3
4
|
import inspect
|
|
4
5
|
from collections.abc import Awaitable
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Any, Callable, Generic, TypeVar, Union, cast
|
|
7
8
|
|
|
8
9
|
from pydantic import ValidationError
|
|
9
10
|
from pydantic_core import SchemaValidator
|
|
10
11
|
from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
11
12
|
|
|
12
|
-
from . import _pydantic, _utils, messages
|
|
13
|
+
from . import _pydantic, _utils, messages as _messages, models
|
|
13
14
|
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
14
15
|
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
from .result import ResultData
|
|
17
|
-
else:
|
|
18
|
-
ResultData = Any
|
|
19
|
-
|
|
20
|
-
|
|
21
16
|
__all__ = (
|
|
22
17
|
'AgentDeps',
|
|
23
18
|
'RunContext',
|
|
@@ -37,7 +32,7 @@ AgentDeps = TypeVar('AgentDeps')
|
|
|
37
32
|
"""Type variable for agent dependencies."""
|
|
38
33
|
|
|
39
34
|
|
|
40
|
-
@dataclass
|
|
35
|
+
@dataclasses.dataclass
|
|
41
36
|
class RunContext(Generic[AgentDeps]):
|
|
42
37
|
"""Information about the current call."""
|
|
43
38
|
|
|
@@ -45,8 +40,23 @@ class RunContext(Generic[AgentDeps]):
|
|
|
45
40
|
"""Dependencies for the agent."""
|
|
46
41
|
retry: int
|
|
47
42
|
"""Number of retries so far."""
|
|
48
|
-
|
|
43
|
+
messages: list[_messages.ModelMessage]
|
|
44
|
+
"""Messages exchanged in the conversation so far."""
|
|
45
|
+
tool_name: str | None
|
|
49
46
|
"""Name of the tool being called."""
|
|
47
|
+
model: models.Model
|
|
48
|
+
"""The model used in this run."""
|
|
49
|
+
|
|
50
|
+
def replace_with(
|
|
51
|
+
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
|
|
52
|
+
) -> RunContext[AgentDeps]:
|
|
53
|
+
# Create a new `RunContext` a new `retry` value and `tool_name`.
|
|
54
|
+
kwargs = {}
|
|
55
|
+
if retry is not None:
|
|
56
|
+
kwargs['retry'] = retry
|
|
57
|
+
if tool_name is not _utils.UNSET:
|
|
58
|
+
kwargs['tool_name'] = tool_name
|
|
59
|
+
return dataclasses.replace(self, **kwargs)
|
|
50
60
|
|
|
51
61
|
|
|
52
62
|
ToolParams = ParamSpec('ToolParams')
|
|
@@ -63,6 +73,8 @@ SystemPromptFunc = Union[
|
|
|
63
73
|
Usage `SystemPromptFunc[AgentDeps]`.
|
|
64
74
|
"""
|
|
65
75
|
|
|
76
|
+
ResultData = TypeVar('ResultData')
|
|
77
|
+
|
|
66
78
|
ResultValidatorFunc = Union[
|
|
67
79
|
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
68
80
|
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
@@ -87,7 +99,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
|
|
|
87
99
|
Usage `ToolPlainFunc[ToolParams]`.
|
|
88
100
|
"""
|
|
89
101
|
ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
|
|
90
|
-
"""Either
|
|
102
|
+
"""Either part_kind of tool function.
|
|
91
103
|
|
|
92
104
|
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
|
|
93
105
|
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
|
|
@@ -97,11 +109,11 @@ Usage `ToolFuncEither[AgentDeps, ToolParams]`.
|
|
|
97
109
|
ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDeps], ToolDefinition], Awaitable[ToolDefinition | None]]'
|
|
98
110
|
"""Definition of a function that can prepare a tool definition at call time.
|
|
99
111
|
|
|
100
|
-
See [tool docs](../
|
|
112
|
+
See [tool docs](../tools.md#tool-prepare) for more information.
|
|
101
113
|
|
|
102
114
|
Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
|
|
103
115
|
|
|
104
|
-
```
|
|
116
|
+
```python {lint="not-imports"}
|
|
105
117
|
from typing import Union
|
|
106
118
|
|
|
107
119
|
from pydantic_ai import RunContext, Tool
|
|
@@ -157,7 +169,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
157
169
|
|
|
158
170
|
Example usage:
|
|
159
171
|
|
|
160
|
-
```
|
|
172
|
+
```python {lint="not-imports"}
|
|
161
173
|
from pydantic_ai import Agent, RunContext, Tool
|
|
162
174
|
|
|
163
175
|
async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
|
|
@@ -168,7 +180,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
168
180
|
|
|
169
181
|
or with a custom prepare method:
|
|
170
182
|
|
|
171
|
-
```
|
|
183
|
+
```python {lint="not-imports"}
|
|
172
184
|
from typing import Union
|
|
173
185
|
|
|
174
186
|
from pydantic_ai import Agent, RunContext, Tool
|
|
@@ -235,17 +247,19 @@ class Tool(Generic[AgentDeps]):
|
|
|
235
247
|
else:
|
|
236
248
|
return tool_def
|
|
237
249
|
|
|
238
|
-
async def run(
|
|
250
|
+
async def run(
|
|
251
|
+
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDeps]
|
|
252
|
+
) -> _messages.ModelRequestPart:
|
|
239
253
|
"""Run the tool function asynchronously."""
|
|
240
254
|
try:
|
|
241
|
-
if isinstance(message.args,
|
|
255
|
+
if isinstance(message.args, _messages.ArgsJson):
|
|
242
256
|
args_dict = self._validator.validate_json(message.args.args_json)
|
|
243
257
|
else:
|
|
244
258
|
args_dict = self._validator.validate_python(message.args.args_dict)
|
|
245
259
|
except ValidationError as e:
|
|
246
260
|
return self._on_error(e, message)
|
|
247
261
|
|
|
248
|
-
args, kwargs = self._call_args(
|
|
262
|
+
args, kwargs = self._call_args(args_dict, message, run_context)
|
|
249
263
|
try:
|
|
250
264
|
if self._is_async:
|
|
251
265
|
function = cast(Callable[[Any], Awaitable[str]], self.function)
|
|
@@ -257,19 +271,23 @@ class Tool(Generic[AgentDeps]):
|
|
|
257
271
|
return self._on_error(e, message)
|
|
258
272
|
|
|
259
273
|
self.current_retry = 0
|
|
260
|
-
return
|
|
274
|
+
return _messages.ToolReturnPart(
|
|
261
275
|
tool_name=message.tool_name,
|
|
262
276
|
content=response_content,
|
|
263
|
-
|
|
277
|
+
tool_call_id=message.tool_call_id,
|
|
264
278
|
)
|
|
265
279
|
|
|
266
280
|
def _call_args(
|
|
267
|
-
self,
|
|
281
|
+
self,
|
|
282
|
+
args_dict: dict[str, Any],
|
|
283
|
+
message: _messages.ToolCallPart,
|
|
284
|
+
run_context: RunContext[AgentDeps],
|
|
268
285
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
269
286
|
if self._single_arg_name:
|
|
270
287
|
args_dict = {self._single_arg_name: args_dict}
|
|
271
288
|
|
|
272
|
-
|
|
289
|
+
ctx = dataclasses.replace(run_context, retry=self.current_retry, tool_name=message.tool_name)
|
|
290
|
+
args = [ctx] if self.takes_ctx else []
|
|
273
291
|
for positional_field in self._positional_fields:
|
|
274
292
|
args.append(args_dict.pop(positional_field))
|
|
275
293
|
if self._var_positional_field:
|
|
@@ -277,7 +295,9 @@ class Tool(Generic[AgentDeps]):
|
|
|
277
295
|
|
|
278
296
|
return args, args_dict
|
|
279
297
|
|
|
280
|
-
def _on_error(
|
|
298
|
+
def _on_error(
|
|
299
|
+
self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
|
|
300
|
+
) -> _messages.RetryPromptPart:
|
|
281
301
|
self.current_retry += 1
|
|
282
302
|
if self.max_retries is None or self.current_retry > self.max_retries:
|
|
283
303
|
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
|
|
@@ -286,10 +306,10 @@ class Tool(Generic[AgentDeps]):
|
|
|
286
306
|
content = exc.errors(include_url=False)
|
|
287
307
|
else:
|
|
288
308
|
content = exc.message
|
|
289
|
-
return
|
|
309
|
+
return _messages.RetryPromptPart(
|
|
290
310
|
tool_name=call_message.tool_name,
|
|
291
311
|
content=content,
|
|
292
|
-
|
|
312
|
+
tool_call_id=call_message.tool_call_id,
|
|
293
313
|
)
|
|
294
314
|
|
|
295
315
|
|
|
@@ -298,7 +318,7 @@ ObjectJsonSchema: TypeAlias = dict[str, Any]
|
|
|
298
318
|
|
|
299
319
|
This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition].
|
|
300
320
|
|
|
301
|
-
With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `
|
|
321
|
+
With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
|
|
302
322
|
"""
|
|
303
323
|
|
|
304
324
|
|