pydantic-ai-slim 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai_slim-0.0.6/.gitignore +15 -0
- pydantic_ai_slim-0.0.6/PKG-INFO +49 -0
- pydantic_ai_slim-0.0.6/README.md +7 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/__init__.py +8 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/_griffe.py +128 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/_pydantic.py +216 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/_result.py +258 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/_retriever.py +114 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/_system_prompt.py +33 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/_utils.py +247 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/agent.py +795 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/dependencies.py +83 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/exceptions.py +56 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/messages.py +205 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/models/__init__.py +300 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/models/function.py +268 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/models/gemini.py +720 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/models/groq.py +400 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/models/openai.py +379 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/models/test.py +389 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/models/vertexai.py +306 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/py.typed +0 -0
- pydantic_ai_slim-0.0.6/pydantic_ai/result.py +314 -0
- pydantic_ai_slim-0.0.6/pyproject.toml +51 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: pydantic-ai-slim
|
|
3
|
+
Version: 0.0.6
|
|
4
|
+
Summary: Agent Framework / shim to use Pydantic with LLMs
|
|
5
|
+
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
|
+
License: MIT
|
|
7
|
+
Classifier: Development Status :: 4 - Beta
|
|
8
|
+
Classifier: Environment :: Console
|
|
9
|
+
Classifier: Environment :: MacOS X
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Intended Audience :: Information Technology
|
|
12
|
+
Classifier: Intended Audience :: System Administrators
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
15
|
+
Classifier: Operating System :: Unix
|
|
16
|
+
Classifier: Programming Language :: Python
|
|
17
|
+
Classifier: Programming Language :: Python :: 3
|
|
18
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
24
|
+
Classifier: Topic :: Internet
|
|
25
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
26
|
+
Requires-Python: >=3.9
|
|
27
|
+
Requires-Dist: eval-type-backport>=0.2.0
|
|
28
|
+
Requires-Dist: griffe>=1.3.2
|
|
29
|
+
Requires-Dist: httpx>=0.27.2
|
|
30
|
+
Requires-Dist: logfire-api>=1.2.0
|
|
31
|
+
Requires-Dist: pydantic>=2.10
|
|
32
|
+
Provides-Extra: groq
|
|
33
|
+
Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
34
|
+
Provides-Extra: logfire
|
|
35
|
+
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
36
|
+
Provides-Extra: openai
|
|
37
|
+
Requires-Dist: openai>=1.54.3; extra == 'openai'
|
|
38
|
+
Provides-Extra: vertexai
|
|
39
|
+
Requires-Dist: google-auth>=2.36.0; extra == 'vertexai'
|
|
40
|
+
Requires-Dist: requests>=2.32.3; extra == 'vertexai'
|
|
41
|
+
Description-Content-Type: text/markdown
|
|
42
|
+
|
|
43
|
+
# Coming soon
|
|
44
|
+
|
|
45
|
+
[](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
|
|
46
|
+
[](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
|
|
47
|
+
[](https://pypi.python.org/pypi/pydantic-ai)
|
|
48
|
+
[](https://github.com/pydantic/pydantic-ai)
|
|
49
|
+
[](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# Coming soon
|
|
2
|
+
|
|
3
|
+
[](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
|
|
4
|
+
[](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
|
|
5
|
+
[](https://pypi.python.org/pypi/pydantic-ai)
|
|
6
|
+
[](https://github.com/pydantic/pydantic-ai)
|
|
7
|
+
[](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from importlib.metadata import version
|
|
2
|
+
|
|
3
|
+
from .agent import Agent
|
|
4
|
+
from .dependencies import CallContext
|
|
5
|
+
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
|
|
6
|
+
|
|
7
|
+
__all__ = 'Agent', 'CallContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
|
|
8
|
+
__version__ = version('pydantic_ai_slim')
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from inspect import Signature
|
|
5
|
+
from typing import Any, Callable, Literal, cast
|
|
6
|
+
|
|
7
|
+
from _griffe.enumerations import DocstringSectionKind
|
|
8
|
+
from _griffe.models import Docstring, Object as GriffeObject
|
|
9
|
+
|
|
10
|
+
DocstringStyle = Literal['google', 'numpy', 'sphinx']
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def doc_descriptions(
|
|
14
|
+
func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None
|
|
15
|
+
) -> tuple[str, dict[str, str]]:
|
|
16
|
+
"""Extract the function description and parameter descriptions from a function's docstring.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
A tuple of (main function description, parameter descriptions).
|
|
20
|
+
"""
|
|
21
|
+
doc = func.__doc__
|
|
22
|
+
if doc is None:
|
|
23
|
+
return '', {}
|
|
24
|
+
|
|
25
|
+
# see https://github.com/mkdocstrings/griffe/issues/293
|
|
26
|
+
parent = cast(GriffeObject, sig)
|
|
27
|
+
|
|
28
|
+
docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent)
|
|
29
|
+
sections = docstring.parse()
|
|
30
|
+
|
|
31
|
+
params = {}
|
|
32
|
+
if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None):
|
|
33
|
+
params = {p.name: p.description for p in parameters.value}
|
|
34
|
+
|
|
35
|
+
main_desc = ''
|
|
36
|
+
if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None):
|
|
37
|
+
main_desc = main.value
|
|
38
|
+
|
|
39
|
+
return main_desc, params
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _infer_docstring_style(doc: str) -> DocstringStyle:
|
|
43
|
+
"""Simplistic docstring style inference."""
|
|
44
|
+
for pattern, replacements, style in _docstring_style_patterns:
|
|
45
|
+
matches = (
|
|
46
|
+
re.search(pattern.format(replacement), doc, re.IGNORECASE | re.MULTILINE) for replacement in replacements
|
|
47
|
+
)
|
|
48
|
+
if any(matches):
|
|
49
|
+
return style
|
|
50
|
+
# fallback to google style
|
|
51
|
+
return 'google'
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# See https://github.com/mkdocstrings/griffe/issues/329#issuecomment-2425017804
|
|
55
|
+
_docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [
|
|
56
|
+
(
|
|
57
|
+
r'\n[ \t]*:{0}([ \t]+\w+)*:([ \t]+.+)?\n',
|
|
58
|
+
[
|
|
59
|
+
'param',
|
|
60
|
+
'parameter',
|
|
61
|
+
'arg',
|
|
62
|
+
'argument',
|
|
63
|
+
'key',
|
|
64
|
+
'keyword',
|
|
65
|
+
'type',
|
|
66
|
+
'var',
|
|
67
|
+
'ivar',
|
|
68
|
+
'cvar',
|
|
69
|
+
'vartype',
|
|
70
|
+
'returns',
|
|
71
|
+
'return',
|
|
72
|
+
'rtype',
|
|
73
|
+
'raises',
|
|
74
|
+
'raise',
|
|
75
|
+
'except',
|
|
76
|
+
'exception',
|
|
77
|
+
],
|
|
78
|
+
'sphinx',
|
|
79
|
+
),
|
|
80
|
+
(
|
|
81
|
+
r'\n[ \t]*{0}:([ \t]+.+)?\n[ \t]+.+',
|
|
82
|
+
[
|
|
83
|
+
'args',
|
|
84
|
+
'arguments',
|
|
85
|
+
'params',
|
|
86
|
+
'parameters',
|
|
87
|
+
'keyword args',
|
|
88
|
+
'keyword arguments',
|
|
89
|
+
'other args',
|
|
90
|
+
'other arguments',
|
|
91
|
+
'other params',
|
|
92
|
+
'other parameters',
|
|
93
|
+
'raises',
|
|
94
|
+
'exceptions',
|
|
95
|
+
'returns',
|
|
96
|
+
'yields',
|
|
97
|
+
'receives',
|
|
98
|
+
'examples',
|
|
99
|
+
'attributes',
|
|
100
|
+
'functions',
|
|
101
|
+
'methods',
|
|
102
|
+
'classes',
|
|
103
|
+
'modules',
|
|
104
|
+
'warns',
|
|
105
|
+
'warnings',
|
|
106
|
+
],
|
|
107
|
+
'google',
|
|
108
|
+
),
|
|
109
|
+
(
|
|
110
|
+
r'\n[ \t]*{0}\n[ \t]*---+\n',
|
|
111
|
+
[
|
|
112
|
+
'deprecated',
|
|
113
|
+
'parameters',
|
|
114
|
+
'other parameters',
|
|
115
|
+
'returns',
|
|
116
|
+
'yields',
|
|
117
|
+
'receives',
|
|
118
|
+
'raises',
|
|
119
|
+
'warns',
|
|
120
|
+
'attributes',
|
|
121
|
+
'functions',
|
|
122
|
+
'methods',
|
|
123
|
+
'classes',
|
|
124
|
+
'modules',
|
|
125
|
+
],
|
|
126
|
+
'numpy',
|
|
127
|
+
),
|
|
128
|
+
]
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Used to build pydantic validators and JSON schemas from functions.
|
|
2
|
+
|
|
3
|
+
This module has to use numerous internal Pydantic APIs and is therefore brittle to changes in Pydantic.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations as _annotations
|
|
7
|
+
|
|
8
|
+
from inspect import Parameter, signature
|
|
9
|
+
from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin
|
|
10
|
+
|
|
11
|
+
from pydantic import ConfigDict, TypeAdapter
|
|
12
|
+
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
13
|
+
from pydantic._internal._config import ConfigWrapper
|
|
14
|
+
from pydantic.fields import FieldInfo
|
|
15
|
+
from pydantic.json_schema import GenerateJsonSchema
|
|
16
|
+
from pydantic.plugin._schema_validator import create_schema_validator
|
|
17
|
+
from pydantic_core import SchemaValidator, core_schema
|
|
18
|
+
|
|
19
|
+
from ._griffe import doc_descriptions
|
|
20
|
+
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from . import _retriever
|
|
24
|
+
from .dependencies import AgentDeps, RetrieverParams
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
__all__ = 'function_schema', 'LazyTypeAdapter'
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FunctionSchema(TypedDict):
|
|
31
|
+
"""Internal information about a function schema."""
|
|
32
|
+
|
|
33
|
+
description: str
|
|
34
|
+
validator: SchemaValidator
|
|
35
|
+
json_schema: ObjectJsonSchema
|
|
36
|
+
# if not None, the function takes a single by that name (besides potentially `info`)
|
|
37
|
+
single_arg_name: str | None
|
|
38
|
+
positional_fields: list[str]
|
|
39
|
+
var_positional_field: str | None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, RetrieverParams]) -> FunctionSchema: # noqa: C901
|
|
43
|
+
"""Build a Pydantic validator and JSON schema from a retriever function.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
either_function: The function to build a validator and JSON schema for.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A `FunctionSchema` instance.
|
|
50
|
+
"""
|
|
51
|
+
function = either_function.whichever()
|
|
52
|
+
takes_ctx = either_function.is_left()
|
|
53
|
+
config = ConfigDict(title=function.__name__)
|
|
54
|
+
config_wrapper = ConfigWrapper(config)
|
|
55
|
+
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
|
|
56
|
+
|
|
57
|
+
sig = signature(function)
|
|
58
|
+
|
|
59
|
+
type_hints = _typing_extra.get_function_type_hints(function)
|
|
60
|
+
|
|
61
|
+
var_kwargs_schema: core_schema.CoreSchema | None = None
|
|
62
|
+
fields: dict[str, core_schema.TypedDictField] = {}
|
|
63
|
+
positional_fields: list[str] = []
|
|
64
|
+
var_positional_field: str | None = None
|
|
65
|
+
errors: list[str] = []
|
|
66
|
+
decorators = _decorators.DecoratorInfos()
|
|
67
|
+
description, field_descriptions = doc_descriptions(function, sig)
|
|
68
|
+
|
|
69
|
+
for index, (name, p) in enumerate(sig.parameters.items()):
|
|
70
|
+
if p.annotation is sig.empty:
|
|
71
|
+
if takes_ctx and index == 0:
|
|
72
|
+
# should be the `context` argument, skip
|
|
73
|
+
continue
|
|
74
|
+
# TODO warn?
|
|
75
|
+
annotation = Any
|
|
76
|
+
else:
|
|
77
|
+
annotation = type_hints[name]
|
|
78
|
+
|
|
79
|
+
if index == 0 and takes_ctx:
|
|
80
|
+
if not _is_call_ctx(annotation):
|
|
81
|
+
errors.append('First argument must be a CallContext instance when using `.retriever`')
|
|
82
|
+
continue
|
|
83
|
+
elif not takes_ctx and _is_call_ctx(annotation):
|
|
84
|
+
errors.append('CallContext instance can only be used with `.retriever`')
|
|
85
|
+
continue
|
|
86
|
+
elif index != 0 and _is_call_ctx(annotation):
|
|
87
|
+
errors.append('CallContext instance can only be used as the first argument')
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
field_name = p.name
|
|
91
|
+
if p.kind == Parameter.VAR_KEYWORD:
|
|
92
|
+
var_kwargs_schema = gen_schema.generate_schema(annotation)
|
|
93
|
+
else:
|
|
94
|
+
if p.kind == Parameter.VAR_POSITIONAL:
|
|
95
|
+
annotation = list[annotation]
|
|
96
|
+
|
|
97
|
+
# FieldInfo.from_annotation expects a type, `annotation` is Any
|
|
98
|
+
annotation = cast(type[Any], annotation)
|
|
99
|
+
field_info = FieldInfo.from_annotation(annotation)
|
|
100
|
+
if field_info.description is None:
|
|
101
|
+
field_info.description = field_descriptions.get(field_name)
|
|
102
|
+
|
|
103
|
+
fields[field_name] = td_schema = gen_schema._generate_td_field_schema( # pyright: ignore[reportPrivateUsage]
|
|
104
|
+
field_name,
|
|
105
|
+
field_info,
|
|
106
|
+
decorators,
|
|
107
|
+
)
|
|
108
|
+
# noinspection PyTypeChecker
|
|
109
|
+
td_schema.setdefault('metadata', {})['is_model_like'] = is_model_like(annotation)
|
|
110
|
+
|
|
111
|
+
if p.kind == Parameter.POSITIONAL_ONLY:
|
|
112
|
+
positional_fields.append(field_name)
|
|
113
|
+
elif p.kind == Parameter.VAR_POSITIONAL:
|
|
114
|
+
var_positional_field = field_name
|
|
115
|
+
|
|
116
|
+
if errors:
|
|
117
|
+
from .exceptions import UserError
|
|
118
|
+
|
|
119
|
+
error_details = '\n '.join(errors)
|
|
120
|
+
raise UserError(f'Error generating schema for {function.__qualname__}:\n {error_details}')
|
|
121
|
+
|
|
122
|
+
core_config = config_wrapper.core_config(None)
|
|
123
|
+
# noinspection PyTypedDict
|
|
124
|
+
core_config['extra_fields_behavior'] = 'allow' if var_kwargs_schema else 'forbid'
|
|
125
|
+
|
|
126
|
+
schema, single_arg_name = _build_schema(fields, var_kwargs_schema, gen_schema, core_config)
|
|
127
|
+
schema = gen_schema.clean_schema(schema)
|
|
128
|
+
# noinspection PyUnresolvedReferences
|
|
129
|
+
schema_validator = create_schema_validator(
|
|
130
|
+
schema,
|
|
131
|
+
function,
|
|
132
|
+
function.__module__,
|
|
133
|
+
function.__qualname__,
|
|
134
|
+
'validate_call',
|
|
135
|
+
core_config,
|
|
136
|
+
config_wrapper.plugin_settings,
|
|
137
|
+
)
|
|
138
|
+
# PluggableSchemaValidator is api compatible with SchemaValidator
|
|
139
|
+
schema_validator = cast(SchemaValidator, schema_validator)
|
|
140
|
+
json_schema = GenerateJsonSchema().generate(schema)
|
|
141
|
+
|
|
142
|
+
# workaround for https://github.com/pydantic/pydantic/issues/10785
|
|
143
|
+
# if we build a custom TypeDict schema (matches when `single_arg_name` is None), we manually set
|
|
144
|
+
# `additionalProperties` in the JSON Schema
|
|
145
|
+
if single_arg_name is None:
|
|
146
|
+
json_schema['additionalProperties'] = bool(var_kwargs_schema)
|
|
147
|
+
|
|
148
|
+
# instead of passing `description` through in core_schema, we just add it here
|
|
149
|
+
if description:
|
|
150
|
+
json_schema = {'description': description} | json_schema
|
|
151
|
+
|
|
152
|
+
return FunctionSchema(
|
|
153
|
+
description=description,
|
|
154
|
+
validator=schema_validator,
|
|
155
|
+
json_schema=check_object_json_schema(json_schema),
|
|
156
|
+
single_arg_name=single_arg_name,
|
|
157
|
+
positional_fields=positional_fields,
|
|
158
|
+
var_positional_field=var_positional_field,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _build_schema(
|
|
163
|
+
fields: dict[str, core_schema.TypedDictField],
|
|
164
|
+
var_kwargs_schema: core_schema.CoreSchema | None,
|
|
165
|
+
gen_schema: _generate_schema.GenerateSchema,
|
|
166
|
+
core_config: core_schema.CoreConfig,
|
|
167
|
+
) -> tuple[core_schema.CoreSchema, str | None]:
|
|
168
|
+
"""Generate a typed dict schema for function parameters.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
fields: The fields to generate a typed dict schema for.
|
|
172
|
+
var_kwargs_schema: The variable keyword arguments schema.
|
|
173
|
+
gen_schema: The `GenerateSchema` instance.
|
|
174
|
+
core_config: The core configuration.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
tuple of (generated core schema, single arg name).
|
|
178
|
+
"""
|
|
179
|
+
if len(fields) == 1 and var_kwargs_schema is None:
|
|
180
|
+
name = next(iter(fields))
|
|
181
|
+
td_field = fields[name]
|
|
182
|
+
if td_field['metadata']['is_model_like']: # type: ignore
|
|
183
|
+
return td_field['schema'], name
|
|
184
|
+
|
|
185
|
+
td_schema = core_schema.typed_dict_schema(
|
|
186
|
+
fields,
|
|
187
|
+
config=core_config,
|
|
188
|
+
extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
|
|
189
|
+
)
|
|
190
|
+
return td_schema, None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _is_call_ctx(annotation: Any) -> bool:
|
|
194
|
+
from .dependencies import CallContext
|
|
195
|
+
|
|
196
|
+
return annotation is CallContext or (
|
|
197
|
+
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is CallContext
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
if TYPE_CHECKING:
|
|
202
|
+
LazyTypeAdapter = TypeAdapter
|
|
203
|
+
else:
|
|
204
|
+
|
|
205
|
+
class LazyTypeAdapter:
|
|
206
|
+
__slots__ = '_args', '_kwargs', '_type_adapter'
|
|
207
|
+
|
|
208
|
+
def __init__(self, *args, **kwargs):
|
|
209
|
+
self._args = args
|
|
210
|
+
self._kwargs = kwargs
|
|
211
|
+
self._type_adapter = None
|
|
212
|
+
|
|
213
|
+
def __getattr__(self, item):
|
|
214
|
+
if self._type_adapter is None:
|
|
215
|
+
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
|
|
216
|
+
return getattr(self._type_adapter, item)
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import sys
|
|
5
|
+
import types
|
|
6
|
+
from collections.abc import Awaitable
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
|
|
9
|
+
|
|
10
|
+
from pydantic import TypeAdapter, ValidationError
|
|
11
|
+
from typing_extensions import Self, TypeAliasType, TypedDict
|
|
12
|
+
|
|
13
|
+
from . import _utils, messages
|
|
14
|
+
from .dependencies import AgentDeps, CallContext, ResultValidatorFunc
|
|
15
|
+
from .exceptions import ModelRetry
|
|
16
|
+
from .messages import ModelStructuredResponse, ToolCall
|
|
17
|
+
from .result import ResultData
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
22
|
+
function: ResultValidatorFunc[AgentDeps, ResultData]
|
|
23
|
+
_takes_ctx: bool = field(init=False)
|
|
24
|
+
_is_async: bool = field(init=False)
|
|
25
|
+
|
|
26
|
+
def __post_init__(self):
|
|
27
|
+
self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
|
|
28
|
+
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
29
|
+
|
|
30
|
+
async def validate(
|
|
31
|
+
self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall | None
|
|
32
|
+
) -> ResultData:
|
|
33
|
+
"""Validate a result but calling the function.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
result: The result data after Pydantic validation the message content.
|
|
37
|
+
deps: The agent dependencies.
|
|
38
|
+
retry: The current retry number.
|
|
39
|
+
tool_call: The original tool call message, `None` if there was no tool call.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Result of either the validated result data (ok) or a retry message (Err).
|
|
43
|
+
"""
|
|
44
|
+
if self._takes_ctx:
|
|
45
|
+
args = CallContext(deps, retry, tool_call.tool_name if tool_call else None), result
|
|
46
|
+
else:
|
|
47
|
+
args = (result,)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
if self._is_async:
|
|
51
|
+
function = cast(Callable[[Any], Awaitable[ResultData]], self.function)
|
|
52
|
+
result_data = await function(*args)
|
|
53
|
+
else:
|
|
54
|
+
function = cast(Callable[[Any], ResultData], self.function)
|
|
55
|
+
result_data = await _utils.run_in_executor(function, *args)
|
|
56
|
+
except ModelRetry as r:
|
|
57
|
+
m = messages.RetryPrompt(content=r.message)
|
|
58
|
+
if tool_call is not None:
|
|
59
|
+
m.tool_name = tool_call.tool_name
|
|
60
|
+
m.tool_id = tool_call.tool_id
|
|
61
|
+
raise ToolRetryError(m) from r
|
|
62
|
+
else:
|
|
63
|
+
return result_data
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ToolRetryError(Exception):
|
|
67
|
+
"""Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
68
|
+
|
|
69
|
+
def __init__(self, tool_retry: messages.RetryPrompt):
|
|
70
|
+
self.tool_retry = tool_retry
|
|
71
|
+
super().__init__()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class ResultSchema(Generic[ResultData]):
|
|
76
|
+
"""Model the final response from an agent run.
|
|
77
|
+
|
|
78
|
+
Similar to `Retriever` but for the final result of running an agent.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
tools: dict[str, ResultTool[ResultData]]
|
|
82
|
+
allow_text_result: bool
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def build(cls, response_type: type[ResultData], name: str, description: str | None) -> Self | None:
|
|
86
|
+
"""Build a ResultSchema dataclass from a response type."""
|
|
87
|
+
if response_type is str:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
if response_type_option := extract_str_from_union(response_type):
|
|
91
|
+
response_type = response_type_option.value
|
|
92
|
+
allow_text_result = True
|
|
93
|
+
else:
|
|
94
|
+
allow_text_result = False
|
|
95
|
+
|
|
96
|
+
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
|
|
97
|
+
return cast(
|
|
98
|
+
ResultTool[ResultData],
|
|
99
|
+
ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
tools: dict[str, ResultTool[ResultData]] = {}
|
|
103
|
+
if args := get_union_args(response_type):
|
|
104
|
+
for i, arg in enumerate(args, start=1):
|
|
105
|
+
tool_name = union_tool_name(name, arg)
|
|
106
|
+
while tool_name in tools:
|
|
107
|
+
tool_name = f'{tool_name}_{i}'
|
|
108
|
+
tools[tool_name] = _build_tool(arg, tool_name, True)
|
|
109
|
+
else:
|
|
110
|
+
tools[name] = _build_tool(response_type, name, False)
|
|
111
|
+
|
|
112
|
+
return cls(tools=tools, allow_text_result=allow_text_result)
|
|
113
|
+
|
|
114
|
+
def find_tool(self, message: ModelStructuredResponse) -> tuple[ToolCall, ResultTool[ResultData]] | None:
|
|
115
|
+
"""Find a tool that matches one of the calls."""
|
|
116
|
+
for call in message.calls:
|
|
117
|
+
if result := self.tools.get(call.tool_name):
|
|
118
|
+
return call, result
|
|
119
|
+
|
|
120
|
+
def tool_names(self) -> list[str]:
|
|
121
|
+
"""Return the names of the tools."""
|
|
122
|
+
return list(self.tools.keys())
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass
|
|
129
|
+
class ResultTool(Generic[ResultData]):
|
|
130
|
+
name: str
|
|
131
|
+
description: str
|
|
132
|
+
type_adapter: TypeAdapter[Any]
|
|
133
|
+
json_schema: _utils.ObjectJsonSchema
|
|
134
|
+
outer_typed_dict_key: str | None
|
|
135
|
+
|
|
136
|
+
@classmethod
|
|
137
|
+
def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
|
|
138
|
+
"""Build a ResultTool dataclass from a response type."""
|
|
139
|
+
assert response_type is not str, 'ResultTool does not support str as a response type'
|
|
140
|
+
|
|
141
|
+
if _utils.is_model_like(response_type):
|
|
142
|
+
type_adapter = TypeAdapter(response_type)
|
|
143
|
+
outer_typed_dict_key: str | None = None
|
|
144
|
+
# noinspection PyArgumentList
|
|
145
|
+
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
|
|
146
|
+
else:
|
|
147
|
+
response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
|
|
148
|
+
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
149
|
+
outer_typed_dict_key = 'response'
|
|
150
|
+
# noinspection PyArgumentList
|
|
151
|
+
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
|
|
152
|
+
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
|
|
153
|
+
json_schema.pop('title')
|
|
154
|
+
|
|
155
|
+
if json_schema_description := json_schema.pop('description', None):
|
|
156
|
+
if description is None:
|
|
157
|
+
tool_description = json_schema_description
|
|
158
|
+
else:
|
|
159
|
+
tool_description = f'{description}. {json_schema_description}'
|
|
160
|
+
else:
|
|
161
|
+
tool_description = description or DEFAULT_DESCRIPTION
|
|
162
|
+
if multiple:
|
|
163
|
+
tool_description = f'{union_arg_name(response_type)}: {tool_description}'
|
|
164
|
+
|
|
165
|
+
return cls(
|
|
166
|
+
name=name,
|
|
167
|
+
description=tool_description,
|
|
168
|
+
type_adapter=type_adapter,
|
|
169
|
+
json_schema=json_schema,
|
|
170
|
+
outer_typed_dict_key=outer_typed_dict_key,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def validate(
|
|
174
|
+
self, tool_call: messages.ToolCall, allow_partial: bool = False, wrap_validation_errors: bool = True
|
|
175
|
+
) -> ResultData:
|
|
176
|
+
"""Validate a result message.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
tool_call: The tool call from the LLM to validate.
|
|
180
|
+
allow_partial: If true, allow partial validation.
|
|
181
|
+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Either the validated result data (left) or a retry message (right).
|
|
185
|
+
"""
|
|
186
|
+
try:
|
|
187
|
+
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
188
|
+
if isinstance(tool_call.args, messages.ArgsJson):
|
|
189
|
+
result = self.type_adapter.validate_json(
|
|
190
|
+
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
result = self.type_adapter.validate_python(
|
|
194
|
+
tool_call.args.args_object, experimental_allow_partial=pyd_allow_partial
|
|
195
|
+
)
|
|
196
|
+
except ValidationError as e:
|
|
197
|
+
if wrap_validation_errors:
|
|
198
|
+
m = messages.RetryPrompt(
|
|
199
|
+
tool_name=tool_call.tool_name,
|
|
200
|
+
content=e.errors(include_url=False),
|
|
201
|
+
tool_id=tool_call.tool_id,
|
|
202
|
+
)
|
|
203
|
+
raise ToolRetryError(m) from e
|
|
204
|
+
else:
|
|
205
|
+
raise
|
|
206
|
+
else:
|
|
207
|
+
if k := self.outer_typed_dict_key:
|
|
208
|
+
result = result[k]
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def union_tool_name(base_name: str, union_arg: Any) -> str:
|
|
213
|
+
return f'{base_name}_{union_arg_name(union_arg)}'
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def union_arg_name(union_arg: Any) -> str:
|
|
217
|
+
return union_arg.__name__
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
|
|
221
|
+
"""Extract the string type from a Union, return the remaining union or remaining type."""
|
|
222
|
+
union_args = get_union_args(response_type)
|
|
223
|
+
if any(t is str for t in union_args):
|
|
224
|
+
remain_args: list[Any] = []
|
|
225
|
+
includes_str = False
|
|
226
|
+
for arg in union_args:
|
|
227
|
+
if arg is str:
|
|
228
|
+
includes_str = True
|
|
229
|
+
else:
|
|
230
|
+
remain_args.append(arg)
|
|
231
|
+
if includes_str:
|
|
232
|
+
if len(remain_args) == 1:
|
|
233
|
+
return _utils.Some(remain_args[0])
|
|
234
|
+
else:
|
|
235
|
+
return _utils.Some(Union[tuple(remain_args)])
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
239
|
+
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
|
|
240
|
+
if isinstance(tp, TypeAliasType):
|
|
241
|
+
tp = tp.__value__
|
|
242
|
+
|
|
243
|
+
origin = get_origin(tp)
|
|
244
|
+
if origin_is_union(origin):
|
|
245
|
+
return get_args(tp)
|
|
246
|
+
else:
|
|
247
|
+
return ()
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
if sys.version_info < (3, 10):
|
|
251
|
+
|
|
252
|
+
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
253
|
+
return tp is Union
|
|
254
|
+
|
|
255
|
+
else:
|
|
256
|
+
|
|
257
|
+
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
258
|
+
return tp is Union or tp is types.UnionType
|