pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -5,14 +5,14 @@ import dataclasses
5
5
  import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
8
+ from copy import deepcopy
8
9
  from types import FrameType
9
10
  from typing import Any, Callable, Generic, cast, final, overload
10
11
 
11
12
  import logfire_api
12
13
  from typing_extensions import TypeVar, deprecated
13
14
 
14
- from pydantic_graph import Graph, GraphRunContext, HistoryStep
15
- from pydantic_graph.nodes import End
15
+ from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext
16
16
 
17
17
  from . import (
18
18
  _agent_graph,
@@ -25,8 +25,7 @@ from . import (
25
25
  result,
26
26
  usage as _usage,
27
27
  )
28
- from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
29
- from .result import ResultDataT
28
+ from .result import FinalResult, ResultDataT, StreamedRunResult
30
29
  from .settings import ModelSettings, merge_model_settings
31
30
  from .tools import (
32
31
  AgentDepsT,
@@ -40,7 +39,24 @@ from .tools import (
40
39
  ToolPrepareFunc,
41
40
  )
42
41
 
43
- __all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
42
+ # Re-exporting like this improves auto-import behavior in PyCharm
43
+ capture_run_messages = _agent_graph.capture_run_messages
44
+ EndStrategy = _agent_graph.EndStrategy
45
+ HandleResponseNode = _agent_graph.HandleResponseNode
46
+ ModelRequestNode = _agent_graph.ModelRequestNode
47
+ UserPromptNode = _agent_graph.UserPromptNode
48
+
49
+
50
+ __all__ = (
51
+ 'Agent',
52
+ 'AgentRun',
53
+ 'AgentRunResult',
54
+ 'capture_run_messages',
55
+ 'EndStrategy',
56
+ 'HandleResponseNode',
57
+ 'ModelRequestNode',
58
+ 'UserPromptNode',
59
+ )
44
60
 
45
61
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
46
62
 
@@ -214,7 +230,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
214
230
  usage_limits: _usage.UsageLimits | None = None,
215
231
  usage: _usage.Usage | None = None,
216
232
  infer_name: bool = True,
217
- ) -> result.RunResult[ResultDataT]: ...
233
+ ) -> AgentRunResult[ResultDataT]: ...
218
234
 
219
235
  @overload
220
236
  async def run(
@@ -229,23 +245,26 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
229
245
  usage_limits: _usage.UsageLimits | None = None,
230
246
  usage: _usage.Usage | None = None,
231
247
  infer_name: bool = True,
232
- ) -> result.RunResult[RunResultDataT]: ...
248
+ ) -> AgentRunResult[RunResultDataT]: ...
233
249
 
234
250
  async def run(
235
251
  self,
236
252
  user_prompt: str,
237
253
  *,
254
+ result_type: type[RunResultDataT] | None = None,
238
255
  message_history: list[_messages.ModelMessage] | None = None,
239
256
  model: models.Model | models.KnownModelName | None = None,
240
257
  deps: AgentDepsT = None,
241
258
  model_settings: ModelSettings | None = None,
242
259
  usage_limits: _usage.UsageLimits | None = None,
243
260
  usage: _usage.Usage | None = None,
244
- result_type: type[RunResultDataT] | None = None,
245
261
  infer_name: bool = True,
246
- ) -> result.RunResult[Any]:
262
+ ) -> AgentRunResult[Any]:
247
263
  """Run the agent with a user prompt in async mode.
248
264
 
265
+ This method builds an internal agent graph (using system prompts, tools and result schemas) and then
266
+ runs the graph to completion. The result of the run is returned.
267
+
249
268
  Example:
250
269
  ```python
251
270
  from pydantic_ai import Agent
@@ -253,15 +272,115 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
253
272
  agent = Agent('openai:gpt-4o')
254
273
 
255
274
  async def main():
256
- result = await agent.run('What is the capital of France?')
257
- print(result.data)
275
+ agent_run = await agent.run('What is the capital of France?')
276
+ print(agent_run.data)
258
277
  #> Paris
259
278
  ```
260
279
 
261
280
  Args:
281
+ user_prompt: User input to start/continue the conversation.
262
282
  result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
263
283
  result validators since result validators would expect an argument that matches the agent's result type.
284
+ message_history: History of the conversation so far.
285
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
286
+ deps: Optional dependencies to use for this run.
287
+ model_settings: Optional settings to use for this model's request.
288
+ usage_limits: Optional limits on model request count or token usage.
289
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
290
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
291
+
292
+ Returns:
293
+ The result of the run.
294
+ """
295
+ if infer_name and self.name is None:
296
+ self._infer_name(inspect.currentframe())
297
+ with self.iter(
298
+ user_prompt=user_prompt,
299
+ result_type=result_type,
300
+ message_history=message_history,
301
+ model=model,
302
+ deps=deps,
303
+ model_settings=model_settings,
304
+ usage_limits=usage_limits,
305
+ usage=usage,
306
+ ) as agent_run:
307
+ async for _ in agent_run:
308
+ pass
309
+
310
+ assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
311
+ return final_result
312
+
313
+ @contextmanager
314
+ def iter(
315
+ self,
316
+ user_prompt: str,
317
+ *,
318
+ result_type: type[RunResultDataT] | None = None,
319
+ message_history: list[_messages.ModelMessage] | None = None,
320
+ model: models.Model | models.KnownModelName | None = None,
321
+ deps: AgentDepsT = None,
322
+ model_settings: ModelSettings | None = None,
323
+ usage_limits: _usage.UsageLimits | None = None,
324
+ usage: _usage.Usage | None = None,
325
+ infer_name: bool = True,
326
+ ) -> Iterator[AgentRun[AgentDepsT, Any]]:
327
+ """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
328
+
329
+ This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an
330
+ `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are
331
+ executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the
332
+ stream of events coming from the execution of tools.
333
+
334
+ The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics,
335
+ and the final result of the run once it has completed.
336
+
337
+ For more details, see the documentation of `AgentRun`.
338
+
339
+ Example:
340
+ ```python
341
+ from pydantic_ai import Agent
342
+
343
+ agent = Agent('openai:gpt-4o')
344
+
345
+ async def main():
346
+ nodes = []
347
+ with agent.iter('What is the capital of France?') as agent_run:
348
+ async for node in agent_run:
349
+ nodes.append(node)
350
+ print(nodes)
351
+ '''
352
+ [
353
+ ModelRequestNode(
354
+ request=ModelRequest(
355
+ parts=[
356
+ UserPromptPart(
357
+ content='What is the capital of France?',
358
+ timestamp=datetime.datetime(...),
359
+ part_kind='user-prompt',
360
+ )
361
+ ],
362
+ kind='request',
363
+ )
364
+ ),
365
+ HandleResponseNode(
366
+ model_response=ModelResponse(
367
+ parts=[TextPart(content='Paris', part_kind='text')],
368
+ model_name='function:model_logic',
369
+ timestamp=datetime.datetime(...),
370
+ kind='response',
371
+ )
372
+ ),
373
+ End(data=FinalResult(data='Paris', tool_name=None)),
374
+ ]
375
+ '''
376
+ print(agent_run.result.data)
377
+ #> Paris
378
+ ```
379
+
380
+ Args:
264
381
  user_prompt: User input to start/continue the conversation.
382
+ result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
383
+ result validators since result validators would expect an argument that matches the agent's result type.
265
384
  message_history: History of the conversation so far.
266
385
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
267
386
  deps: Optional dependencies to use for this run.
@@ -305,54 +424,44 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
305
424
  model_settings = merge_model_settings(self.model_settings, model_settings)
306
425
  usage_limits = usage_limits or _usage.UsageLimits()
307
426
 
308
- with _logfire.span(
427
+ # Build the deps object for the graph
428
+ run_span = _logfire.span(
309
429
  '{agent_name} run {prompt=}',
310
430
  prompt=user_prompt,
311
431
  agent=self,
312
432
  model_name=model_used.model_name if model_used else 'no-model',
313
433
  agent_name=self.name or 'agent',
314
- ) as run_span:
315
- # Build the deps object for the graph
316
- graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
317
- user_deps=deps,
318
- prompt=user_prompt,
319
- new_message_index=new_message_index,
320
- model=model_used,
321
- model_settings=model_settings,
322
- usage_limits=usage_limits,
323
- max_result_retries=self._max_result_retries,
324
- end_strategy=self.end_strategy,
325
- result_schema=result_schema,
326
- result_tools=self._result_schema.tool_defs() if self._result_schema else [],
327
- result_validators=result_validators,
328
- function_tools=self._function_tools,
329
- run_span=run_span,
330
- )
331
-
332
- start_node = _agent_graph.UserPromptNode[AgentDepsT](
333
- user_prompt=user_prompt,
334
- system_prompts=self._system_prompts,
335
- system_prompt_functions=self._system_prompt_functions,
336
- system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
337
- )
338
-
339
- # Actually run
340
- end_result, _ = await graph.run(
341
- start_node,
342
- state=state,
343
- deps=graph_deps,
344
- infer_name=False,
345
- )
346
-
347
- # Build final run result
348
- # We don't do any advanced checking if the data is actually from a final result or not
349
- return result.RunResult(
350
- state.message_history,
351
- new_message_index,
352
- end_result.data,
353
- end_result.tool_name,
354
- state.usage,
355
434
  )
435
+ graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
436
+ user_deps=deps,
437
+ prompt=user_prompt,
438
+ new_message_index=new_message_index,
439
+ model=model_used,
440
+ model_settings=model_settings,
441
+ usage_limits=usage_limits,
442
+ max_result_retries=self._max_result_retries,
443
+ end_strategy=self.end_strategy,
444
+ result_schema=result_schema,
445
+ result_tools=self._result_schema.tool_defs() if self._result_schema else [],
446
+ result_validators=result_validators,
447
+ function_tools=self._function_tools,
448
+ run_span=run_span,
449
+ )
450
+ start_node = _agent_graph.UserPromptNode[AgentDepsT](
451
+ user_prompt=user_prompt,
452
+ system_prompts=self._system_prompts,
453
+ system_prompt_functions=self._system_prompt_functions,
454
+ system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
455
+ )
456
+
457
+ with graph.iter(
458
+ start_node,
459
+ state=state,
460
+ deps=graph_deps,
461
+ infer_name=False,
462
+ span=run_span,
463
+ ) as graph_run:
464
+ yield AgentRun(graph_run)
356
465
 
357
466
  @overload
358
467
  def run_sync(
@@ -366,7 +475,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
366
475
  usage_limits: _usage.UsageLimits | None = None,
367
476
  usage: _usage.Usage | None = None,
368
477
  infer_name: bool = True,
369
- ) -> result.RunResult[ResultDataT]: ...
478
+ ) -> AgentRunResult[ResultDataT]: ...
370
479
 
371
480
  @overload
372
481
  def run_sync(
@@ -381,7 +490,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
381
490
  usage_limits: _usage.UsageLimits | None = None,
382
491
  usage: _usage.Usage | None = None,
383
492
  infer_name: bool = True,
384
- ) -> result.RunResult[RunResultDataT]: ...
493
+ ) -> AgentRunResult[RunResultDataT]: ...
385
494
 
386
495
  def run_sync(
387
496
  self,
@@ -395,8 +504,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
395
504
  usage_limits: _usage.UsageLimits | None = None,
396
505
  usage: _usage.Usage | None = None,
397
506
  infer_name: bool = True,
398
- ) -> result.RunResult[Any]:
399
- """Run the agent with a user prompt synchronously.
507
+ ) -> AgentRunResult[Any]:
508
+ """Synchronously run the agent with a user prompt.
400
509
 
401
510
  This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
402
511
  You therefore can't use this method inside async code or if there's an active event loop.
@@ -413,9 +522,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
413
522
  ```
414
523
 
415
524
  Args:
525
+ user_prompt: User input to start/continue the conversation.
416
526
  result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
417
527
  result validators since result validators would expect an argument that matches the agent's result type.
418
- user_prompt: User input to start/continue the conversation.
419
528
  message_history: History of the conversation so far.
420
529
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
421
530
  deps: Optional dependencies to use for this run.
@@ -474,7 +583,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
474
583
  ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
475
584
 
476
585
  @asynccontextmanager
477
- async def run_stream(
586
+ async def run_stream( # noqa C901
478
587
  self,
479
588
  user_prompt: str,
480
589
  *,
@@ -502,9 +611,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
502
611
  ```
503
612
 
504
613
  Args:
614
+ user_prompt: User input to start/continue the conversation.
505
615
  result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
506
616
  result validators since result validators would expect an argument that matches the agent's result type.
507
- user_prompt: User input to start/continue the conversation.
508
617
  message_history: History of the conversation so far.
509
618
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
510
619
  deps: Optional dependencies to use for this run.
@@ -516,94 +625,104 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
516
625
  Returns:
517
626
  The result of the run.
518
627
  """
628
+ # TODO: We need to deprecate this now that we have the `iter` method.
629
+ # Before that, though, we should add an event for when we reach the final result of the stream.
519
630
  if infer_name and self.name is None:
520
631
  # f_back because `asynccontextmanager` adds one frame
521
632
  if frame := inspect.currentframe(): # pragma: no branch
522
633
  self._infer_name(frame.f_back)
523
- model_used = self._get_model(model)
524
634
 
525
- deps = self._get_deps(deps)
526
- new_message_index = len(message_history) if message_history else 0
527
- result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
528
-
529
- # Build the graph
530
- graph = self._build_stream_graph(result_type)
531
-
532
- # Build the initial state
533
- graph_state = _agent_graph.GraphAgentState(
534
- message_history=message_history[:] if message_history else [],
535
- usage=usage or _usage.Usage(),
536
- retries=0,
537
- run_step=0,
538
- )
539
-
540
- # We consider it a user error if a user tries to restrict the result type while having a result validator that
541
- # may change the result type from the restricted type to something else. Therefore, we consider the following
542
- # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
543
- result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
544
-
545
- # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
546
- # runs. Requires some changes to `Tool` to make them copyable though.
547
- for v in self._function_tools.values():
548
- v.current_retry = 0
549
-
550
- model_settings = merge_model_settings(self.model_settings, model_settings)
551
- usage_limits = usage_limits or _usage.UsageLimits()
552
-
553
- with _logfire.span(
554
- '{agent_name} run stream {prompt=}',
555
- prompt=user_prompt,
556
- agent=self,
557
- model_name=model_used.model_name if model_used else 'no-model',
558
- agent_name=self.name or 'agent',
559
- ) as run_span:
560
- # Build the deps object for the graph
561
- graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
562
- user_deps=deps,
563
- prompt=user_prompt,
564
- new_message_index=new_message_index,
565
- model=model_used,
566
- model_settings=model_settings,
567
- usage_limits=usage_limits,
568
- max_result_retries=self._max_result_retries,
569
- end_strategy=self.end_strategy,
570
- result_schema=result_schema,
571
- result_tools=self._result_schema.tool_defs() if self._result_schema else [],
572
- result_validators=result_validators,
573
- function_tools=self._function_tools,
574
- run_span=run_span,
575
- )
576
-
577
- start_node = _agent_graph.StreamUserPromptNode[AgentDepsT](
578
- user_prompt=user_prompt,
579
- system_prompts=self._system_prompts,
580
- system_prompt_functions=self._system_prompt_functions,
581
- system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
582
- )
583
-
584
- # Actually run
585
- node = start_node
586
- history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = []
635
+ yielded = False
636
+ with self.iter(
637
+ user_prompt,
638
+ result_type=result_type,
639
+ message_history=message_history,
640
+ model=model,
641
+ deps=deps,
642
+ model_settings=model_settings,
643
+ usage_limits=usage_limits,
644
+ usage=usage,
645
+ infer_name=False,
646
+ ) as agent_run:
647
+ first_node = agent_run.next_node # start with the first node
648
+ assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
649
+ node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node)
587
650
  while True:
588
- if isinstance(node, _agent_graph.StreamModelRequestNode):
589
- node = cast(
590
- _agent_graph.StreamModelRequestNode[
591
- AgentDepsT, result.StreamedRunResult[AgentDepsT, RunResultDataT]
592
- ],
593
- node,
594
- )
595
- async with node.run_to_result(GraphRunContext(graph_state, graph_deps)) as r:
596
- if isinstance(r, End):
597
- yield r.data
651
+ if isinstance(node, _agent_graph.ModelRequestNode):
652
+ node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node)
653
+ graph_ctx = agent_run.ctx
654
+ async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
655
+
656
+ async def stream_to_final(
657
+ s: models.StreamedResponse,
658
+ ) -> FinalResult[models.StreamedResponse] | None:
659
+ result_schema = graph_ctx.deps.result_schema
660
+ async for maybe_part_event in streamed_response:
661
+ if isinstance(maybe_part_event, _messages.PartStartEvent):
662
+ new_part = maybe_part_event.part
663
+ if isinstance(new_part, _messages.TextPart):
664
+ if _agent_graph.allow_text_result(result_schema):
665
+ return FinalResult(s, None)
666
+ elif isinstance(new_part, _messages.ToolCallPart):
667
+ if result_schema is not None and (match := result_schema.find_tool([new_part])):
668
+ call, _ = match
669
+ return FinalResult(s, call.tool_name)
670
+ return None
671
+
672
+ final_result_details = await stream_to_final(streamed_response)
673
+ if final_result_details is not None:
674
+ if yielded:
675
+ raise exceptions.AgentRunError('Agent run produced final results')
676
+ yielded = True
677
+
678
+ messages = graph_ctx.state.message_history.copy()
679
+
680
+ async def on_complete() -> None:
681
+ """Called when the stream has completed.
682
+
683
+ The model response will have been added to messages by now
684
+ by `StreamedRunResult._marked_completed`.
685
+ """
686
+ last_message = messages[-1]
687
+ assert isinstance(last_message, _messages.ModelResponse)
688
+ tool_calls = [
689
+ part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
690
+ ]
691
+
692
+ parts: list[_messages.ModelRequestPart] = []
693
+ async for _event in _agent_graph.process_function_tools(
694
+ tool_calls,
695
+ final_result_details.tool_name,
696
+ graph_ctx,
697
+ parts,
698
+ ):
699
+ pass
700
+ # TODO: Should we do something here related to the retry count?
701
+ # Maybe we should move the incrementing of the retry count to where we actually make a request?
702
+ # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
703
+ # ctx.state.increment_retries(ctx.deps.max_result_retries)
704
+ if parts:
705
+ messages.append(_messages.ModelRequest(parts))
706
+
707
+ yield StreamedRunResult(
708
+ messages,
709
+ graph_ctx.deps.new_message_index,
710
+ graph_ctx.deps.usage_limits,
711
+ streamed_response,
712
+ graph_ctx.deps.result_schema,
713
+ _agent_graph.build_run_context(graph_ctx),
714
+ graph_ctx.deps.result_validators,
715
+ final_result_details.tool_name,
716
+ on_complete,
717
+ )
598
718
  break
599
- assert not isinstance(node, End) # the previous line should be hit first
600
- node = await graph.next(
601
- node,
602
- history,
603
- state=graph_state,
604
- deps=graph_deps,
605
- infer_name=False,
606
- )
719
+ next_node = await agent_run.next(node)
720
+ if not isinstance(next_node, BaseNode):
721
+ raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here')
722
+ node = cast(BaseNode[Any, Any, Any], next_node)
723
+
724
+ if not yielded:
725
+ raise exceptions.AgentRunError('Agent run finished without producing a final result')
607
726
 
608
727
  @contextmanager
609
728
  def override(
@@ -1039,14 +1158,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1039
1158
 
1040
1159
  def _build_graph(
1041
1160
  self, result_type: type[RunResultDataT] | None
1042
- ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
1161
+ ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]]:
1043
1162
  return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)
1044
1163
 
1045
- def _build_stream_graph(
1046
- self, result_type: type[RunResultDataT] | None
1047
- ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
1048
- return _agent_graph.build_agent_stream_graph(self.name, self._deps_type, result_type or self.result_type)
1049
-
1050
1164
  def _prepare_result_schema(
1051
1165
  self, result_type: type[RunResultDataT] | None
1052
1166
  ) -> _result.ResultSchema[RunResultDataT] | None:
@@ -1058,3 +1172,314 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1058
1172
  )
1059
1173
  else:
1060
1174
  return self._result_schema # pyright: ignore[reportReturnType]
1175
+
1176
+
1177
+ @dataclasses.dataclass(repr=False)
1178
+ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1179
+ """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
1180
+
1181
+ You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`.
1182
+
1183
+ Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
1184
+ [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
1185
+ becomes available.
1186
+
1187
+ Example:
1188
+ ```python
1189
+ from pydantic_ai import Agent
1190
+
1191
+ agent = Agent('openai:gpt-4o')
1192
+
1193
+ async def main():
1194
+ nodes = []
1195
+ # Iterate through the run, recording each node along the way:
1196
+ with agent.iter('What is the capital of France?') as agent_run:
1197
+ async for node in agent_run:
1198
+ nodes.append(node)
1199
+ print(nodes)
1200
+ '''
1201
+ [
1202
+ ModelRequestNode(
1203
+ request=ModelRequest(
1204
+ parts=[
1205
+ UserPromptPart(
1206
+ content='What is the capital of France?',
1207
+ timestamp=datetime.datetime(...),
1208
+ part_kind='user-prompt',
1209
+ )
1210
+ ],
1211
+ kind='request',
1212
+ )
1213
+ ),
1214
+ HandleResponseNode(
1215
+ model_response=ModelResponse(
1216
+ parts=[TextPart(content='Paris', part_kind='text')],
1217
+ model_name='function:model_logic',
1218
+ timestamp=datetime.datetime(...),
1219
+ kind='response',
1220
+ )
1221
+ ),
1222
+ End(data=FinalResult(data='Paris', tool_name=None)),
1223
+ ]
1224
+ '''
1225
+ print(agent_run.result.data)
1226
+ #> Paris
1227
+ ```
1228
+
1229
+ You can also manually drive the iteration using the [`next`][pydantic_ai.agent.AgentRun.next] method for
1230
+ more granular control.
1231
+ """
1232
+
1233
+ _graph_run: GraphRun[
1234
+ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]
1235
+ ]
1236
+
1237
+ @property
1238
+ def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
1239
+ """The current context of the agent run."""
1240
+ return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]](
1241
+ self._graph_run.state, self._graph_run.deps
1242
+ )
1243
+
1244
+ @property
1245
+ def next_node(
1246
+ self,
1247
+ ) -> (
1248
+ BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]]
1249
+ | End[FinalResult[ResultDataT]]
1250
+ ):
1251
+ """The next node that will be run in the agent graph.
1252
+
1253
+ This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
1254
+ """
1255
+ return self._graph_run.next_node
1256
+
1257
+ @property
1258
+ def result(self) -> AgentRunResult[ResultDataT] | None:
1259
+ """The final result of the run if it has ended, otherwise `None`.
1260
+
1261
+ Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
1262
+ with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult].
1263
+ """
1264
+ graph_run_result = self._graph_run.result
1265
+ if graph_run_result is None:
1266
+ return None
1267
+ return AgentRunResult(
1268
+ graph_run_result.output.data,
1269
+ graph_run_result.output.tool_name,
1270
+ graph_run_result.state,
1271
+ self._graph_run.deps.new_message_index,
1272
+ )
1273
+
1274
+ def __aiter__(
1275
+ self,
1276
+ ) -> AsyncIterator[
1277
+ BaseNode[
1278
+ _agent_graph.GraphAgentState,
1279
+ _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1280
+ FinalResult[ResultDataT],
1281
+ ]
1282
+ | End[FinalResult[ResultDataT]]
1283
+ ]:
1284
+ """Provide async-iteration over the nodes in the agent run."""
1285
+ return self
1286
+
1287
+ async def __anext__(
1288
+ self,
1289
+ ) -> (
1290
+ BaseNode[
1291
+ _agent_graph.GraphAgentState,
1292
+ _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1293
+ FinalResult[ResultDataT],
1294
+ ]
1295
+ | End[FinalResult[ResultDataT]]
1296
+ ):
1297
+ """Advance to the next node automatically based on the last returned node."""
1298
+ return await self._graph_run.__anext__()
1299
+
1300
+ async def next(
1301
+ self,
1302
+ node: BaseNode[
1303
+ _agent_graph.GraphAgentState,
1304
+ _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1305
+ FinalResult[ResultDataT],
1306
+ ],
1307
+ ) -> (
1308
+ BaseNode[
1309
+ _agent_graph.GraphAgentState,
1310
+ _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1311
+ FinalResult[ResultDataT],
1312
+ ]
1313
+ | End[FinalResult[ResultDataT]]
1314
+ ):
1315
+ """Manually drive the agent run by passing in the node you want to run next.
1316
+
1317
+ This lets you inspect or mutate the node before continuing execution, or skip certain nodes
1318
+ under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End]
1319
+ node.
1320
+
1321
+ Example:
1322
+ ```python
1323
+ from pydantic_ai import Agent
1324
+ from pydantic_graph import End
1325
+
1326
+ agent = Agent('openai:gpt-4o')
1327
+
1328
+ async def main():
1329
+ with agent.iter('What is the capital of France?') as agent_run:
1330
+ next_node = agent_run.next_node # start with the first node
1331
+ nodes = [next_node]
1332
+ while not isinstance(next_node, End):
1333
+ next_node = await agent_run.next(next_node)
1334
+ nodes.append(next_node)
1335
+ # Once `next_node` is an End, we've finished:
1336
+ print(nodes)
1337
+ '''
1338
+ [
1339
+ UserPromptNode(
1340
+ user_prompt='What is the capital of France?',
1341
+ system_prompts=(),
1342
+ system_prompt_functions=[],
1343
+ system_prompt_dynamic_functions={},
1344
+ ),
1345
+ ModelRequestNode(
1346
+ request=ModelRequest(
1347
+ parts=[
1348
+ UserPromptPart(
1349
+ content='What is the capital of France?',
1350
+ timestamp=datetime.datetime(...),
1351
+ part_kind='user-prompt',
1352
+ )
1353
+ ],
1354
+ kind='request',
1355
+ )
1356
+ ),
1357
+ HandleResponseNode(
1358
+ model_response=ModelResponse(
1359
+ parts=[TextPart(content='Paris', part_kind='text')],
1360
+ model_name='function:model_logic',
1361
+ timestamp=datetime.datetime(...),
1362
+ kind='response',
1363
+ )
1364
+ ),
1365
+ End(data=FinalResult(data='Paris', tool_name=None)),
1366
+ ]
1367
+ '''
1368
+ print('Final result:', agent_run.result.data)
1369
+ #> Final result: Paris
1370
+ ```
1371
+
1372
+ Args:
1373
+ node: The node to run next in the graph.
1374
+
1375
+ Returns:
1376
+ The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if
1377
+ the run has completed.
1378
+ """
1379
+ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
1380
+ # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
1381
+ return await self._graph_run.next(node)
1382
+
1383
+ def usage(self) -> _usage.Usage:
1384
+ """Get usage statistics for the run so far, including token usage, model requests, and so on."""
1385
+ return self._graph_run.state.usage
1386
+
1387
+ def __repr__(self) -> str:
1388
+ result = self._graph_run.result
1389
+ result_repr = '<run not finished>' if result is None else repr(result.output)
1390
+ return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>'
1391
+
1392
+
1393
+ @dataclasses.dataclass
1394
+ class AgentRunResult(Generic[ResultDataT]):
1395
+ """The final result of an agent run."""
1396
+
1397
+ data: ResultDataT # TODO: rename this to output. I'm putting this off for now mostly to reduce the size of the diff
1398
+
1399
+ _result_tool_name: str | None = dataclasses.field(repr=False)
1400
+ _state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
1401
+ _new_message_index: int = dataclasses.field(repr=False)
1402
+
1403
+ def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
1404
+ """Set return content for the result tool.
1405
+
1406
+ Useful if you want to continue the conversation and want to set the response to the result tool call.
1407
+ """
1408
+ if not self._result_tool_name:
1409
+ raise ValueError('Cannot set result tool return content when the return type is `str`.')
1410
+ messages = deepcopy(self._state.message_history)
1411
+ last_message = messages[-1]
1412
+ for part in last_message.parts:
1413
+ if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name:
1414
+ part.content = return_content
1415
+ return messages
1416
+ raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.')
1417
+
1418
+ def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
1419
+ """Return the history of _messages.
1420
+
1421
+ Args:
1422
+ result_tool_return_content: The return content of the tool call to set in the last message.
1423
+ This provides a convenient way to modify the content of the result tool call if you want to continue
1424
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
1425
+ not be modified.
1426
+
1427
+ Returns:
1428
+ List of messages.
1429
+ """
1430
+ if result_tool_return_content is not None:
1431
+ return self._set_result_tool_return(result_tool_return_content)
1432
+ else:
1433
+ return self._state.message_history
1434
+
1435
+ def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
1436
+ """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes.
1437
+
1438
+ Args:
1439
+ result_tool_return_content: The return content of the tool call to set in the last message.
1440
+ This provides a convenient way to modify the content of the result tool call if you want to continue
1441
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
1442
+ not be modified.
1443
+
1444
+ Returns:
1445
+ JSON bytes representing the messages.
1446
+ """
1447
+ return _messages.ModelMessagesTypeAdapter.dump_json(
1448
+ self.all_messages(result_tool_return_content=result_tool_return_content)
1449
+ )
1450
+
1451
+ def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
1452
+ """Return new messages associated with this run.
1453
+
1454
+ Messages from older runs are excluded.
1455
+
1456
+ Args:
1457
+ result_tool_return_content: The return content of the tool call to set in the last message.
1458
+ This provides a convenient way to modify the content of the result tool call if you want to continue
1459
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
1460
+ not be modified.
1461
+
1462
+ Returns:
1463
+ List of new messages.
1464
+ """
1465
+ return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
1466
+
1467
+ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
1468
+ """Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes.
1469
+
1470
+ Args:
1471
+ result_tool_return_content: The return content of the tool call to set in the last message.
1472
+ This provides a convenient way to modify the content of the result tool call if you want to continue
1473
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
1474
+ not be modified.
1475
+
1476
+ Returns:
1477
+ JSON bytes representing the new messages.
1478
+ """
1479
+ return _messages.ModelMessagesTypeAdapter.dump_json(
1480
+ self.new_messages(result_tool_return_content=result_tool_return_content)
1481
+ )
1482
+
1483
+ def usage(self) -> _usage.Usage:
1484
+ """Return the usage of the whole run."""
1485
+ return self._state.usage