pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +25 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +65 -9
- pydantic_ai/models/anthropic.py +16 -7
- pydantic_ai/models/bedrock.py +8 -5
- pydantic_ai/models/cohere.py +1 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +9 -4
- pydantic_ai/models/gemini.py +15 -9
- pydantic_ai/models/google.py +18 -14
- pydantic_ai/models/groq.py +17 -14
- pydantic_ai/models/huggingface.py +18 -12
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +12 -18
- pydantic_ai/models/openai.py +29 -26
- pydantic_ai/models/test.py +3 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
4
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
5
|
+
from typing import Any, overload
|
|
6
|
+
|
|
7
|
+
from .. import (
|
|
8
|
+
_utils,
|
|
9
|
+
messages as _messages,
|
|
10
|
+
models,
|
|
11
|
+
usage as _usage,
|
|
12
|
+
)
|
|
13
|
+
from ..output import OutputDataT, OutputSpec
|
|
14
|
+
from ..run import AgentRun
|
|
15
|
+
from ..settings import ModelSettings
|
|
16
|
+
from ..tools import (
|
|
17
|
+
AgentDepsT,
|
|
18
|
+
Tool,
|
|
19
|
+
ToolFuncEither,
|
|
20
|
+
)
|
|
21
|
+
from ..toolsets import AbstractToolset
|
|
22
|
+
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
26
|
+
"""Agent which wraps another agent.
|
|
27
|
+
|
|
28
|
+
Does nothing on its own, used as a base class.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, wrapped: AbstractAgent[AgentDepsT, OutputDataT]):
|
|
32
|
+
self.wrapped = wrapped
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def model(self) -> models.Model | models.KnownModelName | str | None:
|
|
36
|
+
return self.wrapped.model
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def name(self) -> str | None:
|
|
40
|
+
return self.wrapped.name
|
|
41
|
+
|
|
42
|
+
@name.setter
|
|
43
|
+
def name(self, value: str | None) -> None:
|
|
44
|
+
self.wrapped.name = value
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def deps_type(self) -> type:
|
|
48
|
+
return self.wrapped.deps_type
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def output_type(self) -> OutputSpec[OutputDataT]:
|
|
52
|
+
return self.wrapped.output_type
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None:
|
|
56
|
+
return self.wrapped.event_stream_handler
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
|
|
60
|
+
return self.wrapped.toolsets
|
|
61
|
+
|
|
62
|
+
async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
|
|
63
|
+
return await self.wrapped.__aenter__()
|
|
64
|
+
|
|
65
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
66
|
+
return await self.wrapped.__aexit__(*args)
|
|
67
|
+
|
|
68
|
+
@overload
|
|
69
|
+
def iter(
|
|
70
|
+
self,
|
|
71
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
72
|
+
*,
|
|
73
|
+
output_type: None = None,
|
|
74
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
75
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
76
|
+
deps: AgentDepsT = None,
|
|
77
|
+
model_settings: ModelSettings | None = None,
|
|
78
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
79
|
+
usage: _usage.Usage | None = None,
|
|
80
|
+
infer_name: bool = True,
|
|
81
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
82
|
+
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...
|
|
83
|
+
|
|
84
|
+
@overload
|
|
85
|
+
def iter(
|
|
86
|
+
self,
|
|
87
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
88
|
+
*,
|
|
89
|
+
output_type: OutputSpec[RunOutputDataT],
|
|
90
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
91
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
92
|
+
deps: AgentDepsT = None,
|
|
93
|
+
model_settings: ModelSettings | None = None,
|
|
94
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
95
|
+
usage: _usage.Usage | None = None,
|
|
96
|
+
infer_name: bool = True,
|
|
97
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
98
|
+
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
|
|
99
|
+
|
|
100
|
+
@asynccontextmanager
|
|
101
|
+
async def iter(
|
|
102
|
+
self,
|
|
103
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
104
|
+
*,
|
|
105
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
106
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
107
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
108
|
+
deps: AgentDepsT = None,
|
|
109
|
+
model_settings: ModelSettings | None = None,
|
|
110
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
111
|
+
usage: _usage.Usage | None = None,
|
|
112
|
+
infer_name: bool = True,
|
|
113
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
114
|
+
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
|
|
115
|
+
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
|
|
116
|
+
|
|
117
|
+
This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an
|
|
118
|
+
`AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are
|
|
119
|
+
executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the
|
|
120
|
+
stream of events coming from the execution of tools.
|
|
121
|
+
|
|
122
|
+
The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics,
|
|
123
|
+
and the final result of the run once it has completed.
|
|
124
|
+
|
|
125
|
+
For more details, see the documentation of `AgentRun`.
|
|
126
|
+
|
|
127
|
+
Example:
|
|
128
|
+
```python
|
|
129
|
+
from pydantic_ai import Agent
|
|
130
|
+
|
|
131
|
+
agent = Agent('openai:gpt-4o')
|
|
132
|
+
|
|
133
|
+
async def main():
|
|
134
|
+
nodes = []
|
|
135
|
+
async with agent.iter('What is the capital of France?') as agent_run:
|
|
136
|
+
async for node in agent_run:
|
|
137
|
+
nodes.append(node)
|
|
138
|
+
print(nodes)
|
|
139
|
+
'''
|
|
140
|
+
[
|
|
141
|
+
UserPromptNode(
|
|
142
|
+
user_prompt='What is the capital of France?',
|
|
143
|
+
instructions=None,
|
|
144
|
+
instructions_functions=[],
|
|
145
|
+
system_prompts=(),
|
|
146
|
+
system_prompt_functions=[],
|
|
147
|
+
system_prompt_dynamic_functions={},
|
|
148
|
+
),
|
|
149
|
+
ModelRequestNode(
|
|
150
|
+
request=ModelRequest(
|
|
151
|
+
parts=[
|
|
152
|
+
UserPromptPart(
|
|
153
|
+
content='What is the capital of France?',
|
|
154
|
+
timestamp=datetime.datetime(...),
|
|
155
|
+
)
|
|
156
|
+
]
|
|
157
|
+
)
|
|
158
|
+
),
|
|
159
|
+
CallToolsNode(
|
|
160
|
+
model_response=ModelResponse(
|
|
161
|
+
parts=[TextPart(content='The capital of France is Paris.')],
|
|
162
|
+
usage=Usage(
|
|
163
|
+
requests=1, request_tokens=56, response_tokens=7, total_tokens=63
|
|
164
|
+
),
|
|
165
|
+
model_name='gpt-4o',
|
|
166
|
+
timestamp=datetime.datetime(...),
|
|
167
|
+
)
|
|
168
|
+
),
|
|
169
|
+
End(data=FinalResult(output='The capital of France is Paris.')),
|
|
170
|
+
]
|
|
171
|
+
'''
|
|
172
|
+
print(agent_run.result.output)
|
|
173
|
+
#> The capital of France is Paris.
|
|
174
|
+
```
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
user_prompt: User input to start/continue the conversation.
|
|
178
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
179
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
180
|
+
message_history: History of the conversation so far.
|
|
181
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
182
|
+
deps: Optional dependencies to use for this run.
|
|
183
|
+
model_settings: Optional settings to use for this model's request.
|
|
184
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
185
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
186
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
187
|
+
toolsets: Optional additional toolsets for this run.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
The result of the run.
|
|
191
|
+
"""
|
|
192
|
+
async with self.wrapped.iter(
|
|
193
|
+
user_prompt=user_prompt,
|
|
194
|
+
output_type=output_type,
|
|
195
|
+
message_history=message_history,
|
|
196
|
+
model=model,
|
|
197
|
+
deps=deps,
|
|
198
|
+
model_settings=model_settings,
|
|
199
|
+
usage_limits=usage_limits,
|
|
200
|
+
usage=usage,
|
|
201
|
+
infer_name=infer_name,
|
|
202
|
+
toolsets=toolsets,
|
|
203
|
+
) as run:
|
|
204
|
+
yield run
|
|
205
|
+
|
|
206
|
+
@contextmanager
|
|
207
|
+
def override(
|
|
208
|
+
self,
|
|
209
|
+
*,
|
|
210
|
+
deps: AgentDepsT | _utils.Unset = _utils.UNSET,
|
|
211
|
+
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
|
|
212
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
|
|
213
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
|
|
214
|
+
) -> Iterator[None]:
|
|
215
|
+
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
|
|
216
|
+
|
|
217
|
+
This is particularly useful when testing.
|
|
218
|
+
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
222
|
+
model: The model to use instead of the model passed to the agent run.
|
|
223
|
+
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
|
|
224
|
+
tools: The tools to use instead of the tools registered with the agent.
|
|
225
|
+
"""
|
|
226
|
+
with self.wrapped.override(deps=deps, model=model, toolsets=toolsets, tools=tools):
|
|
227
|
+
yield
|
pydantic_ai/direct.py
CHANGED
|
@@ -56,8 +56,8 @@ async def model_request(
|
|
|
56
56
|
print(model_response)
|
|
57
57
|
'''
|
|
58
58
|
ModelResponse(
|
|
59
|
-
parts=[TextPart(content='Paris')],
|
|
60
|
-
usage=Usage(requests=1, request_tokens=56, response_tokens=
|
|
59
|
+
parts=[TextPart(content='The capital of France is Paris.')],
|
|
60
|
+
usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63),
|
|
61
61
|
model_name='claude-3-5-haiku-latest',
|
|
62
62
|
timestamp=datetime.datetime(...),
|
|
63
63
|
)
|
|
@@ -109,8 +109,8 @@ def model_request_sync(
|
|
|
109
109
|
print(model_response)
|
|
110
110
|
'''
|
|
111
111
|
ModelResponse(
|
|
112
|
-
parts=[TextPart(content='Paris')],
|
|
113
|
-
usage=Usage(requests=1, request_tokens=56, response_tokens=
|
|
112
|
+
parts=[TextPart(content='The capital of France is Paris.')],
|
|
113
|
+
usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63),
|
|
114
114
|
model_name='claude-3-5-haiku-latest',
|
|
115
115
|
timestamp=datetime.datetime(...),
|
|
116
116
|
)
|
|
@@ -167,6 +167,7 @@ def model_request_stream(
|
|
|
167
167
|
'''
|
|
168
168
|
[
|
|
169
169
|
PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')),
|
|
170
|
+
FinalResultEvent(tool_name=None, tool_call_id=None),
|
|
170
171
|
PartDeltaEvent(
|
|
171
172
|
index=0, delta=TextPartDelta(content_delta='a German-born theoretical ')
|
|
172
173
|
),
|
|
@@ -223,6 +224,7 @@ def model_request_stream_sync(
|
|
|
223
224
|
'''
|
|
224
225
|
[
|
|
225
226
|
PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')),
|
|
227
|
+
FinalResultEvent(tool_name=None, tool_call_id=None),
|
|
226
228
|
PartDeltaEvent(
|
|
227
229
|
index=0, delta=TextPartDelta(content_delta='a German-born theoretical ')
|
|
228
230
|
),
|
|
@@ -273,9 +275,7 @@ class StreamedResponseSync:
|
|
|
273
275
|
"""
|
|
274
276
|
|
|
275
277
|
_async_stream_cm: AbstractAsyncContextManager[StreamedResponse]
|
|
276
|
-
_queue: queue.Queue[messages.
|
|
277
|
-
default_factory=queue.Queue, init=False
|
|
278
|
-
)
|
|
278
|
+
_queue: queue.Queue[messages.AgentStreamEvent | Exception | None] = field(default_factory=queue.Queue, init=False)
|
|
279
279
|
_thread: threading.Thread | None = field(default=None, init=False)
|
|
280
280
|
_stream_response: StreamedResponse | None = field(default=None, init=False)
|
|
281
281
|
_exception: Exception | None = field(default=None, init=False)
|
|
@@ -295,8 +295,8 @@ class StreamedResponseSync:
|
|
|
295
295
|
) -> None:
|
|
296
296
|
self._cleanup()
|
|
297
297
|
|
|
298
|
-
def __iter__(self) -> Iterator[messages.
|
|
299
|
-
"""Stream the response as an iterable of [`
|
|
298
|
+
def __iter__(self) -> Iterator[messages.AgentStreamEvent]:
|
|
299
|
+
"""Stream the response as an iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s."""
|
|
300
300
|
self._check_context_manager_usage()
|
|
301
301
|
|
|
302
302
|
while True:
|
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import replace
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
from pydantic.errors import PydanticUserError
|
|
9
|
+
from temporalio.client import ClientConfig, Plugin as ClientPlugin
|
|
10
|
+
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
|
|
11
|
+
from temporalio.converter import DefaultPayloadConverter
|
|
12
|
+
from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig
|
|
13
|
+
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
|
|
14
|
+
|
|
15
|
+
from ...exceptions import UserError
|
|
16
|
+
from ._agent import TemporalAgent
|
|
17
|
+
from ._logfire import LogfirePlugin
|
|
18
|
+
from ._run_context import TemporalRunContext
|
|
19
|
+
from ._toolset import TemporalWrapperToolset
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
'TemporalAgent',
|
|
23
|
+
'PydanticAIPlugin',
|
|
24
|
+
'LogfirePlugin',
|
|
25
|
+
'AgentPlugin',
|
|
26
|
+
'TemporalRunContext',
|
|
27
|
+
'TemporalWrapperToolset',
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
|
|
32
|
+
"""Temporal client and worker plugin for Pydantic AI."""
|
|
33
|
+
|
|
34
|
+
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
35
|
+
if (data_converter := config.get('data_converter')) and data_converter.payload_converter_class not in (
|
|
36
|
+
DefaultPayloadConverter,
|
|
37
|
+
PydanticPayloadConverter,
|
|
38
|
+
):
|
|
39
|
+
warnings.warn( # pragma: no cover
|
|
40
|
+
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
config['data_converter'] = pydantic_data_converter
|
|
44
|
+
return super().configure_client(config)
|
|
45
|
+
|
|
46
|
+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
|
|
47
|
+
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
|
|
48
|
+
if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch
|
|
49
|
+
config['workflow_runner'] = replace(
|
|
50
|
+
runner,
|
|
51
|
+
restrictions=runner.restrictions.with_passthrough_modules(
|
|
52
|
+
'pydantic_ai',
|
|
53
|
+
'logfire',
|
|
54
|
+
'rich',
|
|
55
|
+
'httpx',
|
|
56
|
+
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
|
|
57
|
+
'attrs',
|
|
58
|
+
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
|
|
59
|
+
'numpy',
|
|
60
|
+
'pandas',
|
|
61
|
+
),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
config['workflow_failure_exception_types'] = [
|
|
65
|
+
*config.get('workflow_failure_exception_types', []), # pyright: ignore[reportUnknownMemberType]
|
|
66
|
+
UserError,
|
|
67
|
+
PydanticUserError,
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
return super().configure_worker(config)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AgentPlugin(WorkerPlugin):
|
|
74
|
+
"""Temporal worker plugin for a specific Pydantic AI agent."""
|
|
75
|
+
|
|
76
|
+
def __init__(self, agent: TemporalAgent[Any, Any]):
|
|
77
|
+
self.agent = agent
|
|
78
|
+
|
|
79
|
+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
|
|
80
|
+
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
|
|
81
|
+
# Activities are checked for name conflicts by Temporal.
|
|
82
|
+
config['activities'] = [*activities, *self.agent.temporal_activities]
|
|
83
|
+
return super().configure_worker(config)
|