pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +29 -35
- pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
- pydantic_ai/_output.py +266 -119
- pydantic_ai/agent.py +15 -15
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +39 -3
- pydantic_ai/models/anthropic.py +4 -0
- pydantic_ai/models/bedrock.py +43 -16
- pydantic_ai/models/cohere.py +4 -0
- pydantic_ai/models/gemini.py +78 -109
- pydantic_ai/models/google.py +47 -112
- pydantic_ai/models/groq.py +17 -2
- pydantic_ai/models/mistral.py +4 -0
- pydantic_ai/models/openai.py +25 -158
- pydantic_ai/profiles/__init__.py +39 -0
- pydantic_ai/{models → profiles}/_json_schema.py +23 -2
- pydantic_ai/profiles/amazon.py +9 -0
- pydantic_ai/profiles/anthropic.py +8 -0
- pydantic_ai/profiles/cohere.py +8 -0
- pydantic_ai/profiles/deepseek.py +8 -0
- pydantic_ai/profiles/google.py +100 -0
- pydantic_ai/profiles/grok.py +8 -0
- pydantic_ai/profiles/meta.py +9 -0
- pydantic_ai/profiles/mistral.py +8 -0
- pydantic_ai/profiles/openai.py +144 -0
- pydantic_ai/profiles/qwen.py +9 -0
- pydantic_ai/providers/__init__.py +18 -0
- pydantic_ai/providers/anthropic.py +5 -0
- pydantic_ai/providers/azure.py +34 -0
- pydantic_ai/providers/bedrock.py +60 -1
- pydantic_ai/providers/cohere.py +5 -0
- pydantic_ai/providers/deepseek.py +12 -0
- pydantic_ai/providers/fireworks.py +99 -0
- pydantic_ai/providers/google.py +5 -0
- pydantic_ai/providers/google_gla.py +5 -0
- pydantic_ai/providers/google_vertex.py +5 -0
- pydantic_ai/providers/grok.py +82 -0
- pydantic_ai/providers/groq.py +25 -0
- pydantic_ai/providers/mistral.py +5 -0
- pydantic_ai/providers/openai.py +5 -0
- pydantic_ai/providers/openrouter.py +36 -0
- pydantic_ai/providers/together.py +96 -0
- pydantic_ai/result.py +34 -103
- pydantic_ai/tools.py +29 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/METADATA +4 -4
- pydantic_ai_slim-0.2.13.dist-info/RECORD +73 -0
- pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py
CHANGED
|
@@ -1,22 +1,55 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from collections.abc import Awaitable, Iterable, Iterator
|
|
4
|
+
from collections.abc import Awaitable, Iterable, Iterator, Sequence
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
7
7
|
|
|
8
8
|
from pydantic import TypeAdapter, ValidationError
|
|
9
|
-
from
|
|
9
|
+
from pydantic_core import SchemaValidator
|
|
10
|
+
from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin
|
|
10
11
|
from typing_inspection import typing_objects
|
|
11
12
|
from typing_inspection.introspection import is_union_origin
|
|
12
13
|
|
|
13
|
-
from . import _utils, messages as _messages
|
|
14
|
+
from . import _function_schema, _utils, messages as _messages
|
|
14
15
|
from .exceptions import ModelRetry
|
|
15
|
-
from .
|
|
16
|
-
from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
|
|
16
|
+
from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition
|
|
17
17
|
|
|
18
18
|
T = TypeVar('T')
|
|
19
19
|
"""An invariant TypeVar."""
|
|
20
|
+
OutputDataT_inv = TypeVar('OutputDataT_inv', default=str)
|
|
21
|
+
"""
|
|
22
|
+
An invariant type variable for the result data of a model.
|
|
23
|
+
|
|
24
|
+
We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used
|
|
25
|
+
in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types
|
|
26
|
+
possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and
|
|
27
|
+
changing it would have negative consequences for the ergonomics of the library.
|
|
28
|
+
|
|
29
|
+
At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would
|
|
30
|
+
resolve these potential variance issues.
|
|
31
|
+
"""
|
|
32
|
+
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
|
|
33
|
+
"""Covariant type variable for the result data type of a run."""
|
|
34
|
+
|
|
35
|
+
OutputValidatorFunc = Union[
|
|
36
|
+
Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv],
|
|
37
|
+
Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]],
|
|
38
|
+
Callable[[OutputDataT_inv], OutputDataT_inv],
|
|
39
|
+
Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]],
|
|
40
|
+
]
|
|
41
|
+
"""
|
|
42
|
+
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
|
|
43
|
+
|
|
44
|
+
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
|
|
45
|
+
* may or may not be async
|
|
46
|
+
|
|
47
|
+
Usage `OutputValidatorFunc[AgentDepsT, T]`.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
|
|
52
|
+
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
|
|
20
53
|
|
|
21
54
|
|
|
22
55
|
@dataclass
|
|
@@ -76,69 +109,135 @@ class ToolRetryError(Exception):
|
|
|
76
109
|
super().__init__()
|
|
77
110
|
|
|
78
111
|
|
|
112
|
+
@dataclass(init=False)
|
|
113
|
+
class ToolOutput(Generic[OutputDataT]):
|
|
114
|
+
"""Marker class to use tools for outputs, and customize the tool."""
|
|
115
|
+
|
|
116
|
+
output_type: SimpleOutputType[OutputDataT]
|
|
117
|
+
name: str | None
|
|
118
|
+
description: str | None
|
|
119
|
+
max_retries: int | None
|
|
120
|
+
strict: bool | None
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
type_: SimpleOutputType[OutputDataT],
|
|
125
|
+
*,
|
|
126
|
+
name: str | None = None,
|
|
127
|
+
description: str | None = None,
|
|
128
|
+
max_retries: int | None = None,
|
|
129
|
+
strict: bool | None = None,
|
|
130
|
+
):
|
|
131
|
+
self.output_type = type_
|
|
132
|
+
self.name = name
|
|
133
|
+
self.description = description
|
|
134
|
+
self.max_retries = max_retries
|
|
135
|
+
self.strict = strict
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
T_co = TypeVar('T_co', covariant=True)
|
|
139
|
+
# output_type=Type or output_type=function or output_type=object.method
|
|
140
|
+
SimpleOutputType = TypeAliasType(
|
|
141
|
+
'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,)
|
|
142
|
+
)
|
|
143
|
+
# output_type=ToolOutput(<see above>) or <see above>
|
|
144
|
+
SimpleOutputTypeOrMarker = TypeAliasType(
|
|
145
|
+
'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,)
|
|
146
|
+
)
|
|
147
|
+
# output_type=<see above> or [<see above>, ...]
|
|
148
|
+
OutputType = TypeAliasType(
|
|
149
|
+
'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
79
153
|
@dataclass
|
|
80
154
|
class OutputSchema(Generic[OutputDataT]):
|
|
81
|
-
"""Model the final
|
|
155
|
+
"""Model the final output from an agent run.
|
|
82
156
|
|
|
83
157
|
Similar to `Tool` but for the final output of running an agent.
|
|
84
158
|
"""
|
|
85
159
|
|
|
86
|
-
tools: dict[str,
|
|
160
|
+
tools: dict[str, OutputTool[OutputDataT]]
|
|
87
161
|
allow_text_output: bool
|
|
88
162
|
|
|
89
163
|
@classmethod
|
|
90
164
|
def build(
|
|
91
|
-
cls: type[OutputSchema[
|
|
92
|
-
output_type:
|
|
165
|
+
cls: type[OutputSchema[OutputDataT]],
|
|
166
|
+
output_type: OutputType[OutputDataT],
|
|
93
167
|
name: str | None = None,
|
|
94
168
|
description: str | None = None,
|
|
95
169
|
strict: bool | None = None,
|
|
96
|
-
) -> OutputSchema[
|
|
97
|
-
"""Build an OutputSchema dataclass from
|
|
170
|
+
) -> OutputSchema[OutputDataT] | None:
|
|
171
|
+
"""Build an OutputSchema dataclass from an output type."""
|
|
98
172
|
if output_type is str:
|
|
99
173
|
return None
|
|
100
174
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
description = output_type.description
|
|
105
|
-
output_type_ = output_type.output_type
|
|
106
|
-
strict = output_type.strict
|
|
175
|
+
output_types: Sequence[SimpleOutputTypeOrMarker[OutputDataT]]
|
|
176
|
+
if isinstance(output_type, Sequence):
|
|
177
|
+
output_types = output_type
|
|
107
178
|
else:
|
|
108
|
-
|
|
179
|
+
output_types = (output_type,)
|
|
109
180
|
|
|
110
|
-
|
|
111
|
-
|
|
181
|
+
output_types_flat: list[SimpleOutputTypeOrMarker[OutputDataT]] = []
|
|
182
|
+
for output_type in output_types:
|
|
183
|
+
if union_types := get_union_args(output_type):
|
|
184
|
+
output_types_flat.extend(union_types)
|
|
185
|
+
else:
|
|
186
|
+
output_types_flat.append(output_type)
|
|
187
|
+
|
|
188
|
+
allow_text_output = False
|
|
189
|
+
if str in output_types_flat:
|
|
112
190
|
allow_text_output = True
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
191
|
+
output_types_flat = [t for t in output_types_flat if t is not str]
|
|
192
|
+
|
|
193
|
+
multiple = len(output_types_flat) > 1
|
|
194
|
+
|
|
195
|
+
default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
196
|
+
default_tool_description = description
|
|
197
|
+
default_tool_strict = strict
|
|
198
|
+
|
|
199
|
+
tools: dict[str, OutputTool[OutputDataT]] = {}
|
|
200
|
+
for output_type in output_types_flat:
|
|
201
|
+
tool_name = None
|
|
202
|
+
tool_description = None
|
|
203
|
+
tool_strict = None
|
|
204
|
+
if isinstance(output_type, ToolOutput):
|
|
205
|
+
tool_output_type = output_type.output_type
|
|
206
|
+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
207
|
+
tool_name = output_type.name
|
|
208
|
+
tool_description = output_type.description
|
|
209
|
+
tool_strict = output_type.strict
|
|
210
|
+
else:
|
|
211
|
+
tool_output_type = output_type
|
|
212
|
+
|
|
213
|
+
if tool_name is None:
|
|
214
|
+
tool_name = default_tool_name
|
|
215
|
+
if multiple:
|
|
216
|
+
tool_name += f'_{tool_output_type.__name__}'
|
|
217
|
+
|
|
218
|
+
i = 1
|
|
219
|
+
original_tool_name = tool_name
|
|
220
|
+
while tool_name in tools:
|
|
221
|
+
i += 1
|
|
222
|
+
tool_name = f'{original_tool_name}_{i}'
|
|
223
|
+
|
|
224
|
+
tool_description = tool_description or default_tool_description
|
|
225
|
+
if tool_strict is None:
|
|
226
|
+
tool_strict = default_tool_strict
|
|
227
|
+
|
|
228
|
+
parameters_schema = OutputObjectSchema(
|
|
229
|
+
output_type=tool_output_type, description=tool_description, strict=tool_strict
|
|
135
230
|
)
|
|
231
|
+
tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple)
|
|
136
232
|
|
|
137
|
-
return cls(
|
|
233
|
+
return cls(
|
|
234
|
+
tools=tools,
|
|
235
|
+
allow_text_output=allow_text_output,
|
|
236
|
+
)
|
|
138
237
|
|
|
139
238
|
def find_named_tool(
|
|
140
239
|
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
141
|
-
) -> tuple[_messages.ToolCallPart,
|
|
240
|
+
) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
|
|
142
241
|
"""Find a tool that matches one of the calls, with a specific name."""
|
|
143
242
|
for part in parts: # pragma: no branch
|
|
144
243
|
if isinstance(part, _messages.ToolCallPart): # pragma: no branch
|
|
@@ -148,7 +247,7 @@ class OutputSchema(Generic[OutputDataT]):
|
|
|
148
247
|
def find_tool(
|
|
149
248
|
self,
|
|
150
249
|
parts: Iterable[_messages.ModelResponsePart],
|
|
151
|
-
) -> Iterator[tuple[_messages.ToolCallPart,
|
|
250
|
+
) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]:
|
|
152
251
|
"""Find a tool that matches one of the calls."""
|
|
153
252
|
for part in parts:
|
|
154
253
|
if isinstance(part, _messages.ToolCallPart): # pragma: no branch
|
|
@@ -164,64 +263,138 @@ class OutputSchema(Generic[OutputDataT]):
|
|
|
164
263
|
return [t.tool_def for t in self.tools.values()]
|
|
165
264
|
|
|
166
265
|
|
|
167
|
-
|
|
266
|
+
def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool:
|
|
267
|
+
return output_schema is None or output_schema.allow_text_output
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass
|
|
271
|
+
class OutputObjectDefinition:
|
|
272
|
+
name: str
|
|
273
|
+
json_schema: ObjectJsonSchema
|
|
274
|
+
description: str | None = None
|
|
275
|
+
strict: bool | None = None
|
|
168
276
|
|
|
169
277
|
|
|
170
278
|
@dataclass(init=False)
|
|
171
|
-
class
|
|
172
|
-
|
|
173
|
-
|
|
279
|
+
class OutputObjectSchema(Generic[OutputDataT]):
|
|
280
|
+
definition: OutputObjectDefinition
|
|
281
|
+
validator: SchemaValidator
|
|
282
|
+
function_schema: _function_schema.FunctionSchema | None = None
|
|
283
|
+
outer_typed_dict_key: str | None = None
|
|
174
284
|
|
|
175
285
|
def __init__(
|
|
176
|
-
self,
|
|
286
|
+
self,
|
|
287
|
+
*,
|
|
288
|
+
output_type: SimpleOutputType[OutputDataT],
|
|
289
|
+
name: str | None = None,
|
|
290
|
+
description: str | None = None,
|
|
291
|
+
strict: bool | None = None,
|
|
177
292
|
):
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
self.
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
parameters_json_schema = _utils.check_object_json_schema(
|
|
184
|
-
self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
|
|
185
|
-
)
|
|
293
|
+
if inspect.isfunction(output_type) or inspect.ismethod(output_type):
|
|
294
|
+
self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema)
|
|
295
|
+
self.validator = self.function_schema.validator
|
|
296
|
+
json_schema = self.function_schema.json_schema
|
|
297
|
+
json_schema['description'] = self.function_schema.description
|
|
186
298
|
else:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
299
|
+
type_adapter: TypeAdapter[Any]
|
|
300
|
+
if _utils.is_model_like(output_type):
|
|
301
|
+
type_adapter = TypeAdapter(output_type)
|
|
302
|
+
else:
|
|
303
|
+
self.outer_typed_dict_key = 'response'
|
|
304
|
+
response_data_typed_dict = TypedDict( # noqa: UP013
|
|
305
|
+
'response_data_typed_dict',
|
|
306
|
+
{'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm]
|
|
307
|
+
)
|
|
308
|
+
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
309
|
+
|
|
310
|
+
# Really a PluggableSchemaValidator, but it's API-compatible
|
|
311
|
+
self.validator = cast(SchemaValidator, type_adapter.validator)
|
|
312
|
+
json_schema = _utils.check_object_json_schema(
|
|
313
|
+
type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
|
|
196
314
|
)
|
|
197
|
-
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
|
|
198
|
-
parameters_json_schema.pop('title')
|
|
199
315
|
|
|
200
|
-
|
|
316
|
+
if self.outer_typed_dict_key:
|
|
317
|
+
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
|
|
318
|
+
json_schema.pop('title')
|
|
319
|
+
|
|
320
|
+
if json_schema_description := json_schema.pop('description', None):
|
|
201
321
|
if description is None:
|
|
202
|
-
|
|
322
|
+
description = json_schema_description
|
|
203
323
|
else:
|
|
204
|
-
|
|
324
|
+
description = f'{description}. {json_schema_description}'
|
|
325
|
+
|
|
326
|
+
self.definition = OutputObjectDefinition(
|
|
327
|
+
name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME),
|
|
328
|
+
description=description,
|
|
329
|
+
json_schema=json_schema,
|
|
330
|
+
strict=strict,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
async def process(
|
|
334
|
+
self,
|
|
335
|
+
data: str | dict[str, Any] | None,
|
|
336
|
+
run_context: RunContext[AgentDepsT],
|
|
337
|
+
allow_partial: bool = False,
|
|
338
|
+
) -> OutputDataT:
|
|
339
|
+
"""Process an output message, performing validation and (if necessary) calling the output function.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
data: The output data to validate.
|
|
343
|
+
run_context: The current run context.
|
|
344
|
+
allow_partial: If true, allow partial validation.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Either the validated output data (left) or a retry message (right).
|
|
348
|
+
"""
|
|
349
|
+
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
350
|
+
if isinstance(data, str):
|
|
351
|
+
output = self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
|
|
205
352
|
else:
|
|
206
|
-
|
|
353
|
+
output = self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
|
|
354
|
+
|
|
355
|
+
if self.function_schema:
|
|
356
|
+
output = await self.function_schema.call(output, run_context)
|
|
357
|
+
|
|
358
|
+
if k := self.outer_typed_dict_key:
|
|
359
|
+
output = output[k]
|
|
360
|
+
return output
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@dataclass(init=False)
|
|
364
|
+
class OutputTool(Generic[OutputDataT]):
|
|
365
|
+
parameters_schema: OutputObjectSchema[OutputDataT]
|
|
366
|
+
tool_def: ToolDefinition
|
|
367
|
+
|
|
368
|
+
def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool):
|
|
369
|
+
self.parameters_schema = parameters_schema
|
|
370
|
+
definition = parameters_schema.definition
|
|
371
|
+
|
|
372
|
+
description = definition.description
|
|
373
|
+
if not description:
|
|
374
|
+
description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
|
|
207
375
|
if multiple:
|
|
208
|
-
|
|
376
|
+
description = f'{definition.name}: {description}'
|
|
209
377
|
|
|
210
378
|
self.tool_def = ToolDefinition(
|
|
211
379
|
name=name,
|
|
212
|
-
description=
|
|
213
|
-
parameters_json_schema=
|
|
214
|
-
|
|
215
|
-
|
|
380
|
+
description=description,
|
|
381
|
+
parameters_json_schema=definition.json_schema,
|
|
382
|
+
strict=definition.strict,
|
|
383
|
+
outer_typed_dict_key=parameters_schema.outer_typed_dict_key,
|
|
216
384
|
)
|
|
217
385
|
|
|
218
|
-
def
|
|
219
|
-
self,
|
|
386
|
+
async def process(
|
|
387
|
+
self,
|
|
388
|
+
tool_call: _messages.ToolCallPart,
|
|
389
|
+
run_context: RunContext[AgentDepsT],
|
|
390
|
+
allow_partial: bool = False,
|
|
391
|
+
wrap_validation_errors: bool = True,
|
|
220
392
|
) -> OutputDataT:
|
|
221
|
-
"""
|
|
393
|
+
"""Process an output message.
|
|
222
394
|
|
|
223
395
|
Args:
|
|
224
396
|
tool_call: The tool call from the LLM to validate.
|
|
397
|
+
run_context: The current run context.
|
|
225
398
|
allow_partial: If true, allow partial validation.
|
|
226
399
|
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
227
400
|
|
|
@@ -229,57 +402,31 @@ class OutputSchemaTool(Generic[OutputDataT]):
|
|
|
229
402
|
Either the validated output data (left) or a retry message (right).
|
|
230
403
|
"""
|
|
231
404
|
try:
|
|
232
|
-
|
|
233
|
-
if isinstance(tool_call.args, str):
|
|
234
|
-
output = self.type_adapter.validate_json(
|
|
235
|
-
tool_call.args or '{}', experimental_allow_partial=pyd_allow_partial
|
|
236
|
-
)
|
|
237
|
-
else:
|
|
238
|
-
output = self.type_adapter.validate_python(
|
|
239
|
-
tool_call.args or {}, experimental_allow_partial=pyd_allow_partial
|
|
240
|
-
)
|
|
405
|
+
output = await self.parameters_schema.process(tool_call.args, run_context, allow_partial=allow_partial)
|
|
241
406
|
except ValidationError as e:
|
|
242
407
|
if wrap_validation_errors:
|
|
243
408
|
m = _messages.RetryPromptPart(
|
|
244
409
|
tool_name=tool_call.tool_name,
|
|
245
|
-
content=e.errors(include_url=False),
|
|
410
|
+
content=e.errors(include_url=False, include_context=False),
|
|
246
411
|
tool_call_id=tool_call.tool_call_id,
|
|
247
412
|
)
|
|
248
413
|
raise ToolRetryError(m) from e
|
|
249
414
|
else:
|
|
250
415
|
raise # pragma: lax no cover
|
|
416
|
+
except ModelRetry as r:
|
|
417
|
+
if wrap_validation_errors:
|
|
418
|
+
m = _messages.RetryPromptPart(
|
|
419
|
+
tool_name=tool_call.tool_name,
|
|
420
|
+
content=r.message,
|
|
421
|
+
tool_call_id=tool_call.tool_call_id,
|
|
422
|
+
)
|
|
423
|
+
raise ToolRetryError(m) from r
|
|
424
|
+
else:
|
|
425
|
+
raise # pragma: lax no cover
|
|
251
426
|
else:
|
|
252
|
-
if k := self.tool_def.outer_typed_dict_key:
|
|
253
|
-
output = output[k]
|
|
254
427
|
return output
|
|
255
428
|
|
|
256
429
|
|
|
257
|
-
def union_tool_name(base_name: str | None, union_arg: Any) -> str:
|
|
258
|
-
return f'{base_name or DEFAULT_OUTPUT_TOOL_NAME}_{union_arg_name(union_arg)}'
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
def union_arg_name(union_arg: Any) -> str:
|
|
262
|
-
return union_arg.__name__
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
def extract_str_from_union(output_type: Any) -> _utils.Option[Any]:
|
|
266
|
-
"""Extract the string type from a Union, return the remaining union or remaining type."""
|
|
267
|
-
union_args = get_union_args(output_type)
|
|
268
|
-
if any(t is str for t in union_args):
|
|
269
|
-
remain_args: list[Any] = []
|
|
270
|
-
includes_str = False
|
|
271
|
-
for arg in union_args:
|
|
272
|
-
if arg is str:
|
|
273
|
-
includes_str = True
|
|
274
|
-
else:
|
|
275
|
-
remain_args.append(arg)
|
|
276
|
-
if includes_str: # pragma: no branch
|
|
277
|
-
if len(remain_args) == 1:
|
|
278
|
-
return _utils.Some(remain_args[0])
|
|
279
|
-
else:
|
|
280
|
-
return _utils.Some(Union[tuple(remain_args)]) # pragma: no cover
|
|
281
|
-
|
|
282
|
-
|
|
283
430
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
284
431
|
"""Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple."""
|
|
285
432
|
if typing_objects.is_typealiastype(tp):
|
pydantic_ai/agent.py
CHANGED
|
@@ -29,7 +29,7 @@ from . import (
|
|
|
29
29
|
usage as _usage,
|
|
30
30
|
)
|
|
31
31
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
32
|
-
from .result import FinalResult, OutputDataT, StreamedRunResult
|
|
32
|
+
from .result import FinalResult, OutputDataT, StreamedRunResult
|
|
33
33
|
from .settings import ModelSettings, merge_model_settings
|
|
34
34
|
from .tools import (
|
|
35
35
|
AgentDepsT,
|
|
@@ -127,7 +127,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
127
127
|
be merged with this value, with the runtime argument taking priority.
|
|
128
128
|
"""
|
|
129
129
|
|
|
130
|
-
output_type:
|
|
130
|
+
output_type: _output.OutputType[OutputDataT]
|
|
131
131
|
"""
|
|
132
132
|
The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
|
|
133
133
|
"""
|
|
@@ -162,7 +162,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
162
162
|
self,
|
|
163
163
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
164
164
|
*,
|
|
165
|
-
output_type:
|
|
165
|
+
output_type: _output.OutputType[OutputDataT] = str,
|
|
166
166
|
instructions: str
|
|
167
167
|
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
168
168
|
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
@@ -199,7 +199,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
199
199
|
name: str | None = None,
|
|
200
200
|
model_settings: ModelSettings | None = None,
|
|
201
201
|
retries: int = 1,
|
|
202
|
-
result_tool_name: str =
|
|
202
|
+
result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME,
|
|
203
203
|
result_tool_description: str | None = None,
|
|
204
204
|
result_retries: int | None = None,
|
|
205
205
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
@@ -214,7 +214,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
214
214
|
self,
|
|
215
215
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
216
216
|
*,
|
|
217
|
-
# TODO change this back to `output_type:
|
|
217
|
+
# TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads
|
|
218
218
|
output_type: Any = str,
|
|
219
219
|
instructions: str
|
|
220
220
|
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
@@ -374,7 +374,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
374
374
|
self,
|
|
375
375
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
376
376
|
*,
|
|
377
|
-
output_type:
|
|
377
|
+
output_type: _output.OutputType[RunOutputDataT],
|
|
378
378
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
379
379
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
380
380
|
deps: AgentDepsT = None,
|
|
@@ -404,7 +404,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
404
404
|
self,
|
|
405
405
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
406
406
|
*,
|
|
407
|
-
output_type:
|
|
407
|
+
output_type: _output.OutputType[RunOutputDataT] | None = None,
|
|
408
408
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
409
409
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
410
410
|
deps: AgentDepsT = None,
|
|
@@ -492,7 +492,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
492
492
|
self,
|
|
493
493
|
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
494
494
|
*,
|
|
495
|
-
output_type:
|
|
495
|
+
output_type: _output.OutputType[RunOutputDataT],
|
|
496
496
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
497
497
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
498
498
|
deps: AgentDepsT = None,
|
|
@@ -524,7 +524,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
524
524
|
self,
|
|
525
525
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
526
526
|
*,
|
|
527
|
-
output_type:
|
|
527
|
+
output_type: _output.OutputType[RunOutputDataT] | None = None,
|
|
528
528
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
529
529
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
530
530
|
deps: AgentDepsT = None,
|
|
@@ -770,7 +770,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
770
770
|
self,
|
|
771
771
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
772
772
|
*,
|
|
773
|
-
output_type:
|
|
773
|
+
output_type: _output.OutputType[RunOutputDataT] | None = None,
|
|
774
774
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
775
775
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
776
776
|
deps: AgentDepsT = None,
|
|
@@ -800,7 +800,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
800
800
|
self,
|
|
801
801
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
802
802
|
*,
|
|
803
|
-
output_type:
|
|
803
|
+
output_type: _output.OutputType[RunOutputDataT] | None = None,
|
|
804
804
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
805
805
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
806
806
|
deps: AgentDepsT = None,
|
|
@@ -883,7 +883,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
883
883
|
self,
|
|
884
884
|
user_prompt: str | Sequence[_messages.UserContent],
|
|
885
885
|
*,
|
|
886
|
-
output_type:
|
|
886
|
+
output_type: _output.OutputType[RunOutputDataT],
|
|
887
887
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
888
888
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
889
889
|
deps: AgentDepsT = None,
|
|
@@ -914,7 +914,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
914
914
|
self,
|
|
915
915
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
916
916
|
*,
|
|
917
|
-
output_type:
|
|
917
|
+
output_type: _output.OutputType[RunOutputDataT] | None = None,
|
|
918
918
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
919
919
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
920
920
|
deps: AgentDepsT = None,
|
|
@@ -994,7 +994,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
994
994
|
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
995
995
|
new_part = maybe_part_event.part
|
|
996
996
|
if isinstance(new_part, _messages.TextPart):
|
|
997
|
-
if
|
|
997
|
+
if _output.allow_text_output(output_schema):
|
|
998
998
|
return FinalResult(s, None, None)
|
|
999
999
|
elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
|
|
1000
1000
|
for call, _ in output_schema.find_tool([new_part]):
|
|
@@ -1628,7 +1628,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1628
1628
|
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1629
1629
|
|
|
1630
1630
|
def _prepare_output_schema(
|
|
1631
|
-
self, output_type:
|
|
1631
|
+
self, output_type: _output.OutputType[RunOutputDataT] | None
|
|
1632
1632
|
) -> _output.OutputSchema[RunOutputDataT] | None:
|
|
1633
1633
|
if output_type is not None:
|
|
1634
1634
|
if self._output_validators:
|
pydantic_ai/mcp.py
CHANGED
|
@@ -373,7 +373,7 @@ class MCPServerHTTP(MCPServer):
|
|
|
373
373
|
url=self.url,
|
|
374
374
|
headers=self.headers,
|
|
375
375
|
timeout=timedelta(seconds=self.timeout),
|
|
376
|
-
sse_read_timeout=timedelta(self.sse_read_timeout),
|
|
376
|
+
sse_read_timeout=timedelta(seconds=self.sse_read_timeout),
|
|
377
377
|
) as (read_stream, write_stream, _):
|
|
378
378
|
yield read_stream, write_stream
|
|
379
379
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -378,7 +378,7 @@ class ToolReturnPart:
|
|
|
378
378
|
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
|
|
379
379
|
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
|
|
380
380
|
if isinstance(self.content, dict):
|
|
381
|
-
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
|
|
381
|
+
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
|
|
382
382
|
else:
|
|
383
383
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
384
384
|
|
|
@@ -589,7 +589,7 @@ class ModelResponse:
|
|
|
589
589
|
kind: Literal['response'] = 'response'
|
|
590
590
|
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
591
591
|
|
|
592
|
-
vendor_details: dict[str, Any] | None = field(default=None
|
|
592
|
+
vendor_details: dict[str, Any] | None = field(default=None)
|
|
593
593
|
"""Additional vendor-specific details in a serializable format.
|
|
594
594
|
|
|
595
595
|
This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
|