pydantic-graph 1.9.1__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -148,6 +148,12 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
148
148
  parent_forks: dict[JoinID, ParentFork[NodeID]]
149
149
  """Parent fork information for each join node."""
150
150
 
151
+ intermediate_join_nodes: dict[JoinID, set[JoinID]]
152
+ """For each join, the set of other joins that appear between it and its parent fork.
153
+
154
+ Used to determine which joins are "final" (have no other joins as intermediates) and
155
+ which joins should preserve fork stacks when proceeding downstream."""
156
+
151
157
  def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
152
158
  """Get the parent fork information for a join node.
153
159
 
@@ -165,6 +171,24 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
165
171
  raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)')
166
172
  return result
167
173
 
174
+ def is_final_join(self, join_id: JoinID) -> bool:
175
+ """Check if a join is 'final' (has no downstream joins with the same parent fork).
176
+
177
+ A join is non-final if it appears as an intermediate node for another join
178
+ with the same parent fork.
179
+
180
+ Args:
181
+ join_id: The ID of the join node
182
+
183
+ Returns:
184
+ True if the join is final, False if it's non-final
185
+ """
186
+ # Check if this join appears in any other join's intermediate_join_nodes
187
+ for intermediate_joins in self.intermediate_join_nodes.values():
188
+ if join_id in intermediate_joins:
189
+ return False
190
+ return True
191
+
168
192
  async def run(
169
193
  self,
170
194
  *,
@@ -517,7 +541,14 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
517
541
  parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
518
542
  for i, x in enumerate(result.fork_stack[::-1]):
519
543
  if x.fork_id == parent_fork_id:
520
- downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
544
+ # For non-final joins (those that are intermediate nodes of other joins),
545
+ # preserve the fork stack so downstream joins can still associate with the same fork run
546
+ if self.graph.is_final_join(result.join_id):
547
+ # Final join: remove the parent fork from the stack
548
+ downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
549
+ else:
550
+ # Non-final join: preserve the fork stack
551
+ downstream_fork_stack = result.fork_stack
521
552
  fork_run_id = x.node_run_id
522
553
  break
523
554
  else: # pragma: no cover
@@ -535,13 +566,9 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
535
566
  join_state.current = join_node.reduce(context, join_state.current, result.inputs)
536
567
  if join_state.cancelled_sibling_tasks:
537
568
  await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
538
- if task_result.source_is_finished: # pragma: no branch
539
- await self._finish_task(task_result.source.task_id)
540
569
  else:
541
570
  for new_task in maybe_overridden_result:
542
571
  self.active_tasks[new_task.task_id] = new_task
543
- if task_result.source_is_finished:
544
- await self._finish_task(task_result.source.task_id)
545
572
 
546
573
  tasks_by_id_values = list(self.active_tasks.values())
547
574
  join_tasks: list[GraphTask] = []
@@ -566,28 +593,61 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
566
593
  new_task_ids = {t.task_id for t in maybe_overridden_result}
567
594
  for t in task_result.result:
568
595
  if t.task_id not in new_task_ids:
569
- await self._finish_task(t.task_id)
596
+ await self._finish_task(t.task_id, t.node_id)
570
597
  self._handle_execution_request(maybe_overridden_result)
571
598
 
599
+ if task_result.source_is_finished:
600
+ await self._finish_task(task_result.source.task_id, task_result.source.node_id)
601
+
572
602
  if not self.active_tasks:
573
603
  # if there are no active tasks, we'll be waiting forever for the next result..
574
604
  break
575
605
 
576
606
  if self.active_reducers: # pragma: no branch
577
- # In this case, there are no pending tasks. We can therefore finalize all active reducers whose
578
- # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
579
- # deeper reducer could produce new tasks in the "prefix" reducer.)
580
- active_fork_stacks = [
581
- join_state.downstream_fork_stack for join_state in self.active_reducers.values()
582
- ]
607
+ # In this case, there are no pending tasks. We can therefore finalize all active reducers
608
+ # that don't have intermediate joins which are also active reducers. If a join J2 has an
609
+ # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
610
+ # because it might produce items that feed into J2.
583
611
  for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
584
- fork_stack = join_state.downstream_fork_stack
585
- if any(
586
- len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
587
- for afs in active_fork_stacks
588
- ):
589
- # this join_state is a strict prefix for one of the other active join_states
590
- continue # pragma: no cover # It's difficult to cover this
612
+ # Check if this join has any intermediate joins that are also active reducers
613
+ should_skip = False
614
+ intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
615
+
616
+ # Get the parent fork for this join to use for comparison
617
+ join_parent_fork = self.graph.get_parent_fork(join_id)
618
+
619
+ for intermediate_join_id in intermediate_joins:
620
+ # Check if the intermediate join is also an active reducer with matching fork run
621
+ for (other_join_id, _), other_join_state in self.active_reducers.items():
622
+ if other_join_id == intermediate_join_id:
623
+ # Check if they share the same fork run for this join's parent fork
624
+ # by finding the parent fork's node_run_id in both fork stacks
625
+ join_parent_fork_run_id = None
626
+ other_parent_fork_run_id = None
627
+
628
+ for fsi in join_state.downstream_fork_stack: # pragma: no branch
629
+ if fsi.fork_id == join_parent_fork.fork_id:
630
+ join_parent_fork_run_id = fsi.node_run_id
631
+ break
632
+
633
+ for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
634
+ if fsi.fork_id == join_parent_fork.fork_id:
635
+ other_parent_fork_run_id = fsi.node_run_id
636
+ break
637
+
638
+ if (
639
+ join_parent_fork_run_id
640
+ and other_parent_fork_run_id
641
+ and join_parent_fork_run_id == other_parent_fork_run_id
642
+ ): # pragma: no branch
643
+ should_skip = True
644
+ break
645
+ if should_skip:
646
+ break
647
+
648
+ if should_skip:
649
+ continue
650
+
591
651
  self.active_reducers.pop(
592
652
  (join_id, fork_run_id)
593
653
  ) # we're handling it now, so we can pop it
@@ -610,7 +670,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
610
670
  # Same note as above about how this is theoretically reachable but we should
611
671
  # just get coverage by unifying the code paths
612
672
  if t.task_id not in new_task_ids: # pragma: no cover
613
- await self._finish_task(t.task_id)
673
+ await self._finish_task(t.task_id, t.node_id)
614
674
  self._handle_execution_request(maybe_overridden_result)
615
675
  except GeneratorExit:
616
676
  self._task_group.cancel_scope.cancel()
@@ -620,7 +680,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
620
680
  'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
621
681
  )
622
682
 
623
- async def _finish_task(self, task_id: TaskID) -> None:
683
+ async def _finish_task(self, task_id: TaskID, node_id: str) -> None:
684
+ # node_id is just included for debugging right now
624
685
  scope = self.cancel_scopes.pop(task_id, None)
625
686
  if scope is not None:
626
687
  scope.cancel()
@@ -837,7 +898,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
837
898
  else:
838
899
  pass
839
900
  for task_id in task_ids_to_cancel:
840
- await self._finish_task(task_id)
901
+ await self._finish_task(task_id, 'sibling')
841
902
 
842
903
 
843
904
  def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
@@ -658,6 +658,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
658
658
  if validate_graph_structure:
659
659
  _validate_graph_structure(nodes, edges_by_source)
660
660
  parent_forks = _collect_dominating_forks(nodes, edges_by_source)
661
+ intermediate_join_nodes = _compute_intermediate_join_nodes(nodes, parent_forks)
661
662
 
662
663
  return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
663
664
  name=self.name,
@@ -668,6 +669,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
668
669
  nodes=nodes,
669
670
  edges_by_source=edges_by_source,
670
671
  parent_forks=parent_forks,
672
+ intermediate_join_nodes=intermediate_join_nodes,
671
673
  auto_instrument=self.auto_instrument,
672
674
  )
673
675
 
@@ -948,6 +950,40 @@ Join {join.id!r} in this graph has no dominating fork in this graph.""")
948
950
  return dominating_forks
949
951
 
950
952
 
953
+ def _compute_intermediate_join_nodes(
954
+ nodes: dict[NodeID, AnyNode], parent_forks: dict[JoinID, ParentFork[NodeID]]
955
+ ) -> dict[JoinID, set[JoinID]]:
956
+ """Compute which joins have other joins as intermediate nodes.
957
+
958
+ A join J1 is an intermediate node of join J2 if J1 appears in J2's intermediate_nodes
959
+ (as computed relative to J2's parent fork).
960
+
961
+ This information is used to determine:
962
+ 1. Which joins are "final" (have no other joins in their intermediate_nodes)
963
+ 2. When selecting which reducer to proceed with when there are no active tasks
964
+
965
+ Args:
966
+ nodes: All nodes in the graph
967
+ parent_forks: Parent fork information for each join
968
+
969
+ Returns:
970
+ A mapping from each join to the set of joins that are intermediate to it
971
+ """
972
+ intermediate_join_nodes: dict[JoinID, set[JoinID]] = {}
973
+
974
+ for join_id, parent_fork in parent_forks.items():
975
+ intermediate_joins = set[JoinID]()
976
+ for intermediate_node_id in parent_fork.intermediate_nodes:
977
+ # Check if this intermediate node is also a join
978
+ intermediate_node = nodes.get(intermediate_node_id)
979
+ if isinstance(intermediate_node, Join):
980
+ # Add it regardless of whether it has the same parent fork
981
+ intermediate_joins.add(JoinID(intermediate_node_id))
982
+ intermediate_join_nodes[join_id] = intermediate_joins
983
+
984
+ return intermediate_join_nodes
985
+
986
+
951
987
  def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]):
952
988
  node_id_remapping = _build_placeholder_node_id_remapping(nodes)
953
989
  replaced_nodes = {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-graph
3
- Version: 1.9.1
3
+ Version: 1.11.0
4
4
  Summary: Graph and state machine library
5
5
  Project-URL: Homepage, https://ai.pydantic.dev/graph/tree/main/pydantic_graph
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -7,8 +7,8 @@ pydantic_graph/nodes.py,sha256=CkY3lrC6jqZtzwhSRjFzmM69TdFFFrr58XSDU4THKHA,7450
7
7
  pydantic_graph/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  pydantic_graph/beta/__init__.py,sha256=VVmbEFaCSXYHwXqS4pANg4B3cn_c86tT62tW_EXcuyw,751
9
9
  pydantic_graph/beta/decision.py,sha256=x-Ta549b-j5hyBPUWFdwRQDRaJqnBHF1pfBP9L8I3vI,11239
10
- pydantic_graph/beta/graph.py,sha256=X3eEniWjdFNjDn4tGGsmjFBBawg7FAoWOXS6P3vh4-8,37680
11
- pydantic_graph/beta/graph_builder.py,sha256=jvMwL-r2l5qT9_YbVM7N-m3Mv5wEGm_5AOTkfFLxxvw,41753
10
+ pydantic_graph/beta/graph.py,sha256=9ZXdjDP67c-e7DOW20JQ6Lb2bZjQ1Z6Gn48MmdMI1NI,41328
11
+ pydantic_graph/beta/graph_builder.py,sha256=dCw4LePreagujGNtTdCVZfRVWkCs35MpoPEtAncLo5U,43326
12
12
  pydantic_graph/beta/id_types.py,sha256=mIhS3HYvmTWWfkGZmt5UEyedn5Ave424FyzUalP9nsU,2915
13
13
  pydantic_graph/beta/join.py,sha256=rzCumDX_YgaU_a5bisfbjbbOuI3IwSZsCZs9TC0T9E4,8002
14
14
  pydantic_graph/beta/mermaid.py,sha256=Bj8a3CODPcojwT7BnrYqLBKTp0AbA1T3XsTmK2St3v4,7127
@@ -22,7 +22,7 @@ pydantic_graph/persistence/__init__.py,sha256=NLBGvUWhem23EdMHHxtX0XgTS2vyixmuWt
22
22
  pydantic_graph/persistence/_utils.py,sha256=6ySxCc1lFz7bbLUwDLkoZWNqi8VNLBVU4xxJbKI23fQ,2264
23
23
  pydantic_graph/persistence/file.py,sha256=XZy295cGc86HfUl_KuB-e7cECZW3bubiEdyJMVQ1OD0,6906
24
24
  pydantic_graph/persistence/in_mem.py,sha256=MmahaVpdzmDB30Dm3ZfSCZBqgmx6vH4HXdBaWwVF0K0,6799
25
- pydantic_graph-1.9.1.dist-info/METADATA,sha256=QFv7MmAYf6spz9HAwabU6-Xn5gcbXDspiISOuqXwOHA,3894
26
- pydantic_graph-1.9.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
- pydantic_graph-1.9.1.dist-info/licenses/LICENSE,sha256=vA6Jc482lEyBBuGUfD1pYx-cM7jxvLYOxPidZ30t_PQ,1100
28
- pydantic_graph-1.9.1.dist-info/RECORD,,
25
+ pydantic_graph-1.11.0.dist-info/METADATA,sha256=J4gOyw89bTDiuNDjHQ67tlQQ0or9xdLTrDYsAXJ6R7U,3895
26
+ pydantic_graph-1.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
+ pydantic_graph-1.11.0.dist-info/licenses/LICENSE,sha256=vA6Jc482lEyBBuGUfD1pYx-cM7jxvLYOxPidZ30t_PQ,1100
28
+ pydantic_graph-1.11.0.dist-info/RECORD,,