pydantic-ai-slim 1.0.13__py3-none-any.whl → 1.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +19 -1
- pydantic_ai/_agent_graph.py +118 -97
- pydantic_ai/_cli.py +4 -7
- pydantic_ai/_output.py +236 -192
- pydantic_ai/_parts_manager.py +8 -42
- pydantic_ai/_tool_manager.py +9 -16
- pydantic_ai/agent/abstract.py +169 -1
- pydantic_ai/builtin_tools.py +82 -0
- pydantic_ai/direct.py +7 -0
- pydantic_ai/durable_exec/dbos/_agent.py +106 -3
- pydantic_ai/durable_exec/temporal/_agent.py +123 -6
- pydantic_ai/durable_exec/temporal/_model.py +8 -0
- pydantic_ai/format_prompt.py +4 -3
- pydantic_ai/mcp.py +20 -10
- pydantic_ai/messages.py +149 -3
- pydantic_ai/models/__init__.py +15 -1
- pydantic_ai/models/anthropic.py +7 -3
- pydantic_ai/models/cohere.py +4 -0
- pydantic_ai/models/function.py +7 -4
- pydantic_ai/models/gemini.py +8 -0
- pydantic_ai/models/google.py +56 -23
- pydantic_ai/models/groq.py +11 -5
- pydantic_ai/models/huggingface.py +5 -3
- pydantic_ai/models/mistral.py +6 -8
- pydantic_ai/models/openai.py +197 -58
- pydantic_ai/models/test.py +4 -0
- pydantic_ai/output.py +5 -2
- pydantic_ai/profiles/__init__.py +2 -0
- pydantic_ai/profiles/google.py +5 -2
- pydantic_ai/profiles/openai.py +2 -1
- pydantic_ai/result.py +46 -30
- pydantic_ai/run.py +35 -7
- pydantic_ai/usage.py +5 -4
- {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/METADATA +3 -3
- {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/RECORD +38 -38
- {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.13.dist-info → pydantic_ai_slim-1.0.15.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py
CHANGED
|
@@ -75,7 +75,7 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversati
|
|
|
75
75
|
async def execute_traced_output_function(
|
|
76
76
|
function_schema: _function_schema.FunctionSchema,
|
|
77
77
|
run_context: RunContext[AgentDepsT],
|
|
78
|
-
args: dict[str, Any]
|
|
78
|
+
args: dict[str, Any],
|
|
79
79
|
wrap_validation_errors: bool = True,
|
|
80
80
|
) -> Any:
|
|
81
81
|
"""Execute an output function within a traced span with error handling.
|
|
@@ -209,18 +209,20 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
209
209
|
return result_data
|
|
210
210
|
|
|
211
211
|
|
|
212
|
-
@dataclass
|
|
212
|
+
@dataclass(kw_only=True)
|
|
213
213
|
class BaseOutputSchema(ABC, Generic[OutputDataT]):
|
|
214
|
-
|
|
214
|
+
text_processor: BaseOutputProcessor[OutputDataT] | None = None
|
|
215
|
+
toolset: OutputToolset[Any] | None = None
|
|
216
|
+
allows_deferred_tools: bool = False
|
|
217
|
+
allows_image: bool = False
|
|
215
218
|
|
|
216
219
|
@abstractmethod
|
|
217
220
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
218
221
|
raise NotImplementedError()
|
|
219
222
|
|
|
220
223
|
@property
|
|
221
|
-
def
|
|
222
|
-
|
|
223
|
-
return None
|
|
224
|
+
def allows_text(self) -> bool:
|
|
225
|
+
return self.text_processor is not None
|
|
224
226
|
|
|
225
227
|
|
|
226
228
|
@dataclass(init=False)
|
|
@@ -262,38 +264,67 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
262
264
|
strict: bool | None = None,
|
|
263
265
|
) -> BaseOutputSchema[OutputDataT]:
|
|
264
266
|
"""Build an OutputSchema dataclass from an output type."""
|
|
265
|
-
|
|
267
|
+
outputs = _flatten_output_spec(output_spec)
|
|
268
|
+
|
|
269
|
+
allows_deferred_tools = DeferredToolRequests in outputs
|
|
270
|
+
if allows_deferred_tools:
|
|
271
|
+
outputs = [output for output in outputs if output is not DeferredToolRequests]
|
|
272
|
+
if len(outputs) == 0:
|
|
273
|
+
raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
|
|
266
274
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
|
|
275
|
+
allows_image = _messages.BinaryImage in outputs
|
|
276
|
+
if allows_image:
|
|
277
|
+
outputs = [output for output in outputs if output is not _messages.BinaryImage]
|
|
271
278
|
|
|
272
279
|
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
|
|
273
280
|
if len(outputs) > 1:
|
|
274
281
|
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
|
|
275
282
|
|
|
283
|
+
flattened_outputs = _flatten_output_spec(output.outputs)
|
|
284
|
+
|
|
285
|
+
if DeferredToolRequests in flattened_outputs:
|
|
286
|
+
raise UserError( # pragma: no cover
|
|
287
|
+
'`NativeOutput` cannot contain `DeferredToolRequests`. Include it alongside the native output marker instead: `output_type=[NativeOutput(...), DeferredToolRequests]`'
|
|
288
|
+
)
|
|
289
|
+
if _messages.BinaryImage in flattened_outputs:
|
|
290
|
+
raise UserError( # pragma: no cover
|
|
291
|
+
'`NativeOutput` cannot contain `BinaryImage`. Include it alongside the native output marker instead: `output_type=[NativeOutput(...), BinaryImage]`'
|
|
292
|
+
)
|
|
293
|
+
|
|
276
294
|
return NativeOutputSchema(
|
|
277
295
|
processor=cls._build_processor(
|
|
278
|
-
|
|
296
|
+
flattened_outputs,
|
|
279
297
|
name=output.name,
|
|
280
298
|
description=output.description,
|
|
281
299
|
strict=output.strict,
|
|
282
300
|
),
|
|
283
301
|
allows_deferred_tools=allows_deferred_tools,
|
|
302
|
+
allows_image=allows_image,
|
|
284
303
|
)
|
|
285
304
|
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
|
|
286
305
|
if len(outputs) > 1:
|
|
287
306
|
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
|
|
288
307
|
|
|
308
|
+
flattened_outputs = _flatten_output_spec(output.outputs)
|
|
309
|
+
|
|
310
|
+
if DeferredToolRequests in flattened_outputs:
|
|
311
|
+
raise UserError( # pragma: no cover
|
|
312
|
+
'`PromptedOutput` cannot contain `DeferredToolRequests`. Include it alongside the prompted output marker instead: `output_type=[PromptedOutput(...), DeferredToolRequests]`'
|
|
313
|
+
)
|
|
314
|
+
if _messages.BinaryImage in flattened_outputs:
|
|
315
|
+
raise UserError( # pragma: no cover
|
|
316
|
+
'`PromptedOutput` cannot contain `BinaryImage`. Include it alongside the prompted output marker instead: `output_type=[PromptedOutput(...), BinaryImage]`'
|
|
317
|
+
)
|
|
318
|
+
|
|
289
319
|
return PromptedOutputSchema(
|
|
320
|
+
template=output.template,
|
|
290
321
|
processor=cls._build_processor(
|
|
291
|
-
|
|
322
|
+
flattened_outputs,
|
|
292
323
|
name=output.name,
|
|
293
324
|
description=output.description,
|
|
294
325
|
),
|
|
295
|
-
template=output.template,
|
|
296
326
|
allows_deferred_tools=allows_deferred_tools,
|
|
327
|
+
allows_image=allows_image,
|
|
297
328
|
)
|
|
298
329
|
|
|
299
330
|
text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
|
|
@@ -317,37 +348,51 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
317
348
|
|
|
318
349
|
toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict)
|
|
319
350
|
|
|
351
|
+
text_processor: BaseOutputProcessor[OutputDataT] | None = None
|
|
352
|
+
|
|
320
353
|
if len(text_outputs) > 0:
|
|
321
354
|
if len(text_outputs) > 1:
|
|
322
355
|
raise UserError('Only one `str` or `TextOutput` is allowed.')
|
|
323
356
|
text_output = text_outputs[0]
|
|
324
357
|
|
|
325
|
-
text_output_schema = None
|
|
326
358
|
if isinstance(text_output, TextOutput):
|
|
327
|
-
|
|
359
|
+
text_processor = TextFunctionOutputProcessor(text_output.output_function)
|
|
360
|
+
else:
|
|
361
|
+
text_processor = TextOutputProcessor()
|
|
328
362
|
|
|
329
363
|
if toolset:
|
|
330
|
-
return
|
|
331
|
-
processor=text_output_schema,
|
|
364
|
+
return ToolOutputSchema(
|
|
332
365
|
toolset=toolset,
|
|
366
|
+
text_processor=text_processor,
|
|
333
367
|
allows_deferred_tools=allows_deferred_tools,
|
|
368
|
+
allows_image=allows_image,
|
|
334
369
|
)
|
|
335
370
|
else:
|
|
336
|
-
return
|
|
371
|
+
return TextOutputSchema(
|
|
372
|
+
text_processor=text_processor,
|
|
373
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
374
|
+
allows_image=allows_image,
|
|
375
|
+
)
|
|
337
376
|
|
|
338
377
|
if len(tool_outputs) > 0:
|
|
339
|
-
return ToolOutputSchema(
|
|
378
|
+
return ToolOutputSchema(
|
|
379
|
+
toolset=toolset, allows_deferred_tools=allows_deferred_tools, allows_image=allows_image
|
|
380
|
+
)
|
|
340
381
|
|
|
341
382
|
if len(other_outputs) > 0:
|
|
342
383
|
schema = OutputSchemaWithoutMode(
|
|
343
384
|
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
|
|
344
385
|
toolset=toolset,
|
|
345
386
|
allows_deferred_tools=allows_deferred_tools,
|
|
387
|
+
allows_image=allows_image,
|
|
346
388
|
)
|
|
347
389
|
if default_mode:
|
|
348
390
|
schema = schema.with_default_mode(default_mode)
|
|
349
391
|
return schema
|
|
350
392
|
|
|
393
|
+
if allows_image:
|
|
394
|
+
return ImageOutputSchema(allows_deferred_tools=allows_deferred_tools)
|
|
395
|
+
|
|
351
396
|
raise UserError('At least one output type must be provided.')
|
|
352
397
|
|
|
353
398
|
@staticmethod
|
|
@@ -356,7 +401,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
356
401
|
name: str | None = None,
|
|
357
402
|
description: str | None = None,
|
|
358
403
|
strict: bool | None = None,
|
|
359
|
-
) ->
|
|
404
|
+
) -> BaseObjectOutputProcessor[OutputDataT]:
|
|
360
405
|
outputs = _flatten_output_spec(outputs)
|
|
361
406
|
if len(outputs) == 1:
|
|
362
407
|
return ObjectOutputProcessor(output=outputs[0], name=name, description=description, strict=strict)
|
|
@@ -368,10 +413,10 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
368
413
|
def mode(self) -> OutputMode:
|
|
369
414
|
raise NotImplementedError()
|
|
370
415
|
|
|
371
|
-
@abstractmethod
|
|
372
416
|
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
373
|
-
"""Raise an error if the mode is not supported by
|
|
374
|
-
|
|
417
|
+
"""Raise an error if the mode is not supported by this model."""
|
|
418
|
+
if self.allows_image and not profile.supports_image_output:
|
|
419
|
+
raise UserError('Image output is not supported by this model.')
|
|
375
420
|
|
|
376
421
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
377
422
|
return self
|
|
@@ -379,142 +424,139 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
379
424
|
|
|
380
425
|
@dataclass(init=False)
|
|
381
426
|
class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
|
|
382
|
-
processor:
|
|
383
|
-
_toolset: OutputToolset[Any] | None
|
|
427
|
+
processor: BaseObjectOutputProcessor[OutputDataT]
|
|
384
428
|
|
|
385
429
|
def __init__(
|
|
386
430
|
self,
|
|
387
|
-
processor:
|
|
431
|
+
processor: BaseObjectOutputProcessor[OutputDataT],
|
|
388
432
|
toolset: OutputToolset[Any] | None,
|
|
389
433
|
allows_deferred_tools: bool,
|
|
434
|
+
allows_image: bool,
|
|
390
435
|
):
|
|
391
|
-
|
|
436
|
+
# We set a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
|
|
437
|
+
# At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
|
|
438
|
+
# but we cover ourselves just in case we end up using the tool output mode.
|
|
439
|
+
super().__init__(
|
|
440
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
441
|
+
toolset=toolset,
|
|
442
|
+
text_processor=processor,
|
|
443
|
+
allows_image=allows_image,
|
|
444
|
+
)
|
|
392
445
|
self.processor = processor
|
|
393
|
-
self._toolset = toolset
|
|
394
446
|
|
|
395
447
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
396
448
|
if mode == 'native':
|
|
397
|
-
return NativeOutputSchema(
|
|
449
|
+
return NativeOutputSchema(
|
|
450
|
+
processor=self.processor,
|
|
451
|
+
allows_deferred_tools=self.allows_deferred_tools,
|
|
452
|
+
allows_image=self.allows_image,
|
|
453
|
+
)
|
|
398
454
|
elif mode == 'prompted':
|
|
399
|
-
return PromptedOutputSchema(
|
|
455
|
+
return PromptedOutputSchema(
|
|
456
|
+
processor=self.processor,
|
|
457
|
+
allows_deferred_tools=self.allows_deferred_tools,
|
|
458
|
+
allows_image=self.allows_image,
|
|
459
|
+
)
|
|
400
460
|
elif mode == 'tool':
|
|
401
|
-
return ToolOutputSchema(
|
|
461
|
+
return ToolOutputSchema(
|
|
462
|
+
toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools, allows_image=self.allows_image
|
|
463
|
+
)
|
|
402
464
|
else:
|
|
403
465
|
assert_never(mode)
|
|
404
466
|
|
|
405
|
-
@property
|
|
406
|
-
def toolset(self) -> OutputToolset[Any] | None:
|
|
407
|
-
"""Get the toolset for this output schema."""
|
|
408
|
-
# We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
|
|
409
|
-
# At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
|
|
410
|
-
# but we cover ourselves just in case we end up using the tool output mode.
|
|
411
|
-
return self._toolset
|
|
412
|
-
|
|
413
467
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
468
|
+
@dataclass(init=False)
|
|
469
|
+
class TextOutputSchema(OutputSchema[OutputDataT]):
|
|
470
|
+
def __init__(
|
|
417
471
|
self,
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
)
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
processor: PlainTextOutputProcessor[OutputDataT] | None = None
|
|
472
|
+
*,
|
|
473
|
+
text_processor: TextOutputProcessor[OutputDataT],
|
|
474
|
+
allows_deferred_tools: bool,
|
|
475
|
+
allows_image: bool,
|
|
476
|
+
):
|
|
477
|
+
super().__init__(
|
|
478
|
+
text_processor=text_processor,
|
|
479
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
480
|
+
allows_image=allows_image,
|
|
481
|
+
)
|
|
429
482
|
|
|
430
483
|
@property
|
|
431
484
|
def mode(self) -> OutputMode:
|
|
432
485
|
return 'text'
|
|
433
486
|
|
|
434
487
|
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
435
|
-
"""Raise an error if the mode is not supported by
|
|
436
|
-
|
|
488
|
+
"""Raise an error if the mode is not supported by this model."""
|
|
489
|
+
super().raise_if_unsupported(profile)
|
|
437
490
|
|
|
438
|
-
async def process(
|
|
439
|
-
self,
|
|
440
|
-
text: str,
|
|
441
|
-
run_context: RunContext[AgentDepsT],
|
|
442
|
-
allow_partial: bool = False,
|
|
443
|
-
wrap_validation_errors: bool = True,
|
|
444
|
-
) -> OutputDataT:
|
|
445
|
-
"""Validate an output message.
|
|
446
491
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
allow_partial: If true, allow partial validation.
|
|
451
|
-
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
492
|
+
class ImageOutputSchema(OutputSchema[OutputDataT]):
|
|
493
|
+
def __init__(self, *, allows_deferred_tools: bool):
|
|
494
|
+
super().__init__(allows_deferred_tools=allows_deferred_tools, allows_image=True)
|
|
452
495
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
if self.processor is None:
|
|
457
|
-
return cast(OutputDataT, text)
|
|
496
|
+
@property
|
|
497
|
+
def mode(self) -> OutputMode:
|
|
498
|
+
return 'image'
|
|
458
499
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
500
|
+
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
501
|
+
"""Raise an error if the mode is not supported by this model."""
|
|
502
|
+
# This already raises if image output is not supported by this model.
|
|
503
|
+
super().raise_if_unsupported(profile)
|
|
462
504
|
|
|
463
505
|
|
|
464
|
-
@dataclass
|
|
465
|
-
class StructuredTextOutputSchema(
|
|
466
|
-
processor:
|
|
506
|
+
@dataclass(init=False)
|
|
507
|
+
class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC):
|
|
508
|
+
processor: BaseObjectOutputProcessor[OutputDataT]
|
|
509
|
+
|
|
510
|
+
def __init__(
|
|
511
|
+
self, *, processor: BaseObjectOutputProcessor[OutputDataT], allows_deferred_tools: bool, allows_image: bool
|
|
512
|
+
):
|
|
513
|
+
super().__init__(
|
|
514
|
+
text_processor=processor, allows_deferred_tools=allows_deferred_tools, allows_image=allows_image
|
|
515
|
+
)
|
|
516
|
+
self.processor = processor
|
|
467
517
|
|
|
468
518
|
@property
|
|
469
519
|
def object_def(self) -> OutputObjectDefinition:
|
|
470
520
|
return self.processor.object_def
|
|
471
521
|
|
|
472
522
|
|
|
473
|
-
@dataclass
|
|
474
523
|
class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
475
524
|
@property
|
|
476
525
|
def mode(self) -> OutputMode:
|
|
477
526
|
return 'native'
|
|
478
527
|
|
|
479
528
|
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
480
|
-
"""Raise an error if the mode is not supported by
|
|
529
|
+
"""Raise an error if the mode is not supported by this model."""
|
|
481
530
|
if not profile.supports_json_schema_output:
|
|
482
|
-
raise UserError('Native structured output is not supported by
|
|
531
|
+
raise UserError('Native structured output is not supported by this model.')
|
|
483
532
|
|
|
484
|
-
async def process(
|
|
485
|
-
self,
|
|
486
|
-
text: str,
|
|
487
|
-
run_context: RunContext[AgentDepsT],
|
|
488
|
-
allow_partial: bool = False,
|
|
489
|
-
wrap_validation_errors: bool = True,
|
|
490
|
-
) -> OutputDataT:
|
|
491
|
-
"""Validate an output message.
|
|
492
533
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
allow_partial: If true, allow partial validation.
|
|
497
|
-
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
534
|
+
@dataclass(init=False)
|
|
535
|
+
class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
536
|
+
template: str | None
|
|
498
537
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
538
|
+
def __init__(
|
|
539
|
+
self,
|
|
540
|
+
*,
|
|
541
|
+
template: str | None = None,
|
|
542
|
+
processor: BaseObjectOutputProcessor[OutputDataT],
|
|
543
|
+
allows_deferred_tools: bool,
|
|
544
|
+
allows_image: bool,
|
|
545
|
+
):
|
|
546
|
+
super().__init__(
|
|
547
|
+
processor=PromptedOutputProcessor(processor),
|
|
548
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
549
|
+
allows_image=allows_image,
|
|
504
550
|
)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
@dataclass
|
|
508
|
-
class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
509
|
-
template: str | None = None
|
|
551
|
+
self.template = template
|
|
510
552
|
|
|
511
553
|
@property
|
|
512
554
|
def mode(self) -> OutputMode:
|
|
513
555
|
return 'prompted'
|
|
514
556
|
|
|
515
557
|
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
516
|
-
"""Raise an error if the mode is not supported by
|
|
517
|
-
|
|
558
|
+
"""Raise an error if the mode is not supported by this model."""
|
|
559
|
+
super().raise_if_unsupported(profile)
|
|
518
560
|
|
|
519
561
|
def instructions(self, default_template: str) -> str:
|
|
520
562
|
"""Get instructions to tell model to output JSON matching the schema."""
|
|
@@ -532,71 +574,35 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
532
574
|
|
|
533
575
|
return template.format(schema=json.dumps(schema))
|
|
534
576
|
|
|
535
|
-
async def process(
|
|
536
|
-
self,
|
|
537
|
-
text: str,
|
|
538
|
-
run_context: RunContext[AgentDepsT],
|
|
539
|
-
allow_partial: bool = False,
|
|
540
|
-
wrap_validation_errors: bool = True,
|
|
541
|
-
) -> OutputDataT:
|
|
542
|
-
"""Validate an output message.
|
|
543
|
-
|
|
544
|
-
Args:
|
|
545
|
-
text: The output text to validate.
|
|
546
|
-
run_context: The current run context.
|
|
547
|
-
allow_partial: If true, allow partial validation.
|
|
548
|
-
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
549
|
-
|
|
550
|
-
Returns:
|
|
551
|
-
Either the validated output data (left) or a retry message (right).
|
|
552
|
-
"""
|
|
553
|
-
text = _utils.strip_markdown_fences(text)
|
|
554
|
-
|
|
555
|
-
return await self.processor.process(
|
|
556
|
-
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
557
|
-
)
|
|
558
|
-
|
|
559
577
|
|
|
560
578
|
@dataclass(init=False)
|
|
561
579
|
class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
562
|
-
_toolset: OutputToolset[Any] | None
|
|
563
|
-
|
|
564
|
-
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tools: bool):
|
|
565
|
-
super().__init__(allows_deferred_tools)
|
|
566
|
-
self._toolset = toolset
|
|
567
|
-
|
|
568
|
-
@property
|
|
569
|
-
def mode(self) -> OutputMode:
|
|
570
|
-
return 'tool'
|
|
571
|
-
|
|
572
|
-
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
573
|
-
"""Raise an error if the mode is not supported by the model."""
|
|
574
|
-
if not profile.supports_tools:
|
|
575
|
-
raise UserError('Output tools are not supported by the model.')
|
|
576
|
-
|
|
577
|
-
@property
|
|
578
|
-
def toolset(self) -> OutputToolset[Any] | None:
|
|
579
|
-
"""Get the toolset for this output schema."""
|
|
580
|
-
return self._toolset
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
@dataclass(init=False)
|
|
584
|
-
class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchema[OutputDataT]):
|
|
585
580
|
def __init__(
|
|
586
581
|
self,
|
|
587
|
-
|
|
582
|
+
*,
|
|
588
583
|
toolset: OutputToolset[Any] | None,
|
|
584
|
+
text_processor: BaseOutputProcessor[OutputDataT] | None = None,
|
|
589
585
|
allows_deferred_tools: bool,
|
|
586
|
+
allows_image: bool,
|
|
590
587
|
):
|
|
591
|
-
super().__init__(
|
|
592
|
-
|
|
588
|
+
super().__init__(
|
|
589
|
+
toolset=toolset,
|
|
590
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
591
|
+
text_processor=text_processor,
|
|
592
|
+
allows_image=allows_image,
|
|
593
|
+
)
|
|
593
594
|
|
|
594
595
|
@property
|
|
595
596
|
def mode(self) -> OutputMode:
|
|
596
|
-
return '
|
|
597
|
+
return 'tool'
|
|
598
|
+
|
|
599
|
+
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
600
|
+
"""Raise an error if the mode is not supported by this model."""
|
|
601
|
+
super().raise_if_unsupported(profile)
|
|
602
|
+
if not profile.supports_tools:
|
|
603
|
+
raise UserError('Tool output is not supported by this model.')
|
|
597
604
|
|
|
598
605
|
|
|
599
|
-
@dataclass(init=False)
|
|
600
606
|
class BaseOutputProcessor(ABC, Generic[OutputDataT]):
|
|
601
607
|
@abstractmethod
|
|
602
608
|
async def process(
|
|
@@ -610,9 +616,35 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
|
|
|
610
616
|
raise NotImplementedError()
|
|
611
617
|
|
|
612
618
|
|
|
613
|
-
@dataclass(
|
|
614
|
-
class
|
|
619
|
+
@dataclass(kw_only=True)
|
|
620
|
+
class BaseObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
615
621
|
object_def: OutputObjectDefinition
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
@dataclass(init=False)
|
|
625
|
+
class PromptedOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):
|
|
626
|
+
wrapped: BaseObjectOutputProcessor[OutputDataT]
|
|
627
|
+
|
|
628
|
+
def __init__(self, wrapped: BaseObjectOutputProcessor[OutputDataT]):
|
|
629
|
+
self.wrapped = wrapped
|
|
630
|
+
super().__init__(object_def=wrapped.object_def)
|
|
631
|
+
|
|
632
|
+
async def process(
|
|
633
|
+
self,
|
|
634
|
+
data: str,
|
|
635
|
+
run_context: RunContext[AgentDepsT],
|
|
636
|
+
allow_partial: bool = False,
|
|
637
|
+
wrap_validation_errors: bool = True,
|
|
638
|
+
) -> OutputDataT:
|
|
639
|
+
text = _utils.strip_markdown_fences(data)
|
|
640
|
+
|
|
641
|
+
return await self.wrapped.process(
|
|
642
|
+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
@dataclass(init=False)
|
|
647
|
+
class ObjectOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):
|
|
616
648
|
outer_typed_dict_key: str | None = None
|
|
617
649
|
validator: SchemaValidator
|
|
618
650
|
_function_schema: _function_schema.FunctionSchema | None = None
|
|
@@ -673,11 +705,13 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
673
705
|
else:
|
|
674
706
|
description = f'{description}. {json_schema_description}'
|
|
675
707
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
708
|
+
super().__init__(
|
|
709
|
+
object_def=OutputObjectDefinition(
|
|
710
|
+
name=name or getattr(output, '__name__', None),
|
|
711
|
+
description=description,
|
|
712
|
+
json_schema=json_schema,
|
|
713
|
+
strict=strict,
|
|
714
|
+
)
|
|
681
715
|
)
|
|
682
716
|
|
|
683
717
|
async def process(
|
|
@@ -726,10 +760,10 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
726
760
|
|
|
727
761
|
async def call(
|
|
728
762
|
self,
|
|
729
|
-
output: Any,
|
|
763
|
+
output: dict[str, Any],
|
|
730
764
|
run_context: RunContext[AgentDepsT],
|
|
731
765
|
wrap_validation_errors: bool = True,
|
|
732
|
-
):
|
|
766
|
+
) -> Any:
|
|
733
767
|
if k := self.outer_typed_dict_key:
|
|
734
768
|
output = output[k]
|
|
735
769
|
|
|
@@ -753,8 +787,7 @@ class UnionOutputModel:
|
|
|
753
787
|
|
|
754
788
|
|
|
755
789
|
@dataclass(init=False)
|
|
756
|
-
class UnionOutputProcessor(
|
|
757
|
-
object_def: OutputObjectDefinition
|
|
790
|
+
class UnionOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):
|
|
758
791
|
_union_processor: ObjectOutputProcessor[UnionOutputModel]
|
|
759
792
|
_processors: dict[str, ObjectOutputProcessor[OutputDataT]]
|
|
760
793
|
|
|
@@ -830,16 +863,18 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
830
863
|
if all_defs:
|
|
831
864
|
json_schema['$defs'] = all_defs
|
|
832
865
|
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
866
|
+
super().__init__(
|
|
867
|
+
object_def=OutputObjectDefinition(
|
|
868
|
+
json_schema=json_schema,
|
|
869
|
+
strict=strict,
|
|
870
|
+
name=name,
|
|
871
|
+
description=description,
|
|
872
|
+
)
|
|
838
873
|
)
|
|
839
874
|
|
|
840
875
|
async def process(
|
|
841
876
|
self,
|
|
842
|
-
data: str
|
|
877
|
+
data: str,
|
|
843
878
|
run_context: RunContext[AgentDepsT],
|
|
844
879
|
allow_partial: bool = False,
|
|
845
880
|
wrap_validation_errors: bool = True,
|
|
@@ -850,7 +885,7 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
850
885
|
|
|
851
886
|
result = union_object.result
|
|
852
887
|
kind = result.kind
|
|
853
|
-
|
|
888
|
+
inner_data = result.data
|
|
854
889
|
try:
|
|
855
890
|
processor = self._processors[kind]
|
|
856
891
|
except KeyError as e: # pragma: no cover
|
|
@@ -861,12 +896,23 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
861
896
|
raise
|
|
862
897
|
|
|
863
898
|
return await processor.process(
|
|
864
|
-
|
|
899
|
+
inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
865
900
|
)
|
|
866
901
|
|
|
867
902
|
|
|
903
|
+
class TextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
904
|
+
async def process(
|
|
905
|
+
self,
|
|
906
|
+
data: str,
|
|
907
|
+
run_context: RunContext[AgentDepsT],
|
|
908
|
+
allow_partial: bool = False,
|
|
909
|
+
wrap_validation_errors: bool = True,
|
|
910
|
+
) -> OutputDataT:
|
|
911
|
+
return cast(OutputDataT, data)
|
|
912
|
+
|
|
913
|
+
|
|
868
914
|
@dataclass(init=False)
|
|
869
|
-
class
|
|
915
|
+
class TextFunctionOutputProcessor(TextOutputProcessor[OutputDataT]):
|
|
870
916
|
_function_schema: _function_schema.FunctionSchema
|
|
871
917
|
_str_argument_name: str
|
|
872
918
|
|
|
@@ -876,17 +922,15 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
876
922
|
):
|
|
877
923
|
self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema)
|
|
878
924
|
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
925
|
+
if (
|
|
926
|
+
not (arguments_schema := self._function_schema.json_schema.get('properties', {}))
|
|
927
|
+
or len(arguments_schema) != 1
|
|
928
|
+
or not (argument_name := next(iter(arguments_schema.keys()), None))
|
|
929
|
+
or arguments_schema.get(argument_name, {}).get('type') != 'string'
|
|
930
|
+
):
|
|
931
|
+
raise UserError('TextOutput must take a function taking a single `str` argument')
|
|
884
932
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
@property
|
|
888
|
-
def object_def(self) -> None:
|
|
889
|
-
return None # pragma: no cover
|
|
933
|
+
self._str_argument_name = argument_name
|
|
890
934
|
|
|
891
935
|
async def process(
|
|
892
936
|
self,
|
|
@@ -896,9 +940,9 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
896
940
|
wrap_validation_errors: bool = True,
|
|
897
941
|
) -> OutputDataT:
|
|
898
942
|
args = {self._str_argument_name: data}
|
|
899
|
-
|
|
943
|
+
data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
|
|
900
944
|
|
|
901
|
-
return
|
|
945
|
+
return await super().process(data, run_context, allow_partial, wrap_validation_errors)
|
|
902
946
|
|
|
903
947
|
|
|
904
948
|
@dataclass(init=False)
|