pydantic-graph 1.2.1__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph/_utils.py +39 -0
- pydantic_graph/beta/__init__.py +25 -0
- pydantic_graph/beta/decision.py +276 -0
- pydantic_graph/beta/graph.py +978 -0
- pydantic_graph/beta/graph_builder.py +1053 -0
- pydantic_graph/beta/id_types.py +76 -0
- pydantic_graph/beta/join.py +249 -0
- pydantic_graph/beta/mermaid.py +208 -0
- pydantic_graph/beta/node.py +95 -0
- pydantic_graph/beta/node_types.py +90 -0
- pydantic_graph/beta/parent_forks.py +232 -0
- pydantic_graph/beta/paths.py +421 -0
- pydantic_graph/beta/step.py +253 -0
- pydantic_graph/beta/util.py +90 -0
- pydantic_graph/exceptions.py +22 -0
- pydantic_graph/graph.py +12 -4
- pydantic_graph/nodes.py +0 -2
- pydantic_graph/persistence/in_mem.py +1 -1
- {pydantic_graph-1.2.1.dist-info → pydantic_graph-1.22.0.dist-info}/METADATA +1 -1
- pydantic_graph-1.22.0.dist-info/RECORD +28 -0
- pydantic_graph-1.2.1.dist-info/RECORD +0 -15
- {pydantic_graph-1.2.1.dist-info → pydantic_graph-1.22.0.dist-info}/WHEEL +0 -0
- {pydantic_graph-1.2.1.dist-info → pydantic_graph-1.22.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,978 @@
|
|
|
1
|
+
"""Core graph execution engine for the next version of the pydantic-graph library.
|
|
2
|
+
|
|
3
|
+
This module provides the main `Graph` class and `GraphRun` execution engine that
|
|
4
|
+
handles the orchestration of nodes, edges, and parallel execution paths in
|
|
5
|
+
the graph-based workflow system.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations as _annotations
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence
|
|
12
|
+
from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
|
|
15
|
+
|
|
16
|
+
from anyio import CancelScope, create_memory_object_stream, create_task_group
|
|
17
|
+
from anyio.abc import TaskGroup
|
|
18
|
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
19
|
+
from typing_extensions import TypeVar, assert_never
|
|
20
|
+
|
|
21
|
+
from pydantic_graph import exceptions
|
|
22
|
+
from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span
|
|
23
|
+
from pydantic_graph.beta.decision import Decision
|
|
24
|
+
from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID
|
|
25
|
+
from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext
|
|
26
|
+
from pydantic_graph.beta.node import (
|
|
27
|
+
EndNode,
|
|
28
|
+
Fork,
|
|
29
|
+
StartNode,
|
|
30
|
+
)
|
|
31
|
+
from pydantic_graph.beta.node_types import AnyNode
|
|
32
|
+
from pydantic_graph.beta.parent_forks import ParentFork
|
|
33
|
+
from pydantic_graph.beta.paths import (
|
|
34
|
+
BroadcastMarker,
|
|
35
|
+
DestinationMarker,
|
|
36
|
+
LabelMarker,
|
|
37
|
+
MapMarker,
|
|
38
|
+
Path,
|
|
39
|
+
TransformMarker,
|
|
40
|
+
)
|
|
41
|
+
from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepNode
|
|
42
|
+
from pydantic_graph.beta.util import unpack_type_expression
|
|
43
|
+
from pydantic_graph.nodes import BaseNode, End
|
|
44
|
+
|
|
45
|
+
if sys.version_info < (3, 11):
|
|
46
|
+
from exceptiongroup import ExceptionGroup as ExceptionGroup # pragma: lax no cover
|
|
47
|
+
else:
|
|
48
|
+
ExceptionGroup = ExceptionGroup # pragma: lax no cover
|
|
49
|
+
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from pydantic_graph.beta.mermaid import StateDiagramDirection
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
StateT = TypeVar('StateT', infer_variance=True)
|
|
55
|
+
"""Type variable for graph state."""
|
|
56
|
+
|
|
57
|
+
DepsT = TypeVar('DepsT', infer_variance=True)
|
|
58
|
+
"""Type variable for graph dependencies."""
|
|
59
|
+
|
|
60
|
+
InputT = TypeVar('InputT', infer_variance=True)
|
|
61
|
+
"""Type variable for graph inputs."""
|
|
62
|
+
|
|
63
|
+
OutputT = TypeVar('OutputT', infer_variance=True)
|
|
64
|
+
"""Type variable for graph outputs."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass(init=False)
|
|
68
|
+
class EndMarker(Generic[OutputT]):
|
|
69
|
+
"""A marker indicating the end of graph execution with a final value.
|
|
70
|
+
|
|
71
|
+
EndMarker is used internally to signal that the graph has completed
|
|
72
|
+
execution and carries the final output value.
|
|
73
|
+
|
|
74
|
+
Type Parameters:
|
|
75
|
+
OutputT: The type of the final output value
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
_value: OutputT
|
|
79
|
+
"""The final output value from the graph execution."""
|
|
80
|
+
|
|
81
|
+
def __init__(self, value: OutputT):
|
|
82
|
+
# This manually-defined initializer is necessary due to https://github.com/python/mypy/issues/17623.
|
|
83
|
+
self._value = value
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def value(self) -> OutputT:
|
|
87
|
+
return self._value
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class JoinItem:
|
|
92
|
+
"""An item representing data flowing into a join operation.
|
|
93
|
+
|
|
94
|
+
JoinItem carries input data from a parallel execution path to a join
|
|
95
|
+
node, along with metadata about which execution 'fork' it originated from.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
join_id: JoinID
|
|
99
|
+
"""The ID of the join node this item is targeting."""
|
|
100
|
+
|
|
101
|
+
inputs: Any
|
|
102
|
+
"""The input data for the join operation."""
|
|
103
|
+
|
|
104
|
+
fork_stack: ForkStack
|
|
105
|
+
"""The stack of ForkStackItems that led to producing this join item."""
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass(repr=False)
|
|
109
|
+
class Graph(Generic[StateT, DepsT, InputT, OutputT]):
|
|
110
|
+
"""A complete graph definition ready for execution.
|
|
111
|
+
|
|
112
|
+
The Graph class represents a complete workflow graph with typed inputs,
|
|
113
|
+
outputs, state, and dependencies. It contains all nodes, edges, and
|
|
114
|
+
metadata needed for execution.
|
|
115
|
+
|
|
116
|
+
Type Parameters:
|
|
117
|
+
StateT: The type of the graph state
|
|
118
|
+
DepsT: The type of the dependencies
|
|
119
|
+
InputT: The type of the input data
|
|
120
|
+
OutputT: The type of the output data
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
name: str | None
|
|
124
|
+
"""Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method."""
|
|
125
|
+
|
|
126
|
+
state_type: type[StateT]
|
|
127
|
+
"""The type of the graph state."""
|
|
128
|
+
|
|
129
|
+
deps_type: type[DepsT]
|
|
130
|
+
"""The type of the dependencies."""
|
|
131
|
+
|
|
132
|
+
input_type: type[InputT]
|
|
133
|
+
"""The type of the input data."""
|
|
134
|
+
|
|
135
|
+
output_type: type[OutputT]
|
|
136
|
+
"""The type of the output data."""
|
|
137
|
+
|
|
138
|
+
auto_instrument: bool
|
|
139
|
+
"""Whether to automatically create instrumentation spans."""
|
|
140
|
+
|
|
141
|
+
nodes: dict[NodeID, AnyNode]
|
|
142
|
+
"""All nodes in the graph indexed by their ID."""
|
|
143
|
+
|
|
144
|
+
edges_by_source: dict[NodeID, list[Path]]
|
|
145
|
+
"""Outgoing paths from each source node."""
|
|
146
|
+
|
|
147
|
+
parent_forks: dict[JoinID, ParentFork[NodeID]]
|
|
148
|
+
"""Parent fork information for each join node."""
|
|
149
|
+
|
|
150
|
+
intermediate_join_nodes: dict[JoinID, set[JoinID]]
|
|
151
|
+
"""For each join, the set of other joins that appear between it and its parent fork.
|
|
152
|
+
|
|
153
|
+
Used to determine which joins are "final" (have no other joins as intermediates) and
|
|
154
|
+
which joins should preserve fork stacks when proceeding downstream."""
|
|
155
|
+
|
|
156
|
+
def get_parent_fork(self, join_id: JoinID) -> ParentFork[NodeID]:
|
|
157
|
+
"""Get the parent fork information for a join node.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
join_id: The ID of the join node
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
The parent fork information for the join
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
RuntimeError: If the join ID is not found or has no parent fork
|
|
167
|
+
"""
|
|
168
|
+
result = self.parent_forks.get(join_id)
|
|
169
|
+
if result is None:
|
|
170
|
+
raise RuntimeError(f'Node {join_id} is not a join node or did not have a dominating fork (this is a bug)')
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
def is_final_join(self, join_id: JoinID) -> bool:
|
|
174
|
+
"""Check if a join is 'final' (has no downstream joins with the same parent fork).
|
|
175
|
+
|
|
176
|
+
A join is non-final if it appears as an intermediate node for another join
|
|
177
|
+
with the same parent fork.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
join_id: The ID of the join node
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
True if the join is final, False if it's non-final
|
|
184
|
+
"""
|
|
185
|
+
# Check if this join appears in any other join's intermediate_join_nodes
|
|
186
|
+
for intermediate_joins in self.intermediate_join_nodes.values():
|
|
187
|
+
if join_id in intermediate_joins:
|
|
188
|
+
return False
|
|
189
|
+
return True
|
|
190
|
+
|
|
191
|
+
async def run(
|
|
192
|
+
self,
|
|
193
|
+
*,
|
|
194
|
+
state: StateT = None,
|
|
195
|
+
deps: DepsT = None,
|
|
196
|
+
inputs: InputT = None,
|
|
197
|
+
span: AbstractContextManager[AbstractSpan] | None = None,
|
|
198
|
+
infer_name: bool = True,
|
|
199
|
+
) -> OutputT:
|
|
200
|
+
"""Execute the graph and return the final output.
|
|
201
|
+
|
|
202
|
+
This is the main entry point for graph execution. It runs the graph
|
|
203
|
+
to completion and returns the final output value.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
state: The graph state instance
|
|
207
|
+
deps: The dependencies instance
|
|
208
|
+
inputs: The input data for the graph
|
|
209
|
+
span: Optional span for tracing/instrumentation
|
|
210
|
+
infer_name: Whether to infer the graph name from the calling frame.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
The final output from the graph execution
|
|
214
|
+
"""
|
|
215
|
+
if infer_name and self.name is None:
|
|
216
|
+
inferred_name = infer_obj_name(self, depth=2)
|
|
217
|
+
if inferred_name is not None: # pragma: no branch
|
|
218
|
+
self.name = inferred_name
|
|
219
|
+
|
|
220
|
+
async with self.iter(state=state, deps=deps, inputs=inputs, span=span, infer_name=False) as graph_run:
|
|
221
|
+
# Note: This would probably be better using `async for _ in graph_run`, but this tests the `next` method,
|
|
222
|
+
# which I'm less confident will be implemented correctly if not used on the critical path. We can change it
|
|
223
|
+
# once we have tests, etc.
|
|
224
|
+
event: Any = None
|
|
225
|
+
while True:
|
|
226
|
+
try:
|
|
227
|
+
event = await graph_run.next(event)
|
|
228
|
+
except StopAsyncIteration:
|
|
229
|
+
assert isinstance(event, EndMarker), 'Graph run should end with an EndMarker.'
|
|
230
|
+
return cast(EndMarker[OutputT], event).value
|
|
231
|
+
|
|
232
|
+
@asynccontextmanager
|
|
233
|
+
async def iter(
|
|
234
|
+
self,
|
|
235
|
+
*,
|
|
236
|
+
state: StateT = None,
|
|
237
|
+
deps: DepsT = None,
|
|
238
|
+
inputs: InputT = None,
|
|
239
|
+
span: AbstractContextManager[AbstractSpan] | None = None,
|
|
240
|
+
infer_name: bool = True,
|
|
241
|
+
) -> AsyncIterator[GraphRun[StateT, DepsT, OutputT]]:
|
|
242
|
+
"""Create an iterator for step-by-step graph execution.
|
|
243
|
+
|
|
244
|
+
This method allows for more fine-grained control over graph execution,
|
|
245
|
+
enabling inspection of intermediate states and results.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
state: The graph state instance
|
|
249
|
+
deps: The dependencies instance
|
|
250
|
+
inputs: The input data for the graph
|
|
251
|
+
span: Optional span for tracing/instrumentation
|
|
252
|
+
infer_name: Whether to infer the graph name from the calling frame.
|
|
253
|
+
|
|
254
|
+
Yields:
|
|
255
|
+
A GraphRun instance that can be iterated for step-by-step execution
|
|
256
|
+
"""
|
|
257
|
+
if infer_name and self.name is None:
|
|
258
|
+
inferred_name = infer_obj_name(self, depth=3) # depth=3 because asynccontextmanager adds one
|
|
259
|
+
if inferred_name is not None: # pragma: no branch
|
|
260
|
+
self.name = inferred_name
|
|
261
|
+
|
|
262
|
+
with ExitStack() as stack:
|
|
263
|
+
entered_span: AbstractSpan | None = None
|
|
264
|
+
if span is None:
|
|
265
|
+
if self.auto_instrument:
|
|
266
|
+
entered_span = stack.enter_context(logfire_span('run graph {graph.name}', graph=self))
|
|
267
|
+
else:
|
|
268
|
+
entered_span = stack.enter_context(span)
|
|
269
|
+
traceparent = None if entered_span is None else get_traceparent(entered_span)
|
|
270
|
+
async with GraphRun[StateT, DepsT, OutputT](
|
|
271
|
+
graph=self,
|
|
272
|
+
state=state,
|
|
273
|
+
deps=deps,
|
|
274
|
+
inputs=inputs,
|
|
275
|
+
traceparent=traceparent,
|
|
276
|
+
) as graph_run:
|
|
277
|
+
yield graph_run
|
|
278
|
+
|
|
279
|
+
def render(self, *, title: str | None = None, direction: StateDiagramDirection | None = None) -> str:
|
|
280
|
+
"""Render the graph as a Mermaid diagram string.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
title: Optional title for the diagram
|
|
284
|
+
direction: Optional direction for the diagram layout
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
A string containing the Mermaid diagram representation
|
|
288
|
+
"""
|
|
289
|
+
from pydantic_graph.beta.mermaid import build_mermaid_graph
|
|
290
|
+
|
|
291
|
+
return build_mermaid_graph(self.nodes, self.edges_by_source).render(title=title, direction=direction)
|
|
292
|
+
|
|
293
|
+
def __repr__(self) -> str:
|
|
294
|
+
super_repr = super().__repr__() # include class and memory address
|
|
295
|
+
# Insert the result of calling `__str__` before the final '>' in the repr
|
|
296
|
+
return f'{super_repr[:-1]}\n{self}\n{super_repr[-1]}'
|
|
297
|
+
|
|
298
|
+
def __str__(self) -> str:
|
|
299
|
+
"""Return a Mermaid diagram representation of the graph.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
A string containing the Mermaid diagram of the graph
|
|
303
|
+
"""
|
|
304
|
+
return self.render()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@dataclass
|
|
308
|
+
class GraphTaskRequest:
|
|
309
|
+
"""A request to run a task representing the execution of a node in the graph.
|
|
310
|
+
|
|
311
|
+
GraphTaskRequest encapsulates all the information needed to execute a specific
|
|
312
|
+
node, including its inputs and the fork context it's executing within.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
node_id: NodeID
|
|
316
|
+
"""The ID of the node to execute."""
|
|
317
|
+
|
|
318
|
+
inputs: Any
|
|
319
|
+
"""The input data for the node."""
|
|
320
|
+
|
|
321
|
+
fork_stack: ForkStack = field(repr=False)
|
|
322
|
+
"""Stack of forks that have been entered.
|
|
323
|
+
|
|
324
|
+
Used by the GraphRun to decide when to proceed through joins.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@dataclass
|
|
329
|
+
class GraphTask(GraphTaskRequest):
|
|
330
|
+
"""A task representing the execution of a node in the graph.
|
|
331
|
+
|
|
332
|
+
GraphTask encapsulates all the information needed to execute a specific
|
|
333
|
+
node, including its inputs and the fork context it's executing within,
|
|
334
|
+
and has a unique ID to identify the task within the graph run.
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
task_id: TaskID = field(repr=False)
|
|
338
|
+
"""Unique identifier for this task."""
|
|
339
|
+
|
|
340
|
+
@staticmethod
|
|
341
|
+
def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask:
|
|
342
|
+
# Don't call the get_task_id callable, this is already a task
|
|
343
|
+
if isinstance(request, GraphTask):
|
|
344
|
+
return request
|
|
345
|
+
return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id())
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class GraphRun(Generic[StateT, DepsT, OutputT]):
|
|
349
|
+
"""A single execution instance of a graph.
|
|
350
|
+
|
|
351
|
+
GraphRun manages the execution state for a single run of a graph,
|
|
352
|
+
including task scheduling, fork/join coordination, and result tracking.
|
|
353
|
+
|
|
354
|
+
Type Parameters:
|
|
355
|
+
StateT: The type of the graph state
|
|
356
|
+
DepsT: The type of the dependencies
|
|
357
|
+
OutputT: The type of the output data
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
def __init__(
|
|
361
|
+
self,
|
|
362
|
+
graph: Graph[StateT, DepsT, InputT, OutputT],
|
|
363
|
+
*,
|
|
364
|
+
state: StateT,
|
|
365
|
+
deps: DepsT,
|
|
366
|
+
inputs: InputT,
|
|
367
|
+
traceparent: str | None,
|
|
368
|
+
):
|
|
369
|
+
"""Initialize a graph run.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
graph: The graph to execute
|
|
373
|
+
state: The graph state instance
|
|
374
|
+
deps: The dependencies instance
|
|
375
|
+
inputs: The input data for the graph
|
|
376
|
+
traceparent: Optional trace parent for instrumentation
|
|
377
|
+
"""
|
|
378
|
+
self.graph = graph
|
|
379
|
+
"""The graph being executed."""
|
|
380
|
+
|
|
381
|
+
self.state = state
|
|
382
|
+
"""The graph state instance."""
|
|
383
|
+
|
|
384
|
+
self.deps = deps
|
|
385
|
+
"""The dependencies instance."""
|
|
386
|
+
|
|
387
|
+
self.inputs = inputs
|
|
388
|
+
"""The initial input data."""
|
|
389
|
+
|
|
390
|
+
self._active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = {}
|
|
391
|
+
"""Active reducers for join operations."""
|
|
392
|
+
|
|
393
|
+
self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None
|
|
394
|
+
"""The next item to be processed."""
|
|
395
|
+
|
|
396
|
+
self._next_task_id = 0
|
|
397
|
+
self._next_node_run_id = 0
|
|
398
|
+
initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),)
|
|
399
|
+
self._first_task = GraphTask(
|
|
400
|
+
node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id()
|
|
401
|
+
)
|
|
402
|
+
self._iterator_task_group = create_task_group()
|
|
403
|
+
self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT](
|
|
404
|
+
self.graph,
|
|
405
|
+
self.state,
|
|
406
|
+
self.deps,
|
|
407
|
+
self._iterator_task_group,
|
|
408
|
+
self._get_next_node_run_id,
|
|
409
|
+
self._get_next_task_id,
|
|
410
|
+
)
|
|
411
|
+
self._iterator = self._iterator_instance.iter_graph(self._first_task)
|
|
412
|
+
|
|
413
|
+
self.__traceparent = traceparent
|
|
414
|
+
self._async_exit_stack = AsyncExitStack()
|
|
415
|
+
|
|
416
|
+
async def __aenter__(self):
|
|
417
|
+
self._async_exit_stack.enter_context(_unwrap_exception_groups())
|
|
418
|
+
await self._async_exit_stack.enter_async_context(self._iterator_task_group)
|
|
419
|
+
await self._async_exit_stack.enter_async_context(self._iterator_context())
|
|
420
|
+
return self
|
|
421
|
+
|
|
422
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
423
|
+
await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
|
|
424
|
+
|
|
425
|
+
@asynccontextmanager
|
|
426
|
+
async def _iterator_context(self):
|
|
427
|
+
try:
|
|
428
|
+
yield
|
|
429
|
+
finally:
|
|
430
|
+
self._iterator_instance.iter_stream_sender.close()
|
|
431
|
+
self._iterator_instance.iter_stream_receiver.close()
|
|
432
|
+
await self._iterator.aclose()
|
|
433
|
+
|
|
434
|
+
@overload
|
|
435
|
+
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
|
|
436
|
+
@overload
|
|
437
|
+
def _traceparent(self) -> str: ...
|
|
438
|
+
def _traceparent(self, *, required: bool = True) -> str | None:
|
|
439
|
+
"""Get the trace parent for instrumentation.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
required: Whether to raise an error if no traceparent exists
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
The traceparent string, or None if not required and not set
|
|
446
|
+
|
|
447
|
+
Raises:
|
|
448
|
+
GraphRuntimeError: If required is True and no traceparent exists
|
|
449
|
+
"""
|
|
450
|
+
if self.__traceparent is None and required: # pragma: no cover
|
|
451
|
+
raise exceptions.GraphRuntimeError('No span was created for this graph run')
|
|
452
|
+
return self.__traceparent
|
|
453
|
+
|
|
454
|
+
def __aiter__(self) -> AsyncIterator[EndMarker[OutputT] | Sequence[GraphTask]]:
|
|
455
|
+
"""Return self as an async iterator.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
Self for async iteration
|
|
459
|
+
"""
|
|
460
|
+
return self
|
|
461
|
+
|
|
462
|
+
async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
|
|
463
|
+
"""Get the next item in the async iteration.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
The next execution result from the graph
|
|
467
|
+
"""
|
|
468
|
+
if self._next is None:
|
|
469
|
+
self._next = await anext(self._iterator)
|
|
470
|
+
else:
|
|
471
|
+
self._next = await self._iterator.asend(self._next)
|
|
472
|
+
return self._next
|
|
473
|
+
|
|
474
|
+
async def next(
|
|
475
|
+
self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None
|
|
476
|
+
) -> EndMarker[OutputT] | Sequence[GraphTask]:
|
|
477
|
+
"""Advance the graph execution by one step.
|
|
478
|
+
|
|
479
|
+
This method allows for sending a value to the iterator, which is useful
|
|
480
|
+
for resuming iteration or overriding intermediate results.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
value: Optional value to send to the iterator
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
The next execution result: either an EndMarker, or sequence of GraphTasks
|
|
487
|
+
"""
|
|
488
|
+
if self._next is None:
|
|
489
|
+
# Prevent `TypeError: can't send non-None value to a just-started async generator`
|
|
490
|
+
# if `next` is called before the `first_node` has run.
|
|
491
|
+
await anext(self)
|
|
492
|
+
if value is not None:
|
|
493
|
+
if isinstance(value, EndMarker):
|
|
494
|
+
self._next = value
|
|
495
|
+
else:
|
|
496
|
+
self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value]
|
|
497
|
+
return await anext(self)
|
|
498
|
+
|
|
499
|
+
@property
|
|
500
|
+
def next_task(self) -> EndMarker[OutputT] | Sequence[GraphTask]:
|
|
501
|
+
"""Get the next task(s) to be executed.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The next execution item, or the initial task if none is set
|
|
505
|
+
"""
|
|
506
|
+
return self._next or [self._first_task]
|
|
507
|
+
|
|
508
|
+
@property
|
|
509
|
+
def output(self) -> OutputT | None:
|
|
510
|
+
"""Get the final output if the graph has completed.
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
The output value if execution is complete, None otherwise
|
|
514
|
+
"""
|
|
515
|
+
if isinstance(self._next, EndMarker):
|
|
516
|
+
return self._next.value
|
|
517
|
+
return None
|
|
518
|
+
|
|
519
|
+
def _get_next_task_id(self) -> TaskID:
|
|
520
|
+
next_id = TaskID(f'task:{self._next_task_id}')
|
|
521
|
+
self._next_task_id += 1
|
|
522
|
+
return next_id
|
|
523
|
+
|
|
524
|
+
def _get_next_node_run_id(self) -> NodeRunID:
|
|
525
|
+
next_id = NodeRunID(f'task:{self._next_node_run_id}')
|
|
526
|
+
self._next_node_run_id += 1
|
|
527
|
+
return next_id
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
@dataclass
|
|
531
|
+
class _GraphTaskAsyncIterable:
|
|
532
|
+
iterable: AsyncIterable[Sequence[GraphTask]]
|
|
533
|
+
fork_stack: ForkStack
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@dataclass
|
|
537
|
+
class _GraphTaskResult:
|
|
538
|
+
source: GraphTask
|
|
539
|
+
result: EndMarker[Any] | Sequence[GraphTask] | JoinItem
|
|
540
|
+
source_is_finished: bool = True
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
@dataclass
|
|
544
|
+
class _GraphIterator(Generic[StateT, DepsT, OutputT]):
|
|
545
|
+
graph: Graph[StateT, DepsT, Any, OutputT]
|
|
546
|
+
state: StateT
|
|
547
|
+
deps: DepsT
|
|
548
|
+
task_group: TaskGroup
|
|
549
|
+
get_next_node_run_id: Callable[[], NodeRunID]
|
|
550
|
+
get_next_task_id: Callable[[], TaskID]
|
|
551
|
+
|
|
552
|
+
cancel_scopes: dict[TaskID, CancelScope] = field(init=False)
|
|
553
|
+
active_tasks: dict[TaskID, GraphTask] = field(init=False)
|
|
554
|
+
active_reducers: dict[tuple[JoinID, NodeRunID], JoinState] = field(init=False)
|
|
555
|
+
iter_stream_sender: MemoryObjectSendStream[_GraphTaskResult] = field(init=False)
|
|
556
|
+
iter_stream_receiver: MemoryObjectReceiveStream[_GraphTaskResult] = field(init=False)
|
|
557
|
+
|
|
558
|
+
def __post_init__(self):
|
|
559
|
+
self.cancel_scopes = {}
|
|
560
|
+
self.active_tasks = {}
|
|
561
|
+
self.active_reducers = {}
|
|
562
|
+
self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]()
|
|
563
|
+
self._next_node_run_id = 1
|
|
564
|
+
|
|
565
|
+
async def iter_graph( # noqa C901
|
|
566
|
+
self, first_task: GraphTask
|
|
567
|
+
) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
|
|
568
|
+
async with self.iter_stream_sender:
|
|
569
|
+
try:
|
|
570
|
+
# Fire off the first task
|
|
571
|
+
self.active_tasks[first_task.task_id] = first_task
|
|
572
|
+
self._handle_execution_request([first_task])
|
|
573
|
+
|
|
574
|
+
# Handle task results
|
|
575
|
+
async with self.iter_stream_receiver:
|
|
576
|
+
while self.active_tasks or self.active_reducers:
|
|
577
|
+
async for task_result in self.iter_stream_receiver: # pragma: no branch
|
|
578
|
+
if isinstance(task_result.result, JoinItem):
|
|
579
|
+
maybe_overridden_result = task_result.result
|
|
580
|
+
else:
|
|
581
|
+
maybe_overridden_result = yield task_result.result
|
|
582
|
+
if isinstance(maybe_overridden_result, EndMarker):
|
|
583
|
+
# If we got an end marker, this task is definitely done, and we're ready to
|
|
584
|
+
# start cleaning everything up
|
|
585
|
+
await self._finish_task(task_result.source.task_id)
|
|
586
|
+
if self.active_tasks:
|
|
587
|
+
# Cancel the remaining tasks
|
|
588
|
+
self.task_group.cancel_scope.cancel()
|
|
589
|
+
return
|
|
590
|
+
elif isinstance(maybe_overridden_result, JoinItem):
|
|
591
|
+
result = maybe_overridden_result
|
|
592
|
+
parent_fork_id = self.graph.get_parent_fork(result.join_id).fork_id
|
|
593
|
+
for i, x in enumerate(result.fork_stack[::-1]):
|
|
594
|
+
if x.fork_id == parent_fork_id:
|
|
595
|
+
# For non-final joins (those that are intermediate nodes of other joins),
|
|
596
|
+
# preserve the fork stack so downstream joins can still associate with the same fork run
|
|
597
|
+
if self.graph.is_final_join(result.join_id):
|
|
598
|
+
# Final join: remove the parent fork from the stack
|
|
599
|
+
downstream_fork_stack = result.fork_stack[: len(result.fork_stack) - i]
|
|
600
|
+
else:
|
|
601
|
+
# Non-final join: preserve the fork stack
|
|
602
|
+
downstream_fork_stack = result.fork_stack
|
|
603
|
+
fork_run_id = x.node_run_id
|
|
604
|
+
break
|
|
605
|
+
else: # pragma: no cover
|
|
606
|
+
raise RuntimeError('Parent fork run not found')
|
|
607
|
+
|
|
608
|
+
join_node = self.graph.nodes[result.join_id]
|
|
609
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
610
|
+
join_state = self.active_reducers.get((result.join_id, fork_run_id))
|
|
611
|
+
if join_state is None:
|
|
612
|
+
current = join_node.initial_factory()
|
|
613
|
+
join_state = self.active_reducers[(result.join_id, fork_run_id)] = JoinState(
|
|
614
|
+
current, downstream_fork_stack
|
|
615
|
+
)
|
|
616
|
+
context = ReducerContext(state=self.state, deps=self.deps, join_state=join_state)
|
|
617
|
+
join_state.current = join_node.reduce(context, join_state.current, result.inputs)
|
|
618
|
+
if join_state.cancelled_sibling_tasks:
|
|
619
|
+
await self._cancel_sibling_tasks(parent_fork_id, fork_run_id)
|
|
620
|
+
else:
|
|
621
|
+
for new_task in maybe_overridden_result:
|
|
622
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
623
|
+
|
|
624
|
+
tasks_by_id_values = list(self.active_tasks.values())
|
|
625
|
+
join_tasks: list[GraphTask] = []
|
|
626
|
+
|
|
627
|
+
for join_id, fork_run_id in self._get_completed_fork_runs(
|
|
628
|
+
task_result.source, tasks_by_id_values
|
|
629
|
+
):
|
|
630
|
+
join_state = self.active_reducers.pop((join_id, fork_run_id))
|
|
631
|
+
join_node = self.graph.nodes[join_id]
|
|
632
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
633
|
+
new_tasks = self._handle_non_fork_edges(
|
|
634
|
+
join_node, join_state.current, join_state.downstream_fork_stack
|
|
635
|
+
)
|
|
636
|
+
join_tasks.extend(new_tasks)
|
|
637
|
+
if join_tasks:
|
|
638
|
+
for new_task in join_tasks:
|
|
639
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
640
|
+
self._handle_execution_request(join_tasks)
|
|
641
|
+
|
|
642
|
+
if isinstance(maybe_overridden_result, Sequence):
|
|
643
|
+
if isinstance(task_result.result, Sequence):
|
|
644
|
+
new_task_ids = {t.task_id for t in maybe_overridden_result}
|
|
645
|
+
for t in task_result.result:
|
|
646
|
+
if t.task_id not in new_task_ids:
|
|
647
|
+
await self._finish_task(t.task_id)
|
|
648
|
+
self._handle_execution_request(maybe_overridden_result)
|
|
649
|
+
|
|
650
|
+
if task_result.source_is_finished:
|
|
651
|
+
await self._finish_task(task_result.source.task_id)
|
|
652
|
+
|
|
653
|
+
if not self.active_tasks:
|
|
654
|
+
# if there are no active tasks, we'll be waiting forever for the next result..
|
|
655
|
+
break
|
|
656
|
+
|
|
657
|
+
if self.active_reducers: # pragma: no branch
|
|
658
|
+
# In this case, there are no pending tasks. We can therefore finalize all active reducers
|
|
659
|
+
# that don't have intermediate joins which are also active reducers. If a join J2 has an
|
|
660
|
+
# intermediate join J1 that shares the same parent fork run, we must finalize J1 first
|
|
661
|
+
# because it might produce items that feed into J2.
|
|
662
|
+
for (join_id, fork_run_id), join_state in list(self.active_reducers.items()):
|
|
663
|
+
# Check if this join has any intermediate joins that are also active reducers
|
|
664
|
+
should_skip = False
|
|
665
|
+
intermediate_joins = self.graph.intermediate_join_nodes.get(join_id, set())
|
|
666
|
+
|
|
667
|
+
# Get the parent fork for this join to use for comparison
|
|
668
|
+
join_parent_fork = self.graph.get_parent_fork(join_id)
|
|
669
|
+
|
|
670
|
+
for intermediate_join_id in intermediate_joins:
|
|
671
|
+
# Check if the intermediate join is also an active reducer with matching fork run
|
|
672
|
+
for (other_join_id, _), other_join_state in self.active_reducers.items():
|
|
673
|
+
if other_join_id == intermediate_join_id:
|
|
674
|
+
# Check if they share the same fork run for this join's parent fork
|
|
675
|
+
# by finding the parent fork's node_run_id in both fork stacks
|
|
676
|
+
join_parent_fork_run_id = None
|
|
677
|
+
other_parent_fork_run_id = None
|
|
678
|
+
|
|
679
|
+
for fsi in join_state.downstream_fork_stack: # pragma: no branch
|
|
680
|
+
if fsi.fork_id == join_parent_fork.fork_id:
|
|
681
|
+
join_parent_fork_run_id = fsi.node_run_id
|
|
682
|
+
break
|
|
683
|
+
|
|
684
|
+
for fsi in other_join_state.downstream_fork_stack: # pragma: no branch
|
|
685
|
+
if fsi.fork_id == join_parent_fork.fork_id:
|
|
686
|
+
other_parent_fork_run_id = fsi.node_run_id
|
|
687
|
+
break
|
|
688
|
+
|
|
689
|
+
if (
|
|
690
|
+
join_parent_fork_run_id
|
|
691
|
+
and other_parent_fork_run_id
|
|
692
|
+
and join_parent_fork_run_id == other_parent_fork_run_id
|
|
693
|
+
): # pragma: no branch
|
|
694
|
+
should_skip = True
|
|
695
|
+
break
|
|
696
|
+
if should_skip:
|
|
697
|
+
break
|
|
698
|
+
|
|
699
|
+
if should_skip:
|
|
700
|
+
continue
|
|
701
|
+
|
|
702
|
+
self.active_reducers.pop(
|
|
703
|
+
(join_id, fork_run_id)
|
|
704
|
+
) # we're handling it now, so we can pop it
|
|
705
|
+
join_node = self.graph.nodes[join_id]
|
|
706
|
+
assert isinstance(join_node, Join), f'Expected a `Join` but got {join_node}'
|
|
707
|
+
new_tasks = self._handle_non_fork_edges(
|
|
708
|
+
join_node, join_state.current, join_state.downstream_fork_stack
|
|
709
|
+
)
|
|
710
|
+
maybe_overridden_result = yield new_tasks
|
|
711
|
+
if isinstance(maybe_overridden_result, EndMarker): # pragma: no cover
|
|
712
|
+
# This is theoretically reachable but it would be awkward.
|
|
713
|
+
# Probably a better way to get coverage here would be to unify the code pat
|
|
714
|
+
# with the other `if isinstance(maybe_overridden_result, EndMarker):`
|
|
715
|
+
self.task_group.cancel_scope.cancel()
|
|
716
|
+
return
|
|
717
|
+
for new_task in maybe_overridden_result:
|
|
718
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
719
|
+
new_task_ids = {t.task_id for t in maybe_overridden_result}
|
|
720
|
+
for t in new_tasks:
|
|
721
|
+
# Same note as above about how this is theoretically reachable but we should
|
|
722
|
+
# just get coverage by unifying the code paths
|
|
723
|
+
if t.task_id not in new_task_ids: # pragma: no cover
|
|
724
|
+
await self._finish_task(t.task_id)
|
|
725
|
+
self._handle_execution_request(maybe_overridden_result)
|
|
726
|
+
except GeneratorExit:
|
|
727
|
+
self.task_group.cancel_scope.cancel()
|
|
728
|
+
return
|
|
729
|
+
|
|
730
|
+
raise RuntimeError( # pragma: no cover
|
|
731
|
+
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
async def _finish_task(self, task_id: TaskID) -> None:
|
|
735
|
+
# node_id is just included for debugging right now
|
|
736
|
+
scope = self.cancel_scopes.pop(task_id, None)
|
|
737
|
+
if scope is not None:
|
|
738
|
+
scope.cancel()
|
|
739
|
+
self.active_tasks.pop(task_id, None)
|
|
740
|
+
|
|
741
|
+
def _handle_execution_request(self, request: Sequence[GraphTask]) -> None:
|
|
742
|
+
for new_task in request:
|
|
743
|
+
self.active_tasks[new_task.task_id] = new_task
|
|
744
|
+
for new_task in request:
|
|
745
|
+
self.task_group.start_soon(self._run_tracked_task, new_task)
|
|
746
|
+
|
|
747
|
+
async def _run_tracked_task(self, t_: GraphTask):
|
|
748
|
+
with CancelScope() as scope:
|
|
749
|
+
self.cancel_scopes[t_.task_id] = scope
|
|
750
|
+
result = await self._run_task(t_)
|
|
751
|
+
if isinstance(result, _GraphTaskAsyncIterable):
|
|
752
|
+
async for new_tasks in result.iterable:
|
|
753
|
+
await self.iter_stream_sender.send(_GraphTaskResult(t_, new_tasks, False))
|
|
754
|
+
await self.iter_stream_sender.send(_GraphTaskResult(t_, []))
|
|
755
|
+
else:
|
|
756
|
+
await self.iter_stream_sender.send(_GraphTaskResult(t_, result))
|
|
757
|
+
|
|
758
|
+
async def _run_task(
|
|
759
|
+
self,
|
|
760
|
+
task: GraphTask,
|
|
761
|
+
) -> EndMarker[OutputT] | Sequence[GraphTask] | _GraphTaskAsyncIterable | JoinItem:
|
|
762
|
+
state = self.state
|
|
763
|
+
deps = self.deps
|
|
764
|
+
|
|
765
|
+
node_id = task.node_id
|
|
766
|
+
inputs = task.inputs
|
|
767
|
+
fork_stack = task.fork_stack
|
|
768
|
+
|
|
769
|
+
node = self.graph.nodes[node_id]
|
|
770
|
+
|
|
771
|
+
if isinstance(node, StartNode | Fork):
|
|
772
|
+
return self._handle_edges(node, inputs, fork_stack)
|
|
773
|
+
elif isinstance(node, Step):
|
|
774
|
+
with ExitStack() as stack:
|
|
775
|
+
if self.graph.auto_instrument:
|
|
776
|
+
stack.enter_context(logfire_span('run node {node_id}', node_id=node.id, node=node))
|
|
777
|
+
|
|
778
|
+
step_context = StepContext[StateT, DepsT, Any](state=state, deps=deps, inputs=inputs)
|
|
779
|
+
output = await node.call(step_context)
|
|
780
|
+
if isinstance(node, NodeStep):
|
|
781
|
+
return self._handle_node(output, fork_stack)
|
|
782
|
+
else:
|
|
783
|
+
return self._handle_edges(node, output, fork_stack)
|
|
784
|
+
elif isinstance(node, Join):
|
|
785
|
+
return JoinItem(node_id, inputs, fork_stack)
|
|
786
|
+
elif isinstance(node, Decision):
|
|
787
|
+
return self._handle_decision(node, inputs, fork_stack)
|
|
788
|
+
elif isinstance(node, EndNode):
|
|
789
|
+
return EndMarker(inputs)
|
|
790
|
+
else:
|
|
791
|
+
assert_never(node)
|
|
792
|
+
|
|
793
|
+
def _handle_decision(
|
|
794
|
+
self, decision: Decision[StateT, DepsT, Any], inputs: Any, fork_stack: ForkStack
|
|
795
|
+
) -> Sequence[GraphTask]:
|
|
796
|
+
for branch in decision.branches:
|
|
797
|
+
match_tester = branch.matches
|
|
798
|
+
if match_tester is not None:
|
|
799
|
+
inputs_match = match_tester(inputs)
|
|
800
|
+
else:
|
|
801
|
+
branch_source = unpack_type_expression(branch.source)
|
|
802
|
+
|
|
803
|
+
if branch_source in {Any, object}:
|
|
804
|
+
inputs_match = True
|
|
805
|
+
elif get_origin(branch_source) is Literal:
|
|
806
|
+
inputs_match = inputs in get_args(branch_source)
|
|
807
|
+
else:
|
|
808
|
+
try:
|
|
809
|
+
inputs_match = isinstance(inputs, branch_source)
|
|
810
|
+
except TypeError as e: # pragma: no cover
|
|
811
|
+
raise RuntimeError(f'Decision branch source {branch_source} is not a valid type.') from e
|
|
812
|
+
|
|
813
|
+
if inputs_match:
|
|
814
|
+
return self._handle_path(branch.path, inputs, fork_stack)
|
|
815
|
+
|
|
816
|
+
raise RuntimeError(f'No branch matched inputs {inputs} for decision node {decision}.')
|
|
817
|
+
|
|
818
|
+
def _handle_node(
|
|
819
|
+
self,
|
|
820
|
+
next_node: BaseNode[StateT, DepsT, Any] | End[Any],
|
|
821
|
+
fork_stack: ForkStack,
|
|
822
|
+
) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]:
|
|
823
|
+
if isinstance(next_node, StepNode):
|
|
824
|
+
return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())]
|
|
825
|
+
elif isinstance(next_node, JoinNode):
|
|
826
|
+
return JoinItem(next_node.join.id, next_node.inputs, fork_stack)
|
|
827
|
+
elif isinstance(next_node, BaseNode):
|
|
828
|
+
node_step = NodeStep(next_node.__class__)
|
|
829
|
+
return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())]
|
|
830
|
+
elif isinstance(next_node, End):
|
|
831
|
+
return EndMarker(next_node.data)
|
|
832
|
+
else:
|
|
833
|
+
assert_never(next_node)
|
|
834
|
+
|
|
835
|
+
def _get_completed_fork_runs(
|
|
836
|
+
self,
|
|
837
|
+
t: GraphTask,
|
|
838
|
+
active_tasks: Iterable[GraphTask],
|
|
839
|
+
) -> list[tuple[JoinID, NodeRunID]]:
|
|
840
|
+
completed_fork_runs: list[tuple[JoinID, NodeRunID]] = []
|
|
841
|
+
|
|
842
|
+
fork_run_indices = {fsi.node_run_id: i for i, fsi in enumerate(t.fork_stack)}
|
|
843
|
+
for join_id, fork_run_id in self.active_reducers.keys():
|
|
844
|
+
fork_run_index = fork_run_indices.get(fork_run_id)
|
|
845
|
+
if fork_run_index is None:
|
|
846
|
+
continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it.
|
|
847
|
+
|
|
848
|
+
# This reducer _may_ now be ready to finalize:
|
|
849
|
+
if self._is_fork_run_completed(active_tasks, join_id, fork_run_id):
|
|
850
|
+
completed_fork_runs.append((join_id, fork_run_id))
|
|
851
|
+
|
|
852
|
+
return completed_fork_runs
|
|
853
|
+
|
|
854
|
+
def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
|
|
855
|
+
if not path.items:
|
|
856
|
+
return [] # pragma: no cover
|
|
857
|
+
|
|
858
|
+
item = path.items[0]
|
|
859
|
+
assert not isinstance(item, MapMarker | BroadcastMarker), (
|
|
860
|
+
'These markers should be removed from paths during graph building'
|
|
861
|
+
)
|
|
862
|
+
if isinstance(item, DestinationMarker):
|
|
863
|
+
return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())]
|
|
864
|
+
elif isinstance(item, TransformMarker):
|
|
865
|
+
inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs))
|
|
866
|
+
return self._handle_path(path.next_path, inputs, fork_stack)
|
|
867
|
+
elif isinstance(item, LabelMarker):
|
|
868
|
+
return self._handle_path(path.next_path, inputs, fork_stack)
|
|
869
|
+
else:
|
|
870
|
+
assert_never(item)
|
|
871
|
+
|
|
872
|
+
def _handle_edges(
|
|
873
|
+
self, node: AnyNode, inputs: Any, fork_stack: ForkStack
|
|
874
|
+
) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
|
|
875
|
+
if isinstance(node, Fork):
|
|
876
|
+
return self._handle_fork_edges(node, inputs, fork_stack)
|
|
877
|
+
else:
|
|
878
|
+
return self._handle_non_fork_edges(node, inputs, fork_stack)
|
|
879
|
+
|
|
880
|
+
def _handle_non_fork_edges(self, node: AnyNode, inputs: Any, fork_stack: ForkStack) -> Sequence[GraphTask]:
|
|
881
|
+
edges = self.graph.edges_by_source.get(node.id, [])
|
|
882
|
+
assert len(edges) == 1 # this should have already been ensured during graph building
|
|
883
|
+
return self._handle_path(edges[0], inputs, fork_stack)
|
|
884
|
+
|
|
885
|
+
def _handle_fork_edges(
|
|
886
|
+
self, node: Fork[Any, Any], inputs: Any, fork_stack: ForkStack
|
|
887
|
+
) -> Sequence[GraphTask] | _GraphTaskAsyncIterable:
|
|
888
|
+
edges = self.graph.edges_by_source.get(node.id, [])
|
|
889
|
+
assert len(edges) == 1 or (isinstance(node, Fork) and not node.is_map), (
|
|
890
|
+
edges,
|
|
891
|
+
node.id,
|
|
892
|
+
) # this should have already been ensured during graph building
|
|
893
|
+
|
|
894
|
+
new_tasks: list[GraphTask] = []
|
|
895
|
+
node_run_id = self.get_next_node_run_id()
|
|
896
|
+
if node.is_map:
|
|
897
|
+
# If the map specifies a downstream join id, eagerly create a join state for it
|
|
898
|
+
if (join_id := node.downstream_join_id) is not None:
|
|
899
|
+
join_node = self.graph.nodes[join_id]
|
|
900
|
+
assert isinstance(join_node, Join)
|
|
901
|
+
self.active_reducers[(join_id, node_run_id)] = JoinState(join_node.initial_factory(), fork_stack)
|
|
902
|
+
|
|
903
|
+
# Eagerly raise a clear error if the input value is not iterable as expected
|
|
904
|
+
if _is_any_iterable(inputs):
|
|
905
|
+
for thread_index, input_item in enumerate(inputs):
|
|
906
|
+
item_tasks = self._handle_path(
|
|
907
|
+
edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
|
|
908
|
+
)
|
|
909
|
+
new_tasks += item_tasks
|
|
910
|
+
elif _is_any_async_iterable(inputs):
|
|
911
|
+
|
|
912
|
+
async def handle_async_iterable() -> AsyncIterator[Sequence[GraphTask]]:
|
|
913
|
+
thread_index = 0
|
|
914
|
+
async for input_item in inputs:
|
|
915
|
+
item_tasks = self._handle_path(
|
|
916
|
+
edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
|
|
917
|
+
)
|
|
918
|
+
yield item_tasks
|
|
919
|
+
thread_index += 1
|
|
920
|
+
|
|
921
|
+
return _GraphTaskAsyncIterable(handle_async_iterable(), fork_stack)
|
|
922
|
+
|
|
923
|
+
else:
|
|
924
|
+
raise RuntimeError(f'Cannot map non-iterable value: {inputs!r}')
|
|
925
|
+
else:
|
|
926
|
+
for i, path in enumerate(edges):
|
|
927
|
+
new_tasks += self._handle_path(path, inputs, fork_stack + (ForkStackItem(node.id, node_run_id, i),))
|
|
928
|
+
return new_tasks
|
|
929
|
+
|
|
930
|
+
def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinID, fork_run_id: NodeRunID) -> bool:
|
|
931
|
+
# Check if any of the tasks in the graph have this fork_run_id in their fork_stack
|
|
932
|
+
# If this is the case, then the fork run is not yet completed
|
|
933
|
+
parent_fork = self.graph.get_parent_fork(join_id)
|
|
934
|
+
for t in tasks:
|
|
935
|
+
if fork_run_id in {x.node_run_id for x in t.fork_stack}:
|
|
936
|
+
if t.node_id in parent_fork.intermediate_nodes or t.node_id == join_id:
|
|
937
|
+
return False
|
|
938
|
+
else:
|
|
939
|
+
pass
|
|
940
|
+
return True
|
|
941
|
+
|
|
942
|
+
async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeRunID):
|
|
943
|
+
task_ids_to_cancel = set[TaskID]()
|
|
944
|
+
for task_id, t in self.active_tasks.items():
|
|
945
|
+
for item in t.fork_stack: # pragma: no branch
|
|
946
|
+
if item.fork_id == parent_fork_id and item.node_run_id == node_run_id:
|
|
947
|
+
task_ids_to_cancel.add(task_id)
|
|
948
|
+
break
|
|
949
|
+
else:
|
|
950
|
+
pass
|
|
951
|
+
for task_id in task_ids_to_cancel:
|
|
952
|
+
await self._finish_task(task_id)
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
|
|
956
|
+
return isinstance(x, Iterable)
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
def _is_any_async_iterable(x: Any) -> TypeGuard[AsyncIterable[Any]]:
|
|
960
|
+
return isinstance(x, AsyncIterable)
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
@contextmanager
|
|
964
|
+
def _unwrap_exception_groups():
|
|
965
|
+
# I need to use a helper function for this because I can't figure out a way to get pyright
|
|
966
|
+
# to type-check the ExceptionGroup catching in both 3.13 and 3.10 without emitting type errors in one;
|
|
967
|
+
# if I try to ignore them in one, I get unnecessary-type-ignore errors in the other
|
|
968
|
+
if TYPE_CHECKING:
|
|
969
|
+
yield
|
|
970
|
+
else:
|
|
971
|
+
try:
|
|
972
|
+
yield
|
|
973
|
+
except ExceptionGroup as e:
|
|
974
|
+
exception = e.exceptions[0]
|
|
975
|
+
if exception.__cause__ is None:
|
|
976
|
+
# bizarrely, this prevents recursion errors when formatting the exception for logfire
|
|
977
|
+
exception.__cause__ = None
|
|
978
|
+
raise exception
|