pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (57) hide show
  1. pydantic_ai/_a2a.py +6 -4
  2. pydantic_ai/_agent_graph.py +25 -32
  3. pydantic_ai/_cli.py +3 -3
  4. pydantic_ai/_output.py +8 -0
  5. pydantic_ai/_tool_manager.py +3 -0
  6. pydantic_ai/ag_ui.py +25 -14
  7. pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
  8. pydantic_ai/agent/abstract.py +942 -0
  9. pydantic_ai/agent/wrapper.py +227 -0
  10. pydantic_ai/direct.py +9 -9
  11. pydantic_ai/durable_exec/__init__.py +0 -0
  12. pydantic_ai/durable_exec/temporal/__init__.py +83 -0
  13. pydantic_ai/durable_exec/temporal/_agent.py +699 -0
  14. pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
  15. pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
  16. pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
  17. pydantic_ai/durable_exec/temporal/_model.py +168 -0
  18. pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
  19. pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
  20. pydantic_ai/ext/aci.py +10 -9
  21. pydantic_ai/ext/langchain.py +4 -2
  22. pydantic_ai/mcp.py +203 -75
  23. pydantic_ai/messages.py +2 -2
  24. pydantic_ai/models/__init__.py +65 -9
  25. pydantic_ai/models/anthropic.py +16 -7
  26. pydantic_ai/models/bedrock.py +8 -5
  27. pydantic_ai/models/cohere.py +1 -4
  28. pydantic_ai/models/fallback.py +4 -2
  29. pydantic_ai/models/function.py +9 -4
  30. pydantic_ai/models/gemini.py +15 -9
  31. pydantic_ai/models/google.py +18 -14
  32. pydantic_ai/models/groq.py +17 -14
  33. pydantic_ai/models/huggingface.py +18 -12
  34. pydantic_ai/models/instrumented.py +3 -1
  35. pydantic_ai/models/mcp_sampling.py +3 -1
  36. pydantic_ai/models/mistral.py +12 -18
  37. pydantic_ai/models/openai.py +29 -26
  38. pydantic_ai/models/test.py +3 -0
  39. pydantic_ai/models/wrapper.py +6 -2
  40. pydantic_ai/profiles/openai.py +1 -1
  41. pydantic_ai/providers/google.py +7 -7
  42. pydantic_ai/result.py +21 -55
  43. pydantic_ai/run.py +357 -0
  44. pydantic_ai/tools.py +0 -1
  45. pydantic_ai/toolsets/__init__.py +2 -0
  46. pydantic_ai/toolsets/_dynamic.py +87 -0
  47. pydantic_ai/toolsets/abstract.py +23 -3
  48. pydantic_ai/toolsets/combined.py +19 -4
  49. pydantic_ai/toolsets/deferred.py +10 -2
  50. pydantic_ai/toolsets/function.py +23 -8
  51. pydantic_ai/toolsets/prefixed.py +4 -0
  52. pydantic_ai/toolsets/wrapper.py +14 -1
  53. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
  54. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
  55. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
  56. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
  57. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Literal
5
+
6
+ from pydantic import ConfigDict, with_config
7
+ from temporalio import activity, workflow
8
+ from temporalio.workflow import ActivityConfig
9
+
10
+ from pydantic_ai.exceptions import UserError
11
+ from pydantic_ai.tools import AgentDepsT, RunContext
12
+ from pydantic_ai.toolsets import FunctionToolset, ToolsetTool
13
+ from pydantic_ai.toolsets.function import FunctionToolsetTool
14
+
15
+ from ._run_context import TemporalRunContext
16
+ from ._toolset import TemporalWrapperToolset
17
+
18
+
19
+ @dataclass
20
+ @with_config(ConfigDict(arbitrary_types_allowed=True))
21
+ class _CallToolParams:
22
+ name: str
23
+ tool_args: dict[str, Any]
24
+ serialized_run_context: Any
25
+
26
+
27
+ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
28
+ def __init__(
29
+ self,
30
+ toolset: FunctionToolset[AgentDepsT],
31
+ *,
32
+ activity_name_prefix: str,
33
+ activity_config: ActivityConfig,
34
+ tool_activity_config: dict[str, ActivityConfig | Literal[False]],
35
+ deps_type: type[AgentDepsT],
36
+ run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
37
+ ):
38
+ super().__init__(toolset)
39
+ self.activity_config = activity_config
40
+ self.tool_activity_config = tool_activity_config
41
+ self.run_context_type = run_context_type
42
+
43
+ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> Any:
44
+ name = params.name
45
+ ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
46
+ try:
47
+ tool = (await toolset.get_tools(ctx))[name]
48
+ except KeyError as e: # pragma: no cover
49
+ raise UserError(
50
+ f'Tool {name!r} not found in toolset {self.id!r}. '
51
+ 'Removing or renaming tools during an agent run is not supported with Temporal.'
52
+ ) from e
53
+
54
+ return await self.wrapped.call_tool(name, params.tool_args, ctx, tool)
55
+
56
+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
57
+ call_tool_activity.__annotations__['deps'] = deps_type
58
+
59
+ self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__toolset__{self.id}__call_tool')(
60
+ call_tool_activity
61
+ )
62
+
63
+ @property
64
+ def temporal_activities(self) -> list[Callable[..., Any]]:
65
+ return [self.call_tool_activity]
66
+
67
+ async def call_tool(
68
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
69
+ ) -> Any:
70
+ if not workflow.in_workflow():
71
+ return await super().call_tool(name, tool_args, ctx, tool)
72
+
73
+ tool_activity_config = self.tool_activity_config.get(name, {})
74
+ if tool_activity_config is False:
75
+ assert isinstance(tool, FunctionToolsetTool)
76
+ if not tool.is_async:
77
+ raise UserError(
78
+ f'Temporal activity config for tool {name!r} has been explicitly set to `False` (activity disabled), '
79
+ 'but non-async tools are run in threads which are not supported outside of an activity. Make the tool function async instead.'
80
+ )
81
+ return await super().call_tool(name, tool_args, ctx, tool)
82
+
83
+ tool_activity_config = self.activity_config | tool_activity_config
84
+ serialized_run_context = self.run_context_type.serialize_run_context(ctx)
85
+ return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
86
+ activity=self.call_tool_activity,
87
+ args=[
88
+ _CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context),
89
+ ctx.deps,
90
+ ],
91
+ **tool_activity_config,
92
+ )
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ from logfire import Logfire
6
+ from opentelemetry.trace import get_tracer
7
+ from temporalio.client import ClientConfig, Plugin as ClientPlugin
8
+ from temporalio.contrib.opentelemetry import TracingInterceptor
9
+ from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
10
+ from temporalio.service import ConnectConfig, ServiceClient
11
+
12
+
13
+ def _default_setup_logfire() -> Logfire:
14
+ import logfire
15
+
16
+ instance = logfire.configure()
17
+ logfire.instrument_pydantic_ai()
18
+ return instance
19
+
20
+
21
+ class LogfirePlugin(ClientPlugin):
22
+ """Temporal client plugin for Logfire."""
23
+
24
+ def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True):
25
+ self.setup_logfire = setup_logfire
26
+ self.metrics = metrics
27
+
28
+ def configure_client(self, config: ClientConfig) -> ClientConfig:
29
+ interceptors = config.get('interceptors', [])
30
+ config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
31
+ return super().configure_client(config)
32
+
33
+ async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
34
+ logfire = self.setup_logfire()
35
+
36
+ if self.metrics:
37
+ logfire_config = logfire.config
38
+ token = logfire_config.token
39
+ if logfire_config.send_to_logfire and token is not None and logfire_config.metrics is not False:
40
+ base_url = logfire_config.advanced.generate_base_url(token)
41
+ metrics_url = base_url + '/v1/metrics'
42
+ headers = {'Authorization': f'Bearer {token}'}
43
+
44
+ config.runtime = Runtime(
45
+ telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
46
+ )
47
+
48
+ return await super().connect_service_client(config)
@@ -0,0 +1,145 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Literal
5
+
6
+ from pydantic import ConfigDict, with_config
7
+ from temporalio import activity, workflow
8
+ from temporalio.workflow import ActivityConfig
9
+ from typing_extensions import Self
10
+
11
+ from pydantic_ai.exceptions import UserError
12
+ from pydantic_ai.mcp import MCPServer, ToolResult
13
+ from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
14
+ from pydantic_ai.toolsets.abstract import ToolsetTool
15
+
16
+ from ._run_context import TemporalRunContext
17
+ from ._toolset import TemporalWrapperToolset
18
+
19
+
20
+ @dataclass
21
+ @with_config(ConfigDict(arbitrary_types_allowed=True))
22
+ class _GetToolsParams:
23
+ serialized_run_context: Any
24
+
25
+
26
+ @dataclass
27
+ @with_config(ConfigDict(arbitrary_types_allowed=True))
28
+ class _CallToolParams:
29
+ name: str
30
+ tool_args: dict[str, Any]
31
+ serialized_run_context: Any
32
+ tool_def: ToolDefinition
33
+
34
+
35
+ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
36
+ def __init__(
37
+ self,
38
+ server: MCPServer,
39
+ *,
40
+ activity_name_prefix: str,
41
+ activity_config: ActivityConfig,
42
+ tool_activity_config: dict[str, ActivityConfig | Literal[False]],
43
+ deps_type: type[AgentDepsT],
44
+ run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
45
+ ):
46
+ super().__init__(server)
47
+ self.activity_config = activity_config
48
+
49
+ self.tool_activity_config: dict[str, ActivityConfig] = {}
50
+ for tool_name, tool_config in tool_activity_config.items():
51
+ if tool_config is False:
52
+ raise UserError(
53
+ f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), '
54
+ 'but MCP tools require the use of IO and so cannot be run outside of an activity.'
55
+ )
56
+ self.tool_activity_config[tool_name] = tool_config
57
+
58
+ self.run_context_type = run_context_type
59
+
60
+ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
61
+ run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
62
+ tools = await self.wrapped.get_tools(run_context)
63
+ # ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
64
+ # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
65
+ return {name: tool.tool_def for name, tool in tools.items()}
66
+
67
+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
68
+ get_tools_activity.__annotations__['deps'] = deps_type
69
+
70
+ self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')(
71
+ get_tools_activity
72
+ )
73
+
74
+ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult:
75
+ run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
76
+ return await self.wrapped.call_tool(
77
+ params.name,
78
+ params.tool_args,
79
+ run_context,
80
+ self.tool_for_tool_def(params.tool_def),
81
+ )
82
+
83
+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
84
+ call_tool_activity.__annotations__['deps'] = deps_type
85
+
86
+ self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')(
87
+ call_tool_activity
88
+ )
89
+
90
+ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
91
+ assert isinstance(self.wrapped, MCPServer)
92
+ return self.wrapped.tool_for_tool_def(tool_def)
93
+
94
+ @property
95
+ def temporal_activities(self) -> list[Callable[..., Any]]:
96
+ return [self.get_tools_activity, self.call_tool_activity]
97
+
98
+ async def __aenter__(self) -> Self:
99
+ # The wrapped MCPServer enters itself around listing and calling tools
100
+ # so we don't need to enter it here (nor could we because we're not inside a Temporal activity).
101
+ return self
102
+
103
+ async def __aexit__(self, *args: Any) -> bool | None:
104
+ return None
105
+
106
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
107
+ if not workflow.in_workflow():
108
+ return await super().get_tools(ctx)
109
+
110
+ serialized_run_context = self.run_context_type.serialize_run_context(ctx)
111
+ tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
112
+ activity=self.get_tools_activity,
113
+ args=[
114
+ _GetToolsParams(serialized_run_context=serialized_run_context),
115
+ ctx.deps,
116
+ ],
117
+ **self.activity_config,
118
+ )
119
+ return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}
120
+
121
+ async def call_tool(
122
+ self,
123
+ name: str,
124
+ tool_args: dict[str, Any],
125
+ ctx: RunContext[AgentDepsT],
126
+ tool: ToolsetTool[AgentDepsT],
127
+ ) -> ToolResult:
128
+ if not workflow.in_workflow():
129
+ return await super().call_tool(name, tool_args, ctx, tool)
130
+
131
+ tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
132
+ serialized_run_context = self.run_context_type.serialize_run_context(ctx)
133
+ return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
134
+ activity=self.call_tool_activity,
135
+ args=[
136
+ _CallToolParams(
137
+ name=name,
138
+ tool_args=tool_args,
139
+ serialized_run_context=serialized_run_context,
140
+ tool_def=tool.tool_def,
141
+ ),
142
+ ctx.deps,
143
+ ],
144
+ **tool_activity_config,
145
+ )
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from typing import Any, Callable
8
+
9
+ from pydantic import ConfigDict, with_config
10
+ from temporalio import activity, workflow
11
+ from temporalio.workflow import ActivityConfig
12
+
13
+ from pydantic_ai.agent import EventStreamHandler
14
+ from pydantic_ai.exceptions import UserError
15
+ from pydantic_ai.messages import (
16
+ ModelMessage,
17
+ ModelResponse,
18
+ ModelResponseStreamEvent,
19
+ )
20
+ from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
21
+ from pydantic_ai.models.wrapper import WrapperModel
22
+ from pydantic_ai.settings import ModelSettings
23
+ from pydantic_ai.tools import AgentDepsT, RunContext
24
+ from pydantic_ai.usage import Usage
25
+
26
+ from ._run_context import TemporalRunContext
27
+
28
+
29
+ @dataclass
30
+ @with_config(ConfigDict(arbitrary_types_allowed=True))
31
+ class _RequestParams:
32
+ messages: list[ModelMessage]
33
+ model_settings: ModelSettings | None
34
+ model_request_parameters: ModelRequestParameters
35
+ serialized_run_context: Any
36
+
37
+
38
+ class TemporalStreamedResponse(StreamedResponse):
39
+ def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
40
+ super().__init__(model_request_parameters)
41
+ self.response = response
42
+
43
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
44
+ return
45
+ # noinspection PyUnreachableCode
46
+ yield
47
+
48
+ def get(self) -> ModelResponse:
49
+ return self.response
50
+
51
+ def usage(self) -> Usage:
52
+ return self.response.usage # pragma: no cover
53
+
54
+ @property
55
+ def model_name(self) -> str:
56
+ return self.response.model_name or '' # pragma: no cover
57
+
58
+ @property
59
+ def timestamp(self) -> datetime:
60
+ return self.response.timestamp # pragma: no cover
61
+
62
+
63
+ class TemporalModel(WrapperModel):
64
+ def __init__(
65
+ self,
66
+ model: Model,
67
+ *,
68
+ activity_name_prefix: str,
69
+ activity_config: ActivityConfig,
70
+ deps_type: type[AgentDepsT],
71
+ run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
72
+ event_stream_handler: EventStreamHandler[Any] | None = None,
73
+ ):
74
+ super().__init__(model)
75
+ self.activity_config = activity_config
76
+ self.run_context_type = run_context_type
77
+ self.event_stream_handler = event_stream_handler
78
+
79
+ @activity.defn(name=f'{activity_name_prefix}__model_request')
80
+ async def request_activity(params: _RequestParams) -> ModelResponse:
81
+ return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters)
82
+
83
+ self.request_activity = request_activity
84
+
85
+ async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> ModelResponse:
86
+ # An error is raised in `request_stream` if no `event_stream_handler` is set.
87
+ assert self.event_stream_handler is not None
88
+
89
+ run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
90
+ async with self.wrapped.request_stream(
91
+ params.messages, params.model_settings, params.model_request_parameters, run_context
92
+ ) as streamed_response:
93
+ await self.event_stream_handler(run_context, streamed_response)
94
+
95
+ async for _ in streamed_response:
96
+ pass
97
+ return streamed_response.get()
98
+
99
+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
100
+ request_stream_activity.__annotations__['deps'] = deps_type
101
+
102
+ self.request_stream_activity = activity.defn(name=f'{activity_name_prefix}__model_request_stream')(
103
+ request_stream_activity
104
+ )
105
+
106
+ @property
107
+ def temporal_activities(self) -> list[Callable[..., Any]]:
108
+ return [self.request_activity, self.request_stream_activity]
109
+
110
+ async def request(
111
+ self,
112
+ messages: list[ModelMessage],
113
+ model_settings: ModelSettings | None,
114
+ model_request_parameters: ModelRequestParameters,
115
+ ) -> ModelResponse:
116
+ if not workflow.in_workflow():
117
+ return await super().request(messages, model_settings, model_request_parameters)
118
+
119
+ return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
120
+ activity=self.request_activity,
121
+ arg=_RequestParams(
122
+ messages=messages,
123
+ model_settings=model_settings,
124
+ model_request_parameters=model_request_parameters,
125
+ serialized_run_context=None,
126
+ ),
127
+ **self.activity_config,
128
+ )
129
+
130
+ @asynccontextmanager
131
+ async def request_stream(
132
+ self,
133
+ messages: list[ModelMessage],
134
+ model_settings: ModelSettings | None,
135
+ model_request_parameters: ModelRequestParameters,
136
+ run_context: RunContext[Any] | None = None,
137
+ ) -> AsyncIterator[StreamedResponse]:
138
+ if not workflow.in_workflow():
139
+ async with super().request_stream(
140
+ messages, model_settings, model_request_parameters, run_context
141
+ ) as streamed_response:
142
+ yield streamed_response
143
+ return
144
+
145
+ if run_context is None:
146
+ raise UserError(
147
+ 'A Temporal model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
148
+ )
149
+
150
+ # We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
151
+ # and that only calls `request_stream` if `event_stream_handler` is set.
152
+ assert self.event_stream_handler is not None
153
+
154
+ serialized_run_context = self.run_context_type.serialize_run_context(run_context)
155
+ response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
156
+ activity=self.request_stream_activity,
157
+ args=[
158
+ _RequestParams(
159
+ messages=messages,
160
+ model_settings=model_settings,
161
+ model_request_parameters=model_request_parameters,
162
+ serialized_run_context=serialized_run_context,
163
+ ),
164
+ run_context.deps,
165
+ ],
166
+ **self.activity_config,
167
+ )
168
+ yield TemporalStreamedResponse(model_request_parameters, response)
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from pydantic_ai.exceptions import UserError
6
+ from pydantic_ai.tools import AgentDepsT, RunContext
7
+
8
+
9
+ class TemporalRunContext(RunContext[AgentDepsT]):
10
+ """The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
11
+
12
+ By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `retry` and `run_step` attributes will be available.
13
+ To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
14
+ """
15
+
16
+ def __init__(self, deps: AgentDepsT, **kwargs: Any):
17
+ self.__dict__ = {**kwargs, 'deps': deps}
18
+ setattr(
19
+ self,
20
+ '__dataclass_fields__',
21
+ {name: field for name, field in RunContext.__dataclass_fields__.items() if name in self.__dict__},
22
+ )
23
+
24
+ def __getattribute__(self, name: str) -> Any:
25
+ try:
26
+ return super().__getattribute__(name)
27
+ except AttributeError as e: # pragma: no cover
28
+ if name in RunContext.__dataclass_fields__:
29
+ raise UserError(
30
+ f'{self.__class__.__name__!r} object has no attribute {name!r}. '
31
+ 'To make the attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to `TemporalAgent`.'
32
+ )
33
+ else:
34
+ raise e
35
+
36
+ @classmethod
37
+ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
38
+ """Serialize the run context to a `dict[str, Any]`."""
39
+ return {
40
+ 'retries': ctx.retries,
41
+ 'tool_call_id': ctx.tool_call_id,
42
+ 'tool_name': ctx.tool_name,
43
+ 'retry': ctx.retry,
44
+ 'run_step': ctx.run_step,
45
+ }
46
+
47
+ @classmethod
48
+ def deserialize_run_context(cls, ctx: dict[str, Any], deps: AgentDepsT) -> TemporalRunContext[AgentDepsT]:
49
+ """Deserialize the run context from a `dict[str, Any]`."""
50
+ return cls(**ctx, deps=deps)
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Callable, Literal
5
+
6
+ from temporalio.workflow import ActivityConfig
7
+
8
+ from pydantic_ai.mcp import MCPServer
9
+ from pydantic_ai.tools import AgentDepsT
10
+ from pydantic_ai.toolsets.abstract import AbstractToolset
11
+ from pydantic_ai.toolsets.function import FunctionToolset
12
+ from pydantic_ai.toolsets.wrapper import WrapperToolset
13
+
14
+ from ._run_context import TemporalRunContext
15
+
16
+
17
+ class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
18
+ @property
19
+ def id(self) -> str:
20
+ # An error is raised in `TemporalAgent` if no `id` is set.
21
+ assert self.wrapped.id is not None
22
+ return self.wrapped.id
23
+
24
+ @property
25
+ @abstractmethod
26
+ def temporal_activities(self) -> list[Callable[..., Any]]:
27
+ raise NotImplementedError
28
+
29
+ def visit_and_replace(
30
+ self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
31
+ ) -> AbstractToolset[AgentDepsT]:
32
+ # Temporalized toolsets cannot be swapped out after the fact.
33
+ return self
34
+
35
+
36
+ def temporalize_toolset(
37
+ toolset: AbstractToolset[AgentDepsT],
38
+ activity_name_prefix: str,
39
+ activity_config: ActivityConfig,
40
+ tool_activity_config: dict[str, ActivityConfig | Literal[False]],
41
+ deps_type: type[AgentDepsT],
42
+ run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
43
+ ) -> AbstractToolset[AgentDepsT]:
44
+ """Temporalize a toolset.
45
+
46
+ Args:
47
+ toolset: The toolset to temporalize.
48
+ activity_name_prefix: Prefix for Temporal activity names.
49
+ activity_config: The Temporal activity config to use.
50
+ tool_activity_config: The Temporal activity config to use for specific tools identified by tool name.
51
+ deps_type: The type of agent's dependencies object. It needs to be serializable using Pydantic's `TypeAdapter`.
52
+ run_context_type: The `TemporalRunContext` (sub)class that's used to serialize and deserialize the run context.
53
+ """
54
+ if isinstance(toolset, FunctionToolset):
55
+ from ._function_toolset import TemporalFunctionToolset
56
+
57
+ return TemporalFunctionToolset(
58
+ toolset,
59
+ activity_name_prefix=activity_name_prefix,
60
+ activity_config=activity_config,
61
+ tool_activity_config=tool_activity_config,
62
+ deps_type=deps_type,
63
+ run_context_type=run_context_type,
64
+ )
65
+ elif isinstance(toolset, MCPServer):
66
+ from ._mcp_server import TemporalMCPServer
67
+
68
+ return TemporalMCPServer(
69
+ toolset,
70
+ activity_name_prefix=activity_name_prefix,
71
+ activity_config=activity_config,
72
+ tool_activity_config=tool_activity_config,
73
+ deps_type=deps_type,
74
+ run_context_type=run_context_type,
75
+ )
76
+ else:
77
+ return toolset
pydantic_ai/ext/aci.py CHANGED
@@ -1,17 +1,16 @@
1
- # Checking whether aci-sdk is installed
2
- try:
3
- from aci import ACI
4
- except ImportError as _import_error:
5
- raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error
1
+ from __future__ import annotations
6
2
 
7
3
  from collections.abc import Sequence
8
4
  from typing import Any
9
5
 
10
- from aci import ACI
11
-
12
6
  from pydantic_ai.tools import Tool
13
7
  from pydantic_ai.toolsets.function import FunctionToolset
14
8
 
9
+ try:
10
+ from aci import ACI
11
+ except ImportError as _import_error:
12
+ raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error
13
+
15
14
 
16
15
  def _clean_schema(schema):
17
16
  if isinstance(schema, dict):
@@ -71,5 +70,7 @@ def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool:
71
70
  class ACIToolset(FunctionToolset):
72
71
  """A toolset that wraps ACI.dev tools."""
73
72
 
74
- def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str):
75
- super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions])
73
+ def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, *, id: str | None = None):
74
+ super().__init__(
75
+ [tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions], id=id
76
+ )
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Any, Protocol
2
4
 
3
5
  from pydantic.json_schema import JsonSchemaValue
@@ -65,5 +67,5 @@ def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
65
67
  class LangChainToolset(FunctionToolset):
66
68
  """A toolset that wraps LangChain tools."""
67
69
 
68
- def __init__(self, tools: list[LangChainTool]):
69
- super().__init__([tool_from_langchain(tool) for tool in tools])
70
+ def __init__(self, tools: list[LangChainTool], *, id: str | None = None):
71
+ super().__init__([tool_from_langchain(tool) for tool in tools], id=id)