pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -156
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -9
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +3 -3
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py
CHANGED
|
@@ -1,24 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import dataclasses
|
|
4
3
|
import inspect
|
|
5
4
|
import json
|
|
6
5
|
from abc import ABC, abstractmethod
|
|
7
|
-
from collections.abc import Awaitable,
|
|
6
|
+
from collections.abc import Awaitable, Sequence
|
|
8
7
|
from dataclasses import dataclass, field
|
|
9
8
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
|
|
10
9
|
|
|
11
|
-
from opentelemetry.trace import Tracer
|
|
12
10
|
from pydantic import TypeAdapter, ValidationError
|
|
13
|
-
from pydantic_core import SchemaValidator
|
|
14
|
-
from typing_extensions import TypedDict, TypeVar, assert_never
|
|
15
|
-
|
|
16
|
-
from pydantic_graph.nodes import GraphRunContext
|
|
11
|
+
from pydantic_core import SchemaValidator, to_json
|
|
12
|
+
from typing_extensions import Self, TypedDict, TypeVar, assert_never
|
|
17
13
|
|
|
18
14
|
from . import _function_schema, _utils, messages as _messages
|
|
19
15
|
from ._run_context import AgentDepsT, RunContext
|
|
20
|
-
from .exceptions import ModelRetry, UserError
|
|
16
|
+
from .exceptions import ModelRetry, ToolRetryError, UserError
|
|
21
17
|
from .output import (
|
|
18
|
+
DeferredToolCalls,
|
|
22
19
|
NativeOutput,
|
|
23
20
|
OutputDataT,
|
|
24
21
|
OutputMode,
|
|
@@ -29,12 +26,12 @@ from .output import (
|
|
|
29
26
|
TextOutput,
|
|
30
27
|
TextOutputFunc,
|
|
31
28
|
ToolOutput,
|
|
29
|
+
_OutputSpecItem, # type: ignore[reportPrivateUsage]
|
|
32
30
|
)
|
|
33
31
|
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
|
|
32
|
+
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
34
33
|
|
|
35
34
|
if TYPE_CHECKING:
|
|
36
|
-
from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState
|
|
37
|
-
|
|
38
35
|
from .profiles import ModelProfile
|
|
39
36
|
|
|
40
37
|
T = TypeVar('T')
|
|
@@ -72,77 +69,45 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
|
|
|
72
69
|
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
|
|
73
70
|
|
|
74
71
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
72
|
+
async def execute_output_function_with_span(
|
|
73
|
+
function_schema: _function_schema.FunctionSchema,
|
|
74
|
+
run_context: RunContext[AgentDepsT],
|
|
75
|
+
args: dict[str, Any] | Any,
|
|
76
|
+
) -> Any:
|
|
77
|
+
"""Execute a function call within a traced span, automatically recording the response."""
|
|
78
|
+
# Set up span attributes
|
|
79
|
+
tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function')
|
|
80
|
+
attributes = {
|
|
81
|
+
'gen_ai.tool.name': tool_name,
|
|
82
|
+
'logfire.msg': f'running output function: {tool_name}',
|
|
83
|
+
}
|
|
84
|
+
if run_context.tool_call_id:
|
|
85
|
+
attributes['gen_ai.tool.call.id'] = run_context.tool_call_id
|
|
86
|
+
if run_context.trace_include_content:
|
|
87
|
+
attributes['tool_arguments'] = to_json(args).decode()
|
|
88
|
+
attributes['logfire.json_schema'] = json.dumps(
|
|
89
|
+
{
|
|
90
|
+
'type': 'object',
|
|
91
|
+
'properties': {
|
|
92
|
+
'tool_arguments': {'type': 'object'},
|
|
93
|
+
'tool_response': {'type': 'object'},
|
|
94
|
+
},
|
|
95
|
+
}
|
|
96
|
+
)
|
|
78
97
|
|
|
79
|
-
tracer:
|
|
80
|
-
|
|
81
|
-
call: _messages.ToolCallPart | None = None
|
|
98
|
+
with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span:
|
|
99
|
+
output = await function_schema.call(args, run_context)
|
|
82
100
|
|
|
83
|
-
|
|
84
|
-
|
|
101
|
+
# Record response if content inclusion is enabled
|
|
102
|
+
if run_context.trace_include_content and span.is_recording():
|
|
103
|
+
from .models.instrumented import InstrumentedModel
|
|
85
104
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
run_context: RunContext[AgentDepsT],
|
|
90
|
-
args: dict[str, Any] | Any,
|
|
91
|
-
call: _messages.ToolCallPart,
|
|
92
|
-
include_tool_call_id: bool = True,
|
|
93
|
-
) -> Any:
|
|
94
|
-
"""Execute a function call within a traced span, automatically recording the response."""
|
|
95
|
-
# Set up span attributes
|
|
96
|
-
attributes = {
|
|
97
|
-
'gen_ai.tool.name': call.tool_name,
|
|
98
|
-
'logfire.msg': f'running output function: {call.tool_name}',
|
|
99
|
-
}
|
|
100
|
-
if include_tool_call_id:
|
|
101
|
-
attributes['gen_ai.tool.call.id'] = call.tool_call_id
|
|
102
|
-
if self.include_content:
|
|
103
|
-
attributes['tool_arguments'] = call.args_as_json_str()
|
|
104
|
-
attributes['logfire.json_schema'] = json.dumps(
|
|
105
|
-
{
|
|
106
|
-
'type': 'object',
|
|
107
|
-
'properties': {
|
|
108
|
-
'tool_arguments': {'type': 'object'},
|
|
109
|
-
'tool_response': {'type': 'object'},
|
|
110
|
-
},
|
|
111
|
-
}
|
|
105
|
+
span.set_attribute(
|
|
106
|
+
'tool_response',
|
|
107
|
+
output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
|
|
112
108
|
)
|
|
113
109
|
|
|
114
|
-
|
|
115
|
-
with self.tracer.start_as_current_span('running output function', attributes=attributes) as span:
|
|
116
|
-
output = await function_schema.call(args, run_context)
|
|
117
|
-
|
|
118
|
-
# Record response if content inclusion is enabled
|
|
119
|
-
if self.include_content and span.is_recording():
|
|
120
|
-
from .models.instrumented import InstrumentedModel
|
|
121
|
-
|
|
122
|
-
span.set_attribute(
|
|
123
|
-
'tool_response',
|
|
124
|
-
output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
return output
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext:
|
|
131
|
-
"""Build a `TraceContext` from the current agent graph run context."""
|
|
132
|
-
return TraceContext(
|
|
133
|
-
tracer=ctx.deps.tracer,
|
|
134
|
-
include_content=(
|
|
135
|
-
ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
|
|
136
|
-
),
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class ToolRetryError(Exception):
|
|
141
|
-
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
142
|
-
|
|
143
|
-
def __init__(self, tool_retry: _messages.RetryPromptPart):
|
|
144
|
-
self.tool_retry = tool_retry
|
|
145
|
-
super().__init__()
|
|
110
|
+
return output
|
|
146
111
|
|
|
147
112
|
|
|
148
113
|
@dataclass
|
|
@@ -158,23 +123,21 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
158
123
|
async def validate(
|
|
159
124
|
self,
|
|
160
125
|
result: T,
|
|
161
|
-
tool_call: _messages.ToolCallPart | None,
|
|
162
126
|
run_context: RunContext[AgentDepsT],
|
|
127
|
+
wrap_validation_errors: bool = True,
|
|
163
128
|
) -> T:
|
|
164
129
|
"""Validate a result but calling the function.
|
|
165
130
|
|
|
166
131
|
Args:
|
|
167
132
|
result: The result data after Pydantic validation the message content.
|
|
168
|
-
tool_call: The original tool call message, `None` if there was no tool call.
|
|
169
133
|
run_context: The current run context.
|
|
170
|
-
|
|
134
|
+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
171
135
|
|
|
172
136
|
Returns:
|
|
173
137
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
174
138
|
"""
|
|
175
139
|
if self._takes_ctx:
|
|
176
|
-
|
|
177
|
-
args = ctx, result
|
|
140
|
+
args = run_context, result
|
|
178
141
|
else:
|
|
179
142
|
args = (result,)
|
|
180
143
|
|
|
@@ -186,24 +149,32 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
186
149
|
function = cast(Callable[[Any], T], self.function)
|
|
187
150
|
result_data = await _utils.run_in_executor(function, *args)
|
|
188
151
|
except ModelRetry as r:
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
152
|
+
if wrap_validation_errors:
|
|
153
|
+
m = _messages.RetryPromptPart(
|
|
154
|
+
content=r.message,
|
|
155
|
+
tool_name=run_context.tool_name,
|
|
156
|
+
)
|
|
157
|
+
if run_context.tool_call_id: # pragma: no cover
|
|
158
|
+
m.tool_call_id = run_context.tool_call_id
|
|
159
|
+
raise ToolRetryError(m) from r
|
|
160
|
+
else:
|
|
161
|
+
raise r
|
|
194
162
|
else:
|
|
195
163
|
return result_data
|
|
196
164
|
|
|
197
165
|
|
|
166
|
+
@dataclass
|
|
198
167
|
class BaseOutputSchema(ABC, Generic[OutputDataT]):
|
|
168
|
+
allows_deferred_tool_calls: bool
|
|
169
|
+
|
|
199
170
|
@abstractmethod
|
|
200
171
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
201
172
|
raise NotImplementedError()
|
|
202
173
|
|
|
203
174
|
@property
|
|
204
|
-
def
|
|
205
|
-
"""Get the
|
|
206
|
-
return
|
|
175
|
+
def toolset(self) -> OutputToolset[Any] | None:
|
|
176
|
+
"""Get the toolset for this output schema."""
|
|
177
|
+
return None
|
|
207
178
|
|
|
208
179
|
|
|
209
180
|
@dataclass(init=False)
|
|
@@ -235,7 +206,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
235
206
|
) -> BaseOutputSchema[OutputDataT]: ...
|
|
236
207
|
|
|
237
208
|
@classmethod
|
|
238
|
-
def build(
|
|
209
|
+
def build( # noqa: C901
|
|
239
210
|
cls,
|
|
240
211
|
output_spec: OutputSpec[OutputDataT],
|
|
241
212
|
*,
|
|
@@ -245,117 +216,93 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
245
216
|
strict: bool | None = None,
|
|
246
217
|
) -> BaseOutputSchema[OutputDataT]:
|
|
247
218
|
"""Build an OutputSchema dataclass from an output type."""
|
|
248
|
-
|
|
249
|
-
|
|
219
|
+
raw_outputs = _flatten_output_spec(output_spec)
|
|
220
|
+
|
|
221
|
+
outputs = [output for output in raw_outputs if output is not DeferredToolCalls]
|
|
222
|
+
allows_deferred_tool_calls = len(outputs) < len(raw_outputs)
|
|
223
|
+
if len(outputs) == 0 and allows_deferred_tool_calls:
|
|
224
|
+
raise UserError('At least one output type must be provided other than `DeferredToolCalls`.')
|
|
225
|
+
|
|
226
|
+
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
|
|
227
|
+
if len(outputs) > 1:
|
|
228
|
+
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
|
|
250
229
|
|
|
251
|
-
if isinstance(output_spec, NativeOutput):
|
|
252
230
|
return NativeOutputSchema(
|
|
253
|
-
cls._build_processor(
|
|
254
|
-
_flatten_output_spec(
|
|
255
|
-
name=
|
|
256
|
-
description=
|
|
257
|
-
strict=
|
|
258
|
-
)
|
|
231
|
+
processor=cls._build_processor(
|
|
232
|
+
_flatten_output_spec(output.outputs),
|
|
233
|
+
name=output.name,
|
|
234
|
+
description=output.description,
|
|
235
|
+
strict=output.strict,
|
|
236
|
+
),
|
|
237
|
+
allows_deferred_tool_calls=allows_deferred_tool_calls,
|
|
259
238
|
)
|
|
260
|
-
elif isinstance(
|
|
239
|
+
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
|
|
240
|
+
if len(outputs) > 1:
|
|
241
|
+
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
|
|
242
|
+
|
|
261
243
|
return PromptedOutputSchema(
|
|
262
|
-
cls._build_processor(
|
|
263
|
-
_flatten_output_spec(
|
|
264
|
-
name=
|
|
265
|
-
description=
|
|
244
|
+
processor=cls._build_processor(
|
|
245
|
+
_flatten_output_spec(output.outputs),
|
|
246
|
+
name=output.name,
|
|
247
|
+
description=output.description,
|
|
266
248
|
),
|
|
267
|
-
template=
|
|
249
|
+
template=output.template,
|
|
250
|
+
allows_deferred_tool_calls=allows_deferred_tool_calls,
|
|
268
251
|
)
|
|
269
252
|
|
|
270
253
|
text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
|
|
271
254
|
tool_outputs: Sequence[ToolOutput[OutputDataT]] = []
|
|
272
255
|
other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = []
|
|
273
|
-
for output in
|
|
256
|
+
for output in outputs:
|
|
274
257
|
if output is str:
|
|
275
258
|
text_outputs.append(cast(type[str], output))
|
|
276
259
|
elif isinstance(output, TextOutput):
|
|
277
260
|
text_outputs.append(output)
|
|
278
261
|
elif isinstance(output, ToolOutput):
|
|
279
262
|
tool_outputs.append(output)
|
|
263
|
+
elif isinstance(output, NativeOutput):
|
|
264
|
+
# We can never get here because this is checked for above.
|
|
265
|
+
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
|
|
266
|
+
elif isinstance(output, PromptedOutput):
|
|
267
|
+
# We can never get here because this is checked for above.
|
|
268
|
+
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
|
|
280
269
|
else:
|
|
281
270
|
other_outputs.append(output)
|
|
282
271
|
|
|
283
|
-
|
|
272
|
+
toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict)
|
|
284
273
|
|
|
285
274
|
if len(text_outputs) > 0:
|
|
286
275
|
if len(text_outputs) > 1:
|
|
287
|
-
raise UserError('Only one
|
|
276
|
+
raise UserError('Only one `str` or `TextOutput` is allowed.')
|
|
288
277
|
text_output = text_outputs[0]
|
|
289
278
|
|
|
290
279
|
text_output_schema = None
|
|
291
280
|
if isinstance(text_output, TextOutput):
|
|
292
281
|
text_output_schema = PlainTextOutputProcessor(text_output.output_function)
|
|
293
282
|
|
|
294
|
-
if
|
|
295
|
-
return
|
|
283
|
+
if toolset:
|
|
284
|
+
return ToolOrTextOutputSchema(
|
|
285
|
+
processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls
|
|
286
|
+
)
|
|
296
287
|
else:
|
|
297
|
-
return
|
|
288
|
+
return PlainTextOutputSchema(
|
|
289
|
+
processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls
|
|
290
|
+
)
|
|
298
291
|
|
|
299
292
|
if len(tool_outputs) > 0:
|
|
300
|
-
return ToolOutputSchema(
|
|
293
|
+
return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
|
|
301
294
|
|
|
302
295
|
if len(other_outputs) > 0:
|
|
303
296
|
schema = OutputSchemaWithoutMode(
|
|
304
297
|
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
|
|
305
|
-
|
|
298
|
+
toolset=toolset,
|
|
299
|
+
allows_deferred_tool_calls=allows_deferred_tool_calls,
|
|
306
300
|
)
|
|
307
301
|
if default_mode:
|
|
308
302
|
schema = schema.with_default_mode(default_mode)
|
|
309
303
|
return schema
|
|
310
304
|
|
|
311
|
-
raise UserError('
|
|
312
|
-
|
|
313
|
-
@staticmethod
|
|
314
|
-
def _build_tools(
|
|
315
|
-
outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
|
|
316
|
-
name: str | None = None,
|
|
317
|
-
description: str | None = None,
|
|
318
|
-
strict: bool | None = None,
|
|
319
|
-
) -> dict[str, OutputTool[OutputDataT]]:
|
|
320
|
-
tools: dict[str, OutputTool[OutputDataT]] = {}
|
|
321
|
-
|
|
322
|
-
default_name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
323
|
-
default_description = description
|
|
324
|
-
default_strict = strict
|
|
325
|
-
|
|
326
|
-
multiple = len(outputs) > 1
|
|
327
|
-
for output in outputs:
|
|
328
|
-
name = None
|
|
329
|
-
description = None
|
|
330
|
-
strict = None
|
|
331
|
-
if isinstance(output, ToolOutput):
|
|
332
|
-
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
333
|
-
name = output.name
|
|
334
|
-
description = output.description
|
|
335
|
-
strict = output.strict
|
|
336
|
-
|
|
337
|
-
output = output.output
|
|
338
|
-
|
|
339
|
-
description = description or default_description
|
|
340
|
-
if strict is None:
|
|
341
|
-
strict = default_strict
|
|
342
|
-
|
|
343
|
-
processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
|
|
344
|
-
|
|
345
|
-
if name is None:
|
|
346
|
-
name = default_name
|
|
347
|
-
if multiple:
|
|
348
|
-
name += f'_{processor.object_def.name}'
|
|
349
|
-
|
|
350
|
-
i = 1
|
|
351
|
-
original_name = name
|
|
352
|
-
while name in tools:
|
|
353
|
-
i += 1
|
|
354
|
-
name = f'{original_name}_{i}'
|
|
355
|
-
|
|
356
|
-
tools[name] = OutputTool(name=name, processor=processor, multiple=multiple)
|
|
357
|
-
|
|
358
|
-
return tools
|
|
305
|
+
raise UserError('At least one output type must be provided.')
|
|
359
306
|
|
|
360
307
|
@staticmethod
|
|
361
308
|
def _build_processor(
|
|
@@ -387,32 +334,39 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
387
334
|
@dataclass(init=False)
|
|
388
335
|
class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
|
|
389
336
|
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
|
|
390
|
-
|
|
337
|
+
_toolset: OutputToolset[Any] | None
|
|
391
338
|
|
|
392
339
|
def __init__(
|
|
393
340
|
self,
|
|
394
341
|
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
|
|
395
|
-
|
|
342
|
+
toolset: OutputToolset[Any] | None,
|
|
343
|
+
allows_deferred_tool_calls: bool,
|
|
396
344
|
):
|
|
345
|
+
super().__init__(allows_deferred_tool_calls)
|
|
397
346
|
self.processor = processor
|
|
398
|
-
self.
|
|
347
|
+
self._toolset = toolset
|
|
399
348
|
|
|
400
349
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
401
350
|
if mode == 'native':
|
|
402
|
-
return NativeOutputSchema(
|
|
351
|
+
return NativeOutputSchema(
|
|
352
|
+
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
353
|
+
)
|
|
403
354
|
elif mode == 'prompted':
|
|
404
|
-
return PromptedOutputSchema(
|
|
355
|
+
return PromptedOutputSchema(
|
|
356
|
+
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
357
|
+
)
|
|
405
358
|
elif mode == 'tool':
|
|
406
|
-
return ToolOutputSchema(self.
|
|
359
|
+
return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls)
|
|
407
360
|
else:
|
|
408
361
|
assert_never(mode)
|
|
409
362
|
|
|
410
363
|
@property
|
|
411
|
-
def
|
|
412
|
-
"""Get the
|
|
413
|
-
# We return
|
|
414
|
-
# At that point we may
|
|
415
|
-
|
|
364
|
+
def toolset(self) -> OutputToolset[Any] | None:
|
|
365
|
+
"""Get the toolset for this output schema."""
|
|
366
|
+
# We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
|
|
367
|
+
# At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
|
|
368
|
+
# but we cover ourselves just in case we end up using the tool output mode.
|
|
369
|
+
return self._toolset
|
|
416
370
|
|
|
417
371
|
|
|
418
372
|
class TextOutputSchema(OutputSchema[OutputDataT], ABC):
|
|
@@ -421,7 +375,6 @@ class TextOutputSchema(OutputSchema[OutputDataT], ABC):
|
|
|
421
375
|
self,
|
|
422
376
|
text: str,
|
|
423
377
|
run_context: RunContext[AgentDepsT],
|
|
424
|
-
trace_context: TraceContext,
|
|
425
378
|
allow_partial: bool = False,
|
|
426
379
|
wrap_validation_errors: bool = True,
|
|
427
380
|
) -> OutputDataT:
|
|
@@ -444,7 +397,6 @@ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
|
|
|
444
397
|
self,
|
|
445
398
|
text: str,
|
|
446
399
|
run_context: RunContext[AgentDepsT],
|
|
447
|
-
trace_context: TraceContext,
|
|
448
400
|
allow_partial: bool = False,
|
|
449
401
|
wrap_validation_errors: bool = True,
|
|
450
402
|
) -> OutputDataT:
|
|
@@ -453,7 +405,6 @@ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
|
|
|
453
405
|
Args:
|
|
454
406
|
text: The output text to validate.
|
|
455
407
|
run_context: The current run context.
|
|
456
|
-
trace_context: The trace context to use for tracing the output processing.
|
|
457
408
|
allow_partial: If true, allow partial validation.
|
|
458
409
|
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
459
410
|
|
|
@@ -464,7 +415,7 @@ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]):
|
|
|
464
415
|
return cast(OutputDataT, text)
|
|
465
416
|
|
|
466
417
|
return await self.processor.process(
|
|
467
|
-
text, run_context,
|
|
418
|
+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
468
419
|
)
|
|
469
420
|
|
|
470
421
|
|
|
@@ -486,13 +437,12 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
486
437
|
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
487
438
|
"""Raise an error if the mode is not supported by the model."""
|
|
488
439
|
if not profile.supports_json_schema_output:
|
|
489
|
-
raise UserError('
|
|
440
|
+
raise UserError('Native structured output is not supported by the model.')
|
|
490
441
|
|
|
491
442
|
async def process(
|
|
492
443
|
self,
|
|
493
444
|
text: str,
|
|
494
445
|
run_context: RunContext[AgentDepsT],
|
|
495
|
-
trace_context: TraceContext,
|
|
496
446
|
allow_partial: bool = False,
|
|
497
447
|
wrap_validation_errors: bool = True,
|
|
498
448
|
) -> OutputDataT:
|
|
@@ -501,7 +451,6 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
501
451
|
Args:
|
|
502
452
|
text: The output text to validate.
|
|
503
453
|
run_context: The current run context.
|
|
504
|
-
trace_context: The trace context to use for tracing the output processing.
|
|
505
454
|
allow_partial: If true, allow partial validation.
|
|
506
455
|
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
507
456
|
|
|
@@ -509,7 +458,7 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
509
458
|
Either the validated output data (left) or a retry message (right).
|
|
510
459
|
"""
|
|
511
460
|
return await self.processor.process(
|
|
512
|
-
text, run_context,
|
|
461
|
+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
513
462
|
)
|
|
514
463
|
|
|
515
464
|
|
|
@@ -545,7 +494,6 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
545
494
|
self,
|
|
546
495
|
text: str,
|
|
547
496
|
run_context: RunContext[AgentDepsT],
|
|
548
|
-
trace_context: TraceContext,
|
|
549
497
|
allow_partial: bool = False,
|
|
550
498
|
wrap_validation_errors: bool = True,
|
|
551
499
|
) -> OutputDataT:
|
|
@@ -554,7 +502,6 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
554
502
|
Args:
|
|
555
503
|
text: The output text to validate.
|
|
556
504
|
run_context: The current run context.
|
|
557
|
-
trace_context: The trace context to use for tracing the output processing.
|
|
558
505
|
allow_partial: If true, allow partial validation.
|
|
559
506
|
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
560
507
|
|
|
@@ -564,16 +511,17 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
564
511
|
text = _utils.strip_markdown_fences(text)
|
|
565
512
|
|
|
566
513
|
return await self.processor.process(
|
|
567
|
-
text, run_context,
|
|
514
|
+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
568
515
|
)
|
|
569
516
|
|
|
570
517
|
|
|
571
518
|
@dataclass(init=False)
|
|
572
519
|
class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
573
|
-
|
|
520
|
+
_toolset: OutputToolset[Any] | None
|
|
574
521
|
|
|
575
|
-
def __init__(self,
|
|
576
|
-
|
|
522
|
+
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool):
|
|
523
|
+
super().__init__(allows_deferred_tool_calls)
|
|
524
|
+
self._toolset = toolset
|
|
577
525
|
|
|
578
526
|
@property
|
|
579
527
|
def mode(self) -> OutputMode:
|
|
@@ -585,36 +533,9 @@ class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
|
585
533
|
raise UserError('Output tools are not supported by the model.')
|
|
586
534
|
|
|
587
535
|
@property
|
|
588
|
-
def
|
|
589
|
-
"""Get the
|
|
590
|
-
return self.
|
|
591
|
-
|
|
592
|
-
def tool_names(self) -> list[str]:
|
|
593
|
-
"""Return the names of the tools."""
|
|
594
|
-
return list(self.tools.keys())
|
|
595
|
-
|
|
596
|
-
def tool_defs(self) -> list[ToolDefinition]:
|
|
597
|
-
"""Get tool definitions to register with the model."""
|
|
598
|
-
return [t.tool_def for t in self.tools.values()]
|
|
599
|
-
|
|
600
|
-
def find_named_tool(
|
|
601
|
-
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
602
|
-
) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
|
|
603
|
-
"""Find a tool that matches one of the calls, with a specific name."""
|
|
604
|
-
for part in parts: # pragma: no branch
|
|
605
|
-
if isinstance(part, _messages.ToolCallPart): # pragma: no branch
|
|
606
|
-
if part.tool_name == tool_name:
|
|
607
|
-
return part, self.tools[tool_name]
|
|
608
|
-
|
|
609
|
-
def find_tool(
|
|
610
|
-
self,
|
|
611
|
-
parts: Iterable[_messages.ModelResponsePart],
|
|
612
|
-
) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]:
|
|
613
|
-
"""Find a tool that matches one of the calls."""
|
|
614
|
-
for part in parts:
|
|
615
|
-
if isinstance(part, _messages.ToolCallPart): # pragma: no branch
|
|
616
|
-
if result := self.tools.get(part.tool_name):
|
|
617
|
-
yield part, result
|
|
536
|
+
def toolset(self) -> OutputToolset[Any] | None:
|
|
537
|
+
"""Get the toolset for this output schema."""
|
|
538
|
+
return self._toolset
|
|
618
539
|
|
|
619
540
|
|
|
620
541
|
@dataclass(init=False)
|
|
@@ -622,10 +543,11 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem
|
|
|
622
543
|
def __init__(
|
|
623
544
|
self,
|
|
624
545
|
processor: PlainTextOutputProcessor[OutputDataT] | None,
|
|
625
|
-
|
|
546
|
+
toolset: OutputToolset[Any] | None,
|
|
547
|
+
allows_deferred_tool_calls: bool,
|
|
626
548
|
):
|
|
549
|
+
super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
|
|
627
550
|
self.processor = processor
|
|
628
|
-
self._tools = tools
|
|
629
551
|
|
|
630
552
|
@property
|
|
631
553
|
def mode(self) -> OutputMode:
|
|
@@ -647,7 +569,6 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
|
|
|
647
569
|
self,
|
|
648
570
|
data: str,
|
|
649
571
|
run_context: RunContext[AgentDepsT],
|
|
650
|
-
trace_context: TraceContext,
|
|
651
572
|
allow_partial: bool = False,
|
|
652
573
|
wrap_validation_errors: bool = True,
|
|
653
574
|
) -> OutputDataT:
|
|
@@ -659,7 +580,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
|
|
|
659
580
|
class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
660
581
|
object_def: OutputObjectDefinition
|
|
661
582
|
outer_typed_dict_key: str | None = None
|
|
662
|
-
|
|
583
|
+
validator: SchemaValidator
|
|
663
584
|
_function_schema: _function_schema.FunctionSchema | None = None
|
|
664
585
|
|
|
665
586
|
def __init__(
|
|
@@ -672,7 +593,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
672
593
|
):
|
|
673
594
|
if inspect.isfunction(output) or inspect.ismethod(output):
|
|
674
595
|
self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema)
|
|
675
|
-
self.
|
|
596
|
+
self.validator = self._function_schema.validator
|
|
676
597
|
json_schema = self._function_schema.json_schema
|
|
677
598
|
json_schema['description'] = self._function_schema.description
|
|
678
599
|
else:
|
|
@@ -688,7 +609,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
688
609
|
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
689
610
|
|
|
690
611
|
# Really a PluggableSchemaValidator, but it's API-compatible
|
|
691
|
-
self.
|
|
612
|
+
self.validator = cast(SchemaValidator, type_adapter.validator)
|
|
692
613
|
json_schema = _utils.check_object_json_schema(
|
|
693
614
|
type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
|
|
694
615
|
)
|
|
@@ -717,7 +638,6 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
717
638
|
self,
|
|
718
639
|
data: str | dict[str, Any] | None,
|
|
719
640
|
run_context: RunContext[AgentDepsT],
|
|
720
|
-
trace_context: TraceContext,
|
|
721
641
|
allow_partial: bool = False,
|
|
722
642
|
wrap_validation_errors: bool = True,
|
|
723
643
|
) -> OutputDataT:
|
|
@@ -726,7 +646,6 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
726
646
|
Args:
|
|
727
647
|
data: The output data to validate.
|
|
728
648
|
run_context: The current run context.
|
|
729
|
-
trace_context: The trace context to use for tracing the output processing.
|
|
730
649
|
allow_partial: If true, allow partial validation.
|
|
731
650
|
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
732
651
|
|
|
@@ -734,11 +653,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
734
653
|
Either the validated output data (left) or a retry message (right).
|
|
735
654
|
"""
|
|
736
655
|
try:
|
|
737
|
-
|
|
738
|
-
if isinstance(data, str):
|
|
739
|
-
output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
|
|
740
|
-
else:
|
|
741
|
-
output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
|
|
656
|
+
output = self.validate(data, allow_partial)
|
|
742
657
|
except ValidationError as e:
|
|
743
658
|
if wrap_validation_errors:
|
|
744
659
|
m = _messages.RetryPromptPart(
|
|
@@ -748,30 +663,40 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
748
663
|
else:
|
|
749
664
|
raise
|
|
750
665
|
|
|
666
|
+
try:
|
|
667
|
+
output = await self.call(output, run_context)
|
|
668
|
+
except ModelRetry as r:
|
|
669
|
+
if wrap_validation_errors:
|
|
670
|
+
m = _messages.RetryPromptPart(
|
|
671
|
+
content=r.message,
|
|
672
|
+
)
|
|
673
|
+
raise ToolRetryError(m) from r
|
|
674
|
+
else:
|
|
675
|
+
raise # pragma: no cover
|
|
676
|
+
|
|
677
|
+
return output
|
|
678
|
+
|
|
679
|
+
def validate(
|
|
680
|
+
self,
|
|
681
|
+
data: str | dict[str, Any] | None,
|
|
682
|
+
allow_partial: bool = False,
|
|
683
|
+
) -> dict[str, Any]:
|
|
684
|
+
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
685
|
+
if isinstance(data, str):
|
|
686
|
+
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
|
|
687
|
+
else:
|
|
688
|
+
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
|
|
689
|
+
|
|
690
|
+
async def call(
|
|
691
|
+
self,
|
|
692
|
+
output: Any,
|
|
693
|
+
run_context: RunContext[AgentDepsT],
|
|
694
|
+
):
|
|
751
695
|
if k := self.outer_typed_dict_key:
|
|
752
696
|
output = output[k]
|
|
753
697
|
|
|
754
698
|
if self._function_schema:
|
|
755
|
-
|
|
756
|
-
if trace_context.call:
|
|
757
|
-
call = trace_context.call
|
|
758
|
-
include_tool_call_id = True
|
|
759
|
-
else:
|
|
760
|
-
function_name = getattr(self._function_schema.function, '__name__', 'output_function')
|
|
761
|
-
call = _messages.ToolCallPart(tool_name=function_name, args=data)
|
|
762
|
-
include_tool_call_id = False
|
|
763
|
-
try:
|
|
764
|
-
output = await trace_context.execute_function_with_span(
|
|
765
|
-
self._function_schema, run_context, output, call, include_tool_call_id
|
|
766
|
-
)
|
|
767
|
-
except ModelRetry as r:
|
|
768
|
-
if wrap_validation_errors:
|
|
769
|
-
m = _messages.RetryPromptPart(
|
|
770
|
-
content=r.message,
|
|
771
|
-
)
|
|
772
|
-
raise ToolRetryError(m) from r
|
|
773
|
-
else:
|
|
774
|
-
raise
|
|
699
|
+
output = await execute_output_function_with_span(self._function_schema, run_context, output)
|
|
775
700
|
|
|
776
701
|
return output
|
|
777
702
|
|
|
@@ -876,12 +801,11 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
876
801
|
self,
|
|
877
802
|
data: str | dict[str, Any] | None,
|
|
878
803
|
run_context: RunContext[AgentDepsT],
|
|
879
|
-
trace_context: TraceContext,
|
|
880
804
|
allow_partial: bool = False,
|
|
881
805
|
wrap_validation_errors: bool = True,
|
|
882
806
|
) -> OutputDataT:
|
|
883
807
|
union_object = await self._union_processor.process(
|
|
884
|
-
data, run_context,
|
|
808
|
+
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
885
809
|
)
|
|
886
810
|
|
|
887
811
|
result = union_object.result
|
|
@@ -897,7 +821,7 @@ class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
897
821
|
raise
|
|
898
822
|
|
|
899
823
|
return await processor.process(
|
|
900
|
-
data, run_context,
|
|
824
|
+
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
|
|
901
825
|
)
|
|
902
826
|
|
|
903
827
|
|
|
@@ -928,20 +852,12 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
928
852
|
self,
|
|
929
853
|
data: str,
|
|
930
854
|
run_context: RunContext[AgentDepsT],
|
|
931
|
-
trace_context: TraceContext,
|
|
932
855
|
allow_partial: bool = False,
|
|
933
856
|
wrap_validation_errors: bool = True,
|
|
934
857
|
) -> OutputDataT:
|
|
935
858
|
args = {self._str_argument_name: data}
|
|
936
|
-
# Wraps the output function call in an OpenTelemetry span.
|
|
937
|
-
# Note: PlainTextOutputProcessor is used for text responses (not tool calls),
|
|
938
|
-
# so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id
|
|
939
|
-
function_name = getattr(self._function_schema.function, '__name__', 'text_output_function')
|
|
940
|
-
call = _messages.ToolCallPart(tool_name=function_name, args=args)
|
|
941
859
|
try:
|
|
942
|
-
output = await
|
|
943
|
-
self._function_schema, run_context, args, call, include_tool_call_id=False
|
|
944
|
-
)
|
|
860
|
+
output = await execute_output_function_with_span(self._function_schema, run_context, args)
|
|
945
861
|
except ModelRetry as r:
|
|
946
862
|
if wrap_validation_errors:
|
|
947
863
|
m = _messages.RetryPromptPart(
|
|
@@ -955,91 +871,139 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
955
871
|
|
|
956
872
|
|
|
957
873
|
@dataclass(init=False)
|
|
958
|
-
class
|
|
959
|
-
|
|
960
|
-
tool_def: ToolDefinition
|
|
874
|
+
class OutputToolset(AbstractToolset[AgentDepsT]):
|
|
875
|
+
"""A toolset that contains contains output tools for agent output types."""
|
|
961
876
|
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
877
|
+
_tool_defs: list[ToolDefinition]
|
|
878
|
+
"""The tool definitions for the output tools in this toolset."""
|
|
879
|
+
processors: dict[str, ObjectOutputProcessor[Any]]
|
|
880
|
+
"""The processors for the output tools in this toolset."""
|
|
881
|
+
max_retries: int
|
|
882
|
+
output_validators: list[OutputValidator[AgentDepsT, Any]]
|
|
965
883
|
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
884
|
+
@classmethod
|
|
885
|
+
def build(
|
|
886
|
+
cls,
|
|
887
|
+
outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
|
|
888
|
+
name: str | None = None,
|
|
889
|
+
description: str | None = None,
|
|
890
|
+
strict: bool | None = None,
|
|
891
|
+
) -> Self | None:
|
|
892
|
+
if len(outputs) == 0:
|
|
893
|
+
return None
|
|
971
894
|
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
description=description,
|
|
975
|
-
parameters_json_schema=object_def.json_schema,
|
|
976
|
-
strict=object_def.strict,
|
|
977
|
-
outer_typed_dict_key=processor.outer_typed_dict_key,
|
|
978
|
-
)
|
|
895
|
+
processors: dict[str, ObjectOutputProcessor[Any]] = {}
|
|
896
|
+
tool_defs: list[ToolDefinition] = []
|
|
979
897
|
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
run_context: RunContext[AgentDepsT],
|
|
984
|
-
trace_context: TraceContext,
|
|
985
|
-
allow_partial: bool = False,
|
|
986
|
-
wrap_validation_errors: bool = True,
|
|
987
|
-
) -> OutputDataT:
|
|
988
|
-
"""Process an output message.
|
|
898
|
+
default_name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
899
|
+
default_description = description
|
|
900
|
+
default_strict = strict
|
|
989
901
|
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
902
|
+
multiple = len(outputs) > 1
|
|
903
|
+
for output in outputs:
|
|
904
|
+
name = None
|
|
905
|
+
description = None
|
|
906
|
+
strict = None
|
|
907
|
+
if isinstance(output, ToolOutput):
|
|
908
|
+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
909
|
+
name = output.name
|
|
910
|
+
description = output.description
|
|
911
|
+
strict = output.strict
|
|
996
912
|
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
913
|
+
output = output.output
|
|
914
|
+
|
|
915
|
+
description = description or default_description
|
|
916
|
+
if strict is None:
|
|
917
|
+
strict = default_strict
|
|
918
|
+
|
|
919
|
+
processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
|
|
920
|
+
object_def = processor.object_def
|
|
921
|
+
|
|
922
|
+
if name is None:
|
|
923
|
+
name = default_name
|
|
924
|
+
if multiple:
|
|
925
|
+
name += f'_{object_def.name}'
|
|
926
|
+
|
|
927
|
+
i = 1
|
|
928
|
+
original_name = name
|
|
929
|
+
while name in processors:
|
|
930
|
+
i += 1
|
|
931
|
+
name = f'{original_name}_{i}'
|
|
932
|
+
|
|
933
|
+
description = object_def.description
|
|
934
|
+
if not description:
|
|
935
|
+
description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
|
|
936
|
+
if multiple:
|
|
937
|
+
description = f'{object_def.name}: {description}'
|
|
938
|
+
|
|
939
|
+
tool_def = ToolDefinition(
|
|
940
|
+
name=name,
|
|
941
|
+
description=description,
|
|
942
|
+
parameters_json_schema=object_def.json_schema,
|
|
943
|
+
strict=object_def.strict,
|
|
944
|
+
outer_typed_dict_key=processor.outer_typed_dict_key,
|
|
945
|
+
kind='output',
|
|
1007
946
|
)
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
947
|
+
processors[name] = processor
|
|
948
|
+
tool_defs.append(tool_def)
|
|
949
|
+
|
|
950
|
+
return cls(processors=processors, tool_defs=tool_defs)
|
|
951
|
+
|
|
952
|
+
def __init__(
|
|
953
|
+
self,
|
|
954
|
+
tool_defs: list[ToolDefinition],
|
|
955
|
+
processors: dict[str, ObjectOutputProcessor[Any]],
|
|
956
|
+
max_retries: int = 1,
|
|
957
|
+
output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None,
|
|
958
|
+
):
|
|
959
|
+
self.processors = processors
|
|
960
|
+
self._tool_defs = tool_defs
|
|
961
|
+
self.max_retries = max_retries
|
|
962
|
+
self.output_validators = output_validators or []
|
|
963
|
+
|
|
964
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
965
|
+
return {
|
|
966
|
+
tool_def.name: ToolsetTool(
|
|
967
|
+
toolset=self,
|
|
968
|
+
tool_def=tool_def,
|
|
969
|
+
max_retries=self.max_retries,
|
|
970
|
+
args_validator=self.processors[tool_def.name].validator,
|
|
971
|
+
)
|
|
972
|
+
for tool_def in self._tool_defs
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
async def call_tool(
|
|
976
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
977
|
+
) -> Any:
|
|
978
|
+
output = await self.processors[name].call(tool_args, ctx)
|
|
979
|
+
for validator in self.output_validators:
|
|
980
|
+
output = await validator.validate(output, ctx, wrap_validation_errors=False)
|
|
981
|
+
return output
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
@overload
|
|
985
|
+
def _flatten_output_spec(
|
|
986
|
+
output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]],
|
|
987
|
+
) -> Sequence[OutputTypeOrFunction[T]]: ...
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
@overload
|
|
991
|
+
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ...
|
|
1030
992
|
|
|
1031
993
|
|
|
1032
|
-
def _flatten_output_spec(output_spec:
|
|
1033
|
-
outputs: Sequence[T]
|
|
994
|
+
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]:
|
|
995
|
+
outputs: Sequence[OutputSpec[T]]
|
|
1034
996
|
if isinstance(output_spec, Sequence):
|
|
1035
997
|
outputs = output_spec
|
|
1036
998
|
else:
|
|
1037
999
|
outputs = (output_spec,)
|
|
1038
1000
|
|
|
1039
|
-
outputs_flat: list[T] = []
|
|
1001
|
+
outputs_flat: list[_OutputSpecItem[T]] = []
|
|
1040
1002
|
for output in outputs:
|
|
1041
|
-
if
|
|
1003
|
+
if isinstance(output, Sequence):
|
|
1004
|
+
outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output)))
|
|
1005
|
+
elif union_types := _utils.get_union_args(output):
|
|
1042
1006
|
outputs_flat.extend(union_types)
|
|
1043
1007
|
else:
|
|
1044
|
-
outputs_flat.append(output)
|
|
1008
|
+
outputs_flat.append(cast(_OutputSpecItem[T], output))
|
|
1045
1009
|
return outputs_flat
|