pydantic-ai-slim 0.4.7__tar.gz → 0.4.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (101) hide show
  1. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/PKG-INFO +5 -3
  2. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_tool_manager.py +43 -31
  3. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/ag_ui.py +36 -43
  4. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/agent.py +5 -3
  5. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/mcp.py +19 -19
  6. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/messages.py +10 -5
  7. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/google.py +0 -1
  8. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/result.py +5 -3
  9. pydantic_ai_slim-0.4.9/pydantic_ai/retries.py +249 -0
  10. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/combined.py +4 -3
  11. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pyproject.toml +2 -0
  12. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/.gitignore +0 -0
  13. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/LICENSE +0 -0
  14. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/README.md +0 -0
  15. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/__init__.py +0 -0
  16. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/__main__.py +0 -0
  17. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_a2a.py +0 -0
  18. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_agent_graph.py +0 -0
  19. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_cli.py +0 -0
  20. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_function_schema.py +0 -0
  21. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_griffe.py +0 -0
  22. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_mcp.py +0 -0
  23. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_output.py +0 -0
  24. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_parts_manager.py +0 -0
  25. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_run_context.py +0 -0
  26. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_system_prompt.py +0 -0
  27. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_thinking_part.py +0 -0
  28. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/_utils.py +0 -0
  29. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/common_tools/__init__.py +0 -0
  30. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  31. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/common_tools/tavily.py +0 -0
  32. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/direct.py +0 -0
  33. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/exceptions.py +0 -0
  34. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/ext/__init__.py +0 -0
  35. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/ext/aci.py +0 -0
  36. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/ext/langchain.py +0 -0
  37. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/format_as_xml.py +0 -0
  38. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/format_prompt.py +0 -0
  39. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/__init__.py +0 -0
  40. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/anthropic.py +0 -0
  41. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/bedrock.py +0 -0
  42. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/cohere.py +0 -0
  43. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/fallback.py +0 -0
  44. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/function.py +0 -0
  45. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/gemini.py +0 -0
  46. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/google.py +0 -0
  47. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/groq.py +0 -0
  48. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/huggingface.py +0 -0
  49. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/instrumented.py +0 -0
  50. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/mcp_sampling.py +0 -0
  51. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/mistral.py +0 -0
  52. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/openai.py +0 -0
  53. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/test.py +0 -0
  54. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/models/wrapper.py +0 -0
  55. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/output.py +0 -0
  56. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/__init__.py +0 -0
  57. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/_json_schema.py +0 -0
  58. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/amazon.py +0 -0
  59. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/anthropic.py +0 -0
  60. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/cohere.py +0 -0
  61. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/deepseek.py +0 -0
  62. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/grok.py +0 -0
  63. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/meta.py +0 -0
  64. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/mistral.py +0 -0
  65. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/moonshotai.py +0 -0
  66. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/openai.py +0 -0
  67. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/profiles/qwen.py +0 -0
  68. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/__init__.py +0 -0
  69. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/anthropic.py +0 -0
  70. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/azure.py +0 -0
  71. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/bedrock.py +0 -0
  72. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/cohere.py +0 -0
  73. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/deepseek.py +0 -0
  74. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/fireworks.py +0 -0
  75. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/github.py +0 -0
  76. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/google.py +0 -0
  77. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/google_gla.py +0 -0
  78. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/google_vertex.py +0 -0
  79. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/grok.py +0 -0
  80. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/groq.py +0 -0
  81. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/heroku.py +0 -0
  82. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/huggingface.py +0 -0
  83. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/mistral.py +0 -0
  84. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/moonshotai.py +0 -0
  85. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/openai.py +0 -0
  86. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/openrouter.py +0 -0
  87. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/together.py +0 -0
  88. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/providers/vercel.py +0 -0
  89. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/py.typed +0 -0
  90. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/settings.py +0 -0
  91. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/tools.py +0 -0
  92. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/__init__.py +0 -0
  93. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/abstract.py +0 -0
  94. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/deferred.py +0 -0
  95. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/filtered.py +0 -0
  96. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/function.py +0 -0
  97. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/prefixed.py +0 -0
  98. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/prepared.py +0 -0
  99. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/renamed.py +0 -0
  100. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/toolsets/wrapper.py +0 -0
  101. {pydantic_ai_slim-0.4.7 → pydantic_ai_slim-0.4.9}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.4.7
3
+ Version: 0.4.9
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>, Douwe Maan <douwe@pydantic.dev>
6
6
  License-Expression: MIT
@@ -30,7 +30,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
30
30
  Requires-Dist: griffe>=1.3.2
31
31
  Requires-Dist: httpx>=0.27
32
32
  Requires-Dist: opentelemetry-api>=1.28.0
33
- Requires-Dist: pydantic-graph==0.4.7
33
+ Requires-Dist: pydantic-graph==0.4.9
34
34
  Requires-Dist: pydantic>=2.10
35
35
  Requires-Dist: typing-inspection>=0.4.0
36
36
  Provides-Extra: a2a
@@ -51,7 +51,7 @@ Requires-Dist: cohere>=5.16.0; (platform_system != 'Emscripten') and extra == 'c
51
51
  Provides-Extra: duckduckgo
52
52
  Requires-Dist: ddgs>=9.0.0; extra == 'duckduckgo'
53
53
  Provides-Extra: evals
54
- Requires-Dist: pydantic-evals==0.4.7; extra == 'evals'
54
+ Requires-Dist: pydantic-evals==0.4.9; extra == 'evals'
55
55
  Provides-Extra: google
56
56
  Requires-Dist: google-genai>=1.24.0; extra == 'google'
57
57
  Provides-Extra: groq
@@ -66,6 +66,8 @@ Provides-Extra: mistral
66
66
  Requires-Dist: mistralai>=1.9.2; extra == 'mistral'
67
67
  Provides-Extra: openai
68
68
  Requires-Dist: openai>=1.92.0; extra == 'openai'
69
+ Provides-Extra: retries
70
+ Requires-Dist: tenacity>=8.2.3; extra == 'retries'
69
71
  Provides-Extra: tavily
70
72
  Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
71
73
  Provides-Extra: vertexai
@@ -2,18 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  from collections.abc import Iterable
5
- from dataclasses import dataclass, replace
5
+ from dataclasses import dataclass, field, replace
6
6
  from typing import Any, Generic
7
7
 
8
8
  from pydantic import ValidationError
9
9
  from typing_extensions import assert_never
10
10
 
11
- from pydantic_ai.output import DeferredToolCalls
12
-
13
11
  from . import messages as _messages
14
12
  from ._run_context import AgentDepsT, RunContext
15
13
  from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
16
14
  from .messages import ToolCallPart
15
+ from .output import DeferredToolCalls
17
16
  from .tools import ToolDefinition
18
17
  from .toolsets.abstract import AbstractToolset, ToolsetTool
19
18
 
@@ -28,6 +27,8 @@ class ToolManager(Generic[AgentDepsT]):
28
27
  """The toolset that provides the tools for this run step."""
29
28
  tools: dict[str, ToolsetTool[AgentDepsT]]
30
29
  """The cached tools for this run step."""
30
+ failed_tools: set[str] = field(default_factory=set)
31
+ """Names of tools that failed in this run step."""
31
32
 
32
33
  @classmethod
33
34
  async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
@@ -40,7 +41,10 @@ class ToolManager(Generic[AgentDepsT]):
40
41
 
41
42
  async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
42
43
  """Build a new tool manager for the next run step, carrying over the retries from the current run step."""
43
- return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries))
44
+ retries = {
45
+ failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools
46
+ }
47
+ return await self.__class__.build(self.toolset, replace(ctx, retries=retries))
44
48
 
45
49
  @property
46
50
  def tool_defs(self) -> list[ToolDefinition]:
@@ -54,20 +58,25 @@ class ToolManager(Generic[AgentDepsT]):
54
58
  except KeyError:
55
59
  return None
56
60
 
57
- async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
61
+ async def handle_call(
62
+ self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
63
+ ) -> Any:
58
64
  """Handle a tool call by validating the arguments, calling the tool, and handling retries.
59
65
 
60
66
  Args:
61
67
  call: The tool call part to handle.
62
68
  allow_partial: Whether to allow partial validation of the tool arguments.
69
+ wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
63
70
  """
64
71
  if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
65
72
  # Output tool calls are not traced
66
- return await self._call_tool(call, allow_partial)
73
+ return await self._call_tool(call, allow_partial, wrap_validation_errors)
67
74
  else:
68
- return await self._call_tool_traced(call, allow_partial)
75
+ return await self._call_tool_traced(call, allow_partial, wrap_validation_errors)
69
76
 
70
- async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
77
+ async def _call_tool(
78
+ self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
79
+ ) -> Any:
71
80
  name = call.tool_name
72
81
  tool = self.tools.get(name)
73
82
  try:
@@ -92,7 +101,7 @@ class ToolManager(Generic[AgentDepsT]):
92
101
  else:
93
102
  args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
94
103
 
95
- output = await self.toolset.call_tool(name, args_dict, ctx, tool)
104
+ return await self.toolset.call_tool(name, args_dict, ctx, tool)
96
105
  except (ValidationError, ModelRetry) as e:
97
106
  max_retries = tool.max_retries if tool is not None else 1
98
107
  current_retry = self.ctx.retries.get(name, 0)
@@ -100,30 +109,33 @@ class ToolManager(Generic[AgentDepsT]):
100
109
  if current_retry == max_retries:
101
110
  raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
102
111
  else:
103
- if isinstance(e, ValidationError):
104
- m = _messages.RetryPromptPart(
105
- tool_name=name,
106
- content=e.errors(include_url=False, include_context=False),
107
- tool_call_id=call.tool_call_id,
108
- )
109
- e = ToolRetryError(m)
110
- elif isinstance(e, ModelRetry):
111
- m = _messages.RetryPromptPart(
112
- tool_name=name,
113
- content=e.message,
114
- tool_call_id=call.tool_call_id,
115
- )
116
- e = ToolRetryError(m)
117
- else:
118
- assert_never(e)
112
+ if wrap_validation_errors:
113
+ if isinstance(e, ValidationError):
114
+ m = _messages.RetryPromptPart(
115
+ tool_name=name,
116
+ content=e.errors(include_url=False, include_context=False),
117
+ tool_call_id=call.tool_call_id,
118
+ )
119
+ e = ToolRetryError(m)
120
+ elif isinstance(e, ModelRetry):
121
+ m = _messages.RetryPromptPart(
122
+ tool_name=name,
123
+ content=e.message,
124
+ tool_call_id=call.tool_call_id,
125
+ )
126
+ e = ToolRetryError(m)
127
+ else:
128
+ assert_never(e)
129
+
130
+ if not allow_partial:
131
+ # If we're validating partial arguments, we don't want to count this as a failed tool as it may still succeed once the full arguments are received.
132
+ self.failed_tools.add(name)
119
133
 
120
- self.ctx.retries[name] = current_retry + 1
121
134
  raise e
122
- else:
123
- self.ctx.retries.pop(name, None)
124
- return output
125
135
 
126
- async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
136
+ async def _call_tool_traced(
137
+ self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
138
+ ) -> Any:
127
139
  """See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
128
140
  span_attributes = {
129
141
  'gen_ai.tool.name': call.tool_name,
@@ -152,7 +164,7 @@ class ToolManager(Generic[AgentDepsT]):
152
164
  }
153
165
  with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
154
166
  try:
155
- tool_result = await self._call_tool(call, allow_partial)
167
+ tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
156
168
  except ToolRetryError as e:
157
169
  part = e.tool_retry
158
170
  if self.ctx.trace_include_content and span.is_recording():
@@ -9,11 +9,13 @@ from __future__ import annotations
9
9
  import json
10
10
  import uuid
11
11
  from collections.abc import Iterable, Mapping, Sequence
12
- from dataclasses import dataclass, field
12
+ from dataclasses import Field, dataclass, field, replace
13
13
  from http import HTTPStatus
14
14
  from typing import (
15
+ TYPE_CHECKING,
15
16
  Any,
16
17
  Callable,
18
+ ClassVar,
17
19
  Final,
18
20
  Generic,
19
21
  Protocol,
@@ -21,6 +23,11 @@ from typing import (
21
23
  runtime_checkable,
22
24
  )
23
25
 
26
+ from pydantic_ai.exceptions import UserError
27
+
28
+ if TYPE_CHECKING:
29
+ pass
30
+
24
31
  try:
25
32
  from ag_ui.core import (
26
33
  AssistantMessage,
@@ -288,8 +295,24 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
288
295
  if not run_input.messages:
289
296
  raise _NoMessagesError
290
297
 
298
+ raw_state: dict[str, Any] = run_input.state or {}
291
299
  if isinstance(deps, StateHandler):
292
- deps.state = run_input.state
300
+ if isinstance(deps.state, BaseModel):
301
+ try:
302
+ state = type(deps.state).model_validate(raw_state)
303
+ except ValidationError as e: # pragma: no cover
304
+ raise _InvalidStateError from e
305
+ else:
306
+ state = raw_state
307
+
308
+ deps = replace(deps, state=state)
309
+ elif raw_state:
310
+ raise UserError(
311
+ f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
312
+ )
313
+ else:
314
+ # `deps` not being a `StateHandler` is OK if there is no state.
315
+ pass
293
316
 
294
317
  messages = _messages_from_ag_ui(run_input.messages)
295
318
 
@@ -311,7 +334,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
311
334
  yield encoder.encode(
312
335
  RunErrorEvent(message=e.message, code=e.code),
313
336
  )
314
- except Exception as e: # pragma: no cover
337
+ except Exception as e:
315
338
  yield encoder.encode(
316
339
  RunErrorEvent(message=str(e)),
317
340
  )
@@ -486,6 +509,9 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
486
509
  if isinstance(msg, UserMessage):
487
510
  result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
488
511
  elif isinstance(msg, AssistantMessage):
512
+ if msg.content:
513
+ result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
514
+
489
515
  if msg.tool_calls:
490
516
  for tool_call in msg.tool_calls:
491
517
  tool_calls[tool_call.id] = tool_call.function.name
@@ -502,9 +528,6 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
502
528
  ]
503
529
  )
504
530
  )
505
-
506
- if msg.content:
507
- result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
508
531
  elif isinstance(msg, SystemMessage):
509
532
  result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
510
533
  elif isinstance(msg, ToolMessage):
@@ -531,7 +554,11 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
531
554
 
532
555
  @runtime_checkable
533
556
  class StateHandler(Protocol):
534
- """Protocol for state handlers in agent runs."""
557
+ """Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field."""
558
+
559
+ # Has to be a dataclass so we can use `replace` to update the state.
560
+ # From https://github.com/python/typeshed/blob/9ab7fde0a0cd24ed7a72837fcb21093b811b80d8/stdlib/_typeshed/__init__.pyi#L352
561
+ __dataclass_fields__: ClassVar[dict[str, Field[Any]]]
535
562
 
536
563
  @property
537
564
  def state(self) -> State:
@@ -558,6 +585,7 @@ StateT = TypeVar('StateT', bound=BaseModel)
558
585
  """Type variable for the state type, which must be a subclass of `BaseModel`."""
559
586
 
560
587
 
588
+ @dataclass
561
589
  class StateDeps(Generic[StateT]):
562
590
  """Provides AG-UI state management.
563
591
 
@@ -570,42 +598,7 @@ class StateDeps(Generic[StateT]):
570
598
  Implements the `StateHandler` protocol.
571
599
  """
572
600
 
573
- def __init__(self, default: StateT) -> None:
574
- """Initialize the state with the provided state type."""
575
- self._state = default
576
-
577
- @property
578
- def state(self) -> StateT:
579
- """Get the current state of the agent run.
580
-
581
- Returns:
582
- The current run state.
583
- """
584
- return self._state
585
-
586
- @state.setter
587
- def state(self, state: State) -> None:
588
- """Set the state of the agent run.
589
-
590
- This method is called to update the state of the agent run with the
591
- provided state.
592
-
593
- Implements the `StateHandler` protocol.
594
-
595
- Args:
596
- state: The run state, which must be `None` or model validate for the state type.
597
-
598
- Raises:
599
- InvalidStateError: If `state` does not validate.
600
- """
601
- if state is None:
602
- # If state is None, we keep the current state, which will be the default state.
603
- return
604
-
605
- try:
606
- self._state = type(self._state).model_validate(state)
607
- except ValidationError as e: # pragma: no cover
608
- raise _InvalidStateError from e
601
+ state: StateT
609
602
 
610
603
 
611
604
  @dataclass(repr=False)
@@ -1792,9 +1792,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1792
1792
  """
1793
1793
  async with self._enter_lock:
1794
1794
  if self._entered_count == 0:
1795
- self._exit_stack = AsyncExitStack()
1796
- toolset = self._get_toolset()
1797
- await self._exit_stack.enter_async_context(toolset)
1795
+ async with AsyncExitStack() as exit_stack:
1796
+ toolset = self._get_toolset()
1797
+ await exit_stack.enter_async_context(toolset)
1798
+
1799
+ self._exit_stack = exit_stack.pop_all()
1798
1800
  self._entered_count += 1
1799
1801
  return self
1800
1802
 
@@ -201,25 +201,24 @@ class MCPServer(AbstractToolset[Any], ABC):
201
201
  """
202
202
  async with self._enter_lock:
203
203
  if self._running_count == 0:
204
- self._exit_stack = AsyncExitStack()
205
-
206
- self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(
207
- self.client_streams()
208
- )
209
- client = ClientSession(
210
- read_stream=self._read_stream,
211
- write_stream=self._write_stream,
212
- sampling_callback=self._sampling_callback if self.allow_sampling else None,
213
- logging_callback=self.log_handler,
214
- read_timeout_seconds=timedelta(seconds=self.read_timeout),
215
- )
216
- self._client = await self._exit_stack.enter_async_context(client)
217
-
218
- with anyio.fail_after(self.timeout):
219
- await self._client.initialize()
220
-
221
- if log_level := self.log_level:
222
- await self._client.set_logging_level(log_level)
204
+ async with AsyncExitStack() as exit_stack:
205
+ self._read_stream, self._write_stream = await exit_stack.enter_async_context(self.client_streams())
206
+ client = ClientSession(
207
+ read_stream=self._read_stream,
208
+ write_stream=self._write_stream,
209
+ sampling_callback=self._sampling_callback if self.allow_sampling else None,
210
+ logging_callback=self.log_handler,
211
+ read_timeout_seconds=timedelta(seconds=self.read_timeout),
212
+ )
213
+ self._client = await exit_stack.enter_async_context(client)
214
+
215
+ with anyio.fail_after(self.timeout):
216
+ await self._client.initialize()
217
+
218
+ if log_level := self.log_level:
219
+ await self._client.set_logging_level(log_level)
220
+
221
+ self._exit_stack = exit_stack.pop_all()
223
222
  self._running_count += 1
224
223
  return self
225
224
 
@@ -544,6 +543,7 @@ class _MCPServerHTTP(MCPServer):
544
543
  self.max_retries = max_retries
545
544
  self.sampling_model = sampling_model
546
545
  self.read_timeout = read_timeout
546
+ self.__post_init__()
547
547
 
548
548
  @property
549
549
  @abstractmethod
@@ -815,11 +815,16 @@ class ModelResponse:
815
815
  },
816
816
  }
817
817
  )
818
- elif isinstance(part, TextPart):
819
- if body.get('content'):
820
- body = new_event_body()
821
- if settings.include_content:
822
- body['content'] = part.content
818
+ elif isinstance(part, (TextPart, ThinkingPart)):
819
+ kind = part.part_kind
820
+ body.setdefault('content', []).append(
821
+ {'kind': kind, **({'text': part.content} if settings.include_content else {})}
822
+ )
823
+
824
+ if content := body.get('content'):
825
+ text_content = content[0].get('text')
826
+ if content == [{'kind': 'text', 'text': text_content}]:
827
+ body['content'] = text_content
823
828
 
824
829
  return result
825
830
 
@@ -49,7 +49,6 @@ class GoogleJsonSchemaTransformer(JsonSchemaTransformer):
49
49
  )
50
50
 
51
51
  schema.pop('title', None)
52
- schema.pop('default', None)
53
52
  schema.pop('$schema', None)
54
53
  if (const := schema.pop('const', None)) is not None:
55
54
  # Gemini doesn't support const, but it does support enum with a single value
@@ -67,7 +67,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
67
67
  except ValidationError:
68
68
  pass
69
69
  if self._final_result_event is not None: # pragma: no branch
70
- yield await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
70
+ yield await self._validate_response(self._raw_stream_response.get())
71
71
 
72
72
  async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
73
73
  """Asynchronously stream the (unvalidated) model responses for the agent."""
@@ -128,7 +128,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
128
128
  async for _ in self:
129
129
  pass
130
130
 
131
- return await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
131
+ return await self._validate_response(self._raw_stream_response.get())
132
132
 
133
133
  async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
134
134
  """Validate a structured result message."""
@@ -150,7 +150,9 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
150
150
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
151
151
  f'Invalid response, unable to find tool call for {output_tool_name!r}'
152
152
  )
153
- return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
153
+ return await self._tool_manager.handle_call(
154
+ tool_call, allow_partial=allow_partial, wrap_validation_errors=False
155
+ )
154
156
  elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
155
157
  if not self._output_schema.allows_deferred_tool_calls:
156
158
  raise exceptions.UserError(
@@ -0,0 +1,249 @@
1
+ """Retries utilities based on tenacity, especially for HTTP requests.
2
+
3
+ This module provides HTTP transport wrappers and wait strategies that integrate with
4
+ the tenacity library to add retry capabilities to HTTP requests. The transports can be
5
+ used with HTTP clients that support custom transports (such as httpx), while the wait
6
+ strategies can be used with any tenacity retry decorator.
7
+
8
+ The module includes:
9
+ - TenacityTransport: Synchronous HTTP transport with retry capabilities
10
+ - AsyncTenacityTransport: Asynchronous HTTP transport with retry capabilities
11
+ - wait_retry_after: Wait strategy that respects HTTP Retry-After headers
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from httpx import AsyncBaseTransport, AsyncHTTPTransport, BaseTransport, HTTPTransport, Request, Response
17
+
18
+ try:
19
+ from tenacity import AsyncRetrying, Retrying
20
+ except ImportError as _import_error:
21
+ raise ImportError(
22
+ 'Please install `tenacity` to use the retries utilities, '
23
+ 'you can use the `retries` optional group — `pip install "pydantic-ai-slim[retries]"`'
24
+ ) from _import_error
25
+
26
+
27
+ __all__ = ['TenacityTransport', 'AsyncTenacityTransport', 'wait_retry_after']
28
+
29
+ from datetime import datetime, timezone
30
+ from email.utils import parsedate_to_datetime
31
+ from typing import Callable, cast
32
+
33
+ from httpx import HTTPStatusError
34
+ from tenacity import RetryCallState, wait_exponential
35
+
36
+
37
+ class TenacityTransport(BaseTransport):
38
+ """Synchronous HTTP transport with tenacity-based retry functionality.
39
+
40
+ This transport wraps another BaseTransport and adds retry capabilities using the tenacity library.
41
+ It can be configured to retry requests based on various conditions such as specific exception types,
42
+ response status codes, or custom validation logic.
43
+
44
+ The transport works by intercepting HTTP requests and responses, allowing the tenacity controller
45
+ to determine when and how to retry failed requests. The validate_response function can be used
46
+ to convert HTTP responses into exceptions that trigger retries.
47
+
48
+ Args:
49
+ wrapped: The underlying transport to wrap and add retry functionality to.
50
+ controller: The tenacity Retrying instance that defines the retry behavior
51
+ (retry conditions, wait strategy, stop conditions, etc.).
52
+ validate_response: Optional callable that takes a Response and can raise an exception
53
+ to be handled by the controller if the response should trigger a retry.
54
+ Common use case is to raise exceptions for certain HTTP status codes.
55
+ If None, no response validation is performed.
56
+
57
+ Example:
58
+ ```python
59
+ from httpx import Client, HTTPTransport, HTTPStatusError
60
+ from tenacity import Retrying, stop_after_attempt, retry_if_exception_type
61
+ from pydantic_ai.retries import TenacityTransport, wait_retry_after
62
+
63
+ transport = TenacityTransport(
64
+ HTTPTransport(),
65
+ Retrying(
66
+ retry=retry_if_exception_type(HTTPStatusError),
67
+ wait=wait_retry_after(max_wait=300),
68
+ stop=stop_after_attempt(5),
69
+ reraise=True
70
+ ),
71
+ validate_response=lambda r: r.raise_for_status()
72
+ )
73
+ client = Client(transport=transport)
74
+ ```
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ controller: Retrying,
80
+ wrapped: BaseTransport | None = None,
81
+ validate_response: Callable[[Response], None] | None = None,
82
+ ):
83
+ self.controller = controller
84
+ self.wrapped = wrapped or HTTPTransport()
85
+ self.validate_response = validate_response
86
+
87
+ def handle_request(self, request: Request) -> Response:
88
+ """Handle an HTTP request with retry logic.
89
+
90
+ Args:
91
+ request: The HTTP request to handle.
92
+
93
+ Returns:
94
+ The HTTP response.
95
+
96
+ Raises:
97
+ RuntimeError: If the retry controller did not make any attempts.
98
+ Exception: Any exception raised by the wrapped transport or validation function.
99
+ """
100
+ for attempt in self.controller:
101
+ with attempt:
102
+ response = self.wrapped.handle_request(request)
103
+ if self.validate_response:
104
+ self.validate_response(response)
105
+ return response
106
+ raise RuntimeError('The retry controller did not make any attempts') # pragma: no cover
107
+
108
+
109
+ class AsyncTenacityTransport(AsyncBaseTransport):
110
+ """Asynchronous HTTP transport with tenacity-based retry functionality.
111
+
112
+ This transport wraps another AsyncBaseTransport and adds retry capabilities using the tenacity library.
113
+ It can be configured to retry requests based on various conditions such as specific exception types,
114
+ response status codes, or custom validation logic.
115
+
116
+ The transport works by intercepting HTTP requests and responses, allowing the tenacity controller
117
+ to determine when and how to retry failed requests. The validate_response function can be used
118
+ to convert HTTP responses into exceptions that trigger retries.
119
+
120
+ Args:
121
+ wrapped: The underlying async transport to wrap and add retry functionality to.
122
+ controller: The tenacity AsyncRetrying instance that defines the retry behavior
123
+ (retry conditions, wait strategy, stop conditions, etc.).
124
+ validate_response: Optional callable that takes a Response and can raise an exception
125
+ to be handled by the controller if the response should trigger a retry.
126
+ Common use case is to raise exceptions for certain HTTP status codes.
127
+ If None, no response validation is performed.
128
+
129
+ Example:
130
+ ```python
131
+ from httpx import AsyncClient, HTTPStatusError
132
+ from tenacity import AsyncRetrying, stop_after_attempt, retry_if_exception_type
133
+ from pydantic_ai.retries import AsyncTenacityTransport, wait_retry_after
134
+
135
+ transport = AsyncTenacityTransport(
136
+ AsyncRetrying(
137
+ retry=retry_if_exception_type(HTTPStatusError),
138
+ wait=wait_retry_after(max_wait=300),
139
+ stop=stop_after_attempt(5),
140
+ reraise=True
141
+ ),
142
+ validate_response=lambda r: r.raise_for_status()
143
+ )
144
+ client = AsyncClient(transport=transport)
145
+ ```
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ controller: AsyncRetrying,
151
+ wrapped: AsyncBaseTransport | None = None,
152
+ validate_response: Callable[[Response], None] | None = None,
153
+ ):
154
+ self.controller = controller
155
+ self.wrapped = wrapped or AsyncHTTPTransport()
156
+ self.validate_response = validate_response
157
+
158
+ async def handle_async_request(self, request: Request) -> Response:
159
+ """Handle an async HTTP request with retry logic.
160
+
161
+ Args:
162
+ request: The HTTP request to handle.
163
+
164
+ Returns:
165
+ The HTTP response.
166
+
167
+ Raises:
168
+ RuntimeError: If the retry controller did not make any attempts.
169
+ Exception: Any exception raised by the wrapped transport or validation function.
170
+ """
171
+ async for attempt in self.controller:
172
+ with attempt:
173
+ response = await self.wrapped.handle_async_request(request)
174
+ if self.validate_response:
175
+ self.validate_response(response)
176
+ return response
177
+ raise RuntimeError('The retry controller did not make any attempts') # pragma: no cover
178
+
179
+
180
+ def wait_retry_after(
181
+ fallback_strategy: Callable[[RetryCallState], float] | None = None, max_wait: float = 300
182
+ ) -> Callable[[RetryCallState], float]:
183
+ """Create a tenacity-compatible wait strategy that respects HTTP Retry-After headers.
184
+
185
+ This wait strategy checks if the exception contains an HTTPStatusError with a
186
+ Retry-After header, and if so, waits for the time specified in the header.
187
+ If no header is present or parsing fails, it falls back to the provided strategy.
188
+
189
+ The Retry-After header can be in two formats:
190
+ - An integer representing seconds to wait
191
+ - An HTTP date string representing when to retry
192
+
193
+ Args:
194
+ fallback_strategy: Wait strategy to use when no Retry-After header is present
195
+ or parsing fails. Defaults to exponential backoff with max 60s.
196
+ max_wait: Maximum time to wait in seconds, regardless of header value.
197
+ Defaults to 300 (5 minutes).
198
+
199
+ Returns:
200
+ A wait function that can be used with tenacity retry decorators.
201
+
202
+ Example:
203
+ ```python
204
+ from httpx import AsyncClient, HTTPStatusError
205
+ from tenacity import AsyncRetrying, stop_after_attempt, retry_if_exception_type
206
+ from pydantic_ai.retries import AsyncTenacityTransport, wait_retry_after
207
+
208
+ transport = AsyncTenacityTransport(
209
+ AsyncRetrying(
210
+ retry=retry_if_exception_type(HTTPStatusError),
211
+ wait=wait_retry_after(max_wait=120),
212
+ stop=stop_after_attempt(5),
213
+ reraise=True
214
+ ),
215
+ validate_response=lambda r: r.raise_for_status()
216
+ )
217
+ client = AsyncClient(transport=transport)
218
+ ```
219
+ """
220
+ if fallback_strategy is None:
221
+ fallback_strategy = wait_exponential(multiplier=1, max=60)
222
+
223
+ def wait_func(state: RetryCallState) -> float:
224
+ exc = state.outcome.exception() if state.outcome else None
225
+ if isinstance(exc, HTTPStatusError):
226
+ retry_after = exc.response.headers.get('retry-after')
227
+ if retry_after:
228
+ try:
229
+ # Try parsing as seconds first
230
+ wait_seconds = int(retry_after)
231
+ return min(float(wait_seconds), max_wait)
232
+ except ValueError:
233
+ # Try parsing as HTTP date
234
+ try:
235
+ retry_time = cast(datetime, parsedate_to_datetime(retry_after))
236
+ assert isinstance(retry_time, datetime)
237
+ now = datetime.now(timezone.utc)
238
+ wait_seconds = (retry_time - now).total_seconds()
239
+
240
+ if wait_seconds > 0:
241
+ return min(wait_seconds, max_wait)
242
+ except (ValueError, TypeError, AssertionError):
243
+ # If date parsing fails, fall back to fallback strategy
244
+ pass
245
+
246
+ # Use fallback strategy
247
+ return fallback_strategy(state)
248
+
249
+ return wait_func
@@ -43,9 +43,10 @@ class CombinedToolset(AbstractToolset[AgentDepsT]):
43
43
  async def __aenter__(self) -> Self:
44
44
  async with self._enter_lock:
45
45
  if self._entered_count == 0:
46
- self._exit_stack = AsyncExitStack()
47
- for toolset in self.toolsets:
48
- await self._exit_stack.enter_async_context(toolset)
46
+ async with AsyncExitStack() as exit_stack:
47
+ for toolset in self.toolsets:
48
+ await exit_stack.enter_async_context(toolset)
49
+ self._exit_stack = exit_stack.pop_all()
49
50
  self._entered_count += 1
50
51
  return self
51
52
 
@@ -84,6 +84,8 @@ evals = ["pydantic-evals=={{ version }}"]
84
84
  a2a = ["fasta2a>=0.4.1"]
85
85
  # AG-UI
86
86
  ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"]
87
+ # Retries
88
+ retries = ["tenacity>=8.2.3"]
87
89
 
88
90
  [dependency-groups]
89
91
  dev = [