pydantic-ai-slim 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -2
- pydantic_ai/_agent_graph.py +33 -15
- pydantic_ai/_cli.py +7 -3
- pydantic_ai/_function_schema.py +1 -4
- pydantic_ai/_output.py +654 -159
- pydantic_ai/_run_context.py +56 -0
- pydantic_ai/_system_prompt.py +2 -1
- pydantic_ai/_utils.py +111 -1
- pydantic_ai/agent.py +56 -34
- pydantic_ai/models/__init__.py +21 -2
- pydantic_ai/models/function.py +21 -3
- pydantic_ai/models/gemini.py +27 -4
- pydantic_ai/models/google.py +29 -4
- pydantic_ai/models/mistral.py +5 -1
- pydantic_ai/models/openai.py +70 -9
- pydantic_ai/models/test.py +1 -1
- pydantic_ai/models/wrapper.py +6 -0
- pydantic_ai/output.py +288 -0
- pydantic_ai/profiles/__init__.py +21 -0
- pydantic_ai/profiles/_json_schema.py +1 -1
- pydantic_ai/profiles/google.py +6 -2
- pydantic_ai/profiles/openai.py +5 -0
- pydantic_ai/result.py +52 -26
- pydantic_ai/tools.py +2 -47
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.3.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.3.dist-info}/RECORD +29 -27
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.3.2.dist-info → pydantic_ai_slim-0.3.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py
CHANGED
|
@@ -1,19 +1,35 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import json
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
4
6
|
from collections.abc import Awaitable, Iterable, Iterator, Sequence
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
|
|
7
9
|
|
|
8
10
|
from pydantic import TypeAdapter, ValidationError
|
|
9
11
|
from pydantic_core import SchemaValidator
|
|
10
|
-
from typing_extensions import
|
|
11
|
-
from typing_inspection import typing_objects
|
|
12
|
-
from typing_inspection.introspection import is_union_origin
|
|
12
|
+
from typing_extensions import TypedDict, TypeVar, assert_never
|
|
13
13
|
|
|
14
14
|
from . import _function_schema, _utils, messages as _messages
|
|
15
|
-
from .
|
|
16
|
-
from .
|
|
15
|
+
from ._run_context import AgentDepsT, RunContext
|
|
16
|
+
from .exceptions import ModelRetry, UserError
|
|
17
|
+
from .output import (
|
|
18
|
+
NativeOutput,
|
|
19
|
+
OutputDataT,
|
|
20
|
+
OutputMode,
|
|
21
|
+
OutputSpec,
|
|
22
|
+
OutputTypeOrFunction,
|
|
23
|
+
PromptedOutput,
|
|
24
|
+
StructuredOutputMode,
|
|
25
|
+
TextOutput,
|
|
26
|
+
TextOutputFunc,
|
|
27
|
+
ToolOutput,
|
|
28
|
+
)
|
|
29
|
+
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from .profiles import ModelProfile
|
|
17
33
|
|
|
18
34
|
T = TypeVar('T')
|
|
19
35
|
"""An invariant TypeVar."""
|
|
@@ -29,8 +45,6 @@ changing it would have negative consequences for the ergonomics of the library.
|
|
|
29
45
|
At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would
|
|
30
46
|
resolve these potential variance issues.
|
|
31
47
|
"""
|
|
32
|
-
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
|
|
33
|
-
"""Covariant type variable for the result data type of a run."""
|
|
34
48
|
|
|
35
49
|
OutputValidatorFunc = Union[
|
|
36
50
|
Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv],
|
|
@@ -52,6 +66,14 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
|
|
|
52
66
|
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
|
|
53
67
|
|
|
54
68
|
|
|
69
|
+
class ToolRetryError(Exception):
|
|
70
|
+
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, tool_retry: _messages.RetryPromptPart):
|
|
73
|
+
self.tool_retry = tool_retry
|
|
74
|
+
super().__init__()
|
|
75
|
+
|
|
76
|
+
|
|
55
77
|
@dataclass
|
|
56
78
|
class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
57
79
|
function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv]
|
|
@@ -101,140 +123,399 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
101
123
|
return result_data
|
|
102
124
|
|
|
103
125
|
|
|
104
|
-
class
|
|
105
|
-
|
|
126
|
+
class BaseOutputSchema(ABC, Generic[OutputDataT]):
|
|
127
|
+
@abstractmethod
|
|
128
|
+
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
129
|
+
raise NotImplementedError()
|
|
106
130
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
131
|
+
@property
|
|
132
|
+
def tools(self) -> dict[str, OutputTool[OutputDataT]]:
|
|
133
|
+
"""Get the tools for this output schema."""
|
|
134
|
+
return {}
|
|
110
135
|
|
|
111
136
|
|
|
112
137
|
@dataclass(init=False)
|
|
113
|
-
class
|
|
114
|
-
"""
|
|
115
|
-
|
|
116
|
-
output_type: SimpleOutputType[OutputDataT]
|
|
117
|
-
name: str | None
|
|
118
|
-
description: str | None
|
|
119
|
-
max_retries: int | None
|
|
120
|
-
strict: bool | None
|
|
138
|
+
class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
139
|
+
"""Model the final output from an agent run."""
|
|
121
140
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
141
|
+
@classmethod
|
|
142
|
+
@overload
|
|
143
|
+
def build(
|
|
144
|
+
cls,
|
|
145
|
+
output_spec: OutputSpec[OutputDataT],
|
|
125
146
|
*,
|
|
147
|
+
default_mode: StructuredOutputMode,
|
|
126
148
|
name: str | None = None,
|
|
127
149
|
description: str | None = None,
|
|
128
|
-
max_retries: int | None = None,
|
|
129
150
|
strict: bool | None = None,
|
|
130
|
-
):
|
|
131
|
-
self.output_type = type_
|
|
132
|
-
self.name = name
|
|
133
|
-
self.description = description
|
|
134
|
-
self.max_retries = max_retries
|
|
135
|
-
self.strict = strict
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
T_co = TypeVar('T_co', covariant=True)
|
|
139
|
-
# output_type=Type or output_type=function or output_type=object.method
|
|
140
|
-
SimpleOutputType = TypeAliasType(
|
|
141
|
-
'SimpleOutputType', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,)
|
|
142
|
-
)
|
|
143
|
-
# output_type=ToolOutput(<see above>) or <see above>
|
|
144
|
-
SimpleOutputTypeOrMarker = TypeAliasType(
|
|
145
|
-
'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,)
|
|
146
|
-
)
|
|
147
|
-
# output_type=<see above> or [<see above>, ...]
|
|
148
|
-
OutputType = TypeAliasType(
|
|
149
|
-
'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,)
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
@dataclass
|
|
154
|
-
class OutputSchema(Generic[OutputDataT]):
|
|
155
|
-
"""Model the final output from an agent run.
|
|
151
|
+
) -> OutputSchema[OutputDataT]: ...
|
|
156
152
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
153
|
+
@classmethod
|
|
154
|
+
@overload
|
|
155
|
+
def build(
|
|
156
|
+
cls,
|
|
157
|
+
output_spec: OutputSpec[OutputDataT],
|
|
158
|
+
*,
|
|
159
|
+
default_mode: None = None,
|
|
160
|
+
name: str | None = None,
|
|
161
|
+
description: str | None = None,
|
|
162
|
+
strict: bool | None = None,
|
|
163
|
+
) -> BaseOutputSchema[OutputDataT]: ...
|
|
162
164
|
|
|
163
165
|
@classmethod
|
|
164
166
|
def build(
|
|
165
|
-
cls
|
|
166
|
-
|
|
167
|
+
cls,
|
|
168
|
+
output_spec: OutputSpec[OutputDataT],
|
|
169
|
+
*,
|
|
170
|
+
default_mode: StructuredOutputMode | None = None,
|
|
167
171
|
name: str | None = None,
|
|
168
172
|
description: str | None = None,
|
|
169
173
|
strict: bool | None = None,
|
|
170
|
-
) ->
|
|
174
|
+
) -> BaseOutputSchema[OutputDataT]:
|
|
171
175
|
"""Build an OutputSchema dataclass from an output type."""
|
|
172
|
-
if
|
|
173
|
-
return
|
|
176
|
+
if output_spec is str:
|
|
177
|
+
return PlainTextOutputSchema()
|
|
178
|
+
|
|
179
|
+
if isinstance(output_spec, NativeOutput):
|
|
180
|
+
return NativeOutputSchema(
|
|
181
|
+
cls._build_processor(
|
|
182
|
+
_flatten_output_spec(output_spec.outputs),
|
|
183
|
+
name=output_spec.name,
|
|
184
|
+
description=output_spec.description,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
elif isinstance(output_spec, PromptedOutput):
|
|
188
|
+
return PromptedOutputSchema(
|
|
189
|
+
cls._build_processor(
|
|
190
|
+
_flatten_output_spec(output_spec.outputs),
|
|
191
|
+
name=output_spec.name,
|
|
192
|
+
description=output_spec.description,
|
|
193
|
+
),
|
|
194
|
+
template=output_spec.template,
|
|
195
|
+
)
|
|
174
196
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
197
|
+
text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
|
|
198
|
+
tool_outputs: Sequence[ToolOutput[OutputDataT]] = []
|
|
199
|
+
other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = []
|
|
200
|
+
for output in _flatten_output_spec(output_spec):
|
|
201
|
+
if output is str:
|
|
202
|
+
text_outputs.append(cast(type[str], output))
|
|
203
|
+
elif isinstance(output, TextOutput):
|
|
204
|
+
text_outputs.append(output)
|
|
205
|
+
elif isinstance(output, ToolOutput):
|
|
206
|
+
tool_outputs.append(output)
|
|
207
|
+
else:
|
|
208
|
+
other_outputs.append(output)
|
|
180
209
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
210
|
+
tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict)
|
|
211
|
+
|
|
212
|
+
if len(text_outputs) > 0:
|
|
213
|
+
if len(text_outputs) > 1:
|
|
214
|
+
raise UserError('Only one text output is allowed.')
|
|
215
|
+
text_output = text_outputs[0]
|
|
216
|
+
|
|
217
|
+
text_output_schema = None
|
|
218
|
+
if isinstance(text_output, TextOutput):
|
|
219
|
+
text_output_schema = PlainTextOutputProcessor(text_output.output_function)
|
|
220
|
+
|
|
221
|
+
if len(tools) == 0:
|
|
222
|
+
return PlainTextOutputSchema(text_output_schema)
|
|
185
223
|
else:
|
|
186
|
-
|
|
224
|
+
return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools)
|
|
187
225
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
allow_text_output = True
|
|
191
|
-
output_types_flat = [t for t in output_types_flat if t is not str]
|
|
226
|
+
if len(tool_outputs) > 0:
|
|
227
|
+
return ToolOutputSchema(tools)
|
|
192
228
|
|
|
193
|
-
|
|
229
|
+
if len(other_outputs) > 0:
|
|
230
|
+
schema = OutputSchemaWithoutMode(
|
|
231
|
+
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
|
|
232
|
+
tools=tools,
|
|
233
|
+
)
|
|
234
|
+
if default_mode:
|
|
235
|
+
schema = schema.with_default_mode(default_mode)
|
|
236
|
+
return schema
|
|
194
237
|
|
|
195
|
-
|
|
196
|
-
default_tool_description = description
|
|
197
|
-
default_tool_strict = strict
|
|
238
|
+
raise UserError('No output type provided.') # pragma: no cover
|
|
198
239
|
|
|
240
|
+
@staticmethod
|
|
241
|
+
def _build_tools(
|
|
242
|
+
outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
|
|
243
|
+
name: str | None = None,
|
|
244
|
+
description: str | None = None,
|
|
245
|
+
strict: bool | None = None,
|
|
246
|
+
) -> dict[str, OutputTool[OutputDataT]]:
|
|
199
247
|
tools: dict[str, OutputTool[OutputDataT]] = {}
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
248
|
+
|
|
249
|
+
default_name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
250
|
+
default_description = description
|
|
251
|
+
default_strict = strict
|
|
252
|
+
|
|
253
|
+
multiple = len(outputs) > 1
|
|
254
|
+
for output in outputs:
|
|
255
|
+
name = None
|
|
256
|
+
description = None
|
|
257
|
+
strict = None
|
|
258
|
+
if isinstance(output, ToolOutput):
|
|
206
259
|
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
else:
|
|
211
|
-
tool_output_type = output_type
|
|
260
|
+
name = output.name
|
|
261
|
+
description = output.description
|
|
262
|
+
strict = output.strict
|
|
212
263
|
|
|
213
|
-
|
|
214
|
-
|
|
264
|
+
output = output.output
|
|
265
|
+
|
|
266
|
+
if name is None:
|
|
267
|
+
name = default_name
|
|
215
268
|
if multiple:
|
|
216
|
-
|
|
269
|
+
name += f'_{output.__name__}'
|
|
217
270
|
|
|
218
271
|
i = 1
|
|
219
|
-
|
|
220
|
-
while
|
|
272
|
+
original_name = name
|
|
273
|
+
while name in tools:
|
|
221
274
|
i += 1
|
|
222
|
-
|
|
275
|
+
name = f'{original_name}_{i}'
|
|
223
276
|
|
|
224
|
-
|
|
225
|
-
if
|
|
226
|
-
|
|
277
|
+
description = description or default_description
|
|
278
|
+
if strict is None:
|
|
279
|
+
strict = default_strict
|
|
227
280
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
281
|
+
processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
|
|
282
|
+
tools[name] = OutputTool(name=name, processor=processor, multiple=multiple)
|
|
283
|
+
|
|
284
|
+
return tools
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def _build_processor(
|
|
288
|
+
outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
|
|
289
|
+
name: str | None = None,
|
|
290
|
+
description: str | None = None,
|
|
291
|
+
strict: bool | None = None,
|
|
292
|
+
) -> ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]:
|
|
293
|
+
outputs = _flatten_output_spec(outputs)
|
|
294
|
+
if len(outputs) == 1:
|
|
295
|
+
return ObjectOutputProcessor(output=outputs[0], name=name, description=description, strict=strict)
|
|
296
|
+
|
|
297
|
+
return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description)
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
@abstractmethod
|
|
301
|
+
def mode(self) -> OutputMode:
|
|
302
|
+
raise NotImplementedError()
|
|
303
|
+
|
|
304
|
+
@abstractmethod
|
|
305
|
+
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
306
|
+
"""Raise an error if the mode is not supported by the model."""
|
|
307
|
+
raise NotImplementedError()
|
|
308
|
+
|
|
309
|
+
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
310
|
+
return self
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@dataclass(init=False)
|
|
314
|
+
class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
|
|
315
|
+
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
|
|
316
|
+
_tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
|
|
317
|
+
|
|
318
|
+
def __init__(
|
|
319
|
+
self,
|
|
320
|
+
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
|
|
321
|
+
tools: dict[str, OutputTool[OutputDataT]],
|
|
322
|
+
):
|
|
323
|
+
self.processor = processor
|
|
324
|
+
self._tools = tools
|
|
325
|
+
|
|
326
|
+
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
327
|
+
if mode == 'native':
|
|
328
|
+
return NativeOutputSchema(self.processor)
|
|
329
|
+
elif mode == 'prompted':
|
|
330
|
+
return PromptedOutputSchema(self.processor)
|
|
331
|
+
elif mode == 'tool':
|
|
332
|
+
return ToolOutputSchema(self.tools)
|
|
333
|
+
else:
|
|
334
|
+
assert_never(mode)
|
|
232
335
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
336
|
+
@property
|
|
337
|
+
def tools(self) -> dict[str, OutputTool[OutputDataT]]:
|
|
338
|
+
"""Get the tools for this output schema."""
|
|
339
|
+
# We return tools here as they're checked in Agent._register_tool.
|
|
340
|
+
# At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time.
|
|
341
|
+
return self._tools
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class TextOutputSchema(OutputSchema[OutputDataT], ABC):
|
|
345
|
+
@abstractmethod
|
|
346
|
+
async def process(
|
|
347
|
+
self,
|
|
348
|
+
text: str,
|
|
349
|
+
run_context: RunContext[AgentDepsT],
|
|
350
|
+
allow_partial: bool = False,
|
|
351
|
+
wrap_validation_errors: bool = True,
|
|
352
|
+
) -> OutputDataT:
|
|
353
|
+
raise NotImplementedError()
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@dataclass
|
|
357
|
+
class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
|
|
358
|
+
processor: PlainTextOutputProcessor[OutputDataT] | None = None
|
|
359
|
+
|
|
360
|
+
@property
|
|
361
|
+
def mode(self) -> OutputMode:
|
|
362
|
+
return 'text'
|
|
363
|
+
|
|
364
|
+
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
365
|
+
"""Raise an error if the mode is not supported by the model."""
|
|
366
|
+
pass
|
|
367
|
+
|
|
368
|
+
async def process(
|
|
369
|
+
self,
|
|
370
|
+
text: str,
|
|
371
|
+
run_context: RunContext[AgentDepsT],
|
|
372
|
+
allow_partial: bool = False,
|
|
373
|
+
wrap_validation_errors: bool = True,
|
|
374
|
+
) -> OutputDataT:
|
|
375
|
+
"""Validate an output message.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
text: The output text to validate.
|
|
379
|
+
run_context: The current run context.
|
|
380
|
+
allow_partial: If true, allow partial validation.
|
|
381
|
+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
Either the validated output data (left) or a retry message (right).
|
|
385
|
+
"""
|
|
386
|
+
if self.processor is None:
|
|
387
|
+
return cast(OutputDataT, text)
|
|
388
|
+
|
|
389
|
+
return await self.processor.process(
|
|
390
|
+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
236
391
|
)
|
|
237
392
|
|
|
393
|
+
|
|
394
|
+
@dataclass
|
|
395
|
+
class StructuredTextOutputSchema(TextOutputSchema[OutputDataT], ABC):
|
|
396
|
+
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
|
|
397
|
+
|
|
398
|
+
@property
|
|
399
|
+
def object_def(self) -> OutputObjectDefinition:
|
|
400
|
+
return self.processor.object_def
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
@dataclass
|
|
404
|
+
class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
405
|
+
@property
|
|
406
|
+
def mode(self) -> OutputMode:
|
|
407
|
+
return 'native'
|
|
408
|
+
|
|
409
|
+
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
410
|
+
"""Raise an error if the mode is not supported by the model."""
|
|
411
|
+
if not profile.supports_json_schema_output:
|
|
412
|
+
raise UserError('Structured output is not supported by the model.')
|
|
413
|
+
|
|
414
|
+
async def process(
|
|
415
|
+
self,
|
|
416
|
+
text: str,
|
|
417
|
+
run_context: RunContext[AgentDepsT],
|
|
418
|
+
allow_partial: bool = False,
|
|
419
|
+
wrap_validation_errors: bool = True,
|
|
420
|
+
) -> OutputDataT:
|
|
421
|
+
"""Validate an output message.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
text: The output text to validate.
|
|
425
|
+
run_context: The current run context.
|
|
426
|
+
allow_partial: If true, allow partial validation.
|
|
427
|
+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
Either the validated output data (left) or a retry message (right).
|
|
431
|
+
"""
|
|
432
|
+
return await self.processor.process(
|
|
433
|
+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
@dataclass
|
|
438
|
+
class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
439
|
+
template: str | None = None
|
|
440
|
+
|
|
441
|
+
@property
|
|
442
|
+
def mode(self) -> OutputMode:
|
|
443
|
+
return 'prompted'
|
|
444
|
+
|
|
445
|
+
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
446
|
+
"""Raise an error if the mode is not supported by the model."""
|
|
447
|
+
pass
|
|
448
|
+
|
|
449
|
+
def instructions(self, default_template: str) -> str:
|
|
450
|
+
"""Get instructions to tell model to output JSON matching the schema."""
|
|
451
|
+
template = self.template or default_template
|
|
452
|
+
|
|
453
|
+
if '{schema}' not in template:
|
|
454
|
+
template = '\n\n'.join([template, '{schema}'])
|
|
455
|
+
|
|
456
|
+
object_def = self.object_def
|
|
457
|
+
schema = object_def.json_schema.copy()
|
|
458
|
+
if object_def.name:
|
|
459
|
+
schema['title'] = object_def.name
|
|
460
|
+
if object_def.description:
|
|
461
|
+
schema['description'] = object_def.description
|
|
462
|
+
|
|
463
|
+
return template.format(schema=json.dumps(schema))
|
|
464
|
+
|
|
465
|
+
async def process(
|
|
466
|
+
self,
|
|
467
|
+
text: str,
|
|
468
|
+
run_context: RunContext[AgentDepsT],
|
|
469
|
+
allow_partial: bool = False,
|
|
470
|
+
wrap_validation_errors: bool = True,
|
|
471
|
+
) -> OutputDataT:
|
|
472
|
+
"""Validate an output message.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
text: The output text to validate.
|
|
476
|
+
run_context: The current run context.
|
|
477
|
+
allow_partial: If true, allow partial validation.
|
|
478
|
+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
Either the validated output data (left) or a retry message (right).
|
|
482
|
+
"""
|
|
483
|
+
text = _utils.strip_markdown_fences(text)
|
|
484
|
+
|
|
485
|
+
return await self.processor.process(
|
|
486
|
+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
@dataclass(init=False)
|
|
491
|
+
class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
492
|
+
_tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
|
|
493
|
+
|
|
494
|
+
def __init__(self, tools: dict[str, OutputTool[OutputDataT]]):
|
|
495
|
+
self._tools = tools
|
|
496
|
+
|
|
497
|
+
@property
|
|
498
|
+
def mode(self) -> OutputMode:
|
|
499
|
+
return 'tool'
|
|
500
|
+
|
|
501
|
+
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
502
|
+
"""Raise an error if the mode is not supported by the model."""
|
|
503
|
+
if not profile.supports_tools:
|
|
504
|
+
raise UserError('Output tools are not supported by the model.')
|
|
505
|
+
|
|
506
|
+
@property
|
|
507
|
+
def tools(self) -> dict[str, OutputTool[OutputDataT]]:
|
|
508
|
+
"""Get the tools for this output schema."""
|
|
509
|
+
return self._tools
|
|
510
|
+
|
|
511
|
+
def tool_names(self) -> list[str]:
|
|
512
|
+
"""Return the names of the tools."""
|
|
513
|
+
return list(self.tools.keys())
|
|
514
|
+
|
|
515
|
+
def tool_defs(self) -> list[ToolDefinition]:
|
|
516
|
+
"""Get tool definitions to register with the model."""
|
|
517
|
+
return [t.tool_def for t in self.tools.values()]
|
|
518
|
+
|
|
238
519
|
def find_named_tool(
|
|
239
520
|
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
240
521
|
) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
|
|
@@ -254,61 +535,78 @@ class OutputSchema(Generic[OutputDataT]):
|
|
|
254
535
|
if result := self.tools.get(part.tool_name):
|
|
255
536
|
yield part, result
|
|
256
537
|
|
|
257
|
-
def tool_names(self) -> list[str]:
|
|
258
|
-
"""Return the names of the tools."""
|
|
259
|
-
return list(self.tools.keys())
|
|
260
|
-
|
|
261
|
-
def tool_defs(self) -> list[ToolDefinition]:
|
|
262
|
-
"""Get tool definitions to register with the model."""
|
|
263
|
-
return [t.tool_def for t in self.tools.values()]
|
|
264
538
|
|
|
539
|
+
@dataclass(init=False)
|
|
540
|
+
class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchema[OutputDataT]):
|
|
541
|
+
def __init__(
|
|
542
|
+
self,
|
|
543
|
+
processor: PlainTextOutputProcessor[OutputDataT] | None,
|
|
544
|
+
tools: dict[str, OutputTool[OutputDataT]],
|
|
545
|
+
):
|
|
546
|
+
self.processor = processor
|
|
547
|
+
self._tools = tools
|
|
265
548
|
|
|
266
|
-
|
|
267
|
-
|
|
549
|
+
@property
|
|
550
|
+
def mode(self) -> OutputMode:
|
|
551
|
+
return 'tool_or_text'
|
|
268
552
|
|
|
269
553
|
|
|
270
554
|
@dataclass
|
|
271
555
|
class OutputObjectDefinition:
|
|
272
|
-
name: str
|
|
273
556
|
json_schema: ObjectJsonSchema
|
|
557
|
+
name: str | None = None
|
|
274
558
|
description: str | None = None
|
|
275
559
|
strict: bool | None = None
|
|
276
560
|
|
|
277
561
|
|
|
278
562
|
@dataclass(init=False)
|
|
279
|
-
class
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
563
|
+
class BaseOutputProcessor(ABC, Generic[OutputDataT]):
|
|
564
|
+
@abstractmethod
|
|
565
|
+
async def process(
|
|
566
|
+
self,
|
|
567
|
+
data: str,
|
|
568
|
+
run_context: RunContext[AgentDepsT],
|
|
569
|
+
allow_partial: bool = False,
|
|
570
|
+
wrap_validation_errors: bool = True,
|
|
571
|
+
) -> OutputDataT:
|
|
572
|
+
"""Process an output message, performing validation and (if necessary) calling the output function."""
|
|
573
|
+
raise NotImplementedError()
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
@dataclass(init=False)
|
|
577
|
+
class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
578
|
+
object_def: OutputObjectDefinition
|
|
283
579
|
outer_typed_dict_key: str | None = None
|
|
580
|
+
_validator: SchemaValidator
|
|
581
|
+
_function_schema: _function_schema.FunctionSchema | None = None
|
|
284
582
|
|
|
285
583
|
def __init__(
|
|
286
584
|
self,
|
|
585
|
+
output: OutputTypeOrFunction[OutputDataT],
|
|
287
586
|
*,
|
|
288
|
-
output_type: SimpleOutputType[OutputDataT],
|
|
289
587
|
name: str | None = None,
|
|
290
588
|
description: str | None = None,
|
|
291
589
|
strict: bool | None = None,
|
|
292
590
|
):
|
|
293
|
-
if inspect.isfunction(
|
|
294
|
-
self.
|
|
295
|
-
self.
|
|
296
|
-
json_schema = self.
|
|
297
|
-
json_schema['description'] = self.
|
|
591
|
+
if inspect.isfunction(output) or inspect.ismethod(output):
|
|
592
|
+
self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema)
|
|
593
|
+
self._validator = self._function_schema.validator
|
|
594
|
+
json_schema = self._function_schema.json_schema
|
|
595
|
+
json_schema['description'] = self._function_schema.description
|
|
298
596
|
else:
|
|
299
597
|
type_adapter: TypeAdapter[Any]
|
|
300
|
-
if _utils.is_model_like(
|
|
301
|
-
type_adapter = TypeAdapter(
|
|
598
|
+
if _utils.is_model_like(output):
|
|
599
|
+
type_adapter = TypeAdapter(output)
|
|
302
600
|
else:
|
|
303
601
|
self.outer_typed_dict_key = 'response'
|
|
304
602
|
response_data_typed_dict = TypedDict( # noqa: UP013
|
|
305
603
|
'response_data_typed_dict',
|
|
306
|
-
{'response': cast(type[OutputDataT],
|
|
604
|
+
{'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm]
|
|
307
605
|
)
|
|
308
606
|
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
309
607
|
|
|
310
608
|
# Really a PluggableSchemaValidator, but it's API-compatible
|
|
311
|
-
self.
|
|
609
|
+
self._validator = cast(SchemaValidator, type_adapter.validator)
|
|
312
610
|
json_schema = _utils.check_object_json_schema(
|
|
313
611
|
type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
|
|
314
612
|
)
|
|
@@ -323,8 +621,8 @@ class OutputObjectSchema(Generic[OutputDataT]):
|
|
|
323
621
|
else:
|
|
324
622
|
description = f'{description}. {json_schema_description}'
|
|
325
623
|
|
|
326
|
-
self.
|
|
327
|
-
name=name or getattr(
|
|
624
|
+
self.object_def = OutputObjectDefinition(
|
|
625
|
+
name=name or getattr(output, '__name__', None),
|
|
328
626
|
description=description,
|
|
329
627
|
json_schema=json_schema,
|
|
330
628
|
strict=strict,
|
|
@@ -335,6 +633,7 @@ class OutputObjectSchema(Generic[OutputDataT]):
|
|
|
335
633
|
data: str | dict[str, Any] | None,
|
|
336
634
|
run_context: RunContext[AgentDepsT],
|
|
337
635
|
allow_partial: bool = False,
|
|
636
|
+
wrap_validation_errors: bool = True,
|
|
338
637
|
) -> OutputDataT:
|
|
339
638
|
"""Process an output message, performing validation and (if necessary) calling the output function.
|
|
340
639
|
|
|
@@ -342,45 +641,235 @@ class OutputObjectSchema(Generic[OutputDataT]):
|
|
|
342
641
|
data: The output data to validate.
|
|
343
642
|
run_context: The current run context.
|
|
344
643
|
allow_partial: If true, allow partial validation.
|
|
644
|
+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
345
645
|
|
|
346
646
|
Returns:
|
|
347
647
|
Either the validated output data (left) or a retry message (right).
|
|
348
648
|
"""
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
649
|
+
try:
|
|
650
|
+
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
651
|
+
if isinstance(data, str):
|
|
652
|
+
output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
|
|
653
|
+
else:
|
|
654
|
+
output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
|
|
655
|
+
except ValidationError as e:
|
|
656
|
+
if wrap_validation_errors:
|
|
657
|
+
m = _messages.RetryPromptPart(
|
|
658
|
+
content=e.errors(include_url=False),
|
|
659
|
+
)
|
|
660
|
+
raise ToolRetryError(m) from e
|
|
661
|
+
else:
|
|
662
|
+
raise # pragma: lax no cover
|
|
357
663
|
|
|
358
664
|
if k := self.outer_typed_dict_key:
|
|
359
665
|
output = output[k]
|
|
666
|
+
|
|
667
|
+
if self._function_schema:
|
|
668
|
+
try:
|
|
669
|
+
output = await self._function_schema.call(output, run_context)
|
|
670
|
+
except ModelRetry as r:
|
|
671
|
+
if wrap_validation_errors:
|
|
672
|
+
m = _messages.RetryPromptPart(
|
|
673
|
+
content=r.message,
|
|
674
|
+
)
|
|
675
|
+
raise ToolRetryError(m) from r
|
|
676
|
+
else:
|
|
677
|
+
raise # pragma: lax no cover
|
|
678
|
+
|
|
360
679
|
return output
|
|
361
680
|
|
|
362
681
|
|
|
682
|
+
@dataclass
|
|
683
|
+
class UnionOutputResult:
|
|
684
|
+
kind: str
|
|
685
|
+
data: ObjectJsonSchema
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
@dataclass
|
|
689
|
+
class UnionOutputModel:
|
|
690
|
+
result: UnionOutputResult
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
@dataclass(init=False)
|
|
694
|
+
class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
695
|
+
object_def: OutputObjectDefinition
|
|
696
|
+
_union_processor: ObjectOutputProcessor[UnionOutputModel]
|
|
697
|
+
_processors: dict[str, ObjectOutputProcessor[OutputDataT]]
|
|
698
|
+
|
|
699
|
+
def __init__(
|
|
700
|
+
self,
|
|
701
|
+
outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
|
|
702
|
+
*,
|
|
703
|
+
name: str | None = None,
|
|
704
|
+
description: str | None = None,
|
|
705
|
+
strict: bool | None = None,
|
|
706
|
+
):
|
|
707
|
+
self._union_processor = ObjectOutputProcessor(output=UnionOutputModel)
|
|
708
|
+
|
|
709
|
+
json_schemas: list[ObjectJsonSchema] = []
|
|
710
|
+
self._processors = {}
|
|
711
|
+
for output in outputs:
|
|
712
|
+
processor = ObjectOutputProcessor(output=output, strict=strict)
|
|
713
|
+
object_def = processor.object_def
|
|
714
|
+
|
|
715
|
+
object_key = object_def.name or output.__name__
|
|
716
|
+
i = 1
|
|
717
|
+
original_key = object_key
|
|
718
|
+
while object_key in self._processors:
|
|
719
|
+
i += 1
|
|
720
|
+
object_key = f'{original_key}_{i}'
|
|
721
|
+
|
|
722
|
+
self._processors[object_key] = processor
|
|
723
|
+
|
|
724
|
+
json_schema = object_def.json_schema
|
|
725
|
+
if object_def.name: # pragma: no branch
|
|
726
|
+
json_schema['title'] = object_def.name
|
|
727
|
+
if object_def.description:
|
|
728
|
+
json_schema['description'] = object_def.description
|
|
729
|
+
|
|
730
|
+
json_schemas.append(json_schema)
|
|
731
|
+
|
|
732
|
+
json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
|
|
733
|
+
|
|
734
|
+
discriminated_json_schemas: list[ObjectJsonSchema] = []
|
|
735
|
+
for object_key, json_schema in zip(self._processors.keys(), json_schemas):
|
|
736
|
+
title = json_schema.pop('title', None)
|
|
737
|
+
description = json_schema.pop('description', None)
|
|
738
|
+
|
|
739
|
+
discriminated_json_schema = {
|
|
740
|
+
'type': 'object',
|
|
741
|
+
'properties': {
|
|
742
|
+
'kind': {
|
|
743
|
+
'type': 'string',
|
|
744
|
+
'const': object_key,
|
|
745
|
+
},
|
|
746
|
+
'data': json_schema,
|
|
747
|
+
},
|
|
748
|
+
'required': ['kind', 'data'],
|
|
749
|
+
'additionalProperties': False,
|
|
750
|
+
}
|
|
751
|
+
if title: # pragma: no branch
|
|
752
|
+
discriminated_json_schema['title'] = title
|
|
753
|
+
if description:
|
|
754
|
+
discriminated_json_schema['description'] = description
|
|
755
|
+
|
|
756
|
+
discriminated_json_schemas.append(discriminated_json_schema)
|
|
757
|
+
|
|
758
|
+
json_schema = {
|
|
759
|
+
'type': 'object',
|
|
760
|
+
'properties': {
|
|
761
|
+
'result': {
|
|
762
|
+
'anyOf': discriminated_json_schemas,
|
|
763
|
+
}
|
|
764
|
+
},
|
|
765
|
+
'required': ['result'],
|
|
766
|
+
'additionalProperties': False,
|
|
767
|
+
}
|
|
768
|
+
if all_defs:
|
|
769
|
+
json_schema['$defs'] = all_defs
|
|
770
|
+
|
|
771
|
+
self.object_def = OutputObjectDefinition(
|
|
772
|
+
json_schema=json_schema,
|
|
773
|
+
strict=strict,
|
|
774
|
+
name=name,
|
|
775
|
+
description=description,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
async def process(
|
|
779
|
+
self,
|
|
780
|
+
data: str | dict[str, Any] | None,
|
|
781
|
+
run_context: RunContext[AgentDepsT],
|
|
782
|
+
allow_partial: bool = False,
|
|
783
|
+
wrap_validation_errors: bool = True,
|
|
784
|
+
) -> OutputDataT:
|
|
785
|
+
union_object = await self._union_processor.process(
|
|
786
|
+
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
result = union_object.result
|
|
790
|
+
kind = result.kind
|
|
791
|
+
data = result.data
|
|
792
|
+
try:
|
|
793
|
+
processor = self._processors[kind]
|
|
794
|
+
except KeyError as e: # pragma: no cover
|
|
795
|
+
if wrap_validation_errors:
|
|
796
|
+
m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}')
|
|
797
|
+
raise ToolRetryError(m) from e
|
|
798
|
+
else:
|
|
799
|
+
raise
|
|
800
|
+
|
|
801
|
+
return await processor.process(
|
|
802
|
+
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
@dataclass(init=False)
|
|
807
|
+
class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
808
|
+
_function_schema: _function_schema.FunctionSchema
|
|
809
|
+
_str_argument_name: str
|
|
810
|
+
|
|
811
|
+
def __init__(
|
|
812
|
+
self,
|
|
813
|
+
output_function: TextOutputFunc[OutputDataT],
|
|
814
|
+
):
|
|
815
|
+
self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema)
|
|
816
|
+
|
|
817
|
+
arguments_schema = self._function_schema.json_schema.get('properties', {})
|
|
818
|
+
argument_name = next(iter(arguments_schema.keys()), None)
|
|
819
|
+
if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string':
|
|
820
|
+
self._str_argument_name = argument_name
|
|
821
|
+
return
|
|
822
|
+
|
|
823
|
+
raise UserError('TextOutput must take a function taking a `str`')
|
|
824
|
+
|
|
825
|
+
@property
|
|
826
|
+
def object_def(self) -> None:
|
|
827
|
+
return None # pragma: no cover
|
|
828
|
+
|
|
829
|
+
async def process(
|
|
830
|
+
self,
|
|
831
|
+
data: str,
|
|
832
|
+
run_context: RunContext[AgentDepsT],
|
|
833
|
+
allow_partial: bool = False,
|
|
834
|
+
wrap_validation_errors: bool = True,
|
|
835
|
+
) -> OutputDataT:
|
|
836
|
+
args = {self._str_argument_name: data}
|
|
837
|
+
|
|
838
|
+
try:
|
|
839
|
+
output = await self._function_schema.call(args, run_context)
|
|
840
|
+
except ModelRetry as r:
|
|
841
|
+
if wrap_validation_errors:
|
|
842
|
+
m = _messages.RetryPromptPart(
|
|
843
|
+
content=r.message,
|
|
844
|
+
)
|
|
845
|
+
raise ToolRetryError(m) from r
|
|
846
|
+
else:
|
|
847
|
+
raise # pragma: lax no cover
|
|
848
|
+
|
|
849
|
+
return cast(OutputDataT, output)
|
|
850
|
+
|
|
851
|
+
|
|
363
852
|
@dataclass(init=False)
|
|
364
853
|
class OutputTool(Generic[OutputDataT]):
|
|
365
|
-
|
|
854
|
+
processor: ObjectOutputProcessor[OutputDataT]
|
|
366
855
|
tool_def: ToolDefinition
|
|
367
856
|
|
|
368
|
-
def __init__(self, *, name: str,
|
|
369
|
-
self.
|
|
370
|
-
|
|
857
|
+
def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool):
|
|
858
|
+
self.processor = processor
|
|
859
|
+
object_def = processor.object_def
|
|
371
860
|
|
|
372
|
-
description =
|
|
861
|
+
description = object_def.description
|
|
373
862
|
if not description:
|
|
374
863
|
description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
|
|
375
864
|
if multiple:
|
|
376
|
-
description = f'{
|
|
865
|
+
description = f'{object_def.name}: {description}'
|
|
377
866
|
|
|
378
867
|
self.tool_def = ToolDefinition(
|
|
379
868
|
name=name,
|
|
380
869
|
description=description,
|
|
381
|
-
parameters_json_schema=
|
|
382
|
-
strict=
|
|
383
|
-
outer_typed_dict_key=
|
|
870
|
+
parameters_json_schema=object_def.json_schema,
|
|
871
|
+
strict=object_def.strict,
|
|
872
|
+
outer_typed_dict_key=processor.outer_typed_dict_key,
|
|
384
873
|
)
|
|
385
874
|
|
|
386
875
|
async def process(
|
|
@@ -402,7 +891,9 @@ class OutputTool(Generic[OutputDataT]):
|
|
|
402
891
|
Either the validated output data (left) or a retry message (right).
|
|
403
892
|
"""
|
|
404
893
|
try:
|
|
405
|
-
output = await self.
|
|
894
|
+
output = await self.processor.process(
|
|
895
|
+
tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False
|
|
896
|
+
)
|
|
406
897
|
except ValidationError as e:
|
|
407
898
|
if wrap_validation_errors:
|
|
408
899
|
m = _messages.RetryPromptPart(
|
|
@@ -427,13 +918,17 @@ class OutputTool(Generic[OutputDataT]):
|
|
|
427
918
|
return output
|
|
428
919
|
|
|
429
920
|
|
|
430
|
-
def
|
|
431
|
-
|
|
432
|
-
if
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
origin = get_origin(tp)
|
|
436
|
-
if is_union_origin(origin):
|
|
437
|
-
return get_args(tp)
|
|
921
|
+
def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]:
|
|
922
|
+
outputs: Sequence[T]
|
|
923
|
+
if isinstance(output_spec, Sequence):
|
|
924
|
+
outputs = output_spec
|
|
438
925
|
else:
|
|
439
|
-
|
|
926
|
+
outputs = (output_spec,)
|
|
927
|
+
|
|
928
|
+
outputs_flat: list[T] = []
|
|
929
|
+
for output in outputs:
|
|
930
|
+
if union_types := _utils.get_union_args(output):
|
|
931
|
+
outputs_flat.extend(union_types)
|
|
932
|
+
else:
|
|
933
|
+
outputs_flat.append(output)
|
|
934
|
+
return outputs_flat
|