pydantic-ai-slim 0.4.8__tar.gz → 0.4.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (101) hide show
  1. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/PKG-INFO +3 -4
  2. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_agent_graph.py +21 -19
  3. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_tool_manager.py +43 -31
  4. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/ag_ui.py +33 -40
  5. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/agent.py +81 -80
  6. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/messages.py +2 -2
  7. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/openai.py +3 -2
  8. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/google.py +0 -1
  9. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/vercel.py +8 -2
  10. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/result.py +5 -3
  11. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pyproject.toml +1 -6
  12. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/.gitignore +0 -0
  13. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/LICENSE +0 -0
  14. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/README.md +0 -0
  15. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/__init__.py +0 -0
  16. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/__main__.py +0 -0
  17. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_a2a.py +0 -0
  18. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_cli.py +0 -0
  19. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_function_schema.py +0 -0
  20. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_griffe.py +0 -0
  21. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_mcp.py +0 -0
  22. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_output.py +0 -0
  23. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_parts_manager.py +0 -0
  24. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_run_context.py +0 -0
  25. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_system_prompt.py +0 -0
  26. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_thinking_part.py +0 -0
  27. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/_utils.py +0 -0
  28. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/common_tools/__init__.py +0 -0
  29. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  30. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/common_tools/tavily.py +0 -0
  31. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/direct.py +0 -0
  32. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/exceptions.py +0 -0
  33. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/ext/__init__.py +0 -0
  34. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/ext/aci.py +0 -0
  35. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/ext/langchain.py +0 -0
  36. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/format_as_xml.py +0 -0
  37. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/format_prompt.py +0 -0
  38. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/mcp.py +0 -0
  39. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/__init__.py +0 -0
  40. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/anthropic.py +0 -0
  41. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/bedrock.py +0 -0
  42. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/cohere.py +0 -0
  43. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/fallback.py +0 -0
  44. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/function.py +0 -0
  45. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/gemini.py +0 -0
  46. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/google.py +0 -0
  47. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/groq.py +0 -0
  48. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/huggingface.py +0 -0
  49. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/instrumented.py +0 -0
  50. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/mcp_sampling.py +0 -0
  51. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/mistral.py +0 -0
  52. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/test.py +0 -0
  53. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/models/wrapper.py +0 -0
  54. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/output.py +0 -0
  55. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/__init__.py +0 -0
  56. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/_json_schema.py +0 -0
  57. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/amazon.py +0 -0
  58. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/anthropic.py +0 -0
  59. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/cohere.py +0 -0
  60. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/deepseek.py +0 -0
  61. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/grok.py +0 -0
  62. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/meta.py +0 -0
  63. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/mistral.py +0 -0
  64. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/moonshotai.py +0 -0
  65. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/openai.py +0 -0
  66. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/profiles/qwen.py +0 -0
  67. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/__init__.py +0 -0
  68. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/anthropic.py +0 -0
  69. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/azure.py +0 -0
  70. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/bedrock.py +0 -0
  71. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/cohere.py +0 -0
  72. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/deepseek.py +0 -0
  73. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/fireworks.py +0 -0
  74. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/github.py +0 -0
  75. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/google.py +0 -0
  76. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/google_gla.py +0 -0
  77. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/google_vertex.py +0 -0
  78. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/grok.py +0 -0
  79. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/groq.py +0 -0
  80. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/heroku.py +0 -0
  81. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/huggingface.py +0 -0
  82. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/mistral.py +0 -0
  83. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/moonshotai.py +0 -0
  84. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/openai.py +0 -0
  85. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/openrouter.py +0 -0
  86. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/providers/together.py +0 -0
  87. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/py.typed +0 -0
  88. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/retries.py +0 -0
  89. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/settings.py +0 -0
  90. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/tools.py +0 -0
  91. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/__init__.py +0 -0
  92. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/abstract.py +0 -0
  93. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/combined.py +0 -0
  94. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/deferred.py +0 -0
  95. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/filtered.py +0 -0
  96. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/function.py +0 -0
  97. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/prefixed.py +0 -0
  98. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/prepared.py +0 -0
  99. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/renamed.py +0 -0
  100. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/toolsets/wrapper.py +0 -0
  101. {pydantic_ai_slim-0.4.8 → pydantic_ai_slim-0.4.10}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.4.8
3
+ Version: 0.4.10
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>, Douwe Maan <douwe@pydantic.dev>
6
6
  License-Expression: MIT
@@ -30,7 +30,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
30
30
  Requires-Dist: griffe>=1.3.2
31
31
  Requires-Dist: httpx>=0.27
32
32
  Requires-Dist: opentelemetry-api>=1.28.0
33
- Requires-Dist: pydantic-graph==0.4.8
33
+ Requires-Dist: pydantic-graph==0.4.10
34
34
  Requires-Dist: pydantic>=2.10
35
35
  Requires-Dist: typing-inspection>=0.4.0
36
36
  Provides-Extra: a2a
@@ -48,11 +48,10 @@ Requires-Dist: prompt-toolkit>=3; extra == 'cli'
48
48
  Requires-Dist: rich>=13; extra == 'cli'
49
49
  Provides-Extra: cohere
50
50
  Requires-Dist: cohere>=5.16.0; (platform_system != 'Emscripten') and extra == 'cohere'
51
- Requires-Dist: tokenizers<=0.21.2; extra == 'cohere'
52
51
  Provides-Extra: duckduckgo
53
52
  Requires-Dist: ddgs>=9.0.0; extra == 'duckduckgo'
54
53
  Provides-Extra: evals
55
- Requires-Dist: pydantic-evals==0.4.8; extra == 'evals'
54
+ Requires-Dist: pydantic-evals==0.4.10; extra == 'evals'
56
55
  Provides-Extra: google
57
56
  Requires-Dist: google-genai>=1.24.0; extra == 'google'
58
57
  Provides-Extra: groq
@@ -659,11 +659,11 @@ async def process_function_tools( # noqa: C901
659
659
  for call in calls_to_run:
660
660
  yield _messages.FunctionToolCallEvent(call)
661
661
 
662
- user_parts: list[_messages.UserPromptPart] = []
662
+ user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list)
663
663
 
664
664
  if calls_to_run:
665
665
  # Run all tool tasks in parallel
666
- parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {}
666
+ tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
667
667
  with ctx.deps.tracer.start_as_current_span(
668
668
  'running tools',
669
669
  attributes={
@@ -681,15 +681,16 @@ async def process_function_tools( # noqa: C901
681
681
  done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
682
682
  for task in done:
683
683
  index = tasks.index(task)
684
- tool_result_part, extra_parts = task.result()
685
- yield _messages.FunctionToolResultEvent(tool_result_part)
684
+ tool_part, tool_user_parts = task.result()
685
+ yield _messages.FunctionToolResultEvent(tool_part)
686
686
 
687
- parts_by_index[index] = [tool_result_part, *extra_parts]
687
+ tool_parts_by_index[index] = tool_part
688
+ user_parts_by_index[index] = tool_user_parts
688
689
 
689
690
  # We append the results at the end, rather than as they are received, to retain a consistent ordering
690
691
  # This is mostly just to simplify testing
691
- for k in sorted(parts_by_index):
692
- output_parts.extend(parts_by_index[k])
692
+ for k in sorted(tool_parts_by_index):
693
+ output_parts.append(tool_parts_by_index[k])
693
694
 
694
695
  # Finally, we handle deferred tool calls
695
696
  for call in tool_calls_by_kind['deferred']:
@@ -704,7 +705,8 @@ async def process_function_tools( # noqa: C901
704
705
  else:
705
706
  yield _messages.FunctionToolCallEvent(call)
706
707
 
707
- output_parts.extend(user_parts)
708
+ for k in sorted(user_parts_by_index):
709
+ output_parts.extend(user_parts_by_index[k])
708
710
 
709
711
  if final_result:
710
712
  output_final_result.append(final_result)
@@ -713,18 +715,18 @@ async def process_function_tools( # noqa: C901
713
715
  async def _call_function_tool(
714
716
  tool_manager: ToolManager[DepsT],
715
717
  tool_call: _messages.ToolCallPart,
716
- ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]:
718
+ ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]:
717
719
  try:
718
720
  tool_result = await tool_manager.handle_call(tool_call)
719
721
  except ToolRetryError as e:
720
722
  return (e.tool_retry, [])
721
723
 
722
- part = _messages.ToolReturnPart(
724
+ tool_part = _messages.ToolReturnPart(
723
725
  tool_name=tool_call.tool_name,
724
726
  content=tool_result,
725
727
  tool_call_id=tool_call.tool_call_id,
726
728
  )
727
- extra_parts: list[_messages.ModelRequestPart] = []
729
+ user_parts: list[_messages.UserPromptPart] = []
728
730
 
729
731
  if isinstance(tool_result, _messages.ToolReturn):
730
732
  if (
@@ -740,12 +742,12 @@ async def _call_function_tool(
740
742
  f'Please use `content` instead.'
741
743
  )
742
744
 
743
- part.content = tool_result.return_value # type: ignore
744
- part.metadata = tool_result.metadata
745
+ tool_part.content = tool_result.return_value # type: ignore
746
+ tool_part.metadata = tool_result.metadata
745
747
  if tool_result.content:
746
- extra_parts.append(
748
+ user_parts.append(
747
749
  _messages.UserPromptPart(
748
- content=list(tool_result.content),
750
+ content=tool_result.content,
749
751
  part_kind='user-prompt',
750
752
  )
751
753
  )
@@ -763,7 +765,7 @@ async def _call_function_tool(
763
765
  else:
764
766
  identifier = multi_modal_content_identifier(content.url)
765
767
 
766
- extra_parts.append(
768
+ user_parts.append(
767
769
  _messages.UserPromptPart(
768
770
  content=[f'This is file {identifier}:', content],
769
771
  part_kind='user-prompt',
@@ -775,11 +777,11 @@ async def _call_function_tool(
775
777
 
776
778
  if isinstance(tool_result, list):
777
779
  contents = cast(list[Any], tool_result)
778
- part.content = [process_content(content) for content in contents]
780
+ tool_part.content = [process_content(content) for content in contents]
779
781
  else:
780
- part.content = process_content(tool_result)
782
+ tool_part.content = process_content(tool_result)
781
783
 
782
- return (part, extra_parts)
784
+ return (tool_part, user_parts)
783
785
 
784
786
 
785
787
  @dataclasses.dataclass
@@ -2,18 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  from collections.abc import Iterable
5
- from dataclasses import dataclass, replace
5
+ from dataclasses import dataclass, field, replace
6
6
  from typing import Any, Generic
7
7
 
8
8
  from pydantic import ValidationError
9
9
  from typing_extensions import assert_never
10
10
 
11
- from pydantic_ai.output import DeferredToolCalls
12
-
13
11
  from . import messages as _messages
14
12
  from ._run_context import AgentDepsT, RunContext
15
13
  from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
16
14
  from .messages import ToolCallPart
15
+ from .output import DeferredToolCalls
17
16
  from .tools import ToolDefinition
18
17
  from .toolsets.abstract import AbstractToolset, ToolsetTool
19
18
 
@@ -28,6 +27,8 @@ class ToolManager(Generic[AgentDepsT]):
28
27
  """The toolset that provides the tools for this run step."""
29
28
  tools: dict[str, ToolsetTool[AgentDepsT]]
30
29
  """The cached tools for this run step."""
30
+ failed_tools: set[str] = field(default_factory=set)
31
+ """Names of tools that failed in this run step."""
31
32
 
32
33
  @classmethod
33
34
  async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
@@ -40,7 +41,10 @@ class ToolManager(Generic[AgentDepsT]):
40
41
 
41
42
  async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
42
43
  """Build a new tool manager for the next run step, carrying over the retries from the current run step."""
43
- return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries))
44
+ retries = {
45
+ failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools
46
+ }
47
+ return await self.__class__.build(self.toolset, replace(ctx, retries=retries))
44
48
 
45
49
  @property
46
50
  def tool_defs(self) -> list[ToolDefinition]:
@@ -54,20 +58,25 @@ class ToolManager(Generic[AgentDepsT]):
54
58
  except KeyError:
55
59
  return None
56
60
 
57
- async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
61
+ async def handle_call(
62
+ self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
63
+ ) -> Any:
58
64
  """Handle a tool call by validating the arguments, calling the tool, and handling retries.
59
65
 
60
66
  Args:
61
67
  call: The tool call part to handle.
62
68
  allow_partial: Whether to allow partial validation of the tool arguments.
69
+ wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
63
70
  """
64
71
  if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
65
72
  # Output tool calls are not traced
66
- return await self._call_tool(call, allow_partial)
73
+ return await self._call_tool(call, allow_partial, wrap_validation_errors)
67
74
  else:
68
- return await self._call_tool_traced(call, allow_partial)
75
+ return await self._call_tool_traced(call, allow_partial, wrap_validation_errors)
69
76
 
70
- async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
77
+ async def _call_tool(
78
+ self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
79
+ ) -> Any:
71
80
  name = call.tool_name
72
81
  tool = self.tools.get(name)
73
82
  try:
@@ -92,7 +101,7 @@ class ToolManager(Generic[AgentDepsT]):
92
101
  else:
93
102
  args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
94
103
 
95
- output = await self.toolset.call_tool(name, args_dict, ctx, tool)
104
+ return await self.toolset.call_tool(name, args_dict, ctx, tool)
96
105
  except (ValidationError, ModelRetry) as e:
97
106
  max_retries = tool.max_retries if tool is not None else 1
98
107
  current_retry = self.ctx.retries.get(name, 0)
@@ -100,30 +109,33 @@ class ToolManager(Generic[AgentDepsT]):
100
109
  if current_retry == max_retries:
101
110
  raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
102
111
  else:
103
- if isinstance(e, ValidationError):
104
- m = _messages.RetryPromptPart(
105
- tool_name=name,
106
- content=e.errors(include_url=False, include_context=False),
107
- tool_call_id=call.tool_call_id,
108
- )
109
- e = ToolRetryError(m)
110
- elif isinstance(e, ModelRetry):
111
- m = _messages.RetryPromptPart(
112
- tool_name=name,
113
- content=e.message,
114
- tool_call_id=call.tool_call_id,
115
- )
116
- e = ToolRetryError(m)
117
- else:
118
- assert_never(e)
112
+ if wrap_validation_errors:
113
+ if isinstance(e, ValidationError):
114
+ m = _messages.RetryPromptPart(
115
+ tool_name=name,
116
+ content=e.errors(include_url=False, include_context=False),
117
+ tool_call_id=call.tool_call_id,
118
+ )
119
+ e = ToolRetryError(m)
120
+ elif isinstance(e, ModelRetry):
121
+ m = _messages.RetryPromptPart(
122
+ tool_name=name,
123
+ content=e.message,
124
+ tool_call_id=call.tool_call_id,
125
+ )
126
+ e = ToolRetryError(m)
127
+ else:
128
+ assert_never(e)
129
+
130
+ if not allow_partial:
131
+ # If we're validating partial arguments, we don't want to count this as a failed tool as it may still succeed once the full arguments are received.
132
+ self.failed_tools.add(name)
119
133
 
120
- self.ctx.retries[name] = current_retry + 1
121
134
  raise e
122
- else:
123
- self.ctx.retries.pop(name, None)
124
- return output
125
135
 
126
- async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
136
+ async def _call_tool_traced(
137
+ self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
138
+ ) -> Any:
127
139
  """See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
128
140
  span_attributes = {
129
141
  'gen_ai.tool.name': call.tool_name,
@@ -152,7 +164,7 @@ class ToolManager(Generic[AgentDepsT]):
152
164
  }
153
165
  with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
154
166
  try:
155
- tool_result = await self._call_tool(call, allow_partial)
167
+ tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
156
168
  except ToolRetryError as e:
157
169
  part = e.tool_retry
158
170
  if self.ctx.trace_include_content and span.is_recording():
@@ -9,11 +9,13 @@ from __future__ import annotations
9
9
  import json
10
10
  import uuid
11
11
  from collections.abc import Iterable, Mapping, Sequence
12
- from dataclasses import dataclass, field
12
+ from dataclasses import Field, dataclass, field, replace
13
13
  from http import HTTPStatus
14
14
  from typing import (
15
+ TYPE_CHECKING,
15
16
  Any,
16
17
  Callable,
18
+ ClassVar,
17
19
  Final,
18
20
  Generic,
19
21
  Protocol,
@@ -21,6 +23,11 @@ from typing import (
21
23
  runtime_checkable,
22
24
  )
23
25
 
26
+ from pydantic_ai.exceptions import UserError
27
+
28
+ if TYPE_CHECKING:
29
+ pass
30
+
24
31
  try:
25
32
  from ag_ui.core import (
26
33
  AssistantMessage,
@@ -288,8 +295,24 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
288
295
  if not run_input.messages:
289
296
  raise _NoMessagesError
290
297
 
298
+ raw_state: dict[str, Any] = run_input.state or {}
291
299
  if isinstance(deps, StateHandler):
292
- deps.state = run_input.state
300
+ if isinstance(deps.state, BaseModel):
301
+ try:
302
+ state = type(deps.state).model_validate(raw_state)
303
+ except ValidationError as e: # pragma: no cover
304
+ raise _InvalidStateError from e
305
+ else:
306
+ state = raw_state
307
+
308
+ deps = replace(deps, state=state)
309
+ elif raw_state:
310
+ raise UserError(
311
+ f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
312
+ )
313
+ else:
314
+ # `deps` not being a `StateHandler` is OK if there is no state.
315
+ pass
293
316
 
294
317
  messages = _messages_from_ag_ui(run_input.messages)
295
318
 
@@ -311,7 +334,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
311
334
  yield encoder.encode(
312
335
  RunErrorEvent(message=e.message, code=e.code),
313
336
  )
314
- except Exception as e: # pragma: no cover
337
+ except Exception as e:
315
338
  yield encoder.encode(
316
339
  RunErrorEvent(message=str(e)),
317
340
  )
@@ -531,7 +554,11 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
531
554
 
532
555
  @runtime_checkable
533
556
  class StateHandler(Protocol):
534
- """Protocol for state handlers in agent runs."""
557
+ """Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field."""
558
+
559
+ # Has to be a dataclass so we can use `replace` to update the state.
560
+ # From https://github.com/python/typeshed/blob/9ab7fde0a0cd24ed7a72837fcb21093b811b80d8/stdlib/_typeshed/__init__.pyi#L352
561
+ __dataclass_fields__: ClassVar[dict[str, Field[Any]]]
535
562
 
536
563
  @property
537
564
  def state(self) -> State:
@@ -558,6 +585,7 @@ StateT = TypeVar('StateT', bound=BaseModel)
558
585
  """Type variable for the state type, which must be a subclass of `BaseModel`."""
559
586
 
560
587
 
588
+ @dataclass
561
589
  class StateDeps(Generic[StateT]):
562
590
  """Provides AG-UI state management.
563
591
 
@@ -570,42 +598,7 @@ class StateDeps(Generic[StateT]):
570
598
  Implements the `StateHandler` protocol.
571
599
  """
572
600
 
573
- def __init__(self, default: StateT) -> None:
574
- """Initialize the state with the provided state type."""
575
- self._state = default
576
-
577
- @property
578
- def state(self) -> StateT:
579
- """Get the current state of the agent run.
580
-
581
- Returns:
582
- The current run state.
583
- """
584
- return self._state
585
-
586
- @state.setter
587
- def state(self, state: State) -> None:
588
- """Set the state of the agent run.
589
-
590
- This method is called to update the state of the agent run with the
591
- provided state.
592
-
593
- Implements the `StateHandler` protocol.
594
-
595
- Args:
596
- state: The run state, which must be `None` or model validate for the state type.
597
-
598
- Raises:
599
- InvalidStateError: If `state` does not validate.
600
- """
601
- if state is None:
602
- # If state is None, we keep the current state, which will be the default state.
603
- return
604
-
605
- try:
606
- self._state = type(self._state).model_validate(state)
607
- except ValidationError as e: # pragma: no cover
608
- raise _InvalidStateError from e
601
+ state: StateT
609
602
 
610
603
 
611
604
  @dataclass(repr=False)
@@ -774,90 +774,91 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
774
774
 
775
775
  toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
776
776
  # This will raise errors for any name conflicts
777
- run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
778
-
779
- # Merge model settings in order of precedence: run > agent > model
780
- merged_settings = merge_model_settings(model_used.settings, self.model_settings)
781
- model_settings = merge_model_settings(merged_settings, model_settings)
782
- usage_limits = usage_limits or _usage.UsageLimits()
783
- agent_name = self.name or 'agent'
784
- run_span = tracer.start_span(
785
- 'agent run',
786
- attributes={
787
- 'model_name': model_used.model_name if model_used else 'no-model',
788
- 'agent_name': agent_name,
789
- 'logfire.msg': f'{agent_name} run',
790
- },
791
- )
792
-
793
- async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
794
- parts = [
795
- self._instructions,
796
- *[await func.run(run_context) for func in self._instructions_functions],
797
- ]
798
-
799
- model_profile = model_used.profile
800
- if isinstance(output_schema, _output.PromptedOutputSchema):
801
- instructions = output_schema.instructions(model_profile.prompted_output_template)
802
- parts.append(instructions)
777
+ async with toolset:
778
+ run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
779
+
780
+ # Merge model settings in order of precedence: run > agent > model
781
+ merged_settings = merge_model_settings(model_used.settings, self.model_settings)
782
+ model_settings = merge_model_settings(merged_settings, model_settings)
783
+ usage_limits = usage_limits or _usage.UsageLimits()
784
+ agent_name = self.name or 'agent'
785
+ run_span = tracer.start_span(
786
+ 'agent run',
787
+ attributes={
788
+ 'model_name': model_used.model_name if model_used else 'no-model',
789
+ 'agent_name': agent_name,
790
+ 'logfire.msg': f'{agent_name} run',
791
+ },
792
+ )
803
793
 
804
- parts = [p for p in parts if p]
805
- if not parts:
806
- return None
807
- return '\n\n'.join(parts).strip()
794
+ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
795
+ parts = [
796
+ self._instructions,
797
+ *[await func.run(run_context) for func in self._instructions_functions],
798
+ ]
808
799
 
809
- graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
810
- user_deps=deps,
811
- prompt=user_prompt,
812
- new_message_index=new_message_index,
813
- model=model_used,
814
- model_settings=model_settings,
815
- usage_limits=usage_limits,
816
- max_result_retries=self._max_result_retries,
817
- end_strategy=self.end_strategy,
818
- output_schema=output_schema,
819
- output_validators=output_validators,
820
- history_processors=self.history_processors,
821
- tool_manager=run_toolset,
822
- tracer=tracer,
823
- get_instructions=get_instructions,
824
- instrumentation_settings=instrumentation_settings,
825
- )
826
- start_node = _agent_graph.UserPromptNode[AgentDepsT](
827
- user_prompt=user_prompt,
828
- instructions=self._instructions,
829
- instructions_functions=self._instructions_functions,
830
- system_prompts=self._system_prompts,
831
- system_prompt_functions=self._system_prompt_functions,
832
- system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
833
- )
800
+ model_profile = model_used.profile
801
+ if isinstance(output_schema, _output.PromptedOutputSchema):
802
+ instructions = output_schema.instructions(model_profile.prompted_output_template)
803
+ parts.append(instructions)
804
+
805
+ parts = [p for p in parts if p]
806
+ if not parts:
807
+ return None
808
+ return '\n\n'.join(parts).strip()
809
+
810
+ graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
811
+ user_deps=deps,
812
+ prompt=user_prompt,
813
+ new_message_index=new_message_index,
814
+ model=model_used,
815
+ model_settings=model_settings,
816
+ usage_limits=usage_limits,
817
+ max_result_retries=self._max_result_retries,
818
+ end_strategy=self.end_strategy,
819
+ output_schema=output_schema,
820
+ output_validators=output_validators,
821
+ history_processors=self.history_processors,
822
+ tool_manager=run_toolset,
823
+ tracer=tracer,
824
+ get_instructions=get_instructions,
825
+ instrumentation_settings=instrumentation_settings,
826
+ )
827
+ start_node = _agent_graph.UserPromptNode[AgentDepsT](
828
+ user_prompt=user_prompt,
829
+ instructions=self._instructions,
830
+ instructions_functions=self._instructions_functions,
831
+ system_prompts=self._system_prompts,
832
+ system_prompt_functions=self._system_prompt_functions,
833
+ system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
834
+ )
834
835
 
835
- try:
836
- async with graph.iter(
837
- start_node,
838
- state=state,
839
- deps=graph_deps,
840
- span=use_span(run_span) if run_span.is_recording() else None,
841
- infer_name=False,
842
- ) as graph_run:
843
- agent_run = AgentRun(graph_run)
844
- yield agent_run
845
- if (final_result := agent_run.result) is not None and run_span.is_recording():
846
- if instrumentation_settings and instrumentation_settings.include_content:
847
- run_span.set_attribute(
848
- 'final_result',
849
- (
850
- final_result.output
851
- if isinstance(final_result.output, str)
852
- else json.dumps(InstrumentedModel.serialize_any(final_result.output))
853
- ),
854
- )
855
- finally:
856
836
  try:
857
- if instrumentation_settings and run_span.is_recording():
858
- run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
837
+ async with graph.iter(
838
+ start_node,
839
+ state=state,
840
+ deps=graph_deps,
841
+ span=use_span(run_span) if run_span.is_recording() else None,
842
+ infer_name=False,
843
+ ) as graph_run:
844
+ agent_run = AgentRun(graph_run)
845
+ yield agent_run
846
+ if (final_result := agent_run.result) is not None and run_span.is_recording():
847
+ if instrumentation_settings and instrumentation_settings.include_content:
848
+ run_span.set_attribute(
849
+ 'final_result',
850
+ (
851
+ final_result.output
852
+ if isinstance(final_result.output, str)
853
+ else json.dumps(InstrumentedModel.serialize_any(final_result.output))
854
+ ),
855
+ )
859
856
  finally:
860
- run_span.end()
857
+ try:
858
+ if instrumentation_settings and run_span.is_recording():
859
+ run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
860
+ finally:
861
+ run_span.end()
861
862
 
862
863
  def _run_span_end_attributes(
863
864
  self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings
@@ -2173,7 +2174,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
2173
2174
  ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
2174
2175
  """Advance to the next node automatically based on the last returned node."""
2175
2176
  next_node = await self._graph_run.__anext__()
2176
- if _agent_graph.is_agent_node(next_node):
2177
+ if _agent_graph.is_agent_node(node=next_node):
2177
2178
  return next_node
2178
2179
  assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
2179
2180
  return next_node
@@ -412,8 +412,8 @@ class ToolReturn:
412
412
  return_value: Any
413
413
  """The return value to be used in the tool response."""
414
414
 
415
- content: Sequence[UserContent] | None = None
416
- """The content sequence to be sent to the model as a UserPromptPart."""
415
+ content: str | Sequence[UserContent] | None = None
416
+ """The content to be sent to the model as a UserPromptPart."""
417
417
 
418
418
  metadata: Any = None
419
419
  """Additional data that can be accessed programmatically by the application but is not sent to the LLM."""
@@ -120,10 +120,10 @@ class OpenAIModelSettings(ModelSettings, total=False):
120
120
  See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
121
121
  """
122
122
 
123
- openai_service_tier: Literal['auto', 'default', 'flex']
123
+ openai_service_tier: Literal['auto', 'default', 'flex', 'priority']
124
124
  """The service tier to use for the model request.
125
125
 
126
- Currently supported values are `auto`, `default`, and `flex`.
126
+ Currently supported values are `auto`, `default`, `flex`, and `priority`.
127
127
  For more information, see [OpenAI's service tiers documentation](https://platform.openai.com/docs/api-reference/chat/object#chat/object-service_tier).
128
128
  """
129
129
 
@@ -803,6 +803,7 @@ class OpenAIResponsesModel(Model):
803
803
  top_p=sampling_settings.get('top_p', NOT_GIVEN),
804
804
  truncation=model_settings.get('openai_truncation', NOT_GIVEN),
805
805
  timeout=model_settings.get('timeout', NOT_GIVEN),
806
+ service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
806
807
  reasoning=reasoning,
807
808
  user=model_settings.get('openai_user', NOT_GIVEN),
808
809
  text=text or NOT_GIVEN,
@@ -49,7 +49,6 @@ class GoogleJsonSchemaTransformer(JsonSchemaTransformer):
49
49
  )
50
50
 
51
51
  schema.pop('title', None)
52
- schema.pop('default', None)
53
52
  schema.pop('$schema', None)
54
53
  if (const := schema.pop('const', None)) is not None:
55
54
  # Gemini doesn't support const, but it does support enum with a single value
@@ -98,10 +98,16 @@ class VercelProvider(Provider[AsyncOpenAI]):
98
98
  'or pass the API key via `VercelProvider(api_key=...)` to use the Vercel provider.'
99
99
  )
100
100
 
101
+ default_headers = {'http-referer': 'https://ai.pydantic.dev/', 'x-title': 'pydantic-ai'}
102
+
101
103
  if openai_client is not None:
102
104
  self._client = openai_client
103
105
  elif http_client is not None:
104
- self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
106
+ self._client = AsyncOpenAI(
107
+ base_url=self.base_url, api_key=api_key, http_client=http_client, default_headers=default_headers
108
+ )
105
109
  else:
106
110
  http_client = cached_async_http_client(provider='vercel')
107
- self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
111
+ self._client = AsyncOpenAI(
112
+ base_url=self.base_url, api_key=api_key, http_client=http_client, default_headers=default_headers
113
+ )
@@ -67,7 +67,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
67
67
  except ValidationError:
68
68
  pass
69
69
  if self._final_result_event is not None: # pragma: no branch
70
- yield await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
70
+ yield await self._validate_response(self._raw_stream_response.get())
71
71
 
72
72
  async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
73
73
  """Asynchronously stream the (unvalidated) model responses for the agent."""
@@ -128,7 +128,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
128
128
  async for _ in self:
129
129
  pass
130
130
 
131
- return await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
131
+ return await self._validate_response(self._raw_stream_response.get())
132
132
 
133
133
  async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
134
134
  """Validate a structured result message."""
@@ -150,7 +150,9 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
150
150
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
151
151
  f'Invalid response, unable to find tool call for {output_tool_name!r}'
152
152
  )
153
- return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
153
+ return await self._tool_manager.handle_call(
154
+ tool_call, allow_partial=allow_partial, wrap_validation_errors=False
155
+ )
154
156
  elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
155
157
  if not self._output_schema.allows_deferred_tool_calls:
156
158
  raise exceptions.UserError(
@@ -63,12 +63,7 @@ dependencies = [
63
63
  logfire = ["logfire>=3.11.0"]
64
64
  # Models
65
65
  openai = ["openai>=1.92.0"]
66
- cohere = [
67
- "cohere>=5.16.0; platform_system != 'Emscripten'",
68
- # Remove once all wheels for 0.21.4+ are built successfully
69
- # https://github.com/huggingface/tokenizers/actions/runs/16570140346/job/46860152621
70
- "tokenizers<=0.21.2",
71
- ]
66
+ cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"]
72
67
  vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
73
68
  google = ["google-genai>=1.24.0"]
74
69
  anthropic = ["anthropic>=0.52.0"]