pydantic-ai-slim 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -0
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +32 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/_utils.py +7 -1
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +217 -1026
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/builtin_tools.py +105 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +75 -13
- pydantic_ai/models/__init__.py +66 -8
- pydantic_ai/models/anthropic.py +135 -18
- pydantic_ai/models/bedrock.py +16 -5
- pydantic_ai/models/cohere.py +11 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +18 -4
- pydantic_ai/models/gemini.py +20 -9
- pydantic_ai/models/google.py +53 -15
- pydantic_ai/models/groq.py +47 -11
- pydantic_ai/models/huggingface.py +26 -11
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +27 -17
- pydantic_ai/models/openai.py +97 -33
- pydantic_ai/models/test.py +12 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/groq.py +23 -0
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +7 -5
- pydantic_ai_slim-0.7.0.dist-info/RECORD +115 -0
- pydantic_ai_slim-0.6.1.dist-info/RECORD +0 -100
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable, Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import ConfigDict, with_config
|
|
7
|
+
from temporalio import activity, workflow
|
|
8
|
+
from temporalio.workflow import ActivityConfig
|
|
9
|
+
|
|
10
|
+
from pydantic_ai.exceptions import UserError
|
|
11
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
12
|
+
from pydantic_ai.toolsets import FunctionToolset, ToolsetTool
|
|
13
|
+
from pydantic_ai.toolsets.function import FunctionToolsetTool
|
|
14
|
+
|
|
15
|
+
from ._run_context import TemporalRunContext
|
|
16
|
+
from ._toolset import TemporalWrapperToolset
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
21
|
+
class _CallToolParams:
|
|
22
|
+
name: str
|
|
23
|
+
tool_args: dict[str, Any]
|
|
24
|
+
serialized_run_context: Any
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
toolset: FunctionToolset[AgentDepsT],
|
|
31
|
+
*,
|
|
32
|
+
activity_name_prefix: str,
|
|
33
|
+
activity_config: ActivityConfig,
|
|
34
|
+
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
|
|
35
|
+
deps_type: type[AgentDepsT],
|
|
36
|
+
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
|
|
37
|
+
):
|
|
38
|
+
super().__init__(toolset)
|
|
39
|
+
self.activity_config = activity_config
|
|
40
|
+
self.tool_activity_config = tool_activity_config
|
|
41
|
+
self.run_context_type = run_context_type
|
|
42
|
+
|
|
43
|
+
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> Any:
|
|
44
|
+
name = params.name
|
|
45
|
+
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
46
|
+
try:
|
|
47
|
+
tool = (await toolset.get_tools(ctx))[name]
|
|
48
|
+
except KeyError as e: # pragma: no cover
|
|
49
|
+
raise UserError(
|
|
50
|
+
f'Tool {name!r} not found in toolset {self.id!r}. '
|
|
51
|
+
'Removing or renaming tools during an agent run is not supported with Temporal.'
|
|
52
|
+
) from e
|
|
53
|
+
|
|
54
|
+
return await self.wrapped.call_tool(name, params.tool_args, ctx, tool)
|
|
55
|
+
|
|
56
|
+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
57
|
+
call_tool_activity.__annotations__['deps'] = deps_type
|
|
58
|
+
|
|
59
|
+
self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__toolset__{self.id}__call_tool')(
|
|
60
|
+
call_tool_activity
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def temporal_activities(self) -> list[Callable[..., Any]]:
|
|
65
|
+
return [self.call_tool_activity]
|
|
66
|
+
|
|
67
|
+
async def call_tool(
|
|
68
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
69
|
+
) -> Any:
|
|
70
|
+
if not workflow.in_workflow():
|
|
71
|
+
return await super().call_tool(name, tool_args, ctx, tool)
|
|
72
|
+
|
|
73
|
+
tool_activity_config = self.tool_activity_config.get(name, {})
|
|
74
|
+
if tool_activity_config is False:
|
|
75
|
+
assert isinstance(tool, FunctionToolsetTool)
|
|
76
|
+
if not tool.is_async:
|
|
77
|
+
raise UserError(
|
|
78
|
+
f'Temporal activity config for tool {name!r} has been explicitly set to `False` (activity disabled), '
|
|
79
|
+
'but non-async tools are run in threads which are not supported outside of an activity. Make the tool function async instead.'
|
|
80
|
+
)
|
|
81
|
+
return await super().call_tool(name, tool_args, ctx, tool)
|
|
82
|
+
|
|
83
|
+
tool_activity_config = self.activity_config | tool_activity_config
|
|
84
|
+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
85
|
+
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
86
|
+
activity=self.call_tool_activity,
|
|
87
|
+
args=[
|
|
88
|
+
_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context),
|
|
89
|
+
ctx.deps,
|
|
90
|
+
],
|
|
91
|
+
**tool_activity_config,
|
|
92
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
from logfire import Logfire
|
|
6
|
+
from opentelemetry.trace import get_tracer
|
|
7
|
+
from temporalio.client import ClientConfig, Plugin as ClientPlugin
|
|
8
|
+
from temporalio.contrib.opentelemetry import TracingInterceptor
|
|
9
|
+
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
|
|
10
|
+
from temporalio.service import ConnectConfig, ServiceClient
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _default_setup_logfire() -> Logfire:
|
|
14
|
+
import logfire
|
|
15
|
+
|
|
16
|
+
instance = logfire.configure()
|
|
17
|
+
logfire.instrument_pydantic_ai()
|
|
18
|
+
return instance
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LogfirePlugin(ClientPlugin):
|
|
22
|
+
"""Temporal client plugin for Logfire."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True):
|
|
25
|
+
self.setup_logfire = setup_logfire
|
|
26
|
+
self.metrics = metrics
|
|
27
|
+
|
|
28
|
+
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
29
|
+
interceptors = config.get('interceptors', [])
|
|
30
|
+
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
|
|
31
|
+
return super().configure_client(config)
|
|
32
|
+
|
|
33
|
+
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
|
|
34
|
+
logfire = self.setup_logfire()
|
|
35
|
+
|
|
36
|
+
if self.metrics:
|
|
37
|
+
logfire_config = logfire.config
|
|
38
|
+
token = logfire_config.token
|
|
39
|
+
if logfire_config.send_to_logfire and token is not None and logfire_config.metrics is not False:
|
|
40
|
+
base_url = logfire_config.advanced.generate_base_url(token)
|
|
41
|
+
metrics_url = base_url + '/v1/metrics'
|
|
42
|
+
headers = {'Authorization': f'Bearer {token}'}
|
|
43
|
+
|
|
44
|
+
config.runtime = Runtime(
|
|
45
|
+
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return await super().connect_service_client(config)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable, Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import ConfigDict, with_config
|
|
7
|
+
from temporalio import activity, workflow
|
|
8
|
+
from temporalio.workflow import ActivityConfig
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from pydantic_ai.exceptions import UserError
|
|
12
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
13
|
+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
|
|
14
|
+
from pydantic_ai.toolsets.abstract import ToolsetTool
|
|
15
|
+
|
|
16
|
+
from ._run_context import TemporalRunContext
|
|
17
|
+
from ._toolset import TemporalWrapperToolset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
22
|
+
class _GetToolsParams:
|
|
23
|
+
serialized_run_context: Any
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
28
|
+
class _CallToolParams:
|
|
29
|
+
name: str
|
|
30
|
+
tool_args: dict[str, Any]
|
|
31
|
+
serialized_run_context: Any
|
|
32
|
+
tool_def: ToolDefinition
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
server: MCPServer,
|
|
39
|
+
*,
|
|
40
|
+
activity_name_prefix: str,
|
|
41
|
+
activity_config: ActivityConfig,
|
|
42
|
+
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
|
|
43
|
+
deps_type: type[AgentDepsT],
|
|
44
|
+
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
|
|
45
|
+
):
|
|
46
|
+
super().__init__(server)
|
|
47
|
+
self.activity_config = activity_config
|
|
48
|
+
|
|
49
|
+
self.tool_activity_config: dict[str, ActivityConfig] = {}
|
|
50
|
+
for tool_name, tool_config in tool_activity_config.items():
|
|
51
|
+
if tool_config is False:
|
|
52
|
+
raise UserError(
|
|
53
|
+
f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), '
|
|
54
|
+
'but MCP tools require the use of IO and so cannot be run outside of an activity.'
|
|
55
|
+
)
|
|
56
|
+
self.tool_activity_config[tool_name] = tool_config
|
|
57
|
+
|
|
58
|
+
self.run_context_type = run_context_type
|
|
59
|
+
|
|
60
|
+
async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
|
|
61
|
+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
62
|
+
tools = await self.wrapped.get_tools(run_context)
|
|
63
|
+
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
|
|
64
|
+
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
|
|
65
|
+
return {name: tool.tool_def for name, tool in tools.items()}
|
|
66
|
+
|
|
67
|
+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
68
|
+
get_tools_activity.__annotations__['deps'] = deps_type
|
|
69
|
+
|
|
70
|
+
self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')(
|
|
71
|
+
get_tools_activity
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult:
|
|
75
|
+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
76
|
+
return await self.wrapped.call_tool(
|
|
77
|
+
params.name,
|
|
78
|
+
params.tool_args,
|
|
79
|
+
run_context,
|
|
80
|
+
self.tool_for_tool_def(params.tool_def),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
84
|
+
call_tool_activity.__annotations__['deps'] = deps_type
|
|
85
|
+
|
|
86
|
+
self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')(
|
|
87
|
+
call_tool_activity
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
|
|
91
|
+
assert isinstance(self.wrapped, MCPServer)
|
|
92
|
+
return self.wrapped.tool_for_tool_def(tool_def)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def temporal_activities(self) -> list[Callable[..., Any]]:
|
|
96
|
+
return [self.get_tools_activity, self.call_tool_activity]
|
|
97
|
+
|
|
98
|
+
async def __aenter__(self) -> Self:
|
|
99
|
+
# The wrapped MCPServer enters itself around listing and calling tools
|
|
100
|
+
# so we don't need to enter it here (nor could we because we're not inside a Temporal activity).
|
|
101
|
+
return self
|
|
102
|
+
|
|
103
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
107
|
+
if not workflow.in_workflow():
|
|
108
|
+
return await super().get_tools(ctx)
|
|
109
|
+
|
|
110
|
+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
111
|
+
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
112
|
+
activity=self.get_tools_activity,
|
|
113
|
+
args=[
|
|
114
|
+
_GetToolsParams(serialized_run_context=serialized_run_context),
|
|
115
|
+
ctx.deps,
|
|
116
|
+
],
|
|
117
|
+
**self.activity_config,
|
|
118
|
+
)
|
|
119
|
+
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}
|
|
120
|
+
|
|
121
|
+
async def call_tool(
|
|
122
|
+
self,
|
|
123
|
+
name: str,
|
|
124
|
+
tool_args: dict[str, Any],
|
|
125
|
+
ctx: RunContext[AgentDepsT],
|
|
126
|
+
tool: ToolsetTool[AgentDepsT],
|
|
127
|
+
) -> ToolResult:
|
|
128
|
+
if not workflow.in_workflow():
|
|
129
|
+
return await super().call_tool(name, tool_args, ctx, tool)
|
|
130
|
+
|
|
131
|
+
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
|
|
132
|
+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
133
|
+
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
134
|
+
activity=self.call_tool_activity,
|
|
135
|
+
args=[
|
|
136
|
+
_CallToolParams(
|
|
137
|
+
name=name,
|
|
138
|
+
tool_args=tool_args,
|
|
139
|
+
serialized_run_context=serialized_run_context,
|
|
140
|
+
tool_def=tool.tool_def,
|
|
141
|
+
),
|
|
142
|
+
ctx.deps,
|
|
143
|
+
],
|
|
144
|
+
**tool_activity_config,
|
|
145
|
+
)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from pydantic import ConfigDict, with_config
|
|
10
|
+
from temporalio import activity, workflow
|
|
11
|
+
from temporalio.workflow import ActivityConfig
|
|
12
|
+
|
|
13
|
+
from pydantic_ai.agent import EventStreamHandler
|
|
14
|
+
from pydantic_ai.exceptions import UserError
|
|
15
|
+
from pydantic_ai.messages import (
|
|
16
|
+
ModelMessage,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
ModelResponseStreamEvent,
|
|
19
|
+
)
|
|
20
|
+
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
|
|
21
|
+
from pydantic_ai.models.wrapper import WrapperModel
|
|
22
|
+
from pydantic_ai.settings import ModelSettings
|
|
23
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
24
|
+
from pydantic_ai.usage import Usage
|
|
25
|
+
|
|
26
|
+
from ._run_context import TemporalRunContext
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
31
|
+
class _RequestParams:
|
|
32
|
+
messages: list[ModelMessage]
|
|
33
|
+
model_settings: ModelSettings | None
|
|
34
|
+
model_request_parameters: ModelRequestParameters
|
|
35
|
+
serialized_run_context: Any
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TemporalStreamedResponse(StreamedResponse):
|
|
39
|
+
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
|
|
40
|
+
super().__init__(model_request_parameters)
|
|
41
|
+
self.response = response
|
|
42
|
+
|
|
43
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
44
|
+
return
|
|
45
|
+
# noinspection PyUnreachableCode
|
|
46
|
+
yield
|
|
47
|
+
|
|
48
|
+
def get(self) -> ModelResponse:
|
|
49
|
+
return self.response
|
|
50
|
+
|
|
51
|
+
def usage(self) -> Usage:
|
|
52
|
+
return self.response.usage # pragma: no cover
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def model_name(self) -> str:
|
|
56
|
+
return self.response.model_name or '' # pragma: no cover
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def timestamp(self) -> datetime:
|
|
60
|
+
return self.response.timestamp # pragma: no cover
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TemporalModel(WrapperModel):
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
model: Model,
|
|
67
|
+
*,
|
|
68
|
+
activity_name_prefix: str,
|
|
69
|
+
activity_config: ActivityConfig,
|
|
70
|
+
deps_type: type[AgentDepsT],
|
|
71
|
+
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
|
|
72
|
+
event_stream_handler: EventStreamHandler[Any] | None = None,
|
|
73
|
+
):
|
|
74
|
+
super().__init__(model)
|
|
75
|
+
self.activity_config = activity_config
|
|
76
|
+
self.run_context_type = run_context_type
|
|
77
|
+
self.event_stream_handler = event_stream_handler
|
|
78
|
+
|
|
79
|
+
@activity.defn(name=f'{activity_name_prefix}__model_request')
|
|
80
|
+
async def request_activity(params: _RequestParams) -> ModelResponse:
|
|
81
|
+
return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters)
|
|
82
|
+
|
|
83
|
+
self.request_activity = request_activity
|
|
84
|
+
|
|
85
|
+
async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> ModelResponse:
|
|
86
|
+
# An error is raised in `request_stream` if no `event_stream_handler` is set.
|
|
87
|
+
assert self.event_stream_handler is not None
|
|
88
|
+
|
|
89
|
+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
90
|
+
async with self.wrapped.request_stream(
|
|
91
|
+
params.messages, params.model_settings, params.model_request_parameters, run_context
|
|
92
|
+
) as streamed_response:
|
|
93
|
+
await self.event_stream_handler(run_context, streamed_response)
|
|
94
|
+
|
|
95
|
+
async for _ in streamed_response:
|
|
96
|
+
pass
|
|
97
|
+
return streamed_response.get()
|
|
98
|
+
|
|
99
|
+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
100
|
+
request_stream_activity.__annotations__['deps'] = deps_type
|
|
101
|
+
|
|
102
|
+
self.request_stream_activity = activity.defn(name=f'{activity_name_prefix}__model_request_stream')(
|
|
103
|
+
request_stream_activity
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def temporal_activities(self) -> list[Callable[..., Any]]:
|
|
108
|
+
return [self.request_activity, self.request_stream_activity]
|
|
109
|
+
|
|
110
|
+
async def request(
|
|
111
|
+
self,
|
|
112
|
+
messages: list[ModelMessage],
|
|
113
|
+
model_settings: ModelSettings | None,
|
|
114
|
+
model_request_parameters: ModelRequestParameters,
|
|
115
|
+
) -> ModelResponse:
|
|
116
|
+
if not workflow.in_workflow():
|
|
117
|
+
return await super().request(messages, model_settings, model_request_parameters)
|
|
118
|
+
|
|
119
|
+
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
120
|
+
activity=self.request_activity,
|
|
121
|
+
arg=_RequestParams(
|
|
122
|
+
messages=messages,
|
|
123
|
+
model_settings=model_settings,
|
|
124
|
+
model_request_parameters=model_request_parameters,
|
|
125
|
+
serialized_run_context=None,
|
|
126
|
+
),
|
|
127
|
+
**self.activity_config,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
@asynccontextmanager
|
|
131
|
+
async def request_stream(
|
|
132
|
+
self,
|
|
133
|
+
messages: list[ModelMessage],
|
|
134
|
+
model_settings: ModelSettings | None,
|
|
135
|
+
model_request_parameters: ModelRequestParameters,
|
|
136
|
+
run_context: RunContext[Any] | None = None,
|
|
137
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
138
|
+
if not workflow.in_workflow():
|
|
139
|
+
async with super().request_stream(
|
|
140
|
+
messages, model_settings, model_request_parameters, run_context
|
|
141
|
+
) as streamed_response:
|
|
142
|
+
yield streamed_response
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
if run_context is None:
|
|
146
|
+
raise UserError(
|
|
147
|
+
'A Temporal model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
|
|
151
|
+
# and that only calls `request_stream` if `event_stream_handler` is set.
|
|
152
|
+
assert self.event_stream_handler is not None
|
|
153
|
+
|
|
154
|
+
serialized_run_context = self.run_context_type.serialize_run_context(run_context)
|
|
155
|
+
response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
156
|
+
activity=self.request_stream_activity,
|
|
157
|
+
args=[
|
|
158
|
+
_RequestParams(
|
|
159
|
+
messages=messages,
|
|
160
|
+
model_settings=model_settings,
|
|
161
|
+
model_request_parameters=model_request_parameters,
|
|
162
|
+
serialized_run_context=serialized_run_context,
|
|
163
|
+
),
|
|
164
|
+
run_context.deps,
|
|
165
|
+
],
|
|
166
|
+
**self.activity_config,
|
|
167
|
+
)
|
|
168
|
+
yield TemporalStreamedResponse(model_request_parameters, response)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.exceptions import UserError
|
|
6
|
+
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TemporalRunContext(RunContext[AgentDepsT]):
|
|
10
|
+
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
|
|
11
|
+
|
|
12
|
+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `retry` and `run_step` attributes will be available.
|
|
13
|
+
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, deps: AgentDepsT, **kwargs: Any):
|
|
17
|
+
self.__dict__ = {**kwargs, 'deps': deps}
|
|
18
|
+
setattr(
|
|
19
|
+
self,
|
|
20
|
+
'__dataclass_fields__',
|
|
21
|
+
{name: field for name, field in RunContext.__dataclass_fields__.items() if name in self.__dict__},
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def __getattribute__(self, name: str) -> Any:
|
|
25
|
+
try:
|
|
26
|
+
return super().__getattribute__(name)
|
|
27
|
+
except AttributeError as e: # pragma: no cover
|
|
28
|
+
if name in RunContext.__dataclass_fields__:
|
|
29
|
+
raise UserError(
|
|
30
|
+
f'{self.__class__.__name__!r} object has no attribute {name!r}. '
|
|
31
|
+
'To make the attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to `TemporalAgent`.'
|
|
32
|
+
)
|
|
33
|
+
else:
|
|
34
|
+
raise e
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
|
|
38
|
+
"""Serialize the run context to a `dict[str, Any]`."""
|
|
39
|
+
return {
|
|
40
|
+
'retries': ctx.retries,
|
|
41
|
+
'tool_call_id': ctx.tool_call_id,
|
|
42
|
+
'tool_name': ctx.tool_name,
|
|
43
|
+
'retry': ctx.retry,
|
|
44
|
+
'run_step': ctx.run_step,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def deserialize_run_context(cls, ctx: dict[str, Any], deps: AgentDepsT) -> TemporalRunContext[AgentDepsT]:
|
|
49
|
+
"""Deserialize the run context from a `dict[str, Any]`."""
|
|
50
|
+
return cls(**ctx, deps=deps)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Callable, Literal
|
|
5
|
+
|
|
6
|
+
from temporalio.workflow import ActivityConfig
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.mcp import MCPServer
|
|
9
|
+
from pydantic_ai.tools import AgentDepsT
|
|
10
|
+
from pydantic_ai.toolsets.abstract import AbstractToolset
|
|
11
|
+
from pydantic_ai.toolsets.function import FunctionToolset
|
|
12
|
+
from pydantic_ai.toolsets.wrapper import WrapperToolset
|
|
13
|
+
|
|
14
|
+
from ._run_context import TemporalRunContext
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
|
|
18
|
+
@property
|
|
19
|
+
def id(self) -> str:
|
|
20
|
+
# An error is raised in `TemporalAgent` if no `id` is set.
|
|
21
|
+
assert self.wrapped.id is not None
|
|
22
|
+
return self.wrapped.id
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def temporal_activities(self) -> list[Callable[..., Any]]:
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
def visit_and_replace(
|
|
30
|
+
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
|
|
31
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
32
|
+
# Temporalized toolsets cannot be swapped out after the fact.
|
|
33
|
+
return self
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def temporalize_toolset(
|
|
37
|
+
toolset: AbstractToolset[AgentDepsT],
|
|
38
|
+
activity_name_prefix: str,
|
|
39
|
+
activity_config: ActivityConfig,
|
|
40
|
+
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
|
|
41
|
+
deps_type: type[AgentDepsT],
|
|
42
|
+
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
|
|
43
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
44
|
+
"""Temporalize a toolset.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
toolset: The toolset to temporalize.
|
|
48
|
+
activity_name_prefix: Prefix for Temporal activity names.
|
|
49
|
+
activity_config: The Temporal activity config to use.
|
|
50
|
+
tool_activity_config: The Temporal activity config to use for specific tools identified by tool name.
|
|
51
|
+
deps_type: The type of agent's dependencies object. It needs to be serializable using Pydantic's `TypeAdapter`.
|
|
52
|
+
run_context_type: The `TemporalRunContext` (sub)class that's used to serialize and deserialize the run context.
|
|
53
|
+
"""
|
|
54
|
+
if isinstance(toolset, FunctionToolset):
|
|
55
|
+
from ._function_toolset import TemporalFunctionToolset
|
|
56
|
+
|
|
57
|
+
return TemporalFunctionToolset(
|
|
58
|
+
toolset,
|
|
59
|
+
activity_name_prefix=activity_name_prefix,
|
|
60
|
+
activity_config=activity_config,
|
|
61
|
+
tool_activity_config=tool_activity_config,
|
|
62
|
+
deps_type=deps_type,
|
|
63
|
+
run_context_type=run_context_type,
|
|
64
|
+
)
|
|
65
|
+
elif isinstance(toolset, MCPServer):
|
|
66
|
+
from ._mcp_server import TemporalMCPServer
|
|
67
|
+
|
|
68
|
+
return TemporalMCPServer(
|
|
69
|
+
toolset,
|
|
70
|
+
activity_name_prefix=activity_name_prefix,
|
|
71
|
+
activity_config=activity_config,
|
|
72
|
+
tool_activity_config=tool_activity_config,
|
|
73
|
+
deps_type=deps_type,
|
|
74
|
+
run_context_type=run_context_type,
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
return toolset
|
pydantic_ai/ext/aci.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
|
-
|
|
2
|
-
try:
|
|
3
|
-
from aci import ACI
|
|
4
|
-
except ImportError as _import_error:
|
|
5
|
-
raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error
|
|
1
|
+
from __future__ import annotations
|
|
6
2
|
|
|
7
3
|
from collections.abc import Sequence
|
|
8
4
|
from typing import Any
|
|
9
5
|
|
|
10
|
-
from aci import ACI
|
|
11
|
-
|
|
12
6
|
from pydantic_ai.tools import Tool
|
|
13
7
|
from pydantic_ai.toolsets.function import FunctionToolset
|
|
14
8
|
|
|
9
|
+
try:
|
|
10
|
+
from aci import ACI
|
|
11
|
+
except ImportError as _import_error:
|
|
12
|
+
raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error
|
|
13
|
+
|
|
15
14
|
|
|
16
15
|
def _clean_schema(schema):
|
|
17
16
|
if isinstance(schema, dict):
|
|
@@ -71,5 +70,7 @@ def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool:
|
|
|
71
70
|
class ACIToolset(FunctionToolset):
|
|
72
71
|
"""A toolset that wraps ACI.dev tools."""
|
|
73
72
|
|
|
74
|
-
def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str):
|
|
75
|
-
super().__init__(
|
|
73
|
+
def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, *, id: str | None = None):
|
|
74
|
+
super().__init__(
|
|
75
|
+
[tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions], id=id
|
|
76
|
+
)
|
pydantic_ai/ext/langchain.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Any, Protocol
|
|
2
4
|
|
|
3
5
|
from pydantic.json_schema import JsonSchemaValue
|
|
@@ -65,5 +67,5 @@ def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
|
|
|
65
67
|
class LangChainToolset(FunctionToolset):
|
|
66
68
|
"""A toolset that wraps LangChain tools."""
|
|
67
69
|
|
|
68
|
-
def __init__(self, tools: list[LangChainTool]):
|
|
69
|
-
super().__init__([tool_from_langchain(tool) for tool in tools])
|
|
70
|
+
def __init__(self, tools: list[LangChainTool], *, id: str | None = None):
|
|
71
|
+
super().__init__([tool_from_langchain(tool) for tool in tools], id=id)
|