pydantic-graph 1.11.0__tar.gz → 1.22.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/.gitignore +2 -0
  2. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/PKG-INFO +1 -1
  3. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/graph.py +234 -183
  4. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/id_types.py +0 -3
  5. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/LICENSE +0 -0
  6. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/README.md +0 -0
  7. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/__init__.py +0 -0
  8. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/_utils.py +0 -0
  9. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/__init__.py +0 -0
  10. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/decision.py +0 -0
  11. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/graph_builder.py +0 -0
  12. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/join.py +0 -0
  13. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/mermaid.py +0 -0
  14. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/node.py +0 -0
  15. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/node_types.py +0 -0
  16. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/parent_forks.py +0 -0
  17. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/paths.py +0 -0
  18. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/step.py +0 -0
  19. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/util.py +0 -0
  20. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/exceptions.py +0 -0
  21. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/graph.py +0 -0
  22. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/mermaid.py +0 -0
  23. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/nodes.py +0 -0
  24. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/__init__.py +0 -0
  25. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/_utils.py +0 -0
  26. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/file.py +0 -0
  27. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/in_mem.py +0 -0
  28. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/py.typed +0 -0
  29. {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pyproject.toml +0 -0
@@ -21,3 +21,5 @@ node_modules/
21
21
  /test_tmp/
22
22
  .mcp.json
23
23
  .claude/
24
+ /.cursor/
25
+ /.devcontainer/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-graph
3
- Version: 1.11.0
3
+ Version: 1.22.0
4
4
  Summary: Graph and state machine library
5
5
  Project-URL: Homepage, https://ai.pydantic.dev/graph/tree/main/pydantic_graph
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -8,9 +8,8 @@ the graph-based workflow system.
8
8
  from __future__ import annotations as _annotations
9
9
 
10
10
  import sys
11
- import uuid
12
- from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
13
- from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager
11
+ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
12
+ from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
14
13
  from dataclasses import dataclass, field
15
14
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
16
15
 
@@ -22,7 +21,7 @@ from typing_extensions import TypeVar, assert_never
22
21
  from pydantic_graph import exceptions
23
22
  from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
24
23
  from pydantic_graph.beta.decision import Decision
25
- from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
24
+ from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
26
25
  from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
27
26
  from pydantic_graph.beta.node import (
28
27
  EndNode,
@@ -306,14 +305,13 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
306
305
 
307
306
 
308
307
  @dataclass
309
- class GraphTask:
310
- """A single task representing the execution of a node in the graph.
308
+ class GraphTaskRequest:
309
+ """A request to run a task representing the execution of a node in the graph.
311
310
 
312
- GraphTask encapsulates all the information needed to execute a specific
311
+ GraphTaskRequest encapsulates all the information needed to execute a specific
313
312
  node, including its inputs and the fork context it's executing within.
314
313
  """
315
314
 
316
- # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317
315
  node_id: NodeID
318
316
  """The ID of the node to execute."""
319
317
 
@@ -326,9 +324,26 @@ class GraphTask:
326
324
  Used by the GraphRun to decide when to proceed through joins.
327
325
  """
328
326
 
329
- task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False)
327
+
328
+ @dataclass
329
+ class GraphTask(GraphTaskRequest):
330
+ """A task representing the execution of a node in the graph.
331
+
332
+ GraphTask encapsulates all the information needed to execute a specific
333
+ node, including its inputs and the fork context it's executing within,
334
+ and has a unique ID to identify the task within the graph run.
335
+ """
336
+
337
+ task_id: TaskID = field(repr=False)
330
338
  """Unique identifier for this task."""
331
339
 
340
+ @staticmethod
341
+ def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
342
+ # Don't call the get_task_id callable, this is already a task
343
+ if isinstance(request, GraphTask):
344
+ return request
345
+ return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
346
+
332
347
 
333
348
  class GraphRun(Generic[StateT, DepsT, OutputT]):
334
349
  """A single execution instance of a graph.
@@ -378,21 +393,43 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
378
393
  self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
379
394
  """The next item to be processed."""
380
395
 
381
- run_id = GraphRunID(str(uuid.uuid4()))
382
- initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
383
- self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
384
- self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](self.graph, self.state, self.deps)
396
+ self._next_task_id = 0
397
+ self._next_node_run_id = 0
398
+ initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
399
+ self._first_task = GraphTask(
400
+ node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
401
+ )
402
+ self._iterator_task_group = create_task_group()
403
+ self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
404
+ self.graph,
405
+ self.state,
406
+ self.deps,
407
+ self._iterator_task_group,
408
+ self._get_next_node_run_id,
409
+ self._get_next_task_id,
410
+ )
385
411
  self._iterator = self._iterator_instance.iter_graph(self._first_task)
386
412
 
387
413
  self.__traceparent = traceparent
414
+ self._async_exit_stack = AsyncExitStack()
388
415
 
389
416
  async def __aenter__(self):
417
+ self._async_exit_stack.enter_context(_unwrap_exception_groups())
418
+ await self._async_exit_stack.enter_async_context(self._iterator_task_group)
419
+ await self._async_exit_stack.enter_async_context(self._iterator_context())
390
420
  return self
391
421
 
392
422
  async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
393
- self._iterator_instance.iter_stream_sender.close()
394
- self._iterator_instance.iter_stream_receiver.close()
395
- await self._iterator.aclose()
423
+ await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
424
+
425
+ @asynccontextmanager
426
+ async def _iterator_context(self):
427
+ try:
428
+ yield
429
+ finally:
430
+ self._iterator_instance.iter_stream_sender.close()
431
+ self._iterator_instance.iter_stream_receiver.close()
432
+ await self._iterator.aclose()
396
433
 
397
434
  @overload
398
435
  def _traceparent(self, *, required: Literal[False]) -> str | None: ...
@@ -435,7 +472,7 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
435
472
  return self._next
436
473
 
437
474
  async def next(
438
- self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None
475
+ self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
439
476
  ) -> EndMarker[OutputT] | Sequence[GraphTask]:
440
477
  """Advance the graph execution by one step.
441
478
 
@@ -453,7 +490,10 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
453
490
  # if `next` is called before the `first_node` has run.
454
491
  await anext(self)
455
492
  if value is not None:
456
- self._next = value
493
+ if isinstance(value, EndMarker):
494
+ self._next = value
495
+ else:
496
+ self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
457
497
  return await anext(self)
458
498
 
459
499
  @property
@@ -476,6 +516,16 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
476
516
  return self._next.value
477
517
  return None
478
518
 
519
+ def _get_next_task_id(self) -> TaskID:
520
+ next_id = TaskID(f'task:{self._next_task_id}')
521
+ self._next_task_id += 1
522
+ return next_id
523
+
524
+ def _get_next_node_run_id(self) -> NodeRunID:
525
+ next_id = NodeRunID(f'task:{self._next_node_run_id}')
526
+ self._next_node_run_id += 1
527
+ return next_id
528
+
479
529
 
480
530
  @dataclass
481
531
  class _GraphTaskAsyncIterable:
@@ -495,192 +545,193 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
495
545
  graph: Graph[StateT, DepsT, Any, OutputT]
496
546
  state: StateT
497
547
  deps: DepsT
548
+ task_group: TaskGroup
549
+ get_next_node_run_id: Callable[[], NodeRunID]
550
+ get_next_task_id: Callable[[], TaskID]
498
551
 
499
552
  cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
500
553
  active_tasks: dict[TaskID, GraphTask] = field(init=False)
501
554
  active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
502
555
  iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
503
556
  iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
504
- _task_group: TaskGroup | None = field(init=False)
505
557
 
506
558
  def __post_init__(self):
507
559
  self.cancel_scopes = {}
508
560
  self.active_tasks = {}
509
561
  self.active_reducers = {}
510
562
  self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
511
-
512
- @property
513
- def task_group(self) -> TaskGroup:
514
- if self._task_group is None:
515
- raise RuntimeError("This graph iterator hasn't been started") # pragma: no cover
516
- return self._task_group
563
+ self._next_node_run_id = 1
517
564
 
518
565
  async def iter_graph( # noqa C901
519
566
  self, first_task: GraphTask
520
567
  ) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
521
- with _unwrap_exception_groups():
522
- async with self.iter_stream_sender, create_task_group() as self._task_group:
523
- try:
524
- # Fire off the first task
525
- self.active_tasks[first_task.task_id] = first_task
526
- self._handle_execution_request([first_task])
527
-
528
- # Handle task results
529
- async with self.iter_stream_receiver:
530
- while self.active_tasks or self.active_reducers:
531
- async for task_result in self.iter_stream_receiver: # pragma: no branch
532
- if isinstance(task_result.result, JoinItem):
533
- maybe_overridden_result = task_result.result
534
- else:
535
- maybe_overridden_result = yield task_result.result
536
- if isinstance(maybe_overridden_result, EndMarker):
568
+ async with self.iter_stream_sender:
569
+ try:
570
+ # Fire off the first task
571
+ self.active_tasks[first_task.task_id] = first_task
572
+ self._handle_execution_request([first_task])
573
+
574
+ # Handle task results
575
+ async with self.iter_stream_receiver:
576
+ while self.active_tasks or self.active_reducers:
577
+ async for task_result in self.iter_stream_receiver: # pragma: no branch
578
+ if isinstance(task_result.result, JoinItem):
579
+ maybe_overridden_result = task_result.result
580
+ else:
581
+ maybe_overridden_result = yield task_result.result
582
+ if isinstance(maybe_overridden_result, EndMarker):
583
+ # If we got an end marker, this task is definitely done, and we're ready to
584
+ # start cleaning everything up
585
+ await self._finish_task(task_result.source.task_id)
586
+ if self.active_tasks:
587
+ # Cancel the remaining tasks
537
588
  self.task_group.cancel_scope.cancel()
538
- return
539
- elif isinstance(maybe_overridden_result, JoinItem):
540
- result = maybe_overridden_result
541
- parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
542
- for i, x in enumerate(result.fork_stack[::-1]):
543
- if x.fork_id == parent_fork_id:
544
- # For non-final joins (those that are intermediate nodes of other joins),
545
- # preserve the fork stack so downstream joins can still associate with the same fork run
546
- if self.graph.is_final_join(result.join_id):
547
- # Final join: remove the parent fork from the stack
548
- downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
549
- else:
550
- # Non-final join: preserve the fork stack
551
- downstream_fork_stack = result.fork_stack
552
- fork_run_id = x.node_run_id
553
- break
554
- else: # pragma: no cover
555
- raise RuntimeError('Parent fork run not found')
556
-
557
- join_node = self.graph.nodes[result.join_id]
558
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
559
- join_state = self.active_reducers.get((result.join_id, fork_run_id))
560
- if join_state is None:
561
- current = join_node.initial_factory()
562
- join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
563
- current, downstream_fork_stack
564
- )
565
- context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
566
- join_state.current = join_node.reduce(context, join_state.current, result.inputs)
567
- if join_state.cancelled_sibling_tasks:
568
- await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
569
- else:
570
- for new_task in maybe_overridden_result:
571
- self.active_tasks[new_task.task_id] = new_task
572
-
573
- tasks_by_id_values = list(self.active_tasks.values())
574
- join_tasks: list[GraphTask] = []
575
-
576
- for join_id, fork_run_id in self._get_completed_fork_runs(
577
- task_result.source, tasks_by_id_values
578
- ):
579
- join_state = self.active_reducers.pop((join_id, fork_run_id))
580
- join_node = self.graph.nodes[join_id]
581
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
582
- new_tasks = self._handle_non_fork_edges(
583
- join_node, join_state.current, join_state.downstream_fork_stack
589
+ return
590
+ elif isinstance(maybe_overridden_result, JoinItem):
591
+ result = maybe_overridden_result
592
+ parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
593
+ for i, x in enumerate(result.fork_stack[::-1]):
594
+ if x.fork_id == parent_fork_id:
595
+ # For non-final joins (those that are intermediate nodes of other joins),
596
+ # preserve the fork stack so downstream joins can still associate with the same fork run
597
+ if self.graph.is_final_join(result.join_id):
598
+ # Final join: remove the parent fork from the stack
599
+ downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
600
+ else:
601
+ # Non-final join: preserve the fork stack
602
+ downstream_fork_stack = result.fork_stack
603
+ fork_run_id = x.node_run_id
604
+ break
605
+ else: # pragma: no cover
606
+ raise RuntimeError('Parent fork run not found')
607
+
608
+ join_node = self.graph.nodes[result.join_id]
609
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
610
+ join_state = self.active_reducers.get((result.join_id, fork_run_id))
611
+ if join_state is None:
612
+ current = join_node.initial_factory()
613
+ join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
614
+ current, downstream_fork_stack
584
615
  )
585
- join_tasks.extend(new_tasks)
586
- if join_tasks:
587
- for new_task in join_tasks:
588
- self.active_tasks[new_task.task_id] = new_task
589
- self._handle_execution_request(join_tasks)
590
-
591
- if isinstance(maybe_overridden_result, Sequence):
592
- if isinstance(task_result.result, Sequence):
593
- new_task_ids = {t.task_id for t in maybe_overridden_result}
594
- for t in task_result.result:
595
- if t.task_id not in new_task_ids:
596
- await self._finish_task(t.task_id, t.node_id)
597
- self._handle_execution_request(maybe_overridden_result)
598
-
599
- if task_result.source_is_finished:
600
- await self._finish_task(task_result.source.task_id, task_result.source.node_id)
601
-
602
- if not self.active_tasks:
603
- # if there are no active tasks, we'll be waiting forever for the next result..
604
- break
605
-
606
- if self.active_reducers: # pragma: no branch
607
- # In this case, there are no pending tasks. We can therefore finalize all active reducers
608
- # that don't have intermediate joins which are also active reducers. If a join J2 has an
609
- # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
610
- # because it might produce items that feed into J2.
611
- for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
612
- # Check if this join has any intermediate joins that are also active reducers
613
- should_skip = False
614
- intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
615
-
616
- # Get the parent fork for this join to use for comparison
617
- join_parent_fork = self.graph.get_parent_fork(join_id)
618
-
619
- for intermediate_join_id in intermediate_joins:
620
- # Check if the intermediate join is also an active reducer with matching fork run
621
- for (other_join_id, _), other_join_state in self.active_reducers.items():
622
- if other_join_id == intermediate_join_id:
623
- # Check if they share the same fork run for this join's parent fork
624
- # by finding the parent fork's node_run_id in both fork stacks
625
- join_parent_fork_run_id = None
626
- other_parent_fork_run_id = None
627
-
628
- for fsi in join_state.downstream_fork_stack: # pragma: no branch
629
- if fsi.fork_id == join_parent_fork.fork_id:
630
- join_parent_fork_run_id = fsi.node_run_id
631
- break
632
-
633
- for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
634
- if fsi.fork_id == join_parent_fork.fork_id:
635
- other_parent_fork_run_id = fsi.node_run_id
636
- break
637
-
638
- if (
639
- join_parent_fork_run_id
640
- and other_parent_fork_run_id
641
- and join_parent_fork_run_id == other_parent_fork_run_id
642
- ): # pragma: no branch
643
- should_skip = True
616
+ context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
617
+ join_state.current = join_node.reduce(context, join_state.current, result.inputs)
618
+ if join_state.cancelled_sibling_tasks:
619
+ await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
620
+ else:
621
+ for new_task in maybe_overridden_result:
622
+ self.active_tasks[new_task.task_id] = new_task
623
+
624
+ tasks_by_id_values = list(self.active_tasks.values())
625
+ join_tasks: list[GraphTask] = []
626
+
627
+ for join_id, fork_run_id in self._get_completed_fork_runs(
628
+ task_result.source, tasks_by_id_values
629
+ ):
630
+ join_state = self.active_reducers.pop((join_id, fork_run_id))
631
+ join_node = self.graph.nodes[join_id]
632
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
633
+ new_tasks = self._handle_non_fork_edges(
634
+ join_node, join_state.current, join_state.downstream_fork_stack
635
+ )
636
+ join_tasks.extend(new_tasks)
637
+ if join_tasks:
638
+ for new_task in join_tasks:
639
+ self.active_tasks[new_task.task_id] = new_task
640
+ self._handle_execution_request(join_tasks)
641
+
642
+ if isinstance(maybe_overridden_result, Sequence):
643
+ if isinstance(task_result.result, Sequence):
644
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
645
+ for t in task_result.result:
646
+ if t.task_id not in new_task_ids:
647
+ await self._finish_task(t.task_id)
648
+ self._handle_execution_request(maybe_overridden_result)
649
+
650
+ if task_result.source_is_finished:
651
+ await self._finish_task(task_result.source.task_id)
652
+
653
+ if not self.active_tasks:
654
+ # if there are no active tasks, we'll be waiting forever for the next result..
655
+ break
656
+
657
+ if self.active_reducers: # pragma: no branch
658
+ # In this case, there are no pending tasks. We can therefore finalize all active reducers
659
+ # that don't have intermediate joins which are also active reducers. If a join J2 has an
660
+ # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
661
+ # because it might produce items that feed into J2.
662
+ for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
663
+ # Check if this join has any intermediate joins that are also active reducers
664
+ should_skip = False
665
+ intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
666
+
667
+ # Get the parent fork for this join to use for comparison
668
+ join_parent_fork = self.graph.get_parent_fork(join_id)
669
+
670
+ for intermediate_join_id in intermediate_joins:
671
+ # Check if the intermediate join is also an active reducer with matching fork run
672
+ for (other_join_id, _), other_join_state in self.active_reducers.items():
673
+ if other_join_id == intermediate_join_id:
674
+ # Check if they share the same fork run for this join's parent fork
675
+ # by finding the parent fork's node_run_id in both fork stacks
676
+ join_parent_fork_run_id = None
677
+ other_parent_fork_run_id = None
678
+
679
+ for fsi in join_state.downstream_fork_stack: # pragma: no branch
680
+ if fsi.fork_id == join_parent_fork.fork_id:
681
+ join_parent_fork_run_id = fsi.node_run_id
682
+ break
683
+
684
+ for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
685
+ if fsi.fork_id == join_parent_fork.fork_id:
686
+ other_parent_fork_run_id = fsi.node_run_id
644
687
  break
645
- if should_skip:
646
- break
647
688
 
689
+ if (
690
+ join_parent_fork_run_id
691
+ and other_parent_fork_run_id
692
+ and join_parent_fork_run_id == other_parent_fork_run_id
693
+ ): # pragma: no branch
694
+ should_skip = True
695
+ break
648
696
  if should_skip:
649
- continue
650
-
651
- self.active_reducers.pop(
652
- (join_id, fork_run_id)
653
- ) # we're handling it now, so we can pop it
654
- join_node = self.graph.nodes[join_id]
655
- assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
656
- new_tasks = self._handle_non_fork_edges(
657
- join_node, join_state.current, join_state.downstream_fork_stack
658
- )
659
- maybe_overridden_result = yield new_tasks
660
- if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
661
- # This is theoretically reachable but it would be awkward.
662
- # Probably a better way to get coverage here would be to unify the code pat
663
- # with the other `if isinstance(maybe_overridden_result, EndMarker):`
664
- self.task_group.cancel_scope.cancel()
665
- return
666
- for new_task in maybe_overridden_result:
667
- self.active_tasks[new_task.task_id] = new_task
668
- new_task_ids = {t.task_id for t in maybe_overridden_result}
669
- for t in new_tasks:
670
- # Same note as above about how this is theoretically reachable but we should
671
- # just get coverage by unifying the code paths
672
- if t.task_id not in new_task_ids: # pragma: no cover
673
- await self._finish_task(t.task_id, t.node_id)
674
- self._handle_execution_request(maybe_overridden_result)
675
- except GeneratorExit:
676
- self._task_group.cancel_scope.cancel()
677
- return
697
+ break
698
+
699
+ if should_skip:
700
+ continue
701
+
702
+ self.active_reducers.pop(
703
+ (join_id, fork_run_id)
704
+ ) # we're handling it now, so we can pop it
705
+ join_node = self.graph.nodes[join_id]
706
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
707
+ new_tasks = self._handle_non_fork_edges(
708
+ join_node, join_state.current, join_state.downstream_fork_stack
709
+ )
710
+ maybe_overridden_result = yield new_tasks
711
+ if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
712
+ # This is theoretically reachable but it would be awkward.
713
+ # Probably a better way to get coverage here would be to unify the code pat
714
+ # with the other `if isinstance(maybe_overridden_result, EndMarker):`
715
+ self.task_group.cancel_scope.cancel()
716
+ return
717
+ for new_task in maybe_overridden_result:
718
+ self.active_tasks[new_task.task_id] = new_task
719
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
720
+ for t in new_tasks:
721
+ # Same note as above about how this is theoretically reachable but we should
722
+ # just get coverage by unifying the code paths
723
+ if t.task_id not in new_task_ids: # pragma: no cover
724
+ await self._finish_task(t.task_id)
725
+ self._handle_execution_request(maybe_overridden_result)
726
+ except GeneratorExit:
727
+ self.task_group.cancel_scope.cancel()
728
+ return
678
729
 
679
730
  raise RuntimeError( # pragma: no cover
680
731
  'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
681
732
  )
682
733
 
683
- async def _finish_task(self, task_id: TaskID, node_id: str) -> None:
734
+ async def _finish_task(self, task_id: TaskID) -> None:
684
735
  # node_id is just included for debugging right now
685
736
  scope = self.cancel_scopes.pop(task_id, None)
686
737
  if scope is not None:
@@ -770,12 +821,12 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
770
821
  fork_stack: ForkStack,
771
822
  ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
772
823
  if isinstance(next_node, StepNode):
773
- return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
824
+ return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
774
825
  elif isinstance(next_node, JoinNode):
775
826
  return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
776
827
  elif isinstance(next_node, BaseNode):
777
828
  node_step = NodeStep(next_node.__class__)
778
- return [GraphTask(node_step.id, next_node, fork_stack)]
829
+ return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
779
830
  elif isinstance(next_node, End):
780
831
  return EndMarker(next_node.data)
781
832
  else:
@@ -809,7 +860,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
809
860
  'These markers should be removed from paths during graph building'
810
861
  )
811
862
  if isinstance(item, DestinationMarker):
812
- return [GraphTask(item.destination_id, inputs, fork_stack)]
863
+ return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
813
864
  elif isinstance(item, TransformMarker):
814
865
  inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
815
866
  return self._handle_path(path.next_path, inputs, fork_stack)
@@ -841,7 +892,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
841
892
  ) # this should have already been ensured during graph building
842
893
 
843
894
  new_tasks: list[GraphTask] = []
844
- node_run_id = NodeRunID(str(uuid.uuid4()))
895
+ node_run_id = self.get_next_node_run_id()
845
896
  if node.is_map:
846
897
  # If the map specifies a downstream join id, eagerly create a join state for it
847
898
  if (join_id := node.downstream_join_id) is not None:
@@ -898,7 +949,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
898
949
  else:
899
950
  pass
900
951
  for task_id in task_ids_to_cancel:
901
- await self._finish_task(task_id, 'sibling')
952
+ await self._finish_task(task_id)
902
953
 
903
954
 
904
955
  def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
@@ -24,9 +24,6 @@ JoinID = NodeID
24
24
  ForkID = NodeID
25
25
  """Alias for NodeId when referring to fork nodes."""
26
26
 
27
- GraphRunID = NewType('GraphRunID', str)
28
- """Unique identifier for a complete graph execution run."""
29
-
30
27
  TaskID = NewType('TaskID', str)
31
28
  """Unique identifier for a task within the graph execution."""
32
29
 
File without changes