pydantic-ai-slim 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -20,9 +20,10 @@ from . import (
20
20
  messages as _messages,
21
21
  models,
22
22
  result,
23
+ usage as _usage,
23
24
  )
24
25
  from .result import ResultData
25
- from .settings import ModelSettings, UsageLimits, merge_model_settings
26
+ from .settings import ModelSettings, merge_model_settings
26
27
  from .tools import (
27
28
  AgentDeps,
28
29
  RunContext,
@@ -192,8 +193,8 @@ class Agent(Generic[AgentDeps, ResultData]):
192
193
  model: models.Model | models.KnownModelName | None = None,
193
194
  deps: AgentDeps = None,
194
195
  model_settings: ModelSettings | None = None,
195
- usage_limits: UsageLimits | None = None,
196
- usage: result.Usage | None = None,
196
+ usage_limits: _usage.UsageLimits | None = None,
197
+ usage: _usage.Usage | None = None,
197
198
  infer_name: bool = True,
198
199
  ) -> result.RunResult[ResultData]:
199
200
  """Run the agent with a user prompt in async mode.
@@ -236,7 +237,7 @@ class Agent(Generic[AgentDeps, ResultData]):
236
237
  model_name=model_used.name(),
237
238
  agent_name=self.name or 'agent',
238
239
  ) as run_span:
239
- run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
240
+ run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
240
241
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
241
242
  run_context.messages = messages
242
243
 
@@ -244,7 +245,7 @@ class Agent(Generic[AgentDeps, ResultData]):
244
245
  tool.current_retry = 0
245
246
 
246
247
  model_settings = merge_model_settings(self.model_settings, model_settings)
247
- usage_limits = usage_limits or UsageLimits()
248
+ usage_limits = usage_limits or _usage.UsageLimits()
248
249
 
249
250
  while True:
250
251
  usage_limits.check_before_request(run_context.usage)
@@ -272,11 +273,14 @@ class Agent(Generic[AgentDeps, ResultData]):
272
273
  # Check if we got a final result
273
274
  if final_result is not None:
274
275
  result_data = final_result.data
276
+ result_tool_name = final_result.tool_name
275
277
  run_span.set_attribute('all_messages', messages)
276
278
  run_span.set_attribute('usage', run_context.usage)
277
279
  handle_span.set_attribute('result', result_data)
278
280
  handle_span.message = 'handle model response -> final result'
279
- return result.RunResult(messages, new_message_index, result_data, run_context.usage)
281
+ return result.RunResult(
282
+ messages, new_message_index, result_data, result_tool_name, run_context.usage
283
+ )
280
284
  else:
281
285
  # continue the conversation
282
286
  handle_span.set_attribute('tool_responses', tool_responses)
@@ -291,8 +295,8 @@ class Agent(Generic[AgentDeps, ResultData]):
291
295
  model: models.Model | models.KnownModelName | None = None,
292
296
  deps: AgentDeps = None,
293
297
  model_settings: ModelSettings | None = None,
294
- usage_limits: UsageLimits | None = None,
295
- usage: result.Usage | None = None,
298
+ usage_limits: _usage.UsageLimits | None = None,
299
+ usage: _usage.Usage | None = None,
296
300
  infer_name: bool = True,
297
301
  ) -> result.RunResult[ResultData]:
298
302
  """Run the agent with a user prompt synchronously.
@@ -349,8 +353,8 @@ class Agent(Generic[AgentDeps, ResultData]):
349
353
  model: models.Model | models.KnownModelName | None = None,
350
354
  deps: AgentDeps = None,
351
355
  model_settings: ModelSettings | None = None,
352
- usage_limits: UsageLimits | None = None,
353
- usage: result.Usage | None = None,
356
+ usage_limits: _usage.UsageLimits | None = None,
357
+ usage: _usage.Usage | None = None,
354
358
  infer_name: bool = True,
355
359
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
356
360
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -396,7 +400,7 @@ class Agent(Generic[AgentDeps, ResultData]):
396
400
  model_name=model_used.name(),
397
401
  agent_name=self.name or 'agent',
398
402
  ) as run_span:
399
- run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
403
+ run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
400
404
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
401
405
  run_context.messages = messages
402
406
 
@@ -404,7 +408,7 @@ class Agent(Generic[AgentDeps, ResultData]):
404
408
  tool.current_retry = 0
405
409
 
406
410
  model_settings = merge_model_settings(self.model_settings, model_settings)
407
- usage_limits = usage_limits or UsageLimits()
411
+ usage_limits = usage_limits or _usage.UsageLimits()
408
412
 
409
413
  while True:
410
414
  run_context.run_step += 1
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import Iterable, Iterator, Mapping
4
+ from dataclasses import asdict, dataclass, is_dataclass
5
+ from datetime import date
6
+ from typing import Any
7
+ from xml.etree import ElementTree
8
+
9
+ from pydantic import BaseModel
10
+
11
+ __all__ = ('format_as_xml',)
12
+
13
+
14
+ def format_as_xml(
15
+ obj: Any,
16
+ root_tag: str = 'examples',
17
+ item_tag: str = 'example',
18
+ include_root_tag: bool = True,
19
+ none_str: str = 'null',
20
+ indent: str | None = ' ',
21
+ ) -> str:
22
+ """Format a Python object as XML.
23
+
24
+ This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
25
+ rather than JSON etc.
26
+
27
+ Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
28
+ `Iterable`, `dataclass`, and `BaseModel`.
29
+
30
+ Args:
31
+ obj: Python Object to serialize to XML.
32
+ root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
33
+ item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
34
+ for dataclasses and Pydantic models.
35
+ include_root_tag: Whether to include the root tag in the output
36
+ (The root tag is always included if it includes a body - e.g. when the input is a simple value).
37
+ none_str: String to use for `None` values.
38
+ indent: Indentation string to use for pretty printing.
39
+
40
+ Returns: XML representation of the object.
41
+
42
+ Example:
43
+ ```python {title="format_as_xml_example.py" lint="skip"}
44
+ from pydantic_ai.format_as_xml import format_as_xml
45
+
46
+ print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
47
+ '''
48
+ <user>
49
+ <name>John</name>
50
+ <height>6</height>
51
+ <weight>200</weight>
52
+ </user>
53
+ '''
54
+ ```
55
+ """
56
+ el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
57
+ if not include_root_tag and el.text is None:
58
+ join = '' if indent is None else '\n'
59
+ return join.join(_rootless_xml_elements(el, indent))
60
+ else:
61
+ if indent is not None:
62
+ ElementTree.indent(el, space=indent)
63
+ return ElementTree.tostring(el, encoding='unicode')
64
+
65
+
66
+ @dataclass
67
+ class _ToXml:
68
+ item_tag: str
69
+ none_str: str
70
+
71
+ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
72
+ element = ElementTree.Element(self.item_tag if tag is None else tag)
73
+ if value is None:
74
+ element.text = self.none_str
75
+ elif isinstance(value, str):
76
+ element.text = value
77
+ elif isinstance(value, (bytes, bytearray)):
78
+ element.text = value.decode(errors='ignore')
79
+ elif isinstance(value, (bool, int, float)):
80
+ element.text = str(value)
81
+ elif isinstance(value, date):
82
+ element.text = value.isoformat()
83
+ elif isinstance(value, Mapping):
84
+ self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
85
+ elif is_dataclass(value) and not isinstance(value, type):
86
+ if tag is None:
87
+ element = ElementTree.Element(value.__class__.__name__)
88
+ dc_dict = asdict(value)
89
+ self._mapping_to_xml(element, dc_dict)
90
+ elif isinstance(value, BaseModel):
91
+ if tag is None:
92
+ element = ElementTree.Element(value.__class__.__name__)
93
+ self._mapping_to_xml(element, value.model_dump(mode='python'))
94
+ elif isinstance(value, Iterable):
95
+ for item in value: # pyright: ignore[reportUnknownVariableType]
96
+ item_el = self.to_xml(item, None)
97
+ element.append(item_el)
98
+ else:
99
+ raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
100
+ return element
101
+
102
+ def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
103
+ for key, value in mapping.items():
104
+ if isinstance(key, int):
105
+ key = str(key)
106
+ elif not isinstance(key, str):
107
+ raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
108
+ element.append(self.to_xml(value, key))
109
+
110
+
111
+ def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
112
+ for sub_element in root:
113
+ if indent is not None:
114
+ ElementTree.indent(sub_element, space=indent)
115
+ yield ElementTree.tostring(sub_element, encoding='unicode')
@@ -273,17 +273,26 @@ class GeminiAgentModel(AgentModel):
273
273
  contents: list[_GeminiContent] = []
274
274
  for m in messages:
275
275
  if isinstance(m, ModelRequest):
276
+ message_parts: list[_GeminiPartUnion] = []
277
+
276
278
  for part in m.parts:
277
279
  if isinstance(part, SystemPromptPart):
278
280
  sys_prompt_parts.append(_GeminiTextPart(text=part.content))
279
281
  elif isinstance(part, UserPromptPart):
280
- contents.append(_content_user_prompt(part))
282
+ message_parts.append(_GeminiTextPart(text=part.content))
281
283
  elif isinstance(part, ToolReturnPart):
282
- contents.append(_content_tool_return(part))
284
+ message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
283
285
  elif isinstance(part, RetryPromptPart):
284
- contents.append(_content_retry_prompt(part))
286
+ if part.tool_name is None:
287
+ message_parts.append(_GeminiTextPart(text=part.model_response()))
288
+ else:
289
+ response = {'call_error': part.model_response()}
290
+ message_parts.append(_response_part_from_response(part.tool_name, response))
285
291
  else:
286
292
  assert_never(part)
293
+
294
+ if message_parts:
295
+ contents.append(_GeminiContent(role='user', parts=message_parts))
287
296
  elif isinstance(m, ModelResponse):
288
297
  contents.append(_content_model_response(m))
289
298
  else:
@@ -420,24 +429,6 @@ class _GeminiContent(TypedDict):
420
429
  parts: list[_GeminiPartUnion]
421
430
 
422
431
 
423
- def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
424
- return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
425
-
426
-
427
- def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
428
- f_response = _response_part_from_response(m.tool_name, m.model_response_object())
429
- return _GeminiContent(role='user', parts=[f_response])
430
-
431
-
432
- def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
433
- if m.tool_name is None:
434
- part = _GeminiTextPart(text=m.model_response())
435
- else:
436
- response = {'call_error': m.model_response()}
437
- part = _response_part_from_response(m.tool_name, response)
438
- return _GeminiContent(role='user', parts=[part])
439
-
440
-
441
432
  def _content_model_response(m: ModelResponse) -> _GeminiContent:
442
433
  parts: list[_GeminiPartUnion] = []
443
434
  for item in m.parts:
@@ -178,7 +178,7 @@ def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredenti
178
178
  # pyright: reportUnknownVariableType=false
179
179
  # pyright: reportUnknownArgumentType=false
180
180
  async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
181
- return await run_in_executor(google.auth.default)
181
+ return await run_in_executor(google.auth.default, scopes=['https://www.googleapis.com/auth/cloud-platform'])
182
182
 
183
183
 
184
184
  # default expiry is 3600 seconds
pydantic_ai/result.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import AsyncIterator, Awaitable, Callable
5
- from copy import copy
5
+ from copy import deepcopy
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
8
  from typing import Generic, Union, cast
@@ -11,16 +11,10 @@ import logfire_api
11
11
  from typing_extensions import TypeVar
12
12
 
13
13
  from . import _result, _utils, exceptions, messages as _messages, models
14
- from .settings import UsageLimits
15
14
  from .tools import AgentDeps, RunContext
15
+ from .usage import Usage, UsageLimits
16
16
 
17
- __all__ = (
18
- 'ResultData',
19
- 'ResultValidatorFunc',
20
- 'Usage',
21
- 'RunResult',
22
- 'StreamedRunResult',
23
- )
17
+ __all__ = 'ResultData', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
24
18
 
25
19
 
26
20
  ResultData = TypeVar('ResultData', default=str)
@@ -44,55 +38,6 @@ Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
44
38
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
45
39
 
46
40
 
47
- @dataclass
48
- class Usage:
49
- """LLM usage associated with a request or run.
50
-
51
- Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
52
-
53
- You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
54
- """
55
-
56
- requests: int = 0
57
- """Number of requests made to the LLM API."""
58
- request_tokens: int | None = None
59
- """Tokens used in processing requests."""
60
- response_tokens: int | None = None
61
- """Tokens used in generating responses."""
62
- total_tokens: int | None = None
63
- """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
64
- details: dict[str, int] | None = None
65
- """Any extra details returned by the model."""
66
-
67
- def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
68
- """Increment the usage in place.
69
-
70
- Args:
71
- incr_usage: The usage to increment by.
72
- requests: The number of requests to increment by in addition to `incr_usage.requests`.
73
- """
74
- self.requests += requests
75
- for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
76
- self_value = getattr(self, f)
77
- other_value = getattr(incr_usage, f)
78
- if self_value is not None or other_value is not None:
79
- setattr(self, f, (self_value or 0) + (other_value or 0))
80
-
81
- if incr_usage.details:
82
- self.details = self.details or {}
83
- for key, value in incr_usage.details.items():
84
- self.details[key] = self.details.get(key, 0) + value
85
-
86
- def __add__(self, other: Usage) -> Usage:
87
- """Add two Usages together.
88
-
89
- This is provided so it's trivial to sum usage information from multiple requests and runs.
90
- """
91
- new_usage = copy(self)
92
- new_usage.incr(other)
93
- return new_usage
94
-
95
-
96
41
  @dataclass
97
42
  class _BaseRunResult(ABC, Generic[ResultData]):
98
43
  """Base type for results.
@@ -103,25 +48,70 @@ class _BaseRunResult(ABC, Generic[ResultData]):
103
48
  _all_messages: list[_messages.ModelMessage]
104
49
  _new_message_index: int
105
50
 
106
- def all_messages(self) -> list[_messages.ModelMessage]:
107
- """Return the history of _messages."""
51
+ def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
52
+ """Return the history of _messages.
53
+
54
+ Args:
55
+ result_tool_return_content: The return content of the tool call to set in the last message.
56
+ This provides a convenient way to modify the content of the result tool call if you want to continue
57
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
58
+ not be modified.
59
+
60
+ Returns:
61
+ List of messages.
62
+ """
108
63
  # this is a method to be consistent with the other methods
64
+ if result_tool_return_content is not None:
65
+ raise NotImplementedError('Setting result tool return content is not supported for this result type.')
109
66
  return self._all_messages
110
67
 
111
- def all_messages_json(self) -> bytes:
112
- """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
113
- return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
68
+ def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
69
+ """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.
70
+
71
+ Args:
72
+ result_tool_return_content: The return content of the tool call to set in the last message.
73
+ This provides a convenient way to modify the content of the result tool call if you want to continue
74
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
75
+ not be modified.
76
+
77
+ Returns:
78
+ JSON bytes representing the messages.
79
+ """
80
+ return _messages.ModelMessagesTypeAdapter.dump_json(
81
+ self.all_messages(result_tool_return_content=result_tool_return_content)
82
+ )
114
83
 
115
- def new_messages(self) -> list[_messages.ModelMessage]:
84
+ def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
116
85
  """Return new messages associated with this run.
117
86
 
118
- System prompts and any messages from older runs are excluded.
87
+ Messages from older runs are excluded.
88
+
89
+ Args:
90
+ result_tool_return_content: The return content of the tool call to set in the last message.
91
+ This provides a convenient way to modify the content of the result tool call if you want to continue
92
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
93
+ not be modified.
94
+
95
+ Returns:
96
+ List of new messages.
119
97
  """
120
- return self.all_messages()[self._new_message_index :]
98
+ return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
99
+
100
+ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
101
+ """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes.
121
102
 
122
- def new_messages_json(self) -> bytes:
123
- """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
124
- return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
103
+ Args:
104
+ result_tool_return_content: The return content of the tool call to set in the last message.
105
+ This provides a convenient way to modify the content of the result tool call if you want to continue
106
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
107
+ not be modified.
108
+
109
+ Returns:
110
+ JSON bytes representing the new messages.
111
+ """
112
+ return _messages.ModelMessagesTypeAdapter.dump_json(
113
+ self.new_messages(result_tool_return_content=result_tool_return_content)
114
+ )
125
115
 
126
116
  @abstractmethod
127
117
  def usage(self) -> Usage:
@@ -134,12 +124,45 @@ class RunResult(_BaseRunResult[ResultData]):
134
124
 
135
125
  data: ResultData
136
126
  """Data from the final response in the run."""
127
+ _result_tool_name: str | None
137
128
  _usage: Usage
138
129
 
139
130
  def usage(self) -> Usage:
140
131
  """Return the usage of the whole run."""
141
132
  return self._usage
142
133
 
134
+ def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
135
+ """Return the history of _messages.
136
+
137
+ Args:
138
+ result_tool_return_content: The return content of the tool call to set in the last message.
139
+ This provides a convenient way to modify the content of the result tool call if you want to continue
140
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
141
+ not be modified.
142
+
143
+ Returns:
144
+ List of messages.
145
+ """
146
+ if result_tool_return_content is not None:
147
+ return self._set_result_tool_return(result_tool_return_content)
148
+ else:
149
+ return self._all_messages
150
+
151
+ def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
152
+ """Set return content for the result tool.
153
+
154
+ Useful if you want to continue the conversation and want to set the response to the result tool call.
155
+ """
156
+ if not self._result_tool_name:
157
+ raise ValueError('Cannot set result tool return content when the return type is `str`.')
158
+ messages = deepcopy(self._all_messages)
159
+ last_message = messages[-1]
160
+ for part in last_message.parts:
161
+ if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name:
162
+ part.content = return_content
163
+ return messages
164
+ raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.')
165
+
143
166
 
144
167
  @dataclass
145
168
  class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
pydantic_ai/settings.py CHANGED
@@ -1,15 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
3
  from typing import TYPE_CHECKING
5
4
 
6
5
  from httpx import Timeout
7
6
  from typing_extensions import TypedDict
8
7
 
9
- from .exceptions import UsageLimitExceeded
10
-
11
8
  if TYPE_CHECKING:
12
- from .result import Usage
9
+ pass
13
10
 
14
11
 
15
12
  class ModelSettings(TypedDict, total=False):
@@ -82,60 +79,3 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings |
82
79
  return base | overrides
83
80
  else:
84
81
  return base or overrides
85
-
86
-
87
- @dataclass
88
- class UsageLimits:
89
- """Limits on model usage.
90
-
91
- The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
92
- Token counts are provided in responses from the model, and the token limits are checked after each response.
93
-
94
- Each of the limits can be set to `None` to disable that limit.
95
- """
96
-
97
- request_limit: int | None = 50
98
- """The maximum number of requests allowed to the model."""
99
- request_tokens_limit: int | None = None
100
- """The maximum number of tokens allowed in requests to the model."""
101
- response_tokens_limit: int | None = None
102
- """The maximum number of tokens allowed in responses from the model."""
103
- total_tokens_limit: int | None = None
104
- """The maximum number of tokens allowed in requests and responses combined."""
105
-
106
- def has_token_limits(self) -> bool:
107
- """Returns `True` if this instance places any limits on token counts.
108
-
109
- If this returns `False`, the `check_tokens` method will never raise an error.
110
-
111
- This is useful because if we have token limits, we need to check them after receiving each streamed message.
112
- If there are no limits, we can skip that processing in the streaming response iterator.
113
- """
114
- return any(
115
- limit is not None
116
- for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
117
- )
118
-
119
- def check_before_request(self, usage: Usage) -> None:
120
- """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
121
- request_limit = self.request_limit
122
- if request_limit is not None and usage.requests >= request_limit:
123
- raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
124
-
125
- def check_tokens(self, usage: Usage) -> None:
126
- """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
127
- request_tokens = usage.request_tokens or 0
128
- if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
129
- raise UsageLimitExceeded(
130
- f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
131
- )
132
-
133
- response_tokens = usage.response_tokens or 0
134
- if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
135
- raise UsageLimitExceeded(
136
- f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
137
- )
138
-
139
- total_tokens = usage.total_tokens or 0
140
- if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
141
- raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
pydantic_ai/tools.py CHANGED
@@ -4,11 +4,11 @@ import dataclasses
4
4
  import inspect
5
5
  from collections.abc import Awaitable
6
6
  from dataclasses import dataclass, field
7
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
10
  from pydantic_core import SchemaValidator
11
- from typing_extensions import Concatenate, ParamSpec, TypeAlias
11
+ from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
12
12
 
13
13
  from . import _pydantic, _utils, messages as _messages, models
14
14
  from .exceptions import ModelRetry, UnexpectedModelBehavior
@@ -30,7 +30,7 @@ __all__ = (
30
30
  'ToolDefinition',
31
31
  )
32
32
 
33
- AgentDeps = TypeVar('AgentDeps')
33
+ AgentDeps = TypeVar('AgentDeps', default=None)
34
34
  """Type variable for agent dependencies."""
35
35
 
36
36
 
@@ -67,7 +67,7 @@ class RunContext(Generic[AgentDeps]):
67
67
  return dataclasses.replace(self, **kwargs)
68
68
 
69
69
 
70
- ToolParams = ParamSpec('ToolParams')
70
+ ToolParams = ParamSpec('ToolParams', default=...)
71
71
  """Retrieval function param spec."""
72
72
 
73
73
  SystemPromptFunc = Union[
@@ -92,7 +92,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
92
92
  Usage `ToolPlainFunc[ToolParams]`.
93
93
  """
94
94
  ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
95
- """Either part_kind of tool function.
95
+ """Either kind of tool function.
96
96
 
97
97
  This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
98
98
  [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
@@ -134,7 +134,7 @@ A = TypeVar('A')
134
134
  class Tool(Generic[AgentDeps]):
135
135
  """A tool function for an agent."""
136
136
 
137
- function: ToolFuncEither[AgentDeps, ...]
137
+ function: ToolFuncEither[AgentDeps]
138
138
  takes_ctx: bool
139
139
  max_retries: int | None
140
140
  name: str
@@ -150,7 +150,7 @@ class Tool(Generic[AgentDeps]):
150
150
 
151
151
  def __init__(
152
152
  self,
153
- function: ToolFuncEither[AgentDeps, ...],
153
+ function: ToolFuncEither[AgentDeps],
154
154
  *,
155
155
  takes_ctx: bool | None = None,
156
156
  max_retries: int | None = None,
pydantic_ai/usage.py ADDED
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from copy import copy
4
+ from dataclasses import dataclass
5
+
6
+ from .exceptions import UsageLimitExceeded
7
+
8
+ __all__ = 'Usage', 'UsageLimits'
9
+
10
+
11
+ @dataclass
12
+ class Usage:
13
+ """LLM usage associated with a request or run.
14
+
15
+ Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
16
+
17
+ You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
18
+ """
19
+
20
+ requests: int = 0
21
+ """Number of requests made to the LLM API."""
22
+ request_tokens: int | None = None
23
+ """Tokens used in processing requests."""
24
+ response_tokens: int | None = None
25
+ """Tokens used in generating responses."""
26
+ total_tokens: int | None = None
27
+ """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
28
+ details: dict[str, int] | None = None
29
+ """Any extra details returned by the model."""
30
+
31
+ def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
32
+ """Increment the usage in place.
33
+
34
+ Args:
35
+ incr_usage: The usage to increment by.
36
+ requests: The number of requests to increment by in addition to `incr_usage.requests`.
37
+ """
38
+ self.requests += requests
39
+ for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
40
+ self_value = getattr(self, f)
41
+ other_value = getattr(incr_usage, f)
42
+ if self_value is not None or other_value is not None:
43
+ setattr(self, f, (self_value or 0) + (other_value or 0))
44
+
45
+ if incr_usage.details:
46
+ self.details = self.details or {}
47
+ for key, value in incr_usage.details.items():
48
+ self.details[key] = self.details.get(key, 0) + value
49
+
50
+ def __add__(self, other: Usage) -> Usage:
51
+ """Add two Usages together.
52
+
53
+ This is provided so it's trivial to sum usage information from multiple requests and runs.
54
+ """
55
+ new_usage = copy(self)
56
+ new_usage.incr(other)
57
+ return new_usage
58
+
59
+
60
+ @dataclass
61
+ class UsageLimits:
62
+ """Limits on model usage.
63
+
64
+ The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
65
+ Token counts are provided in responses from the model, and the token limits are checked after each response.
66
+
67
+ Each of the limits can be set to `None` to disable that limit.
68
+ """
69
+
70
+ request_limit: int | None = 50
71
+ """The maximum number of requests allowed to the model."""
72
+ request_tokens_limit: int | None = None
73
+ """The maximum number of tokens allowed in requests to the model."""
74
+ response_tokens_limit: int | None = None
75
+ """The maximum number of tokens allowed in responses from the model."""
76
+ total_tokens_limit: int | None = None
77
+ """The maximum number of tokens allowed in requests and responses combined."""
78
+
79
+ def has_token_limits(self) -> bool:
80
+ """Returns `True` if this instance places any limits on token counts.
81
+
82
+ If this returns `False`, the `check_tokens` method will never raise an error.
83
+
84
+ This is useful because if we have token limits, we need to check them after receiving each streamed message.
85
+ If there are no limits, we can skip that processing in the streaming response iterator.
86
+ """
87
+ return any(
88
+ limit is not None
89
+ for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
90
+ )
91
+
92
+ def check_before_request(self, usage: Usage) -> None:
93
+ """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
94
+ request_limit = self.request_limit
95
+ if request_limit is not None and usage.requests >= request_limit:
96
+ raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
97
+
98
+ def check_tokens(self, usage: Usage) -> None:
99
+ """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
100
+ request_tokens = usage.request_tokens or 0
101
+ if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
102
+ raise UsageLimitExceeded(
103
+ f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
104
+ )
105
+
106
+ response_tokens = usage.response_tokens or 0
107
+ if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
108
+ raise UsageLimitExceeded(
109
+ f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
110
+ )
111
+
112
+ total_tokens = usage.total_tokens or 0
113
+ if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
114
+ raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.16
3
+ Version: 0.0.17
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -4,23 +4,25 @@ pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
4
4
  pydantic_ai/_result.py,sha256=cUSugZQV0n5Z4fFHiMqua-2xs_0S6m-rr-yd6QS3nFE,10317
5
5
  pydantic_ai/_system_prompt.py,sha256=MZJWksIoS5GM3Au5lznlcQnC-h7eqwtE7oI5WFgRcOg,1090
6
6
  pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
7
- pydantic_ai/agent.py,sha256=NJTcPSlqb4Fd-x9pDPuoXGCwFGF1GHcHevutoB0Busw,52333
7
+ pydantic_ai/agent.py,sha256=8v7gyfMKB76k04SabQNV3QtUz80fSSL2BofULWwYO-o,52514
8
8
  pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
9
+ pydantic_ai/format_as_xml.py,sha256=Gm65687GL8Z6A_lPiJWL1O_E3ovHEBn2O1DKhn1CDnA,4472
9
10
  pydantic_ai/messages.py,sha256=ImbWY8Ft3mxInUQ08EmIWywf4nJBvTiJhmsECRYDkSQ,8968
10
11
  pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- pydantic_ai/result.py,sha256=LbZVHZnJnQwgegSz5PtwS9r_ifrJnLRpsa9xjYnHg1g,15549
12
- pydantic_ai/settings.py,sha256=W8krcFsujjhE03rwckrz39F4Dz_9RwdBSeEF3izK0-Y,4918
13
- pydantic_ai/tools.py,sha256=mnh3Lvs0Ri0FkqpV1MUooExNN4epTCcBKw6DyZvNSQ8,11745
12
+ pydantic_ai/result.py,sha256=-dpaaD24E1Ns7fxz5Gn7SKou-A8Cag4LjEyCBJbrHzY,17597
13
+ pydantic_ai/settings.py,sha256=oTk8ZfYuUsNxpJMWLvSrO1OH_0ur7VKgDNTMQG0tPyM,1974
14
+ pydantic_ai/tools.py,sha256=G4lwAb7QIowtSHk7w5cH8WQFIFqwMPn0J6Nqhgz7ubA,11757
15
+ pydantic_ai/usage.py,sha256=60d9f6M7YEYuKMbqDGDogX4KsA73fhDtWyDXYXoIPaI,4948
14
16
  pydantic_ai/models/__init__.py,sha256=XHt02IDQAircb-lEkIbIcuabSAIh5_UKnz2V1xN0Glw,10926
15
17
  pydantic_ai/models/anthropic.py,sha256=EUZgmvT0jhMDbooBp_jfW0z2cM5jTMuAhVws1XKgaNs,13451
16
18
  pydantic_ai/models/function.py,sha256=i7qkS_31aHrTbYVh6OzQ7Cwucz44F5PjT2EJK3GMphw,10573
17
- pydantic_ai/models/gemini.py,sha256=Sr19D2hN8iEAcoLlzv5883pto90TgEr_xiGlV8hMOwA,28572
19
+ pydantic_ai/models/gemini.py,sha256=jHBVJFLgp7kPLXYy1zYTs_-ush9qS2fkmC28hK8vkJ0,28417
18
20
  pydantic_ai/models/groq.py,sha256=ZoPkuWJrf78JPnTRfZhi7v0ETgxJKNN5dH8BLWagGGk,15770
19
21
  pydantic_ai/models/mistral.py,sha256=xGVI6-b8-9vnFickPPI2cRaHEWLc0jKKUM_vMjipf-U,25894
20
22
  pydantic_ai/models/ollama.py,sha256=ELqxhcNcnvQBnadd3gukS01zprUp6v8N_h1P5K-uf6c,4188
21
23
  pydantic_ai/models/openai.py,sha256=qFFInL3NbgfGcsAWigxMP5mscp76hC-jJimHc9woU6Y,16518
22
24
  pydantic_ai/models/test.py,sha256=u2pdZd9OLXQ_jI6CaVt96udXuIcv0Hfnfqd3pFGmeJM,16514
23
- pydantic_ai/models/vertexai.py,sha256=DBCBfpvpIhZaMG7cKvRl5rugCZqJqqEFm74uBc45weo,9259
24
- pydantic_ai_slim-0.0.16.dist-info/METADATA,sha256=4udd7j2erIuMC0ekYgmgQAqsKfhA5sLsKzTcD_QyOeo,2730
25
- pydantic_ai_slim-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
- pydantic_ai_slim-0.0.16.dist-info/RECORD,,
25
+ pydantic_ai/models/vertexai.py,sha256=gBlEGBIOoqGHYqu6d16VLRI0rWizx5I7P2s8IuGM1CQ,9318
26
+ pydantic_ai_slim-0.0.17.dist-info/METADATA,sha256=hhVw5I9w5RQba3Dvsi3dKP9KUuFCfMuehHeGSQhhOmQ,2730
27
+ pydantic_ai_slim-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
+ pydantic_ai_slim-0.0.17.dist-info/RECORD,,