pydantic-graph 1.3.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,939 @@
1
+ """Core graph execution engine for the next version of the pydantic-graph library.
2
+
3
+ This module provides the main `Graph` class and `GraphRun` execution engine that
4
+ handles the orchestration of nodes, edges, and parallel execution paths in
5
+ the graph-based workflow system.
6
+ """
7
+
8
+ from __future__ import annotations as _annotations
9
+
10
+ import sys
11
+ import uuid
12
+ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
13
+ from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
14
+ from dataclasses import dataclass, field
15
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
16
+
17
+ from anyio import CancelScope, create_memory_object_stream, create_task_group
18
+ from anyio.abc import TaskGroup
19
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
20
+ from typing_extensions import TypeVar, assert_never
21
+
22
+ from pydantic_graph import exceptions
23
+ from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
24
+ from pydantic_graph.beta.decision import Decision
25
+ from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID
26
+ from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
27
+ from pydantic_graph.beta.node import (
28
+ EndNode,
29
+ Fork,
30
+ StartNode,
31
+ )
32
+ from pydantic_graph.beta.node_types import AnyNode
33
+ from pydantic_graph.beta.parent_forks import ParentFork
34
+ from pydantic_graph.beta.paths import (
35
+ BroadcastMarker,
36
+ DestinationMarker,
37
+ LabelMarker,
38
+ MapMarker,
39
+ Path,
40
+ TransformMarker,
41
+ )
42
+ from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepNode
43
+ from pydantic_graph.beta.util import unpack_type_expression
44
+ from pydantic_graph.nodes import BaseNode, End
45
+
46
+ if sys.version_info < (3, 11):
47
+ from exceptiongroup import ExceptionGroup as ExceptionGroup # pragma: lax no cover
48
+ else:
49
+ ExceptionGroup = ExceptionGroup # pragma: lax no cover
50
+
51
+ if TYPE_CHECKING:
52
+ from pydantic_graph.beta.mermaid import StateDiagramDirection
53
+
54
+
55
+ StateT = TypeVar('StateT', infer_variance=True)
56
+ """Type variable for graph state."""
57
+
58
+ DepsT = TypeVar('DepsT', infer_variance=True)
59
+ """Type variable for graph dependencies."""
60
+
61
+ InputT = TypeVar('InputT', infer_variance=True)
62
+ """Type variable for graph inputs."""
63
+
64
+ OutputT = TypeVar('OutputT', infer_variance=True)
65
+ """Type variable for graph outputs."""
66
+
67
+
68
+ @dataclass(init=False)
69
+ class EndMarker(Generic[OutputT]):
70
+ """A marker indicating the end of graph execution with a final value.
71
+
72
+ EndMarker is used internally to signal that the graph has completed
73
+ execution and carries the final output value.
74
+
75
+ Type Parameters:
76
+ OutputT: The type of the final output value
77
+ """
78
+
79
+ _value: OutputT
80
+ """The final output value from the graph execution."""
81
+
82
+ def __init__(self, value: OutputT):
83
+ # This manually-defined initializer is necessary due to https://github.com/python/mypy/issues/17623.
84
+ self._value = value
85
+
86
+ @property
87
+ def value(self) -> OutputT:
88
+ return self._value
89
+
90
+
91
+ @dataclass
92
+ class JoinItem:
93
+ """An item representing data flowing into a join operation.
94
+
95
+ JoinItem carries input data from a parallel execution path to a join
96
+ node, along with metadata about which execution 'fork' it originated from.
97
+ """
98
+
99
+ join_id: JoinID
100
+ """The ID of the join node this item is targeting."""
101
+
102
+ inputs: Any
103
+ """The input data for the join operation."""
104
+
105
+ fork_stack: ForkStack
106
+ """The stack of ForkStackItems that led to producing this join item."""
107
+
108
+
109
+ @dataclass(repr=False)
110
+ class Graph(Generic[StateT, DepsT, InputT, OutputT]):
111
+ """A complete graph definition ready for execution.
112
+
113
+ The Graph class represents a complete workflow graph with typed inputs,
114
+ outputs, state, and dependencies. It contains all nodes, edges, and
115
+ metadata needed for execution.
116
+
117
+ Type Parameters:
118
+ StateT: The type of the graph state
119
+ DepsT: The type of the dependencies
120
+ InputT: The type of the input data
121
+ OutputT: The type of the output data
122
+ """
123
+
124
+ name: str | None
125
+ """Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method."""
126
+
127
+ state_type: type[StateT]
128
+ """The type of the graph state."""
129
+
130
+ deps_type: type[DepsT]
131
+ """The type of the dependencies."""
132
+
133
+ input_type: type[InputT]
134
+ """The type of the input data."""
135
+
136
+ output_type: type[OutputT]
137
+ """The type of the output data."""
138
+
139
+ auto_instrument: bool
140
+ """Whether to automatically create instrumentation spans."""
141
+
142
+ nodes: dict[NodeID, AnyNode]
143
+ """All nodes in the graph indexed by their ID."""
144
+
145
+ edges_by_source: dict[NodeID, list[Path]]
146
+ """Outgoing paths from each source node."""
147
+
148
+ parent_forks: dict[JoinID, ParentFork[NodeID]]
149
+ """Parent fork information for each join node."""
150
+
151
+ intermediate_join_nodes: dict[JoinID, set[JoinID]]
152
+ """For each join, the set of other joins that appear between it and its parent fork.
153
+
154
+ Used to determine which joins are "final" (have no other joins as intermediates) and
155
+ which joins should preserve fork stacks when proceeding downstream."""
156
+
157
+ def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
158
+ """Get the parent fork information for a join node.
159
+
160
+ Args:
161
+ join_id: The ID of the join node
162
+
163
+ Returns:
164
+ The parent fork information for the join
165
+
166
+ Raises:
167
+ RuntimeError: If the join ID is not found or has no parent fork
168
+ """
169
+ result = self.parent_forks.get(join_id)
170
+ if result is None:
171
+ raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)')
172
+ return result
173
+
174
+ def is_final_join(self, join_id: JoinID) -> bool:
175
+ """Check if a join is 'final' (has no downstream joins with the same parent fork).
176
+
177
+ A join is non-final if it appears as an intermediate node for another join
178
+ with the same parent fork.
179
+
180
+ Args:
181
+ join_id: The ID of the join node
182
+
183
+ Returns:
184
+ True if the join is final, False if it's non-final
185
+ """
186
+ # Check if this join appears in any other join's intermediate_join_nodes
187
+ for intermediate_joins in self.intermediate_join_nodes.values():
188
+ if join_id in intermediate_joins:
189
+ return False
190
+ return True
191
+
192
+ async def run(
193
+ self,
194
+ *,
195
+ state: StateT = None,
196
+ deps: DepsT = None,
197
+ inputs: InputT = None,
198
+ span: AbstractContextManager[AbstractSpan] | None = None,
199
+ infer_name: bool = True,
200
+ ) -> OutputT:
201
+ """Execute the graph and return the final output.
202
+
203
+ This is the main entry point for graph execution. It runs the graph
204
+ to completion and returns the final output value.
205
+
206
+ Args:
207
+ state: The graph state instance
208
+ deps: The dependencies instance
209
+ inputs: The input data for the graph
210
+ span: Optional span for tracing/instrumentation
211
+ infer_name: Whether to infer the graph name from the calling frame.
212
+
213
+ Returns:
214
+ The final output from the graph execution
215
+ """
216
+ if infer_name and self.name is None:
217
+ inferred_name = infer_obj_name(self, depth=2)
218
+ if inferred_name is not None: # pragma: no branch
219
+ self.name = inferred_name
220
+
221
+ async with self.iter(state=state, deps=deps, inputs=inputs, span=span, infer_name=False) as graph_run:
222
+ # Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method,
223
+ # which I'm less confident will be implemented correctly if not used on the critical path. We can change it
224
+ # once we have tests, etc.
225
+ event: Any = None
226
+ while True:
227
+ try:
228
+ event = await graph_run.next(event)
229
+ except StopAsyncIteration:
230
+ assert isinstance(event, EndMarker), 'Graph run should end with an EndMarker.'
231
+ return cast(EndMarker[OutputT], event).value
232
+
233
+ @asynccontextmanager
234
+ async def iter(
235
+ self,
236
+ *,
237
+ state: StateT = None,
238
+ deps: DepsT = None,
239
+ inputs: InputT = None,
240
+ span: AbstractContextManager[AbstractSpan] | None = None,
241
+ infer_name: bool = True,
242
+ ) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]:
243
+ """Create an iterator for step-by-step graph execution.
244
+
245
+ This method allows for more fine-grained control over graph execution,
246
+ enabling inspection of intermediate states and results.
247
+
248
+ Args:
249
+ state: The graph state instance
250
+ deps: The dependencies instance
251
+ inputs: The input data for the graph
252
+ span: Optional span for tracing/instrumentation
253
+ infer_name: Whether to infer the graph name from the calling frame.
254
+
255
+ Yields:
256
+ A GraphRun instance that can be iterated for step-by-step execution
257
+ """
258
+ if infer_name and self.name is None:
259
+ inferred_name = infer_obj_name(self, depth=3) # depth=3 because asynccontextmanager adds one
260
+ if inferred_name is not None: # pragma: no branch
261
+ self.name = inferred_name
262
+
263
+ with ExitStack() as stack:
264
+ entered_span: AbstractSpan | None = None
265
+ if span is None:
266
+ if self.auto_instrument:
267
+ entered_span = stack.enter_context(logfire_span('run graph {graph.name}', graph=self))
268
+ else:
269
+ entered_span = stack.enter_context(span)
270
+ traceparent = None if entered_span is None else get_traceparent(entered_span)
271
+ async with GraphRun[StateT, DepsT, OutputT](
272
+ graph=self,
273
+ state=state,
274
+ deps=deps,
275
+ inputs=inputs,
276
+ traceparent=traceparent,
277
+ ) as graph_run:
278
+ yield graph_run
279
+
280
+ def render(self, *, title: str | None = None, direction: StateDiagramDirection | None = None) -> str:
281
+ """Render the graph as a Mermaid diagram string.
282
+
283
+ Args:
284
+ title: Optional title for the diagram
285
+ direction: Optional direction for the diagram layout
286
+
287
+ Returns:
288
+ A string containing the Mermaid diagram representation
289
+ """
290
+ from pydantic_graph.beta.mermaid import build_mermaid_graph
291
+
292
+ return build_mermaid_graph(self.nodes, self.edges_by_source).render(title=title, direction=direction)
293
+
294
+ def __repr__(self) -> str:
295
+ super_repr = super().__repr__() # include class and memory address
296
+ # Insert the result of calling `__str__` before the final '>' in the repr
297
+ return f'{super_repr[:-1]}\n{self}\n{super_repr[-1]}'
298
+
299
+ def __str__(self) -> str:
300
+ """Return a Mermaid diagram representation of the graph.
301
+
302
+ Returns:
303
+ A string containing the Mermaid diagram of the graph
304
+ """
305
+ return self.render()
306
+
307
+
308
+ @dataclass
309
+ class GraphTask:
310
+ """A single task representing the execution of a node in the graph.
311
+
312
+ GraphTask encapsulates all the information needed to execute a specific
313
+ node, including its inputs and the fork context it's executing within.
314
+ """
315
+
316
+ # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself
317
+ node_id: NodeID
318
+ """The ID of the node to execute."""
319
+
320
+ inputs: Any
321
+ """The input data for the node."""
322
+
323
+ fork_stack: ForkStack = field(repr=False)
324
+ """Stack of forks that have been entered.
325
+
326
+ Used by the GraphRun to decide when to proceed through joins.
327
+ """
328
+
329
+ task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False)
330
+ """Unique identifier for this task."""
331
+
332
+
333
+ class GraphRun(Generic[StateT, DepsT, OutputT]):
334
+ """A single execution instance of a graph.
335
+
336
+ GraphRun manages the execution state for a single run of a graph,
337
+ including task scheduling, fork/join coordination, and result tracking.
338
+
339
+ Type Parameters:
340
+ StateT: The type of the graph state
341
+ DepsT: The type of the dependencies
342
+ OutputT: The type of the output data
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ graph: Graph[StateT, DepsT, InputT, OutputT],
348
+ *,
349
+ state: StateT,
350
+ deps: DepsT,
351
+ inputs: InputT,
352
+ traceparent: str | None,
353
+ ):
354
+ """Initialize a graph run.
355
+
356
+ Args:
357
+ graph: The graph to execute
358
+ state: The graph state instance
359
+ deps: The dependencies instance
360
+ inputs: The input data for the graph
361
+ traceparent: Optional trace parent for instrumentation
362
+ """
363
+ self.graph = graph
364
+ """The graph being executed."""
365
+
366
+ self.state = state
367
+ """The graph state instance."""
368
+
369
+ self.deps = deps
370
+ """The dependencies instance."""
371
+
372
+ self.inputs = inputs
373
+ """The initial input data."""
374
+
375
+ self._active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = {}
376
+ """Active reducers for join operations."""
377
+
378
+ self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
379
+ """The next item to be processed."""
380
+
381
+ run_id = GraphRunID(str(uuid.uuid4()))
382
+ initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),)
383
+ self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack)
384
+ self._iterator_task_group = create_task_group()
385
+ self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
386
+ self.graph, self.state, self.deps, self._iterator_task_group
387
+ )
388
+ self._iterator = self._iterator_instance.iter_graph(self._first_task)
389
+
390
+ self.__traceparent = traceparent
391
+ self._async_exit_stack = AsyncExitStack()
392
+
393
+ async def __aenter__(self):
394
+ self._async_exit_stack.enter_context(_unwrap_exception_groups())
395
+ await self._async_exit_stack.enter_async_context(self._iterator_task_group)
396
+ await self._async_exit_stack.enter_async_context(self._iterator_context())
397
+ return self
398
+
399
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
400
+ await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
401
+
402
+ @asynccontextmanager
403
+ async def _iterator_context(self):
404
+ try:
405
+ yield
406
+ finally:
407
+ self._iterator_instance.iter_stream_sender.close()
408
+ self._iterator_instance.iter_stream_receiver.close()
409
+ await self._iterator.aclose()
410
+
411
+ @overload
412
+ def _traceparent(self, *, required: Literal[False]) -> str | None: ...
413
+ @overload
414
+ def _traceparent(self) -> str: ...
415
+ def _traceparent(self, *, required: bool = True) -> str | None:
416
+ """Get the trace parent for instrumentation.
417
+
418
+ Args:
419
+ required: Whether to raise an error if no traceparent exists
420
+
421
+ Returns:
422
+ The traceparent string, or None if not required and not set
423
+
424
+ Raises:
425
+ GraphRuntimeError: If required is True and no traceparent exists
426
+ """
427
+ if self.__traceparent is None and required: # pragma: no cover
428
+ raise exceptions.GraphRuntimeError('No span was created for this graph run')
429
+ return self.__traceparent
430
+
431
+ def __aiter__(self) -> AsyncIterator[EndMarker[OutputT] | Sequence[GraphTask]]:
432
+ """Return self as an async iterator.
433
+
434
+ Returns:
435
+ Self for async iteration
436
+ """
437
+ return self
438
+
439
+ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
440
+ """Get the next item in the async iteration.
441
+
442
+ Returns:
443
+ The next execution result from the graph
444
+ """
445
+ if self._next is None:
446
+ self._next = await anext(self._iterator)
447
+ else:
448
+ self._next = await self._iterator.asend(self._next)
449
+ return self._next
450
+
451
+ async def next(
452
+ self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None
453
+ ) -> EndMarker[OutputT] | Sequence[GraphTask]:
454
+ """Advance the graph execution by one step.
455
+
456
+ This method allows for sending a value to the iterator, which is useful
457
+ for resuming iteration or overriding intermediate results.
458
+
459
+ Args:
460
+ value: Optional value to send to the iterator
461
+
462
+ Returns:
463
+ The next execution result: either an EndMarker, or sequence of GraphTasks
464
+ """
465
+ if self._next is None:
466
+ # Prevent `TypeError: can't send non-None value to a just-started async generator`
467
+ # if `next` is called before the `first_node` has run.
468
+ await anext(self)
469
+ if value is not None:
470
+ self._next = value
471
+ return await anext(self)
472
+
473
+ @property
474
+ def next_task(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
475
+ """Get the next task(s) to be executed.
476
+
477
+ Returns:
478
+ The next execution item, or the initial task if none is set
479
+ """
480
+ return self._next or [self._first_task]
481
+
482
+ @property
483
+ def output(self) -> OutputT | None:
484
+ """Get the final output if the graph has completed.
485
+
486
+ Returns:
487
+ The output value if execution is complete, None otherwise
488
+ """
489
+ if isinstance(self._next, EndMarker):
490
+ return self._next.value
491
+ return None
492
+
493
+
494
+ @dataclass
495
+ class _GraphTaskAsyncIterable:
496
+ iterable: AsyncIterable[Sequence[GraphTask]]
497
+ fork_stack: ForkStack
498
+
499
+
500
+ @dataclass
501
+ class _GraphTaskResult:
502
+ source: GraphTask
503
+ result: EndMarker[Any] | Sequence[GraphTask] | JoinItem
504
+ source_is_finished: bool = True
505
+
506
+
507
+ @dataclass
508
+ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
509
+ graph: Graph[StateT, DepsT, Any, OutputT]
510
+ state: StateT
511
+ deps: DepsT
512
+ task_group: TaskGroup
513
+
514
+ cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
515
+ active_tasks: dict[TaskID, GraphTask] = field(init=False)
516
+ active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
517
+ iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
518
+ iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
519
+
520
+ def __post_init__(self):
521
+ self.cancel_scopes = {}
522
+ self.active_tasks = {}
523
+ self.active_reducers = {}
524
+ self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
525
+
526
+ async def iter_graph( # noqa C901
527
+ self, first_task: GraphTask
528
+ ) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
529
+ async with self.iter_stream_sender:
530
+ try:
531
+ # Fire off the first task
532
+ self.active_tasks[first_task.task_id] = first_task
533
+ self._handle_execution_request([first_task])
534
+
535
+ # Handle task results
536
+ async with self.iter_stream_receiver:
537
+ while self.active_tasks or self.active_reducers:
538
+ async for task_result in self.iter_stream_receiver: # pragma: no branch
539
+ if isinstance(task_result.result, JoinItem):
540
+ maybe_overridden_result = task_result.result
541
+ else:
542
+ maybe_overridden_result = yield task_result.result
543
+ if isinstance(maybe_overridden_result, EndMarker):
544
+ # If we got an end marker, this task is definitely done, and we're ready to
545
+ # start cleaning everything up
546
+ await self._finish_task(task_result.source.task_id)
547
+ if self.active_tasks:
548
+ # Cancel the remaining tasks
549
+ self.task_group.cancel_scope.cancel()
550
+ return
551
+ elif isinstance(maybe_overridden_result, JoinItem):
552
+ result = maybe_overridden_result
553
+ parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
554
+ for i, x in enumerate(result.fork_stack[::-1]):
555
+ if x.fork_id == parent_fork_id:
556
+ # For non-final joins (those that are intermediate nodes of other joins),
557
+ # preserve the fork stack so downstream joins can still associate with the same fork run
558
+ if self.graph.is_final_join(result.join_id):
559
+ # Final join: remove the parent fork from the stack
560
+ downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
561
+ else:
562
+ # Non-final join: preserve the fork stack
563
+ downstream_fork_stack = result.fork_stack
564
+ fork_run_id = x.node_run_id
565
+ break
566
+ else: # pragma: no cover
567
+ raise RuntimeError('Parent fork run not found')
568
+
569
+ join_node = self.graph.nodes[result.join_id]
570
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
571
+ join_state = self.active_reducers.get((result.join_id, fork_run_id))
572
+ if join_state is None:
573
+ current = join_node.initial_factory()
574
+ join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
575
+ current, downstream_fork_stack
576
+ )
577
+ context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
578
+ join_state.current = join_node.reduce(context, join_state.current, result.inputs)
579
+ if join_state.cancelled_sibling_tasks:
580
+ await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
581
+ else:
582
+ for new_task in maybe_overridden_result:
583
+ self.active_tasks[new_task.task_id] = new_task
584
+
585
+ tasks_by_id_values = list(self.active_tasks.values())
586
+ join_tasks: list[GraphTask] = []
587
+
588
+ for join_id, fork_run_id in self._get_completed_fork_runs(
589
+ task_result.source, tasks_by_id_values
590
+ ):
591
+ join_state = self.active_reducers.pop((join_id, fork_run_id))
592
+ join_node = self.graph.nodes[join_id]
593
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
594
+ new_tasks = self._handle_non_fork_edges(
595
+ join_node, join_state.current, join_state.downstream_fork_stack
596
+ )
597
+ join_tasks.extend(new_tasks)
598
+ if join_tasks:
599
+ for new_task in join_tasks:
600
+ self.active_tasks[new_task.task_id] = new_task
601
+ self._handle_execution_request(join_tasks)
602
+
603
+ if isinstance(maybe_overridden_result, Sequence):
604
+ if isinstance(task_result.result, Sequence):
605
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
606
+ for t in task_result.result:
607
+ if t.task_id not in new_task_ids:
608
+ await self._finish_task(t.task_id)
609
+ self._handle_execution_request(maybe_overridden_result)
610
+
611
+ if task_result.source_is_finished:
612
+ await self._finish_task(task_result.source.task_id)
613
+
614
+ if not self.active_tasks:
615
+ # if there are no active tasks, we'll be waiting forever for the next result..
616
+ break
617
+
618
+ if self.active_reducers: # pragma: no branch
619
+ # In this case, there are no pending tasks. We can therefore finalize all active reducers
620
+ # that don't have intermediate joins which are also active reducers. If a join J2 has an
621
+ # intermediate join J1 that shares the same parent fork run, we must finalize J1 first
622
+ # because it might produce items that feed into J2.
623
+ for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
624
+ # Check if this join has any intermediate joins that are also active reducers
625
+ should_skip = False
626
+ intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
627
+
628
+ # Get the parent fork for this join to use for comparison
629
+ join_parent_fork = self.graph.get_parent_fork(join_id)
630
+
631
+ for intermediate_join_id in intermediate_joins:
632
+ # Check if the intermediate join is also an active reducer with matching fork run
633
+ for (other_join_id, _), other_join_state in self.active_reducers.items():
634
+ if other_join_id == intermediate_join_id:
635
+ # Check if they share the same fork run for this join's parent fork
636
+ # by finding the parent fork's node_run_id in both fork stacks
637
+ join_parent_fork_run_id = None
638
+ other_parent_fork_run_id = None
639
+
640
+ for fsi in join_state.downstream_fork_stack: # pragma: no branch
641
+ if fsi.fork_id == join_parent_fork.fork_id:
642
+ join_parent_fork_run_id = fsi.node_run_id
643
+ break
644
+
645
+ for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
646
+ if fsi.fork_id == join_parent_fork.fork_id:
647
+ other_parent_fork_run_id = fsi.node_run_id
648
+ break
649
+
650
+ if (
651
+ join_parent_fork_run_id
652
+ and other_parent_fork_run_id
653
+ and join_parent_fork_run_id == other_parent_fork_run_id
654
+ ): # pragma: no branch
655
+ should_skip = True
656
+ break
657
+ if should_skip:
658
+ break
659
+
660
+ if should_skip:
661
+ continue
662
+
663
+ self.active_reducers.pop(
664
+ (join_id, fork_run_id)
665
+ ) # we're handling it now, so we can pop it
666
+ join_node = self.graph.nodes[join_id]
667
+ assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
668
+ new_tasks = self._handle_non_fork_edges(
669
+ join_node, join_state.current, join_state.downstream_fork_stack
670
+ )
671
+ maybe_overridden_result = yield new_tasks
672
+ if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
673
+ # This is theoretically reachable but it would be awkward.
674
+ # Probably a better way to get coverage here would be to unify the code pat
675
+ # with the other `if isinstance(maybe_overridden_result, EndMarker):`
676
+ self.task_group.cancel_scope.cancel()
677
+ return
678
+ for new_task in maybe_overridden_result:
679
+ self.active_tasks[new_task.task_id] = new_task
680
+ new_task_ids = {t.task_id for t in maybe_overridden_result}
681
+ for t in new_tasks:
682
+ # Same note as above about how this is theoretically reachable but we should
683
+ # just get coverage by unifying the code paths
684
+ if t.task_id not in new_task_ids: # pragma: no cover
685
+ await self._finish_task(t.task_id)
686
+ self._handle_execution_request(maybe_overridden_result)
687
+ except GeneratorExit:
688
+ self.task_group.cancel_scope.cancel()
689
+ return
690
+
691
+ raise RuntimeError( # pragma: no cover
692
+ 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
693
+ )
694
+
695
+ async def _finish_task(self, task_id: TaskID) -> None:
696
+ # node_id is just included for debugging right now
697
+ scope = self.cancel_scopes.pop(task_id, None)
698
+ if scope is not None:
699
+ scope.cancel()
700
+ self.active_tasks.pop(task_id, None)
701
+
702
+ def _handle_execution_request(self, request: Sequence[GraphTask]) -> None:
703
+ for new_task in request:
704
+ self.active_tasks[new_task.task_id] = new_task
705
+ for new_task in request:
706
+ self.task_group.start_soon(self._run_tracked_task, new_task)
707
+
708
+ async def _run_tracked_task(self, t_: GraphTask):
709
+ with CancelScope() as scope:
710
+ self.cancel_scopes[t_.task_id] = scope
711
+ result = await self._run_task(t_)
712
+ if isinstance(result, _GraphTaskAsyncIterable):
713
+ async for new_tasks in result.iterable:
714
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
715
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
716
+ else:
717
+ await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
718
+
719
+ async def _run_task(
720
+ self,
721
+ task: GraphTask,
722
+ ) -> EndMarker[OutputT] | Sequence[GraphTask] | _GraphTaskAsyncIterable | JoinItem:
723
+ state = self.state
724
+ deps = self.deps
725
+
726
+ node_id = task.node_id
727
+ inputs = task.inputs
728
+ fork_stack = task.fork_stack
729
+
730
+ node = self.graph.nodes[node_id]
731
+
732
+ if isinstance(node, StartNode | Fork):
733
+ return self._handle_edges(node, inputs, fork_stack)
734
+ elif isinstance(node, Step):
735
+ with ExitStack() as stack:
736
+ if self.graph.auto_instrument:
737
+ stack.enter_context(logfire_span('run node {node_id}', node_id=node.id, node=node))
738
+
739
+ step_context = StepContext[StateT, DepsT, Any](state=state, deps=deps, inputs=inputs)
740
+ output = await node.call(step_context)
741
+ if isinstance(node, NodeStep):
742
+ return self._handle_node(output, fork_stack)
743
+ else:
744
+ return self._handle_edges(node, output, fork_stack)
745
+ elif isinstance(node, Join):
746
+ return JoinItem(node_id, inputs, fork_stack)
747
+ elif isinstance(node, Decision):
748
+ return self._handle_decision(node, inputs, fork_stack)
749
+ elif isinstance(node, EndNode):
750
+ return EndMarker(inputs)
751
+ else:
752
+ assert_never(node)
753
+
754
+ def _handle_decision(
755
+ self, decision: Decision[StateT, DepsT, Any], inputs: Any, fork_stack: ForkStack
756
+ ) -> Sequence[GraphTask]:
757
+ for branch in decision.branches:
758
+ match_tester = branch.matches
759
+ if match_tester is not None:
760
+ inputs_match = match_tester(inputs)
761
+ else:
762
+ branch_source = unpack_type_expression(branch.source)
763
+
764
+ if branch_source in {Any, object}:
765
+ inputs_match = True
766
+ elif get_origin(branch_source) is Literal:
767
+ inputs_match = inputs in get_args(branch_source)
768
+ else:
769
+ try:
770
+ inputs_match = isinstance(inputs, branch_source)
771
+ except TypeError as e: # pragma: no cover
772
+ raise RuntimeError(f'Decision branch source {branch_source} is not a valid type.') from e
773
+
774
+ if inputs_match:
775
+ return self._handle_path(branch.path, inputs, fork_stack)
776
+
777
+ raise RuntimeError(f'No branch matched inputs {inputs} for decision node {decision}.')
778
+
779
+ def _handle_node(
780
+ self,
781
+ next_node: BaseNode[StateT, DepsT, Any] | End[Any],
782
+ fork_stack: ForkStack,
783
+ ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
784
+ if isinstance(next_node, StepNode):
785
+ return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)]
786
+ elif isinstance(next_node, JoinNode):
787
+ return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
788
+ elif isinstance(next_node, BaseNode):
789
+ node_step = NodeStep(next_node.__class__)
790
+ return [GraphTask(node_step.id, next_node, fork_stack)]
791
+ elif isinstance(next_node, End):
792
+ return EndMarker(next_node.data)
793
+ else:
794
+ assert_never(next_node)
795
+
796
+ def _get_completed_fork_runs(
797
+ self,
798
+ t: GraphTask,
799
+ active_tasks: Iterable[GraphTask],
800
+ ) -> list[tuple[JoinID, NodeRunID]]:
801
+ completed_fork_runs: list[tuple[JoinID, NodeRunID]] = []
802
+
803
+ fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)}
804
+ for join_id, fork_run_id in self.active_reducers.keys():
805
+ fork_run_index = fork_run_indices.get(fork_run_id)
806
+ if fork_run_index is None:
807
+ continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it.
808
+
809
+ # This reducer _may_ now be ready to finalize:
810
+ if self._is_fork_run_completed(active_tasks, join_id, fork_run_id):
811
+ completed_fork_runs.append((join_id, fork_run_id))
812
+
813
+ return completed_fork_runs
814
+
815
+ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
816
+ if not path.items:
817
+ return [] # pragma: no cover
818
+
819
+ item = path.items[0]
820
+ assert not isinstance(item, MapMarker | BroadcastMarker), (
821
+ 'These markers should be removed from paths during graph building'
822
+ )
823
+ if isinstance(item, DestinationMarker):
824
+ return [GraphTask(item.destination_id, inputs, fork_stack)]
825
+ elif isinstance(item, TransformMarker):
826
+ inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
827
+ return self._handle_path(path.next_path, inputs, fork_stack)
828
+ elif isinstance(item, LabelMarker):
829
+ return self._handle_path(path.next_path, inputs, fork_stack)
830
+ else:
831
+ assert_never(item)
832
+
833
+ def _handle_edges(
834
+ self, node: AnyNode, inputs: Any, fork_stack: ForkStack
835
+ ) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
836
+ if isinstance(node, Fork):
837
+ return self._handle_fork_edges(node, inputs, fork_stack)
838
+ else:
839
+ return self._handle_non_fork_edges(node, inputs, fork_stack)
840
+
841
+ def _handle_non_fork_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
842
+ edges = self.graph.edges_by_source.get(node.id, [])
843
+ assert len(edges) == 1 # this should have already been ensured during graph building
844
+ return self._handle_path(edges[0], inputs, fork_stack)
845
+
846
+ def _handle_fork_edges(
847
+ self, node: Fork[Any, Any], inputs: Any, fork_stack: ForkStack
848
+ ) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
849
+ edges = self.graph.edges_by_source.get(node.id, [])
850
+ assert len(edges) == 1 or (isinstance(node, Fork) and not node.is_map), (
851
+ edges,
852
+ node.id,
853
+ ) # this should have already been ensured during graph building
854
+
855
+ new_tasks: list[GraphTask] = []
856
+ node_run_id = NodeRunID(str(uuid.uuid4()))
857
+ if node.is_map:
858
+ # If the map specifies a downstream join id, eagerly create a join state for it
859
+ if (join_id := node.downstream_join_id) is not None:
860
+ join_node = self.graph.nodes[join_id]
861
+ assert isinstance(join_node, Join)
862
+ self.active_reducers[(join_id, node_run_id)] = JoinState(join_node.initial_factory(), fork_stack)
863
+
864
+ # Eagerly raise a clear error if the input value is not iterable as expected
865
+ if _is_any_iterable(inputs):
866
+ for thread_index, input_item in enumerate(inputs):
867
+ item_tasks = self._handle_path(
868
+ edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
869
+ )
870
+ new_tasks += item_tasks
871
+ elif _is_any_async_iterable(inputs):
872
+
873
+ async def handle_async_iterable() -> AsyncIterator[Sequence[GraphTask]]:
874
+ thread_index = 0
875
+ async for input_item in inputs:
876
+ item_tasks = self._handle_path(
877
+ edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
878
+ )
879
+ yield item_tasks
880
+ thread_index += 1
881
+
882
+ return _GraphTaskAsyncIterable(handle_async_iterable(), fork_stack)
883
+
884
+ else:
885
+ raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
886
+ else:
887
+ for i, path in enumerate(edges):
888
+ new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),))
889
+ return new_tasks
890
+
891
+ def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool:
892
+ # Check if any of the tasks in the graph have this fork_run_id in their fork_stack
893
+ # If this is the case, then the fork run is not yet completed
894
+ parent_fork = self.graph.get_parent_fork(join_id)
895
+ for t in tasks:
896
+ if fork_run_id in {x.node_run_id for x in t.fork_stack}:
897
+ if t.node_id in parent_fork.intermediate_nodes or t.node_id == join_id:
898
+ return False
899
+ else:
900
+ pass
901
+ return True
902
+
903
+ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeRunID):
904
+ task_ids_to_cancel = set[TaskID]()
905
+ for task_id, t in self.active_tasks.items():
906
+ for item in t.fork_stack: # pragma: no branch
907
+ if item.fork_id == parent_fork_id and item.node_run_id == node_run_id:
908
+ task_ids_to_cancel.add(task_id)
909
+ break
910
+ else:
911
+ pass
912
+ for task_id in task_ids_to_cancel:
913
+ await self._finish_task(task_id)
914
+
915
+
916
+ def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
917
+ return isinstance(x, Iterable)
918
+
919
+
920
+ def _is_any_async_iterable(x: Any) -> TypeGuard[AsyncIterable[Any]]:
921
+ return isinstance(x, AsyncIterable)
922
+
923
+
924
+ @contextmanager
925
+ def _unwrap_exception_groups():
926
+ # I need to use a helper function for this because I can't figure out a way to get pyright
927
+ # to type-check the ExceptionGroup catching in both 3.13 and 3.10 without emitting type errors in one;
928
+ # if I try to ignore them in one, I get unnecessary-type-ignore errors in the other
929
+ if TYPE_CHECKING:
930
+ yield
931
+ else:
932
+ try:
933
+ yield
934
+ except ExceptionGroup as e:
935
+ exception = e.exceptions[0]
936
+ if exception.__cause__ is None:
937
+ # bizarrely, this prevents recursion errors when formatting the exception for logfire
938
+ exception.__cause__ = None
939
+ raise exception