pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_result.py +4 -7
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/agent.py +85 -75
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +29 -7
- pydantic_ai/models/__init__.py +10 -9
- pydantic_ai/models/anthropic.py +12 -12
- pydantic_ai/models/function.py +16 -22
- pydantic_ai/models/gemini.py +16 -18
- pydantic_ai/models/groq.py +21 -23
- pydantic_ai/models/mistral.py +24 -36
- pydantic_ai/models/openai.py +21 -23
- pydantic_ai/models/test.py +23 -17
- pydantic_ai/result.py +63 -33
- pydantic_ai/settings.py +65 -0
- pydantic_ai/tools.py +24 -14
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.13.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
|
-
ArgsJson,
|
|
17
16
|
ModelMessage,
|
|
18
17
|
ModelRequest,
|
|
19
18
|
ModelResponse,
|
|
@@ -25,7 +24,7 @@ from ..messages import (
|
|
|
25
24
|
ToolReturnPart,
|
|
26
25
|
UserPromptPart,
|
|
27
26
|
)
|
|
28
|
-
from ..result import
|
|
27
|
+
from ..result import Usage
|
|
29
28
|
from ..settings import ModelSettings
|
|
30
29
|
from ..tools import ToolDefinition
|
|
31
30
|
from . import (
|
|
@@ -147,9 +146,9 @@ class OpenAIAgentModel(AgentModel):
|
|
|
147
146
|
|
|
148
147
|
async def request(
|
|
149
148
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
150
|
-
) -> tuple[ModelResponse, result.
|
|
149
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
151
150
|
response = await self._completions_create(messages, False, model_settings)
|
|
152
|
-
return self._process_response(response),
|
|
151
|
+
return self._process_response(response), _map_usage(response)
|
|
153
152
|
|
|
154
153
|
@asynccontextmanager
|
|
155
154
|
async def request_stream(
|
|
@@ -211,14 +210,14 @@ class OpenAIAgentModel(AgentModel):
|
|
|
211
210
|
items.append(TextPart(choice.message.content))
|
|
212
211
|
if choice.message.tool_calls is not None:
|
|
213
212
|
for c in choice.message.tool_calls:
|
|
214
|
-
items.append(ToolCallPart.
|
|
213
|
+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
215
214
|
return ModelResponse(items, timestamp=timestamp)
|
|
216
215
|
|
|
217
216
|
@staticmethod
|
|
218
217
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
219
218
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
220
219
|
timestamp: datetime | None = None
|
|
221
|
-
|
|
220
|
+
start_usage = Usage()
|
|
222
221
|
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
223
222
|
while True:
|
|
224
223
|
try:
|
|
@@ -227,19 +226,19 @@ class OpenAIAgentModel(AgentModel):
|
|
|
227
226
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
228
227
|
|
|
229
228
|
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
230
|
-
|
|
229
|
+
start_usage += _map_usage(chunk)
|
|
231
230
|
|
|
232
231
|
if chunk.choices:
|
|
233
232
|
delta = chunk.choices[0].delta
|
|
234
233
|
|
|
235
234
|
if delta.content is not None:
|
|
236
|
-
return OpenAIStreamTextResponse(delta.content, response, timestamp,
|
|
235
|
+
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
237
236
|
elif delta.tool_calls is not None:
|
|
238
237
|
return OpenAIStreamStructuredResponse(
|
|
239
238
|
response,
|
|
240
239
|
{c.index: c for c in delta.tool_calls},
|
|
241
240
|
timestamp,
|
|
242
|
-
|
|
241
|
+
start_usage,
|
|
243
242
|
)
|
|
244
243
|
# else continue until we get either delta.content or delta.tool_calls
|
|
245
244
|
|
|
@@ -302,7 +301,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
302
301
|
_first: str | None
|
|
303
302
|
_response: AsyncStream[ChatCompletionChunk]
|
|
304
303
|
_timestamp: datetime
|
|
305
|
-
|
|
304
|
+
_usage: result.Usage
|
|
306
305
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
307
306
|
|
|
308
307
|
async def __anext__(self) -> None:
|
|
@@ -312,7 +311,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
312
311
|
return None
|
|
313
312
|
|
|
314
313
|
chunk = await self._response.__anext__()
|
|
315
|
-
self.
|
|
314
|
+
self._usage += _map_usage(chunk)
|
|
316
315
|
try:
|
|
317
316
|
choice = chunk.choices[0]
|
|
318
317
|
except IndexError:
|
|
@@ -328,8 +327,8 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
328
327
|
yield from self._buffer
|
|
329
328
|
self._buffer.clear()
|
|
330
329
|
|
|
331
|
-
def
|
|
332
|
-
return self.
|
|
330
|
+
def usage(self) -> Usage:
|
|
331
|
+
return self._usage
|
|
333
332
|
|
|
334
333
|
def timestamp(self) -> datetime:
|
|
335
334
|
return self._timestamp
|
|
@@ -342,11 +341,11 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
342
341
|
_response: AsyncStream[ChatCompletionChunk]
|
|
343
342
|
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
344
343
|
_timestamp: datetime
|
|
345
|
-
|
|
344
|
+
_usage: result.Usage
|
|
346
345
|
|
|
347
346
|
async def __anext__(self) -> None:
|
|
348
347
|
chunk = await self._response.__anext__()
|
|
349
|
-
self.
|
|
348
|
+
self._usage += _map_usage(chunk)
|
|
350
349
|
try:
|
|
351
350
|
choice = chunk.choices[0]
|
|
352
351
|
except IndexError:
|
|
@@ -372,37 +371,36 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
372
371
|
for c in self._delta_tool_calls.values():
|
|
373
372
|
if f := c.function:
|
|
374
373
|
if f.name is not None and f.arguments is not None:
|
|
375
|
-
items.append(ToolCallPart.
|
|
374
|
+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
376
375
|
|
|
377
376
|
return ModelResponse(items, timestamp=self._timestamp)
|
|
378
377
|
|
|
379
|
-
def
|
|
380
|
-
return self.
|
|
378
|
+
def usage(self) -> Usage:
|
|
379
|
+
return self._usage
|
|
381
380
|
|
|
382
381
|
def timestamp(self) -> datetime:
|
|
383
382
|
return self._timestamp
|
|
384
383
|
|
|
385
384
|
|
|
386
385
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
387
|
-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
388
386
|
return chat.ChatCompletionMessageToolCallParam(
|
|
389
387
|
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
390
388
|
type='function',
|
|
391
|
-
function={'name': t.tool_name, 'arguments': t.
|
|
389
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
392
390
|
)
|
|
393
391
|
|
|
394
392
|
|
|
395
|
-
def
|
|
393
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
|
|
396
394
|
usage = response.usage
|
|
397
395
|
if usage is None:
|
|
398
|
-
return result.
|
|
396
|
+
return result.Usage()
|
|
399
397
|
else:
|
|
400
398
|
details: dict[str, int] = {}
|
|
401
399
|
if usage.completion_tokens_details is not None:
|
|
402
400
|
details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
|
|
403
401
|
if usage.prompt_tokens_details is not None:
|
|
404
402
|
details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
|
|
405
|
-
return result.
|
|
403
|
+
return result.Usage(
|
|
406
404
|
request_tokens=usage.prompt_tokens,
|
|
407
405
|
response_tokens=usage.completion_tokens,
|
|
408
406
|
total_tokens=usage.total_tokens,
|
pydantic_ai/models/test.py
CHANGED
|
@@ -21,7 +21,7 @@ from ..messages import (
|
|
|
21
21
|
ToolCallPart,
|
|
22
22
|
ToolReturnPart,
|
|
23
23
|
)
|
|
24
|
-
from ..result import
|
|
24
|
+
from ..result import Usage
|
|
25
25
|
from ..settings import ModelSettings
|
|
26
26
|
from ..tools import ToolDefinition
|
|
27
27
|
from . import (
|
|
@@ -31,6 +31,7 @@ from . import (
|
|
|
31
31
|
StreamStructuredResponse,
|
|
32
32
|
StreamTextResponse,
|
|
33
33
|
)
|
|
34
|
+
from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
@dataclass
|
|
@@ -131,15 +132,17 @@ class TestAgentModel(AgentModel):
|
|
|
131
132
|
|
|
132
133
|
async def request(
|
|
133
134
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
134
|
-
) -> tuple[ModelResponse,
|
|
135
|
-
|
|
135
|
+
) -> tuple[ModelResponse, Usage]:
|
|
136
|
+
model_response = self._request(messages, model_settings)
|
|
137
|
+
usage = _estimate_usage([*messages, model_response])
|
|
138
|
+
return model_response, usage
|
|
136
139
|
|
|
137
140
|
@asynccontextmanager
|
|
138
141
|
async def request_stream(
|
|
139
142
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
140
143
|
) -> AsyncIterator[EitherStreamedResponse]:
|
|
141
144
|
msg = self._request(messages, model_settings)
|
|
142
|
-
|
|
145
|
+
usage = _estimate_usage(messages)
|
|
143
146
|
|
|
144
147
|
# TODO: Rework this once we make StreamTextResponse more general
|
|
145
148
|
texts: list[str] = []
|
|
@@ -153,9 +156,9 @@ class TestAgentModel(AgentModel):
|
|
|
153
156
|
assert_never(item)
|
|
154
157
|
|
|
155
158
|
if texts:
|
|
156
|
-
yield TestStreamTextResponse('\n\n'.join(texts),
|
|
159
|
+
yield TestStreamTextResponse('\n\n'.join(texts), usage)
|
|
157
160
|
else:
|
|
158
|
-
yield TestStreamStructuredResponse(msg,
|
|
161
|
+
yield TestStreamStructuredResponse(msg, usage)
|
|
159
162
|
|
|
160
163
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
161
164
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
@@ -164,7 +167,7 @@ class TestAgentModel(AgentModel):
|
|
|
164
167
|
# if there are tools, the first thing we want to do is call all of them
|
|
165
168
|
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
166
169
|
return ModelResponse(
|
|
167
|
-
parts=[ToolCallPart.
|
|
170
|
+
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
|
|
168
171
|
)
|
|
169
172
|
|
|
170
173
|
if messages:
|
|
@@ -176,7 +179,7 @@ class TestAgentModel(AgentModel):
|
|
|
176
179
|
if new_retry_names:
|
|
177
180
|
return ModelResponse(
|
|
178
181
|
parts=[
|
|
179
|
-
ToolCallPart.
|
|
182
|
+
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
|
|
180
183
|
for name, args in self.tool_calls
|
|
181
184
|
if name in new_retry_names
|
|
182
185
|
]
|
|
@@ -202,10 +205,10 @@ class TestAgentModel(AgentModel):
|
|
|
202
205
|
custom_result_args = self.result.right
|
|
203
206
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
204
207
|
if custom_result_args is not None:
|
|
205
|
-
return ModelResponse(parts=[ToolCallPart.
|
|
208
|
+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
|
|
206
209
|
else:
|
|
207
210
|
response_args = self.gen_tool_args(result_tool)
|
|
208
|
-
return ModelResponse(parts=[ToolCallPart.
|
|
211
|
+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
|
|
209
212
|
|
|
210
213
|
|
|
211
214
|
@dataclass
|
|
@@ -213,7 +216,7 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
213
216
|
"""A text response that streams test data."""
|
|
214
217
|
|
|
215
218
|
_text: str
|
|
216
|
-
|
|
219
|
+
_usage: Usage
|
|
217
220
|
_iter: Iterator[str] = field(init=False)
|
|
218
221
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
219
222
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
@@ -228,14 +231,17 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
228
231
|
self._iter = iter(words)
|
|
229
232
|
|
|
230
233
|
async def __anext__(self) -> None:
|
|
231
|
-
|
|
234
|
+
next_str = _utils.sync_anext(self._iter)
|
|
235
|
+
response_tokens = _estimate_string_usage(next_str)
|
|
236
|
+
self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
237
|
+
self._buffer.append(next_str)
|
|
232
238
|
|
|
233
239
|
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
234
240
|
yield from self._buffer
|
|
235
241
|
self._buffer.clear()
|
|
236
242
|
|
|
237
|
-
def
|
|
238
|
-
return self.
|
|
243
|
+
def usage(self) -> Usage:
|
|
244
|
+
return self._usage
|
|
239
245
|
|
|
240
246
|
def timestamp(self) -> datetime:
|
|
241
247
|
return self._timestamp
|
|
@@ -246,7 +252,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
|
246
252
|
"""A structured response that streams test data."""
|
|
247
253
|
|
|
248
254
|
_structured_response: ModelResponse
|
|
249
|
-
|
|
255
|
+
_usage: Usage
|
|
250
256
|
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
251
257
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
252
258
|
|
|
@@ -256,8 +262,8 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
|
256
262
|
def get(self, *, final: bool = False) -> ModelResponse:
|
|
257
263
|
return self._structured_response
|
|
258
264
|
|
|
259
|
-
def
|
|
260
|
-
return self.
|
|
265
|
+
def usage(self) -> Usage:
|
|
266
|
+
return self._usage
|
|
261
267
|
|
|
262
268
|
def timestamp(self) -> datetime:
|
|
263
269
|
return self._timestamp
|
pydantic_ai/result.py
CHANGED
|
@@ -9,11 +9,12 @@ from typing import Generic, TypeVar, cast
|
|
|
9
9
|
import logfire_api
|
|
10
10
|
|
|
11
11
|
from . import _result, _utils, exceptions, messages as _messages, models
|
|
12
|
-
from .
|
|
12
|
+
from .settings import UsageLimits
|
|
13
|
+
from .tools import AgentDeps, RunContext
|
|
13
14
|
|
|
14
15
|
__all__ = (
|
|
15
16
|
'ResultData',
|
|
16
|
-
'
|
|
17
|
+
'Usage',
|
|
17
18
|
'RunResult',
|
|
18
19
|
'StreamedRunResult',
|
|
19
20
|
)
|
|
@@ -26,30 +27,32 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
@dataclass
|
|
29
|
-
class
|
|
30
|
-
"""
|
|
30
|
+
class Usage:
|
|
31
|
+
"""LLM usage associated to a request or run.
|
|
31
32
|
|
|
32
|
-
Responsibility for calculating
|
|
33
|
+
Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
|
|
33
34
|
|
|
34
|
-
You'll need to look up the documentation of the model you're using to
|
|
35
|
+
You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
|
|
35
36
|
"""
|
|
36
37
|
|
|
38
|
+
requests: int = 0
|
|
39
|
+
"""Number of requests made."""
|
|
37
40
|
request_tokens: int | None = None
|
|
38
|
-
"""Tokens used in processing
|
|
41
|
+
"""Tokens used in processing requests."""
|
|
39
42
|
response_tokens: int | None = None
|
|
40
|
-
"""Tokens used in generating
|
|
43
|
+
"""Tokens used in generating responses."""
|
|
41
44
|
total_tokens: int | None = None
|
|
42
45
|
"""Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
|
|
43
46
|
details: dict[str, int] | None = None
|
|
44
47
|
"""Any extra details returned by the model."""
|
|
45
48
|
|
|
46
|
-
def __add__(self, other:
|
|
47
|
-
"""Add two
|
|
49
|
+
def __add__(self, other: Usage) -> Usage:
|
|
50
|
+
"""Add two Usages together.
|
|
48
51
|
|
|
49
|
-
This is provided so it's trivial to sum
|
|
52
|
+
This is provided so it's trivial to sum usage information from multiple requests and runs.
|
|
50
53
|
"""
|
|
51
54
|
counts: dict[str, int] = {}
|
|
52
|
-
for f in 'request_tokens', 'response_tokens', 'total_tokens':
|
|
55
|
+
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
|
|
53
56
|
self_value = getattr(self, f)
|
|
54
57
|
other_value = getattr(other, f)
|
|
55
58
|
if self_value is not None or other_value is not None:
|
|
@@ -61,7 +64,7 @@ class Cost:
|
|
|
61
64
|
for key, value in other.details.items():
|
|
62
65
|
details[key] = details.get(key, 0) + value
|
|
63
66
|
|
|
64
|
-
return
|
|
67
|
+
return Usage(**counts, details=details or None)
|
|
65
68
|
|
|
66
69
|
|
|
67
70
|
@dataclass
|
|
@@ -95,7 +98,7 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
95
98
|
return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
|
|
96
99
|
|
|
97
100
|
@abstractmethod
|
|
98
|
-
def
|
|
101
|
+
def usage(self) -> Usage:
|
|
99
102
|
raise NotImplementedError()
|
|
100
103
|
|
|
101
104
|
|
|
@@ -105,22 +108,23 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
105
108
|
|
|
106
109
|
data: ResultData
|
|
107
110
|
"""Data from the final response in the run."""
|
|
108
|
-
|
|
111
|
+
_usage: Usage
|
|
109
112
|
|
|
110
|
-
def
|
|
111
|
-
"""Return the
|
|
112
|
-
return self.
|
|
113
|
+
def usage(self) -> Usage:
|
|
114
|
+
"""Return the usage of the whole run."""
|
|
115
|
+
return self._usage
|
|
113
116
|
|
|
114
117
|
|
|
115
118
|
@dataclass
|
|
116
119
|
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
|
|
117
120
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
118
121
|
|
|
119
|
-
|
|
120
|
-
"""
|
|
122
|
+
usage_so_far: Usage
|
|
123
|
+
"""Usage of the run up until the last request."""
|
|
124
|
+
_usage_limits: UsageLimits | None
|
|
121
125
|
_stream_response: models.EitherStreamedResponse
|
|
122
126
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
123
|
-
|
|
127
|
+
_run_ctx: RunContext[AgentDeps]
|
|
124
128
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
|
|
125
129
|
_result_tool_name: str | None
|
|
126
130
|
_on_complete: Callable[[], Awaitable[None]]
|
|
@@ -173,11 +177,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
173
177
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
174
178
|
performing validation as each token is received.
|
|
175
179
|
"""
|
|
180
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
181
|
+
self._stream_response, self._usage_limits, self.usage
|
|
182
|
+
)
|
|
183
|
+
|
|
176
184
|
with _logfire.span('response stream text') as lf_span:
|
|
177
185
|
if isinstance(self._stream_response, models.StreamStructuredResponse):
|
|
178
186
|
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
179
187
|
if delta:
|
|
180
|
-
async with _utils.group_by_temporal(
|
|
188
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
181
189
|
async for _ in group_iter:
|
|
182
190
|
yield ''.join(self._stream_response.get())
|
|
183
191
|
final_delta = ''.join(self._stream_response.get(final=True))
|
|
@@ -188,7 +196,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
188
196
|
# yielding at each step
|
|
189
197
|
chunks: list[str] = []
|
|
190
198
|
combined = ''
|
|
191
|
-
async with _utils.group_by_temporal(
|
|
199
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
192
200
|
async for _ in group_iter:
|
|
193
201
|
new = False
|
|
194
202
|
for chunk in self._stream_response.get():
|
|
@@ -225,6 +233,10 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
225
233
|
Returns:
|
|
226
234
|
An async iterable of the structured response message and whether that is the last message.
|
|
227
235
|
"""
|
|
236
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
237
|
+
self._stream_response, self._usage_limits, self.usage
|
|
238
|
+
)
|
|
239
|
+
|
|
228
240
|
with _logfire.span('response stream structured') as lf_span:
|
|
229
241
|
if isinstance(self._stream_response, models.StreamTextResponse):
|
|
230
242
|
raise exceptions.UserError('stream_structured() can only be used with structured responses')
|
|
@@ -235,7 +247,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
235
247
|
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
236
248
|
yield msg, False
|
|
237
249
|
break
|
|
238
|
-
async with _utils.group_by_temporal(
|
|
250
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
239
251
|
async for _ in group_iter:
|
|
240
252
|
msg = self._stream_response.get()
|
|
241
253
|
for item in msg.parts:
|
|
@@ -249,8 +261,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
249
261
|
|
|
250
262
|
async def get_data(self) -> ResultData:
|
|
251
263
|
"""Stream the whole response, validate and return it."""
|
|
252
|
-
|
|
264
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
265
|
+
self._stream_response, self._usage_limits, self.usage
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
async for _ in usage_checking_stream:
|
|
253
269
|
pass
|
|
270
|
+
|
|
254
271
|
if isinstance(self._stream_response, models.StreamTextResponse):
|
|
255
272
|
text = ''.join(self._stream_response.get(final=True))
|
|
256
273
|
text = await self._validate_text_result(text)
|
|
@@ -266,13 +283,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
266
283
|
"""Return whether the stream response contains structured data (as opposed to text)."""
|
|
267
284
|
return isinstance(self._stream_response, models.StreamStructuredResponse)
|
|
268
285
|
|
|
269
|
-
def
|
|
270
|
-
"""Return the
|
|
286
|
+
def usage(self) -> Usage:
|
|
287
|
+
"""Return the usage of the whole run.
|
|
271
288
|
|
|
272
289
|
!!! note
|
|
273
|
-
This won't return the full
|
|
290
|
+
This won't return the full usage until the stream is finished.
|
|
274
291
|
"""
|
|
275
|
-
return self.
|
|
292
|
+
return self.usage_so_far + self._stream_response.usage()
|
|
276
293
|
|
|
277
294
|
def timestamp(self) -> datetime:
|
|
278
295
|
"""Get the timestamp of the response."""
|
|
@@ -294,17 +311,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
294
311
|
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
|
|
295
312
|
|
|
296
313
|
for validator in self._result_validators:
|
|
297
|
-
result_data = await validator.validate(result_data,
|
|
314
|
+
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
298
315
|
return result_data
|
|
299
316
|
|
|
300
317
|
async def _validate_text_result(self, text: str) -> str:
|
|
301
318
|
for validator in self._result_validators:
|
|
302
319
|
text = await validator.validate( # pyright: ignore[reportAssignmentType]
|
|
303
320
|
text, # pyright: ignore[reportArgumentType]
|
|
304
|
-
self._deps,
|
|
305
|
-
0,
|
|
306
321
|
None,
|
|
307
|
-
self.
|
|
322
|
+
self._run_ctx,
|
|
308
323
|
)
|
|
309
324
|
return text
|
|
310
325
|
|
|
@@ -312,3 +327,18 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
312
327
|
self.is_complete = True
|
|
313
328
|
self._all_messages.append(message)
|
|
314
329
|
await self._on_complete()
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _get_usage_checking_stream_response(
|
|
333
|
+
stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage]
|
|
334
|
+
) -> AsyncIterator[ResultData]:
|
|
335
|
+
if limits is not None and limits.has_token_limits():
|
|
336
|
+
|
|
337
|
+
async def _usage_checking_iterator():
|
|
338
|
+
async for item in stream_response:
|
|
339
|
+
limits.check_tokens(get_usage())
|
|
340
|
+
yield item
|
|
341
|
+
|
|
342
|
+
return _usage_checking_iterator()
|
|
343
|
+
else:
|
|
344
|
+
return stream_response
|
pydantic_ai/settings.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
3
6
|
from httpx import Timeout
|
|
4
7
|
from typing_extensions import TypedDict
|
|
5
8
|
|
|
9
|
+
from .exceptions import UsageLimitExceeded
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .result import Usage
|
|
13
|
+
|
|
6
14
|
|
|
7
15
|
class ModelSettings(TypedDict, total=False):
|
|
8
16
|
"""Settings to configure an LLM.
|
|
@@ -70,3 +78,60 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings |
|
|
|
70
78
|
return base | overrides
|
|
71
79
|
else:
|
|
72
80
|
return base or overrides
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class UsageLimits:
|
|
85
|
+
"""Limits on model usage.
|
|
86
|
+
|
|
87
|
+
The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
|
|
88
|
+
Token counts are provided in responses from the model, and the token limits are checked after each response.
|
|
89
|
+
|
|
90
|
+
Each of the limits can be set to `None` to disable that limit.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
request_limit: int | None = 50
|
|
94
|
+
"""The maximum number of requests allowed to the model."""
|
|
95
|
+
request_tokens_limit: int | None = None
|
|
96
|
+
"""The maximum number of tokens allowed in requests to the model."""
|
|
97
|
+
response_tokens_limit: int | None = None
|
|
98
|
+
"""The maximum number of tokens allowed in responses from the model."""
|
|
99
|
+
total_tokens_limit: int | None = None
|
|
100
|
+
"""The maximum number of tokens allowed in requests and responses combined."""
|
|
101
|
+
|
|
102
|
+
def has_token_limits(self) -> bool:
|
|
103
|
+
"""Returns `True` if this instance places any limits on token counts.
|
|
104
|
+
|
|
105
|
+
If this returns `False`, the `check_tokens` method will never raise an error.
|
|
106
|
+
|
|
107
|
+
This is useful because if we have token limits, we need to check them after receiving each streamed message.
|
|
108
|
+
If there are no limits, we can skip that processing in the streaming response iterator.
|
|
109
|
+
"""
|
|
110
|
+
return any(
|
|
111
|
+
limit is not None
|
|
112
|
+
for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def check_before_request(self, usage: Usage) -> None:
|
|
116
|
+
"""Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
|
|
117
|
+
request_limit = self.request_limit
|
|
118
|
+
if request_limit is not None and usage.requests >= request_limit:
|
|
119
|
+
raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
|
|
120
|
+
|
|
121
|
+
def check_tokens(self, usage: Usage) -> None:
|
|
122
|
+
"""Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
|
|
123
|
+
request_tokens = usage.request_tokens or 0
|
|
124
|
+
if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
|
|
125
|
+
raise UsageLimitExceeded(
|
|
126
|
+
f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
response_tokens = usage.response_tokens or 0
|
|
130
|
+
if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
|
|
131
|
+
raise UsageLimitExceeded(
|
|
132
|
+
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
total_tokens = request_tokens + response_tokens
|
|
136
|
+
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
|
|
137
|
+
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
|