pydantic-ai-slim 1.9.1__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_run_context.py +8 -2
- pydantic_ai/_tool_manager.py +1 -0
- pydantic_ai/_utils.py +18 -0
- pydantic_ai/agent/__init__.py +13 -3
- pydantic_ai/agent/abstract.py +155 -3
- pydantic_ai/agent/wrapper.py +5 -0
- pydantic_ai/durable_exec/dbos/_agent.py +28 -0
- pydantic_ai/durable_exec/prefect/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -73
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +9 -3
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/messages.py +10 -1
- pydantic_ai/models/openai.py +4 -0
- pydantic_ai/profiles/openai.py +5 -2
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/result.py +159 -4
- pydantic_ai/tools.py +10 -6
- pydantic_ai/ui/_event_stream.py +4 -4
- pydantic_ai/ui/ag_ui/_event_stream.py +11 -2
- {pydantic_ai_slim-1.9.1.dist-info → pydantic_ai_slim-1.11.0.dist-info}/METADATA +8 -6
- {pydantic_ai_slim-1.9.1.dist-info → pydantic_ai_slim-1.11.0.dist-info}/RECORD +26 -26
- {pydantic_ai_slim-1.9.1.dist-info → pydantic_ai_slim-1.11.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.9.1.dist-info → pydantic_ai_slim-1.11.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.9.1.dist-info → pydantic_ai_slim-1.11.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,57 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
-
from
|
|
5
|
-
from typing import Annotated, Any, Literal
|
|
4
|
+
from typing import Any, Literal
|
|
6
5
|
|
|
7
|
-
from pydantic import ConfigDict, Discriminator, with_config
|
|
8
6
|
from temporalio import activity, workflow
|
|
9
7
|
from temporalio.workflow import ActivityConfig
|
|
10
|
-
from typing_extensions import assert_never
|
|
11
8
|
|
|
12
9
|
from pydantic_ai import FunctionToolset, ToolsetTool
|
|
13
|
-
from pydantic_ai.exceptions import
|
|
10
|
+
from pydantic_ai.exceptions import UserError
|
|
14
11
|
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
15
12
|
from pydantic_ai.toolsets.function import FunctionToolsetTool
|
|
16
13
|
|
|
17
14
|
from ._run_context import TemporalRunContext
|
|
18
|
-
from ._toolset import
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class _CallToolParams:
|
|
24
|
-
name: str
|
|
25
|
-
tool_args: dict[str, Any]
|
|
26
|
-
serialized_run_context: Any
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class _ApprovalRequired:
|
|
31
|
-
kind: Literal['approval_required'] = 'approval_required'
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@dataclass
|
|
35
|
-
class _CallDeferred:
|
|
36
|
-
kind: Literal['call_deferred'] = 'call_deferred'
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class _ModelRetry:
|
|
41
|
-
message: str
|
|
42
|
-
kind: Literal['model_retry'] = 'model_retry'
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dataclass
|
|
46
|
-
class _ToolReturn:
|
|
47
|
-
result: Any
|
|
48
|
-
kind: Literal['tool_return'] = 'tool_return'
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
_CallToolResult = Annotated[
|
|
52
|
-
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
53
|
-
Discriminator('kind'),
|
|
54
|
-
]
|
|
15
|
+
from ._toolset import (
|
|
16
|
+
CallToolParams,
|
|
17
|
+
CallToolResult,
|
|
18
|
+
TemporalWrapperToolset,
|
|
19
|
+
)
|
|
55
20
|
|
|
56
21
|
|
|
57
22
|
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
@@ -70,7 +35,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
70
35
|
self.tool_activity_config = tool_activity_config
|
|
71
36
|
self.run_context_type = run_context_type
|
|
72
37
|
|
|
73
|
-
async def call_tool_activity(params:
|
|
38
|
+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
|
|
74
39
|
name = params.name
|
|
75
40
|
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
76
41
|
try:
|
|
@@ -84,15 +49,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
84
49
|
# The tool args will already have been validated into their proper types in the `ToolManager`,
|
|
85
50
|
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
|
|
86
51
|
args_dict = tool.args_validator.validate_python(params.tool_args)
|
|
87
|
-
|
|
88
|
-
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
|
|
89
|
-
return _ToolReturn(result=result)
|
|
90
|
-
except ApprovalRequired:
|
|
91
|
-
return _ApprovalRequired()
|
|
92
|
-
except CallDeferred:
|
|
93
|
-
return _CallDeferred()
|
|
94
|
-
except ModelRetry as e:
|
|
95
|
-
return _ModelRetry(message=e.message)
|
|
52
|
+
return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool))
|
|
96
53
|
|
|
97
54
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
98
55
|
call_tool_activity.__annotations__['deps'] = deps_type
|
|
@@ -123,25 +80,18 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
123
80
|
|
|
124
81
|
tool_activity_config = self.activity_config | tool_activity_config
|
|
125
82
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
83
|
+
return self._unwrap_call_tool_result(
|
|
84
|
+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
85
|
+
activity=self.call_tool_activity,
|
|
86
|
+
args=[
|
|
87
|
+
CallToolParams(
|
|
88
|
+
name=name,
|
|
89
|
+
tool_args=tool_args,
|
|
90
|
+
serialized_run_context=serialized_run_context,
|
|
91
|
+
tool_def=None,
|
|
92
|
+
),
|
|
93
|
+
ctx.deps,
|
|
94
|
+
],
|
|
95
|
+
**tool_activity_config,
|
|
96
|
+
)
|
|
137
97
|
)
|
|
138
|
-
if isinstance(result, _ApprovalRequired):
|
|
139
|
-
raise ApprovalRequired()
|
|
140
|
-
elif isinstance(result, _CallDeferred):
|
|
141
|
-
raise CallDeferred()
|
|
142
|
-
elif isinstance(result, _ModelRetry):
|
|
143
|
-
raise ModelRetry(result.message)
|
|
144
|
-
elif isinstance(result, _ToolReturn):
|
|
145
|
-
return result.result
|
|
146
|
-
else:
|
|
147
|
-
assert_never(result)
|
|
@@ -11,11 +11,15 @@ from typing_extensions import Self
|
|
|
11
11
|
|
|
12
12
|
from pydantic_ai import ToolsetTool
|
|
13
13
|
from pydantic_ai.exceptions import UserError
|
|
14
|
-
from pydantic_ai.mcp import MCPServer
|
|
14
|
+
from pydantic_ai.mcp import MCPServer
|
|
15
15
|
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
|
|
16
16
|
|
|
17
17
|
from ._run_context import TemporalRunContext
|
|
18
|
-
from ._toolset import
|
|
18
|
+
from ._toolset import (
|
|
19
|
+
CallToolParams,
|
|
20
|
+
CallToolResult,
|
|
21
|
+
TemporalWrapperToolset,
|
|
22
|
+
)
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
@dataclass
|
|
@@ -24,15 +28,6 @@ class _GetToolsParams:
|
|
|
24
28
|
serialized_run_context: Any
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
@dataclass
|
|
28
|
-
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
29
|
-
class _CallToolParams:
|
|
30
|
-
name: str
|
|
31
|
-
tool_args: dict[str, Any]
|
|
32
|
-
serialized_run_context: Any
|
|
33
|
-
tool_def: ToolDefinition
|
|
34
|
-
|
|
35
|
-
|
|
36
31
|
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
37
32
|
def __init__(
|
|
38
33
|
self,
|
|
@@ -72,13 +67,16 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
|
72
67
|
get_tools_activity
|
|
73
68
|
)
|
|
74
69
|
|
|
75
|
-
async def call_tool_activity(params:
|
|
70
|
+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
|
|
76
71
|
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
72
|
+
assert isinstance(params.tool_def, ToolDefinition)
|
|
73
|
+
return await self._wrap_call_tool_result(
|
|
74
|
+
self.wrapped.call_tool(
|
|
75
|
+
params.name,
|
|
76
|
+
params.tool_args,
|
|
77
|
+
run_context,
|
|
78
|
+
self.tool_for_tool_def(params.tool_def),
|
|
79
|
+
)
|
|
82
80
|
)
|
|
83
81
|
|
|
84
82
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
@@ -125,22 +123,24 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
|
125
123
|
tool_args: dict[str, Any],
|
|
126
124
|
ctx: RunContext[AgentDepsT],
|
|
127
125
|
tool: ToolsetTool[AgentDepsT],
|
|
128
|
-
) ->
|
|
126
|
+
) -> CallToolResult:
|
|
129
127
|
if not workflow.in_workflow():
|
|
130
128
|
return await super().call_tool(name, tool_args, ctx, tool)
|
|
131
129
|
|
|
132
130
|
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
|
|
133
131
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
134
|
-
return
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
132
|
+
return self._unwrap_call_tool_result(
|
|
133
|
+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
134
|
+
activity=self.call_tool_activity,
|
|
135
|
+
args=[
|
|
136
|
+
CallToolParams(
|
|
137
|
+
name=name,
|
|
138
|
+
tool_args=tool_args,
|
|
139
|
+
serialized_run_context=serialized_run_context,
|
|
140
|
+
tool_def=tool.tool_def,
|
|
141
|
+
),
|
|
142
|
+
ctx.deps,
|
|
143
|
+
],
|
|
144
|
+
**tool_activity_config,
|
|
145
|
+
)
|
|
146
146
|
)
|
|
@@ -2,14 +2,19 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from typing_extensions import TypeVar
|
|
6
|
+
|
|
5
7
|
from pydantic_ai.exceptions import UserError
|
|
6
|
-
from pydantic_ai.tools import
|
|
8
|
+
from pydantic_ai.tools import RunContext
|
|
9
|
+
|
|
10
|
+
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
|
|
11
|
+
"""Type variable for the agent dependencies in `RunContext`."""
|
|
7
12
|
|
|
8
13
|
|
|
9
14
|
class TemporalRunContext(RunContext[AgentDepsT]):
|
|
10
15
|
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
|
|
11
16
|
|
|
12
|
-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `
|
|
17
|
+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` attributes will be available.
|
|
13
18
|
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
|
|
14
19
|
"""
|
|
15
20
|
|
|
@@ -44,9 +49,10 @@ class TemporalRunContext(RunContext[AgentDepsT]):
|
|
|
44
49
|
'retry': ctx.retry,
|
|
45
50
|
'max_retries': ctx.max_retries,
|
|
46
51
|
'run_step': ctx.run_step,
|
|
52
|
+
'partial_output': ctx.partial_output,
|
|
47
53
|
}
|
|
48
54
|
|
|
49
55
|
@classmethod
|
|
50
|
-
def deserialize_run_context(cls, ctx: dict[str, Any], deps:
|
|
56
|
+
def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]:
|
|
51
57
|
"""Deserialize the run context from a `dict[str, Any]`."""
|
|
52
58
|
return cls(**ctx, deps=deps)
|
|
@@ -1,17 +1,58 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Annotated, Any, Literal
|
|
6
7
|
|
|
8
|
+
from pydantic import ConfigDict, Discriminator, with_config
|
|
7
9
|
from temporalio.workflow import ActivityConfig
|
|
10
|
+
from typing_extensions import assert_never
|
|
8
11
|
|
|
9
12
|
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
|
|
10
|
-
from pydantic_ai.
|
|
13
|
+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry
|
|
14
|
+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
|
|
11
15
|
|
|
12
16
|
from ._run_context import TemporalRunContext
|
|
13
17
|
|
|
14
18
|
|
|
19
|
+
@dataclass
|
|
20
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
21
|
+
class CallToolParams:
|
|
22
|
+
name: str
|
|
23
|
+
tool_args: dict[str, Any]
|
|
24
|
+
serialized_run_context: Any
|
|
25
|
+
tool_def: ToolDefinition | None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _ApprovalRequired:
|
|
30
|
+
kind: Literal['approval_required'] = 'approval_required'
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class _CallDeferred:
|
|
35
|
+
kind: Literal['call_deferred'] = 'call_deferred'
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class _ModelRetry:
|
|
40
|
+
message: str
|
|
41
|
+
kind: Literal['model_retry'] = 'model_retry'
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class _ToolReturn:
|
|
46
|
+
result: Any
|
|
47
|
+
kind: Literal['tool_return'] = 'tool_return'
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
CallToolResult = Annotated[
|
|
51
|
+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
52
|
+
Discriminator('kind'),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
15
56
|
class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
16
57
|
@property
|
|
17
58
|
def id(self) -> str:
|
|
@@ -30,6 +71,29 @@ class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
|
30
71
|
# Temporalized toolsets cannot be swapped out after the fact.
|
|
31
72
|
return self
|
|
32
73
|
|
|
74
|
+
async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
|
|
75
|
+
try:
|
|
76
|
+
result = await coro
|
|
77
|
+
return _ToolReturn(result=result)
|
|
78
|
+
except ApprovalRequired:
|
|
79
|
+
return _ApprovalRequired()
|
|
80
|
+
except CallDeferred:
|
|
81
|
+
return _CallDeferred()
|
|
82
|
+
except ModelRetry as e:
|
|
83
|
+
return _ModelRetry(message=e.message)
|
|
84
|
+
|
|
85
|
+
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
|
|
86
|
+
if isinstance(result, _ToolReturn):
|
|
87
|
+
return result.result
|
|
88
|
+
elif isinstance(result, _ApprovalRequired):
|
|
89
|
+
raise ApprovalRequired()
|
|
90
|
+
elif isinstance(result, _CallDeferred):
|
|
91
|
+
raise CallDeferred()
|
|
92
|
+
elif isinstance(result, _ModelRetry):
|
|
93
|
+
raise ModelRetry(result.message)
|
|
94
|
+
else:
|
|
95
|
+
assert_never(result)
|
|
96
|
+
|
|
33
97
|
|
|
34
98
|
def temporalize_toolset(
|
|
35
99
|
toolset: AbstractToolset[AgentDepsT],
|
pydantic_ai/messages.py
CHANGED
|
@@ -34,6 +34,7 @@ DocumentMediaType: TypeAlias = Literal[
|
|
|
34
34
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
|
35
35
|
'text/html',
|
|
36
36
|
'text/markdown',
|
|
37
|
+
'application/msword',
|
|
37
38
|
'application/vnd.ms-excel',
|
|
38
39
|
]
|
|
39
40
|
VideoMediaType: TypeAlias = Literal[
|
|
@@ -434,8 +435,12 @@ class DocumentUrl(FileUrl):
|
|
|
434
435
|
return 'application/pdf'
|
|
435
436
|
elif self.url.endswith('.rtf'):
|
|
436
437
|
return 'application/rtf'
|
|
438
|
+
elif self.url.endswith('.doc'):
|
|
439
|
+
return 'application/msword'
|
|
437
440
|
elif self.url.endswith('.docx'):
|
|
438
441
|
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
|
442
|
+
elif self.url.endswith('.xls'):
|
|
443
|
+
return 'application/vnd.ms-excel'
|
|
439
444
|
elif self.url.endswith('.xlsx'):
|
|
440
445
|
return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
|
441
446
|
|
|
@@ -645,6 +650,7 @@ _document_format_lookup: dict[str, DocumentFormat] = {
|
|
|
645
650
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'xlsx',
|
|
646
651
|
'text/html': 'html',
|
|
647
652
|
'text/markdown': 'md',
|
|
653
|
+
'application/msword': 'doc',
|
|
648
654
|
'application/vnd.ms-excel': 'xls',
|
|
649
655
|
}
|
|
650
656
|
_audio_format_lookup: dict[str, AudioFormat] = {
|
|
@@ -882,7 +888,10 @@ class RetryPromptPart:
|
|
|
882
888
|
description = self.content
|
|
883
889
|
else:
|
|
884
890
|
json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
|
|
885
|
-
|
|
891
|
+
plural = isinstance(self.content, list) and len(self.content) != 1
|
|
892
|
+
description = (
|
|
893
|
+
f'{len(self.content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```'
|
|
894
|
+
)
|
|
886
895
|
return f'{description}\n\nFix the errors and try again.'
|
|
887
896
|
|
|
888
897
|
def otel_event(self, settings: InstrumentationSettings) -> Event:
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -948,6 +948,10 @@ class OpenAIResponsesModel(Model):
|
|
|
948
948
|
|
|
949
949
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
950
950
|
|
|
951
|
+
@property
|
|
952
|
+
def base_url(self) -> str:
|
|
953
|
+
return str(self.client.base_url)
|
|
954
|
+
|
|
951
955
|
@property
|
|
952
956
|
def model_name(self) -> OpenAIModelName:
|
|
953
957
|
"""The model name."""
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -62,7 +62,10 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
62
62
|
|
|
63
63
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
64
64
|
"""Get the model profile for an OpenAI model."""
|
|
65
|
-
|
|
65
|
+
is_gpt_5 = model_name.startswith('gpt-5')
|
|
66
|
+
is_o_series = model_name.startswith('o')
|
|
67
|
+
is_reasoning_model = is_o_series or (is_gpt_5 and 'gpt-5-chat' not in model_name)
|
|
68
|
+
|
|
66
69
|
# Check if the model supports web search (only specific search-preview models)
|
|
67
70
|
supports_web_search = '-search-preview' in model_name
|
|
68
71
|
|
|
@@ -91,7 +94,7 @@ def openai_model_profile(model_name: str) -> ModelProfile:
|
|
|
91
94
|
json_schema_transformer=OpenAIJsonSchemaTransformer,
|
|
92
95
|
supports_json_schema_output=True,
|
|
93
96
|
supports_json_object_output=True,
|
|
94
|
-
supports_image_output=
|
|
97
|
+
supports_image_output=is_gpt_5 or 'o3' in model_name or '4.1' in model_name or '4o' in model_name,
|
|
95
98
|
openai_unsupported_model_settings=openai_unsupported_model_settings,
|
|
96
99
|
openai_system_prompt_role=openai_system_prompt_role,
|
|
97
100
|
openai_chat_supports_web_search=supports_web_search,
|
|
@@ -81,6 +81,9 @@ class OpenRouterProvider(Provider[AsyncOpenAI]):
|
|
|
81
81
|
@overload
|
|
82
82
|
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
|
|
83
83
|
|
|
84
|
+
@overload
|
|
85
|
+
def __init__(self, *, http_client: httpx.AsyncClient) -> None: ...
|
|
86
|
+
|
|
84
87
|
@overload
|
|
85
88
|
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
86
89
|
|
pydantic_ai/result.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from dataclasses import dataclass, field
|
|
5
|
+
from dataclasses import dataclass, field, replace
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import TYPE_CHECKING, Generic, cast, overload
|
|
8
8
|
|
|
@@ -35,6 +35,7 @@ __all__ = (
|
|
|
35
35
|
'OutputDataT_inv',
|
|
36
36
|
'ToolOutput',
|
|
37
37
|
'OutputValidatorFunc',
|
|
38
|
+
'StreamedRunResultSync',
|
|
38
39
|
)
|
|
39
40
|
|
|
40
41
|
|
|
@@ -116,7 +117,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
116
117
|
else:
|
|
117
118
|
async for text in self._stream_response_text(delta=False, debounce_by=debounce_by):
|
|
118
119
|
for validator in self._output_validators:
|
|
119
|
-
text = await validator.validate(text, self._run_ctx)
|
|
120
|
+
text = await validator.validate(text, replace(self._run_ctx, partial_output=True))
|
|
120
121
|
yield text
|
|
121
122
|
|
|
122
123
|
# TODO (v2): Drop in favor of `response` property
|
|
@@ -194,7 +195,9 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
194
195
|
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
195
196
|
)
|
|
196
197
|
for validator in self._output_validators:
|
|
197
|
-
result_data = await validator.validate(
|
|
198
|
+
result_data = await validator.validate(
|
|
199
|
+
result_data, replace(self._run_ctx, partial_output=allow_partial)
|
|
200
|
+
)
|
|
198
201
|
return result_data
|
|
199
202
|
else:
|
|
200
203
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
@@ -555,6 +558,158 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
555
558
|
await self._on_complete()
|
|
556
559
|
|
|
557
560
|
|
|
561
|
+
@dataclass(init=False)
|
|
562
|
+
class StreamedRunResultSync(Generic[AgentDepsT, OutputDataT]):
|
|
563
|
+
"""Synchronous wrapper for [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] that only exposes sync methods."""
|
|
564
|
+
|
|
565
|
+
_streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
|
|
566
|
+
|
|
567
|
+
def __init__(self, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]) -> None:
|
|
568
|
+
self._streamed_run_result = streamed_run_result
|
|
569
|
+
|
|
570
|
+
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
571
|
+
"""Return the history of messages.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
575
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
576
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
577
|
+
not be modified.
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
List of messages.
|
|
581
|
+
"""
|
|
582
|
+
return self._streamed_run_result.all_messages(output_tool_return_content=output_tool_return_content)
|
|
583
|
+
|
|
584
|
+
def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
|
|
585
|
+
"""Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResultSync.all_messages] as JSON bytes.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
589
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
590
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
591
|
+
not be modified.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
JSON bytes representing the messages.
|
|
595
|
+
"""
|
|
596
|
+
return self._streamed_run_result.all_messages_json(output_tool_return_content=output_tool_return_content)
|
|
597
|
+
|
|
598
|
+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
599
|
+
"""Return new messages associated with this run.
|
|
600
|
+
|
|
601
|
+
Messages from older runs are excluded.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
605
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
606
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
607
|
+
not be modified.
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
List of new messages.
|
|
611
|
+
"""
|
|
612
|
+
return self._streamed_run_result.new_messages(output_tool_return_content=output_tool_return_content)
|
|
613
|
+
|
|
614
|
+
def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
|
|
615
|
+
"""Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResultSync.new_messages] as JSON bytes.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
619
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
620
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
621
|
+
not be modified.
|
|
622
|
+
|
|
623
|
+
Returns:
|
|
624
|
+
JSON bytes representing the new messages.
|
|
625
|
+
"""
|
|
626
|
+
return self._streamed_run_result.new_messages_json(output_tool_return_content=output_tool_return_content)
|
|
627
|
+
|
|
628
|
+
def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]:
|
|
629
|
+
"""Stream the output as an iterable.
|
|
630
|
+
|
|
631
|
+
The pydantic validator for structured data will be called in
|
|
632
|
+
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
|
|
633
|
+
on each iteration.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
|
|
637
|
+
Debouncing is particularly important for long structured outputs to reduce the overhead of
|
|
638
|
+
performing validation as each token is received.
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
An iterable of the response data.
|
|
642
|
+
"""
|
|
643
|
+
return _utils.sync_async_iterator(self._streamed_run_result.stream_output(debounce_by=debounce_by))
|
|
644
|
+
|
|
645
|
+
def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]:
|
|
646
|
+
"""Stream the text result as an iterable.
|
|
647
|
+
|
|
648
|
+
!!! note
|
|
649
|
+
Result validators will NOT be called on the text result if `delta=True`.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
|
|
653
|
+
up to the current point.
|
|
654
|
+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
|
|
655
|
+
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
656
|
+
performing validation as each token is received.
|
|
657
|
+
"""
|
|
658
|
+
return _utils.sync_async_iterator(self._streamed_run_result.stream_text(delta=delta, debounce_by=debounce_by))
|
|
659
|
+
|
|
660
|
+
def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]:
|
|
661
|
+
"""Stream the response as an iterable of Structured LLM Messages.
|
|
662
|
+
|
|
663
|
+
Args:
|
|
664
|
+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
|
|
665
|
+
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
666
|
+
performing validation as each token is received.
|
|
667
|
+
|
|
668
|
+
Returns:
|
|
669
|
+
An iterable of the structured response message and whether that is the last message.
|
|
670
|
+
"""
|
|
671
|
+
return _utils.sync_async_iterator(self._streamed_run_result.stream_responses(debounce_by=debounce_by))
|
|
672
|
+
|
|
673
|
+
def get_output(self) -> OutputDataT:
|
|
674
|
+
"""Stream the whole response, validate and return it."""
|
|
675
|
+
return _utils.get_event_loop().run_until_complete(self._streamed_run_result.get_output())
|
|
676
|
+
|
|
677
|
+
@property
|
|
678
|
+
def response(self) -> _messages.ModelResponse:
|
|
679
|
+
"""Return the current state of the response."""
|
|
680
|
+
return self._streamed_run_result.response
|
|
681
|
+
|
|
682
|
+
def usage(self) -> RunUsage:
|
|
683
|
+
"""Return the usage of the whole run.
|
|
684
|
+
|
|
685
|
+
!!! note
|
|
686
|
+
This won't return the full usage until the stream is finished.
|
|
687
|
+
"""
|
|
688
|
+
return self._streamed_run_result.usage()
|
|
689
|
+
|
|
690
|
+
def timestamp(self) -> datetime:
|
|
691
|
+
"""Get the timestamp of the response."""
|
|
692
|
+
return self._streamed_run_result.timestamp()
|
|
693
|
+
|
|
694
|
+
def validate_response_output(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
|
|
695
|
+
"""Validate a structured result message."""
|
|
696
|
+
return _utils.get_event_loop().run_until_complete(
|
|
697
|
+
self._streamed_run_result.validate_response_output(message, allow_partial=allow_partial)
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
@property
|
|
701
|
+
def is_complete(self) -> bool:
|
|
702
|
+
"""Whether the stream has all been received.
|
|
703
|
+
|
|
704
|
+
This is set to `True` when one of
|
|
705
|
+
[`stream_output`][pydantic_ai.result.StreamedRunResultSync.stream_output],
|
|
706
|
+
[`stream_text`][pydantic_ai.result.StreamedRunResultSync.stream_text],
|
|
707
|
+
[`stream_responses`][pydantic_ai.result.StreamedRunResultSync.stream_responses] or
|
|
708
|
+
[`get_output`][pydantic_ai.result.StreamedRunResultSync.get_output] completes.
|
|
709
|
+
"""
|
|
710
|
+
return self._streamed_run_result.is_complete
|
|
711
|
+
|
|
712
|
+
|
|
558
713
|
@dataclass(repr=False)
|
|
559
714
|
class FinalResult(Generic[OutputDataT]):
|
|
560
715
|
"""Marker class storing the final output of an agent run and associated metadata."""
|