pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_agent_graph.py +310 -140
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +4 -4
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +3 -22
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +7 -8
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +23 -2
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +2 -2
- pydantic_ai/messages.py +81 -28
- pydantic_ai/models/__init__.py +19 -7
- pydantic_ai/models/anthropic.py +6 -6
- pydantic_ai/models/bedrock.py +63 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +10 -13
- pydantic_ai/models/google.py +4 -4
- pydantic_ai/models/groq.py +5 -5
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +44 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +20 -29
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +10 -29
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +126 -22
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +13 -4
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +7 -5
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +6 -7
- pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,15 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
|
-
from collections.abc import Sequence
|
|
4
|
+
from collections.abc import AsyncIterator, Callable, Sequence
|
|
5
|
+
from contextlib import AbstractAsyncContextManager
|
|
5
6
|
from dataclasses import replace
|
|
6
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
from pydantic.errors import PydanticUserError
|
|
9
|
-
from temporalio.client import ClientConfig, Plugin as ClientPlugin
|
|
10
|
+
from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory
|
|
10
11
|
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
|
|
11
|
-
from temporalio.converter import DefaultPayloadConverter
|
|
12
|
-
from temporalio.
|
|
12
|
+
from temporalio.converter import DataConverter, DefaultPayloadConverter
|
|
13
|
+
from temporalio.service import ConnectConfig, ServiceClient
|
|
14
|
+
from temporalio.worker import (
|
|
15
|
+
Plugin as WorkerPlugin,
|
|
16
|
+
Replayer,
|
|
17
|
+
ReplayerConfig,
|
|
18
|
+
Worker,
|
|
19
|
+
WorkerConfig,
|
|
20
|
+
WorkflowReplayResult,
|
|
21
|
+
)
|
|
13
22
|
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
|
|
14
23
|
|
|
15
24
|
from ...exceptions import UserError
|
|
@@ -31,17 +40,15 @@ __all__ = [
|
|
|
31
40
|
class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
|
|
32
41
|
"""Temporal client and worker plugin for Pydantic AI."""
|
|
33
42
|
|
|
34
|
-
def
|
|
35
|
-
|
|
36
|
-
DefaultPayloadConverter,
|
|
37
|
-
PydanticPayloadConverter,
|
|
38
|
-
):
|
|
39
|
-
warnings.warn( # pragma: no cover
|
|
40
|
-
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
|
|
41
|
-
)
|
|
43
|
+
def init_client_plugin(self, next: ClientPlugin) -> None:
|
|
44
|
+
self.next_client_plugin = next
|
|
42
45
|
|
|
43
|
-
|
|
44
|
-
|
|
46
|
+
def init_worker_plugin(self, next: WorkerPlugin) -> None:
|
|
47
|
+
self.next_worker_plugin = next
|
|
48
|
+
|
|
49
|
+
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
50
|
+
config['data_converter'] = self._get_new_data_converter(config.get('data_converter'))
|
|
51
|
+
return self.next_client_plugin.configure_client(config)
|
|
45
52
|
|
|
46
53
|
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
|
|
47
54
|
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
|
|
@@ -50,6 +57,8 @@ class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
|
|
|
50
57
|
runner,
|
|
51
58
|
restrictions=runner.restrictions.with_passthrough_modules(
|
|
52
59
|
'pydantic_ai',
|
|
60
|
+
'pydantic',
|
|
61
|
+
'pydantic_core',
|
|
53
62
|
'logfire',
|
|
54
63
|
'rich',
|
|
55
64
|
'httpx',
|
|
@@ -67,7 +76,35 @@ class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
|
|
|
67
76
|
PydanticUserError,
|
|
68
77
|
]
|
|
69
78
|
|
|
70
|
-
return
|
|
79
|
+
return self.next_worker_plugin.configure_worker(config)
|
|
80
|
+
|
|
81
|
+
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
|
|
82
|
+
return await self.next_client_plugin.connect_service_client(config)
|
|
83
|
+
|
|
84
|
+
async def run_worker(self, worker: Worker) -> None:
|
|
85
|
+
await self.next_worker_plugin.run_worker(worker)
|
|
86
|
+
|
|
87
|
+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
|
|
88
|
+
config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType]
|
|
89
|
+
return self.next_worker_plugin.configure_replayer(config)
|
|
90
|
+
|
|
91
|
+
def run_replayer(
|
|
92
|
+
self,
|
|
93
|
+
replayer: Replayer,
|
|
94
|
+
histories: AsyncIterator[WorkflowHistory],
|
|
95
|
+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
|
|
96
|
+
return self.next_worker_plugin.run_replayer(replayer, histories)
|
|
97
|
+
|
|
98
|
+
def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter:
|
|
99
|
+
if converter and converter.payload_converter_class not in (
|
|
100
|
+
DefaultPayloadConverter,
|
|
101
|
+
PydanticPayloadConverter,
|
|
102
|
+
):
|
|
103
|
+
warnings.warn( # pragma: no cover
|
|
104
|
+
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return pydantic_data_converter
|
|
71
108
|
|
|
72
109
|
|
|
73
110
|
class AgentPlugin(WorkerPlugin):
|
|
@@ -76,8 +113,24 @@ class AgentPlugin(WorkerPlugin):
|
|
|
76
113
|
def __init__(self, agent: TemporalAgent[Any, Any]):
|
|
77
114
|
self.agent = agent
|
|
78
115
|
|
|
116
|
+
def init_worker_plugin(self, next: WorkerPlugin) -> None:
|
|
117
|
+
self.next_worker_plugin = next
|
|
118
|
+
|
|
79
119
|
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
|
|
80
120
|
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
|
|
81
121
|
# Activities are checked for name conflicts by Temporal.
|
|
82
122
|
config['activities'] = [*activities, *self.agent.temporal_activities]
|
|
83
|
-
return
|
|
123
|
+
return self.next_worker_plugin.configure_worker(config)
|
|
124
|
+
|
|
125
|
+
async def run_worker(self, worker: Worker) -> None:
|
|
126
|
+
await self.next_worker_plugin.run_worker(worker)
|
|
127
|
+
|
|
128
|
+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
|
|
129
|
+
return self.next_worker_plugin.configure_replayer(config)
|
|
130
|
+
|
|
131
|
+
def run_replayer(
|
|
132
|
+
self,
|
|
133
|
+
replayer: Replayer,
|
|
134
|
+
histories: AsyncIterator[WorkflowHistory],
|
|
135
|
+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
|
|
136
|
+
return self.next_worker_plugin.run_replayer(replayer, histories)
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
3
|
+
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
|
4
4
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
5
5
|
from contextvars import ContextVar
|
|
6
6
|
from datetime import timedelta
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Literal, overload
|
|
8
8
|
|
|
9
9
|
from pydantic.errors import PydanticUserError
|
|
10
10
|
from pydantic_core import PydanticSerializationError
|
|
@@ -28,6 +28,7 @@ from pydantic_ai.output import OutputDataT, OutputSpec
|
|
|
28
28
|
from pydantic_ai.result import StreamedRunResult
|
|
29
29
|
from pydantic_ai.settings import ModelSettings
|
|
30
30
|
from pydantic_ai.tools import (
|
|
31
|
+
DeferredToolResults,
|
|
31
32
|
Tool,
|
|
32
33
|
ToolFuncEither,
|
|
33
34
|
)
|
|
@@ -196,6 +197,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
196
197
|
*,
|
|
197
198
|
output_type: None = None,
|
|
198
199
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
200
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
199
201
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
200
202
|
deps: AgentDepsT = None,
|
|
201
203
|
model_settings: ModelSettings | None = None,
|
|
@@ -213,6 +215,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
213
215
|
*,
|
|
214
216
|
output_type: OutputSpec[RunOutputDataT],
|
|
215
217
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
218
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
216
219
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
217
220
|
deps: AgentDepsT = None,
|
|
218
221
|
model_settings: ModelSettings | None = None,
|
|
@@ -229,6 +232,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
229
232
|
*,
|
|
230
233
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
231
234
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
235
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
232
236
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
233
237
|
deps: AgentDepsT = None,
|
|
234
238
|
model_settings: ModelSettings | None = None,
|
|
@@ -261,6 +265,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
261
265
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
262
266
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
263
267
|
message_history: History of the conversation so far.
|
|
268
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
264
269
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
265
270
|
deps: Optional dependencies to use for this run.
|
|
266
271
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -283,6 +288,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
283
288
|
user_prompt,
|
|
284
289
|
output_type=output_type,
|
|
285
290
|
message_history=message_history,
|
|
291
|
+
deferred_tool_results=deferred_tool_results,
|
|
286
292
|
model=model,
|
|
287
293
|
deps=deps,
|
|
288
294
|
model_settings=model_settings,
|
|
@@ -301,6 +307,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
301
307
|
*,
|
|
302
308
|
output_type: None = None,
|
|
303
309
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
310
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
304
311
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
305
312
|
deps: AgentDepsT = None,
|
|
306
313
|
model_settings: ModelSettings | None = None,
|
|
@@ -318,6 +325,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
318
325
|
*,
|
|
319
326
|
output_type: OutputSpec[RunOutputDataT],
|
|
320
327
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
328
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
321
329
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
322
330
|
deps: AgentDepsT = None,
|
|
323
331
|
model_settings: ModelSettings | None = None,
|
|
@@ -334,6 +342,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
334
342
|
*,
|
|
335
343
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
336
344
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
345
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
337
346
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
338
347
|
deps: AgentDepsT = None,
|
|
339
348
|
model_settings: ModelSettings | None = None,
|
|
@@ -365,6 +374,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
365
374
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
366
375
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
367
376
|
message_history: History of the conversation so far.
|
|
377
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
368
378
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
369
379
|
deps: Optional dependencies to use for this run.
|
|
370
380
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -386,6 +396,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
386
396
|
user_prompt,
|
|
387
397
|
output_type=output_type,
|
|
388
398
|
message_history=message_history,
|
|
399
|
+
deferred_tool_results=deferred_tool_results,
|
|
389
400
|
model=model,
|
|
390
401
|
deps=deps,
|
|
391
402
|
model_settings=model_settings,
|
|
@@ -404,6 +415,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
404
415
|
*,
|
|
405
416
|
output_type: None = None,
|
|
406
417
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
418
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
407
419
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
408
420
|
deps: AgentDepsT = None,
|
|
409
421
|
model_settings: ModelSettings | None = None,
|
|
@@ -421,6 +433,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
421
433
|
*,
|
|
422
434
|
output_type: OutputSpec[RunOutputDataT],
|
|
423
435
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
436
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
424
437
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
425
438
|
deps: AgentDepsT = None,
|
|
426
439
|
model_settings: ModelSettings | None = None,
|
|
@@ -438,6 +451,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
438
451
|
*,
|
|
439
452
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
440
453
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
454
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
441
455
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
442
456
|
deps: AgentDepsT = None,
|
|
443
457
|
model_settings: ModelSettings | None = None,
|
|
@@ -467,6 +481,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
467
481
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
468
482
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
469
483
|
message_history: History of the conversation so far.
|
|
484
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
470
485
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
471
486
|
deps: Optional dependencies to use for this run.
|
|
472
487
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -490,6 +505,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
490
505
|
user_prompt,
|
|
491
506
|
output_type=output_type,
|
|
492
507
|
message_history=message_history,
|
|
508
|
+
deferred_tool_results=deferred_tool_results,
|
|
493
509
|
model=model,
|
|
494
510
|
deps=deps,
|
|
495
511
|
model_settings=model_settings,
|
|
@@ -509,6 +525,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
509
525
|
*,
|
|
510
526
|
output_type: None = None,
|
|
511
527
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
528
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
512
529
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
513
530
|
deps: AgentDepsT = None,
|
|
514
531
|
model_settings: ModelSettings | None = None,
|
|
@@ -526,6 +543,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
526
543
|
*,
|
|
527
544
|
output_type: OutputSpec[RunOutputDataT],
|
|
528
545
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
546
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
529
547
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
530
548
|
deps: AgentDepsT = None,
|
|
531
549
|
model_settings: ModelSettings | None = None,
|
|
@@ -543,6 +561,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
543
561
|
*,
|
|
544
562
|
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
545
563
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
564
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
546
565
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
547
566
|
deps: AgentDepsT = None,
|
|
548
567
|
model_settings: ModelSettings | None = None,
|
|
@@ -616,6 +635,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
616
635
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
617
636
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
618
637
|
message_history: History of the conversation so far.
|
|
638
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
619
639
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
620
640
|
deps: Optional dependencies to use for this run.
|
|
621
641
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -648,6 +668,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
|
|
|
648
668
|
user_prompt=user_prompt,
|
|
649
669
|
output_type=output_type,
|
|
650
670
|
message_history=message_history,
|
|
671
|
+
deferred_tool_results=deferred_tool_results,
|
|
651
672
|
model=model,
|
|
652
673
|
deps=deps,
|
|
653
674
|
model_settings=model_settings,
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Annotated, Any, Literal, assert_never
|
|
5
6
|
|
|
6
|
-
from pydantic import ConfigDict, with_config
|
|
7
|
+
from pydantic import ConfigDict, Discriminator, with_config
|
|
7
8
|
from temporalio import activity, workflow
|
|
8
9
|
from temporalio.workflow import ActivityConfig
|
|
9
10
|
|
|
10
|
-
from pydantic_ai.exceptions import UserError
|
|
11
|
+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
|
|
11
12
|
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
12
13
|
from pydantic_ai.toolsets import FunctionToolset, ToolsetTool
|
|
13
14
|
from pydantic_ai.toolsets.function import FunctionToolsetTool
|
|
@@ -24,6 +25,34 @@ class _CallToolParams:
|
|
|
24
25
|
serialized_run_context: Any
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
@dataclass
|
|
29
|
+
class _ApprovalRequired:
|
|
30
|
+
kind: Literal['approval_required'] = 'approval_required'
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class _CallDeferred:
|
|
35
|
+
kind: Literal['call_deferred'] = 'call_deferred'
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class _ModelRetry:
|
|
40
|
+
message: str
|
|
41
|
+
kind: Literal['model_retry'] = 'model_retry'
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class _ToolReturn:
|
|
46
|
+
result: Any
|
|
47
|
+
kind: Literal['tool_return'] = 'tool_return'
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
_CallToolResult = Annotated[
|
|
51
|
+
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
|
|
52
|
+
Discriminator('kind'),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
27
56
|
class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
28
57
|
def __init__(
|
|
29
58
|
self,
|
|
@@ -40,7 +69,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
40
69
|
self.tool_activity_config = tool_activity_config
|
|
41
70
|
self.run_context_type = run_context_type
|
|
42
71
|
|
|
43
|
-
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) ->
|
|
72
|
+
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
|
|
44
73
|
name = params.name
|
|
45
74
|
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
46
75
|
try:
|
|
@@ -54,7 +83,15 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
54
83
|
# The tool args will already have been validated into their proper types in the `ToolManager`,
|
|
55
84
|
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
|
|
56
85
|
args_dict = tool.args_validator.validate_python(params.tool_args)
|
|
57
|
-
|
|
86
|
+
try:
|
|
87
|
+
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
|
|
88
|
+
return _ToolReturn(result=result)
|
|
89
|
+
except ApprovalRequired:
|
|
90
|
+
return _ApprovalRequired()
|
|
91
|
+
except CallDeferred:
|
|
92
|
+
return _CallDeferred()
|
|
93
|
+
except ModelRetry as e:
|
|
94
|
+
return _ModelRetry(message=e.message)
|
|
58
95
|
|
|
59
96
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
60
97
|
call_tool_activity.__annotations__['deps'] = deps_type
|
|
@@ -85,7 +122,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
85
122
|
|
|
86
123
|
tool_activity_config = self.activity_config | tool_activity_config
|
|
87
124
|
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
|
|
88
|
-
|
|
125
|
+
result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
89
126
|
activity=self.call_tool_activity,
|
|
90
127
|
args=[
|
|
91
128
|
_CallToolParams(
|
|
@@ -97,3 +134,13 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
97
134
|
],
|
|
98
135
|
**tool_activity_config,
|
|
99
136
|
)
|
|
137
|
+
if isinstance(result, _ApprovalRequired):
|
|
138
|
+
raise ApprovalRequired()
|
|
139
|
+
elif isinstance(result, _CallDeferred):
|
|
140
|
+
raise CallDeferred()
|
|
141
|
+
elif isinstance(result, _ModelRetry):
|
|
142
|
+
raise ModelRetry(result.message)
|
|
143
|
+
elif isinstance(result, _ToolReturn):
|
|
144
|
+
return result.result
|
|
145
|
+
else:
|
|
146
|
+
assert_never(result)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
4
|
|
|
5
5
|
from logfire import Logfire
|
|
6
6
|
from opentelemetry.trace import get_tracer
|
|
@@ -25,10 +25,13 @@ class LogfirePlugin(ClientPlugin):
|
|
|
25
25
|
self.setup_logfire = setup_logfire
|
|
26
26
|
self.metrics = metrics
|
|
27
27
|
|
|
28
|
+
def init_client_plugin(self, next: ClientPlugin) -> None:
|
|
29
|
+
self.next_client_plugin = next
|
|
30
|
+
|
|
28
31
|
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
29
32
|
interceptors = config.get('interceptors', [])
|
|
30
33
|
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
|
|
31
|
-
return
|
|
34
|
+
return self.next_client_plugin.configure_client(config)
|
|
32
35
|
|
|
33
36
|
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
|
|
34
37
|
logfire = self.setup_logfire()
|
|
@@ -45,4 +48,4 @@ class LogfirePlugin(ClientPlugin):
|
|
|
45
48
|
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
|
|
46
49
|
)
|
|
47
50
|
|
|
48
|
-
return await
|
|
51
|
+
return await self.next_client_plugin.connect_service_client(config)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal
|
|
5
6
|
|
|
6
7
|
from pydantic import ConfigDict, with_config
|
|
7
8
|
from temporalio import activity, workflow
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterator, Callable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from pydantic import ConfigDict, with_config
|
|
10
10
|
from temporalio import activity, workflow
|
|
@@ -9,7 +9,7 @@ from pydantic_ai.tools import AgentDepsT, RunContext
|
|
|
9
9
|
class TemporalRunContext(RunContext[AgentDepsT]):
|
|
10
10
|
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
|
|
11
11
|
|
|
12
|
-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `retry` and `run_step` attributes will be available.
|
|
12
|
+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry` and `run_step` attributes will be available.
|
|
13
13
|
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
|
|
14
14
|
"""
|
|
15
15
|
|
|
@@ -40,6 +40,7 @@ class TemporalRunContext(RunContext[AgentDepsT]):
|
|
|
40
40
|
'retries': ctx.retries,
|
|
41
41
|
'tool_call_id': ctx.tool_call_id,
|
|
42
42
|
'tool_name': ctx.tool_name,
|
|
43
|
+
'tool_call_approved': ctx.tool_call_approved,
|
|
43
44
|
'retry': ctx.retry,
|
|
44
45
|
'run_step': ctx.run_step,
|
|
45
46
|
}
|
pydantic_ai/exceptions.py
CHANGED
|
@@ -2,7 +2,9 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from pydantic_core import core_schema
|
|
6
8
|
|
|
7
9
|
if sys.version_info < (3, 11):
|
|
8
10
|
from exceptiongroup import ExceptionGroup as ExceptionGroup # pragma: lax no cover
|
|
@@ -14,6 +16,8 @@ if TYPE_CHECKING:
|
|
|
14
16
|
|
|
15
17
|
__all__ = (
|
|
16
18
|
'ModelRetry',
|
|
19
|
+
'CallDeferred',
|
|
20
|
+
'ApprovalRequired',
|
|
17
21
|
'UserError',
|
|
18
22
|
'AgentRunError',
|
|
19
23
|
'UnexpectedModelBehavior',
|
|
@@ -24,7 +28,7 @@ __all__ = (
|
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
class ModelRetry(Exception):
|
|
27
|
-
"""Exception
|
|
31
|
+
"""Exception to raise when a tool function should be retried.
|
|
28
32
|
|
|
29
33
|
The agent will return the message to the model and ask it to try calling the function/tool again.
|
|
30
34
|
"""
|
|
@@ -36,6 +40,45 @@ class ModelRetry(Exception):
|
|
|
36
40
|
self.message = message
|
|
37
41
|
super().__init__(message)
|
|
38
42
|
|
|
43
|
+
def __eq__(self, other: Any) -> bool:
|
|
44
|
+
return isinstance(other, self.__class__) and other.message == self.message
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> core_schema.CoreSchema:
|
|
48
|
+
"""Pydantic core schema to allow `ModelRetry` to be (de)serialized."""
|
|
49
|
+
schema = core_schema.typed_dict_schema(
|
|
50
|
+
{
|
|
51
|
+
'message': core_schema.typed_dict_field(core_schema.str_schema()),
|
|
52
|
+
'kind': core_schema.typed_dict_field(core_schema.literal_schema(['model-retry'])),
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
return core_schema.no_info_after_validator_function(
|
|
56
|
+
lambda dct: ModelRetry(dct['message']),
|
|
57
|
+
schema,
|
|
58
|
+
serialization=core_schema.plain_serializer_function_ser_schema(
|
|
59
|
+
lambda x: {'message': x.message, 'kind': 'model-retry'},
|
|
60
|
+
return_schema=schema,
|
|
61
|
+
),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class CallDeferred(Exception):
|
|
66
|
+
"""Exception to raise when a tool call should be deferred.
|
|
67
|
+
|
|
68
|
+
See [tools docs](../tools.md#deferred-tools) for more information.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ApprovalRequired(Exception):
|
|
75
|
+
"""Exception to raise when a tool call requires human-in-the-loop approval.
|
|
76
|
+
|
|
77
|
+
See [tools docs](../tools.md#human-in-the-loop-tool-approval) for more information.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
pass
|
|
81
|
+
|
|
39
82
|
|
|
40
83
|
class UserError(RuntimeError):
|
|
41
84
|
"""Error caused by a usage mistake by the application developer — You!"""
|
pydantic_ai/format_prompt.py
CHANGED
|
@@ -72,9 +72,9 @@ class _ToXml:
|
|
|
72
72
|
element.text = self.none_str
|
|
73
73
|
elif isinstance(value, str):
|
|
74
74
|
element.text = value
|
|
75
|
-
elif isinstance(value,
|
|
75
|
+
elif isinstance(value, bytes | bytearray):
|
|
76
76
|
element.text = value.decode(errors='ignore')
|
|
77
|
-
elif isinstance(value,
|
|
77
|
+
elif isinstance(value, bool | int | float):
|
|
78
78
|
element.text = str(value)
|
|
79
79
|
elif isinstance(value, date):
|
|
80
80
|
element.text = value.isoformat()
|
pydantic_ai/mcp.py
CHANGED
|
@@ -5,12 +5,12 @@ import functools
|
|
|
5
5
|
import warnings
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from asyncio import Lock
|
|
8
|
-
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
8
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
|
|
9
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
10
10
|
from dataclasses import field, replace
|
|
11
11
|
from datetime import timedelta
|
|
12
12
|
from pathlib import Path
|
|
13
|
-
from typing import Any
|
|
13
|
+
from typing import Any
|
|
14
14
|
|
|
15
15
|
import anyio
|
|
16
16
|
import httpx
|