aethergraph 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aethergraph/__init__.py +49 -0
- aethergraph/config/__init__.py +0 -0
- aethergraph/config/config.py +121 -0
- aethergraph/config/context.py +16 -0
- aethergraph/config/llm.py +26 -0
- aethergraph/config/loader.py +60 -0
- aethergraph/config/runtime.py +9 -0
- aethergraph/contracts/errors/errors.py +44 -0
- aethergraph/contracts/services/artifacts.py +142 -0
- aethergraph/contracts/services/channel.py +72 -0
- aethergraph/contracts/services/continuations.py +23 -0
- aethergraph/contracts/services/eventbus.py +12 -0
- aethergraph/contracts/services/kv.py +24 -0
- aethergraph/contracts/services/llm.py +17 -0
- aethergraph/contracts/services/mcp.py +22 -0
- aethergraph/contracts/services/memory.py +108 -0
- aethergraph/contracts/services/resume.py +28 -0
- aethergraph/contracts/services/state_stores.py +33 -0
- aethergraph/contracts/services/wakeup.py +28 -0
- aethergraph/core/execution/base_scheduler.py +77 -0
- aethergraph/core/execution/forward_scheduler.py +777 -0
- aethergraph/core/execution/global_scheduler.py +634 -0
- aethergraph/core/execution/retry_policy.py +22 -0
- aethergraph/core/execution/step_forward.py +411 -0
- aethergraph/core/execution/step_result.py +18 -0
- aethergraph/core/execution/wait_types.py +72 -0
- aethergraph/core/graph/graph_builder.py +192 -0
- aethergraph/core/graph/graph_fn.py +219 -0
- aethergraph/core/graph/graph_io.py +67 -0
- aethergraph/core/graph/graph_refs.py +154 -0
- aethergraph/core/graph/graph_spec.py +115 -0
- aethergraph/core/graph/graph_state.py +59 -0
- aethergraph/core/graph/graphify.py +128 -0
- aethergraph/core/graph/interpreter.py +145 -0
- aethergraph/core/graph/node_handle.py +33 -0
- aethergraph/core/graph/node_spec.py +46 -0
- aethergraph/core/graph/node_state.py +63 -0
- aethergraph/core/graph/task_graph.py +747 -0
- aethergraph/core/graph/task_node.py +82 -0
- aethergraph/core/graph/utils.py +37 -0
- aethergraph/core/graph/visualize.py +239 -0
- aethergraph/core/runtime/ad_hoc_context.py +61 -0
- aethergraph/core/runtime/base_service.py +153 -0
- aethergraph/core/runtime/bind_adapter.py +42 -0
- aethergraph/core/runtime/bound_memory.py +69 -0
- aethergraph/core/runtime/execution_context.py +220 -0
- aethergraph/core/runtime/graph_runner.py +349 -0
- aethergraph/core/runtime/lifecycle.py +26 -0
- aethergraph/core/runtime/node_context.py +203 -0
- aethergraph/core/runtime/node_services.py +30 -0
- aethergraph/core/runtime/recovery.py +159 -0
- aethergraph/core/runtime/run_registration.py +33 -0
- aethergraph/core/runtime/runtime_env.py +157 -0
- aethergraph/core/runtime/runtime_registry.py +32 -0
- aethergraph/core/runtime/runtime_services.py +224 -0
- aethergraph/core/runtime/wakeup_watcher.py +40 -0
- aethergraph/core/tools/__init__.py +10 -0
- aethergraph/core/tools/builtins/channel_tools.py +194 -0
- aethergraph/core/tools/builtins/toolset.py +134 -0
- aethergraph/core/tools/toolkit.py +510 -0
- aethergraph/core/tools/waitable.py +109 -0
- aethergraph/plugins/channel/__init__.py +0 -0
- aethergraph/plugins/channel/adapters/__init__.py +0 -0
- aethergraph/plugins/channel/adapters/console.py +106 -0
- aethergraph/plugins/channel/adapters/file.py +102 -0
- aethergraph/plugins/channel/adapters/slack.py +285 -0
- aethergraph/plugins/channel/adapters/telegram.py +302 -0
- aethergraph/plugins/channel/adapters/webhook.py +104 -0
- aethergraph/plugins/channel/adapters/webui.py +134 -0
- aethergraph/plugins/channel/routes/__init__.py +0 -0
- aethergraph/plugins/channel/routes/console_routes.py +86 -0
- aethergraph/plugins/channel/routes/slack_routes.py +49 -0
- aethergraph/plugins/channel/routes/telegram_routes.py +26 -0
- aethergraph/plugins/channel/routes/webui_routes.py +136 -0
- aethergraph/plugins/channel/utils/__init__.py +0 -0
- aethergraph/plugins/channel/utils/slack_utils.py +278 -0
- aethergraph/plugins/channel/utils/telegram_utils.py +324 -0
- aethergraph/plugins/channel/websockets/slack_ws.py +68 -0
- aethergraph/plugins/channel/websockets/telegram_polling.py +151 -0
- aethergraph/plugins/mcp/fs_server.py +128 -0
- aethergraph/plugins/mcp/http_server.py +101 -0
- aethergraph/plugins/mcp/ws_server.py +180 -0
- aethergraph/plugins/net/http.py +10 -0
- aethergraph/plugins/utils/data_io.py +359 -0
- aethergraph/runner/__init__.py +5 -0
- aethergraph/runtime/__init__.py +62 -0
- aethergraph/server/__init__.py +3 -0
- aethergraph/server/app_factory.py +84 -0
- aethergraph/server/start.py +122 -0
- aethergraph/services/__init__.py +10 -0
- aethergraph/services/artifacts/facade.py +284 -0
- aethergraph/services/artifacts/factory.py +35 -0
- aethergraph/services/artifacts/fs_store.py +656 -0
- aethergraph/services/artifacts/jsonl_index.py +123 -0
- aethergraph/services/artifacts/paths.py +23 -0
- aethergraph/services/artifacts/sqlite_index.py +209 -0
- aethergraph/services/artifacts/utils.py +124 -0
- aethergraph/services/auth/dev.py +16 -0
- aethergraph/services/channel/channel_bus.py +293 -0
- aethergraph/services/channel/factory.py +44 -0
- aethergraph/services/channel/session.py +511 -0
- aethergraph/services/channel/wait_helpers.py +57 -0
- aethergraph/services/clock/clock.py +9 -0
- aethergraph/services/container/default_container.py +320 -0
- aethergraph/services/continuations/continuation.py +56 -0
- aethergraph/services/continuations/factory.py +34 -0
- aethergraph/services/continuations/stores/fs_store.py +264 -0
- aethergraph/services/continuations/stores/inmem_store.py +95 -0
- aethergraph/services/eventbus/inmem.py +21 -0
- aethergraph/services/features/static.py +10 -0
- aethergraph/services/kv/ephemeral.py +90 -0
- aethergraph/services/kv/factory.py +27 -0
- aethergraph/services/kv/layered.py +41 -0
- aethergraph/services/kv/sqlite_kv.py +128 -0
- aethergraph/services/llm/factory.py +157 -0
- aethergraph/services/llm/generic_client.py +542 -0
- aethergraph/services/llm/providers.py +3 -0
- aethergraph/services/llm/service.py +105 -0
- aethergraph/services/logger/base.py +36 -0
- aethergraph/services/logger/compat.py +50 -0
- aethergraph/services/logger/formatters.py +106 -0
- aethergraph/services/logger/std.py +203 -0
- aethergraph/services/mcp/helpers.py +23 -0
- aethergraph/services/mcp/http_client.py +70 -0
- aethergraph/services/mcp/mcp_tools.py +21 -0
- aethergraph/services/mcp/registry.py +14 -0
- aethergraph/services/mcp/service.py +100 -0
- aethergraph/services/mcp/stdio_client.py +70 -0
- aethergraph/services/mcp/ws_client.py +115 -0
- aethergraph/services/memory/bound.py +106 -0
- aethergraph/services/memory/distillers/episode.py +116 -0
- aethergraph/services/memory/distillers/rolling.py +74 -0
- aethergraph/services/memory/facade.py +633 -0
- aethergraph/services/memory/factory.py +78 -0
- aethergraph/services/memory/hotlog_kv.py +27 -0
- aethergraph/services/memory/indices.py +74 -0
- aethergraph/services/memory/io_helpers.py +72 -0
- aethergraph/services/memory/persist_fs.py +40 -0
- aethergraph/services/memory/resolver.py +152 -0
- aethergraph/services/metering/noop.py +4 -0
- aethergraph/services/prompts/file_store.py +41 -0
- aethergraph/services/rag/chunker.py +29 -0
- aethergraph/services/rag/facade.py +593 -0
- aethergraph/services/rag/index/base.py +27 -0
- aethergraph/services/rag/index/faiss_index.py +121 -0
- aethergraph/services/rag/index/sqlite_index.py +134 -0
- aethergraph/services/rag/index_factory.py +52 -0
- aethergraph/services/rag/parsers/md.py +7 -0
- aethergraph/services/rag/parsers/pdf.py +14 -0
- aethergraph/services/rag/parsers/txt.py +7 -0
- aethergraph/services/rag/utils/hybrid.py +39 -0
- aethergraph/services/rag/utils/make_fs_key.py +62 -0
- aethergraph/services/redactor/simple.py +16 -0
- aethergraph/services/registry/key_parsing.py +44 -0
- aethergraph/services/registry/registry_key.py +19 -0
- aethergraph/services/registry/unified_registry.py +185 -0
- aethergraph/services/resume/multi_scheduler_resume_bus.py +65 -0
- aethergraph/services/resume/router.py +73 -0
- aethergraph/services/schedulers/registry.py +41 -0
- aethergraph/services/secrets/base.py +7 -0
- aethergraph/services/secrets/env.py +8 -0
- aethergraph/services/state_stores/externalize.py +135 -0
- aethergraph/services/state_stores/graph_observer.py +131 -0
- aethergraph/services/state_stores/json_store.py +67 -0
- aethergraph/services/state_stores/resume_policy.py +119 -0
- aethergraph/services/state_stores/serialize.py +249 -0
- aethergraph/services/state_stores/utils.py +91 -0
- aethergraph/services/state_stores/validate.py +78 -0
- aethergraph/services/tracing/noop.py +18 -0
- aethergraph/services/waits/wait_registry.py +91 -0
- aethergraph/services/wakeup/memory_queue.py +57 -0
- aethergraph/services/wakeup/scanner_producer.py +56 -0
- aethergraph/services/wakeup/worker.py +31 -0
- aethergraph/tools/__init__.py +25 -0
- aethergraph/utils/optdeps.py +8 -0
- aethergraph-0.1.0a1.dist-info/METADATA +410 -0
- aethergraph-0.1.0a1.dist-info/RECORD +182 -0
- aethergraph-0.1.0a1.dist-info/WHEEL +5 -0
- aethergraph-0.1.0a1.dist-info/entry_points.txt +2 -0
- aethergraph-0.1.0a1.dist-info/licenses/LICENSE +176 -0
- aethergraph-0.1.0a1.dist-info/licenses/NOTICE +31 -0
- aethergraph-0.1.0a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,634 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
import inspect
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from aethergraph.contracts.services.resume import (
|
|
11
|
+
ResumeEvent, # we’ll extend usage to include run_id
|
|
12
|
+
)
|
|
13
|
+
from aethergraph.contracts.services.wakeup import WakeupEvent
|
|
14
|
+
|
|
15
|
+
from ..graph.graph_refs import GRAPH_INPUTS_NODE_ID
|
|
16
|
+
from ..graph.node_spec import NodeEvent
|
|
17
|
+
from ..graph.node_state import TERMINAL_STATES, WAITING_STATES, NodeStatus
|
|
18
|
+
from ..graph.task_node import TaskNodeRuntime
|
|
19
|
+
from .retry_policy import RetryPolicy
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from aethergraph.services.schedulers.registry import SchedulerRegistry
|
|
23
|
+
|
|
24
|
+
from ..graph.task_graph import TaskGraph
|
|
25
|
+
from ..runtime.runtime_env import RuntimeEnv
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# --------- Global control events tagged with run_id ---------
|
|
29
|
+
@dataclass
|
|
30
|
+
class GlobalResumeEvent:
|
|
31
|
+
run_id: str
|
|
32
|
+
node_id: str
|
|
33
|
+
payload: dict[str, Any]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class GlobalWakeupEvent:
|
|
38
|
+
run_id: str
|
|
39
|
+
node_id: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# --------- Per-run state ---------
|
|
43
|
+
@dataclass
|
|
44
|
+
class RunSettings:
|
|
45
|
+
max_concurrency: int = 4
|
|
46
|
+
retry_policy: RetryPolicy = field(default_factory=RetryPolicy)
|
|
47
|
+
stop_on_first_error: bool = False
|
|
48
|
+
skip_dependents_on_failure: bool = True
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class RunState:
|
|
53
|
+
run_id: str
|
|
54
|
+
graph: TaskGraph
|
|
55
|
+
env: RuntimeEnv
|
|
56
|
+
settings: RunSettings
|
|
57
|
+
|
|
58
|
+
# bookkeeping
|
|
59
|
+
running_tasks: dict[str, asyncio.Task] = field(default_factory=dict) # node_id -> task
|
|
60
|
+
resume_payloads: dict[str, dict[str, Any]] = field(default_factory=dict) # node_id -> payload
|
|
61
|
+
resume_pending: set[str] = field(default_factory=set) # node_ids awaiting capacity
|
|
62
|
+
ready_pending: set[str] = field(default_factory=set) # nodes explicitly enqueued
|
|
63
|
+
backoff_tasks: dict[str, asyncio.Task] = field(default_factory=dict) # node_id -> sleeper task
|
|
64
|
+
terminated: bool = False
|
|
65
|
+
|
|
66
|
+
def capacity(self) -> int:
|
|
67
|
+
return max(0, self.settings.max_concurrency - len(self.running_tasks))
|
|
68
|
+
|
|
69
|
+
def any_waiting(self) -> bool:
|
|
70
|
+
return any(
|
|
71
|
+
(n.spec.type != "plan") and (n.state.status in WAITING_STATES) for n in self.graph.nodes
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def all_terminal(self) -> bool:
|
|
75
|
+
for n in self.graph.nodes:
|
|
76
|
+
if n.spec.type == "plan":
|
|
77
|
+
continue
|
|
78
|
+
if n.state.status not in TERMINAL_STATES:
|
|
79
|
+
return False
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class RunEvent:
|
|
85
|
+
run_id: str
|
|
86
|
+
status: str # "SUCCESS" | "FAILED" | "CANCELLED"
|
|
87
|
+
timestamp: float
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# --------- Global Forward Scheduler ---------
|
|
91
|
+
class GlobalForwardScheduler:
|
|
92
|
+
"""
|
|
93
|
+
A global event-driven DAG scheduler that coordinates execution across many graphs (runs)
|
|
94
|
+
in a single asyncio event loop.
|
|
95
|
+
|
|
96
|
+
• One global control plane (queue) carrying (run_id, node_id, payload) events
|
|
97
|
+
• Resumed nodes (WAITING_* -> RUNNING) are prioritized globally
|
|
98
|
+
• Each run has its own capacity; also a global cap can be applied if desired
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
*,
|
|
104
|
+
registry: SchedulerRegistry,
|
|
105
|
+
global_max_concurrency: int | None = None,
|
|
106
|
+
logger: Any | None = None,
|
|
107
|
+
):
|
|
108
|
+
self._runs: dict[str, RunState] = {}
|
|
109
|
+
self._listeners: list[Callable[[NodeEvent], Awaitable[None]]] = []
|
|
110
|
+
self._events: asyncio.Queue = asyncio.Queue()
|
|
111
|
+
self._pause_event = asyncio.Event()
|
|
112
|
+
self._pause_event.set()
|
|
113
|
+
self._terminated = False
|
|
114
|
+
|
|
115
|
+
# Optional global cap across all runs
|
|
116
|
+
self._global_max_concurrency = global_max_concurrency # None => unlimited
|
|
117
|
+
self._logger = logger
|
|
118
|
+
|
|
119
|
+
# registry for MultiSchedulerResumeBus routing
|
|
120
|
+
self._registry = registry
|
|
121
|
+
|
|
122
|
+
# convenience: track our loop (used by ResumeBus cross-thread dispatch)
|
|
123
|
+
try:
|
|
124
|
+
self.loop = asyncio.get_running_loop()
|
|
125
|
+
except RuntimeError:
|
|
126
|
+
self.loop = None
|
|
127
|
+
|
|
128
|
+
self._run_listeners: list[Callable[[RunEvent], Awaitable[None]]] = []
|
|
129
|
+
|
|
130
|
+
# ----- public hooks -----
|
|
131
|
+
def add_listener(self, cb: Callable[[NodeEvent], Awaitable[None]]):
|
|
132
|
+
if not inspect.iscoroutinefunction(cb):
|
|
133
|
+
raise ValueError("Listener must be an async function")
|
|
134
|
+
self._listeners.append(cb)
|
|
135
|
+
|
|
136
|
+
async def submit(
|
|
137
|
+
self, *, run_id: str, graph: TaskGraph, env: RuntimeEnv, settings: RunSettings | None = None
|
|
138
|
+
):
|
|
139
|
+
"""Register a new run (graph+env) with optional per-run settings."""
|
|
140
|
+
if run_id in self._runs:
|
|
141
|
+
raise ValueError(f"run_id already submitted: {run_id}")
|
|
142
|
+
rs = RunState(run_id=run_id, graph=graph, env=env, settings=settings or RunSettings())
|
|
143
|
+
self._runs[run_id] = rs
|
|
144
|
+
self._registry.register(run_id, self) # so MultiSchedulerResumeBus can find us
|
|
145
|
+
|
|
146
|
+
async def run_until_all_done(self):
|
|
147
|
+
"""Drive the global loop until all runs are terminal."""
|
|
148
|
+
await self._drive_loop(block_until="all_done")
|
|
149
|
+
|
|
150
|
+
async def run_until_complete(self, run_id: str):
|
|
151
|
+
"""Drive the global loop until the specified run is terminal."""
|
|
152
|
+
await self._drive_loop(block_until=("run_done", run_id))
|
|
153
|
+
|
|
154
|
+
async def run_forever(self):
|
|
155
|
+
"""Service-mode: keep running until shutdown() is called."""
|
|
156
|
+
await self._drive_loop(block_until="forever")
|
|
157
|
+
|
|
158
|
+
async def shutdown(self):
|
|
159
|
+
if self._terminated:
|
|
160
|
+
return
|
|
161
|
+
self._terminated = True
|
|
162
|
+
|
|
163
|
+
# mark runs as terminated & cancel sleepers/runners
|
|
164
|
+
for rs in self._runs.values():
|
|
165
|
+
rs.terminated = True
|
|
166
|
+
for t in list(rs.backoff_tasks.values()):
|
|
167
|
+
t.cancel()
|
|
168
|
+
for t in list(rs.running_tasks.values()):
|
|
169
|
+
t.cancel()
|
|
170
|
+
|
|
171
|
+
# wake the driver if it's blocked on events.get()
|
|
172
|
+
try:
|
|
173
|
+
await self._events.put(GlobalWakeupEvent(run_id="__shutdown__", node_id="__shutdown__"))
|
|
174
|
+
except RuntimeError as e:
|
|
175
|
+
# queue may be closing; best-effort
|
|
176
|
+
if self._logger:
|
|
177
|
+
self._logger.warning(f"[GlobalForwardScheduler.shutdown] failed to wake up: {e}")
|
|
178
|
+
|
|
179
|
+
# also ensure the pause gate isn’t closed
|
|
180
|
+
if hasattr(self, "_pause_event"):
|
|
181
|
+
self._pause_event.set()
|
|
182
|
+
|
|
183
|
+
async def terminate_run(self, run_id: str):
|
|
184
|
+
rs = self._runs.get(run_id)
|
|
185
|
+
if not rs:
|
|
186
|
+
return
|
|
187
|
+
rs.terminated = True
|
|
188
|
+
for t in list(rs.backoff_tasks.values()):
|
|
189
|
+
t.cancel()
|
|
190
|
+
for t in list(rs.running_tasks.values()):
|
|
191
|
+
t.cancel()
|
|
192
|
+
|
|
193
|
+
# external resume/wakeup API (called by ResumeBus)
|
|
194
|
+
async def on_resume_event(self, run_id: str, node_id: str, payload: dict[str, Any]):
|
|
195
|
+
await self._events.put(GlobalResumeEvent(run_id=run_id, node_id=node_id, payload=payload))
|
|
196
|
+
|
|
197
|
+
async def on_wakeup_event(self, run_id: str, node_id: str):
|
|
198
|
+
await self._events.put(GlobalWakeupEvent(run_id=run_id, node_id=node_id))
|
|
199
|
+
|
|
200
|
+
# ----- main loop -----
|
|
201
|
+
async def _drive_loop(self, *, block_until: str | tuple[str, str]):
|
|
202
|
+
if self.loop is None:
|
|
203
|
+
self.loop = asyncio.get_running_loop()
|
|
204
|
+
|
|
205
|
+
MAX_DRAIN = 200
|
|
206
|
+
while not self._terminated:
|
|
207
|
+
await self._pause_event.wait()
|
|
208
|
+
|
|
209
|
+
# 1) Drain control events (non-blocking)
|
|
210
|
+
drained = 0
|
|
211
|
+
while drained < MAX_DRAIN:
|
|
212
|
+
try:
|
|
213
|
+
ev = self._events.get_nowait()
|
|
214
|
+
except asyncio.QueueEmpty:
|
|
215
|
+
break
|
|
216
|
+
drained += 1
|
|
217
|
+
await self._handle_event(ev)
|
|
218
|
+
|
|
219
|
+
# 2) Attempt to schedule work
|
|
220
|
+
scheduled_any = await self._schedule_global()
|
|
221
|
+
|
|
222
|
+
# 3) Check termination conditions
|
|
223
|
+
if block_until == "all_done":
|
|
224
|
+
if all(
|
|
225
|
+
rs.all_terminal()
|
|
226
|
+
and not rs.running_tasks
|
|
227
|
+
and not rs.backoff_tasks
|
|
228
|
+
and not rs.resume_pending
|
|
229
|
+
for rs in self._runs.values()
|
|
230
|
+
):
|
|
231
|
+
break
|
|
232
|
+
|
|
233
|
+
elif isinstance(block_until, tuple) and block_until[0] == "run_done":
|
|
234
|
+
tgt = self._runs.get(block_until[1])
|
|
235
|
+
if (
|
|
236
|
+
tgt
|
|
237
|
+
and tgt.all_terminal()
|
|
238
|
+
and not tgt.running_tasks
|
|
239
|
+
and not tgt.backoff_tasks
|
|
240
|
+
and not tgt.resume_pending
|
|
241
|
+
):
|
|
242
|
+
# compute a simple status
|
|
243
|
+
status = "SUCCESS"
|
|
244
|
+
for n in tgt.graph.nodes:
|
|
245
|
+
if n.spec.type == "plan":
|
|
246
|
+
continue
|
|
247
|
+
if n.state.status == NodeStatus.FAILED:
|
|
248
|
+
status = "FAILED"
|
|
249
|
+
break
|
|
250
|
+
await self._emit_run(
|
|
251
|
+
RunEvent(
|
|
252
|
+
run_id=tgt.run_id,
|
|
253
|
+
status=status,
|
|
254
|
+
timestamp=datetime.utcnow().timestamp(),
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
# 4) If nothing is running anywhere and nothing scheduled, decide how to wait
|
|
260
|
+
any_running = any(rs.running_tasks for rs in self._runs.values())
|
|
261
|
+
if not any_running and not scheduled_any:
|
|
262
|
+
# if any run has waiting nodes, block for a global resume/wakeup
|
|
263
|
+
if any(rs.any_waiting() for rs in self._runs.values()):
|
|
264
|
+
ev = await self._events.get()
|
|
265
|
+
await self._handle_event(ev)
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
# if all runs are terminal, and we’re not in 'forever' mode, the outer loop will exit next tick
|
|
269
|
+
if block_until == "forever":
|
|
270
|
+
# idle until next event
|
|
271
|
+
ev = await self._events.get()
|
|
272
|
+
await self._handle_event(ev)
|
|
273
|
+
continue
|
|
274
|
+
|
|
275
|
+
# 5) Wait for either any task to finish OR a control event
|
|
276
|
+
running_tasks = [t for rs in self._runs.values() for t in rs.running_tasks.values()]
|
|
277
|
+
ctrl = asyncio.create_task(self._events.get())
|
|
278
|
+
try:
|
|
279
|
+
if running_tasks:
|
|
280
|
+
done, _ = await asyncio.wait(
|
|
281
|
+
running_tasks + [ctrl], return_when=asyncio.FIRST_COMPLETED
|
|
282
|
+
)
|
|
283
|
+
if ctrl in done:
|
|
284
|
+
ev = ctrl.result()
|
|
285
|
+
await self._handle_event(ev)
|
|
286
|
+
else:
|
|
287
|
+
# No running tasks; wait for the next control event
|
|
288
|
+
ev = await ctrl
|
|
289
|
+
await self._handle_event(ev)
|
|
290
|
+
finally:
|
|
291
|
+
if not ctrl.done():
|
|
292
|
+
ctrl.cancel()
|
|
293
|
+
|
|
294
|
+
# ----- scheduling -----
|
|
295
|
+
async def _schedule_global(self) -> bool:
|
|
296
|
+
"""
|
|
297
|
+
Global scheduling:
|
|
298
|
+
1) Start resumed waiters across all runs (respect per-run capacity and optional global cap)
|
|
299
|
+
2) Start any explicitly pending nodes
|
|
300
|
+
3) Compute new ready sets (round-robin across runs)
|
|
301
|
+
"""
|
|
302
|
+
scheduled = 0
|
|
303
|
+
|
|
304
|
+
def global_capacity_left() -> int:
|
|
305
|
+
if self._global_max_concurrency is None:
|
|
306
|
+
return 10**9
|
|
307
|
+
total_running = sum(len(rs.running_tasks) for rs in self._runs.values())
|
|
308
|
+
return max(0, self._global_max_concurrency - total_running)
|
|
309
|
+
|
|
310
|
+
# phase 1: resumed waiters first (global)
|
|
311
|
+
for rs in self._runs.values():
|
|
312
|
+
while rs.resume_pending and rs.capacity() > 0 and global_capacity_left() > 0:
|
|
313
|
+
nid = rs.resume_pending.pop()
|
|
314
|
+
node = rs.graph.node(nid)
|
|
315
|
+
if node and node.state.status in WAITING_STATES and nid not in rs.running_tasks:
|
|
316
|
+
await self._start_node(rs, node)
|
|
317
|
+
scheduled += 1
|
|
318
|
+
|
|
319
|
+
# phase 2: explicit pending (from run_one-style requests)
|
|
320
|
+
for rs in self._runs.values():
|
|
321
|
+
while rs.ready_pending and rs.capacity() > 0 and global_capacity_left() > 0:
|
|
322
|
+
nid = rs.ready_pending.pop()
|
|
323
|
+
node = rs.graph.node(nid)
|
|
324
|
+
if (
|
|
325
|
+
node
|
|
326
|
+
and nid not in rs.running_tasks
|
|
327
|
+
and node.state.status not in TERMINAL_STATES
|
|
328
|
+
and self._deps_satisfied(rs, node)
|
|
329
|
+
):
|
|
330
|
+
await self._start_node(rs, node)
|
|
331
|
+
scheduled += 1
|
|
332
|
+
|
|
333
|
+
# phase 3: normal ready nodes (round-robin for fairness)
|
|
334
|
+
any_capacity = any(rs.capacity() > 0 for rs in self._runs.values())
|
|
335
|
+
if any_capacity and global_capacity_left() > 0:
|
|
336
|
+
# simple round-robin by iterating runs and taking up to run capacity
|
|
337
|
+
for rs in self._runs.values():
|
|
338
|
+
if rs.capacity() <= 0:
|
|
339
|
+
continue
|
|
340
|
+
ready = list(self._compute_ready(rs))
|
|
341
|
+
take = min(len(ready), rs.capacity(), global_capacity_left())
|
|
342
|
+
for nid in ready[:take]:
|
|
343
|
+
await self._start_node(rs, rs.graph.node(nid))
|
|
344
|
+
scheduled += 1
|
|
345
|
+
|
|
346
|
+
return scheduled > 0
|
|
347
|
+
|
|
348
|
+
def _compute_ready(self, rs: RunState) -> set[str]:
|
|
349
|
+
ready: set[str] = set()
|
|
350
|
+
for node in rs.graph.nodes:
|
|
351
|
+
node_id = node.node_id
|
|
352
|
+
if node.spec.type == "plan":
|
|
353
|
+
continue
|
|
354
|
+
st = node.state.status
|
|
355
|
+
if st in (NodeStatus.DONE, NodeStatus.FAILED, NodeStatus.SKIPPED, *WAITING_STATES):
|
|
356
|
+
continue
|
|
357
|
+
if node_id in rs.running_tasks:
|
|
358
|
+
continue
|
|
359
|
+
if self._deps_satisfied(rs, node):
|
|
360
|
+
ready.add(node_id)
|
|
361
|
+
return ready
|
|
362
|
+
|
|
363
|
+
def _deps_satisfied(self, rs: RunState, node: TaskNodeRuntime) -> bool:
|
|
364
|
+
for dep in node.spec.dependencies or []:
|
|
365
|
+
if dep == GRAPH_INPUTS_NODE_ID:
|
|
366
|
+
continue
|
|
367
|
+
dn = rs.graph.node(dep)
|
|
368
|
+
if dn is None or dn.state.status != NodeStatus.DONE:
|
|
369
|
+
return False
|
|
370
|
+
return True
|
|
371
|
+
|
|
372
|
+
# ----- event handling -----
|
|
373
|
+
async def _handle_event(
|
|
374
|
+
self, ev: GlobalResumeEvent | GlobalWakeupEvent | ResumeEvent | WakeupEvent
|
|
375
|
+
):
|
|
376
|
+
# Back-compat: if someone still enqueues a plain ResumeEvent without run_id, ignore (we’re global now).
|
|
377
|
+
if isinstance(ev, ResumeEvent):
|
|
378
|
+
if self._logger:
|
|
379
|
+
self._logger.warning(
|
|
380
|
+
"Ignored legacy ResumeEvent without run_id in GlobalForwardScheduler"
|
|
381
|
+
)
|
|
382
|
+
return
|
|
383
|
+
if isinstance(ev, WakeupEvent):
|
|
384
|
+
if self._logger:
|
|
385
|
+
self._logger.warning(
|
|
386
|
+
"Ignored legacy WakeupEvent without run_id in GlobalForwardScheduler"
|
|
387
|
+
)
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
if isinstance(ev, GlobalResumeEvent):
|
|
391
|
+
rs = self._runs.get(ev.run_id)
|
|
392
|
+
if not rs or rs.terminated:
|
|
393
|
+
return
|
|
394
|
+
rs.resume_payloads[ev.node_id] = ev.payload
|
|
395
|
+
# cancel any backoff
|
|
396
|
+
t = rs.backoff_tasks.pop(ev.node_id, None)
|
|
397
|
+
if t:
|
|
398
|
+
t.cancel()
|
|
399
|
+
# try immediate start
|
|
400
|
+
started = await self._try_start_immediately(rs, ev.node_id)
|
|
401
|
+
if not started:
|
|
402
|
+
rs.resume_pending.add(ev.node_id)
|
|
403
|
+
return
|
|
404
|
+
|
|
405
|
+
if isinstance(ev, GlobalWakeupEvent):
|
|
406
|
+
rs = self._runs.get(ev.run_id)
|
|
407
|
+
if not rs or rs.terminated:
|
|
408
|
+
return
|
|
409
|
+
await self._try_start_immediately(rs, ev.node_id)
|
|
410
|
+
return
|
|
411
|
+
|
|
412
|
+
async def _try_start_immediately(self, rs: RunState, node_id: str) -> bool:
|
|
413
|
+
if rs.capacity() <= 0:
|
|
414
|
+
return False
|
|
415
|
+
node = rs.graph.node(node_id)
|
|
416
|
+
if not node:
|
|
417
|
+
return False
|
|
418
|
+
if node.state.status not in WAITING_STATES:
|
|
419
|
+
return False
|
|
420
|
+
await self._start_node(rs, node)
|
|
421
|
+
return True
|
|
422
|
+
|
|
423
|
+
# ----- node execution -----
|
|
424
|
+
async def _start_node(self, rs: RunState, node: TaskNodeRuntime):
|
|
425
|
+
from .step_forward import step_forward
|
|
426
|
+
|
|
427
|
+
if rs.terminated:
|
|
428
|
+
return
|
|
429
|
+
node_id = node.node_id
|
|
430
|
+
resume_payload = rs.resume_payloads.pop(node_id, None)
|
|
431
|
+
|
|
432
|
+
if node.state.status in WAITING_STATES and resume_payload is None:
|
|
433
|
+
# no payload yet; keep pending
|
|
434
|
+
rs.resume_pending.add(node_id)
|
|
435
|
+
return
|
|
436
|
+
|
|
437
|
+
async def _runner():
|
|
438
|
+
try:
|
|
439
|
+
await rs.graph.set_node_status(node_id, NodeStatus.RUNNING)
|
|
440
|
+
ctx = rs.env.make_ctx(node=node, resume_payload=resume_payload)
|
|
441
|
+
result = await step_forward(
|
|
442
|
+
node=node, ctx=ctx, retry_policy=rs.settings.retry_policy
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
if result.status == NodeStatus.DONE:
|
|
446
|
+
outs = result.outputs or {}
|
|
447
|
+
await rs.graph.set_node_outputs(node_id, outs)
|
|
448
|
+
await rs.graph.set_node_status(node_id, NodeStatus.DONE)
|
|
449
|
+
rs.env.outputs_by_node[node.node_id] = outs
|
|
450
|
+
await self._emit(
|
|
451
|
+
NodeEvent(
|
|
452
|
+
run_id=rs.env.run_id,
|
|
453
|
+
graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
|
|
454
|
+
node_id=node.node_id,
|
|
455
|
+
status=str(NodeStatus.DONE),
|
|
456
|
+
outputs=outs,
|
|
457
|
+
timestamp=datetime.utcnow().timestamp(),
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
elif result.status.startswith("WAITING_"):
|
|
461
|
+
await rs.graph.set_node_status(node_id, result.status)
|
|
462
|
+
await self._emit(
|
|
463
|
+
NodeEvent(
|
|
464
|
+
run_id=rs.env.run_id,
|
|
465
|
+
graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
|
|
466
|
+
node_id=node.node_id,
|
|
467
|
+
status=result.status,
|
|
468
|
+
outputs=node.outputs or {},
|
|
469
|
+
timestamp=datetime.utcnow().timestamp(),
|
|
470
|
+
)
|
|
471
|
+
)
|
|
472
|
+
elif result.status == NodeStatus.FAILED:
|
|
473
|
+
await rs.graph.set_node_status(node_id, NodeStatus.FAILED)
|
|
474
|
+
await self._emit(
|
|
475
|
+
NodeEvent(
|
|
476
|
+
run_id=rs.env.run_id,
|
|
477
|
+
graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
|
|
478
|
+
node_id=node.node_id,
|
|
479
|
+
status=str(NodeStatus.FAILED),
|
|
480
|
+
outputs=node.outputs or {},
|
|
481
|
+
timestamp=datetime.utcnow().timestamp(),
|
|
482
|
+
)
|
|
483
|
+
)
|
|
484
|
+
attempts = getattr(node, "attempts", 0)
|
|
485
|
+
if attempts > 0 and attempts < rs.settings.retry_policy.max_attempts:
|
|
486
|
+
delay = rs.settings.retry_policy.backoff(attempts - 1).total_seconds()
|
|
487
|
+
rs.backoff_tasks[node.node_id] = asyncio.create_task(
|
|
488
|
+
self._sleep_and_requeue(rs, node, delay)
|
|
489
|
+
)
|
|
490
|
+
else:
|
|
491
|
+
if rs.settings.skip_dependents_on_failure:
|
|
492
|
+
await self._skip_dependents(rs, node_id)
|
|
493
|
+
if rs.settings.stop_on_first_error:
|
|
494
|
+
rs.terminated = True
|
|
495
|
+
elif result.status == NodeStatus.SKIPPED:
|
|
496
|
+
await rs.graph.set_node_status(node_id, NodeStatus.SKIPPED)
|
|
497
|
+
await self._emit(
|
|
498
|
+
NodeEvent(
|
|
499
|
+
run_id=rs.env.run_id,
|
|
500
|
+
graph_id=getattr(rs.graph.spec, "graph_id", "inline"),
|
|
501
|
+
node_id=node.node_id,
|
|
502
|
+
status=str(NodeStatus.SKIPPED),
|
|
503
|
+
outputs=node.outputs or {},
|
|
504
|
+
timestamp=datetime.utcnow().timestamp(),
|
|
505
|
+
)
|
|
506
|
+
)
|
|
507
|
+
except asyncio.CancelledError:
|
|
508
|
+
try:
|
|
509
|
+
await rs.graph.set_node_status(node_id, NodeStatus.FAILED)
|
|
510
|
+
except Exception as e:
|
|
511
|
+
if self._logger:
|
|
512
|
+
self._logger.warning(
|
|
513
|
+
f"[GlobalForwardScheduler._start_node] failed to set node {node_id} as FAILED on cancellation: {e}"
|
|
514
|
+
)
|
|
515
|
+
finally:
|
|
516
|
+
pass
|
|
517
|
+
|
|
518
|
+
task = asyncio.create_task(_runner())
|
|
519
|
+
rs.running_tasks[node_id] = task
|
|
520
|
+
task.add_done_callback(lambda t, nid=node_id, r=rs: r.running_tasks.pop(nid, None))
|
|
521
|
+
|
|
522
|
+
async def _sleep_and_requeue(self, rs: RunState, node: TaskNodeRuntime, delay: float):
|
|
523
|
+
try:
|
|
524
|
+
await asyncio.sleep(delay)
|
|
525
|
+
if not rs.terminated:
|
|
526
|
+
await self._start_node(rs, node)
|
|
527
|
+
except asyncio.CancelledError:
|
|
528
|
+
pass
|
|
529
|
+
finally:
|
|
530
|
+
rs.backoff_tasks.pop(node.node_id, None)
|
|
531
|
+
|
|
532
|
+
async def _skip_dependents(self, rs: RunState, failed_node_id: str):
|
|
533
|
+
q = [failed_node_id]
|
|
534
|
+
seen = set()
|
|
535
|
+
while q:
|
|
536
|
+
cur = q.pop(0)
|
|
537
|
+
for n in rs.graph.nodes:
|
|
538
|
+
if cur in (n.spec.dependencies or []):
|
|
539
|
+
if n.node_id in seen:
|
|
540
|
+
continue
|
|
541
|
+
seen.add(n.node_id)
|
|
542
|
+
node = rs.graph.node(n.node_id)
|
|
543
|
+
if (
|
|
544
|
+
node.state.status not in TERMINAL_STATES
|
|
545
|
+
and n.node_id not in rs.running_tasks
|
|
546
|
+
):
|
|
547
|
+
await rs.graph.set_node_status(n.node_id, NodeStatus.SKIPPED)
|
|
548
|
+
q.append(n.node_id)
|
|
549
|
+
|
|
550
|
+
async def _emit(self, event: NodeEvent):
|
|
551
|
+
for cb in self._listeners:
|
|
552
|
+
try:
|
|
553
|
+
await cb(event)
|
|
554
|
+
except Exception as e:
|
|
555
|
+
if self._logger:
|
|
556
|
+
self._logger.warning(f"[GlobalForwardScheduler._emit] listener error: {e}")
|
|
557
|
+
else:
|
|
558
|
+
print(f"[GlobalForwardScheduler._emit] listener error: {e}")
|
|
559
|
+
|
|
560
|
+
def _get_run(self, run_id: str) -> RunState:
|
|
561
|
+
rs = self._runs.get(run_id)
|
|
562
|
+
if not rs:
|
|
563
|
+
raise KeyError(f"Unknown run_id: {run_id}")
|
|
564
|
+
return rs
|
|
565
|
+
|
|
566
|
+
async def enqueue_ready(self, run_id: str, node_id: str) -> None:
|
|
567
|
+
"""Mark a node in this run as explicitly pending (like old run_one)."""
|
|
568
|
+
rs = self._get_run(run_id)
|
|
569
|
+
rs.ready_pending.add(node_id)
|
|
570
|
+
# nudge the loop (a no-op wakeup is fine)
|
|
571
|
+
await self._events.put(GlobalWakeupEvent(run_id=run_id, node_id=node_id))
|
|
572
|
+
|
|
573
|
+
async def wait_for_node_terminal(self, run_id: str, node_id: str) -> dict[str, Any]:
|
|
574
|
+
"""Resolve when node reaches a terminal status; return its outputs (may be {})."""
|
|
575
|
+
loop = asyncio.get_running_loop()
|
|
576
|
+
fut: asyncio.Future = loop.create_future()
|
|
577
|
+
|
|
578
|
+
statuses_to_wait_for = (
|
|
579
|
+
str(NodeStatus.DONE),
|
|
580
|
+
str(NodeStatus.FAILED),
|
|
581
|
+
str(NodeStatus.SKIPPED),
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
async def _once(ev):
|
|
585
|
+
if ev.run_id == run_id and ev.node_id == node_id and ev.status in statuses_to_wait_for:
|
|
586
|
+
if not fut.done():
|
|
587
|
+
fut.set_result(ev.outputs or {})
|
|
588
|
+
else:
|
|
589
|
+
pass
|
|
590
|
+
|
|
591
|
+
# one-shot listener
|
|
592
|
+
self.add_listener(_once)
|
|
593
|
+
try:
|
|
594
|
+
return await fut
|
|
595
|
+
finally:
|
|
596
|
+
# best-effort: remove listener by rebuilding list (small scale)
|
|
597
|
+
self._listeners = [cb for cb in self._listeners if cb is not _once]
|
|
598
|
+
|
|
599
|
+
def post_resume_event_threadsafe(self, run_id: str, node_id: str, payload: dict) -> None:
|
|
600
|
+
if self.loop is None:
|
|
601
|
+
raise RuntimeError("GlobalForwardScheduler.loop is not set yet")
|
|
602
|
+
self.loop.call_soon_threadsafe(
|
|
603
|
+
self._events.put_nowait,
|
|
604
|
+
GlobalResumeEvent(run_id=run_id, node_id=node_id, payload=payload),
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
def get_status(self) -> dict:
|
|
608
|
+
runs = {}
|
|
609
|
+
for run_id, rs in self._runs.items():
|
|
610
|
+
waiting = sum(1 for n in rs.graph.nodes if n.state.status in WAITING_STATES)
|
|
611
|
+
runs[run_id] = {
|
|
612
|
+
"running": len(rs.running_tasks),
|
|
613
|
+
"resume_pending": len(rs.resume_pending),
|
|
614
|
+
"backoff_sleepers": len(rs.backoff_tasks),
|
|
615
|
+
"waiting": waiting,
|
|
616
|
+
"terminated": rs.terminated,
|
|
617
|
+
"capacity": rs.capacity(),
|
|
618
|
+
}
|
|
619
|
+
total_running = sum(r["running"] for r in runs.values())
|
|
620
|
+
idle = (total_running == 0) and any(r["waiting"] > 0 for r in runs.values())
|
|
621
|
+
return {"idle": idle, "runs": runs}
|
|
622
|
+
|
|
623
|
+
def add_run_listener(self, cb):
|
|
624
|
+
if not inspect.iscoroutinefunction(cb):
|
|
625
|
+
raise ValueError("Listener must be async")
|
|
626
|
+
self._run_listeners.append(cb)
|
|
627
|
+
|
|
628
|
+
async def _emit_run(self, ev: RunEvent):
|
|
629
|
+
for cb in list(self._run_listeners):
|
|
630
|
+
try:
|
|
631
|
+
await cb(ev)
|
|
632
|
+
except Exception as e:
|
|
633
|
+
if self._logger:
|
|
634
|
+
self._logger.warning(f"run listener error: {e}")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class RetryPolicy:
|
|
7
|
+
max_attempts: int = 0
|
|
8
|
+
backoff_base: float = 2.0
|
|
9
|
+
backoff_first: float = 1.0 # seconds
|
|
10
|
+
retry_on: tuple[type[BaseException], ...] = (Exception,)
|
|
11
|
+
|
|
12
|
+
def should_retry(self, attempt: int, error: BaseException) -> bool:
|
|
13
|
+
"""Determine if we should retry based on attempt count and error type."""
|
|
14
|
+
if attempt >= self.max_attempts:
|
|
15
|
+
return False
|
|
16
|
+
return isinstance(error, self.retry_on)
|
|
17
|
+
|
|
18
|
+
def backoff(self, attempt: int) -> float:
|
|
19
|
+
"""Calculate backoff time in seconds for the given attempt."""
|
|
20
|
+
# attempt = 0 -> first failure
|
|
21
|
+
delay = self.backoff_first * (self.backoff_base**attempt)
|
|
22
|
+
return timedelta(seconds=delay)
|