pydantic-graph 0.2.2__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,978 @@
1
+ """Core graph execution engine for the next version of the pydantic-graph library.
2
+
3
+ This module provides the main `Graph` class and `GraphRun` execution engine that
4
+ handles the orchestration of nodes, edges, and parallel execution paths in
5
+ the graph-based workflow system.
6
+ """
7
+
8
+ from __future__ import annotations as _annotations
9
+
10
+ import sys
11
+ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
12
+ from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
13
+ from dataclasses import dataclass, field
14
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
15
+
16
+ from anyio import CancelScope, create_memory_object_stream, create_task_group
17
+ from anyio.abc import TaskGroup
18
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
19
+ from typing_extensions import TypeVar, assert_never
20
+
21
+ from pydantic_graph import exceptions
22
+ from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
23
+ from pydantic_graph.beta.decision import Decision
24
+ from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
25
+ from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
26
+ from pydantic_graph.beta.node import (
27
+ EndNode,
28
+ Fork,
29
+ StartNode,
30
+ )
31
+ from pydantic_graph.beta.node_types import AnyNode
32
+ from pydantic_graph.beta.parent_forks import ParentFork
33
+ from pydantic_graph.beta.paths import (
34
+ BroadcastMarker,
35
+ DestinationMarker,
36
+ LabelMarker,
37
+ MapMarker,
38
+ Path,
39
+ TransformMarker,
40
+ )
41
+ from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepNode
42
+ from pydantic_graph.beta.util import unpack_type_expression
43
+ from pydantic_graph.nodes import BaseNode, End
44
+
45
+ if sys.version_info < (3, 11):
46
+ from exceptiongroup import BaseExceptionGroup as BaseExceptionGroup # pragma: lax no cover
47
+ else:
48
+ BaseExceptionGroup = BaseExceptionGroup # pragma: lax no cover
49
+
50
+ if TYPE_CHECKING:
51
+ from pydantic_graph.beta.mermaid import StateDiagramDirection
52
+
53
+
54
+ StateT = TypeVar('StateT', infer_variance=True)
55
+ """Type variable for graph state."""
56
+
57
+ DepsT = TypeVar('DepsT', infer_variance=True)
58
+ """Type variable for graph dependencies."""
59
+
60
+ InputT = TypeVar('InputT', infer_variance=True)
61
+ """Type variable for graph inputs."""
62
+
63
+ OutputT = TypeVar('OutputT', infer_variance=True)
64
+ """Type variable for graph outputs."""
65
+
66
+
67
+ @dataclass(init=False)
68
+ class EndMarker(Generic[OutputT]):
69
+ """A marker indicating the end of graph execution with a final value.
70
+
71
+ EndMarker is used internally to signal that the graph has completed
72
+ execution and carries the final output value.
73
+
74
+ Type Parameters:
75
+ OutputT: The type of the final output value
76
+ """
77
+
78
+ _value: OutputT
79
+ """The final output value from the graph execution."""
80
+
81
+ def __init__(self, value: OutputT):
82
+ # This manually-defined initializer is necessary due to https://github.com/python/mypy/issues/17623.
83
+ self._value = value
84
+
85
+ @property
86
+ def value(self) -> OutputT:
87
+ return self._value
88
+
89
+
90
+ @dataclass
91
+ class JoinItem:
92
+ """An item representing data flowing into a join operation.
93
+
94
+ JoinItem carries input data from a parallel execution path to a join
95
+ node, along with metadata about which execution 'fork' it originated from.
96
+ """
97
+
98
+ join_id: JoinID
99
+ """The ID of the join node this item is targeting."""
100
+
101
+ inputs: Any
102
+ """The input data for the join operation."""
103
+
104
+ fork_stack: ForkStack
105
+ """The stack of ForkStackItems that led to producing this join item."""
106
+
107
+
108
+ @dataclass(repr=False)
109
+ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
110
+ """A complete graph definition ready for execution.
111
+
112
+ The Graph class represents a complete workflow graph with typed inputs,
113
+ outputs, state, and dependencies. It contains all nodes, edges, and
114
+ metadata needed for execution.
115
+
116
+ Type Parameters:
117
+ StateT: The type of the graph state
118
+ DepsT: The type of the dependencies
119
+ InputT: The type of the input data
120
+ OutputT: The type of the output data
121
+ """
122
+
123
+ name: str | None
124
+ """Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method."""
125
+
126
+ state_type: type[StateT]
127
+ """The type of the graph state."""
128
+
129
+ deps_type: type[DepsT]
130
+ """The type of the dependencies."""
131
+
132
+ input_type: type[InputT]
133
+ """The type of the input data."""
134
+
135
+ output_type: type[OutputT]
136
+ """The type of the output data."""
137
+
138
+ auto_instrument: bool
139
+ """Whether to automatically create instrumentation spans."""
140
+
141
+ nodes: dict[NodeID, AnyNode]
142
+ """All nodes in the graph indexed by their ID."""
143
+
144
+ edges_by_source: dict[NodeID, list[Path]]
145
+ """Outgoing paths from each source node."""
146
+
147
+ parent_forks: dict[JoinID, ParentFork[NodeID]]
148
+ """Parent fork information for each join node."""
149
+
150
+ intermediate_join_nodes: dict[JoinID, set[JoinID]]
151
+ """For each join, the set of other joins that appear between it and its parent fork.
152
+
153
+ Used to determine which joins are "final" (have no other joins as intermediates) and
154
+ which joins should preserve fork stacks when proceeding downstream."""
155
+
156
+ def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
157
+ """Get the parent fork information for a join node.
158
+
159
+ Args:
160
+ join_id: The ID of the join node
161
+
162
+ Returns:
163
+ The parent fork information for the join
164
+
165
+ Raises:
166
+ RuntimeError: If the join ID is not found or has no parent fork
167
+ """
168
+ result = self.parent_forks.get(join_id)
169
+ if result is None:
170
+ raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)')
171
+ return result
172
+
173
+ def is_final_join(self, join_id: JoinID) -> bool:
174
+ """Check if a join is 'final' (has no downstream joins with the same parent fork).
175
+
176
+ A join is non-final if it appears as an intermediate node for another join
177
+ with the same parent fork.
178
+
179
+ Args:
180
+ join_id: The ID of the join node
181
+
182
+ Returns:
183
+ True if the join is final, False if it's non-final
184
+ """
185
+ # Check if this join appears in any other join's intermediate_join_nodes
186
+ for intermediate_joins in self.intermediate_join_nodes.values():
187
+ if join_id in intermediate_joins:
188
+ return False
189
+ return True
190
+
191
+ async def run(
192
+ self,
193
+ *,
194
+ state: StateT = None,
195
+ deps: DepsT = None,
196
+ inputs: InputT = None,
197
+ span: AbstractContextManager[AbstractSpan] | None = None,
198
+ infer_name: bool = True,
199
+ ) -> OutputT:
200
+ """Execute the graph and return the final output.
201
+
202
+ This is the main entry point for graph execution. It runs the graph
203
+ to completion and returns the final output value.
204
+
205
+ Args:
206
+ state: The graph state instance
207
+ deps: The dependencies instance
208
+ inputs: The input data for the graph
209
+ span: Optional span for tracing/instrumentation
210
+ infer_name: Whether to infer the graph name from the calling frame.
211
+
212
+ Returns:
213
+ The final output from the graph execution
214
+ """
215
+ if infer_name and self.name is None:
216
+ inferred_name = infer_obj_name(self, depth=2)
217
+ if inferred_name is not None: # pragma: no branch
218
+ self.name = inferred_name
219
+
220
+ async with self.iter(state=state, deps=deps, inputs=inputs, span=span, infer_name=False) as graph_run:
221
+ # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method,
222
+ # which I'm less confident will be implemented correctly if not used on the critical path. We can change it
223
+ # once we have tests, etc.
224
+ event: Any = None
225
+ while True:
226
+ try:
227
+ event = await graph_run.next(event)
228
+ except StopAsyncIteration:
229
+ assert isinstance(event, EndMarker), 'Graph run should end with an EndMarker.'
230
+ return cast(EndMarker[OutputT], event).value
231
+
232
+ @asynccontextmanager
233
+ async def iter(
234
+ self,
235
+ *,
236
+ state: StateT = None,
237
+ deps: DepsT = None,
238
+ inputs: InputT = None,
239
+ span: AbstractContextManager[AbstractSpan] | None = None,
240
+ infer_name: bool = True,
241
+ ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]:
242
+ """Create an iterator for step-by-step graph execution.
243
+
244
+ This method allows for more fine-grained control over graph execution,
245
+ enabling inspection of intermediate states and results.
246
+
247
+ Args:
248
+ state: The graph state instance
249
+ deps: The dependencies instance
250
+ inputs: The input data for the graph
251
+ span: Optional span for tracing/instrumentation
252
+ infer_name: Whether to infer the graph name from the calling frame.
253
+
254
+ Yields:
255
+ A GraphRun instance that can be iterated for step-by-step execution
256
+ """
257
+ if infer_name and self.name is None:
258
+ inferred_name = infer_obj_name(self, depth=3) # depth=3 because asynccontextmanager adds one
259
+ if inferred_name is not None: # pragma: no branch
260
+ self.name = inferred_name
261
+
262
+ with ExitStack() as stack:
263
+ entered_span: AbstractSpan | None = None
264
+ if span is None:
265
+ if self.auto_instrument:
266
+ entered_span = stack.enter_context(logfire_span('run graph {graph.name}', graph=self))
267
+ else:
268
+ entered_span = stack.enter_context(span)
269
+ traceparent = None if entered_span is None else get_traceparent(entered_span)
270
+ async with GraphRun[StateT, DepsT, OutputT](
271
+ graph=self,
272
+ state=state,
273
+ deps=deps,
274
+ inputs=inputs,
275
+ traceparent=traceparent,
276
+ ) as graph_run:
277
+ yield graph_run
278
+
279
+ def render(self, *, title: str | None = None, direction: StateDiagramDirection | None = None) -> str:
280
+ """Render the graph as a Mermaid diagram string.
281
+
282
+ Args:
283
+ title: Optional title for the diagram
284
+ direction: Optional direction for the diagram layout
285
+
286
+ Returns:
287
+ A string containing the Mermaid diagram representation
288
+ """
289
+ from pydantic_graph.beta.mermaid import build_mermaid_graph
290
+
291
+ return build_mermaid_graph(self.nodes, self.edges_by_source).render(title=title, direction=direction)
292
+
293
+ def __repr__(self) -> str:
294
+ super_repr = super().__repr__() # include class and memory address
295
+ # Insert the result of calling `__str__` before the final '>' in the repr
296
+ return f'{super_repr[:-1]}\n{self}\n{super_repr[-1]}'
297
+
298
+ def __str__(self) -> str:
299
+ """Return a Mermaid diagram representation of the graph.
300
+
301
+ Returns:
302
+ A string containing the Mermaid diagram of the graph
303
+ """
304
+ return self.render()
305
+
306
+
307
+ @dataclass
308
+ class GraphTaskRequest:
309
+ """A request to run a task representing the execution of a node in the graph.
310
+
311
+ GraphTaskRequest encapsulates all the information needed to execute a specific
312
+ node, including its inputs and the fork context it's executing within.
313
+ """
314
+
315
+ node_id: NodeID
316
+ """The ID of the node to execute."""
317
+
318
+ inputs: Any
319
+ """The input data for the node."""
320
+
321
+ fork_stack: ForkStack = field(repr=False)
322
+ """Stack of forks that have been entered.
323
+
324
+ Used by the GraphRun to decide when to proceed through joins.
325
+ """
326
+
327
+
328
+ @dataclass
329
+ class GraphTask(GraphTaskRequest):
330
+ """A task representing the execution of a node in the graph.
331
+
332
+ GraphTask encapsulates all the information needed to execute a specific
333
+ node, including its inputs and the fork context it's executing within,
334
+ and has a unique ID to identify the task within the graph run.
335
+ """
336
+
337
+ task_id: TaskID = field(repr=False)
338
+ """Unique identifier for this task."""
339
+
340
+ @staticmethod
341
+ def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
342
+ # Don't call the get_task_id callable, this is already a task
343
+ if isinstance(request, GraphTask):
344
+ return request
345
+ return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
346
+
347
+
348
+ class GraphRun(Generic[StateT, DepsT, OutputT]):
349
+ """A single execution instance of a graph.
350
+
351
+ GraphRun manages the execution state for a single run of a graph,
352
+ including task scheduling, fork/join coordination, and result tracking.
353
+
354
+ Type Parameters:
355
+ StateT: The type of the graph state
356
+ DepsT: The type of the dependencies
357
+ OutputT: The type of the output data
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ graph: Graph[StateT, DepsT, InputT, OutputT],
363
+ *,
364
+ state: StateT,
365
+ deps: DepsT,
366
+ inputs: InputT,
367
+ traceparent: str | None,
368
+ ):
369
+ """Initialize a graph run.
370
+
371
+ Args:
372
+ graph: The graph to execute
373
+ state: The graph state instance
374
+ deps: The dependencies instance
375
+ inputs: The input data for the graph
376
+ traceparent: Optional trace parent for instrumentation
377
+ """
378
+ self.graph = graph
379
+ """The graph being executed."""
380
+
381
+ self.state = state
382
+ """The graph state instance."""
383
+
384
+ self.deps = deps
385
+ """The dependencies instance."""
386
+
387
+ self.inputs = inputs
388
+ """The initial input data."""
389
+
390
+ self._active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = {}
391
+ """Active reducers for join operations."""
392
+
393
+ self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
394
+ """The next item to be processed."""
395
+
396
+ self._next_task_id = 0
397
+ self._next_node_run_id = 0
398
+ initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
399
+ self._first_task = GraphTask(
400
+ node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
401
+ )
402
+ self._iterator_task_group = create_task_group()
403
+ self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
404
+ self.graph,
405
+ self.state,
406
+ self.deps,
407
+ self._iterator_task_group,
408
+ self._get_next_node_run_id,
409
+ self._get_next_task_id,
410
+ )
411
+ self._iterator = self._iterator_instance.iter_graph(self._first_task)
412
+
413
+ self.__traceparent = traceparent
414
+ self._async_exit_stack = AsyncExitStack()
415
+
416
+ async def __aenter__(self):
417
+ self._async_exit_stack.enter_context(_unwrap_exception_groups())
418
+ await self._async_exit_stack.enter_async_context(self._iterator_task_group)
419
+ await self._async_exit_stack.enter_async_context(self._iterator_context())
420
+ return self
421
+
422
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
423
+ await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
424
+
425
+ @asynccontextmanager
426
+ async def _iterator_context(self):
427
+ try:
428
+ yield
429
+ finally:
430
+ self._iterator_instance.iter_stream_sender.close()
431
+ self._iterator_instance.iter_stream_receiver.close()
432
+ await self._iterator.aclose()
433
+
434
+ @overload
435
+ def _traceparent(self, *, required: Literal[False]) -> str | None: ...
436
+ @overload
437
+ def _traceparent(self) -> str: ...
438
+ def _traceparent(self, *, required: bool = True) -> str | None:
439
+ """Get the trace parent for instrumentation.
440
+
441
+ Args:
442
+ required: Whether to raise an error if no traceparent exists
443
+
444
+ Returns:
445
+ The traceparent string, or None if not required and not set
446
+
447
+ Raises:
448
+ GraphRuntimeError: If required is True and no traceparent exists
449
+ """
450
+ if self.__traceparent is None and required: # pragma: no cover
451
+ raise exceptions.GraphRuntimeError('No span was created for this graph run')
452
+ return self.__traceparent
453
+
454
+ def __aiter__(self) -> AsyncIterator[EndMarker[OutputT] | Sequence[GraphTask]]:
455
+ """Return self as an async iterator.
456
+
457
+ Returns:
458
+ Self for async iteration
459
+ """
460
+ return self
461
+
462
+ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
463
+ """Get the next item in the async iteration.
464
+
465
+ Returns:
466
+ The next execution result from the graph
467
+ """
468
+ if self._next is None:
469
+ self._next = await anext(self._iterator)
470
+ else:
471
+ self._next = await self._iterator.asend(self._next)
472
+ return self._next
473
+
474
+ async def next(
475
+ self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
476
+ ) -> EndMarker[OutputT] | Sequence[GraphTask]:
477
+ """Advance the graph execution by one step.
478
+
479
+ This method allows for sending a value to the iterator, which is useful
480
+ for resuming iteration or overriding intermediate results.
481
+
482
+ Args:
483
+ value: Optional value to send to the iterator
484
+
485
+ Returns:
486
+ The next execution result: either an EndMarker, or sequence of GraphTasks
487
+ """
488
+ if self._next is None:
489
+ # Prevent `TypeError: can't send non-None value to a just-started async generator`
490
+ # if `next` is called before the `first_node` has run.
491
+ await anext(self)
492
+ if value is not None:
493
+ if isinstance(value, EndMarker):
494
+ self._next = value
495
+ else:
496
+ self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
497
+ return await anext(self)
498
+
499
+ @property
500
+ def next_task(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
501
+ """Get the next task(s) to be executed.
502
+
503
+ Returns:
504
+ The next execution item, or the initial task if none is set
505
+ """
506
+ return self._next or [self._first_task]
507
+
508
+ @property
509
+ def output(self) -> OutputT | None:
510
+ """Get the final output if the graph has completed.
511
+
512
+ Returns:
513
+ The output value if execution is complete, None otherwise
514
+ """
515
+ if isinstance(self._next, EndMarker):
516
+ return self._next.value
517
+ return None
518
+
519
+ def _get_next_task_id(self) -> TaskID:
520
+ next_id = TaskID(f'task:{self._next_task_id}')
521
+ self._next_task_id += 1
522
+ return next_id
523
+
524
+ def _get_next_node_run_id(self) -> NodeRunID:
525
+ next_id = NodeRunID(f'task:{self._next_node_run_id}')
526
+ self._next_node_run_id += 1
527
+ return next_id
528
+
529
+
530
+ @dataclass
531
+ class _GraphTaskAsyncIterable:
532
+ iterable: AsyncIterable[Sequence[GraphTask]]
533
+ fork_stack: ForkStack
534
+
535
+
536
+ @dataclass
537
+ class _GraphTaskResult:
538
+ source: GraphTask
539
+ result: EndMarker[Any] | Sequence[GraphTask] | JoinItem
540
+ source_is_finished: bool = True
541
+
542
+
543
+ @dataclass
544
+ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
545
+ graph: Graph[StateT, DepsT, Any, OutputT]
546
+ state: StateT
547
+ deps: DepsT
548
+ task_group: TaskGroup
549
+ get_next_node_run_id: Callable[[], NodeRunID]
550
+ get_next_task_id: Callable[[], TaskID]
551
+
552
+ cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
553
+ active_tasks: dict[TaskID, GraphTask] = field(init=False)
554
+ active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
555
+ iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
556
+ iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
557
+
558
+ def __post_init__(self):
559
+ self.cancel_scopes = {}
560
+ self.active_tasks = {}
561
+ self.active_reducers = {}
562
+ self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
563
+ self._next_node_run_id = 1
564
+
565
+ async def iter_graph( # noqa: C901
566
+ self, first_task: GraphTask
567
+ ) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
568
+ async with self.iter_stream_sender:
569
+ try:
570
+ # Fire off the first task
571
+ self.active_tasks[first_task.task_id] = first_task
572
+ self._handle_execution_request([first_task])
573
+
574
+ # Handle task results
575
+ async with self.iter_stream_receiver:
576
+ while self.active_tasks or self.active_reducers:
577
+ async for task_result in self.iter_stream_receiver: # pragma: no branch
578
+ if isinstance(task_result.result, JoinItem):
579
+ maybe_overridden_result = task_result.result
580
+ else:
581
+ maybe_overridden_result = yield task_result.result
582
+ if isinstance(maybe_overridden_result, EndMarker):
583
+ # If we got an end marker, this task is definitely done, and we're ready to
584
+ # start cleaning everything up
585
+ await self._finish_task(task_result.source.task_id)
586
+ if self.active_tasks:
587
+ # Cancel the remaining tasks
588
+ self.task_group.cancel_scope.cancel()
589
+ return
590
+ elif isinstance(maybe_overridden_result, JoinItem):
591
+ result = maybe_overridden_result
592
+ parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
593
+ for i, x in enumerate(result.fork_stack[::-1]):
594
+ if x.fork_id == parent_fork_id:
595
+ # For non-final joins (those that are intermediate nodes of other joins),
596
+ # preserve the fork stack so downstream joins can still associate with the same fork run
597
+ if self.graph.is_final_join(result.join_id):
598
+ # Final join: remove the parent fork from the stack
599
+ downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
600
+ else:
601
+ # Non-final join: preserve the fork stack
602
+ downstream_fork_stack = result.fork_stack
603
+ fork_run_id = x.node_run_id
604
+ break
605
+ else: # pragma: no cover
606
+ raise RuntimeError('Parent fork run not found')
607
+
608
+ join_node = self.graph.nodes[result.join_id]
609
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
610
+ join_state = self.active_reducers.get((result.join_id, fork_run_id))
611
+ if join_state is None:
612
+ current = join_node.initial_factory()
613
+ join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
614
+ current, downstream_fork_stack
615
+ )
616
+ context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
617
+ join_state.current = join_node.reduce(context, join_state.current, result.inputs)
618
+ if join_state.cancelled_sibling_tasks:
619
+ await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
620
+ else:
621
+ for new_task in maybe_overridden_result:
622
+ self.active_tasks[new_task.task_id] = new_task
623
+
624
+ tasks_by_id_values = list(self.active_tasks.values())
625
+ join_tasks: list[GraphTask] = []
626
+
627
+ for join_id, fork_run_id in self._get_completed_fork_runs(
628
+ task_result.source, tasks_by_id_values
629
+ ):
630
+ join_state = self.active_reducers.pop((join_id, fork_run_id))
631
+ join_node = self.graph.nodes[join_id]
632
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
633
+ new_tasks = self._handle_non_fork_edges(
634
+ join_node, join_state.current, join_state.downstream_fork_stack
635
+ )
636
+ join_tasks.extend(new_tasks)
637
+ if join_tasks:
638
+ for new_task in join_tasks:
639
+ self.active_tasks[new_task.task_id] = new_task
640
+ self._handle_execution_request(join_tasks)
641
+
642
+ if isinstance(maybe_overridden_result, Sequence):
643
+ if isinstance(task_result.result, Sequence):
644
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
645
+ for t in task_result.result:
646
+ if t.task_id not in new_task_ids:
647
+ await self._finish_task(t.task_id)
648
+ self._handle_execution_request(maybe_overridden_result)
649
+
650
+ if task_result.source_is_finished:
651
+ await self._finish_task(task_result.source.task_id)
652
+
653
+ if not self.active_tasks:
654
+ # if there are no active tasks, we'll be waiting forever for the next result..
655
+ break
656
+
657
+ if self.active_reducers: # pragma: no branch
658
+ # In this case, there are no pending tasks. We can therefore finalize all active reducers
659
+ # that don't have intermediate joins which are also active reducers. If a join J2 has an
660
+ # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
661
+ # because it might produce items that feed into J2.
662
+ for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
663
+ # Check if this join has any intermediate joins that are also active reducers
664
+ should_skip = False
665
+ intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
666
+
667
+ # Get the parent fork for this join to use for comparison
668
+ join_parent_fork = self.graph.get_parent_fork(join_id)
669
+
670
+ for intermediate_join_id in intermediate_joins:
671
+ # Check if the intermediate join is also an active reducer with matching fork run
672
+ for (other_join_id, _), other_join_state in self.active_reducers.items():
673
+ if other_join_id == intermediate_join_id:
674
+ # Check if they share the same fork run for this join's parent fork
675
+ # by finding the parent fork's node_run_id in both fork stacks
676
+ join_parent_fork_run_id = None
677
+ other_parent_fork_run_id = None
678
+
679
+ for fsi in join_state.downstream_fork_stack: # pragma: no branch
680
+ if fsi.fork_id == join_parent_fork.fork_id:
681
+ join_parent_fork_run_id = fsi.node_run_id
682
+ break
683
+
684
+ for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
685
+ if fsi.fork_id == join_parent_fork.fork_id:
686
+ other_parent_fork_run_id = fsi.node_run_id
687
+ break
688
+
689
+ if (
690
+ join_parent_fork_run_id
691
+ and other_parent_fork_run_id
692
+ and join_parent_fork_run_id == other_parent_fork_run_id
693
+ ): # pragma: no branch
694
+ should_skip = True
695
+ break
696
+ if should_skip:
697
+ break
698
+
699
+ if should_skip:
700
+ continue
701
+
702
+ self.active_reducers.pop(
703
+ (join_id, fork_run_id)
704
+ ) # we're handling it now, so we can pop it
705
+ join_node = self.graph.nodes[join_id]
706
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
707
+ new_tasks = self._handle_non_fork_edges(
708
+ join_node, join_state.current, join_state.downstream_fork_stack
709
+ )
710
+ maybe_overridden_result = yield new_tasks
711
+ if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
712
+ # This is theoretically reachable but it would be awkward.
713
+ # Probably a better way to get coverage here would be to unify the code pat
714
+ # with the other `if isinstance(maybe_overridden_result, EndMarker):`
715
+ self.task_group.cancel_scope.cancel()
716
+ return
717
+ for new_task in maybe_overridden_result:
718
+ self.active_tasks[new_task.task_id] = new_task
719
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
720
+ for t in new_tasks:
721
+ # Same note as above about how this is theoretically reachable but we should
722
+ # just get coverage by unifying the code paths
723
+ if t.task_id not in new_task_ids: # pragma: no cover
724
+ await self._finish_task(t.task_id)
725
+ self._handle_execution_request(maybe_overridden_result)
726
+ except GeneratorExit:
727
+ self.task_group.cancel_scope.cancel()
728
+ return
729
+
730
+ raise RuntimeError( # pragma: no cover
731
+ 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
732
+ )
733
+
734
+ async def _finish_task(self, task_id: TaskID) -> None:
735
+ # node_id is just included for debugging right now
736
+ scope = self.cancel_scopes.pop(task_id, None)
737
+ if scope is not None:
738
+ scope.cancel()
739
+ self.active_tasks.pop(task_id, None)
740
+
741
+ def _handle_execution_request(self, request: Sequence[GraphTask]) -> None:
742
+ for new_task in request:
743
+ self.active_tasks[new_task.task_id] = new_task
744
+ for new_task in request:
745
+ self.task_group.start_soon(self._run_tracked_task, new_task)
746
+
747
+ async def _run_tracked_task(self, t_: GraphTask):
748
+ with CancelScope() as scope:
749
+ self.cancel_scopes[t_.task_id] = scope
750
+ result = await self._run_task(t_)
751
+ if isinstance(result, _GraphTaskAsyncIterable):
752
+ async for new_tasks in result.iterable:
753
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
754
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
755
+ else:
756
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
757
+
758
+ async def _run_task(
759
+ self,
760
+ task: GraphTask,
761
+ ) -> EndMarker[OutputT] | Sequence[GraphTask] | _GraphTaskAsyncIterable | JoinItem:
762
+ state = self.state
763
+ deps = self.deps
764
+
765
+ node_id = task.node_id
766
+ inputs = task.inputs
767
+ fork_stack = task.fork_stack
768
+
769
+ node = self.graph.nodes[node_id]
770
+
771
+ if isinstance(node, StartNode | Fork):
772
+ return self._handle_edges(node, inputs, fork_stack)
773
+ elif isinstance(node, Step):
774
+ with ExitStack() as stack:
775
+ if self.graph.auto_instrument:
776
+ stack.enter_context(logfire_span('run node {node_id}', node_id=node.id, node=node))
777
+
778
+ step_context = StepContext[StateT, DepsT, Any](state=state, deps=deps, inputs=inputs)
779
+ output = await node.call(step_context)
780
+ if isinstance(node, NodeStep):
781
+ return self._handle_node(output, fork_stack)
782
+ else:
783
+ return self._handle_edges(node, output, fork_stack)
784
+ elif isinstance(node, Join):
785
+ return JoinItem(node_id, inputs, fork_stack)
786
+ elif isinstance(node, Decision):
787
+ return self._handle_decision(node, inputs, fork_stack)
788
+ elif isinstance(node, EndNode):
789
+ return EndMarker(inputs)
790
+ else:
791
+ assert_never(node)
792
+
793
+ def _handle_decision(
794
+ self, decision: Decision[StateT, DepsT, Any], inputs: Any, fork_stack: ForkStack
795
+ ) -> Sequence[GraphTask]:
796
+ for branch in decision.branches:
797
+ match_tester = branch.matches
798
+ if match_tester is not None:
799
+ inputs_match = match_tester(inputs)
800
+ else:
801
+ branch_source = unpack_type_expression(branch.source)
802
+
803
+ if branch_source in {Any, object}:
804
+ inputs_match = True
805
+ elif get_origin(branch_source) is Literal:
806
+ inputs_match = inputs in get_args(branch_source)
807
+ else:
808
+ try:
809
+ inputs_match = isinstance(inputs, branch_source)
810
+ except TypeError as e: # pragma: no cover
811
+ raise RuntimeError(f'Decision branch source {branch_source} is not a valid type.') from e
812
+
813
+ if inputs_match:
814
+ return self._handle_path(branch.path, inputs, fork_stack)
815
+
816
+ raise RuntimeError(f'No branch matched inputs {inputs} for decision node {decision}.')
817
+
818
+ def _handle_node(
819
+ self,
820
+ next_node: BaseNode[StateT, DepsT, Any] | End[Any],
821
+ fork_stack: ForkStack,
822
+ ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
823
+ if isinstance(next_node, StepNode):
824
+ return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
825
+ elif isinstance(next_node, JoinNode):
826
+ return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
827
+ elif isinstance(next_node, BaseNode):
828
+ node_step = NodeStep(next_node.__class__)
829
+ return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
830
+ elif isinstance(next_node, End):
831
+ return EndMarker(next_node.data)
832
+ else:
833
+ assert_never(next_node)
834
+
835
+ def _get_completed_fork_runs(
836
+ self,
837
+ t: GraphTask,
838
+ active_tasks: Iterable[GraphTask],
839
+ ) -> list[tuple[JoinID, NodeRunID]]:
840
+ completed_fork_runs: list[tuple[JoinID, NodeRunID]] = []
841
+
842
+ fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)}
843
+ for join_id, fork_run_id in self.active_reducers.keys():
844
+ fork_run_index = fork_run_indices.get(fork_run_id)
845
+ if fork_run_index is None:
846
+ continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it.
847
+
848
+ # This reducer _may_ now be ready to finalize:
849
+ if self._is_fork_run_completed(active_tasks, join_id, fork_run_id):
850
+ completed_fork_runs.append((join_id, fork_run_id))
851
+
852
+ return completed_fork_runs
853
+
854
+ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
855
+ if not path.items:
856
+ return [] # pragma: no cover
857
+
858
+ item = path.items[0]
859
+ assert not isinstance(item, MapMarker | BroadcastMarker), (
860
+ 'These markers should be removed from paths during graph building'
861
+ )
862
+ if isinstance(item, DestinationMarker):
863
+ return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
864
+ elif isinstance(item, TransformMarker):
865
+ inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
866
+ return self._handle_path(path.next_path, inputs, fork_stack)
867
+ elif isinstance(item, LabelMarker):
868
+ return self._handle_path(path.next_path, inputs, fork_stack)
869
+ else:
870
+ assert_never(item)
871
+
872
+ def _handle_edges(
873
+ self, node: AnyNode, inputs: Any, fork_stack: ForkStack
874
+ ) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
875
+ if isinstance(node, Fork):
876
+ return self._handle_fork_edges(node, inputs, fork_stack)
877
+ else:
878
+ return self._handle_non_fork_edges(node, inputs, fork_stack)
879
+
880
+ def _handle_non_fork_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
881
+ edges = self.graph.edges_by_source.get(node.id, [])
882
+ assert len(edges) == 1 # this should have already been ensured during graph building
883
+ return self._handle_path(edges[0], inputs, fork_stack)
884
+
885
+ def _handle_fork_edges(
886
+ self, node: Fork[Any, Any], inputs: Any, fork_stack: ForkStack
887
+ ) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
888
+ edges = self.graph.edges_by_source.get(node.id, [])
889
+ assert len(edges) == 1 or (isinstance(node, Fork) and not node.is_map), (
890
+ edges,
891
+ node.id,
892
+ ) # this should have already been ensured during graph building
893
+
894
+ new_tasks: list[GraphTask] = []
895
+ node_run_id = self.get_next_node_run_id()
896
+ if node.is_map:
897
+ # If the map specifies a downstream join id, eagerly create a join state for it
898
+ if (join_id := node.downstream_join_id) is not None:
899
+ join_node = self.graph.nodes[join_id]
900
+ assert isinstance(join_node, Join)
901
+ self.active_reducers[(join_id, node_run_id)] = JoinState(join_node.initial_factory(), fork_stack)
902
+
903
+ # Eagerly raise a clear error if the input value is not iterable as expected
904
+ if _is_any_iterable(inputs):
905
+ for thread_index, input_item in enumerate(inputs):
906
+ item_tasks = self._handle_path(
907
+ edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
908
+ )
909
+ new_tasks += item_tasks
910
+ elif _is_any_async_iterable(inputs):
911
+
912
+ async def handle_async_iterable() -> AsyncIterator[Sequence[GraphTask]]:
913
+ thread_index = 0
914
+ async for input_item in inputs:
915
+ item_tasks = self._handle_path(
916
+ edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
917
+ )
918
+ yield item_tasks
919
+ thread_index += 1
920
+
921
+ return _GraphTaskAsyncIterable(handle_async_iterable(), fork_stack)
922
+
923
+ else:
924
+ raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
925
+ else:
926
+ for i, path in enumerate(edges):
927
+ new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),))
928
+ return new_tasks
929
+
930
+ def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool:
931
+ # Check if any of the tasks in the graph have this fork_run_id in their fork_stack
932
+ # If this is the case, then the fork run is not yet completed
933
+ parent_fork = self.graph.get_parent_fork(join_id)
934
+ for t in tasks:
935
+ if fork_run_id in {x.node_run_id for x in t.fork_stack}:
936
+ if t.node_id in parent_fork.intermediate_nodes or t.node_id == join_id:
937
+ return False
938
+ else:
939
+ pass
940
+ return True
941
+
942
+ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeRunID):
943
+ task_ids_to_cancel = set[TaskID]()
944
+ for task_id, t in self.active_tasks.items():
945
+ for item in t.fork_stack: # pragma: no branch
946
+ if item.fork_id == parent_fork_id and item.node_run_id == node_run_id:
947
+ task_ids_to_cancel.add(task_id)
948
+ break
949
+ else:
950
+ pass
951
+ for task_id in task_ids_to_cancel:
952
+ await self._finish_task(task_id)
953
+
954
+
955
+ def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
956
+ return isinstance(x, Iterable)
957
+
958
+
959
+ def _is_any_async_iterable(x: Any) -> TypeGuard[AsyncIterable[Any]]:
960
+ return isinstance(x, AsyncIterable)
961
+
962
+
963
+ @contextmanager
964
+ def _unwrap_exception_groups():
965
+ # I need to use a helper function for this because I can't figure out a way to get pyright
966
+ # to type-check the ExceptionGroup catching in both 3.13 and 3.10 without emitting type errors in one;
967
+ # if I try to ignore them in one, I get unnecessary-type-ignore errors in the other
968
+ if TYPE_CHECKING:
969
+ yield
970
+ else:
971
+ try:
972
+ yield
973
+ except BaseExceptionGroup as e:
974
+ exception = e.exceptions[0]
975
+ if exception.__cause__ is None:
976
+ # bizarrely, this prevents recursion errors when formatting the exception for logfire
977
+ exception.__cause__ = None
978
+ raise exception