pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. pydantic_ai/_agent_graph.py +29 -35
  2. pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
  3. pydantic_ai/_output.py +266 -119
  4. pydantic_ai/agent.py +15 -15
  5. pydantic_ai/mcp.py +1 -1
  6. pydantic_ai/messages.py +2 -2
  7. pydantic_ai/models/__init__.py +39 -3
  8. pydantic_ai/models/anthropic.py +4 -0
  9. pydantic_ai/models/bedrock.py +43 -16
  10. pydantic_ai/models/cohere.py +4 -0
  11. pydantic_ai/models/gemini.py +78 -109
  12. pydantic_ai/models/google.py +47 -112
  13. pydantic_ai/models/groq.py +17 -2
  14. pydantic_ai/models/mistral.py +4 -0
  15. pydantic_ai/models/openai.py +25 -158
  16. pydantic_ai/profiles/__init__.py +39 -0
  17. pydantic_ai/{models → profiles}/_json_schema.py +23 -2
  18. pydantic_ai/profiles/amazon.py +9 -0
  19. pydantic_ai/profiles/anthropic.py +8 -0
  20. pydantic_ai/profiles/cohere.py +8 -0
  21. pydantic_ai/profiles/deepseek.py +8 -0
  22. pydantic_ai/profiles/google.py +100 -0
  23. pydantic_ai/profiles/grok.py +8 -0
  24. pydantic_ai/profiles/meta.py +9 -0
  25. pydantic_ai/profiles/mistral.py +8 -0
  26. pydantic_ai/profiles/openai.py +144 -0
  27. pydantic_ai/profiles/qwen.py +9 -0
  28. pydantic_ai/providers/__init__.py +18 -0
  29. pydantic_ai/providers/anthropic.py +5 -0
  30. pydantic_ai/providers/azure.py +34 -0
  31. pydantic_ai/providers/bedrock.py +60 -1
  32. pydantic_ai/providers/cohere.py +5 -0
  33. pydantic_ai/providers/deepseek.py +12 -0
  34. pydantic_ai/providers/fireworks.py +99 -0
  35. pydantic_ai/providers/google.py +5 -0
  36. pydantic_ai/providers/google_gla.py +5 -0
  37. pydantic_ai/providers/google_vertex.py +5 -0
  38. pydantic_ai/providers/grok.py +82 -0
  39. pydantic_ai/providers/groq.py +25 -0
  40. pydantic_ai/providers/mistral.py +5 -0
  41. pydantic_ai/providers/openai.py +5 -0
  42. pydantic_ai/providers/openrouter.py +36 -0
  43. pydantic_ai/providers/together.py +96 -0
  44. pydantic_ai/result.py +34 -103
  45. pydantic_ai/tools.py +29 -59
  46. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/METADATA +4 -4
  47. pydantic_ai_slim-0.2.13.dist-info/RECORD +73 -0
  48. pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
  49. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/WHEEL +0 -0
  50. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/entry_points.txt +0 -0
  51. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_output.py CHANGED
@@ -1,22 +1,55 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
- from collections.abc import Awaitable, Iterable, Iterator
4
+ from collections.abc import Awaitable, Iterable, Iterator, Sequence
5
5
  from dataclasses import dataclass, field
6
6
  from typing import Any, Callable, Generic, Literal, Union, cast
7
7
 
8
8
  from pydantic import TypeAdapter, ValidationError
9
- from typing_extensions import TypedDict, TypeVar, get_args, get_origin
9
+ from pydantic_core import SchemaValidator
10
+ from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin
10
11
  from typing_inspection import typing_objects
11
12
  from typing_inspection.introspection import is_union_origin
12
13
 
13
- from . import _utils, messages as _messages
14
+ from . import _function_schema, _utils, messages as _messages
14
15
  from .exceptions import ModelRetry
15
- from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput
16
- from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
16
+ from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition
17
17
 
18
18
  T = TypeVar('T')
19
19
  """An invariant TypeVar."""
20
+ OutputDataT_inv = TypeVar('OutputDataT_inv', default=str)
21
+ """
22
+ An invariant type variable for the result data of a model.
23
+
24
+ We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used
25
+ in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types
26
+ possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and
27
+ changing it would have negative consequences for the ergonomics of the library.
28
+
29
+ At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would
30
+ resolve these potential variance issues.
31
+ """
32
+ OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
33
+ """Covariant type variable for the result data type of a run."""
34
+
35
+ OutputValidatorFunc = Union[
36
+ Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv],
37
+ Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]],
38
+ Callable[[OutputDataT_inv], OutputDataT_inv],
39
+ Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]],
40
+ ]
41
+ """
42
+ A function that always takes and returns the same type of data (which is the result type of an agent run), and:
43
+
44
+ * may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
45
+ * may or may not be async
46
+
47
+ Usage `OutputValidatorFunc[AgentDepsT, T]`.
48
+ """
49
+
50
+
51
+ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
52
+ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
20
53
 
21
54
 
22
55
  @dataclass
@@ -76,69 +109,135 @@ class ToolRetryError(Exception):
76
109
  super().__init__()
77
110
 
78
111
 
112
+ @dataclass(init=False)
113
+ class ToolOutput(Generic[OutputDataT]):
114
+ """Marker class to use tools for outputs, and customize the tool."""
115
+
116
+ output_type: SimpleOutputType[OutputDataT]
117
+ name: str | None
118
+ description: str | None
119
+ max_retries: int | None
120
+ strict: bool | None
121
+
122
+ def __init__(
123
+ self,
124
+ type_: SimpleOutputType[OutputDataT],
125
+ *,
126
+ name: str | None = None,
127
+ description: str | None = None,
128
+ max_retries: int | None = None,
129
+ strict: bool | None = None,
130
+ ):
131
+ self.output_type = type_
132
+ self.name = name
133
+ self.description = description
134
+ self.max_retries = max_retries
135
+ self.strict = strict
136
+
137
+
138
+ T_co = TypeVar('T_co', covariant=True)
139
+ # output_type=Type or output_type=function or output_type=object.method
140
+ SimpleOutputType = TypeAliasType(
141
+ 'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,)
142
+ )
143
+ # output_type=ToolOutput(<see above>) or <see above>
144
+ SimpleOutputTypeOrMarker = TypeAliasType(
145
+ 'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,)
146
+ )
147
+ # output_type=<see above> or [<see above>, ...]
148
+ OutputType = TypeAliasType(
149
+ 'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,)
150
+ )
151
+
152
+
79
153
  @dataclass
80
154
  class OutputSchema(Generic[OutputDataT]):
81
- """Model the final response from an agent run.
155
+ """Model the final output from an agent run.
82
156
 
83
157
  Similar to `Tool` but for the final output of running an agent.
84
158
  """
85
159
 
86
- tools: dict[str, OutputSchemaTool[OutputDataT]]
160
+ tools: dict[str, OutputTool[OutputDataT]]
87
161
  allow_text_output: bool
88
162
 
89
163
  @classmethod
90
164
  def build(
91
- cls: type[OutputSchema[T]],
92
- output_type: type[T] | ToolOutput[T],
165
+ cls: type[OutputSchema[OutputDataT]],
166
+ output_type: OutputType[OutputDataT],
93
167
  name: str | None = None,
94
168
  description: str | None = None,
95
169
  strict: bool | None = None,
96
- ) -> OutputSchema[T] | None:
97
- """Build an OutputSchema dataclass from a response type."""
170
+ ) -> OutputSchema[OutputDataT] | None:
171
+ """Build an OutputSchema dataclass from an output type."""
98
172
  if output_type is str:
99
173
  return None
100
174
 
101
- if isinstance(output_type, ToolOutput):
102
- # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
103
- name = output_type.name
104
- description = output_type.description
105
- output_type_ = output_type.output_type
106
- strict = output_type.strict
175
+ output_types: Sequence[SimpleOutputTypeOrMarker[OutputDataT]]
176
+ if isinstance(output_type, Sequence):
177
+ output_types = output_type
107
178
  else:
108
- output_type_ = output_type
179
+ output_types = (output_type,)
109
180
 
110
- if output_type_option := extract_str_from_union(output_type):
111
- output_type_ = output_type_option.value
181
+ output_types_flat: list[SimpleOutputTypeOrMarker[OutputDataT]] = []
182
+ for output_type in output_types:
183
+ if union_types := get_union_args(output_type):
184
+ output_types_flat.extend(union_types)
185
+ else:
186
+ output_types_flat.append(output_type)
187
+
188
+ allow_text_output = False
189
+ if str in output_types_flat:
112
190
  allow_text_output = True
113
- else:
114
- allow_text_output = False
115
-
116
- tools: dict[str, OutputSchemaTool[T]] = {}
117
- if args := get_union_args(output_type_):
118
- for i, arg in enumerate(args, start=1):
119
- tool_name = raw_tool_name = union_tool_name(name, arg)
120
- while tool_name in tools:
121
- tool_name = f'{raw_tool_name}_{i}'
122
- tools[tool_name] = cast(
123
- OutputSchemaTool[T],
124
- OutputSchemaTool(
125
- output_type=arg, name=tool_name, description=description, multiple=True, strict=strict
126
- ),
127
- )
128
- else:
129
- name = name or DEFAULT_OUTPUT_TOOL_NAME
130
- tools[name] = cast(
131
- OutputSchemaTool[T],
132
- OutputSchemaTool(
133
- output_type=output_type_, name=name, description=description, multiple=False, strict=strict
134
- ),
191
+ output_types_flat = [t for t in output_types_flat if t is not str]
192
+
193
+ multiple = len(output_types_flat) > 1
194
+
195
+ default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME
196
+ default_tool_description = description
197
+ default_tool_strict = strict
198
+
199
+ tools: dict[str, OutputTool[OutputDataT]] = {}
200
+ for output_type in output_types_flat:
201
+ tool_name = None
202
+ tool_description = None
203
+ tool_strict = None
204
+ if isinstance(output_type, ToolOutput):
205
+ tool_output_type = output_type.output_type
206
+ # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
207
+ tool_name = output_type.name
208
+ tool_description = output_type.description
209
+ tool_strict = output_type.strict
210
+ else:
211
+ tool_output_type = output_type
212
+
213
+ if tool_name is None:
214
+ tool_name = default_tool_name
215
+ if multiple:
216
+ tool_name += f'_{tool_output_type.__name__}'
217
+
218
+ i = 1
219
+ original_tool_name = tool_name
220
+ while tool_name in tools:
221
+ i += 1
222
+ tool_name = f'{original_tool_name}_{i}'
223
+
224
+ tool_description = tool_description or default_tool_description
225
+ if tool_strict is None:
226
+ tool_strict = default_tool_strict
227
+
228
+ parameters_schema = OutputObjectSchema(
229
+ output_type=tool_output_type, description=tool_description, strict=tool_strict
135
230
  )
231
+ tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple)
136
232
 
137
- return cls(tools=tools, allow_text_output=allow_text_output)
233
+ return cls(
234
+ tools=tools,
235
+ allow_text_output=allow_text_output,
236
+ )
138
237
 
139
238
  def find_named_tool(
140
239
  self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
141
- ) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None:
240
+ ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
142
241
  """Find a tool that matches one of the calls, with a specific name."""
143
242
  for part in parts: # pragma: no branch
144
243
  if isinstance(part, _messages.ToolCallPart): # pragma: no branch
@@ -148,7 +247,7 @@ class OutputSchema(Generic[OutputDataT]):
148
247
  def find_tool(
149
248
  self,
150
249
  parts: Iterable[_messages.ModelResponsePart],
151
- ) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]:
250
+ ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]:
152
251
  """Find a tool that matches one of the calls."""
153
252
  for part in parts:
154
253
  if isinstance(part, _messages.ToolCallPart): # pragma: no branch
@@ -164,64 +263,138 @@ class OutputSchema(Generic[OutputDataT]):
164
263
  return [t.tool_def for t in self.tools.values()]
165
264
 
166
265
 
167
- DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
266
+ def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool:
267
+ return output_schema is None or output_schema.allow_text_output
268
+
269
+
270
+ @dataclass
271
+ class OutputObjectDefinition:
272
+ name: str
273
+ json_schema: ObjectJsonSchema
274
+ description: str | None = None
275
+ strict: bool | None = None
168
276
 
169
277
 
170
278
  @dataclass(init=False)
171
- class OutputSchemaTool(Generic[OutputDataT]):
172
- tool_def: ToolDefinition
173
- type_adapter: TypeAdapter[Any]
279
+ class OutputObjectSchema(Generic[OutputDataT]):
280
+ definition: OutputObjectDefinition
281
+ validator: SchemaValidator
282
+ function_schema: _function_schema.FunctionSchema | None = None
283
+ outer_typed_dict_key: str | None = None
174
284
 
175
285
  def __init__(
176
- self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None
286
+ self,
287
+ *,
288
+ output_type: SimpleOutputType[OutputDataT],
289
+ name: str | None = None,
290
+ description: str | None = None,
291
+ strict: bool | None = None,
177
292
  ):
178
- """Build a OutputSchemaTool from a response type."""
179
- if _utils.is_model_like(output_type):
180
- self.type_adapter = TypeAdapter(output_type)
181
- outer_typed_dict_key: str | None = None
182
- # noinspection PyArgumentList
183
- parameters_json_schema = _utils.check_object_json_schema(
184
- self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
185
- )
293
+ if inspect.isfunction(output_type) or inspect.ismethod(output_type):
294
+ self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema)
295
+ self.validator = self.function_schema.validator
296
+ json_schema = self.function_schema.json_schema
297
+ json_schema['description'] = self.function_schema.description
186
298
  else:
187
- response_data_typed_dict = TypedDict( # noqa: UP013
188
- 'response_data_typed_dict',
189
- {'response': output_type}, # pyright: ignore[reportInvalidTypeForm]
190
- )
191
- self.type_adapter = TypeAdapter(response_data_typed_dict)
192
- outer_typed_dict_key = 'response'
193
- # noinspection PyArgumentList
194
- parameters_json_schema = _utils.check_object_json_schema(
195
- self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
299
+ type_adapter: TypeAdapter[Any]
300
+ if _utils.is_model_like(output_type):
301
+ type_adapter = TypeAdapter(output_type)
302
+ else:
303
+ self.outer_typed_dict_key = 'response'
304
+ response_data_typed_dict = TypedDict( # noqa: UP013
305
+ 'response_data_typed_dict',
306
+ {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm]
307
+ )
308
+ type_adapter = TypeAdapter(response_data_typed_dict)
309
+
310
+ # Really a PluggableSchemaValidator, but it's API-compatible
311
+ self.validator = cast(SchemaValidator, type_adapter.validator)
312
+ json_schema = _utils.check_object_json_schema(
313
+ type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
196
314
  )
197
- # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
198
- parameters_json_schema.pop('title')
199
315
 
200
- if json_schema_description := parameters_json_schema.pop('description', None):
316
+ if self.outer_typed_dict_key:
317
+ # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
318
+ json_schema.pop('title')
319
+
320
+ if json_schema_description := json_schema.pop('description', None):
201
321
  if description is None:
202
- tool_description = json_schema_description
322
+ description = json_schema_description
203
323
  else:
204
- tool_description = f'{description}. {json_schema_description}' # pragma: no cover
324
+ description = f'{description}. {json_schema_description}'
325
+
326
+ self.definition = OutputObjectDefinition(
327
+ name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME),
328
+ description=description,
329
+ json_schema=json_schema,
330
+ strict=strict,
331
+ )
332
+
333
+ async def process(
334
+ self,
335
+ data: str | dict[str, Any] | None,
336
+ run_context: RunContext[AgentDepsT],
337
+ allow_partial: bool = False,
338
+ ) -> OutputDataT:
339
+ """Process an output message, performing validation and (if necessary) calling the output function.
340
+
341
+ Args:
342
+ data: The output data to validate.
343
+ run_context: The current run context.
344
+ allow_partial: If true, allow partial validation.
345
+
346
+ Returns:
347
+ Either the validated output data (left) or a retry message (right).
348
+ """
349
+ pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
350
+ if isinstance(data, str):
351
+ output = self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial)
205
352
  else:
206
- tool_description = description or DEFAULT_DESCRIPTION
353
+ output = self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial)
354
+
355
+ if self.function_schema:
356
+ output = await self.function_schema.call(output, run_context)
357
+
358
+ if k := self.outer_typed_dict_key:
359
+ output = output[k]
360
+ return output
361
+
362
+
363
+ @dataclass(init=False)
364
+ class OutputTool(Generic[OutputDataT]):
365
+ parameters_schema: OutputObjectSchema[OutputDataT]
366
+ tool_def: ToolDefinition
367
+
368
+ def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool):
369
+ self.parameters_schema = parameters_schema
370
+ definition = parameters_schema.definition
371
+
372
+ description = definition.description
373
+ if not description:
374
+ description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
207
375
  if multiple:
208
- tool_description = f'{union_arg_name(output_type)}: {tool_description}'
376
+ description = f'{definition.name}: {description}'
209
377
 
210
378
  self.tool_def = ToolDefinition(
211
379
  name=name,
212
- description=tool_description,
213
- parameters_json_schema=parameters_json_schema,
214
- outer_typed_dict_key=outer_typed_dict_key,
215
- strict=strict,
380
+ description=description,
381
+ parameters_json_schema=definition.json_schema,
382
+ strict=definition.strict,
383
+ outer_typed_dict_key=parameters_schema.outer_typed_dict_key,
216
384
  )
217
385
 
218
- def validate(
219
- self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
386
+ async def process(
387
+ self,
388
+ tool_call: _messages.ToolCallPart,
389
+ run_context: RunContext[AgentDepsT],
390
+ allow_partial: bool = False,
391
+ wrap_validation_errors: bool = True,
220
392
  ) -> OutputDataT:
221
- """Validate an output message.
393
+ """Process an output message.
222
394
 
223
395
  Args:
224
396
  tool_call: The tool call from the LLM to validate.
397
+ run_context: The current run context.
225
398
  allow_partial: If true, allow partial validation.
226
399
  wrap_validation_errors: If true, wrap the validation errors in a retry message.
227
400
 
@@ -229,57 +402,31 @@ class OutputSchemaTool(Generic[OutputDataT]):
229
402
  Either the validated output data (left) or a retry message (right).
230
403
  """
231
404
  try:
232
- pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
233
- if isinstance(tool_call.args, str):
234
- output = self.type_adapter.validate_json(
235
- tool_call.args or '{}', experimental_allow_partial=pyd_allow_partial
236
- )
237
- else:
238
- output = self.type_adapter.validate_python(
239
- tool_call.args or {}, experimental_allow_partial=pyd_allow_partial
240
- )
405
+ output = await self.parameters_schema.process(tool_call.args, run_context, allow_partial=allow_partial)
241
406
  except ValidationError as e:
242
407
  if wrap_validation_errors:
243
408
  m = _messages.RetryPromptPart(
244
409
  tool_name=tool_call.tool_name,
245
- content=e.errors(include_url=False),
410
+ content=e.errors(include_url=False, include_context=False),
246
411
  tool_call_id=tool_call.tool_call_id,
247
412
  )
248
413
  raise ToolRetryError(m) from e
249
414
  else:
250
415
  raise # pragma: lax no cover
416
+ except ModelRetry as r:
417
+ if wrap_validation_errors:
418
+ m = _messages.RetryPromptPart(
419
+ tool_name=tool_call.tool_name,
420
+ content=r.message,
421
+ tool_call_id=tool_call.tool_call_id,
422
+ )
423
+ raise ToolRetryError(m) from r
424
+ else:
425
+ raise # pragma: lax no cover
251
426
  else:
252
- if k := self.tool_def.outer_typed_dict_key:
253
- output = output[k]
254
427
  return output
255
428
 
256
429
 
257
- def union_tool_name(base_name: str | None, union_arg: Any) -> str:
258
- return f'{base_name or DEFAULT_OUTPUT_TOOL_NAME}_{union_arg_name(union_arg)}'
259
-
260
-
261
- def union_arg_name(union_arg: Any) -> str:
262
- return union_arg.__name__
263
-
264
-
265
- def extract_str_from_union(output_type: Any) -> _utils.Option[Any]:
266
- """Extract the string type from a Union, return the remaining union or remaining type."""
267
- union_args = get_union_args(output_type)
268
- if any(t is str for t in union_args):
269
- remain_args: list[Any] = []
270
- includes_str = False
271
- for arg in union_args:
272
- if arg is str:
273
- includes_str = True
274
- else:
275
- remain_args.append(arg)
276
- if includes_str: # pragma: no branch
277
- if len(remain_args) == 1:
278
- return _utils.Some(remain_args[0])
279
- else:
280
- return _utils.Some(Union[tuple(remain_args)]) # pragma: no cover
281
-
282
-
283
430
  def get_union_args(tp: Any) -> tuple[Any, ...]:
284
431
  """Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple."""
285
432
  if typing_objects.is_typealiastype(tp):
pydantic_ai/agent.py CHANGED
@@ -29,7 +29,7 @@ from . import (
29
29
  usage as _usage,
30
30
  )
31
31
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
32
- from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput
32
+ from .result import FinalResult, OutputDataT, StreamedRunResult
33
33
  from .settings import ModelSettings, merge_model_settings
34
34
  from .tools import (
35
35
  AgentDepsT,
@@ -127,7 +127,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
127
127
  be merged with this value, with the runtime argument taking priority.
128
128
  """
129
129
 
130
- output_type: type[OutputDataT] | ToolOutput[OutputDataT]
130
+ output_type: _output.OutputType[OutputDataT]
131
131
  """
132
132
  The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
133
133
  """
@@ -162,7 +162,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
162
162
  self,
163
163
  model: models.Model | models.KnownModelName | str | None = None,
164
164
  *,
165
- output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,
165
+ output_type: _output.OutputType[OutputDataT] = str,
166
166
  instructions: str
167
167
  | _system_prompt.SystemPromptFunc[AgentDepsT]
168
168
  | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
@@ -199,7 +199,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
199
199
  name: str | None = None,
200
200
  model_settings: ModelSettings | None = None,
201
201
  retries: int = 1,
202
- result_tool_name: str = 'final_result',
202
+ result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME,
203
203
  result_tool_description: str | None = None,
204
204
  result_retries: int | None = None,
205
205
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
@@ -214,7 +214,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
214
214
  self,
215
215
  model: models.Model | models.KnownModelName | str | None = None,
216
216
  *,
217
- # TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads
217
+ # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads
218
218
  output_type: Any = str,
219
219
  instructions: str
220
220
  | _system_prompt.SystemPromptFunc[AgentDepsT]
@@ -374,7 +374,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
374
374
  self,
375
375
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
376
376
  *,
377
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT],
377
+ output_type: _output.OutputType[RunOutputDataT],
378
378
  message_history: list[_messages.ModelMessage] | None = None,
379
379
  model: models.Model | models.KnownModelName | str | None = None,
380
380
  deps: AgentDepsT = None,
@@ -404,7 +404,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
404
404
  self,
405
405
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
406
406
  *,
407
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
407
+ output_type: _output.OutputType[RunOutputDataT] | None = None,
408
408
  message_history: list[_messages.ModelMessage] | None = None,
409
409
  model: models.Model | models.KnownModelName | str | None = None,
410
410
  deps: AgentDepsT = None,
@@ -492,7 +492,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
492
492
  self,
493
493
  user_prompt: str | Sequence[_messages.UserContent] | None,
494
494
  *,
495
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT],
495
+ output_type: _output.OutputType[RunOutputDataT],
496
496
  message_history: list[_messages.ModelMessage] | None = None,
497
497
  model: models.Model | models.KnownModelName | str | None = None,
498
498
  deps: AgentDepsT = None,
@@ -524,7 +524,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
524
524
  self,
525
525
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
526
526
  *,
527
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
527
+ output_type: _output.OutputType[RunOutputDataT] | None = None,
528
528
  message_history: list[_messages.ModelMessage] | None = None,
529
529
  model: models.Model | models.KnownModelName | str | None = None,
530
530
  deps: AgentDepsT = None,
@@ -770,7 +770,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
770
770
  self,
771
771
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
772
772
  *,
773
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None,
773
+ output_type: _output.OutputType[RunOutputDataT] | None = None,
774
774
  message_history: list[_messages.ModelMessage] | None = None,
775
775
  model: models.Model | models.KnownModelName | str | None = None,
776
776
  deps: AgentDepsT = None,
@@ -800,7 +800,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
800
800
  self,
801
801
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
802
802
  *,
803
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
803
+ output_type: _output.OutputType[RunOutputDataT] | None = None,
804
804
  message_history: list[_messages.ModelMessage] | None = None,
805
805
  model: models.Model | models.KnownModelName | str | None = None,
806
806
  deps: AgentDepsT = None,
@@ -883,7 +883,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
883
883
  self,
884
884
  user_prompt: str | Sequence[_messages.UserContent],
885
885
  *,
886
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT],
886
+ output_type: _output.OutputType[RunOutputDataT],
887
887
  message_history: list[_messages.ModelMessage] | None = None,
888
888
  model: models.Model | models.KnownModelName | str | None = None,
889
889
  deps: AgentDepsT = None,
@@ -914,7 +914,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
914
914
  self,
915
915
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
916
916
  *,
917
- output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
917
+ output_type: _output.OutputType[RunOutputDataT] | None = None,
918
918
  message_history: list[_messages.ModelMessage] | None = None,
919
919
  model: models.Model | models.KnownModelName | str | None = None,
920
920
  deps: AgentDepsT = None,
@@ -994,7 +994,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
994
994
  if isinstance(maybe_part_event, _messages.PartStartEvent):
995
995
  new_part = maybe_part_event.part
996
996
  if isinstance(new_part, _messages.TextPart):
997
- if _agent_graph.allow_text_output(output_schema):
997
+ if _output.allow_text_output(output_schema):
998
998
  return FinalResult(s, None, None)
999
999
  elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
1000
1000
  for call, _ in output_schema.find_tool([new_part]):
@@ -1628,7 +1628,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1628
1628
  raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1629
1629
 
1630
1630
  def _prepare_output_schema(
1631
- self, output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None
1631
+ self, output_type: _output.OutputType[RunOutputDataT] | None
1632
1632
  ) -> _output.OutputSchema[RunOutputDataT] | None:
1633
1633
  if output_type is not None:
1634
1634
  if self._output_validators:
pydantic_ai/mcp.py CHANGED
@@ -373,7 +373,7 @@ class MCPServerHTTP(MCPServer):
373
373
  url=self.url,
374
374
  headers=self.headers,
375
375
  timeout=timedelta(seconds=self.timeout),
376
- sse_read_timeout=timedelta(self.sse_read_timeout),
376
+ sse_read_timeout=timedelta(seconds=self.sse_read_timeout),
377
377
  ) as (read_stream, write_stream, _):
378
378
  yield read_stream, write_stream
379
379
 
pydantic_ai/messages.py CHANGED
@@ -378,7 +378,7 @@ class ToolReturnPart:
378
378
  """Return a dictionary representation of the content, wrapping non-dict types appropriately."""
379
379
  # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
380
380
  if isinstance(self.content, dict):
381
- return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType] # pragma: no cover
381
+ return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
382
382
  else:
383
383
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
384
384
 
@@ -589,7 +589,7 @@ class ModelResponse:
589
589
  kind: Literal['response'] = 'response'
590
590
  """Message type identifier, this is available on all parts as a discriminator."""
591
591
 
592
- vendor_details: dict[str, Any] | None = field(default=None, repr=False)
592
+ vendor_details: dict[str, Any] | None = field(default=None)
593
593
  """Additional vendor-specific details in a serializable format.
594
594
 
595
595
  This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.