pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +219 -315
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +296 -226
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -155
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +14 -2
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +19 -9
- pydantic_ai/models/__init__.py +43 -19
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +3 -11
- pydantic_ai/models/google.py +3 -12
- pydantic_ai/models/groq.py +2 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +3 -3
- pydantic_ai/models/openai.py +5 -5
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +57 -33
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
- pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
- pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py
CHANGED
|
@@ -3,18 +3,19 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import json
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
-
from collections.abc import Awaitable,
|
|
6
|
+
from collections.abc import Awaitable, Sequence
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from pydantic import TypeAdapter, ValidationError
|
|
11
|
-
from pydantic_core import SchemaValidator
|
|
12
|
-
from typing_extensions import TypedDict, TypeVar, assert_never
|
|
11
|
+
from pydantic_core import SchemaValidator, to_json
|
|
12
|
+
from typing_extensions import Self, TypedDict, TypeVar, assert_never
|
|
13
13
|
|
|
14
14
|
from . import _function_schema, _utils, messages as _messages
|
|
15
15
|
from ._run_context import AgentDepsT, RunContext
|
|
16
|
-
from .exceptions import ModelRetry, UserError
|
|
16
|
+
from .exceptions import ModelRetry, ToolRetryError, UserError
|
|
17
17
|
from .output import (
|
|
18
|
+
DeferredToolCalls,
|
|
18
19
|
NativeOutput,
|
|
19
20
|
OutputDataT,
|
|
20
21
|
OutputMode,
|
|
@@ -25,8 +26,10 @@ from .output import (
|
|
|
25
26
|
TextOutput,
|
|
26
27
|
TextOutputFunc,
|
|
27
28
|
ToolOutput,
|
|
29
|
+
_OutputSpecItem, # type: ignore[reportPrivateUsage]
|
|
28
30
|
)
|
|
29
31
|
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
|
|
32
|
+
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
30
33
|
|
|
31
34
|
if TYPE_CHECKING:
|
|
32
35
|
from .profiles import ModelProfile
|
|
@@ -66,12 +69,45 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
|
|
|
66
69
|
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
|
|
67
70
|
|
|
68
71
|
|
|
69
|
-
|
|
70
|
-
|
|
72
|
+
async def execute_output_function_with_span(
|
|
73
|
+
function_schema: _function_schema.FunctionSchema,
|
|
74
|
+
run_context: RunContext[AgentDepsT],
|
|
75
|
+
args: dict[str, Any] | Any,
|
|
76
|
+
) -> Any:
|
|
77
|
+
"""Execute a function call within a traced span, automatically recording the response."""
|
|
78
|
+
# Set up span attributes
|
|
79
|
+
tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function')
|
|
80
|
+
attributes = {
|
|
81
|
+
'gen_ai.tool.name': tool_name,
|
|
82
|
+
'logfire.msg': f'running output function: {tool_name}',
|
|
83
|
+
}
|
|
84
|
+
if run_context.tool_call_id:
|
|
85
|
+
attributes['gen_ai.tool.call.id'] = run_context.tool_call_id
|
|
86
|
+
if run_context.trace_include_content:
|
|
87
|
+
attributes['tool_arguments'] = to_json(args).decode()
|
|
88
|
+
attributes['logfire.json_schema'] = json.dumps(
|
|
89
|
+
{
|
|
90
|
+
'type': 'object',
|
|
91
|
+
'properties': {
|
|
92
|
+
'tool_arguments': {'type': 'object'},
|
|
93
|
+
'tool_response': {'type': 'object'},
|
|
94
|
+
},
|
|
95
|
+
}
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span:
|
|
99
|
+
output = await function_schema.call(args, run_context)
|
|
71
100
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
101
|
+
# Record response if content inclusion is enabled
|
|
102
|
+
if run_context.trace_include_content and span.is_recording():
|
|
103
|
+
from .models.instrumented import InstrumentedModel
|
|
104
|
+
|
|
105
|
+
span.set_attribute(
|
|
106
|
+
'tool_response',
|
|
107
|
+
output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return output
|
|
75
111
|
|
|
76
112
|
|
|
77
113
|
@dataclass
|
|
@@ -87,22 +123,21 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
87
123
|
async def validate(
|
|
88
124
|
self,
|
|
89
125
|
result: T,
|
|
90
|
-
tool_call: _messages.ToolCallPart | None,
|
|
91
126
|
run_context: RunContext[AgentDepsT],
|
|
127
|
+
wrap_validation_errors: bool = True,
|
|
92
128
|
) -> T:
|
|
93
129
|
"""Validate a result but calling the function.
|
|
94
130
|
|
|
95
131
|
Args:
|
|
96
132
|
result: The result data after Pydantic validation the message content.
|
|
97
|
-
tool_call: The original tool call message, `None` if there was no tool call.
|
|
98
133
|
run_context: The current run context.
|
|
134
|
+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
99
135
|
|
|
100
136
|
Returns:
|
|
101
137
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
102
138
|
"""
|
|
103
139
|
if self._takes_ctx:
|
|
104
|
-
|
|
105
|
-
args = ctx, result
|
|
140
|
+
args = run_context, result
|
|
106
141
|
else:
|
|
107
142
|
args = (result,)
|
|
108
143
|
|
|
@@ -114,24 +149,32 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
114
149
|
function = cast(Callable[[Any], T], self.function)
|
|
115
150
|
result_data = await _utils.run_in_executor(function, *args)
|
|
116
151
|
except ModelRetry as r:
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
152
|
+
if wrap_validation_errors:
|
|
153
|
+
m = _messages.RetryPromptPart(
|
|
154
|
+
content=r.message,
|
|
155
|
+
tool_name=run_context.tool_name,
|
|
156
|
+
)
|
|
157
|
+
if run_context.tool_call_id: # pragma: no cover
|
|
158
|
+
m.tool_call_id = run_context.tool_call_id
|
|
159
|
+
raise ToolRetryError(m) from r
|
|
160
|
+
else:
|
|
161
|
+
raise r
|
|
122
162
|
else:
|
|
123
163
|
return result_data
|
|
124
164
|
|
|
125
165
|
|
|
166
|
+
@dataclass
|
|
126
167
|
class BaseOutputSchema(ABC, Generic[OutputDataT]):
|
|
168
|
+
allows_deferred_tool_calls: bool
|
|
169
|
+
|
|
127
170
|
@abstractmethod
|
|
128
171
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
129
172
|
raise NotImplementedError()
|
|
130
173
|
|
|
131
174
|
@property
|
|
132
|
-
def
|
|
133
|
-
"""Get the
|
|
134
|
-
return
|
|
175
|
+
def toolset(self) -> OutputToolset[Any] | None:
|
|
176
|
+
"""Get the toolset for this output schema."""
|
|
177
|
+
return None
|
|
135
178
|
|
|
136
179
|
|
|
137
180
|
@dataclass(init=False)
|
|
@@ -163,7 +206,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
163
206
|
) -> BaseOutputSchema[OutputDataT]: ...
|
|
164
207
|
|
|
165
208
|
@classmethod
|
|
166
|
-
def build(
|
|
209
|
+
def build( # noqa: C901
|
|
167
210
|
cls,
|
|
168
211
|
output_spec: OutputSpec[OutputDataT],
|
|
169
212
|
*,
|
|
@@ -173,117 +216,93 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
173
216
|
strict: bool | None = None,
|
|
174
217
|
) -> BaseOutputSchema[OutputDataT]:
|
|
175
218
|
"""Build an OutputSchema dataclass from an output type."""
|
|
176
|
-
|
|
177
|
-
|
|
219
|
+
raw_outputs = _flatten_output_spec(output_spec)
|
|
220
|
+
|
|
221
|
+
outputs = [output for output in raw_outputs if output is not DeferredToolCalls]
|
|
222
|
+
allows_deferred_tool_calls = len(outputs) < len(raw_outputs)
|
|
223
|
+
if len(outputs) == 0 and allows_deferred_tool_calls:
|
|
224
|
+
raise UserError('At least one output type must be provided other than `DeferredToolCalls`.')
|
|
225
|
+
|
|
226
|
+
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
|
|
227
|
+
if len(outputs) > 1:
|
|
228
|
+
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
|
|
178
229
|
|
|
179
|
-
if isinstance(output_spec, NativeOutput):
|
|
180
230
|
return NativeOutputSchema(
|
|
181
|
-
cls._build_processor(
|
|
182
|
-
_flatten_output_spec(
|
|
183
|
-
name=
|
|
184
|
-
description=
|
|
185
|
-
strict=
|
|
186
|
-
)
|
|
231
|
+
processor=cls._build_processor(
|
|
232
|
+
_flatten_output_spec(output.outputs),
|
|
233
|
+
name=output.name,
|
|
234
|
+
description=output.description,
|
|
235
|
+
strict=output.strict,
|
|
236
|
+
),
|
|
237
|
+
allows_deferred_tool_calls=allows_deferred_tool_calls,
|
|
187
238
|
)
|
|
188
|
-
elif isinstance(
|
|
239
|
+
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
|
|
240
|
+
if len(outputs) > 1:
|
|
241
|
+
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
|
|
242
|
+
|
|
189
243
|
return PromptedOutputSchema(
|
|
190
|
-
cls._build_processor(
|
|
191
|
-
_flatten_output_spec(
|
|
192
|
-
name=
|
|
193
|
-
description=
|
|
244
|
+
processor=cls._build_processor(
|
|
245
|
+
_flatten_output_spec(output.outputs),
|
|
246
|
+
name=output.name,
|
|
247
|
+
description=output.description,
|
|
194
248
|
),
|
|
195
|
-
template=
|
|
249
|
+
template=output.template,
|
|
250
|
+
allows_deferred_tool_calls=allows_deferred_tool_calls,
|
|
196
251
|
)
|
|
197
252
|
|
|
198
253
|
text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
|
|
199
254
|
tool_outputs: Sequence[ToolOutput[OutputDataT]] = []
|
|
200
255
|
other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = []
|
|
201
|
-
for output in
|
|
256
|
+
for output in outputs:
|
|
202
257
|
if output is str:
|
|
203
258
|
text_outputs.append(cast(type[str], output))
|
|
204
259
|
elif isinstance(output, TextOutput):
|
|
205
260
|
text_outputs.append(output)
|
|
206
261
|
elif isinstance(output, ToolOutput):
|
|
207
262
|
tool_outputs.append(output)
|
|
263
|
+
elif isinstance(output, NativeOutput):
|
|
264
|
+
# We can never get here because this is checked for above.
|
|
265
|
+
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
|
|
266
|
+
elif isinstance(output, PromptedOutput):
|
|
267
|
+
# We can never get here because this is checked for above.
|
|
268
|
+
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
|
|
208
269
|
else:
|
|
209
270
|
other_outputs.append(output)
|
|
210
271
|
|
|
211
|
-
|
|
272
|
+
toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict)
|
|
212
273
|
|
|
213
274
|
if len(text_outputs) > 0:
|
|
214
275
|
if len(text_outputs) > 1:
|
|
215
|
-
raise UserError('Only one
|
|
276
|
+
raise UserError('Only one `str` or `TextOutput` is allowed.')
|
|
216
277
|
text_output = text_outputs[0]
|
|
217
278
|
|
|
218
279
|
text_output_schema = None
|
|
219
280
|
if isinstance(text_output, TextOutput):
|
|
220
281
|
text_output_schema = PlainTextOutputProcessor(text_output.output_function)
|
|
221
282
|
|
|
222
|
-
if
|
|
223
|
-
return
|
|
283
|
+
if toolset:
|
|
284
|
+
return ToolOrTextOutputSchema(
|
|
285
|
+
processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls
|
|
286
|
+
)
|
|
224
287
|
else:
|
|
225
|
-
return
|
|
288
|
+
return PlainTextOutputSchema(
|
|
289
|
+
processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls
|
|
290
|
+
)
|
|
226
291
|
|
|
227
292
|
if len(tool_outputs) > 0:
|
|
228
|
-
return ToolOutputSchema(
|
|
293
|
+
return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
|
|
229
294
|
|
|
230
295
|
if len(other_outputs) > 0:
|
|
231
296
|
schema = OutputSchemaWithoutMode(
|
|
232
297
|
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
|
|
233
|
-
|
|
298
|
+
toolset=toolset,
|
|
299
|
+
allows_deferred_tool_calls=allows_deferred_tool_calls,
|
|
234
300
|
)
|
|
235
301
|
if default_mode:
|
|
236
302
|
schema = schema.with_default_mode(default_mode)
|
|
237
303
|
return schema
|
|
238
304
|
|
|
239
|
-
raise UserError('
|
|
240
|
-
|
|
241
|
-
@staticmethod
|
|
242
|
-
def _build_tools(
|
|
243
|
-
outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
|
|
244
|
-
name: str | None = None,
|
|
245
|
-
description: str | None = None,
|
|
246
|
-
strict: bool | None = None,
|
|
247
|
-
) -> dict[str, OutputTool[OutputDataT]]:
|
|
248
|
-
tools: dict[str, OutputTool[OutputDataT]] = {}
|
|
249
|
-
|
|
250
|
-
default_name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
251
|
-
default_description = description
|
|
252
|
-
default_strict = strict
|
|
253
|
-
|
|
254
|
-
multiple = len(outputs) > 1
|
|
255
|
-
for output in outputs:
|
|
256
|
-
name = None
|
|
257
|
-
description = None
|
|
258
|
-
strict = None
|
|
259
|
-
if isinstance(output, ToolOutput):
|
|
260
|
-
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
261
|
-
name = output.name
|
|
262
|
-
description = output.description
|
|
263
|
-
strict = output.strict
|
|
264
|
-
|
|
265
|
-
output = output.output
|
|
266
|
-
|
|
267
|
-
description = description or default_description
|
|
268
|
-
if strict is None:
|
|
269
|
-
strict = default_strict
|
|
270
|
-
|
|
271
|
-
processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
|
|
272
|
-
|
|
273
|
-
if name is None:
|
|
274
|
-
name = default_name
|
|
275
|
-
if multiple:
|
|
276
|
-
name += f'_{processor.object_def.name}'
|
|
277
|
-
|
|
278
|
-
i = 1
|
|
279
|
-
original_name = name
|
|
280
|
-
while name in tools:
|
|
281
|
-
i += 1
|
|
282
|
-
name = f'{original_name}_{i}'
|
|
283
|
-
|
|
284
|
-
tools[name] = OutputTool(name=name, processor=processor, multiple=multiple)
|
|
285
|
-
|
|
286
|
-
return tools
|
|
305
|
+
raise UserError('At least one output type must be provided.')
|
|
287
306
|
|
|
288
307
|
@staticmethod
|
|
289
308
|
def _build_processor(
|
|
@@ -315,32 +334,39 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
315
334
|
@dataclass(init=False)
|
|
316
335
|
class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
|
|
317
336
|
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
|
|
318
|
-
|
|
337
|
+
_toolset: OutputToolset[Any] | None
|
|
319
338
|
|
|
320
339
|
def __init__(
|
|
321
340
|
self,
|
|
322
341
|
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
|
|
323
|
-
|
|
342
|
+
toolset: OutputToolset[Any] | None,
|
|
343
|
+
allows_deferred_tool_calls: bool,
|
|
324
344
|
):
|
|
345
|
+
super().__init__(allows_deferred_tool_calls)
|
|
325
346
|
self.processor = processor
|
|
326
|
-
self.
|
|
347
|
+
self._toolset = toolset
|
|
327
348
|
|
|
328
349
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
329
350
|
if mode == 'native':
|
|
330
|
-
return NativeOutputSchema(
|
|
351
|
+
return NativeOutputSchema(
|
|
352
|
+
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
353
|
+
)
|
|
331
354
|
elif mode == 'prompted':
|
|
332
|
-
return PromptedOutputSchema(
|
|
355
|
+
return PromptedOutputSchema(
|
|
356
|
+
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
357
|
+
)
|
|
333
358
|
elif mode == 'tool':
|
|
334
|
-
return ToolOutputSchema(self.
|
|
359
|
+
return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls)
|
|
335
360
|
else:
|
|
336
361
|
assert_never(mode)
|
|
337
362
|
|
|
338
363
|
@property
|
|
339
|
-
def
|
|
340
|
-
"""Get the
|
|
341
|
-
# We return
|
|
342
|
-
# At that point we may
|
|
343
|
-
|
|
364
|
+
def toolset(self) -> OutputToolset[Any] | None:
|
|
365
|
+
"""Get the toolset for this output schema."""
|
|
366
|
+
# We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor.
|
|
367
|
+
# At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
|
|
368
|
+
# but we cover ourselves just in case we end up using the tool output mode.
|
|
369
|
+
return self._toolset
|
|
344
370
|
|
|
345
371
|
|
|
346
372
|
class TextOutputSchema(OutputSchema[OutputDataT], ABC):
|
|
@@ -411,7 +437,7 @@ class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
411
437
|
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
412
438
|
"""Raise an error if the mode is not supported by the model."""
|
|
413
439
|
if not profile.supports_json_schema_output:
|
|
414
|
-
raise UserError('
|
|
440
|
+
raise UserError('Native structured output is not supported by the model.')
|
|
415
441
|
|
|
416
442
|
async def process(
|
|
417
443
|
self,
|
|
@@ -491,10 +517,11 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
491
517
|
|
|
492
518
|
@dataclass(init=False)
|
|
493
519
|
class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
494
|
-
|
|
520
|
+
_toolset: OutputToolset[Any] | None
|
|
495
521
|
|
|
496
|
-
def __init__(self,
|
|
497
|
-
|
|
522
|
+
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool):
|
|
523
|
+
super().__init__(allows_deferred_tool_calls)
|
|
524
|
+
self._toolset = toolset
|
|
498
525
|
|
|
499
526
|
@property
|
|
500
527
|
def mode(self) -> OutputMode:
|
|
@@ -506,36 +533,9 @@ class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
|
506
533
|
raise UserError('Output tools are not supported by the model.')
|
|
507
534
|
|
|
508
535
|
@property
|
|
509
|
-
def
|
|
510
|
-
"""Get the
|
|
511
|
-
return self.
|
|
512
|
-
|
|
513
|
-
def tool_names(self) -> list[str]:
|
|
514
|
-
"""Return the names of the tools."""
|
|
515
|
-
return list(self.tools.keys())
|
|
516
|
-
|
|
517
|
-
def tool_defs(self) -> list[ToolDefinition]:
|
|
518
|
-
"""Get tool definitions to register with the model."""
|
|
519
|
-
return [t.tool_def for t in self.tools.values()]
|
|
520
|
-
|
|
521
|
-
def find_named_tool(
|
|
522
|
-
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
523
|
-
) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
|
|
524
|
-
"""Find a tool that matches one of the calls, with a specific name."""
|
|
525
|
-
for part in parts: # pragma: no branch
|
|
526
|
-
if isinstance(part, _messages.ToolCallPart): # pragma: no branch
|
|
527
|
-
if part.tool_name == tool_name:
|
|
528
|
-
return part, self.tools[tool_name]
|
|
529
|
-
|
|
530
|
-
def find_tool(
|
|
531
|
-
self,
|
|
532
|
-
parts: Iterable[_messages.ModelResponsePart],
|
|
533
|
-
) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]:
|
|
534
|
-
"""Find a tool that matches one of the calls."""
|
|
535
|
-
for part in parts:
|
|
536
|
-
if isinstance(part, _messages.ToolCallPart): # pragma: no branch
|
|
537
|
-
if result := self.tools.get(part.tool_name):
|
|
538
|
-
yield part, result
|
|
536
|
+
def toolset(self) -> OutputToolset[Any] | None:
|
|
537
|
+
"""Get the toolset for this output schema."""
|
|
538
|
+
return self._toolset
|
|
539
539
|
|
|
540
540
|
|
|
541
541
|
@dataclass(init=False)
|
|
@@ -543,10 +543,11 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem
|
|
|
543
543
|
def __init__(
|
|
544
544
|
self,
|
|
545
545
|
processor: PlainTextOutputProcessor[OutputDataT] | None,
|
|
546
|
-
|
|
546
|
+
toolset: OutputToolset[Any] | None,
|
|
547
|
+
allows_deferred_tool_calls: bool,
|
|
547
548
|
):
|
|
549
|
+
super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls)
|
|
548
550
|
self.processor = processor
|
|
549
|
-
self._tools = tools
|
|
550
551
|
|
|
551
552
|
@property
|
|
552
553
|
def mode(self) -> OutputMode:
|
|
@@ -579,7 +580,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]):
|
|
|
579
580
|
class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
580
581
|
object_def: OutputObjectDefinition
|
|
581
582
|
outer_typed_dict_key: str | None = None
|
|
582
|
-
|
|
583
|
+
validator: SchemaValidator
|
|
583
584
|
_function_schema: _function_schema.FunctionSchema | None = None
|
|
584
585
|
|
|
585
586
|
def __init__(
|
|
@@ -592,7 +593,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
592
593
|
):
|
|
593
594
|
if inspect.isfunction(output) or inspect.ismethod(output):
|
|
594
595
|
self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema)
|
|
595
|
-
self.
|
|
596
|
+
self.validator = self._function_schema.validator
|
|
596
597
|
json_schema = self._function_schema.json_schema
|
|
597
598
|
json_schema['description'] = self._function_schema.description
|
|
598
599
|
else:
|
|
@@ -608,7 +609,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
608
609
|
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
609
610
|
|
|
610
611
|
# Really a PluggableSchemaValidator, but it's API-compatible
|
|
611
|
-
self.
|
|
612
|
+
self.validator = cast(SchemaValidator, type_adapter.validator)
|
|
612
613
|
json_schema = _utils.check_object_json_schema(
|
|
613
614
|
type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
|
|
614
615
|
)
|
|
@@ -652,11 +653,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
652
653
|
Either the validated output data (left) or a retry message (right).
|
|
653
654
|
"""
|
|
654
655
|
try:
|
|
655
|
-
|
|
656
|
-
if isinstance(data, str):
|
|
657
|
-
output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
|
|
658
|
-
else:
|
|
659
|
-
output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
|
|
656
|
+
output = self.validate(data, allow_partial)
|
|
660
657
|
except ValidationError as e:
|
|
661
658
|
if wrap_validation_errors:
|
|
662
659
|
m = _messages.RetryPromptPart(
|
|
@@ -664,22 +661,42 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
664
661
|
)
|
|
665
662
|
raise ToolRetryError(m) from e
|
|
666
663
|
else:
|
|
667
|
-
raise
|
|
664
|
+
raise
|
|
665
|
+
|
|
666
|
+
try:
|
|
667
|
+
output = await self.call(output, run_context)
|
|
668
|
+
except ModelRetry as r:
|
|
669
|
+
if wrap_validation_errors:
|
|
670
|
+
m = _messages.RetryPromptPart(
|
|
671
|
+
content=r.message,
|
|
672
|
+
)
|
|
673
|
+
raise ToolRetryError(m) from r
|
|
674
|
+
else:
|
|
675
|
+
raise # pragma: no cover
|
|
676
|
+
|
|
677
|
+
return output
|
|
678
|
+
|
|
679
|
+
def validate(
|
|
680
|
+
self,
|
|
681
|
+
data: str | dict[str, Any] | None,
|
|
682
|
+
allow_partial: bool = False,
|
|
683
|
+
) -> dict[str, Any]:
|
|
684
|
+
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
685
|
+
if isinstance(data, str):
|
|
686
|
+
return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
|
|
687
|
+
else:
|
|
688
|
+
return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
|
|
668
689
|
|
|
690
|
+
async def call(
|
|
691
|
+
self,
|
|
692
|
+
output: Any,
|
|
693
|
+
run_context: RunContext[AgentDepsT],
|
|
694
|
+
):
|
|
669
695
|
if k := self.outer_typed_dict_key:
|
|
670
696
|
output = output[k]
|
|
671
697
|
|
|
672
698
|
if self._function_schema:
|
|
673
|
-
|
|
674
|
-
output = await self._function_schema.call(output, run_context)
|
|
675
|
-
except ModelRetry as r:
|
|
676
|
-
if wrap_validation_errors:
|
|
677
|
-
m = _messages.RetryPromptPart(
|
|
678
|
-
content=r.message,
|
|
679
|
-
)
|
|
680
|
-
raise ToolRetryError(m) from r
|
|
681
|
-
else:
|
|
682
|
-
raise # pragma: lax no cover
|
|
699
|
+
output = await execute_output_function_with_span(self._function_schema, run_context, output)
|
|
683
700
|
|
|
684
701
|
return output
|
|
685
702
|
|
|
@@ -839,9 +856,8 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
839
856
|
wrap_validation_errors: bool = True,
|
|
840
857
|
) -> OutputDataT:
|
|
841
858
|
args = {self._str_argument_name: data}
|
|
842
|
-
|
|
843
859
|
try:
|
|
844
|
-
output = await self._function_schema
|
|
860
|
+
output = await execute_output_function_with_span(self._function_schema, run_context, args)
|
|
845
861
|
except ModelRetry as r:
|
|
846
862
|
if wrap_validation_errors:
|
|
847
863
|
m = _messages.RetryPromptPart(
|
|
@@ -849,91 +865,145 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
849
865
|
)
|
|
850
866
|
raise ToolRetryError(m) from r
|
|
851
867
|
else:
|
|
852
|
-
raise # pragma:
|
|
868
|
+
raise # pragma: no cover
|
|
853
869
|
|
|
854
870
|
return cast(OutputDataT, output)
|
|
855
871
|
|
|
856
872
|
|
|
857
873
|
@dataclass(init=False)
|
|
858
|
-
class
|
|
859
|
-
|
|
860
|
-
tool_def: ToolDefinition
|
|
874
|
+
class OutputToolset(AbstractToolset[AgentDepsT]):
|
|
875
|
+
"""A toolset that contains contains output tools for agent output types."""
|
|
861
876
|
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
877
|
+
_tool_defs: list[ToolDefinition]
|
|
878
|
+
"""The tool definitions for the output tools in this toolset."""
|
|
879
|
+
processors: dict[str, ObjectOutputProcessor[Any]]
|
|
880
|
+
"""The processors for the output tools in this toolset."""
|
|
881
|
+
max_retries: int
|
|
882
|
+
output_validators: list[OutputValidator[AgentDepsT, Any]]
|
|
865
883
|
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
884
|
+
@classmethod
|
|
885
|
+
def build(
|
|
886
|
+
cls,
|
|
887
|
+
outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
|
|
888
|
+
name: str | None = None,
|
|
889
|
+
description: str | None = None,
|
|
890
|
+
strict: bool | None = None,
|
|
891
|
+
) -> Self | None:
|
|
892
|
+
if len(outputs) == 0:
|
|
893
|
+
return None
|
|
871
894
|
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
description=description,
|
|
875
|
-
parameters_json_schema=object_def.json_schema,
|
|
876
|
-
strict=object_def.strict,
|
|
877
|
-
outer_typed_dict_key=processor.outer_typed_dict_key,
|
|
878
|
-
)
|
|
895
|
+
processors: dict[str, ObjectOutputProcessor[Any]] = {}
|
|
896
|
+
tool_defs: list[ToolDefinition] = []
|
|
879
897
|
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
run_context: RunContext[AgentDepsT],
|
|
884
|
-
allow_partial: bool = False,
|
|
885
|
-
wrap_validation_errors: bool = True,
|
|
886
|
-
) -> OutputDataT:
|
|
887
|
-
"""Process an output message.
|
|
898
|
+
default_name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
899
|
+
default_description = description
|
|
900
|
+
default_strict = strict
|
|
888
901
|
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
902
|
+
multiple = len(outputs) > 1
|
|
903
|
+
for output in outputs:
|
|
904
|
+
name = None
|
|
905
|
+
description = None
|
|
906
|
+
strict = None
|
|
907
|
+
if isinstance(output, ToolOutput):
|
|
908
|
+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
909
|
+
name = output.name
|
|
910
|
+
description = output.description
|
|
911
|
+
strict = output.strict
|
|
894
912
|
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
913
|
+
output = output.output
|
|
914
|
+
|
|
915
|
+
description = description or default_description
|
|
916
|
+
if strict is None:
|
|
917
|
+
strict = default_strict
|
|
918
|
+
|
|
919
|
+
processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
|
|
920
|
+
object_def = processor.object_def
|
|
921
|
+
|
|
922
|
+
if name is None:
|
|
923
|
+
name = default_name
|
|
924
|
+
if multiple:
|
|
925
|
+
name += f'_{object_def.name}'
|
|
926
|
+
|
|
927
|
+
i = 1
|
|
928
|
+
original_name = name
|
|
929
|
+
while name in processors:
|
|
930
|
+
i += 1
|
|
931
|
+
name = f'{original_name}_{i}'
|
|
932
|
+
|
|
933
|
+
description = object_def.description
|
|
934
|
+
if not description:
|
|
935
|
+
description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
|
|
936
|
+
if multiple:
|
|
937
|
+
description = f'{object_def.name}: {description}'
|
|
938
|
+
|
|
939
|
+
tool_def = ToolDefinition(
|
|
940
|
+
name=name,
|
|
941
|
+
description=description,
|
|
942
|
+
parameters_json_schema=object_def.json_schema,
|
|
943
|
+
strict=object_def.strict,
|
|
944
|
+
outer_typed_dict_key=processor.outer_typed_dict_key,
|
|
945
|
+
kind='output',
|
|
901
946
|
)
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
947
|
+
processors[name] = processor
|
|
948
|
+
tool_defs.append(tool_def)
|
|
949
|
+
|
|
950
|
+
return cls(processors=processors, tool_defs=tool_defs)
|
|
951
|
+
|
|
952
|
+
def __init__(
|
|
953
|
+
self,
|
|
954
|
+
tool_defs: list[ToolDefinition],
|
|
955
|
+
processors: dict[str, ObjectOutputProcessor[Any]],
|
|
956
|
+
max_retries: int = 1,
|
|
957
|
+
output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None,
|
|
958
|
+
):
|
|
959
|
+
self.processors = processors
|
|
960
|
+
self._tool_defs = tool_defs
|
|
961
|
+
self.max_retries = max_retries
|
|
962
|
+
self.output_validators = output_validators or []
|
|
963
|
+
|
|
964
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
965
|
+
return {
|
|
966
|
+
tool_def.name: ToolsetTool(
|
|
967
|
+
toolset=self,
|
|
968
|
+
tool_def=tool_def,
|
|
969
|
+
max_retries=self.max_retries,
|
|
970
|
+
args_validator=self.processors[tool_def.name].validator,
|
|
971
|
+
)
|
|
972
|
+
for tool_def in self._tool_defs
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
async def call_tool(
|
|
976
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
977
|
+
) -> Any:
|
|
978
|
+
output = await self.processors[name].call(tool_args, ctx)
|
|
979
|
+
for validator in self.output_validators:
|
|
980
|
+
output = await validator.validate(output, ctx, wrap_validation_errors=False)
|
|
981
|
+
return output
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
@overload
|
|
985
|
+
def _flatten_output_spec(
|
|
986
|
+
output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]],
|
|
987
|
+
) -> Sequence[OutputTypeOrFunction[T]]: ...
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
@overload
|
|
991
|
+
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ...
|
|
924
992
|
|
|
925
993
|
|
|
926
|
-
def _flatten_output_spec(output_spec:
|
|
927
|
-
outputs: Sequence[T]
|
|
994
|
+
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]:
|
|
995
|
+
outputs: Sequence[OutputSpec[T]]
|
|
928
996
|
if isinstance(output_spec, Sequence):
|
|
929
997
|
outputs = output_spec
|
|
930
998
|
else:
|
|
931
999
|
outputs = (output_spec,)
|
|
932
1000
|
|
|
933
|
-
outputs_flat: list[T] = []
|
|
1001
|
+
outputs_flat: list[_OutputSpecItem[T]] = []
|
|
934
1002
|
for output in outputs:
|
|
935
|
-
if
|
|
1003
|
+
if isinstance(output, Sequence):
|
|
1004
|
+
outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output)))
|
|
1005
|
+
elif union_types := _utils.get_union_args(output):
|
|
936
1006
|
outputs_flat.extend(union_types)
|
|
937
1007
|
else:
|
|
938
|
-
outputs_flat.append(output)
|
|
1008
|
+
outputs_flat.append(cast(_OutputSpecItem[T], output))
|
|
939
1009
|
return outputs_flat
|