pydantic-graph 1.0.14__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph/_utils.py +39 -0
- pydantic_graph/beta/__init__.py +25 -0
- pydantic_graph/beta/decision.py +276 -0
- pydantic_graph/beta/graph.py +978 -0
- pydantic_graph/beta/graph_builder.py +1053 -0
- pydantic_graph/beta/id_types.py +76 -0
- pydantic_graph/beta/join.py +249 -0
- pydantic_graph/beta/mermaid.py +208 -0
- pydantic_graph/beta/node.py +95 -0
- pydantic_graph/beta/node_types.py +90 -0
- pydantic_graph/beta/parent_forks.py +232 -0
- pydantic_graph/beta/paths.py +421 -0
- pydantic_graph/beta/step.py +253 -0
- pydantic_graph/beta/util.py +90 -0
- pydantic_graph/exceptions.py +22 -0
- pydantic_graph/graph.py +12 -4
- pydantic_graph/nodes.py +0 -2
- pydantic_graph/persistence/in_mem.py +3 -3
- {pydantic_graph-1.0.14.dist-info → pydantic_graph-1.22.0.dist-info}/METADATA +1 -1
- pydantic_graph-1.22.0.dist-info/RECORD +28 -0
- pydantic_graph-1.0.14.dist-info/RECORD +0 -15
- {pydantic_graph-1.0.14.dist-info → pydantic_graph-1.22.0.dist-info}/WHEEL +0 -0
- {pydantic_graph-1.0.14.dist-info → pydantic_graph-1.22.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1053 @@
|
|
|
1
|
+
"""Graph builder for constructing executable graph definitions.
|
|
2
|
+
|
|
3
|
+
This module provides the GraphBuilder class and related utilities for
|
|
4
|
+
constructing typed, executable graph definitions with steps, joins,
|
|
5
|
+
decisions, and edge routing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import inspect
|
|
11
|
+
from collections import Counter, defaultdict
|
|
12
|
+
from collections.abc import AsyncIterable, Callable, Iterable
|
|
13
|
+
from dataclasses import dataclass, replace
|
|
14
|
+
from types import NoneType
|
|
15
|
+
from typing import Any, Generic, Literal, cast, get_origin, get_type_hints, overload
|
|
16
|
+
|
|
17
|
+
from typing_extensions import Never, TypeAliasType, TypeVar
|
|
18
|
+
|
|
19
|
+
from pydantic_graph import _utils, exceptions
|
|
20
|
+
from pydantic_graph._utils import UNSET, Unset
|
|
21
|
+
from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder
|
|
22
|
+
from pydantic_graph.beta.graph import Graph
|
|
23
|
+
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, generate_placeholder_node_id, replace_placeholder_id
|
|
24
|
+
from pydantic_graph.beta.join import Join, JoinNode, ReducerFunction
|
|
25
|
+
from pydantic_graph.beta.mermaid import build_mermaid_graph
|
|
26
|
+
from pydantic_graph.beta.node import (
|
|
27
|
+
EndNode,
|
|
28
|
+
Fork,
|
|
29
|
+
StartNode,
|
|
30
|
+
)
|
|
31
|
+
from pydantic_graph.beta.node_types import (
|
|
32
|
+
AnyDestinationNode,
|
|
33
|
+
AnyNode,
|
|
34
|
+
DestinationNode,
|
|
35
|
+
SourceNode,
|
|
36
|
+
)
|
|
37
|
+
from pydantic_graph.beta.parent_forks import ParentFork, ParentForkFinder
|
|
38
|
+
from pydantic_graph.beta.paths import (
|
|
39
|
+
BroadcastMarker,
|
|
40
|
+
DestinationMarker,
|
|
41
|
+
EdgePath,
|
|
42
|
+
EdgePathBuilder,
|
|
43
|
+
MapMarker,
|
|
44
|
+
Path,
|
|
45
|
+
PathBuilder,
|
|
46
|
+
)
|
|
47
|
+
from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepFunction, StepNode, StreamFunction
|
|
48
|
+
from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression
|
|
49
|
+
from pydantic_graph.exceptions import GraphBuildingError, GraphValidationError
|
|
50
|
+
from pydantic_graph.nodes import BaseNode, End
|
|
51
|
+
|
|
52
|
+
StateT = TypeVar('StateT', infer_variance=True)
|
|
53
|
+
DepsT = TypeVar('DepsT', infer_variance=True)
|
|
54
|
+
InputT = TypeVar('InputT', infer_variance=True)
|
|
55
|
+
OutputT = TypeVar('OutputT', infer_variance=True)
|
|
56
|
+
SourceT = TypeVar('SourceT', infer_variance=True)
|
|
57
|
+
SourceNodeT = TypeVar('SourceNodeT', bound=BaseNode[Any, Any, Any], infer_variance=True)
|
|
58
|
+
SourceOutputT = TypeVar('SourceOutputT', infer_variance=True)
|
|
59
|
+
GraphInputT = TypeVar('GraphInputT', infer_variance=True)
|
|
60
|
+
GraphOutputT = TypeVar('GraphOutputT', infer_variance=True)
|
|
61
|
+
T = TypeVar('T', infer_variance=True)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass(init=False)
|
|
65
|
+
class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
|
|
66
|
+
"""A builder for constructing executable graph definitions.
|
|
67
|
+
|
|
68
|
+
GraphBuilder provides a fluent interface for defining nodes, edges, and
|
|
69
|
+
routing in a graph workflow. It supports typed state, dependencies, and
|
|
70
|
+
input/output validation.
|
|
71
|
+
|
|
72
|
+
Type Parameters:
|
|
73
|
+
StateT: The type of the graph state
|
|
74
|
+
DepsT: The type of the dependencies
|
|
75
|
+
GraphInputT: The type of the graph input data
|
|
76
|
+
GraphOutputT: The type of the graph output data
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
name: str | None
|
|
80
|
+
"""Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method."""
|
|
81
|
+
|
|
82
|
+
state_type: TypeOrTypeExpression[StateT]
|
|
83
|
+
"""The type of the graph state."""
|
|
84
|
+
|
|
85
|
+
deps_type: TypeOrTypeExpression[DepsT]
|
|
86
|
+
"""The type of the dependencies."""
|
|
87
|
+
|
|
88
|
+
input_type: TypeOrTypeExpression[GraphInputT]
|
|
89
|
+
"""The type of the graph input data."""
|
|
90
|
+
|
|
91
|
+
output_type: TypeOrTypeExpression[GraphOutputT]
|
|
92
|
+
"""The type of the graph output data."""
|
|
93
|
+
|
|
94
|
+
auto_instrument: bool
|
|
95
|
+
"""Whether to automatically create instrumentation spans."""
|
|
96
|
+
|
|
97
|
+
_nodes: dict[NodeID, AnyNode]
|
|
98
|
+
"""Internal storage for nodes in the graph."""
|
|
99
|
+
|
|
100
|
+
_edges_by_source: dict[NodeID, list[Path]]
|
|
101
|
+
"""Internal storage for edges by source node."""
|
|
102
|
+
|
|
103
|
+
_decision_index: int
|
|
104
|
+
"""Counter for generating unique decision node IDs."""
|
|
105
|
+
|
|
106
|
+
Source = TypeAliasType('Source', SourceNode[StateT, DepsT, OutputT], type_params=(OutputT,))
|
|
107
|
+
Destination = TypeAliasType('Destination', DestinationNode[StateT, DepsT, InputT], type_params=(InputT,))
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
*,
|
|
112
|
+
name: str | None = None,
|
|
113
|
+
state_type: TypeOrTypeExpression[StateT] = NoneType,
|
|
114
|
+
deps_type: TypeOrTypeExpression[DepsT] = NoneType,
|
|
115
|
+
input_type: TypeOrTypeExpression[GraphInputT] = NoneType,
|
|
116
|
+
output_type: TypeOrTypeExpression[GraphOutputT] = NoneType,
|
|
117
|
+
auto_instrument: bool = True,
|
|
118
|
+
):
|
|
119
|
+
"""Initialize a graph builder.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method.
|
|
123
|
+
state_type: The type of the graph state
|
|
124
|
+
deps_type: The type of the dependencies
|
|
125
|
+
input_type: The type of the graph input data
|
|
126
|
+
output_type: The type of the graph output data
|
|
127
|
+
auto_instrument: Whether to automatically create instrumentation spans
|
|
128
|
+
"""
|
|
129
|
+
self.name = name
|
|
130
|
+
|
|
131
|
+
self.state_type = state_type
|
|
132
|
+
self.deps_type = deps_type
|
|
133
|
+
self.input_type = input_type
|
|
134
|
+
self.output_type = output_type
|
|
135
|
+
|
|
136
|
+
self.auto_instrument = auto_instrument
|
|
137
|
+
|
|
138
|
+
self._nodes = {}
|
|
139
|
+
self._edges_by_source = defaultdict(list)
|
|
140
|
+
self._decision_index = 1
|
|
141
|
+
|
|
142
|
+
self._start_node = StartNode[GraphInputT]()
|
|
143
|
+
self._end_node = EndNode[GraphOutputT]()
|
|
144
|
+
|
|
145
|
+
# Node building
|
|
146
|
+
@property
|
|
147
|
+
def start_node(self) -> StartNode[GraphInputT]:
|
|
148
|
+
"""Get the start node for the graph.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
The start node that receives the initial graph input
|
|
152
|
+
"""
|
|
153
|
+
return self._start_node
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def end_node(self) -> EndNode[GraphOutputT]:
|
|
157
|
+
"""Get the end node for the graph.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
The end node that produces the final graph output
|
|
161
|
+
"""
|
|
162
|
+
return self._end_node
|
|
163
|
+
|
|
164
|
+
@overload
|
|
165
|
+
def step(
|
|
166
|
+
self,
|
|
167
|
+
*,
|
|
168
|
+
node_id: str | None = None,
|
|
169
|
+
label: str | None = None,
|
|
170
|
+
) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ...
|
|
171
|
+
@overload
|
|
172
|
+
def step(
|
|
173
|
+
self,
|
|
174
|
+
call: StepFunction[StateT, DepsT, InputT, OutputT],
|
|
175
|
+
*,
|
|
176
|
+
node_id: str | None = None,
|
|
177
|
+
label: str | None = None,
|
|
178
|
+
) -> Step[StateT, DepsT, InputT, OutputT]: ...
|
|
179
|
+
def step(
|
|
180
|
+
self,
|
|
181
|
+
call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None,
|
|
182
|
+
*,
|
|
183
|
+
node_id: str | None = None,
|
|
184
|
+
label: str | None = None,
|
|
185
|
+
) -> (
|
|
186
|
+
Step[StateT, DepsT, InputT, OutputT]
|
|
187
|
+
| Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]
|
|
188
|
+
):
|
|
189
|
+
"""Create a step from a step function.
|
|
190
|
+
|
|
191
|
+
This method can be used as a decorator or called directly to create
|
|
192
|
+
a step node from an async function.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
call: The step function to wrap
|
|
196
|
+
node_id: Optional ID for the node
|
|
197
|
+
label: Optional human-readable label
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Either a Step instance or a decorator function
|
|
201
|
+
"""
|
|
202
|
+
if call is None:
|
|
203
|
+
|
|
204
|
+
def decorator(
|
|
205
|
+
func: StepFunction[StateT, DepsT, InputT, OutputT],
|
|
206
|
+
) -> Step[StateT, DepsT, InputT, OutputT]:
|
|
207
|
+
return self.step(call=func, node_id=node_id, label=label)
|
|
208
|
+
|
|
209
|
+
return decorator
|
|
210
|
+
|
|
211
|
+
node_id = node_id or get_callable_name(call)
|
|
212
|
+
|
|
213
|
+
step = Step[StateT, DepsT, InputT, OutputT](id=NodeID(node_id), call=call, label=label)
|
|
214
|
+
|
|
215
|
+
return step
|
|
216
|
+
|
|
217
|
+
@overload
|
|
218
|
+
def stream(
|
|
219
|
+
self,
|
|
220
|
+
*,
|
|
221
|
+
node_id: str | None = None,
|
|
222
|
+
label: str | None = None,
|
|
223
|
+
) -> Callable[
|
|
224
|
+
[StreamFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
|
|
225
|
+
]: ...
|
|
226
|
+
@overload
|
|
227
|
+
def stream(
|
|
228
|
+
self,
|
|
229
|
+
call: StreamFunction[StateT, DepsT, InputT, OutputT],
|
|
230
|
+
*,
|
|
231
|
+
node_id: str | None = None,
|
|
232
|
+
label: str | None = None,
|
|
233
|
+
) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]: ...
|
|
234
|
+
@overload
|
|
235
|
+
def stream(
|
|
236
|
+
self,
|
|
237
|
+
call: StreamFunction[StateT, DepsT, InputT, OutputT] | None = None,
|
|
238
|
+
*,
|
|
239
|
+
node_id: str | None = None,
|
|
240
|
+
label: str | None = None,
|
|
241
|
+
) -> (
|
|
242
|
+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
|
|
243
|
+
| Callable[
|
|
244
|
+
[StreamFunction[StateT, DepsT, InputT, OutputT]],
|
|
245
|
+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
|
|
246
|
+
]
|
|
247
|
+
): ...
|
|
248
|
+
def stream(
|
|
249
|
+
self,
|
|
250
|
+
call: StreamFunction[StateT, DepsT, InputT, OutputT] | None = None,
|
|
251
|
+
*,
|
|
252
|
+
node_id: str | None = None,
|
|
253
|
+
label: str | None = None,
|
|
254
|
+
) -> (
|
|
255
|
+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
|
|
256
|
+
| Callable[
|
|
257
|
+
[StreamFunction[StateT, DepsT, InputT, OutputT]],
|
|
258
|
+
Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
|
|
259
|
+
]
|
|
260
|
+
):
|
|
261
|
+
"""Create a step from an async iterator (which functions like a "stream").
|
|
262
|
+
|
|
263
|
+
This method can be used as a decorator or called directly to create
|
|
264
|
+
a step node from an async function.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
call: The step function to wrap
|
|
268
|
+
node_id: Optional ID for the node
|
|
269
|
+
label: Optional human-readable label
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Either a Step instance or a decorator function
|
|
273
|
+
"""
|
|
274
|
+
if call is None:
|
|
275
|
+
|
|
276
|
+
def decorator(
|
|
277
|
+
func: StreamFunction[StateT, DepsT, InputT, OutputT],
|
|
278
|
+
) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]:
|
|
279
|
+
return self.stream(call=func, node_id=node_id, label=label)
|
|
280
|
+
|
|
281
|
+
return decorator
|
|
282
|
+
|
|
283
|
+
# We need to wrap the call so that we can call `await` even though the result is an async iterator
|
|
284
|
+
async def wrapper(ctx: StepContext[StateT, DepsT, InputT]):
|
|
285
|
+
return call(ctx)
|
|
286
|
+
|
|
287
|
+
return self.step(call=wrapper, node_id=node_id, label=label)
|
|
288
|
+
|
|
289
|
+
@overload
|
|
290
|
+
def join(
|
|
291
|
+
self,
|
|
292
|
+
reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
|
|
293
|
+
*,
|
|
294
|
+
initial: OutputT,
|
|
295
|
+
node_id: str | None = None,
|
|
296
|
+
parent_fork_id: str | None = None,
|
|
297
|
+
preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
|
|
298
|
+
) -> Join[StateT, DepsT, InputT, OutputT]: ...
|
|
299
|
+
@overload
|
|
300
|
+
def join(
|
|
301
|
+
self,
|
|
302
|
+
reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
|
|
303
|
+
*,
|
|
304
|
+
initial_factory: Callable[[], OutputT],
|
|
305
|
+
node_id: str | None = None,
|
|
306
|
+
parent_fork_id: str | None = None,
|
|
307
|
+
preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
|
|
308
|
+
) -> Join[StateT, DepsT, InputT, OutputT]: ...
|
|
309
|
+
|
|
310
|
+
def join(
|
|
311
|
+
self,
|
|
312
|
+
reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
|
|
313
|
+
*,
|
|
314
|
+
initial: OutputT | Unset = UNSET,
|
|
315
|
+
initial_factory: Callable[[], OutputT] | Unset = UNSET,
|
|
316
|
+
node_id: str | None = None,
|
|
317
|
+
parent_fork_id: str | None = None,
|
|
318
|
+
preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
|
|
319
|
+
) -> Join[StateT, DepsT, InputT, OutputT]:
|
|
320
|
+
if initial_factory is UNSET:
|
|
321
|
+
initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa E731
|
|
322
|
+
|
|
323
|
+
return Join[StateT, DepsT, InputT, OutputT](
|
|
324
|
+
id=JoinID(NodeID(node_id or generate_placeholder_node_id(get_callable_name(reducer)))),
|
|
325
|
+
reducer=reducer,
|
|
326
|
+
initial_factory=cast(Callable[[], OutputT], initial_factory),
|
|
327
|
+
parent_fork_id=ForkID(parent_fork_id) if parent_fork_id is not None else None,
|
|
328
|
+
preferred_parent_fork=preferred_parent_fork,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Edge building
|
|
332
|
+
def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa C901
|
|
333
|
+
"""Add one or more edge paths to the graph.
|
|
334
|
+
|
|
335
|
+
This method processes edge paths and automatically creates any necessary
|
|
336
|
+
fork nodes for broadcasts and maps.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
*edges: The edge paths to add to the graph
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
def _handle_path(p: Path):
|
|
343
|
+
"""Process a path and create necessary fork nodes.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
p: The path to process
|
|
347
|
+
"""
|
|
348
|
+
for item in p.items:
|
|
349
|
+
if isinstance(item, BroadcastMarker):
|
|
350
|
+
new_node = Fork[Any, Any](id=item.fork_id, is_map=False, downstream_join_id=None)
|
|
351
|
+
self._insert_node(new_node)
|
|
352
|
+
for path in item.paths:
|
|
353
|
+
_handle_path(Path(items=[*path.items]))
|
|
354
|
+
elif isinstance(item, MapMarker):
|
|
355
|
+
new_node = Fork[Any, Any](id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id)
|
|
356
|
+
self._insert_node(new_node)
|
|
357
|
+
elif isinstance(item, DestinationMarker):
|
|
358
|
+
pass
|
|
359
|
+
|
|
360
|
+
def _handle_destination_node(d: AnyDestinationNode):
|
|
361
|
+
if id(d) in destination_ids:
|
|
362
|
+
return # prevent infinite recursion if there is a cycle of decisions
|
|
363
|
+
|
|
364
|
+
destination_ids.add(id(d))
|
|
365
|
+
destinations.append(d)
|
|
366
|
+
self._insert_node(d)
|
|
367
|
+
if isinstance(d, Decision):
|
|
368
|
+
for branch in d.branches:
|
|
369
|
+
_handle_path(branch.path)
|
|
370
|
+
for d2 in branch.destinations:
|
|
371
|
+
_handle_destination_node(d2)
|
|
372
|
+
|
|
373
|
+
destination_ids = set[int]()
|
|
374
|
+
destinations: list[AnyDestinationNode] = []
|
|
375
|
+
for edge in edges:
|
|
376
|
+
for source_node in edge.sources:
|
|
377
|
+
self._insert_node(source_node)
|
|
378
|
+
self._edges_by_source[source_node.id].append(edge.path)
|
|
379
|
+
for destination_node in edge.destinations:
|
|
380
|
+
_handle_destination_node(destination_node)
|
|
381
|
+
_handle_path(edge.path)
|
|
382
|
+
|
|
383
|
+
# Automatically create edges from step function return hints including `BaseNode`s
|
|
384
|
+
for destination in destinations:
|
|
385
|
+
if not isinstance(destination, Step) or isinstance(destination, NodeStep):
|
|
386
|
+
continue
|
|
387
|
+
parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
|
|
388
|
+
type_hints = get_type_hints(destination.call, localns=parent_namespace, include_extras=True)
|
|
389
|
+
try:
|
|
390
|
+
return_hint = type_hints['return']
|
|
391
|
+
except KeyError:
|
|
392
|
+
pass
|
|
393
|
+
else:
|
|
394
|
+
edge = self._edge_from_return_hint(destination, return_hint)
|
|
395
|
+
if edge is not None:
|
|
396
|
+
self.add(edge)
|
|
397
|
+
|
|
398
|
+
def add_edge(self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None:
|
|
399
|
+
"""Add a simple edge between two nodes.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
source: The source node
|
|
403
|
+
destination: The destination node
|
|
404
|
+
label: Optional label for the edge
|
|
405
|
+
"""
|
|
406
|
+
builder = self.edge_from(source)
|
|
407
|
+
if label is not None:
|
|
408
|
+
builder = builder.label(label)
|
|
409
|
+
self.add(builder.to(destination))
|
|
410
|
+
|
|
411
|
+
def add_mapping_edge(
|
|
412
|
+
self,
|
|
413
|
+
source: Source[Iterable[T]],
|
|
414
|
+
map_to: Destination[T],
|
|
415
|
+
*,
|
|
416
|
+
pre_map_label: str | None = None,
|
|
417
|
+
post_map_label: str | None = None,
|
|
418
|
+
fork_id: ForkID | None = None,
|
|
419
|
+
downstream_join_id: JoinID | None = None,
|
|
420
|
+
) -> None:
|
|
421
|
+
"""Add an edge that maps iterable data across parallel paths.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
source: The source node that produces iterable data
|
|
425
|
+
map_to: The destination node that receives individual items
|
|
426
|
+
pre_map_label: Optional label before the map operation
|
|
427
|
+
post_map_label: Optional label after the map operation
|
|
428
|
+
fork_id: Optional ID for the fork node produced for this map operation
|
|
429
|
+
downstream_join_id: Optional ID of a join node that will always be downstream of this map.
|
|
430
|
+
Specifying this ensures correct handling if you try to map an empty iterable.
|
|
431
|
+
"""
|
|
432
|
+
builder = self.edge_from(source)
|
|
433
|
+
if pre_map_label is not None:
|
|
434
|
+
builder = builder.label(pre_map_label)
|
|
435
|
+
builder = builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id)
|
|
436
|
+
if post_map_label is not None:
|
|
437
|
+
builder = builder.label(post_map_label)
|
|
438
|
+
self.add(builder.to(map_to))
|
|
439
|
+
|
|
440
|
+
# TODO(DavidM): Support adding subgraphs; I think this behaves like a step with the same inputs/outputs but gets rendered as a subgraph in mermaid
|
|
441
|
+
|
|
442
|
+
def edge_from(self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, DepsT, SourceOutputT]:
|
|
443
|
+
"""Create an edge path builder starting from the given source nodes.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
*sources: The source nodes to start the edge path from
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
An EdgePathBuilder for constructing the complete edge path
|
|
450
|
+
"""
|
|
451
|
+
return EdgePathBuilder[StateT, DepsT, SourceOutputT](
|
|
452
|
+
sources=sources, path_builder=PathBuilder(working_items=[])
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def decision(self, *, note: str | None = None, node_id: str | None = None) -> Decision[StateT, DepsT, Never]:
|
|
456
|
+
"""Create a new decision node.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
note: Optional note to describe the decision logic
|
|
460
|
+
node_id: Optional ID for the node produced for this decision logic
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
A new Decision node with no branches
|
|
464
|
+
"""
|
|
465
|
+
return Decision(id=NodeID(node_id or generate_placeholder_node_id('decision')), branches=[], note=note)
|
|
466
|
+
|
|
467
|
+
def match(
|
|
468
|
+
self,
|
|
469
|
+
source: TypeOrTypeExpression[SourceT],
|
|
470
|
+
*,
|
|
471
|
+
matches: Callable[[Any], bool] | None = None,
|
|
472
|
+
) -> DecisionBranchBuilder[StateT, DepsT, SourceT, SourceT, Never]:
|
|
473
|
+
"""Create a decision branch matcher.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
source: The type or type expression to match against
|
|
477
|
+
matches: Optional custom matching function
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
A DecisionBranchBuilder for constructing the branch
|
|
481
|
+
"""
|
|
482
|
+
# Note, the following node_id really is just a placeholder and shouldn't end up in the final graph
|
|
483
|
+
# This is why we don't expose a way for end users to override the value used here.
|
|
484
|
+
node_id = NodeID(generate_placeholder_node_id('match_decision'))
|
|
485
|
+
decision = Decision[StateT, DepsT, Never](id=node_id, branches=[], note=None)
|
|
486
|
+
new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[])
|
|
487
|
+
return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder)
|
|
488
|
+
|
|
489
|
+
def match_node(
|
|
490
|
+
self,
|
|
491
|
+
source: type[SourceNodeT],
|
|
492
|
+
*,
|
|
493
|
+
matches: Callable[[Any], bool] | None = None,
|
|
494
|
+
) -> DecisionBranch[SourceNodeT]:
|
|
495
|
+
"""Create a decision branch for BaseNode subclasses.
|
|
496
|
+
|
|
497
|
+
This is similar to match() but specifically designed for matching
|
|
498
|
+
against BaseNode types from the v1 system.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
source: The BaseNode subclass to match against
|
|
502
|
+
matches: Optional custom matching function
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
A DecisionBranch for the BaseNode type
|
|
506
|
+
"""
|
|
507
|
+
node = NodeStep(source)
|
|
508
|
+
path = Path(items=[DestinationMarker(node.id)])
|
|
509
|
+
return DecisionBranch(source=source, matches=matches, path=path, destinations=[node])
|
|
510
|
+
|
|
511
|
+
def node(
|
|
512
|
+
self,
|
|
513
|
+
node_type: type[BaseNode[StateT, DepsT, GraphOutputT]],
|
|
514
|
+
) -> EdgePath[StateT, DepsT]:
|
|
515
|
+
"""Create an edge path from a BaseNode class.
|
|
516
|
+
|
|
517
|
+
This method integrates v1-style BaseNode classes into the v2 graph
|
|
518
|
+
system by analyzing their type hints and creating appropriate edges.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
node_type: The BaseNode subclass to integrate
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
An EdgePath representing the node and its connections
|
|
525
|
+
|
|
526
|
+
Raises:
|
|
527
|
+
GraphSetupError: If the node type is missing required type hints
|
|
528
|
+
"""
|
|
529
|
+
parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
|
|
530
|
+
type_hints = get_type_hints(node_type.run, localns=parent_namespace, include_extras=True)
|
|
531
|
+
try:
|
|
532
|
+
return_hint = type_hints['return']
|
|
533
|
+
except KeyError as e: # pragma: no cover
|
|
534
|
+
raise exceptions.GraphSetupError(
|
|
535
|
+
f'Node {node_type} is missing a return type hint on its `run` method'
|
|
536
|
+
) from e
|
|
537
|
+
|
|
538
|
+
node = NodeStep(node_type)
|
|
539
|
+
|
|
540
|
+
edge = self._edge_from_return_hint(node, return_hint)
|
|
541
|
+
if not edge: # pragma: no cover
|
|
542
|
+
raise exceptions.GraphSetupError(f'Node {node_type} is missing a return type hint on its `run` method')
|
|
543
|
+
|
|
544
|
+
return edge
|
|
545
|
+
|
|
546
|
+
# Helpers
|
|
547
|
+
def _insert_node(self, node: AnyNode) -> None:
|
|
548
|
+
"""Insert a node into the graph, checking for ID conflicts.
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
node: The node to insert
|
|
552
|
+
|
|
553
|
+
Raises:
|
|
554
|
+
ValueError: If a different node with the same ID already exists
|
|
555
|
+
"""
|
|
556
|
+
existing = self._nodes.get(node.id)
|
|
557
|
+
if existing is None:
|
|
558
|
+
self._nodes[node.id] = node
|
|
559
|
+
elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type:
|
|
560
|
+
pass
|
|
561
|
+
elif existing is not node:
|
|
562
|
+
raise GraphBuildingError(
|
|
563
|
+
f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}'
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
def _edge_from_return_hint(
|
|
567
|
+
self, node: SourceNode[StateT, DepsT, Any], return_hint: TypeOrTypeExpression[Any]
|
|
568
|
+
) -> EdgePath[StateT, DepsT] | None:
|
|
569
|
+
"""Create edges from a return type hint.
|
|
570
|
+
|
|
571
|
+
This method analyzes return type hints from step functions or node methods
|
|
572
|
+
to automatically create appropriate edges in the graph.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
node: The source node
|
|
576
|
+
return_hint: The return type hint to analyze
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
An EdgePath if edges can be inferred, None otherwise
|
|
580
|
+
|
|
581
|
+
Raises:
|
|
582
|
+
GraphSetupError: If the return type hint is invalid or incomplete
|
|
583
|
+
"""
|
|
584
|
+
destinations: list[AnyDestinationNode] = []
|
|
585
|
+
union_args = _utils.get_union_args(return_hint)
|
|
586
|
+
for return_type in union_args:
|
|
587
|
+
return_type, annotations = _utils.unpack_annotated(return_type)
|
|
588
|
+
return_type_origin = get_origin(return_type) or return_type
|
|
589
|
+
if return_type_origin is End:
|
|
590
|
+
destinations.append(self.end_node)
|
|
591
|
+
elif return_type_origin is BaseNode:
|
|
592
|
+
raise exceptions.GraphSetupError( # pragma: no cover
|
|
593
|
+
f'Node {node} return type hint includes a plain `BaseNode`. '
|
|
594
|
+
'Edge inference requires each possible returned `BaseNode` subclass to be listed explicitly.'
|
|
595
|
+
)
|
|
596
|
+
elif return_type_origin is StepNode:
|
|
597
|
+
step = cast(
|
|
598
|
+
Step[StateT, DepsT, Any, Any] | None,
|
|
599
|
+
next((a for a in annotations if isinstance(a, Step)), None), # pyright: ignore[reportUnknownArgumentType]
|
|
600
|
+
)
|
|
601
|
+
if step is None:
|
|
602
|
+
raise exceptions.GraphSetupError( # pragma: no cover
|
|
603
|
+
f'Node {node} return type hint includes a `StepNode` without a `Step` annotation. '
|
|
604
|
+
'When returning `my_step.as_node()`, use `Annotated[StepNode[StateT, DepsT], my_step]` as the return type hint.'
|
|
605
|
+
)
|
|
606
|
+
destinations.append(step)
|
|
607
|
+
elif return_type_origin is JoinNode:
|
|
608
|
+
join = cast(
|
|
609
|
+
Join[StateT, DepsT, Any, Any] | None,
|
|
610
|
+
next((a for a in annotations if isinstance(a, Join)), None), # pyright: ignore[reportUnknownArgumentType]
|
|
611
|
+
)
|
|
612
|
+
if join is None:
|
|
613
|
+
raise exceptions.GraphSetupError( # pragma: no cover
|
|
614
|
+
f'Node {node} return type hint includes a `JoinNode` without a `Join` annotation. '
|
|
615
|
+
'When returning `my_join.as_node()`, use `Annotated[JoinNode[StateT, DepsT], my_join]` as the return type hint.'
|
|
616
|
+
)
|
|
617
|
+
destinations.append(join)
|
|
618
|
+
elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode):
|
|
619
|
+
destinations.append(NodeStep(return_type))
|
|
620
|
+
|
|
621
|
+
if len(destinations) < len(union_args):
|
|
622
|
+
# Only build edges if all the return types are nodes
|
|
623
|
+
return None
|
|
624
|
+
|
|
625
|
+
edge = self.edge_from(node)
|
|
626
|
+
if len(destinations) == 1:
|
|
627
|
+
return edge.to(destinations[0])
|
|
628
|
+
else:
|
|
629
|
+
decision = self.decision()
|
|
630
|
+
for destination in destinations:
|
|
631
|
+
# We don't actually use this decision mechanism, but we need to build the edges for parent-fork finding
|
|
632
|
+
decision = decision.branch(self.match(NoneType).to(destination))
|
|
633
|
+
return edge.to(decision)
|
|
634
|
+
|
|
635
|
+
# Graph building
|
|
636
|
+
def build(self, validate_graph_structure: bool = True) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
|
|
637
|
+
"""Build the final executable graph from the accumulated nodes and edges.
|
|
638
|
+
|
|
639
|
+
This method performs validation, normalization, and analysis of the graph
|
|
640
|
+
structure to create a complete, executable graph instance.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
validate_graph_structure: whether to perform validation of the graph structure
|
|
644
|
+
See the docstring of _validate_graph_structure below for more details.
|
|
645
|
+
|
|
646
|
+
Returns:
|
|
647
|
+
A complete Graph instance ready for execution
|
|
648
|
+
|
|
649
|
+
Raises:
|
|
650
|
+
ValueError: If the graph structure is invalid (e.g., join without parent fork)
|
|
651
|
+
"""
|
|
652
|
+
nodes = self._nodes
|
|
653
|
+
edges_by_source = self._edges_by_source
|
|
654
|
+
|
|
655
|
+
nodes, edges_by_source = _replace_placeholder_node_ids(nodes, edges_by_source)
|
|
656
|
+
nodes, edges_by_source = _flatten_paths(nodes, edges_by_source)
|
|
657
|
+
nodes, edges_by_source = _normalize_forks(nodes, edges_by_source)
|
|
658
|
+
if validate_graph_structure:
|
|
659
|
+
_validate_graph_structure(nodes, edges_by_source)
|
|
660
|
+
parent_forks = _collect_dominating_forks(nodes, edges_by_source)
|
|
661
|
+
intermediate_join_nodes = _compute_intermediate_join_nodes(nodes, parent_forks)
|
|
662
|
+
|
|
663
|
+
return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
|
|
664
|
+
name=self.name,
|
|
665
|
+
state_type=unpack_type_expression(self.state_type),
|
|
666
|
+
deps_type=unpack_type_expression(self.deps_type),
|
|
667
|
+
input_type=unpack_type_expression(self.input_type),
|
|
668
|
+
output_type=unpack_type_expression(self.output_type),
|
|
669
|
+
nodes=nodes,
|
|
670
|
+
edges_by_source=edges_by_source,
|
|
671
|
+
parent_forks=parent_forks,
|
|
672
|
+
intermediate_join_nodes=intermediate_join_nodes,
|
|
673
|
+
auto_instrument=self.auto_instrument,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def _validate_graph_structure( # noqa C901
|
|
678
|
+
nodes: dict[NodeID, AnyNode],
|
|
679
|
+
edges_by_source: dict[NodeID, list[Path]],
|
|
680
|
+
) -> None:
|
|
681
|
+
"""Validate the graph structure for common issues.
|
|
682
|
+
|
|
683
|
+
This function raises an error if any of the following criteria are not met:
|
|
684
|
+
1. There are edges from the start node
|
|
685
|
+
2. There are edges to the end node
|
|
686
|
+
3. No non-End node is a dead end (no outgoing edges)
|
|
687
|
+
4. The end node is reachable from the start node
|
|
688
|
+
5. All nodes are reachable from the start node
|
|
689
|
+
|
|
690
|
+
Note 1: Under some circumstances it may be reasonable to build a graph that violates one or more of
|
|
691
|
+
the above conditions. We may eventually add support for more granular control over validation,
|
|
692
|
+
but today, if you want to build a graph that violates any of these assumptions you need to pass
|
|
693
|
+
`validate_graph_structure=False` to the call to `GraphBuilder.build`.
|
|
694
|
+
|
|
695
|
+
Note 2: Some of the earlier items in the above list are redundant with the later items.
|
|
696
|
+
I've included the earlier items in the list as a reminder to ourselves if/when we add more granular validation
|
|
697
|
+
because you might want to check the earlier items but not the later items, as described in Note 1.
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
nodes: The nodes in the graph
|
|
701
|
+
edges_by_source: The edges by source node
|
|
702
|
+
|
|
703
|
+
Raises:
|
|
704
|
+
GraphBuildingError: If any of the aforementioned structural issues are found.
|
|
705
|
+
"""
|
|
706
|
+
how_to_suppress = ' If this is intentional, you can suppress this error by passing `validate_graph_structure=False` to the call to `GraphBuilder.build`.'
|
|
707
|
+
|
|
708
|
+
# Extract all destination IDs from edges and decision branches
|
|
709
|
+
all_destinations: set[NodeID] = set()
|
|
710
|
+
|
|
711
|
+
def _collect_destinations_from_path(path: Path) -> None:
|
|
712
|
+
for item in path.items:
|
|
713
|
+
if isinstance(item, DestinationMarker):
|
|
714
|
+
all_destinations.add(item.destination_id)
|
|
715
|
+
|
|
716
|
+
for paths in edges_by_source.values():
|
|
717
|
+
for path in paths:
|
|
718
|
+
_collect_destinations_from_path(path)
|
|
719
|
+
|
|
720
|
+
# Also collect destinations from decision branches
|
|
721
|
+
for node in nodes.values():
|
|
722
|
+
if isinstance(node, Decision):
|
|
723
|
+
for branch in node.branches:
|
|
724
|
+
_collect_destinations_from_path(branch.path)
|
|
725
|
+
|
|
726
|
+
# Check 1: Check if there are edges from the start node
|
|
727
|
+
start_edges = edges_by_source.get(StartNode.id, [])
|
|
728
|
+
if not start_edges:
|
|
729
|
+
raise GraphValidationError('The graph has no edges from the start node.' + how_to_suppress)
|
|
730
|
+
|
|
731
|
+
# Check 2: Check if there are edges to the end node
|
|
732
|
+
if EndNode.id not in all_destinations:
|
|
733
|
+
raise GraphValidationError('The graph has no edges to the end node.' + how_to_suppress)
|
|
734
|
+
|
|
735
|
+
# Check 3: Find all nodes with no outgoing edges (dead ends)
|
|
736
|
+
dead_end_nodes: list[NodeID] = []
|
|
737
|
+
for node_id, node in nodes.items():
|
|
738
|
+
# Skip the end node itself
|
|
739
|
+
if isinstance(node, EndNode):
|
|
740
|
+
continue
|
|
741
|
+
|
|
742
|
+
# Check if this node has any outgoing edges
|
|
743
|
+
has_edges = node_id in edges_by_source and len(edges_by_source[node_id]) > 0
|
|
744
|
+
|
|
745
|
+
# Also check if it's a decision node with branches
|
|
746
|
+
if isinstance(node, Decision):
|
|
747
|
+
has_edges = has_edges or len(node.branches) > 0
|
|
748
|
+
|
|
749
|
+
if not has_edges:
|
|
750
|
+
dead_end_nodes.append(node_id)
|
|
751
|
+
|
|
752
|
+
if dead_end_nodes:
|
|
753
|
+
raise GraphValidationError(f'The following nodes have no outgoing edges: {dead_end_nodes}.' + how_to_suppress)
|
|
754
|
+
|
|
755
|
+
# Checks 4 and 5: Ensure all nodes (and in particular, the end node) are reachable from the start node
|
|
756
|
+
reachable: set[NodeID] = {StartNode.id}
|
|
757
|
+
to_visit = [StartNode.id]
|
|
758
|
+
|
|
759
|
+
while to_visit:
|
|
760
|
+
current_id = to_visit.pop()
|
|
761
|
+
|
|
762
|
+
# Add destinations from regular edges
|
|
763
|
+
for path in edges_by_source.get(current_id, []):
|
|
764
|
+
for item in path.items:
|
|
765
|
+
if isinstance(item, DestinationMarker):
|
|
766
|
+
if item.destination_id not in reachable:
|
|
767
|
+
reachable.add(item.destination_id)
|
|
768
|
+
to_visit.append(item.destination_id)
|
|
769
|
+
|
|
770
|
+
# Add destinations from decision branches
|
|
771
|
+
current_node = nodes.get(current_id)
|
|
772
|
+
if isinstance(current_node, Decision):
|
|
773
|
+
for branch in current_node.branches:
|
|
774
|
+
for item in branch.path.items:
|
|
775
|
+
if isinstance(item, DestinationMarker):
|
|
776
|
+
if item.destination_id not in reachable:
|
|
777
|
+
reachable.add(item.destination_id)
|
|
778
|
+
to_visit.append(item.destination_id)
|
|
779
|
+
|
|
780
|
+
unreachable_nodes = [node_id for node_id in nodes if node_id not in reachable]
|
|
781
|
+
if unreachable_nodes:
|
|
782
|
+
raise GraphValidationError(
|
|
783
|
+
f'The following nodes are not reachable from the start node: {unreachable_nodes}.' + how_to_suppress
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def _flatten_paths(
|
|
788
|
+
nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
|
|
789
|
+
) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:
|
|
790
|
+
new_nodes = nodes.copy()
|
|
791
|
+
new_edges: dict[NodeID, list[Path]] = defaultdict(list)
|
|
792
|
+
|
|
793
|
+
paths_to_handle: list[tuple[NodeID, Path]] = []
|
|
794
|
+
|
|
795
|
+
def _split_at_first_fork(path: Path) -> tuple[Path, list[tuple[NodeID, Path]]]:
|
|
796
|
+
for i, item in enumerate(path.items):
|
|
797
|
+
if isinstance(item, MapMarker):
|
|
798
|
+
assert item.fork_id in nodes, 'This should have been added to the node during GraphBuilder.add'
|
|
799
|
+
upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)])
|
|
800
|
+
downstream = Path(path.items[i + 1 :])
|
|
801
|
+
return upstream, [(item.fork_id, downstream)]
|
|
802
|
+
|
|
803
|
+
if isinstance(item, BroadcastMarker):
|
|
804
|
+
assert item.fork_id in nodes, 'This should have been added to the node during GraphBuilder.add'
|
|
805
|
+
upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)])
|
|
806
|
+
return upstream, [(item.fork_id, p) for p in item.paths]
|
|
807
|
+
return path, []
|
|
808
|
+
|
|
809
|
+
for node in new_nodes.values():
|
|
810
|
+
if isinstance(node, Decision):
|
|
811
|
+
for branch in node.branches:
|
|
812
|
+
upstream, downstreams = _split_at_first_fork(branch.path)
|
|
813
|
+
branch.path = upstream
|
|
814
|
+
paths_to_handle.extend(downstreams)
|
|
815
|
+
|
|
816
|
+
for source_id, edges_from_source in edges.items():
|
|
817
|
+
for path in edges_from_source:
|
|
818
|
+
paths_to_handle.append((source_id, path))
|
|
819
|
+
|
|
820
|
+
while paths_to_handle:
|
|
821
|
+
source_id, path = paths_to_handle.pop()
|
|
822
|
+
upstream, downstreams = _split_at_first_fork(path)
|
|
823
|
+
new_edges[source_id].append(upstream)
|
|
824
|
+
paths_to_handle.extend(downstreams)
|
|
825
|
+
|
|
826
|
+
return new_nodes, dict(new_edges)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def _normalize_forks(
|
|
830
|
+
nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
|
|
831
|
+
) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:
|
|
832
|
+
"""Normalize the graph structure so only broadcast forks have multiple outgoing edges.
|
|
833
|
+
|
|
834
|
+
This function ensures that any node with multiple outgoing edges is converted
|
|
835
|
+
to use an explicit broadcast fork, simplifying the graph execution model.
|
|
836
|
+
|
|
837
|
+
Args:
|
|
838
|
+
nodes: The nodes in the graph
|
|
839
|
+
edges: The edges by source node
|
|
840
|
+
|
|
841
|
+
Returns:
|
|
842
|
+
A tuple of normalized nodes and edges
|
|
843
|
+
"""
|
|
844
|
+
new_nodes = nodes.copy()
|
|
845
|
+
new_edges: dict[NodeID, list[Path]] = {}
|
|
846
|
+
|
|
847
|
+
paths_to_handle: list[Path] = []
|
|
848
|
+
|
|
849
|
+
for source_id, edges_from_source in edges.items():
|
|
850
|
+
paths_to_handle.extend(edges_from_source)
|
|
851
|
+
|
|
852
|
+
node = nodes[source_id]
|
|
853
|
+
if isinstance(node, Fork) and not node.is_map:
|
|
854
|
+
new_edges[source_id] = edges_from_source
|
|
855
|
+
continue # broadcast fork; nothing to do
|
|
856
|
+
if len(edges_from_source) == 1:
|
|
857
|
+
new_edges[source_id] = edges_from_source
|
|
858
|
+
continue
|
|
859
|
+
new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False, downstream_join_id=None)
|
|
860
|
+
new_nodes[new_fork.id] = new_fork
|
|
861
|
+
new_edges[source_id] = [Path(items=[DestinationMarker(new_fork.id)])]
|
|
862
|
+
new_edges[new_fork.id] = edges_from_source
|
|
863
|
+
|
|
864
|
+
return new_nodes, new_edges
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
def _collect_dominating_forks(
|
|
868
|
+
graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
|
|
869
|
+
) -> dict[JoinID, ParentFork[NodeID]]:
|
|
870
|
+
"""Find the dominating fork for each join node in the graph.
|
|
871
|
+
|
|
872
|
+
This function analyzes the graph structure to find the parent fork that
|
|
873
|
+
dominates each join node, which is necessary for proper synchronization
|
|
874
|
+
during graph execution.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
graph_nodes: All nodes in the graph
|
|
878
|
+
graph_edges_by_source: Edges organized by source node
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
A mapping from join IDs to their parent fork information
|
|
882
|
+
|
|
883
|
+
Raises:
|
|
884
|
+
ValueError: If any join node lacks a dominating fork
|
|
885
|
+
"""
|
|
886
|
+
nodes = set(graph_nodes)
|
|
887
|
+
start_ids: set[NodeID] = {StartNode.id}
|
|
888
|
+
edges: dict[NodeID, list[NodeID]] = defaultdict(list)
|
|
889
|
+
|
|
890
|
+
fork_ids: set[NodeID] = set(start_ids)
|
|
891
|
+
for source_id in nodes:
|
|
892
|
+
working_source_id = source_id
|
|
893
|
+
node = graph_nodes.get(source_id)
|
|
894
|
+
|
|
895
|
+
if isinstance(node, Fork):
|
|
896
|
+
fork_ids.add(node.id)
|
|
897
|
+
|
|
898
|
+
def _handle_path(path: Path, last_source_id: NodeID):
|
|
899
|
+
"""Process a path and collect edges and fork information.
|
|
900
|
+
|
|
901
|
+
Args:
|
|
902
|
+
path: The path to process
|
|
903
|
+
last_source_id: The current source node ID
|
|
904
|
+
"""
|
|
905
|
+
for item in path.items: # pragma: no branch
|
|
906
|
+
# No need to handle MapMarker or BroadcastMarker here as these should have all been removed
|
|
907
|
+
# by the call to `_flatten_paths`
|
|
908
|
+
if isinstance(item, DestinationMarker):
|
|
909
|
+
edges[last_source_id].append(item.destination_id)
|
|
910
|
+
# Destinations should only ever occur as the last item in the list, so no need to update the working_source_id
|
|
911
|
+
break
|
|
912
|
+
|
|
913
|
+
if isinstance(node, Decision):
|
|
914
|
+
for branch in node.branches:
|
|
915
|
+
_handle_path(branch.path, working_source_id)
|
|
916
|
+
else:
|
|
917
|
+
for path in graph_edges_by_source.get(source_id, []):
|
|
918
|
+
_handle_path(path, source_id)
|
|
919
|
+
|
|
920
|
+
finder = ParentForkFinder(
|
|
921
|
+
nodes=nodes,
|
|
922
|
+
start_ids=start_ids,
|
|
923
|
+
fork_ids=fork_ids,
|
|
924
|
+
edges=edges,
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
joins = [node for node in graph_nodes.values() if isinstance(node, Join)]
|
|
928
|
+
dominating_forks: dict[JoinID, ParentFork[NodeID]] = {}
|
|
929
|
+
for join in joins:
|
|
930
|
+
dominating_fork = finder.find_parent_fork(
|
|
931
|
+
join.id, parent_fork_id=join.parent_fork_id, prefer_closest=join.preferred_parent_fork == 'closest'
|
|
932
|
+
)
|
|
933
|
+
if dominating_fork is None:
|
|
934
|
+
rendered_mermaid_graph = build_mermaid_graph(graph_nodes, graph_edges_by_source).render()
|
|
935
|
+
raise GraphBuildingError(f"""A node in the graph is missing a dominating fork.
|
|
936
|
+
|
|
937
|
+
For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying:
|
|
938
|
+
* Every path from the StartNode to J passes through F
|
|
939
|
+
* There are no cycles in the graph including J that don't pass through F.
|
|
940
|
+
In this case, F is called a "dominating fork" for J.
|
|
941
|
+
|
|
942
|
+
This is used to determine when all tasks upstream of this Join are complete and we can proceed with execution.
|
|
943
|
+
|
|
944
|
+
Mermaid diagram:
|
|
945
|
+
{rendered_mermaid_graph}
|
|
946
|
+
|
|
947
|
+
Join {join.id!r} in this graph has no dominating fork in this graph.""")
|
|
948
|
+
dominating_forks[join.id] = dominating_fork
|
|
949
|
+
|
|
950
|
+
return dominating_forks
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def _compute_intermediate_join_nodes(
|
|
954
|
+
nodes: dict[NodeID, AnyNode], parent_forks: dict[JoinID, ParentFork[NodeID]]
|
|
955
|
+
) -> dict[JoinID, set[JoinID]]:
|
|
956
|
+
"""Compute which joins have other joins as intermediate nodes.
|
|
957
|
+
|
|
958
|
+
A join J1 is an intermediate node of join J2 if J1 appears in J2's intermediate_nodes
|
|
959
|
+
(as computed relative to J2's parent fork).
|
|
960
|
+
|
|
961
|
+
This information is used to determine:
|
|
962
|
+
1. Which joins are "final" (have no other joins in their intermediate_nodes)
|
|
963
|
+
2. When selecting which reducer to proceed with when there are no active tasks
|
|
964
|
+
|
|
965
|
+
Args:
|
|
966
|
+
nodes: All nodes in the graph
|
|
967
|
+
parent_forks: Parent fork information for each join
|
|
968
|
+
|
|
969
|
+
Returns:
|
|
970
|
+
A mapping from each join to the set of joins that are intermediate to it
|
|
971
|
+
"""
|
|
972
|
+
intermediate_join_nodes: dict[JoinID, set[JoinID]] = {}
|
|
973
|
+
|
|
974
|
+
for join_id, parent_fork in parent_forks.items():
|
|
975
|
+
intermediate_joins = set[JoinID]()
|
|
976
|
+
for intermediate_node_id in parent_fork.intermediate_nodes:
|
|
977
|
+
# Check if this intermediate node is also a join
|
|
978
|
+
intermediate_node = nodes.get(intermediate_node_id)
|
|
979
|
+
if isinstance(intermediate_node, Join):
|
|
980
|
+
# Add it regardless of whether it has the same parent fork
|
|
981
|
+
intermediate_joins.add(JoinID(intermediate_node_id))
|
|
982
|
+
intermediate_join_nodes[join_id] = intermediate_joins
|
|
983
|
+
|
|
984
|
+
return intermediate_join_nodes
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]):
|
|
988
|
+
node_id_remapping = _build_placeholder_node_id_remapping(nodes)
|
|
989
|
+
replaced_nodes = {
|
|
990
|
+
node_id_remapping.get(name, name): _update_node_with_id_remapping(node, node_id_remapping)
|
|
991
|
+
for name, node in nodes.items()
|
|
992
|
+
}
|
|
993
|
+
replaced_edges_by_source = {
|
|
994
|
+
node_id_remapping.get(source, source): [_update_path_with_id_remapping(p, node_id_remapping) for p in paths]
|
|
995
|
+
for source, paths in edges_by_source.items()
|
|
996
|
+
}
|
|
997
|
+
return replaced_nodes, replaced_edges_by_source
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
def _build_placeholder_node_id_remapping(nodes: dict[NodeID, AnyNode]) -> dict[NodeID, NodeID]:
|
|
1001
|
+
"""The determinism of the generated remapping here is dependent on the determinism of the ordering of the `nodes` dict.
|
|
1002
|
+
|
|
1003
|
+
Note: If we want to generate more interesting names, we could try to make use of information about the edges
|
|
1004
|
+
into/out of the relevant nodes. I'm not sure if there's a good use case for that though so I didn't bother for now.
|
|
1005
|
+
"""
|
|
1006
|
+
counter = Counter[str]()
|
|
1007
|
+
remapping: dict[NodeID, NodeID] = {}
|
|
1008
|
+
for node_id in nodes.keys():
|
|
1009
|
+
replaced_node_id = replace_placeholder_id(node_id)
|
|
1010
|
+
if replaced_node_id == node_id:
|
|
1011
|
+
continue
|
|
1012
|
+
counter[replaced_node_id] = count = counter[replaced_node_id] + 1
|
|
1013
|
+
remapping[node_id] = NodeID(f'{replaced_node_id}_{count}' if count > 1 else replaced_node_id)
|
|
1014
|
+
return remapping
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
def _update_node_with_id_remapping(node: AnyNode, node_id_remapping: dict[NodeID, NodeID]) -> AnyNode:
|
|
1018
|
+
# Note: it's a bit awkward that we mutate the provided nodes, but this is necessary to ensure that
|
|
1019
|
+
# calls to `.as_node` reference the correct node_ids when relying on compatibility with the v1 API.
|
|
1020
|
+
# We only mutate placeholder IDs so I _think_ this should generally be okay. I guess we can
|
|
1021
|
+
# rework it more carefully if it causes issues in the future..
|
|
1022
|
+
if isinstance(node, Step):
|
|
1023
|
+
node.id = node_id_remapping.get(node.id, node.id)
|
|
1024
|
+
elif isinstance(node, Join):
|
|
1025
|
+
node.id = JoinID(node_id_remapping.get(node.id, node.id))
|
|
1026
|
+
elif isinstance(node, Fork):
|
|
1027
|
+
node.id = ForkID(node_id_remapping.get(node.id, node.id))
|
|
1028
|
+
if node.downstream_join_id is not None:
|
|
1029
|
+
node.downstream_join_id = JoinID(node_id_remapping.get(node.downstream_join_id, node.downstream_join_id))
|
|
1030
|
+
elif isinstance(node, Decision):
|
|
1031
|
+
node.id = node_id_remapping.get(node.id, node.id)
|
|
1032
|
+
node.branches = [
|
|
1033
|
+
replace(branch, path=_update_path_with_id_remapping(branch.path, node_id_remapping))
|
|
1034
|
+
for branch in node.branches
|
|
1035
|
+
]
|
|
1036
|
+
return node
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def _update_path_with_id_remapping(path: Path, node_id_remapping: dict[NodeID, NodeID]) -> Path:
|
|
1040
|
+
# Note: we have already deepcopied the node provided to this function so it should be okay to make mutations,
|
|
1041
|
+
# this could change if we change the code surrounding the code paths leading to this function call though.
|
|
1042
|
+
for item in path.items:
|
|
1043
|
+
if isinstance(item, MapMarker):
|
|
1044
|
+
downstream_join_id = item.downstream_join_id
|
|
1045
|
+
if downstream_join_id is not None:
|
|
1046
|
+
item.downstream_join_id = JoinID(node_id_remapping.get(downstream_join_id, downstream_join_id))
|
|
1047
|
+
item.fork_id = ForkID(node_id_remapping.get(item.fork_id, item.fork_id))
|
|
1048
|
+
elif isinstance(item, BroadcastMarker):
|
|
1049
|
+
item.fork_id = ForkID(node_id_remapping.get(item.fork_id, item.fork_id))
|
|
1050
|
+
item.paths = [_update_path_with_id_remapping(p, node_id_remapping) for p in item.paths]
|
|
1051
|
+
elif isinstance(item, DestinationMarker):
|
|
1052
|
+
item.destination_id = node_id_remapping.get(item.destination_id, item.destination_id)
|
|
1053
|
+
return path
|