pydantic-ai-slim 1.7.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -0
- pydantic_ai/_agent_graph.py +3 -0
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/_run_context.py +8 -2
- pydantic_ai/_tool_manager.py +1 -0
- pydantic_ai/_utils.py +18 -0
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/__init__.py +13 -3
- pydantic_ai/agent/abstract.py +172 -9
- pydantic_ai/agent/wrapper.py +5 -0
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +31 -0
- pydantic_ai/durable_exec/prefect/_agent.py +28 -0
- pydantic_ai/durable_exec/temporal/_agent.py +28 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -73
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +9 -3
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/messages.py +49 -8
- pydantic_ai/models/__init__.py +42 -1
- pydantic_ai/models/google.py +5 -12
- pydantic_ai/models/groq.py +9 -1
- pydantic_ai/models/openai.py +6 -3
- pydantic_ai/profiles/openai.py +5 -2
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/result.py +178 -11
- pydantic_ai/tools.py +10 -6
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
- pydantic_ai/ui/ag_ui/app.py +148 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/METADATA +10 -6
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/RECORD +47 -33
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,57 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
-
from
|
|
5
|
-
from typing import Annotated, Any, Literal
|
|
4
|
+
from typing import Any, Literal
|
|
6
5
|
|
|
7
|
-
from pydantic import ConfigDict, Discriminator, with_config
|
|
8
6
|
from temporalio import activity, workflow
|
|
9
7
|
from temporalio.workflow import ActivityConfig
|
|
10
|
-
from typing_extensions import assert_never
|
|
11
8
|
|
|
12
9
|
from pydantic_ai import FunctionToolset, ToolsetTool
|
|
13
|
-
from pydantic_ai.exceptions import
|
|
10
|
+
from pydantic_ai.exceptions import UserError
|
|
14
11
|
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
15
12
|
from pydantic_ai.toolsets.function import FunctionToolsetTool
|
|
16
13
|
|
|
17
14
|
from ._run_context import TemporalRunContext
|
|
18
|
-
from ._toolset import
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class _CallToolParams:
|
|
24
|
-
name: str
|
|
25
|
-
tool_args: dict[str, Any]
|
|
26
|
-
serialized_run_context: Any
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class _ApprovalRequired:
|
|
31
|
-
kind: Literal['approval_required'] = 'approval_required'
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@dataclass
|
|
35
|
-
class _CallDeferred:
|
|
36
|
-
kind: Literal['call_deferred'] = 'call_deferred'
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class _ModelRetry:
|
|
41
|
-
message: str
|
|
42
|
-
kind: Literal['model_retry'] = 'model_retry'
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dataclass
|
|
46
|
-
class _ToolReturn:
|
|
47
|
-
result: Any
|
|
48
|
-
kind: Literal['tool_return'] = 'tool_return'
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
_CallToolResult = Annotated[
|
|
52
|
-
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
53
|
-
Discriminator('kind'),
|
|
54
|
-
]
|
|
15
|
+
from ._toolset import (
|
|
16
|
+
CallToolParams,
|
|
17
|
+
CallToolResult,
|
|
18
|
+
TemporalWrapperToolset,
|
|
19
|
+
)
|
|
55
20
|
|
|
56
21
|
|
|
57
22
|
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
@@ -70,7 +35,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
70
35
|
self.tool_activity_config = tool_activity_config
|
|
71
36
|
self.run_context_type = run_context_type
|
|
72
37
|
|
|
73
|
-
async def call_tool_activity(params:
|
|
38
|
+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
|
|
74
39
|
name = params.name
|
|
75
40
|
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
76
41
|
try:
|
|
@@ -84,15 +49,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
84
49
|
# The tool args will already have been validated into their proper types in the `ToolManager`,
|
|
85
50
|
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
|
|
86
51
|
args_dict = tool.args_validator.validate_python(params.tool_args)
|
|
87
|
-
|
|
88
|
-
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
|
|
89
|
-
return _ToolReturn(result=result)
|
|
90
|
-
except ApprovalRequired:
|
|
91
|
-
return _ApprovalRequired()
|
|
92
|
-
except CallDeferred:
|
|
93
|
-
return _CallDeferred()
|
|
94
|
-
except ModelRetry as e:
|
|
95
|
-
return _ModelRetry(message=e.message)
|
|
52
|
+
return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool))
|
|
96
53
|
|
|
97
54
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
98
55
|
call_tool_activity.__annotations__['deps'] = deps_type
|
|
@@ -123,25 +80,18 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
123
80
|
|
|
124
81
|
tool_activity_config = self.activity_config | tool_activity_config
|
|
125
82
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
83
|
+
return self._unwrap_call_tool_result(
|
|
84
|
+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
85
|
+
activity=self.call_tool_activity,
|
|
86
|
+
args=[
|
|
87
|
+
CallToolParams(
|
|
88
|
+
name=name,
|
|
89
|
+
tool_args=tool_args,
|
|
90
|
+
serialized_run_context=serialized_run_context,
|
|
91
|
+
tool_def=None,
|
|
92
|
+
),
|
|
93
|
+
ctx.deps,
|
|
94
|
+
],
|
|
95
|
+
**tool_activity_config,
|
|
96
|
+
)
|
|
137
97
|
)
|
|
138
|
-
if isinstance(result, _ApprovalRequired):
|
|
139
|
-
raise ApprovalRequired()
|
|
140
|
-
elif isinstance(result, _CallDeferred):
|
|
141
|
-
raise CallDeferred()
|
|
142
|
-
elif isinstance(result, _ModelRetry):
|
|
143
|
-
raise ModelRetry(result.message)
|
|
144
|
-
elif isinstance(result, _ToolReturn):
|
|
145
|
-
return result.result
|
|
146
|
-
else:
|
|
147
|
-
assert_never(result)
|
|
@@ -11,11 +11,15 @@ from typing_extensions import Self
|
|
|
11
11
|
|
|
12
12
|
from pydantic_ai import ToolsetTool
|
|
13
13
|
from pydantic_ai.exceptions import UserError
|
|
14
|
-
from pydantic_ai.mcp import MCPServer
|
|
14
|
+
from pydantic_ai.mcp import MCPServer
|
|
15
15
|
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
|
|
16
16
|
|
|
17
17
|
from ._run_context import TemporalRunContext
|
|
18
|
-
from ._toolset import
|
|
18
|
+
from ._toolset import (
|
|
19
|
+
CallToolParams,
|
|
20
|
+
CallToolResult,
|
|
21
|
+
TemporalWrapperToolset,
|
|
22
|
+
)
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
@dataclass
|
|
@@ -24,15 +28,6 @@ class _GetToolsParams:
|
|
|
24
28
|
serialized_run_context: Any
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
@dataclass
|
|
28
|
-
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
29
|
-
class _CallToolParams:
|
|
30
|
-
name: str
|
|
31
|
-
tool_args: dict[str, Any]
|
|
32
|
-
serialized_run_context: Any
|
|
33
|
-
tool_def: ToolDefinition
|
|
34
|
-
|
|
35
|
-
|
|
36
31
|
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
37
32
|
def __init__(
|
|
38
33
|
self,
|
|
@@ -72,13 +67,16 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
|
72
67
|
get_tools_activity
|
|
73
68
|
)
|
|
74
69
|
|
|
75
|
-
async def call_tool_activity(params:
|
|
70
|
+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
|
|
76
71
|
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
72
|
+
assert isinstance(params.tool_def, ToolDefinition)
|
|
73
|
+
return await self._wrap_call_tool_result(
|
|
74
|
+
self.wrapped.call_tool(
|
|
75
|
+
params.name,
|
|
76
|
+
params.tool_args,
|
|
77
|
+
run_context,
|
|
78
|
+
self.tool_for_tool_def(params.tool_def),
|
|
79
|
+
)
|
|
82
80
|
)
|
|
83
81
|
|
|
84
82
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
@@ -125,22 +123,24 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
|
125
123
|
tool_args: dict[str, Any],
|
|
126
124
|
ctx: RunContext[AgentDepsT],
|
|
127
125
|
tool: ToolsetTool[AgentDepsT],
|
|
128
|
-
) ->
|
|
126
|
+
) -> CallToolResult:
|
|
129
127
|
if not workflow.in_workflow():
|
|
130
128
|
return await super().call_tool(name, tool_args, ctx, tool)
|
|
131
129
|
|
|
132
130
|
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
|
|
133
131
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
134
|
-
return
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
132
|
+
return self._unwrap_call_tool_result(
|
|
133
|
+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
134
|
+
activity=self.call_tool_activity,
|
|
135
|
+
args=[
|
|
136
|
+
CallToolParams(
|
|
137
|
+
name=name,
|
|
138
|
+
tool_args=tool_args,
|
|
139
|
+
serialized_run_context=serialized_run_context,
|
|
140
|
+
tool_def=tool.tool_def,
|
|
141
|
+
),
|
|
142
|
+
ctx.deps,
|
|
143
|
+
],
|
|
144
|
+
**tool_activity_config,
|
|
145
|
+
)
|
|
146
146
|
)
|
|
@@ -2,14 +2,19 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from typing_extensions import TypeVar
|
|
6
|
+
|
|
5
7
|
from pydantic_ai.exceptions import UserError
|
|
6
|
-
from pydantic_ai.tools import
|
|
8
|
+
from pydantic_ai.tools import RunContext
|
|
9
|
+
|
|
10
|
+
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
|
|
11
|
+
"""Type variable for the agent dependencies in `RunContext`."""
|
|
7
12
|
|
|
8
13
|
|
|
9
14
|
class TemporalRunContext(RunContext[AgentDepsT]):
|
|
10
15
|
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
|
|
11
16
|
|
|
12
|
-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `
|
|
17
|
+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` attributes will be available.
|
|
13
18
|
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
|
|
14
19
|
"""
|
|
15
20
|
|
|
@@ -44,9 +49,10 @@ class TemporalRunContext(RunContext[AgentDepsT]):
|
|
|
44
49
|
'retry': ctx.retry,
|
|
45
50
|
'max_retries': ctx.max_retries,
|
|
46
51
|
'run_step': ctx.run_step,
|
|
52
|
+
'partial_output': ctx.partial_output,
|
|
47
53
|
}
|
|
48
54
|
|
|
49
55
|
@classmethod
|
|
50
|
-
def deserialize_run_context(cls, ctx: dict[str, Any], deps:
|
|
56
|
+
def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]:
|
|
51
57
|
"""Deserialize the run context from a `dict[str, Any]`."""
|
|
52
58
|
return cls(**ctx, deps=deps)
|
|
@@ -1,17 +1,58 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Annotated, Any, Literal
|
|
6
7
|
|
|
8
|
+
from pydantic import ConfigDict, Discriminator, with_config
|
|
7
9
|
from temporalio.workflow import ActivityConfig
|
|
10
|
+
from typing_extensions import assert_never
|
|
8
11
|
|
|
9
12
|
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
|
|
10
|
-
from pydantic_ai.
|
|
13
|
+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry
|
|
14
|
+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
|
|
11
15
|
|
|
12
16
|
from ._run_context import TemporalRunContext
|
|
13
17
|
|
|
14
18
|
|
|
19
|
+
@dataclass
|
|
20
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
21
|
+
class CallToolParams:
|
|
22
|
+
name: str
|
|
23
|
+
tool_args: dict[str, Any]
|
|
24
|
+
serialized_run_context: Any
|
|
25
|
+
tool_def: ToolDefinition | None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _ApprovalRequired:
|
|
30
|
+
kind: Literal['approval_required'] = 'approval_required'
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class _CallDeferred:
|
|
35
|
+
kind: Literal['call_deferred'] = 'call_deferred'
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class _ModelRetry:
|
|
40
|
+
message: str
|
|
41
|
+
kind: Literal['model_retry'] = 'model_retry'
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class _ToolReturn:
|
|
46
|
+
result: Any
|
|
47
|
+
kind: Literal['tool_return'] = 'tool_return'
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
CallToolResult = Annotated[
|
|
51
|
+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
52
|
+
Discriminator('kind'),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
15
56
|
class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
16
57
|
@property
|
|
17
58
|
def id(self) -> str:
|
|
@@ -30,6 +71,29 @@ class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
|
30
71
|
# Temporalized toolsets cannot be swapped out after the fact.
|
|
31
72
|
return self
|
|
32
73
|
|
|
74
|
+
async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
|
|
75
|
+
try:
|
|
76
|
+
result = await coro
|
|
77
|
+
return _ToolReturn(result=result)
|
|
78
|
+
except ApprovalRequired:
|
|
79
|
+
return _ApprovalRequired()
|
|
80
|
+
except CallDeferred:
|
|
81
|
+
return _CallDeferred()
|
|
82
|
+
except ModelRetry as e:
|
|
83
|
+
return _ModelRetry(message=e.message)
|
|
84
|
+
|
|
85
|
+
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
|
|
86
|
+
if isinstance(result, _ToolReturn):
|
|
87
|
+
return result.result
|
|
88
|
+
elif isinstance(result, _ApprovalRequired):
|
|
89
|
+
raise ApprovalRequired()
|
|
90
|
+
elif isinstance(result, _CallDeferred):
|
|
91
|
+
raise CallDeferred()
|
|
92
|
+
elif isinstance(result, _ModelRetry):
|
|
93
|
+
raise ModelRetry(result.message)
|
|
94
|
+
else:
|
|
95
|
+
assert_never(result)
|
|
96
|
+
|
|
33
97
|
|
|
34
98
|
def temporalize_toolset(
|
|
35
99
|
toolset: AbstractToolset[AgentDepsT],
|
pydantic_ai/messages.py
CHANGED
|
@@ -13,7 +13,7 @@ import pydantic
|
|
|
13
13
|
import pydantic_core
|
|
14
14
|
from genai_prices import calc_price, types as genai_types
|
|
15
15
|
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
|
|
16
|
-
from typing_extensions import
|
|
16
|
+
from typing_extensions import deprecated
|
|
17
17
|
|
|
18
18
|
from . import _otel_messages, _utils
|
|
19
19
|
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
@@ -34,6 +34,7 @@ DocumentMediaType: TypeAlias = Literal[
|
|
|
34
34
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
|
35
35
|
'text/html',
|
|
36
36
|
'text/markdown',
|
|
37
|
+
'application/msword',
|
|
37
38
|
'application/vnd.ms-excel',
|
|
38
39
|
]
|
|
39
40
|
VideoMediaType: TypeAlias = Literal[
|
|
@@ -434,8 +435,12 @@ class DocumentUrl(FileUrl):
|
|
|
434
435
|
return 'application/pdf'
|
|
435
436
|
elif self.url.endswith('.rtf'):
|
|
436
437
|
return 'application/rtf'
|
|
438
|
+
elif self.url.endswith('.doc'):
|
|
439
|
+
return 'application/msword'
|
|
437
440
|
elif self.url.endswith('.docx'):
|
|
438
441
|
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
|
442
|
+
elif self.url.endswith('.xls'):
|
|
443
|
+
return 'application/vnd.ms-excel'
|
|
439
444
|
elif self.url.endswith('.xlsx'):
|
|
440
445
|
return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
|
441
446
|
|
|
@@ -514,16 +519,16 @@ class BinaryContent:
|
|
|
514
519
|
vendor_metadata=bc.vendor_metadata,
|
|
515
520
|
)
|
|
516
521
|
else:
|
|
517
|
-
return bc
|
|
522
|
+
return bc
|
|
518
523
|
|
|
519
524
|
@classmethod
|
|
520
|
-
def from_data_uri(cls, data_uri: str) ->
|
|
525
|
+
def from_data_uri(cls, data_uri: str) -> BinaryContent:
|
|
521
526
|
"""Create a `BinaryContent` from a data URI."""
|
|
522
527
|
prefix = 'data:'
|
|
523
528
|
if not data_uri.startswith(prefix):
|
|
524
|
-
raise ValueError('Data URI must start with "data:"')
|
|
529
|
+
raise ValueError('Data URI must start with "data:"')
|
|
525
530
|
media_type, data = data_uri[len(prefix) :].split(';base64,', 1)
|
|
526
|
-
return cls(data=base64.b64decode(data), media_type=media_type)
|
|
531
|
+
return cls.narrow_type(cls(data=base64.b64decode(data), media_type=media_type))
|
|
527
532
|
|
|
528
533
|
@pydantic.computed_field
|
|
529
534
|
@property
|
|
@@ -645,6 +650,7 @@ _document_format_lookup: dict[str, DocumentFormat] = {
|
|
|
645
650
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'xlsx',
|
|
646
651
|
'text/html': 'html',
|
|
647
652
|
'text/markdown': 'md',
|
|
653
|
+
'application/msword': 'doc',
|
|
648
654
|
'application/vnd.ms-excel': 'xls',
|
|
649
655
|
}
|
|
650
656
|
_audio_format_lookup: dict[str, AudioFormat] = {
|
|
@@ -882,7 +888,10 @@ class RetryPromptPart:
|
|
|
882
888
|
description = self.content
|
|
883
889
|
else:
|
|
884
890
|
json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
|
|
885
|
-
|
|
891
|
+
plural = isinstance(self.content, list) and len(self.content) != 1
|
|
892
|
+
description = (
|
|
893
|
+
f'{len(self.content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```'
|
|
894
|
+
)
|
|
886
895
|
return f'{description}\n\nFix the errors and try again.'
|
|
887
896
|
|
|
888
897
|
def otel_event(self, settings: InstrumentationSettings) -> Event:
|
|
@@ -1612,6 +1621,14 @@ class PartStartEvent:
|
|
|
1612
1621
|
part: ModelResponsePart
|
|
1613
1622
|
"""The newly started `ModelResponsePart`."""
|
|
1614
1623
|
|
|
1624
|
+
previous_part_kind: (
|
|
1625
|
+
Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'file'] | None
|
|
1626
|
+
) = None
|
|
1627
|
+
"""The kind of the previous part, if any.
|
|
1628
|
+
|
|
1629
|
+
This is useful for UI event streams to know whether to group parts of the same kind together when emitting events.
|
|
1630
|
+
"""
|
|
1631
|
+
|
|
1615
1632
|
event_kind: Literal['part_start'] = 'part_start'
|
|
1616
1633
|
"""Event type identifier, used as a discriminator."""
|
|
1617
1634
|
|
|
@@ -1634,6 +1651,30 @@ class PartDeltaEvent:
|
|
|
1634
1651
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
1635
1652
|
|
|
1636
1653
|
|
|
1654
|
+
@dataclass(repr=False, kw_only=True)
|
|
1655
|
+
class PartEndEvent:
|
|
1656
|
+
"""An event indicating that a part is complete."""
|
|
1657
|
+
|
|
1658
|
+
index: int
|
|
1659
|
+
"""The index of the part within the overall response parts list."""
|
|
1660
|
+
|
|
1661
|
+
part: ModelResponsePart
|
|
1662
|
+
"""The complete `ModelResponsePart`."""
|
|
1663
|
+
|
|
1664
|
+
next_part_kind: (
|
|
1665
|
+
Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'file'] | None
|
|
1666
|
+
) = None
|
|
1667
|
+
"""The kind of the next part, if any.
|
|
1668
|
+
|
|
1669
|
+
This is useful for UI event streams to know whether to group parts of the same kind together when emitting events.
|
|
1670
|
+
"""
|
|
1671
|
+
|
|
1672
|
+
event_kind: Literal['part_end'] = 'part_end'
|
|
1673
|
+
"""Event type identifier, used as a discriminator."""
|
|
1674
|
+
|
|
1675
|
+
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
1676
|
+
|
|
1677
|
+
|
|
1637
1678
|
@dataclass(repr=False, kw_only=True)
|
|
1638
1679
|
class FinalResultEvent:
|
|
1639
1680
|
"""An event indicating the response to the current model request matches the output schema and will produce a result."""
|
|
@@ -1649,9 +1690,9 @@ class FinalResultEvent:
|
|
|
1649
1690
|
|
|
1650
1691
|
|
|
1651
1692
|
ModelResponseStreamEvent = Annotated[
|
|
1652
|
-
PartStartEvent | PartDeltaEvent | FinalResultEvent, pydantic.Discriminator('event_kind')
|
|
1693
|
+
PartStartEvent | PartDeltaEvent | PartEndEvent | FinalResultEvent, pydantic.Discriminator('event_kind')
|
|
1653
1694
|
]
|
|
1654
|
-
"""An event in the model response stream, starting a new part, applying a delta to an existing one, or indicating the final result."""
|
|
1695
|
+
"""An event in the model response stream, starting a new part, applying a delta to an existing one, indicating a part is complete, or indicating the final result."""
|
|
1655
1696
|
|
|
1656
1697
|
|
|
1657
1698
|
@dataclass(repr=False)
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -27,6 +27,7 @@ from .._run_context import RunContext
|
|
|
27
27
|
from ..builtin_tools import AbstractBuiltinTool
|
|
28
28
|
from ..exceptions import UserError
|
|
29
29
|
from ..messages import (
|
|
30
|
+
BaseToolCallPart,
|
|
30
31
|
BinaryImage,
|
|
31
32
|
FilePart,
|
|
32
33
|
FileUrl,
|
|
@@ -35,9 +36,12 @@ from ..messages import (
|
|
|
35
36
|
ModelMessage,
|
|
36
37
|
ModelRequest,
|
|
37
38
|
ModelResponse,
|
|
39
|
+
ModelResponsePart,
|
|
38
40
|
ModelResponseStreamEvent,
|
|
41
|
+
PartEndEvent,
|
|
39
42
|
PartStartEvent,
|
|
40
43
|
TextPart,
|
|
44
|
+
ThinkingPart,
|
|
41
45
|
ToolCallPart,
|
|
42
46
|
VideoUrl,
|
|
43
47
|
)
|
|
@@ -543,7 +547,44 @@ class StreamedResponse(ABC):
|
|
|
543
547
|
async for event in iterator:
|
|
544
548
|
yield event
|
|
545
549
|
|
|
546
|
-
|
|
550
|
+
async def iterator_with_part_end(
|
|
551
|
+
iterator: AsyncIterator[ModelResponseStreamEvent],
|
|
552
|
+
) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
553
|
+
last_start_event: PartStartEvent | None = None
|
|
554
|
+
|
|
555
|
+
def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent | None:
|
|
556
|
+
if not last_start_event:
|
|
557
|
+
return None
|
|
558
|
+
|
|
559
|
+
index = last_start_event.index
|
|
560
|
+
part = self._parts_manager.get_parts()[index]
|
|
561
|
+
if not isinstance(part, TextPart | ThinkingPart | BaseToolCallPart):
|
|
562
|
+
# Parts other than these 3 don't have deltas, so don't need an end part.
|
|
563
|
+
return None
|
|
564
|
+
|
|
565
|
+
return PartEndEvent(
|
|
566
|
+
index=index,
|
|
567
|
+
part=part,
|
|
568
|
+
next_part_kind=next_part.part_kind if next_part else None,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
async for event in iterator:
|
|
572
|
+
if isinstance(event, PartStartEvent):
|
|
573
|
+
if last_start_event:
|
|
574
|
+
end_event = part_end_event(event.part)
|
|
575
|
+
if end_event:
|
|
576
|
+
yield end_event
|
|
577
|
+
|
|
578
|
+
event.previous_part_kind = last_start_event.part.part_kind
|
|
579
|
+
last_start_event = event
|
|
580
|
+
|
|
581
|
+
yield event
|
|
582
|
+
|
|
583
|
+
end_event = part_end_event()
|
|
584
|
+
if end_event:
|
|
585
|
+
yield end_event
|
|
586
|
+
|
|
587
|
+
self._event_iterator = iterator_with_part_end(iterator_with_final_event(self._get_event_iterator()))
|
|
547
588
|
return self._event_iterator
|
|
548
589
|
|
|
549
590
|
@abstractmethod
|
pydantic_ai/models/google.py
CHANGED
|
@@ -471,11 +471,9 @@ class GoogleModel(Model):
|
|
|
471
471
|
raise UnexpectedModelBehavior(
|
|
472
472
|
f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
|
|
473
473
|
)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
) # pragma: no cover
|
|
478
|
-
parts = candidate.content.parts or []
|
|
474
|
+
parts = [] # pragma: no cover
|
|
475
|
+
else:
|
|
476
|
+
parts = candidate.content.parts or []
|
|
479
477
|
|
|
480
478
|
usage = _metadata_as_usage(response)
|
|
481
479
|
return _process_response_from_parts(
|
|
@@ -649,17 +647,12 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
649
647
|
# )
|
|
650
648
|
|
|
651
649
|
if candidate.content is None or candidate.content.parts is None:
|
|
652
|
-
if self.finish_reason == '
|
|
653
|
-
# Normal completion - skip this chunk
|
|
654
|
-
continue
|
|
655
|
-
elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
|
|
650
|
+
if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
|
|
656
651
|
raise UnexpectedModelBehavior(
|
|
657
652
|
f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
|
|
658
653
|
)
|
|
659
654
|
else: # pragma: no cover
|
|
660
|
-
|
|
661
|
-
'Content field missing from streaming Gemini response', chunk.model_dump_json()
|
|
662
|
-
)
|
|
655
|
+
continue
|
|
663
656
|
|
|
664
657
|
parts = candidate.content.parts
|
|
665
658
|
if not parts:
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -524,6 +524,8 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
524
524
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
525
525
|
try:
|
|
526
526
|
executed_tool_call_id: str | None = None
|
|
527
|
+
reasoning_index = 0
|
|
528
|
+
reasoning = False
|
|
527
529
|
async for chunk in self._response:
|
|
528
530
|
self._usage += _map_usage(chunk)
|
|
529
531
|
|
|
@@ -540,10 +542,16 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
540
542
|
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
541
543
|
|
|
542
544
|
if choice.delta.reasoning is not None:
|
|
545
|
+
if not reasoning:
|
|
546
|
+
reasoning_index += 1
|
|
547
|
+
reasoning = True
|
|
548
|
+
|
|
543
549
|
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
544
550
|
yield self._parts_manager.handle_thinking_delta(
|
|
545
|
-
vendor_part_id='reasoning', content=choice.delta.reasoning
|
|
551
|
+
vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
|
|
546
552
|
)
|
|
553
|
+
else:
|
|
554
|
+
reasoning = False
|
|
547
555
|
|
|
548
556
|
if choice.delta.executed_tools:
|
|
549
557
|
for tool in choice.delta.executed_tools:
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -948,6 +948,10 @@ class OpenAIResponsesModel(Model):
|
|
|
948
948
|
|
|
949
949
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
950
950
|
|
|
951
|
+
@property
|
|
952
|
+
def base_url(self) -> str:
|
|
953
|
+
return str(self.client.base_url)
|
|
954
|
+
|
|
951
955
|
@property
|
|
952
956
|
def model_name(self) -> OpenAIModelName:
|
|
953
957
|
"""The model name."""
|
|
@@ -1148,10 +1152,10 @@ class OpenAIResponsesModel(Model):
|
|
|
1148
1152
|
+ list(model_settings.get('openai_builtin_tools', []))
|
|
1149
1153
|
+ self._get_tools(model_request_parameters)
|
|
1150
1154
|
)
|
|
1151
|
-
|
|
1155
|
+
profile = OpenAIModelProfile.from_profile(self.profile)
|
|
1152
1156
|
if not tools:
|
|
1153
1157
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
1154
|
-
elif not model_request_parameters.allow_text_output:
|
|
1158
|
+
elif not model_request_parameters.allow_text_output and profile.openai_supports_tool_choice_required:
|
|
1155
1159
|
tool_choice = 'required'
|
|
1156
1160
|
else:
|
|
1157
1161
|
tool_choice = 'auto'
|
|
@@ -1184,7 +1188,6 @@ class OpenAIResponsesModel(Model):
|
|
|
1184
1188
|
text = text or {}
|
|
1185
1189
|
text['verbosity'] = verbosity
|
|
1186
1190
|
|
|
1187
|
-
profile = OpenAIModelProfile.from_profile(self.profile)
|
|
1188
1191
|
unsupported_model_settings = profile.openai_unsupported_model_settings
|
|
1189
1192
|
for setting in unsupported_model_settings:
|
|
1190
1193
|
model_settings.pop(setting, None)
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -62,7 +62,10 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
62
62
|
|
|
63
63
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
64
64
|
"""Get the model profile for an OpenAI model."""
|
|
65
|
-
|
|
65
|
+
is_gpt_5 = model_name.startswith('gpt-5')
|
|
66
|
+
is_o_series = model_name.startswith('o')
|
|
67
|
+
is_reasoning_model = is_o_series or (is_gpt_5 and 'gpt-5-chat' not in model_name)
|
|
68
|
+
|
|
66
69
|
# Check if the model supports web search (only specific search-preview models)
|
|
67
70
|
supports_web_search = '-search-preview' in model_name
|
|
68
71
|
|
|
@@ -91,7 +94,7 @@ def openai_model_profile(model_name: str) -> ModelProfile:
|
|
|
91
94
|
json_schema_transformer=OpenAIJsonSchemaTransformer,
|
|
92
95
|
supports_json_schema_output=True,
|
|
93
96
|
supports_json_object_output=True,
|
|
94
|
-
supports_image_output=
|
|
97
|
+
supports_image_output=is_gpt_5 or 'o3' in model_name or '4.1' in model_name or '4o' in model_name,
|
|
95
98
|
openai_unsupported_model_settings=openai_unsupported_model_settings,
|
|
96
99
|
openai_system_prompt_role=openai_system_prompt_role,
|
|
97
100
|
openai_chat_supports_web_search=supports_web_search,
|