pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +84 -17
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +70 -17
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +156 -44
  31. pydantic_ai/models/__init__.py +20 -7
  32. pydantic_ai/models/anthropic.py +10 -17
  33. pydantic_ai/models/bedrock.py +55 -57
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +13 -14
  38. pydantic_ai/models/google.py +19 -5
  39. pydantic_ai/models/groq.py +127 -39
  40. pydantic_ai/models/huggingface.py +5 -5
  41. pydantic_ai/models/instrumented.py +49 -21
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +8 -8
  44. pydantic_ai/models/openai.py +37 -42
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +173 -52
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterator, Sequence
3
+ from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterator, Sequence
4
4
  from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
5
5
  from contextvars import ContextVar
6
+ from dataclasses import dataclass
6
7
  from datetime import timedelta
7
- from typing import Any, Callable, Literal, overload
8
+ from typing import Any, Literal, overload
8
9
 
10
+ from pydantic import ConfigDict, with_config
9
11
  from pydantic.errors import PydanticUserError
10
12
  from pydantic_core import PydanticSerializationError
11
- from temporalio import workflow
13
+ from temporalio import activity, workflow
12
14
  from temporalio.common import RetryPolicy
13
15
  from temporalio.workflow import ActivityConfig
14
16
  from typing_extensions import Never
@@ -21,22 +23,31 @@ from pydantic_ai import (
21
23
  )
22
24
  from pydantic_ai._run_context import AgentDepsT
23
25
  from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
24
- from pydantic_ai.durable_exec.temporal._run_context import TemporalRunContext
25
26
  from pydantic_ai.exceptions import UserError
26
27
  from pydantic_ai.models import Model
27
28
  from pydantic_ai.output import OutputDataT, OutputSpec
28
29
  from pydantic_ai.result import StreamedRunResult
29
30
  from pydantic_ai.settings import ModelSettings
30
31
  from pydantic_ai.tools import (
32
+ DeferredToolResults,
33
+ RunContext,
31
34
  Tool,
32
35
  ToolFuncEither,
33
36
  )
34
37
  from pydantic_ai.toolsets import AbstractToolset
35
38
 
36
39
  from ._model import TemporalModel
40
+ from ._run_context import TemporalRunContext
37
41
  from ._toolset import TemporalWrapperToolset, temporalize_toolset
38
42
 
39
43
 
44
+ @dataclass
45
+ @with_config(ConfigDict(arbitrary_types_allowed=True))
46
+ class _EventStreamHandlerParams:
47
+ event: _messages.AgentStreamEvent
48
+ serialized_run_context: Any
49
+
50
+
40
51
  class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
41
52
  def __init__(
42
53
  self,
@@ -85,6 +96,10 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
85
96
  """
86
97
  super().__init__(wrapped)
87
98
 
99
+ self._name = name
100
+ self._event_stream_handler = event_stream_handler
101
+ self.run_context_type = run_context_type
102
+
88
103
  # start_to_close_timeout is required
89
104
  activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(seconds=60))
90
105
 
@@ -96,13 +111,13 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
96
111
  PydanticUserError.__name__,
97
112
  ]
98
113
  activity_config['retry_policy'] = retry_policy
114
+ self.activity_config = activity_config
99
115
 
100
116
  model_activity_config = model_activity_config or {}
101
117
  toolset_activity_config = toolset_activity_config or {}
102
118
  tool_activity_config = tool_activity_config or {}
103
119
 
104
- self._name = name or wrapped.name
105
- if self._name is None:
120
+ if self.name is None:
106
121
  raise UserError(
107
122
  "An agent needs to have a unique `name` in order to be used with Temporal. The name will be used to identify the agent's activities within the workflow."
108
123
  )
@@ -115,13 +130,33 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
115
130
  'An agent needs to have a `model` in order to be used with Temporal, it cannot be set at agent run time.'
116
131
  )
117
132
 
133
+ async def event_stream_handler_activity(params: _EventStreamHandlerParams, deps: AgentDepsT) -> None:
134
+ # We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
135
+ # and that only ends up calling `event_stream_handler` if it is set.
136
+ assert self.event_stream_handler is not None
137
+
138
+ run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
139
+
140
+ async def streamed_response():
141
+ yield params.event
142
+
143
+ await self.event_stream_handler(run_context, streamed_response())
144
+
145
+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
146
+ event_stream_handler_activity.__annotations__['deps'] = self.deps_type
147
+
148
+ self.event_stream_handler_activity = activity.defn(name=f'{activity_name_prefix}__event_stream_handler')(
149
+ event_stream_handler_activity
150
+ )
151
+ activities.append(self.event_stream_handler_activity)
152
+
118
153
  temporal_model = TemporalModel(
119
154
  wrapped.model,
120
155
  activity_name_prefix=activity_name_prefix,
121
156
  activity_config=activity_config | model_activity_config,
122
157
  deps_type=self.deps_type,
123
- run_context_type=run_context_type,
124
- event_stream_handler=event_stream_handler or wrapped.event_stream_handler,
158
+ run_context_type=self.run_context_type,
159
+ event_stream_handler=self.event_stream_handler,
125
160
  )
126
161
  activities.extend(temporal_model.temporal_activities)
127
162
 
@@ -138,7 +173,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
138
173
  activity_config | toolset_activity_config.get(id, {}),
139
174
  tool_activity_config.get(id, {}),
140
175
  self.deps_type,
141
- run_context_type,
176
+ self.run_context_type,
142
177
  )
143
178
  if isinstance(toolset, TemporalWrapperToolset):
144
179
  activities.extend(toolset.temporal_activities)
@@ -154,7 +189,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
154
189
 
155
190
  @property
156
191
  def name(self) -> str | None:
157
- return self._name
192
+ return self._name or super().name
158
193
 
159
194
  @name.setter
160
195
  def name(self, value: str | None) -> None: # pragma: no cover
@@ -166,6 +201,33 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
166
201
  def model(self) -> Model:
167
202
  return self._model
168
203
 
204
+ @property
205
+ def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None:
206
+ handler = self._event_stream_handler or super().event_stream_handler
207
+ if handler is None:
208
+ return None
209
+ elif workflow.in_workflow():
210
+ return self._call_event_stream_handler_activity
211
+ else:
212
+ return handler
213
+
214
+ async def _call_event_stream_handler_activity(
215
+ self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent]
216
+ ) -> None:
217
+ serialized_run_context = self.run_context_type.serialize_run_context(ctx)
218
+ async for event in stream:
219
+ await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
220
+ activity=self.event_stream_handler_activity,
221
+ args=[
222
+ _EventStreamHandlerParams(
223
+ event=event,
224
+ serialized_run_context=serialized_run_context,
225
+ ),
226
+ ctx.deps,
227
+ ],
228
+ **self.activity_config,
229
+ )
230
+
169
231
  @property
170
232
  def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
171
233
  with self._temporal_overrides():
@@ -196,6 +258,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
196
258
  *,
197
259
  output_type: None = None,
198
260
  message_history: list[_messages.ModelMessage] | None = None,
261
+ deferred_tool_results: DeferredToolResults | None = None,
199
262
  model: models.Model | models.KnownModelName | str | None = None,
200
263
  deps: AgentDepsT = None,
201
264
  model_settings: ModelSettings | None = None,
@@ -213,6 +276,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
213
276
  *,
214
277
  output_type: OutputSpec[RunOutputDataT],
215
278
  message_history: list[_messages.ModelMessage] | None = None,
279
+ deferred_tool_results: DeferredToolResults | None = None,
216
280
  model: models.Model | models.KnownModelName | str | None = None,
217
281
  deps: AgentDepsT = None,
218
282
  model_settings: ModelSettings | None = None,
@@ -229,6 +293,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
229
293
  *,
230
294
  output_type: OutputSpec[RunOutputDataT] | None = None,
231
295
  message_history: list[_messages.ModelMessage] | None = None,
296
+ deferred_tool_results: DeferredToolResults | None = None,
232
297
  model: models.Model | models.KnownModelName | str | None = None,
233
298
  deps: AgentDepsT = None,
234
299
  model_settings: ModelSettings | None = None,
@@ -261,6 +326,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
261
326
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
262
327
  output validators since output validators would expect an argument that matches the agent's output type.
263
328
  message_history: History of the conversation so far.
329
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
264
330
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
265
331
  deps: Optional dependencies to use for this run.
266
332
  model_settings: Optional settings to use for this model's request.
@@ -283,6 +349,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
283
349
  user_prompt,
284
350
  output_type=output_type,
285
351
  message_history=message_history,
352
+ deferred_tool_results=deferred_tool_results,
286
353
  model=model,
287
354
  deps=deps,
288
355
  model_settings=model_settings,
@@ -290,7 +357,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
290
357
  usage=usage,
291
358
  infer_name=infer_name,
292
359
  toolsets=toolsets,
293
- event_stream_handler=event_stream_handler,
360
+ event_stream_handler=event_stream_handler or self.event_stream_handler,
294
361
  **_deprecated_kwargs,
295
362
  )
296
363
 
@@ -301,6 +368,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
301
368
  *,
302
369
  output_type: None = None,
303
370
  message_history: list[_messages.ModelMessage] | None = None,
371
+ deferred_tool_results: DeferredToolResults | None = None,
304
372
  model: models.Model | models.KnownModelName | str | None = None,
305
373
  deps: AgentDepsT = None,
306
374
  model_settings: ModelSettings | None = None,
@@ -318,6 +386,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
318
386
  *,
319
387
  output_type: OutputSpec[RunOutputDataT],
320
388
  message_history: list[_messages.ModelMessage] | None = None,
389
+ deferred_tool_results: DeferredToolResults | None = None,
321
390
  model: models.Model | models.KnownModelName | str | None = None,
322
391
  deps: AgentDepsT = None,
323
392
  model_settings: ModelSettings | None = None,
@@ -334,6 +403,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
334
403
  *,
335
404
  output_type: OutputSpec[RunOutputDataT] | None = None,
336
405
  message_history: list[_messages.ModelMessage] | None = None,
406
+ deferred_tool_results: DeferredToolResults | None = None,
337
407
  model: models.Model | models.KnownModelName | str | None = None,
338
408
  deps: AgentDepsT = None,
339
409
  model_settings: ModelSettings | None = None,
@@ -365,6 +435,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
365
435
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
366
436
  output validators since output validators would expect an argument that matches the agent's output type.
367
437
  message_history: History of the conversation so far.
438
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
368
439
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
369
440
  deps: Optional dependencies to use for this run.
370
441
  model_settings: Optional settings to use for this model's request.
@@ -386,6 +457,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
386
457
  user_prompt,
387
458
  output_type=output_type,
388
459
  message_history=message_history,
460
+ deferred_tool_results=deferred_tool_results,
389
461
  model=model,
390
462
  deps=deps,
391
463
  model_settings=model_settings,
@@ -404,6 +476,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
404
476
  *,
405
477
  output_type: None = None,
406
478
  message_history: list[_messages.ModelMessage] | None = None,
479
+ deferred_tool_results: DeferredToolResults | None = None,
407
480
  model: models.Model | models.KnownModelName | str | None = None,
408
481
  deps: AgentDepsT = None,
409
482
  model_settings: ModelSettings | None = None,
@@ -421,6 +494,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
421
494
  *,
422
495
  output_type: OutputSpec[RunOutputDataT],
423
496
  message_history: list[_messages.ModelMessage] | None = None,
497
+ deferred_tool_results: DeferredToolResults | None = None,
424
498
  model: models.Model | models.KnownModelName | str | None = None,
425
499
  deps: AgentDepsT = None,
426
500
  model_settings: ModelSettings | None = None,
@@ -438,6 +512,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
438
512
  *,
439
513
  output_type: OutputSpec[RunOutputDataT] | None = None,
440
514
  message_history: list[_messages.ModelMessage] | None = None,
515
+ deferred_tool_results: DeferredToolResults | None = None,
441
516
  model: models.Model | models.KnownModelName | str | None = None,
442
517
  deps: AgentDepsT = None,
443
518
  model_settings: ModelSettings | None = None,
@@ -467,6 +542,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
467
542
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
468
543
  output validators since output validators would expect an argument that matches the agent's output type.
469
544
  message_history: History of the conversation so far.
545
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
470
546
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
471
547
  deps: Optional dependencies to use for this run.
472
548
  model_settings: Optional settings to use for this model's request.
@@ -490,6 +566,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
490
566
  user_prompt,
491
567
  output_type=output_type,
492
568
  message_history=message_history,
569
+ deferred_tool_results=deferred_tool_results,
493
570
  model=model,
494
571
  deps=deps,
495
572
  model_settings=model_settings,
@@ -509,6 +586,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
509
586
  *,
510
587
  output_type: None = None,
511
588
  message_history: list[_messages.ModelMessage] | None = None,
589
+ deferred_tool_results: DeferredToolResults | None = None,
512
590
  model: models.Model | models.KnownModelName | str | None = None,
513
591
  deps: AgentDepsT = None,
514
592
  model_settings: ModelSettings | None = None,
@@ -526,6 +604,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
526
604
  *,
527
605
  output_type: OutputSpec[RunOutputDataT],
528
606
  message_history: list[_messages.ModelMessage] | None = None,
607
+ deferred_tool_results: DeferredToolResults | None = None,
529
608
  model: models.Model | models.KnownModelName | str | None = None,
530
609
  deps: AgentDepsT = None,
531
610
  model_settings: ModelSettings | None = None,
@@ -543,6 +622,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
543
622
  *,
544
623
  output_type: OutputSpec[RunOutputDataT] | None = None,
545
624
  message_history: list[_messages.ModelMessage] | None = None,
625
+ deferred_tool_results: DeferredToolResults | None = None,
546
626
  model: models.Model | models.KnownModelName | str | None = None,
547
627
  deps: AgentDepsT = None,
548
628
  model_settings: ModelSettings | None = None,
@@ -616,6 +696,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
616
696
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
617
697
  output validators since output validators would expect an argument that matches the agent's output type.
618
698
  message_history: History of the conversation so far.
699
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
619
700
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
620
701
  deps: Optional dependencies to use for this run.
621
702
  model_settings: Optional settings to use for this model's request.
@@ -648,6 +729,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
648
729
  user_prompt=user_prompt,
649
730
  output_type=output_type,
650
731
  message_history=message_history,
732
+ deferred_tool_results=deferred_tool_results,
651
733
  model=model,
652
734
  deps=deps,
653
735
  model_settings=model_settings,
@@ -1,13 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections.abc import Callable
3
4
  from dataclasses import dataclass
4
- from typing import Any, Callable, Literal
5
+ from typing import Annotated, Any, Literal, assert_never
5
6
 
6
- from pydantic import ConfigDict, with_config
7
+ from pydantic import ConfigDict, Discriminator, with_config
7
8
  from temporalio import activity, workflow
8
9
  from temporalio.workflow import ActivityConfig
9
10
 
10
- from pydantic_ai.exceptions import UserError
11
+ from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
11
12
  from pydantic_ai.tools import AgentDepsT, RunContext
12
13
  from pydantic_ai.toolsets import FunctionToolset, ToolsetTool
13
14
  from pydantic_ai.toolsets.function import FunctionToolsetTool
@@ -24,6 +25,34 @@ class _CallToolParams:
24
25
  serialized_run_context: Any
25
26
 
26
27
 
28
+ @dataclass
29
+ class _ApprovalRequired:
30
+ kind: Literal['approval_required'] = 'approval_required'
31
+
32
+
33
+ @dataclass
34
+ class _CallDeferred:
35
+ kind: Literal['call_deferred'] = 'call_deferred'
36
+
37
+
38
+ @dataclass
39
+ class _ModelRetry:
40
+ message: str
41
+ kind: Literal['model_retry'] = 'model_retry'
42
+
43
+
44
+ @dataclass
45
+ class _ToolReturn:
46
+ result: Any
47
+ kind: Literal['tool_return'] = 'tool_return'
48
+
49
+
50
+ _CallToolResult = Annotated[
51
+ _ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
52
+ Discriminator('kind'),
53
+ ]
54
+
55
+
27
56
  class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
28
57
  def __init__(
29
58
  self,
@@ -40,7 +69,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
40
69
  self.tool_activity_config = tool_activity_config
41
70
  self.run_context_type = run_context_type
42
71
 
43
- async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> Any:
72
+ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
44
73
  name = params.name
45
74
  ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
46
75
  try:
@@ -54,7 +83,15 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
54
83
  # The tool args will already have been validated into their proper types in the `ToolManager`,
55
84
  # but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
56
85
  args_dict = tool.args_validator.validate_python(params.tool_args)
57
- return await self.wrapped.call_tool(name, args_dict, ctx, tool)
86
+ try:
87
+ result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
88
+ return _ToolReturn(result=result)
89
+ except ApprovalRequired:
90
+ return _ApprovalRequired()
91
+ except CallDeferred:
92
+ return _CallDeferred()
93
+ except ModelRetry as e:
94
+ return _ModelRetry(message=e.message)
58
95
 
59
96
  # Set type hint explicitly so that Temporal can take care of serialization and deserialization
60
97
  call_tool_activity.__annotations__['deps'] = deps_type
@@ -85,7 +122,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
85
122
 
86
123
  tool_activity_config = self.activity_config | tool_activity_config
87
124
  serialized_run_context = self.run_context_type.serialize_run_context(ctx)
88
- return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
125
+ result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
89
126
  activity=self.call_tool_activity,
90
127
  args=[
91
128
  _CallToolParams(
@@ -97,3 +134,13 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
97
134
  ],
98
135
  **tool_activity_config,
99
136
  )
137
+ if isinstance(result, _ApprovalRequired):
138
+ raise ApprovalRequired()
139
+ elif isinstance(result, _CallDeferred):
140
+ raise CallDeferred()
141
+ elif isinstance(result, _ModelRetry):
142
+ raise ModelRetry(result.message)
143
+ elif isinstance(result, _ToolReturn):
144
+ return result.result
145
+ else:
146
+ assert_never(result)
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Callable
3
+ from collections.abc import Callable
4
4
 
5
5
  from logfire import Logfire
6
6
  from opentelemetry.trace import get_tracer
@@ -25,10 +25,13 @@ class LogfirePlugin(ClientPlugin):
25
25
  self.setup_logfire = setup_logfire
26
26
  self.metrics = metrics
27
27
 
28
+ def init_client_plugin(self, next: ClientPlugin) -> None:
29
+ self.next_client_plugin = next
30
+
28
31
  def configure_client(self, config: ClientConfig) -> ClientConfig:
29
32
  interceptors = config.get('interceptors', [])
30
33
  config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
31
- return super().configure_client(config)
34
+ return self.next_client_plugin.configure_client(config)
32
35
 
33
36
  async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
34
37
  logfire = self.setup_logfire()
@@ -45,4 +48,4 @@ class LogfirePlugin(ClientPlugin):
45
48
  telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
46
49
  )
47
50
 
48
- return await super().connect_service_client(config)
51
+ return await self.next_client_plugin.connect_service_client(config)
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections.abc import Callable
3
4
  from dataclasses import dataclass
4
- from typing import Any, Callable, Literal
5
+ from typing import Any, Literal
5
6
 
6
7
  from pydantic import ConfigDict, with_config
7
8
  from temporalio import activity, workflow
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import AsyncIterator
3
+ from collections.abc import AsyncIterator, Callable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass
6
6
  from datetime import datetime
7
- from typing import Any, Callable
7
+ from typing import Any
8
8
 
9
9
  from pydantic import ConfigDict, with_config
10
10
  from temporalio import activity, workflow
@@ -9,7 +9,7 @@ from pydantic_ai.tools import AgentDepsT, RunContext
9
9
  class TemporalRunContext(RunContext[AgentDepsT]):
10
10
  """The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
11
11
 
12
- By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `retry` and `run_step` attributes will be available.
12
+ By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry` and `run_step` attributes will be available.
13
13
  To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
14
14
  """
15
15
 
@@ -40,6 +40,7 @@ class TemporalRunContext(RunContext[AgentDepsT]):
40
40
  'retries': ctx.retries,
41
41
  'tool_call_id': ctx.tool_call_id,
42
42
  'tool_name': ctx.tool_name,
43
+ 'tool_call_approved': ctx.tool_call_approved,
43
44
  'retry': ctx.retry,
44
45
  'run_step': ctx.run_step,
45
46
  }
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import Any, Callable, Literal
4
+ from collections.abc import Callable
5
+ from typing import Any, Literal
5
6
 
6
7
  from temporalio.workflow import ActivityConfig
7
8
 
pydantic_ai/exceptions.py CHANGED
@@ -2,7 +2,9 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import json
4
4
  import sys
5
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from pydantic_core import core_schema
6
8
 
7
9
  if sys.version_info < (3, 11):
8
10
  from exceptiongroup import ExceptionGroup as ExceptionGroup # pragma: lax no cover
@@ -14,6 +16,8 @@ if TYPE_CHECKING:
14
16
 
15
17
  __all__ = (
16
18
  'ModelRetry',
19
+ 'CallDeferred',
20
+ 'ApprovalRequired',
17
21
  'UserError',
18
22
  'AgentRunError',
19
23
  'UnexpectedModelBehavior',
@@ -24,7 +28,7 @@ __all__ = (
24
28
 
25
29
 
26
30
  class ModelRetry(Exception):
27
- """Exception raised when a tool function should be retried.
31
+ """Exception to raise when a tool function should be retried.
28
32
 
29
33
  The agent will return the message to the model and ask it to try calling the function/tool again.
30
34
  """
@@ -36,6 +40,45 @@ class ModelRetry(Exception):
36
40
  self.message = message
37
41
  super().__init__(message)
38
42
 
43
+ def __eq__(self, other: Any) -> bool:
44
+ return isinstance(other, self.__class__) and other.message == self.message
45
+
46
+ @classmethod
47
+ def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> core_schema.CoreSchema:
48
+ """Pydantic core schema to allow `ModelRetry` to be (de)serialized."""
49
+ schema = core_schema.typed_dict_schema(
50
+ {
51
+ 'message': core_schema.typed_dict_field(core_schema.str_schema()),
52
+ 'kind': core_schema.typed_dict_field(core_schema.literal_schema(['model-retry'])),
53
+ }
54
+ )
55
+ return core_schema.no_info_after_validator_function(
56
+ lambda dct: ModelRetry(dct['message']),
57
+ schema,
58
+ serialization=core_schema.plain_serializer_function_ser_schema(
59
+ lambda x: {'message': x.message, 'kind': 'model-retry'},
60
+ return_schema=schema,
61
+ ),
62
+ )
63
+
64
+
65
+ class CallDeferred(Exception):
66
+ """Exception to raise when a tool call should be deferred.
67
+
68
+ See [tools docs](../deferred-tools.md#deferred-tools) for more information.
69
+ """
70
+
71
+ pass
72
+
73
+
74
+ class ApprovalRequired(Exception):
75
+ """Exception to raise when a tool call requires human-in-the-loop approval.
76
+
77
+ See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.
78
+ """
79
+
80
+ pass
81
+
39
82
 
40
83
  class UserError(RuntimeError):
41
84
  """Error caused by a usage mistake by the application developer — You!"""
@@ -72,9 +72,9 @@ class _ToXml:
72
72
  element.text = self.none_str
73
73
  elif isinstance(value, str):
74
74
  element.text = value
75
- elif isinstance(value, (bytes, bytearray)):
75
+ elif isinstance(value, bytes | bytearray):
76
76
  element.text = value.decode(errors='ignore')
77
- elif isinstance(value, (bool, int, float)):
77
+ elif isinstance(value, bool | int | float):
78
78
  element.text = str(value)
79
79
  elif isinstance(value, date):
80
80
  element.text = value.isoformat()