pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +156 -44
- pydantic_ai/models/__init__.py +20 -7
- pydantic_ai/models/anthropic.py +10 -17
- pydantic_ai/models/bedrock.py +55 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +13 -14
- pydantic_ai/models/google.py +19 -5
- pydantic_ai/models/groq.py +127 -39
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +49 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +37 -42
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterator, Sequence
|
|
4
4
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
5
5
|
from contextvars import ContextVar
|
|
6
|
+
from dataclasses import dataclass
|
|
6
7
|
from datetime import timedelta
|
|
7
|
-
from typing import Any,
|
|
8
|
+
from typing import Any, Literal, overload
|
|
8
9
|
|
|
10
|
+
from pydantic import ConfigDict, with_config
|
|
9
11
|
from pydantic.errors import PydanticUserError
|
|
10
12
|
from pydantic_core import PydanticSerializationError
|
|
11
|
-
from temporalio import workflow
|
|
13
|
+
from temporalio import activity, workflow
|
|
12
14
|
from temporalio.common import RetryPolicy
|
|
13
15
|
from temporalio.workflow import ActivityConfig
|
|
14
16
|
from typing_extensions import Never
|
|
@@ -21,22 +23,31 @@ from pydantic_ai import (
|
|
|
21
23
|
)
|
|
22
24
|
from pydantic_ai._run_context import AgentDepsT
|
|
23
25
|
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
|
|
24
|
-
from pydantic_ai.durable_exec.temporal._run_context import TemporalRunContext
|
|
25
26
|
from pydantic_ai.exceptions import UserError
|
|
26
27
|
from pydantic_ai.models import Model
|
|
27
28
|
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
28
29
|
from pydantic_ai.result import StreamedRunResult
|
|
29
30
|
from pydantic_ai.settings import ModelSettings
|
|
30
31
|
from pydantic_ai.tools import (
|
|
32
|
+
DeferredToolResults,
|
|
33
|
+
RunContext,
|
|
31
34
|
Tool,
|
|
32
35
|
ToolFuncEither,
|
|
33
36
|
)
|
|
34
37
|
from pydantic_ai.toolsets import AbstractToolset
|
|
35
38
|
|
|
36
39
|
from ._model import TemporalModel
|
|
40
|
+
from ._run_context import TemporalRunContext
|
|
37
41
|
from ._toolset import TemporalWrapperToolset, temporalize_toolset
|
|
38
42
|
|
|
39
43
|
|
|
44
|
+
@dataclass
|
|
45
|
+
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
46
|
+
class _EventStreamHandlerParams:
|
|
47
|
+
event: _messages.AgentStreamEvent
|
|
48
|
+
serialized_run_context: Any
|
|
49
|
+
|
|
50
|
+
|
|
40
51
|
class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
41
52
|
def __init__(
|
|
42
53
|
self,
|
|
@@ -85,6 +96,10 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
85
96
|
"""
|
|
86
97
|
super().__init__(wrapped)
|
|
87
98
|
|
|
99
|
+
self._name = name
|
|
100
|
+
self._event_stream_handler = event_stream_handler
|
|
101
|
+
self.run_context_type = run_context_type
|
|
102
|
+
|
|
88
103
|
# start_to_close_timeout is required
|
|
89
104
|
activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(seconds=60))
|
|
90
105
|
|
|
@@ -96,13 +111,13 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
96
111
|
PydanticUserError.__name__,
|
|
97
112
|
]
|
|
98
113
|
activity_config['retry_policy'] = retry_policy
|
|
114
|
+
self.activity_config = activity_config
|
|
99
115
|
|
|
100
116
|
model_activity_config = model_activity_config or {}
|
|
101
117
|
toolset_activity_config = toolset_activity_config or {}
|
|
102
118
|
tool_activity_config = tool_activity_config or {}
|
|
103
119
|
|
|
104
|
-
self.
|
|
105
|
-
if self._name is None:
|
|
120
|
+
if self.name is None:
|
|
106
121
|
raise UserError(
|
|
107
122
|
"An agent needs to have a unique `name` in order to be used with Temporal. The name will be used to identify the agent's activities within the workflow."
|
|
108
123
|
)
|
|
@@ -115,13 +130,33 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
115
130
|
'An agent needs to have a `model` in order to be used with Temporal, it cannot be set at agent run time.'
|
|
116
131
|
)
|
|
117
132
|
|
|
133
|
+
async def event_stream_handler_activity(params: _EventStreamHandlerParams, deps: AgentDepsT) -> None:
|
|
134
|
+
# We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
|
|
135
|
+
# and that only ends up calling `event_stream_handler` if it is set.
|
|
136
|
+
assert self.event_stream_handler is not None
|
|
137
|
+
|
|
138
|
+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
139
|
+
|
|
140
|
+
async def streamed_response():
|
|
141
|
+
yield params.event
|
|
142
|
+
|
|
143
|
+
await self.event_stream_handler(run_context, streamed_response())
|
|
144
|
+
|
|
145
|
+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
146
|
+
event_stream_handler_activity.__annotations__['deps'] = self.deps_type
|
|
147
|
+
|
|
148
|
+
self.event_stream_handler_activity = activity.defn(name=f'{activity_name_prefix}__event_stream_handler')(
|
|
149
|
+
event_stream_handler_activity
|
|
150
|
+
)
|
|
151
|
+
activities.append(self.event_stream_handler_activity)
|
|
152
|
+
|
|
118
153
|
temporal_model = TemporalModel(
|
|
119
154
|
wrapped.model,
|
|
120
155
|
activity_name_prefix=activity_name_prefix,
|
|
121
156
|
activity_config=activity_config | model_activity_config,
|
|
122
157
|
deps_type=self.deps_type,
|
|
123
|
-
run_context_type=run_context_type,
|
|
124
|
-
event_stream_handler=
|
|
158
|
+
run_context_type=self.run_context_type,
|
|
159
|
+
event_stream_handler=self.event_stream_handler,
|
|
125
160
|
)
|
|
126
161
|
activities.extend(temporal_model.temporal_activities)
|
|
127
162
|
|
|
@@ -138,7 +173,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
138
173
|
activity_config | toolset_activity_config.get(id, {}),
|
|
139
174
|
tool_activity_config.get(id, {}),
|
|
140
175
|
self.deps_type,
|
|
141
|
-
run_context_type,
|
|
176
|
+
self.run_context_type,
|
|
142
177
|
)
|
|
143
178
|
if isinstance(toolset, TemporalWrapperToolset):
|
|
144
179
|
activities.extend(toolset.temporal_activities)
|
|
@@ -154,7 +189,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
154
189
|
|
|
155
190
|
@property
|
|
156
191
|
def name(self) -> str | None:
|
|
157
|
-
return self._name
|
|
192
|
+
return self._name or super().name
|
|
158
193
|
|
|
159
194
|
@name.setter
|
|
160
195
|
def name(self, value: str | None) -> None: # pragma: no cover
|
|
@@ -166,6 +201,33 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
166
201
|
def model(self) -> Model:
|
|
167
202
|
return self._model
|
|
168
203
|
|
|
204
|
+
@property
|
|
205
|
+
def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None:
|
|
206
|
+
handler = self._event_stream_handler or super().event_stream_handler
|
|
207
|
+
if handler is None:
|
|
208
|
+
return None
|
|
209
|
+
elif workflow.in_workflow():
|
|
210
|
+
return self._call_event_stream_handler_activity
|
|
211
|
+
else:
|
|
212
|
+
return handler
|
|
213
|
+
|
|
214
|
+
async def _call_event_stream_handler_activity(
|
|
215
|
+
self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent]
|
|
216
|
+
) -> None:
|
|
217
|
+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
218
|
+
async for event in stream:
|
|
219
|
+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
220
|
+
activity=self.event_stream_handler_activity,
|
|
221
|
+
args=[
|
|
222
|
+
_EventStreamHandlerParams(
|
|
223
|
+
event=event,
|
|
224
|
+
serialized_run_context=serialized_run_context,
|
|
225
|
+
),
|
|
226
|
+
ctx.deps,
|
|
227
|
+
],
|
|
228
|
+
**self.activity_config,
|
|
229
|
+
)
|
|
230
|
+
|
|
169
231
|
@property
|
|
170
232
|
def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
|
|
171
233
|
with self._temporal_overrides():
|
|
@@ -196,6 +258,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
196
258
|
*,
|
|
197
259
|
output_type: None = None,
|
|
198
260
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
261
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
199
262
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
200
263
|
deps: AgentDepsT = None,
|
|
201
264
|
model_settings: ModelSettings | None = None,
|
|
@@ -213,6 +276,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
213
276
|
*,
|
|
214
277
|
output_type: OutputSpec[RunOutputDataT],
|
|
215
278
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
279
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
216
280
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
217
281
|
deps: AgentDepsT = None,
|
|
218
282
|
model_settings: ModelSettings | None = None,
|
|
@@ -229,6 +293,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
229
293
|
*,
|
|
230
294
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
231
295
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
296
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
232
297
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
233
298
|
deps: AgentDepsT = None,
|
|
234
299
|
model_settings: ModelSettings | None = None,
|
|
@@ -261,6 +326,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
261
326
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
262
327
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
263
328
|
message_history: History of the conversation so far.
|
|
329
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
264
330
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
265
331
|
deps: Optional dependencies to use for this run.
|
|
266
332
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -283,6 +349,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
283
349
|
user_prompt,
|
|
284
350
|
output_type=output_type,
|
|
285
351
|
message_history=message_history,
|
|
352
|
+
deferred_tool_results=deferred_tool_results,
|
|
286
353
|
model=model,
|
|
287
354
|
deps=deps,
|
|
288
355
|
model_settings=model_settings,
|
|
@@ -290,7 +357,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
290
357
|
usage=usage,
|
|
291
358
|
infer_name=infer_name,
|
|
292
359
|
toolsets=toolsets,
|
|
293
|
-
event_stream_handler=event_stream_handler,
|
|
360
|
+
event_stream_handler=event_stream_handler or self.event_stream_handler,
|
|
294
361
|
**_deprecated_kwargs,
|
|
295
362
|
)
|
|
296
363
|
|
|
@@ -301,6 +368,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
301
368
|
*,
|
|
302
369
|
output_type: None = None,
|
|
303
370
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
371
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
304
372
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
305
373
|
deps: AgentDepsT = None,
|
|
306
374
|
model_settings: ModelSettings | None = None,
|
|
@@ -318,6 +386,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
318
386
|
*,
|
|
319
387
|
output_type: OutputSpec[RunOutputDataT],
|
|
320
388
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
389
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
321
390
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
322
391
|
deps: AgentDepsT = None,
|
|
323
392
|
model_settings: ModelSettings | None = None,
|
|
@@ -334,6 +403,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
334
403
|
*,
|
|
335
404
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
336
405
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
406
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
337
407
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
338
408
|
deps: AgentDepsT = None,
|
|
339
409
|
model_settings: ModelSettings | None = None,
|
|
@@ -365,6 +435,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
365
435
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
366
436
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
367
437
|
message_history: History of the conversation so far.
|
|
438
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
368
439
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
369
440
|
deps: Optional dependencies to use for this run.
|
|
370
441
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -386,6 +457,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
386
457
|
user_prompt,
|
|
387
458
|
output_type=output_type,
|
|
388
459
|
message_history=message_history,
|
|
460
|
+
deferred_tool_results=deferred_tool_results,
|
|
389
461
|
model=model,
|
|
390
462
|
deps=deps,
|
|
391
463
|
model_settings=model_settings,
|
|
@@ -404,6 +476,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
404
476
|
*,
|
|
405
477
|
output_type: None = None,
|
|
406
478
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
479
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
407
480
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
408
481
|
deps: AgentDepsT = None,
|
|
409
482
|
model_settings: ModelSettings | None = None,
|
|
@@ -421,6 +494,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
421
494
|
*,
|
|
422
495
|
output_type: OutputSpec[RunOutputDataT],
|
|
423
496
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
497
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
424
498
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
425
499
|
deps: AgentDepsT = None,
|
|
426
500
|
model_settings: ModelSettings | None = None,
|
|
@@ -438,6 +512,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
438
512
|
*,
|
|
439
513
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
440
514
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
515
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
441
516
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
442
517
|
deps: AgentDepsT = None,
|
|
443
518
|
model_settings: ModelSettings | None = None,
|
|
@@ -467,6 +542,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
467
542
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
468
543
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
469
544
|
message_history: History of the conversation so far.
|
|
545
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
470
546
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
471
547
|
deps: Optional dependencies to use for this run.
|
|
472
548
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -490,6 +566,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
490
566
|
user_prompt,
|
|
491
567
|
output_type=output_type,
|
|
492
568
|
message_history=message_history,
|
|
569
|
+
deferred_tool_results=deferred_tool_results,
|
|
493
570
|
model=model,
|
|
494
571
|
deps=deps,
|
|
495
572
|
model_settings=model_settings,
|
|
@@ -509,6 +586,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
509
586
|
*,
|
|
510
587
|
output_type: None = None,
|
|
511
588
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
589
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
512
590
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
513
591
|
deps: AgentDepsT = None,
|
|
514
592
|
model_settings: ModelSettings | None = None,
|
|
@@ -526,6 +604,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
526
604
|
*,
|
|
527
605
|
output_type: OutputSpec[RunOutputDataT],
|
|
528
606
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
607
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
529
608
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
530
609
|
deps: AgentDepsT = None,
|
|
531
610
|
model_settings: ModelSettings | None = None,
|
|
@@ -543,6 +622,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
543
622
|
*,
|
|
544
623
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
545
624
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
625
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
546
626
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
547
627
|
deps: AgentDepsT = None,
|
|
548
628
|
model_settings: ModelSettings | None = None,
|
|
@@ -616,6 +696,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
616
696
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
617
697
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
618
698
|
message_history: History of the conversation so far.
|
|
699
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
619
700
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
620
701
|
deps: Optional dependencies to use for this run.
|
|
621
702
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -648,6 +729,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
648
729
|
user_prompt=user_prompt,
|
|
649
730
|
output_type=output_type,
|
|
650
731
|
message_history=message_history,
|
|
732
|
+
deferred_tool_results=deferred_tool_results,
|
|
651
733
|
model=model,
|
|
652
734
|
deps=deps,
|
|
653
735
|
model_settings=model_settings,
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Annotated, Any, Literal, assert_never
|
|
5
6
|
|
|
6
|
-
from pydantic import ConfigDict, with_config
|
|
7
|
+
from pydantic import ConfigDict, Discriminator, with_config
|
|
7
8
|
from temporalio import activity, workflow
|
|
8
9
|
from temporalio.workflow import ActivityConfig
|
|
9
10
|
|
|
10
|
-
from pydantic_ai.exceptions import UserError
|
|
11
|
+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
|
|
11
12
|
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
12
13
|
from pydantic_ai.toolsets import FunctionToolset, ToolsetTool
|
|
13
14
|
from pydantic_ai.toolsets.function import FunctionToolsetTool
|
|
@@ -24,6 +25,34 @@ class _CallToolParams:
|
|
|
24
25
|
serialized_run_context: Any
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
@dataclass
|
|
29
|
+
class _ApprovalRequired:
|
|
30
|
+
kind: Literal['approval_required'] = 'approval_required'
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class _CallDeferred:
|
|
35
|
+
kind: Literal['call_deferred'] = 'call_deferred'
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class _ModelRetry:
|
|
40
|
+
message: str
|
|
41
|
+
kind: Literal['model_retry'] = 'model_retry'
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class _ToolReturn:
|
|
46
|
+
result: Any
|
|
47
|
+
kind: Literal['tool_return'] = 'tool_return'
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
_CallToolResult = Annotated[
|
|
51
|
+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
52
|
+
Discriminator('kind'),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
27
56
|
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
28
57
|
def __init__(
|
|
29
58
|
self,
|
|
@@ -40,7 +69,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
40
69
|
self.tool_activity_config = tool_activity_config
|
|
41
70
|
self.run_context_type = run_context_type
|
|
42
71
|
|
|
43
|
-
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) ->
|
|
72
|
+
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
|
|
44
73
|
name = params.name
|
|
45
74
|
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
46
75
|
try:
|
|
@@ -54,7 +83,15 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
54
83
|
# The tool args will already have been validated into their proper types in the `ToolManager`,
|
|
55
84
|
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
|
|
56
85
|
args_dict = tool.args_validator.validate_python(params.tool_args)
|
|
57
|
-
|
|
86
|
+
try:
|
|
87
|
+
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
|
|
88
|
+
return _ToolReturn(result=result)
|
|
89
|
+
except ApprovalRequired:
|
|
90
|
+
return _ApprovalRequired()
|
|
91
|
+
except CallDeferred:
|
|
92
|
+
return _CallDeferred()
|
|
93
|
+
except ModelRetry as e:
|
|
94
|
+
return _ModelRetry(message=e.message)
|
|
58
95
|
|
|
59
96
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
60
97
|
call_tool_activity.__annotations__['deps'] = deps_type
|
|
@@ -85,7 +122,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
85
122
|
|
|
86
123
|
tool_activity_config = self.activity_config | tool_activity_config
|
|
87
124
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
88
|
-
|
|
125
|
+
result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
89
126
|
activity=self.call_tool_activity,
|
|
90
127
|
args=[
|
|
91
128
|
_CallToolParams(
|
|
@@ -97,3 +134,13 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
97
134
|
],
|
|
98
135
|
**tool_activity_config,
|
|
99
136
|
)
|
|
137
|
+
if isinstance(result, _ApprovalRequired):
|
|
138
|
+
raise ApprovalRequired()
|
|
139
|
+
elif isinstance(result, _CallDeferred):
|
|
140
|
+
raise CallDeferred()
|
|
141
|
+
elif isinstance(result, _ModelRetry):
|
|
142
|
+
raise ModelRetry(result.message)
|
|
143
|
+
elif isinstance(result, _ToolReturn):
|
|
144
|
+
return result.result
|
|
145
|
+
else:
|
|
146
|
+
assert_never(result)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
4
|
|
|
5
5
|
from logfire import Logfire
|
|
6
6
|
from opentelemetry.trace import get_tracer
|
|
@@ -25,10 +25,13 @@ class LogfirePlugin(ClientPlugin):
|
|
|
25
25
|
self.setup_logfire = setup_logfire
|
|
26
26
|
self.metrics = metrics
|
|
27
27
|
|
|
28
|
+
def init_client_plugin(self, next: ClientPlugin) -> None:
|
|
29
|
+
self.next_client_plugin = next
|
|
30
|
+
|
|
28
31
|
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
29
32
|
interceptors = config.get('interceptors', [])
|
|
30
33
|
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
|
|
31
|
-
return
|
|
34
|
+
return self.next_client_plugin.configure_client(config)
|
|
32
35
|
|
|
33
36
|
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
|
|
34
37
|
logfire = self.setup_logfire()
|
|
@@ -45,4 +48,4 @@ class LogfirePlugin(ClientPlugin):
|
|
|
45
48
|
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
|
|
46
49
|
)
|
|
47
50
|
|
|
48
|
-
return await
|
|
51
|
+
return await self.next_client_plugin.connect_service_client(config)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal
|
|
5
6
|
|
|
6
7
|
from pydantic import ConfigDict, with_config
|
|
7
8
|
from temporalio import activity, workflow
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterator, Callable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from pydantic import ConfigDict, with_config
|
|
10
10
|
from temporalio import activity, workflow
|
|
@@ -9,7 +9,7 @@ from pydantic_ai.tools import AgentDepsT, RunContext
|
|
|
9
9
|
class TemporalRunContext(RunContext[AgentDepsT]):
|
|
10
10
|
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
|
|
11
11
|
|
|
12
|
-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `retry` and `run_step` attributes will be available.
|
|
12
|
+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry` and `run_step` attributes will be available.
|
|
13
13
|
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
|
|
14
14
|
"""
|
|
15
15
|
|
|
@@ -40,6 +40,7 @@ class TemporalRunContext(RunContext[AgentDepsT]):
|
|
|
40
40
|
'retries': ctx.retries,
|
|
41
41
|
'tool_call_id': ctx.tool_call_id,
|
|
42
42
|
'tool_name': ctx.tool_name,
|
|
43
|
+
'tool_call_approved': ctx.tool_call_approved,
|
|
43
44
|
'retry': ctx.retry,
|
|
44
45
|
'run_step': ctx.run_step,
|
|
45
46
|
}
|
pydantic_ai/exceptions.py
CHANGED
|
@@ -2,7 +2,9 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from pydantic_core import core_schema
|
|
6
8
|
|
|
7
9
|
if sys.version_info < (3, 11):
|
|
8
10
|
from exceptiongroup import ExceptionGroup as ExceptionGroup # pragma: lax no cover
|
|
@@ -14,6 +16,8 @@ if TYPE_CHECKING:
|
|
|
14
16
|
|
|
15
17
|
__all__ = (
|
|
16
18
|
'ModelRetry',
|
|
19
|
+
'CallDeferred',
|
|
20
|
+
'ApprovalRequired',
|
|
17
21
|
'UserError',
|
|
18
22
|
'AgentRunError',
|
|
19
23
|
'UnexpectedModelBehavior',
|
|
@@ -24,7 +28,7 @@ __all__ = (
|
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
class ModelRetry(Exception):
|
|
27
|
-
"""Exception
|
|
31
|
+
"""Exception to raise when a tool function should be retried.
|
|
28
32
|
|
|
29
33
|
The agent will return the message to the model and ask it to try calling the function/tool again.
|
|
30
34
|
"""
|
|
@@ -36,6 +40,45 @@ class ModelRetry(Exception):
|
|
|
36
40
|
self.message = message
|
|
37
41
|
super().__init__(message)
|
|
38
42
|
|
|
43
|
+
def __eq__(self, other: Any) -> bool:
|
|
44
|
+
return isinstance(other, self.__class__) and other.message == self.message
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> core_schema.CoreSchema:
|
|
48
|
+
"""Pydantic core schema to allow `ModelRetry` to be (de)serialized."""
|
|
49
|
+
schema = core_schema.typed_dict_schema(
|
|
50
|
+
{
|
|
51
|
+
'message': core_schema.typed_dict_field(core_schema.str_schema()),
|
|
52
|
+
'kind': core_schema.typed_dict_field(core_schema.literal_schema(['model-retry'])),
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
return core_schema.no_info_after_validator_function(
|
|
56
|
+
lambda dct: ModelRetry(dct['message']),
|
|
57
|
+
schema,
|
|
58
|
+
serialization=core_schema.plain_serializer_function_ser_schema(
|
|
59
|
+
lambda x: {'message': x.message, 'kind': 'model-retry'},
|
|
60
|
+
return_schema=schema,
|
|
61
|
+
),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class CallDeferred(Exception):
|
|
66
|
+
"""Exception to raise when a tool call should be deferred.
|
|
67
|
+
|
|
68
|
+
See [tools docs](../deferred-tools.md#deferred-tools) for more information.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ApprovalRequired(Exception):
|
|
75
|
+
"""Exception to raise when a tool call requires human-in-the-loop approval.
|
|
76
|
+
|
|
77
|
+
See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
pass
|
|
81
|
+
|
|
39
82
|
|
|
40
83
|
class UserError(RuntimeError):
|
|
41
84
|
"""Error caused by a usage mistake by the application developer — You!"""
|
pydantic_ai/format_prompt.py
CHANGED
|
@@ -72,9 +72,9 @@ class _ToXml:
|
|
|
72
72
|
element.text = self.none_str
|
|
73
73
|
elif isinstance(value, str):
|
|
74
74
|
element.text = value
|
|
75
|
-
elif isinstance(value,
|
|
75
|
+
elif isinstance(value, bytes | bytearray):
|
|
76
76
|
element.text = value.decode(errors='ignore')
|
|
77
|
-
elif isinstance(value,
|
|
77
|
+
elif isinstance(value, bool | int | float):
|
|
78
78
|
element.text = str(value)
|
|
79
79
|
elif isinstance(value, date):
|
|
80
80
|
element.text = value.isoformat()
|