loopgraph 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- loopgraph/__init__.py +38 -0
- loopgraph/_debug.py +45 -0
- loopgraph/bus/__init__.py +5 -0
- loopgraph/bus/eventbus.py +186 -0
- loopgraph/concurrency/__init__.py +5 -0
- loopgraph/concurrency/policies.py +181 -0
- loopgraph/core/__init__.py +18 -0
- loopgraph/core/graph.py +425 -0
- loopgraph/core/state.py +443 -0
- loopgraph/core/types.py +72 -0
- loopgraph/diagnostics/__init__.py +5 -0
- loopgraph/diagnostics/inspect.py +70 -0
- loopgraph/persistence/__init__.py +6 -0
- loopgraph/persistence/event_log.py +63 -0
- loopgraph/persistence/snapshot.py +52 -0
- loopgraph/py.typed +0 -0
- loopgraph/registry/__init__.py +1 -0
- loopgraph/registry/function_registry.py +117 -0
- loopgraph/scheduler/__init__.py +5 -0
- loopgraph/scheduler/scheduler.py +569 -0
- loopgraph-0.2.0.dist-info/METADATA +165 -0
- loopgraph-0.2.0.dist-info/RECORD +24 -0
- loopgraph-0.2.0.dist-info/WHEEL +5 -0
- loopgraph-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
"""Core scheduler for executing workflow graphs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from .._debug import (
|
|
8
|
+
log_branch,
|
|
9
|
+
log_loop_iteration,
|
|
10
|
+
log_parameter,
|
|
11
|
+
log_variable_change,
|
|
12
|
+
)
|
|
13
|
+
from ..bus.eventbus import Event, EventBus
|
|
14
|
+
from ..concurrency import ConcurrencyManager, SemaphorePolicy
|
|
15
|
+
from ..core.graph import Edge, Graph, Node
|
|
16
|
+
from ..core.state import ExecutionState
|
|
17
|
+
from ..core.types import EventType, NodeKind, NodeStatus, VisitOutcome
|
|
18
|
+
from ..persistence import EventLog, SnapshotStore
|
|
19
|
+
from ..registry.function_registry import FunctionRegistry
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Scheduler:
|
|
23
|
+
"""Execute graphs by dispatching node handlers.
|
|
24
|
+
|
|
25
|
+
>>> import asyncio
|
|
26
|
+
>>> from loopgraph.core.graph import Edge, Node, NodeKind
|
|
27
|
+
>>> registry = FunctionRegistry()
|
|
28
|
+
>>> registry.register("start", lambda payload: payload + ["start"])
|
|
29
|
+
>>> registry.register("end", lambda payload: f"{payload}-done")
|
|
30
|
+
>>> registry.register("branch", lambda payload: "right")
|
|
31
|
+
>>> graph = Graph(
|
|
32
|
+
... nodes={
|
|
33
|
+
... "start": Node(id="start", kind=NodeKind.TASK, handler="start"),
|
|
34
|
+
... "branch": Node(id="branch", kind=NodeKind.SWITCH, handler="branch"),
|
|
35
|
+
... "end": Node(id="end", kind=NodeKind.TASK, handler="end"),
|
|
36
|
+
... },
|
|
37
|
+
... edges={
|
|
38
|
+
... "e1": Edge(id="e1", source="start", target="branch"),
|
|
39
|
+
... "e2": Edge(id="e2", source="branch", target="end", metadata={"route": "right"}),
|
|
40
|
+
... },
|
|
41
|
+
... )
|
|
42
|
+
>>> async def run() -> Dict[str, Any]:
|
|
43
|
+
... bus = EventBus()
|
|
44
|
+
... policy = SemaphorePolicy(limit=2)
|
|
45
|
+
... scheduler = Scheduler(registry, bus, policy)
|
|
46
|
+
... results = await scheduler.run(graph, initial_payload=[])
|
|
47
|
+
... return results
|
|
48
|
+
>>> asyncio.run(run())
|
|
49
|
+
{'start': ['start'], 'branch': 'right', 'end': 'right-done'}
|
|
50
|
+
>>> from loopgraph.core.types import EventType
|
|
51
|
+
>>> from loopgraph.persistence import InMemoryEventLog, InMemorySnapshotStore
|
|
52
|
+
>>> async def run_with_persistence() -> Dict[str, Any]:
|
|
53
|
+
... bus = EventBus()
|
|
54
|
+
... policy = SemaphorePolicy(limit=2)
|
|
55
|
+
... store = InMemorySnapshotStore()
|
|
56
|
+
... event_log = InMemoryEventLog()
|
|
57
|
+
... scheduler = Scheduler(
|
|
58
|
+
... registry,
|
|
59
|
+
... bus,
|
|
60
|
+
... policy,
|
|
61
|
+
... snapshot_store=store,
|
|
62
|
+
... event_log=event_log,
|
|
63
|
+
... )
|
|
64
|
+
... await scheduler.run(graph, graph_id="demo", initial_payload=[])
|
|
65
|
+
... snapshot = store.load("demo")
|
|
66
|
+
... event_types = [evt.type.name for evt in event_log.iter("demo")]
|
|
67
|
+
... return {
|
|
68
|
+
... "completed": snapshot["completed_nodes"],
|
|
69
|
+
... "events": event_types,
|
|
70
|
+
... }
|
|
71
|
+
>>> asyncio.run(run_with_persistence())
|
|
72
|
+
{'completed': ['branch', 'end', 'start'], 'events': ['NODE_SCHEDULED', 'NODE_COMPLETED', 'NODE_SCHEDULED', 'NODE_COMPLETED', 'NODE_SCHEDULED', 'NODE_COMPLETED']}
|
|
73
|
+
|
|
74
|
+
Loop re-entry is supported when a SWITCH routes to an already-completed node
|
|
75
|
+
that still has visit capacity.
|
|
76
|
+
|
|
77
|
+
>>> loop_registry = FunctionRegistry()
|
|
78
|
+
>>> loop_counter = {"count": 0}
|
|
79
|
+
>>> def loop_handler(_: object) -> dict[str, int]:
|
|
80
|
+
... loop_counter["count"] += 1
|
|
81
|
+
... return {"iteration": loop_counter["count"]}
|
|
82
|
+
>>> def switch_handler(payload: dict[str, int]) -> str:
|
|
83
|
+
... if payload["iteration"] < 2:
|
|
84
|
+
... return "continue"
|
|
85
|
+
... return "done"
|
|
86
|
+
>>> loop_registry.register("start", lambda _: None)
|
|
87
|
+
>>> loop_registry.register("loop", loop_handler)
|
|
88
|
+
>>> loop_registry.register("switch", switch_handler)
|
|
89
|
+
>>> loop_registry.register("out", lambda payload: payload)
|
|
90
|
+
>>> loop_graph = Graph(
|
|
91
|
+
... nodes={
|
|
92
|
+
... "start": Node(id="start", kind=NodeKind.TASK, handler="start"),
|
|
93
|
+
... "loop": Node(
|
|
94
|
+
... id="loop",
|
|
95
|
+
... kind=NodeKind.TASK,
|
|
96
|
+
... handler="loop",
|
|
97
|
+
... max_visits=2,
|
|
98
|
+
... allow_partial_upstream=True,
|
|
99
|
+
... ),
|
|
100
|
+
... "switch": Node(id="switch", kind=NodeKind.SWITCH, handler="switch"),
|
|
101
|
+
... "out": Node(id="out", kind=NodeKind.TASK, handler="out"),
|
|
102
|
+
... },
|
|
103
|
+
... edges={
|
|
104
|
+
... "start->loop": Edge(id="start->loop", source="start", target="loop"),
|
|
105
|
+
... "loop->switch": Edge(id="loop->switch", source="loop", target="switch"),
|
|
106
|
+
... "switch->loop": Edge(
|
|
107
|
+
... id="switch->loop",
|
|
108
|
+
... source="switch",
|
|
109
|
+
... target="loop",
|
|
110
|
+
... metadata={"route": "continue"},
|
|
111
|
+
... ),
|
|
112
|
+
... "switch->out": Edge(
|
|
113
|
+
... id="switch->out",
|
|
114
|
+
... source="switch",
|
|
115
|
+
... target="out",
|
|
116
|
+
... metadata={"route": "done"},
|
|
117
|
+
... ),
|
|
118
|
+
... },
|
|
119
|
+
... )
|
|
120
|
+
>>> async def run_loop() -> Dict[str, Any]:
|
|
121
|
+
... scheduler = Scheduler(loop_registry, EventBus(), SemaphorePolicy(limit=1))
|
|
122
|
+
... return await scheduler.run(loop_graph)
|
|
123
|
+
>>> asyncio.run(run_loop())["out"]
|
|
124
|
+
'done'
|
|
125
|
+
>>> loop_counter["count"]
|
|
126
|
+
2
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
registry: FunctionRegistry,
|
|
132
|
+
event_bus: EventBus,
|
|
133
|
+
concurrency_manager: ConcurrencyManager,
|
|
134
|
+
*,
|
|
135
|
+
snapshot_store: Optional[SnapshotStore] = None,
|
|
136
|
+
event_log: Optional[EventLog] = None,
|
|
137
|
+
) -> None:
|
|
138
|
+
func_name = "Scheduler.__init__"
|
|
139
|
+
log_parameter(
|
|
140
|
+
func_name,
|
|
141
|
+
registry=registry,
|
|
142
|
+
event_bus=event_bus,
|
|
143
|
+
snapshot_store=snapshot_store,
|
|
144
|
+
event_log=event_log,
|
|
145
|
+
)
|
|
146
|
+
self._registry = registry
|
|
147
|
+
log_variable_change(func_name, "self._registry", self._registry)
|
|
148
|
+
self._event_bus = event_bus
|
|
149
|
+
log_variable_change(func_name, "self._event_bus", self._event_bus)
|
|
150
|
+
self._concurrency_manager = concurrency_manager
|
|
151
|
+
log_variable_change(
|
|
152
|
+
func_name, "self._concurrency_manager", self._concurrency_manager
|
|
153
|
+
)
|
|
154
|
+
self._uses_default_semaphore = isinstance(
|
|
155
|
+
self._concurrency_manager, SemaphorePolicy
|
|
156
|
+
)
|
|
157
|
+
log_variable_change(
|
|
158
|
+
func_name, "self._uses_default_semaphore", self._uses_default_semaphore
|
|
159
|
+
)
|
|
160
|
+
self._event_counter = 0
|
|
161
|
+
log_variable_change(func_name, "self._event_counter", self._event_counter)
|
|
162
|
+
self._snapshot_store = snapshot_store
|
|
163
|
+
log_variable_change(func_name, "self._snapshot_store", self._snapshot_store)
|
|
164
|
+
self._event_log = event_log
|
|
165
|
+
log_variable_change(func_name, "self._event_log", self._event_log)
|
|
166
|
+
|
|
167
|
+
async def run(
|
|
168
|
+
self,
|
|
169
|
+
graph: Graph,
|
|
170
|
+
*,
|
|
171
|
+
graph_id: str = "graph",
|
|
172
|
+
initial_payload: Optional[Any] = None,
|
|
173
|
+
) -> Dict[str, Any]:
|
|
174
|
+
"""Execute nodes in the graph until completion."""
|
|
175
|
+
func_name = "Scheduler.run"
|
|
176
|
+
log_parameter(
|
|
177
|
+
func_name,
|
|
178
|
+
graph=graph,
|
|
179
|
+
graph_id=graph_id,
|
|
180
|
+
initial_payload=initial_payload,
|
|
181
|
+
)
|
|
182
|
+
execution_state = self._load_or_create_state(graph_id)
|
|
183
|
+
log_variable_change(func_name, "execution_state", execution_state)
|
|
184
|
+
snapshot_data = execution_state.snapshot()
|
|
185
|
+
log_variable_change(func_name, "snapshot_data", snapshot_data)
|
|
186
|
+
results = self._initial_results_from_snapshot(snapshot_data)
|
|
187
|
+
log_variable_change(func_name, "results", results)
|
|
188
|
+
completed_nodes = set(snapshot_data["completed_nodes"])
|
|
189
|
+
log_variable_change(func_name, "completed_nodes", completed_nodes)
|
|
190
|
+
pending = {node_id for node_id in graph.nodes if node_id not in completed_nodes}
|
|
191
|
+
log_variable_change(func_name, "pending", pending)
|
|
192
|
+
loop_iteration = 0
|
|
193
|
+
while pending:
|
|
194
|
+
log_loop_iteration(func_name, "pending_loop", loop_iteration)
|
|
195
|
+
loop_iteration += 1
|
|
196
|
+
progressed = False
|
|
197
|
+
log_variable_change(func_name, "progressed", progressed)
|
|
198
|
+
for iteration, node_id in enumerate(list(pending)):
|
|
199
|
+
log_loop_iteration(func_name, "pending_nodes", iteration)
|
|
200
|
+
if not execution_state.is_ready(graph, node_id):
|
|
201
|
+
log_branch(func_name, "node_not_ready")
|
|
202
|
+
continue
|
|
203
|
+
log_branch(func_name, "node_ready")
|
|
204
|
+
progressed = True
|
|
205
|
+
log_variable_change(func_name, "progressed", progressed)
|
|
206
|
+
node = graph.nodes[node_id]
|
|
207
|
+
log_variable_change(func_name, "node", node)
|
|
208
|
+
execution_state.mark_running(node_id)
|
|
209
|
+
event_id = self._next_event_id(node_id)
|
|
210
|
+
log_variable_change(func_name, "event_id", event_id)
|
|
211
|
+
input_payload = self._build_input_payload(
|
|
212
|
+
graph=graph,
|
|
213
|
+
node=node,
|
|
214
|
+
results=results,
|
|
215
|
+
initial_payload=initial_payload,
|
|
216
|
+
)
|
|
217
|
+
log_variable_change(func_name, "input_payload", input_payload)
|
|
218
|
+
await self._dispatch_event(
|
|
219
|
+
Event(
|
|
220
|
+
id=event_id,
|
|
221
|
+
graph_id=graph_id,
|
|
222
|
+
node_id=node_id,
|
|
223
|
+
type=EventType.NODE_SCHEDULED,
|
|
224
|
+
payload=input_payload,
|
|
225
|
+
status=NodeStatus.RUNNING,
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
handler_result, reentry_targets = await self._execute_node(
|
|
229
|
+
node=node,
|
|
230
|
+
graph=graph,
|
|
231
|
+
execution_state=execution_state,
|
|
232
|
+
graph_id=graph_id,
|
|
233
|
+
upstream_payload=input_payload,
|
|
234
|
+
)
|
|
235
|
+
log_variable_change(func_name, "handler_result", handler_result)
|
|
236
|
+
log_variable_change(func_name, "reentry_targets", reentry_targets)
|
|
237
|
+
results[node_id] = handler_result
|
|
238
|
+
log_variable_change(func_name, "results", results)
|
|
239
|
+
pending.remove(node_id)
|
|
240
|
+
log_variable_change(func_name, "pending", pending)
|
|
241
|
+
for reentry_iteration, reentry_target in enumerate(reentry_targets):
|
|
242
|
+
log_loop_iteration(func_name, "reentry_targets", reentry_iteration)
|
|
243
|
+
pending.add(reentry_target)
|
|
244
|
+
log_variable_change(
|
|
245
|
+
func_name,
|
|
246
|
+
f"pending_with_reentry_{reentry_target}",
|
|
247
|
+
pending,
|
|
248
|
+
)
|
|
249
|
+
if not progressed:
|
|
250
|
+
log_branch(func_name, "no_progress")
|
|
251
|
+
raise RuntimeError("Scheduler could not make progress")
|
|
252
|
+
log_branch(func_name, "completed")
|
|
253
|
+
self._persist_snapshot(execution_state, graph_id)
|
|
254
|
+
return results
|
|
255
|
+
|
|
256
|
+
def _next_event_id(self, node_id: str) -> str:
|
|
257
|
+
"""Produce a unique event identifier."""
|
|
258
|
+
func_name = "Scheduler._next_event_id"
|
|
259
|
+
log_parameter(func_name, node_id=node_id)
|
|
260
|
+
self._event_counter += 1
|
|
261
|
+
log_variable_change(func_name, "self._event_counter", self._event_counter)
|
|
262
|
+
event_id = f"{node_id}-{self._event_counter}"
|
|
263
|
+
log_variable_change(func_name, "event_id", event_id)
|
|
264
|
+
return event_id
|
|
265
|
+
|
|
266
|
+
async def _execute_node(
|
|
267
|
+
self,
|
|
268
|
+
node: Node,
|
|
269
|
+
graph: Graph,
|
|
270
|
+
execution_state: ExecutionState,
|
|
271
|
+
graph_id: str,
|
|
272
|
+
upstream_payload: Optional[Any],
|
|
273
|
+
) -> Tuple[Any, List[str]]:
|
|
274
|
+
"""Execute a single node handler."""
|
|
275
|
+
func_name = "Scheduler._execute_node"
|
|
276
|
+
log_parameter(
|
|
277
|
+
func_name, node=node, graph=graph, upstream_payload=upstream_payload
|
|
278
|
+
)
|
|
279
|
+
priority = node.priority
|
|
280
|
+
log_variable_change(func_name, "priority", priority)
|
|
281
|
+
async with self._concurrency_manager.slot(node.id, priority=priority):
|
|
282
|
+
log_branch(func_name, "acquired_slot")
|
|
283
|
+
try:
|
|
284
|
+
result = await self._registry.execute(node.handler, upstream_payload)
|
|
285
|
+
log_variable_change(func_name, "result", result)
|
|
286
|
+
except Exception as exc: # pragma: no cover - explicit failure path
|
|
287
|
+
log_branch(func_name, "handler_failed")
|
|
288
|
+
outcome = VisitOutcome.failure(str(exc))
|
|
289
|
+
log_variable_change(func_name, "outcome", outcome)
|
|
290
|
+
event_id = self._next_event_id(node.id)
|
|
291
|
+
log_variable_change(func_name, "event_id", event_id)
|
|
292
|
+
execution_state.mark_failed(node.id, event_id, outcome)
|
|
293
|
+
await self._dispatch_event(
|
|
294
|
+
Event(
|
|
295
|
+
id=event_id,
|
|
296
|
+
graph_id=graph_id,
|
|
297
|
+
node_id=node.id,
|
|
298
|
+
type=EventType.NODE_FAILED,
|
|
299
|
+
payload={"error": str(exc)},
|
|
300
|
+
status=NodeStatus.FAILED,
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
self._persist_snapshot(execution_state, graph_id)
|
|
304
|
+
raise
|
|
305
|
+
outcome = VisitOutcome.success(result)
|
|
306
|
+
log_variable_change(func_name, "outcome", outcome)
|
|
307
|
+
event_id = self._next_event_id(node.id)
|
|
308
|
+
log_variable_change(func_name, "event_id", event_id)
|
|
309
|
+
execution_state.mark_complete(node.id, event_id, outcome)
|
|
310
|
+
await self._dispatch_event(
|
|
311
|
+
Event(
|
|
312
|
+
id=event_id,
|
|
313
|
+
graph_id=graph_id,
|
|
314
|
+
node_id=node.id,
|
|
315
|
+
type=EventType.NODE_COMPLETED,
|
|
316
|
+
payload=result,
|
|
317
|
+
status=NodeStatus.COMPLETED,
|
|
318
|
+
visit_count=execution_state.snapshot()["states"][node.id]["visits"][
|
|
319
|
+
"count"
|
|
320
|
+
],
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
selected_edges = self._determine_downstream_edges(
|
|
324
|
+
graph=graph,
|
|
325
|
+
node=node,
|
|
326
|
+
handler_result=result,
|
|
327
|
+
execution_state=execution_state,
|
|
328
|
+
)
|
|
329
|
+
log_variable_change(func_name, "selected_edges", selected_edges)
|
|
330
|
+
reentry_targets: List[str] = []
|
|
331
|
+
log_variable_change(func_name, "reentry_targets", reentry_targets)
|
|
332
|
+
for iteration, edge in enumerate(selected_edges):
|
|
333
|
+
log_loop_iteration(func_name, "downstream_edges", iteration)
|
|
334
|
+
log_variable_change(func_name, "edge", edge)
|
|
335
|
+
downstream = graph.nodes[edge.target]
|
|
336
|
+
log_variable_change(func_name, "downstream", downstream)
|
|
337
|
+
downstream_state = execution_state._ensure_state(downstream.id)
|
|
338
|
+
log_variable_change(func_name, "downstream_state", downstream_state)
|
|
339
|
+
downstream_status = downstream_state.status
|
|
340
|
+
log_variable_change(func_name, "downstream_status", downstream_status)
|
|
341
|
+
downstream_visits = downstream_state.visits.count
|
|
342
|
+
log_variable_change(func_name, "downstream_visits", downstream_visits)
|
|
343
|
+
if downstream_status is NodeStatus.COMPLETED:
|
|
344
|
+
has_remaining_visits = self._has_remaining_visits(
|
|
345
|
+
graph=graph,
|
|
346
|
+
execution_state=execution_state,
|
|
347
|
+
node_id=downstream.id,
|
|
348
|
+
)
|
|
349
|
+
log_variable_change(
|
|
350
|
+
func_name, "has_remaining_visits", has_remaining_visits
|
|
351
|
+
)
|
|
352
|
+
if has_remaining_visits:
|
|
353
|
+
log_branch(func_name, "reentry_reset_completed")
|
|
354
|
+
execution_state.reset_for_reentry(downstream.id)
|
|
355
|
+
reentry_targets.append(downstream.id)
|
|
356
|
+
log_variable_change(
|
|
357
|
+
func_name,
|
|
358
|
+
"reentry_targets",
|
|
359
|
+
list(reentry_targets),
|
|
360
|
+
)
|
|
361
|
+
else:
|
|
362
|
+
log_branch(func_name, "reentry_completed_exhausted")
|
|
363
|
+
elif downstream_status is NodeStatus.FAILED:
|
|
364
|
+
log_branch(func_name, "reentry_failed_skip")
|
|
365
|
+
elif downstream_visits == 0 and downstream_status is NodeStatus.PENDING:
|
|
366
|
+
log_branch(func_name, "initial_pending_target")
|
|
367
|
+
else:
|
|
368
|
+
log_branch(func_name, "reentry_non_terminal_error")
|
|
369
|
+
raise RuntimeError(
|
|
370
|
+
"Encountered non-terminal re-entry target "
|
|
371
|
+
f"'{downstream.id}' in state '{downstream_status.value}'"
|
|
372
|
+
)
|
|
373
|
+
execution_state.note_upstream_completion(downstream.id, node.id)
|
|
374
|
+
self._persist_snapshot(execution_state, graph_id)
|
|
375
|
+
return result, reentry_targets
|
|
376
|
+
|
|
377
|
+
def _build_input_payload(
|
|
378
|
+
self,
|
|
379
|
+
graph: Graph,
|
|
380
|
+
node: Node,
|
|
381
|
+
results: Dict[str, Any],
|
|
382
|
+
initial_payload: Optional[Any],
|
|
383
|
+
) -> Optional[Any]:
|
|
384
|
+
"""Determine the payload supplied to a handler based on upstream results."""
|
|
385
|
+
|
|
386
|
+
func_name = "Scheduler._build_input_payload"
|
|
387
|
+
log_parameter(
|
|
388
|
+
func_name,
|
|
389
|
+
graph=graph,
|
|
390
|
+
node=node,
|
|
391
|
+
results=results,
|
|
392
|
+
initial_payload=initial_payload,
|
|
393
|
+
)
|
|
394
|
+
upstream_nodes = graph.upstream_nodes(node.id)
|
|
395
|
+
log_variable_change(func_name, "upstream_nodes", upstream_nodes)
|
|
396
|
+
if not upstream_nodes:
|
|
397
|
+
log_branch(func_name, "no_upstream")
|
|
398
|
+
log_variable_change(func_name, "payload", initial_payload)
|
|
399
|
+
return initial_payload
|
|
400
|
+
|
|
401
|
+
if node.kind is NodeKind.AGGREGATE:
|
|
402
|
+
log_branch(func_name, "aggregate_payload")
|
|
403
|
+
aggregated: List[Any] = []
|
|
404
|
+
log_variable_change(func_name, "aggregated", aggregated)
|
|
405
|
+
for iteration, upstream in enumerate(upstream_nodes):
|
|
406
|
+
log_loop_iteration(func_name, "aggregate_upstream", iteration)
|
|
407
|
+
if upstream.id in results:
|
|
408
|
+
aggregated.append(results[upstream.id])
|
|
409
|
+
log_variable_change(
|
|
410
|
+
func_name, "aggregated", list(aggregated)
|
|
411
|
+
)
|
|
412
|
+
log_variable_change(func_name, "payload", aggregated)
|
|
413
|
+
return aggregated
|
|
414
|
+
|
|
415
|
+
upstream_id = upstream_nodes[0].id
|
|
416
|
+
log_variable_change(func_name, "upstream_id", upstream_id)
|
|
417
|
+
payload = results.get(upstream_id)
|
|
418
|
+
log_variable_change(func_name, "payload", payload)
|
|
419
|
+
return payload
|
|
420
|
+
|
|
421
|
+
def _determine_downstream_edges(
|
|
422
|
+
self,
|
|
423
|
+
graph: Graph,
|
|
424
|
+
node: Node,
|
|
425
|
+
handler_result: Any,
|
|
426
|
+
execution_state: ExecutionState,
|
|
427
|
+
) -> List[Edge]:
|
|
428
|
+
"""Select downstream edges to activate after a node completes."""
|
|
429
|
+
|
|
430
|
+
func_name = "Scheduler._determine_downstream_edges"
|
|
431
|
+
log_parameter(
|
|
432
|
+
func_name,
|
|
433
|
+
graph=graph,
|
|
434
|
+
node=node,
|
|
435
|
+
handler_result=handler_result,
|
|
436
|
+
)
|
|
437
|
+
edges = graph.downstream_edges(node.id)
|
|
438
|
+
log_variable_change(func_name, "edges", edges)
|
|
439
|
+
if node.kind is not NodeKind.SWITCH:
|
|
440
|
+
log_branch(func_name, "non_switch")
|
|
441
|
+
return edges
|
|
442
|
+
|
|
443
|
+
log_branch(func_name, "switch_node")
|
|
444
|
+
if not isinstance(handler_result, str):
|
|
445
|
+
log_branch(func_name, "invalid_route_type")
|
|
446
|
+
raise ValueError(
|
|
447
|
+
f"Switch handler for node '{node.id}' must return a string route"
|
|
448
|
+
)
|
|
449
|
+
route = handler_result
|
|
450
|
+
log_variable_change(func_name, "route", route)
|
|
451
|
+
selected: List[Edge] = []
|
|
452
|
+
exit_edges: List[Edge] = []
|
|
453
|
+
log_variable_change(func_name, "selected", selected)
|
|
454
|
+
for iteration, edge in enumerate(edges):
|
|
455
|
+
log_loop_iteration(func_name, "switch_edges", iteration)
|
|
456
|
+
metadata_route = edge.metadata.get("route")
|
|
457
|
+
log_variable_change(func_name, "metadata_route", metadata_route)
|
|
458
|
+
if metadata_route == route:
|
|
459
|
+
has_capacity = self._has_remaining_visits(
|
|
460
|
+
graph, execution_state, edge.target
|
|
461
|
+
)
|
|
462
|
+
log_variable_change(func_name, "has_capacity", has_capacity)
|
|
463
|
+
if has_capacity:
|
|
464
|
+
selected.append(edge)
|
|
465
|
+
log_variable_change(func_name, "selected", list(selected))
|
|
466
|
+
else:
|
|
467
|
+
log_branch(func_name, "target_exhausted")
|
|
468
|
+
if metadata_route == "exit":
|
|
469
|
+
exit_edges.append(edge)
|
|
470
|
+
log_variable_change(func_name, "exit_edges", list(exit_edges))
|
|
471
|
+
|
|
472
|
+
if selected:
|
|
473
|
+
log_variable_change(func_name, "selected_final", selected)
|
|
474
|
+
return selected
|
|
475
|
+
if exit_edges:
|
|
476
|
+
log_branch(func_name, "fallback_exit")
|
|
477
|
+
return exit_edges
|
|
478
|
+
log_branch(func_name, "no_matching_edge")
|
|
479
|
+
return []
|
|
480
|
+
|
|
481
|
+
def _has_remaining_visits(
|
|
482
|
+
self, graph: Graph, execution_state: ExecutionState, node_id: str
|
|
483
|
+
) -> bool:
|
|
484
|
+
func_name = "Scheduler._has_remaining_visits"
|
|
485
|
+
log_parameter(func_name, node_id=node_id)
|
|
486
|
+
node = graph.nodes[node_id]
|
|
487
|
+
log_variable_change(func_name, "node", node)
|
|
488
|
+
if node.max_visits is None:
|
|
489
|
+
log_branch(func_name, "no_visit_limit")
|
|
490
|
+
return True
|
|
491
|
+
state = execution_state._ensure_state(node_id)
|
|
492
|
+
visits = state.visits.count
|
|
493
|
+
log_variable_change(func_name, "visits", visits)
|
|
494
|
+
has_capacity = visits < node.max_visits
|
|
495
|
+
log_variable_change(func_name, "has_capacity", has_capacity)
|
|
496
|
+
return has_capacity
|
|
497
|
+
|
|
498
|
+
def _load_or_create_state(self, graph_id: str) -> ExecutionState:
|
|
499
|
+
"""Retrieve execution state from snapshots if available."""
|
|
500
|
+
|
|
501
|
+
func_name = "Scheduler._load_or_create_state"
|
|
502
|
+
log_parameter(func_name, graph_id=graph_id)
|
|
503
|
+
if not self._snapshot_store:
|
|
504
|
+
log_branch(func_name, "no_snapshot_store")
|
|
505
|
+
state = ExecutionState()
|
|
506
|
+
log_variable_change(func_name, "state", state)
|
|
507
|
+
return state
|
|
508
|
+
|
|
509
|
+
try:
|
|
510
|
+
snapshot_payload = self._snapshot_store.load(graph_id)
|
|
511
|
+
except KeyError:
|
|
512
|
+
log_branch(func_name, "snapshot_missing")
|
|
513
|
+
state = ExecutionState()
|
|
514
|
+
else:
|
|
515
|
+
log_branch(func_name, "snapshot_loaded")
|
|
516
|
+
state = ExecutionState.restore(dict(snapshot_payload))
|
|
517
|
+
log_variable_change(func_name, "state", state)
|
|
518
|
+
return state
|
|
519
|
+
|
|
520
|
+
def _initial_results_from_snapshot(self, snapshot: Dict[str, Any]) -> Dict[str, Any]:
|
|
521
|
+
"""Extract node results from a snapshot payload."""
|
|
522
|
+
|
|
523
|
+
func_name = "Scheduler._initial_results_from_snapshot"
|
|
524
|
+
log_parameter(func_name, snapshot=snapshot)
|
|
525
|
+
results: Dict[str, Any] = {}
|
|
526
|
+
log_variable_change(func_name, "results", results)
|
|
527
|
+
states = snapshot.get("states", {})
|
|
528
|
+
log_variable_change(func_name, "states", states)
|
|
529
|
+
for iteration, (node_id, node_state) in enumerate(states.items()):
|
|
530
|
+
log_loop_iteration(func_name, "states", iteration)
|
|
531
|
+
status_value = node_state.get("status")
|
|
532
|
+
log_variable_change(func_name, "status_value", status_value)
|
|
533
|
+
if status_value == NodeStatus.COMPLETED.value:
|
|
534
|
+
payload = node_state.get("last_payload")
|
|
535
|
+
log_variable_change(func_name, "payload", payload)
|
|
536
|
+
if payload is not None:
|
|
537
|
+
results[node_id] = payload
|
|
538
|
+
log_variable_change(
|
|
539
|
+
func_name, f"results[{node_id!r}]", results[node_id]
|
|
540
|
+
)
|
|
541
|
+
log_variable_change(func_name, "results_final", results)
|
|
542
|
+
return results
|
|
543
|
+
|
|
544
|
+
def _persist_snapshot(
|
|
545
|
+
self, execution_state: ExecutionState, graph_id: str
|
|
546
|
+
) -> None:
|
|
547
|
+
"""Persist execution state snapshot if a store is configured."""
|
|
548
|
+
|
|
549
|
+
func_name = "Scheduler._persist_snapshot"
|
|
550
|
+
log_parameter(func_name, graph_id=graph_id)
|
|
551
|
+
if not self._snapshot_store:
|
|
552
|
+
log_branch(func_name, "no_snapshot_store")
|
|
553
|
+
return
|
|
554
|
+
snapshot = execution_state.snapshot()
|
|
555
|
+
log_variable_change(func_name, "snapshot", snapshot)
|
|
556
|
+
self._snapshot_store.save(graph_id, snapshot)
|
|
557
|
+
log_branch(func_name, "snapshot_saved")
|
|
558
|
+
|
|
559
|
+
async def _dispatch_event(self, event: Event) -> None:
|
|
560
|
+
"""Append an event to the log and publish it on the bus."""
|
|
561
|
+
|
|
562
|
+
func_name = "Scheduler._dispatch_event"
|
|
563
|
+
log_parameter(func_name, event=event)
|
|
564
|
+
if self._event_log:
|
|
565
|
+
self._event_log.append(event)
|
|
566
|
+
log_branch(func_name, "event_logged")
|
|
567
|
+
else:
|
|
568
|
+
log_branch(func_name, "no_event_log")
|
|
569
|
+
await self._event_bus.emit(event)
|