pydantic-graph 1.11.0__tar.gz → 1.11.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/PKG-INFO +1 -1
  2. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/graph.py +178 -166
  3. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/.gitignore +0 -0
  4. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/LICENSE +0 -0
  5. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/README.md +0 -0
  6. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/__init__.py +0 -0
  7. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/_utils.py +0 -0
  8. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/__init__.py +0 -0
  9. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/decision.py +0 -0
  10. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/graph_builder.py +0 -0
  11. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/id_types.py +0 -0
  12. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/join.py +0 -0
  13. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/mermaid.py +0 -0
  14. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/node.py +0 -0
  15. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/node_types.py +0 -0
  16. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/parent_forks.py +0 -0
  17. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/paths.py +0 -0
  18. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/step.py +0 -0
  19. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/util.py +0 -0
  20. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/exceptions.py +0 -0
  21. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/graph.py +0 -0
  22. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/mermaid.py +0 -0
  23. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/nodes.py +0 -0
  24. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/__init__.py +0 -0
  25. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/_utils.py +0 -0
  26. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/file.py +0 -0
  27. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/in_mem.py +0 -0
  28. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pydantic_graph/py.typed +0 -0
  29. {pydantic_graph-1.11.0 → pydantic_graph-1.11.1}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-graph
3
- Version: 1.11.0
3
+ Version: 1.11.1
4
4
  Summary: Graph and state machine library
5
5
  Project-URL: Homepage, https://ai.pydantic.dev/graph/tree/main/pydantic_graph
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -10,7 +10,7 @@ from __future__ import annotations as _annotations
10
10
  import sys
11
11
  import uuid
12
12
  from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
13
- from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager
13
+ from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
14
14
  from dataclasses import dataclass, field
15
15
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
16
16
 
@@ -381,18 +381,32 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
381
381
  run_id = GraphRunID(str(uuid.uuid4()))
382
382
  initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
383
383
  self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
384
- self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](self.graph, self.state, self.deps)
384
+ self._iterator_task_group = create_task_group()
385
+ self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
386
+ self.graph, self.state, self.deps, self._iterator_task_group
387
+ )
385
388
  self._iterator = self._iterator_instance.iter_graph(self._first_task)
386
389
 
387
390
  self.__traceparent = traceparent
391
+ self._async_exit_stack = AsyncExitStack()
388
392
 
389
393
  async def __aenter__(self):
394
+ self._async_exit_stack.enter_context(_unwrap_exception_groups())
395
+ await self._async_exit_stack.enter_async_context(self._iterator_task_group)
396
+ await self._async_exit_stack.enter_async_context(self._iterator_context())
390
397
  return self
391
398
 
392
399
  async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
393
- self._iterator_instance.iter_stream_sender.close()
394
- self._iterator_instance.iter_stream_receiver.close()
395
- await self._iterator.aclose()
400
+ await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
401
+
402
+ @asynccontextmanager
403
+ async def _iterator_context(self):
404
+ try:
405
+ yield
406
+ finally:
407
+ self._iterator_instance.iter_stream_sender.close()
408
+ self._iterator_instance.iter_stream_receiver.close()
409
+ await self._iterator.aclose()
396
410
 
397
411
  @overload
398
412
  def _traceparent(self, *, required: Literal[False]) -> str | None: ...
@@ -495,13 +509,13 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
495
509
  graph: Graph[StateT, DepsT, Any, OutputT]
496
510
  state: StateT
497
511
  deps: DepsT
512
+ task_group: TaskGroup
498
513
 
499
514
  cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
500
515
  active_tasks: dict[TaskID, GraphTask] = field(init=False)
501
516
  active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
502
517
  iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
503
518
  iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
504
- _task_group: TaskGroup | None = field(init=False)
505
519
 
506
520
  def __post_init__(self):
507
521
  self.cancel_scopes = {}
@@ -509,178 +523,176 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
509
523
  self.active_reducers = {}
510
524
  self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
511
525
 
512
- @property
513
- def task_group(self) -> TaskGroup:
514
- if self._task_group is None:
515
- raise RuntimeError("This graph iterator hasn't been started") # pragma: no cover
516
- return self._task_group
517
-
518
526
  async def iter_graph( # noqa C901
519
527
  self, first_task: GraphTask
520
528
  ) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
521
- with _unwrap_exception_groups():
522
- async with self.iter_stream_sender, create_task_group() as self._task_group:
523
- try:
524
- # Fire off the first task
525
- self.active_tasks[first_task.task_id] = first_task
526
- self._handle_execution_request([first_task])
527
-
528
- # Handle task results
529
- async with self.iter_stream_receiver:
530
- while self.active_tasks or self.active_reducers:
531
- async for task_result in self.iter_stream_receiver: # pragma: no branch
532
- if isinstance(task_result.result, JoinItem):
533
- maybe_overridden_result = task_result.result
534
- else:
535
- maybe_overridden_result = yield task_result.result
536
- if isinstance(maybe_overridden_result, EndMarker):
529
+ async with self.iter_stream_sender:
530
+ try:
531
+ # Fire off the first task
532
+ self.active_tasks[first_task.task_id] = first_task
533
+ self._handle_execution_request([first_task])
534
+
535
+ # Handle task results
536
+ async with self.iter_stream_receiver:
537
+ while self.active_tasks or self.active_reducers:
538
+ async for task_result in self.iter_stream_receiver: # pragma: no branch
539
+ if isinstance(task_result.result, JoinItem):
540
+ maybe_overridden_result = task_result.result
541
+ else:
542
+ maybe_overridden_result = yield task_result.result
543
+ if isinstance(maybe_overridden_result, EndMarker):
544
+ # If we got an end marker, this task is definitely done, and we're ready to
545
+ # start cleaning everything up
546
+ await self._finish_task(task_result.source.task_id)
547
+ if self.active_tasks:
548
+ # Cancel the remaining tasks
537
549
  self.task_group.cancel_scope.cancel()
538
- return
539
- elif isinstance(maybe_overridden_result, JoinItem):
540
- result = maybe_overridden_result
541
- parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
542
- for i, x in enumerate(result.fork_stack[::-1]):
543
- if x.fork_id == parent_fork_id:
544
- # For non-final joins (those that are intermediate nodes of other joins),
545
- # preserve the fork stack so downstream joins can still associate with the same fork run
546
- if self.graph.is_final_join(result.join_id):
547
- # Final join: remove the parent fork from the stack
548
- downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
549
- else:
550
- # Non-final join: preserve the fork stack
551
- downstream_fork_stack = result.fork_stack
552
- fork_run_id = x.node_run_id
553
- break
554
- else: # pragma: no cover
555
- raise RuntimeError('Parent fork run not found')
556
-
557
- join_node = self.graph.nodes[result.join_id]
558
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
559
- join_state = self.active_reducers.get((result.join_id, fork_run_id))
560
- if join_state is None:
561
- current = join_node.initial_factory()
562
- join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
563
- current, downstream_fork_stack
564
- )
565
- context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
566
- join_state.current = join_node.reduce(context, join_state.current, result.inputs)
567
- if join_state.cancelled_sibling_tasks:
568
- await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
569
- else:
570
- for new_task in maybe_overridden_result:
571
- self.active_tasks[new_task.task_id] = new_task
572
-
573
- tasks_by_id_values = list(self.active_tasks.values())
574
- join_tasks: list[GraphTask] = []
575
-
576
- for join_id, fork_run_id in self._get_completed_fork_runs(
577
- task_result.source, tasks_by_id_values
578
- ):
579
- join_state = self.active_reducers.pop((join_id, fork_run_id))
580
- join_node = self.graph.nodes[join_id]
581
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
582
- new_tasks = self._handle_non_fork_edges(
583
- join_node, join_state.current, join_state.downstream_fork_stack
550
+ return
551
+ elif isinstance(maybe_overridden_result, JoinItem):
552
+ result = maybe_overridden_result
553
+ parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
554
+ for i, x in enumerate(result.fork_stack[::-1]):
555
+ if x.fork_id == parent_fork_id:
556
+ # For non-final joins (those that are intermediate nodes of other joins),
557
+ # preserve the fork stack so downstream joins can still associate with the same fork run
558
+ if self.graph.is_final_join(result.join_id):
559
+ # Final join: remove the parent fork from the stack
560
+ downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
561
+ else:
562
+ # Non-final join: preserve the fork stack
563
+ downstream_fork_stack = result.fork_stack
564
+ fork_run_id = x.node_run_id
565
+ break
566
+ else: # pragma: no cover
567
+ raise RuntimeError('Parent fork run not found')
568
+
569
+ join_node = self.graph.nodes[result.join_id]
570
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
571
+ join_state = self.active_reducers.get((result.join_id, fork_run_id))
572
+ if join_state is None:
573
+ current = join_node.initial_factory()
574
+ join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
575
+ current, downstream_fork_stack
584
576
  )
585
- join_tasks.extend(new_tasks)
586
- if join_tasks:
587
- for new_task in join_tasks:
588
- self.active_tasks[new_task.task_id] = new_task
589
- self._handle_execution_request(join_tasks)
590
-
591
- if isinstance(maybe_overridden_result, Sequence):
592
- if isinstance(task_result.result, Sequence):
593
- new_task_ids = {t.task_id for t in maybe_overridden_result}
594
- for t in task_result.result:
595
- if t.task_id not in new_task_ids:
596
- await self._finish_task(t.task_id, t.node_id)
597
- self._handle_execution_request(maybe_overridden_result)
598
-
599
- if task_result.source_is_finished:
600
- await self._finish_task(task_result.source.task_id, task_result.source.node_id)
601
-
602
- if not self.active_tasks:
603
- # if there are no active tasks, we'll be waiting forever for the next result..
604
- break
605
-
606
- if self.active_reducers: # pragma: no branch
607
- # In this case, there are no pending tasks. We can therefore finalize all active reducers
608
- # that don't have intermediate joins which are also active reducers. If a join J2 has an
609
- # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
610
- # because it might produce items that feed into J2.
611
- for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
612
- # Check if this join has any intermediate joins that are also active reducers
613
- should_skip = False
614
- intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
615
-
616
- # Get the parent fork for this join to use for comparison
617
- join_parent_fork = self.graph.get_parent_fork(join_id)
618
-
619
- for intermediate_join_id in intermediate_joins:
620
- # Check if the intermediate join is also an active reducer with matching fork run
621
- for (other_join_id, _), other_join_state in self.active_reducers.items():
622
- if other_join_id == intermediate_join_id:
623
- # Check if they share the same fork run for this join's parent fork
624
- # by finding the parent fork's node_run_id in both fork stacks
625
- join_parent_fork_run_id = None
626
- other_parent_fork_run_id = None
627
-
628
- for fsi in join_state.downstream_fork_stack: # pragma: no branch
629
- if fsi.fork_id == join_parent_fork.fork_id:
630
- join_parent_fork_run_id = fsi.node_run_id
631
- break
632
-
633
- for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
634
- if fsi.fork_id == join_parent_fork.fork_id:
635
- other_parent_fork_run_id = fsi.node_run_id
636
- break
637
-
638
- if (
639
- join_parent_fork_run_id
640
- and other_parent_fork_run_id
641
- and join_parent_fork_run_id == other_parent_fork_run_id
642
- ): # pragma: no branch
643
- should_skip = True
577
+ context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
578
+ join_state.current = join_node.reduce(context, join_state.current, result.inputs)
579
+ if join_state.cancelled_sibling_tasks:
580
+ await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
581
+ else:
582
+ for new_task in maybe_overridden_result:
583
+ self.active_tasks[new_task.task_id] = new_task
584
+
585
+ tasks_by_id_values = list(self.active_tasks.values())
586
+ join_tasks: list[GraphTask] = []
587
+
588
+ for join_id, fork_run_id in self._get_completed_fork_runs(
589
+ task_result.source, tasks_by_id_values
590
+ ):
591
+ join_state = self.active_reducers.pop((join_id, fork_run_id))
592
+ join_node = self.graph.nodes[join_id]
593
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
594
+ new_tasks = self._handle_non_fork_edges(
595
+ join_node, join_state.current, join_state.downstream_fork_stack
596
+ )
597
+ join_tasks.extend(new_tasks)
598
+ if join_tasks:
599
+ for new_task in join_tasks:
600
+ self.active_tasks[new_task.task_id] = new_task
601
+ self._handle_execution_request(join_tasks)
602
+
603
+ if isinstance(maybe_overridden_result, Sequence):
604
+ if isinstance(task_result.result, Sequence):
605
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
606
+ for t in task_result.result:
607
+ if t.task_id not in new_task_ids:
608
+ await self._finish_task(t.task_id)
609
+ self._handle_execution_request(maybe_overridden_result)
610
+
611
+ if task_result.source_is_finished:
612
+ await self._finish_task(task_result.source.task_id)
613
+
614
+ if not self.active_tasks:
615
+ # if there are no active tasks, we'll be waiting forever for the next result..
616
+ break
617
+
618
+ if self.active_reducers: # pragma: no branch
619
+ # In this case, there are no pending tasks. We can therefore finalize all active reducers
620
+ # that don't have intermediate joins which are also active reducers. If a join J2 has an
621
+ # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
622
+ # because it might produce items that feed into J2.
623
+ for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
624
+ # Check if this join has any intermediate joins that are also active reducers
625
+ should_skip = False
626
+ intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
627
+
628
+ # Get the parent fork for this join to use for comparison
629
+ join_parent_fork = self.graph.get_parent_fork(join_id)
630
+
631
+ for intermediate_join_id in intermediate_joins:
632
+ # Check if the intermediate join is also an active reducer with matching fork run
633
+ for (other_join_id, _), other_join_state in self.active_reducers.items():
634
+ if other_join_id == intermediate_join_id:
635
+ # Check if they share the same fork run for this join's parent fork
636
+ # by finding the parent fork's node_run_id in both fork stacks
637
+ join_parent_fork_run_id = None
638
+ other_parent_fork_run_id = None
639
+
640
+ for fsi in join_state.downstream_fork_stack: # pragma: no branch
641
+ if fsi.fork_id == join_parent_fork.fork_id:
642
+ join_parent_fork_run_id = fsi.node_run_id
643
+ break
644
+
645
+ for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
646
+ if fsi.fork_id == join_parent_fork.fork_id:
647
+ other_parent_fork_run_id = fsi.node_run_id
644
648
  break
645
- if should_skip:
646
- break
647
649
 
650
+ if (
651
+ join_parent_fork_run_id
652
+ and other_parent_fork_run_id
653
+ and join_parent_fork_run_id == other_parent_fork_run_id
654
+ ): # pragma: no branch
655
+ should_skip = True
656
+ break
648
657
  if should_skip:
649
- continue
650
-
651
- self.active_reducers.pop(
652
- (join_id, fork_run_id)
653
- ) # we're handling it now, so we can pop it
654
- join_node = self.graph.nodes[join_id]
655
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
656
- new_tasks = self._handle_non_fork_edges(
657
- join_node, join_state.current, join_state.downstream_fork_stack
658
- )
659
- maybe_overridden_result = yield new_tasks
660
- if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
661
- # This is theoretically reachable but it would be awkward.
662
- # Probably a better way to get coverage here would be to unify the code pat
663
- # with the other `if isinstance(maybe_overridden_result, EndMarker):`
664
- self.task_group.cancel_scope.cancel()
665
- return
666
- for new_task in maybe_overridden_result:
667
- self.active_tasks[new_task.task_id] = new_task
668
- new_task_ids = {t.task_id for t in maybe_overridden_result}
669
- for t in new_tasks:
670
- # Same note as above about how this is theoretically reachable but we should
671
- # just get coverage by unifying the code paths
672
- if t.task_id not in new_task_ids: # pragma: no cover
673
- await self._finish_task(t.task_id, t.node_id)
674
- self._handle_execution_request(maybe_overridden_result)
675
- except GeneratorExit:
676
- self._task_group.cancel_scope.cancel()
677
- return
658
+ break
659
+
660
+ if should_skip:
661
+ continue
662
+
663
+ self.active_reducers.pop(
664
+ (join_id, fork_run_id)
665
+ ) # we're handling it now, so we can pop it
666
+ join_node = self.graph.nodes[join_id]
667
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
668
+ new_tasks = self._handle_non_fork_edges(
669
+ join_node, join_state.current, join_state.downstream_fork_stack
670
+ )
671
+ maybe_overridden_result = yield new_tasks
672
+ if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
673
+ # This is theoretically reachable but it would be awkward.
674
+ # Probably a better way to get coverage here would be to unify the code pat
675
+ # with the other `if isinstance(maybe_overridden_result, EndMarker):`
676
+ self.task_group.cancel_scope.cancel()
677
+ return
678
+ for new_task in maybe_overridden_result:
679
+ self.active_tasks[new_task.task_id] = new_task
680
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
681
+ for t in new_tasks:
682
+ # Same note as above about how this is theoretically reachable but we should
683
+ # just get coverage by unifying the code paths
684
+ if t.task_id not in new_task_ids: # pragma: no cover
685
+ await self._finish_task(t.task_id)
686
+ self._handle_execution_request(maybe_overridden_result)
687
+ except GeneratorExit:
688
+ self.task_group.cancel_scope.cancel()
689
+ return
678
690
 
679
691
  raise RuntimeError( # pragma: no cover
680
692
  'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
681
693
  )
682
694
 
683
- async def _finish_task(self, task_id: TaskID, node_id: str) -> None:
695
+ async def _finish_task(self, task_id: TaskID) -> None:
684
696
  # node_id is just included for debugging right now
685
697
  scope = self.cancel_scopes.pop(task_id, None)
686
698
  if scope is not None:
@@ -898,7 +910,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
898
910
  else:
899
911
  pass
900
912
  for task_id in task_ids_to_cancel:
901
- await self._finish_task(task_id, 'sibling')
913
+ await self._finish_task(task_id)
902
914
 
903
915
 
904
916
  def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
File without changes