furu 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- furu/__init__.py +8 -0
- furu/adapters/submitit.py +23 -2
- furu/config.py +13 -1
- furu/core/furu.py +355 -196
- furu/core/list.py +1 -1
- furu/dashboard/__init__.py +10 -1
- furu/dashboard/main.py +10 -3
- furu/errors.py +17 -4
- furu/execution/__init__.py +22 -0
- furu/execution/context.py +30 -0
- furu/execution/local.py +184 -0
- furu/execution/paths.py +20 -0
- furu/execution/plan.py +238 -0
- furu/execution/plan_utils.py +13 -0
- furu/execution/slurm_dag.py +271 -0
- furu/execution/slurm_pool.py +878 -0
- furu/execution/slurm_spec.py +38 -0
- furu/execution/submitit_factory.py +47 -0
- furu/runtime/logging.py +10 -10
- furu/storage/state.py +34 -6
- {furu-0.0.3.dist-info → furu-0.0.4.dist-info}/METADATA +74 -37
- {furu-0.0.3.dist-info → furu-0.0.4.dist-info}/RECORD +24 -14
- {furu-0.0.3.dist-info → furu-0.0.4.dist-info}/WHEEL +0 -0
- {furu-0.0.3.dist-info → furu-0.0.4.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,878 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import socket
|
|
7
|
+
import threading
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Literal, Mapping, TypedDict, cast
|
|
14
|
+
|
|
15
|
+
from ..adapters import SubmititAdapter
|
|
16
|
+
from ..adapters.submitit import SubmititJob
|
|
17
|
+
from ..config import FURU_CONFIG
|
|
18
|
+
from ..core import Furu
|
|
19
|
+
from ..errors import FuruComputeError, FuruMissingArtifact, FuruSpecMismatch
|
|
20
|
+
from ..runtime.logging import get_logger
|
|
21
|
+
from ..serialization.serializer import JsonValue
|
|
22
|
+
from ..storage.state import _FuruState, _StateResultFailed, _StateResultSuccess
|
|
23
|
+
from .paths import submitit_root_dir
|
|
24
|
+
from .plan import DependencyPlan, build_plan, ready_todo
|
|
25
|
+
from .plan_utils import reconcile_or_timeout_in_progress
|
|
26
|
+
from .slurm_spec import SlurmSpec
|
|
27
|
+
from .submitit_factory import make_executor_for_spec
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
FailureKind = Literal["compute", "protocol"]
|
|
31
|
+
PoolFailurePhase = Literal["payload", "worker"]
|
|
32
|
+
MISSING_HEARTBEAT_REQUEUE_LIMIT = 1
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _TaskPayload(TypedDict, total=False):
|
|
36
|
+
hash: str
|
|
37
|
+
spec_key: str
|
|
38
|
+
obj: JsonValue
|
|
39
|
+
error: str
|
|
40
|
+
traceback: str
|
|
41
|
+
attempt: int
|
|
42
|
+
failure_kind: FailureKind
|
|
43
|
+
failed_at: str
|
|
44
|
+
claimed_at: str
|
|
45
|
+
worker_id: str
|
|
46
|
+
missing_heartbeat_requeues: int
|
|
47
|
+
stale_heartbeat_requeues: int
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class SlurmPoolRun:
|
|
52
|
+
run_dir: Path
|
|
53
|
+
submitit_root: Path
|
|
54
|
+
plan: DependencyPlan
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def classify_pool_exception(
|
|
58
|
+
exc: Exception,
|
|
59
|
+
*,
|
|
60
|
+
phase: PoolFailurePhase,
|
|
61
|
+
state: _FuruState | None = None,
|
|
62
|
+
) -> FailureKind:
|
|
63
|
+
if phase == "payload":
|
|
64
|
+
return "protocol"
|
|
65
|
+
if isinstance(exc, (FuruMissingArtifact, FuruSpecMismatch)):
|
|
66
|
+
return "protocol"
|
|
67
|
+
if isinstance(exc, FuruComputeError):
|
|
68
|
+
return "compute"
|
|
69
|
+
if state is not None and isinstance(state.result, _StateResultFailed):
|
|
70
|
+
return "compute"
|
|
71
|
+
return "protocol"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _normalize_window_size(window_size: str | int, root_count: int) -> int:
|
|
75
|
+
if root_count == 0:
|
|
76
|
+
return 0
|
|
77
|
+
if isinstance(window_size, str):
|
|
78
|
+
match window_size:
|
|
79
|
+
case "dfs":
|
|
80
|
+
return 1
|
|
81
|
+
case "bfs":
|
|
82
|
+
return root_count
|
|
83
|
+
case _:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"window_size must be 'dfs', 'bfs', or a positive integer"
|
|
86
|
+
)
|
|
87
|
+
if isinstance(window_size, bool) or not isinstance(window_size, int):
|
|
88
|
+
raise TypeError("window_size must be 'dfs', 'bfs', or a positive integer")
|
|
89
|
+
if window_size < 1:
|
|
90
|
+
raise ValueError("window_size must be >= 1")
|
|
91
|
+
return min(window_size, root_count)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _run_dir(run_root: Path | None) -> Path:
|
|
95
|
+
base = run_root or (FURU_CONFIG.base_root / "runs")
|
|
96
|
+
base.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
98
|
+
token = uuid.uuid4().hex[:6]
|
|
99
|
+
run_dir = base / f"{stamp}-{token}"
|
|
100
|
+
run_dir.mkdir(parents=True, exist_ok=True)
|
|
101
|
+
return run_dir
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _queue_root(run_dir: Path) -> Path:
|
|
105
|
+
return run_dir / "queue"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _todo_dir(run_dir: Path, spec_key: str) -> Path:
|
|
109
|
+
return _queue_root(run_dir) / "todo" / spec_key
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _running_dir(run_dir: Path, spec_key: str) -> Path:
|
|
113
|
+
return _queue_root(run_dir) / "running" / spec_key
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _done_dir(run_dir: Path) -> Path:
|
|
117
|
+
return _queue_root(run_dir) / "done"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _failed_dir(run_dir: Path) -> Path:
|
|
121
|
+
return _queue_root(run_dir) / "failed"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class _FailedQueueEntry:
|
|
126
|
+
path: Path
|
|
127
|
+
payload: _TaskPayload | None
|
|
128
|
+
parse_error: str | None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _read_failed_entry(path: Path) -> _FailedQueueEntry:
|
|
132
|
+
try:
|
|
133
|
+
payload = json.loads(path.read_text())
|
|
134
|
+
except json.JSONDecodeError as exc:
|
|
135
|
+
return _FailedQueueEntry(path=path, payload=None, parse_error=str(exc))
|
|
136
|
+
if not isinstance(payload, dict):
|
|
137
|
+
return _FailedQueueEntry(
|
|
138
|
+
path=path,
|
|
139
|
+
payload=None,
|
|
140
|
+
parse_error="Failed payload is not a JSON object",
|
|
141
|
+
)
|
|
142
|
+
return _FailedQueueEntry(
|
|
143
|
+
path=path, payload=cast(_TaskPayload, payload), parse_error=None
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _scan_failed_tasks(run_dir: Path) -> list[_FailedQueueEntry]:
|
|
148
|
+
failed_root = _failed_dir(run_dir)
|
|
149
|
+
if not failed_root.exists():
|
|
150
|
+
return []
|
|
151
|
+
failed_files = sorted(failed_root.rglob("*.json"))
|
|
152
|
+
if not failed_files:
|
|
153
|
+
return []
|
|
154
|
+
return [_read_failed_entry(path) for path in failed_files]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _worker_id_from_path(task_path: Path) -> str | None:
|
|
158
|
+
parts = task_path.parts
|
|
159
|
+
try:
|
|
160
|
+
running_index = parts.index("running")
|
|
161
|
+
except ValueError:
|
|
162
|
+
return None
|
|
163
|
+
if len(parts) <= running_index + 2:
|
|
164
|
+
return None
|
|
165
|
+
return parts[running_index + 2]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _format_failed_entry(entry: _FailedQueueEntry, *, reason: str) -> str:
|
|
169
|
+
payload = entry.payload or {}
|
|
170
|
+
task_hash = payload.get("hash") or entry.path.stem
|
|
171
|
+
spec_key = payload.get("spec_key") or "unknown"
|
|
172
|
+
worker_id = payload.get("worker_id") or "unknown"
|
|
173
|
+
attempt = payload.get("attempt")
|
|
174
|
+
attempt_str = str(attempt) if isinstance(attempt, int) else "unknown"
|
|
175
|
+
failure_kind = payload.get("failure_kind") or "unknown"
|
|
176
|
+
error = payload.get("error") or "unknown"
|
|
177
|
+
lines = [
|
|
178
|
+
"run_slurm_pool stopped: failed task entry detected in queue/failed.",
|
|
179
|
+
f"Reason: {reason}",
|
|
180
|
+
f"path: {entry.path}",
|
|
181
|
+
f"hash: {task_hash}",
|
|
182
|
+
f"spec_key: {spec_key}",
|
|
183
|
+
f"worker_id: {worker_id}",
|
|
184
|
+
f"attempt: {attempt_str}",
|
|
185
|
+
f"failure_kind: {failure_kind}",
|
|
186
|
+
f"error: {error}",
|
|
187
|
+
]
|
|
188
|
+
if entry.parse_error is not None:
|
|
189
|
+
lines.append(f"parse_error: {entry.parse_error}")
|
|
190
|
+
return "\n".join(lines)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _requeue_failed_task(
|
|
194
|
+
run_dir: Path,
|
|
195
|
+
entry: _FailedQueueEntry,
|
|
196
|
+
payload: _TaskPayload,
|
|
197
|
+
*,
|
|
198
|
+
next_attempt: int,
|
|
199
|
+
) -> None:
|
|
200
|
+
spec_key = payload.get("spec_key")
|
|
201
|
+
obj_payload = payload.get("obj")
|
|
202
|
+
if not isinstance(spec_key, str):
|
|
203
|
+
raise RuntimeError(
|
|
204
|
+
_format_failed_entry(entry, reason="Failed entry missing spec_key")
|
|
205
|
+
)
|
|
206
|
+
if obj_payload is None:
|
|
207
|
+
raise RuntimeError(
|
|
208
|
+
_format_failed_entry(entry, reason="Failed entry missing obj payload")
|
|
209
|
+
)
|
|
210
|
+
updated_payload = cast(_TaskPayload, dict(payload))
|
|
211
|
+
for stale_field in (
|
|
212
|
+
"error",
|
|
213
|
+
"failure_kind",
|
|
214
|
+
"traceback",
|
|
215
|
+
"failed_at",
|
|
216
|
+
"claimed_at",
|
|
217
|
+
"worker_id",
|
|
218
|
+
):
|
|
219
|
+
updated_payload.pop(stale_field, None)
|
|
220
|
+
updated_payload["attempt"] = next_attempt
|
|
221
|
+
updated_payload["spec_key"] = spec_key
|
|
222
|
+
updated_payload["obj"] = obj_payload
|
|
223
|
+
task_path = _todo_dir(run_dir, spec_key) / entry.path.name
|
|
224
|
+
task_path.parent.mkdir(parents=True, exist_ok=True)
|
|
225
|
+
if task_path.exists():
|
|
226
|
+
raise RuntimeError(
|
|
227
|
+
_format_failed_entry(
|
|
228
|
+
entry,
|
|
229
|
+
reason=f"Retry conflict: todo entry already exists at {task_path}",
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
_atomic_write_json(task_path, updated_payload)
|
|
233
|
+
entry.path.unlink(missing_ok=True)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _handle_failed_tasks(
|
|
237
|
+
run_dir: Path,
|
|
238
|
+
entries: list[_FailedQueueEntry],
|
|
239
|
+
*,
|
|
240
|
+
retry_failed: bool,
|
|
241
|
+
max_compute_retries: int,
|
|
242
|
+
) -> int:
|
|
243
|
+
requeued = 0
|
|
244
|
+
for entry in entries:
|
|
245
|
+
if entry.parse_error is not None:
|
|
246
|
+
raise RuntimeError(
|
|
247
|
+
_format_failed_entry(entry, reason="Invalid failed task payload")
|
|
248
|
+
)
|
|
249
|
+
if entry.payload is None:
|
|
250
|
+
raise RuntimeError(
|
|
251
|
+
_format_failed_entry(entry, reason="Missing failed task payload")
|
|
252
|
+
)
|
|
253
|
+
payload = entry.payload
|
|
254
|
+
failure_kind = payload.get("failure_kind")
|
|
255
|
+
if failure_kind != "compute":
|
|
256
|
+
raise RuntimeError(
|
|
257
|
+
_format_failed_entry(entry, reason="Protocol failure in failed queue")
|
|
258
|
+
)
|
|
259
|
+
attempt = payload.get("attempt")
|
|
260
|
+
if not isinstance(attempt, int):
|
|
261
|
+
raise RuntimeError(
|
|
262
|
+
_format_failed_entry(entry, reason="Failed entry missing attempt count")
|
|
263
|
+
)
|
|
264
|
+
if retry_failed and attempt <= max_compute_retries:
|
|
265
|
+
_requeue_failed_task(
|
|
266
|
+
run_dir,
|
|
267
|
+
entry,
|
|
268
|
+
payload,
|
|
269
|
+
next_attempt=attempt + 1,
|
|
270
|
+
)
|
|
271
|
+
requeued += 1
|
|
272
|
+
continue
|
|
273
|
+
if not retry_failed:
|
|
274
|
+
reason = "Compute failure with retry_failed disabled"
|
|
275
|
+
else:
|
|
276
|
+
retries_used = max(attempt - 1, 0)
|
|
277
|
+
reason = (
|
|
278
|
+
"Compute failure exhausted retries "
|
|
279
|
+
f"({retries_used}/{max_compute_retries})"
|
|
280
|
+
)
|
|
281
|
+
raise RuntimeError(_format_failed_entry(entry, reason=reason))
|
|
282
|
+
return requeued
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _done_hashes(run_dir: Path) -> set[str]:
|
|
286
|
+
done_dir = _done_dir(run_dir)
|
|
287
|
+
if not done_dir.exists():
|
|
288
|
+
return set()
|
|
289
|
+
return {path.stem for path in done_dir.iterdir() if path.is_file()}
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _missing_spec_keys(
|
|
293
|
+
plan: DependencyPlan, specs: dict[str, SlurmSpec]
|
|
294
|
+
) -> dict[str, list[str]]:
|
|
295
|
+
missing: dict[str, list[str]] = {}
|
|
296
|
+
for node in plan.nodes.values():
|
|
297
|
+
if node.status != "TODO":
|
|
298
|
+
continue
|
|
299
|
+
if node.spec_key in specs:
|
|
300
|
+
continue
|
|
301
|
+
missing.setdefault(node.spec_key, []).append(
|
|
302
|
+
f"{node.obj.__class__.__name__}({node.obj._furu_hash})"
|
|
303
|
+
)
|
|
304
|
+
return missing
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _task_filename(digest: str) -> str:
|
|
308
|
+
return f"{digest}.json"
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _atomic_write_json(path: Path, payload: Mapping[str, JsonValue]) -> None:
|
|
312
|
+
"""Write a JSON payload atomically.
|
|
313
|
+
|
|
314
|
+
Queue entries are consumed by other processes (workers/controllers). Using a
|
|
315
|
+
temp file + atomic rename avoids readers observing partially-written JSON on
|
|
316
|
+
shared/network filesystems.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
320
|
+
tmp_path = path.with_name(f"{path.name}.tmp-{uuid.uuid4().hex}")
|
|
321
|
+
tmp_path.write_text(json.dumps(payload, indent=2))
|
|
322
|
+
tmp_path.replace(path)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _ensure_queue_layout(run_dir: Path, specs: dict[str, SlurmSpec]) -> None:
|
|
326
|
+
queue_root = _queue_root(run_dir)
|
|
327
|
+
(queue_root / "todo").mkdir(parents=True, exist_ok=True)
|
|
328
|
+
(queue_root / "running").mkdir(parents=True, exist_ok=True)
|
|
329
|
+
_done_dir(run_dir).mkdir(parents=True, exist_ok=True)
|
|
330
|
+
_failed_dir(run_dir).mkdir(parents=True, exist_ok=True)
|
|
331
|
+
for spec_key in specs:
|
|
332
|
+
_todo_dir(run_dir, spec_key).mkdir(parents=True, exist_ok=True)
|
|
333
|
+
_running_dir(run_dir, spec_key).mkdir(parents=True, exist_ok=True)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _task_known(run_dir: Path, spec_key: str, digest: str) -> bool:
|
|
337
|
+
filename = _task_filename(digest)
|
|
338
|
+
todo_path = _todo_dir(run_dir, spec_key) / filename
|
|
339
|
+
if todo_path.exists():
|
|
340
|
+
return True
|
|
341
|
+
if (_done_dir(run_dir) / filename).exists():
|
|
342
|
+
return True
|
|
343
|
+
if (_failed_dir(run_dir) / filename).exists():
|
|
344
|
+
return True
|
|
345
|
+
running_root = _running_dir(run_dir, spec_key)
|
|
346
|
+
if running_root.exists():
|
|
347
|
+
for path in running_root.glob(f"*/{filename}"):
|
|
348
|
+
if path.exists():
|
|
349
|
+
return True
|
|
350
|
+
return False
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _enqueue_task(run_dir: Path, node_hash: str, spec_key: str, obj: Furu) -> bool:
|
|
354
|
+
if _task_known(run_dir, spec_key, node_hash):
|
|
355
|
+
return False
|
|
356
|
+
payload: _TaskPayload = {
|
|
357
|
+
"hash": node_hash,
|
|
358
|
+
"spec_key": spec_key,
|
|
359
|
+
"obj": obj.to_dict(),
|
|
360
|
+
"attempt": 1,
|
|
361
|
+
}
|
|
362
|
+
path = _todo_dir(run_dir, spec_key) / _task_filename(node_hash)
|
|
363
|
+
_atomic_write_json(path, payload)
|
|
364
|
+
return True
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _claim_task(run_dir: Path, spec_key: str, worker_id: str) -> Path | None:
|
|
368
|
+
todo_root = _todo_dir(run_dir, spec_key)
|
|
369
|
+
if not todo_root.exists():
|
|
370
|
+
return None
|
|
371
|
+
running_root = _running_dir(run_dir, spec_key) / worker_id
|
|
372
|
+
running_root.mkdir(parents=True, exist_ok=True)
|
|
373
|
+
|
|
374
|
+
for path in sorted(todo_root.glob("*.json")):
|
|
375
|
+
if not path.is_file():
|
|
376
|
+
continue
|
|
377
|
+
target = running_root / path.name
|
|
378
|
+
try:
|
|
379
|
+
path.replace(target)
|
|
380
|
+
except FileNotFoundError:
|
|
381
|
+
continue
|
|
382
|
+
now = time.time()
|
|
383
|
+
with contextlib.suppress(OSError):
|
|
384
|
+
os.utime(target, (now, now))
|
|
385
|
+
|
|
386
|
+
# Best-effort: persist an explicit claim timestamp. This makes missing-heartbeat
|
|
387
|
+
# grace robust on filesystems with coarse mtimes or unexpected mtime behavior.
|
|
388
|
+
try:
|
|
389
|
+
raw = json.loads(target.read_text())
|
|
390
|
+
if isinstance(raw, dict):
|
|
391
|
+
payload = cast(_TaskPayload, raw)
|
|
392
|
+
payload["claimed_at"] = datetime.now(timezone.utc).isoformat(
|
|
393
|
+
timespec="seconds"
|
|
394
|
+
)
|
|
395
|
+
payload["worker_id"] = worker_id
|
|
396
|
+
tmp = target.with_suffix(f".tmp-{uuid.uuid4().hex}")
|
|
397
|
+
tmp.write_text(json.dumps(payload, indent=2))
|
|
398
|
+
tmp.replace(target)
|
|
399
|
+
except Exception as exc:
|
|
400
|
+
logger = get_logger()
|
|
401
|
+
logger.warning(
|
|
402
|
+
"pool claim: failed to stamp claimed_at/worker_id for %s: %s",
|
|
403
|
+
target,
|
|
404
|
+
exc,
|
|
405
|
+
)
|
|
406
|
+
return target
|
|
407
|
+
return None
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _heartbeat_path(task_path: Path) -> Path:
|
|
411
|
+
return task_path.with_suffix(".hb")
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def _touch_heartbeat(path: Path) -> None:
|
|
415
|
+
now = time.time()
|
|
416
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
417
|
+
with contextlib.suppress(OSError):
|
|
418
|
+
if path.exists():
|
|
419
|
+
os.utime(path, (now, now))
|
|
420
|
+
return
|
|
421
|
+
path.touch()
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _heartbeat_loop(
|
|
425
|
+
path: Path, interval_sec: float, stop_event: threading.Event
|
|
426
|
+
) -> None:
|
|
427
|
+
while not stop_event.wait(interval_sec):
|
|
428
|
+
_touch_heartbeat(path)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _mark_done(run_dir: Path, task_path: Path) -> None:
|
|
432
|
+
target = _done_dir(run_dir) / task_path.name
|
|
433
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
434
|
+
hb_path = _heartbeat_path(task_path)
|
|
435
|
+
try:
|
|
436
|
+
task_path.replace(target)
|
|
437
|
+
except FileNotFoundError:
|
|
438
|
+
hb_path.unlink(missing_ok=True)
|
|
439
|
+
return
|
|
440
|
+
hb_path.unlink(missing_ok=True)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _mark_failed(
|
|
444
|
+
run_dir: Path,
|
|
445
|
+
task_path: Path,
|
|
446
|
+
message: str,
|
|
447
|
+
*,
|
|
448
|
+
failure_kind: FailureKind,
|
|
449
|
+
) -> None:
|
|
450
|
+
payload: _TaskPayload = {
|
|
451
|
+
"hash": task_path.stem,
|
|
452
|
+
"error": message,
|
|
453
|
+
"failure_kind": failure_kind,
|
|
454
|
+
"attempt": 1,
|
|
455
|
+
}
|
|
456
|
+
try:
|
|
457
|
+
raw_payload = json.loads(task_path.read_text())
|
|
458
|
+
except (json.JSONDecodeError, FileNotFoundError):
|
|
459
|
+
raw_payload = None
|
|
460
|
+
if isinstance(raw_payload, dict):
|
|
461
|
+
payload.update(cast(_TaskPayload, raw_payload))
|
|
462
|
+
payload["error"] = message
|
|
463
|
+
payload["failure_kind"] = failure_kind
|
|
464
|
+
if not isinstance(payload.get("attempt"), int):
|
|
465
|
+
payload["attempt"] = 1
|
|
466
|
+
if "worker_id" not in payload:
|
|
467
|
+
worker_id = _worker_id_from_path(task_path)
|
|
468
|
+
if worker_id is not None:
|
|
469
|
+
payload["worker_id"] = worker_id
|
|
470
|
+
target = _failed_dir(run_dir) / task_path.name
|
|
471
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
472
|
+
_atomic_write_json(target, payload)
|
|
473
|
+
task_path.unlink(missing_ok=True)
|
|
474
|
+
_heartbeat_path(task_path).unlink(missing_ok=True)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def pool_worker_main(
|
|
478
|
+
run_dir: Path,
|
|
479
|
+
spec_key: str,
|
|
480
|
+
idle_timeout_sec: float,
|
|
481
|
+
poll_interval_sec: float,
|
|
482
|
+
) -> None:
|
|
483
|
+
worker_id = f"{socket.gethostname()}-{os.getpid()}"
|
|
484
|
+
last_task_time = time.time()
|
|
485
|
+
|
|
486
|
+
while True:
|
|
487
|
+
task_path = _claim_task(run_dir, spec_key, worker_id)
|
|
488
|
+
if task_path is None:
|
|
489
|
+
if time.time() - last_task_time > idle_timeout_sec:
|
|
490
|
+
return
|
|
491
|
+
time.sleep(poll_interval_sec)
|
|
492
|
+
continue
|
|
493
|
+
|
|
494
|
+
last_task_time = time.time()
|
|
495
|
+
try:
|
|
496
|
+
payload = json.loads(task_path.read_text())
|
|
497
|
+
except json.JSONDecodeError as exc:
|
|
498
|
+
_mark_failed(
|
|
499
|
+
run_dir,
|
|
500
|
+
task_path,
|
|
501
|
+
f"Invalid task payload JSON: {exc}",
|
|
502
|
+
failure_kind=classify_pool_exception(exc, phase="payload"),
|
|
503
|
+
)
|
|
504
|
+
raise
|
|
505
|
+
obj_payload = payload.get("obj") if isinstance(payload, dict) else None
|
|
506
|
+
if obj_payload is None:
|
|
507
|
+
_mark_failed(
|
|
508
|
+
run_dir,
|
|
509
|
+
task_path,
|
|
510
|
+
"Missing task payload",
|
|
511
|
+
failure_kind="protocol",
|
|
512
|
+
)
|
|
513
|
+
raise RuntimeError("Missing task payload")
|
|
514
|
+
|
|
515
|
+
try:
|
|
516
|
+
obj = Furu.from_dict(obj_payload)
|
|
517
|
+
except Exception as exc:
|
|
518
|
+
_mark_failed(
|
|
519
|
+
run_dir,
|
|
520
|
+
task_path,
|
|
521
|
+
f"Invalid task payload: {exc}",
|
|
522
|
+
failure_kind=classify_pool_exception(exc, phase="payload"),
|
|
523
|
+
)
|
|
524
|
+
raise
|
|
525
|
+
if not isinstance(obj, Furu):
|
|
526
|
+
message = f"Invalid task payload: expected Furu, got {type(obj).__name__}"
|
|
527
|
+
_mark_failed(run_dir, task_path, message, failure_kind="protocol")
|
|
528
|
+
raise RuntimeError(message)
|
|
529
|
+
if obj._executor_spec_key() != spec_key:
|
|
530
|
+
message = (
|
|
531
|
+
f"Spec mismatch: task {obj._executor_spec_key()} on worker {spec_key}"
|
|
532
|
+
)
|
|
533
|
+
_mark_failed(run_dir, task_path, message, failure_kind="protocol")
|
|
534
|
+
raise RuntimeError(message)
|
|
535
|
+
|
|
536
|
+
hb_path = _heartbeat_path(task_path)
|
|
537
|
+
_touch_heartbeat(hb_path)
|
|
538
|
+
heartbeat_stop = threading.Event()
|
|
539
|
+
heartbeat_thread = threading.Thread(
|
|
540
|
+
target=_heartbeat_loop,
|
|
541
|
+
args=(hb_path, max(0.5, poll_interval_sec), heartbeat_stop),
|
|
542
|
+
daemon=True,
|
|
543
|
+
)
|
|
544
|
+
heartbeat_thread.start()
|
|
545
|
+
|
|
546
|
+
try:
|
|
547
|
+
obj._worker_entry(allow_failed=FURU_CONFIG.retry_failed)
|
|
548
|
+
except Exception as exc:
|
|
549
|
+
heartbeat_stop.set()
|
|
550
|
+
heartbeat_thread.join()
|
|
551
|
+
state = obj.get_state()
|
|
552
|
+
failure_kind = classify_pool_exception(
|
|
553
|
+
exc,
|
|
554
|
+
phase="worker",
|
|
555
|
+
state=state,
|
|
556
|
+
)
|
|
557
|
+
_mark_failed(run_dir, task_path, str(exc), failure_kind=failure_kind)
|
|
558
|
+
raise
|
|
559
|
+
finally:
|
|
560
|
+
heartbeat_stop.set()
|
|
561
|
+
heartbeat_thread.join()
|
|
562
|
+
|
|
563
|
+
state = obj.get_state()
|
|
564
|
+
if isinstance(state.result, _StateResultSuccess):
|
|
565
|
+
_mark_done(run_dir, task_path)
|
|
566
|
+
continue
|
|
567
|
+
|
|
568
|
+
if isinstance(state.result, _StateResultFailed):
|
|
569
|
+
message = "Task failed; furu state is failed"
|
|
570
|
+
failure_kind: FailureKind = "compute"
|
|
571
|
+
else:
|
|
572
|
+
message = (
|
|
573
|
+
f"Task did not complete successfully (state={state.result.status})"
|
|
574
|
+
)
|
|
575
|
+
failure_kind = "protocol"
|
|
576
|
+
_mark_failed(run_dir, task_path, message, failure_kind=failure_kind)
|
|
577
|
+
raise RuntimeError(message)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _backlog(run_dir: Path, spec_key: str) -> int:
|
|
581
|
+
todo_dir = _todo_dir(run_dir, spec_key)
|
|
582
|
+
if not todo_dir.exists():
|
|
583
|
+
return 0
|
|
584
|
+
return sum(1 for path in todo_dir.glob("*.json") if path.is_file())
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def _requeue_stale_running(
|
|
588
|
+
run_dir: Path,
|
|
589
|
+
*,
|
|
590
|
+
stale_sec: float,
|
|
591
|
+
heartbeat_grace_sec: float,
|
|
592
|
+
max_compute_retries: int,
|
|
593
|
+
) -> int:
|
|
594
|
+
running_root = _queue_root(run_dir) / "running"
|
|
595
|
+
if not running_root.exists():
|
|
596
|
+
return 0
|
|
597
|
+
|
|
598
|
+
now = time.time()
|
|
599
|
+
moved = 0
|
|
600
|
+
for path in sorted(running_root.rglob("*.json")):
|
|
601
|
+
hb_path = _heartbeat_path(path)
|
|
602
|
+
try:
|
|
603
|
+
hb_mtime = hb_path.stat().st_mtime
|
|
604
|
+
except FileNotFoundError:
|
|
605
|
+
hb_mtime = None
|
|
606
|
+
try:
|
|
607
|
+
mtime = path.stat().st_mtime
|
|
608
|
+
except FileNotFoundError:
|
|
609
|
+
continue
|
|
610
|
+
if hb_mtime is None:
|
|
611
|
+
if now - mtime <= heartbeat_grace_sec:
|
|
612
|
+
continue
|
|
613
|
+
try:
|
|
614
|
+
raw_payload = json.loads(path.read_text())
|
|
615
|
+
except json.JSONDecodeError:
|
|
616
|
+
raw_payload = None
|
|
617
|
+
if not isinstance(raw_payload, dict):
|
|
618
|
+
message = (
|
|
619
|
+
"Missing heartbeat file for running task beyond grace period; "
|
|
620
|
+
"invalid payload."
|
|
621
|
+
)
|
|
622
|
+
_mark_failed(run_dir, path, message, failure_kind="protocol")
|
|
623
|
+
continue
|
|
624
|
+
payload = cast(_TaskPayload, raw_payload)
|
|
625
|
+
claimed_at = payload.get("claimed_at")
|
|
626
|
+
if isinstance(claimed_at, str):
|
|
627
|
+
normalized = claimed_at.replace("Z", "+00:00")
|
|
628
|
+
try:
|
|
629
|
+
claimed_dt = datetime.fromisoformat(normalized)
|
|
630
|
+
except ValueError:
|
|
631
|
+
logger = get_logger()
|
|
632
|
+
logger.warning(
|
|
633
|
+
"pool controller: invalid claimed_at=%r for %s; falling back to mtime",
|
|
634
|
+
claimed_at,
|
|
635
|
+
path,
|
|
636
|
+
)
|
|
637
|
+
else:
|
|
638
|
+
if claimed_dt.tzinfo is None:
|
|
639
|
+
claimed_dt = claimed_dt.replace(tzinfo=timezone.utc)
|
|
640
|
+
if now - claimed_dt.timestamp() <= heartbeat_grace_sec:
|
|
641
|
+
continue
|
|
642
|
+
requeues = payload.get("missing_heartbeat_requeues")
|
|
643
|
+
requeues_count = requeues if isinstance(requeues, int) else 0
|
|
644
|
+
if requeues_count < MISSING_HEARTBEAT_REQUEUE_LIMIT:
|
|
645
|
+
if len(path.parents) < 3:
|
|
646
|
+
continue
|
|
647
|
+
spec_key = path.parent.parent.name
|
|
648
|
+
target = _todo_dir(run_dir, spec_key) / path.name
|
|
649
|
+
if target.exists():
|
|
650
|
+
logger = get_logger()
|
|
651
|
+
logger.warning(
|
|
652
|
+
"run_slurm_pool: missing-heartbeat requeue found existing todo %s; cleaning up stale running entry %s",
|
|
653
|
+
target,
|
|
654
|
+
path,
|
|
655
|
+
)
|
|
656
|
+
path.unlink(missing_ok=True)
|
|
657
|
+
hb_path.unlink(missing_ok=True)
|
|
658
|
+
continue
|
|
659
|
+
updated_payload = dict(payload)
|
|
660
|
+
updated_payload["missing_heartbeat_requeues"] = requeues_count + 1
|
|
661
|
+
updated_payload.pop("claimed_at", None)
|
|
662
|
+
updated_payload.pop("worker_id", None)
|
|
663
|
+
_atomic_write_json(target, updated_payload)
|
|
664
|
+
path.unlink(missing_ok=True)
|
|
665
|
+
hb_path.unlink(missing_ok=True)
|
|
666
|
+
moved += 1
|
|
667
|
+
continue
|
|
668
|
+
message = (
|
|
669
|
+
"Missing heartbeat file for running task beyond grace period; "
|
|
670
|
+
"missing-heartbeat requeues exhausted."
|
|
671
|
+
)
|
|
672
|
+
_mark_failed(run_dir, path, message, failure_kind="protocol")
|
|
673
|
+
continue
|
|
674
|
+
if now - hb_mtime <= stale_sec:
|
|
675
|
+
continue
|
|
676
|
+
if len(path.parents) < 3:
|
|
677
|
+
continue
|
|
678
|
+
try:
|
|
679
|
+
raw_payload = json.loads(path.read_text())
|
|
680
|
+
except json.JSONDecodeError:
|
|
681
|
+
raw_payload = None
|
|
682
|
+
if not isinstance(raw_payload, dict):
|
|
683
|
+
message = "Stale heartbeat beyond threshold; invalid payload."
|
|
684
|
+
_mark_failed(run_dir, path, message, failure_kind="protocol")
|
|
685
|
+
raise RuntimeError(message)
|
|
686
|
+
payload = cast(_TaskPayload, raw_payload)
|
|
687
|
+
attempt = payload.get("attempt")
|
|
688
|
+
if not isinstance(attempt, int):
|
|
689
|
+
message = "Stale heartbeat beyond threshold; missing attempt count."
|
|
690
|
+
_mark_failed(run_dir, path, message, failure_kind="protocol")
|
|
691
|
+
raise RuntimeError(message)
|
|
692
|
+
if attempt > max_compute_retries:
|
|
693
|
+
retries_used = max(attempt - 1, 0)
|
|
694
|
+
message = (
|
|
695
|
+
"Stale heartbeat exhausted retries "
|
|
696
|
+
f"({retries_used}/{max_compute_retries})."
|
|
697
|
+
)
|
|
698
|
+
_mark_failed(run_dir, path, message, failure_kind="protocol")
|
|
699
|
+
raise RuntimeError(message)
|
|
700
|
+
spec_key = path.parent.parent.name
|
|
701
|
+
target = _todo_dir(run_dir, spec_key) / path.name
|
|
702
|
+
if target.exists():
|
|
703
|
+
logger = get_logger()
|
|
704
|
+
logger.warning(
|
|
705
|
+
"run_slurm_pool: stale-heartbeat requeue found existing todo %s; cleaning up stale running entry %s",
|
|
706
|
+
target,
|
|
707
|
+
path,
|
|
708
|
+
)
|
|
709
|
+
path.unlink(missing_ok=True)
|
|
710
|
+
hb_path.unlink(missing_ok=True)
|
|
711
|
+
continue
|
|
712
|
+
requeues = payload.get("stale_heartbeat_requeues")
|
|
713
|
+
requeues_count = requeues if isinstance(requeues, int) else 0
|
|
714
|
+
updated_payload = dict(payload)
|
|
715
|
+
updated_payload["attempt"] = attempt + 1
|
|
716
|
+
updated_payload["stale_heartbeat_requeues"] = requeues_count + 1
|
|
717
|
+
updated_payload.pop("claimed_at", None)
|
|
718
|
+
updated_payload.pop("worker_id", None)
|
|
719
|
+
_atomic_write_json(target, updated_payload)
|
|
720
|
+
path.unlink(missing_ok=True)
|
|
721
|
+
hb_path.unlink(missing_ok=True)
|
|
722
|
+
moved += 1
|
|
723
|
+
return moved
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def run_slurm_pool(
|
|
727
|
+
roots: list[Furu],
|
|
728
|
+
*,
|
|
729
|
+
specs: dict[str, SlurmSpec],
|
|
730
|
+
max_workers_total: int = 50,
|
|
731
|
+
window_size: str | int = "bfs",
|
|
732
|
+
idle_timeout_sec: float = 60.0,
|
|
733
|
+
poll_interval_sec: float = 2.0,
|
|
734
|
+
stale_running_sec: float = 900.0,
|
|
735
|
+
heartbeat_grace_sec: float = 30.0,
|
|
736
|
+
submitit_root: Path | None = None,
|
|
737
|
+
run_root: Path | None = None,
|
|
738
|
+
) -> SlurmPoolRun:
|
|
739
|
+
if "default" not in specs:
|
|
740
|
+
raise KeyError("Missing slurm spec for key 'default'.")
|
|
741
|
+
if max_workers_total < 1:
|
|
742
|
+
raise ValueError("max_workers_total must be >= 1")
|
|
743
|
+
|
|
744
|
+
run_dir = _run_dir(run_root)
|
|
745
|
+
run_id = run_dir.name
|
|
746
|
+
submitit_root_effective = submitit_root_dir(submitit_root)
|
|
747
|
+
_ensure_queue_layout(run_dir, specs)
|
|
748
|
+
|
|
749
|
+
window = _normalize_window_size(window_size, len(roots))
|
|
750
|
+
active_indices = list(range(min(window, len(roots))))
|
|
751
|
+
next_index = len(active_indices)
|
|
752
|
+
jobs_by_spec: dict[str, list[SubmititJob]] = {spec_key: [] for spec_key in specs}
|
|
753
|
+
job_adapter = SubmititAdapter(executor=None)
|
|
754
|
+
|
|
755
|
+
plan = build_plan([roots[index] for index in active_indices])
|
|
756
|
+
|
|
757
|
+
while True:
|
|
758
|
+
active_roots = [roots[index] for index in active_indices]
|
|
759
|
+
plan = build_plan(active_roots, completed_hashes=_done_hashes(run_dir))
|
|
760
|
+
|
|
761
|
+
failed_entries = _scan_failed_tasks(run_dir)
|
|
762
|
+
if failed_entries:
|
|
763
|
+
_handle_failed_tasks(
|
|
764
|
+
run_dir,
|
|
765
|
+
failed_entries,
|
|
766
|
+
retry_failed=FURU_CONFIG.retry_failed,
|
|
767
|
+
max_compute_retries=FURU_CONFIG.max_compute_retries,
|
|
768
|
+
)
|
|
769
|
+
_requeue_stale_running(
|
|
770
|
+
run_dir,
|
|
771
|
+
stale_sec=stale_running_sec,
|
|
772
|
+
heartbeat_grace_sec=heartbeat_grace_sec,
|
|
773
|
+
max_compute_retries=FURU_CONFIG.max_compute_retries,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
if not FURU_CONFIG.retry_failed:
|
|
777
|
+
failed = [node for node in plan.nodes.values() if node.status == "FAILED"]
|
|
778
|
+
if failed:
|
|
779
|
+
names = ", ".join(
|
|
780
|
+
f"{node.obj.__class__.__name__}({node.obj._furu_hash})"
|
|
781
|
+
for node in failed
|
|
782
|
+
)
|
|
783
|
+
raise RuntimeError(
|
|
784
|
+
f"Cannot run slurm pool with failed dependencies: {names}"
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
missing_specs = _missing_spec_keys(plan, specs)
|
|
788
|
+
if missing_specs:
|
|
789
|
+
details = "; ".join(
|
|
790
|
+
f"{key} (e.g., {', '.join(nodes[:2])})"
|
|
791
|
+
for key, nodes in sorted(missing_specs.items())
|
|
792
|
+
)
|
|
793
|
+
raise KeyError(f"Missing slurm spec for keys: {details}")
|
|
794
|
+
|
|
795
|
+
ready = ready_todo(plan)
|
|
796
|
+
for digest in ready:
|
|
797
|
+
node = plan.nodes[digest]
|
|
798
|
+
_enqueue_task(run_dir, digest, node.spec_key, node.obj)
|
|
799
|
+
|
|
800
|
+
for spec_key, jobs in jobs_by_spec.items():
|
|
801
|
+
jobs_by_spec[spec_key] = [
|
|
802
|
+
job for job in jobs if not job_adapter.is_done(job)
|
|
803
|
+
]
|
|
804
|
+
|
|
805
|
+
total_workers = sum(len(jobs) for jobs in jobs_by_spec.values())
|
|
806
|
+
backlog_by_spec = {spec_key: _backlog(run_dir, spec_key) for spec_key in specs}
|
|
807
|
+
|
|
808
|
+
while total_workers < max_workers_total and any(
|
|
809
|
+
count > 0 for count in backlog_by_spec.values()
|
|
810
|
+
):
|
|
811
|
+
spec_key = max(backlog_by_spec, key=lambda key: backlog_by_spec[key])
|
|
812
|
+
if backlog_by_spec[spec_key] <= 0:
|
|
813
|
+
break
|
|
814
|
+
executor = make_executor_for_spec(
|
|
815
|
+
spec_key,
|
|
816
|
+
specs[spec_key],
|
|
817
|
+
kind="workers",
|
|
818
|
+
submitit_root=submitit_root_effective,
|
|
819
|
+
run_id=run_id,
|
|
820
|
+
)
|
|
821
|
+
adapter = SubmititAdapter(executor)
|
|
822
|
+
job = adapter.submit(
|
|
823
|
+
lambda: pool_worker_main(
|
|
824
|
+
run_dir,
|
|
825
|
+
spec_key,
|
|
826
|
+
idle_timeout_sec=idle_timeout_sec,
|
|
827
|
+
poll_interval_sec=poll_interval_sec,
|
|
828
|
+
)
|
|
829
|
+
)
|
|
830
|
+
jobs_by_spec[spec_key].append(job)
|
|
831
|
+
total_workers += 1
|
|
832
|
+
backlog_by_spec[spec_key] -= 1
|
|
833
|
+
|
|
834
|
+
finished_indices = [
|
|
835
|
+
index
|
|
836
|
+
for index in active_indices
|
|
837
|
+
if plan.nodes.get(roots[index]._furu_hash) is not None
|
|
838
|
+
and plan.nodes[roots[index]._furu_hash].status == "DONE"
|
|
839
|
+
]
|
|
840
|
+
for index in finished_indices:
|
|
841
|
+
active_indices.remove(index)
|
|
842
|
+
|
|
843
|
+
while len(active_indices) < window and next_index < len(roots):
|
|
844
|
+
active_indices.append(next_index)
|
|
845
|
+
next_index += 1
|
|
846
|
+
|
|
847
|
+
if not active_indices and next_index >= len(roots):
|
|
848
|
+
return SlurmPoolRun(
|
|
849
|
+
run_dir=run_dir,
|
|
850
|
+
submitit_root=submitit_root_effective,
|
|
851
|
+
plan=plan,
|
|
852
|
+
)
|
|
853
|
+
if (
|
|
854
|
+
not ready
|
|
855
|
+
and total_workers == 0
|
|
856
|
+
and not any(count > 0 for count in backlog_by_spec.values())
|
|
857
|
+
and not any(node.status == "IN_PROGRESS" for node in plan.nodes.values())
|
|
858
|
+
):
|
|
859
|
+
todo_nodes = [node for node in plan.nodes.values() if node.status == "TODO"]
|
|
860
|
+
if todo_nodes:
|
|
861
|
+
sample = ", ".join(
|
|
862
|
+
f"{node.obj.__class__.__name__}({node.obj._furu_hash})"
|
|
863
|
+
for node in todo_nodes[:3]
|
|
864
|
+
)
|
|
865
|
+
raise RuntimeError(
|
|
866
|
+
"run_slurm_pool stalled with no progress; "
|
|
867
|
+
f"remaining TODO nodes: {sample}"
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
if any(node.status == "IN_PROGRESS" for node in plan.nodes.values()):
|
|
871
|
+
stale_detected = reconcile_or_timeout_in_progress(
|
|
872
|
+
plan,
|
|
873
|
+
stale_timeout_sec=FURU_CONFIG.stale_timeout,
|
|
874
|
+
)
|
|
875
|
+
if stale_detected:
|
|
876
|
+
continue
|
|
877
|
+
|
|
878
|
+
time.sleep(poll_interval_sec)
|