pydantic-ai-slim 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,115 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import Iterable, Iterator, Mapping
4
+ from dataclasses import asdict, dataclass, is_dataclass
5
+ from datetime import date
6
+ from typing import Any
7
+ from xml.etree import ElementTree
8
+
9
+ from pydantic import BaseModel
10
+
11
+ __all__ = ('format_as_xml',)
12
+
13
+
14
+ def format_as_xml(
15
+ obj: Any,
16
+ root_tag: str = 'examples',
17
+ item_tag: str = 'example',
18
+ include_root_tag: bool = True,
19
+ none_str: str = 'null',
20
+ indent: str | None = ' ',
21
+ ) -> str:
22
+ """Format a Python object as XML.
23
+
24
+ This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
25
+ rather than JSON etc.
26
+
27
+ Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
28
+ `Iterable`, `dataclass`, and `BaseModel`.
29
+
30
+ Args:
31
+ obj: Python Object to serialize to XML.
32
+ root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
33
+ item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
34
+ for dataclasses and Pydantic models.
35
+ include_root_tag: Whether to include the root tag in the output
36
+ (The root tag is always included if it includes a body - e.g. when the input is a simple value).
37
+ none_str: String to use for `None` values.
38
+ indent: Indentation string to use for pretty printing.
39
+
40
+ Returns: XML representation of the object.
41
+
42
+ Example:
43
+ ```python {title="format_as_xml_example.py" lint="skip"}
44
+ from pydantic_ai.format_as_xml import format_as_xml
45
+
46
+ print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
47
+ '''
48
+ <user>
49
+ <name>John</name>
50
+ <height>6</height>
51
+ <weight>200</weight>
52
+ </user>
53
+ '''
54
+ ```
55
+ """
56
+ el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
57
+ if not include_root_tag and el.text is None:
58
+ join = '' if indent is None else '\n'
59
+ return join.join(_rootless_xml_elements(el, indent))
60
+ else:
61
+ if indent is not None:
62
+ ElementTree.indent(el, space=indent)
63
+ return ElementTree.tostring(el, encoding='unicode')
64
+
65
+
66
+ @dataclass
67
+ class _ToXml:
68
+ item_tag: str
69
+ none_str: str
70
+
71
+ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
72
+ element = ElementTree.Element(self.item_tag if tag is None else tag)
73
+ if value is None:
74
+ element.text = self.none_str
75
+ elif isinstance(value, str):
76
+ element.text = value
77
+ elif isinstance(value, (bytes, bytearray)):
78
+ element.text = value.decode(errors='ignore')
79
+ elif isinstance(value, (bool, int, float)):
80
+ element.text = str(value)
81
+ elif isinstance(value, date):
82
+ element.text = value.isoformat()
83
+ elif isinstance(value, Mapping):
84
+ self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
85
+ elif is_dataclass(value) and not isinstance(value, type):
86
+ if tag is None:
87
+ element = ElementTree.Element(value.__class__.__name__)
88
+ dc_dict = asdict(value)
89
+ self._mapping_to_xml(element, dc_dict)
90
+ elif isinstance(value, BaseModel):
91
+ if tag is None:
92
+ element = ElementTree.Element(value.__class__.__name__)
93
+ self._mapping_to_xml(element, value.model_dump(mode='python'))
94
+ elif isinstance(value, Iterable):
95
+ for item in value: # pyright: ignore[reportUnknownVariableType]
96
+ item_el = self.to_xml(item, None)
97
+ element.append(item_el)
98
+ else:
99
+ raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
100
+ return element
101
+
102
+ def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
103
+ for key, value in mapping.items():
104
+ if isinstance(key, int):
105
+ key = str(key)
106
+ elif not isinstance(key, str):
107
+ raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
108
+ element.append(self.to_xml(value, key))
109
+
110
+
111
+ def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
112
+ for sub_element in root:
113
+ if indent is not None:
114
+ ElementTree.indent(sub_element, space=indent)
115
+ yield ElementTree.tostring(sub_element, encoding='unicode')
pydantic_ai/messages.py CHANGED
@@ -21,6 +21,12 @@ class SystemPromptPart:
21
21
  content: str
22
22
  """The content of the prompt."""
23
23
 
24
+ dynamic_ref: str | None = None
25
+ """The ref of the dynamic system prompt function that generated this part.
26
+
27
+ Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
28
+ """
29
+
24
30
  part_kind: Literal['system-prompt'] = 'system-prompt'
25
31
  """Part type identifier, this is available on all parts as a discriminator."""
26
32
 
@@ -48,13 +48,12 @@ KnownModelName = Literal[
48
48
  'groq:mixtral-8x7b-32768',
49
49
  'groq:gemma2-9b-it',
50
50
  'groq:gemma-7b-it',
51
- 'gemini-1.5-flash',
52
- 'gemini-1.5-pro',
53
- 'gemini-2.0-flash-exp',
54
- 'vertexai:gemini-1.5-flash',
55
- 'vertexai:gemini-1.5-pro',
56
- # since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
57
- # don't start with "mistral", we add the "mistral:" prefix to all to be explicit
51
+ 'google-gla:gemini-1.5-flash',
52
+ 'google-gla:gemini-1.5-pro',
53
+ 'google-gla:gemini-2.0-flash-exp',
54
+ 'google-vertex:gemini-1.5-flash',
55
+ 'google-vertex:gemini-1.5-pro',
56
+ 'google-vertex:gemini-2.0-flash-exp',
58
57
  'mistral:mistral-small-latest',
59
58
  'mistral:mistral-large-latest',
60
59
  'mistral:codestral-latest',
@@ -76,9 +75,9 @@ KnownModelName = Literal[
76
75
  'ollama:qwen2',
77
76
  'ollama:qwen2.5',
78
77
  'ollama:starcoder2',
79
- 'claude-3-5-haiku-latest',
80
- 'claude-3-5-sonnet-latest',
81
- 'claude-3-opus-latest',
78
+ 'anthropic:claude-3-5-haiku-latest',
79
+ 'anthropic:claude-3-5-sonnet-latest',
80
+ 'anthropic:claude-3-opus-latest',
82
81
  'test',
83
82
  ]
84
83
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -274,6 +273,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
274
273
  from .openai import OpenAIModel
275
274
 
276
275
  return OpenAIModel(model[7:])
276
+ elif model.startswith(('gpt', 'o1')):
277
+ from .openai import OpenAIModel
278
+
279
+ return OpenAIModel(model)
280
+ elif model.startswith('google-gla'):
281
+ from .gemini import GeminiModel
282
+
283
+ return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
284
+ # backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
277
285
  elif model.startswith('gemini'):
278
286
  from .gemini import GeminiModel
279
287
 
@@ -283,6 +291,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
283
291
  from .groq import GroqModel
284
292
 
285
293
  return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
294
+ elif model.startswith('google-vertex'):
295
+ from .vertexai import VertexAIModel
296
+
297
+ return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
298
+ # backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
286
299
  elif model.startswith('vertexai:'):
287
300
  from .vertexai import VertexAIModel
288
301
 
@@ -295,6 +308,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
295
308
  from .ollama import OllamaModel
296
309
 
297
310
  return OllamaModel(model[7:])
311
+ elif model.startswith('anthropic'):
312
+ from .anthropic import AnthropicModel
313
+
314
+ return AnthropicModel(model[10:])
315
+ # backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
298
316
  elif model.startswith('claude'):
299
317
  from .anthropic import AnthropicModel
300
318
 
@@ -136,7 +136,7 @@ class AnthropicModel(Model):
136
136
  )
137
137
 
138
138
  def name(self) -> str:
139
- return self.model_name
139
+ return f'anthropic:{self.model_name}'
140
140
 
141
141
  @staticmethod
142
142
  def _map_tool_definition(f: ToolDefinition) -> ToolParam:
@@ -111,7 +111,7 @@ class GeminiModel(Model):
111
111
  )
112
112
 
113
113
  def name(self) -> str:
114
- return self.model_name
114
+ return f'google-gla:{self.model_name}'
115
115
 
116
116
 
117
117
  class AuthProtocol(Protocol):
@@ -273,17 +273,26 @@ class GeminiAgentModel(AgentModel):
273
273
  contents: list[_GeminiContent] = []
274
274
  for m in messages:
275
275
  if isinstance(m, ModelRequest):
276
+ message_parts: list[_GeminiPartUnion] = []
277
+
276
278
  for part in m.parts:
277
279
  if isinstance(part, SystemPromptPart):
278
280
  sys_prompt_parts.append(_GeminiTextPart(text=part.content))
279
281
  elif isinstance(part, UserPromptPart):
280
- contents.append(_content_user_prompt(part))
282
+ message_parts.append(_GeminiTextPart(text=part.content))
281
283
  elif isinstance(part, ToolReturnPart):
282
- contents.append(_content_tool_return(part))
284
+ message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
283
285
  elif isinstance(part, RetryPromptPart):
284
- contents.append(_content_retry_prompt(part))
286
+ if part.tool_name is None:
287
+ message_parts.append(_GeminiTextPart(text=part.model_response()))
288
+ else:
289
+ response = {'call_error': part.model_response()}
290
+ message_parts.append(_response_part_from_response(part.tool_name, response))
285
291
  else:
286
292
  assert_never(part)
293
+
294
+ if message_parts:
295
+ contents.append(_GeminiContent(role='user', parts=message_parts))
287
296
  elif isinstance(m, ModelResponse):
288
297
  contents.append(_content_model_response(m))
289
298
  else:
@@ -420,24 +429,6 @@ class _GeminiContent(TypedDict):
420
429
  parts: list[_GeminiPartUnion]
421
430
 
422
431
 
423
- def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
424
- return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
425
-
426
-
427
- def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
428
- f_response = _response_part_from_response(m.tool_name, m.model_response_object())
429
- return _GeminiContent(role='user', parts=[f_response])
430
-
431
-
432
- def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
433
- if m.tool_name is None:
434
- part = _GeminiTextPart(text=m.model_response())
435
- else:
436
- response = {'call_error': m.model_response()}
437
- part = _response_part_from_response(m.tool_name, response)
438
- return _GeminiContent(role='user', parts=[part])
439
-
440
-
441
432
  def _content_model_response(m: ModelResponse) -> _GeminiContent:
442
433
  parts: list[_GeminiPartUnion] = []
443
434
  for item in m.parts:
@@ -702,7 +693,7 @@ class _GeminiJsonSchema:
702
693
 
703
694
  def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
704
695
  schema.pop('title', None)
705
- default = schema.pop('default', _utils.UNSET)
696
+ schema.pop('default', None)
706
697
  if ref := schema.pop('$ref', None):
707
698
  # noinspection PyTypeChecker
708
699
  key = re.sub(r'^#/\$defs/', '', ref)
@@ -717,11 +708,12 @@ class _GeminiJsonSchema:
717
708
  if any_of := schema.get('anyOf'):
718
709
  for item_schema in any_of:
719
710
  self._simplify(item_schema, refs_stack)
720
- if len(any_of) == 2 and {'type': 'null'} in any_of and default is None:
711
+ if len(any_of) == 2 and {'type': 'null'} in any_of:
721
712
  for item_schema in any_of:
722
713
  if item_schema != {'type': 'null'}:
723
714
  schema.clear()
724
715
  schema.update(item_schema)
716
+ schema['nullable'] = True
725
717
  return
726
718
 
727
719
  type_ = schema.get('type')
@@ -730,6 +722,12 @@ class _GeminiJsonSchema:
730
722
  self._object(schema, refs_stack)
731
723
  elif type_ == 'array':
732
724
  return self._array(schema, refs_stack)
725
+ elif type_ == 'string' and (fmt := schema.pop('format', None)):
726
+ description = schema.get('description')
727
+ if description:
728
+ schema['description'] = f'{description} (format: {fmt})'
729
+ else:
730
+ schema['description'] = f'Format: {fmt}'
733
731
 
734
732
  def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
735
733
  ad_props = schema.pop('additionalProperties', None)
@@ -164,7 +164,7 @@ class VertexAIModel(Model):
164
164
  return url, auth
165
165
 
166
166
  def name(self) -> str:
167
- return f'vertexai:{self.model_name}'
167
+ return f'google-vertex:{self.model_name}'
168
168
 
169
169
 
170
170
  # pyright: reportUnknownMemberType=false
@@ -178,7 +178,7 @@ def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredenti
178
178
  # pyright: reportUnknownVariableType=false
179
179
  # pyright: reportUnknownArgumentType=false
180
180
  async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
181
- return await run_in_executor(google.auth.default)
181
+ return await run_in_executor(google.auth.default, scopes=['https://www.googleapis.com/auth/cloud-platform'])
182
182
 
183
183
 
184
184
  # default expiry is 3600 seconds
pydantic_ai/result.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import AsyncIterator, Awaitable, Callable
5
- from copy import copy
5
+ from copy import deepcopy
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
8
  from typing import Generic, Union, cast
@@ -11,16 +11,10 @@ import logfire_api
11
11
  from typing_extensions import TypeVar
12
12
 
13
13
  from . import _result, _utils, exceptions, messages as _messages, models
14
- from .settings import UsageLimits
15
14
  from .tools import AgentDeps, RunContext
15
+ from .usage import Usage, UsageLimits
16
16
 
17
- __all__ = (
18
- 'ResultData',
19
- 'ResultValidatorFunc',
20
- 'Usage',
21
- 'RunResult',
22
- 'StreamedRunResult',
23
- )
17
+ __all__ = 'ResultData', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
24
18
 
25
19
 
26
20
  ResultData = TypeVar('ResultData', default=str)
@@ -44,55 +38,6 @@ Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
44
38
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
45
39
 
46
40
 
47
- @dataclass
48
- class Usage:
49
- """LLM usage associated with a request or run.
50
-
51
- Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
52
-
53
- You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
54
- """
55
-
56
- requests: int = 0
57
- """Number of requests made to the LLM API."""
58
- request_tokens: int | None = None
59
- """Tokens used in processing requests."""
60
- response_tokens: int | None = None
61
- """Tokens used in generating responses."""
62
- total_tokens: int | None = None
63
- """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
64
- details: dict[str, int] | None = None
65
- """Any extra details returned by the model."""
66
-
67
- def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
68
- """Increment the usage in place.
69
-
70
- Args:
71
- incr_usage: The usage to increment by.
72
- requests: The number of requests to increment by in addition to `incr_usage.requests`.
73
- """
74
- self.requests += requests
75
- for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
76
- self_value = getattr(self, f)
77
- other_value = getattr(incr_usage, f)
78
- if self_value is not None or other_value is not None:
79
- setattr(self, f, (self_value or 0) + (other_value or 0))
80
-
81
- if incr_usage.details:
82
- self.details = self.details or {}
83
- for key, value in incr_usage.details.items():
84
- self.details[key] = self.details.get(key, 0) + value
85
-
86
- def __add__(self, other: Usage) -> Usage:
87
- """Add two Usages together.
88
-
89
- This is provided so it's trivial to sum usage information from multiple requests and runs.
90
- """
91
- new_usage = copy(self)
92
- new_usage.incr(other)
93
- return new_usage
94
-
95
-
96
41
  @dataclass
97
42
  class _BaseRunResult(ABC, Generic[ResultData]):
98
43
  """Base type for results.
@@ -103,25 +48,70 @@ class _BaseRunResult(ABC, Generic[ResultData]):
103
48
  _all_messages: list[_messages.ModelMessage]
104
49
  _new_message_index: int
105
50
 
106
- def all_messages(self) -> list[_messages.ModelMessage]:
107
- """Return the history of _messages."""
51
+ def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
52
+ """Return the history of _messages.
53
+
54
+ Args:
55
+ result_tool_return_content: The return content of the tool call to set in the last message.
56
+ This provides a convenient way to modify the content of the result tool call if you want to continue
57
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
58
+ not be modified.
59
+
60
+ Returns:
61
+ List of messages.
62
+ """
108
63
  # this is a method to be consistent with the other methods
64
+ if result_tool_return_content is not None:
65
+ raise NotImplementedError('Setting result tool return content is not supported for this result type.')
109
66
  return self._all_messages
110
67
 
111
- def all_messages_json(self) -> bytes:
112
- """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
113
- return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
68
+ def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
69
+ """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.
70
+
71
+ Args:
72
+ result_tool_return_content: The return content of the tool call to set in the last message.
73
+ This provides a convenient way to modify the content of the result tool call if you want to continue
74
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
75
+ not be modified.
76
+
77
+ Returns:
78
+ JSON bytes representing the messages.
79
+ """
80
+ return _messages.ModelMessagesTypeAdapter.dump_json(
81
+ self.all_messages(result_tool_return_content=result_tool_return_content)
82
+ )
114
83
 
115
- def new_messages(self) -> list[_messages.ModelMessage]:
84
+ def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
116
85
  """Return new messages associated with this run.
117
86
 
118
- System prompts and any messages from older runs are excluded.
87
+ Messages from older runs are excluded.
88
+
89
+ Args:
90
+ result_tool_return_content: The return content of the tool call to set in the last message.
91
+ This provides a convenient way to modify the content of the result tool call if you want to continue
92
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
93
+ not be modified.
94
+
95
+ Returns:
96
+ List of new messages.
119
97
  """
120
- return self.all_messages()[self._new_message_index :]
98
+ return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
99
+
100
+ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
101
+ """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes.
121
102
 
122
- def new_messages_json(self) -> bytes:
123
- """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
124
- return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
103
+ Args:
104
+ result_tool_return_content: The return content of the tool call to set in the last message.
105
+ This provides a convenient way to modify the content of the result tool call if you want to continue
106
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
107
+ not be modified.
108
+
109
+ Returns:
110
+ JSON bytes representing the new messages.
111
+ """
112
+ return _messages.ModelMessagesTypeAdapter.dump_json(
113
+ self.new_messages(result_tool_return_content=result_tool_return_content)
114
+ )
125
115
 
126
116
  @abstractmethod
127
117
  def usage(self) -> Usage:
@@ -134,12 +124,45 @@ class RunResult(_BaseRunResult[ResultData]):
134
124
 
135
125
  data: ResultData
136
126
  """Data from the final response in the run."""
127
+ _result_tool_name: str | None
137
128
  _usage: Usage
138
129
 
139
130
  def usage(self) -> Usage:
140
131
  """Return the usage of the whole run."""
141
132
  return self._usage
142
133
 
134
+ def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
135
+ """Return the history of _messages.
136
+
137
+ Args:
138
+ result_tool_return_content: The return content of the tool call to set in the last message.
139
+ This provides a convenient way to modify the content of the result tool call if you want to continue
140
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
141
+ not be modified.
142
+
143
+ Returns:
144
+ List of messages.
145
+ """
146
+ if result_tool_return_content is not None:
147
+ return self._set_result_tool_return(result_tool_return_content)
148
+ else:
149
+ return self._all_messages
150
+
151
+ def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
152
+ """Set return content for the result tool.
153
+
154
+ Useful if you want to continue the conversation and want to set the response to the result tool call.
155
+ """
156
+ if not self._result_tool_name:
157
+ raise ValueError('Cannot set result tool return content when the return type is `str`.')
158
+ messages = deepcopy(self._all_messages)
159
+ last_message = messages[-1]
160
+ for part in last_message.parts:
161
+ if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name:
162
+ part.content = return_content
163
+ return messages
164
+ raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.')
165
+
143
166
 
144
167
  @dataclass
145
168
  class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
pydantic_ai/settings.py CHANGED
@@ -1,15 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
3
  from typing import TYPE_CHECKING
5
4
 
6
5
  from httpx import Timeout
7
6
  from typing_extensions import TypedDict
8
7
 
9
- from .exceptions import UsageLimitExceeded
10
-
11
8
  if TYPE_CHECKING:
12
- from .result import Usage
9
+ pass
13
10
 
14
11
 
15
12
  class ModelSettings(TypedDict, total=False):
@@ -82,60 +79,3 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings |
82
79
  return base | overrides
83
80
  else:
84
81
  return base or overrides
85
-
86
-
87
- @dataclass
88
- class UsageLimits:
89
- """Limits on model usage.
90
-
91
- The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
92
- Token counts are provided in responses from the model, and the token limits are checked after each response.
93
-
94
- Each of the limits can be set to `None` to disable that limit.
95
- """
96
-
97
- request_limit: int | None = 50
98
- """The maximum number of requests allowed to the model."""
99
- request_tokens_limit: int | None = None
100
- """The maximum number of tokens allowed in requests to the model."""
101
- response_tokens_limit: int | None = None
102
- """The maximum number of tokens allowed in responses from the model."""
103
- total_tokens_limit: int | None = None
104
- """The maximum number of tokens allowed in requests and responses combined."""
105
-
106
- def has_token_limits(self) -> bool:
107
- """Returns `True` if this instance places any limits on token counts.
108
-
109
- If this returns `False`, the `check_tokens` method will never raise an error.
110
-
111
- This is useful because if we have token limits, we need to check them after receiving each streamed message.
112
- If there are no limits, we can skip that processing in the streaming response iterator.
113
- """
114
- return any(
115
- limit is not None
116
- for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
117
- )
118
-
119
- def check_before_request(self, usage: Usage) -> None:
120
- """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
121
- request_limit = self.request_limit
122
- if request_limit is not None and usage.requests >= request_limit:
123
- raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
124
-
125
- def check_tokens(self, usage: Usage) -> None:
126
- """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
127
- request_tokens = usage.request_tokens or 0
128
- if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
129
- raise UsageLimitExceeded(
130
- f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
131
- )
132
-
133
- response_tokens = usage.response_tokens or 0
134
- if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
135
- raise UsageLimitExceeded(
136
- f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
137
- )
138
-
139
- total_tokens = usage.total_tokens or 0
140
- if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
141
- raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
pydantic_ai/tools.py CHANGED
@@ -4,11 +4,11 @@ import dataclasses
4
4
  import inspect
5
5
  from collections.abc import Awaitable
6
6
  from dataclasses import dataclass, field
7
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
10
  from pydantic_core import SchemaValidator
11
- from typing_extensions import Concatenate, ParamSpec, TypeAlias
11
+ from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
12
12
 
13
13
  from . import _pydantic, _utils, messages as _messages, models
14
14
  from .exceptions import ModelRetry, UnexpectedModelBehavior
@@ -30,7 +30,7 @@ __all__ = (
30
30
  'ToolDefinition',
31
31
  )
32
32
 
33
- AgentDeps = TypeVar('AgentDeps')
33
+ AgentDeps = TypeVar('AgentDeps', default=None)
34
34
  """Type variable for agent dependencies."""
35
35
 
36
36
 
@@ -67,7 +67,7 @@ class RunContext(Generic[AgentDeps]):
67
67
  return dataclasses.replace(self, **kwargs)
68
68
 
69
69
 
70
- ToolParams = ParamSpec('ToolParams')
70
+ ToolParams = ParamSpec('ToolParams', default=...)
71
71
  """Retrieval function param spec."""
72
72
 
73
73
  SystemPromptFunc = Union[
@@ -92,7 +92,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
92
92
  Usage `ToolPlainFunc[ToolParams]`.
93
93
  """
94
94
  ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
95
- """Either part_kind of tool function.
95
+ """Either kind of tool function.
96
96
 
97
97
  This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
98
98
  [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
@@ -134,7 +134,7 @@ A = TypeVar('A')
134
134
  class Tool(Generic[AgentDeps]):
135
135
  """A tool function for an agent."""
136
136
 
137
- function: ToolFuncEither[AgentDeps, ...]
137
+ function: ToolFuncEither[AgentDeps]
138
138
  takes_ctx: bool
139
139
  max_retries: int | None
140
140
  name: str
@@ -150,7 +150,7 @@ class Tool(Generic[AgentDeps]):
150
150
 
151
151
  def __init__(
152
152
  self,
153
- function: ToolFuncEither[AgentDeps, ...],
153
+ function: ToolFuncEither[AgentDeps],
154
154
  *,
155
155
  takes_ctx: bool | None = None,
156
156
  max_retries: int | None = None,