pydantic-graph 0.2.2__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,76 @@
1
+ """Type definitions for identifiers used throughout the graph execution system.
2
+
3
+ This module defines NewType wrappers and aliases for various ID types used in graph execution,
4
+ providing type safety and clarity when working with different kinds of identifiers.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import re
10
+ import uuid
11
+ from dataclasses import dataclass
12
+ from typing import NewType
13
+
14
+ NodeID = NewType('NodeID', str)
15
+ """Unique identifier for a node in the graph."""
16
+
17
+ NodeRunID = NewType('NodeRunID', str)
18
+ """Unique identifier for a specific execution instance of a node."""
19
+
20
+ # The following aliases are just included for clarity; making them NewTypes is a hassle
21
+ JoinID = NodeID
22
+ """Alias for NodeId when referring to join nodes."""
23
+
24
+ ForkID = NodeID
25
+ """Alias for NodeId when referring to fork nodes."""
26
+
27
+ TaskID = NewType('TaskID', str)
28
+ """Unique identifier for a task within the graph execution."""
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class ForkStackItem:
33
+ """Represents a single fork point in the execution stack.
34
+
35
+ When a node creates multiple parallel execution paths (forks), each fork is tracked
36
+ using a ForkStackItem. This allows the system to maintain the execution hierarchy
37
+ and coordinate parallel branches of execution.
38
+ """
39
+
40
+ fork_id: ForkID
41
+ """The ID of the node that created this fork."""
42
+ node_run_id: NodeRunID
43
+ """The ID associated to the specific run of the node that created this fork."""
44
+ thread_index: int
45
+ """The index of the execution "thread" created during the node run that created this fork.
46
+
47
+ This is largely intended for observability/debugging; it may eventually be used to ensure idempotency."""
48
+
49
+
50
+ ForkStack = tuple[ForkStackItem, ...]
51
+ """A stack of fork items representing the full hierarchy of parallel execution branches.
52
+
53
+ The fork stack tracks the complete path through nested parallel executions,
54
+ allowing the system to coordinate and join parallel branches correctly.
55
+ """
56
+
57
+
58
+ def generate_placeholder_node_id(label: str) -> str:
59
+ """Generate a placeholder node ID, to be replaced during graph building."""
60
+ return f'{_NODE_ID_PLACEHOLDER_PREFIX}:{label}:{uuid.uuid4()}'
61
+
62
+
63
+ def replace_placeholder_id(node_id: NodeID) -> str:
64
+ """Returns whether a given NodeID is a placeholder node ID which should be replaced during graph building."""
65
+ return re.sub(rf'{_NODE_ID_PLACEHOLDER_PREFIX}:([^:]+):.*', r'\1', node_id)
66
+
67
+
68
+ _NODE_ID_PLACEHOLDER_PREFIX = '__placeholder__'
69
+ """
70
+ When Node IDs are required but not specified when building a graph, we generate placeholder values
71
+ using this prefix followed by a random string.
72
+
73
+ During graph building, we replace these with simpler and deterministically-selected values.
74
+ This ensures that the node IDs are stable when rebuilding the graph, and makes the generated mermaid diagrams etc.
75
+ easier to read.
76
+ """
@@ -0,0 +1,249 @@
1
+ """Join operations and reducers for graph execution.
2
+
3
+ This module provides the core components for joining parallel execution paths
4
+ in a graph, including various reducer types that aggregate data from multiple
5
+ sources into a single output.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import inspect
11
+ from abc import abstractmethod
12
+ from collections.abc import Callable, Iterable, Mapping
13
+ from dataclasses import dataclass
14
+ from typing import Any, Generic, Literal, cast, overload
15
+
16
+ from typing_extensions import Protocol, Self, TypeAliasType, TypeVar
17
+
18
+ from pydantic_graph import BaseNode, End, GraphRunContext
19
+ from pydantic_graph.beta.id_types import ForkID, ForkStack, JoinID
20
+
21
+ StateT = TypeVar('StateT', infer_variance=True)
22
+ DepsT = TypeVar('DepsT', infer_variance=True)
23
+ InputT = TypeVar('InputT', infer_variance=True)
24
+ OutputT = TypeVar('OutputT', infer_variance=True)
25
+ T = TypeVar('T', infer_variance=True)
26
+ K = TypeVar('K', infer_variance=True)
27
+ V = TypeVar('V', infer_variance=True)
28
+
29
+
30
+ # TODO(P1): I guess we should make this class private, etc.
31
+ @dataclass
32
+ class JoinState:
33
+ """The state of a join during graph execution associated to a particular fork run."""
34
+
35
+ current: Any
36
+ downstream_fork_stack: ForkStack
37
+ cancelled_sibling_tasks: bool = False
38
+
39
+
40
+ @dataclass(init=False)
41
+ class ReducerContext(Generic[StateT, DepsT]):
42
+ """Context information passed to reducer functions during graph execution.
43
+
44
+ The reducer context provides access to the current graph state and dependencies.
45
+
46
+ Type Parameters:
47
+ StateT: The type of the graph state
48
+ DepsT: The type of the dependencies
49
+ """
50
+
51
+ _state: StateT
52
+ """The current graph state."""
53
+ _deps: DepsT
54
+ """The dependencies of the current graph run."""
55
+ _join_state: JoinState
56
+ """The JoinState for this reducer context."""
57
+
58
+ def __init__(self, *, state: StateT, deps: DepsT, join_state: JoinState):
59
+ self._state = state
60
+ self._deps = deps
61
+ self._join_state = join_state
62
+
63
+ @property
64
+ def state(self) -> StateT:
65
+ """The state of the graph run."""
66
+ return self._state
67
+
68
+ @property
69
+ def deps(self) -> DepsT:
70
+ """The deps for the graph run."""
71
+ return self._deps
72
+
73
+ def cancel_sibling_tasks(self):
74
+ """Cancel all sibling tasks created from the same fork.
75
+
76
+ You can call this if you want your join to have early-stopping behavior.
77
+ """
78
+ self._join_state.cancelled_sibling_tasks = True
79
+
80
+
81
+ PlainReducerFunction = TypeAliasType(
82
+ 'PlainReducerFunction',
83
+ Callable[[OutputT, InputT], OutputT],
84
+ type_params=(InputT, OutputT),
85
+ )
86
+ ContextReducerFunction = TypeAliasType(
87
+ 'ContextReducerFunction',
88
+ Callable[[ReducerContext[StateT, DepsT], OutputT, InputT], OutputT],
89
+ type_params=(StateT, DepsT, InputT, OutputT),
90
+ )
91
+ ReducerFunction = TypeAliasType(
92
+ 'ReducerFunction',
93
+ ContextReducerFunction[StateT, DepsT, InputT, OutputT] | PlainReducerFunction[InputT, OutputT],
94
+ type_params=(StateT, DepsT, InputT, OutputT),
95
+ )
96
+ """
97
+ A function used for reducing inputs to a join node.
98
+ """
99
+
100
+
101
+ def reduce_null(current: None, inputs: Any) -> None:
102
+ """A reducer that discards all input data and returns None."""
103
+ return None
104
+
105
+
106
+ def reduce_list_append(current: list[T], inputs: T) -> list[T]:
107
+ """A reducer that appends to a list."""
108
+ current.append(inputs)
109
+ return current
110
+
111
+
112
+ def reduce_list_extend(current: list[T], inputs: Iterable[T]) -> list[T]:
113
+ """A reducer that extends a list."""
114
+ current.extend(inputs)
115
+ return current
116
+
117
+
118
+ def reduce_dict_update(current: dict[K, V], inputs: Mapping[K, V]) -> dict[K, V]:
119
+ """A reducer that updates a dict."""
120
+ current.update(inputs)
121
+ return current
122
+
123
+
124
+ class SupportsSum(Protocol):
125
+ """A protocol for a type that supports adding to itself."""
126
+
127
+ @abstractmethod
128
+ def __add__(self, other: Self, /) -> Self:
129
+ pass
130
+
131
+
132
+ NumericT = TypeVar('NumericT', bound=SupportsSum, infer_variance=True)
133
+
134
+
135
+ def reduce_sum(current: NumericT, inputs: NumericT) -> NumericT:
136
+ """A reducer that sums numbers."""
137
+ return current + inputs
138
+
139
+
140
+ @dataclass
141
+ class ReduceFirstValue(Generic[T]):
142
+ """A reducer that returns the first value it encounters, and cancels all other tasks."""
143
+
144
+ def __call__(self, ctx: ReducerContext[object, object], current: T, inputs: T) -> T:
145
+ """The reducer function."""
146
+ ctx.cancel_sibling_tasks()
147
+ return inputs
148
+
149
+
150
+ @dataclass(init=False)
151
+ class Join(Generic[StateT, DepsT, InputT, OutputT]):
152
+ """A join operation that synchronizes and aggregates parallel execution paths.
153
+
154
+ A join defines how to combine outputs from multiple parallel execution paths
155
+ using a [`ReducerFunction`][pydantic_graph.beta.join.ReducerFunction]. It specifies which fork
156
+ it joins (if any) and manages the initialization of reducers.
157
+
158
+ Type Parameters:
159
+ StateT: The type of the graph state
160
+ DepsT: The type of the dependencies
161
+ InputT: The type of input data to join
162
+ OutputT: The type of the final joined output
163
+ """
164
+
165
+ id: JoinID
166
+ _reducer: ReducerFunction[StateT, DepsT, InputT, OutputT]
167
+ _initial_factory: Callable[[], OutputT]
168
+ parent_fork_id: ForkID | None
169
+ preferred_parent_fork: Literal['closest', 'farthest']
170
+
171
+ def __init__(
172
+ self,
173
+ *,
174
+ id: JoinID,
175
+ reducer: ReducerFunction[StateT, DepsT, InputT, OutputT],
176
+ initial_factory: Callable[[], OutputT],
177
+ parent_fork_id: ForkID | None = None,
178
+ preferred_parent_fork: Literal['farthest', 'closest'] = 'farthest',
179
+ ):
180
+ self.id = id
181
+ self._reducer = reducer
182
+ self._initial_factory = initial_factory
183
+ self.parent_fork_id = parent_fork_id
184
+ self.preferred_parent_fork = preferred_parent_fork
185
+
186
+ @property
187
+ def reducer(self):
188
+ return self._reducer
189
+
190
+ @property
191
+ def initial_factory(self):
192
+ return self._initial_factory
193
+
194
+ def reduce(self, ctx: ReducerContext[StateT, DepsT], current: OutputT, inputs: InputT) -> OutputT:
195
+ n_parameters = len(inspect.signature(self.reducer).parameters)
196
+ if n_parameters == 2:
197
+ return cast(PlainReducerFunction[InputT, OutputT], self.reducer)(current, inputs)
198
+ else:
199
+ return cast(ContextReducerFunction[StateT, DepsT, InputT, OutputT], self.reducer)(ctx, current, inputs)
200
+
201
+ @overload
202
+ def as_node(self, inputs: None = None) -> JoinNode[StateT, DepsT]: ...
203
+
204
+ @overload
205
+ def as_node(self, inputs: InputT) -> JoinNode[StateT, DepsT]: ...
206
+
207
+ def as_node(self, inputs: InputT | None = None) -> JoinNode[StateT, DepsT]:
208
+ """Create a step node with bound inputs.
209
+
210
+ Args:
211
+ inputs: The input data to bind to this step, or None
212
+
213
+ Returns:
214
+ A [`StepNode`][pydantic_graph.beta.step.StepNode] with this step and the bound inputs
215
+ """
216
+ return JoinNode(self, inputs)
217
+
218
+
219
+ @dataclass
220
+ class JoinNode(BaseNode[StateT, DepsT, Any]):
221
+ """A base node that represents a join item with bound inputs.
222
+
223
+ JoinNode bridges between the v1 and v2 graph execution systems by wrapping
224
+ a [`Join`][pydantic_graph.beta.join.Join] with bound inputs in a BaseNode interface.
225
+ It is not meant to be run directly but rather used to indicate transitions
226
+ to v2-style steps.
227
+ """
228
+
229
+ join: Join[StateT, DepsT, Any, Any]
230
+ """The step to execute."""
231
+
232
+ inputs: Any
233
+ """The inputs bound to this step."""
234
+
235
+ async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[Any]:
236
+ """Attempt to run the join node.
237
+
238
+ Args:
239
+ ctx: The graph execution context
240
+
241
+ Returns:
242
+ The result of step execution
243
+
244
+ Raises:
245
+ NotImplementedError: Always raised as StepNode is not meant to be run directly
246
+ """
247
+ raise NotImplementedError(
248
+ '`JoinNode` is not meant to be run directly, it is meant to be used in `BaseNode` subclasses to indicate a transition to v2-style steps.'
249
+ )
@@ -0,0 +1,208 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import Literal
6
+
7
+ from typing_extensions import assert_never
8
+
9
+ from pydantic_graph.beta.decision import Decision
10
+ from pydantic_graph.beta.id_types import NodeID
11
+ from pydantic_graph.beta.join import Join
12
+ from pydantic_graph.beta.node import EndNode, Fork, StartNode
13
+ from pydantic_graph.beta.node_types import AnyNode
14
+ from pydantic_graph.beta.paths import BroadcastMarker, DestinationMarker, LabelMarker, MapMarker, Path
15
+ from pydantic_graph.beta.step import Step
16
+
17
+ DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32'
18
+ """The default CSS to use for highlighting nodes."""
19
+
20
+
21
+ StateDiagramDirection = Literal['TB', 'LR', 'RL', 'BT']
22
+ """Used to specify the direction of the state diagram generated by mermaid.
23
+
24
+ - `'TB'`: Top to bottom, this is the default for mermaid charts.
25
+ - `'LR'`: Left to right
26
+ - `'RL'`: Right to left
27
+ - `'BT'`: Bottom to top
28
+ """
29
+
30
+ NodeKind = Literal['broadcast', 'map', 'join', 'start', 'end', 'step', 'decision']
31
+
32
+
33
+ @dataclass
34
+ class MermaidNode:
35
+ """A mermaid node."""
36
+
37
+ id: str
38
+ kind: NodeKind
39
+ label: str | None
40
+ note: str | None
41
+
42
+
43
+ @dataclass
44
+ class MermaidEdge:
45
+ """A mermaid edge."""
46
+
47
+ start_id: str
48
+ end_id: str
49
+ label: str | None
50
+
51
+
52
+ def build_mermaid_graph( # noqa: C901
53
+ graph_nodes: dict[NodeID, AnyNode], graph_edges_by_source: dict[NodeID, list[Path]]
54
+ ) -> MermaidGraph:
55
+ """Build a mermaid graph."""
56
+ nodes: list[MermaidNode] = []
57
+ edges_by_source: dict[str, list[MermaidEdge]] = defaultdict(list)
58
+
59
+ def _collect_edges(path: Path, last_source_id: NodeID) -> None:
60
+ working_label: str | None = None
61
+ for item in path.items:
62
+ assert not isinstance(item, MapMarker | BroadcastMarker), 'These should be removed during Graph building'
63
+ if isinstance(item, LabelMarker):
64
+ working_label = item.label
65
+ elif isinstance(item, DestinationMarker):
66
+ edges_by_source[last_source_id].append(MermaidEdge(last_source_id, item.destination_id, working_label))
67
+
68
+ for node_id, node in graph_nodes.items():
69
+ kind: NodeKind
70
+ label: str | None = None
71
+ note: str | None = None
72
+ if isinstance(node, StartNode):
73
+ kind = 'start'
74
+ elif isinstance(node, EndNode):
75
+ kind = 'end'
76
+ elif isinstance(node, Step):
77
+ kind = 'step'
78
+ label = node.label
79
+ elif isinstance(node, Join):
80
+ kind = 'join'
81
+ elif isinstance(node, Fork):
82
+ kind = 'map' if node.is_map else 'broadcast'
83
+ elif isinstance(node, Decision):
84
+ kind = 'decision'
85
+ note = node.note
86
+ else:
87
+ assert_never(node)
88
+
89
+ source_node = MermaidNode(id=node_id, kind=kind, label=label, note=note)
90
+ nodes.append(source_node)
91
+
92
+ for k, v in graph_edges_by_source.items():
93
+ for path in v:
94
+ _collect_edges(path, k)
95
+
96
+ for node in graph_nodes.values():
97
+ if isinstance(node, Decision):
98
+ for branch in node.branches:
99
+ _collect_edges(branch.path, node.id)
100
+
101
+ # Add edges in the same order that we added nodes
102
+ edges: list[MermaidEdge] = sum([edges_by_source.get(node.id, []) for node in nodes], list[MermaidEdge]())
103
+ return MermaidGraph(nodes, edges)
104
+
105
+
106
+ @dataclass
107
+ class MermaidGraph:
108
+ """A mermaid graph."""
109
+
110
+ nodes: list[MermaidNode]
111
+ edges: list[MermaidEdge]
112
+
113
+ title: str | None = None
114
+ direction: StateDiagramDirection | None = None
115
+
116
+ def render(
117
+ self,
118
+ direction: StateDiagramDirection | None = None,
119
+ title: str | None = None,
120
+ edge_labels: bool = True,
121
+ ):
122
+ lines: list[str] = []
123
+ if title:
124
+ lines = ['---', f'title: {title}', '---']
125
+ lines.append('stateDiagram-v2')
126
+ if direction is not None:
127
+ lines.append(f' direction {direction}')
128
+
129
+ nodes, edges = _topological_sort(self.nodes, self.edges)
130
+ for node in nodes:
131
+ # List all nodes in order they were created
132
+ node_lines: list[str] = []
133
+ if node.kind == 'start' or node.kind == 'end':
134
+ pass # Start and end nodes use special [*] syntax in edges
135
+ elif node.kind == 'step':
136
+ line = f' {node.id}'
137
+ if node.label:
138
+ line += f': {node.label}'
139
+ node_lines.append(line)
140
+ elif node.kind == 'join':
141
+ node_lines = [f' state {node.id} <<join>>']
142
+ elif node.kind == 'broadcast' or node.kind == 'map':
143
+ node_lines = [f' state {node.id} <<fork>>']
144
+ elif node.kind == 'decision':
145
+ node_lines = [f' state {node.id} <<choice>>']
146
+ if node.note:
147
+ node_lines.append(f' note right of {node.id}\n {node.note}\n end note')
148
+ else: # pragma: no cover
149
+ assert_never(node.kind)
150
+ lines.extend(node_lines)
151
+
152
+ lines.append('')
153
+
154
+ for edge in edges:
155
+ # Use special [*] syntax for start/end nodes
156
+ render_start_id = '[*]' if edge.start_id == StartNode.id else edge.start_id
157
+ render_end_id = '[*]' if edge.end_id == EndNode.id else edge.end_id
158
+ edge_line = f' {render_start_id} --> {render_end_id}'
159
+ if edge.label and edge_labels:
160
+ edge_line += f': {edge.label}'
161
+ lines.append(edge_line)
162
+
163
+ return '\n'.join(lines)
164
+
165
+
166
+ def _topological_sort(
167
+ nodes: list[MermaidNode], edges: list[MermaidEdge]
168
+ ) -> tuple[list[MermaidNode], list[MermaidEdge]]:
169
+ """Sort nodes and edges in a logical topological order.
170
+
171
+ Uses BFS from the start node to assign depths, then sorts:
172
+ - Nodes by their distance from start
173
+ - Edges by the distance of their source and target nodes
174
+ """
175
+ # Build adjacency list for BFS
176
+ adjacency: dict[str, list[str]] = defaultdict(list)
177
+ for edge in edges:
178
+ adjacency[edge.start_id].append(edge.end_id)
179
+
180
+ # BFS to assign depth to each node (distance from start)
181
+ depths: dict[str, int] = {}
182
+ queue: list[tuple[str, int]] = [(StartNode.id, 0)]
183
+ depths[StartNode.id] = 0
184
+
185
+ while queue:
186
+ node_id, depth = queue.pop(0)
187
+ for next_id in adjacency[node_id]:
188
+ if next_id not in depths: # pragma: no branch
189
+ depths[next_id] = depth + 1
190
+ queue.append((next_id, depth + 1))
191
+
192
+ # Sort nodes by depth (distance from start), then by id for stability
193
+ # Nodes not reachable from start get infinity depth (sorted to end)
194
+ sorted_nodes = sorted(nodes, key=lambda n: (depths.get(n.id, float('inf')), n.id))
195
+
196
+ # Sort edges by source depth, then target depth
197
+ # This ensures edges closer to start come first, edges closer to end come last
198
+ sorted_edges = sorted(
199
+ edges,
200
+ key=lambda e: (
201
+ depths.get(e.start_id, float('inf')),
202
+ depths.get(e.end_id, float('inf')),
203
+ e.start_id,
204
+ e.end_id,
205
+ ),
206
+ )
207
+
208
+ return sorted_nodes, sorted_edges
@@ -0,0 +1,95 @@
1
+ """Core node types for graph construction and execution.
2
+
3
+ This module defines the fundamental node types used to build execution graphs,
4
+ including start/end nodes and fork nodes for parallel execution.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Generic
11
+
12
+ from typing_extensions import TypeVar
13
+
14
+ from pydantic_graph.beta.id_types import ForkID, JoinID, NodeID
15
+
16
+ StateT = TypeVar('StateT', infer_variance=True)
17
+ """Type variable for graph state."""
18
+
19
+ OutputT = TypeVar('OutputT', infer_variance=True)
20
+ """Type variable for node output data."""
21
+
22
+ InputT = TypeVar('InputT', infer_variance=True)
23
+ """Type variable for node input data."""
24
+
25
+
26
+ class StartNode(Generic[OutputT]):
27
+ """Entry point node for graph execution.
28
+
29
+ The StartNode represents the beginning of a graph execution flow.
30
+ """
31
+
32
+ id = NodeID('__start__')
33
+ """Fixed identifier for the start node."""
34
+
35
+
36
+ class EndNode(Generic[InputT]):
37
+ """Terminal node representing the completion of graph execution.
38
+
39
+ The EndNode marks the successful completion of a graph execution flow
40
+ and can collect the final output data.
41
+ """
42
+
43
+ id = NodeID('__end__')
44
+ """Fixed identifier for the end node."""
45
+
46
+ def _force_variance(self, inputs: InputT) -> None: # pragma: no cover
47
+ """Force type variance for proper generic typing.
48
+
49
+ This method exists solely for type checking purposes and should never be called.
50
+
51
+ Args:
52
+ inputs: Input data of type InputT.
53
+
54
+ Raises:
55
+ RuntimeError: Always, as this method should never be executed.
56
+ """
57
+ raise RuntimeError('This method should never be called, it is just defined for typing purposes.')
58
+
59
+
60
+ @dataclass
61
+ class Fork(Generic[InputT, OutputT]):
62
+ """Fork node that creates parallel execution branches.
63
+
64
+ A Fork node splits the execution flow into multiple parallel branches,
65
+ enabling concurrent execution of downstream nodes. It can either map
66
+ a sequence across multiple branches or duplicate data to each branch.
67
+ """
68
+
69
+ id: ForkID
70
+ """Unique identifier for this fork node."""
71
+
72
+ is_map: bool
73
+ """Determines fork behavior.
74
+
75
+ If True, InputT must be Sequence[OutputT] and each element is sent to a separate branch.
76
+ If False, InputT must be OutputT and the same data is sent to all branches.
77
+ """
78
+ downstream_join_id: JoinID | None
79
+ """Optional identifier of a downstream join node that should be jumped to if mapping an empty iterable."""
80
+
81
+ def _force_variance(self, inputs: InputT) -> OutputT: # pragma: no cover
82
+ """Force type variance for proper generic typing.
83
+
84
+ This method exists solely for type checking purposes and should never be called.
85
+
86
+ Args:
87
+ inputs: Input data to be forked.
88
+
89
+ Returns:
90
+ Output data type (never actually returned).
91
+
92
+ Raises:
93
+ RuntimeError: Always, as this method should never be executed.
94
+ """
95
+ raise RuntimeError('This method should never be called, it is just defined for typing purposes.')