pydantic-ai-slim 1.0.0b1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydantic_ai/_a2a.py CHANGED
@@ -272,7 +272,7 @@ class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT]
272
272
  assert_never(part)
273
273
  return model_parts
274
274
 
275
- def _response_parts_to_a2a(self, parts: list[ModelResponsePart]) -> list[Part]:
275
+ def _response_parts_to_a2a(self, parts: Sequence[ModelResponsePart]) -> list[Part]:
276
276
  """Convert pydantic-ai ModelResponsePart objects to A2A Part objects.
277
277
 
278
278
  This handles the conversion from pydantic-ai's internal response parts to
@@ -2,7 +2,6 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
- import hashlib
6
5
  from collections import defaultdict, deque
7
6
  from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
8
7
  from contextlib import asynccontextmanager, contextmanager
@@ -302,16 +301,21 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
302
301
  if self.system_prompt_dynamic_functions:
303
302
  for msg in messages:
304
303
  if isinstance(msg, _messages.ModelRequest):
305
- for i, part in enumerate(msg.parts):
304
+ reevaluated_message_parts: list[_messages.ModelRequestPart] = []
305
+ for part in msg.parts:
306
306
  if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
307
307
  # Look up the runner by its ref
308
308
  if runner := self.system_prompt_dynamic_functions.get( # pragma: lax no cover
309
309
  part.dynamic_ref
310
310
  ):
311
311
  updated_part_content = await runner.run(run_context)
312
- msg.parts[i] = _messages.SystemPromptPart(
313
- updated_part_content, dynamic_ref=part.dynamic_ref
314
- )
312
+ part = _messages.SystemPromptPart(updated_part_content, dynamic_ref=part.dynamic_ref)
313
+
314
+ reevaluated_message_parts.append(part)
315
+
316
+ # Replace message parts with reevaluated ones to prevent mutating parts list
317
+ if reevaluated_message_parts != msg.parts:
318
+ msg.parts = reevaluated_message_parts
315
319
 
316
320
  async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
317
321
  """Build the initial messages for the conversation."""
@@ -650,13 +654,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
650
654
  )
651
655
 
652
656
 
653
- def multi_modal_content_identifier(identifier: str | bytes) -> str:
654
- """Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
655
- if isinstance(identifier, str):
656
- identifier = identifier.encode('utf-8')
657
- return hashlib.sha1(identifier).hexdigest()[:6]
658
-
659
-
660
657
  async def process_function_tools( # noqa: C901
661
658
  tool_manager: ToolManager[DepsT],
662
659
  tool_calls: list[_messages.ToolCallPart],
@@ -764,6 +761,7 @@ async def process_function_tools( # noqa: C901
764
761
  calls_to_run,
765
762
  deferred_tool_results,
766
763
  ctx.deps.tracer,
764
+ ctx.deps.usage_limits,
767
765
  output_parts,
768
766
  deferred_calls,
769
767
  ):
@@ -810,6 +808,7 @@ async def _call_tools(
810
808
  tool_calls: list[_messages.ToolCallPart],
811
809
  deferred_tool_results: dict[str, DeferredToolResult],
812
810
  tracer: Tracer,
811
+ usage_limits: _usage.UsageLimits | None,
813
812
  output_parts: list[_messages.ModelRequestPart],
814
813
  output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
815
814
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -830,7 +829,7 @@ async def _call_tools(
830
829
  ):
831
830
  tasks = [
832
831
  asyncio.create_task(
833
- _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id)),
832
+ _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits),
834
833
  name=call.tool_name,
835
834
  )
836
835
  for call in tool_calls
@@ -870,14 +869,15 @@ async def _call_tool(
870
869
  tool_manager: ToolManager[DepsT],
871
870
  tool_call: _messages.ToolCallPart,
872
871
  tool_call_result: DeferredToolResult | None,
872
+ usage_limits: _usage.UsageLimits | None,
873
873
  ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
874
874
  try:
875
875
  if tool_call_result is None:
876
- tool_result = await tool_manager.handle_call(tool_call)
876
+ tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
877
877
  elif isinstance(tool_call_result, ToolApproved):
878
878
  if tool_call_result.override_args is not None:
879
879
  tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
880
- tool_result = await tool_manager.handle_call(tool_call)
880
+ tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
881
881
  elif isinstance(tool_call_result, ToolDenied):
882
882
  return _messages.ToolReturnPart(
883
883
  tool_name=tool_call.tool_name,
@@ -915,10 +915,7 @@ async def _call_tool(
915
915
  f'`ToolReturn` should be used directly.'
916
916
  )
917
917
  elif isinstance(content, _messages.MultiModalContent):
918
- if isinstance(content, _messages.BinaryContent):
919
- identifier = content.identifier or multi_modal_content_identifier(content.data)
920
- else:
921
- identifier = multi_modal_content_identifier(content.url)
918
+ identifier = content.identifier
922
919
 
923
920
  return_values.append(f'See file {identifier}')
924
921
  user_contents.extend([f'This is file {identifier}:', content])
@@ -154,6 +154,7 @@ class ModelResponsePartsManager:
154
154
  *,
155
155
  vendor_part_id: Hashable | None,
156
156
  content: str | None = None,
157
+ id: str | None = None,
157
158
  signature: str | None = None,
158
159
  ) -> ModelResponseStreamEvent:
159
160
  """Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate.
@@ -167,6 +168,7 @@ class ModelResponsePartsManager:
167
168
  of thinking. If None, a new part will be created unless the latest part is already
168
169
  a ThinkingPart.
169
170
  content: The thinking content to append to the appropriate ThinkingPart.
171
+ id: An optional id for the thinking part.
170
172
  signature: An optional signature for the thinking content.
171
173
 
172
174
  Returns:
@@ -197,7 +199,7 @@ class ModelResponsePartsManager:
197
199
  if content is not None:
198
200
  # There is no existing thinking part that should be updated, so create a new one
199
201
  new_part_index = len(self._parts)
200
- part = ThinkingPart(content=content, signature=signature)
202
+ part = ThinkingPart(content=content, id=id, signature=signature)
201
203
  if vendor_part_id is not None: # pragma: no branch
202
204
  self._vendor_id_to_part_index[vendor_part_id] = new_part_index
203
205
  self._parts.append(part)
@@ -14,6 +14,7 @@ from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
14
14
  from .messages import ToolCallPart
15
15
  from .tools import ToolDefinition
16
16
  from .toolsets.abstract import AbstractToolset, ToolsetTool
17
+ from .usage import UsageLimits
17
18
 
18
19
 
19
20
  @dataclass
@@ -66,7 +67,11 @@ class ToolManager(Generic[AgentDepsT]):
66
67
  return None
67
68
 
68
69
  async def handle_call(
69
- self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
70
+ self,
71
+ call: ToolCallPart,
72
+ allow_partial: bool = False,
73
+ wrap_validation_errors: bool = True,
74
+ usage_limits: UsageLimits | None = None,
70
75
  ) -> Any:
71
76
  """Handle a tool call by validating the arguments, calling the tool, and handling retries.
72
77
 
@@ -74,13 +79,14 @@ class ToolManager(Generic[AgentDepsT]):
74
79
  call: The tool call part to handle.
75
80
  allow_partial: Whether to allow partial validation of the tool arguments.
76
81
  wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
82
+ usage_limits: Optional usage limits to check before executing tools.
77
83
  """
78
84
  if self.tools is None or self.ctx is None:
79
85
  raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
80
86
 
81
87
  if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
82
- # Output tool calls are not traced
83
- return await self._call_tool(call, allow_partial, wrap_validation_errors)
88
+ # Output tool calls are not traced and not counted
89
+ return await self._call_tool(call, allow_partial, wrap_validation_errors, count_tool_usage=False)
84
90
  else:
85
91
  return await self._call_tool_traced(
86
92
  call,
@@ -88,9 +94,17 @@ class ToolManager(Generic[AgentDepsT]):
88
94
  wrap_validation_errors,
89
95
  self.ctx.tracer,
90
96
  self.ctx.trace_include_content,
97
+ usage_limits,
91
98
  )
92
99
 
93
- async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validation_errors: bool) -> Any:
100
+ async def _call_tool(
101
+ self,
102
+ call: ToolCallPart,
103
+ allow_partial: bool,
104
+ wrap_validation_errors: bool,
105
+ usage_limits: UsageLimits | None = None,
106
+ count_tool_usage: bool = True,
107
+ ) -> Any:
94
108
  if self.tools is None or self.ctx is None:
95
109
  raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
96
110
 
@@ -121,7 +135,15 @@ class ToolManager(Generic[AgentDepsT]):
121
135
  else:
122
136
  args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
123
137
 
124
- return await self.toolset.call_tool(name, args_dict, ctx, tool)
138
+ if usage_limits is not None and count_tool_usage:
139
+ usage_limits.check_before_tool_call(self.ctx.usage)
140
+
141
+ result = await self.toolset.call_tool(name, args_dict, ctx, tool)
142
+
143
+ if count_tool_usage:
144
+ self.ctx.usage.tool_calls += 1
145
+
146
+ return result
125
147
  except (ValidationError, ModelRetry) as e:
126
148
  max_retries = tool.max_retries if tool is not None else 1
127
149
  current_retry = self.ctx.retries.get(name, 0)
@@ -160,6 +182,7 @@ class ToolManager(Generic[AgentDepsT]):
160
182
  wrap_validation_errors: bool,
161
183
  tracer: Tracer,
162
184
  include_content: bool = False,
185
+ usage_limits: UsageLimits | None = None,
163
186
  ) -> Any:
164
187
  """See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
165
188
  span_attributes = {
@@ -189,7 +212,7 @@ class ToolManager(Generic[AgentDepsT]):
189
212
  }
190
213
  with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
191
214
  try:
192
- tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
215
+ tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits)
193
216
  except ToolRetryError as e:
194
217
  part = e.tool_retry
195
218
  if include_content and span.is_recording():
pydantic_ai/ag_ui.py CHANGED
@@ -68,6 +68,9 @@ try:
68
68
  TextMessageContentEvent,
69
69
  TextMessageEndEvent,
70
70
  TextMessageStartEvent,
71
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
72
+ # ThinkingEndEvent,
73
+ # ThinkingStartEvent,
71
74
  ThinkingTextMessageContentEvent,
72
75
  ThinkingTextMessageEndEvent,
73
76
  ThinkingTextMessageStartEvent,
@@ -392,6 +395,12 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
392
395
  if stream_ctx.part_end: # pragma: no branch
393
396
  yield stream_ctx.part_end
394
397
  stream_ctx.part_end = None
398
+ if stream_ctx.thinking:
399
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
400
+ # yield ThinkingEndEvent(
401
+ # type=EventType.THINKING_END,
402
+ # )
403
+ stream_ctx.thinking = False
395
404
  elif isinstance(node, CallToolsNode):
396
405
  async with node.stream(run.ctx) as handle_stream:
397
406
  async for event in handle_stream:
@@ -400,7 +409,7 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
400
409
  yield msg
401
410
 
402
411
 
403
- async def _handle_model_request_event(
412
+ async def _handle_model_request_event( # noqa: C901
404
413
  stream_ctx: _RequestStreamContext,
405
414
  agent_event: ModelResponseStreamEvent,
406
415
  ) -> AsyncIterator[BaseEvent]:
@@ -420,56 +429,70 @@ async def _handle_model_request_event(
420
429
  stream_ctx.part_end = None
421
430
 
422
431
  part = agent_event.part
423
- if isinstance(part, TextPart):
424
- message_id = stream_ctx.new_message_id()
425
- yield TextMessageStartEvent(
426
- message_id=message_id,
427
- )
428
- if part.content: # pragma: no branch
429
- yield TextMessageContentEvent(
430
- message_id=message_id,
432
+ if isinstance(part, ThinkingPart): # pragma: no branch
433
+ if not stream_ctx.thinking:
434
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
435
+ # yield ThinkingStartEvent(
436
+ # type=EventType.THINKING_START,
437
+ # )
438
+ stream_ctx.thinking = True
439
+
440
+ if part.content:
441
+ yield ThinkingTextMessageStartEvent(
442
+ type=EventType.THINKING_TEXT_MESSAGE_START,
443
+ )
444
+ yield ThinkingTextMessageContentEvent(
445
+ type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
431
446
  delta=part.content,
432
447
  )
433
- stream_ctx.part_end = TextMessageEndEvent(
434
- message_id=message_id,
435
- )
436
- elif isinstance(part, ToolCallPart): # pragma: no branch
437
- message_id = stream_ctx.message_id or stream_ctx.new_message_id()
438
- yield ToolCallStartEvent(
439
- tool_call_id=part.tool_call_id,
440
- tool_call_name=part.tool_name,
441
- parent_message_id=message_id,
442
- )
443
- if part.args:
444
- yield ToolCallArgsEvent(
448
+ stream_ctx.part_end = ThinkingTextMessageEndEvent(
449
+ type=EventType.THINKING_TEXT_MESSAGE_END,
450
+ )
451
+ else:
452
+ if stream_ctx.thinking:
453
+ # TODO: Enable once https://github.com/ag-ui-protocol/ag-ui/issues/289 is resolved.
454
+ # yield ThinkingEndEvent(
455
+ # type=EventType.THINKING_END,
456
+ # )
457
+ stream_ctx.thinking = False
458
+
459
+ if isinstance(part, TextPart):
460
+ message_id = stream_ctx.new_message_id()
461
+ yield TextMessageStartEvent(
462
+ message_id=message_id,
463
+ )
464
+ if part.content: # pragma: no branch
465
+ yield TextMessageContentEvent(
466
+ message_id=message_id,
467
+ delta=part.content,
468
+ )
469
+ stream_ctx.part_end = TextMessageEndEvent(
470
+ message_id=message_id,
471
+ )
472
+ elif isinstance(part, ToolCallPart): # pragma: no branch
473
+ message_id = stream_ctx.message_id or stream_ctx.new_message_id()
474
+ yield ToolCallStartEvent(
475
+ tool_call_id=part.tool_call_id,
476
+ tool_call_name=part.tool_name,
477
+ parent_message_id=message_id,
478
+ )
479
+ if part.args:
480
+ yield ToolCallArgsEvent(
481
+ tool_call_id=part.tool_call_id,
482
+ delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
483
+ )
484
+ stream_ctx.part_end = ToolCallEndEvent(
445
485
  tool_call_id=part.tool_call_id,
446
- delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
447
486
  )
448
- stream_ctx.part_end = ToolCallEndEvent(
449
- tool_call_id=part.tool_call_id,
450
- )
451
-
452
- elif isinstance(part, ThinkingPart): # pragma: no branch
453
- yield ThinkingTextMessageStartEvent(
454
- type=EventType.THINKING_TEXT_MESSAGE_START,
455
- )
456
- # Always send the content even if it's empty, as it may be
457
- # used to indicate the start of thinking.
458
- yield ThinkingTextMessageContentEvent(
459
- type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
460
- delta=part.content,
461
- )
462
- stream_ctx.part_end = ThinkingTextMessageEndEvent(
463
- type=EventType.THINKING_TEXT_MESSAGE_END,
464
- )
465
487
 
466
488
  elif isinstance(agent_event, PartDeltaEvent):
467
489
  delta = agent_event.delta
468
490
  if isinstance(delta, TextPartDelta):
469
- yield TextMessageContentEvent(
470
- message_id=stream_ctx.message_id,
471
- delta=delta.content_delta,
472
- )
491
+ if delta.content_delta: # pragma: no branch
492
+ yield TextMessageContentEvent(
493
+ message_id=stream_ctx.message_id,
494
+ delta=delta.content_delta,
495
+ )
473
496
  elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
474
497
  assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
475
498
  yield ToolCallArgsEvent(
@@ -478,6 +501,14 @@ async def _handle_model_request_event(
478
501
  )
479
502
  elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
480
503
  if delta.content_delta: # pragma: no branch
504
+ if not isinstance(stream_ctx.part_end, ThinkingTextMessageEndEvent):
505
+ yield ThinkingTextMessageStartEvent(
506
+ type=EventType.THINKING_TEXT_MESSAGE_START,
507
+ )
508
+ stream_ctx.part_end = ThinkingTextMessageEndEvent(
509
+ type=EventType.THINKING_TEXT_MESSAGE_END,
510
+ )
511
+
481
512
  yield ThinkingTextMessageContentEvent(
482
513
  type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
483
514
  delta=delta.content_delta,
@@ -629,6 +660,7 @@ class _RequestStreamContext:
629
660
 
630
661
  message_id: str = ''
631
662
  part_end: BaseEvent | None = None
663
+ thinking: bool = False
632
664
 
633
665
  def new_message_id(self) -> str:
634
666
  """Generate a new message ID for the request stream.
@@ -4,15 +4,15 @@ import dataclasses
4
4
  import inspect
5
5
  import json
6
6
  import warnings
7
+ from asyncio import Lock
7
8
  from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
8
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
9
10
  from contextvars import ContextVar
10
11
  from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
11
12
 
12
- import anyio
13
13
  from opentelemetry.trace import NoOpTracer, use_span
14
14
  from pydantic.json_schema import GenerateJsonSchema
15
- from typing_extensions import TypeVar, deprecated
15
+ from typing_extensions import Self, TypeVar, deprecated
16
16
 
17
17
  from pydantic_graph import Graph
18
18
 
@@ -157,7 +157,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
157
157
 
158
158
  _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)
159
159
 
160
- _enter_lock: anyio.Lock = dataclasses.field(repr=False)
160
+ _enter_lock: Lock = dataclasses.field(repr=False)
161
161
  _entered_count: int = dataclasses.field(repr=False)
162
162
  _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)
163
163
 
@@ -374,7 +374,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
374
374
  _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
375
375
  ] = ContextVar('_override_tools', default=None)
376
376
 
377
- self._enter_lock = anyio.Lock()
377
+ self._enter_lock = Lock()
378
378
  self._entered_count = 0
379
379
  self._exit_stack = None
380
380
 
@@ -1066,7 +1066,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1066
1066
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1067
1067
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1068
1068
  requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
1069
- See the [tools documentation](../tools.md#human-in-the-loop-tool-approval) for more info.
1069
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
1070
1070
  """
1071
1071
 
1072
1072
  def tool_decorator(
@@ -1165,7 +1165,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1165
1165
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1166
1166
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1167
1167
  requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
1168
- See the [tools documentation](../tools.md#human-in-the-loop-tool-approval) for more info.
1168
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
1169
1169
  """
1170
1170
 
1171
1171
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
@@ -1355,7 +1355,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
1355
1355
 
1356
1356
  return schema # pyright: ignore[reportReturnType]
1357
1357
 
1358
- async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
1358
+ async def __aenter__(self) -> Self:
1359
1359
  """Enter the agent context.
1360
1360
 
1361
1361
  This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import AsyncIterator, Callable, Iterator, Sequence
3
+ from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterator, Sequence
4
4
  from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
5
5
  from contextvars import ContextVar
6
+ from dataclasses import dataclass
6
7
  from datetime import timedelta
7
8
  from typing import Any, Literal, overload
8
9
 
10
+ from pydantic import ConfigDict, with_config
9
11
  from pydantic.errors import PydanticUserError
10
12
  from pydantic_core import PydanticSerializationError
11
- from temporalio import workflow
13
+ from temporalio import activity, workflow
12
14
  from temporalio.common import RetryPolicy
13
15
  from temporalio.workflow import ActivityConfig
14
16
  from typing_extensions import Never
@@ -21,7 +23,6 @@ from pydantic_ai import (
21
23
  )
22
24
  from pydantic_ai._run_context import AgentDepsT
23
25
  from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
24
- from pydantic_ai.durable_exec.temporal._run_context import TemporalRunContext
25
26
  from pydantic_ai.exceptions import UserError
26
27
  from pydantic_ai.models import Model
27
28
  from pydantic_ai.output import OutputDataT, OutputSpec
@@ -29,15 +30,24 @@ from pydantic_ai.result import StreamedRunResult
29
30
  from pydantic_ai.settings import ModelSettings
30
31
  from pydantic_ai.tools import (
31
32
  DeferredToolResults,
33
+ RunContext,
32
34
  Tool,
33
35
  ToolFuncEither,
34
36
  )
35
37
  from pydantic_ai.toolsets import AbstractToolset
36
38
 
37
39
  from ._model import TemporalModel
40
+ from ._run_context import TemporalRunContext
38
41
  from ._toolset import TemporalWrapperToolset, temporalize_toolset
39
42
 
40
43
 
44
+ @dataclass
45
+ @with_config(ConfigDict(arbitrary_types_allowed=True))
46
+ class _EventStreamHandlerParams:
47
+ event: _messages.AgentStreamEvent
48
+ serialized_run_context: Any
49
+
50
+
41
51
  class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
42
52
  def __init__(
43
53
  self,
@@ -86,6 +96,10 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
86
96
  """
87
97
  super().__init__(wrapped)
88
98
 
99
+ self._name = name
100
+ self._event_stream_handler = event_stream_handler
101
+ self.run_context_type = run_context_type
102
+
89
103
  # start_to_close_timeout is required
90
104
  activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(seconds=60))
91
105
 
@@ -97,13 +111,13 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
97
111
  PydanticUserError.__name__,
98
112
  ]
99
113
  activity_config['retry_policy'] = retry_policy
114
+ self.activity_config = activity_config
100
115
 
101
116
  model_activity_config = model_activity_config or {}
102
117
  toolset_activity_config = toolset_activity_config or {}
103
118
  tool_activity_config = tool_activity_config or {}
104
119
 
105
- self._name = name or wrapped.name
106
- if self._name is None:
120
+ if self.name is None:
107
121
  raise UserError(
108
122
  "An agent needs to have a unique `name` in order to be used with Temporal. The name will be used to identify the agent's activities within the workflow."
109
123
  )
@@ -116,13 +130,33 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
116
130
  'An agent needs to have a `model` in order to be used with Temporal, it cannot be set at agent run time.'
117
131
  )
118
132
 
133
+ async def event_stream_handler_activity(params: _EventStreamHandlerParams, deps: AgentDepsT) -> None:
134
+ # We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
135
+ # and that only ends up calling `event_stream_handler` if it is set.
136
+ assert self.event_stream_handler is not None
137
+
138
+ run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
139
+
140
+ async def streamed_response():
141
+ yield params.event
142
+
143
+ await self.event_stream_handler(run_context, streamed_response())
144
+
145
+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
146
+ event_stream_handler_activity.__annotations__['deps'] = self.deps_type
147
+
148
+ self.event_stream_handler_activity = activity.defn(name=f'{activity_name_prefix}__event_stream_handler')(
149
+ event_stream_handler_activity
150
+ )
151
+ activities.append(self.event_stream_handler_activity)
152
+
119
153
  temporal_model = TemporalModel(
120
154
  wrapped.model,
121
155
  activity_name_prefix=activity_name_prefix,
122
156
  activity_config=activity_config | model_activity_config,
123
157
  deps_type=self.deps_type,
124
- run_context_type=run_context_type,
125
- event_stream_handler=event_stream_handler or wrapped.event_stream_handler,
158
+ run_context_type=self.run_context_type,
159
+ event_stream_handler=self.event_stream_handler,
126
160
  )
127
161
  activities.extend(temporal_model.temporal_activities)
128
162
 
@@ -139,7 +173,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
139
173
  activity_config | toolset_activity_config.get(id, {}),
140
174
  tool_activity_config.get(id, {}),
141
175
  self.deps_type,
142
- run_context_type,
176
+ self.run_context_type,
143
177
  )
144
178
  if isinstance(toolset, TemporalWrapperToolset):
145
179
  activities.extend(toolset.temporal_activities)
@@ -155,7 +189,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
155
189
 
156
190
  @property
157
191
  def name(self) -> str | None:
158
- return self._name
192
+ return self._name or super().name
159
193
 
160
194
  @name.setter
161
195
  def name(self, value: str | None) -> None: # pragma: no cover
@@ -167,6 +201,33 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
167
201
  def model(self) -> Model:
168
202
  return self._model
169
203
 
204
+ @property
205
+ def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None:
206
+ handler = self._event_stream_handler or super().event_stream_handler
207
+ if handler is None:
208
+ return None
209
+ elif workflow.in_workflow():
210
+ return self._call_event_stream_handler_activity
211
+ else:
212
+ return handler
213
+
214
+ async def _call_event_stream_handler_activity(
215
+ self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent]
216
+ ) -> None:
217
+ serialized_run_context = self.run_context_type.serialize_run_context(ctx)
218
+ async for event in stream:
219
+ await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
220
+ activity=self.event_stream_handler_activity,
221
+ args=[
222
+ _EventStreamHandlerParams(
223
+ event=event,
224
+ serialized_run_context=serialized_run_context,
225
+ ),
226
+ ctx.deps,
227
+ ],
228
+ **self.activity_config,
229
+ )
230
+
170
231
  @property
171
232
  def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
172
233
  with self._temporal_overrides():
@@ -296,7 +357,7 @@ class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]):
296
357
  usage=usage,
297
358
  infer_name=infer_name,
298
359
  toolsets=toolsets,
299
- event_stream_handler=event_stream_handler,
360
+ event_stream_handler=event_stream_handler or self.event_stream_handler,
300
361
  **_deprecated_kwargs,
301
362
  )
302
363
 
pydantic_ai/exceptions.py CHANGED
@@ -65,7 +65,7 @@ class ModelRetry(Exception):
65
65
  class CallDeferred(Exception):
66
66
  """Exception to raise when a tool call should be deferred.
67
67
 
68
- See [tools docs](../tools.md#deferred-tools) for more information.
68
+ See [tools docs](../deferred-tools.md#deferred-tools) for more information.
69
69
  """
70
70
 
71
71
  pass
@@ -74,7 +74,7 @@ class CallDeferred(Exception):
74
74
  class ApprovalRequired(Exception):
75
75
  """Exception to raise when a tool call requires human-in-the-loop approval.
76
76
 
77
- See [tools docs](../tools.md#human-in-the-loop-tool-approval) for more information.
77
+ See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.
78
78
  """
79
79
 
80
80
  pass