pydantic-graph 1.8.0__tar.gz → 1.13.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/.gitignore +1 -1
  2. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/PKG-INFO +1 -1
  3. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/graph.py +201 -128
  4. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/graph_builder.py +36 -0
  5. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/LICENSE +0 -0
  6. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/README.md +0 -0
  7. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/__init__.py +0 -0
  8. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/_utils.py +0 -0
  9. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/__init__.py +0 -0
  10. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/decision.py +0 -0
  11. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/id_types.py +0 -0
  12. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/join.py +0 -0
  13. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/mermaid.py +0 -0
  14. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/node.py +0 -0
  15. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/node_types.py +0 -0
  16. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/parent_forks.py +0 -0
  17. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/paths.py +0 -0
  18. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/step.py +0 -0
  19. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/beta/util.py +0 -0
  20. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/exceptions.py +0 -0
  21. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/graph.py +0 -0
  22. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/mermaid.py +0 -0
  23. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/nodes.py +0 -0
  24. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/persistence/__init__.py +0 -0
  25. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/persistence/_utils.py +0 -0
  26. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/persistence/file.py +0 -0
  27. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/persistence/in_mem.py +0 -0
  28. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pydantic_graph/py.typed +0 -0
  29. {pydantic_graph-1.8.0 → pydantic_graph-1.13.0}/pyproject.toml +0 -0
@@ -10,7 +10,7 @@ env*/
10
10
  /TODO.md
11
11
  /postgres-data/
12
12
  .DS_Store
13
- examples/pydantic_ai_examples/.chat_app_messages.sqlite
13
+ .chat_app_messages.sqlite
14
14
  .cache/
15
15
  .vscode/
16
16
  /question_graph_history.json
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-graph
3
- Version: 1.8.0
3
+ Version: 1.13.0
4
4
  Summary: Graph and state machine library
5
5
  Project-URL: Homepage, https://ai.pydantic.dev/graph/tree/main/pydantic_graph
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -10,7 +10,7 @@ from __future__ import annotations as _annotations
10
10
  import sys
11
11
  import uuid
12
12
  from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
13
- from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager
13
+ from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
14
14
  from dataclasses import dataclass, field
15
15
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
16
16
 
@@ -148,6 +148,12 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
148
148
  parent_forks: dict[JoinID, ParentFork[NodeID]]
149
149
  """Parent fork information for each join node."""
150
150
 
151
+ intermediate_join_nodes: dict[JoinID, set[JoinID]]
152
+ """For each join, the set of other joins that appear between it and its parent fork.
153
+
154
+ Used to determine which joins are "final" (have no other joins as intermediates) and
155
+ which joins should preserve fork stacks when proceeding downstream."""
156
+
151
157
  def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
152
158
  """Get the parent fork information for a join node.
153
159
 
@@ -165,6 +171,24 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
165
171
  raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)')
166
172
  return result
167
173
 
174
+ def is_final_join(self, join_id: JoinID) -> bool:
175
+ """Check if a join is 'final' (has no downstream joins with the same parent fork).
176
+
177
+ A join is non-final if it appears as an intermediate node for another join
178
+ with the same parent fork.
179
+
180
+ Args:
181
+ join_id: The ID of the join node
182
+
183
+ Returns:
184
+ True if the join is final, False if it's non-final
185
+ """
186
+ # Check if this join appears in any other join's intermediate_join_nodes
187
+ for intermediate_joins in self.intermediate_join_nodes.values():
188
+ if join_id in intermediate_joins:
189
+ return False
190
+ return True
191
+
168
192
  async def run(
169
193
  self,
170
194
  *,
@@ -357,18 +381,32 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
357
381
  run_id = GraphRunID(str(uuid.uuid4()))
358
382
  initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
359
383
  self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
360
- self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](self.graph, self.state, self.deps)
384
+ self._iterator_task_group = create_task_group()
385
+ self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
386
+ self.graph, self.state, self.deps, self._iterator_task_group
387
+ )
361
388
  self._iterator = self._iterator_instance.iter_graph(self._first_task)
362
389
 
363
390
  self.__traceparent = traceparent
391
+ self._async_exit_stack = AsyncExitStack()
364
392
 
365
393
  async def __aenter__(self):
394
+ self._async_exit_stack.enter_context(_unwrap_exception_groups())
395
+ await self._async_exit_stack.enter_async_context(self._iterator_task_group)
396
+ await self._async_exit_stack.enter_async_context(self._iterator_context())
366
397
  return self
367
398
 
368
399
  async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
369
- self._iterator_instance.iter_stream_sender.close()
370
- self._iterator_instance.iter_stream_receiver.close()
371
- await self._iterator.aclose()
400
+ await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
401
+
402
+ @asynccontextmanager
403
+ async def _iterator_context(self):
404
+ try:
405
+ yield
406
+ finally:
407
+ self._iterator_instance.iter_stream_sender.close()
408
+ self._iterator_instance.iter_stream_receiver.close()
409
+ await self._iterator.aclose()
372
410
 
373
411
  @overload
374
412
  def _traceparent(self, *, required: Literal[False]) -> str | None: ...
@@ -471,13 +509,13 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
471
509
  graph: Graph[StateT, DepsT, Any, OutputT]
472
510
  state: StateT
473
511
  deps: DepsT
512
+ task_group: TaskGroup
474
513
 
475
514
  cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
476
515
  active_tasks: dict[TaskID, GraphTask] = field(init=False)
477
516
  active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
478
517
  iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
479
518
  iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
480
- _task_group: TaskGroup | None = field(init=False)
481
519
 
482
520
  def __post_init__(self):
483
521
  self.cancel_scopes = {}
@@ -485,142 +523,177 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
485
523
  self.active_reducers = {}
486
524
  self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
487
525
 
488
- @property
489
- def task_group(self) -> TaskGroup:
490
- if self._task_group is None:
491
- raise RuntimeError("This graph iterator hasn't been started") # pragma: no cover
492
- return self._task_group
493
-
494
526
  async def iter_graph( # noqa C901
495
527
  self, first_task: GraphTask
496
528
  ) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
497
- with _unwrap_exception_groups():
498
- async with self.iter_stream_sender, create_task_group() as self._task_group:
499
- try:
500
- # Fire off the first task
501
- self.active_tasks[first_task.task_id] = first_task
502
- self._handle_execution_request([first_task])
503
-
504
- # Handle task results
505
- async with self.iter_stream_receiver:
506
- while self.active_tasks or self.active_reducers:
507
- async for task_result in self.iter_stream_receiver: # pragma: no branch
508
- if isinstance(task_result.result, JoinItem):
509
- maybe_overridden_result = task_result.result
510
- else:
511
- maybe_overridden_result = yield task_result.result
512
- if isinstance(maybe_overridden_result, EndMarker):
529
+ async with self.iter_stream_sender:
530
+ try:
531
+ # Fire off the first task
532
+ self.active_tasks[first_task.task_id] = first_task
533
+ self._handle_execution_request([first_task])
534
+
535
+ # Handle task results
536
+ async with self.iter_stream_receiver:
537
+ while self.active_tasks or self.active_reducers:
538
+ async for task_result in self.iter_stream_receiver: # pragma: no branch
539
+ if isinstance(task_result.result, JoinItem):
540
+ maybe_overridden_result = task_result.result
541
+ else:
542
+ maybe_overridden_result = yield task_result.result
543
+ if isinstance(maybe_overridden_result, EndMarker):
544
+ # If we got an end marker, this task is definitely done, and we're ready to
545
+ # start cleaning everything up
546
+ await self._finish_task(task_result.source.task_id)
547
+ if self.active_tasks:
548
+ # Cancel the remaining tasks
513
549
  self.task_group.cancel_scope.cancel()
514
- return
515
- elif isinstance(maybe_overridden_result, JoinItem):
516
- result = maybe_overridden_result
517
- parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
518
- for i, x in enumerate(result.fork_stack[::-1]):
519
- if x.fork_id == parent_fork_id:
550
+ return
551
+ elif isinstance(maybe_overridden_result, JoinItem):
552
+ result = maybe_overridden_result
553
+ parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
554
+ for i, x in enumerate(result.fork_stack[::-1]):
555
+ if x.fork_id == parent_fork_id:
556
+ # For non-final joins (those that are intermediate nodes of other joins),
557
+ # preserve the fork stack so downstream joins can still associate with the same fork run
558
+ if self.graph.is_final_join(result.join_id):
559
+ # Final join: remove the parent fork from the stack
520
560
  downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
521
- fork_run_id = x.node_run_id
522
- break
523
- else: # pragma: no cover
524
- raise RuntimeError('Parent fork run not found')
525
-
526
- join_node = self.graph.nodes[result.join_id]
527
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
528
- join_state = self.active_reducers.get((result.join_id, fork_run_id))
529
- if join_state is None:
530
- current = join_node.initial_factory()
531
- join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
532
- current, downstream_fork_stack
533
- )
534
- context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
535
- join_state.current = join_node.reduce(context, join_state.current, result.inputs)
536
- if join_state.cancelled_sibling_tasks:
537
- await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
538
- if task_result.source_is_finished: # pragma: no branch
539
- await self._finish_task(task_result.source.task_id)
540
- else:
541
- for new_task in maybe_overridden_result:
542
- self.active_tasks[new_task.task_id] = new_task
543
- if task_result.source_is_finished:
544
- await self._finish_task(task_result.source.task_id)
545
-
546
- tasks_by_id_values = list(self.active_tasks.values())
547
- join_tasks: list[GraphTask] = []
548
-
549
- for join_id, fork_run_id in self._get_completed_fork_runs(
550
- task_result.source, tasks_by_id_values
551
- ):
552
- join_state = self.active_reducers.pop((join_id, fork_run_id))
553
- join_node = self.graph.nodes[join_id]
554
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
555
- new_tasks = self._handle_non_fork_edges(
556
- join_node, join_state.current, join_state.downstream_fork_stack
557
- )
558
- join_tasks.extend(new_tasks)
559
- if join_tasks:
560
- for new_task in join_tasks:
561
- self.active_tasks[new_task.task_id] = new_task
562
- self._handle_execution_request(join_tasks)
563
-
564
- if isinstance(maybe_overridden_result, Sequence):
565
- if isinstance(task_result.result, Sequence):
566
- new_task_ids = {t.task_id for t in maybe_overridden_result}
567
- for t in task_result.result:
568
- if t.task_id not in new_task_ids:
569
- await self._finish_task(t.task_id)
570
- self._handle_execution_request(maybe_overridden_result)
571
-
572
- if not self.active_tasks:
573
- # if there are no active tasks, we'll be waiting forever for the next result..
574
- break
575
-
576
- if self.active_reducers: # pragma: no branch
577
- # In this case, there are no pending tasks. We can therefore finalize all active reducers whose
578
- # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
579
- # deeper reducer could produce new tasks in the "prefix" reducer.)
580
- active_fork_stacks = [
581
- join_state.downstream_fork_stack for join_state in self.active_reducers.values()
582
- ]
583
- for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
584
- fork_stack = join_state.downstream_fork_stack
585
- if any(
586
- len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
587
- for afs in active_fork_stacks
588
- ):
589
- # this join_state is a strict prefix for one of the other active join_states
590
- continue # pragma: no cover # It's difficult to cover this
591
- self.active_reducers.pop(
592
- (join_id, fork_run_id)
593
- ) # we're handling it now, so we can pop it
594
- join_node = self.graph.nodes[join_id]
595
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
596
- new_tasks = self._handle_non_fork_edges(
597
- join_node, join_state.current, join_state.downstream_fork_stack
561
+ else:
562
+ # Non-final join: preserve the fork stack
563
+ downstream_fork_stack = result.fork_stack
564
+ fork_run_id = x.node_run_id
565
+ break
566
+ else: # pragma: no cover
567
+ raise RuntimeError('Parent fork run not found')
568
+
569
+ join_node = self.graph.nodes[result.join_id]
570
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
571
+ join_state = self.active_reducers.get((result.join_id, fork_run_id))
572
+ if join_state is None:
573
+ current = join_node.initial_factory()
574
+ join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
575
+ current, downstream_fork_stack
598
576
  )
599
- maybe_overridden_result = yield new_tasks
600
- if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
601
- # This is theoretically reachable but it would be awkward.
602
- # Probably a better way to get coverage here would be to unify the code pat
603
- # with the other `if isinstance(maybe_overridden_result, EndMarker):`
604
- self.task_group.cancel_scope.cancel()
605
- return
606
- for new_task in maybe_overridden_result:
607
- self.active_tasks[new_task.task_id] = new_task
577
+ context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
578
+ join_state.current = join_node.reduce(context, join_state.current, result.inputs)
579
+ if join_state.cancelled_sibling_tasks:
580
+ await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
581
+ else:
582
+ for new_task in maybe_overridden_result:
583
+ self.active_tasks[new_task.task_id] = new_task
584
+
585
+ tasks_by_id_values = list(self.active_tasks.values())
586
+ join_tasks: list[GraphTask] = []
587
+
588
+ for join_id, fork_run_id in self._get_completed_fork_runs(
589
+ task_result.source, tasks_by_id_values
590
+ ):
591
+ join_state = self.active_reducers.pop((join_id, fork_run_id))
592
+ join_node = self.graph.nodes[join_id]
593
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
594
+ new_tasks = self._handle_non_fork_edges(
595
+ join_node, join_state.current, join_state.downstream_fork_stack
596
+ )
597
+ join_tasks.extend(new_tasks)
598
+ if join_tasks:
599
+ for new_task in join_tasks:
600
+ self.active_tasks[new_task.task_id] = new_task
601
+ self._handle_execution_request(join_tasks)
602
+
603
+ if isinstance(maybe_overridden_result, Sequence):
604
+ if isinstance(task_result.result, Sequence):
608
605
  new_task_ids = {t.task_id for t in maybe_overridden_result}
609
- for t in new_tasks:
610
- # Same note as above about how this is theoretically reachable but we should
611
- # just get coverage by unifying the code paths
612
- if t.task_id not in new_task_ids: # pragma: no cover
606
+ for t in task_result.result:
607
+ if t.task_id not in new_task_ids:
613
608
  await self._finish_task(t.task_id)
614
- self._handle_execution_request(maybe_overridden_result)
615
- except GeneratorExit:
616
- self._task_group.cancel_scope.cancel()
617
- return
609
+ self._handle_execution_request(maybe_overridden_result)
610
+
611
+ if task_result.source_is_finished:
612
+ await self._finish_task(task_result.source.task_id)
613
+
614
+ if not self.active_tasks:
615
+ # if there are no active tasks, we'll be waiting forever for the next result..
616
+ break
617
+
618
+ if self.active_reducers: # pragma: no branch
619
+ # In this case, there are no pending tasks. We can therefore finalize all active reducers
620
+ # that don't have intermediate joins which are also active reducers. If a join J2 has an
621
+ # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
622
+ # because it might produce items that feed into J2.
623
+ for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
624
+ # Check if this join has any intermediate joins that are also active reducers
625
+ should_skip = False
626
+ intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
627
+
628
+ # Get the parent fork for this join to use for comparison
629
+ join_parent_fork = self.graph.get_parent_fork(join_id)
630
+
631
+ for intermediate_join_id in intermediate_joins:
632
+ # Check if the intermediate join is also an active reducer with matching fork run
633
+ for (other_join_id, _), other_join_state in self.active_reducers.items():
634
+ if other_join_id == intermediate_join_id:
635
+ # Check if they share the same fork run for this join's parent fork
636
+ # by finding the parent fork's node_run_id in both fork stacks
637
+ join_parent_fork_run_id = None
638
+ other_parent_fork_run_id = None
639
+
640
+ for fsi in join_state.downstream_fork_stack: # pragma: no branch
641
+ if fsi.fork_id == join_parent_fork.fork_id:
642
+ join_parent_fork_run_id = fsi.node_run_id
643
+ break
644
+
645
+ for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
646
+ if fsi.fork_id == join_parent_fork.fork_id:
647
+ other_parent_fork_run_id = fsi.node_run_id
648
+ break
649
+
650
+ if (
651
+ join_parent_fork_run_id
652
+ and other_parent_fork_run_id
653
+ and join_parent_fork_run_id == other_parent_fork_run_id
654
+ ): # pragma: no branch
655
+ should_skip = True
656
+ break
657
+ if should_skip:
658
+ break
659
+
660
+ if should_skip:
661
+ continue
662
+
663
+ self.active_reducers.pop(
664
+ (join_id, fork_run_id)
665
+ ) # we're handling it now, so we can pop it
666
+ join_node = self.graph.nodes[join_id]
667
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
668
+ new_tasks = self._handle_non_fork_edges(
669
+ join_node, join_state.current, join_state.downstream_fork_stack
670
+ )
671
+ maybe_overridden_result = yield new_tasks
672
+ if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
673
+ # This is theoretically reachable but it would be awkward.
674
+ # Probably a better way to get coverage here would be to unify the code pat
675
+ # with the other `if isinstance(maybe_overridden_result, EndMarker):`
676
+ self.task_group.cancel_scope.cancel()
677
+ return
678
+ for new_task in maybe_overridden_result:
679
+ self.active_tasks[new_task.task_id] = new_task
680
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
681
+ for t in new_tasks:
682
+ # Same note as above about how this is theoretically reachable but we should
683
+ # just get coverage by unifying the code paths
684
+ if t.task_id not in new_task_ids: # pragma: no cover
685
+ await self._finish_task(t.task_id)
686
+ self._handle_execution_request(maybe_overridden_result)
687
+ except GeneratorExit:
688
+ self.task_group.cancel_scope.cancel()
689
+ return
618
690
 
619
691
  raise RuntimeError( # pragma: no cover
620
692
  'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
621
693
  )
622
694
 
623
695
  async def _finish_task(self, task_id: TaskID) -> None:
696
+ # node_id is just included for debugging right now
624
697
  scope = self.cancel_scopes.pop(task_id, None)
625
698
  if scope is not None:
626
699
  scope.cancel()
@@ -658,6 +658,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
658
658
  if validate_graph_structure:
659
659
  _validate_graph_structure(nodes, edges_by_source)
660
660
  parent_forks = _collect_dominating_forks(nodes, edges_by_source)
661
+ intermediate_join_nodes = _compute_intermediate_join_nodes(nodes, parent_forks)
661
662
 
662
663
  return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
663
664
  name=self.name,
@@ -668,6 +669,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
668
669
  nodes=nodes,
669
670
  edges_by_source=edges_by_source,
670
671
  parent_forks=parent_forks,
672
+ intermediate_join_nodes=intermediate_join_nodes,
671
673
  auto_instrument=self.auto_instrument,
672
674
  )
673
675
 
@@ -948,6 +950,40 @@ Join {join.id!r} in this graph has no dominating fork in this graph.""")
948
950
  return dominating_forks
949
951
 
950
952
 
953
+ def _compute_intermediate_join_nodes(
954
+ nodes: dict[NodeID, AnyNode], parent_forks: dict[JoinID, ParentFork[NodeID]]
955
+ ) -> dict[JoinID, set[JoinID]]:
956
+ """Compute which joins have other joins as intermediate nodes.
957
+
958
+ A join J1 is an intermediate node of join J2 if J1 appears in J2's intermediate_nodes
959
+ (as computed relative to J2's parent fork).
960
+
961
+ This information is used to determine:
962
+ 1. Which joins are "final" (have no other joins in their intermediate_nodes)
963
+ 2. When selecting which reducer to proceed with when there are no active tasks
964
+
965
+ Args:
966
+ nodes: All nodes in the graph
967
+ parent_forks: Parent fork information for each join
968
+
969
+ Returns:
970
+ A mapping from each join to the set of joins that are intermediate to it
971
+ """
972
+ intermediate_join_nodes: dict[JoinID, set[JoinID]] = {}
973
+
974
+ for join_id, parent_fork in parent_forks.items():
975
+ intermediate_joins = set[JoinID]()
976
+ for intermediate_node_id in parent_fork.intermediate_nodes:
977
+ # Check if this intermediate node is also a join
978
+ intermediate_node = nodes.get(intermediate_node_id)
979
+ if isinstance(intermediate_node, Join):
980
+ # Add it regardless of whether it has the same parent fork
981
+ intermediate_joins.add(JoinID(intermediate_node_id))
982
+ intermediate_join_nodes[join_id] = intermediate_joins
983
+
984
+ return intermediate_join_nodes
985
+
986
+
951
987
  def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]):
952
988
  node_id_remapping = _build_placeholder_node_id_remapping(nodes)
953
989
  replaced_nodes = {
File without changes