pydantic-ai-slim 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,15 @@
1
+ site
2
+ .python-version
3
+ .venv
4
+ dist
5
+ __pycache__
6
+ *.env
7
+ /scratch/
8
+ /.coverage
9
+ env*/
10
+ /TODO.md
11
+ /postgres-data/
12
+ .DS_Store
13
+ /pydantic_ai_examples/.chat_app_messages.jsonl
14
+ .cache/
15
+ .docs-insiders-install
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.3
2
+ Name: pydantic-ai-slim
3
+ Version: 0.0.6
4
+ Summary: Agent Framework / shim to use Pydantic with LLMs
5
+ Author-email: Samuel Colvin <samuel@pydantic.dev>
6
+ License: MIT
7
+ Classifier: Development Status :: 4 - Beta
8
+ Classifier: Environment :: Console
9
+ Classifier: Environment :: MacOS X
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Intended Audience :: Information Technology
12
+ Classifier: Intended Audience :: System Administrators
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: POSIX :: Linux
15
+ Classifier: Operating System :: Unix
16
+ Classifier: Programming Language :: Python
17
+ Classifier: Programming Language :: Python :: 3
18
+ Classifier: Programming Language :: Python :: 3 :: Only
19
+ Classifier: Programming Language :: Python :: 3.9
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Programming Language :: Python :: 3.13
24
+ Classifier: Topic :: Internet
25
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
26
+ Requires-Python: >=3.9
27
+ Requires-Dist: eval-type-backport>=0.2.0
28
+ Requires-Dist: griffe>=1.3.2
29
+ Requires-Dist: httpx>=0.27.2
30
+ Requires-Dist: logfire-api>=1.2.0
31
+ Requires-Dist: pydantic>=2.10
32
+ Provides-Extra: groq
33
+ Requires-Dist: groq>=0.12.0; extra == 'groq'
34
+ Provides-Extra: logfire
35
+ Requires-Dist: logfire>=2.3; extra == 'logfire'
36
+ Provides-Extra: openai
37
+ Requires-Dist: openai>=1.54.3; extra == 'openai'
38
+ Provides-Extra: vertexai
39
+ Requires-Dist: google-auth>=2.36.0; extra == 'vertexai'
40
+ Requires-Dist: requests>=2.32.3; extra == 'vertexai'
41
+ Description-Content-Type: text/markdown
42
+
43
+ # Coming soon
44
+
45
+ [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
46
+ [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
47
+ [![PyPI](https://img.shields.io/pypi/v/pydantic-ai.svg)](https://pypi.python.org/pypi/pydantic-ai)
48
+ [![versions](https://img.shields.io/pypi/pyversions/pydantic-ai.svg)](https://github.com/pydantic/pydantic-ai)
49
+ [![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
@@ -0,0 +1,7 @@
1
+ # Coming soon
2
+
3
+ [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain)
4
+ [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai)
5
+ [![PyPI](https://img.shields.io/pypi/v/pydantic-ai.svg)](https://pypi.python.org/pypi/pydantic-ai)
6
+ [![versions](https://img.shields.io/pypi/pyversions/pydantic-ai.svg)](https://github.com/pydantic/pydantic-ai)
7
+ [![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE)
@@ -0,0 +1,8 @@
1
+ from importlib.metadata import version
2
+
3
+ from .agent import Agent
4
+ from .dependencies import CallContext
5
+ from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
6
+
7
+ __all__ = 'Agent', 'CallContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
8
+ __version__ = version('pydantic_ai_slim')
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import re
4
+ from inspect import Signature
5
+ from typing import Any, Callable, Literal, cast
6
+
7
+ from _griffe.enumerations import DocstringSectionKind
8
+ from _griffe.models import Docstring, Object as GriffeObject
9
+
10
+ DocstringStyle = Literal['google', 'numpy', 'sphinx']
11
+
12
+
13
+ def doc_descriptions(
14
+ func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None
15
+ ) -> tuple[str, dict[str, str]]:
16
+ """Extract the function description and parameter descriptions from a function's docstring.
17
+
18
+ Returns:
19
+ A tuple of (main function description, parameter descriptions).
20
+ """
21
+ doc = func.__doc__
22
+ if doc is None:
23
+ return '', {}
24
+
25
+ # see https://github.com/mkdocstrings/griffe/issues/293
26
+ parent = cast(GriffeObject, sig)
27
+
28
+ docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent)
29
+ sections = docstring.parse()
30
+
31
+ params = {}
32
+ if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None):
33
+ params = {p.name: p.description for p in parameters.value}
34
+
35
+ main_desc = ''
36
+ if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None):
37
+ main_desc = main.value
38
+
39
+ return main_desc, params
40
+
41
+
42
+ def _infer_docstring_style(doc: str) -> DocstringStyle:
43
+ """Simplistic docstring style inference."""
44
+ for pattern, replacements, style in _docstring_style_patterns:
45
+ matches = (
46
+ re.search(pattern.format(replacement), doc, re.IGNORECASE | re.MULTILINE) for replacement in replacements
47
+ )
48
+ if any(matches):
49
+ return style
50
+ # fallback to google style
51
+ return 'google'
52
+
53
+
54
+ # See https://github.com/mkdocstrings/griffe/issues/329#issuecomment-2425017804
55
+ _docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [
56
+ (
57
+ r'\n[ \t]*:{0}([ \t]+\w+)*:([ \t]+.+)?\n',
58
+ [
59
+ 'param',
60
+ 'parameter',
61
+ 'arg',
62
+ 'argument',
63
+ 'key',
64
+ 'keyword',
65
+ 'type',
66
+ 'var',
67
+ 'ivar',
68
+ 'cvar',
69
+ 'vartype',
70
+ 'returns',
71
+ 'return',
72
+ 'rtype',
73
+ 'raises',
74
+ 'raise',
75
+ 'except',
76
+ 'exception',
77
+ ],
78
+ 'sphinx',
79
+ ),
80
+ (
81
+ r'\n[ \t]*{0}:([ \t]+.+)?\n[ \t]+.+',
82
+ [
83
+ 'args',
84
+ 'arguments',
85
+ 'params',
86
+ 'parameters',
87
+ 'keyword args',
88
+ 'keyword arguments',
89
+ 'other args',
90
+ 'other arguments',
91
+ 'other params',
92
+ 'other parameters',
93
+ 'raises',
94
+ 'exceptions',
95
+ 'returns',
96
+ 'yields',
97
+ 'receives',
98
+ 'examples',
99
+ 'attributes',
100
+ 'functions',
101
+ 'methods',
102
+ 'classes',
103
+ 'modules',
104
+ 'warns',
105
+ 'warnings',
106
+ ],
107
+ 'google',
108
+ ),
109
+ (
110
+ r'\n[ \t]*{0}\n[ \t]*---+\n',
111
+ [
112
+ 'deprecated',
113
+ 'parameters',
114
+ 'other parameters',
115
+ 'returns',
116
+ 'yields',
117
+ 'receives',
118
+ 'raises',
119
+ 'warns',
120
+ 'attributes',
121
+ 'functions',
122
+ 'methods',
123
+ 'classes',
124
+ 'modules',
125
+ ],
126
+ 'numpy',
127
+ ),
128
+ ]
@@ -0,0 +1,216 @@
1
+ """Used to build pydantic validators and JSON schemas from functions.
2
+
3
+ This module has to use numerous internal Pydantic APIs and is therefore brittle to changes in Pydantic.
4
+ """
5
+
6
+ from __future__ import annotations as _annotations
7
+
8
+ from inspect import Parameter, signature
9
+ from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin
10
+
11
+ from pydantic import ConfigDict, TypeAdapter
12
+ from pydantic._internal import _decorators, _generate_schema, _typing_extra
13
+ from pydantic._internal._config import ConfigWrapper
14
+ from pydantic.fields import FieldInfo
15
+ from pydantic.json_schema import GenerateJsonSchema
16
+ from pydantic.plugin._schema_validator import create_schema_validator
17
+ from pydantic_core import SchemaValidator, core_schema
18
+
19
+ from ._griffe import doc_descriptions
20
+ from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
21
+
22
+ if TYPE_CHECKING:
23
+ from . import _retriever
24
+ from .dependencies import AgentDeps, RetrieverParams
25
+
26
+
27
+ __all__ = 'function_schema', 'LazyTypeAdapter'
28
+
29
+
30
+ class FunctionSchema(TypedDict):
31
+ """Internal information about a function schema."""
32
+
33
+ description: str
34
+ validator: SchemaValidator
35
+ json_schema: ObjectJsonSchema
36
+ # if not None, the function takes a single by that name (besides potentially `info`)
37
+ single_arg_name: str | None
38
+ positional_fields: list[str]
39
+ var_positional_field: str | None
40
+
41
+
42
+ def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, RetrieverParams]) -> FunctionSchema: # noqa: C901
43
+ """Build a Pydantic validator and JSON schema from a retriever function.
44
+
45
+ Args:
46
+ either_function: The function to build a validator and JSON schema for.
47
+
48
+ Returns:
49
+ A `FunctionSchema` instance.
50
+ """
51
+ function = either_function.whichever()
52
+ takes_ctx = either_function.is_left()
53
+ config = ConfigDict(title=function.__name__)
54
+ config_wrapper = ConfigWrapper(config)
55
+ gen_schema = _generate_schema.GenerateSchema(config_wrapper)
56
+
57
+ sig = signature(function)
58
+
59
+ type_hints = _typing_extra.get_function_type_hints(function)
60
+
61
+ var_kwargs_schema: core_schema.CoreSchema | None = None
62
+ fields: dict[str, core_schema.TypedDictField] = {}
63
+ positional_fields: list[str] = []
64
+ var_positional_field: str | None = None
65
+ errors: list[str] = []
66
+ decorators = _decorators.DecoratorInfos()
67
+ description, field_descriptions = doc_descriptions(function, sig)
68
+
69
+ for index, (name, p) in enumerate(sig.parameters.items()):
70
+ if p.annotation is sig.empty:
71
+ if takes_ctx and index == 0:
72
+ # should be the `context` argument, skip
73
+ continue
74
+ # TODO warn?
75
+ annotation = Any
76
+ else:
77
+ annotation = type_hints[name]
78
+
79
+ if index == 0 and takes_ctx:
80
+ if not _is_call_ctx(annotation):
81
+ errors.append('First argument must be a CallContext instance when using `.retriever`')
82
+ continue
83
+ elif not takes_ctx and _is_call_ctx(annotation):
84
+ errors.append('CallContext instance can only be used with `.retriever`')
85
+ continue
86
+ elif index != 0 and _is_call_ctx(annotation):
87
+ errors.append('CallContext instance can only be used as the first argument')
88
+ continue
89
+
90
+ field_name = p.name
91
+ if p.kind == Parameter.VAR_KEYWORD:
92
+ var_kwargs_schema = gen_schema.generate_schema(annotation)
93
+ else:
94
+ if p.kind == Parameter.VAR_POSITIONAL:
95
+ annotation = list[annotation]
96
+
97
+ # FieldInfo.from_annotation expects a type, `annotation` is Any
98
+ annotation = cast(type[Any], annotation)
99
+ field_info = FieldInfo.from_annotation(annotation)
100
+ if field_info.description is None:
101
+ field_info.description = field_descriptions.get(field_name)
102
+
103
+ fields[field_name] = td_schema = gen_schema._generate_td_field_schema( # pyright: ignore[reportPrivateUsage]
104
+ field_name,
105
+ field_info,
106
+ decorators,
107
+ )
108
+ # noinspection PyTypeChecker
109
+ td_schema.setdefault('metadata', {})['is_model_like'] = is_model_like(annotation)
110
+
111
+ if p.kind == Parameter.POSITIONAL_ONLY:
112
+ positional_fields.append(field_name)
113
+ elif p.kind == Parameter.VAR_POSITIONAL:
114
+ var_positional_field = field_name
115
+
116
+ if errors:
117
+ from .exceptions import UserError
118
+
119
+ error_details = '\n '.join(errors)
120
+ raise UserError(f'Error generating schema for {function.__qualname__}:\n {error_details}')
121
+
122
+ core_config = config_wrapper.core_config(None)
123
+ # noinspection PyTypedDict
124
+ core_config['extra_fields_behavior'] = 'allow' if var_kwargs_schema else 'forbid'
125
+
126
+ schema, single_arg_name = _build_schema(fields, var_kwargs_schema, gen_schema, core_config)
127
+ schema = gen_schema.clean_schema(schema)
128
+ # noinspection PyUnresolvedReferences
129
+ schema_validator = create_schema_validator(
130
+ schema,
131
+ function,
132
+ function.__module__,
133
+ function.__qualname__,
134
+ 'validate_call',
135
+ core_config,
136
+ config_wrapper.plugin_settings,
137
+ )
138
+ # PluggableSchemaValidator is api compatible with SchemaValidator
139
+ schema_validator = cast(SchemaValidator, schema_validator)
140
+ json_schema = GenerateJsonSchema().generate(schema)
141
+
142
+ # workaround for https://github.com/pydantic/pydantic/issues/10785
143
+ # if we build a custom TypeDict schema (matches when `single_arg_name` is None), we manually set
144
+ # `additionalProperties` in the JSON Schema
145
+ if single_arg_name is None:
146
+ json_schema['additionalProperties'] = bool(var_kwargs_schema)
147
+
148
+ # instead of passing `description` through in core_schema, we just add it here
149
+ if description:
150
+ json_schema = {'description': description} | json_schema
151
+
152
+ return FunctionSchema(
153
+ description=description,
154
+ validator=schema_validator,
155
+ json_schema=check_object_json_schema(json_schema),
156
+ single_arg_name=single_arg_name,
157
+ positional_fields=positional_fields,
158
+ var_positional_field=var_positional_field,
159
+ )
160
+
161
+
162
+ def _build_schema(
163
+ fields: dict[str, core_schema.TypedDictField],
164
+ var_kwargs_schema: core_schema.CoreSchema | None,
165
+ gen_schema: _generate_schema.GenerateSchema,
166
+ core_config: core_schema.CoreConfig,
167
+ ) -> tuple[core_schema.CoreSchema, str | None]:
168
+ """Generate a typed dict schema for function parameters.
169
+
170
+ Args:
171
+ fields: The fields to generate a typed dict schema for.
172
+ var_kwargs_schema: The variable keyword arguments schema.
173
+ gen_schema: The `GenerateSchema` instance.
174
+ core_config: The core configuration.
175
+
176
+ Returns:
177
+ tuple of (generated core schema, single arg name).
178
+ """
179
+ if len(fields) == 1 and var_kwargs_schema is None:
180
+ name = next(iter(fields))
181
+ td_field = fields[name]
182
+ if td_field['metadata']['is_model_like']: # type: ignore
183
+ return td_field['schema'], name
184
+
185
+ td_schema = core_schema.typed_dict_schema(
186
+ fields,
187
+ config=core_config,
188
+ extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
189
+ )
190
+ return td_schema, None
191
+
192
+
193
+ def _is_call_ctx(annotation: Any) -> bool:
194
+ from .dependencies import CallContext
195
+
196
+ return annotation is CallContext or (
197
+ _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is CallContext
198
+ )
199
+
200
+
201
+ if TYPE_CHECKING:
202
+ LazyTypeAdapter = TypeAdapter
203
+ else:
204
+
205
+ class LazyTypeAdapter:
206
+ __slots__ = '_args', '_kwargs', '_type_adapter'
207
+
208
+ def __init__(self, *args, **kwargs):
209
+ self._args = args
210
+ self._kwargs = kwargs
211
+ self._type_adapter = None
212
+
213
+ def __getattr__(self, item):
214
+ if self._type_adapter is None:
215
+ self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
216
+ return getattr(self._type_adapter, item)
@@ -0,0 +1,258 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import inspect
4
+ import sys
5
+ import types
6
+ from collections.abc import Awaitable
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
9
+
10
+ from pydantic import TypeAdapter, ValidationError
11
+ from typing_extensions import Self, TypeAliasType, TypedDict
12
+
13
+ from . import _utils, messages
14
+ from .dependencies import AgentDeps, CallContext, ResultValidatorFunc
15
+ from .exceptions import ModelRetry
16
+ from .messages import ModelStructuredResponse, ToolCall
17
+ from .result import ResultData
18
+
19
+
20
+ @dataclass
21
+ class ResultValidator(Generic[AgentDeps, ResultData]):
22
+ function: ResultValidatorFunc[AgentDeps, ResultData]
23
+ _takes_ctx: bool = field(init=False)
24
+ _is_async: bool = field(init=False)
25
+
26
+ def __post_init__(self):
27
+ self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
28
+ self._is_async = inspect.iscoroutinefunction(self.function)
29
+
30
+ async def validate(
31
+ self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall | None
32
+ ) -> ResultData:
33
+ """Validate a result but calling the function.
34
+
35
+ Args:
36
+ result: The result data after Pydantic validation the message content.
37
+ deps: The agent dependencies.
38
+ retry: The current retry number.
39
+ tool_call: The original tool call message, `None` if there was no tool call.
40
+
41
+ Returns:
42
+ Result of either the validated result data (ok) or a retry message (Err).
43
+ """
44
+ if self._takes_ctx:
45
+ args = CallContext(deps, retry, tool_call.tool_name if tool_call else None), result
46
+ else:
47
+ args = (result,)
48
+
49
+ try:
50
+ if self._is_async:
51
+ function = cast(Callable[[Any], Awaitable[ResultData]], self.function)
52
+ result_data = await function(*args)
53
+ else:
54
+ function = cast(Callable[[Any], ResultData], self.function)
55
+ result_data = await _utils.run_in_executor(function, *args)
56
+ except ModelRetry as r:
57
+ m = messages.RetryPrompt(content=r.message)
58
+ if tool_call is not None:
59
+ m.tool_name = tool_call.tool_name
60
+ m.tool_id = tool_call.tool_id
61
+ raise ToolRetryError(m) from r
62
+ else:
63
+ return result_data
64
+
65
+
66
+ class ToolRetryError(Exception):
67
+ """Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
68
+
69
+ def __init__(self, tool_retry: messages.RetryPrompt):
70
+ self.tool_retry = tool_retry
71
+ super().__init__()
72
+
73
+
74
+ @dataclass
75
+ class ResultSchema(Generic[ResultData]):
76
+ """Model the final response from an agent run.
77
+
78
+ Similar to `Retriever` but for the final result of running an agent.
79
+ """
80
+
81
+ tools: dict[str, ResultTool[ResultData]]
82
+ allow_text_result: bool
83
+
84
+ @classmethod
85
+ def build(cls, response_type: type[ResultData], name: str, description: str | None) -> Self | None:
86
+ """Build a ResultSchema dataclass from a response type."""
87
+ if response_type is str:
88
+ return None
89
+
90
+ if response_type_option := extract_str_from_union(response_type):
91
+ response_type = response_type_option.value
92
+ allow_text_result = True
93
+ else:
94
+ allow_text_result = False
95
+
96
+ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
97
+ return cast(
98
+ ResultTool[ResultData],
99
+ ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
100
+ )
101
+
102
+ tools: dict[str, ResultTool[ResultData]] = {}
103
+ if args := get_union_args(response_type):
104
+ for i, arg in enumerate(args, start=1):
105
+ tool_name = union_tool_name(name, arg)
106
+ while tool_name in tools:
107
+ tool_name = f'{tool_name}_{i}'
108
+ tools[tool_name] = _build_tool(arg, tool_name, True)
109
+ else:
110
+ tools[name] = _build_tool(response_type, name, False)
111
+
112
+ return cls(tools=tools, allow_text_result=allow_text_result)
113
+
114
+ def find_tool(self, message: ModelStructuredResponse) -> tuple[ToolCall, ResultTool[ResultData]] | None:
115
+ """Find a tool that matches one of the calls."""
116
+ for call in message.calls:
117
+ if result := self.tools.get(call.tool_name):
118
+ return call, result
119
+
120
+ def tool_names(self) -> list[str]:
121
+ """Return the names of the tools."""
122
+ return list(self.tools.keys())
123
+
124
+
125
+ DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
126
+
127
+
128
+ @dataclass
129
+ class ResultTool(Generic[ResultData]):
130
+ name: str
131
+ description: str
132
+ type_adapter: TypeAdapter[Any]
133
+ json_schema: _utils.ObjectJsonSchema
134
+ outer_typed_dict_key: str | None
135
+
136
+ @classmethod
137
+ def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
138
+ """Build a ResultTool dataclass from a response type."""
139
+ assert response_type is not str, 'ResultTool does not support str as a response type'
140
+
141
+ if _utils.is_model_like(response_type):
142
+ type_adapter = TypeAdapter(response_type)
143
+ outer_typed_dict_key: str | None = None
144
+ # noinspection PyArgumentList
145
+ json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
146
+ else:
147
+ response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
148
+ type_adapter = TypeAdapter(response_data_typed_dict)
149
+ outer_typed_dict_key = 'response'
150
+ # noinspection PyArgumentList
151
+ json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
152
+ # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
153
+ json_schema.pop('title')
154
+
155
+ if json_schema_description := json_schema.pop('description', None):
156
+ if description is None:
157
+ tool_description = json_schema_description
158
+ else:
159
+ tool_description = f'{description}. {json_schema_description}'
160
+ else:
161
+ tool_description = description or DEFAULT_DESCRIPTION
162
+ if multiple:
163
+ tool_description = f'{union_arg_name(response_type)}: {tool_description}'
164
+
165
+ return cls(
166
+ name=name,
167
+ description=tool_description,
168
+ type_adapter=type_adapter,
169
+ json_schema=json_schema,
170
+ outer_typed_dict_key=outer_typed_dict_key,
171
+ )
172
+
173
+ def validate(
174
+ self, tool_call: messages.ToolCall, allow_partial: bool = False, wrap_validation_errors: bool = True
175
+ ) -> ResultData:
176
+ """Validate a result message.
177
+
178
+ Args:
179
+ tool_call: The tool call from the LLM to validate.
180
+ allow_partial: If true, allow partial validation.
181
+ wrap_validation_errors: If true, wrap the validation errors in a retry message.
182
+
183
+ Returns:
184
+ Either the validated result data (left) or a retry message (right).
185
+ """
186
+ try:
187
+ pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
188
+ if isinstance(tool_call.args, messages.ArgsJson):
189
+ result = self.type_adapter.validate_json(
190
+ tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
191
+ )
192
+ else:
193
+ result = self.type_adapter.validate_python(
194
+ tool_call.args.args_object, experimental_allow_partial=pyd_allow_partial
195
+ )
196
+ except ValidationError as e:
197
+ if wrap_validation_errors:
198
+ m = messages.RetryPrompt(
199
+ tool_name=tool_call.tool_name,
200
+ content=e.errors(include_url=False),
201
+ tool_id=tool_call.tool_id,
202
+ )
203
+ raise ToolRetryError(m) from e
204
+ else:
205
+ raise
206
+ else:
207
+ if k := self.outer_typed_dict_key:
208
+ result = result[k]
209
+ return result
210
+
211
+
212
+ def union_tool_name(base_name: str, union_arg: Any) -> str:
213
+ return f'{base_name}_{union_arg_name(union_arg)}'
214
+
215
+
216
+ def union_arg_name(union_arg: Any) -> str:
217
+ return union_arg.__name__
218
+
219
+
220
+ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
221
+ """Extract the string type from a Union, return the remaining union or remaining type."""
222
+ union_args = get_union_args(response_type)
223
+ if any(t is str for t in union_args):
224
+ remain_args: list[Any] = []
225
+ includes_str = False
226
+ for arg in union_args:
227
+ if arg is str:
228
+ includes_str = True
229
+ else:
230
+ remain_args.append(arg)
231
+ if includes_str:
232
+ if len(remain_args) == 1:
233
+ return _utils.Some(remain_args[0])
234
+ else:
235
+ return _utils.Some(Union[tuple(remain_args)])
236
+
237
+
238
+ def get_union_args(tp: Any) -> tuple[Any, ...]:
239
+ """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
240
+ if isinstance(tp, TypeAliasType):
241
+ tp = tp.__value__
242
+
243
+ origin = get_origin(tp)
244
+ if origin_is_union(origin):
245
+ return get_args(tp)
246
+ else:
247
+ return ()
248
+
249
+
250
+ if sys.version_info < (3, 10):
251
+
252
+ def origin_is_union(tp: type[Any] | None) -> bool:
253
+ return tp is Union
254
+
255
+ else:
256
+
257
+ def origin_is_union(tp: type[Any] | None) -> bool:
258
+ return tp is Union or tp is types.UnionType