aethergraph 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. aethergraph/__init__.py +49 -0
  2. aethergraph/config/__init__.py +0 -0
  3. aethergraph/config/config.py +121 -0
  4. aethergraph/config/context.py +16 -0
  5. aethergraph/config/llm.py +26 -0
  6. aethergraph/config/loader.py +60 -0
  7. aethergraph/config/runtime.py +9 -0
  8. aethergraph/contracts/errors/errors.py +44 -0
  9. aethergraph/contracts/services/artifacts.py +142 -0
  10. aethergraph/contracts/services/channel.py +72 -0
  11. aethergraph/contracts/services/continuations.py +23 -0
  12. aethergraph/contracts/services/eventbus.py +12 -0
  13. aethergraph/contracts/services/kv.py +24 -0
  14. aethergraph/contracts/services/llm.py +17 -0
  15. aethergraph/contracts/services/mcp.py +22 -0
  16. aethergraph/contracts/services/memory.py +108 -0
  17. aethergraph/contracts/services/resume.py +28 -0
  18. aethergraph/contracts/services/state_stores.py +33 -0
  19. aethergraph/contracts/services/wakeup.py +28 -0
  20. aethergraph/core/execution/base_scheduler.py +77 -0
  21. aethergraph/core/execution/forward_scheduler.py +777 -0
  22. aethergraph/core/execution/global_scheduler.py +634 -0
  23. aethergraph/core/execution/retry_policy.py +22 -0
  24. aethergraph/core/execution/step_forward.py +411 -0
  25. aethergraph/core/execution/step_result.py +18 -0
  26. aethergraph/core/execution/wait_types.py +72 -0
  27. aethergraph/core/graph/graph_builder.py +192 -0
  28. aethergraph/core/graph/graph_fn.py +219 -0
  29. aethergraph/core/graph/graph_io.py +67 -0
  30. aethergraph/core/graph/graph_refs.py +154 -0
  31. aethergraph/core/graph/graph_spec.py +115 -0
  32. aethergraph/core/graph/graph_state.py +59 -0
  33. aethergraph/core/graph/graphify.py +128 -0
  34. aethergraph/core/graph/interpreter.py +145 -0
  35. aethergraph/core/graph/node_handle.py +33 -0
  36. aethergraph/core/graph/node_spec.py +46 -0
  37. aethergraph/core/graph/node_state.py +63 -0
  38. aethergraph/core/graph/task_graph.py +747 -0
  39. aethergraph/core/graph/task_node.py +82 -0
  40. aethergraph/core/graph/utils.py +37 -0
  41. aethergraph/core/graph/visualize.py +239 -0
  42. aethergraph/core/runtime/ad_hoc_context.py +61 -0
  43. aethergraph/core/runtime/base_service.py +153 -0
  44. aethergraph/core/runtime/bind_adapter.py +42 -0
  45. aethergraph/core/runtime/bound_memory.py +69 -0
  46. aethergraph/core/runtime/execution_context.py +220 -0
  47. aethergraph/core/runtime/graph_runner.py +349 -0
  48. aethergraph/core/runtime/lifecycle.py +26 -0
  49. aethergraph/core/runtime/node_context.py +203 -0
  50. aethergraph/core/runtime/node_services.py +30 -0
  51. aethergraph/core/runtime/recovery.py +159 -0
  52. aethergraph/core/runtime/run_registration.py +33 -0
  53. aethergraph/core/runtime/runtime_env.py +157 -0
  54. aethergraph/core/runtime/runtime_registry.py +32 -0
  55. aethergraph/core/runtime/runtime_services.py +224 -0
  56. aethergraph/core/runtime/wakeup_watcher.py +40 -0
  57. aethergraph/core/tools/__init__.py +10 -0
  58. aethergraph/core/tools/builtins/channel_tools.py +194 -0
  59. aethergraph/core/tools/builtins/toolset.py +134 -0
  60. aethergraph/core/tools/toolkit.py +510 -0
  61. aethergraph/core/tools/waitable.py +109 -0
  62. aethergraph/plugins/channel/__init__.py +0 -0
  63. aethergraph/plugins/channel/adapters/__init__.py +0 -0
  64. aethergraph/plugins/channel/adapters/console.py +106 -0
  65. aethergraph/plugins/channel/adapters/file.py +102 -0
  66. aethergraph/plugins/channel/adapters/slack.py +285 -0
  67. aethergraph/plugins/channel/adapters/telegram.py +302 -0
  68. aethergraph/plugins/channel/adapters/webhook.py +104 -0
  69. aethergraph/plugins/channel/adapters/webui.py +134 -0
  70. aethergraph/plugins/channel/routes/__init__.py +0 -0
  71. aethergraph/plugins/channel/routes/console_routes.py +86 -0
  72. aethergraph/plugins/channel/routes/slack_routes.py +49 -0
  73. aethergraph/plugins/channel/routes/telegram_routes.py +26 -0
  74. aethergraph/plugins/channel/routes/webui_routes.py +136 -0
  75. aethergraph/plugins/channel/utils/__init__.py +0 -0
  76. aethergraph/plugins/channel/utils/slack_utils.py +278 -0
  77. aethergraph/plugins/channel/utils/telegram_utils.py +324 -0
  78. aethergraph/plugins/channel/websockets/slack_ws.py +68 -0
  79. aethergraph/plugins/channel/websockets/telegram_polling.py +151 -0
  80. aethergraph/plugins/mcp/fs_server.py +128 -0
  81. aethergraph/plugins/mcp/http_server.py +101 -0
  82. aethergraph/plugins/mcp/ws_server.py +180 -0
  83. aethergraph/plugins/net/http.py +10 -0
  84. aethergraph/plugins/utils/data_io.py +359 -0
  85. aethergraph/runner/__init__.py +5 -0
  86. aethergraph/runtime/__init__.py +62 -0
  87. aethergraph/server/__init__.py +3 -0
  88. aethergraph/server/app_factory.py +84 -0
  89. aethergraph/server/start.py +122 -0
  90. aethergraph/services/__init__.py +10 -0
  91. aethergraph/services/artifacts/facade.py +284 -0
  92. aethergraph/services/artifacts/factory.py +35 -0
  93. aethergraph/services/artifacts/fs_store.py +656 -0
  94. aethergraph/services/artifacts/jsonl_index.py +123 -0
  95. aethergraph/services/artifacts/paths.py +23 -0
  96. aethergraph/services/artifacts/sqlite_index.py +209 -0
  97. aethergraph/services/artifacts/utils.py +124 -0
  98. aethergraph/services/auth/dev.py +16 -0
  99. aethergraph/services/channel/channel_bus.py +293 -0
  100. aethergraph/services/channel/factory.py +44 -0
  101. aethergraph/services/channel/session.py +511 -0
  102. aethergraph/services/channel/wait_helpers.py +57 -0
  103. aethergraph/services/clock/clock.py +9 -0
  104. aethergraph/services/container/default_container.py +320 -0
  105. aethergraph/services/continuations/continuation.py +56 -0
  106. aethergraph/services/continuations/factory.py +34 -0
  107. aethergraph/services/continuations/stores/fs_store.py +264 -0
  108. aethergraph/services/continuations/stores/inmem_store.py +95 -0
  109. aethergraph/services/eventbus/inmem.py +21 -0
  110. aethergraph/services/features/static.py +10 -0
  111. aethergraph/services/kv/ephemeral.py +90 -0
  112. aethergraph/services/kv/factory.py +27 -0
  113. aethergraph/services/kv/layered.py +41 -0
  114. aethergraph/services/kv/sqlite_kv.py +128 -0
  115. aethergraph/services/llm/factory.py +157 -0
  116. aethergraph/services/llm/generic_client.py +542 -0
  117. aethergraph/services/llm/providers.py +3 -0
  118. aethergraph/services/llm/service.py +105 -0
  119. aethergraph/services/logger/base.py +36 -0
  120. aethergraph/services/logger/compat.py +50 -0
  121. aethergraph/services/logger/formatters.py +106 -0
  122. aethergraph/services/logger/std.py +203 -0
  123. aethergraph/services/mcp/helpers.py +23 -0
  124. aethergraph/services/mcp/http_client.py +70 -0
  125. aethergraph/services/mcp/mcp_tools.py +21 -0
  126. aethergraph/services/mcp/registry.py +14 -0
  127. aethergraph/services/mcp/service.py +100 -0
  128. aethergraph/services/mcp/stdio_client.py +70 -0
  129. aethergraph/services/mcp/ws_client.py +115 -0
  130. aethergraph/services/memory/bound.py +106 -0
  131. aethergraph/services/memory/distillers/episode.py +116 -0
  132. aethergraph/services/memory/distillers/rolling.py +74 -0
  133. aethergraph/services/memory/facade.py +633 -0
  134. aethergraph/services/memory/factory.py +78 -0
  135. aethergraph/services/memory/hotlog_kv.py +27 -0
  136. aethergraph/services/memory/indices.py +74 -0
  137. aethergraph/services/memory/io_helpers.py +72 -0
  138. aethergraph/services/memory/persist_fs.py +40 -0
  139. aethergraph/services/memory/resolver.py +152 -0
  140. aethergraph/services/metering/noop.py +4 -0
  141. aethergraph/services/prompts/file_store.py +41 -0
  142. aethergraph/services/rag/chunker.py +29 -0
  143. aethergraph/services/rag/facade.py +593 -0
  144. aethergraph/services/rag/index/base.py +27 -0
  145. aethergraph/services/rag/index/faiss_index.py +121 -0
  146. aethergraph/services/rag/index/sqlite_index.py +134 -0
  147. aethergraph/services/rag/index_factory.py +52 -0
  148. aethergraph/services/rag/parsers/md.py +7 -0
  149. aethergraph/services/rag/parsers/pdf.py +14 -0
  150. aethergraph/services/rag/parsers/txt.py +7 -0
  151. aethergraph/services/rag/utils/hybrid.py +39 -0
  152. aethergraph/services/rag/utils/make_fs_key.py +62 -0
  153. aethergraph/services/redactor/simple.py +16 -0
  154. aethergraph/services/registry/key_parsing.py +44 -0
  155. aethergraph/services/registry/registry_key.py +19 -0
  156. aethergraph/services/registry/unified_registry.py +185 -0
  157. aethergraph/services/resume/multi_scheduler_resume_bus.py +65 -0
  158. aethergraph/services/resume/router.py +73 -0
  159. aethergraph/services/schedulers/registry.py +41 -0
  160. aethergraph/services/secrets/base.py +7 -0
  161. aethergraph/services/secrets/env.py +8 -0
  162. aethergraph/services/state_stores/externalize.py +135 -0
  163. aethergraph/services/state_stores/graph_observer.py +131 -0
  164. aethergraph/services/state_stores/json_store.py +67 -0
  165. aethergraph/services/state_stores/resume_policy.py +119 -0
  166. aethergraph/services/state_stores/serialize.py +249 -0
  167. aethergraph/services/state_stores/utils.py +91 -0
  168. aethergraph/services/state_stores/validate.py +78 -0
  169. aethergraph/services/tracing/noop.py +18 -0
  170. aethergraph/services/waits/wait_registry.py +91 -0
  171. aethergraph/services/wakeup/memory_queue.py +57 -0
  172. aethergraph/services/wakeup/scanner_producer.py +56 -0
  173. aethergraph/services/wakeup/worker.py +31 -0
  174. aethergraph/tools/__init__.py +25 -0
  175. aethergraph/utils/optdeps.py +8 -0
  176. aethergraph-0.1.0a1.dist-info/METADATA +410 -0
  177. aethergraph-0.1.0a1.dist-info/RECORD +182 -0
  178. aethergraph-0.1.0a1.dist-info/WHEEL +5 -0
  179. aethergraph-0.1.0a1.dist-info/entry_points.txt +2 -0
  180. aethergraph-0.1.0a1.dist-info/licenses/LICENSE +176 -0
  181. aethergraph-0.1.0a1.dist-info/licenses/NOTICE +31 -0
  182. aethergraph-0.1.0a1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,634 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import Awaitable, Callable
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ import inspect
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from aethergraph.contracts.services.resume import (
11
+ ResumeEvent, # we’ll extend usage to include run_id
12
+ )
13
+ from aethergraph.contracts.services.wakeup import WakeupEvent
14
+
15
+ from ..graph.graph_refs import GRAPH_INPUTS_NODE_ID
16
+ from ..graph.node_spec import NodeEvent
17
+ from ..graph.node_state import TERMINAL_STATES, WAITING_STATES, NodeStatus
18
+ from ..graph.task_node import TaskNodeRuntime
19
+ from .retry_policy import RetryPolicy
20
+
21
+ if TYPE_CHECKING:
22
+ from aethergraph.services.schedulers.registry import SchedulerRegistry
23
+
24
+ from ..graph.task_graph import TaskGraph
25
+ from ..runtime.runtime_env import RuntimeEnv
26
+
27
+
28
+ # --------- Global control events tagged with run_id ---------
29
+ @dataclass
30
+ class GlobalResumeEvent:
31
+ run_id: str
32
+ node_id: str
33
+ payload: dict[str, Any]
34
+
35
+
36
+ @dataclass
37
+ class GlobalWakeupEvent:
38
+ run_id: str
39
+ node_id: str
40
+
41
+
42
+ # --------- Per-run state ---------
43
+ @dataclass
44
+ class RunSettings:
45
+ max_concurrency: int = 4
46
+ retry_policy: RetryPolicy = field(default_factory=RetryPolicy)
47
+ stop_on_first_error: bool = False
48
+ skip_dependents_on_failure: bool = True
49
+
50
+
51
+ @dataclass
52
+ class RunState:
53
+ run_id: str
54
+ graph: TaskGraph
55
+ env: RuntimeEnv
56
+ settings: RunSettings
57
+
58
+ # bookkeeping
59
+ running_tasks: dict[str, asyncio.Task] = field(default_factory=dict) # node_id -> task
60
+ resume_payloads: dict[str, dict[str, Any]] = field(default_factory=dict) # node_id -> payload
61
+ resume_pending: set[str] = field(default_factory=set) # node_ids awaiting capacity
62
+ ready_pending: set[str] = field(default_factory=set) # nodes explicitly enqueued
63
+ backoff_tasks: dict[str, asyncio.Task] = field(default_factory=dict) # node_id -> sleeper task
64
+ terminated: bool = False
65
+
66
+ def capacity(self) -> int:
67
+ return max(0, self.settings.max_concurrency - len(self.running_tasks))
68
+
69
+ def any_waiting(self) -> bool:
70
+ return any(
71
+ (n.spec.type != "plan") and (n.state.status in WAITING_STATES) for n in self.graph.nodes
72
+ )
73
+
74
+ def all_terminal(self) -> bool:
75
+ for n in self.graph.nodes:
76
+ if n.spec.type == "plan":
77
+ continue
78
+ if n.state.status not in TERMINAL_STATES:
79
+ return False
80
+ return True
81
+
82
+
83
+ @dataclass
84
+ class RunEvent:
85
+ run_id: str
86
+ status: str # "SUCCESS" | "FAILED" | "CANCELLED"
87
+ timestamp: float
88
+
89
+
90
+ # --------- Global Forward Scheduler ---------
91
+ class GlobalForwardScheduler:
92
+ """
93
+ A global event-driven DAG scheduler that coordinates execution across many graphs (runs)
94
+ in a single asyncio event loop.
95
+
96
+ • One global control plane (queue) carrying (run_id, node_id, payload) events
97
+ • Resumed nodes (WAITING_* -> RUNNING) are prioritized globally
98
+ • Each run has its own capacity; also a global cap can be applied if desired
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ *,
104
+ registry: SchedulerRegistry,
105
+ global_max_concurrency: int | None = None,
106
+ logger: Any | None = None,
107
+ ):
108
+ self._runs: dict[str, RunState] = {}
109
+ self._listeners: list[Callable[[NodeEvent], Awaitable[None]]] = []
110
+ self._events: asyncio.Queue = asyncio.Queue()
111
+ self._pause_event = asyncio.Event()
112
+ self._pause_event.set()
113
+ self._terminated = False
114
+
115
+ # Optional global cap across all runs
116
+ self._global_max_concurrency = global_max_concurrency # None => unlimited
117
+ self._logger = logger
118
+
119
+ # registry for MultiSchedulerResumeBus routing
120
+ self._registry = registry
121
+
122
+ # convenience: track our loop (used by ResumeBus cross-thread dispatch)
123
+ try:
124
+ self.loop = asyncio.get_running_loop()
125
+ except RuntimeError:
126
+ self.loop = None
127
+
128
+ self._run_listeners: list[Callable[[RunEvent], Awaitable[None]]] = []
129
+
130
+ # ----- public hooks -----
131
+ def add_listener(self, cb: Callable[[NodeEvent], Awaitable[None]]):
132
+ if not inspect.iscoroutinefunction(cb):
133
+ raise ValueError("Listener must be an async function")
134
+ self._listeners.append(cb)
135
+
136
+ async def submit(
137
+ self, *, run_id: str, graph: TaskGraph, env: RuntimeEnv, settings: RunSettings | None = None
138
+ ):
139
+ """Register a new run (graph+env) with optional per-run settings."""
140
+ if run_id in self._runs:
141
+ raise ValueError(f"run_id already submitted: {run_id}")
142
+ rs = RunState(run_id=run_id, graph=graph, env=env, settings=settings or RunSettings())
143
+ self._runs[run_id] = rs
144
+ self._registry.register(run_id, self) # so MultiSchedulerResumeBus can find us
145
+
146
+ async def run_until_all_done(self):
147
+ """Drive the global loop until all runs are terminal."""
148
+ await self._drive_loop(block_until="all_done")
149
+
150
+ async def run_until_complete(self, run_id: str):
151
+ """Drive the global loop until the specified run is terminal."""
152
+ await self._drive_loop(block_until=("run_done", run_id))
153
+
154
+ async def run_forever(self):
155
+ """Service-mode: keep running until shutdown() is called."""
156
+ await self._drive_loop(block_until="forever")
157
+
158
+ async def shutdown(self):
159
+ if self._terminated:
160
+ return
161
+ self._terminated = True
162
+
163
+ # mark runs as terminated & cancel sleepers/runners
164
+ for rs in self._runs.values():
165
+ rs.terminated = True
166
+ for t in list(rs.backoff_tasks.values()):
167
+ t.cancel()
168
+ for t in list(rs.running_tasks.values()):
169
+ t.cancel()
170
+
171
+ # wake the driver if it's blocked on events.get()
172
+ try:
173
+ await self._events.put(GlobalWakeupEvent(run_id="__shutdown__", node_id="__shutdown__"))
174
+ except RuntimeError as e:
175
+ # queue may be closing; best-effort
176
+ if self._logger:
177
+ self._logger.warning(f"[GlobalForwardScheduler.shutdown] failed to wake up: {e}")
178
+
179
+ # also ensure the pause gate isn’t closed
180
+ if hasattr(self, "_pause_event"):
181
+ self._pause_event.set()
182
+
183
+ async def terminate_run(self, run_id: str):
184
+ rs = self._runs.get(run_id)
185
+ if not rs:
186
+ return
187
+ rs.terminated = True
188
+ for t in list(rs.backoff_tasks.values()):
189
+ t.cancel()
190
+ for t in list(rs.running_tasks.values()):
191
+ t.cancel()
192
+
193
+ # external resume/wakeup API (called by ResumeBus)
194
+ async def on_resume_event(self, run_id: str, node_id: str, payload: dict[str, Any]):
195
+ await self._events.put(GlobalResumeEvent(run_id=run_id, node_id=node_id, payload=payload))
196
+
197
+ async def on_wakeup_event(self, run_id: str, node_id: str):
198
+ await self._events.put(GlobalWakeupEvent(run_id=run_id, node_id=node_id))
199
+
200
+ # ----- main loop -----
201
+ async def _drive_loop(self, *, block_until: str | tuple[str, str]):
202
+ if self.loop is None:
203
+ self.loop = asyncio.get_running_loop()
204
+
205
+ MAX_DRAIN = 200
206
+ while not self._terminated:
207
+ await self._pause_event.wait()
208
+
209
+ # 1) Drain control events (non-blocking)
210
+ drained = 0
211
+ while drained < MAX_DRAIN:
212
+ try:
213
+ ev = self._events.get_nowait()
214
+ except asyncio.QueueEmpty:
215
+ break
216
+ drained += 1
217
+ await self._handle_event(ev)
218
+
219
+ # 2) Attempt to schedule work
220
+ scheduled_any = await self._schedule_global()
221
+
222
+ # 3) Check termination conditions
223
+ if block_until == "all_done":
224
+ if all(
225
+ rs.all_terminal()
226
+ and not rs.running_tasks
227
+ and not rs.backoff_tasks
228
+ and not rs.resume_pending
229
+ for rs in self._runs.values()
230
+ ):
231
+ break
232
+
233
+ elif isinstance(block_until, tuple) and block_until[0] == "run_done":
234
+ tgt = self._runs.get(block_until[1])
235
+ if (
236
+ tgt
237
+ and tgt.all_terminal()
238
+ and not tgt.running_tasks
239
+ and not tgt.backoff_tasks
240
+ and not tgt.resume_pending
241
+ ):
242
+ # compute a simple status
243
+ status = "SUCCESS"
244
+ for n in tgt.graph.nodes:
245
+ if n.spec.type == "plan":
246
+ continue
247
+ if n.state.status == NodeStatus.FAILED:
248
+ status = "FAILED"
249
+ break
250
+ await self._emit_run(
251
+ RunEvent(
252
+ run_id=tgt.run_id,
253
+ status=status,
254
+ timestamp=datetime.utcnow().timestamp(),
255
+ )
256
+ )
257
+ break
258
+
259
+ # 4) If nothing is running anywhere and nothing scheduled, decide how to wait
260
+ any_running = any(rs.running_tasks for rs in self._runs.values())
261
+ if not any_running and not scheduled_any:
262
+ # if any run has waiting nodes, block for a global resume/wakeup
263
+ if any(rs.any_waiting() for rs in self._runs.values()):
264
+ ev = await self._events.get()
265
+ await self._handle_event(ev)
266
+ continue
267
+
268
+ # if all runs are terminal, and we’re not in 'forever' mode, the outer loop will exit next tick
269
+ if block_until == "forever":
270
+ # idle until next event
271
+ ev = await self._events.get()
272
+ await self._handle_event(ev)
273
+ continue
274
+
275
+ # 5) Wait for either any task to finish OR a control event
276
+ running_tasks = [t for rs in self._runs.values() for t in rs.running_tasks.values()]
277
+ ctrl = asyncio.create_task(self._events.get())
278
+ try:
279
+ if running_tasks:
280
+ done, _ = await asyncio.wait(
281
+ running_tasks + [ctrl], return_when=asyncio.FIRST_COMPLETED
282
+ )
283
+ if ctrl in done:
284
+ ev = ctrl.result()
285
+ await self._handle_event(ev)
286
+ else:
287
+ # No running tasks; wait for the next control event
288
+ ev = await ctrl
289
+ await self._handle_event(ev)
290
+ finally:
291
+ if not ctrl.done():
292
+ ctrl.cancel()
293
+
294
+ # ----- scheduling -----
295
+ async def _schedule_global(self) -> bool:
296
+ """
297
+ Global scheduling:
298
+ 1) Start resumed waiters across all runs (respect per-run capacity and optional global cap)
299
+ 2) Start any explicitly pending nodes
300
+ 3) Compute new ready sets (round-robin across runs)
301
+ """
302
+ scheduled = 0
303
+
304
+ def global_capacity_left() -> int:
305
+ if self._global_max_concurrency is None:
306
+ return 10**9
307
+ total_running = sum(len(rs.running_tasks) for rs in self._runs.values())
308
+ return max(0, self._global_max_concurrency - total_running)
309
+
310
+ # phase 1: resumed waiters first (global)
311
+ for rs in self._runs.values():
312
+ while rs.resume_pending and rs.capacity() > 0 and global_capacity_left() > 0:
313
+ nid = rs.resume_pending.pop()
314
+ node = rs.graph.node(nid)
315
+ if node and node.state.status in WAITING_STATES and nid not in rs.running_tasks:
316
+ await self._start_node(rs, node)
317
+ scheduled += 1
318
+
319
+ # phase 2: explicit pending (from run_one-style requests)
320
+ for rs in self._runs.values():
321
+ while rs.ready_pending and rs.capacity() > 0 and global_capacity_left() > 0:
322
+ nid = rs.ready_pending.pop()
323
+ node = rs.graph.node(nid)
324
+ if (
325
+ node
326
+ and nid not in rs.running_tasks
327
+ and node.state.status not in TERMINAL_STATES
328
+ and self._deps_satisfied(rs, node)
329
+ ):
330
+ await self._start_node(rs, node)
331
+ scheduled += 1
332
+
333
+ # phase 3: normal ready nodes (round-robin for fairness)
334
+ any_capacity = any(rs.capacity() > 0 for rs in self._runs.values())
335
+ if any_capacity and global_capacity_left() > 0:
336
+ # simple round-robin by iterating runs and taking up to run capacity
337
+ for rs in self._runs.values():
338
+ if rs.capacity() <= 0:
339
+ continue
340
+ ready = list(self._compute_ready(rs))
341
+ take = min(len(ready), rs.capacity(), global_capacity_left())
342
+ for nid in ready[:take]:
343
+ await self._start_node(rs, rs.graph.node(nid))
344
+ scheduled += 1
345
+
346
+ return scheduled > 0
347
+
348
+ def _compute_ready(self, rs: RunState) -> set[str]:
349
+ ready: set[str] = set()
350
+ for node in rs.graph.nodes:
351
+ node_id = node.node_id
352
+ if node.spec.type == "plan":
353
+ continue
354
+ st = node.state.status
355
+ if st in (NodeStatus.DONE, NodeStatus.FAILED, NodeStatus.SKIPPED, *WAITING_STATES):
356
+ continue
357
+ if node_id in rs.running_tasks:
358
+ continue
359
+ if self._deps_satisfied(rs, node):
360
+ ready.add(node_id)
361
+ return ready
362
+
363
+ def _deps_satisfied(self, rs: RunState, node: TaskNodeRuntime) -> bool:
364
+ for dep in node.spec.dependencies or []:
365
+ if dep == GRAPH_INPUTS_NODE_ID:
366
+ continue
367
+ dn = rs.graph.node(dep)
368
+ if dn is None or dn.state.status != NodeStatus.DONE:
369
+ return False
370
+ return True
371
+
372
+ # ----- event handling -----
373
+ async def _handle_event(
374
+ self, ev: GlobalResumeEvent | GlobalWakeupEvent | ResumeEvent | WakeupEvent
375
+ ):
376
+ # Back-compat: if someone still enqueues a plain ResumeEvent without run_id, ignore (we’re global now).
377
+ if isinstance(ev, ResumeEvent):
378
+ if self._logger:
379
+ self._logger.warning(
380
+ "Ignored legacy ResumeEvent without run_id in GlobalForwardScheduler"
381
+ )
382
+ return
383
+ if isinstance(ev, WakeupEvent):
384
+ if self._logger:
385
+ self._logger.warning(
386
+ "Ignored legacy WakeupEvent without run_id in GlobalForwardScheduler"
387
+ )
388
+ return
389
+
390
+ if isinstance(ev, GlobalResumeEvent):
391
+ rs = self._runs.get(ev.run_id)
392
+ if not rs or rs.terminated:
393
+ return
394
+ rs.resume_payloads[ev.node_id] = ev.payload
395
+ # cancel any backoff
396
+ t = rs.backoff_tasks.pop(ev.node_id, None)
397
+ if t:
398
+ t.cancel()
399
+ # try immediate start
400
+ started = await self._try_start_immediately(rs, ev.node_id)
401
+ if not started:
402
+ rs.resume_pending.add(ev.node_id)
403
+ return
404
+
405
+ if isinstance(ev, GlobalWakeupEvent):
406
+ rs = self._runs.get(ev.run_id)
407
+ if not rs or rs.terminated:
408
+ return
409
+ await self._try_start_immediately(rs, ev.node_id)
410
+ return
411
+
412
+ async def _try_start_immediately(self, rs: RunState, node_id: str) -> bool:
413
+ if rs.capacity() <= 0:
414
+ return False
415
+ node = rs.graph.node(node_id)
416
+ if not node:
417
+ return False
418
+ if node.state.status not in WAITING_STATES:
419
+ return False
420
+ await self._start_node(rs, node)
421
+ return True
422
+
423
+ # ----- node execution -----
424
+ async def _start_node(self, rs: RunState, node: TaskNodeRuntime):
425
+ from .step_forward import step_forward
426
+
427
+ if rs.terminated:
428
+ return
429
+ node_id = node.node_id
430
+ resume_payload = rs.resume_payloads.pop(node_id, None)
431
+
432
+ if node.state.status in WAITING_STATES and resume_payload is None:
433
+ # no payload yet; keep pending
434
+ rs.resume_pending.add(node_id)
435
+ return
436
+
437
+ async def _runner():
438
+ try:
439
+ await rs.graph.set_node_status(node_id, NodeStatus.RUNNING)
440
+ ctx = rs.env.make_ctx(node=node, resume_payload=resume_payload)
441
+ result = await step_forward(
442
+ node=node, ctx=ctx, retry_policy=rs.settings.retry_policy
443
+ )
444
+
445
+ if result.status == NodeStatus.DONE:
446
+ outs = result.outputs or {}
447
+ await rs.graph.set_node_outputs(node_id, outs)
448
+ await rs.graph.set_node_status(node_id, NodeStatus.DONE)
449
+ rs.env.outputs_by_node[node.node_id] = outs
450
+ await self._emit(
451
+ NodeEvent(
452
+ run_id=rs.env.run_id,
453
+ graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
454
+ node_id=node.node_id,
455
+ status=str(NodeStatus.DONE),
456
+ outputs=outs,
457
+ timestamp=datetime.utcnow().timestamp(),
458
+ )
459
+ )
460
+ elif result.status.startswith("WAITING_"):
461
+ await rs.graph.set_node_status(node_id, result.status)
462
+ await self._emit(
463
+ NodeEvent(
464
+ run_id=rs.env.run_id,
465
+ graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
466
+ node_id=node.node_id,
467
+ status=result.status,
468
+ outputs=node.outputs or {},
469
+ timestamp=datetime.utcnow().timestamp(),
470
+ )
471
+ )
472
+ elif result.status == NodeStatus.FAILED:
473
+ await rs.graph.set_node_status(node_id, NodeStatus.FAILED)
474
+ await self._emit(
475
+ NodeEvent(
476
+ run_id=rs.env.run_id,
477
+ graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
478
+ node_id=node.node_id,
479
+ status=str(NodeStatus.FAILED),
480
+ outputs=node.outputs or {},
481
+ timestamp=datetime.utcnow().timestamp(),
482
+ )
483
+ )
484
+ attempts = getattr(node, "attempts", 0)
485
+ if attempts > 0 and attempts < rs.settings.retry_policy.max_attempts:
486
+ delay = rs.settings.retry_policy.backoff(attempts - 1).total_seconds()
487
+ rs.backoff_tasks[node.node_id] = asyncio.create_task(
488
+ self._sleep_and_requeue(rs, node, delay)
489
+ )
490
+ else:
491
+ if rs.settings.skip_dependents_on_failure:
492
+ await self._skip_dependents(rs, node_id)
493
+ if rs.settings.stop_on_first_error:
494
+ rs.terminated = True
495
+ elif result.status == NodeStatus.SKIPPED:
496
+ await rs.graph.set_node_status(node_id, NodeStatus.SKIPPED)
497
+ await self._emit(
498
+ NodeEvent(
499
+ run_id=rs.env.run_id,
500
+ graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
501
+ node_id=node.node_id,
502
+ status=str(NodeStatus.SKIPPED),
503
+ outputs=node.outputs or {},
504
+ timestamp=datetime.utcnow().timestamp(),
505
+ )
506
+ )
507
+ except asyncio.CancelledError:
508
+ try:
509
+ await rs.graph.set_node_status(node_id, NodeStatus.FAILED)
510
+ except Exception as e:
511
+ if self._logger:
512
+ self._logger.warning(
513
+ f"[GlobalForwardScheduler._start_node] failed to set node {node_id} as FAILED on cancellation: {e}"
514
+ )
515
+ finally:
516
+ pass
517
+
518
+ task = asyncio.create_task(_runner())
519
+ rs.running_tasks[node_id] = task
520
+ task.add_done_callback(lambda t, nid=node_id, r=rs: r.running_tasks.pop(nid, None))
521
+
522
+ async def _sleep_and_requeue(self, rs: RunState, node: TaskNodeRuntime, delay: float):
523
+ try:
524
+ await asyncio.sleep(delay)
525
+ if not rs.terminated:
526
+ await self._start_node(rs, node)
527
+ except asyncio.CancelledError:
528
+ pass
529
+ finally:
530
+ rs.backoff_tasks.pop(node.node_id, None)
531
+
532
+ async def _skip_dependents(self, rs: RunState, failed_node_id: str):
533
+ q = [failed_node_id]
534
+ seen = set()
535
+ while q:
536
+ cur = q.pop(0)
537
+ for n in rs.graph.nodes:
538
+ if cur in (n.spec.dependencies or []):
539
+ if n.node_id in seen:
540
+ continue
541
+ seen.add(n.node_id)
542
+ node = rs.graph.node(n.node_id)
543
+ if (
544
+ node.state.status not in TERMINAL_STATES
545
+ and n.node_id not in rs.running_tasks
546
+ ):
547
+ await rs.graph.set_node_status(n.node_id, NodeStatus.SKIPPED)
548
+ q.append(n.node_id)
549
+
550
+ async def _emit(self, event: NodeEvent):
551
+ for cb in self._listeners:
552
+ try:
553
+ await cb(event)
554
+ except Exception as e:
555
+ if self._logger:
556
+ self._logger.warning(f"[GlobalForwardScheduler._emit] listener error: {e}")
557
+ else:
558
+ print(f"[GlobalForwardScheduler._emit] listener error: {e}")
559
+
560
+ def _get_run(self, run_id: str) -> RunState:
561
+ rs = self._runs.get(run_id)
562
+ if not rs:
563
+ raise KeyError(f"Unknown run_id: {run_id}")
564
+ return rs
565
+
566
+ async def enqueue_ready(self, run_id: str, node_id: str) -> None:
567
+ """Mark a node in this run as explicitly pending (like old run_one)."""
568
+ rs = self._get_run(run_id)
569
+ rs.ready_pending.add(node_id)
570
+ # nudge the loop (a no-op wakeup is fine)
571
+ await self._events.put(GlobalWakeupEvent(run_id=run_id, node_id=node_id))
572
+
573
+ async def wait_for_node_terminal(self, run_id: str, node_id: str) -> dict[str, Any]:
574
+ """Resolve when node reaches a terminal status; return its outputs (may be {})."""
575
+ loop = asyncio.get_running_loop()
576
+ fut: asyncio.Future = loop.create_future()
577
+
578
+ statuses_to_wait_for = (
579
+ str(NodeStatus.DONE),
580
+ str(NodeStatus.FAILED),
581
+ str(NodeStatus.SKIPPED),
582
+ )
583
+
584
+ async def _once(ev):
585
+ if ev.run_id == run_id and ev.node_id == node_id and ev.status in statuses_to_wait_for:
586
+ if not fut.done():
587
+ fut.set_result(ev.outputs or {})
588
+ else:
589
+ pass
590
+
591
+ # one-shot listener
592
+ self.add_listener(_once)
593
+ try:
594
+ return await fut
595
+ finally:
596
+ # best-effort: remove listener by rebuilding list (small scale)
597
+ self._listeners = [cb for cb in self._listeners if cb is not _once]
598
+
599
+ def post_resume_event_threadsafe(self, run_id: str, node_id: str, payload: dict) -> None:
600
+ if self.loop is None:
601
+ raise RuntimeError("GlobalForwardScheduler.loop is not set yet")
602
+ self.loop.call_soon_threadsafe(
603
+ self._events.put_nowait,
604
+ GlobalResumeEvent(run_id=run_id, node_id=node_id, payload=payload),
605
+ )
606
+
607
+ def get_status(self) -> dict:
608
+ runs = {}
609
+ for run_id, rs in self._runs.items():
610
+ waiting = sum(1 for n in rs.graph.nodes if n.state.status in WAITING_STATES)
611
+ runs[run_id] = {
612
+ "running": len(rs.running_tasks),
613
+ "resume_pending": len(rs.resume_pending),
614
+ "backoff_sleepers": len(rs.backoff_tasks),
615
+ "waiting": waiting,
616
+ "terminated": rs.terminated,
617
+ "capacity": rs.capacity(),
618
+ }
619
+ total_running = sum(r["running"] for r in runs.values())
620
+ idle = (total_running == 0) and any(r["waiting"] > 0 for r in runs.values())
621
+ return {"idle": idle, "runs": runs}
622
+
623
+ def add_run_listener(self, cb):
624
+ if not inspect.iscoroutinefunction(cb):
625
+ raise ValueError("Listener must be async")
626
+ self._run_listeners.append(cb)
627
+
628
+ async def _emit_run(self, ev: RunEvent):
629
+ for cb in list(self._run_listeners):
630
+ try:
631
+ await cb(ev)
632
+ except Exception as e:
633
+ if self._logger:
634
+ self._logger.warning(f"run listener error: {e}")
@@ -0,0 +1,22 @@
1
+ from dataclasses import dataclass
2
+ from datetime import timedelta
3
+
4
+
5
+ @dataclass
6
+ class RetryPolicy:
7
+ max_attempts: int = 0
8
+ backoff_base: float = 2.0
9
+ backoff_first: float = 1.0 # seconds
10
+ retry_on: tuple[type[BaseException], ...] = (Exception,)
11
+
12
+ def should_retry(self, attempt: int, error: BaseException) -> bool:
13
+ """Determine if we should retry based on attempt count and error type."""
14
+ if attempt >= self.max_attempts:
15
+ return False
16
+ return isinstance(error, self.retry_on)
17
+
18
+ def backoff(self, attempt: int) -> float:
19
+ """Calculate backoff time in seconds for the given attempt."""
20
+ # attempt = 0 -> first failure
21
+ delay = self.backoff_first * (self.backoff_base**attempt)
22
+ return timedelta(seconds=delay)