pydantic-ai-slim 1.0.13__py3-none-any.whl → 1.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (38) hide show
  1. pydantic_ai/__init__.py +19 -1
  2. pydantic_ai/_agent_graph.py +118 -97
  3. pydantic_ai/_cli.py +4 -7
  4. pydantic_ai/_output.py +236 -192
  5. pydantic_ai/_parts_manager.py +8 -42
  6. pydantic_ai/_tool_manager.py +9 -16
  7. pydantic_ai/agent/abstract.py +169 -1
  8. pydantic_ai/builtin_tools.py +82 -0
  9. pydantic_ai/direct.py +7 -0
  10. pydantic_ai/durable_exec/dbos/_agent.py +106 -3
  11. pydantic_ai/durable_exec/temporal/_agent.py +123 -6
  12. pydantic_ai/durable_exec/temporal/_model.py +8 -0
  13. pydantic_ai/format_prompt.py +4 -3
  14. pydantic_ai/mcp.py +20 -10
  15. pydantic_ai/messages.py +149 -3
  16. pydantic_ai/models/__init__.py +15 -1
  17. pydantic_ai/models/anthropic.py +7 -3
  18. pydantic_ai/models/cohere.py +4 -0
  19. pydantic_ai/models/function.py +7 -4
  20. pydantic_ai/models/gemini.py +8 -0
  21. pydantic_ai/models/google.py +56 -23
  22. pydantic_ai/models/groq.py +11 -5
  23. pydantic_ai/models/huggingface.py +5 -3
  24. pydantic_ai/models/mistral.py +6 -8
  25. pydantic_ai/models/openai.py +197 -58
  26. pydantic_ai/models/test.py +4 -0
  27. pydantic_ai/output.py +5 -2
  28. pydantic_ai/profiles/__init__.py +2 -0
  29. pydantic_ai/profiles/google.py +5 -2
  30. pydantic_ai/profiles/openai.py +2 -1
  31. pydantic_ai/result.py +46 -30
  32. pydantic_ai/run.py +35 -7
  33. pydantic_ai/usage.py +5 -4
  34. {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/METADATA +3 -3
  35. {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/RECORD +38 -38
  36. {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/WHEEL +0 -0
  37. {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/entry_points.txt +0 -0
  38. {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py CHANGED
@@ -75,7 +75,7 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversati
75
75
  async def execute_traced_output_function(
76
76
  function_schema: _function_schema.FunctionSchema,
77
77
  run_context: RunContext[AgentDepsT],
78
- args: dict[str, Any] | Any,
78
+ args: dict[str, Any],
79
79
  wrap_validation_errors: bool = True,
80
80
  ) -> Any:
81
81
  """Execute an output function within a traced span with error handling.
@@ -209,18 +209,20 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
209
209
  return result_data
210
210
 
211
211
 
212
- @dataclass
212
+ @dataclass(kw_only=True)
213
213
  class BaseOutputSchema(ABC, Generic[OutputDataT]):
214
- allows_deferred_tools: bool
214
+ text_processor: BaseOutputProcessor[OutputDataT] | None = None
215
+ toolset: OutputToolset[Any] | None = None
216
+ allows_deferred_tools: bool = False
217
+ allows_image: bool = False
215
218
 
216
219
  @abstractmethod
217
220
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
218
221
  raise NotImplementedError()
219
222
 
220
223
  @property
221
- def toolset(self) -> OutputToolset[Any] | None:
222
- """Get the toolset for this output schema."""
223
- return None
224
+ def allows_text(self) -> bool:
225
+ return self.text_processor is not None
224
226
 
225
227
 
226
228
  @dataclass(init=False)
@@ -262,38 +264,67 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
262
264
  strict: bool | None = None,
263
265
  ) -> BaseOutputSchema[OutputDataT]:
264
266
  """Build an OutputSchema dataclass from an output type."""
265
- raw_outputs = _flatten_output_spec(output_spec)
267
+ outputs = _flatten_output_spec(output_spec)
268
+
269
+ allows_deferred_tools = DeferredToolRequests in outputs
270
+ if allows_deferred_tools:
271
+ outputs = [output for output in outputs if output is not DeferredToolRequests]
272
+ if len(outputs) == 0:
273
+ raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
266
274
 
267
- outputs = [output for output in raw_outputs if output is not DeferredToolRequests]
268
- allows_deferred_tools = len(outputs) < len(raw_outputs)
269
- if len(outputs) == 0 and allows_deferred_tools:
270
- raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
275
+ allows_image = _messages.BinaryImage in outputs
276
+ if allows_image:
277
+ outputs = [output for output in outputs if output is not _messages.BinaryImage]
271
278
 
272
279
  if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
273
280
  if len(outputs) > 1:
274
281
  raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
275
282
 
283
+ flattened_outputs = _flatten_output_spec(output.outputs)
284
+
285
+ if DeferredToolRequests in flattened_outputs:
286
+ raise UserError( # pragma: no cover
287
+ '`NativeOutput` cannot contain `DeferredToolRequests`. Include it alongside the native output marker instead: `output_type=[NativeOutput(...), DeferredToolRequests]`'
288
+ )
289
+ if _messages.BinaryImage in flattened_outputs:
290
+ raise UserError( # pragma: no cover
291
+ '`NativeOutput` cannot contain `BinaryImage`. Include it alongside the native output marker instead: `output_type=[NativeOutput(...), BinaryImage]`'
292
+ )
293
+
276
294
  return NativeOutputSchema(
277
295
  processor=cls._build_processor(
278
- _flatten_output_spec(output.outputs),
296
+ flattened_outputs,
279
297
  name=output.name,
280
298
  description=output.description,
281
299
  strict=output.strict,
282
300
  ),
283
301
  allows_deferred_tools=allows_deferred_tools,
302
+ allows_image=allows_image,
284
303
  )
285
304
  elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
286
305
  if len(outputs) > 1:
287
306
  raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
288
307
 
308
+ flattened_outputs = _flatten_output_spec(output.outputs)
309
+
310
+ if DeferredToolRequests in flattened_outputs:
311
+ raise UserError( # pragma: no cover
312
+ '`PromptedOutput` cannot contain `DeferredToolRequests`. Include it alongside the prompted output marker instead: `output_type=[PromptedOutput(...), DeferredToolRequests]`'
313
+ )
314
+ if _messages.BinaryImage in flattened_outputs:
315
+ raise UserError( # pragma: no cover
316
+ '`PromptedOutput` cannot contain `BinaryImage`. Include it alongside the prompted output marker instead: `output_type=[PromptedOutput(...), BinaryImage]`'
317
+ )
318
+
289
319
  return PromptedOutputSchema(
320
+ template=output.template,
290
321
  processor=cls._build_processor(
291
- _flatten_output_spec(output.outputs),
322
+ flattened_outputs,
292
323
  name=output.name,
293
324
  description=output.description,
294
325
  ),
295
- template=output.template,
296
326
  allows_deferred_tools=allows_deferred_tools,
327
+ allows_image=allows_image,
297
328
  )
298
329
 
299
330
  text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
@@ -317,37 +348,51 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
317
348
 
318
349
  toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict)
319
350
 
351
+ text_processor: BaseOutputProcessor[OutputDataT] | None = None
352
+
320
353
  if len(text_outputs) > 0:
321
354
  if len(text_outputs) > 1:
322
355
  raise UserError('Only one `str` or `TextOutput` is allowed.')
323
356
  text_output = text_outputs[0]
324
357
 
325
- text_output_schema = None
326
358
  if isinstance(text_output, TextOutput):
327
- text_output_schema = PlainTextOutputProcessor(text_output.output_function)
359
+ text_processor = TextFunctionOutputProcessor(text_output.output_function)
360
+ else:
361
+ text_processor = TextOutputProcessor()
328
362
 
329
363
  if toolset:
330
- return ToolOrTextOutputSchema(
331
- processor=text_output_schema,
364
+ return ToolOutputSchema(
332
365
  toolset=toolset,
366
+ text_processor=text_processor,
333
367
  allows_deferred_tools=allows_deferred_tools,
368
+ allows_image=allows_image,
334
369
  )
335
370
  else:
336
- return PlainTextOutputSchema(processor=text_output_schema, allows_deferred_tools=allows_deferred_tools)
371
+ return TextOutputSchema(
372
+ text_processor=text_processor,
373
+ allows_deferred_tools=allows_deferred_tools,
374
+ allows_image=allows_image,
375
+ )
337
376
 
338
377
  if len(tool_outputs) > 0:
339
- return ToolOutputSchema(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
378
+ return ToolOutputSchema(
379
+ toolset=toolset, allows_deferred_tools=allows_deferred_tools, allows_image=allows_image
380
+ )
340
381
 
341
382
  if len(other_outputs) > 0:
342
383
  schema = OutputSchemaWithoutMode(
343
384
  processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
344
385
  toolset=toolset,
345
386
  allows_deferred_tools=allows_deferred_tools,
387
+ allows_image=allows_image,
346
388
  )
347
389
  if default_mode:
348
390
  schema = schema.with_default_mode(default_mode)
349
391
  return schema
350
392
 
393
+ if allows_image:
394
+ return ImageOutputSchema(allows_deferred_tools=allows_deferred_tools)
395
+
351
396
  raise UserError('At least one output type must be provided.')
352
397
 
353
398
  @staticmethod
@@ -356,7 +401,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
356
401
  name: str | None = None,
357
402
  description: str | None = None,
358
403
  strict: bool | None = None,
359
- ) -> ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]:
404
+ ) -> BaseObjectOutputProcessor[OutputDataT]:
360
405
  outputs = _flatten_output_spec(outputs)
361
406
  if len(outputs) == 1:
362
407
  return ObjectOutputProcessor(output=outputs[0], name=name, description=description, strict=strict)
@@ -368,10 +413,10 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
368
413
  def mode(self) -> OutputMode:
369
414
  raise NotImplementedError()
370
415
 
371
- @abstractmethod
372
416
  def raise_if_unsupported(self, profile: ModelProfile) -> None:
373
- """Raise an error if the mode is not supported by the model."""
374
- raise NotImplementedError()
417
+ """Raise an error if the mode is not supported by this model."""
418
+ if self.allows_image and not profile.supports_image_output:
419
+ raise UserError('Image output is not supported by this model.')
375
420
 
376
421
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
377
422
  return self
@@ -379,142 +424,139 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
379
424
 
380
425
  @dataclass(init=False)
381
426
  class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
382
- processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
383
- _toolset: OutputToolset[Any] | None
427
+ processor: BaseObjectOutputProcessor[OutputDataT]
384
428
 
385
429
  def __init__(
386
430
  self,
387
- processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
431
+ processor: BaseObjectOutputProcessor[OutputDataT],
388
432
  toolset: OutputToolset[Any] | None,
389
433
  allows_deferred_tools: bool,
434
+ allows_image: bool,
390
435
  ):
391
- super().__init__(allows_deferred_tools)
436
+ # We set a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
437
+ # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
438
+ # but we cover ourselves just in case we end up using the tool output mode.
439
+ super().__init__(
440
+ allows_deferred_tools=allows_deferred_tools,
441
+ toolset=toolset,
442
+ text_processor=processor,
443
+ allows_image=allows_image,
444
+ )
392
445
  self.processor = processor
393
- self._toolset = toolset
394
446
 
395
447
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
396
448
  if mode == 'native':
397
- return NativeOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
449
+ return NativeOutputSchema(
450
+ processor=self.processor,
451
+ allows_deferred_tools=self.allows_deferred_tools,
452
+ allows_image=self.allows_image,
453
+ )
398
454
  elif mode == 'prompted':
399
- return PromptedOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
455
+ return PromptedOutputSchema(
456
+ processor=self.processor,
457
+ allows_deferred_tools=self.allows_deferred_tools,
458
+ allows_image=self.allows_image,
459
+ )
400
460
  elif mode == 'tool':
401
- return ToolOutputSchema(toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools)
461
+ return ToolOutputSchema(
462
+ toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools, allows_image=self.allows_image
463
+ )
402
464
  else:
403
465
  assert_never(mode)
404
466
 
405
- @property
406
- def toolset(self) -> OutputToolset[Any] | None:
407
- """Get the toolset for this output schema."""
408
- # We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
409
- # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
410
- # but we cover ourselves just in case we end up using the tool output mode.
411
- return self._toolset
412
-
413
467
 
414
- class TextOutputSchema(OutputSchema[OutputDataT], ABC):
415
- @abstractmethod
416
- async def process(
468
+ @dataclass(init=False)
469
+ class TextOutputSchema(OutputSchema[OutputDataT]):
470
+ def __init__(
417
471
  self,
418
- text: str,
419
- run_context: RunContext[AgentDepsT],
420
- allow_partial: bool = False,
421
- wrap_validation_errors: bool = True,
422
- ) -> OutputDataT:
423
- raise NotImplementedError()
424
-
425
-
426
- @dataclass
427
- class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
428
- processor: PlainTextOutputProcessor[OutputDataT] | None = None
472
+ *,
473
+ text_processor: TextOutputProcessor[OutputDataT],
474
+ allows_deferred_tools: bool,
475
+ allows_image: bool,
476
+ ):
477
+ super().__init__(
478
+ text_processor=text_processor,
479
+ allows_deferred_tools=allows_deferred_tools,
480
+ allows_image=allows_image,
481
+ )
429
482
 
430
483
  @property
431
484
  def mode(self) -> OutputMode:
432
485
  return 'text'
433
486
 
434
487
  def raise_if_unsupported(self, profile: ModelProfile) -> None:
435
- """Raise an error if the mode is not supported by the model."""
436
- pass
488
+ """Raise an error if the mode is not supported by this model."""
489
+ super().raise_if_unsupported(profile)
437
490
 
438
- async def process(
439
- self,
440
- text: str,
441
- run_context: RunContext[AgentDepsT],
442
- allow_partial: bool = False,
443
- wrap_validation_errors: bool = True,
444
- ) -> OutputDataT:
445
- """Validate an output message.
446
491
 
447
- Args:
448
- text: The output text to validate.
449
- run_context: The current run context.
450
- allow_partial: If true, allow partial validation.
451
- wrap_validation_errors: If true, wrap the validation errors in a retry message.
492
+ class ImageOutputSchema(OutputSchema[OutputDataT]):
493
+ def __init__(self, *, allows_deferred_tools: bool):
494
+ super().__init__(allows_deferred_tools=allows_deferred_tools, allows_image=True)
452
495
 
453
- Returns:
454
- Either the validated output data (left) or a retry message (right).
455
- """
456
- if self.processor is None:
457
- return cast(OutputDataT, text)
496
+ @property
497
+ def mode(self) -> OutputMode:
498
+ return 'image'
458
499
 
459
- return await self.processor.process(
460
- text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
461
- )
500
+ def raise_if_unsupported(self, profile: ModelProfile) -> None:
501
+ """Raise an error if the mode is not supported by this model."""
502
+ # This already raises if image output is not supported by this model.
503
+ super().raise_if_unsupported(profile)
462
504
 
463
505
 
464
- @dataclass
465
- class StructuredTextOutputSchema(TextOutputSchema[OutputDataT], ABC):
466
- processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
506
+ @dataclass(init=False)
507
+ class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC):
508
+ processor: BaseObjectOutputProcessor[OutputDataT]
509
+
510
+ def __init__(
511
+ self, *, processor: BaseObjectOutputProcessor[OutputDataT], allows_deferred_tools: bool, allows_image: bool
512
+ ):
513
+ super().__init__(
514
+ text_processor=processor, allows_deferred_tools=allows_deferred_tools, allows_image=allows_image
515
+ )
516
+ self.processor = processor
467
517
 
468
518
  @property
469
519
  def object_def(self) -> OutputObjectDefinition:
470
520
  return self.processor.object_def
471
521
 
472
522
 
473
- @dataclass
474
523
  class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
475
524
  @property
476
525
  def mode(self) -> OutputMode:
477
526
  return 'native'
478
527
 
479
528
  def raise_if_unsupported(self, profile: ModelProfile) -> None:
480
- """Raise an error if the mode is not supported by the model."""
529
+ """Raise an error if the mode is not supported by this model."""
481
530
  if not profile.supports_json_schema_output:
482
- raise UserError('Native structured output is not supported by the model.')
531
+ raise UserError('Native structured output is not supported by this model.')
483
532
 
484
- async def process(
485
- self,
486
- text: str,
487
- run_context: RunContext[AgentDepsT],
488
- allow_partial: bool = False,
489
- wrap_validation_errors: bool = True,
490
- ) -> OutputDataT:
491
- """Validate an output message.
492
533
 
493
- Args:
494
- text: The output text to validate.
495
- run_context: The current run context.
496
- allow_partial: If true, allow partial validation.
497
- wrap_validation_errors: If true, wrap the validation errors in a retry message.
534
+ @dataclass(init=False)
535
+ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
536
+ template: str | None
498
537
 
499
- Returns:
500
- Either the validated output data (left) or a retry message (right).
501
- """
502
- return await self.processor.process(
503
- text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
538
+ def __init__(
539
+ self,
540
+ *,
541
+ template: str | None = None,
542
+ processor: BaseObjectOutputProcessor[OutputDataT],
543
+ allows_deferred_tools: bool,
544
+ allows_image: bool,
545
+ ):
546
+ super().__init__(
547
+ processor=PromptedOutputProcessor(processor),
548
+ allows_deferred_tools=allows_deferred_tools,
549
+ allows_image=allows_image,
504
550
  )
505
-
506
-
507
- @dataclass
508
- class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
509
- template: str | None = None
551
+ self.template = template
510
552
 
511
553
  @property
512
554
  def mode(self) -> OutputMode:
513
555
  return 'prompted'
514
556
 
515
557
  def raise_if_unsupported(self, profile: ModelProfile) -> None:
516
- """Raise an error if the mode is not supported by the model."""
517
- pass
558
+ """Raise an error if the mode is not supported by this model."""
559
+ super().raise_if_unsupported(profile)
518
560
 
519
561
  def instructions(self, default_template: str) -> str:
520
562
  """Get instructions to tell model to output JSON matching the schema."""
@@ -532,71 +574,35 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
532
574
 
533
575
  return template.format(schema=json.dumps(schema))
534
576
 
535
- async def process(
536
- self,
537
- text: str,
538
- run_context: RunContext[AgentDepsT],
539
- allow_partial: bool = False,
540
- wrap_validation_errors: bool = True,
541
- ) -> OutputDataT:
542
- """Validate an output message.
543
-
544
- Args:
545
- text: The output text to validate.
546
- run_context: The current run context.
547
- allow_partial: If true, allow partial validation.
548
- wrap_validation_errors: If true, wrap the validation errors in a retry message.
549
-
550
- Returns:
551
- Either the validated output data (left) or a retry message (right).
552
- """
553
- text = _utils.strip_markdown_fences(text)
554
-
555
- return await self.processor.process(
556
- text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
557
- )
558
-
559
577
 
560
578
  @dataclass(init=False)
561
579
  class ToolOutputSchema(OutputSchema[OutputDataT]):
562
- _toolset: OutputToolset[Any] | None
563
-
564
- def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tools: bool):
565
- super().__init__(allows_deferred_tools)
566
- self._toolset = toolset
567
-
568
- @property
569
- def mode(self) -> OutputMode:
570
- return 'tool'
571
-
572
- def raise_if_unsupported(self, profile: ModelProfile) -> None:
573
- """Raise an error if the mode is not supported by the model."""
574
- if not profile.supports_tools:
575
- raise UserError('Output tools are not supported by the model.')
576
-
577
- @property
578
- def toolset(self) -> OutputToolset[Any] | None:
579
- """Get the toolset for this output schema."""
580
- return self._toolset
581
-
582
-
583
- @dataclass(init=False)
584
- class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchema[OutputDataT]):
585
580
  def __init__(
586
581
  self,
587
- processor: PlainTextOutputProcessor[OutputDataT] | None,
582
+ *,
588
583
  toolset: OutputToolset[Any] | None,
584
+ text_processor: BaseOutputProcessor[OutputDataT] | None = None,
589
585
  allows_deferred_tools: bool,
586
+ allows_image: bool,
590
587
  ):
591
- super().__init__(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
592
- self.processor = processor
588
+ super().__init__(
589
+ toolset=toolset,
590
+ allows_deferred_tools=allows_deferred_tools,
591
+ text_processor=text_processor,
592
+ allows_image=allows_image,
593
+ )
593
594
 
594
595
  @property
595
596
  def mode(self) -> OutputMode:
596
- return 'tool_or_text'
597
+ return 'tool'
598
+
599
+ def raise_if_unsupported(self, profile: ModelProfile) -> None:
600
+ """Raise an error if the mode is not supported by this model."""
601
+ super().raise_if_unsupported(profile)
602
+ if not profile.supports_tools:
603
+ raise UserError('Tool output is not supported by this model.')
597
604
 
598
605
 
599
- @dataclass(init=False)
600
606
  class BaseOutputProcessor(ABC, Generic[OutputDataT]):
601
607
  @abstractmethod
602
608
  async def process(
@@ -610,9 +616,35 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
610
616
  raise NotImplementedError()
611
617
 
612
618
 
613
- @dataclass(init=False)
614
- class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
619
+ @dataclass(kw_only=True)
620
+ class BaseObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
615
621
  object_def: OutputObjectDefinition
622
+
623
+
624
+ @dataclass(init=False)
625
+ class PromptedOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):
626
+ wrapped: BaseObjectOutputProcessor[OutputDataT]
627
+
628
+ def __init__(self, wrapped: BaseObjectOutputProcessor[OutputDataT]):
629
+ self.wrapped = wrapped
630
+ super().__init__(object_def=wrapped.object_def)
631
+
632
+ async def process(
633
+ self,
634
+ data: str,
635
+ run_context: RunContext[AgentDepsT],
636
+ allow_partial: bool = False,
637
+ wrap_validation_errors: bool = True,
638
+ ) -> OutputDataT:
639
+ text = _utils.strip_markdown_fences(data)
640
+
641
+ return await self.wrapped.process(
642
+ text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
643
+ )
644
+
645
+
646
+ @dataclass(init=False)
647
+ class ObjectOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):
616
648
  outer_typed_dict_key: str | None = None
617
649
  validator: SchemaValidator
618
650
  _function_schema: _function_schema.FunctionSchema | None = None
@@ -673,11 +705,13 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
673
705
  else:
674
706
  description = f'{description}. {json_schema_description}'
675
707
 
676
- self.object_def = OutputObjectDefinition(
677
- name=name or getattr(output, '__name__', None),
678
- description=description,
679
- json_schema=json_schema,
680
- strict=strict,
708
+ super().__init__(
709
+ object_def=OutputObjectDefinition(
710
+ name=name or getattr(output, '__name__', None),
711
+ description=description,
712
+ json_schema=json_schema,
713
+ strict=strict,
714
+ )
681
715
  )
682
716
 
683
717
  async def process(
@@ -726,10 +760,10 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
726
760
 
727
761
  async def call(
728
762
  self,
729
- output: Any,
763
+ output: dict[str, Any],
730
764
  run_context: RunContext[AgentDepsT],
731
765
  wrap_validation_errors: bool = True,
732
- ):
766
+ ) -> Any:
733
767
  if k := self.outer_typed_dict_key:
734
768
  output = output[k]
735
769
 
@@ -753,8 +787,7 @@ class UnionOutputModel:
753
787
 
754
788
 
755
789
  @dataclass(init=False)
756
- class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
757
- object_def: OutputObjectDefinition
790
+ class UnionOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):
758
791
  _union_processor: ObjectOutputProcessor[UnionOutputModel]
759
792
  _processors: dict[str, ObjectOutputProcessor[OutputDataT]]
760
793
 
@@ -830,16 +863,18 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
830
863
  if all_defs:
831
864
  json_schema['$defs'] = all_defs
832
865
 
833
- self.object_def = OutputObjectDefinition(
834
- json_schema=json_schema,
835
- strict=strict,
836
- name=name,
837
- description=description,
866
+ super().__init__(
867
+ object_def=OutputObjectDefinition(
868
+ json_schema=json_schema,
869
+ strict=strict,
870
+ name=name,
871
+ description=description,
872
+ )
838
873
  )
839
874
 
840
875
  async def process(
841
876
  self,
842
- data: str | dict[str, Any] | None,
877
+ data: str,
843
878
  run_context: RunContext[AgentDepsT],
844
879
  allow_partial: bool = False,
845
880
  wrap_validation_errors: bool = True,
@@ -850,7 +885,7 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
850
885
 
851
886
  result = union_object.result
852
887
  kind = result.kind
853
- data = result.data
888
+ inner_data = result.data
854
889
  try:
855
890
  processor = self._processors[kind]
856
891
  except KeyError as e: # pragma: no cover
@@ -861,12 +896,23 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
861
896
  raise
862
897
 
863
898
  return await processor.process(
864
- data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
899
+ inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
865
900
  )
866
901
 
867
902
 
903
+ class TextOutputProcessor(BaseOutputProcessor[OutputDataT]):
904
+ async def process(
905
+ self,
906
+ data: str,
907
+ run_context: RunContext[AgentDepsT],
908
+ allow_partial: bool = False,
909
+ wrap_validation_errors: bool = True,
910
+ ) -> OutputDataT:
911
+ return cast(OutputDataT, data)
912
+
913
+
868
914
  @dataclass(init=False)
869
- class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
915
+ class TextFunctionOutputProcessor(TextOutputProcessor[OutputDataT]):
870
916
  _function_schema: _function_schema.FunctionSchema
871
917
  _str_argument_name: str
872
918
 
@@ -876,17 +922,15 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
876
922
  ):
877
923
  self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema)
878
924
 
879
- arguments_schema = self._function_schema.json_schema.get('properties', {})
880
- argument_name = next(iter(arguments_schema.keys()), None)
881
- if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string':
882
- self._str_argument_name = argument_name
883
- return
925
+ if (
926
+ not (arguments_schema := self._function_schema.json_schema.get('properties', {}))
927
+ or len(arguments_schema) != 1
928
+ or not (argument_name := next(iter(arguments_schema.keys()), None))
929
+ or arguments_schema.get(argument_name, {}).get('type') != 'string'
930
+ ):
931
+ raise UserError('TextOutput must take a function taking a single `str` argument')
884
932
 
885
- raise UserError('TextOutput must take a function taking a `str`')
886
-
887
- @property
888
- def object_def(self) -> None:
889
- return None # pragma: no cover
933
+ self._str_argument_name = argument_name
890
934
 
891
935
  async def process(
892
936
  self,
@@ -896,9 +940,9 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
896
940
  wrap_validation_errors: bool = True,
897
941
  ) -> OutputDataT:
898
942
  args = {self._str_argument_name: data}
899
- output = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
943
+ data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
900
944
 
901
- return cast(OutputDataT, output)
945
+ return await super().process(data, run_context, allow_partial, wrap_validation_errors)
902
946
 
903
947
 
904
948
  @dataclass(init=False)