pydantic-graph 1.11.0__tar.gz → 1.22.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/.gitignore +2 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/PKG-INFO +1 -1
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/graph.py +234 -183
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/id_types.py +0 -3
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/LICENSE +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/README.md +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/__init__.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/_utils.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/__init__.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/decision.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/graph_builder.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/join.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/mermaid.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/node.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/node_types.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/parent_forks.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/paths.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/step.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/beta/util.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/exceptions.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/graph.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/mermaid.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/nodes.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/__init__.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/_utils.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/file.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/persistence/in_mem.py +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pydantic_graph/py.typed +0 -0
- {pydantic_graph-1.11.0 → pydantic_graph-1.22.0}/pyproject.toml +0 -0
|
@@ -8,9 +8,8 @@ the graph-based workflow system.
|
|
|
8
8
|
from __future__ import annotations as _annotations
|
|
9
9
|
|
|
10
10
|
import sys
|
|
11
|
-
import
|
|
12
|
-
from
|
|
13
|
-
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager
|
|
11
|
+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
|
|
12
|
+
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
|
|
14
13
|
from dataclasses import dataclass, field
|
|
15
14
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
|
|
16
15
|
|
|
@@ -22,7 +21,7 @@ from typing_extensions import TypeVar, assert_never
|
|
|
22
21
|
from pydantic_graph import exceptions
|
|
23
22
|
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
|
|
24
23
|
from pydantic_graph.beta.decision import Decision
|
|
25
|
-
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem,
|
|
24
|
+
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
|
|
26
25
|
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
|
|
27
26
|
from pydantic_graph.beta.node import (
|
|
28
27
|
EndNode,
|
|
@@ -306,14 +305,13 @@ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
|
|
|
306
305
|
|
|
307
306
|
|
|
308
307
|
@dataclass
|
|
309
|
-
class
|
|
310
|
-
"""A
|
|
308
|
+
class GraphTaskRequest:
|
|
309
|
+
"""A request to run a task representing the execution of a node in the graph.
|
|
311
310
|
|
|
312
|
-
|
|
311
|
+
GraphTaskRequest encapsulates all the information needed to execute a specific
|
|
313
312
|
node, including its inputs and the fork context it's executing within.
|
|
314
313
|
"""
|
|
315
314
|
|
|
316
|
-
# With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
|
|
317
315
|
node_id: NodeID
|
|
318
316
|
"""The ID of the node to execute."""
|
|
319
317
|
|
|
@@ -326,9 +324,26 @@ class GraphTask:
|
|
|
326
324
|
Used by the GraphRun to decide when to proceed through joins.
|
|
327
325
|
"""
|
|
328
326
|
|
|
329
|
-
|
|
327
|
+
|
|
328
|
+
@dataclass
|
|
329
|
+
class GraphTask(GraphTaskRequest):
|
|
330
|
+
"""A task representing the execution of a node in the graph.
|
|
331
|
+
|
|
332
|
+
GraphTask encapsulates all the information needed to execute a specific
|
|
333
|
+
node, including its inputs and the fork context it's executing within,
|
|
334
|
+
and has a unique ID to identify the task within the graph run.
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
task_id: TaskID = field(repr=False)
|
|
330
338
|
"""Unique identifier for this task."""
|
|
331
339
|
|
|
340
|
+
@staticmethod
|
|
341
|
+
def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
|
|
342
|
+
# Don't call the get_task_id callable, this is already a task
|
|
343
|
+
if isinstance(request, GraphTask):
|
|
344
|
+
return request
|
|
345
|
+
return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
|
|
346
|
+
|
|
332
347
|
|
|
333
348
|
class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
334
349
|
"""A single execution instance of a graph.
|
|
@@ -378,21 +393,43 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
378
393
|
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
|
|
379
394
|
"""The next item to be processed."""
|
|
380
395
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
self.
|
|
396
|
+
self._next_task_id = 0
|
|
397
|
+
self._next_node_run_id = 0
|
|
398
|
+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
|
|
399
|
+
self._first_task = GraphTask(
|
|
400
|
+
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
|
|
401
|
+
)
|
|
402
|
+
self._iterator_task_group = create_task_group()
|
|
403
|
+
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
|
|
404
|
+
self.graph,
|
|
405
|
+
self.state,
|
|
406
|
+
self.deps,
|
|
407
|
+
self._iterator_task_group,
|
|
408
|
+
self._get_next_node_run_id,
|
|
409
|
+
self._get_next_task_id,
|
|
410
|
+
)
|
|
385
411
|
self._iterator = self._iterator_instance.iter_graph(self._first_task)
|
|
386
412
|
|
|
387
413
|
self.__traceparent = traceparent
|
|
414
|
+
self._async_exit_stack = AsyncExitStack()
|
|
388
415
|
|
|
389
416
|
async def __aenter__(self):
|
|
417
|
+
self._async_exit_stack.enter_context(_unwrap_exception_groups())
|
|
418
|
+
await self._async_exit_stack.enter_async_context(self._iterator_task_group)
|
|
419
|
+
await self._async_exit_stack.enter_async_context(self._iterator_context())
|
|
390
420
|
return self
|
|
391
421
|
|
|
392
422
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
393
|
-
self.
|
|
394
|
-
|
|
395
|
-
|
|
423
|
+
await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
|
|
424
|
+
|
|
425
|
+
@asynccontextmanager
|
|
426
|
+
async def _iterator_context(self):
|
|
427
|
+
try:
|
|
428
|
+
yield
|
|
429
|
+
finally:
|
|
430
|
+
self._iterator_instance.iter_stream_sender.close()
|
|
431
|
+
self._iterator_instance.iter_stream_receiver.close()
|
|
432
|
+
await self._iterator.aclose()
|
|
396
433
|
|
|
397
434
|
@overload
|
|
398
435
|
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
|
|
@@ -435,7 +472,7 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
435
472
|
return self._next
|
|
436
473
|
|
|
437
474
|
async def next(
|
|
438
|
-
self, value: EndMarker[OutputT] | Sequence[
|
|
475
|
+
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
|
|
439
476
|
) -> EndMarker[OutputT] | Sequence[GraphTask]:
|
|
440
477
|
"""Advance the graph execution by one step.
|
|
441
478
|
|
|
@@ -453,7 +490,10 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
453
490
|
# if `next` is called before the `first_node` has run.
|
|
454
491
|
await anext(self)
|
|
455
492
|
if value is not None:
|
|
456
|
-
|
|
493
|
+
if isinstance(value, EndMarker):
|
|
494
|
+
self._next = value
|
|
495
|
+
else:
|
|
496
|
+
self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
|
|
457
497
|
return await anext(self)
|
|
458
498
|
|
|
459
499
|
@property
|
|
@@ -476,6 +516,16 @@ class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
|
476
516
|
return self._next.value
|
|
477
517
|
return None
|
|
478
518
|
|
|
519
|
+
def _get_next_task_id(self) -> TaskID:
|
|
520
|
+
next_id = TaskID(f'task:{self._next_task_id}')
|
|
521
|
+
self._next_task_id += 1
|
|
522
|
+
return next_id
|
|
523
|
+
|
|
524
|
+
def _get_next_node_run_id(self) -> NodeRunID:
|
|
525
|
+
next_id = NodeRunID(f'task:{self._next_node_run_id}')
|
|
526
|
+
self._next_node_run_id += 1
|
|
527
|
+
return next_id
|
|
528
|
+
|
|
479
529
|
|
|
480
530
|
@dataclass
|
|
481
531
|
class _GraphTaskAsyncIterable:
|
|
@@ -495,192 +545,193 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
495
545
|
graph: Graph[StateT, DepsT, Any, OutputT]
|
|
496
546
|
state: StateT
|
|
497
547
|
deps: DepsT
|
|
548
|
+
task_group: TaskGroup
|
|
549
|
+
get_next_node_run_id: Callable[[], NodeRunID]
|
|
550
|
+
get_next_task_id: Callable[[], TaskID]
|
|
498
551
|
|
|
499
552
|
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
|
|
500
553
|
active_tasks: dict[TaskID, GraphTask] = field(init=False)
|
|
501
554
|
active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
|
|
502
555
|
iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
|
|
503
556
|
iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
|
|
504
|
-
_task_group: TaskGroup | None = field(init=False)
|
|
505
557
|
|
|
506
558
|
def __post_init__(self):
|
|
507
559
|
self.cancel_scopes = {}
|
|
508
560
|
self.active_tasks = {}
|
|
509
561
|
self.active_reducers = {}
|
|
510
562
|
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
|
|
511
|
-
|
|
512
|
-
@property
|
|
513
|
-
def task_group(self) -> TaskGroup:
|
|
514
|
-
if self._task_group is None:
|
|
515
|
-
raise RuntimeError("This graph iterator hasn't been started") # pragma: no cover
|
|
516
|
-
return self._task_group
|
|
563
|
+
self._next_node_run_id = 1
|
|
517
564
|
|
|
518
565
|
async def iter_graph( # noqa C901
|
|
519
566
|
self, first_task: GraphTask
|
|
520
567
|
) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
|
|
521
|
-
with
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
568
|
+
async with self.iter_stream_sender:
|
|
569
|
+
try:
|
|
570
|
+
# Fire off the first task
|
|
571
|
+
self.active_tasks[first_task.task_id] = first_task
|
|
572
|
+
self._handle_execution_request([first_task])
|
|
573
|
+
|
|
574
|
+
# Handle task results
|
|
575
|
+
async with self.iter_stream_receiver:
|
|
576
|
+
while self.active_tasks or self.active_reducers:
|
|
577
|
+
async for task_result in self.iter_stream_receiver: # pragma: no branch
|
|
578
|
+
if isinstance(task_result.result, JoinItem):
|
|
579
|
+
maybe_overridden_result = task_result.result
|
|
580
|
+
else:
|
|
581
|
+
maybe_overridden_result = yield task_result.result
|
|
582
|
+
if isinstance(maybe_overridden_result, EndMarker):
|
|
583
|
+
# If we got an end marker, this task is definitely done, and we're ready to
|
|
584
|
+
# start cleaning everything up
|
|
585
|
+
await self._finish_task(task_result.source.task_id)
|
|
586
|
+
if self.active_tasks:
|
|
587
|
+
# Cancel the remaining tasks
|
|
537
588
|
self.task_group.cancel_scope.cancel()
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
)
|
|
565
|
-
context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
|
|
566
|
-
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
|
|
567
|
-
if join_state.cancelled_sibling_tasks:
|
|
568
|
-
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
|
|
569
|
-
else:
|
|
570
|
-
for new_task in maybe_overridden_result:
|
|
571
|
-
self.active_tasks[new_task.task_id] = new_task
|
|
572
|
-
|
|
573
|
-
tasks_by_id_values = list(self.active_tasks.values())
|
|
574
|
-
join_tasks: list[GraphTask] = []
|
|
575
|
-
|
|
576
|
-
for join_id, fork_run_id in self._get_completed_fork_runs(
|
|
577
|
-
task_result.source, tasks_by_id_values
|
|
578
|
-
):
|
|
579
|
-
join_state = self.active_reducers.pop((join_id, fork_run_id))
|
|
580
|
-
join_node = self.graph.nodes[join_id]
|
|
581
|
-
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
582
|
-
new_tasks = self._handle_non_fork_edges(
|
|
583
|
-
join_node, join_state.current, join_state.downstream_fork_stack
|
|
589
|
+
return
|
|
590
|
+
elif isinstance(maybe_overridden_result, JoinItem):
|
|
591
|
+
result = maybe_overridden_result
|
|
592
|
+
parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
|
|
593
|
+
for i, x in enumerate(result.fork_stack[::-1]):
|
|
594
|
+
if x.fork_id == parent_fork_id:
|
|
595
|
+
# For non-final joins (those that are intermediate nodes of other joins),
|
|
596
|
+
# preserve the fork stack so downstream joins can still associate with the same fork run
|
|
597
|
+
if self.graph.is_final_join(result.join_id):
|
|
598
|
+
# Final join: remove the parent fork from the stack
|
|
599
|
+
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
|
|
600
|
+
else:
|
|
601
|
+
# Non-final join: preserve the fork stack
|
|
602
|
+
downstream_fork_stack = result.fork_stack
|
|
603
|
+
fork_run_id = x.node_run_id
|
|
604
|
+
break
|
|
605
|
+
else: # pragma: no cover
|
|
606
|
+
raise RuntimeError('Parent fork run not found')
|
|
607
|
+
|
|
608
|
+
join_node = self.graph.nodes[result.join_id]
|
|
609
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
610
|
+
join_state = self.active_reducers.get((result.join_id, fork_run_id))
|
|
611
|
+
if join_state is None:
|
|
612
|
+
current = join_node.initial_factory()
|
|
613
|
+
join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
|
|
614
|
+
current, downstream_fork_stack
|
|
584
615
|
)
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
if
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
616
|
+
context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
|
|
617
|
+
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
|
|
618
|
+
if join_state.cancelled_sibling_tasks:
|
|
619
|
+
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
|
|
620
|
+
else:
|
|
621
|
+
for new_task in maybe_overridden_result:
|
|
622
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
623
|
+
|
|
624
|
+
tasks_by_id_values = list(self.active_tasks.values())
|
|
625
|
+
join_tasks: list[GraphTask] = []
|
|
626
|
+
|
|
627
|
+
for join_id, fork_run_id in self._get_completed_fork_runs(
|
|
628
|
+
task_result.source, tasks_by_id_values
|
|
629
|
+
):
|
|
630
|
+
join_state = self.active_reducers.pop((join_id, fork_run_id))
|
|
631
|
+
join_node = self.graph.nodes[join_id]
|
|
632
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
633
|
+
new_tasks = self._handle_non_fork_edges(
|
|
634
|
+
join_node, join_state.current, join_state.downstream_fork_stack
|
|
635
|
+
)
|
|
636
|
+
join_tasks.extend(new_tasks)
|
|
637
|
+
if join_tasks:
|
|
638
|
+
for new_task in join_tasks:
|
|
639
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
640
|
+
self._handle_execution_request(join_tasks)
|
|
641
|
+
|
|
642
|
+
if isinstance(maybe_overridden_result, Sequence):
|
|
643
|
+
if isinstance(task_result.result, Sequence):
|
|
644
|
+
new_task_ids = {t.task_id for t in maybe_overridden_result}
|
|
645
|
+
for t in task_result.result:
|
|
646
|
+
if t.task_id not in new_task_ids:
|
|
647
|
+
await self._finish_task(t.task_id)
|
|
648
|
+
self._handle_execution_request(maybe_overridden_result)
|
|
649
|
+
|
|
650
|
+
if task_result.source_is_finished:
|
|
651
|
+
await self._finish_task(task_result.source.task_id)
|
|
652
|
+
|
|
653
|
+
if not self.active_tasks:
|
|
654
|
+
# if there are no active tasks, we'll be waiting forever for the next result..
|
|
655
|
+
break
|
|
656
|
+
|
|
657
|
+
if self.active_reducers: # pragma: no branch
|
|
658
|
+
# In this case, there are no pending tasks. We can therefore finalize all active reducers
|
|
659
|
+
# that don't have intermediate joins which are also active reducers. If a join J2 has an
|
|
660
|
+
# intermediate join J1 that shares the same parent fork run, we must finalize J1 first
|
|
661
|
+
# because it might produce items that feed into J2.
|
|
662
|
+
for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
|
|
663
|
+
# Check if this join has any intermediate joins that are also active reducers
|
|
664
|
+
should_skip = False
|
|
665
|
+
intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
|
|
666
|
+
|
|
667
|
+
# Get the parent fork for this join to use for comparison
|
|
668
|
+
join_parent_fork = self.graph.get_parent_fork(join_id)
|
|
669
|
+
|
|
670
|
+
for intermediate_join_id in intermediate_joins:
|
|
671
|
+
# Check if the intermediate join is also an active reducer with matching fork run
|
|
672
|
+
for (other_join_id, _), other_join_state in self.active_reducers.items():
|
|
673
|
+
if other_join_id == intermediate_join_id:
|
|
674
|
+
# Check if they share the same fork run for this join's parent fork
|
|
675
|
+
# by finding the parent fork's node_run_id in both fork stacks
|
|
676
|
+
join_parent_fork_run_id = None
|
|
677
|
+
other_parent_fork_run_id = None
|
|
678
|
+
|
|
679
|
+
for fsi in join_state.downstream_fork_stack: # pragma: no branch
|
|
680
|
+
if fsi.fork_id == join_parent_fork.fork_id:
|
|
681
|
+
join_parent_fork_run_id = fsi.node_run_id
|
|
682
|
+
break
|
|
683
|
+
|
|
684
|
+
for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
|
|
685
|
+
if fsi.fork_id == join_parent_fork.fork_id:
|
|
686
|
+
other_parent_fork_run_id = fsi.node_run_id
|
|
644
687
|
break
|
|
645
|
-
if should_skip:
|
|
646
|
-
break
|
|
647
688
|
|
|
689
|
+
if (
|
|
690
|
+
join_parent_fork_run_id
|
|
691
|
+
and other_parent_fork_run_id
|
|
692
|
+
and join_parent_fork_run_id == other_parent_fork_run_id
|
|
693
|
+
): # pragma: no branch
|
|
694
|
+
should_skip = True
|
|
695
|
+
break
|
|
648
696
|
if should_skip:
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
697
|
+
break
|
|
698
|
+
|
|
699
|
+
if should_skip:
|
|
700
|
+
continue
|
|
701
|
+
|
|
702
|
+
self.active_reducers.pop(
|
|
703
|
+
(join_id, fork_run_id)
|
|
704
|
+
) # we're handling it now, so we can pop it
|
|
705
|
+
join_node = self.graph.nodes[join_id]
|
|
706
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
707
|
+
new_tasks = self._handle_non_fork_edges(
|
|
708
|
+
join_node, join_state.current, join_state.downstream_fork_stack
|
|
709
|
+
)
|
|
710
|
+
maybe_overridden_result = yield new_tasks
|
|
711
|
+
if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
|
|
712
|
+
# This is theoretically reachable but it would be awkward.
|
|
713
|
+
# Probably a better way to get coverage here would be to unify the code pat
|
|
714
|
+
# with the other `if isinstance(maybe_overridden_result, EndMarker):`
|
|
715
|
+
self.task_group.cancel_scope.cancel()
|
|
716
|
+
return
|
|
717
|
+
for new_task in maybe_overridden_result:
|
|
718
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
719
|
+
new_task_ids = {t.task_id for t in maybe_overridden_result}
|
|
720
|
+
for t in new_tasks:
|
|
721
|
+
# Same note as above about how this is theoretically reachable but we should
|
|
722
|
+
# just get coverage by unifying the code paths
|
|
723
|
+
if t.task_id not in new_task_ids: # pragma: no cover
|
|
724
|
+
await self._finish_task(t.task_id)
|
|
725
|
+
self._handle_execution_request(maybe_overridden_result)
|
|
726
|
+
except GeneratorExit:
|
|
727
|
+
self.task_group.cancel_scope.cancel()
|
|
728
|
+
return
|
|
678
729
|
|
|
679
730
|
raise RuntimeError( # pragma: no cover
|
|
680
731
|
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
|
|
681
732
|
)
|
|
682
733
|
|
|
683
|
-
async def _finish_task(self, task_id: TaskID
|
|
734
|
+
async def _finish_task(self, task_id: TaskID) -> None:
|
|
684
735
|
# node_id is just included for debugging right now
|
|
685
736
|
scope = self.cancel_scopes.pop(task_id, None)
|
|
686
737
|
if scope is not None:
|
|
@@ -770,12 +821,12 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
770
821
|
fork_stack: ForkStack,
|
|
771
822
|
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
|
|
772
823
|
if isinstance(next_node, StepNode):
|
|
773
|
-
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
|
|
824
|
+
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
|
|
774
825
|
elif isinstance(next_node, JoinNode):
|
|
775
826
|
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
|
|
776
827
|
elif isinstance(next_node, BaseNode):
|
|
777
828
|
node_step = NodeStep(next_node.__class__)
|
|
778
|
-
return [GraphTask(node_step.id, next_node, fork_stack)]
|
|
829
|
+
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
|
|
779
830
|
elif isinstance(next_node, End):
|
|
780
831
|
return EndMarker(next_node.data)
|
|
781
832
|
else:
|
|
@@ -809,7 +860,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
809
860
|
'These markers should be removed from paths during graph building'
|
|
810
861
|
)
|
|
811
862
|
if isinstance(item, DestinationMarker):
|
|
812
|
-
return [GraphTask(item.destination_id, inputs, fork_stack)]
|
|
863
|
+
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
|
|
813
864
|
elif isinstance(item, TransformMarker):
|
|
814
865
|
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
|
|
815
866
|
return self._handle_path(path.next_path, inputs, fork_stack)
|
|
@@ -841,7 +892,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
841
892
|
) # this should have already been ensured during graph building
|
|
842
893
|
|
|
843
894
|
new_tasks: list[GraphTask] = []
|
|
844
|
-
node_run_id =
|
|
895
|
+
node_run_id = self.get_next_node_run_id()
|
|
845
896
|
if node.is_map:
|
|
846
897
|
# If the map specifies a downstream join id, eagerly create a join state for it
|
|
847
898
|
if (join_id := node.downstream_join_id) is not None:
|
|
@@ -898,7 +949,7 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
|
898
949
|
else:
|
|
899
950
|
pass
|
|
900
951
|
for task_id in task_ids_to_cancel:
|
|
901
|
-
await self._finish_task(task_id
|
|
952
|
+
await self._finish_task(task_id)
|
|
902
953
|
|
|
903
954
|
|
|
904
955
|
def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
|
|
@@ -24,9 +24,6 @@ JoinID = NodeID
|
|
|
24
24
|
ForkID = NodeID
|
|
25
25
|
"""Alias for NodeId when referring to fork nodes."""
|
|
26
26
|
|
|
27
|
-
GraphRunID = NewType('GraphRunID', str)
|
|
28
|
-
"""Unique identifier for a complete graph execution run."""
|
|
29
|
-
|
|
30
27
|
TaskID = NewType('TaskID', str)
|
|
31
28
|
"""Unique identifier for a task within the graph execution."""
|
|
32
29
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|