pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_agent_graph.py +310 -140
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +4 -4
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +3 -22
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +7 -8
- pydantic_ai/agent/__init__.py +70 -9
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +4 -2
- pydantic_ai/durable_exec/temporal/_agent.py +23 -2
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +2 -2
- pydantic_ai/messages.py +73 -25
- pydantic_ai/models/__init__.py +5 -4
- pydantic_ai/models/anthropic.py +5 -5
- pydantic_ai/models/bedrock.py +58 -56
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +9 -12
- pydantic_ai/models/google.py +3 -3
- pydantic_ai/models/groq.py +4 -4
- pydantic_ai/models/huggingface.py +4 -4
- pydantic_ai/models/instrumented.py +30 -16
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +6 -6
- pydantic_ai/models/openai.py +18 -27
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/result.py +144 -41
- pydantic_ai/retries.py +10 -29
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +126 -22
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +13 -4
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +7 -5
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +5 -6
- pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
- pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/output.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import Awaitable, Sequence
|
|
3
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Generic, Literal
|
|
6
6
|
|
|
7
7
|
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
|
|
8
8
|
from pydantic.json_schema import JsonSchemaValue
|
|
9
9
|
from pydantic_core import core_schema
|
|
10
|
-
from typing_extensions import TypeAliasType, TypeVar
|
|
10
|
+
from typing_extensions import TypeAliasType, TypeVar, deprecated
|
|
11
11
|
|
|
12
12
|
from . import _utils
|
|
13
13
|
from .messages import ToolCallPart
|
|
14
|
-
from .tools import RunContext, ToolDefinition
|
|
14
|
+
from .tools import DeferredToolRequests, RunContext, ToolDefinition
|
|
15
15
|
|
|
16
16
|
__all__ = (
|
|
17
17
|
# classes
|
|
@@ -42,7 +42,7 @@ StructuredOutputMode = Literal['tool', 'native', 'prompted']
|
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
OutputTypeOrFunction = TypeAliasType(
|
|
45
|
-
'OutputTypeOrFunction',
|
|
45
|
+
'OutputTypeOrFunction', type[T_co] | Callable[..., Awaitable[T_co] | T_co], type_params=(T_co,)
|
|
46
46
|
)
|
|
47
47
|
"""Definition of an output type or function.
|
|
48
48
|
|
|
@@ -54,10 +54,7 @@ See [output docs](../output.md) for more information.
|
|
|
54
54
|
|
|
55
55
|
TextOutputFunc = TypeAliasType(
|
|
56
56
|
'TextOutputFunc',
|
|
57
|
-
|
|
58
|
-
Callable[[RunContext, str], Union[Awaitable[T_co], T_co]],
|
|
59
|
-
Callable[[str], Union[Awaitable[T_co], T_co]],
|
|
60
|
-
],
|
|
57
|
+
Callable[[RunContext, str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co],
|
|
61
58
|
type_params=(T_co,),
|
|
62
59
|
)
|
|
63
60
|
"""Definition of a function that will be called to process the model's plain text output. The function must take a single string argument.
|
|
@@ -135,10 +132,9 @@ class NativeOutput(Generic[OutputDataT]):
|
|
|
135
132
|
|
|
136
133
|
Example:
|
|
137
134
|
```python {title="native_output.py" requires="tool_output.py"}
|
|
138
|
-
from tool_output import Fruit, Vehicle
|
|
139
|
-
|
|
140
135
|
from pydantic_ai import Agent, NativeOutput
|
|
141
136
|
|
|
137
|
+
from tool_output import Fruit, Vehicle
|
|
142
138
|
|
|
143
139
|
agent = Agent(
|
|
144
140
|
'openai:gpt-4o',
|
|
@@ -184,10 +180,11 @@ class PromptedOutput(Generic[OutputDataT]):
|
|
|
184
180
|
Example:
|
|
185
181
|
```python {title="prompted_output.py" requires="tool_output.py"}
|
|
186
182
|
from pydantic import BaseModel
|
|
187
|
-
from tool_output import Vehicle
|
|
188
183
|
|
|
189
184
|
from pydantic_ai import Agent, PromptedOutput
|
|
190
185
|
|
|
186
|
+
from tool_output import Vehicle
|
|
187
|
+
|
|
191
188
|
|
|
192
189
|
class Device(BaseModel):
|
|
193
190
|
name: str
|
|
@@ -286,18 +283,17 @@ def StructuredDict(
|
|
|
286
283
|
```python {title="structured_dict.py"}
|
|
287
284
|
from pydantic_ai import Agent, StructuredDict
|
|
288
285
|
|
|
289
|
-
|
|
290
286
|
schema = {
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
287
|
+
'type': 'object',
|
|
288
|
+
'properties': {
|
|
289
|
+
'name': {'type': 'string'},
|
|
290
|
+
'age': {'type': 'integer'}
|
|
295
291
|
},
|
|
296
|
-
|
|
292
|
+
'required': ['name', 'age']
|
|
297
293
|
}
|
|
298
294
|
|
|
299
295
|
agent = Agent('openai:gpt-4o', output_type=StructuredDict(schema))
|
|
300
|
-
result = agent.run_sync(
|
|
296
|
+
result = agent.run_sync('Create a person')
|
|
301
297
|
print(result.output)
|
|
302
298
|
#> {'name': 'John Doe', 'age': 30}
|
|
303
299
|
```
|
|
@@ -333,16 +329,13 @@ def StructuredDict(
|
|
|
333
329
|
|
|
334
330
|
_OutputSpecItem = TypeAliasType(
|
|
335
331
|
'_OutputSpecItem',
|
|
336
|
-
|
|
332
|
+
OutputTypeOrFunction[T_co] | ToolOutput[T_co] | NativeOutput[T_co] | PromptedOutput[T_co] | TextOutput[T_co],
|
|
337
333
|
type_params=(T_co,),
|
|
338
334
|
)
|
|
339
335
|
|
|
340
336
|
OutputSpec = TypeAliasType(
|
|
341
337
|
'OutputSpec',
|
|
342
|
-
|
|
343
|
-
_OutputSpecItem[T_co],
|
|
344
|
-
Sequence['OutputSpec[T_co]'],
|
|
345
|
-
],
|
|
338
|
+
_OutputSpecItem[T_co] | Sequence['OutputSpec[T_co]'],
|
|
346
339
|
type_params=(T_co,),
|
|
347
340
|
)
|
|
348
341
|
"""Specification of the agent's output data.
|
|
@@ -359,12 +352,14 @@ See [output docs](../output.md) for more information.
|
|
|
359
352
|
"""
|
|
360
353
|
|
|
361
354
|
|
|
362
|
-
@
|
|
363
|
-
class DeferredToolCalls:
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
355
|
+
@deprecated('`DeferredToolCalls` is deprecated, use `DeferredToolRequests` instead')
|
|
356
|
+
class DeferredToolCalls(DeferredToolRequests): # pragma: no cover
|
|
357
|
+
@property
|
|
358
|
+
@deprecated('`DeferredToolCalls.tool_calls` is deprecated, use `DeferredToolRequests.calls` instead')
|
|
359
|
+
def tool_calls(self) -> list[ToolCallPart]:
|
|
360
|
+
return self.calls
|
|
368
361
|
|
|
369
|
-
|
|
370
|
-
tool_defs
|
|
362
|
+
@property
|
|
363
|
+
@deprecated('`DeferredToolCalls.tool_defs` is deprecated')
|
|
364
|
+
def tool_defs(self) -> dict[str, ToolDefinition]:
|
|
365
|
+
return {}
|
pydantic_ai/profiles/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass, fields, replace
|
|
4
5
|
from textwrap import dedent
|
|
5
|
-
from typing import Callable, Union
|
|
6
6
|
|
|
7
7
|
from typing_extensions import Self
|
|
8
8
|
|
|
@@ -18,7 +18,7 @@ __all__ = [
|
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
@dataclass
|
|
21
|
+
@dataclass(kw_only=True)
|
|
22
22
|
class ModelProfile:
|
|
23
23
|
"""Describes how requests to and responses from specific models or families of models need to be constructed and processed to get the best results, independent of the model and provider classes used."""
|
|
24
24
|
|
|
@@ -75,6 +75,6 @@ class ModelProfile:
|
|
|
75
75
|
return replace(self, **non_default_attrs)
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
ModelProfileSpec =
|
|
78
|
+
ModelProfileSpec = ModelProfile | Callable[[str], ModelProfile | None]
|
|
79
79
|
|
|
80
80
|
DEFAULT_PROFILE = ModelProfile()
|
pydantic_ai/profiles/groq.py
CHANGED
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
+
import warnings
|
|
4
5
|
from collections.abc import Sequence
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from typing import Any, Literal
|
|
@@ -11,7 +12,7 @@ from ._json_schema import JsonSchema, JsonSchemaTransformer
|
|
|
11
12
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
@dataclass
|
|
15
|
+
@dataclass(kw_only=True)
|
|
15
16
|
class OpenAIModelProfile(ModelProfile):
|
|
16
17
|
"""Profile for models used with `OpenAIChatModel`.
|
|
17
18
|
|
|
@@ -21,7 +22,6 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
21
22
|
openai_supports_strict_tool_definition: bool = True
|
|
22
23
|
"""This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions."""
|
|
23
24
|
|
|
24
|
-
# TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
|
|
25
25
|
openai_supports_sampling_settings: bool = True
|
|
26
26
|
"""Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
|
|
27
27
|
|
|
@@ -38,6 +38,14 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
38
38
|
openai_system_prompt_role: OpenAISystemPromptRole | None = None
|
|
39
39
|
"""The role to use for the system prompt message. If not provided, defaults to `'system'`."""
|
|
40
40
|
|
|
41
|
+
def __post_init__(self): # pragma: no cover
|
|
42
|
+
if not self.openai_supports_sampling_settings:
|
|
43
|
+
warnings.warn(
|
|
44
|
+
'The `openai_supports_sampling_settings` has no effect, and it will be removed in future versions. '
|
|
45
|
+
'Use `openai_unsupported_model_settings` instead.',
|
|
46
|
+
DeprecationWarning,
|
|
47
|
+
)
|
|
48
|
+
|
|
41
49
|
|
|
42
50
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
43
51
|
"""Get the model profile for an OpenAI model."""
|
|
@@ -46,6 +54,19 @@ def openai_model_profile(model_name: str) -> ModelProfile:
|
|
|
46
54
|
# We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
|
|
47
55
|
# when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
|
|
48
56
|
|
|
57
|
+
if is_reasoning_model:
|
|
58
|
+
openai_unsupported_model_settings = (
|
|
59
|
+
'temperature',
|
|
60
|
+
'top_p',
|
|
61
|
+
'presence_penalty',
|
|
62
|
+
'frequency_penalty',
|
|
63
|
+
'logit_bias',
|
|
64
|
+
'logprobs',
|
|
65
|
+
'top_logprobs',
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
openai_unsupported_model_settings = ()
|
|
69
|
+
|
|
49
70
|
# The o1-mini model doesn't support the `system` role, so we default to `user`.
|
|
50
71
|
# See https://github.com/pydantic/pydantic-ai/issues/974 for more details.
|
|
51
72
|
openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None
|
|
@@ -54,7 +75,7 @@ def openai_model_profile(model_name: str) -> ModelProfile:
|
|
|
54
75
|
json_schema_transformer=OpenAIJsonSchemaTransformer,
|
|
55
76
|
supports_json_schema_output=True,
|
|
56
77
|
supports_json_object_output=True,
|
|
57
|
-
|
|
78
|
+
openai_unsupported_model_settings=openai_unsupported_model_settings,
|
|
58
79
|
openai_system_prompt_role=openai_system_prompt_role,
|
|
59
80
|
)
|
|
60
81
|
|
|
@@ -89,7 +110,7 @@ _STRICT_COMPATIBLE_STRING_FORMATS = [
|
|
|
89
110
|
_sentinel = object()
|
|
90
111
|
|
|
91
112
|
|
|
92
|
-
@dataclass
|
|
113
|
+
@dataclass(init=False)
|
|
93
114
|
class OpenAIJsonSchemaTransformer(JsonSchemaTransformer):
|
|
94
115
|
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
|
|
95
116
|
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import TypeAlias, overload
|
|
5
5
|
|
|
6
6
|
import httpx
|
|
7
|
-
from typing_extensions import TypeAlias
|
|
8
7
|
|
|
9
8
|
from pydantic_ai.exceptions import UserError
|
|
10
9
|
from pydantic_ai.models import cached_async_http_client
|
|
@@ -21,7 +20,7 @@ except ImportError as _import_error:
|
|
|
21
20
|
) from _import_error
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
AsyncAnthropicClient: TypeAlias =
|
|
23
|
+
AsyncAnthropicClient: TypeAlias = AsyncAnthropic | AsyncAnthropicBedrock
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
class AnthropicProvider(Provider[AsyncAnthropicClient]):
|
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -2,8 +2,9 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from dataclasses import dataclass
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Literal, overload
|
|
7
8
|
|
|
8
9
|
from pydantic_ai.exceptions import UserError
|
|
9
10
|
from pydantic_ai.profiles import ModelProfile
|
|
@@ -27,7 +28,7 @@ except ImportError as _import_error:
|
|
|
27
28
|
) from _import_error
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
@dataclass
|
|
31
|
+
@dataclass(kw_only=True)
|
|
31
32
|
class BedrockModelProfile(ModelProfile):
|
|
32
33
|
"""Profile for models used with BedrockModel.
|
|
33
34
|
|
pydantic_ai/result.py
CHANGED
|
@@ -1,16 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
3
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
|
4
4
|
from copy import copy
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Generic, cast
|
|
7
|
+
from typing import Generic, cast, overload
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
from typing_extensions import TypeVar, deprecated
|
|
11
11
|
|
|
12
|
-
from pydantic_ai._tool_manager import ToolManager
|
|
13
|
-
|
|
14
12
|
from . import _utils, exceptions, messages as _messages, models
|
|
15
13
|
from ._output import (
|
|
16
14
|
OutputDataT_inv,
|
|
@@ -22,11 +20,14 @@ from ._output import (
|
|
|
22
20
|
ToolOutputSchema,
|
|
23
21
|
)
|
|
24
22
|
from ._run_context import AgentDepsT, RunContext
|
|
23
|
+
from ._tool_manager import ToolManager
|
|
25
24
|
from .messages import ModelResponseStreamEvent
|
|
26
25
|
from .output import (
|
|
26
|
+
DeferredToolRequests,
|
|
27
27
|
OutputDataT,
|
|
28
28
|
ToolOutput,
|
|
29
29
|
)
|
|
30
|
+
from .run import AgentRunResult
|
|
30
31
|
from .usage import RunUsage, UsageLimits
|
|
31
32
|
|
|
32
33
|
__all__ = (
|
|
@@ -41,7 +42,7 @@ T = TypeVar('T')
|
|
|
41
42
|
"""An invariant TypeVar."""
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
@dataclass
|
|
45
|
+
@dataclass(kw_only=True)
|
|
45
46
|
class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
46
47
|
_raw_stream_response: models.StreamedResponse
|
|
47
48
|
_output_schema: OutputSchema[OutputDataT]
|
|
@@ -155,12 +156,12 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
155
156
|
return await self._tool_manager.handle_call(
|
|
156
157
|
tool_call, allow_partial=allow_partial, wrap_validation_errors=False
|
|
157
158
|
)
|
|
158
|
-
elif
|
|
159
|
-
if not self._output_schema.
|
|
159
|
+
elif deferred_tool_requests := _get_deferred_tool_requests(message.parts, self._tool_manager):
|
|
160
|
+
if not self._output_schema.allows_deferred_tools:
|
|
160
161
|
raise exceptions.UserError(
|
|
161
|
-
'A deferred tool call was present, but `
|
|
162
|
+
'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
|
|
162
163
|
)
|
|
163
|
-
return cast(OutputDataT,
|
|
164
|
+
return cast(OutputDataT, deferred_tool_requests)
|
|
164
165
|
elif isinstance(self._output_schema, TextOutputSchema):
|
|
165
166
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
166
167
|
|
|
@@ -233,15 +234,17 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
233
234
|
return self._agent_stream_iterator
|
|
234
235
|
|
|
235
236
|
|
|
236
|
-
@dataclass
|
|
237
|
+
@dataclass(init=False)
|
|
237
238
|
class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
238
239
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
239
240
|
|
|
240
241
|
_all_messages: list[_messages.ModelMessage]
|
|
241
242
|
_new_message_index: int
|
|
242
243
|
|
|
243
|
-
_stream_response: AgentStream[AgentDepsT, OutputDataT]
|
|
244
|
-
_on_complete: Callable[[], Awaitable[None]]
|
|
244
|
+
_stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None
|
|
245
|
+
_on_complete: Callable[[], Awaitable[None]] | None = None
|
|
246
|
+
|
|
247
|
+
_run_result: AgentRunResult[OutputDataT] | None = None
|
|
245
248
|
|
|
246
249
|
is_complete: bool = field(default=False, init=False)
|
|
247
250
|
"""Whether the stream has all been received.
|
|
@@ -253,6 +256,39 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
253
256
|
[`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes.
|
|
254
257
|
"""
|
|
255
258
|
|
|
259
|
+
@overload
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
all_messages: list[_messages.ModelMessage],
|
|
263
|
+
new_message_index: int,
|
|
264
|
+
stream_response: AgentStream[AgentDepsT, OutputDataT] | None,
|
|
265
|
+
on_complete: Callable[[], Awaitable[None]] | None,
|
|
266
|
+
) -> None: ...
|
|
267
|
+
|
|
268
|
+
@overload
|
|
269
|
+
def __init__(
|
|
270
|
+
self,
|
|
271
|
+
all_messages: list[_messages.ModelMessage],
|
|
272
|
+
new_message_index: int,
|
|
273
|
+
*,
|
|
274
|
+
run_result: AgentRunResult[OutputDataT],
|
|
275
|
+
) -> None: ...
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
all_messages: list[_messages.ModelMessage],
|
|
280
|
+
new_message_index: int,
|
|
281
|
+
stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None,
|
|
282
|
+
on_complete: Callable[[], Awaitable[None]] | None = None,
|
|
283
|
+
run_result: AgentRunResult[OutputDataT] | None = None,
|
|
284
|
+
) -> None:
|
|
285
|
+
self._all_messages = all_messages
|
|
286
|
+
self._new_message_index = new_message_index
|
|
287
|
+
|
|
288
|
+
self._stream_response = stream_response
|
|
289
|
+
self._on_complete = on_complete
|
|
290
|
+
self._run_result = run_result
|
|
291
|
+
|
|
256
292
|
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
257
293
|
"""Return the history of _messages.
|
|
258
294
|
|
|
@@ -340,9 +376,15 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
340
376
|
Returns:
|
|
341
377
|
An async iterable of the response data.
|
|
342
378
|
"""
|
|
343
|
-
|
|
344
|
-
yield output
|
|
345
|
-
|
|
379
|
+
if self._run_result is not None:
|
|
380
|
+
yield self._run_result.output
|
|
381
|
+
await self._marked_completed()
|
|
382
|
+
elif self._stream_response is not None:
|
|
383
|
+
async for output in self._stream_response.stream_output(debounce_by=debounce_by):
|
|
384
|
+
yield output
|
|
385
|
+
await self._marked_completed(self._stream_response.get())
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
346
388
|
|
|
347
389
|
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
|
|
348
390
|
"""Stream the text result as an async iterable.
|
|
@@ -357,9 +399,20 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
357
399
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
358
400
|
performing validation as each token is received.
|
|
359
401
|
"""
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
402
|
+
if self._run_result is not None: # pragma: no cover
|
|
403
|
+
# We can't really get here, as `_run_result` is only set in `run_stream` when `CallToolsNode` produces `DeferredToolRequests` output
|
|
404
|
+
# as a result of a tool function raising `CallDeferred` or `ApprovalRequired`.
|
|
405
|
+
# That'll change if we ever support something like `raise EndRun(output: OutputT)` where `OutputT` could be `str`.
|
|
406
|
+
if not isinstance(self._run_result.output, str):
|
|
407
|
+
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
408
|
+
yield self._run_result.output
|
|
409
|
+
await self._marked_completed()
|
|
410
|
+
elif self._stream_response is not None:
|
|
411
|
+
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
|
|
412
|
+
yield text
|
|
413
|
+
await self._marked_completed(self._stream_response.get())
|
|
414
|
+
else:
|
|
415
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
363
416
|
|
|
364
417
|
@deprecated('`StreamedRunResult.stream_structured` is deprecated, use `stream_responses` instead.')
|
|
365
418
|
async def stream_structured(
|
|
@@ -381,20 +434,34 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
381
434
|
Returns:
|
|
382
435
|
An async iterable of the structured response message and whether that is the last message.
|
|
383
436
|
"""
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
yield
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
437
|
+
if self._run_result is not None:
|
|
438
|
+
model_response = cast(_messages.ModelResponse, self.all_messages()[-1])
|
|
439
|
+
yield model_response, True
|
|
440
|
+
await self._marked_completed()
|
|
441
|
+
elif self._stream_response is not None:
|
|
442
|
+
# if the message currently has any parts with content, yield before streaming
|
|
443
|
+
async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
|
|
444
|
+
yield msg, False
|
|
445
|
+
|
|
446
|
+
msg = self._stream_response.get()
|
|
447
|
+
yield msg, True
|
|
448
|
+
|
|
449
|
+
await self._marked_completed(msg)
|
|
450
|
+
else:
|
|
451
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
392
452
|
|
|
393
453
|
async def get_output(self) -> OutputDataT:
|
|
394
454
|
"""Stream the whole response, validate and return it."""
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
455
|
+
if self._run_result is not None:
|
|
456
|
+
output = self._run_result.output
|
|
457
|
+
await self._marked_completed()
|
|
458
|
+
return output
|
|
459
|
+
elif self._stream_response is not None:
|
|
460
|
+
output = await self._stream_response.get_output()
|
|
461
|
+
await self._marked_completed(self._stream_response.get())
|
|
462
|
+
return output
|
|
463
|
+
else:
|
|
464
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
398
465
|
|
|
399
466
|
def usage(self) -> RunUsage:
|
|
400
467
|
"""Return the usage of the whole run.
|
|
@@ -402,28 +469,45 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
402
469
|
!!! note
|
|
403
470
|
This won't return the full usage until the stream is finished.
|
|
404
471
|
"""
|
|
405
|
-
|
|
472
|
+
if self._run_result is not None:
|
|
473
|
+
return self._run_result.usage()
|
|
474
|
+
elif self._stream_response is not None:
|
|
475
|
+
return self._stream_response.usage()
|
|
476
|
+
else:
|
|
477
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
406
478
|
|
|
407
479
|
def timestamp(self) -> datetime:
|
|
408
480
|
"""Get the timestamp of the response."""
|
|
409
|
-
|
|
481
|
+
if self._run_result is not None:
|
|
482
|
+
return self._run_result.timestamp()
|
|
483
|
+
elif self._stream_response is not None:
|
|
484
|
+
return self._stream_response.timestamp()
|
|
485
|
+
else:
|
|
486
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
410
487
|
|
|
411
488
|
@deprecated('`validate_structured_output` is deprecated, use `validate_response_output` instead.')
|
|
412
489
|
async def validate_structured_output(
|
|
413
490
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
414
491
|
) -> OutputDataT:
|
|
415
|
-
return await self.
|
|
492
|
+
return await self.validate_response_output(message, allow_partial=allow_partial)
|
|
416
493
|
|
|
417
494
|
async def validate_response_output(
|
|
418
495
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
419
496
|
) -> OutputDataT:
|
|
420
497
|
"""Validate a structured result message."""
|
|
421
|
-
|
|
498
|
+
if self._run_result is not None:
|
|
499
|
+
return self._run_result.output
|
|
500
|
+
elif self._stream_response is not None:
|
|
501
|
+
return await self._stream_response.validate_response_output(message, allow_partial=allow_partial)
|
|
502
|
+
else:
|
|
503
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
422
504
|
|
|
423
|
-
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
|
|
505
|
+
async def _marked_completed(self, message: _messages.ModelResponse | None = None) -> None:
|
|
424
506
|
self.is_complete = True
|
|
425
|
-
|
|
426
|
-
|
|
507
|
+
if message is not None:
|
|
508
|
+
self._all_messages.append(message)
|
|
509
|
+
if self._on_complete is not None:
|
|
510
|
+
await self._on_complete()
|
|
427
511
|
|
|
428
512
|
|
|
429
513
|
@dataclass(repr=False)
|
|
@@ -432,8 +516,10 @@ class FinalResult(Generic[OutputDataT]):
|
|
|
432
516
|
|
|
433
517
|
output: OutputDataT
|
|
434
518
|
"""The final result data."""
|
|
519
|
+
|
|
435
520
|
tool_name: str | None = None
|
|
436
521
|
"""Name of the final output tool; `None` if the output came from unstructured text content."""
|
|
522
|
+
|
|
437
523
|
tool_call_id: str | None = None
|
|
438
524
|
"""ID of the tool call that produced the final output; `None` if the output came from unstructured text content."""
|
|
439
525
|
|
|
@@ -454,9 +540,26 @@ def _get_usage_checking_stream_response(
|
|
|
454
540
|
|
|
455
541
|
return _usage_checking_iterator()
|
|
456
542
|
else:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
543
|
+
return aiter(stream_response)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def _get_deferred_tool_requests(
|
|
547
|
+
parts: Iterable[_messages.ModelResponsePart], tool_manager: ToolManager[AgentDepsT]
|
|
548
|
+
) -> DeferredToolRequests | None:
|
|
549
|
+
"""Get the deferred tool requests from the model response parts."""
|
|
550
|
+
approvals: list[_messages.ToolCallPart] = []
|
|
551
|
+
calls: list[_messages.ToolCallPart] = []
|
|
552
|
+
|
|
553
|
+
for part in parts:
|
|
554
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
555
|
+
tool_def = tool_manager.get_tool_def(part.tool_name)
|
|
556
|
+
if tool_def is not None: # pragma: no branch
|
|
557
|
+
if tool_def.kind == 'unapproved':
|
|
558
|
+
approvals.append(part)
|
|
559
|
+
elif tool_def.kind == 'external':
|
|
560
|
+
calls.append(part)
|
|
561
|
+
|
|
562
|
+
if not calls and not approvals:
|
|
563
|
+
return None
|
|
461
564
|
|
|
462
|
-
|
|
565
|
+
return DeferredToolRequests(calls=calls, approvals=approvals)
|