pydantic-ai-slim 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/_output.py CHANGED
@@ -1,19 +1,35 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
+ import json
5
+ from abc import ABC, abstractmethod
4
6
  from collections.abc import Awaitable, Iterable, Iterator, Sequence
5
7
  from dataclasses import dataclass, field
6
- from typing import Any, Callable, Generic, Literal, Union, cast
8
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
7
9
 
8
10
  from pydantic import TypeAdapter, ValidationError
9
11
  from pydantic_core import SchemaValidator
10
- from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin
11
- from typing_inspection import typing_objects
12
- from typing_inspection.introspection import is_union_origin
12
+ from typing_extensions import TypedDict, TypeVar, assert_never
13
13
 
14
14
  from . import _function_schema, _utils, messages as _messages
15
- from .exceptions import ModelRetry
16
- from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition
15
+ from ._run_context import AgentDepsT, RunContext
16
+ from .exceptions import ModelRetry, UserError
17
+ from .output import (
18
+ NativeOutput,
19
+ OutputDataT,
20
+ OutputMode,
21
+ OutputSpec,
22
+ OutputTypeOrFunction,
23
+ PromptedOutput,
24
+ StructuredOutputMode,
25
+ TextOutput,
26
+ TextOutputFunc,
27
+ ToolOutput,
28
+ )
29
+ from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
30
+
31
+ if TYPE_CHECKING:
32
+ from .profiles import ModelProfile
17
33
 
18
34
  T = TypeVar('T')
19
35
  """An invariant TypeVar."""
@@ -29,8 +45,6 @@ changing it would have negative consequences for the ergonomics of the library.
29
45
  At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would
30
46
  resolve these potential variance issues.
31
47
  """
32
- OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
33
- """Covariant type variable for the result data type of a run."""
34
48
 
35
49
  OutputValidatorFunc = Union[
36
50
  Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv],
@@ -52,6 +66,14 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
52
66
  DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
53
67
 
54
68
 
69
+ class ToolRetryError(Exception):
70
+ """Exception used to signal a `ToolRetry` message should be returned to the LLM."""
71
+
72
+ def __init__(self, tool_retry: _messages.RetryPromptPart):
73
+ self.tool_retry = tool_retry
74
+ super().__init__()
75
+
76
+
55
77
  @dataclass
56
78
  class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
57
79
  function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv]
@@ -101,140 +123,399 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
101
123
  return result_data
102
124
 
103
125
 
104
- class ToolRetryError(Exception):
105
- """Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
126
+ class BaseOutputSchema(ABC, Generic[OutputDataT]):
127
+ @abstractmethod
128
+ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
129
+ raise NotImplementedError()
106
130
 
107
- def __init__(self, tool_retry: _messages.RetryPromptPart):
108
- self.tool_retry = tool_retry
109
- super().__init__()
131
+ @property
132
+ def tools(self) -> dict[str, OutputTool[OutputDataT]]:
133
+ """Get the tools for this output schema."""
134
+ return {}
110
135
 
111
136
 
112
137
  @dataclass(init=False)
113
- class ToolOutput(Generic[OutputDataT]):
114
- """Marker class to use tools for outputs, and customize the tool."""
115
-
116
- output_type: SimpleOutputType[OutputDataT]
117
- name: str | None
118
- description: str | None
119
- max_retries: int | None
120
- strict: bool | None
138
+ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
139
+ """Model the final output from an agent run."""
121
140
 
122
- def __init__(
123
- self,
124
- type_: SimpleOutputType[OutputDataT],
141
+ @classmethod
142
+ @overload
143
+ def build(
144
+ cls,
145
+ output_spec: OutputSpec[OutputDataT],
125
146
  *,
147
+ default_mode: StructuredOutputMode,
126
148
  name: str | None = None,
127
149
  description: str | None = None,
128
- max_retries: int | None = None,
129
150
  strict: bool | None = None,
130
- ):
131
- self.output_type = type_
132
- self.name = name
133
- self.description = description
134
- self.max_retries = max_retries
135
- self.strict = strict
136
-
137
-
138
- T_co = TypeVar('T_co', covariant=True)
139
- # output_type=Type or output_type=function or output_type=object.method
140
- SimpleOutputType = TypeAliasType(
141
- 'SimpleOutputType', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,)
142
- )
143
- # output_type=ToolOutput(<see above>) or <see above>
144
- SimpleOutputTypeOrMarker = TypeAliasType(
145
- 'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,)
146
- )
147
- # output_type=<see above> or [<see above>, ...]
148
- OutputType = TypeAliasType(
149
- 'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,)
150
- )
151
-
152
-
153
- @dataclass
154
- class OutputSchema(Generic[OutputDataT]):
155
- """Model the final output from an agent run.
151
+ ) -> OutputSchema[OutputDataT]: ...
156
152
 
157
- Similar to `Tool` but for the final output of running an agent.
158
- """
159
-
160
- tools: dict[str, OutputTool[OutputDataT]]
161
- allow_text_output: bool
153
+ @classmethod
154
+ @overload
155
+ def build(
156
+ cls,
157
+ output_spec: OutputSpec[OutputDataT],
158
+ *,
159
+ default_mode: None = None,
160
+ name: str | None = None,
161
+ description: str | None = None,
162
+ strict: bool | None = None,
163
+ ) -> BaseOutputSchema[OutputDataT]: ...
162
164
 
163
165
  @classmethod
164
166
  def build(
165
- cls: type[OutputSchema[OutputDataT]],
166
- output_type: OutputType[OutputDataT],
167
+ cls,
168
+ output_spec: OutputSpec[OutputDataT],
169
+ *,
170
+ default_mode: StructuredOutputMode | None = None,
167
171
  name: str | None = None,
168
172
  description: str | None = None,
169
173
  strict: bool | None = None,
170
- ) -> OutputSchema[OutputDataT] | None:
174
+ ) -> BaseOutputSchema[OutputDataT]:
171
175
  """Build an OutputSchema dataclass from an output type."""
172
- if output_type is str:
173
- return None
176
+ if output_spec is str:
177
+ return PlainTextOutputSchema()
178
+
179
+ if isinstance(output_spec, NativeOutput):
180
+ return NativeOutputSchema(
181
+ cls._build_processor(
182
+ _flatten_output_spec(output_spec.outputs),
183
+ name=output_spec.name,
184
+ description=output_spec.description,
185
+ )
186
+ )
187
+ elif isinstance(output_spec, PromptedOutput):
188
+ return PromptedOutputSchema(
189
+ cls._build_processor(
190
+ _flatten_output_spec(output_spec.outputs),
191
+ name=output_spec.name,
192
+ description=output_spec.description,
193
+ ),
194
+ template=output_spec.template,
195
+ )
174
196
 
175
- output_types: Sequence[SimpleOutputTypeOrMarker[OutputDataT]]
176
- if isinstance(output_type, Sequence):
177
- output_types = output_type
178
- else:
179
- output_types = (output_type,)
197
+ text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
198
+ tool_outputs: Sequence[ToolOutput[OutputDataT]] = []
199
+ other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = []
200
+ for output in _flatten_output_spec(output_spec):
201
+ if output is str:
202
+ text_outputs.append(cast(type[str], output))
203
+ elif isinstance(output, TextOutput):
204
+ text_outputs.append(output)
205
+ elif isinstance(output, ToolOutput):
206
+ tool_outputs.append(output)
207
+ else:
208
+ other_outputs.append(output)
180
209
 
181
- output_types_flat: list[SimpleOutputTypeOrMarker[OutputDataT]] = []
182
- for output_type in output_types:
183
- if union_types := get_union_args(output_type):
184
- output_types_flat.extend(union_types)
210
+ tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict)
211
+
212
+ if len(text_outputs) > 0:
213
+ if len(text_outputs) > 1:
214
+ raise UserError('Only one text output is allowed.')
215
+ text_output = text_outputs[0]
216
+
217
+ text_output_schema = None
218
+ if isinstance(text_output, TextOutput):
219
+ text_output_schema = PlainTextOutputProcessor(text_output.output_function)
220
+
221
+ if len(tools) == 0:
222
+ return PlainTextOutputSchema(text_output_schema)
185
223
  else:
186
- output_types_flat.append(output_type)
224
+ return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools)
187
225
 
188
- allow_text_output = False
189
- if str in output_types_flat:
190
- allow_text_output = True
191
- output_types_flat = [t for t in output_types_flat if t is not str]
226
+ if len(tool_outputs) > 0:
227
+ return ToolOutputSchema(tools)
192
228
 
193
- multiple = len(output_types_flat) > 1
229
+ if len(other_outputs) > 0:
230
+ schema = OutputSchemaWithoutMode(
231
+ processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
232
+ tools=tools,
233
+ )
234
+ if default_mode:
235
+ schema = schema.with_default_mode(default_mode)
236
+ return schema
194
237
 
195
- default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME
196
- default_tool_description = description
197
- default_tool_strict = strict
238
+ raise UserError('No output type provided.') # pragma: no cover
198
239
 
240
+ @staticmethod
241
+ def _build_tools(
242
+ outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
243
+ name: str | None = None,
244
+ description: str | None = None,
245
+ strict: bool | None = None,
246
+ ) -> dict[str, OutputTool[OutputDataT]]:
199
247
  tools: dict[str, OutputTool[OutputDataT]] = {}
200
- for output_type in output_types_flat:
201
- tool_name = None
202
- tool_description = None
203
- tool_strict = None
204
- if isinstance(output_type, ToolOutput):
205
- tool_output_type = output_type.output_type
248
+
249
+ default_name = name or DEFAULT_OUTPUT_TOOL_NAME
250
+ default_description = description
251
+ default_strict = strict
252
+
253
+ multiple = len(outputs) > 1
254
+ for output in outputs:
255
+ name = None
256
+ description = None
257
+ strict = None
258
+ if isinstance(output, ToolOutput):
206
259
  # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
207
- tool_name = output_type.name
208
- tool_description = output_type.description
209
- tool_strict = output_type.strict
210
- else:
211
- tool_output_type = output_type
260
+ name = output.name
261
+ description = output.description
262
+ strict = output.strict
212
263
 
213
- if tool_name is None:
214
- tool_name = default_tool_name
264
+ output = output.output
265
+
266
+ if name is None:
267
+ name = default_name
215
268
  if multiple:
216
- tool_name += f'_{tool_output_type.__name__}'
269
+ name += f'_{output.__name__}'
217
270
 
218
271
  i = 1
219
- original_tool_name = tool_name
220
- while tool_name in tools:
272
+ original_name = name
273
+ while name in tools:
221
274
  i += 1
222
- tool_name = f'{original_tool_name}_{i}'
275
+ name = f'{original_name}_{i}'
223
276
 
224
- tool_description = tool_description or default_tool_description
225
- if tool_strict is None:
226
- tool_strict = default_tool_strict
277
+ description = description or default_description
278
+ if strict is None:
279
+ strict = default_strict
227
280
 
228
- parameters_schema = OutputObjectSchema(
229
- output_type=tool_output_type, description=tool_description, strict=tool_strict
230
- )
231
- tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple)
281
+ processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
282
+ tools[name] = OutputTool(name=name, processor=processor, multiple=multiple)
283
+
284
+ return tools
285
+
286
+ @staticmethod
287
+ def _build_processor(
288
+ outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
289
+ name: str | None = None,
290
+ description: str | None = None,
291
+ strict: bool | None = None,
292
+ ) -> ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]:
293
+ outputs = _flatten_output_spec(outputs)
294
+ if len(outputs) == 1:
295
+ return ObjectOutputProcessor(output=outputs[0], name=name, description=description, strict=strict)
296
+
297
+ return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description)
298
+
299
+ @property
300
+ @abstractmethod
301
+ def mode(self) -> OutputMode:
302
+ raise NotImplementedError()
303
+
304
+ @abstractmethod
305
+ def raise_if_unsupported(self, profile: ModelProfile) -> None:
306
+ """Raise an error if the mode is not supported by the model."""
307
+ raise NotImplementedError()
308
+
309
+ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
310
+ return self
311
+
312
+
313
+ @dataclass(init=False)
314
+ class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
315
+ processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
316
+ _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
317
+
318
+ def __init__(
319
+ self,
320
+ processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
321
+ tools: dict[str, OutputTool[OutputDataT]],
322
+ ):
323
+ self.processor = processor
324
+ self._tools = tools
325
+
326
+ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
327
+ if mode == 'native':
328
+ return NativeOutputSchema(self.processor)
329
+ elif mode == 'prompted':
330
+ return PromptedOutputSchema(self.processor)
331
+ elif mode == 'tool':
332
+ return ToolOutputSchema(self.tools)
333
+ else:
334
+ assert_never(mode)
232
335
 
233
- return cls(
234
- tools=tools,
235
- allow_text_output=allow_text_output,
336
+ @property
337
+ def tools(self) -> dict[str, OutputTool[OutputDataT]]:
338
+ """Get the tools for this output schema."""
339
+ # We return tools here as they're checked in Agent._register_tool.
340
+ # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time.
341
+ return self._tools
342
+
343
+
344
+ class TextOutputSchema(OutputSchema[OutputDataT], ABC):
345
+ @abstractmethod
346
+ async def process(
347
+ self,
348
+ text: str,
349
+ run_context: RunContext[AgentDepsT],
350
+ allow_partial: bool = False,
351
+ wrap_validation_errors: bool = True,
352
+ ) -> OutputDataT:
353
+ raise NotImplementedError()
354
+
355
+
356
+ @dataclass
357
+ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
358
+ processor: PlainTextOutputProcessor[OutputDataT] | None = None
359
+
360
+ @property
361
+ def mode(self) -> OutputMode:
362
+ return 'text'
363
+
364
+ def raise_if_unsupported(self, profile: ModelProfile) -> None:
365
+ """Raise an error if the mode is not supported by the model."""
366
+ pass
367
+
368
+ async def process(
369
+ self,
370
+ text: str,
371
+ run_context: RunContext[AgentDepsT],
372
+ allow_partial: bool = False,
373
+ wrap_validation_errors: bool = True,
374
+ ) -> OutputDataT:
375
+ """Validate an output message.
376
+
377
+ Args:
378
+ text: The output text to validate.
379
+ run_context: The current run context.
380
+ allow_partial: If true, allow partial validation.
381
+ wrap_validation_errors: If true, wrap the validation errors in a retry message.
382
+
383
+ Returns:
384
+ Either the validated output data (left) or a retry message (right).
385
+ """
386
+ if self.processor is None:
387
+ return cast(OutputDataT, text)
388
+
389
+ return await self.processor.process(
390
+ text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
236
391
  )
237
392
 
393
+
394
+ @dataclass
395
+ class StructuredTextOutputSchema(TextOutputSchema[OutputDataT], ABC):
396
+ processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
397
+
398
+ @property
399
+ def object_def(self) -> OutputObjectDefinition:
400
+ return self.processor.object_def
401
+
402
+
403
+ @dataclass
404
+ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
405
+ @property
406
+ def mode(self) -> OutputMode:
407
+ return 'native'
408
+
409
+ def raise_if_unsupported(self, profile: ModelProfile) -> None:
410
+ """Raise an error if the mode is not supported by the model."""
411
+ if not profile.supports_json_schema_output:
412
+ raise UserError('Structured output is not supported by the model.')
413
+
414
+ async def process(
415
+ self,
416
+ text: str,
417
+ run_context: RunContext[AgentDepsT],
418
+ allow_partial: bool = False,
419
+ wrap_validation_errors: bool = True,
420
+ ) -> OutputDataT:
421
+ """Validate an output message.
422
+
423
+ Args:
424
+ text: The output text to validate.
425
+ run_context: The current run context.
426
+ allow_partial: If true, allow partial validation.
427
+ wrap_validation_errors: If true, wrap the validation errors in a retry message.
428
+
429
+ Returns:
430
+ Either the validated output data (left) or a retry message (right).
431
+ """
432
+ return await self.processor.process(
433
+ text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
434
+ )
435
+
436
+
437
+ @dataclass
438
+ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
439
+ template: str | None = None
440
+
441
+ @property
442
+ def mode(self) -> OutputMode:
443
+ return 'prompted'
444
+
445
+ def raise_if_unsupported(self, profile: ModelProfile) -> None:
446
+ """Raise an error if the mode is not supported by the model."""
447
+ pass
448
+
449
+ def instructions(self, default_template: str) -> str:
450
+ """Get instructions to tell model to output JSON matching the schema."""
451
+ template = self.template or default_template
452
+
453
+ if '{schema}' not in template:
454
+ template = '\n\n'.join([template, '{schema}'])
455
+
456
+ object_def = self.object_def
457
+ schema = object_def.json_schema.copy()
458
+ if object_def.name:
459
+ schema['title'] = object_def.name
460
+ if object_def.description:
461
+ schema['description'] = object_def.description
462
+
463
+ return template.format(schema=json.dumps(schema))
464
+
465
+ async def process(
466
+ self,
467
+ text: str,
468
+ run_context: RunContext[AgentDepsT],
469
+ allow_partial: bool = False,
470
+ wrap_validation_errors: bool = True,
471
+ ) -> OutputDataT:
472
+ """Validate an output message.
473
+
474
+ Args:
475
+ text: The output text to validate.
476
+ run_context: The current run context.
477
+ allow_partial: If true, allow partial validation.
478
+ wrap_validation_errors: If true, wrap the validation errors in a retry message.
479
+
480
+ Returns:
481
+ Either the validated output data (left) or a retry message (right).
482
+ """
483
+ text = _utils.strip_markdown_fences(text)
484
+
485
+ return await self.processor.process(
486
+ text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
487
+ )
488
+
489
+
490
+ @dataclass(init=False)
491
+ class ToolOutputSchema(OutputSchema[OutputDataT]):
492
+ _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
493
+
494
+ def __init__(self, tools: dict[str, OutputTool[OutputDataT]]):
495
+ self._tools = tools
496
+
497
+ @property
498
+ def mode(self) -> OutputMode:
499
+ return 'tool'
500
+
501
+ def raise_if_unsupported(self, profile: ModelProfile) -> None:
502
+ """Raise an error if the mode is not supported by the model."""
503
+ if not profile.supports_tools:
504
+ raise UserError('Output tools are not supported by the model.')
505
+
506
+ @property
507
+ def tools(self) -> dict[str, OutputTool[OutputDataT]]:
508
+ """Get the tools for this output schema."""
509
+ return self._tools
510
+
511
+ def tool_names(self) -> list[str]:
512
+ """Return the names of the tools."""
513
+ return list(self.tools.keys())
514
+
515
+ def tool_defs(self) -> list[ToolDefinition]:
516
+ """Get tool definitions to register with the model."""
517
+ return [t.tool_def for t in self.tools.values()]
518
+
238
519
  def find_named_tool(
239
520
  self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
240
521
  ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
@@ -254,61 +535,78 @@ class OutputSchema(Generic[OutputDataT]):
254
535
  if result := self.tools.get(part.tool_name):
255
536
  yield part, result
256
537
 
257
- def tool_names(self) -> list[str]:
258
- """Return the names of the tools."""
259
- return list(self.tools.keys())
260
-
261
- def tool_defs(self) -> list[ToolDefinition]:
262
- """Get tool definitions to register with the model."""
263
- return [t.tool_def for t in self.tools.values()]
264
538
 
539
+ @dataclass(init=False)
540
+ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchema[OutputDataT]):
541
+ def __init__(
542
+ self,
543
+ processor: PlainTextOutputProcessor[OutputDataT] | None,
544
+ tools: dict[str, OutputTool[OutputDataT]],
545
+ ):
546
+ self.processor = processor
547
+ self._tools = tools
265
548
 
266
- def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool:
267
- return output_schema is None or output_schema.allow_text_output
549
+ @property
550
+ def mode(self) -> OutputMode:
551
+ return 'tool_or_text'
268
552
 
269
553
 
270
554
  @dataclass
271
555
  class OutputObjectDefinition:
272
- name: str
273
556
  json_schema: ObjectJsonSchema
557
+ name: str | None = None
274
558
  description: str | None = None
275
559
  strict: bool | None = None
276
560
 
277
561
 
278
562
  @dataclass(init=False)
279
- class OutputObjectSchema(Generic[OutputDataT]):
280
- definition: OutputObjectDefinition
281
- validator: SchemaValidator
282
- function_schema: _function_schema.FunctionSchema | None = None
563
+ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
564
+ @abstractmethod
565
+ async def process(
566
+ self,
567
+ data: str,
568
+ run_context: RunContext[AgentDepsT],
569
+ allow_partial: bool = False,
570
+ wrap_validation_errors: bool = True,
571
+ ) -> OutputDataT:
572
+ """Process an output message, performing validation and (if necessary) calling the output function."""
573
+ raise NotImplementedError()
574
+
575
+
576
+ @dataclass(init=False)
577
+ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
578
+ object_def: OutputObjectDefinition
283
579
  outer_typed_dict_key: str | None = None
580
+ _validator: SchemaValidator
581
+ _function_schema: _function_schema.FunctionSchema | None = None
284
582
 
285
583
  def __init__(
286
584
  self,
585
+ output: OutputTypeOrFunction[OutputDataT],
287
586
  *,
288
- output_type: SimpleOutputType[OutputDataT],
289
587
  name: str | None = None,
290
588
  description: str | None = None,
291
589
  strict: bool | None = None,
292
590
  ):
293
- if inspect.isfunction(output_type) or inspect.ismethod(output_type):
294
- self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema)
295
- self.validator = self.function_schema.validator
296
- json_schema = self.function_schema.json_schema
297
- json_schema['description'] = self.function_schema.description
591
+ if inspect.isfunction(output) or inspect.ismethod(output):
592
+ self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema)
593
+ self._validator = self._function_schema.validator
594
+ json_schema = self._function_schema.json_schema
595
+ json_schema['description'] = self._function_schema.description
298
596
  else:
299
597
  type_adapter: TypeAdapter[Any]
300
- if _utils.is_model_like(output_type):
301
- type_adapter = TypeAdapter(output_type)
598
+ if _utils.is_model_like(output):
599
+ type_adapter = TypeAdapter(output)
302
600
  else:
303
601
  self.outer_typed_dict_key = 'response'
304
602
  response_data_typed_dict = TypedDict( # noqa: UP013
305
603
  'response_data_typed_dict',
306
- {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm]
604
+ {'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm]
307
605
  )
308
606
  type_adapter = TypeAdapter(response_data_typed_dict)
309
607
 
310
608
  # Really a PluggableSchemaValidator, but it's API-compatible
311
- self.validator = cast(SchemaValidator, type_adapter.validator)
609
+ self._validator = cast(SchemaValidator, type_adapter.validator)
312
610
  json_schema = _utils.check_object_json_schema(
313
611
  type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
314
612
  )
@@ -323,8 +621,8 @@ class OutputObjectSchema(Generic[OutputDataT]):
323
621
  else:
324
622
  description = f'{description}. {json_schema_description}'
325
623
 
326
- self.definition = OutputObjectDefinition(
327
- name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME),
624
+ self.object_def = OutputObjectDefinition(
625
+ name=name or getattr(output, '__name__', None),
328
626
  description=description,
329
627
  json_schema=json_schema,
330
628
  strict=strict,
@@ -335,6 +633,7 @@ class OutputObjectSchema(Generic[OutputDataT]):
335
633
  data: str | dict[str, Any] | None,
336
634
  run_context: RunContext[AgentDepsT],
337
635
  allow_partial: bool = False,
636
+ wrap_validation_errors: bool = True,
338
637
  ) -> OutputDataT:
339
638
  """Process an output message, performing validation and (if necessary) calling the output function.
340
639
 
@@ -342,45 +641,235 @@ class OutputObjectSchema(Generic[OutputDataT]):
342
641
  data: The output data to validate.
343
642
  run_context: The current run context.
344
643
  allow_partial: If true, allow partial validation.
644
+ wrap_validation_errors: If true, wrap the validation errors in a retry message.
345
645
 
346
646
  Returns:
347
647
  Either the validated output data (left) or a retry message (right).
348
648
  """
349
- pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
350
- if isinstance(data, str):
351
- output = self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
352
- else:
353
- output = self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
354
-
355
- if self.function_schema:
356
- output = await self.function_schema.call(output, run_context)
649
+ try:
650
+ pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
651
+ if isinstance(data, str):
652
+ output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
653
+ else:
654
+ output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
655
+ except ValidationError as e:
656
+ if wrap_validation_errors:
657
+ m = _messages.RetryPromptPart(
658
+ content=e.errors(include_url=False),
659
+ )
660
+ raise ToolRetryError(m) from e
661
+ else:
662
+ raise # pragma: lax no cover
357
663
 
358
664
  if k := self.outer_typed_dict_key:
359
665
  output = output[k]
666
+
667
+ if self._function_schema:
668
+ try:
669
+ output = await self._function_schema.call(output, run_context)
670
+ except ModelRetry as r:
671
+ if wrap_validation_errors:
672
+ m = _messages.RetryPromptPart(
673
+ content=r.message,
674
+ )
675
+ raise ToolRetryError(m) from r
676
+ else:
677
+ raise # pragma: lax no cover
678
+
360
679
  return output
361
680
 
362
681
 
682
+ @dataclass
683
+ class UnionOutputResult:
684
+ kind: str
685
+ data: ObjectJsonSchema
686
+
687
+
688
+ @dataclass
689
+ class UnionOutputModel:
690
+ result: UnionOutputResult
691
+
692
+
693
+ @dataclass(init=False)
694
+ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
695
+ object_def: OutputObjectDefinition
696
+ _union_processor: ObjectOutputProcessor[UnionOutputModel]
697
+ _processors: dict[str, ObjectOutputProcessor[OutputDataT]]
698
+
699
+ def __init__(
700
+ self,
701
+ outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
702
+ *,
703
+ name: str | None = None,
704
+ description: str | None = None,
705
+ strict: bool | None = None,
706
+ ):
707
+ self._union_processor = ObjectOutputProcessor(output=UnionOutputModel)
708
+
709
+ json_schemas: list[ObjectJsonSchema] = []
710
+ self._processors = {}
711
+ for output in outputs:
712
+ processor = ObjectOutputProcessor(output=output, strict=strict)
713
+ object_def = processor.object_def
714
+
715
+ object_key = object_def.name or output.__name__
716
+ i = 1
717
+ original_key = object_key
718
+ while object_key in self._processors:
719
+ i += 1
720
+ object_key = f'{original_key}_{i}'
721
+
722
+ self._processors[object_key] = processor
723
+
724
+ json_schema = object_def.json_schema
725
+ if object_def.name: # pragma: no branch
726
+ json_schema['title'] = object_def.name
727
+ if object_def.description:
728
+ json_schema['description'] = object_def.description
729
+
730
+ json_schemas.append(json_schema)
731
+
732
+ json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
733
+
734
+ discriminated_json_schemas: list[ObjectJsonSchema] = []
735
+ for object_key, json_schema in zip(self._processors.keys(), json_schemas):
736
+ title = json_schema.pop('title', None)
737
+ description = json_schema.pop('description', None)
738
+
739
+ discriminated_json_schema = {
740
+ 'type': 'object',
741
+ 'properties': {
742
+ 'kind': {
743
+ 'type': 'string',
744
+ 'const': object_key,
745
+ },
746
+ 'data': json_schema,
747
+ },
748
+ 'required': ['kind', 'data'],
749
+ 'additionalProperties': False,
750
+ }
751
+ if title: # pragma: no branch
752
+ discriminated_json_schema['title'] = title
753
+ if description:
754
+ discriminated_json_schema['description'] = description
755
+
756
+ discriminated_json_schemas.append(discriminated_json_schema)
757
+
758
+ json_schema = {
759
+ 'type': 'object',
760
+ 'properties': {
761
+ 'result': {
762
+ 'anyOf': discriminated_json_schemas,
763
+ }
764
+ },
765
+ 'required': ['result'],
766
+ 'additionalProperties': False,
767
+ }
768
+ if all_defs:
769
+ json_schema['$defs'] = all_defs
770
+
771
+ self.object_def = OutputObjectDefinition(
772
+ json_schema=json_schema,
773
+ strict=strict,
774
+ name=name,
775
+ description=description,
776
+ )
777
+
778
+ async def process(
779
+ self,
780
+ data: str | dict[str, Any] | None,
781
+ run_context: RunContext[AgentDepsT],
782
+ allow_partial: bool = False,
783
+ wrap_validation_errors: bool = True,
784
+ ) -> OutputDataT:
785
+ union_object = await self._union_processor.process(
786
+ data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
787
+ )
788
+
789
+ result = union_object.result
790
+ kind = result.kind
791
+ data = result.data
792
+ try:
793
+ processor = self._processors[kind]
794
+ except KeyError as e: # pragma: no cover
795
+ if wrap_validation_errors:
796
+ m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}')
797
+ raise ToolRetryError(m) from e
798
+ else:
799
+ raise
800
+
801
+ return await processor.process(
802
+ data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
803
+ )
804
+
805
+
806
+ @dataclass(init=False)
807
+ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
808
+ _function_schema: _function_schema.FunctionSchema
809
+ _str_argument_name: str
810
+
811
+ def __init__(
812
+ self,
813
+ output_function: TextOutputFunc[OutputDataT],
814
+ ):
815
+ self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema)
816
+
817
+ arguments_schema = self._function_schema.json_schema.get('properties', {})
818
+ argument_name = next(iter(arguments_schema.keys()), None)
819
+ if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string':
820
+ self._str_argument_name = argument_name
821
+ return
822
+
823
+ raise UserError('TextOutput must take a function taking a `str`')
824
+
825
+ @property
826
+ def object_def(self) -> None:
827
+ return None # pragma: no cover
828
+
829
+ async def process(
830
+ self,
831
+ data: str,
832
+ run_context: RunContext[AgentDepsT],
833
+ allow_partial: bool = False,
834
+ wrap_validation_errors: bool = True,
835
+ ) -> OutputDataT:
836
+ args = {self._str_argument_name: data}
837
+
838
+ try:
839
+ output = await self._function_schema.call(args, run_context)
840
+ except ModelRetry as r:
841
+ if wrap_validation_errors:
842
+ m = _messages.RetryPromptPart(
843
+ content=r.message,
844
+ )
845
+ raise ToolRetryError(m) from r
846
+ else:
847
+ raise # pragma: lax no cover
848
+
849
+ return cast(OutputDataT, output)
850
+
851
+
363
852
  @dataclass(init=False)
364
853
  class OutputTool(Generic[OutputDataT]):
365
- parameters_schema: OutputObjectSchema[OutputDataT]
854
+ processor: ObjectOutputProcessor[OutputDataT]
366
855
  tool_def: ToolDefinition
367
856
 
368
- def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool):
369
- self.parameters_schema = parameters_schema
370
- definition = parameters_schema.definition
857
+ def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool):
858
+ self.processor = processor
859
+ object_def = processor.object_def
371
860
 
372
- description = definition.description
861
+ description = object_def.description
373
862
  if not description:
374
863
  description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
375
864
  if multiple:
376
- description = f'{definition.name}: {description}'
865
+ description = f'{object_def.name}: {description}'
377
866
 
378
867
  self.tool_def = ToolDefinition(
379
868
  name=name,
380
869
  description=description,
381
- parameters_json_schema=definition.json_schema,
382
- strict=definition.strict,
383
- outer_typed_dict_key=parameters_schema.outer_typed_dict_key,
870
+ parameters_json_schema=object_def.json_schema,
871
+ strict=object_def.strict,
872
+ outer_typed_dict_key=processor.outer_typed_dict_key,
384
873
  )
385
874
 
386
875
  async def process(
@@ -402,7 +891,9 @@ class OutputTool(Generic[OutputDataT]):
402
891
  Either the validated output data (left) or a retry message (right).
403
892
  """
404
893
  try:
405
- output = await self.parameters_schema.process(tool_call.args, run_context, allow_partial=allow_partial)
894
+ output = await self.processor.process(
895
+ tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False
896
+ )
406
897
  except ValidationError as e:
407
898
  if wrap_validation_errors:
408
899
  m = _messages.RetryPromptPart(
@@ -427,13 +918,17 @@ class OutputTool(Generic[OutputDataT]):
427
918
  return output
428
919
 
429
920
 
430
- def get_union_args(tp: Any) -> tuple[Any, ...]:
431
- """Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple."""
432
- if typing_objects.is_typealiastype(tp):
433
- tp = tp.__value__
434
-
435
- origin = get_origin(tp)
436
- if is_union_origin(origin):
437
- return get_args(tp)
921
+ def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]:
922
+ outputs: Sequence[T]
923
+ if isinstance(output_spec, Sequence):
924
+ outputs = output_spec
438
925
  else:
439
- return ()
926
+ outputs = (output_spec,)
927
+
928
+ outputs_flat: list[T] = []
929
+ for output in outputs:
930
+ if union_types := _utils.get_union_args(output):
931
+ outputs_flat.extend(union_types)
932
+ else:
933
+ outputs_flat.append(output)
934
+ return outputs_flat