pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (70) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_agent_graph.py +310 -140
  3. pydantic_ai/_function_schema.py +5 -5
  4. pydantic_ai/_griffe.py +2 -1
  5. pydantic_ai/_otel_messages.py +2 -2
  6. pydantic_ai/_output.py +31 -35
  7. pydantic_ai/_parts_manager.py +4 -4
  8. pydantic_ai/_run_context.py +3 -1
  9. pydantic_ai/_system_prompt.py +2 -2
  10. pydantic_ai/_tool_manager.py +3 -22
  11. pydantic_ai/_utils.py +14 -26
  12. pydantic_ai/ag_ui.py +7 -8
  13. pydantic_ai/agent/__init__.py +70 -9
  14. pydantic_ai/agent/abstract.py +35 -4
  15. pydantic_ai/agent/wrapper.py +6 -0
  16. pydantic_ai/builtin_tools.py +2 -2
  17. pydantic_ai/common_tools/duckduckgo.py +4 -2
  18. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  19. pydantic_ai/durable_exec/temporal/_agent.py +23 -2
  20. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  21. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  22. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  23. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  24. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  25. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  26. pydantic_ai/exceptions.py +45 -2
  27. pydantic_ai/format_prompt.py +2 -2
  28. pydantic_ai/mcp.py +2 -2
  29. pydantic_ai/messages.py +73 -25
  30. pydantic_ai/models/__init__.py +5 -4
  31. pydantic_ai/models/anthropic.py +5 -5
  32. pydantic_ai/models/bedrock.py +58 -56
  33. pydantic_ai/models/cohere.py +3 -3
  34. pydantic_ai/models/fallback.py +2 -2
  35. pydantic_ai/models/function.py +25 -23
  36. pydantic_ai/models/gemini.py +9 -12
  37. pydantic_ai/models/google.py +3 -3
  38. pydantic_ai/models/groq.py +4 -4
  39. pydantic_ai/models/huggingface.py +4 -4
  40. pydantic_ai/models/instrumented.py +30 -16
  41. pydantic_ai/models/mcp_sampling.py +3 -1
  42. pydantic_ai/models/mistral.py +6 -6
  43. pydantic_ai/models/openai.py +18 -27
  44. pydantic_ai/models/test.py +24 -4
  45. pydantic_ai/output.py +27 -32
  46. pydantic_ai/profiles/__init__.py +3 -3
  47. pydantic_ai/profiles/groq.py +1 -1
  48. pydantic_ai/profiles/openai.py +25 -4
  49. pydantic_ai/providers/anthropic.py +2 -3
  50. pydantic_ai/providers/bedrock.py +3 -2
  51. pydantic_ai/result.py +144 -41
  52. pydantic_ai/retries.py +10 -29
  53. pydantic_ai/run.py +12 -5
  54. pydantic_ai/tools.py +126 -22
  55. pydantic_ai/toolsets/__init__.py +4 -1
  56. pydantic_ai/toolsets/_dynamic.py +4 -4
  57. pydantic_ai/toolsets/abstract.py +18 -2
  58. pydantic_ai/toolsets/approval_required.py +32 -0
  59. pydantic_ai/toolsets/combined.py +7 -12
  60. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  61. pydantic_ai/toolsets/filtered.py +1 -1
  62. pydantic_ai/toolsets/function.py +13 -4
  63. pydantic_ai/toolsets/wrapper.py +2 -1
  64. pydantic_ai/usage.py +7 -5
  65. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +5 -6
  66. pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
  67. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  68. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
  69. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/output.py CHANGED
@@ -1,17 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Awaitable, Sequence
3
+ from collections.abc import Awaitable, Callable, Sequence
4
4
  from dataclasses import dataclass
5
- from typing import Any, Callable, Generic, Literal, Union
5
+ from typing import Any, Generic, Literal
6
6
 
7
7
  from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
8
8
  from pydantic.json_schema import JsonSchemaValue
9
9
  from pydantic_core import core_schema
10
- from typing_extensions import TypeAliasType, TypeVar
10
+ from typing_extensions import TypeAliasType, TypeVar, deprecated
11
11
 
12
12
  from . import _utils
13
13
  from .messages import ToolCallPart
14
- from .tools import RunContext, ToolDefinition
14
+ from .tools import DeferredToolRequests, RunContext, ToolDefinition
15
15
 
16
16
  __all__ = (
17
17
  # classes
@@ -42,7 +42,7 @@ StructuredOutputMode = Literal['tool', 'native', 'prompted']
42
42
 
43
43
 
44
44
  OutputTypeOrFunction = TypeAliasType(
45
- 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,)
45
+ 'OutputTypeOrFunction', type[T_co] | Callable[..., Awaitable[T_co] | T_co], type_params=(T_co,)
46
46
  )
47
47
  """Definition of an output type or function.
48
48
 
@@ -54,10 +54,7 @@ See [output docs](../output.md) for more information.
54
54
 
55
55
  TextOutputFunc = TypeAliasType(
56
56
  'TextOutputFunc',
57
- Union[
58
- Callable[[RunContext, str], Union[Awaitable[T_co], T_co]],
59
- Callable[[str], Union[Awaitable[T_co], T_co]],
60
- ],
57
+ Callable[[RunContext, str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co],
61
58
  type_params=(T_co,),
62
59
  )
63
60
  """Definition of a function that will be called to process the model's plain text output. The function must take a single string argument.
@@ -135,10 +132,9 @@ class NativeOutput(Generic[OutputDataT]):
135
132
 
136
133
  Example:
137
134
  ```python {title="native_output.py" requires="tool_output.py"}
138
- from tool_output import Fruit, Vehicle
139
-
140
135
  from pydantic_ai import Agent, NativeOutput
141
136
 
137
+ from tool_output import Fruit, Vehicle
142
138
 
143
139
  agent = Agent(
144
140
  'openai:gpt-4o',
@@ -184,10 +180,11 @@ class PromptedOutput(Generic[OutputDataT]):
184
180
  Example:
185
181
  ```python {title="prompted_output.py" requires="tool_output.py"}
186
182
  from pydantic import BaseModel
187
- from tool_output import Vehicle
188
183
 
189
184
  from pydantic_ai import Agent, PromptedOutput
190
185
 
186
+ from tool_output import Vehicle
187
+
191
188
 
192
189
  class Device(BaseModel):
193
190
  name: str
@@ -286,18 +283,17 @@ def StructuredDict(
286
283
  ```python {title="structured_dict.py"}
287
284
  from pydantic_ai import Agent, StructuredDict
288
285
 
289
-
290
286
  schema = {
291
- "type": "object",
292
- "properties": {
293
- "name": {"type": "string"},
294
- "age": {"type": "integer"}
287
+ 'type': 'object',
288
+ 'properties': {
289
+ 'name': {'type': 'string'},
290
+ 'age': {'type': 'integer'}
295
291
  },
296
- "required": ["name", "age"]
292
+ 'required': ['name', 'age']
297
293
  }
298
294
 
299
295
  agent = Agent('openai:gpt-4o', output_type=StructuredDict(schema))
300
- result = agent.run_sync("Create a person")
296
+ result = agent.run_sync('Create a person')
301
297
  print(result.output)
302
298
  #> {'name': 'John Doe', 'age': 30}
303
299
  ```
@@ -333,16 +329,13 @@ def StructuredDict(
333
329
 
334
330
  _OutputSpecItem = TypeAliasType(
335
331
  '_OutputSpecItem',
336
- Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]],
332
+ OutputTypeOrFunction[T_co] | ToolOutput[T_co] | NativeOutput[T_co] | PromptedOutput[T_co] | TextOutput[T_co],
337
333
  type_params=(T_co,),
338
334
  )
339
335
 
340
336
  OutputSpec = TypeAliasType(
341
337
  'OutputSpec',
342
- Union[
343
- _OutputSpecItem[T_co],
344
- Sequence['OutputSpec[T_co]'],
345
- ],
338
+ _OutputSpecItem[T_co] | Sequence['OutputSpec[T_co]'],
346
339
  type_params=(T_co,),
347
340
  )
348
341
  """Specification of the agent's output data.
@@ -359,12 +352,14 @@ See [output docs](../output.md) for more information.
359
352
  """
360
353
 
361
354
 
362
- @dataclass
363
- class DeferredToolCalls:
364
- """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.
365
-
366
- See [deferred toolset docs](../toolsets.md#deferred-toolset) for more information.
367
- """
355
+ @deprecated('`DeferredToolCalls` is deprecated, use `DeferredToolRequests` instead')
356
+ class DeferredToolCalls(DeferredToolRequests): # pragma: no cover
357
+ @property
358
+ @deprecated('`DeferredToolCalls.tool_calls` is deprecated, use `DeferredToolRequests.calls` instead')
359
+ def tool_calls(self) -> list[ToolCallPart]:
360
+ return self.calls
368
361
 
369
- tool_calls: list[ToolCallPart]
370
- tool_defs: dict[str, ToolDefinition]
362
+ @property
363
+ @deprecated('`DeferredToolCalls.tool_defs` is deprecated')
364
+ def tool_defs(self) -> dict[str, ToolDefinition]:
365
+ return {}
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ from collections.abc import Callable
3
4
  from dataclasses import dataclass, fields, replace
4
5
  from textwrap import dedent
5
- from typing import Callable, Union
6
6
 
7
7
  from typing_extensions import Self
8
8
 
@@ -18,7 +18,7 @@ __all__ = [
18
18
  ]
19
19
 
20
20
 
21
- @dataclass
21
+ @dataclass(kw_only=True)
22
22
  class ModelProfile:
23
23
  """Describes how requests to and responses from specific models or families of models need to be constructed and processed to get the best results, independent of the model and provider classes used."""
24
24
 
@@ -75,6 +75,6 @@ class ModelProfile:
75
75
  return replace(self, **non_default_attrs)
76
76
 
77
77
 
78
- ModelProfileSpec = Union[ModelProfile, Callable[[str], Union[ModelProfile, None]]]
78
+ ModelProfileSpec = ModelProfile | Callable[[str], ModelProfile | None]
79
79
 
80
80
  DEFAULT_PROFILE = ModelProfile()
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
5
  from . import ModelProfile
6
6
 
7
7
 
8
- @dataclass
8
+ @dataclass(kw_only=True)
9
9
  class GroqModelProfile(ModelProfile):
10
10
  """Profile for models used with GroqModel.
11
11
 
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import re
4
+ import warnings
4
5
  from collections.abc import Sequence
5
6
  from dataclasses import dataclass
6
7
  from typing import Any, Literal
@@ -11,7 +12,7 @@ from ._json_schema import JsonSchema, JsonSchemaTransformer
11
12
  OpenAISystemPromptRole = Literal['system', 'developer', 'user']
12
13
 
13
14
 
14
- @dataclass
15
+ @dataclass(kw_only=True)
15
16
  class OpenAIModelProfile(ModelProfile):
16
17
  """Profile for models used with `OpenAIChatModel`.
17
18
 
@@ -21,7 +22,6 @@ class OpenAIModelProfile(ModelProfile):
21
22
  openai_supports_strict_tool_definition: bool = True
22
23
  """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions."""
23
24
 
24
- # TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
25
25
  openai_supports_sampling_settings: bool = True
26
26
  """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
27
27
 
@@ -38,6 +38,14 @@ class OpenAIModelProfile(ModelProfile):
38
38
  openai_system_prompt_role: OpenAISystemPromptRole | None = None
39
39
  """The role to use for the system prompt message. If not provided, defaults to `'system'`."""
40
40
 
41
+ def __post_init__(self): # pragma: no cover
42
+ if not self.openai_supports_sampling_settings:
43
+ warnings.warn(
44
+ 'The `openai_supports_sampling_settings` has no effect, and it will be removed in future versions. '
45
+ 'Use `openai_unsupported_model_settings` instead.',
46
+ DeprecationWarning,
47
+ )
48
+
41
49
 
42
50
  def openai_model_profile(model_name: str) -> ModelProfile:
43
51
  """Get the model profile for an OpenAI model."""
@@ -46,6 +54,19 @@ def openai_model_profile(model_name: str) -> ModelProfile:
46
54
  # We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
47
55
  # when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
48
56
 
57
+ if is_reasoning_model:
58
+ openai_unsupported_model_settings = (
59
+ 'temperature',
60
+ 'top_p',
61
+ 'presence_penalty',
62
+ 'frequency_penalty',
63
+ 'logit_bias',
64
+ 'logprobs',
65
+ 'top_logprobs',
66
+ )
67
+ else:
68
+ openai_unsupported_model_settings = ()
69
+
49
70
  # The o1-mini model doesn't support the `system` role, so we default to `user`.
50
71
  # See https://github.com/pydantic/pydantic-ai/issues/974 for more details.
51
72
  openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None
@@ -54,7 +75,7 @@ def openai_model_profile(model_name: str) -> ModelProfile:
54
75
  json_schema_transformer=OpenAIJsonSchemaTransformer,
55
76
  supports_json_schema_output=True,
56
77
  supports_json_object_output=True,
57
- openai_supports_sampling_settings=not is_reasoning_model,
78
+ openai_unsupported_model_settings=openai_unsupported_model_settings,
58
79
  openai_system_prompt_role=openai_system_prompt_role,
59
80
  )
60
81
 
@@ -89,7 +110,7 @@ _STRICT_COMPATIBLE_STRING_FORMATS = [
89
110
  _sentinel = object()
90
111
 
91
112
 
92
- @dataclass
113
+ @dataclass(init=False)
93
114
  class OpenAIJsonSchemaTransformer(JsonSchemaTransformer):
94
115
  """Recursively handle the schema to make it compatible with OpenAI strict mode.
95
116
 
@@ -1,10 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
- from typing import Union, overload
4
+ from typing import TypeAlias, overload
5
5
 
6
6
  import httpx
7
- from typing_extensions import TypeAlias
8
7
 
9
8
  from pydantic_ai.exceptions import UserError
10
9
  from pydantic_ai.models import cached_async_http_client
@@ -21,7 +20,7 @@ except ImportError as _import_error:
21
20
  ) from _import_error
22
21
 
23
22
 
24
- AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
23
+ AsyncAnthropicClient: TypeAlias = AsyncAnthropic | AsyncAnthropicBedrock
25
24
 
26
25
 
27
26
  class AnthropicProvider(Provider[AsyncAnthropicClient]):
@@ -2,8 +2,9 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
  import re
5
+ from collections.abc import Callable
5
6
  from dataclasses import dataclass
6
- from typing import Callable, Literal, overload
7
+ from typing import Literal, overload
7
8
 
8
9
  from pydantic_ai.exceptions import UserError
9
10
  from pydantic_ai.profiles import ModelProfile
@@ -27,7 +28,7 @@ except ImportError as _import_error:
27
28
  ) from _import_error
28
29
 
29
30
 
30
- @dataclass
31
+ @dataclass(kw_only=True)
31
32
  class BedrockModelProfile(ModelProfile):
32
33
  """Profile for models used with BedrockModel.
33
34
 
pydantic_ai/result.py CHANGED
@@ -1,16 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Awaitable, Callable
3
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
4
4
  from copy import copy
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime
7
- from typing import Generic, cast
7
+ from typing import Generic, cast, overload
8
8
 
9
9
  from pydantic import ValidationError
10
10
  from typing_extensions import TypeVar, deprecated
11
11
 
12
- from pydantic_ai._tool_manager import ToolManager
13
-
14
12
  from . import _utils, exceptions, messages as _messages, models
15
13
  from ._output import (
16
14
  OutputDataT_inv,
@@ -22,11 +20,14 @@ from ._output import (
22
20
  ToolOutputSchema,
23
21
  )
24
22
  from ._run_context import AgentDepsT, RunContext
23
+ from ._tool_manager import ToolManager
25
24
  from .messages import ModelResponseStreamEvent
26
25
  from .output import (
26
+ DeferredToolRequests,
27
27
  OutputDataT,
28
28
  ToolOutput,
29
29
  )
30
+ from .run import AgentRunResult
30
31
  from .usage import RunUsage, UsageLimits
31
32
 
32
33
  __all__ = (
@@ -41,7 +42,7 @@ T = TypeVar('T')
41
42
  """An invariant TypeVar."""
42
43
 
43
44
 
44
- @dataclass
45
+ @dataclass(kw_only=True)
45
46
  class AgentStream(Generic[AgentDepsT, OutputDataT]):
46
47
  _raw_stream_response: models.StreamedResponse
47
48
  _output_schema: OutputSchema[OutputDataT]
@@ -155,12 +156,12 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
155
156
  return await self._tool_manager.handle_call(
156
157
  tool_call, allow_partial=allow_partial, wrap_validation_errors=False
157
158
  )
158
- elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
159
- if not self._output_schema.allows_deferred_tool_calls:
159
+ elif deferred_tool_requests := _get_deferred_tool_requests(message.parts, self._tool_manager):
160
+ if not self._output_schema.allows_deferred_tools:
160
161
  raise exceptions.UserError(
161
- 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
162
+ 'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
162
163
  )
163
- return cast(OutputDataT, deferred_tool_calls)
164
+ return cast(OutputDataT, deferred_tool_requests)
164
165
  elif isinstance(self._output_schema, TextOutputSchema):
165
166
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
166
167
 
@@ -233,15 +234,17 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
233
234
  return self._agent_stream_iterator
234
235
 
235
236
 
236
- @dataclass
237
+ @dataclass(init=False)
237
238
  class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
238
239
  """Result of a streamed run that returns structured data via a tool call."""
239
240
 
240
241
  _all_messages: list[_messages.ModelMessage]
241
242
  _new_message_index: int
242
243
 
243
- _stream_response: AgentStream[AgentDepsT, OutputDataT]
244
- _on_complete: Callable[[], Awaitable[None]]
244
+ _stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None
245
+ _on_complete: Callable[[], Awaitable[None]] | None = None
246
+
247
+ _run_result: AgentRunResult[OutputDataT] | None = None
245
248
 
246
249
  is_complete: bool = field(default=False, init=False)
247
250
  """Whether the stream has all been received.
@@ -253,6 +256,39 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
253
256
  [`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes.
254
257
  """
255
258
 
259
+ @overload
260
+ def __init__(
261
+ self,
262
+ all_messages: list[_messages.ModelMessage],
263
+ new_message_index: int,
264
+ stream_response: AgentStream[AgentDepsT, OutputDataT] | None,
265
+ on_complete: Callable[[], Awaitable[None]] | None,
266
+ ) -> None: ...
267
+
268
+ @overload
269
+ def __init__(
270
+ self,
271
+ all_messages: list[_messages.ModelMessage],
272
+ new_message_index: int,
273
+ *,
274
+ run_result: AgentRunResult[OutputDataT],
275
+ ) -> None: ...
276
+
277
+ def __init__(
278
+ self,
279
+ all_messages: list[_messages.ModelMessage],
280
+ new_message_index: int,
281
+ stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None,
282
+ on_complete: Callable[[], Awaitable[None]] | None = None,
283
+ run_result: AgentRunResult[OutputDataT] | None = None,
284
+ ) -> None:
285
+ self._all_messages = all_messages
286
+ self._new_message_index = new_message_index
287
+
288
+ self._stream_response = stream_response
289
+ self._on_complete = on_complete
290
+ self._run_result = run_result
291
+
256
292
  def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
257
293
  """Return the history of _messages.
258
294
 
@@ -340,9 +376,15 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
340
376
  Returns:
341
377
  An async iterable of the response data.
342
378
  """
343
- async for output in self._stream_response.stream_output(debounce_by=debounce_by):
344
- yield output
345
- await self._marked_completed(self._stream_response.get())
379
+ if self._run_result is not None:
380
+ yield self._run_result.output
381
+ await self._marked_completed()
382
+ elif self._stream_response is not None:
383
+ async for output in self._stream_response.stream_output(debounce_by=debounce_by):
384
+ yield output
385
+ await self._marked_completed(self._stream_response.get())
386
+ else:
387
+ raise ValueError('No stream response or run result provided') # pragma: no cover
346
388
 
347
389
  async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
348
390
  """Stream the text result as an async iterable.
@@ -357,9 +399,20 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
357
399
  Debouncing is particularly important for long structured responses to reduce the overhead of
358
400
  performing validation as each token is received.
359
401
  """
360
- async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
361
- yield text
362
- await self._marked_completed(self._stream_response.get())
402
+ if self._run_result is not None: # pragma: no cover
403
+ # We can't really get here, as `_run_result` is only set in `run_stream` when `CallToolsNode` produces `DeferredToolRequests` output
404
+ # as a result of a tool function raising `CallDeferred` or `ApprovalRequired`.
405
+ # That'll change if we ever support something like `raise EndRun(output: OutputT)` where `OutputT` could be `str`.
406
+ if not isinstance(self._run_result.output, str):
407
+ raise exceptions.UserError('stream_text() can only be used with text responses')
408
+ yield self._run_result.output
409
+ await self._marked_completed()
410
+ elif self._stream_response is not None:
411
+ async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
412
+ yield text
413
+ await self._marked_completed(self._stream_response.get())
414
+ else:
415
+ raise ValueError('No stream response or run result provided') # pragma: no cover
363
416
 
364
417
  @deprecated('`StreamedRunResult.stream_structured` is deprecated, use `stream_responses` instead.')
365
418
  async def stream_structured(
@@ -381,20 +434,34 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
381
434
  Returns:
382
435
  An async iterable of the structured response message and whether that is the last message.
383
436
  """
384
- # if the message currently has any parts with content, yield before streaming
385
- async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
386
- yield msg, False
387
-
388
- msg = self._stream_response.get()
389
- yield msg, True
390
-
391
- await self._marked_completed(msg)
437
+ if self._run_result is not None:
438
+ model_response = cast(_messages.ModelResponse, self.all_messages()[-1])
439
+ yield model_response, True
440
+ await self._marked_completed()
441
+ elif self._stream_response is not None:
442
+ # if the message currently has any parts with content, yield before streaming
443
+ async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
444
+ yield msg, False
445
+
446
+ msg = self._stream_response.get()
447
+ yield msg, True
448
+
449
+ await self._marked_completed(msg)
450
+ else:
451
+ raise ValueError('No stream response or run result provided') # pragma: no cover
392
452
 
393
453
  async def get_output(self) -> OutputDataT:
394
454
  """Stream the whole response, validate and return it."""
395
- output = await self._stream_response.get_output()
396
- await self._marked_completed(self._stream_response.get())
397
- return output
455
+ if self._run_result is not None:
456
+ output = self._run_result.output
457
+ await self._marked_completed()
458
+ return output
459
+ elif self._stream_response is not None:
460
+ output = await self._stream_response.get_output()
461
+ await self._marked_completed(self._stream_response.get())
462
+ return output
463
+ else:
464
+ raise ValueError('No stream response or run result provided') # pragma: no cover
398
465
 
399
466
  def usage(self) -> RunUsage:
400
467
  """Return the usage of the whole run.
@@ -402,28 +469,45 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
402
469
  !!! note
403
470
  This won't return the full usage until the stream is finished.
404
471
  """
405
- return self._stream_response.usage()
472
+ if self._run_result is not None:
473
+ return self._run_result.usage()
474
+ elif self._stream_response is not None:
475
+ return self._stream_response.usage()
476
+ else:
477
+ raise ValueError('No stream response or run result provided') # pragma: no cover
406
478
 
407
479
  def timestamp(self) -> datetime:
408
480
  """Get the timestamp of the response."""
409
- return self._stream_response.timestamp()
481
+ if self._run_result is not None:
482
+ return self._run_result.timestamp()
483
+ elif self._stream_response is not None:
484
+ return self._stream_response.timestamp()
485
+ else:
486
+ raise ValueError('No stream response or run result provided') # pragma: no cover
410
487
 
411
488
  @deprecated('`validate_structured_output` is deprecated, use `validate_response_output` instead.')
412
489
  async def validate_structured_output(
413
490
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
414
491
  ) -> OutputDataT:
415
- return await self._stream_response.validate_response_output(message, allow_partial=allow_partial)
492
+ return await self.validate_response_output(message, allow_partial=allow_partial)
416
493
 
417
494
  async def validate_response_output(
418
495
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
419
496
  ) -> OutputDataT:
420
497
  """Validate a structured result message."""
421
- return await self._stream_response.validate_response_output(message, allow_partial=allow_partial)
498
+ if self._run_result is not None:
499
+ return self._run_result.output
500
+ elif self._stream_response is not None:
501
+ return await self._stream_response.validate_response_output(message, allow_partial=allow_partial)
502
+ else:
503
+ raise ValueError('No stream response or run result provided') # pragma: no cover
422
504
 
423
- async def _marked_completed(self, message: _messages.ModelResponse) -> None:
505
+ async def _marked_completed(self, message: _messages.ModelResponse | None = None) -> None:
424
506
  self.is_complete = True
425
- self._all_messages.append(message)
426
- await self._on_complete()
507
+ if message is not None:
508
+ self._all_messages.append(message)
509
+ if self._on_complete is not None:
510
+ await self._on_complete()
427
511
 
428
512
 
429
513
  @dataclass(repr=False)
@@ -432,8 +516,10 @@ class FinalResult(Generic[OutputDataT]):
432
516
 
433
517
  output: OutputDataT
434
518
  """The final result data."""
519
+
435
520
  tool_name: str | None = None
436
521
  """Name of the final output tool; `None` if the output came from unstructured text content."""
522
+
437
523
  tool_call_id: str | None = None
438
524
  """ID of the tool call that produced the final output; `None` if the output came from unstructured text content."""
439
525
 
@@ -454,9 +540,26 @@ def _get_usage_checking_stream_response(
454
540
 
455
541
  return _usage_checking_iterator()
456
542
  else:
457
- # TODO: Use `return aiter(stream_response)` once we drop support for Python 3.9
458
- async def _iterator():
459
- async for item in stream_response:
460
- yield item
543
+ return aiter(stream_response)
544
+
545
+
546
+ def _get_deferred_tool_requests(
547
+ parts: Iterable[_messages.ModelResponsePart], tool_manager: ToolManager[AgentDepsT]
548
+ ) -> DeferredToolRequests | None:
549
+ """Get the deferred tool requests from the model response parts."""
550
+ approvals: list[_messages.ToolCallPart] = []
551
+ calls: list[_messages.ToolCallPart] = []
552
+
553
+ for part in parts:
554
+ if isinstance(part, _messages.ToolCallPart):
555
+ tool_def = tool_manager.get_tool_def(part.tool_name)
556
+ if tool_def is not None: # pragma: no branch
557
+ if tool_def.kind == 'unapproved':
558
+ approvals.append(part)
559
+ elif tool_def.kind == 'external':
560
+ calls.append(part)
561
+
562
+ if not calls and not approvals:
563
+ return None
461
564
 
462
- return _iterator()
565
+ return DeferredToolRequests(calls=calls, approvals=approvals)