pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (45) hide show
  1. pydantic_ai/_agent_graph.py +220 -319
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +295 -331
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -156
  10. pydantic_ai/exceptions.py +12 -0
  11. pydantic_ai/ext/aci.py +12 -3
  12. pydantic_ai/ext/langchain.py +9 -1
  13. pydantic_ai/mcp.py +147 -84
  14. pydantic_ai/messages.py +13 -5
  15. pydantic_ai/models/__init__.py +30 -18
  16. pydantic_ai/models/anthropic.py +1 -1
  17. pydantic_ai/models/function.py +50 -24
  18. pydantic_ai/models/gemini.py +1 -9
  19. pydantic_ai/models/google.py +2 -11
  20. pydantic_ai/models/groq.py +1 -0
  21. pydantic_ai/models/mistral.py +1 -1
  22. pydantic_ai/models/openai.py +3 -3
  23. pydantic_ai/output.py +21 -7
  24. pydantic_ai/profiles/google.py +1 -1
  25. pydantic_ai/profiles/moonshotai.py +8 -0
  26. pydantic_ai/providers/grok.py +13 -1
  27. pydantic_ai/providers/groq.py +2 -0
  28. pydantic_ai/result.py +58 -45
  29. pydantic_ai/tools.py +26 -119
  30. pydantic_ai/toolsets/__init__.py +22 -0
  31. pydantic_ai/toolsets/abstract.py +155 -0
  32. pydantic_ai/toolsets/combined.py +88 -0
  33. pydantic_ai/toolsets/deferred.py +38 -0
  34. pydantic_ai/toolsets/filtered.py +24 -0
  35. pydantic_ai/toolsets/function.py +238 -0
  36. pydantic_ai/toolsets/prefixed.py +37 -0
  37. pydantic_ai/toolsets/prepared.py +36 -0
  38. pydantic_ai/toolsets/renamed.py +42 -0
  39. pydantic_ai/toolsets/wrapper.py +37 -0
  40. pydantic_ai/usage.py +14 -8
  41. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
  42. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
  43. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  44. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  45. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py CHANGED
@@ -1,24 +1,21 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import dataclasses
4
3
  import inspect
5
4
  import json
6
5
  from abc import ABC, abstractmethod
7
- from collections.abc import Awaitable, Iterable, Iterator, Sequence
6
+ from collections.abc import Awaitable, Sequence
8
7
  from dataclasses import dataclass, field
9
8
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
10
9
 
11
- from opentelemetry.trace import Tracer
12
10
  from pydantic import TypeAdapter, ValidationError
13
- from pydantic_core import SchemaValidator
14
- from typing_extensions import TypedDict, TypeVar, assert_never
15
-
16
- from pydantic_graph.nodes import GraphRunContext
11
+ from pydantic_core import SchemaValidator, to_json
12
+ from typing_extensions import Self, TypedDict, TypeVar, assert_never
17
13
 
18
14
  from . import _function_schema, _utils, messages as _messages
19
15
  from ._run_context import AgentDepsT, RunContext
20
- from .exceptions import ModelRetry, UserError
16
+ from .exceptions import ModelRetry, ToolRetryError, UserError
21
17
  from .output import (
18
+ DeferredToolCalls,
22
19
  NativeOutput,
23
20
  OutputDataT,
24
21
  OutputMode,
@@ -29,12 +26,12 @@ from .output import (
29
26
  TextOutput,
30
27
  TextOutputFunc,
31
28
  ToolOutput,
29
+ _OutputSpecItem, # type: ignore[reportPrivateUsage]
32
30
  )
33
31
  from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
32
+ from .toolsets.abstract import AbstractToolset, ToolsetTool
34
33
 
35
34
  if TYPE_CHECKING:
36
- from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState
37
-
38
35
  from .profiles import ModelProfile
39
36
 
40
37
  T = TypeVar('T')
@@ -72,77 +69,45 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
72
69
  DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
73
70
 
74
71
 
75
- @dataclass(frozen=True)
76
- class TraceContext:
77
- """A context for tracing output processing."""
72
+ async def execute_output_function_with_span(
73
+ function_schema: _function_schema.FunctionSchema,
74
+ run_context: RunContext[AgentDepsT],
75
+ args: dict[str, Any] | Any,
76
+ ) -> Any:
77
+ """Execute a function call within a traced span, automatically recording the response."""
78
+ # Set up span attributes
79
+ tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function')
80
+ attributes = {
81
+ 'gen_ai.tool.name': tool_name,
82
+ 'logfire.msg': f'running output function: {tool_name}',
83
+ }
84
+ if run_context.tool_call_id:
85
+ attributes['gen_ai.tool.call.id'] = run_context.tool_call_id
86
+ if run_context.trace_include_content:
87
+ attributes['tool_arguments'] = to_json(args).decode()
88
+ attributes['logfire.json_schema'] = json.dumps(
89
+ {
90
+ 'type': 'object',
91
+ 'properties': {
92
+ 'tool_arguments': {'type': 'object'},
93
+ 'tool_response': {'type': 'object'},
94
+ },
95
+ }
96
+ )
78
97
 
79
- tracer: Tracer
80
- include_content: bool
81
- call: _messages.ToolCallPart | None = None
98
+ with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span:
99
+ output = await function_schema.call(args, run_context)
82
100
 
83
- def with_call(self, call: _messages.ToolCallPart):
84
- return dataclasses.replace(self, call=call)
101
+ # Record response if content inclusion is enabled
102
+ if run_context.trace_include_content and span.is_recording():
103
+ from .models.instrumented import InstrumentedModel
85
104
 
86
- async def execute_function_with_span(
87
- self,
88
- function_schema: _function_schema.FunctionSchema,
89
- run_context: RunContext[AgentDepsT],
90
- args: dict[str, Any] | Any,
91
- call: _messages.ToolCallPart,
92
- include_tool_call_id: bool = True,
93
- ) -> Any:
94
- """Execute a function call within a traced span, automatically recording the response."""
95
- # Set up span attributes
96
- attributes = {
97
- 'gen_ai.tool.name': call.tool_name,
98
- 'logfire.msg': f'running output function: {call.tool_name}',
99
- }
100
- if include_tool_call_id:
101
- attributes['gen_ai.tool.call.id'] = call.tool_call_id
102
- if self.include_content:
103
- attributes['tool_arguments'] = call.args_as_json_str()
104
- attributes['logfire.json_schema'] = json.dumps(
105
- {
106
- 'type': 'object',
107
- 'properties': {
108
- 'tool_arguments': {'type': 'object'},
109
- 'tool_response': {'type': 'object'},
110
- },
111
- }
105
+ span.set_attribute(
106
+ 'tool_response',
107
+ output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
112
108
  )
113
109
 
114
- # Execute function within span
115
- with self.tracer.start_as_current_span('running output function', attributes=attributes) as span:
116
- output = await function_schema.call(args, run_context)
117
-
118
- # Record response if content inclusion is enabled
119
- if self.include_content and span.is_recording():
120
- from .models.instrumented import InstrumentedModel
121
-
122
- span.set_attribute(
123
- 'tool_response',
124
- output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
125
- )
126
-
127
- return output
128
-
129
-
130
- def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext:
131
- """Build a `TraceContext` from the current agent graph run context."""
132
- return TraceContext(
133
- tracer=ctx.deps.tracer,
134
- include_content=(
135
- ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
136
- ),
137
- )
138
-
139
-
140
- class ToolRetryError(Exception):
141
- """Exception used to signal a `ToolRetry` message should be returned to the LLM."""
142
-
143
- def __init__(self, tool_retry: _messages.RetryPromptPart):
144
- self.tool_retry = tool_retry
145
- super().__init__()
110
+ return output
146
111
 
147
112
 
148
113
  @dataclass
@@ -158,23 +123,21 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
158
123
  async def validate(
159
124
  self,
160
125
  result: T,
161
- tool_call: _messages.ToolCallPart | None,
162
126
  run_context: RunContext[AgentDepsT],
127
+ wrap_validation_errors: bool = True,
163
128
  ) -> T:
164
129
  """Validate a result but calling the function.
165
130
 
166
131
  Args:
167
132
  result: The result data after Pydantic validation the message content.
168
- tool_call: The original tool call message, `None` if there was no tool call.
169
133
  run_context: The current run context.
170
- trace_context: The trace context to use for tracing the output processing.
134
+ wrap_validation_errors: If true, wrap the validation errors in a retry message.
171
135
 
172
136
  Returns:
173
137
  Result of either the validated result data (ok) or a retry message (Err).
174
138
  """
175
139
  if self._takes_ctx:
176
- ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
177
- args = ctx, result
140
+ args = run_context, result
178
141
  else:
179
142
  args = (result,)
180
143
 
@@ -186,24 +149,32 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
186
149
  function = cast(Callable[[Any], T], self.function)
187
150
  result_data = await _utils.run_in_executor(function, *args)
188
151
  except ModelRetry as r:
189
- m = _messages.RetryPromptPart(content=r.message)
190
- if tool_call is not None:
191
- m.tool_name = tool_call.tool_name
192
- m.tool_call_id = tool_call.tool_call_id
193
- raise ToolRetryError(m) from r
152
+ if wrap_validation_errors:
153
+ m = _messages.RetryPromptPart(
154
+ content=r.message,
155
+ tool_name=run_context.tool_name,
156
+ )
157
+ if run_context.tool_call_id: # pragma: no cover
158
+ m.tool_call_id = run_context.tool_call_id
159
+ raise ToolRetryError(m) from r
160
+ else:
161
+ raise r
194
162
  else:
195
163
  return result_data
196
164
 
197
165
 
166
+ @dataclass
198
167
  class BaseOutputSchema(ABC, Generic[OutputDataT]):
168
+ allows_deferred_tool_calls: bool
169
+
199
170
  @abstractmethod
200
171
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
201
172
  raise NotImplementedError()
202
173
 
203
174
  @property
204
- def tools(self) -> dict[str, OutputTool[OutputDataT]]:
205
- """Get the tools for this output schema."""
206
- return {}
175
+ def toolset(self) -> OutputToolset[Any] | None:
176
+ """Get the toolset for this output schema."""
177
+ return None
207
178
 
208
179
 
209
180
  @dataclass(init=False)
@@ -235,7 +206,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
235
206
  ) -> BaseOutputSchema[OutputDataT]: ...
236
207
 
237
208
  @classmethod
238
- def build(
209
+ def build( # noqa: C901
239
210
  cls,
240
211
  output_spec: OutputSpec[OutputDataT],
241
212
  *,
@@ -245,117 +216,93 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
245
216
  strict: bool | None = None,
246
217
  ) -> BaseOutputSchema[OutputDataT]:
247
218
  """Build an OutputSchema dataclass from an output type."""
248
- if output_spec is str:
249
- return PlainTextOutputSchema()
219
+ raw_outputs = _flatten_output_spec(output_spec)
220
+
221
+ outputs = [output for output in raw_outputs if output is not DeferredToolCalls]
222
+ allows_deferred_tool_calls = len(outputs) < len(raw_outputs)
223
+ if len(outputs) == 0 and allows_deferred_tool_calls:
224
+ raise UserError('At least one output type must be provided other than `DeferredToolCalls`.')
225
+
226
+ if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
227
+ if len(outputs) > 1:
228
+ raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
250
229
 
251
- if isinstance(output_spec, NativeOutput):
252
230
  return NativeOutputSchema(
253
- cls._build_processor(
254
- _flatten_output_spec(output_spec.outputs),
255
- name=output_spec.name,
256
- description=output_spec.description,
257
- strict=output_spec.strict,
258
- )
231
+ processor=cls._build_processor(
232
+ _flatten_output_spec(output.outputs),
233
+ name=output.name,
234
+ description=output.description,
235
+ strict=output.strict,
236
+ ),
237
+ allows_deferred_tool_calls=allows_deferred_tool_calls,
259
238
  )
260
- elif isinstance(output_spec, PromptedOutput):
239
+ elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
240
+ if len(outputs) > 1:
241
+ raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
242
+
261
243
  return PromptedOutputSchema(
262
- cls._build_processor(
263
- _flatten_output_spec(output_spec.outputs),
264
- name=output_spec.name,
265
- description=output_spec.description,
244
+ processor=cls._build_processor(
245
+ _flatten_output_spec(output.outputs),
246
+ name=output.name,
247
+ description=output.description,
266
248
  ),
267
- template=output_spec.template,
249
+ template=output.template,
250
+ allows_deferred_tool_calls=allows_deferred_tool_calls,
268
251
  )
269
252
 
270
253
  text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
271
254
  tool_outputs: Sequence[ToolOutput[OutputDataT]] = []
272
255
  other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = []
273
- for output in _flatten_output_spec(output_spec):
256
+ for output in outputs:
274
257
  if output is str:
275
258
  text_outputs.append(cast(type[str], output))
276
259
  elif isinstance(output, TextOutput):
277
260
  text_outputs.append(output)
278
261
  elif isinstance(output, ToolOutput):
279
262
  tool_outputs.append(output)
263
+ elif isinstance(output, NativeOutput):
264
+ # We can never get here because this is checked for above.
265
+ raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
266
+ elif isinstance(output, PromptedOutput):
267
+ # We can never get here because this is checked for above.
268
+ raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
280
269
  else:
281
270
  other_outputs.append(output)
282
271
 
283
- tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict)
272
+ toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict)
284
273
 
285
274
  if len(text_outputs) > 0:
286
275
  if len(text_outputs) > 1:
287
- raise UserError('Only one text output is allowed.')
276
+ raise UserError('Only one `str` or `TextOutput` is allowed.')
288
277
  text_output = text_outputs[0]
289
278
 
290
279
  text_output_schema = None
291
280
  if isinstance(text_output, TextOutput):
292
281
  text_output_schema = PlainTextOutputProcessor(text_output.output_function)
293
282
 
294
- if len(tools) == 0:
295
- return PlainTextOutputSchema(text_output_schema)
283
+ if toolset:
284
+ return ToolOrTextOutputSchema(
285
+ processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls
286
+ )
296
287
  else:
297
- return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools)
288
+ return PlainTextOutputSchema(
289
+ processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls
290
+ )
298
291
 
299
292
  if len(tool_outputs) > 0:
300
- return ToolOutputSchema(tools)
293
+ return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
301
294
 
302
295
  if len(other_outputs) > 0:
303
296
  schema = OutputSchemaWithoutMode(
304
297
  processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
305
- tools=tools,
298
+ toolset=toolset,
299
+ allows_deferred_tool_calls=allows_deferred_tool_calls,
306
300
  )
307
301
  if default_mode:
308
302
  schema = schema.with_default_mode(default_mode)
309
303
  return schema
310
304
 
311
- raise UserError('No output type provided.') # pragma: no cover
312
-
313
- @staticmethod
314
- def _build_tools(
315
- outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
316
- name: str | None = None,
317
- description: str | None = None,
318
- strict: bool | None = None,
319
- ) -> dict[str, OutputTool[OutputDataT]]:
320
- tools: dict[str, OutputTool[OutputDataT]] = {}
321
-
322
- default_name = name or DEFAULT_OUTPUT_TOOL_NAME
323
- default_description = description
324
- default_strict = strict
325
-
326
- multiple = len(outputs) > 1
327
- for output in outputs:
328
- name = None
329
- description = None
330
- strict = None
331
- if isinstance(output, ToolOutput):
332
- # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
333
- name = output.name
334
- description = output.description
335
- strict = output.strict
336
-
337
- output = output.output
338
-
339
- description = description or default_description
340
- if strict is None:
341
- strict = default_strict
342
-
343
- processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
344
-
345
- if name is None:
346
- name = default_name
347
- if multiple:
348
- name += f'_{processor.object_def.name}'
349
-
350
- i = 1
351
- original_name = name
352
- while name in tools:
353
- i += 1
354
- name = f'{original_name}_{i}'
355
-
356
- tools[name] = OutputTool(name=name, processor=processor, multiple=multiple)
357
-
358
- return tools
305
+ raise UserError('At least one output type must be provided.')
359
306
 
360
307
  @staticmethod
361
308
  def _build_processor(
@@ -387,32 +334,39 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
387
334
  @dataclass(init=False)
388
335
  class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
389
336
  processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
390
- _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
337
+ _toolset: OutputToolset[Any] | None
391
338
 
392
339
  def __init__(
393
340
  self,
394
341
  processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
395
- tools: dict[str, OutputTool[OutputDataT]],
342
+ toolset: OutputToolset[Any] | None,
343
+ allows_deferred_tool_calls: bool,
396
344
  ):
345
+ super().__init__(allows_deferred_tool_calls)
397
346
  self.processor = processor
398
- self._tools = tools
347
+ self._toolset = toolset
399
348
 
400
349
  def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
401
350
  if mode == 'native':
402
- return NativeOutputSchema(self.processor)
351
+ return NativeOutputSchema(
352
+ processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
353
+ )
403
354
  elif mode == 'prompted':
404
- return PromptedOutputSchema(self.processor)
355
+ return PromptedOutputSchema(
356
+ processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
357
+ )
405
358
  elif mode == 'tool':
406
- return ToolOutputSchema(self.tools)
359
+ return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls)
407
360
  else:
408
361
  assert_never(mode)
409
362
 
410
363
  @property
411
- def tools(self) -> dict[str, OutputTool[OutputDataT]]:
412
- """Get the tools for this output schema."""
413
- # We return tools here as they're checked in Agent._register_tool.
414
- # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time.
415
- return self._tools
364
+ def toolset(self) -> OutputToolset[Any] | None:
365
+ """Get the toolset for this output schema."""
366
+ # We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
367
+ # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
368
+ # but we cover ourselves just in case we end up using the tool output mode.
369
+ return self._toolset
416
370
 
417
371
 
418
372
  class TextOutputSchema(OutputSchema[OutputDataT], ABC):
@@ -421,7 +375,6 @@ class TextOutputSchema(OutputSchema[OutputDataT], ABC):
421
375
  self,
422
376
  text: str,
423
377
  run_context: RunContext[AgentDepsT],
424
- trace_context: TraceContext,
425
378
  allow_partial: bool = False,
426
379
  wrap_validation_errors: bool = True,
427
380
  ) -> OutputDataT:
@@ -444,7 +397,6 @@ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
444
397
  self,
445
398
  text: str,
446
399
  run_context: RunContext[AgentDepsT],
447
- trace_context: TraceContext,
448
400
  allow_partial: bool = False,
449
401
  wrap_validation_errors: bool = True,
450
402
  ) -> OutputDataT:
@@ -453,7 +405,6 @@ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
453
405
  Args:
454
406
  text: The output text to validate.
455
407
  run_context: The current run context.
456
- trace_context: The trace context to use for tracing the output processing.
457
408
  allow_partial: If true, allow partial validation.
458
409
  wrap_validation_errors: If true, wrap the validation errors in a retry message.
459
410
 
@@ -464,7 +415,7 @@ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
464
415
  return cast(OutputDataT, text)
465
416
 
466
417
  return await self.processor.process(
467
- text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
418
+ text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
468
419
  )
469
420
 
470
421
 
@@ -486,13 +437,12 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
486
437
  def raise_if_unsupported(self, profile: ModelProfile) -> None:
487
438
  """Raise an error if the mode is not supported by the model."""
488
439
  if not profile.supports_json_schema_output:
489
- raise UserError('Structured output is not supported by the model.')
440
+ raise UserError('Native structured output is not supported by the model.')
490
441
 
491
442
  async def process(
492
443
  self,
493
444
  text: str,
494
445
  run_context: RunContext[AgentDepsT],
495
- trace_context: TraceContext,
496
446
  allow_partial: bool = False,
497
447
  wrap_validation_errors: bool = True,
498
448
  ) -> OutputDataT:
@@ -501,7 +451,6 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
501
451
  Args:
502
452
  text: The output text to validate.
503
453
  run_context: The current run context.
504
- trace_context: The trace context to use for tracing the output processing.
505
454
  allow_partial: If true, allow partial validation.
506
455
  wrap_validation_errors: If true, wrap the validation errors in a retry message.
507
456
 
@@ -509,7 +458,7 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
509
458
  Either the validated output data (left) or a retry message (right).
510
459
  """
511
460
  return await self.processor.process(
512
- text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
461
+ text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
513
462
  )
514
463
 
515
464
 
@@ -545,7 +494,6 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
545
494
  self,
546
495
  text: str,
547
496
  run_context: RunContext[AgentDepsT],
548
- trace_context: TraceContext,
549
497
  allow_partial: bool = False,
550
498
  wrap_validation_errors: bool = True,
551
499
  ) -> OutputDataT:
@@ -554,7 +502,6 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
554
502
  Args:
555
503
  text: The output text to validate.
556
504
  run_context: The current run context.
557
- trace_context: The trace context to use for tracing the output processing.
558
505
  allow_partial: If true, allow partial validation.
559
506
  wrap_validation_errors: If true, wrap the validation errors in a retry message.
560
507
 
@@ -564,16 +511,17 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
564
511
  text = _utils.strip_markdown_fences(text)
565
512
 
566
513
  return await self.processor.process(
567
- text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
514
+ text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
568
515
  )
569
516
 
570
517
 
571
518
  @dataclass(init=False)
572
519
  class ToolOutputSchema(OutputSchema[OutputDataT]):
573
- _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
520
+ _toolset: OutputToolset[Any] | None
574
521
 
575
- def __init__(self, tools: dict[str, OutputTool[OutputDataT]]):
576
- self._tools = tools
522
+ def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool):
523
+ super().__init__(allows_deferred_tool_calls)
524
+ self._toolset = toolset
577
525
 
578
526
  @property
579
527
  def mode(self) -> OutputMode:
@@ -585,36 +533,9 @@ class ToolOutputSchema(OutputSchema[OutputDataT]):
585
533
  raise UserError('Output tools are not supported by the model.')
586
534
 
587
535
  @property
588
- def tools(self) -> dict[str, OutputTool[OutputDataT]]:
589
- """Get the tools for this output schema."""
590
- return self._tools
591
-
592
- def tool_names(self) -> list[str]:
593
- """Return the names of the tools."""
594
- return list(self.tools.keys())
595
-
596
- def tool_defs(self) -> list[ToolDefinition]:
597
- """Get tool definitions to register with the model."""
598
- return [t.tool_def for t in self.tools.values()]
599
-
600
- def find_named_tool(
601
- self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
602
- ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
603
- """Find a tool that matches one of the calls, with a specific name."""
604
- for part in parts: # pragma: no branch
605
- if isinstance(part, _messages.ToolCallPart): # pragma: no branch
606
- if part.tool_name == tool_name:
607
- return part, self.tools[tool_name]
608
-
609
- def find_tool(
610
- self,
611
- parts: Iterable[_messages.ModelResponsePart],
612
- ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]:
613
- """Find a tool that matches one of the calls."""
614
- for part in parts:
615
- if isinstance(part, _messages.ToolCallPart): # pragma: no branch
616
- if result := self.tools.get(part.tool_name):
617
- yield part, result
536
+ def toolset(self) -> OutputToolset[Any] | None:
537
+ """Get the toolset for this output schema."""
538
+ return self._toolset
618
539
 
619
540
 
620
541
  @dataclass(init=False)
@@ -622,10 +543,11 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem
622
543
  def __init__(
623
544
  self,
624
545
  processor: PlainTextOutputProcessor[OutputDataT] | None,
625
- tools: dict[str, OutputTool[OutputDataT]],
546
+ toolset: OutputToolset[Any] | None,
547
+ allows_deferred_tool_calls: bool,
626
548
  ):
549
+ super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
627
550
  self.processor = processor
628
- self._tools = tools
629
551
 
630
552
  @property
631
553
  def mode(self) -> OutputMode:
@@ -647,7 +569,6 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
647
569
  self,
648
570
  data: str,
649
571
  run_context: RunContext[AgentDepsT],
650
- trace_context: TraceContext,
651
572
  allow_partial: bool = False,
652
573
  wrap_validation_errors: bool = True,
653
574
  ) -> OutputDataT:
@@ -659,7 +580,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
659
580
  class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
660
581
  object_def: OutputObjectDefinition
661
582
  outer_typed_dict_key: str | None = None
662
- _validator: SchemaValidator
583
+ validator: SchemaValidator
663
584
  _function_schema: _function_schema.FunctionSchema | None = None
664
585
 
665
586
  def __init__(
@@ -672,7 +593,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
672
593
  ):
673
594
  if inspect.isfunction(output) or inspect.ismethod(output):
674
595
  self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema)
675
- self._validator = self._function_schema.validator
596
+ self.validator = self._function_schema.validator
676
597
  json_schema = self._function_schema.json_schema
677
598
  json_schema['description'] = self._function_schema.description
678
599
  else:
@@ -688,7 +609,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
688
609
  type_adapter = TypeAdapter(response_data_typed_dict)
689
610
 
690
611
  # Really a PluggableSchemaValidator, but it's API-compatible
691
- self._validator = cast(SchemaValidator, type_adapter.validator)
612
+ self.validator = cast(SchemaValidator, type_adapter.validator)
692
613
  json_schema = _utils.check_object_json_schema(
693
614
  type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
694
615
  )
@@ -717,7 +638,6 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
717
638
  self,
718
639
  data: str | dict[str, Any] | None,
719
640
  run_context: RunContext[AgentDepsT],
720
- trace_context: TraceContext,
721
641
  allow_partial: bool = False,
722
642
  wrap_validation_errors: bool = True,
723
643
  ) -> OutputDataT:
@@ -726,7 +646,6 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
726
646
  Args:
727
647
  data: The output data to validate.
728
648
  run_context: The current run context.
729
- trace_context: The trace context to use for tracing the output processing.
730
649
  allow_partial: If true, allow partial validation.
731
650
  wrap_validation_errors: If true, wrap the validation errors in a retry message.
732
651
 
@@ -734,11 +653,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
734
653
  Either the validated output data (left) or a retry message (right).
735
654
  """
736
655
  try:
737
- pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
738
- if isinstance(data, str):
739
- output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
740
- else:
741
- output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
656
+ output = self.validate(data, allow_partial)
742
657
  except ValidationError as e:
743
658
  if wrap_validation_errors:
744
659
  m = _messages.RetryPromptPart(
@@ -748,30 +663,40 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
748
663
  else:
749
664
  raise
750
665
 
666
+ try:
667
+ output = await self.call(output, run_context)
668
+ except ModelRetry as r:
669
+ if wrap_validation_errors:
670
+ m = _messages.RetryPromptPart(
671
+ content=r.message,
672
+ )
673
+ raise ToolRetryError(m) from r
674
+ else:
675
+ raise # pragma: no cover
676
+
677
+ return output
678
+
679
+ def validate(
680
+ self,
681
+ data: str | dict[str, Any] | None,
682
+ allow_partial: bool = False,
683
+ ) -> dict[str, Any]:
684
+ pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
685
+ if isinstance(data, str):
686
+ return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
687
+ else:
688
+ return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
689
+
690
+ async def call(
691
+ self,
692
+ output: Any,
693
+ run_context: RunContext[AgentDepsT],
694
+ ):
751
695
  if k := self.outer_typed_dict_key:
752
696
  output = output[k]
753
697
 
754
698
  if self._function_schema:
755
- # Wraps the output function call in an OpenTelemetry span.
756
- if trace_context.call:
757
- call = trace_context.call
758
- include_tool_call_id = True
759
- else:
760
- function_name = getattr(self._function_schema.function, '__name__', 'output_function')
761
- call = _messages.ToolCallPart(tool_name=function_name, args=data)
762
- include_tool_call_id = False
763
- try:
764
- output = await trace_context.execute_function_with_span(
765
- self._function_schema, run_context, output, call, include_tool_call_id
766
- )
767
- except ModelRetry as r:
768
- if wrap_validation_errors:
769
- m = _messages.RetryPromptPart(
770
- content=r.message,
771
- )
772
- raise ToolRetryError(m) from r
773
- else:
774
- raise
699
+ output = await execute_output_function_with_span(self._function_schema, run_context, output)
775
700
 
776
701
  return output
777
702
 
@@ -876,12 +801,11 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
876
801
  self,
877
802
  data: str | dict[str, Any] | None,
878
803
  run_context: RunContext[AgentDepsT],
879
- trace_context: TraceContext,
880
804
  allow_partial: bool = False,
881
805
  wrap_validation_errors: bool = True,
882
806
  ) -> OutputDataT:
883
807
  union_object = await self._union_processor.process(
884
- data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
808
+ data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
885
809
  )
886
810
 
887
811
  result = union_object.result
@@ -897,7 +821,7 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
897
821
  raise
898
822
 
899
823
  return await processor.process(
900
- data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
824
+ data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
901
825
  )
902
826
 
903
827
 
@@ -928,20 +852,12 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
928
852
  self,
929
853
  data: str,
930
854
  run_context: RunContext[AgentDepsT],
931
- trace_context: TraceContext,
932
855
  allow_partial: bool = False,
933
856
  wrap_validation_errors: bool = True,
934
857
  ) -> OutputDataT:
935
858
  args = {self._str_argument_name: data}
936
- # Wraps the output function call in an OpenTelemetry span.
937
- # Note: PlainTextOutputProcessor is used for text responses (not tool calls),
938
- # so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id
939
- function_name = getattr(self._function_schema.function, '__name__', 'text_output_function')
940
- call = _messages.ToolCallPart(tool_name=function_name, args=args)
941
859
  try:
942
- output = await trace_context.execute_function_with_span(
943
- self._function_schema, run_context, args, call, include_tool_call_id=False
944
- )
860
+ output = await execute_output_function_with_span(self._function_schema, run_context, args)
945
861
  except ModelRetry as r:
946
862
  if wrap_validation_errors:
947
863
  m = _messages.RetryPromptPart(
@@ -955,91 +871,139 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
955
871
 
956
872
 
957
873
  @dataclass(init=False)
958
- class OutputTool(Generic[OutputDataT]):
959
- processor: ObjectOutputProcessor[OutputDataT]
960
- tool_def: ToolDefinition
874
+ class OutputToolset(AbstractToolset[AgentDepsT]):
875
+ """A toolset that contains contains output tools for agent output types."""
961
876
 
962
- def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool):
963
- self.processor = processor
964
- object_def = processor.object_def
877
+ _tool_defs: list[ToolDefinition]
878
+ """The tool definitions for the output tools in this toolset."""
879
+ processors: dict[str, ObjectOutputProcessor[Any]]
880
+ """The processors for the output tools in this toolset."""
881
+ max_retries: int
882
+ output_validators: list[OutputValidator[AgentDepsT, Any]]
965
883
 
966
- description = object_def.description
967
- if not description:
968
- description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
969
- if multiple:
970
- description = f'{object_def.name}: {description}'
884
+ @classmethod
885
+ def build(
886
+ cls,
887
+ outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
888
+ name: str | None = None,
889
+ description: str | None = None,
890
+ strict: bool | None = None,
891
+ ) -> Self | None:
892
+ if len(outputs) == 0:
893
+ return None
971
894
 
972
- self.tool_def = ToolDefinition(
973
- name=name,
974
- description=description,
975
- parameters_json_schema=object_def.json_schema,
976
- strict=object_def.strict,
977
- outer_typed_dict_key=processor.outer_typed_dict_key,
978
- )
895
+ processors: dict[str, ObjectOutputProcessor[Any]] = {}
896
+ tool_defs: list[ToolDefinition] = []
979
897
 
980
- async def process(
981
- self,
982
- tool_call: _messages.ToolCallPart,
983
- run_context: RunContext[AgentDepsT],
984
- trace_context: TraceContext,
985
- allow_partial: bool = False,
986
- wrap_validation_errors: bool = True,
987
- ) -> OutputDataT:
988
- """Process an output message.
898
+ default_name = name or DEFAULT_OUTPUT_TOOL_NAME
899
+ default_description = description
900
+ default_strict = strict
989
901
 
990
- Args:
991
- tool_call: The tool call from the LLM to validate.
992
- run_context: The current run context.
993
- trace_context: The trace context to use for tracing the output processing.
994
- allow_partial: If true, allow partial validation.
995
- wrap_validation_errors: If true, wrap the validation errors in a retry message.
902
+ multiple = len(outputs) > 1
903
+ for output in outputs:
904
+ name = None
905
+ description = None
906
+ strict = None
907
+ if isinstance(output, ToolOutput):
908
+ # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
909
+ name = output.name
910
+ description = output.description
911
+ strict = output.strict
996
912
 
997
- Returns:
998
- Either the validated output data (left) or a retry message (right).
999
- """
1000
- try:
1001
- output = await self.processor.process(
1002
- tool_call.args,
1003
- run_context,
1004
- trace_context.with_call(tool_call),
1005
- allow_partial=allow_partial,
1006
- wrap_validation_errors=False,
913
+ output = output.output
914
+
915
+ description = description or default_description
916
+ if strict is None:
917
+ strict = default_strict
918
+
919
+ processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
920
+ object_def = processor.object_def
921
+
922
+ if name is None:
923
+ name = default_name
924
+ if multiple:
925
+ name += f'_{object_def.name}'
926
+
927
+ i = 1
928
+ original_name = name
929
+ while name in processors:
930
+ i += 1
931
+ name = f'{original_name}_{i}'
932
+
933
+ description = object_def.description
934
+ if not description:
935
+ description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
936
+ if multiple:
937
+ description = f'{object_def.name}: {description}'
938
+
939
+ tool_def = ToolDefinition(
940
+ name=name,
941
+ description=description,
942
+ parameters_json_schema=object_def.json_schema,
943
+ strict=object_def.strict,
944
+ outer_typed_dict_key=processor.outer_typed_dict_key,
945
+ kind='output',
1007
946
  )
1008
- except ValidationError as e:
1009
- if wrap_validation_errors:
1010
- m = _messages.RetryPromptPart(
1011
- tool_name=tool_call.tool_name,
1012
- content=e.errors(include_url=False, include_context=False),
1013
- tool_call_id=tool_call.tool_call_id,
1014
- )
1015
- raise ToolRetryError(m) from e
1016
- else:
1017
- raise # pragma: no cover
1018
- except ModelRetry as r:
1019
- if wrap_validation_errors:
1020
- m = _messages.RetryPromptPart(
1021
- tool_name=tool_call.tool_name,
1022
- content=r.message,
1023
- tool_call_id=tool_call.tool_call_id,
1024
- )
1025
- raise ToolRetryError(m) from r
1026
- else:
1027
- raise # pragma: no cover
1028
- else:
1029
- return output
947
+ processors[name] = processor
948
+ tool_defs.append(tool_def)
949
+
950
+ return cls(processors=processors, tool_defs=tool_defs)
951
+
952
+ def __init__(
953
+ self,
954
+ tool_defs: list[ToolDefinition],
955
+ processors: dict[str, ObjectOutputProcessor[Any]],
956
+ max_retries: int = 1,
957
+ output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None,
958
+ ):
959
+ self.processors = processors
960
+ self._tool_defs = tool_defs
961
+ self.max_retries = max_retries
962
+ self.output_validators = output_validators or []
963
+
964
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
965
+ return {
966
+ tool_def.name: ToolsetTool(
967
+ toolset=self,
968
+ tool_def=tool_def,
969
+ max_retries=self.max_retries,
970
+ args_validator=self.processors[tool_def.name].validator,
971
+ )
972
+ for tool_def in self._tool_defs
973
+ }
974
+
975
+ async def call_tool(
976
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
977
+ ) -> Any:
978
+ output = await self.processors[name].call(tool_args, ctx)
979
+ for validator in self.output_validators:
980
+ output = await validator.validate(output, ctx, wrap_validation_errors=False)
981
+ return output
982
+
983
+
984
+ @overload
985
+ def _flatten_output_spec(
986
+ output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]],
987
+ ) -> Sequence[OutputTypeOrFunction[T]]: ...
988
+
989
+
990
+ @overload
991
+ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ...
1030
992
 
1031
993
 
1032
- def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]:
1033
- outputs: Sequence[T]
994
+ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]:
995
+ outputs: Sequence[OutputSpec[T]]
1034
996
  if isinstance(output_spec, Sequence):
1035
997
  outputs = output_spec
1036
998
  else:
1037
999
  outputs = (output_spec,)
1038
1000
 
1039
- outputs_flat: list[T] = []
1001
+ outputs_flat: list[_OutputSpecItem[T]] = []
1040
1002
  for output in outputs:
1041
- if union_types := _utils.get_union_args(output):
1003
+ if isinstance(output, Sequence):
1004
+ outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output)))
1005
+ elif union_types := _utils.get_union_args(output):
1042
1006
  outputs_flat.extend(union_types)
1043
1007
  else:
1044
- outputs_flat.append(output)
1008
+ outputs_flat.append(cast(_OutputSpecItem[T], output))
1045
1009
  return outputs_flat