pydantic-ai-slim 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +13 -1
- pydantic_ai/_system_prompt.py +1 -0
- pydantic_ai/agent.py +272 -68
- pydantic_ai/format_as_xml.py +115 -0
- pydantic_ai/messages.py +6 -0
- pydantic_ai/models/__init__.py +28 -10
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/gemini.py +22 -24
- pydantic_ai/models/vertexai.py +2 -2
- pydantic_ai/result.py +92 -69
- pydantic_ai/settings.py +1 -61
- pydantic_ai/tools.py +7 -7
- pydantic_ai/usage.py +114 -0
- {pydantic_ai_slim-0.0.16.dist-info → pydantic_ai_slim-0.0.18.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.18.dist-info/RECORD +28 -0
- pydantic_ai_slim-0.0.16.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.16.dist-info → pydantic_ai_slim-0.0.18.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable, Iterator, Mapping
|
|
4
|
+
from dataclasses import asdict, dataclass, is_dataclass
|
|
5
|
+
from datetime import date
|
|
6
|
+
from typing import Any
|
|
7
|
+
from xml.etree import ElementTree
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
__all__ = ('format_as_xml',)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def format_as_xml(
|
|
15
|
+
obj: Any,
|
|
16
|
+
root_tag: str = 'examples',
|
|
17
|
+
item_tag: str = 'example',
|
|
18
|
+
include_root_tag: bool = True,
|
|
19
|
+
none_str: str = 'null',
|
|
20
|
+
indent: str | None = ' ',
|
|
21
|
+
) -> str:
|
|
22
|
+
"""Format a Python object as XML.
|
|
23
|
+
|
|
24
|
+
This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
|
|
25
|
+
rather than JSON etc.
|
|
26
|
+
|
|
27
|
+
Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
|
|
28
|
+
`Iterable`, `dataclass`, and `BaseModel`.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
obj: Python Object to serialize to XML.
|
|
32
|
+
root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
|
|
33
|
+
item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
|
|
34
|
+
for dataclasses and Pydantic models.
|
|
35
|
+
include_root_tag: Whether to include the root tag in the output
|
|
36
|
+
(The root tag is always included if it includes a body - e.g. when the input is a simple value).
|
|
37
|
+
none_str: String to use for `None` values.
|
|
38
|
+
indent: Indentation string to use for pretty printing.
|
|
39
|
+
|
|
40
|
+
Returns: XML representation of the object.
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
```python {title="format_as_xml_example.py" lint="skip"}
|
|
44
|
+
from pydantic_ai.format_as_xml import format_as_xml
|
|
45
|
+
|
|
46
|
+
print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
|
|
47
|
+
'''
|
|
48
|
+
<user>
|
|
49
|
+
<name>John</name>
|
|
50
|
+
<height>6</height>
|
|
51
|
+
<weight>200</weight>
|
|
52
|
+
</user>
|
|
53
|
+
'''
|
|
54
|
+
```
|
|
55
|
+
"""
|
|
56
|
+
el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
|
|
57
|
+
if not include_root_tag and el.text is None:
|
|
58
|
+
join = '' if indent is None else '\n'
|
|
59
|
+
return join.join(_rootless_xml_elements(el, indent))
|
|
60
|
+
else:
|
|
61
|
+
if indent is not None:
|
|
62
|
+
ElementTree.indent(el, space=indent)
|
|
63
|
+
return ElementTree.tostring(el, encoding='unicode')
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class _ToXml:
|
|
68
|
+
item_tag: str
|
|
69
|
+
none_str: str
|
|
70
|
+
|
|
71
|
+
def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
|
|
72
|
+
element = ElementTree.Element(self.item_tag if tag is None else tag)
|
|
73
|
+
if value is None:
|
|
74
|
+
element.text = self.none_str
|
|
75
|
+
elif isinstance(value, str):
|
|
76
|
+
element.text = value
|
|
77
|
+
elif isinstance(value, (bytes, bytearray)):
|
|
78
|
+
element.text = value.decode(errors='ignore')
|
|
79
|
+
elif isinstance(value, (bool, int, float)):
|
|
80
|
+
element.text = str(value)
|
|
81
|
+
elif isinstance(value, date):
|
|
82
|
+
element.text = value.isoformat()
|
|
83
|
+
elif isinstance(value, Mapping):
|
|
84
|
+
self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
|
|
85
|
+
elif is_dataclass(value) and not isinstance(value, type):
|
|
86
|
+
if tag is None:
|
|
87
|
+
element = ElementTree.Element(value.__class__.__name__)
|
|
88
|
+
dc_dict = asdict(value)
|
|
89
|
+
self._mapping_to_xml(element, dc_dict)
|
|
90
|
+
elif isinstance(value, BaseModel):
|
|
91
|
+
if tag is None:
|
|
92
|
+
element = ElementTree.Element(value.__class__.__name__)
|
|
93
|
+
self._mapping_to_xml(element, value.model_dump(mode='python'))
|
|
94
|
+
elif isinstance(value, Iterable):
|
|
95
|
+
for item in value: # pyright: ignore[reportUnknownVariableType]
|
|
96
|
+
item_el = self.to_xml(item, None)
|
|
97
|
+
element.append(item_el)
|
|
98
|
+
else:
|
|
99
|
+
raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
|
|
100
|
+
return element
|
|
101
|
+
|
|
102
|
+
def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
|
|
103
|
+
for key, value in mapping.items():
|
|
104
|
+
if isinstance(key, int):
|
|
105
|
+
key = str(key)
|
|
106
|
+
elif not isinstance(key, str):
|
|
107
|
+
raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
|
|
108
|
+
element.append(self.to_xml(value, key))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
|
|
112
|
+
for sub_element in root:
|
|
113
|
+
if indent is not None:
|
|
114
|
+
ElementTree.indent(sub_element, space=indent)
|
|
115
|
+
yield ElementTree.tostring(sub_element, encoding='unicode')
|
pydantic_ai/messages.py
CHANGED
|
@@ -21,6 +21,12 @@ class SystemPromptPart:
|
|
|
21
21
|
content: str
|
|
22
22
|
"""The content of the prompt."""
|
|
23
23
|
|
|
24
|
+
dynamic_ref: str | None = None
|
|
25
|
+
"""The ref of the dynamic system prompt function that generated this part.
|
|
26
|
+
|
|
27
|
+
Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
|
|
28
|
+
"""
|
|
29
|
+
|
|
24
30
|
part_kind: Literal['system-prompt'] = 'system-prompt'
|
|
25
31
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
26
32
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -48,13 +48,12 @@ KnownModelName = Literal[
|
|
|
48
48
|
'groq:mixtral-8x7b-32768',
|
|
49
49
|
'groq:gemma2-9b-it',
|
|
50
50
|
'groq:gemma-7b-it',
|
|
51
|
-
'gemini-1.5-flash',
|
|
52
|
-
'gemini-1.5-pro',
|
|
53
|
-
'gemini-2.0-flash-exp',
|
|
54
|
-
'
|
|
55
|
-
'
|
|
56
|
-
|
|
57
|
-
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
|
|
51
|
+
'google-gla:gemini-1.5-flash',
|
|
52
|
+
'google-gla:gemini-1.5-pro',
|
|
53
|
+
'google-gla:gemini-2.0-flash-exp',
|
|
54
|
+
'google-vertex:gemini-1.5-flash',
|
|
55
|
+
'google-vertex:gemini-1.5-pro',
|
|
56
|
+
'google-vertex:gemini-2.0-flash-exp',
|
|
58
57
|
'mistral:mistral-small-latest',
|
|
59
58
|
'mistral:mistral-large-latest',
|
|
60
59
|
'mistral:codestral-latest',
|
|
@@ -76,9 +75,9 @@ KnownModelName = Literal[
|
|
|
76
75
|
'ollama:qwen2',
|
|
77
76
|
'ollama:qwen2.5',
|
|
78
77
|
'ollama:starcoder2',
|
|
79
|
-
'claude-3-5-haiku-latest',
|
|
80
|
-
'claude-3-5-sonnet-latest',
|
|
81
|
-
'claude-3-opus-latest',
|
|
78
|
+
'anthropic:claude-3-5-haiku-latest',
|
|
79
|
+
'anthropic:claude-3-5-sonnet-latest',
|
|
80
|
+
'anthropic:claude-3-opus-latest',
|
|
82
81
|
'test',
|
|
83
82
|
]
|
|
84
83
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -274,6 +273,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
274
273
|
from .openai import OpenAIModel
|
|
275
274
|
|
|
276
275
|
return OpenAIModel(model[7:])
|
|
276
|
+
elif model.startswith(('gpt', 'o1')):
|
|
277
|
+
from .openai import OpenAIModel
|
|
278
|
+
|
|
279
|
+
return OpenAIModel(model)
|
|
280
|
+
elif model.startswith('google-gla'):
|
|
281
|
+
from .gemini import GeminiModel
|
|
282
|
+
|
|
283
|
+
return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
|
|
284
|
+
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
277
285
|
elif model.startswith('gemini'):
|
|
278
286
|
from .gemini import GeminiModel
|
|
279
287
|
|
|
@@ -283,6 +291,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
283
291
|
from .groq import GroqModel
|
|
284
292
|
|
|
285
293
|
return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
|
|
294
|
+
elif model.startswith('google-vertex'):
|
|
295
|
+
from .vertexai import VertexAIModel
|
|
296
|
+
|
|
297
|
+
return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
|
|
298
|
+
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
286
299
|
elif model.startswith('vertexai:'):
|
|
287
300
|
from .vertexai import VertexAIModel
|
|
288
301
|
|
|
@@ -295,6 +308,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
295
308
|
from .ollama import OllamaModel
|
|
296
309
|
|
|
297
310
|
return OllamaModel(model[7:])
|
|
311
|
+
elif model.startswith('anthropic'):
|
|
312
|
+
from .anthropic import AnthropicModel
|
|
313
|
+
|
|
314
|
+
return AnthropicModel(model[10:])
|
|
315
|
+
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
|
|
298
316
|
elif model.startswith('claude'):
|
|
299
317
|
from .anthropic import AnthropicModel
|
|
300
318
|
|
pydantic_ai/models/anthropic.py
CHANGED
pydantic_ai/models/gemini.py
CHANGED
|
@@ -111,7 +111,7 @@ class GeminiModel(Model):
|
|
|
111
111
|
)
|
|
112
112
|
|
|
113
113
|
def name(self) -> str:
|
|
114
|
-
return self.model_name
|
|
114
|
+
return f'google-gla:{self.model_name}'
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
class AuthProtocol(Protocol):
|
|
@@ -273,17 +273,26 @@ class GeminiAgentModel(AgentModel):
|
|
|
273
273
|
contents: list[_GeminiContent] = []
|
|
274
274
|
for m in messages:
|
|
275
275
|
if isinstance(m, ModelRequest):
|
|
276
|
+
message_parts: list[_GeminiPartUnion] = []
|
|
277
|
+
|
|
276
278
|
for part in m.parts:
|
|
277
279
|
if isinstance(part, SystemPromptPart):
|
|
278
280
|
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
|
|
279
281
|
elif isinstance(part, UserPromptPart):
|
|
280
|
-
|
|
282
|
+
message_parts.append(_GeminiTextPart(text=part.content))
|
|
281
283
|
elif isinstance(part, ToolReturnPart):
|
|
282
|
-
|
|
284
|
+
message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
|
|
283
285
|
elif isinstance(part, RetryPromptPart):
|
|
284
|
-
|
|
286
|
+
if part.tool_name is None:
|
|
287
|
+
message_parts.append(_GeminiTextPart(text=part.model_response()))
|
|
288
|
+
else:
|
|
289
|
+
response = {'call_error': part.model_response()}
|
|
290
|
+
message_parts.append(_response_part_from_response(part.tool_name, response))
|
|
285
291
|
else:
|
|
286
292
|
assert_never(part)
|
|
293
|
+
|
|
294
|
+
if message_parts:
|
|
295
|
+
contents.append(_GeminiContent(role='user', parts=message_parts))
|
|
287
296
|
elif isinstance(m, ModelResponse):
|
|
288
297
|
contents.append(_content_model_response(m))
|
|
289
298
|
else:
|
|
@@ -420,24 +429,6 @@ class _GeminiContent(TypedDict):
|
|
|
420
429
|
parts: list[_GeminiPartUnion]
|
|
421
430
|
|
|
422
431
|
|
|
423
|
-
def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
|
|
424
|
-
return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
|
|
428
|
-
f_response = _response_part_from_response(m.tool_name, m.model_response_object())
|
|
429
|
-
return _GeminiContent(role='user', parts=[f_response])
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
|
|
433
|
-
if m.tool_name is None:
|
|
434
|
-
part = _GeminiTextPart(text=m.model_response())
|
|
435
|
-
else:
|
|
436
|
-
response = {'call_error': m.model_response()}
|
|
437
|
-
part = _response_part_from_response(m.tool_name, response)
|
|
438
|
-
return _GeminiContent(role='user', parts=[part])
|
|
439
|
-
|
|
440
|
-
|
|
441
432
|
def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
442
433
|
parts: list[_GeminiPartUnion] = []
|
|
443
434
|
for item in m.parts:
|
|
@@ -702,7 +693,7 @@ class _GeminiJsonSchema:
|
|
|
702
693
|
|
|
703
694
|
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
704
695
|
schema.pop('title', None)
|
|
705
|
-
|
|
696
|
+
schema.pop('default', None)
|
|
706
697
|
if ref := schema.pop('$ref', None):
|
|
707
698
|
# noinspection PyTypeChecker
|
|
708
699
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
@@ -717,11 +708,12 @@ class _GeminiJsonSchema:
|
|
|
717
708
|
if any_of := schema.get('anyOf'):
|
|
718
709
|
for item_schema in any_of:
|
|
719
710
|
self._simplify(item_schema, refs_stack)
|
|
720
|
-
if len(any_of) == 2 and {'type': 'null'} in any_of
|
|
711
|
+
if len(any_of) == 2 and {'type': 'null'} in any_of:
|
|
721
712
|
for item_schema in any_of:
|
|
722
713
|
if item_schema != {'type': 'null'}:
|
|
723
714
|
schema.clear()
|
|
724
715
|
schema.update(item_schema)
|
|
716
|
+
schema['nullable'] = True
|
|
725
717
|
return
|
|
726
718
|
|
|
727
719
|
type_ = schema.get('type')
|
|
@@ -730,6 +722,12 @@ class _GeminiJsonSchema:
|
|
|
730
722
|
self._object(schema, refs_stack)
|
|
731
723
|
elif type_ == 'array':
|
|
732
724
|
return self._array(schema, refs_stack)
|
|
725
|
+
elif type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
726
|
+
description = schema.get('description')
|
|
727
|
+
if description:
|
|
728
|
+
schema['description'] = f'{description} (format: {fmt})'
|
|
729
|
+
else:
|
|
730
|
+
schema['description'] = f'Format: {fmt}'
|
|
733
731
|
|
|
734
732
|
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
735
733
|
ad_props = schema.pop('additionalProperties', None)
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -164,7 +164,7 @@ class VertexAIModel(Model):
|
|
|
164
164
|
return url, auth
|
|
165
165
|
|
|
166
166
|
def name(self) -> str:
|
|
167
|
-
return f'
|
|
167
|
+
return f'google-vertex:{self.model_name}'
|
|
168
168
|
|
|
169
169
|
|
|
170
170
|
# pyright: reportUnknownMemberType=false
|
|
@@ -178,7 +178,7 @@ def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredenti
|
|
|
178
178
|
# pyright: reportUnknownVariableType=false
|
|
179
179
|
# pyright: reportUnknownArgumentType=false
|
|
180
180
|
async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
|
|
181
|
-
return await run_in_executor(google.auth.default)
|
|
181
|
+
return await run_in_executor(google.auth.default, scopes=['https://www.googleapis.com/auth/cloud-platform'])
|
|
182
182
|
|
|
183
183
|
|
|
184
184
|
# default expiry is 3600 seconds
|
pydantic_ai/result.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
5
|
-
from copy import
|
|
5
|
+
from copy import deepcopy
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Generic, Union, cast
|
|
@@ -11,16 +11,10 @@ import logfire_api
|
|
|
11
11
|
from typing_extensions import TypeVar
|
|
12
12
|
|
|
13
13
|
from . import _result, _utils, exceptions, messages as _messages, models
|
|
14
|
-
from .settings import UsageLimits
|
|
15
14
|
from .tools import AgentDeps, RunContext
|
|
15
|
+
from .usage import Usage, UsageLimits
|
|
16
16
|
|
|
17
|
-
__all__ =
|
|
18
|
-
'ResultData',
|
|
19
|
-
'ResultValidatorFunc',
|
|
20
|
-
'Usage',
|
|
21
|
-
'RunResult',
|
|
22
|
-
'StreamedRunResult',
|
|
23
|
-
)
|
|
17
|
+
__all__ = 'ResultData', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
|
|
24
18
|
|
|
25
19
|
|
|
26
20
|
ResultData = TypeVar('ResultData', default=str)
|
|
@@ -44,55 +38,6 @@ Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
|
|
|
44
38
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
45
39
|
|
|
46
40
|
|
|
47
|
-
@dataclass
|
|
48
|
-
class Usage:
|
|
49
|
-
"""LLM usage associated with a request or run.
|
|
50
|
-
|
|
51
|
-
Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
|
|
52
|
-
|
|
53
|
-
You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
requests: int = 0
|
|
57
|
-
"""Number of requests made to the LLM API."""
|
|
58
|
-
request_tokens: int | None = None
|
|
59
|
-
"""Tokens used in processing requests."""
|
|
60
|
-
response_tokens: int | None = None
|
|
61
|
-
"""Tokens used in generating responses."""
|
|
62
|
-
total_tokens: int | None = None
|
|
63
|
-
"""Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
|
|
64
|
-
details: dict[str, int] | None = None
|
|
65
|
-
"""Any extra details returned by the model."""
|
|
66
|
-
|
|
67
|
-
def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
|
|
68
|
-
"""Increment the usage in place.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
incr_usage: The usage to increment by.
|
|
72
|
-
requests: The number of requests to increment by in addition to `incr_usage.requests`.
|
|
73
|
-
"""
|
|
74
|
-
self.requests += requests
|
|
75
|
-
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
|
|
76
|
-
self_value = getattr(self, f)
|
|
77
|
-
other_value = getattr(incr_usage, f)
|
|
78
|
-
if self_value is not None or other_value is not None:
|
|
79
|
-
setattr(self, f, (self_value or 0) + (other_value or 0))
|
|
80
|
-
|
|
81
|
-
if incr_usage.details:
|
|
82
|
-
self.details = self.details or {}
|
|
83
|
-
for key, value in incr_usage.details.items():
|
|
84
|
-
self.details[key] = self.details.get(key, 0) + value
|
|
85
|
-
|
|
86
|
-
def __add__(self, other: Usage) -> Usage:
|
|
87
|
-
"""Add two Usages together.
|
|
88
|
-
|
|
89
|
-
This is provided so it's trivial to sum usage information from multiple requests and runs.
|
|
90
|
-
"""
|
|
91
|
-
new_usage = copy(self)
|
|
92
|
-
new_usage.incr(other)
|
|
93
|
-
return new_usage
|
|
94
|
-
|
|
95
|
-
|
|
96
41
|
@dataclass
|
|
97
42
|
class _BaseRunResult(ABC, Generic[ResultData]):
|
|
98
43
|
"""Base type for results.
|
|
@@ -103,25 +48,70 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
103
48
|
_all_messages: list[_messages.ModelMessage]
|
|
104
49
|
_new_message_index: int
|
|
105
50
|
|
|
106
|
-
def all_messages(self) -> list[_messages.ModelMessage]:
|
|
107
|
-
"""Return the history of _messages.
|
|
51
|
+
def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
52
|
+
"""Return the history of _messages.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
result_tool_return_content: The return content of the tool call to set in the last message.
|
|
56
|
+
This provides a convenient way to modify the content of the result tool call if you want to continue
|
|
57
|
+
the conversation and want to set the response to the result tool call. If `None`, the last message will
|
|
58
|
+
not be modified.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of messages.
|
|
62
|
+
"""
|
|
108
63
|
# this is a method to be consistent with the other methods
|
|
64
|
+
if result_tool_return_content is not None:
|
|
65
|
+
raise NotImplementedError('Setting result tool return content is not supported for this result type.')
|
|
109
66
|
return self._all_messages
|
|
110
67
|
|
|
111
|
-
def all_messages_json(self) -> bytes:
|
|
112
|
-
"""Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.
|
|
113
|
-
|
|
68
|
+
def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
|
|
69
|
+
"""Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
result_tool_return_content: The return content of the tool call to set in the last message.
|
|
73
|
+
This provides a convenient way to modify the content of the result tool call if you want to continue
|
|
74
|
+
the conversation and want to set the response to the result tool call. If `None`, the last message will
|
|
75
|
+
not be modified.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
JSON bytes representing the messages.
|
|
79
|
+
"""
|
|
80
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(
|
|
81
|
+
self.all_messages(result_tool_return_content=result_tool_return_content)
|
|
82
|
+
)
|
|
114
83
|
|
|
115
|
-
def new_messages(self) -> list[_messages.ModelMessage]:
|
|
84
|
+
def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
116
85
|
"""Return new messages associated with this run.
|
|
117
86
|
|
|
118
|
-
|
|
87
|
+
Messages from older runs are excluded.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
result_tool_return_content: The return content of the tool call to set in the last message.
|
|
91
|
+
This provides a convenient way to modify the content of the result tool call if you want to continue
|
|
92
|
+
the conversation and want to set the response to the result tool call. If `None`, the last message will
|
|
93
|
+
not be modified.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
List of new messages.
|
|
119
97
|
"""
|
|
120
|
-
return self.all_messages()[self._new_message_index :]
|
|
98
|
+
return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
|
|
99
|
+
|
|
100
|
+
def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
|
|
101
|
+
"""Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes.
|
|
121
102
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
103
|
+
Args:
|
|
104
|
+
result_tool_return_content: The return content of the tool call to set in the last message.
|
|
105
|
+
This provides a convenient way to modify the content of the result tool call if you want to continue
|
|
106
|
+
the conversation and want to set the response to the result tool call. If `None`, the last message will
|
|
107
|
+
not be modified.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
JSON bytes representing the new messages.
|
|
111
|
+
"""
|
|
112
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(
|
|
113
|
+
self.new_messages(result_tool_return_content=result_tool_return_content)
|
|
114
|
+
)
|
|
125
115
|
|
|
126
116
|
@abstractmethod
|
|
127
117
|
def usage(self) -> Usage:
|
|
@@ -134,12 +124,45 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
134
124
|
|
|
135
125
|
data: ResultData
|
|
136
126
|
"""Data from the final response in the run."""
|
|
127
|
+
_result_tool_name: str | None
|
|
137
128
|
_usage: Usage
|
|
138
129
|
|
|
139
130
|
def usage(self) -> Usage:
|
|
140
131
|
"""Return the usage of the whole run."""
|
|
141
132
|
return self._usage
|
|
142
133
|
|
|
134
|
+
def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
135
|
+
"""Return the history of _messages.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
result_tool_return_content: The return content of the tool call to set in the last message.
|
|
139
|
+
This provides a convenient way to modify the content of the result tool call if you want to continue
|
|
140
|
+
the conversation and want to set the response to the result tool call. If `None`, the last message will
|
|
141
|
+
not be modified.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
List of messages.
|
|
145
|
+
"""
|
|
146
|
+
if result_tool_return_content is not None:
|
|
147
|
+
return self._set_result_tool_return(result_tool_return_content)
|
|
148
|
+
else:
|
|
149
|
+
return self._all_messages
|
|
150
|
+
|
|
151
|
+
def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
|
|
152
|
+
"""Set return content for the result tool.
|
|
153
|
+
|
|
154
|
+
Useful if you want to continue the conversation and want to set the response to the result tool call.
|
|
155
|
+
"""
|
|
156
|
+
if not self._result_tool_name:
|
|
157
|
+
raise ValueError('Cannot set result tool return content when the return type is `str`.')
|
|
158
|
+
messages = deepcopy(self._all_messages)
|
|
159
|
+
last_message = messages[-1]
|
|
160
|
+
for part in last_message.parts:
|
|
161
|
+
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name:
|
|
162
|
+
part.content = return_content
|
|
163
|
+
return messages
|
|
164
|
+
raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.')
|
|
165
|
+
|
|
143
166
|
|
|
144
167
|
@dataclass
|
|
145
168
|
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
|
pydantic_ai/settings.py
CHANGED
|
@@ -1,15 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
4
3
|
from typing import TYPE_CHECKING
|
|
5
4
|
|
|
6
5
|
from httpx import Timeout
|
|
7
6
|
from typing_extensions import TypedDict
|
|
8
7
|
|
|
9
|
-
from .exceptions import UsageLimitExceeded
|
|
10
|
-
|
|
11
8
|
if TYPE_CHECKING:
|
|
12
|
-
|
|
9
|
+
pass
|
|
13
10
|
|
|
14
11
|
|
|
15
12
|
class ModelSettings(TypedDict, total=False):
|
|
@@ -82,60 +79,3 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings |
|
|
|
82
79
|
return base | overrides
|
|
83
80
|
else:
|
|
84
81
|
return base or overrides
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@dataclass
|
|
88
|
-
class UsageLimits:
|
|
89
|
-
"""Limits on model usage.
|
|
90
|
-
|
|
91
|
-
The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
|
|
92
|
-
Token counts are provided in responses from the model, and the token limits are checked after each response.
|
|
93
|
-
|
|
94
|
-
Each of the limits can be set to `None` to disable that limit.
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
request_limit: int | None = 50
|
|
98
|
-
"""The maximum number of requests allowed to the model."""
|
|
99
|
-
request_tokens_limit: int | None = None
|
|
100
|
-
"""The maximum number of tokens allowed in requests to the model."""
|
|
101
|
-
response_tokens_limit: int | None = None
|
|
102
|
-
"""The maximum number of tokens allowed in responses from the model."""
|
|
103
|
-
total_tokens_limit: int | None = None
|
|
104
|
-
"""The maximum number of tokens allowed in requests and responses combined."""
|
|
105
|
-
|
|
106
|
-
def has_token_limits(self) -> bool:
|
|
107
|
-
"""Returns `True` if this instance places any limits on token counts.
|
|
108
|
-
|
|
109
|
-
If this returns `False`, the `check_tokens` method will never raise an error.
|
|
110
|
-
|
|
111
|
-
This is useful because if we have token limits, we need to check them after receiving each streamed message.
|
|
112
|
-
If there are no limits, we can skip that processing in the streaming response iterator.
|
|
113
|
-
"""
|
|
114
|
-
return any(
|
|
115
|
-
limit is not None
|
|
116
|
-
for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
def check_before_request(self, usage: Usage) -> None:
|
|
120
|
-
"""Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
|
|
121
|
-
request_limit = self.request_limit
|
|
122
|
-
if request_limit is not None and usage.requests >= request_limit:
|
|
123
|
-
raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
|
|
124
|
-
|
|
125
|
-
def check_tokens(self, usage: Usage) -> None:
|
|
126
|
-
"""Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
|
|
127
|
-
request_tokens = usage.request_tokens or 0
|
|
128
|
-
if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
|
|
129
|
-
raise UsageLimitExceeded(
|
|
130
|
-
f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
response_tokens = usage.response_tokens or 0
|
|
134
|
-
if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
|
|
135
|
-
raise UsageLimitExceeded(
|
|
136
|
-
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
total_tokens = usage.total_tokens or 0
|
|
140
|
-
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
|
|
141
|
-
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
|
pydantic_ai/tools.py
CHANGED
|
@@ -4,11 +4,11 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
from collections.abc import Awaitable
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic,
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
from pydantic_core import SchemaValidator
|
|
11
|
-
from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
11
|
+
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
|
|
12
12
|
|
|
13
13
|
from . import _pydantic, _utils, messages as _messages, models
|
|
14
14
|
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
@@ -30,7 +30,7 @@ __all__ = (
|
|
|
30
30
|
'ToolDefinition',
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
AgentDeps = TypeVar('AgentDeps')
|
|
33
|
+
AgentDeps = TypeVar('AgentDeps', default=None)
|
|
34
34
|
"""Type variable for agent dependencies."""
|
|
35
35
|
|
|
36
36
|
|
|
@@ -67,7 +67,7 @@ class RunContext(Generic[AgentDeps]):
|
|
|
67
67
|
return dataclasses.replace(self, **kwargs)
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
ToolParams = ParamSpec('ToolParams')
|
|
70
|
+
ToolParams = ParamSpec('ToolParams', default=...)
|
|
71
71
|
"""Retrieval function param spec."""
|
|
72
72
|
|
|
73
73
|
SystemPromptFunc = Union[
|
|
@@ -92,7 +92,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
|
|
|
92
92
|
Usage `ToolPlainFunc[ToolParams]`.
|
|
93
93
|
"""
|
|
94
94
|
ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
|
|
95
|
-
"""Either
|
|
95
|
+
"""Either kind of tool function.
|
|
96
96
|
|
|
97
97
|
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
|
|
98
98
|
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
|
|
@@ -134,7 +134,7 @@ A = TypeVar('A')
|
|
|
134
134
|
class Tool(Generic[AgentDeps]):
|
|
135
135
|
"""A tool function for an agent."""
|
|
136
136
|
|
|
137
|
-
function: ToolFuncEither[AgentDeps
|
|
137
|
+
function: ToolFuncEither[AgentDeps]
|
|
138
138
|
takes_ctx: bool
|
|
139
139
|
max_retries: int | None
|
|
140
140
|
name: str
|
|
@@ -150,7 +150,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
150
150
|
|
|
151
151
|
def __init__(
|
|
152
152
|
self,
|
|
153
|
-
function: ToolFuncEither[AgentDeps
|
|
153
|
+
function: ToolFuncEither[AgentDeps],
|
|
154
154
|
*,
|
|
155
155
|
takes_ctx: bool | None = None,
|
|
156
156
|
max_retries: int | None = None,
|