loopgraph 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,425 @@
1
+ """Graph primitives describing workflow structure."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, cast
7
+
8
+ from .._debug import (
9
+ log_branch,
10
+ log_loop_iteration,
11
+ log_parameter,
12
+ log_variable_change,
13
+ )
14
+ from .types import NodeKind
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Node:
19
+ """Immutable graph node definition."""
20
+
21
+ id: str
22
+ kind: NodeKind
23
+ handler: str
24
+ config: Dict[str, object] = field(default_factory=dict)
25
+ max_visits: Optional[int] = None
26
+ priority: int = 0
27
+ allow_partial_upstream: bool = False
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class Edge:
32
+ """Directed edge linking nodes."""
33
+
34
+ id: str
35
+ source: str
36
+ target: str
37
+ metadata: Dict[str, object] = field(default_factory=dict)
38
+
39
+
40
+ @dataclass
41
+ class Graph:
42
+ """Graph container maintaining adjacency helpers.
43
+
44
+ >>> nodes = {
45
+ ... "start": Node(id="start", kind=NodeKind.TASK, handler="start_handler"),
46
+ ... "end": Node(id="end", kind=NodeKind.TERMINAL, handler="end_handler"),
47
+ ... }
48
+ >>> edges = {
49
+ ... "edge-1": Edge(id="edge-1", source="start", target="end"),
50
+ ... }
51
+ >>> graph = Graph(nodes=nodes, edges=edges)
52
+ >>> graph.validate()
53
+ >>> [node.id for node in graph.downstream_nodes("start")]
54
+ ['end']
55
+ """
56
+
57
+ nodes: Dict[str, Node] = field(default_factory=dict)
58
+ edges: Dict[str, Edge] = field(default_factory=dict)
59
+
60
+ def __post_init__(self) -> None:
61
+ func_name = "Graph.__post_init__"
62
+ log_parameter(func_name, nodes=self.nodes, edges=self.edges)
63
+ self._forward_adj: Dict[str, List[Edge]] = {}
64
+ log_variable_change(func_name, "self._forward_adj", self._forward_adj)
65
+ self._reverse_adj: Dict[str, List[Edge]] = {}
66
+ log_variable_change(func_name, "self._reverse_adj", self._reverse_adj)
67
+ self._rebuild_indices()
68
+
69
+ def _rebuild_indices(self) -> None:
70
+ """Rebuild adjacency maps when graph structure changes."""
71
+ func_name = "Graph._rebuild_indices"
72
+ log_parameter(func_name)
73
+ forward: Dict[str, List[Edge]] = {node_id: [] for node_id in self.nodes}
74
+ log_variable_change(func_name, "forward", forward)
75
+ reverse: Dict[str, List[Edge]] = {node_id: [] for node_id in self.nodes}
76
+ log_variable_change(func_name, "reverse", reverse)
77
+ for iteration, edge in enumerate(self.edges.values()):
78
+ log_loop_iteration(func_name, "edges", iteration)
79
+ forward.setdefault(edge.source, []).append(edge)
80
+ log_variable_change(
81
+ func_name,
82
+ f"forward[{edge.source}]",
83
+ forward.get(edge.source),
84
+ )
85
+ reverse.setdefault(edge.target, []).append(edge)
86
+ log_variable_change(
87
+ func_name,
88
+ f"reverse[{edge.target}]",
89
+ reverse.get(edge.target),
90
+ )
91
+ self._forward_adj = forward
92
+ log_variable_change(func_name, "self._forward_adj", self._forward_adj)
93
+ self._reverse_adj = reverse
94
+ log_variable_change(func_name, "self._reverse_adj", self._reverse_adj)
95
+
96
+ def validate(self) -> None:
97
+ """Validate node handlers and edge connectivity."""
98
+ func_name = "Graph.validate"
99
+ log_parameter(func_name)
100
+ missing_nodes: List[str] = []
101
+ log_variable_change(func_name, "missing_nodes", missing_nodes)
102
+ for iteration, edge in enumerate(self.edges.values()):
103
+ log_loop_iteration(func_name, "edges", iteration)
104
+ if edge.source not in self.nodes:
105
+ log_branch(func_name, "missing_source")
106
+ missing_nodes.append(edge.source)
107
+ log_variable_change(func_name, "missing_nodes", missing_nodes)
108
+ else:
109
+ log_branch(func_name, "existing_source")
110
+ if edge.target not in self.nodes:
111
+ log_branch(func_name, "missing_target")
112
+ missing_nodes.append(edge.target)
113
+ log_variable_change(func_name, "missing_nodes", missing_nodes)
114
+ else:
115
+ log_branch(func_name, "existing_target")
116
+
117
+ if missing_nodes:
118
+ log_branch(func_name, "missing_nodes_error")
119
+ raise ValueError(f"Edges reference unknown nodes: {missing_nodes}")
120
+ log_branch(func_name, "all_edges_valid")
121
+
122
+ invalid_handlers: List[str] = []
123
+ log_variable_change(func_name, "invalid_handlers", invalid_handlers)
124
+ for iteration, node in enumerate(self.nodes.values()):
125
+ log_loop_iteration(func_name, "nodes", iteration)
126
+ if not node.handler:
127
+ log_branch(func_name, "missing_handler")
128
+ invalid_handlers.append(node.id)
129
+ log_variable_change(func_name, "invalid_handlers", invalid_handlers)
130
+ else:
131
+ log_branch(func_name, "handler_present")
132
+
133
+ if node.kind is NodeKind.SWITCH:
134
+ log_branch(func_name, "switch_node_check")
135
+ switch_edges = self._forward_adj.get(node.id, [])
136
+ log_variable_change(func_name, "switch_edges", switch_edges)
137
+ has_self_loop = any(
138
+ edge.source == edge.target == node.id for edge in switch_edges
139
+ )
140
+ log_variable_change(func_name, "has_self_loop", has_self_loop)
141
+ if has_self_loop:
142
+ log_branch(func_name, "switch_self_loop")
143
+ raise ValueError(
144
+ f"Switch node '{node.id}' cannot have a self-loop edge"
145
+ )
146
+ missing_route = [
147
+ edge.id for edge in switch_edges if "route" not in edge.metadata
148
+ ]
149
+ log_variable_change(func_name, "missing_route", missing_route)
150
+ if missing_route:
151
+ log_branch(func_name, "switch_missing_route")
152
+ raise ValueError(
153
+ f"Switch node '{node.id}' requires route metadata on edges: {missing_route}"
154
+ )
155
+ elif node.kind is NodeKind.AGGREGATE:
156
+ log_branch(func_name, "aggregate_node_check")
157
+ required_raw = node.config.get("required")
158
+ log_variable_change(func_name, "required_raw", required_raw)
159
+ if required_raw is not None:
160
+ if not isinstance(required_raw, int):
161
+ log_branch(func_name, "aggregate_required_not_int")
162
+ raise ValueError(
163
+ f"Aggregate node '{node.id}' requires integer 'required' config"
164
+ )
165
+ if required_raw <= 0:
166
+ log_branch(func_name, "aggregate_required_non_positive")
167
+ raise ValueError(
168
+ f"Aggregate node '{node.id}' requires 'required' > 0"
169
+ )
170
+ upstream_count = len(self.upstream_nodes(node.id))
171
+ log_variable_change(
172
+ func_name, "upstream_count", upstream_count
173
+ )
174
+ if required_raw > upstream_count:
175
+ log_branch(func_name, "aggregate_required_exceeds_upstream")
176
+ raise ValueError(
177
+ f"Aggregate node '{node.id}' requires {required_raw} upstream nodes but only has {upstream_count}"
178
+ )
179
+
180
+ if invalid_handlers:
181
+ log_branch(func_name, "invalid_handlers_error")
182
+ raise ValueError(f"Nodes missing handlers: {invalid_handlers}")
183
+ log_branch(func_name, "handlers_valid")
184
+ cycles = self._find_cycles()
185
+ log_variable_change(func_name, "cycles", cycles)
186
+ cycle_node_sets = [set(cycle) for cycle in cycles]
187
+ log_variable_change(func_name, "cycle_node_sets", cycle_node_sets)
188
+ shared_nodes: Set[str] = set()
189
+ log_variable_change(func_name, "shared_nodes", shared_nodes)
190
+ for left_idx, left in enumerate(cycle_node_sets):
191
+ log_loop_iteration(func_name, "left_cycle", left_idx)
192
+ for right_idx in range(left_idx + 1, len(cycle_node_sets)):
193
+ overlap = left & cycle_node_sets[right_idx]
194
+ log_variable_change(func_name, "overlap", overlap)
195
+ if overlap:
196
+ log_branch(func_name, "shared_nodes_detected")
197
+ shared_nodes.update(overlap)
198
+ log_variable_change(func_name, "shared_nodes", shared_nodes)
199
+ if shared_nodes:
200
+ log_branch(func_name, "shared_node_multi_loop_error")
201
+ shared_list = sorted(shared_nodes)
202
+ log_variable_change(func_name, "shared_list", shared_list)
203
+ raise ValueError(
204
+ f"Graph has multi-loop shared nodes: {shared_list}"
205
+ )
206
+ log_branch(func_name, "cycle_validation_passed")
207
+
208
+ @staticmethod
209
+ def _canonical_cycle(cycle: List[str]) -> Tuple[str, ...]:
210
+ """Normalize a cycle so equivalent rotations share one representation."""
211
+ cycle_len = len(cycle)
212
+ rotations = [
213
+ tuple(cycle[index:] + cycle[:index]) for index in range(cycle_len)
214
+ ]
215
+ return min(rotations)
216
+
217
+ def _find_cycles(self) -> List[Tuple[str, ...]]:
218
+ """Find directed simple cycles in the graph."""
219
+ func_name = "Graph._find_cycles"
220
+ log_parameter(func_name)
221
+ adjacency: Dict[str, List[str]] = {
222
+ node_id: [edge.target for edge in self._forward_adj.get(node_id, [])]
223
+ for node_id in self.nodes
224
+ }
225
+ log_variable_change(func_name, "adjacency", adjacency)
226
+ seen: Set[Tuple[str, ...]] = set()
227
+ log_variable_change(func_name, "seen", seen)
228
+
229
+ def dfs(
230
+ start: str,
231
+ current: str,
232
+ path: List[str],
233
+ in_path: Set[str],
234
+ ) -> None:
235
+ for neighbor in adjacency.get(current, []):
236
+ if neighbor == start and len(path) > 1:
237
+ canonical = self._canonical_cycle(path)
238
+ seen.add(canonical)
239
+ continue
240
+ if neighbor in in_path:
241
+ continue
242
+ path.append(neighbor)
243
+ in_path.add(neighbor)
244
+ dfs(start, neighbor, path, in_path)
245
+ in_path.remove(neighbor)
246
+ path.pop()
247
+
248
+ for iteration, start in enumerate(sorted(self.nodes)):
249
+ log_loop_iteration(func_name, "cycle_start_nodes", iteration)
250
+ dfs(start, start, [start], {start})
251
+
252
+ cycles = sorted(seen)
253
+ log_variable_change(func_name, "cycles", cycles)
254
+ return cycles
255
+
256
+ def downstream_nodes(self, node_id: str) -> List[Node]:
257
+ """Return nodes that can be visited after the given node.
258
+
259
+ >>> graph = Graph(
260
+ ... nodes={
261
+ ... "a": Node(id="a", kind=NodeKind.TASK, handler="A"),
262
+ ... "b": Node(id="b", kind=NodeKind.TERMINAL, handler="B"),
263
+ ... },
264
+ ... edges={"e": Edge(id="e", source="a", target="b")},
265
+ ... )
266
+ >>> [node.id for node in graph.downstream_nodes("a")]
267
+ ['b']
268
+ """
269
+ func_name = "Graph.downstream_nodes"
270
+ log_parameter(func_name, node_id=node_id)
271
+ if node_id not in self.nodes:
272
+ log_branch(func_name, "missing_node")
273
+ raise KeyError(f"Node '{node_id}' not found")
274
+ log_branch(func_name, "node_present")
275
+ edges = self._forward_adj.get(node_id, [])
276
+ log_variable_change(func_name, "edges", edges)
277
+ targets = [self.nodes[edge.target] for edge in edges]
278
+ log_variable_change(func_name, "targets", targets)
279
+ return targets
280
+
281
+ def downstream_edges(self, node_id: str) -> List[Edge]:
282
+ """Return edges originating from the given node.
283
+
284
+ >>> graph = Graph(
285
+ ... nodes={
286
+ ... "a": Node(id="a", kind=NodeKind.TASK, handler="A"),
287
+ ... "b": Node(id="b", kind=NodeKind.TASK, handler="B"),
288
+ ... },
289
+ ... edges={"e": Edge(id="e", source="a", target="b")},
290
+ ... )
291
+ >>> [edge.id for edge in graph.downstream_edges("a")]
292
+ ['e']
293
+ """
294
+
295
+ func_name = "Graph.downstream_edges"
296
+ log_parameter(func_name, node_id=node_id)
297
+ if node_id not in self.nodes:
298
+ log_branch(func_name, "missing_node")
299
+ raise KeyError(f"Node '{node_id}' not found")
300
+ log_branch(func_name, "node_present")
301
+ edges = list(self._forward_adj.get(node_id, []))
302
+ log_variable_change(func_name, "edges", edges)
303
+ return edges
304
+
305
+ def upstream_nodes(self, node_id: str) -> List[Node]:
306
+ """Return nodes that must complete before the given node.
307
+
308
+ >>> graph = Graph(
309
+ ... nodes={
310
+ ... "a": Node(id="a", kind=NodeKind.TASK, handler="A"),
311
+ ... "b": Node(id="b", kind=NodeKind.TASK, handler="B"),
312
+ ... },
313
+ ... edges={"e": Edge(id="e", source="a", target="b")},
314
+ ... )
315
+ >>> [node.id for node in graph.upstream_nodes("b")]
316
+ ['a']
317
+ """
318
+ func_name = "Graph.upstream_nodes"
319
+ log_parameter(func_name, node_id=node_id)
320
+ if node_id not in self.nodes:
321
+ log_branch(func_name, "missing_node")
322
+ raise KeyError(f"Node '{node_id}' not found")
323
+ log_branch(func_name, "node_present")
324
+ edges = self._reverse_adj.get(node_id, [])
325
+ log_variable_change(func_name, "edges", edges)
326
+ sources = [self.nodes[edge.source] for edge in edges]
327
+ log_variable_change(func_name, "sources", sources)
328
+ return sources
329
+
330
+ def to_dict(self) -> Dict[str, object]:
331
+ """Serialize the graph to a dictionary."""
332
+ func_name = "Graph.to_dict"
333
+ log_parameter(func_name)
334
+ node_list = [
335
+ {
336
+ "id": node.id,
337
+ "kind": node.kind.value,
338
+ "handler": node.handler,
339
+ "config": node.config,
340
+ "max_visits": node.max_visits,
341
+ "priority": node.priority,
342
+ "allow_partial_upstream": node.allow_partial_upstream,
343
+ }
344
+ for node in self.nodes.values()
345
+ ]
346
+ log_variable_change(func_name, "node_list", node_list)
347
+ edge_list = [
348
+ {
349
+ "id": edge.id,
350
+ "source": edge.source,
351
+ "target": edge.target,
352
+ "metadata": edge.metadata,
353
+ }
354
+ for edge in self.edges.values()
355
+ ]
356
+ log_variable_change(func_name, "edge_list", edge_list)
357
+ payload = {"nodes": node_list, "edges": edge_list}
358
+ log_variable_change(func_name, "payload", payload)
359
+ return payload
360
+
361
+ @classmethod
362
+ def from_dict(cls, payload: Mapping[str, Iterable[Mapping[str, object]]]) -> "Graph":
363
+ """Deserialize a graph from a dictionary payload."""
364
+ func_name = "Graph.from_dict"
365
+ log_parameter(func_name, payload=payload)
366
+ empty_nodes: Iterable[Mapping[str, object]] = []
367
+ node_entries = payload.get("nodes", empty_nodes)
368
+ log_variable_change(func_name, "node_entries", node_entries)
369
+ nodes: Dict[str, Node] = {}
370
+ log_variable_change(func_name, "nodes", nodes)
371
+ for iteration, entry in enumerate(node_entries):
372
+ log_loop_iteration(func_name, "nodes", iteration)
373
+ config_entry = entry.get("config", {})
374
+ config = (
375
+ dict(cast(Mapping[str, Any], config_entry))
376
+ if isinstance(config_entry, Mapping)
377
+ else {}
378
+ )
379
+ log_variable_change(func_name, "config", config)
380
+ max_visits_value = entry.get("max_visits")
381
+ max_visits = cast(Optional[int], max_visits_value)
382
+ log_variable_change(func_name, "max_visits", max_visits)
383
+ priority_value = cast(Optional[int], entry.get("priority"))
384
+ priority = priority_value if priority_value is not None else 0
385
+ log_variable_change(func_name, "priority", priority)
386
+ allow_partial_value = entry.get("allow_partial_upstream", False)
387
+ log_variable_change(func_name, "allow_partial_value", allow_partial_value)
388
+ node = Node(
389
+ id=str(entry["id"]),
390
+ kind=NodeKind(entry["kind"]),
391
+ handler=str(entry["handler"]),
392
+ config=config,
393
+ max_visits=max_visits,
394
+ priority=priority,
395
+ allow_partial_upstream=bool(allow_partial_value),
396
+ )
397
+ nodes[node.id] = node
398
+ log_variable_change(func_name, f"nodes[{node.id}]", nodes[node.id])
399
+
400
+ empty_edges: Iterable[Mapping[str, object]] = []
401
+ edge_entries = payload.get("edges", empty_edges)
402
+ log_variable_change(func_name, "edge_entries", edge_entries)
403
+ edges: Dict[str, Edge] = {}
404
+ log_variable_change(func_name, "edges", edges)
405
+ for iteration, entry in enumerate(edge_entries):
406
+ log_loop_iteration(func_name, "edges", iteration)
407
+ metadata_entry = entry.get("metadata", {})
408
+ metadata = (
409
+ dict(cast(Mapping[str, Any], metadata_entry))
410
+ if isinstance(metadata_entry, Mapping)
411
+ else {}
412
+ )
413
+ log_variable_change(func_name, "metadata", metadata)
414
+ edge = Edge(
415
+ id=str(entry["id"]),
416
+ source=str(entry["source"]),
417
+ target=str(entry["target"]),
418
+ metadata=metadata,
419
+ )
420
+ edges[edge.id] = edge
421
+ log_variable_change(func_name, f"edges[{edge.id}]", edges[edge.id])
422
+
423
+ graph = cls(nodes=nodes, edges=edges)
424
+ log_variable_change(func_name, "graph", graph)
425
+ return graph