pydantic-graph 1.10.0__tar.gz → 1.11.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/PKG-INFO +1 -1
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/graph.py +201 -128
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/graph_builder.py +36 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/.gitignore +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/LICENSE +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/README.md +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/__init__.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/_utils.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/__init__.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/decision.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/id_types.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/join.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/mermaid.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/node.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/node_types.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/parent_forks.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/paths.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/step.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/beta/util.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/exceptions.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/graph.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/mermaid.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/nodes.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/__init__.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/_utils.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/file.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/persistence/in_mem.py +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pydantic_graph/py.typed +0 -0
- {pydantic_graph-1.10.0 → pydantic_graph-1.11.1}/pyproject.toml +0 -0
|
@@ -10,7 +10,7 @@ from __future__ import annotations as _annotations
|
|
|
10
10
|
import sys
|
|
11
11
|
import uuid
|
|
12
12
|
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
|
|
13
|
-
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager
|
|
13
|
+
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
|
|
14
14
|
from dataclasses import dataclass, field
|
|
15
15
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
|
|
16
16
|
|
|
@@ -148,6 +148,12 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
|
|
|
148
148
|
parent_forks: dict[JoinID, ParentFork[NodeID]]
|
|
149
149
|
"""Parent fork information for each join node."""
|
|
150
150
|
|
|
151
|
+
intermediate_join_nodes: dict[JoinID, set[JoinID]]
|
|
152
|
+
"""For each join, the set of other joins that appear between it and its parent fork.
|
|
153
|
+
|
|
154
|
+
Used to determine which joins are "final" (have no other joins as intermediates) and
|
|
155
|
+
which joins should preserve fork stacks when proceeding downstream."""
|
|
156
|
+
|
|
151
157
|
def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
|
|
152
158
|
"""Get the parent fork information for a join node.
|
|
153
159
|
|
|
@@ -165,6 +171,24 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
|
|
|
165
171
|
raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)')
|
|
166
172
|
return result
|
|
167
173
|
|
|
174
|
+
def is_final_join(self, join_id: JoinID) -> bool:
|
|
175
|
+
"""Check if a join is 'final' (has no downstream joins with the same parent fork).
|
|
176
|
+
|
|
177
|
+
A join is non-final if it appears as an intermediate node for another join
|
|
178
|
+
with the same parent fork.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
join_id: The ID of the join node
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
True if the join is final, False if it's non-final
|
|
185
|
+
"""
|
|
186
|
+
# Check if this join appears in any other join's intermediate_join_nodes
|
|
187
|
+
for intermediate_joins in self.intermediate_join_nodes.values():
|
|
188
|
+
if join_id in intermediate_joins:
|
|
189
|
+
return False
|
|
190
|
+
return True
|
|
191
|
+
|
|
168
192
|
async def run(
|
|
169
193
|
self,
|
|
170
194
|
*,
|
|
@@ -357,18 +381,32 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
357
381
|
run_id = GraphRunID(str(uuid.uuid4()))
|
|
358
382
|
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
|
|
359
383
|
self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
|
|
360
|
-
self.
|
|
384
|
+
self._iterator_task_group = create_task_group()
|
|
385
|
+
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
|
|
386
|
+
self.graph, self.state, self.deps, self._iterator_task_group
|
|
387
|
+
)
|
|
361
388
|
self._iterator = self._iterator_instance.iter_graph(self._first_task)
|
|
362
389
|
|
|
363
390
|
self.__traceparent = traceparent
|
|
391
|
+
self._async_exit_stack = AsyncExitStack()
|
|
364
392
|
|
|
365
393
|
async def __aenter__(self):
|
|
394
|
+
self._async_exit_stack.enter_context(_unwrap_exception_groups())
|
|
395
|
+
await self._async_exit_stack.enter_async_context(self._iterator_task_group)
|
|
396
|
+
await self._async_exit_stack.enter_async_context(self._iterator_context())
|
|
366
397
|
return self
|
|
367
398
|
|
|
368
399
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
369
|
-
self.
|
|
370
|
-
|
|
371
|
-
|
|
400
|
+
await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
|
|
401
|
+
|
|
402
|
+
@asynccontextmanager
|
|
403
|
+
async def _iterator_context(self):
|
|
404
|
+
try:
|
|
405
|
+
yield
|
|
406
|
+
finally:
|
|
407
|
+
self._iterator_instance.iter_stream_sender.close()
|
|
408
|
+
self._iterator_instance.iter_stream_receiver.close()
|
|
409
|
+
await self._iterator.aclose()
|
|
372
410
|
|
|
373
411
|
@overload
|
|
374
412
|
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
|
|
@@ -471,13 +509,13 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
471
509
|
graph: Graph[StateT, DepsT, Any, OutputT]
|
|
472
510
|
state: StateT
|
|
473
511
|
deps: DepsT
|
|
512
|
+
task_group: TaskGroup
|
|
474
513
|
|
|
475
514
|
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
|
|
476
515
|
active_tasks: dict[TaskID, GraphTask] = field(init=False)
|
|
477
516
|
active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
|
|
478
517
|
iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
|
|
479
518
|
iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
|
|
480
|
-
_task_group: TaskGroup | None = field(init=False)
|
|
481
519
|
|
|
482
520
|
def __post_init__(self):
|
|
483
521
|
self.cancel_scopes = {}
|
|
@@ -485,142 +523,177 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
485
523
|
self.active_reducers = {}
|
|
486
524
|
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
|
|
487
525
|
|
|
488
|
-
@property
|
|
489
|
-
def task_group(self) -> TaskGroup:
|
|
490
|
-
if self._task_group is None:
|
|
491
|
-
raise RuntimeError("This graph iterator hasn't been started") # pragma: no cover
|
|
492
|
-
return self._task_group
|
|
493
|
-
|
|
494
526
|
async def iter_graph( # noqa C901
|
|
495
527
|
self, first_task: GraphTask
|
|
496
528
|
) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
|
|
497
|
-
with
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
529
|
+
async with self.iter_stream_sender:
|
|
530
|
+
try:
|
|
531
|
+
# Fire off the first task
|
|
532
|
+
self.active_tasks[first_task.task_id] = first_task
|
|
533
|
+
self._handle_execution_request([first_task])
|
|
534
|
+
|
|
535
|
+
# Handle task results
|
|
536
|
+
async with self.iter_stream_receiver:
|
|
537
|
+
while self.active_tasks or self.active_reducers:
|
|
538
|
+
async for task_result in self.iter_stream_receiver: # pragma: no branch
|
|
539
|
+
if isinstance(task_result.result, JoinItem):
|
|
540
|
+
maybe_overridden_result = task_result.result
|
|
541
|
+
else:
|
|
542
|
+
maybe_overridden_result = yield task_result.result
|
|
543
|
+
if isinstance(maybe_overridden_result, EndMarker):
|
|
544
|
+
# If we got an end marker, this task is definitely done, and we're ready to
|
|
545
|
+
# start cleaning everything up
|
|
546
|
+
await self._finish_task(task_result.source.task_id)
|
|
547
|
+
if self.active_tasks:
|
|
548
|
+
# Cancel the remaining tasks
|
|
513
549
|
self.task_group.cancel_scope.cancel()
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
550
|
+
return
|
|
551
|
+
elif isinstance(maybe_overridden_result, JoinItem):
|
|
552
|
+
result = maybe_overridden_result
|
|
553
|
+
parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
|
|
554
|
+
for i, x in enumerate(result.fork_stack[::-1]):
|
|
555
|
+
if x.fork_id == parent_fork_id:
|
|
556
|
+
# For non-final joins (those that are intermediate nodes of other joins),
|
|
557
|
+
# preserve the fork stack so downstream joins can still associate with the same fork run
|
|
558
|
+
if self.graph.is_final_join(result.join_id):
|
|
559
|
+
# Final join: remove the parent fork from the stack
|
|
520
560
|
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
if join_state.cancelled_sibling_tasks:
|
|
537
|
-
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
|
|
538
|
-
if task_result.source_is_finished: # pragma: no branch
|
|
539
|
-
await self._finish_task(task_result.source.task_id)
|
|
540
|
-
else:
|
|
541
|
-
for new_task in maybe_overridden_result:
|
|
542
|
-
self.active_tasks[new_task.task_id] = new_task
|
|
543
|
-
if task_result.source_is_finished:
|
|
544
|
-
await self._finish_task(task_result.source.task_id)
|
|
545
|
-
|
|
546
|
-
tasks_by_id_values = list(self.active_tasks.values())
|
|
547
|
-
join_tasks: list[GraphTask] = []
|
|
548
|
-
|
|
549
|
-
for join_id, fork_run_id in self._get_completed_fork_runs(
|
|
550
|
-
task_result.source, tasks_by_id_values
|
|
551
|
-
):
|
|
552
|
-
join_state = self.active_reducers.pop((join_id, fork_run_id))
|
|
553
|
-
join_node = self.graph.nodes[join_id]
|
|
554
|
-
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
555
|
-
new_tasks = self._handle_non_fork_edges(
|
|
556
|
-
join_node, join_state.current, join_state.downstream_fork_stack
|
|
557
|
-
)
|
|
558
|
-
join_tasks.extend(new_tasks)
|
|
559
|
-
if join_tasks:
|
|
560
|
-
for new_task in join_tasks:
|
|
561
|
-
self.active_tasks[new_task.task_id] = new_task
|
|
562
|
-
self._handle_execution_request(join_tasks)
|
|
563
|
-
|
|
564
|
-
if isinstance(maybe_overridden_result, Sequence):
|
|
565
|
-
if isinstance(task_result.result, Sequence):
|
|
566
|
-
new_task_ids = {t.task_id for t in maybe_overridden_result}
|
|
567
|
-
for t in task_result.result:
|
|
568
|
-
if t.task_id not in new_task_ids:
|
|
569
|
-
await self._finish_task(t.task_id)
|
|
570
|
-
self._handle_execution_request(maybe_overridden_result)
|
|
571
|
-
|
|
572
|
-
if not self.active_tasks:
|
|
573
|
-
# if there are no active tasks, we'll be waiting forever for the next result..
|
|
574
|
-
break
|
|
575
|
-
|
|
576
|
-
if self.active_reducers: # pragma: no branch
|
|
577
|
-
# In this case, there are no pending tasks. We can therefore finalize all active reducers whose
|
|
578
|
-
# downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
|
|
579
|
-
# deeper reducer could produce new tasks in the "prefix" reducer.)
|
|
580
|
-
active_fork_stacks = [
|
|
581
|
-
join_state.downstream_fork_stack for join_state in self.active_reducers.values()
|
|
582
|
-
]
|
|
583
|
-
for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
|
|
584
|
-
fork_stack = join_state.downstream_fork_stack
|
|
585
|
-
if any(
|
|
586
|
-
len(afs) > len(fork_stack) and fork_stack == afs[: len(fork_stack)]
|
|
587
|
-
for afs in active_fork_stacks
|
|
588
|
-
):
|
|
589
|
-
# this join_state is a strict prefix for one of the other active join_states
|
|
590
|
-
continue # pragma: no cover # It's difficult to cover this
|
|
591
|
-
self.active_reducers.pop(
|
|
592
|
-
(join_id, fork_run_id)
|
|
593
|
-
) # we're handling it now, so we can pop it
|
|
594
|
-
join_node = self.graph.nodes[join_id]
|
|
595
|
-
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
596
|
-
new_tasks = self._handle_non_fork_edges(
|
|
597
|
-
join_node, join_state.current, join_state.downstream_fork_stack
|
|
561
|
+
else:
|
|
562
|
+
# Non-final join: preserve the fork stack
|
|
563
|
+
downstream_fork_stack = result.fork_stack
|
|
564
|
+
fork_run_id = x.node_run_id
|
|
565
|
+
break
|
|
566
|
+
else: # pragma: no cover
|
|
567
|
+
raise RuntimeError('Parent fork run not found')
|
|
568
|
+
|
|
569
|
+
join_node = self.graph.nodes[result.join_id]
|
|
570
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
571
|
+
join_state = self.active_reducers.get((result.join_id, fork_run_id))
|
|
572
|
+
if join_state is None:
|
|
573
|
+
current = join_node.initial_factory()
|
|
574
|
+
join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
|
|
575
|
+
current, downstream_fork_stack
|
|
598
576
|
)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
577
|
+
context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
|
|
578
|
+
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
|
|
579
|
+
if join_state.cancelled_sibling_tasks:
|
|
580
|
+
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
|
|
581
|
+
else:
|
|
582
|
+
for new_task in maybe_overridden_result:
|
|
583
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
584
|
+
|
|
585
|
+
tasks_by_id_values = list(self.active_tasks.values())
|
|
586
|
+
join_tasks: list[GraphTask] = []
|
|
587
|
+
|
|
588
|
+
for join_id, fork_run_id in self._get_completed_fork_runs(
|
|
589
|
+
task_result.source, tasks_by_id_values
|
|
590
|
+
):
|
|
591
|
+
join_state = self.active_reducers.pop((join_id, fork_run_id))
|
|
592
|
+
join_node = self.graph.nodes[join_id]
|
|
593
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
594
|
+
new_tasks = self._handle_non_fork_edges(
|
|
595
|
+
join_node, join_state.current, join_state.downstream_fork_stack
|
|
596
|
+
)
|
|
597
|
+
join_tasks.extend(new_tasks)
|
|
598
|
+
if join_tasks:
|
|
599
|
+
for new_task in join_tasks:
|
|
600
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
601
|
+
self._handle_execution_request(join_tasks)
|
|
602
|
+
|
|
603
|
+
if isinstance(maybe_overridden_result, Sequence):
|
|
604
|
+
if isinstance(task_result.result, Sequence):
|
|
608
605
|
new_task_ids = {t.task_id for t in maybe_overridden_result}
|
|
609
|
-
for t in
|
|
610
|
-
|
|
611
|
-
# just get coverage by unifying the code paths
|
|
612
|
-
if t.task_id not in new_task_ids: # pragma: no cover
|
|
606
|
+
for t in task_result.result:
|
|
607
|
+
if t.task_id not in new_task_ids:
|
|
613
608
|
await self._finish_task(t.task_id)
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
609
|
+
self._handle_execution_request(maybe_overridden_result)
|
|
610
|
+
|
|
611
|
+
if task_result.source_is_finished:
|
|
612
|
+
await self._finish_task(task_result.source.task_id)
|
|
613
|
+
|
|
614
|
+
if not self.active_tasks:
|
|
615
|
+
# if there are no active tasks, we'll be waiting forever for the next result..
|
|
616
|
+
break
|
|
617
|
+
|
|
618
|
+
if self.active_reducers: # pragma: no branch
|
|
619
|
+
# In this case, there are no pending tasks. We can therefore finalize all active reducers
|
|
620
|
+
# that don't have intermediate joins which are also active reducers. If a join J2 has an
|
|
621
|
+
# intermediate join J1 that shares the same parent fork run, we must finalize J1 first
|
|
622
|
+
# because it might produce items that feed into J2.
|
|
623
|
+
for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
|
|
624
|
+
# Check if this join has any intermediate joins that are also active reducers
|
|
625
|
+
should_skip = False
|
|
626
|
+
intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
|
|
627
|
+
|
|
628
|
+
# Get the parent fork for this join to use for comparison
|
|
629
|
+
join_parent_fork = self.graph.get_parent_fork(join_id)
|
|
630
|
+
|
|
631
|
+
for intermediate_join_id in intermediate_joins:
|
|
632
|
+
# Check if the intermediate join is also an active reducer with matching fork run
|
|
633
|
+
for (other_join_id, _), other_join_state in self.active_reducers.items():
|
|
634
|
+
if other_join_id == intermediate_join_id:
|
|
635
|
+
# Check if they share the same fork run for this join's parent fork
|
|
636
|
+
# by finding the parent fork's node_run_id in both fork stacks
|
|
637
|
+
join_parent_fork_run_id = None
|
|
638
|
+
other_parent_fork_run_id = None
|
|
639
|
+
|
|
640
|
+
for fsi in join_state.downstream_fork_stack: # pragma: no branch
|
|
641
|
+
if fsi.fork_id == join_parent_fork.fork_id:
|
|
642
|
+
join_parent_fork_run_id = fsi.node_run_id
|
|
643
|
+
break
|
|
644
|
+
|
|
645
|
+
for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
|
|
646
|
+
if fsi.fork_id == join_parent_fork.fork_id:
|
|
647
|
+
other_parent_fork_run_id = fsi.node_run_id
|
|
648
|
+
break
|
|
649
|
+
|
|
650
|
+
if (
|
|
651
|
+
join_parent_fork_run_id
|
|
652
|
+
and other_parent_fork_run_id
|
|
653
|
+
and join_parent_fork_run_id == other_parent_fork_run_id
|
|
654
|
+
): # pragma: no branch
|
|
655
|
+
should_skip = True
|
|
656
|
+
break
|
|
657
|
+
if should_skip:
|
|
658
|
+
break
|
|
659
|
+
|
|
660
|
+
if should_skip:
|
|
661
|
+
continue
|
|
662
|
+
|
|
663
|
+
self.active_reducers.pop(
|
|
664
|
+
(join_id, fork_run_id)
|
|
665
|
+
) # we're handling it now, so we can pop it
|
|
666
|
+
join_node = self.graph.nodes[join_id]
|
|
667
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
668
|
+
new_tasks = self._handle_non_fork_edges(
|
|
669
|
+
join_node, join_state.current, join_state.downstream_fork_stack
|
|
670
|
+
)
|
|
671
|
+
maybe_overridden_result = yield new_tasks
|
|
672
|
+
if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
|
|
673
|
+
# This is theoretically reachable but it would be awkward.
|
|
674
|
+
# Probably a better way to get coverage here would be to unify the code pat
|
|
675
|
+
# with the other `if isinstance(maybe_overridden_result, EndMarker):`
|
|
676
|
+
self.task_group.cancel_scope.cancel()
|
|
677
|
+
return
|
|
678
|
+
for new_task in maybe_overridden_result:
|
|
679
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
680
|
+
new_task_ids = {t.task_id for t in maybe_overridden_result}
|
|
681
|
+
for t in new_tasks:
|
|
682
|
+
# Same note as above about how this is theoretically reachable but we should
|
|
683
|
+
# just get coverage by unifying the code paths
|
|
684
|
+
if t.task_id not in new_task_ids: # pragma: no cover
|
|
685
|
+
await self._finish_task(t.task_id)
|
|
686
|
+
self._handle_execution_request(maybe_overridden_result)
|
|
687
|
+
except GeneratorExit:
|
|
688
|
+
self.task_group.cancel_scope.cancel()
|
|
689
|
+
return
|
|
618
690
|
|
|
619
691
|
raise RuntimeError( # pragma: no cover
|
|
620
692
|
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
|
|
621
693
|
)
|
|
622
694
|
|
|
623
695
|
async def _finish_task(self, task_id: TaskID) -> None:
|
|
696
|
+
# node_id is just included for debugging right now
|
|
624
697
|
scope = self.cancel_scopes.pop(task_id, None)
|
|
625
698
|
if scope is not None:
|
|
626
699
|
scope.cancel()
|
|
@@ -658,6 +658,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
|
|
|
658
658
|
if validate_graph_structure:
|
|
659
659
|
_validate_graph_structure(nodes, edges_by_source)
|
|
660
660
|
parent_forks = _collect_dominating_forks(nodes, edges_by_source)
|
|
661
|
+
intermediate_join_nodes = _compute_intermediate_join_nodes(nodes, parent_forks)
|
|
661
662
|
|
|
662
663
|
return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
|
|
663
664
|
name=self.name,
|
|
@@ -668,6 +669,7 @@ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
|
|
|
668
669
|
nodes=nodes,
|
|
669
670
|
edges_by_source=edges_by_source,
|
|
670
671
|
parent_forks=parent_forks,
|
|
672
|
+
intermediate_join_nodes=intermediate_join_nodes,
|
|
671
673
|
auto_instrument=self.auto_instrument,
|
|
672
674
|
)
|
|
673
675
|
|
|
@@ -948,6 +950,40 @@ Join {join.id!r} in this graph has no dominating fork in this graph.""")
|
|
|
948
950
|
return dominating_forks
|
|
949
951
|
|
|
950
952
|
|
|
953
|
+
def _compute_intermediate_join_nodes(
|
|
954
|
+
nodes: dict[NodeID, AnyNode], parent_forks: dict[JoinID, ParentFork[NodeID]]
|
|
955
|
+
) -> dict[JoinID, set[JoinID]]:
|
|
956
|
+
"""Compute which joins have other joins as intermediate nodes.
|
|
957
|
+
|
|
958
|
+
A join J1 is an intermediate node of join J2 if J1 appears in J2's intermediate_nodes
|
|
959
|
+
(as computed relative to J2's parent fork).
|
|
960
|
+
|
|
961
|
+
This information is used to determine:
|
|
962
|
+
1. Which joins are "final" (have no other joins in their intermediate_nodes)
|
|
963
|
+
2. When selecting which reducer to proceed with when there are no active tasks
|
|
964
|
+
|
|
965
|
+
Args:
|
|
966
|
+
nodes: All nodes in the graph
|
|
967
|
+
parent_forks: Parent fork information for each join
|
|
968
|
+
|
|
969
|
+
Returns:
|
|
970
|
+
A mapping from each join to the set of joins that are intermediate to it
|
|
971
|
+
"""
|
|
972
|
+
intermediate_join_nodes: dict[JoinID, set[JoinID]] = {}
|
|
973
|
+
|
|
974
|
+
for join_id, parent_fork in parent_forks.items():
|
|
975
|
+
intermediate_joins = set[JoinID]()
|
|
976
|
+
for intermediate_node_id in parent_fork.intermediate_nodes:
|
|
977
|
+
# Check if this intermediate node is also a join
|
|
978
|
+
intermediate_node = nodes.get(intermediate_node_id)
|
|
979
|
+
if isinstance(intermediate_node, Join):
|
|
980
|
+
# Add it regardless of whether it has the same parent fork
|
|
981
|
+
intermediate_joins.add(JoinID(intermediate_node_id))
|
|
982
|
+
intermediate_join_nodes[join_id] = intermediate_joins
|
|
983
|
+
|
|
984
|
+
return intermediate_join_nodes
|
|
985
|
+
|
|
986
|
+
|
|
951
987
|
def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]):
|
|
952
988
|
node_id_remapping = _build_placeholder_node_id_remapping(nodes)
|
|
953
989
|
replaced_nodes = {
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|