pydantic-graph 0.7.5__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph/_utils.py +70 -3
- pydantic_graph/beta/__init__.py +25 -0
- pydantic_graph/beta/decision.py +276 -0
- pydantic_graph/beta/graph.py +978 -0
- pydantic_graph/beta/graph_builder.py +1053 -0
- pydantic_graph/beta/id_types.py +76 -0
- pydantic_graph/beta/join.py +249 -0
- pydantic_graph/beta/mermaid.py +208 -0
- pydantic_graph/beta/node.py +95 -0
- pydantic_graph/beta/node_types.py +90 -0
- pydantic_graph/beta/parent_forks.py +232 -0
- pydantic_graph/beta/paths.py +421 -0
- pydantic_graph/beta/step.py +253 -0
- pydantic_graph/beta/util.py +90 -0
- pydantic_graph/exceptions.py +22 -0
- pydantic_graph/graph.py +21 -26
- pydantic_graph/mermaid.py +2 -2
- pydantic_graph/nodes.py +10 -13
- pydantic_graph/persistence/__init__.py +4 -4
- pydantic_graph/persistence/_utils.py +1 -1
- pydantic_graph/persistence/file.py +12 -13
- pydantic_graph/persistence/in_mem.py +3 -3
- {pydantic_graph-0.7.5.dist-info → pydantic_graph-1.24.0.dist-info}/METADATA +4 -5
- pydantic_graph-1.24.0.dist-info/RECORD +28 -0
- pydantic_graph-0.7.5.dist-info/RECORD +0 -15
- {pydantic_graph-0.7.5.dist-info → pydantic_graph-1.24.0.dist-info}/WHEEL +0 -0
- {pydantic_graph-0.7.5.dist-info → pydantic_graph-1.24.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Type definitions for identifiers used throughout the graph execution system.
|
|
2
|
+
|
|
3
|
+
This module defines NewType wrappers and aliases for various ID types used in graph execution,
|
|
4
|
+
providing type safety and clarity when working with different kinds of identifiers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
import uuid
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import NewType
|
|
13
|
+
|
|
14
|
+
NodeID = NewType('NodeID', str)
|
|
15
|
+
"""Unique identifier for a node in the graph."""
|
|
16
|
+
|
|
17
|
+
NodeRunID = NewType('NodeRunID', str)
|
|
18
|
+
"""Unique identifier for a specific execution instance of a node."""
|
|
19
|
+
|
|
20
|
+
# The following aliases are just included for clarity; making them NewTypes is a hassle
|
|
21
|
+
JoinID = NodeID
|
|
22
|
+
"""Alias for NodeId when referring to join nodes."""
|
|
23
|
+
|
|
24
|
+
ForkID = NodeID
|
|
25
|
+
"""Alias for NodeId when referring to fork nodes."""
|
|
26
|
+
|
|
27
|
+
TaskID = NewType('TaskID', str)
|
|
28
|
+
"""Unique identifier for a task within the graph execution."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class ForkStackItem:
|
|
33
|
+
"""Represents a single fork point in the execution stack.
|
|
34
|
+
|
|
35
|
+
When a node creates multiple parallel execution paths (forks), each fork is tracked
|
|
36
|
+
using a ForkStackItem. This allows the system to maintain the execution hierarchy
|
|
37
|
+
and coordinate parallel branches of execution.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
fork_id: ForkID
|
|
41
|
+
"""The ID of the node that created this fork."""
|
|
42
|
+
node_run_id: NodeRunID
|
|
43
|
+
"""The ID associated to the specific run of the node that created this fork."""
|
|
44
|
+
thread_index: int
|
|
45
|
+
"""The index of the execution "thread" created during the node run that created this fork.
|
|
46
|
+
|
|
47
|
+
This is largely intended for observability/debugging; it may eventually be used to ensure idempotency."""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
ForkStack = tuple[ForkStackItem, ...]
|
|
51
|
+
"""A stack of fork items representing the full hierarchy of parallel execution branches.
|
|
52
|
+
|
|
53
|
+
The fork stack tracks the complete path through nested parallel executions,
|
|
54
|
+
allowing the system to coordinate and join parallel branches correctly.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def generate_placeholder_node_id(label: str) -> str:
|
|
59
|
+
"""Generate a placeholder node ID, to be replaced during graph building."""
|
|
60
|
+
return f'{_NODE_ID_PLACEHOLDER_PREFIX}:{label}:{uuid.uuid4()}'
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def replace_placeholder_id(node_id: NodeID) -> str:
|
|
64
|
+
"""Returns whether a given NodeID is a placeholder node ID which should be replaced during graph building."""
|
|
65
|
+
return re.sub(rf'{_NODE_ID_PLACEHOLDER_PREFIX}:([^:]+):.*', r'\1', node_id)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
_NODE_ID_PLACEHOLDER_PREFIX = '__placeholder__'
|
|
69
|
+
"""
|
|
70
|
+
When Node IDs are required but not specified when building a graph, we generate placeholder values
|
|
71
|
+
using this prefix followed by a random string.
|
|
72
|
+
|
|
73
|
+
During graph building, we replace these with simpler and deterministically-selected values.
|
|
74
|
+
This ensures that the node IDs are stable when rebuilding the graph, and makes the generated mermaid diagrams etc.
|
|
75
|
+
easier to read.
|
|
76
|
+
"""
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""Join operations and reducers for graph execution.
|
|
2
|
+
|
|
3
|
+
This module provides the core components for joining parallel execution paths
|
|
4
|
+
in a graph, including various reducer types that aggregate data from multiple
|
|
5
|
+
sources into a single output.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import inspect
|
|
11
|
+
from abc import abstractmethod
|
|
12
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Any, Generic, Literal, cast, overload
|
|
15
|
+
|
|
16
|
+
from typing_extensions import Protocol, Self, TypeAliasType, TypeVar
|
|
17
|
+
|
|
18
|
+
from pydantic_graph import BaseNode, End, GraphRunContext
|
|
19
|
+
from pydantic_graph.beta.id_types import ForkID, ForkStack, JoinID
|
|
20
|
+
|
|
21
|
+
StateT = TypeVar('StateT', infer_variance=True)
|
|
22
|
+
DepsT = TypeVar('DepsT', infer_variance=True)
|
|
23
|
+
InputT = TypeVar('InputT', infer_variance=True)
|
|
24
|
+
OutputT = TypeVar('OutputT', infer_variance=True)
|
|
25
|
+
T = TypeVar('T', infer_variance=True)
|
|
26
|
+
K = TypeVar('K', infer_variance=True)
|
|
27
|
+
V = TypeVar('V', infer_variance=True)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# TODO(P1): I guess we should make this class private, etc.
|
|
31
|
+
@dataclass
|
|
32
|
+
class JoinState:
|
|
33
|
+
"""The state of a join during graph execution associated to a particular fork run."""
|
|
34
|
+
|
|
35
|
+
current: Any
|
|
36
|
+
downstream_fork_stack: ForkStack
|
|
37
|
+
cancelled_sibling_tasks: bool = False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(init=False)
|
|
41
|
+
class ReducerContext(Generic[StateT, DepsT]):
|
|
42
|
+
"""Context information passed to reducer functions during graph execution.
|
|
43
|
+
|
|
44
|
+
The reducer context provides access to the current graph state and dependencies.
|
|
45
|
+
|
|
46
|
+
Type Parameters:
|
|
47
|
+
StateT: The type of the graph state
|
|
48
|
+
DepsT: The type of the dependencies
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
_state: StateT
|
|
52
|
+
"""The current graph state."""
|
|
53
|
+
_deps: DepsT
|
|
54
|
+
"""The dependencies of the current graph run."""
|
|
55
|
+
_join_state: JoinState
|
|
56
|
+
"""The JoinState for this reducer context."""
|
|
57
|
+
|
|
58
|
+
def __init__(self, *, state: StateT, deps: DepsT, join_state: JoinState):
|
|
59
|
+
self._state = state
|
|
60
|
+
self._deps = deps
|
|
61
|
+
self._join_state = join_state
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def state(self) -> StateT:
|
|
65
|
+
"""The state of the graph run."""
|
|
66
|
+
return self._state
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def deps(self) -> DepsT:
|
|
70
|
+
"""The deps for the graph run."""
|
|
71
|
+
return self._deps
|
|
72
|
+
|
|
73
|
+
def cancel_sibling_tasks(self):
|
|
74
|
+
"""Cancel all sibling tasks created from the same fork.
|
|
75
|
+
|
|
76
|
+
You can call this if you want your join to have early-stopping behavior.
|
|
77
|
+
"""
|
|
78
|
+
self._join_state.cancelled_sibling_tasks = True
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
PlainReducerFunction = TypeAliasType(
|
|
82
|
+
'PlainReducerFunction',
|
|
83
|
+
Callable[[OutputT, InputT], OutputT],
|
|
84
|
+
type_params=(InputT, OutputT),
|
|
85
|
+
)
|
|
86
|
+
ContextReducerFunction = TypeAliasType(
|
|
87
|
+
'ContextReducerFunction',
|
|
88
|
+
Callable[[ReducerContext[StateT, DepsT], OutputT, InputT], OutputT],
|
|
89
|
+
type_params=(StateT, DepsT, InputT, OutputT),
|
|
90
|
+
)
|
|
91
|
+
ReducerFunction = TypeAliasType(
|
|
92
|
+
'ReducerFunction',
|
|
93
|
+
ContextReducerFunction[StateT, DepsT, InputT, OutputT] | PlainReducerFunction[InputT, OutputT],
|
|
94
|
+
type_params=(StateT, DepsT, InputT, OutputT),
|
|
95
|
+
)
|
|
96
|
+
"""
|
|
97
|
+
A function used for reducing inputs to a join node.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def reduce_null(current: None, inputs: Any) -> None:
|
|
102
|
+
"""A reducer that discards all input data and returns None."""
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def reduce_list_append(current: list[T], inputs: T) -> list[T]:
|
|
107
|
+
"""A reducer that appends to a list."""
|
|
108
|
+
current.append(inputs)
|
|
109
|
+
return current
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def reduce_list_extend(current: list[T], inputs: Iterable[T]) -> list[T]:
|
|
113
|
+
"""A reducer that extends a list."""
|
|
114
|
+
current.extend(inputs)
|
|
115
|
+
return current
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def reduce_dict_update(current: dict[K, V], inputs: Mapping[K, V]) -> dict[K, V]:
|
|
119
|
+
"""A reducer that updates a dict."""
|
|
120
|
+
current.update(inputs)
|
|
121
|
+
return current
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class SupportsSum(Protocol):
|
|
125
|
+
"""A protocol for a type that supports adding to itself."""
|
|
126
|
+
|
|
127
|
+
@abstractmethod
|
|
128
|
+
def __add__(self, other: Self, /) -> Self:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
NumericT = TypeVar('NumericT', bound=SupportsSum, infer_variance=True)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def reduce_sum(current: NumericT, inputs: NumericT) -> NumericT:
|
|
136
|
+
"""A reducer that sums numbers."""
|
|
137
|
+
return current + inputs
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class ReduceFirstValue(Generic[T]):
|
|
142
|
+
"""A reducer that returns the first value it encounters, and cancels all other tasks."""
|
|
143
|
+
|
|
144
|
+
def __call__(self, ctx: ReducerContext[object, object], current: T, inputs: T) -> T:
|
|
145
|
+
"""The reducer function."""
|
|
146
|
+
ctx.cancel_sibling_tasks()
|
|
147
|
+
return inputs
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass(init=False)
|
|
151
|
+
class Join(Generic[StateT, DepsT, InputT, OutputT]):
|
|
152
|
+
"""A join operation that synchronizes and aggregates parallel execution paths.
|
|
153
|
+
|
|
154
|
+
A join defines how to combine outputs from multiple parallel execution paths
|
|
155
|
+
using a [`ReducerFunction`][pydantic_graph.beta.join.ReducerFunction]. It specifies which fork
|
|
156
|
+
it joins (if any) and manages the initialization of reducers.
|
|
157
|
+
|
|
158
|
+
Type Parameters:
|
|
159
|
+
StateT: The type of the graph state
|
|
160
|
+
DepsT: The type of the dependencies
|
|
161
|
+
InputT: The type of input data to join
|
|
162
|
+
OutputT: The type of the final joined output
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
id: JoinID
|
|
166
|
+
_reducer: ReducerFunction[StateT, DepsT, InputT, OutputT]
|
|
167
|
+
_initial_factory: Callable[[], OutputT]
|
|
168
|
+
parent_fork_id: ForkID | None
|
|
169
|
+
preferred_parent_fork: Literal['closest', 'farthest']
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
*,
|
|
174
|
+
id: JoinID,
|
|
175
|
+
reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
|
|
176
|
+
initial_factory: Callable[[], OutputT],
|
|
177
|
+
parent_fork_id: ForkID | None = None,
|
|
178
|
+
preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
|
|
179
|
+
):
|
|
180
|
+
self.id = id
|
|
181
|
+
self._reducer = reducer
|
|
182
|
+
self._initial_factory = initial_factory
|
|
183
|
+
self.parent_fork_id = parent_fork_id
|
|
184
|
+
self.preferred_parent_fork = preferred_parent_fork
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def reducer(self):
|
|
188
|
+
return self._reducer
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def initial_factory(self):
|
|
192
|
+
return self._initial_factory
|
|
193
|
+
|
|
194
|
+
def reduce(self, ctx: ReducerContext[StateT, DepsT], current: OutputT, inputs: InputT) -> OutputT:
|
|
195
|
+
n_parameters = len(inspect.signature(self.reducer).parameters)
|
|
196
|
+
if n_parameters == 2:
|
|
197
|
+
return cast(PlainReducerFunction[InputT, OutputT], self.reducer)(current, inputs)
|
|
198
|
+
else:
|
|
199
|
+
return cast(ContextReducerFunction[StateT, DepsT, InputT, OutputT], self.reducer)(ctx, current, inputs)
|
|
200
|
+
|
|
201
|
+
@overload
|
|
202
|
+
def as_node(self, inputs: None = None) -> JoinNode[StateT, DepsT]: ...
|
|
203
|
+
|
|
204
|
+
@overload
|
|
205
|
+
def as_node(self, inputs: InputT) -> JoinNode[StateT, DepsT]: ...
|
|
206
|
+
|
|
207
|
+
def as_node(self, inputs: InputT | None = None) -> JoinNode[StateT, DepsT]:
|
|
208
|
+
"""Create a step node with bound inputs.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
inputs: The input data to bind to this step, or None
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs
|
|
215
|
+
"""
|
|
216
|
+
return JoinNode(self, inputs)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@dataclass
|
|
220
|
+
class JoinNode(BaseNode[StateT, DepsT, Any]):
|
|
221
|
+
"""A base node that represents a join item with bound inputs.
|
|
222
|
+
|
|
223
|
+
JoinNode bridges between the v1 and v2 graph execution systems by wrapping
|
|
224
|
+
a [`Join`][pydantic_graph.beta.join.Join] with bound inputs in a BaseNode interface.
|
|
225
|
+
It is not meant to be run directly but rather used to indicate transitions
|
|
226
|
+
to v2-style steps.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
join: Join[StateT, DepsT, Any, Any]
|
|
230
|
+
"""The step to execute."""
|
|
231
|
+
|
|
232
|
+
inputs: Any
|
|
233
|
+
"""The inputs bound to this step."""
|
|
234
|
+
|
|
235
|
+
async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]:
|
|
236
|
+
"""Attempt to run the join node.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
ctx: The graph execution context
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
The result of step execution
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
NotImplementedError: Always raised as StepNode is not meant to be run directly
|
|
246
|
+
"""
|
|
247
|
+
raise NotImplementedError(
|
|
248
|
+
'`JoinNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.'
|
|
249
|
+
)
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from typing_extensions import assert_never
|
|
8
|
+
|
|
9
|
+
from pydantic_graph.beta.decision import Decision
|
|
10
|
+
from pydantic_graph.beta.id_types import NodeID
|
|
11
|
+
from pydantic_graph.beta.join import Join
|
|
12
|
+
from pydantic_graph.beta.node import EndNode, Fork, StartNode
|
|
13
|
+
from pydantic_graph.beta.node_types import AnyNode
|
|
14
|
+
from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, MapMarker, Path
|
|
15
|
+
from pydantic_graph.beta.step import Step
|
|
16
|
+
|
|
17
|
+
DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32'
|
|
18
|
+
"""The default CSS to use for highlighting nodes."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
StateDiagramDirection = Literal['TB', 'LR', 'RL', 'BT']
|
|
22
|
+
"""Used to specify the direction of the state diagram generated by mermaid.
|
|
23
|
+
|
|
24
|
+
- `'TB'`: Top to bottom, this is the default for mermaid charts.
|
|
25
|
+
- `'LR'`: Left to right
|
|
26
|
+
- `'RL'`: Right to left
|
|
27
|
+
- `'BT'`: Bottom to top
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
NodeKind = Literal['broadcast', 'map', 'join', 'start', 'end', 'step', 'decision']
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class MermaidNode:
|
|
35
|
+
"""A mermaid node."""
|
|
36
|
+
|
|
37
|
+
id: str
|
|
38
|
+
kind: NodeKind
|
|
39
|
+
label: str | None
|
|
40
|
+
note: str | None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class MermaidEdge:
|
|
45
|
+
"""A mermaid edge."""
|
|
46
|
+
|
|
47
|
+
start_id: str
|
|
48
|
+
end_id: str
|
|
49
|
+
label: str | None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def build_mermaid_graph( # noqa: C901
|
|
53
|
+
graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
|
|
54
|
+
) -> MermaidGraph:
|
|
55
|
+
"""Build a mermaid graph."""
|
|
56
|
+
nodes: list[MermaidNode] = []
|
|
57
|
+
edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list)
|
|
58
|
+
|
|
59
|
+
def _collect_edges(path: Path, last_source_id: NodeID) -> None:
|
|
60
|
+
working_label: str | None = None
|
|
61
|
+
for item in path.items:
|
|
62
|
+
assert not isinstance(item, MapMarker | BroadcastMarker), 'These should be removed during Graph building'
|
|
63
|
+
if isinstance(item, LabelMarker):
|
|
64
|
+
working_label = item.label
|
|
65
|
+
elif isinstance(item, DestinationMarker):
|
|
66
|
+
edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.destination_id, working_label))
|
|
67
|
+
|
|
68
|
+
for node_id, node in graph_nodes.items():
|
|
69
|
+
kind: NodeKind
|
|
70
|
+
label: str | None = None
|
|
71
|
+
note: str | None = None
|
|
72
|
+
if isinstance(node, StartNode):
|
|
73
|
+
kind = 'start'
|
|
74
|
+
elif isinstance(node, EndNode):
|
|
75
|
+
kind = 'end'
|
|
76
|
+
elif isinstance(node, Step):
|
|
77
|
+
kind = 'step'
|
|
78
|
+
label = node.label
|
|
79
|
+
elif isinstance(node, Join):
|
|
80
|
+
kind = 'join'
|
|
81
|
+
elif isinstance(node, Fork):
|
|
82
|
+
kind = 'map' if node.is_map else 'broadcast'
|
|
83
|
+
elif isinstance(node, Decision):
|
|
84
|
+
kind = 'decision'
|
|
85
|
+
note = node.note
|
|
86
|
+
else:
|
|
87
|
+
assert_never(node)
|
|
88
|
+
|
|
89
|
+
source_node = MermaidNode(id=node_id, kind=kind, label=label, note=note)
|
|
90
|
+
nodes.append(source_node)
|
|
91
|
+
|
|
92
|
+
for k, v in graph_edges_by_source.items():
|
|
93
|
+
for path in v:
|
|
94
|
+
_collect_edges(path, k)
|
|
95
|
+
|
|
96
|
+
for node in graph_nodes.values():
|
|
97
|
+
if isinstance(node, Decision):
|
|
98
|
+
for branch in node.branches:
|
|
99
|
+
_collect_edges(branch.path, node.id)
|
|
100
|
+
|
|
101
|
+
# Add edges in the same order that we added nodes
|
|
102
|
+
edges: list[MermaidEdge] = sum([edges_by_source.get(node.id, []) for node in nodes], list[MermaidEdge]())
|
|
103
|
+
return MermaidGraph(nodes, edges)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class MermaidGraph:
|
|
108
|
+
"""A mermaid graph."""
|
|
109
|
+
|
|
110
|
+
nodes: list[MermaidNode]
|
|
111
|
+
edges: list[MermaidEdge]
|
|
112
|
+
|
|
113
|
+
title: str | None = None
|
|
114
|
+
direction: StateDiagramDirection | None = None
|
|
115
|
+
|
|
116
|
+
def render(
|
|
117
|
+
self,
|
|
118
|
+
direction: StateDiagramDirection | None = None,
|
|
119
|
+
title: str | None = None,
|
|
120
|
+
edge_labels: bool = True,
|
|
121
|
+
):
|
|
122
|
+
lines: list[str] = []
|
|
123
|
+
if title:
|
|
124
|
+
lines = ['---', f'title: {title}', '---']
|
|
125
|
+
lines.append('stateDiagram-v2')
|
|
126
|
+
if direction is not None:
|
|
127
|
+
lines.append(f' direction {direction}')
|
|
128
|
+
|
|
129
|
+
nodes, edges = _topological_sort(self.nodes, self.edges)
|
|
130
|
+
for node in nodes:
|
|
131
|
+
# List all nodes in order they were created
|
|
132
|
+
node_lines: list[str] = []
|
|
133
|
+
if node.kind == 'start' or node.kind == 'end':
|
|
134
|
+
pass # Start and end nodes use special [*] syntax in edges
|
|
135
|
+
elif node.kind == 'step':
|
|
136
|
+
line = f' {node.id}'
|
|
137
|
+
if node.label:
|
|
138
|
+
line += f': {node.label}'
|
|
139
|
+
node_lines.append(line)
|
|
140
|
+
elif node.kind == 'join':
|
|
141
|
+
node_lines = [f' state {node.id} <<join>>']
|
|
142
|
+
elif node.kind == 'broadcast' or node.kind == 'map':
|
|
143
|
+
node_lines = [f' state {node.id} <<fork>>']
|
|
144
|
+
elif node.kind == 'decision':
|
|
145
|
+
node_lines = [f' state {node.id} <<choice>>']
|
|
146
|
+
if node.note:
|
|
147
|
+
node_lines.append(f' note right of {node.id}\n {node.note}\n end note')
|
|
148
|
+
else: # pragma: no cover
|
|
149
|
+
assert_never(node.kind)
|
|
150
|
+
lines.extend(node_lines)
|
|
151
|
+
|
|
152
|
+
lines.append('')
|
|
153
|
+
|
|
154
|
+
for edge in edges:
|
|
155
|
+
# Use special [*] syntax for start/end nodes
|
|
156
|
+
render_start_id = '[*]' if edge.start_id == StartNode.id else edge.start_id
|
|
157
|
+
render_end_id = '[*]' if edge.end_id == EndNode.id else edge.end_id
|
|
158
|
+
edge_line = f' {render_start_id} --> {render_end_id}'
|
|
159
|
+
if edge.label and edge_labels:
|
|
160
|
+
edge_line += f': {edge.label}'
|
|
161
|
+
lines.append(edge_line)
|
|
162
|
+
|
|
163
|
+
return '\n'.join(lines)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _topological_sort(
|
|
167
|
+
nodes: list[MermaidNode], edges: list[MermaidEdge]
|
|
168
|
+
) -> tuple[list[MermaidNode], list[MermaidEdge]]:
|
|
169
|
+
"""Sort nodes and edges in a logical topological order.
|
|
170
|
+
|
|
171
|
+
Uses BFS from the start node to assign depths, then sorts:
|
|
172
|
+
- Nodes by their distance from start
|
|
173
|
+
- Edges by the distance of their source and target nodes
|
|
174
|
+
"""
|
|
175
|
+
# Build adjacency list for BFS
|
|
176
|
+
adjacency: dict[str, list[str]] = defaultdict(list)
|
|
177
|
+
for edge in edges:
|
|
178
|
+
adjacency[edge.start_id].append(edge.end_id)
|
|
179
|
+
|
|
180
|
+
# BFS to assign depth to each node (distance from start)
|
|
181
|
+
depths: dict[str, int] = {}
|
|
182
|
+
queue: list[tuple[str, int]] = [(StartNode.id, 0)]
|
|
183
|
+
depths[StartNode.id] = 0
|
|
184
|
+
|
|
185
|
+
while queue:
|
|
186
|
+
node_id, depth = queue.pop(0)
|
|
187
|
+
for next_id in adjacency[node_id]:
|
|
188
|
+
if next_id not in depths: # pragma: no branch
|
|
189
|
+
depths[next_id] = depth + 1
|
|
190
|
+
queue.append((next_id, depth + 1))
|
|
191
|
+
|
|
192
|
+
# Sort nodes by depth (distance from start), then by id for stability
|
|
193
|
+
# Nodes not reachable from start get infinity depth (sorted to end)
|
|
194
|
+
sorted_nodes = sorted(nodes, key=lambda n: (depths.get(n.id, float('inf')), n.id))
|
|
195
|
+
|
|
196
|
+
# Sort edges by source depth, then target depth
|
|
197
|
+
# This ensures edges closer to start come first, edges closer to end come last
|
|
198
|
+
sorted_edges = sorted(
|
|
199
|
+
edges,
|
|
200
|
+
key=lambda e: (
|
|
201
|
+
depths.get(e.start_id, float('inf')),
|
|
202
|
+
depths.get(e.end_id, float('inf')),
|
|
203
|
+
e.start_id,
|
|
204
|
+
e.end_id,
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return sorted_nodes, sorted_edges
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Core node types for graph construction and execution.
|
|
2
|
+
|
|
3
|
+
This module defines the fundamental node types used to build execution graphs,
|
|
4
|
+
including start/end nodes and fork nodes for parallel execution.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Generic
|
|
11
|
+
|
|
12
|
+
from typing_extensions import TypeVar
|
|
13
|
+
|
|
14
|
+
from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
|
|
15
|
+
|
|
16
|
+
StateT = TypeVar('StateT', infer_variance=True)
|
|
17
|
+
"""Type variable for graph state."""
|
|
18
|
+
|
|
19
|
+
OutputT = TypeVar('OutputT', infer_variance=True)
|
|
20
|
+
"""Type variable for node output data."""
|
|
21
|
+
|
|
22
|
+
InputT = TypeVar('InputT', infer_variance=True)
|
|
23
|
+
"""Type variable for node input data."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class StartNode(Generic[OutputT]):
|
|
27
|
+
"""Entry point node for graph execution.
|
|
28
|
+
|
|
29
|
+
The StartNode represents the beginning of a graph execution flow.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
id = NodeID('__start__')
|
|
33
|
+
"""Fixed identifier for the start node."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class EndNode(Generic[InputT]):
|
|
37
|
+
"""Terminal node representing the completion of graph execution.
|
|
38
|
+
|
|
39
|
+
The EndNode marks the successful completion of a graph execution flow
|
|
40
|
+
and can collect the final output data.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
id = NodeID('__end__')
|
|
44
|
+
"""Fixed identifier for the end node."""
|
|
45
|
+
|
|
46
|
+
def _force_variance(self, inputs: InputT) -> None: # pragma: no cover
|
|
47
|
+
"""Force type variance for proper generic typing.
|
|
48
|
+
|
|
49
|
+
This method exists solely for type checking purposes and should never be called.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
inputs: Input data of type InputT.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
RuntimeError: Always, as this method should never be executed.
|
|
56
|
+
"""
|
|
57
|
+
raise RuntimeError('This method should never be called, it is just defined for typing purposes.')
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class Fork(Generic[InputT, OutputT]):
|
|
62
|
+
"""Fork node that creates parallel execution branches.
|
|
63
|
+
|
|
64
|
+
A Fork node splits the execution flow into multiple parallel branches,
|
|
65
|
+
enabling concurrent execution of downstream nodes. It can either map
|
|
66
|
+
a sequence across multiple branches or duplicate data to each branch.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
id: ForkID
|
|
70
|
+
"""Unique identifier for this fork node."""
|
|
71
|
+
|
|
72
|
+
is_map: bool
|
|
73
|
+
"""Determines fork behavior.
|
|
74
|
+
|
|
75
|
+
If True, InputT must be Sequence[OutputT] and each element is sent to a separate branch.
|
|
76
|
+
If False, InputT must be OutputT and the same data is sent to all branches.
|
|
77
|
+
"""
|
|
78
|
+
downstream_join_id: JoinID | None
|
|
79
|
+
"""Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable."""
|
|
80
|
+
|
|
81
|
+
def _force_variance(self, inputs: InputT) -> OutputT: # pragma: no cover
|
|
82
|
+
"""Force type variance for proper generic typing.
|
|
83
|
+
|
|
84
|
+
This method exists solely for type checking purposes and should never be called.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
inputs: Input data to be forked.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Output data type (never actually returned).
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
RuntimeError: Always, as this method should never be executed.
|
|
94
|
+
"""
|
|
95
|
+
raise RuntimeError('This method should never be called, it is just defined for typing purposes.')
|