pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (55) hide show
  1. pydantic_ai/_agent_graph.py +219 -315
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +296 -226
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -155
  10. pydantic_ai/common_tools/duckduckgo.py +5 -2
  11. pydantic_ai/exceptions.py +14 -2
  12. pydantic_ai/ext/aci.py +12 -3
  13. pydantic_ai/ext/langchain.py +9 -1
  14. pydantic_ai/mcp.py +147 -84
  15. pydantic_ai/messages.py +19 -9
  16. pydantic_ai/models/__init__.py +43 -19
  17. pydantic_ai/models/anthropic.py +2 -2
  18. pydantic_ai/models/bedrock.py +1 -1
  19. pydantic_ai/models/cohere.py +1 -1
  20. pydantic_ai/models/function.py +50 -24
  21. pydantic_ai/models/gemini.py +3 -11
  22. pydantic_ai/models/google.py +3 -12
  23. pydantic_ai/models/groq.py +2 -1
  24. pydantic_ai/models/huggingface.py +463 -0
  25. pydantic_ai/models/instrumented.py +1 -1
  26. pydantic_ai/models/mistral.py +3 -3
  27. pydantic_ai/models/openai.py +5 -5
  28. pydantic_ai/output.py +21 -7
  29. pydantic_ai/profiles/google.py +1 -1
  30. pydantic_ai/profiles/moonshotai.py +8 -0
  31. pydantic_ai/providers/__init__.py +4 -0
  32. pydantic_ai/providers/google.py +2 -2
  33. pydantic_ai/providers/google_vertex.py +10 -5
  34. pydantic_ai/providers/grok.py +13 -1
  35. pydantic_ai/providers/groq.py +2 -0
  36. pydantic_ai/providers/huggingface.py +88 -0
  37. pydantic_ai/result.py +57 -33
  38. pydantic_ai/tools.py +26 -119
  39. pydantic_ai/toolsets/__init__.py +22 -0
  40. pydantic_ai/toolsets/abstract.py +155 -0
  41. pydantic_ai/toolsets/combined.py +88 -0
  42. pydantic_ai/toolsets/deferred.py +38 -0
  43. pydantic_ai/toolsets/filtered.py +24 -0
  44. pydantic_ai/toolsets/function.py +238 -0
  45. pydantic_ai/toolsets/prefixed.py +37 -0
  46. pydantic_ai/toolsets/prepared.py +36 -0
  47. pydantic_ai/toolsets/renamed.py +42 -0
  48. pydantic_ai/toolsets/wrapper.py +37 -0
  49. pydantic_ai/usage.py +14 -8
  50. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
  51. pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
  52. pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
  53. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  54. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  55. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py CHANGED
@@ -3,18 +3,19 @@ from __future__ import annotations as _annotations
3
3
  import inspect
4
4
  import json
5
5
  from abc import ABC, abstractmethod
6
- from collections.abc import Awaitable, Iterable, Iterator, Sequence
6
+ from collections.abc import Awaitable, Sequence
7
7
  from dataclasses import dataclass, field
8
8
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
9
9
 
10
10
  from pydantic import TypeAdapter, ValidationError
11
- from pydantic_core import SchemaValidator
12
- from typing_extensions import TypedDict, TypeVar, assert_never
11
+ from pydantic_core import SchemaValidator, to_json
12
+ from typing_extensions import Self, TypedDict, TypeVar, assert_never
13
13
 
14
14
  from . import _function_schema, _utils, messages as _messages
15
15
  from ._run_context import AgentDepsT, RunContext
16
- from .exceptions import ModelRetry, UserError
16
+ from .exceptions import ModelRetry, ToolRetryError, UserError
17
17
  from .output import (
18
+ DeferredToolCalls,
18
19
  NativeOutput,
19
20
  OutputDataT,
20
21
  OutputMode,
@@ -25,8 +26,10 @@ from .output import (
25
26
  TextOutput,
26
27
  TextOutputFunc,
27
28
  ToolOutput,
29
+ _OutputSpecItem, # type: ignore[reportPrivateUsage]
28
30
  )
29
31
  from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
32
+ from .toolsets.abstract import AbstractToolset, ToolsetTool
30
33
 
31
34
  if TYPE_CHECKING:
32
35
  from .profiles import ModelProfile
@@ -66,12 +69,45 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
66
69
  DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
67
70
 
68
71
 
69
- class ToolRetryError(Exception):
70
- """Exception used to signal a `ToolRetry` message should be returned to the LLM."""
72
+ async def execute_output_function_with_span(
73
+ function_schema: _function_schema.FunctionSchema,
74
+ run_context: RunContext[AgentDepsT],
75
+ args: dict[str, Any] | Any,
76
+ ) -> Any:
77
+ """Execute a function call within a traced span, automatically recording the response."""
78
+ # Set up span attributes
79
+ tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function')
80
+ attributes = {
81
+ 'gen_ai.tool.name': tool_name,
82
+ 'logfire.msg': f'running output function: {tool_name}',
83
+ }
84
+ if run_context.tool_call_id:
85
+ attributes['gen_ai.tool.call.id'] = run_context.tool_call_id
86
+ if run_context.trace_include_content:
87
+ attributes['tool_arguments'] = to_json(args).decode()
88
+ attributes['logfire.json_schema'] = json.dumps(
89
+ {
90
+ 'type': 'object',
91
+ 'properties': {
92
+ 'tool_arguments': {'type': 'object'},
93
+ 'tool_response': {'type': 'object'},
94
+ },
95
+ }
96
+ )
97
+
98
+ with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span:
99
+ output = await function_schema.call(args, run_context)
71
100
 
72
- def __init__(self, tool_retry: _messages.RetryPromptPart):
73
- self.tool_retry = tool_retry
74
- super().__init__()
101
+ # Record response if content inclusion is enabled
102
+ if run_context.trace_include_content and span.is_recording():
103
+ from .models.instrumented import InstrumentedModel
104
+
105
+ span.set_attribute(
106
+ 'tool_response',
107
+ output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
108
+ )
109
+
110
+ return output
75
111
 
76
112
 
77
113
  @dataclass
@@ -87,22 +123,21 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
87
123
  async def validate(
88
124
  self,
89
125
  result: T,
90
- tool_call: _messages.ToolCallPart | None,
91
126
  run_context: RunContext[AgentDepsT],
127
+ wrap_validation_errors: bool = True,
92
128
  ) -> T:
93
129
  """Validate a result but calling the function.
94
130
 
95
131
  Args:
96
132
  result: The result data after Pydantic validation the message content.
97
- tool_call: The original tool call message, `None` if there was no tool call.
98
133
  run_context: The current run context.
134
+ wrap_validation_errors: If true, wrap the validation errors in a retry message.
99
135
 
100
136
  Returns:
101
137
  Result of either the validated result data (ok) or a retry message (Err).
102
138
  """
103
139
  if self._takes_ctx:
104
- ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
105
- args = ctx, result
140
+ args = run_context, result
106
141
  else:
107
142
  args = (result,)
108
143
 
@@ -114,24 +149,32 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
114
149
  function = cast(Callable[[Any], T], self.function)
115
150
  result_data = await _utils.run_in_executor(function, *args)
116
151
  except ModelRetry as r:
117
- m = _messages.RetryPromptPart(content=r.message)
118
- if tool_call is not None:
119
- m.tool_name = tool_call.tool_name
120
- m.tool_call_id = tool_call.tool_call_id
121
- raise ToolRetryError(m) from r
152
+ if wrap_validation_errors:
153
+ m = _messages.RetryPromptPart(
154
+ content=r.message,
155
+ tool_name=run_context.tool_name,
156
+ )
157
+ if run_context.tool_call_id: # pragma: no cover
158
+ m.tool_call_id = run_context.tool_call_id
159
+ raise ToolRetryError(m) from r
160
+ else:
161
+ raise r
122
162
  else:
123
163
  return result_data
124
164
 
125
165
 
166
+ @dataclass
126
167
  class BaseOutputSchema(ABC, Generic[OutputDataT]):
168
+ allows_deferred_tool_calls: bool
169
+
127
170
  @abstractmethod
128
171
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
129
172
  raise NotImplementedError()
130
173
 
131
174
  @property
132
- def tools(self) -> dict[str, OutputTool[OutputDataT]]:
133
- """Get the tools for this output schema."""
134
- return {}
175
+ def toolset(self) -> OutputToolset[Any] | None:
176
+ """Get the toolset for this output schema."""
177
+ return None
135
178
 
136
179
 
137
180
  @dataclass(init=False)
@@ -163,7 +206,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
163
206
  ) -> BaseOutputSchema[OutputDataT]: ...
164
207
 
165
208
  @classmethod
166
- def build(
209
+ def build( # noqa: C901
167
210
  cls,
168
211
  output_spec: OutputSpec[OutputDataT],
169
212
  *,
@@ -173,117 +216,93 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
173
216
  strict: bool | None = None,
174
217
  ) -> BaseOutputSchema[OutputDataT]:
175
218
  """Build an OutputSchema dataclass from an output type."""
176
- if output_spec is str:
177
- return PlainTextOutputSchema()
219
+ raw_outputs = _flatten_output_spec(output_spec)
220
+
221
+ outputs = [output for output in raw_outputs if output is not DeferredToolCalls]
222
+ allows_deferred_tool_calls = len(outputs) < len(raw_outputs)
223
+ if len(outputs) == 0 and allows_deferred_tool_calls:
224
+ raise UserError('At least one output type must be provided other than `DeferredToolCalls`.')
225
+
226
+ if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
227
+ if len(outputs) > 1:
228
+ raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
178
229
 
179
- if isinstance(output_spec, NativeOutput):
180
230
  return NativeOutputSchema(
181
- cls._build_processor(
182
- _flatten_output_spec(output_spec.outputs),
183
- name=output_spec.name,
184
- description=output_spec.description,
185
- strict=output_spec.strict,
186
- )
231
+ processor=cls._build_processor(
232
+ _flatten_output_spec(output.outputs),
233
+ name=output.name,
234
+ description=output.description,
235
+ strict=output.strict,
236
+ ),
237
+ allows_deferred_tool_calls=allows_deferred_tool_calls,
187
238
  )
188
- elif isinstance(output_spec, PromptedOutput):
239
+ elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
240
+ if len(outputs) > 1:
241
+ raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
242
+
189
243
  return PromptedOutputSchema(
190
- cls._build_processor(
191
- _flatten_output_spec(output_spec.outputs),
192
- name=output_spec.name,
193
- description=output_spec.description,
244
+ processor=cls._build_processor(
245
+ _flatten_output_spec(output.outputs),
246
+ name=output.name,
247
+ description=output.description,
194
248
  ),
195
- template=output_spec.template,
249
+ template=output.template,
250
+ allows_deferred_tool_calls=allows_deferred_tool_calls,
196
251
  )
197
252
 
198
253
  text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
199
254
  tool_outputs: Sequence[ToolOutput[OutputDataT]] = []
200
255
  other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = []
201
- for output in _flatten_output_spec(output_spec):
256
+ for output in outputs:
202
257
  if output is str:
203
258
  text_outputs.append(cast(type[str], output))
204
259
  elif isinstance(output, TextOutput):
205
260
  text_outputs.append(output)
206
261
  elif isinstance(output, ToolOutput):
207
262
  tool_outputs.append(output)
263
+ elif isinstance(output, NativeOutput):
264
+ # We can never get here because this is checked for above.
265
+ raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
266
+ elif isinstance(output, PromptedOutput):
267
+ # We can never get here because this is checked for above.
268
+ raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
208
269
  else:
209
270
  other_outputs.append(output)
210
271
 
211
- tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict)
272
+ toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict)
212
273
 
213
274
  if len(text_outputs) > 0:
214
275
  if len(text_outputs) > 1:
215
- raise UserError('Only one text output is allowed.')
276
+ raise UserError('Only one `str` or `TextOutput` is allowed.')
216
277
  text_output = text_outputs[0]
217
278
 
218
279
  text_output_schema = None
219
280
  if isinstance(text_output, TextOutput):
220
281
  text_output_schema = PlainTextOutputProcessor(text_output.output_function)
221
282
 
222
- if len(tools) == 0:
223
- return PlainTextOutputSchema(text_output_schema)
283
+ if toolset:
284
+ return ToolOrTextOutputSchema(
285
+ processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls
286
+ )
224
287
  else:
225
- return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools)
288
+ return PlainTextOutputSchema(
289
+ processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls
290
+ )
226
291
 
227
292
  if len(tool_outputs) > 0:
228
- return ToolOutputSchema(tools)
293
+ return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
229
294
 
230
295
  if len(other_outputs) > 0:
231
296
  schema = OutputSchemaWithoutMode(
232
297
  processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
233
- tools=tools,
298
+ toolset=toolset,
299
+ allows_deferred_tool_calls=allows_deferred_tool_calls,
234
300
  )
235
301
  if default_mode:
236
302
  schema = schema.with_default_mode(default_mode)
237
303
  return schema
238
304
 
239
- raise UserError('No output type provided.') # pragma: no cover
240
-
241
- @staticmethod
242
- def _build_tools(
243
- outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
244
- name: str | None = None,
245
- description: str | None = None,
246
- strict: bool | None = None,
247
- ) -> dict[str, OutputTool[OutputDataT]]:
248
- tools: dict[str, OutputTool[OutputDataT]] = {}
249
-
250
- default_name = name or DEFAULT_OUTPUT_TOOL_NAME
251
- default_description = description
252
- default_strict = strict
253
-
254
- multiple = len(outputs) > 1
255
- for output in outputs:
256
- name = None
257
- description = None
258
- strict = None
259
- if isinstance(output, ToolOutput):
260
- # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
261
- name = output.name
262
- description = output.description
263
- strict = output.strict
264
-
265
- output = output.output
266
-
267
- description = description or default_description
268
- if strict is None:
269
- strict = default_strict
270
-
271
- processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
272
-
273
- if name is None:
274
- name = default_name
275
- if multiple:
276
- name += f'_{processor.object_def.name}'
277
-
278
- i = 1
279
- original_name = name
280
- while name in tools:
281
- i += 1
282
- name = f'{original_name}_{i}'
283
-
284
- tools[name] = OutputTool(name=name, processor=processor, multiple=multiple)
285
-
286
- return tools
305
+ raise UserError('At least one output type must be provided.')
287
306
 
288
307
  @staticmethod
289
308
  def _build_processor(
@@ -315,32 +334,39 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
315
334
  @dataclass(init=False)
316
335
  class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
317
336
  processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
318
- _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
337
+ _toolset: OutputToolset[Any] | None
319
338
 
320
339
  def __init__(
321
340
  self,
322
341
  processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
323
- tools: dict[str, OutputTool[OutputDataT]],
342
+ toolset: OutputToolset[Any] | None,
343
+ allows_deferred_tool_calls: bool,
324
344
  ):
345
+ super().__init__(allows_deferred_tool_calls)
325
346
  self.processor = processor
326
- self._tools = tools
347
+ self._toolset = toolset
327
348
 
328
349
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
329
350
  if mode == 'native':
330
- return NativeOutputSchema(self.processor)
351
+ return NativeOutputSchema(
352
+ processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
353
+ )
331
354
  elif mode == 'prompted':
332
- return PromptedOutputSchema(self.processor)
355
+ return PromptedOutputSchema(
356
+ processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
357
+ )
333
358
  elif mode == 'tool':
334
- return ToolOutputSchema(self.tools)
359
+ return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls)
335
360
  else:
336
361
  assert_never(mode)
337
362
 
338
363
  @property
339
- def tools(self) -> dict[str, OutputTool[OutputDataT]]:
340
- """Get the tools for this output schema."""
341
- # We return tools here as they're checked in Agent._register_tool.
342
- # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time.
343
- return self._tools
364
+ def toolset(self) -> OutputToolset[Any] | None:
365
+ """Get the toolset for this output schema."""
366
+ # We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
367
+ # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
368
+ # but we cover ourselves just in case we end up using the tool output mode.
369
+ return self._toolset
344
370
 
345
371
 
346
372
  class TextOutputSchema(OutputSchema[OutputDataT], ABC):
@@ -411,7 +437,7 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
411
437
  def raise_if_unsupported(self, profile: ModelProfile) -> None:
412
438
  """Raise an error if the mode is not supported by the model."""
413
439
  if not profile.supports_json_schema_output:
414
- raise UserError('Structured output is not supported by the model.')
440
+ raise UserError('Native structured output is not supported by the model.')
415
441
 
416
442
  async def process(
417
443
  self,
@@ -491,10 +517,11 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
491
517
 
492
518
  @dataclass(init=False)
493
519
  class ToolOutputSchema(OutputSchema[OutputDataT]):
494
- _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
520
+ _toolset: OutputToolset[Any] | None
495
521
 
496
- def __init__(self, tools: dict[str, OutputTool[OutputDataT]]):
497
- self._tools = tools
522
+ def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool):
523
+ super().__init__(allows_deferred_tool_calls)
524
+ self._toolset = toolset
498
525
 
499
526
  @property
500
527
  def mode(self) -> OutputMode:
@@ -506,36 +533,9 @@ class ToolOutputSchema(OutputSchema[OutputDataT]):
506
533
  raise UserError('Output tools are not supported by the model.')
507
534
 
508
535
  @property
509
- def tools(self) -> dict[str, OutputTool[OutputDataT]]:
510
- """Get the tools for this output schema."""
511
- return self._tools
512
-
513
- def tool_names(self) -> list[str]:
514
- """Return the names of the tools."""
515
- return list(self.tools.keys())
516
-
517
- def tool_defs(self) -> list[ToolDefinition]:
518
- """Get tool definitions to register with the model."""
519
- return [t.tool_def for t in self.tools.values()]
520
-
521
- def find_named_tool(
522
- self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
523
- ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
524
- """Find a tool that matches one of the calls, with a specific name."""
525
- for part in parts: # pragma: no branch
526
- if isinstance(part, _messages.ToolCallPart): # pragma: no branch
527
- if part.tool_name == tool_name:
528
- return part, self.tools[tool_name]
529
-
530
- def find_tool(
531
- self,
532
- parts: Iterable[_messages.ModelResponsePart],
533
- ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]:
534
- """Find a tool that matches one of the calls."""
535
- for part in parts:
536
- if isinstance(part, _messages.ToolCallPart): # pragma: no branch
537
- if result := self.tools.get(part.tool_name):
538
- yield part, result
536
+ def toolset(self) -> OutputToolset[Any] | None:
537
+ """Get the toolset for this output schema."""
538
+ return self._toolset
539
539
 
540
540
 
541
541
  @dataclass(init=False)
@@ -543,10 +543,11 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem
543
543
  def __init__(
544
544
  self,
545
545
  processor: PlainTextOutputProcessor[OutputDataT] | None,
546
- tools: dict[str, OutputTool[OutputDataT]],
546
+ toolset: OutputToolset[Any] | None,
547
+ allows_deferred_tool_calls: bool,
547
548
  ):
549
+ super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
548
550
  self.processor = processor
549
- self._tools = tools
550
551
 
551
552
  @property
552
553
  def mode(self) -> OutputMode:
@@ -579,7 +580,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
579
580
  class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
580
581
  object_def: OutputObjectDefinition
581
582
  outer_typed_dict_key: str | None = None
582
- _validator: SchemaValidator
583
+ validator: SchemaValidator
583
584
  _function_schema: _function_schema.FunctionSchema | None = None
584
585
 
585
586
  def __init__(
@@ -592,7 +593,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
592
593
  ):
593
594
  if inspect.isfunction(output) or inspect.ismethod(output):
594
595
  self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema)
595
- self._validator = self._function_schema.validator
596
+ self.validator = self._function_schema.validator
596
597
  json_schema = self._function_schema.json_schema
597
598
  json_schema['description'] = self._function_schema.description
598
599
  else:
@@ -608,7 +609,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
608
609
  type_adapter = TypeAdapter(response_data_typed_dict)
609
610
 
610
611
  # Really a PluggableSchemaValidator, but it's API-compatible
611
- self._validator = cast(SchemaValidator, type_adapter.validator)
612
+ self.validator = cast(SchemaValidator, type_adapter.validator)
612
613
  json_schema = _utils.check_object_json_schema(
613
614
  type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
614
615
  )
@@ -652,11 +653,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
652
653
  Either the validated output data (left) or a retry message (right).
653
654
  """
654
655
  try:
655
- pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
656
- if isinstance(data, str):
657
- output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
658
- else:
659
- output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
656
+ output = self.validate(data, allow_partial)
660
657
  except ValidationError as e:
661
658
  if wrap_validation_errors:
662
659
  m = _messages.RetryPromptPart(
@@ -664,22 +661,42 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
664
661
  )
665
662
  raise ToolRetryError(m) from e
666
663
  else:
667
- raise # pragma: lax no cover
664
+ raise
665
+
666
+ try:
667
+ output = await self.call(output, run_context)
668
+ except ModelRetry as r:
669
+ if wrap_validation_errors:
670
+ m = _messages.RetryPromptPart(
671
+ content=r.message,
672
+ )
673
+ raise ToolRetryError(m) from r
674
+ else:
675
+ raise # pragma: no cover
676
+
677
+ return output
678
+
679
+ def validate(
680
+ self,
681
+ data: str | dict[str, Any] | None,
682
+ allow_partial: bool = False,
683
+ ) -> dict[str, Any]:
684
+ pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
685
+ if isinstance(data, str):
686
+ return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
687
+ else:
688
+ return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
668
689
 
690
+ async def call(
691
+ self,
692
+ output: Any,
693
+ run_context: RunContext[AgentDepsT],
694
+ ):
669
695
  if k := self.outer_typed_dict_key:
670
696
  output = output[k]
671
697
 
672
698
  if self._function_schema:
673
- try:
674
- output = await self._function_schema.call(output, run_context)
675
- except ModelRetry as r:
676
- if wrap_validation_errors:
677
- m = _messages.RetryPromptPart(
678
- content=r.message,
679
- )
680
- raise ToolRetryError(m) from r
681
- else:
682
- raise # pragma: lax no cover
699
+ output = await execute_output_function_with_span(self._function_schema, run_context, output)
683
700
 
684
701
  return output
685
702
 
@@ -839,9 +856,8 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
839
856
  wrap_validation_errors: bool = True,
840
857
  ) -> OutputDataT:
841
858
  args = {self._str_argument_name: data}
842
-
843
859
  try:
844
- output = await self._function_schema.call(args, run_context)
860
+ output = await execute_output_function_with_span(self._function_schema, run_context, args)
845
861
  except ModelRetry as r:
846
862
  if wrap_validation_errors:
847
863
  m = _messages.RetryPromptPart(
@@ -849,91 +865,145 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
849
865
  )
850
866
  raise ToolRetryError(m) from r
851
867
  else:
852
- raise # pragma: lax no cover
868
+ raise # pragma: no cover
853
869
 
854
870
  return cast(OutputDataT, output)
855
871
 
856
872
 
857
873
  @dataclass(init=False)
858
- class OutputTool(Generic[OutputDataT]):
859
- processor: ObjectOutputProcessor[OutputDataT]
860
- tool_def: ToolDefinition
874
+ class OutputToolset(AbstractToolset[AgentDepsT]):
875
+ """A toolset that contains contains output tools for agent output types."""
861
876
 
862
- def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool):
863
- self.processor = processor
864
- object_def = processor.object_def
877
+ _tool_defs: list[ToolDefinition]
878
+ """The tool definitions for the output tools in this toolset."""
879
+ processors: dict[str, ObjectOutputProcessor[Any]]
880
+ """The processors for the output tools in this toolset."""
881
+ max_retries: int
882
+ output_validators: list[OutputValidator[AgentDepsT, Any]]
865
883
 
866
- description = object_def.description
867
- if not description:
868
- description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
869
- if multiple:
870
- description = f'{object_def.name}: {description}'
884
+ @classmethod
885
+ def build(
886
+ cls,
887
+ outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
888
+ name: str | None = None,
889
+ description: str | None = None,
890
+ strict: bool | None = None,
891
+ ) -> Self | None:
892
+ if len(outputs) == 0:
893
+ return None
871
894
 
872
- self.tool_def = ToolDefinition(
873
- name=name,
874
- description=description,
875
- parameters_json_schema=object_def.json_schema,
876
- strict=object_def.strict,
877
- outer_typed_dict_key=processor.outer_typed_dict_key,
878
- )
895
+ processors: dict[str, ObjectOutputProcessor[Any]] = {}
896
+ tool_defs: list[ToolDefinition] = []
879
897
 
880
- async def process(
881
- self,
882
- tool_call: _messages.ToolCallPart,
883
- run_context: RunContext[AgentDepsT],
884
- allow_partial: bool = False,
885
- wrap_validation_errors: bool = True,
886
- ) -> OutputDataT:
887
- """Process an output message.
898
+ default_name = name or DEFAULT_OUTPUT_TOOL_NAME
899
+ default_description = description
900
+ default_strict = strict
888
901
 
889
- Args:
890
- tool_call: The tool call from the LLM to validate.
891
- run_context: The current run context.
892
- allow_partial: If true, allow partial validation.
893
- wrap_validation_errors: If true, wrap the validation errors in a retry message.
902
+ multiple = len(outputs) > 1
903
+ for output in outputs:
904
+ name = None
905
+ description = None
906
+ strict = None
907
+ if isinstance(output, ToolOutput):
908
+ # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
909
+ name = output.name
910
+ description = output.description
911
+ strict = output.strict
894
912
 
895
- Returns:
896
- Either the validated output data (left) or a retry message (right).
897
- """
898
- try:
899
- output = await self.processor.process(
900
- tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False
913
+ output = output.output
914
+
915
+ description = description or default_description
916
+ if strict is None:
917
+ strict = default_strict
918
+
919
+ processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
920
+ object_def = processor.object_def
921
+
922
+ if name is None:
923
+ name = default_name
924
+ if multiple:
925
+ name += f'_{object_def.name}'
926
+
927
+ i = 1
928
+ original_name = name
929
+ while name in processors:
930
+ i += 1
931
+ name = f'{original_name}_{i}'
932
+
933
+ description = object_def.description
934
+ if not description:
935
+ description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
936
+ if multiple:
937
+ description = f'{object_def.name}: {description}'
938
+
939
+ tool_def = ToolDefinition(
940
+ name=name,
941
+ description=description,
942
+ parameters_json_schema=object_def.json_schema,
943
+ strict=object_def.strict,
944
+ outer_typed_dict_key=processor.outer_typed_dict_key,
945
+ kind='output',
901
946
  )
902
- except ValidationError as e:
903
- if wrap_validation_errors:
904
- m = _messages.RetryPromptPart(
905
- tool_name=tool_call.tool_name,
906
- content=e.errors(include_url=False, include_context=False),
907
- tool_call_id=tool_call.tool_call_id,
908
- )
909
- raise ToolRetryError(m) from e
910
- else:
911
- raise # pragma: lax no cover
912
- except ModelRetry as r:
913
- if wrap_validation_errors:
914
- m = _messages.RetryPromptPart(
915
- tool_name=tool_call.tool_name,
916
- content=r.message,
917
- tool_call_id=tool_call.tool_call_id,
918
- )
919
- raise ToolRetryError(m) from r
920
- else:
921
- raise # pragma: lax no cover
922
- else:
923
- return output
947
+ processors[name] = processor
948
+ tool_defs.append(tool_def)
949
+
950
+ return cls(processors=processors, tool_defs=tool_defs)
951
+
952
+ def __init__(
953
+ self,
954
+ tool_defs: list[ToolDefinition],
955
+ processors: dict[str, ObjectOutputProcessor[Any]],
956
+ max_retries: int = 1,
957
+ output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None,
958
+ ):
959
+ self.processors = processors
960
+ self._tool_defs = tool_defs
961
+ self.max_retries = max_retries
962
+ self.output_validators = output_validators or []
963
+
964
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
965
+ return {
966
+ tool_def.name: ToolsetTool(
967
+ toolset=self,
968
+ tool_def=tool_def,
969
+ max_retries=self.max_retries,
970
+ args_validator=self.processors[tool_def.name].validator,
971
+ )
972
+ for tool_def in self._tool_defs
973
+ }
974
+
975
+ async def call_tool(
976
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
977
+ ) -> Any:
978
+ output = await self.processors[name].call(tool_args, ctx)
979
+ for validator in self.output_validators:
980
+ output = await validator.validate(output, ctx, wrap_validation_errors=False)
981
+ return output
982
+
983
+
984
+ @overload
985
+ def _flatten_output_spec(
986
+ output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]],
987
+ ) -> Sequence[OutputTypeOrFunction[T]]: ...
988
+
989
+
990
+ @overload
991
+ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ...
924
992
 
925
993
 
926
- def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]:
927
- outputs: Sequence[T]
994
+ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]:
995
+ outputs: Sequence[OutputSpec[T]]
928
996
  if isinstance(output_spec, Sequence):
929
997
  outputs = output_spec
930
998
  else:
931
999
  outputs = (output_spec,)
932
1000
 
933
- outputs_flat: list[T] = []
1001
+ outputs_flat: list[_OutputSpecItem[T]] = []
934
1002
  for output in outputs:
935
- if union_types := _utils.get_union_args(output):
1003
+ if isinstance(output, Sequence):
1004
+ outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output)))
1005
+ elif union_types := _utils.get_union_args(output):
936
1006
  outputs_flat.extend(union_types)
937
1007
  else:
938
- outputs_flat.append(output)
1008
+ outputs_flat.append(cast(_OutputSpecItem[T], output))
939
1009
  return outputs_flat