pydantic-ai-slim 0.4.4__tar.gz → 0.4.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (98) hide show
  1. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/.gitignore +0 -1
  2. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/PKG-INFO +4 -4
  3. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_a2a.py +3 -3
  4. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_function_schema.py +13 -4
  5. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_output.py +41 -25
  6. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/agent.py +18 -37
  7. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/format_prompt.py +3 -6
  8. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/__init__.py +1 -1
  9. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/function.py +15 -16
  10. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/gemini.py +0 -9
  11. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/instrumented.py +6 -1
  12. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/mistral.py +12 -2
  13. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/openai.py +13 -1
  14. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/result.py +115 -151
  15. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pyproject.toml +1 -1
  16. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/LICENSE +0 -0
  17. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/README.md +0 -0
  18. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/__init__.py +0 -0
  19. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/__main__.py +0 -0
  20. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_agent_graph.py +0 -0
  21. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_cli.py +0 -0
  22. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_griffe.py +0 -0
  23. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_mcp.py +0 -0
  24. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_parts_manager.py +0 -0
  25. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_run_context.py +0 -0
  26. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_system_prompt.py +0 -0
  27. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_thinking_part.py +0 -0
  28. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_tool_manager.py +0 -0
  29. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/_utils.py +0 -0
  30. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/ag_ui.py +0 -0
  31. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/common_tools/__init__.py +0 -0
  32. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  33. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/common_tools/tavily.py +0 -0
  34. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/direct.py +0 -0
  35. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/exceptions.py +0 -0
  36. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/ext/__init__.py +0 -0
  37. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/ext/aci.py +0 -0
  38. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/ext/langchain.py +0 -0
  39. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/format_as_xml.py +0 -0
  40. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/mcp.py +0 -0
  41. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/messages.py +0 -0
  42. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/anthropic.py +0 -0
  43. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/bedrock.py +0 -0
  44. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/cohere.py +0 -0
  45. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/fallback.py +0 -0
  46. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/google.py +0 -0
  47. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/groq.py +0 -0
  48. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/huggingface.py +0 -0
  49. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/mcp_sampling.py +0 -0
  50. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/test.py +0 -0
  51. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/wrapper.py +0 -0
  52. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/output.py +0 -0
  53. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/__init__.py +0 -0
  54. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/_json_schema.py +0 -0
  55. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/amazon.py +0 -0
  56. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/anthropic.py +0 -0
  57. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/cohere.py +0 -0
  58. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/deepseek.py +0 -0
  59. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/google.py +0 -0
  60. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/grok.py +0 -0
  61. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/meta.py +0 -0
  62. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/mistral.py +0 -0
  63. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/moonshotai.py +0 -0
  64. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/openai.py +0 -0
  65. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/qwen.py +0 -0
  66. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/__init__.py +0 -0
  67. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/anthropic.py +0 -0
  68. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/azure.py +0 -0
  69. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/bedrock.py +0 -0
  70. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/cohere.py +0 -0
  71. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/deepseek.py +0 -0
  72. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/fireworks.py +0 -0
  73. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/github.py +0 -0
  74. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/google.py +0 -0
  75. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/google_gla.py +0 -0
  76. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/google_vertex.py +0 -0
  77. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/grok.py +0 -0
  78. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/groq.py +0 -0
  79. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/heroku.py +0 -0
  80. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/huggingface.py +0 -0
  81. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/mistral.py +0 -0
  82. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/openai.py +0 -0
  83. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/openrouter.py +0 -0
  84. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/together.py +0 -0
  85. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/py.typed +0 -0
  86. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/settings.py +0 -0
  87. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/tools.py +0 -0
  88. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/__init__.py +0 -0
  89. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/abstract.py +0 -0
  90. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/combined.py +0 -0
  91. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/deferred.py +0 -0
  92. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/filtered.py +0 -0
  93. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/function.py +0 -0
  94. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/prefixed.py +0 -0
  95. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/prepared.py +0 -0
  96. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/renamed.py +0 -0
  97. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/wrapper.py +0 -0
  98. {pydantic_ai_slim-0.4.4 → pydantic_ai_slim-0.4.6}/pydantic_ai/usage.py +0 -0
@@ -15,7 +15,6 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
15
15
  .vscode/
16
16
  /question_graph_history.json
17
17
  /docs-site/.wrangler/
18
- /CLAUDE.md
19
18
  node_modules/
20
19
  **.idea/
21
20
  .coverage*
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.4.4
3
+ Version: 0.4.6
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>, Douwe Maan <douwe@pydantic.dev>
6
6
  License-Expression: MIT
@@ -30,7 +30,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
30
30
  Requires-Dist: griffe>=1.3.2
31
31
  Requires-Dist: httpx>=0.27
32
32
  Requires-Dist: opentelemetry-api>=1.28.0
33
- Requires-Dist: pydantic-graph==0.4.4
33
+ Requires-Dist: pydantic-graph==0.4.6
34
34
  Requires-Dist: pydantic>=2.10
35
35
  Requires-Dist: typing-inspection>=0.4.0
36
36
  Provides-Extra: a2a
@@ -51,7 +51,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
51
51
  Provides-Extra: duckduckgo
52
52
  Requires-Dist: ddgs>=9.0.0; extra == 'duckduckgo'
53
53
  Provides-Extra: evals
54
- Requires-Dist: pydantic-evals==0.4.4; extra == 'evals'
54
+ Requires-Dist: pydantic-evals==0.4.6; extra == 'evals'
55
55
  Provides-Extra: google
56
56
  Requires-Dist: google-genai>=1.24.0; extra == 'google'
57
57
  Provides-Extra: groq
@@ -63,7 +63,7 @@ Requires-Dist: logfire>=3.11.0; extra == 'logfire'
63
63
  Provides-Extra: mcp
64
64
  Requires-Dist: mcp>=1.9.4; (python_version >= '3.10') and extra == 'mcp'
65
65
  Provides-Extra: mistral
66
- Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
66
+ Requires-Dist: mistralai>=1.9.2; extra == 'mistral'
67
67
  Provides-Extra: openai
68
68
  Requires-Dist: openai>=1.92.0; extra == 'openai'
69
69
  Provides-Extra: tavily
@@ -59,12 +59,12 @@ except ImportError as _import_error:
59
59
 
60
60
 
61
61
  @asynccontextmanager
62
- async def worker_lifespan(app: FastA2A, worker: Worker) -> AsyncIterator[None]:
62
+ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, OutputDataT]) -> AsyncIterator[None]:
63
63
  """Custom lifespan that runs the worker during application startup.
64
64
 
65
65
  This ensures the worker is started and ready to process tasks as soon as the application starts.
66
66
  """
67
- async with app.task_manager:
67
+ async with app.task_manager, agent:
68
68
  async with worker.run():
69
69
  yield
70
70
 
@@ -93,7 +93,7 @@ def agent_to_a2a(
93
93
  broker = broker or InMemoryBroker()
94
94
  worker = AgentWorker(agent=agent, broker=broker, storage=storage)
95
95
 
96
- lifespan = lifespan or partial(worker_lifespan, worker=worker)
96
+ lifespan = lifespan or partial(worker_lifespan, worker=worker, agent=agent)
97
97
 
98
98
  return FastA2A(
99
99
  storage=storage,
@@ -96,8 +96,13 @@ def function_schema( # noqa: C901
96
96
  config = ConfigDict(title=function.__name__, use_attribute_docstrings=True)
97
97
  config_wrapper = ConfigWrapper(config)
98
98
  gen_schema = _generate_schema.GenerateSchema(config_wrapper)
99
+ errors: list[str] = []
99
100
 
100
- sig = signature(function)
101
+ try:
102
+ sig = signature(function)
103
+ except ValueError as e:
104
+ errors.append(str(e))
105
+ sig = signature(lambda: None)
101
106
 
102
107
  type_hints = _typing_extra.get_function_type_hints(function)
103
108
 
@@ -105,7 +110,6 @@ def function_schema( # noqa: C901
105
110
  fields: dict[str, core_schema.TypedDictField] = {}
106
111
  positional_fields: list[str] = []
107
112
  var_positional_field: str | None = None
108
- errors: list[str] = []
109
113
  decorators = _decorators.DecoratorInfos()
110
114
 
111
115
  description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
@@ -235,14 +239,19 @@ def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
235
239
  Returns:
236
240
  `True` if the function takes a `RunContext` as first argument, `False` otherwise.
237
241
  """
238
- sig = signature(function)
242
+ try:
243
+ sig = signature(function)
244
+ except ValueError: # pragma: no cover
245
+ return False # pragma: no cover
239
246
  try:
240
247
  first_param_name = next(iter(sig.parameters.keys()))
241
248
  except StopIteration:
242
249
  return False
243
250
  else:
244
251
  type_hints = _typing_extra.get_function_type_hints(function)
245
- annotation = type_hints[first_param_name]
252
+ annotation = type_hints.get(first_param_name)
253
+ if annotation is None:
254
+ return False # pragma: no cover
246
255
  return True is not sig.empty and _is_call_ctx(annotation)
247
256
 
248
257
 
@@ -69,12 +69,31 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
69
69
  DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
70
70
 
71
71
 
72
- async def execute_output_function_with_span(
72
+ async def execute_traced_output_function(
73
73
  function_schema: _function_schema.FunctionSchema,
74
74
  run_context: RunContext[AgentDepsT],
75
75
  args: dict[str, Any] | Any,
76
+ wrap_validation_errors: bool = True,
76
77
  ) -> Any:
77
- """Execute a function call within a traced span, automatically recording the response."""
78
+ """Execute an output function within a traced span with error handling.
79
+
80
+ This function executes the output function within an OpenTelemetry span for observability,
81
+ automatically records the function response, and handles ModelRetry exceptions by converting
82
+ them to ToolRetryError when wrap_validation_errors is True.
83
+
84
+ Args:
85
+ function_schema: The function schema containing the function to execute
86
+ run_context: The current run context containing tracing and tool information
87
+ args: Arguments to pass to the function
88
+ wrap_validation_errors: If True, wrap ModelRetry exceptions in ToolRetryError
89
+
90
+ Returns:
91
+ The result of the function execution
92
+
93
+ Raises:
94
+ ToolRetryError: When wrap_validation_errors is True and a ModelRetry is caught
95
+ ModelRetry: When wrap_validation_errors is False and a ModelRetry occurs
96
+ """
78
97
  # Set up span attributes
79
98
  tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function')
80
99
  attributes = {
@@ -96,7 +115,19 @@ async def execute_output_function_with_span(
96
115
  )
97
116
 
98
117
  with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span:
99
- output = await function_schema.call(args, run_context)
118
+ try:
119
+ output = await function_schema.call(args, run_context)
120
+ except ModelRetry as r:
121
+ if wrap_validation_errors:
122
+ m = _messages.RetryPromptPart(
123
+ content=r.message,
124
+ tool_name=run_context.tool_name,
125
+ )
126
+ if run_context.tool_call_id:
127
+ m.tool_call_id = run_context.tool_call_id # pragma: no cover
128
+ raise ToolRetryError(m) from r
129
+ else:
130
+ raise
100
131
 
101
132
  # Record response if content inclusion is enabled
102
133
  if run_context.trace_include_content and span.is_recording():
@@ -663,16 +694,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
663
694
  else:
664
695
  raise
665
696
 
666
- try:
667
- output = await self.call(output, run_context)
668
- except ModelRetry as r:
669
- if wrap_validation_errors:
670
- m = _messages.RetryPromptPart(
671
- content=r.message,
672
- )
673
- raise ToolRetryError(m) from r
674
- else:
675
- raise # pragma: no cover
697
+ output = await self.call(output, run_context, wrap_validation_errors)
676
698
 
677
699
  return output
678
700
 
@@ -691,12 +713,15 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
691
713
  self,
692
714
  output: Any,
693
715
  run_context: RunContext[AgentDepsT],
716
+ wrap_validation_errors: bool = True,
694
717
  ):
695
718
  if k := self.outer_typed_dict_key:
696
719
  output = output[k]
697
720
 
698
721
  if self._function_schema:
699
- output = await execute_output_function_with_span(self._function_schema, run_context, output)
722
+ output = await execute_traced_output_function(
723
+ self._function_schema, run_context, output, wrap_validation_errors
724
+ )
700
725
 
701
726
  return output
702
727
 
@@ -856,16 +881,7 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
856
881
  wrap_validation_errors: bool = True,
857
882
  ) -> OutputDataT:
858
883
  args = {self._str_argument_name: data}
859
- try:
860
- output = await execute_output_function_with_span(self._function_schema, run_context, args)
861
- except ModelRetry as r:
862
- if wrap_validation_errors:
863
- m = _messages.RetryPromptPart(
864
- content=r.message,
865
- )
866
- raise ToolRetryError(m) from r
867
- else:
868
- raise # pragma: no cover
884
+ output = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
869
885
 
870
886
  return cast(OutputDataT, output)
871
887
 
@@ -975,7 +991,7 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
975
991
  async def call_tool(
976
992
  self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
977
993
  ) -> Any:
978
- output = await self.processors[name].call(tool_args, ctx)
994
+ output = await self.processors[name].call(tool_args, ctx, wrap_validation_errors=False)
979
995
  for validator in self.output_validators:
980
996
  output = await validator.validate(output, ctx, wrap_validation_errors=False)
981
997
  return output
@@ -36,7 +36,7 @@ from ._tool_manager import ToolManager
36
36
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
37
37
  from .output import OutputDataT, OutputSpec
38
38
  from .profiles import ModelProfile
39
- from .result import FinalResult, StreamedRunResult
39
+ from .result import AgentStream, FinalResult, StreamedRunResult
40
40
  from .settings import ModelSettings, merge_model_settings
41
41
  from .tools import (
42
42
  AgentDepsT,
@@ -843,14 +843,15 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
843
843
  agent_run = AgentRun(graph_run)
844
844
  yield agent_run
845
845
  if (final_result := agent_run.result) is not None and run_span.is_recording():
846
- run_span.set_attribute(
847
- 'final_result',
848
- (
849
- final_result.output
850
- if isinstance(final_result.output, str)
851
- else json.dumps(InstrumentedModel.serialize_any(final_result.output))
852
- ),
853
- )
846
+ if instrumentation_settings and instrumentation_settings.include_content:
847
+ run_span.set_attribute(
848
+ 'final_result',
849
+ (
850
+ final_result.output
851
+ if isinstance(final_result.output, str)
852
+ else json.dumps(InstrumentedModel.serialize_any(final_result.output))
853
+ ),
854
+ )
854
855
  finally:
855
856
  try:
856
857
  if instrumentation_settings and run_span.is_recording():
@@ -1126,29 +1127,15 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1126
1127
  while True:
1127
1128
  if self.is_model_request_node(node):
1128
1129
  graph_ctx = agent_run.ctx
1129
- async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
1130
-
1131
- async def stream_to_final(
1132
- s: models.StreamedResponse,
1133
- ) -> FinalResult[models.StreamedResponse] | None:
1134
- output_schema = graph_ctx.deps.output_schema
1135
- async for maybe_part_event in streamed_response:
1136
- if isinstance(maybe_part_event, _messages.PartStartEvent):
1137
- new_part = maybe_part_event.part
1138
- if isinstance(new_part, _messages.TextPart) and isinstance(
1139
- output_schema, _output.TextOutputSchema
1140
- ):
1141
- return FinalResult(s, None, None)
1142
- elif isinstance(new_part, _messages.ToolCallPart) and (
1143
- tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
1144
- ):
1145
- if tool_def.kind == 'output':
1146
- return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
1147
- elif tool_def.kind == 'deferred':
1148
- return FinalResult(s, None, None)
1130
+ async with node.stream(graph_ctx) as stream:
1131
+
1132
+ async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None:
1133
+ async for event in stream:
1134
+ if isinstance(event, _messages.FinalResultEvent):
1135
+ return FinalResult(s, event.tool_name, event.tool_call_id)
1149
1136
  return None
1150
1137
 
1151
- final_result = await stream_to_final(streamed_response)
1138
+ final_result = await stream_to_final(stream)
1152
1139
  if final_result is not None:
1153
1140
  if yielded:
1154
1141
  raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
@@ -1183,14 +1170,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1183
1170
  yield StreamedRunResult(
1184
1171
  messages,
1185
1172
  graph_ctx.deps.new_message_index,
1186
- graph_ctx.deps.usage_limits,
1187
- streamed_response,
1188
- graph_ctx.deps.output_schema,
1189
- _agent_graph.build_run_context(graph_ctx),
1190
- graph_ctx.deps.output_validators,
1191
- final_result.tool_name,
1173
+ stream,
1192
1174
  on_complete,
1193
- graph_ctx.deps.tool_manager,
1194
1175
  )
1195
1176
  break
1196
1177
  next_node = await agent_run.next(node)
@@ -13,9 +13,8 @@ __all__ = ('format_as_xml',)
13
13
 
14
14
  def format_as_xml(
15
15
  obj: Any,
16
- root_tag: str = 'examples',
17
- item_tag: str = 'example',
18
- include_root_tag: bool = True,
16
+ root_tag: str | None = None,
17
+ item_tag: str = 'item',
19
18
  none_str: str = 'null',
20
19
  indent: str | None = ' ',
21
20
  ) -> str:
@@ -32,8 +31,6 @@ def format_as_xml(
32
31
  root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
33
32
  item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
34
33
  for dataclasses and Pydantic models.
35
- include_root_tag: Whether to include the root tag in the output
36
- (The root tag is always included if it includes a body - e.g. when the input is a simple value).
37
34
  none_str: String to use for `None` values.
38
35
  indent: Indentation string to use for pretty printing.
39
36
 
@@ -55,7 +52,7 @@ def format_as_xml(
55
52
  ```
56
53
  """
57
54
  el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
58
- if not include_root_tag and el.text is None:
55
+ if root_tag is None and el.text is None:
59
56
  join = '' if indent is None else '\n'
60
57
  return join.join(_rootless_xml_elements(el, indent))
61
58
  else:
@@ -758,7 +758,7 @@ async def download_item(
758
758
 
759
759
  data_type = media_type
760
760
  if type_format == 'extension':
761
- data_type = data_type.split('/')[1]
761
+ data_type = item.format
762
762
 
763
763
  data = response.content
764
764
  if data_format in ('base64', 'base64_uri'):
@@ -16,9 +16,7 @@ from pydantic_ai.profiles import ModelProfileSpec
16
16
  from .. import _utils, usage
17
17
  from .._utils import PeekableAsyncStream
18
18
  from ..messages import (
19
- AudioUrl,
20
19
  BinaryContent,
21
- ImageUrl,
22
20
  ModelMessage,
23
21
  ModelRequest,
24
22
  ModelResponse,
@@ -345,18 +343,19 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
345
343
  def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
346
344
  if not content:
347
345
  return 0
346
+
348
347
  if isinstance(content, str):
349
- return len(re.split(r'[\s",.:]+', content.strip()))
350
- else:
351
- tokens = 0
352
- for part in content:
353
- if isinstance(part, str):
354
- tokens += len(re.split(r'[\s",.:]+', part.strip()))
355
- # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
356
- if isinstance(part, (AudioUrl, ImageUrl)):
357
- tokens += 0
358
- elif isinstance(part, BinaryContent):
359
- tokens += len(part.data)
360
- else:
361
- tokens += 0
362
- return tokens
348
+ return len(_TOKEN_SPLIT_RE.split(content.strip()))
349
+
350
+ tokens = 0
351
+ for part in content:
352
+ if isinstance(part, str):
353
+ tokens += len(_TOKEN_SPLIT_RE.split(part.strip()))
354
+ elif isinstance(part, BinaryContent):
355
+ tokens += len(part.data)
356
+ # TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
357
+
358
+ return tokens
359
+
360
+
361
+ _TOKEN_SPLIT_RE = re.compile(r'[\s",.:]+')
@@ -91,15 +91,6 @@ class GeminiModelSettings(ModelSettings, total=False):
91
91
  See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
92
92
  """
93
93
 
94
- gemini_thinking_config: ThinkingConfig
95
- """Thinking is on by default in both the API and AI Studio.
96
-
97
- Being on by default doesn't mean the model will send back thoughts. For that, you need to set `include_thoughts`
98
- to `True`. If you want to turn it off, set `thinking_budget` to `0`.
99
-
100
- See more about it on <https://ai.google.dev/gemini-api/docs/thinking>.
101
- """
102
-
103
94
 
104
95
  @dataclass(init=False)
105
96
  class GeminiModel(Model):
@@ -156,7 +156,12 @@ class InstrumentationSettings:
156
156
  events: list[Event] = []
157
157
  instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
158
158
  if instructions is not None:
159
- events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'}))
159
+ events.append(
160
+ Event(
161
+ 'gen_ai.system.message',
162
+ body={**({'content': instructions} if self.include_content else {}), 'role': 'system'},
163
+ )
164
+ )
160
165
 
161
166
  for message_index, message in enumerate(messages):
162
167
  message_events: list[Event] = []
@@ -52,6 +52,7 @@ try:
52
52
  CompletionChunk as MistralCompletionChunk,
53
53
  Content as MistralContent,
54
54
  ContentChunk as MistralContentChunk,
55
+ DocumentURLChunk as MistralDocumentURLChunk,
55
56
  FunctionCall as MistralFunctionCall,
56
57
  ImageURL as MistralImageURL,
57
58
  ImageURLChunk as MistralImageURLChunk,
@@ -539,10 +540,19 @@ class MistralModel(Model):
539
540
  if item.is_image:
540
541
  image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
541
542
  content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
543
+ elif item.media_type == 'application/pdf':
544
+ content.append(
545
+ MistralDocumentURLChunk(
546
+ document_url=f'data:application/pdf;base64,{base64_encoded}', type='document_url'
547
+ )
548
+ )
542
549
  else:
543
- raise RuntimeError('Only image binary content is supported for Mistral.')
550
+ raise RuntimeError('BinaryContent other than image or PDF is not supported in Mistral.')
544
551
  elif isinstance(item, DocumentUrl):
545
- raise RuntimeError('DocumentUrl is not supported in Mistral.') # pragma: no cover
552
+ if item.media_type == 'application/pdf':
553
+ content.append(MistralDocumentURLChunk(document_url=item.url, type='document_url'))
554
+ else:
555
+ raise RuntimeError('DocumentUrl other than PDF is not supported in Mistral.')
546
556
  elif isinstance(item, VideoUrl):
547
557
  raise RuntimeError('VideoUrl is not supported in Mistral.')
548
558
  else: # pragma: no cover
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
9
  from typing import Any, Literal, Union, cast, overload
10
10
 
11
+ from pydantic import ValidationError
11
12
  from typing_extensions import assert_never
12
13
 
13
14
  from pydantic_ai._thinking_part import split_content_into_text_and_thinking
@@ -347,8 +348,19 @@ class OpenAIModel(Model):
347
348
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
348
349
  raise # pragma: no cover
349
350
 
350
- def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
351
+ def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse:
351
352
  """Process a non-streamed response, and prepare a message to return."""
353
+ # Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function:
354
+ # * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!)
355
+ # * if the endpoint returns plain text, the return type is a string
356
+ # Thus we validate it fully here.
357
+ if not isinstance(response, chat.ChatCompletion):
358
+ raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
359
+
360
+ try:
361
+ response = chat.ChatCompletion.model_validate(response.model_dump())
362
+ except ValidationError as e:
363
+ raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e
352
364
  timestamp = number_to_datetime(response.created)
353
365
  choice = response.choices[0]
354
366
  items: list[ModelResponsePart] = []
@@ -63,22 +63,18 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
63
63
  async for response in self.stream_responses(debounce_by=debounce_by):
64
64
  if self._final_result_event is not None:
65
65
  try:
66
- yield await self._validate_response(
67
- response, self._final_result_event.tool_name, allow_partial=True
68
- )
66
+ yield await self._validate_response(response, allow_partial=True)
69
67
  except ValidationError:
70
68
  pass
71
69
  if self._final_result_event is not None: # pragma: no branch
72
- yield await self._validate_response(
73
- self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
74
- )
70
+ yield await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
75
71
 
76
72
  async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
77
73
  """Asynchronously stream the (unvalidated) model responses for the agent."""
78
74
  # if the message currently has any parts with content, yield before streaming
79
75
  msg = self._raw_stream_response.get()
80
76
  for part in msg.parts:
81
- if part.has_content(): # pragma: no cover
77
+ if part.has_content():
82
78
  yield msg
83
79
  break
84
80
 
@@ -86,6 +82,35 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
86
82
  async for _items in group_iter:
87
83
  yield self._raw_stream_response.get() # current state of the response
88
84
 
85
+ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
86
+ """Stream the text result as an async iterable.
87
+
88
+ !!! note
89
+ Result validators will NOT be called on the text result if `delta=True`.
90
+
91
+ Args:
92
+ delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
93
+ up to the current point.
94
+ debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
95
+ Debouncing is particularly important for long structured responses to reduce the overhead of
96
+ performing validation as each token is received.
97
+ """
98
+ if not isinstance(self._output_schema, PlainTextOutputSchema):
99
+ raise exceptions.UserError('stream_text() can only be used with text responses')
100
+
101
+ if delta:
102
+ async for text in self._stream_response_text(delta=True, debounce_by=debounce_by):
103
+ yield text
104
+ else:
105
+ async for text in self._stream_response_text(delta=False, debounce_by=debounce_by):
106
+ for validator in self._output_validators:
107
+ text = await validator.validate(text, self._run_ctx) # pragma: no cover
108
+ yield text
109
+
110
+ def get(self) -> _messages.ModelResponse:
111
+ """Get the current state of the response."""
112
+ return self._raw_stream_response.get()
113
+
89
114
  def usage(self) -> Usage:
90
115
  """Return the usage of the whole run.
91
116
 
@@ -94,10 +119,24 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
94
119
  """
95
120
  return self._initial_run_ctx_usage + self._raw_stream_response.usage()
96
121
 
97
- async def _validate_response(
98
- self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False
99
- ) -> OutputDataT:
122
+ def timestamp(self) -> datetime:
123
+ """Get the timestamp of the response."""
124
+ return self._raw_stream_response.timestamp
125
+
126
+ async def get_output(self) -> OutputDataT:
127
+ """Stream the whole response, validate the output and return it."""
128
+ async for _ in self:
129
+ pass
130
+
131
+ return await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
132
+
133
+ async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
100
134
  """Validate a structured result message."""
135
+ if self._final_result_event is None:
136
+ raise exceptions.UnexpectedModelBehavior('Invalid response, unable to find output') # pragma: no cover
137
+
138
+ output_tool_name = self._final_result_event.tool_name
139
+
101
140
  if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None:
102
141
  tool_call = next(
103
142
  (
@@ -114,7 +153,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
114
153
  return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
115
154
  elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
116
155
  if not self._output_schema.allows_deferred_tool_calls:
117
- raise exceptions.UserError( # pragma: no cover
156
+ raise exceptions.UserError(
118
157
  'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
119
158
  )
120
159
  return cast(OutputDataT, deferred_tool_calls)
@@ -132,6 +171,54 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
132
171
  'Invalid response, unable to process text output'
133
172
  )
134
173
 
174
+ async def _stream_response_text(
175
+ self, *, delta: bool = False, debounce_by: float | None = 0.1
176
+ ) -> AsyncIterator[str]:
177
+ """Stream the response as an async iterable of text."""
178
+
179
+ # Define a "merged" version of the iterator that will yield items that have already been retrieved
180
+ # and items that we receive while streaming. We define a dedicated async iterator for this so we can
181
+ # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
182
+ async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
183
+ # yields tuples of (text_content, part_index)
184
+ # we don't currently make use of the part_index, but in principle this may be useful
185
+ # so we retain it here for now to make possible future refactors simpler
186
+ msg = self._raw_stream_response.get()
187
+ for i, part in enumerate(msg.parts):
188
+ if isinstance(part, _messages.TextPart) and part.content:
189
+ yield part.content, i
190
+
191
+ async for event in self._raw_stream_response:
192
+ if (
193
+ isinstance(event, _messages.PartStartEvent)
194
+ and isinstance(event.part, _messages.TextPart)
195
+ and event.part.content
196
+ ):
197
+ yield event.part.content, event.index # pragma: no cover
198
+ elif ( # pragma: no branch
199
+ isinstance(event, _messages.PartDeltaEvent)
200
+ and isinstance(event.delta, _messages.TextPartDelta)
201
+ and event.delta.content_delta
202
+ ):
203
+ yield event.delta.content_delta, event.index
204
+
205
+ async def _stream_text_deltas() -> AsyncIterator[str]:
206
+ async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
207
+ async for items in group_iter:
208
+ # Note: we are currently just dropping the part index on the group here
209
+ yield ''.join([content for content, _ in items])
210
+
211
+ if delta:
212
+ async for text in _stream_text_deltas():
213
+ yield text
214
+ else:
215
+ # a quick benchmark shows it's faster to build up a string with concat when we're
216
+ # yielding at each step
217
+ deltas: list[str] = []
218
+ async for text in _stream_text_deltas():
219
+ deltas.append(text)
220
+ yield ''.join(deltas)
221
+
135
222
  def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
136
223
  """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
137
224
 
@@ -189,16 +276,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
189
276
  _all_messages: list[_messages.ModelMessage]
190
277
  _new_message_index: int
191
278
 
192
- _usage_limits: UsageLimits | None
193
- _stream_response: models.StreamedResponse
194
- _output_schema: OutputSchema[OutputDataT]
195
- _run_ctx: RunContext[AgentDepsT]
196
- _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
197
- _output_tool_name: str | None
279
+ _stream_response: AgentStream[AgentDepsT, OutputDataT]
198
280
  _on_complete: Callable[[], Awaitable[None]]
199
- _tool_manager: ToolManager[AgentDepsT]
200
281
 
201
- _initial_run_ctx_usage: Usage = field(init=False)
202
282
  is_complete: bool = field(default=False, init=False)
203
283
  """Whether the stream has all been received.
204
284
 
@@ -209,9 +289,6 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
209
289
  [`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes.
210
290
  """
211
291
 
212
- def __post_init__(self):
213
- self._initial_run_ctx_usage = copy(self._run_ctx.usage)
214
-
215
292
  @overload
216
293
  def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
217
294
 
@@ -332,12 +409,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
332
409
  Returns:
333
410
  An async iterable of the response data.
334
411
  """
335
- async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
336
- try:
337
- yield await self.validate_structured_output(structured_message, allow_partial=not is_last)
338
- except ValidationError:
339
- if is_last:
340
- raise # pragma: no cover
412
+ async for output in self._stream_response.stream_output(debounce_by=debounce_by):
413
+ yield output
414
+ await self._marked_completed(self._stream_response.get())
341
415
 
342
416
  async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
343
417
  """Stream the text result as an async iterable.
@@ -352,16 +426,8 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
352
426
  Debouncing is particularly important for long structured responses to reduce the overhead of
353
427
  performing validation as each token is received.
354
428
  """
355
- if not isinstance(self._output_schema, PlainTextOutputSchema):
356
- raise exceptions.UserError('stream_text() can only be used with text responses')
357
-
358
- if delta:
359
- async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
360
- yield text
361
- else:
362
- async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
363
- combined_validated_text = await self._validate_text_output(text)
364
- yield combined_validated_text
429
+ async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
430
+ yield text
365
431
  await self._marked_completed(self._stream_response.get())
366
432
 
367
433
  async def stream_structured(
@@ -378,13 +444,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
378
444
  An async iterable of the structured response message and whether that is the last message.
379
445
  """
380
446
  # if the message currently has any parts with content, yield before streaming
381
- msg = self._stream_response.get()
382
- for part in msg.parts:
383
- if part.has_content():
384
- yield msg, False
385
- break
386
-
387
- async for msg in self._stream_response_structured(debounce_by=debounce_by):
447
+ async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
388
448
  yield msg, False
389
449
 
390
450
  msg = self._stream_response.get()
@@ -394,15 +454,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
394
454
 
395
455
  async def get_output(self) -> OutputDataT:
396
456
  """Stream the whole response, validate and return it."""
397
- usage_checking_stream = _get_usage_checking_stream_response(
398
- self._stream_response, self._usage_limits, self.usage
399
- )
400
-
401
- async for _ in usage_checking_stream:
402
- pass
403
- message = self._stream_response.get()
404
- await self._marked_completed(message)
405
- return await self.validate_structured_output(message)
457
+ output = await self._stream_response.get_output()
458
+ await self._marked_completed(self._stream_response.get())
459
+ return output
406
460
 
407
461
  @deprecated('`get_data` is deprecated, use `get_output` instead.')
408
462
  async def get_data(self) -> OutputDataT:
@@ -414,11 +468,11 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
414
468
  !!! note
415
469
  This won't return the full usage until the stream is finished.
416
470
  """
417
- return self._initial_run_ctx_usage + self._stream_response.usage()
471
+ return self._stream_response.usage()
418
472
 
419
473
  def timestamp(self) -> datetime:
420
474
  """Get the timestamp of the response."""
421
- return self._stream_response.timestamp
475
+ return self._stream_response.timestamp()
422
476
 
423
477
  @deprecated('`validate_structured_result` is deprecated, use `validate_structured_output` instead.')
424
478
  async def validate_structured_result(
@@ -430,105 +484,15 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
430
484
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
431
485
  ) -> OutputDataT:
432
486
  """Validate a structured result message."""
433
- if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None:
434
- tool_call = next(
435
- (
436
- part
437
- for part in message.parts
438
- if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name
439
- ),
440
- None,
441
- )
442
- if tool_call is None:
443
- raise exceptions.UnexpectedModelBehavior( # pragma: no cover
444
- f'Invalid response, unable to find tool call for {self._output_tool_name!r}'
445
- )
446
- return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
447
- elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
448
- if not self._output_schema.allows_deferred_tool_calls:
449
- raise exceptions.UserError(
450
- 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
451
- )
452
- return cast(OutputDataT, deferred_tool_calls)
453
- elif isinstance(self._output_schema, TextOutputSchema):
454
- text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
455
-
456
- result_data = await self._output_schema.process(
457
- text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
458
- )
459
- for validator in self._output_validators:
460
- result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover
461
- return result_data
462
- else:
463
- raise exceptions.UnexpectedModelBehavior( # pragma: no cover
464
- 'Invalid response, unable to process text output'
465
- )
466
-
467
- async def _validate_text_output(self, text: str) -> str:
468
- for validator in self._output_validators:
469
- text = await validator.validate(text, self._run_ctx) # pragma: no cover
470
- return text
487
+ return await self._stream_response._validate_response( # pyright: ignore[reportPrivateUsage]
488
+ message, allow_partial=allow_partial
489
+ )
471
490
 
472
491
  async def _marked_completed(self, message: _messages.ModelResponse) -> None:
473
492
  self.is_complete = True
474
493
  self._all_messages.append(message)
475
494
  await self._on_complete()
476
495
 
477
- async def _stream_response_structured(
478
- self, *, debounce_by: float | None = 0.1
479
- ) -> AsyncIterator[_messages.ModelResponse]:
480
- async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
481
- async for _items in group_iter:
482
- yield self._stream_response.get()
483
-
484
- async def _stream_response_text(
485
- self, *, delta: bool = False, debounce_by: float | None = 0.1
486
- ) -> AsyncIterator[str]:
487
- """Stream the response as an async iterable of text."""
488
-
489
- # Define a "merged" version of the iterator that will yield items that have already been retrieved
490
- # and items that we receive while streaming. We define a dedicated async iterator for this so we can
491
- # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
492
- async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
493
- # yields tuples of (text_content, part_index)
494
- # we don't currently make use of the part_index, but in principle this may be useful
495
- # so we retain it here for now to make possible future refactors simpler
496
- msg = self._stream_response.get()
497
- for i, part in enumerate(msg.parts):
498
- if isinstance(part, _messages.TextPart) and part.content:
499
- yield part.content, i
500
-
501
- async for event in self._stream_response:
502
- if (
503
- isinstance(event, _messages.PartStartEvent)
504
- and isinstance(event.part, _messages.TextPart)
505
- and event.part.content
506
- ):
507
- yield event.part.content, event.index # pragma: no cover
508
- elif ( # pragma: no branch
509
- isinstance(event, _messages.PartDeltaEvent)
510
- and isinstance(event.delta, _messages.TextPartDelta)
511
- and event.delta.content_delta
512
- ):
513
- yield event.delta.content_delta, event.index
514
-
515
- async def _stream_text_deltas() -> AsyncIterator[str]:
516
- async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
517
- async for items in group_iter:
518
- # Note: we are currently just dropping the part index on the group here
519
- yield ''.join([content for content, _ in items])
520
-
521
- if delta:
522
- async for text in _stream_text_deltas():
523
- yield text
524
- else:
525
- # a quick benchmark shows it's faster to build up a string with concat when we're
526
- # yielding at each step
527
- deltas: list[str] = []
528
- async for text in _stream_text_deltas():
529
- deltas.append(text)
530
- yield ''.join(deltas)
531
-
532
496
 
533
497
  @dataclass(repr=False)
534
498
  class FinalResult(Generic[OutputDataT]):
@@ -556,12 +520,12 @@ def _get_usage_checking_stream_response(
556
520
  ) -> AsyncIterable[_messages.ModelResponseStreamEvent]:
557
521
  if limits is not None and limits.has_token_limits():
558
522
 
559
- async def _usage_checking_iterator(): # pragma: no cover
523
+ async def _usage_checking_iterator():
560
524
  async for item in stream_response:
561
525
  limits.check_tokens(get_usage())
562
526
  yield item
563
527
 
564
- return _usage_checking_iterator() # pragma: no cover
528
+ return _usage_checking_iterator()
565
529
  else:
566
530
  return stream_response
567
531
 
@@ -68,7 +68,7 @@ vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
68
68
  google = ["google-genai>=1.24.0"]
69
69
  anthropic = ["anthropic>=0.52.0"]
70
70
  groq = ["groq>=0.19.0"]
71
- mistral = ["mistralai>=1.2.5"]
71
+ mistral = ["mistralai>=1.9.2"]
72
72
  bedrock = ["boto3>=1.37.24"]
73
73
  huggingface = ["huggingface-hub[inference]>=0.33.2"]
74
74
  # Tools