pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +25 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +65 -9
- pydantic_ai/models/anthropic.py +16 -7
- pydantic_ai/models/bedrock.py +8 -5
- pydantic_ai/models/cohere.py +1 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +9 -4
- pydantic_ai/models/gemini.py +15 -9
- pydantic_ai/models/google.py +18 -14
- pydantic_ai/models/groq.py +17 -14
- pydantic_ai/models/huggingface.py +18 -12
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +12 -18
- pydantic_ai/models/openai.py +29 -26
- pydantic_ai/models/test.py +3 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -11,13 +11,13 @@ from typing import Any, Literal, Union, cast, overload
|
|
|
11
11
|
from pydantic import ValidationError
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
|
-
from pydantic_ai.exceptions import UserError
|
|
15
|
-
|
|
16
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
17
15
|
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
16
|
+
from .._run_context import RunContext
|
|
18
17
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
19
18
|
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
20
19
|
from ..builtin_tools import CodeExecutionTool, WebSearchTool
|
|
20
|
+
from ..exceptions import UserError
|
|
21
21
|
from ..messages import (
|
|
22
22
|
AudioUrl,
|
|
23
23
|
BinaryContent,
|
|
@@ -256,13 +256,14 @@ class OpenAIModel(Model):
|
|
|
256
256
|
messages: list[ModelMessage],
|
|
257
257
|
model_settings: ModelSettings | None,
|
|
258
258
|
model_request_parameters: ModelRequestParameters,
|
|
259
|
+
run_context: RunContext[Any] | None = None,
|
|
259
260
|
) -> AsyncIterator[StreamedResponse]:
|
|
260
261
|
check_allow_model_requests()
|
|
261
262
|
response = await self._completions_create(
|
|
262
263
|
messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
263
264
|
)
|
|
264
265
|
async with response:
|
|
265
|
-
yield await self._process_streamed_response(response)
|
|
266
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
266
267
|
|
|
267
268
|
@property
|
|
268
269
|
def model_name(self) -> OpenAIModelName:
|
|
@@ -427,7 +428,9 @@ class OpenAIModel(Model):
|
|
|
427
428
|
vendor_id=response.id,
|
|
428
429
|
)
|
|
429
430
|
|
|
430
|
-
async def _process_streamed_response(
|
|
431
|
+
async def _process_streamed_response(
|
|
432
|
+
self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters
|
|
433
|
+
) -> OpenAIStreamedResponse:
|
|
431
434
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
432
435
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
433
436
|
first_chunk = await peekable_response.peek()
|
|
@@ -437,6 +440,7 @@ class OpenAIModel(Model):
|
|
|
437
440
|
)
|
|
438
441
|
|
|
439
442
|
return OpenAIStreamedResponse(
|
|
443
|
+
model_request_parameters=model_request_parameters,
|
|
440
444
|
_model_name=self._model_name,
|
|
441
445
|
_model_profile=self.profile,
|
|
442
446
|
_response=peekable_response,
|
|
@@ -444,10 +448,7 @@ class OpenAIModel(Model):
|
|
|
444
448
|
)
|
|
445
449
|
|
|
446
450
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
447
|
-
|
|
448
|
-
if model_request_parameters.output_tools:
|
|
449
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
450
|
-
return tools
|
|
451
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
451
452
|
|
|
452
453
|
def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None:
|
|
453
454
|
for tool in model_request_parameters.builtin_tools:
|
|
@@ -461,8 +462,10 @@ class OpenAIModel(Model):
|
|
|
461
462
|
),
|
|
462
463
|
)
|
|
463
464
|
return WebSearchOptions(search_context_size=tool.search_context_size)
|
|
464
|
-
|
|
465
|
-
raise UserError(
|
|
465
|
+
else:
|
|
466
|
+
raise UserError(
|
|
467
|
+
f'`{tool.__class__.__name__}` is not supported by `OpenAIModel`. If it should be, please file an issue.'
|
|
468
|
+
)
|
|
466
469
|
|
|
467
470
|
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|
|
468
471
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
@@ -631,14 +634,6 @@ class OpenAIResponsesModel(Model):
|
|
|
631
634
|
The [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) is the
|
|
632
635
|
new API for OpenAI models.
|
|
633
636
|
|
|
634
|
-
The Responses API has built-in tools, that you can use instead of building your own:
|
|
635
|
-
|
|
636
|
-
- [Web search](https://platform.openai.com/docs/guides/tools-web-search)
|
|
637
|
-
- [File search](https://platform.openai.com/docs/guides/tools-file-search)
|
|
638
|
-
- [Computer use](https://platform.openai.com/docs/guides/tools-computer-use)
|
|
639
|
-
|
|
640
|
-
Use the `openai_builtin_tools` setting to add these tools to your model.
|
|
641
|
-
|
|
642
637
|
If you are interested in the differences between the Responses API and the Chat Completions API,
|
|
643
638
|
see the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions).
|
|
644
639
|
"""
|
|
@@ -702,13 +697,14 @@ class OpenAIResponsesModel(Model):
|
|
|
702
697
|
messages: list[ModelMessage],
|
|
703
698
|
model_settings: ModelSettings | None,
|
|
704
699
|
model_request_parameters: ModelRequestParameters,
|
|
700
|
+
run_context: RunContext[Any] | None = None,
|
|
705
701
|
) -> AsyncIterator[StreamedResponse]:
|
|
706
702
|
check_allow_model_requests()
|
|
707
703
|
response = await self._responses_create(
|
|
708
704
|
messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
|
|
709
705
|
)
|
|
710
706
|
async with response:
|
|
711
|
-
yield await self._process_streamed_response(response)
|
|
707
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
712
708
|
|
|
713
709
|
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
714
710
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -735,7 +731,9 @@ class OpenAIResponsesModel(Model):
|
|
|
735
731
|
)
|
|
736
732
|
|
|
737
733
|
async def _process_streamed_response(
|
|
738
|
-
self,
|
|
734
|
+
self,
|
|
735
|
+
response: AsyncStream[responses.ResponseStreamEvent],
|
|
736
|
+
model_request_parameters: ModelRequestParameters,
|
|
739
737
|
) -> OpenAIResponsesStreamedResponse:
|
|
740
738
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
741
739
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -745,6 +743,7 @@ class OpenAIResponsesModel(Model):
|
|
|
745
743
|
|
|
746
744
|
assert isinstance(first_chunk, responses.ResponseCreatedEvent)
|
|
747
745
|
return OpenAIResponsesStreamedResponse(
|
|
746
|
+
model_request_parameters=model_request_parameters,
|
|
748
747
|
_model_name=self._model_name,
|
|
749
748
|
_response=peekable_response,
|
|
750
749
|
_timestamp=number_to_datetime(first_chunk.response.created_at),
|
|
@@ -775,8 +774,11 @@ class OpenAIResponsesModel(Model):
|
|
|
775
774
|
model_settings: OpenAIResponsesModelSettings,
|
|
776
775
|
model_request_parameters: ModelRequestParameters,
|
|
777
776
|
) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]:
|
|
778
|
-
tools =
|
|
779
|
-
|
|
777
|
+
tools = (
|
|
778
|
+
self._get_builtin_tools(model_request_parameters)
|
|
779
|
+
+ list(model_settings.get('openai_builtin_tools', []))
|
|
780
|
+
+ self._get_tools(model_request_parameters)
|
|
781
|
+
)
|
|
780
782
|
|
|
781
783
|
if not tools:
|
|
782
784
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
@@ -859,10 +861,7 @@ class OpenAIResponsesModel(Model):
|
|
|
859
861
|
return Reasoning(effort=reasoning_effort, summary=reasoning_summary)
|
|
860
862
|
|
|
861
863
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
|
|
862
|
-
|
|
863
|
-
if model_request_parameters.output_tools:
|
|
864
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
865
|
-
return tools
|
|
864
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
866
865
|
|
|
867
866
|
def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]:
|
|
868
867
|
tools: list[responses.ToolParam] = []
|
|
@@ -878,6 +877,10 @@ class OpenAIResponsesModel(Model):
|
|
|
878
877
|
tools.append(web_search_tool)
|
|
879
878
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
880
879
|
tools.append({'type': 'code_interpreter', 'container': {'type': 'auto'}})
|
|
880
|
+
else:
|
|
881
|
+
raise UserError( # pragma: no cover
|
|
882
|
+
f'`{tool.__class__.__name__}` is not supported by `OpenAIResponsesModel`. If it should be, please file an issue.'
|
|
883
|
+
)
|
|
881
884
|
return tools
|
|
882
885
|
|
|
883
886
|
def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam:
|
pydantic_ai/models/test.py
CHANGED
|
@@ -12,6 +12,7 @@ import pydantic_core
|
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
14
|
from .. import _utils
|
|
15
|
+
from .._run_context import RunContext
|
|
15
16
|
from ..exceptions import UserError
|
|
16
17
|
from ..messages import (
|
|
17
18
|
BuiltinToolCallPart,
|
|
@@ -121,11 +122,13 @@ class TestModel(Model):
|
|
|
121
122
|
messages: list[ModelMessage],
|
|
122
123
|
model_settings: ModelSettings | None,
|
|
123
124
|
model_request_parameters: ModelRequestParameters,
|
|
125
|
+
run_context: RunContext[Any] | None = None,
|
|
124
126
|
) -> AsyncIterator[StreamedResponse]:
|
|
125
127
|
self.last_model_request_parameters = model_request_parameters
|
|
126
128
|
|
|
127
129
|
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
128
130
|
yield TestStreamedResponse(
|
|
131
|
+
model_request_parameters=model_request_parameters,
|
|
129
132
|
_model_name=self._model_name,
|
|
130
133
|
_structured_response=model_response,
|
|
131
134
|
_messages=messages,
|
pydantic_ai/models/wrapper.py
CHANGED
|
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|
|
6
6
|
from functools import cached_property
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
+
from .._run_context import RunContext
|
|
9
10
|
from ..messages import ModelMessage, ModelResponse
|
|
10
11
|
from ..profiles import ModelProfile
|
|
11
12
|
from ..settings import ModelSettings
|
|
@@ -35,8 +36,11 @@ class WrapperModel(Model):
|
|
|
35
36
|
messages: list[ModelMessage],
|
|
36
37
|
model_settings: ModelSettings | None,
|
|
37
38
|
model_request_parameters: ModelRequestParameters,
|
|
39
|
+
run_context: RunContext[Any] | None = None,
|
|
38
40
|
) -> AsyncIterator[StreamedResponse]:
|
|
39
|
-
async with self.wrapped.request_stream(
|
|
41
|
+
async with self.wrapped.request_stream(
|
|
42
|
+
messages, model_settings, model_request_parameters, run_context
|
|
43
|
+
) as response_stream:
|
|
40
44
|
yield response_stream
|
|
41
45
|
|
|
42
46
|
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
@@ -60,4 +64,4 @@ class WrapperModel(Model):
|
|
|
60
64
|
return self.wrapped.settings
|
|
61
65
|
|
|
62
66
|
def __getattr__(self, item: str):
|
|
63
|
-
return getattr(self.wrapped, item)
|
|
67
|
+
return getattr(self.wrapped, item)
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -32,7 +32,7 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
32
32
|
|
|
33
33
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
34
34
|
"""Get the model profile for an OpenAI model."""
|
|
35
|
-
is_reasoning_model = model_name.startswith('o')
|
|
35
|
+
is_reasoning_model = model_name.startswith('o') or model_name.startswith('gpt-5')
|
|
36
36
|
# Structured Outputs (output mode 'native') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later.
|
|
37
37
|
# We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
|
|
38
38
|
# when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
|
pydantic_ai/providers/google.py
CHANGED
|
@@ -12,8 +12,8 @@ from pydantic_ai.profiles.google import google_model_profile
|
|
|
12
12
|
from pydantic_ai.providers import Provider
|
|
13
13
|
|
|
14
14
|
try:
|
|
15
|
-
from google import genai
|
|
16
15
|
from google.auth.credentials import Credentials
|
|
16
|
+
from google.genai import Client
|
|
17
17
|
from google.genai.types import HttpOptionsDict
|
|
18
18
|
except ImportError as _import_error:
|
|
19
19
|
raise ImportError(
|
|
@@ -22,7 +22,7 @@ except ImportError as _import_error:
|
|
|
22
22
|
) from _import_error
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class GoogleProvider(Provider[
|
|
25
|
+
class GoogleProvider(Provider[Client]):
|
|
26
26
|
"""Provider for Google."""
|
|
27
27
|
|
|
28
28
|
@property
|
|
@@ -34,7 +34,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
34
34
|
return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage]
|
|
35
35
|
|
|
36
36
|
@property
|
|
37
|
-
def client(self) ->
|
|
37
|
+
def client(self) -> Client:
|
|
38
38
|
return self._client
|
|
39
39
|
|
|
40
40
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
@@ -53,7 +53,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
53
53
|
) -> None: ...
|
|
54
54
|
|
|
55
55
|
@overload
|
|
56
|
-
def __init__(self, *, client:
|
|
56
|
+
def __init__(self, *, client: Client) -> None: ...
|
|
57
57
|
|
|
58
58
|
@overload
|
|
59
59
|
def __init__(self, *, vertexai: bool = False) -> None: ...
|
|
@@ -65,7 +65,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
65
65
|
credentials: Credentials | None = None,
|
|
66
66
|
project: str | None = None,
|
|
67
67
|
location: VertexAILocation | Literal['global'] | None = None,
|
|
68
|
-
client:
|
|
68
|
+
client: Client | None = None,
|
|
69
69
|
vertexai: bool | None = None,
|
|
70
70
|
) -> None:
|
|
71
71
|
"""Create a new Google provider.
|
|
@@ -102,9 +102,9 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
102
102
|
'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`'
|
|
103
103
|
'to use the Google Generative Language API.'
|
|
104
104
|
)
|
|
105
|
-
self._client =
|
|
105
|
+
self._client = Client(vertexai=vertexai, api_key=api_key, http_options=http_options)
|
|
106
106
|
else:
|
|
107
|
-
self._client =
|
|
107
|
+
self._client = Client(
|
|
108
108
|
vertexai=vertexai,
|
|
109
109
|
project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'),
|
|
110
110
|
# From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149:
|
pydantic_ai/result.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import
|
|
3
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
4
4
|
from copy import copy
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
@@ -22,7 +22,7 @@ from ._output import (
|
|
|
22
22
|
ToolOutputSchema,
|
|
23
23
|
)
|
|
24
24
|
from ._run_context import AgentDepsT, RunContext
|
|
25
|
-
from .messages import AgentStreamEvent
|
|
25
|
+
from .messages import AgentStreamEvent
|
|
26
26
|
from .output import (
|
|
27
27
|
OutputDataT,
|
|
28
28
|
ToolOutput,
|
|
@@ -45,13 +45,13 @@ T = TypeVar('T')
|
|
|
45
45
|
class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
46
46
|
_raw_stream_response: models.StreamedResponse
|
|
47
47
|
_output_schema: OutputSchema[OutputDataT]
|
|
48
|
+
_model_request_parameters: models.ModelRequestParameters
|
|
48
49
|
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
49
50
|
_run_ctx: RunContext[AgentDepsT]
|
|
50
51
|
_usage_limits: UsageLimits | None
|
|
51
52
|
_tool_manager: ToolManager[AgentDepsT]
|
|
52
53
|
|
|
53
54
|
_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
54
|
-
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
55
55
|
_initial_run_ctx_usage: Usage = field(init=False)
|
|
56
56
|
|
|
57
57
|
def __post_init__(self):
|
|
@@ -60,12 +60,12 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
60
60
|
async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
|
|
61
61
|
"""Asynchronously stream the (validated) agent outputs."""
|
|
62
62
|
async for response in self.stream_responses(debounce_by=debounce_by):
|
|
63
|
-
if self.
|
|
63
|
+
if self._raw_stream_response.final_result_event is not None:
|
|
64
64
|
try:
|
|
65
65
|
yield await self._validate_response(response, allow_partial=True)
|
|
66
66
|
except ValidationError:
|
|
67
67
|
pass
|
|
68
|
-
if self.
|
|
68
|
+
if self._raw_stream_response.final_result_event is not None: # pragma: no branch
|
|
69
69
|
yield await self._validate_response(self._raw_stream_response.get())
|
|
70
70
|
|
|
71
71
|
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
|
|
@@ -131,10 +131,11 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
131
131
|
|
|
132
132
|
async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
|
|
133
133
|
"""Validate a structured result message."""
|
|
134
|
-
|
|
134
|
+
final_result_event = self._raw_stream_response.final_result_event
|
|
135
|
+
if final_result_event is None:
|
|
135
136
|
raise exceptions.UnexpectedModelBehavior('Invalid response, unable to find output') # pragma: no cover
|
|
136
137
|
|
|
137
|
-
output_tool_name =
|
|
138
|
+
output_tool_name = final_result_event.tool_name
|
|
138
139
|
|
|
139
140
|
if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None:
|
|
140
141
|
tool_call = next(
|
|
@@ -195,7 +196,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
195
196
|
and isinstance(event.part, _messages.TextPart)
|
|
196
197
|
and event.part.content
|
|
197
198
|
):
|
|
198
|
-
yield event.part.content, event.index
|
|
199
|
+
yield event.part.content, event.index
|
|
199
200
|
elif ( # pragma: no branch
|
|
200
201
|
isinstance(event, _messages.PartDeltaEvent)
|
|
201
202
|
and isinstance(event.delta, _messages.TextPartDelta)
|
|
@@ -221,52 +222,12 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
221
222
|
yield ''.join(deltas)
|
|
222
223
|
|
|
223
224
|
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
224
|
-
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
|
|
228
|
-
first match is found.
|
|
229
|
-
"""
|
|
230
|
-
if self._agent_stream_iterator is not None:
|
|
231
|
-
return self._agent_stream_iterator
|
|
232
|
-
|
|
233
|
-
async def aiter():
|
|
234
|
-
output_schema = self._output_schema
|
|
235
|
-
|
|
236
|
-
def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
|
|
237
|
-
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
|
|
238
|
-
if isinstance(e, _messages.PartStartEvent):
|
|
239
|
-
new_part = e.part
|
|
240
|
-
if isinstance(new_part, _messages.TextPart) and isinstance(
|
|
241
|
-
output_schema, TextOutputSchema
|
|
242
|
-
): # pragma: no branch
|
|
243
|
-
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
244
|
-
elif isinstance(new_part, _messages.ToolCallPart) and (
|
|
245
|
-
tool_def := self._tool_manager.get_tool_def(new_part.tool_name)
|
|
246
|
-
):
|
|
247
|
-
if tool_def.kind == 'output':
|
|
248
|
-
return _messages.FinalResultEvent(
|
|
249
|
-
tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id
|
|
250
|
-
)
|
|
251
|
-
elif tool_def.kind == 'deferred':
|
|
252
|
-
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
253
|
-
|
|
254
|
-
usage_checking_stream = _get_usage_checking_stream_response(
|
|
225
|
+
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s."""
|
|
226
|
+
if self._agent_stream_iterator is None:
|
|
227
|
+
self._agent_stream_iterator = _get_usage_checking_stream_response(
|
|
255
228
|
self._raw_stream_response, self._usage_limits, self.usage
|
|
256
229
|
)
|
|
257
|
-
|
|
258
|
-
yield event
|
|
259
|
-
if (final_result_event := _get_final_result_event(event)) is not None:
|
|
260
|
-
self._final_result_event = final_result_event
|
|
261
|
-
yield final_result_event
|
|
262
|
-
break
|
|
263
|
-
|
|
264
|
-
# If we broke out of the above loop, we need to yield the rest of the events
|
|
265
|
-
# If we didn't, this will just be a no-op
|
|
266
|
-
async for event in usage_checking_stream:
|
|
267
|
-
yield event
|
|
268
|
-
|
|
269
|
-
self._agent_stream_iterator = aiter()
|
|
230
|
+
|
|
270
231
|
return self._agent_stream_iterator
|
|
271
232
|
|
|
272
233
|
|
|
@@ -462,10 +423,10 @@ class FinalResult(Generic[OutputDataT]):
|
|
|
462
423
|
|
|
463
424
|
|
|
464
425
|
def _get_usage_checking_stream_response(
|
|
465
|
-
stream_response:
|
|
426
|
+
stream_response: models.StreamedResponse,
|
|
466
427
|
limits: UsageLimits | None,
|
|
467
428
|
get_usage: Callable[[], Usage],
|
|
468
|
-
) ->
|
|
429
|
+
) -> AsyncIterator[AgentStreamEvent]:
|
|
469
430
|
if limits is not None and limits.has_token_limits():
|
|
470
431
|
|
|
471
432
|
async def _usage_checking_iterator():
|
|
@@ -475,4 +436,9 @@ def _get_usage_checking_stream_response(
|
|
|
475
436
|
|
|
476
437
|
return _usage_checking_iterator()
|
|
477
438
|
else:
|
|
478
|
-
return stream_response
|
|
439
|
+
# TODO: Use `return aiter(stream_response)` once we drop support for Python 3.9
|
|
440
|
+
async def _iterator():
|
|
441
|
+
async for item in stream_response:
|
|
442
|
+
yield item
|
|
443
|
+
|
|
444
|
+
return _iterator()
|