pydantic-ai-slim 1.9.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +18 -14
- pydantic_ai/_output.py +20 -105
- pydantic_ai/_run_context.py +8 -2
- pydantic_ai/_tool_manager.py +30 -11
- pydantic_ai/_utils.py +18 -0
- pydantic_ai/agent/__init__.py +34 -32
- pydantic_ai/agent/abstract.py +155 -3
- pydantic_ai/agent/wrapper.py +5 -0
- pydantic_ai/common_tools/duckduckgo.py +1 -1
- pydantic_ai/durable_exec/dbos/_agent.py +28 -0
- pydantic_ai/durable_exec/prefect/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -73
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +9 -3
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/mcp.py +4 -4
- pydantic_ai/messages.py +11 -2
- pydantic_ai/models/__init__.py +80 -35
- pydantic_ai/models/anthropic.py +27 -8
- pydantic_ai/models/bedrock.py +3 -3
- pydantic_ai/models/cohere.py +5 -3
- pydantic_ai/models/fallback.py +25 -4
- pydantic_ai/models/function.py +8 -0
- pydantic_ai/models/gemini.py +3 -3
- pydantic_ai/models/google.py +25 -22
- pydantic_ai/models/groq.py +5 -3
- pydantic_ai/models/huggingface.py +3 -3
- pydantic_ai/models/instrumented.py +29 -13
- pydantic_ai/models/mistral.py +6 -4
- pydantic_ai/models/openai.py +15 -6
- pydantic_ai/models/outlines.py +21 -12
- pydantic_ai/models/wrapper.py +1 -1
- pydantic_ai/output.py +3 -2
- pydantic_ai/profiles/openai.py +5 -2
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/result.py +159 -4
- pydantic_ai/tools.py +12 -10
- pydantic_ai/ui/_adapter.py +2 -2
- pydantic_ai/ui/_event_stream.py +4 -4
- pydantic_ai/ui/ag_ui/_event_stream.py +11 -2
- pydantic_ai/ui/ag_ui/app.py +8 -1
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/METADATA +9 -7
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/RECORD +48 -48
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.9.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,57 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
-
from
|
|
5
|
-
from typing import Annotated, Any, Literal
|
|
4
|
+
from typing import Any, Literal
|
|
6
5
|
|
|
7
|
-
from pydantic import ConfigDict, Discriminator, with_config
|
|
8
6
|
from temporalio import activity, workflow
|
|
9
7
|
from temporalio.workflow import ActivityConfig
|
|
10
|
-
from typing_extensions import assert_never
|
|
11
8
|
|
|
12
9
|
from pydantic_ai import FunctionToolset, ToolsetTool
|
|
13
|
-
from pydantic_ai.exceptions import
|
|
10
|
+
from pydantic_ai.exceptions import UserError
|
|
14
11
|
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
15
12
|
from pydantic_ai.toolsets.function import FunctionToolsetTool
|
|
16
13
|
|
|
17
14
|
from ._run_context import TemporalRunContext
|
|
18
|
-
from ._toolset import
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class _CallToolParams:
|
|
24
|
-
name: str
|
|
25
|
-
tool_args: dict[str, Any]
|
|
26
|
-
serialized_run_context: Any
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class _ApprovalRequired:
|
|
31
|
-
kind: Literal['approval_required'] = 'approval_required'
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@dataclass
|
|
35
|
-
class _CallDeferred:
|
|
36
|
-
kind: Literal['call_deferred'] = 'call_deferred'
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class _ModelRetry:
|
|
41
|
-
message: str
|
|
42
|
-
kind: Literal['model_retry'] = 'model_retry'
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dataclass
|
|
46
|
-
class _ToolReturn:
|
|
47
|
-
result: Any
|
|
48
|
-
kind: Literal['tool_return'] = 'tool_return'
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
_CallToolResult = Annotated[
|
|
52
|
-
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
53
|
-
Discriminator('kind'),
|
|
54
|
-
]
|
|
15
|
+
from ._toolset import (
|
|
16
|
+
CallToolParams,
|
|
17
|
+
CallToolResult,
|
|
18
|
+
TemporalWrapperToolset,
|
|
19
|
+
)
|
|
55
20
|
|
|
56
21
|
|
|
57
22
|
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
@@ -70,7 +35,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
70
35
|
self.tool_activity_config = tool_activity_config
|
|
71
36
|
self.run_context_type = run_context_type
|
|
72
37
|
|
|
73
|
-
async def call_tool_activity(params:
|
|
38
|
+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
|
|
74
39
|
name = params.name
|
|
75
40
|
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
76
41
|
try:
|
|
@@ -84,15 +49,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
84
49
|
# The tool args will already have been validated into their proper types in the `ToolManager`,
|
|
85
50
|
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
|
|
86
51
|
args_dict = tool.args_validator.validate_python(params.tool_args)
|
|
87
|
-
|
|
88
|
-
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
|
|
89
|
-
return _ToolReturn(result=result)
|
|
90
|
-
except ApprovalRequired:
|
|
91
|
-
return _ApprovalRequired()
|
|
92
|
-
except CallDeferred:
|
|
93
|
-
return _CallDeferred()
|
|
94
|
-
except ModelRetry as e:
|
|
95
|
-
return _ModelRetry(message=e.message)
|
|
52
|
+
return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool))
|
|
96
53
|
|
|
97
54
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
98
55
|
call_tool_activity.__annotations__['deps'] = deps_type
|
|
@@ -123,25 +80,18 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
123
80
|
|
|
124
81
|
tool_activity_config = self.activity_config | tool_activity_config
|
|
125
82
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
83
|
+
return self._unwrap_call_tool_result(
|
|
84
|
+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
85
|
+
activity=self.call_tool_activity,
|
|
86
|
+
args=[
|
|
87
|
+
CallToolParams(
|
|
88
|
+
name=name,
|
|
89
|
+
tool_args=tool_args,
|
|
90
|
+
serialized_run_context=serialized_run_context,
|
|
91
|
+
tool_def=None,
|
|
92
|
+
),
|
|
93
|
+
ctx.deps,
|
|
94
|
+
],
|
|
95
|
+
**tool_activity_config,
|
|
96
|
+
)
|
|
137
97
|
)
|
|
138
|
-
if isinstance(result, _ApprovalRequired):
|
|
139
|
-
raise ApprovalRequired()
|
|
140
|
-
elif isinstance(result, _CallDeferred):
|
|
141
|
-
raise CallDeferred()
|
|
142
|
-
elif isinstance(result, _ModelRetry):
|
|
143
|
-
raise ModelRetry(result.message)
|
|
144
|
-
elif isinstance(result, _ToolReturn):
|
|
145
|
-
return result.result
|
|
146
|
-
else:
|
|
147
|
-
assert_never(result)
|
|
@@ -11,11 +11,15 @@ from typing_extensions import Self
|
|
|
11
11
|
|
|
12
12
|
from pydantic_ai import ToolsetTool
|
|
13
13
|
from pydantic_ai.exceptions import UserError
|
|
14
|
-
from pydantic_ai.mcp import MCPServer
|
|
14
|
+
from pydantic_ai.mcp import MCPServer
|
|
15
15
|
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
|
|
16
16
|
|
|
17
17
|
from ._run_context import TemporalRunContext
|
|
18
|
-
from ._toolset import
|
|
18
|
+
from ._toolset import (
|
|
19
|
+
CallToolParams,
|
|
20
|
+
CallToolResult,
|
|
21
|
+
TemporalWrapperToolset,
|
|
22
|
+
)
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
@dataclass
|
|
@@ -24,15 +28,6 @@ class _GetToolsParams:
|
|
|
24
28
|
serialized_run_context: Any
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
@dataclass
|
|
28
|
-
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
29
|
-
class _CallToolParams:
|
|
30
|
-
name: str
|
|
31
|
-
tool_args: dict[str, Any]
|
|
32
|
-
serialized_run_context: Any
|
|
33
|
-
tool_def: ToolDefinition
|
|
34
|
-
|
|
35
|
-
|
|
36
31
|
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
37
32
|
def __init__(
|
|
38
33
|
self,
|
|
@@ -72,13 +67,16 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
|
72
67
|
get_tools_activity
|
|
73
68
|
)
|
|
74
69
|
|
|
75
|
-
async def call_tool_activity(params:
|
|
70
|
+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
|
|
76
71
|
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
72
|
+
assert isinstance(params.tool_def, ToolDefinition)
|
|
73
|
+
return await self._wrap_call_tool_result(
|
|
74
|
+
self.wrapped.call_tool(
|
|
75
|
+
params.name,
|
|
76
|
+
params.tool_args,
|
|
77
|
+
run_context,
|
|
78
|
+
self.tool_for_tool_def(params.tool_def),
|
|
79
|
+
)
|
|
82
80
|
)
|
|
83
81
|
|
|
84
82
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
@@ -125,22 +123,24 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
|
125
123
|
tool_args: dict[str, Any],
|
|
126
124
|
ctx: RunContext[AgentDepsT],
|
|
127
125
|
tool: ToolsetTool[AgentDepsT],
|
|
128
|
-
) ->
|
|
126
|
+
) -> CallToolResult:
|
|
129
127
|
if not workflow.in_workflow():
|
|
130
128
|
return await super().call_tool(name, tool_args, ctx, tool)
|
|
131
129
|
|
|
132
130
|
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
|
|
133
131
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
134
|
-
return
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
132
|
+
return self._unwrap_call_tool_result(
|
|
133
|
+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
134
|
+
activity=self.call_tool_activity,
|
|
135
|
+
args=[
|
|
136
|
+
CallToolParams(
|
|
137
|
+
name=name,
|
|
138
|
+
tool_args=tool_args,
|
|
139
|
+
serialized_run_context=serialized_run_context,
|
|
140
|
+
tool_def=tool.tool_def,
|
|
141
|
+
),
|
|
142
|
+
ctx.deps,
|
|
143
|
+
],
|
|
144
|
+
**tool_activity_config,
|
|
145
|
+
)
|
|
146
146
|
)
|
|
@@ -2,14 +2,19 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from typing_extensions import TypeVar
|
|
6
|
+
|
|
5
7
|
from pydantic_ai.exceptions import UserError
|
|
6
|
-
from pydantic_ai.tools import
|
|
8
|
+
from pydantic_ai.tools import RunContext
|
|
9
|
+
|
|
10
|
+
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
|
|
11
|
+
"""Type variable for the agent dependencies in `RunContext`."""
|
|
7
12
|
|
|
8
13
|
|
|
9
14
|
class TemporalRunContext(RunContext[AgentDepsT]):
|
|
10
15
|
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
|
|
11
16
|
|
|
12
|
-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `
|
|
17
|
+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` attributes will be available.
|
|
13
18
|
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
|
|
14
19
|
"""
|
|
15
20
|
|
|
@@ -44,9 +49,10 @@ class TemporalRunContext(RunContext[AgentDepsT]):
|
|
|
44
49
|
'retry': ctx.retry,
|
|
45
50
|
'max_retries': ctx.max_retries,
|
|
46
51
|
'run_step': ctx.run_step,
|
|
52
|
+
'partial_output': ctx.partial_output,
|
|
47
53
|
}
|
|
48
54
|
|
|
49
55
|
@classmethod
|
|
50
|
-
def deserialize_run_context(cls, ctx: dict[str, Any], deps:
|
|
56
|
+
def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]:
|
|
51
57
|
"""Deserialize the run context from a `dict[str, Any]`."""
|
|
52
58
|
return cls(**ctx, deps=deps)
|
|
@@ -1,17 +1,58 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Annotated, Any, Literal
|
|
6
7
|
|
|
8
|
+
from pydantic import ConfigDict, Discriminator, with_config
|
|
7
9
|
from temporalio.workflow import ActivityConfig
|
|
10
|
+
from typing_extensions import assert_never
|
|
8
11
|
|
|
9
12
|
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
|
|
10
|
-
from pydantic_ai.
|
|
13
|
+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry
|
|
14
|
+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
|
|
11
15
|
|
|
12
16
|
from ._run_context import TemporalRunContext
|
|
13
17
|
|
|
14
18
|
|
|
19
|
+
@dataclass
|
|
20
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
21
|
+
class CallToolParams:
|
|
22
|
+
name: str
|
|
23
|
+
tool_args: dict[str, Any]
|
|
24
|
+
serialized_run_context: Any
|
|
25
|
+
tool_def: ToolDefinition | None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _ApprovalRequired:
|
|
30
|
+
kind: Literal['approval_required'] = 'approval_required'
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class _CallDeferred:
|
|
35
|
+
kind: Literal['call_deferred'] = 'call_deferred'
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class _ModelRetry:
|
|
40
|
+
message: str
|
|
41
|
+
kind: Literal['model_retry'] = 'model_retry'
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class _ToolReturn:
|
|
46
|
+
result: Any
|
|
47
|
+
kind: Literal['tool_return'] = 'tool_return'
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
CallToolResult = Annotated[
|
|
51
|
+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
52
|
+
Discriminator('kind'),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
15
56
|
class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
16
57
|
@property
|
|
17
58
|
def id(self) -> str:
|
|
@@ -30,6 +71,29 @@ class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
|
30
71
|
# Temporalized toolsets cannot be swapped out after the fact.
|
|
31
72
|
return self
|
|
32
73
|
|
|
74
|
+
async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
|
|
75
|
+
try:
|
|
76
|
+
result = await coro
|
|
77
|
+
return _ToolReturn(result=result)
|
|
78
|
+
except ApprovalRequired:
|
|
79
|
+
return _ApprovalRequired()
|
|
80
|
+
except CallDeferred:
|
|
81
|
+
return _CallDeferred()
|
|
82
|
+
except ModelRetry as e:
|
|
83
|
+
return _ModelRetry(message=e.message)
|
|
84
|
+
|
|
85
|
+
def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
|
|
86
|
+
if isinstance(result, _ToolReturn):
|
|
87
|
+
return result.result
|
|
88
|
+
elif isinstance(result, _ApprovalRequired):
|
|
89
|
+
raise ApprovalRequired()
|
|
90
|
+
elif isinstance(result, _CallDeferred):
|
|
91
|
+
raise CallDeferred()
|
|
92
|
+
elif isinstance(result, _ModelRetry):
|
|
93
|
+
raise ModelRetry(result.message)
|
|
94
|
+
else:
|
|
95
|
+
assert_never(result)
|
|
96
|
+
|
|
33
97
|
|
|
34
98
|
def temporalize_toolset(
|
|
35
99
|
toolset: AbstractToolset[AgentDepsT],
|
pydantic_ai/mcp.py
CHANGED
|
@@ -726,9 +726,9 @@ class _MCPServerHTTP(MCPServer):
|
|
|
726
726
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
727
727
|
MemoryObjectSendStream[SessionMessage],
|
|
728
728
|
]
|
|
729
|
-
]:
|
|
729
|
+
]:
|
|
730
730
|
if self.http_client and self.headers:
|
|
731
|
-
raise ValueError('`http_client` is mutually exclusive with `headers`.')
|
|
731
|
+
raise ValueError('`http_client` is mutually exclusive with `headers`.') # pragma: no cover
|
|
732
732
|
|
|
733
733
|
transport_client_partial = functools.partial(
|
|
734
734
|
self._transport_client,
|
|
@@ -737,7 +737,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
737
737
|
sse_read_timeout=self.read_timeout,
|
|
738
738
|
)
|
|
739
739
|
|
|
740
|
-
if self.http_client is not None:
|
|
740
|
+
if self.http_client is not None: # pragma: no cover
|
|
741
741
|
|
|
742
742
|
def httpx_client_factory(
|
|
743
743
|
headers: dict[str, str] | None = None,
|
|
@@ -866,7 +866,7 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
|
866
866
|
|
|
867
867
|
@property
|
|
868
868
|
def _transport_client(self):
|
|
869
|
-
return streamablehttp_client
|
|
869
|
+
return streamablehttp_client
|
|
870
870
|
|
|
871
871
|
def __eq__(self, value: object, /) -> bool:
|
|
872
872
|
return super().__eq__(value) and isinstance(value, MCPServerStreamableHTTP) and self.url == value.url
|
pydantic_ai/messages.py
CHANGED
|
@@ -34,6 +34,7 @@ DocumentMediaType: TypeAlias = Literal[
|
|
|
34
34
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
|
35
35
|
'text/html',
|
|
36
36
|
'text/markdown',
|
|
37
|
+
'application/msword',
|
|
37
38
|
'application/vnd.ms-excel',
|
|
38
39
|
]
|
|
39
40
|
VideoMediaType: TypeAlias = Literal[
|
|
@@ -434,8 +435,12 @@ class DocumentUrl(FileUrl):
|
|
|
434
435
|
return 'application/pdf'
|
|
435
436
|
elif self.url.endswith('.rtf'):
|
|
436
437
|
return 'application/rtf'
|
|
438
|
+
elif self.url.endswith('.doc'):
|
|
439
|
+
return 'application/msword'
|
|
437
440
|
elif self.url.endswith('.docx'):
|
|
438
441
|
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
|
442
|
+
elif self.url.endswith('.xls'):
|
|
443
|
+
return 'application/vnd.ms-excel'
|
|
439
444
|
elif self.url.endswith('.xlsx'):
|
|
440
445
|
return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
|
441
446
|
|
|
@@ -480,7 +485,7 @@ class BinaryContent:
|
|
|
480
485
|
"""
|
|
481
486
|
|
|
482
487
|
_identifier: Annotated[str | None, pydantic.Field(alias='identifier', default=None, exclude=True)] = field(
|
|
483
|
-
compare=False, default=None
|
|
488
|
+
compare=False, default=None
|
|
484
489
|
)
|
|
485
490
|
|
|
486
491
|
kind: Literal['binary'] = 'binary'
|
|
@@ -645,6 +650,7 @@ _document_format_lookup: dict[str, DocumentFormat] = {
|
|
|
645
650
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'xlsx',
|
|
646
651
|
'text/html': 'html',
|
|
647
652
|
'text/markdown': 'md',
|
|
653
|
+
'application/msword': 'doc',
|
|
648
654
|
'application/vnd.ms-excel': 'xls',
|
|
649
655
|
}
|
|
650
656
|
_audio_format_lookup: dict[str, AudioFormat] = {
|
|
@@ -882,7 +888,10 @@ class RetryPromptPart:
|
|
|
882
888
|
description = self.content
|
|
883
889
|
else:
|
|
884
890
|
json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
|
|
885
|
-
|
|
891
|
+
plural = isinstance(self.content, list) and len(self.content) != 1
|
|
892
|
+
description = (
|
|
893
|
+
f'{len(self.content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```'
|
|
894
|
+
)
|
|
886
895
|
return f'{description}\n\nFix the errors and try again.'
|
|
887
896
|
|
|
888
897
|
def otel_event(self, settings: InstrumentationSettings) -> Event:
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -21,7 +21,7 @@ from typing_extensions import TypeAliasType, TypedDict
|
|
|
21
21
|
|
|
22
22
|
from .. import _utils
|
|
23
23
|
from .._json_schema import JsonSchemaTransformer
|
|
24
|
-
from .._output import OutputObjectDefinition
|
|
24
|
+
from .._output import OutputObjectDefinition, PromptedOutputSchema
|
|
25
25
|
from .._parts_manager import ModelResponsePartsManager
|
|
26
26
|
from .._run_context import RunContext
|
|
27
27
|
from ..builtin_tools import AbstractBuiltinTool
|
|
@@ -309,6 +309,7 @@ class ModelRequestParameters:
|
|
|
309
309
|
output_mode: OutputMode = 'text'
|
|
310
310
|
output_object: OutputObjectDefinition | None = None
|
|
311
311
|
output_tools: list[ToolDefinition] = field(default_factory=list)
|
|
312
|
+
prompted_output_template: str | None = None
|
|
312
313
|
allow_text_output: bool = True
|
|
313
314
|
allow_image_output: bool = False
|
|
314
315
|
|
|
@@ -316,6 +317,12 @@ class ModelRequestParameters:
|
|
|
316
317
|
def tool_defs(self) -> dict[str, ToolDefinition]:
|
|
317
318
|
return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]}
|
|
318
319
|
|
|
320
|
+
@cached_property
|
|
321
|
+
def prompted_output_instructions(self) -> str | None:
|
|
322
|
+
if self.output_mode == 'prompted' and self.prompted_output_template and self.output_object:
|
|
323
|
+
return PromptedOutputSchema.build_instructions(self.prompted_output_template, self.output_object)
|
|
324
|
+
return None
|
|
325
|
+
|
|
319
326
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
320
327
|
|
|
321
328
|
|
|
@@ -408,23 +415,52 @@ class Model(ABC):
|
|
|
408
415
|
) -> tuple[ModelSettings | None, ModelRequestParameters]:
|
|
409
416
|
"""Prepare request inputs before they are passed to the provider.
|
|
410
417
|
|
|
411
|
-
This merges the given
|
|
412
|
-
|
|
418
|
+
This merges the given `model_settings` with the model's own `settings` attribute and ensures
|
|
419
|
+
`customize_request_parameters` is applied to the resolved
|
|
413
420
|
[`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if
|
|
414
421
|
they need to customize the preparation flow further, but most implementations should simply call
|
|
415
|
-
|
|
422
|
+
`self.prepare_request(...)` at the start of their `request` (and related) methods.
|
|
416
423
|
"""
|
|
417
424
|
model_settings = merge_model_settings(self.settings, model_settings)
|
|
418
425
|
|
|
419
|
-
|
|
426
|
+
params = self.customize_request_parameters(model_request_parameters)
|
|
427
|
+
|
|
428
|
+
if builtin_tools := params.builtin_tools:
|
|
420
429
|
# Deduplicate builtin tools
|
|
421
|
-
|
|
422
|
-
|
|
430
|
+
params = replace(
|
|
431
|
+
params,
|
|
423
432
|
builtin_tools=list({tool.unique_id: tool for tool in builtin_tools}.values()),
|
|
424
433
|
)
|
|
425
434
|
|
|
426
|
-
|
|
427
|
-
|
|
435
|
+
if params.output_mode == 'auto':
|
|
436
|
+
output_mode = self.profile.default_structured_output_mode
|
|
437
|
+
params = replace(
|
|
438
|
+
params,
|
|
439
|
+
output_mode=output_mode,
|
|
440
|
+
allow_text_output=output_mode in ('native', 'prompted'),
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Reset irrelevant fields
|
|
444
|
+
if params.output_tools and params.output_mode != 'tool':
|
|
445
|
+
params = replace(params, output_tools=[])
|
|
446
|
+
if params.output_object and params.output_mode not in ('native', 'prompted'):
|
|
447
|
+
params = replace(params, output_object=None)
|
|
448
|
+
if params.prompted_output_template and params.output_mode != 'prompted':
|
|
449
|
+
params = replace(params, prompted_output_template=None) # pragma: no cover
|
|
450
|
+
|
|
451
|
+
# Set default prompted output template
|
|
452
|
+
if params.output_mode == 'prompted' and not params.prompted_output_template:
|
|
453
|
+
params = replace(params, prompted_output_template=self.profile.prompted_output_template)
|
|
454
|
+
|
|
455
|
+
# Check if output mode is supported
|
|
456
|
+
if params.output_mode == 'native' and not self.profile.supports_json_schema_output:
|
|
457
|
+
raise UserError('Native structured output is not supported by this model.')
|
|
458
|
+
if params.output_mode == 'tool' and not self.profile.supports_tools:
|
|
459
|
+
raise UserError('Tool output is not supported by this model.')
|
|
460
|
+
if params.allow_image_output and not self.profile.supports_image_output:
|
|
461
|
+
raise UserError('Image output is not supported by this model.')
|
|
462
|
+
|
|
463
|
+
return model_settings, params
|
|
428
464
|
|
|
429
465
|
@property
|
|
430
466
|
@abstractmethod
|
|
@@ -462,13 +498,17 @@ class Model(ABC):
|
|
|
462
498
|
return None
|
|
463
499
|
|
|
464
500
|
@staticmethod
|
|
465
|
-
def _get_instructions(
|
|
501
|
+
def _get_instructions(
|
|
502
|
+
messages: list[ModelMessage], model_request_parameters: ModelRequestParameters | None = None
|
|
503
|
+
) -> str | None:
|
|
466
504
|
"""Get instructions from the first ModelRequest found when iterating messages in reverse.
|
|
467
505
|
|
|
468
506
|
In the case that a "mock" request was generated to include a tool-return part for a result tool,
|
|
469
507
|
we want to use the instructions from the second-to-most-recent request (which should correspond to the
|
|
470
508
|
original request that generated the response that resulted in the tool-return part).
|
|
471
509
|
"""
|
|
510
|
+
instructions = None
|
|
511
|
+
|
|
472
512
|
last_two_requests: list[ModelRequest] = []
|
|
473
513
|
for message in reversed(messages):
|
|
474
514
|
if isinstance(message, ModelRequest):
|
|
@@ -476,33 +516,38 @@ class Model(ABC):
|
|
|
476
516
|
if len(last_two_requests) == 2:
|
|
477
517
|
break
|
|
478
518
|
if message.instructions is not None:
|
|
479
|
-
|
|
519
|
+
instructions = message.instructions
|
|
520
|
+
break
|
|
480
521
|
|
|
481
522
|
# If we don't have two requests, and we didn't already return instructions, there are definitely not any:
|
|
482
|
-
if len(last_two_requests)
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
523
|
+
if instructions is None and len(last_two_requests) == 2:
|
|
524
|
+
most_recent_request = last_two_requests[0]
|
|
525
|
+
second_most_recent_request = last_two_requests[1]
|
|
526
|
+
|
|
527
|
+
# If we've gotten this far and the most recent request consists of only tool-return parts or retry-prompt parts,
|
|
528
|
+
# we use the instructions from the second-to-most-recent request. This is necessary because when handling
|
|
529
|
+
# result tools, we generate a "mock" ModelRequest with a tool-return part for it, and that ModelRequest will not
|
|
530
|
+
# have the relevant instructions from the agent.
|
|
531
|
+
|
|
532
|
+
# While it's possible that you could have a message history where the most recent request has only tool returns,
|
|
533
|
+
# I believe there is no way to achieve that would _change_ the instructions without manually crafting the most
|
|
534
|
+
# recent message. That might make sense in principle for some usage pattern, but it's enough of an edge case
|
|
535
|
+
# that I think it's not worth worrying about, since you can work around this by inserting another ModelRequest
|
|
536
|
+
# with no parts at all immediately before the request that has the tool calls (that works because we only look
|
|
537
|
+
# at the two most recent ModelRequests here).
|
|
538
|
+
|
|
539
|
+
# If you have a use case where this causes pain, please open a GitHub issue and we can discuss alternatives.
|
|
540
|
+
|
|
541
|
+
if all(p.part_kind == 'tool-return' or p.part_kind == 'retry-prompt' for p in most_recent_request.parts):
|
|
542
|
+
instructions = second_most_recent_request.instructions
|
|
543
|
+
|
|
544
|
+
if model_request_parameters and (output_instructions := model_request_parameters.prompted_output_instructions):
|
|
545
|
+
if instructions:
|
|
546
|
+
instructions = '\n\n'.join([instructions, output_instructions])
|
|
547
|
+
else:
|
|
548
|
+
instructions = output_instructions
|
|
549
|
+
|
|
550
|
+
return instructions
|
|
506
551
|
|
|
507
552
|
|
|
508
553
|
@dataclass
|