pydantic-graph 0.7.5__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1053 @@
1
+ """Graph builder for constructing executable graph definitions.
2
+
3
+ This module provides the GraphBuilder class and related utilities for
4
+ constructing typed, executable graph definitions with steps, joins,
5
+ decisions, and edge routing.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import inspect
11
+ from collections import Counter, defaultdict
12
+ from collections.abc import AsyncIterable, Callable, Iterable
13
+ from dataclasses import dataclass, replace
14
+ from types import NoneType
15
+ from typing import Any, Generic, Literal, cast, get_origin, get_type_hints, overload
16
+
17
+ from typing_extensions import Never, TypeAliasType, TypeVar
18
+
19
+ from pydantic_graph import _utils, exceptions
20
+ from pydantic_graph._utils import UNSET, Unset
21
+ from pydantic_graph.beta.decision import Decision, DecisionBranch, DecisionBranchBuilder
22
+ from pydantic_graph.beta.graph import Graph
23
+ from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID, generate_placeholder_node_id, replace_placeholder_id
24
+ from pydantic_graph.beta.join import Join, JoinNode, ReducerFunction
25
+ from pydantic_graph.beta.mermaid import build_mermaid_graph
26
+ from pydantic_graph.beta.node import (
27
+ EndNode,
28
+ Fork,
29
+ StartNode,
30
+ )
31
+ from pydantic_graph.beta.node_types import (
32
+ AnyDestinationNode,
33
+ AnyNode,
34
+ DestinationNode,
35
+ SourceNode,
36
+ )
37
+ from pydantic_graph.beta.parent_forks import ParentFork, ParentForkFinder
38
+ from pydantic_graph.beta.paths import (
39
+ BroadcastMarker,
40
+ DestinationMarker,
41
+ EdgePath,
42
+ EdgePathBuilder,
43
+ MapMarker,
44
+ Path,
45
+ PathBuilder,
46
+ )
47
+ from pydantic_graph.beta.step import NodeStep, Step, StepContext, StepFunction, StepNode, StreamFunction
48
+ from pydantic_graph.beta.util import TypeOrTypeExpression, get_callable_name, unpack_type_expression
49
+ from pydantic_graph.exceptions import GraphBuildingError, GraphValidationError
50
+ from pydantic_graph.nodes import BaseNode, End
51
+
52
+ StateT = TypeVar('StateT', infer_variance=True)
53
+ DepsT = TypeVar('DepsT', infer_variance=True)
54
+ InputT = TypeVar('InputT', infer_variance=True)
55
+ OutputT = TypeVar('OutputT', infer_variance=True)
56
+ SourceT = TypeVar('SourceT', infer_variance=True)
57
+ SourceNodeT = TypeVar('SourceNodeT', bound=BaseNode[Any, Any, Any], infer_variance=True)
58
+ SourceOutputT = TypeVar('SourceOutputT', infer_variance=True)
59
+ GraphInputT = TypeVar('GraphInputT', infer_variance=True)
60
+ GraphOutputT = TypeVar('GraphOutputT', infer_variance=True)
61
+ T = TypeVar('T', infer_variance=True)
62
+
63
+
64
+ @dataclass(init=False)
65
+ class GraphBuilder(Generic[StateT, DepsT, GraphInputT, GraphOutputT]):
66
+ """A builder for constructing executable graph definitions.
67
+
68
+ GraphBuilder provides a fluent interface for defining nodes, edges, and
69
+ routing in a graph workflow. It supports typed state, dependencies, and
70
+ input/output validation.
71
+
72
+ Type Parameters:
73
+ StateT: The type of the graph state
74
+ DepsT: The type of the dependencies
75
+ GraphInputT: The type of the graph input data
76
+ GraphOutputT: The type of the graph output data
77
+ """
78
+
79
+ name: str | None
80
+ """Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method."""
81
+
82
+ state_type: TypeOrTypeExpression[StateT]
83
+ """The type of the graph state."""
84
+
85
+ deps_type: TypeOrTypeExpression[DepsT]
86
+ """The type of the dependencies."""
87
+
88
+ input_type: TypeOrTypeExpression[GraphInputT]
89
+ """The type of the graph input data."""
90
+
91
+ output_type: TypeOrTypeExpression[GraphOutputT]
92
+ """The type of the graph output data."""
93
+
94
+ auto_instrument: bool
95
+ """Whether to automatically create instrumentation spans."""
96
+
97
+ _nodes: dict[NodeID, AnyNode]
98
+ """Internal storage for nodes in the graph."""
99
+
100
+ _edges_by_source: dict[NodeID, list[Path]]
101
+ """Internal storage for edges by source node."""
102
+
103
+ _decision_index: int
104
+ """Counter for generating unique decision node IDs."""
105
+
106
+ Source = TypeAliasType('Source', SourceNode[StateT, DepsT, OutputT], type_params=(OutputT,))
107
+ Destination = TypeAliasType('Destination', DestinationNode[StateT, DepsT, InputT], type_params=(InputT,))
108
+
109
+ def __init__(
110
+ self,
111
+ *,
112
+ name: str | None = None,
113
+ state_type: TypeOrTypeExpression[StateT] = NoneType,
114
+ deps_type: TypeOrTypeExpression[DepsT] = NoneType,
115
+ input_type: TypeOrTypeExpression[GraphInputT] = NoneType,
116
+ output_type: TypeOrTypeExpression[GraphOutputT] = NoneType,
117
+ auto_instrument: bool = True,
118
+ ):
119
+ """Initialize a graph builder.
120
+
121
+ Args:
122
+ name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method.
123
+ state_type: The type of the graph state
124
+ deps_type: The type of the dependencies
125
+ input_type: The type of the graph input data
126
+ output_type: The type of the graph output data
127
+ auto_instrument: Whether to automatically create instrumentation spans
128
+ """
129
+ self.name = name
130
+
131
+ self.state_type = state_type
132
+ self.deps_type = deps_type
133
+ self.input_type = input_type
134
+ self.output_type = output_type
135
+
136
+ self.auto_instrument = auto_instrument
137
+
138
+ self._nodes = {}
139
+ self._edges_by_source = defaultdict(list)
140
+ self._decision_index = 1
141
+
142
+ self._start_node = StartNode[GraphInputT]()
143
+ self._end_node = EndNode[GraphOutputT]()
144
+
145
+ # Node building
146
+ @property
147
+ def start_node(self) -> StartNode[GraphInputT]:
148
+ """Get the start node for the graph.
149
+
150
+ Returns:
151
+ The start node that receives the initial graph input
152
+ """
153
+ return self._start_node
154
+
155
+ @property
156
+ def end_node(self) -> EndNode[GraphOutputT]:
157
+ """Get the end node for the graph.
158
+
159
+ Returns:
160
+ The end node that produces the final graph output
161
+ """
162
+ return self._end_node
163
+
164
+ @overload
165
+ def step(
166
+ self,
167
+ *,
168
+ node_id: str | None = None,
169
+ label: str | None = None,
170
+ ) -> Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]: ...
171
+ @overload
172
+ def step(
173
+ self,
174
+ call: StepFunction[StateT, DepsT, InputT, OutputT],
175
+ *,
176
+ node_id: str | None = None,
177
+ label: str | None = None,
178
+ ) -> Step[StateT, DepsT, InputT, OutputT]: ...
179
+ def step(
180
+ self,
181
+ call: StepFunction[StateT, DepsT, InputT, OutputT] | None = None,
182
+ *,
183
+ node_id: str | None = None,
184
+ label: str | None = None,
185
+ ) -> (
186
+ Step[StateT, DepsT, InputT, OutputT]
187
+ | Callable[[StepFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, OutputT]]
188
+ ):
189
+ """Create a step from a step function.
190
+
191
+ This method can be used as a decorator or called directly to create
192
+ a step node from an async function.
193
+
194
+ Args:
195
+ call: The step function to wrap
196
+ node_id: Optional ID for the node
197
+ label: Optional human-readable label
198
+
199
+ Returns:
200
+ Either a Step instance or a decorator function
201
+ """
202
+ if call is None:
203
+
204
+ def decorator(
205
+ func: StepFunction[StateT, DepsT, InputT, OutputT],
206
+ ) -> Step[StateT, DepsT, InputT, OutputT]:
207
+ return self.step(call=func, node_id=node_id, label=label)
208
+
209
+ return decorator
210
+
211
+ node_id = node_id or get_callable_name(call)
212
+
213
+ step = Step[StateT, DepsT, InputT, OutputT](id=NodeID(node_id), call=call, label=label)
214
+
215
+ return step
216
+
217
+ @overload
218
+ def stream(
219
+ self,
220
+ *,
221
+ node_id: str | None = None,
222
+ label: str | None = None,
223
+ ) -> Callable[
224
+ [StreamFunction[StateT, DepsT, InputT, OutputT]], Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
225
+ ]: ...
226
+ @overload
227
+ def stream(
228
+ self,
229
+ call: StreamFunction[StateT, DepsT, InputT, OutputT],
230
+ *,
231
+ node_id: str | None = None,
232
+ label: str | None = None,
233
+ ) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]: ...
234
+ @overload
235
+ def stream(
236
+ self,
237
+ call: StreamFunction[StateT, DepsT, InputT, OutputT] | None = None,
238
+ *,
239
+ node_id: str | None = None,
240
+ label: str | None = None,
241
+ ) -> (
242
+ Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
243
+ | Callable[
244
+ [StreamFunction[StateT, DepsT, InputT, OutputT]],
245
+ Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
246
+ ]
247
+ ): ...
248
+ def stream(
249
+ self,
250
+ call: StreamFunction[StateT, DepsT, InputT, OutputT] | None = None,
251
+ *,
252
+ node_id: str | None = None,
253
+ label: str | None = None,
254
+ ) -> (
255
+ Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]
256
+ | Callable[
257
+ [StreamFunction[StateT, DepsT, InputT, OutputT]],
258
+ Step[StateT, DepsT, InputT, AsyncIterable[OutputT]],
259
+ ]
260
+ ):
261
+ """Create a step from an async iterator (which functions like a "stream").
262
+
263
+ This method can be used as a decorator or called directly to create
264
+ a step node from an async function.
265
+
266
+ Args:
267
+ call: The step function to wrap
268
+ node_id: Optional ID for the node
269
+ label: Optional human-readable label
270
+
271
+ Returns:
272
+ Either a Step instance or a decorator function
273
+ """
274
+ if call is None:
275
+
276
+ def decorator(
277
+ func: StreamFunction[StateT, DepsT, InputT, OutputT],
278
+ ) -> Step[StateT, DepsT, InputT, AsyncIterable[OutputT]]:
279
+ return self.stream(call=func, node_id=node_id, label=label)
280
+
281
+ return decorator
282
+
283
+ # We need to wrap the call so that we can call `await` even though the result is an async iterator
284
+ async def wrapper(ctx: StepContext[StateT, DepsT, InputT]):
285
+ return call(ctx)
286
+
287
+ return self.step(call=wrapper, node_id=node_id, label=label)
288
+
289
+ @overload
290
+ def join(
291
+ self,
292
+ reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
293
+ *,
294
+ initial: OutputT,
295
+ node_id: str | None = None,
296
+ parent_fork_id: str | None = None,
297
+ preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
298
+ ) -> Join[StateT, DepsT, InputT, OutputT]: ...
299
+ @overload
300
+ def join(
301
+ self,
302
+ reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
303
+ *,
304
+ initial_factory: Callable[[], OutputT],
305
+ node_id: str | None = None,
306
+ parent_fork_id: str | None = None,
307
+ preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
308
+ ) -> Join[StateT, DepsT, InputT, OutputT]: ...
309
+
310
+ def join(
311
+ self,
312
+ reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
313
+ *,
314
+ initial: OutputT | Unset = UNSET,
315
+ initial_factory: Callable[[], OutputT] | Unset = UNSET,
316
+ node_id: str | None = None,
317
+ parent_fork_id: str | None = None,
318
+ preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
319
+ ) -> Join[StateT, DepsT, InputT, OutputT]:
320
+ if initial_factory is UNSET:
321
+ initial_factory = lambda: initial # pyright: ignore[reportAssignmentType] # noqa: E731
322
+
323
+ return Join[StateT, DepsT, InputT, OutputT](
324
+ id=JoinID(NodeID(node_id or generate_placeholder_node_id(get_callable_name(reducer)))),
325
+ reducer=reducer,
326
+ initial_factory=cast(Callable[[], OutputT], initial_factory),
327
+ parent_fork_id=ForkID(parent_fork_id) if parent_fork_id is not None else None,
328
+ preferred_parent_fork=preferred_parent_fork,
329
+ )
330
+
331
+ # Edge building
332
+ def add(self, *edges: EdgePath[StateT, DepsT]) -> None: # noqa: C901
333
+ """Add one or more edge paths to the graph.
334
+
335
+ This method processes edge paths and automatically creates any necessary
336
+ fork nodes for broadcasts and maps.
337
+
338
+ Args:
339
+ *edges: The edge paths to add to the graph
340
+ """
341
+
342
+ def _handle_path(p: Path):
343
+ """Process a path and create necessary fork nodes.
344
+
345
+ Args:
346
+ p: The path to process
347
+ """
348
+ for item in p.items:
349
+ if isinstance(item, BroadcastMarker):
350
+ new_node = Fork[Any, Any](id=item.fork_id, is_map=False, downstream_join_id=None)
351
+ self._insert_node(new_node)
352
+ for path in item.paths:
353
+ _handle_path(Path(items=[*path.items]))
354
+ elif isinstance(item, MapMarker):
355
+ new_node = Fork[Any, Any](id=item.fork_id, is_map=True, downstream_join_id=item.downstream_join_id)
356
+ self._insert_node(new_node)
357
+ elif isinstance(item, DestinationMarker):
358
+ pass
359
+
360
+ def _handle_destination_node(d: AnyDestinationNode):
361
+ if id(d) in destination_ids:
362
+ return # prevent infinite recursion if there is a cycle of decisions
363
+
364
+ destination_ids.add(id(d))
365
+ destinations.append(d)
366
+ self._insert_node(d)
367
+ if isinstance(d, Decision):
368
+ for branch in d.branches:
369
+ _handle_path(branch.path)
370
+ for d2 in branch.destinations:
371
+ _handle_destination_node(d2)
372
+
373
+ destination_ids = set[int]()
374
+ destinations: list[AnyDestinationNode] = []
375
+ for edge in edges:
376
+ for source_node in edge.sources:
377
+ self._insert_node(source_node)
378
+ self._edges_by_source[source_node.id].append(edge.path)
379
+ for destination_node in edge.destinations:
380
+ _handle_destination_node(destination_node)
381
+ _handle_path(edge.path)
382
+
383
+ # Automatically create edges from step function return hints including `BaseNode`s
384
+ for destination in destinations:
385
+ if not isinstance(destination, Step) or isinstance(destination, NodeStep):
386
+ continue
387
+ parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
388
+ type_hints = get_type_hints(destination.call, localns=parent_namespace, include_extras=True)
389
+ try:
390
+ return_hint = type_hints['return']
391
+ except KeyError:
392
+ pass
393
+ else:
394
+ edge = self._edge_from_return_hint(destination, return_hint)
395
+ if edge is not None:
396
+ self.add(edge)
397
+
398
+ def add_edge(self, source: Source[T], destination: Destination[T], *, label: str | None = None) -> None:
399
+ """Add a simple edge between two nodes.
400
+
401
+ Args:
402
+ source: The source node
403
+ destination: The destination node
404
+ label: Optional label for the edge
405
+ """
406
+ builder = self.edge_from(source)
407
+ if label is not None:
408
+ builder = builder.label(label)
409
+ self.add(builder.to(destination))
410
+
411
+ def add_mapping_edge(
412
+ self,
413
+ source: Source[Iterable[T]],
414
+ map_to: Destination[T],
415
+ *,
416
+ pre_map_label: str | None = None,
417
+ post_map_label: str | None = None,
418
+ fork_id: ForkID | None = None,
419
+ downstream_join_id: JoinID | None = None,
420
+ ) -> None:
421
+ """Add an edge that maps iterable data across parallel paths.
422
+
423
+ Args:
424
+ source: The source node that produces iterable data
425
+ map_to: The destination node that receives individual items
426
+ pre_map_label: Optional label before the map operation
427
+ post_map_label: Optional label after the map operation
428
+ fork_id: Optional ID for the fork node produced for this map operation
429
+ downstream_join_id: Optional ID of a join node that will always be downstream of this map.
430
+ Specifying this ensures correct handling if you try to map an empty iterable.
431
+ """
432
+ builder = self.edge_from(source)
433
+ if pre_map_label is not None:
434
+ builder = builder.label(pre_map_label)
435
+ builder = builder.map(fork_id=fork_id, downstream_join_id=downstream_join_id)
436
+ if post_map_label is not None:
437
+ builder = builder.label(post_map_label)
438
+ self.add(builder.to(map_to))
439
+
440
+ # TODO(DavidM): Support adding subgraphs; I think this behaves like a step with the same inputs/outputs but gets rendered as a subgraph in mermaid
441
+
442
+ def edge_from(self, *sources: Source[SourceOutputT]) -> EdgePathBuilder[StateT, DepsT, SourceOutputT]:
443
+ """Create an edge path builder starting from the given source nodes.
444
+
445
+ Args:
446
+ *sources: The source nodes to start the edge path from
447
+
448
+ Returns:
449
+ An EdgePathBuilder for constructing the complete edge path
450
+ """
451
+ return EdgePathBuilder[StateT, DepsT, SourceOutputT](
452
+ sources=sources, path_builder=PathBuilder(working_items=[])
453
+ )
454
+
455
+ def decision(self, *, note: str | None = None, node_id: str | None = None) -> Decision[StateT, DepsT, Never]:
456
+ """Create a new decision node.
457
+
458
+ Args:
459
+ note: Optional note to describe the decision logic
460
+ node_id: Optional ID for the node produced for this decision logic
461
+
462
+ Returns:
463
+ A new Decision node with no branches
464
+ """
465
+ return Decision(id=NodeID(node_id or generate_placeholder_node_id('decision')), branches=[], note=note)
466
+
467
+ def match(
468
+ self,
469
+ source: TypeOrTypeExpression[SourceT],
470
+ *,
471
+ matches: Callable[[Any], bool] | None = None,
472
+ ) -> DecisionBranchBuilder[StateT, DepsT, SourceT, SourceT, Never]:
473
+ """Create a decision branch matcher.
474
+
475
+ Args:
476
+ source: The type or type expression to match against
477
+ matches: Optional custom matching function
478
+
479
+ Returns:
480
+ A DecisionBranchBuilder for constructing the branch
481
+ """
482
+ # Note, the following node_id really is just a placeholder and shouldn't end up in the final graph
483
+ # This is why we don't expose a way for end users to override the value used here.
484
+ node_id = NodeID(generate_placeholder_node_id('match_decision'))
485
+ decision = Decision[StateT, DepsT, Never](id=node_id, branches=[], note=None)
486
+ new_path_builder = PathBuilder[StateT, DepsT, SourceT](working_items=[])
487
+ return DecisionBranchBuilder(decision=decision, source=source, matches=matches, path_builder=new_path_builder)
488
+
489
+ def match_node(
490
+ self,
491
+ source: type[SourceNodeT],
492
+ *,
493
+ matches: Callable[[Any], bool] | None = None,
494
+ ) -> DecisionBranch[SourceNodeT]:
495
+ """Create a decision branch for BaseNode subclasses.
496
+
497
+ This is similar to match() but specifically designed for matching
498
+ against BaseNode types from the v1 system.
499
+
500
+ Args:
501
+ source: The BaseNode subclass to match against
502
+ matches: Optional custom matching function
503
+
504
+ Returns:
505
+ A DecisionBranch for the BaseNode type
506
+ """
507
+ node = NodeStep(source)
508
+ path = Path(items=[DestinationMarker(node.id)])
509
+ return DecisionBranch(source=source, matches=matches, path=path, destinations=[node])
510
+
511
+ def node(
512
+ self,
513
+ node_type: type[BaseNode[StateT, DepsT, GraphOutputT]],
514
+ ) -> EdgePath[StateT, DepsT]:
515
+ """Create an edge path from a BaseNode class.
516
+
517
+ This method integrates v1-style BaseNode classes into the v2 graph
518
+ system by analyzing their type hints and creating appropriate edges.
519
+
520
+ Args:
521
+ node_type: The BaseNode subclass to integrate
522
+
523
+ Returns:
524
+ An EdgePath representing the node and its connections
525
+
526
+ Raises:
527
+ GraphSetupError: If the node type is missing required type hints
528
+ """
529
+ parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
530
+ type_hints = get_type_hints(node_type.run, localns=parent_namespace, include_extras=True)
531
+ try:
532
+ return_hint = type_hints['return']
533
+ except KeyError as e: # pragma: no cover
534
+ raise exceptions.GraphSetupError(
535
+ f'Node {node_type} is missing a return type hint on its `run` method'
536
+ ) from e
537
+
538
+ node = NodeStep(node_type)
539
+
540
+ edge = self._edge_from_return_hint(node, return_hint)
541
+ if not edge: # pragma: no cover
542
+ raise exceptions.GraphSetupError(f'Node {node_type} is missing a return type hint on its `run` method')
543
+
544
+ return edge
545
+
546
+ # Helpers
547
+ def _insert_node(self, node: AnyNode) -> None:
548
+ """Insert a node into the graph, checking for ID conflicts.
549
+
550
+ Args:
551
+ node: The node to insert
552
+
553
+ Raises:
554
+ ValueError: If a different node with the same ID already exists
555
+ """
556
+ existing = self._nodes.get(node.id)
557
+ if existing is None:
558
+ self._nodes[node.id] = node
559
+ elif isinstance(existing, NodeStep) and isinstance(node, NodeStep) and existing.node_type is node.node_type:
560
+ pass
561
+ elif existing is not node:
562
+ raise GraphBuildingError(
563
+ f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}'
564
+ )
565
+
566
+ def _edge_from_return_hint(
567
+ self, node: SourceNode[StateT, DepsT, Any], return_hint: TypeOrTypeExpression[Any]
568
+ ) -> EdgePath[StateT, DepsT] | None:
569
+ """Create edges from a return type hint.
570
+
571
+ This method analyzes return type hints from step functions or node methods
572
+ to automatically create appropriate edges in the graph.
573
+
574
+ Args:
575
+ node: The source node
576
+ return_hint: The return type hint to analyze
577
+
578
+ Returns:
579
+ An EdgePath if edges can be inferred, None otherwise
580
+
581
+ Raises:
582
+ GraphSetupError: If the return type hint is invalid or incomplete
583
+ """
584
+ destinations: list[AnyDestinationNode] = []
585
+ union_args = _utils.get_union_args(return_hint)
586
+ for return_type in union_args:
587
+ return_type, annotations = _utils.unpack_annotated(return_type)
588
+ return_type_origin = get_origin(return_type) or return_type
589
+ if return_type_origin is End:
590
+ destinations.append(self.end_node)
591
+ elif return_type_origin is BaseNode:
592
+ raise exceptions.GraphSetupError( # pragma: no cover
593
+ f'Node {node} return type hint includes a plain `BaseNode`. '
594
+ 'Edge inference requires each possible returned `BaseNode` subclass to be listed explicitly.'
595
+ )
596
+ elif return_type_origin is StepNode:
597
+ step = cast(
598
+ Step[StateT, DepsT, Any, Any] | None,
599
+ next((a for a in annotations if isinstance(a, Step)), None), # pyright: ignore[reportUnknownArgumentType]
600
+ )
601
+ if step is None:
602
+ raise exceptions.GraphSetupError( # pragma: no cover
603
+ f'Node {node} return type hint includes a `StepNode` without a `Step` annotation. '
604
+ 'When returning `my_step.as_node()`, use `Annotated[StepNode[StateT, DepsT], my_step]` as the return type hint.'
605
+ )
606
+ destinations.append(step)
607
+ elif return_type_origin is JoinNode:
608
+ join = cast(
609
+ Join[StateT, DepsT, Any, Any] | None,
610
+ next((a for a in annotations if isinstance(a, Join)), None), # pyright: ignore[reportUnknownArgumentType]
611
+ )
612
+ if join is None:
613
+ raise exceptions.GraphSetupError( # pragma: no cover
614
+ f'Node {node} return type hint includes a `JoinNode` without a `Join` annotation. '
615
+ 'When returning `my_join.as_node()`, use `Annotated[JoinNode[StateT, DepsT], my_join]` as the return type hint.'
616
+ )
617
+ destinations.append(join)
618
+ elif inspect.isclass(return_type_origin) and issubclass(return_type_origin, BaseNode):
619
+ destinations.append(NodeStep(return_type))
620
+
621
+ if len(destinations) < len(union_args):
622
+ # Only build edges if all the return types are nodes
623
+ return None
624
+
625
+ edge = self.edge_from(node)
626
+ if len(destinations) == 1:
627
+ return edge.to(destinations[0])
628
+ else:
629
+ decision = self.decision()
630
+ for destination in destinations:
631
+ # We don't actually use this decision mechanism, but we need to build the edges for parent-fork finding
632
+ decision = decision.branch(self.match(NoneType).to(destination))
633
+ return edge.to(decision)
634
+
635
+ # Graph building
636
+ def build(self, validate_graph_structure: bool = True) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
637
+ """Build the final executable graph from the accumulated nodes and edges.
638
+
639
+ This method performs validation, normalization, and analysis of the graph
640
+ structure to create a complete, executable graph instance.
641
+
642
+ Args:
643
+ validate_graph_structure: whether to perform validation of the graph structure
644
+ See the docstring of _validate_graph_structure below for more details.
645
+
646
+ Returns:
647
+ A complete Graph instance ready for execution
648
+
649
+ Raises:
650
+ ValueError: If the graph structure is invalid (e.g., join without parent fork)
651
+ """
652
+ nodes = self._nodes
653
+ edges_by_source = self._edges_by_source
654
+
655
+ nodes, edges_by_source = _replace_placeholder_node_ids(nodes, edges_by_source)
656
+ nodes, edges_by_source = _flatten_paths(nodes, edges_by_source)
657
+ nodes, edges_by_source = _normalize_forks(nodes, edges_by_source)
658
+ if validate_graph_structure:
659
+ _validate_graph_structure(nodes, edges_by_source)
660
+ parent_forks = _collect_dominating_forks(nodes, edges_by_source)
661
+ intermediate_join_nodes = _compute_intermediate_join_nodes(nodes, parent_forks)
662
+
663
+ return Graph[StateT, DepsT, GraphInputT, GraphOutputT](
664
+ name=self.name,
665
+ state_type=unpack_type_expression(self.state_type),
666
+ deps_type=unpack_type_expression(self.deps_type),
667
+ input_type=unpack_type_expression(self.input_type),
668
+ output_type=unpack_type_expression(self.output_type),
669
+ nodes=nodes,
670
+ edges_by_source=edges_by_source,
671
+ parent_forks=parent_forks,
672
+ intermediate_join_nodes=intermediate_join_nodes,
673
+ auto_instrument=self.auto_instrument,
674
+ )
675
+
676
+
677
+ def _validate_graph_structure( # noqa: C901
678
+ nodes: dict[NodeID, AnyNode],
679
+ edges_by_source: dict[NodeID, list[Path]],
680
+ ) -> None:
681
+ """Validate the graph structure for common issues.
682
+
683
+ This function raises an error if any of the following criteria are not met:
684
+ 1. There are edges from the start node
685
+ 2. There are edges to the end node
686
+ 3. No non-End node is a dead end (no outgoing edges)
687
+ 4. The end node is reachable from the start node
688
+ 5. All nodes are reachable from the start node
689
+
690
+ Note 1: Under some circumstances it may be reasonable to build a graph that violates one or more of
691
+ the above conditions. We may eventually add support for more granular control over validation,
692
+ but today, if you want to build a graph that violates any of these assumptions you need to pass
693
+ `validate_graph_structure=False` to the call to `GraphBuilder.build`.
694
+
695
+ Note 2: Some of the earlier items in the above list are redundant with the later items.
696
+ I've included the earlier items in the list as a reminder to ourselves if/when we add more granular validation
697
+ because you might want to check the earlier items but not the later items, as described in Note 1.
698
+
699
+ Args:
700
+ nodes: The nodes in the graph
701
+ edges_by_source: The edges by source node
702
+
703
+ Raises:
704
+ GraphBuildingError: If any of the aforementioned structural issues are found.
705
+ """
706
+ how_to_suppress = ' If this is intentional, you can suppress this error by passing `validate_graph_structure=False` to the call to `GraphBuilder.build`.'
707
+
708
+ # Extract all destination IDs from edges and decision branches
709
+ all_destinations: set[NodeID] = set()
710
+
711
+ def _collect_destinations_from_path(path: Path) -> None:
712
+ for item in path.items:
713
+ if isinstance(item, DestinationMarker):
714
+ all_destinations.add(item.destination_id)
715
+
716
+ for paths in edges_by_source.values():
717
+ for path in paths:
718
+ _collect_destinations_from_path(path)
719
+
720
+ # Also collect destinations from decision branches
721
+ for node in nodes.values():
722
+ if isinstance(node, Decision):
723
+ for branch in node.branches:
724
+ _collect_destinations_from_path(branch.path)
725
+
726
+ # Check 1: Check if there are edges from the start node
727
+ start_edges = edges_by_source.get(StartNode.id, [])
728
+ if not start_edges:
729
+ raise GraphValidationError('The graph has no edges from the start node.' + how_to_suppress)
730
+
731
+ # Check 2: Check if there are edges to the end node
732
+ if EndNode.id not in all_destinations:
733
+ raise GraphValidationError('The graph has no edges to the end node.' + how_to_suppress)
734
+
735
+ # Check 3: Find all nodes with no outgoing edges (dead ends)
736
+ dead_end_nodes: list[NodeID] = []
737
+ for node_id, node in nodes.items():
738
+ # Skip the end node itself
739
+ if isinstance(node, EndNode):
740
+ continue
741
+
742
+ # Check if this node has any outgoing edges
743
+ has_edges = node_id in edges_by_source and len(edges_by_source[node_id]) > 0
744
+
745
+ # Also check if it's a decision node with branches
746
+ if isinstance(node, Decision):
747
+ has_edges = has_edges or len(node.branches) > 0
748
+
749
+ if not has_edges:
750
+ dead_end_nodes.append(node_id)
751
+
752
+ if dead_end_nodes:
753
+ raise GraphValidationError(f'The following nodes have no outgoing edges: {dead_end_nodes}.' + how_to_suppress)
754
+
755
+ # Checks 4 and 5: Ensure all nodes (and in particular, the end node) are reachable from the start node
756
+ reachable: set[NodeID] = {StartNode.id}
757
+ to_visit = [StartNode.id]
758
+
759
+ while to_visit:
760
+ current_id = to_visit.pop()
761
+
762
+ # Add destinations from regular edges
763
+ for path in edges_by_source.get(current_id, []):
764
+ for item in path.items:
765
+ if isinstance(item, DestinationMarker):
766
+ if item.destination_id not in reachable:
767
+ reachable.add(item.destination_id)
768
+ to_visit.append(item.destination_id)
769
+
770
+ # Add destinations from decision branches
771
+ current_node = nodes.get(current_id)
772
+ if isinstance(current_node, Decision):
773
+ for branch in current_node.branches:
774
+ for item in branch.path.items:
775
+ if isinstance(item, DestinationMarker):
776
+ if item.destination_id not in reachable:
777
+ reachable.add(item.destination_id)
778
+ to_visit.append(item.destination_id)
779
+
780
+ unreachable_nodes = [node_id for node_id in nodes if node_id not in reachable]
781
+ if unreachable_nodes:
782
+ raise GraphValidationError(
783
+ f'The following nodes are not reachable from the start node: {unreachable_nodes}.' + how_to_suppress
784
+ )
785
+
786
+
787
+ def _flatten_paths(
788
+ nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
789
+ ) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:
790
+ new_nodes = nodes.copy()
791
+ new_edges: dict[NodeID, list[Path]] = defaultdict(list)
792
+
793
+ paths_to_handle: list[tuple[NodeID, Path]] = []
794
+
795
+ def _split_at_first_fork(path: Path) -> tuple[Path, list[tuple[NodeID, Path]]]:
796
+ for i, item in enumerate(path.items):
797
+ if isinstance(item, MapMarker):
798
+ assert item.fork_id in nodes, 'This should have been added to the node during GraphBuilder.add'
799
+ upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)])
800
+ downstream = Path(path.items[i + 1 :])
801
+ return upstream, [(item.fork_id, downstream)]
802
+
803
+ if isinstance(item, BroadcastMarker):
804
+ assert item.fork_id in nodes, 'This should have been added to the node during GraphBuilder.add'
805
+ upstream = Path(list(path.items[:i]) + [DestinationMarker(item.fork_id)])
806
+ return upstream, [(item.fork_id, p) for p in item.paths]
807
+ return path, []
808
+
809
+ for node in new_nodes.values():
810
+ if isinstance(node, Decision):
811
+ for branch in node.branches:
812
+ upstream, downstreams = _split_at_first_fork(branch.path)
813
+ branch.path = upstream
814
+ paths_to_handle.extend(downstreams)
815
+
816
+ for source_id, edges_from_source in edges.items():
817
+ for path in edges_from_source:
818
+ paths_to_handle.append((source_id, path))
819
+
820
+ while paths_to_handle:
821
+ source_id, path = paths_to_handle.pop()
822
+ upstream, downstreams = _split_at_first_fork(path)
823
+ new_edges[source_id].append(upstream)
824
+ paths_to_handle.extend(downstreams)
825
+
826
+ return new_nodes, dict(new_edges)
827
+
828
+
829
+ def _normalize_forks(
830
+ nodes: dict[NodeID, AnyNode], edges: dict[NodeID, list[Path]]
831
+ ) -> tuple[dict[NodeID, AnyNode], dict[NodeID, list[Path]]]:
832
+ """Normalize the graph structure so only broadcast forks have multiple outgoing edges.
833
+
834
+ This function ensures that any node with multiple outgoing edges is converted
835
+ to use an explicit broadcast fork, simplifying the graph execution model.
836
+
837
+ Args:
838
+ nodes: The nodes in the graph
839
+ edges: The edges by source node
840
+
841
+ Returns:
842
+ A tuple of normalized nodes and edges
843
+ """
844
+ new_nodes = nodes.copy()
845
+ new_edges: dict[NodeID, list[Path]] = {}
846
+
847
+ paths_to_handle: list[Path] = []
848
+
849
+ for source_id, edges_from_source in edges.items():
850
+ paths_to_handle.extend(edges_from_source)
851
+
852
+ node = nodes[source_id]
853
+ if isinstance(node, Fork) and not node.is_map:
854
+ new_edges[source_id] = edges_from_source
855
+ continue # broadcast fork; nothing to do
856
+ if len(edges_from_source) == 1:
857
+ new_edges[source_id] = edges_from_source
858
+ continue
859
+ new_fork = Fork[Any, Any](id=ForkID(NodeID(f'{node.id}_broadcast_fork')), is_map=False, downstream_join_id=None)
860
+ new_nodes[new_fork.id] = new_fork
861
+ new_edges[source_id] = [Path(items=[DestinationMarker(new_fork.id)])]
862
+ new_edges[new_fork.id] = edges_from_source
863
+
864
+ return new_nodes, new_edges
865
+
866
+
867
+ def _collect_dominating_forks(
868
+ graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
869
+ ) -> dict[JoinID, ParentFork[NodeID]]:
870
+ """Find the dominating fork for each join node in the graph.
871
+
872
+ This function analyzes the graph structure to find the parent fork that
873
+ dominates each join node, which is necessary for proper synchronization
874
+ during graph execution.
875
+
876
+ Args:
877
+ graph_nodes: All nodes in the graph
878
+ graph_edges_by_source: Edges organized by source node
879
+
880
+ Returns:
881
+ A mapping from join IDs to their parent fork information
882
+
883
+ Raises:
884
+ ValueError: If any join node lacks a dominating fork
885
+ """
886
+ nodes = set(graph_nodes)
887
+ start_ids: set[NodeID] = {StartNode.id}
888
+ edges: dict[NodeID, list[NodeID]] = defaultdict(list)
889
+
890
+ fork_ids: set[NodeID] = set(start_ids)
891
+ for source_id in nodes:
892
+ working_source_id = source_id
893
+ node = graph_nodes.get(source_id)
894
+
895
+ if isinstance(node, Fork):
896
+ fork_ids.add(node.id)
897
+
898
+ def _handle_path(path: Path, last_source_id: NodeID):
899
+ """Process a path and collect edges and fork information.
900
+
901
+ Args:
902
+ path: The path to process
903
+ last_source_id: The current source node ID
904
+ """
905
+ for item in path.items: # pragma: no branch
906
+ # No need to handle MapMarker or BroadcastMarker here as these should have all been removed
907
+ # by the call to `_flatten_paths`
908
+ if isinstance(item, DestinationMarker):
909
+ edges[last_source_id].append(item.destination_id)
910
+ # Destinations should only ever occur as the last item in the list, so no need to update the working_source_id
911
+ break
912
+
913
+ if isinstance(node, Decision):
914
+ for branch in node.branches:
915
+ _handle_path(branch.path, working_source_id)
916
+ else:
917
+ for path in graph_edges_by_source.get(source_id, []):
918
+ _handle_path(path, source_id)
919
+
920
+ finder = ParentForkFinder(
921
+ nodes=nodes,
922
+ start_ids=start_ids,
923
+ fork_ids=fork_ids,
924
+ edges=edges,
925
+ )
926
+
927
+ joins = [node for node in graph_nodes.values() if isinstance(node, Join)]
928
+ dominating_forks: dict[JoinID, ParentFork[NodeID]] = {}
929
+ for join in joins:
930
+ dominating_fork = finder.find_parent_fork(
931
+ join.id, parent_fork_id=join.parent_fork_id, prefer_closest=join.preferred_parent_fork == 'closest'
932
+ )
933
+ if dominating_fork is None:
934
+ rendered_mermaid_graph = build_mermaid_graph(graph_nodes, graph_edges_by_source).render()
935
+ raise GraphBuildingError(f"""A node in the graph is missing a dominating fork.
936
+
937
+ For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying:
938
+ * Every path from the StartNode to J passes through F
939
+ * There are no cycles in the graph including J that don't pass through F.
940
+ In this case, F is called a "dominating fork" for J.
941
+
942
+ This is used to determine when all tasks upstream of this Join are complete and we can proceed with execution.
943
+
944
+ Mermaid diagram:
945
+ {rendered_mermaid_graph}
946
+
947
+ Join {join.id!r} in this graph has no dominating fork in this graph.""")
948
+ dominating_forks[join.id] = dominating_fork
949
+
950
+ return dominating_forks
951
+
952
+
953
+ def _compute_intermediate_join_nodes(
954
+ nodes: dict[NodeID, AnyNode], parent_forks: dict[JoinID, ParentFork[NodeID]]
955
+ ) -> dict[JoinID, set[JoinID]]:
956
+ """Compute which joins have other joins as intermediate nodes.
957
+
958
+ A join J1 is an intermediate node of join J2 if J1 appears in J2's intermediate_nodes
959
+ (as computed relative to J2's parent fork).
960
+
961
+ This information is used to determine:
962
+ 1. Which joins are "final" (have no other joins in their intermediate_nodes)
963
+ 2. When selecting which reducer to proceed with when there are no active tasks
964
+
965
+ Args:
966
+ nodes: All nodes in the graph
967
+ parent_forks: Parent fork information for each join
968
+
969
+ Returns:
970
+ A mapping from each join to the set of joins that are intermediate to it
971
+ """
972
+ intermediate_join_nodes: dict[JoinID, set[JoinID]] = {}
973
+
974
+ for join_id, parent_fork in parent_forks.items():
975
+ intermediate_joins = set[JoinID]()
976
+ for intermediate_node_id in parent_fork.intermediate_nodes:
977
+ # Check if this intermediate node is also a join
978
+ intermediate_node = nodes.get(intermediate_node_id)
979
+ if isinstance(intermediate_node, Join):
980
+ # Add it regardless of whether it has the same parent fork
981
+ intermediate_joins.add(JoinID(intermediate_node_id))
982
+ intermediate_join_nodes[join_id] = intermediate_joins
983
+
984
+ return intermediate_join_nodes
985
+
986
+
987
+ def _replace_placeholder_node_ids(nodes: dict[NodeID, AnyNode], edges_by_source: dict[NodeID, list[Path]]):
988
+ node_id_remapping = _build_placeholder_node_id_remapping(nodes)
989
+ replaced_nodes = {
990
+ node_id_remapping.get(name, name): _update_node_with_id_remapping(node, node_id_remapping)
991
+ for name, node in nodes.items()
992
+ }
993
+ replaced_edges_by_source = {
994
+ node_id_remapping.get(source, source): [_update_path_with_id_remapping(p, node_id_remapping) for p in paths]
995
+ for source, paths in edges_by_source.items()
996
+ }
997
+ return replaced_nodes, replaced_edges_by_source
998
+
999
+
1000
+ def _build_placeholder_node_id_remapping(nodes: dict[NodeID, AnyNode]) -> dict[NodeID, NodeID]:
1001
+ """The determinism of the generated remapping here is dependent on the determinism of the ordering of the `nodes` dict.
1002
+
1003
+ Note: If we want to generate more interesting names, we could try to make use of information about the edges
1004
+ into/out of the relevant nodes. I'm not sure if there's a good use case for that though so I didn't bother for now.
1005
+ """
1006
+ counter = Counter[str]()
1007
+ remapping: dict[NodeID, NodeID] = {}
1008
+ for node_id in nodes.keys():
1009
+ replaced_node_id = replace_placeholder_id(node_id)
1010
+ if replaced_node_id == node_id:
1011
+ continue
1012
+ counter[replaced_node_id] = count = counter[replaced_node_id] + 1
1013
+ remapping[node_id] = NodeID(f'{replaced_node_id}_{count}' if count > 1 else replaced_node_id)
1014
+ return remapping
1015
+
1016
+
1017
+ def _update_node_with_id_remapping(node: AnyNode, node_id_remapping: dict[NodeID, NodeID]) -> AnyNode:
1018
+ # Note: it's a bit awkward that we mutate the provided nodes, but this is necessary to ensure that
1019
+ # calls to `.as_node` reference the correct node_ids when relying on compatibility with the v1 API.
1020
+ # We only mutate placeholder IDs so I _think_ this should generally be okay. I guess we can
1021
+ # rework it more carefully if it causes issues in the future..
1022
+ if isinstance(node, Step):
1023
+ node.id = node_id_remapping.get(node.id, node.id)
1024
+ elif isinstance(node, Join):
1025
+ node.id = JoinID(node_id_remapping.get(node.id, node.id))
1026
+ elif isinstance(node, Fork):
1027
+ node.id = ForkID(node_id_remapping.get(node.id, node.id))
1028
+ if node.downstream_join_id is not None:
1029
+ node.downstream_join_id = JoinID(node_id_remapping.get(node.downstream_join_id, node.downstream_join_id))
1030
+ elif isinstance(node, Decision):
1031
+ node.id = node_id_remapping.get(node.id, node.id)
1032
+ node.branches = [
1033
+ replace(branch, path=_update_path_with_id_remapping(branch.path, node_id_remapping))
1034
+ for branch in node.branches
1035
+ ]
1036
+ return node
1037
+
1038
+
1039
+ def _update_path_with_id_remapping(path: Path, node_id_remapping: dict[NodeID, NodeID]) -> Path:
1040
+ # Note: we have already deepcopied the node provided to this function so it should be okay to make mutations,
1041
+ # this could change if we change the code surrounding the code paths leading to this function call though.
1042
+ for item in path.items:
1043
+ if isinstance(item, MapMarker):
1044
+ downstream_join_id = item.downstream_join_id
1045
+ if downstream_join_id is not None:
1046
+ item.downstream_join_id = JoinID(node_id_remapping.get(downstream_join_id, downstream_join_id))
1047
+ item.fork_id = ForkID(node_id_remapping.get(item.fork_id, item.fork_id))
1048
+ elif isinstance(item, BroadcastMarker):
1049
+ item.fork_id = ForkID(node_id_remapping.get(item.fork_id, item.fork_id))
1050
+ item.paths = [_update_path_with_id_remapping(p, node_id_remapping) for p in item.paths]
1051
+ elif isinstance(item, DestinationMarker):
1052
+ item.destination_id = node_id_remapping.get(item.destination_id, item.destination_id)
1053
+ return path