pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -13,7 +13,6 @@ from typing_extensions import assert_never
13
13
  from .. import UnexpectedModelBehavior, _utils, result
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
- ArgsJson,
17
16
  ModelMessage,
18
17
  ModelRequest,
19
18
  ModelResponse,
@@ -25,7 +24,7 @@ from ..messages import (
25
24
  ToolReturnPart,
26
25
  UserPromptPart,
27
26
  )
28
- from ..result import Cost
27
+ from ..result import Usage
29
28
  from ..settings import ModelSettings
30
29
  from ..tools import ToolDefinition
31
30
  from . import (
@@ -147,9 +146,9 @@ class OpenAIAgentModel(AgentModel):
147
146
 
148
147
  async def request(
149
148
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
150
- ) -> tuple[ModelResponse, result.Cost]:
149
+ ) -> tuple[ModelResponse, result.Usage]:
151
150
  response = await self._completions_create(messages, False, model_settings)
152
- return self._process_response(response), _map_cost(response)
151
+ return self._process_response(response), _map_usage(response)
153
152
 
154
153
  @asynccontextmanager
155
154
  async def request_stream(
@@ -211,14 +210,14 @@ class OpenAIAgentModel(AgentModel):
211
210
  items.append(TextPart(choice.message.content))
212
211
  if choice.message.tool_calls is not None:
213
212
  for c in choice.message.tool_calls:
214
- items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
213
+ items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
215
214
  return ModelResponse(items, timestamp=timestamp)
216
215
 
217
216
  @staticmethod
218
217
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
219
218
  """Process a streamed response, and prepare a streaming response to return."""
220
219
  timestamp: datetime | None = None
221
- start_cost = Cost()
220
+ start_usage = Usage()
222
221
  # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
223
222
  while True:
224
223
  try:
@@ -227,19 +226,19 @@ class OpenAIAgentModel(AgentModel):
227
226
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
228
227
 
229
228
  timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
230
- start_cost += _map_cost(chunk)
229
+ start_usage += _map_usage(chunk)
231
230
 
232
231
  if chunk.choices:
233
232
  delta = chunk.choices[0].delta
234
233
 
235
234
  if delta.content is not None:
236
- return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
235
+ return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
237
236
  elif delta.tool_calls is not None:
238
237
  return OpenAIStreamStructuredResponse(
239
238
  response,
240
239
  {c.index: c for c in delta.tool_calls},
241
240
  timestamp,
242
- start_cost,
241
+ start_usage,
243
242
  )
244
243
  # else continue until we get either delta.content or delta.tool_calls
245
244
 
@@ -302,7 +301,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
302
301
  _first: str | None
303
302
  _response: AsyncStream[ChatCompletionChunk]
304
303
  _timestamp: datetime
305
- _cost: result.Cost
304
+ _usage: result.Usage
306
305
  _buffer: list[str] = field(default_factory=list, init=False)
307
306
 
308
307
  async def __anext__(self) -> None:
@@ -312,7 +311,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
312
311
  return None
313
312
 
314
313
  chunk = await self._response.__anext__()
315
- self._cost += _map_cost(chunk)
314
+ self._usage += _map_usage(chunk)
316
315
  try:
317
316
  choice = chunk.choices[0]
318
317
  except IndexError:
@@ -328,8 +327,8 @@ class OpenAIStreamTextResponse(StreamTextResponse):
328
327
  yield from self._buffer
329
328
  self._buffer.clear()
330
329
 
331
- def cost(self) -> Cost:
332
- return self._cost
330
+ def usage(self) -> Usage:
331
+ return self._usage
333
332
 
334
333
  def timestamp(self) -> datetime:
335
334
  return self._timestamp
@@ -342,11 +341,11 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
342
341
  _response: AsyncStream[ChatCompletionChunk]
343
342
  _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
344
343
  _timestamp: datetime
345
- _cost: result.Cost
344
+ _usage: result.Usage
346
345
 
347
346
  async def __anext__(self) -> None:
348
347
  chunk = await self._response.__anext__()
349
- self._cost += _map_cost(chunk)
348
+ self._usage += _map_usage(chunk)
350
349
  try:
351
350
  choice = chunk.choices[0]
352
351
  except IndexError:
@@ -372,37 +371,36 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
372
371
  for c in self._delta_tool_calls.values():
373
372
  if f := c.function:
374
373
  if f.name is not None and f.arguments is not None:
375
- items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
374
+ items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
376
375
 
377
376
  return ModelResponse(items, timestamp=self._timestamp)
378
377
 
379
- def cost(self) -> Cost:
380
- return self._cost
378
+ def usage(self) -> Usage:
379
+ return self._usage
381
380
 
382
381
  def timestamp(self) -> datetime:
383
382
  return self._timestamp
384
383
 
385
384
 
386
385
  def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
387
- assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
388
386
  return chat.ChatCompletionMessageToolCallParam(
389
387
  id=_guard_tool_call_id(t=t, model_source='OpenAI'),
390
388
  type='function',
391
- function={'name': t.tool_name, 'arguments': t.args.args_json},
389
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
392
390
  )
393
391
 
394
392
 
395
- def _map_cost(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Cost:
393
+ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
396
394
  usage = response.usage
397
395
  if usage is None:
398
- return result.Cost()
396
+ return result.Usage()
399
397
  else:
400
398
  details: dict[str, int] = {}
401
399
  if usage.completion_tokens_details is not None:
402
400
  details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
403
401
  if usage.prompt_tokens_details is not None:
404
402
  details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
405
- return result.Cost(
403
+ return result.Usage(
406
404
  request_tokens=usage.prompt_tokens,
407
405
  response_tokens=usage.completion_tokens,
408
406
  total_tokens=usage.total_tokens,
@@ -21,7 +21,7 @@ from ..messages import (
21
21
  ToolCallPart,
22
22
  ToolReturnPart,
23
23
  )
24
- from ..result import Cost
24
+ from ..result import Usage
25
25
  from ..settings import ModelSettings
26
26
  from ..tools import ToolDefinition
27
27
  from . import (
@@ -31,6 +31,7 @@ from . import (
31
31
  StreamStructuredResponse,
32
32
  StreamTextResponse,
33
33
  )
34
+ from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
34
35
 
35
36
 
36
37
  @dataclass
@@ -131,15 +132,17 @@ class TestAgentModel(AgentModel):
131
132
 
132
133
  async def request(
133
134
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
134
- ) -> tuple[ModelResponse, Cost]:
135
- return self._request(messages, model_settings), Cost()
135
+ ) -> tuple[ModelResponse, Usage]:
136
+ model_response = self._request(messages, model_settings)
137
+ usage = _estimate_usage([*messages, model_response])
138
+ return model_response, usage
136
139
 
137
140
  @asynccontextmanager
138
141
  async def request_stream(
139
142
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
140
143
  ) -> AsyncIterator[EitherStreamedResponse]:
141
144
  msg = self._request(messages, model_settings)
142
- cost = Cost()
145
+ usage = _estimate_usage(messages)
143
146
 
144
147
  # TODO: Rework this once we make StreamTextResponse more general
145
148
  texts: list[str] = []
@@ -153,9 +156,9 @@ class TestAgentModel(AgentModel):
153
156
  assert_never(item)
154
157
 
155
158
  if texts:
156
- yield TestStreamTextResponse('\n\n'.join(texts), cost)
159
+ yield TestStreamTextResponse('\n\n'.join(texts), usage)
157
160
  else:
158
- yield TestStreamStructuredResponse(msg, cost)
161
+ yield TestStreamStructuredResponse(msg, usage)
159
162
 
160
163
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
161
164
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
@@ -164,7 +167,7 @@ class TestAgentModel(AgentModel):
164
167
  # if there are tools, the first thing we want to do is call all of them
165
168
  if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
166
169
  return ModelResponse(
167
- parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
170
+ parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
168
171
  )
169
172
 
170
173
  if messages:
@@ -176,7 +179,7 @@ class TestAgentModel(AgentModel):
176
179
  if new_retry_names:
177
180
  return ModelResponse(
178
181
  parts=[
179
- ToolCallPart.from_dict(name, self.gen_tool_args(args))
182
+ ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
180
183
  for name, args in self.tool_calls
181
184
  if name in new_retry_names
182
185
  ]
@@ -202,10 +205,10 @@ class TestAgentModel(AgentModel):
202
205
  custom_result_args = self.result.right
203
206
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
204
207
  if custom_result_args is not None:
205
- return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, custom_result_args)])
208
+ return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
206
209
  else:
207
210
  response_args = self.gen_tool_args(result_tool)
208
- return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
211
+ return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
209
212
 
210
213
 
211
214
  @dataclass
@@ -213,7 +216,7 @@ class TestStreamTextResponse(StreamTextResponse):
213
216
  """A text response that streams test data."""
214
217
 
215
218
  _text: str
216
- _cost: Cost
219
+ _usage: Usage
217
220
  _iter: Iterator[str] = field(init=False)
218
221
  _timestamp: datetime = field(default_factory=_utils.now_utc)
219
222
  _buffer: list[str] = field(default_factory=list, init=False)
@@ -228,14 +231,17 @@ class TestStreamTextResponse(StreamTextResponse):
228
231
  self._iter = iter(words)
229
232
 
230
233
  async def __anext__(self) -> None:
231
- self._buffer.append(_utils.sync_anext(self._iter))
234
+ next_str = _utils.sync_anext(self._iter)
235
+ response_tokens = _estimate_string_usage(next_str)
236
+ self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
237
+ self._buffer.append(next_str)
232
238
 
233
239
  def get(self, *, final: bool = False) -> Iterable[str]:
234
240
  yield from self._buffer
235
241
  self._buffer.clear()
236
242
 
237
- def cost(self) -> Cost:
238
- return self._cost
243
+ def usage(self) -> Usage:
244
+ return self._usage
239
245
 
240
246
  def timestamp(self) -> datetime:
241
247
  return self._timestamp
@@ -246,7 +252,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
246
252
  """A structured response that streams test data."""
247
253
 
248
254
  _structured_response: ModelResponse
249
- _cost: Cost
255
+ _usage: Usage
250
256
  _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
251
257
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
252
258
 
@@ -256,8 +262,8 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
256
262
  def get(self, *, final: bool = False) -> ModelResponse:
257
263
  return self._structured_response
258
264
 
259
- def cost(self) -> Cost:
260
- return self._cost
265
+ def usage(self) -> Usage:
266
+ return self._usage
261
267
 
262
268
  def timestamp(self) -> datetime:
263
269
  return self._timestamp
pydantic_ai/result.py CHANGED
@@ -9,11 +9,12 @@ from typing import Generic, TypeVar, cast
9
9
  import logfire_api
10
10
 
11
11
  from . import _result, _utils, exceptions, messages as _messages, models
12
- from .tools import AgentDeps
12
+ from .settings import UsageLimits
13
+ from .tools import AgentDeps, RunContext
13
14
 
14
15
  __all__ = (
15
16
  'ResultData',
16
- 'Cost',
17
+ 'Usage',
17
18
  'RunResult',
18
19
  'StreamedRunResult',
19
20
  )
@@ -26,30 +27,32 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
26
27
 
27
28
 
28
29
  @dataclass
29
- class Cost:
30
- """Cost of a request or run.
30
+ class Usage:
31
+ """LLM usage associated to a request or run.
31
32
 
32
- Responsibility for calculating costs is on the model used, PydanticAI simply sums the cost of requests.
33
+ Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
33
34
 
34
- You'll need to look up the documentation of the model you're using to convent "token count" costs to monetary costs.
35
+ You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
35
36
  """
36
37
 
38
+ requests: int = 0
39
+ """Number of requests made."""
37
40
  request_tokens: int | None = None
38
- """Tokens used in processing the request."""
41
+ """Tokens used in processing requests."""
39
42
  response_tokens: int | None = None
40
- """Tokens used in generating the response."""
43
+ """Tokens used in generating responses."""
41
44
  total_tokens: int | None = None
42
45
  """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
43
46
  details: dict[str, int] | None = None
44
47
  """Any extra details returned by the model."""
45
48
 
46
- def __add__(self, other: Cost) -> Cost:
47
- """Add two costs together.
49
+ def __add__(self, other: Usage) -> Usage:
50
+ """Add two Usages together.
48
51
 
49
- This is provided so it's trivial to sum costs from multiple requests and runs.
52
+ This is provided so it's trivial to sum usage information from multiple requests and runs.
50
53
  """
51
54
  counts: dict[str, int] = {}
52
- for f in 'request_tokens', 'response_tokens', 'total_tokens':
55
+ for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
53
56
  self_value = getattr(self, f)
54
57
  other_value = getattr(other, f)
55
58
  if self_value is not None or other_value is not None:
@@ -61,7 +64,7 @@ class Cost:
61
64
  for key, value in other.details.items():
62
65
  details[key] = details.get(key, 0) + value
63
66
 
64
- return Cost(**counts, details=details or None)
67
+ return Usage(**counts, details=details or None)
65
68
 
66
69
 
67
70
  @dataclass
@@ -95,7 +98,7 @@ class _BaseRunResult(ABC, Generic[ResultData]):
95
98
  return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
96
99
 
97
100
  @abstractmethod
98
- def cost(self) -> Cost:
101
+ def usage(self) -> Usage:
99
102
  raise NotImplementedError()
100
103
 
101
104
 
@@ -105,22 +108,23 @@ class RunResult(_BaseRunResult[ResultData]):
105
108
 
106
109
  data: ResultData
107
110
  """Data from the final response in the run."""
108
- _cost: Cost
111
+ _usage: Usage
109
112
 
110
- def cost(self) -> Cost:
111
- """Return the cost of the whole run."""
112
- return self._cost
113
+ def usage(self) -> Usage:
114
+ """Return the usage of the whole run."""
115
+ return self._usage
113
116
 
114
117
 
115
118
  @dataclass
116
119
  class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
117
120
  """Result of a streamed run that returns structured data via a tool call."""
118
121
 
119
- cost_so_far: Cost
120
- """Cost of the run up until the last request."""
122
+ usage_so_far: Usage
123
+ """Usage of the run up until the last request."""
124
+ _usage_limits: UsageLimits | None
121
125
  _stream_response: models.EitherStreamedResponse
122
126
  _result_schema: _result.ResultSchema[ResultData] | None
123
- _deps: AgentDeps
127
+ _run_ctx: RunContext[AgentDeps]
124
128
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
125
129
  _result_tool_name: str | None
126
130
  _on_complete: Callable[[], Awaitable[None]]
@@ -173,11 +177,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
173
177
  Debouncing is particularly important for long structured responses to reduce the overhead of
174
178
  performing validation as each token is received.
175
179
  """
180
+ usage_checking_stream = _get_usage_checking_stream_response(
181
+ self._stream_response, self._usage_limits, self.usage
182
+ )
183
+
176
184
  with _logfire.span('response stream text') as lf_span:
177
185
  if isinstance(self._stream_response, models.StreamStructuredResponse):
178
186
  raise exceptions.UserError('stream_text() can only be used with text responses')
179
187
  if delta:
180
- async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
188
+ async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
181
189
  async for _ in group_iter:
182
190
  yield ''.join(self._stream_response.get())
183
191
  final_delta = ''.join(self._stream_response.get(final=True))
@@ -188,7 +196,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
188
196
  # yielding at each step
189
197
  chunks: list[str] = []
190
198
  combined = ''
191
- async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
199
+ async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
192
200
  async for _ in group_iter:
193
201
  new = False
194
202
  for chunk in self._stream_response.get():
@@ -225,6 +233,10 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
225
233
  Returns:
226
234
  An async iterable of the structured response message and whether that is the last message.
227
235
  """
236
+ usage_checking_stream = _get_usage_checking_stream_response(
237
+ self._stream_response, self._usage_limits, self.usage
238
+ )
239
+
228
240
  with _logfire.span('response stream structured') as lf_span:
229
241
  if isinstance(self._stream_response, models.StreamTextResponse):
230
242
  raise exceptions.UserError('stream_structured() can only be used with structured responses')
@@ -235,7 +247,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
235
247
  if isinstance(item, _messages.ToolCallPart) and item.has_content():
236
248
  yield msg, False
237
249
  break
238
- async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
250
+ async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
239
251
  async for _ in group_iter:
240
252
  msg = self._stream_response.get()
241
253
  for item in msg.parts:
@@ -249,8 +261,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
249
261
 
250
262
  async def get_data(self) -> ResultData:
251
263
  """Stream the whole response, validate and return it."""
252
- async for _ in self._stream_response:
264
+ usage_checking_stream = _get_usage_checking_stream_response(
265
+ self._stream_response, self._usage_limits, self.usage
266
+ )
267
+
268
+ async for _ in usage_checking_stream:
253
269
  pass
270
+
254
271
  if isinstance(self._stream_response, models.StreamTextResponse):
255
272
  text = ''.join(self._stream_response.get(final=True))
256
273
  text = await self._validate_text_result(text)
@@ -266,13 +283,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
266
283
  """Return whether the stream response contains structured data (as opposed to text)."""
267
284
  return isinstance(self._stream_response, models.StreamStructuredResponse)
268
285
 
269
- def cost(self) -> Cost:
270
- """Return the cost of the whole run.
286
+ def usage(self) -> Usage:
287
+ """Return the usage of the whole run.
271
288
 
272
289
  !!! note
273
- This won't return the full cost until the stream is finished.
290
+ This won't return the full usage until the stream is finished.
274
291
  """
275
- return self.cost_so_far + self._stream_response.cost()
292
+ return self.usage_so_far + self._stream_response.usage()
276
293
 
277
294
  def timestamp(self) -> datetime:
278
295
  """Get the timestamp of the response."""
@@ -294,17 +311,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
294
311
  result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
295
312
 
296
313
  for validator in self._result_validators:
297
- result_data = await validator.validate(result_data, self._deps, 0, call, self._all_messages)
314
+ result_data = await validator.validate(result_data, call, self._run_ctx)
298
315
  return result_data
299
316
 
300
317
  async def _validate_text_result(self, text: str) -> str:
301
318
  for validator in self._result_validators:
302
319
  text = await validator.validate( # pyright: ignore[reportAssignmentType]
303
320
  text, # pyright: ignore[reportArgumentType]
304
- self._deps,
305
- 0,
306
321
  None,
307
- self._all_messages,
322
+ self._run_ctx,
308
323
  )
309
324
  return text
310
325
 
@@ -312,3 +327,18 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
312
327
  self.is_complete = True
313
328
  self._all_messages.append(message)
314
329
  await self._on_complete()
330
+
331
+
332
+ def _get_usage_checking_stream_response(
333
+ stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage]
334
+ ) -> AsyncIterator[ResultData]:
335
+ if limits is not None and limits.has_token_limits():
336
+
337
+ async def _usage_checking_iterator():
338
+ async for item in stream_response:
339
+ limits.check_tokens(get_usage())
340
+ yield item
341
+
342
+ return _usage_checking_iterator()
343
+ else:
344
+ return stream_response
pydantic_ai/settings.py CHANGED
@@ -1,8 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
3
6
  from httpx import Timeout
4
7
  from typing_extensions import TypedDict
5
8
 
9
+ from .exceptions import UsageLimitExceeded
10
+
11
+ if TYPE_CHECKING:
12
+ from .result import Usage
13
+
6
14
 
7
15
  class ModelSettings(TypedDict, total=False):
8
16
  """Settings to configure an LLM.
@@ -70,3 +78,60 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings |
70
78
  return base | overrides
71
79
  else:
72
80
  return base or overrides
81
+
82
+
83
+ @dataclass
84
+ class UsageLimits:
85
+ """Limits on model usage.
86
+
87
+ The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
88
+ Token counts are provided in responses from the model, and the token limits are checked after each response.
89
+
90
+ Each of the limits can be set to `None` to disable that limit.
91
+ """
92
+
93
+ request_limit: int | None = 50
94
+ """The maximum number of requests allowed to the model."""
95
+ request_tokens_limit: int | None = None
96
+ """The maximum number of tokens allowed in requests to the model."""
97
+ response_tokens_limit: int | None = None
98
+ """The maximum number of tokens allowed in responses from the model."""
99
+ total_tokens_limit: int | None = None
100
+ """The maximum number of tokens allowed in requests and responses combined."""
101
+
102
+ def has_token_limits(self) -> bool:
103
+ """Returns `True` if this instance places any limits on token counts.
104
+
105
+ If this returns `False`, the `check_tokens` method will never raise an error.
106
+
107
+ This is useful because if we have token limits, we need to check them after receiving each streamed message.
108
+ If there are no limits, we can skip that processing in the streaming response iterator.
109
+ """
110
+ return any(
111
+ limit is not None
112
+ for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
113
+ )
114
+
115
+ def check_before_request(self, usage: Usage) -> None:
116
+ """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
117
+ request_limit = self.request_limit
118
+ if request_limit is not None and usage.requests >= request_limit:
119
+ raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
120
+
121
+ def check_tokens(self, usage: Usage) -> None:
122
+ """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
123
+ request_tokens = usage.request_tokens or 0
124
+ if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
125
+ raise UsageLimitExceeded(
126
+ f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
127
+ )
128
+
129
+ response_tokens = usage.response_tokens or 0
130
+ if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
131
+ raise UsageLimitExceeded(
132
+ f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
133
+ )
134
+
135
+ total_tokens = request_tokens + response_tokens
136
+ if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
137
+ raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')