furu 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,878 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import json
5
+ import os
6
+ import socket
7
+ import threading
8
+ import time
9
+ import uuid
10
+ from dataclasses import dataclass
11
+ from datetime import datetime, timezone
12
+ from pathlib import Path
13
+ from typing import Literal, Mapping, TypedDict, cast
14
+
15
+ from ..adapters import SubmititAdapter
16
+ from ..adapters.submitit import SubmititJob
17
+ from ..config import FURU_CONFIG
18
+ from ..core import Furu
19
+ from ..errors import FuruComputeError, FuruMissingArtifact, FuruSpecMismatch
20
+ from ..runtime.logging import get_logger
21
+ from ..serialization.serializer import JsonValue
22
+ from ..storage.state import _FuruState, _StateResultFailed, _StateResultSuccess
23
+ from .paths import submitit_root_dir
24
+ from .plan import DependencyPlan, build_plan, ready_todo
25
+ from .plan_utils import reconcile_or_timeout_in_progress
26
+ from .slurm_spec import SlurmSpec
27
+ from .submitit_factory import make_executor_for_spec
28
+
29
+
30
+ FailureKind = Literal["compute", "protocol"]
31
+ PoolFailurePhase = Literal["payload", "worker"]
32
+ MISSING_HEARTBEAT_REQUEUE_LIMIT = 1
33
+
34
+
35
+ class _TaskPayload(TypedDict, total=False):
36
+ hash: str
37
+ spec_key: str
38
+ obj: JsonValue
39
+ error: str
40
+ traceback: str
41
+ attempt: int
42
+ failure_kind: FailureKind
43
+ failed_at: str
44
+ claimed_at: str
45
+ worker_id: str
46
+ missing_heartbeat_requeues: int
47
+ stale_heartbeat_requeues: int
48
+
49
+
50
+ @dataclass
51
+ class SlurmPoolRun:
52
+ run_dir: Path
53
+ submitit_root: Path
54
+ plan: DependencyPlan
55
+
56
+
57
+ def classify_pool_exception(
58
+ exc: Exception,
59
+ *,
60
+ phase: PoolFailurePhase,
61
+ state: _FuruState | None = None,
62
+ ) -> FailureKind:
63
+ if phase == "payload":
64
+ return "protocol"
65
+ if isinstance(exc, (FuruMissingArtifact, FuruSpecMismatch)):
66
+ return "protocol"
67
+ if isinstance(exc, FuruComputeError):
68
+ return "compute"
69
+ if state is not None and isinstance(state.result, _StateResultFailed):
70
+ return "compute"
71
+ return "protocol"
72
+
73
+
74
+ def _normalize_window_size(window_size: str | int, root_count: int) -> int:
75
+ if root_count == 0:
76
+ return 0
77
+ if isinstance(window_size, str):
78
+ match window_size:
79
+ case "dfs":
80
+ return 1
81
+ case "bfs":
82
+ return root_count
83
+ case _:
84
+ raise ValueError(
85
+ "window_size must be 'dfs', 'bfs', or a positive integer"
86
+ )
87
+ if isinstance(window_size, bool) or not isinstance(window_size, int):
88
+ raise TypeError("window_size must be 'dfs', 'bfs', or a positive integer")
89
+ if window_size < 1:
90
+ raise ValueError("window_size must be >= 1")
91
+ return min(window_size, root_count)
92
+
93
+
94
+ def _run_dir(run_root: Path | None) -> Path:
95
+ base = run_root or (FURU_CONFIG.base_root / "runs")
96
+ base.mkdir(parents=True, exist_ok=True)
97
+ stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
98
+ token = uuid.uuid4().hex[:6]
99
+ run_dir = base / f"{stamp}-{token}"
100
+ run_dir.mkdir(parents=True, exist_ok=True)
101
+ return run_dir
102
+
103
+
104
+ def _queue_root(run_dir: Path) -> Path:
105
+ return run_dir / "queue"
106
+
107
+
108
+ def _todo_dir(run_dir: Path, spec_key: str) -> Path:
109
+ return _queue_root(run_dir) / "todo" / spec_key
110
+
111
+
112
+ def _running_dir(run_dir: Path, spec_key: str) -> Path:
113
+ return _queue_root(run_dir) / "running" / spec_key
114
+
115
+
116
+ def _done_dir(run_dir: Path) -> Path:
117
+ return _queue_root(run_dir) / "done"
118
+
119
+
120
+ def _failed_dir(run_dir: Path) -> Path:
121
+ return _queue_root(run_dir) / "failed"
122
+
123
+
124
+ @dataclass
125
+ class _FailedQueueEntry:
126
+ path: Path
127
+ payload: _TaskPayload | None
128
+ parse_error: str | None
129
+
130
+
131
+ def _read_failed_entry(path: Path) -> _FailedQueueEntry:
132
+ try:
133
+ payload = json.loads(path.read_text())
134
+ except json.JSONDecodeError as exc:
135
+ return _FailedQueueEntry(path=path, payload=None, parse_error=str(exc))
136
+ if not isinstance(payload, dict):
137
+ return _FailedQueueEntry(
138
+ path=path,
139
+ payload=None,
140
+ parse_error="Failed payload is not a JSON object",
141
+ )
142
+ return _FailedQueueEntry(
143
+ path=path, payload=cast(_TaskPayload, payload), parse_error=None
144
+ )
145
+
146
+
147
+ def _scan_failed_tasks(run_dir: Path) -> list[_FailedQueueEntry]:
148
+ failed_root = _failed_dir(run_dir)
149
+ if not failed_root.exists():
150
+ return []
151
+ failed_files = sorted(failed_root.rglob("*.json"))
152
+ if not failed_files:
153
+ return []
154
+ return [_read_failed_entry(path) for path in failed_files]
155
+
156
+
157
+ def _worker_id_from_path(task_path: Path) -> str | None:
158
+ parts = task_path.parts
159
+ try:
160
+ running_index = parts.index("running")
161
+ except ValueError:
162
+ return None
163
+ if len(parts) <= running_index + 2:
164
+ return None
165
+ return parts[running_index + 2]
166
+
167
+
168
+ def _format_failed_entry(entry: _FailedQueueEntry, *, reason: str) -> str:
169
+ payload = entry.payload or {}
170
+ task_hash = payload.get("hash") or entry.path.stem
171
+ spec_key = payload.get("spec_key") or "unknown"
172
+ worker_id = payload.get("worker_id") or "unknown"
173
+ attempt = payload.get("attempt")
174
+ attempt_str = str(attempt) if isinstance(attempt, int) else "unknown"
175
+ failure_kind = payload.get("failure_kind") or "unknown"
176
+ error = payload.get("error") or "unknown"
177
+ lines = [
178
+ "run_slurm_pool stopped: failed task entry detected in queue/failed.",
179
+ f"Reason: {reason}",
180
+ f"path: {entry.path}",
181
+ f"hash: {task_hash}",
182
+ f"spec_key: {spec_key}",
183
+ f"worker_id: {worker_id}",
184
+ f"attempt: {attempt_str}",
185
+ f"failure_kind: {failure_kind}",
186
+ f"error: {error}",
187
+ ]
188
+ if entry.parse_error is not None:
189
+ lines.append(f"parse_error: {entry.parse_error}")
190
+ return "\n".join(lines)
191
+
192
+
193
+ def _requeue_failed_task(
194
+ run_dir: Path,
195
+ entry: _FailedQueueEntry,
196
+ payload: _TaskPayload,
197
+ *,
198
+ next_attempt: int,
199
+ ) -> None:
200
+ spec_key = payload.get("spec_key")
201
+ obj_payload = payload.get("obj")
202
+ if not isinstance(spec_key, str):
203
+ raise RuntimeError(
204
+ _format_failed_entry(entry, reason="Failed entry missing spec_key")
205
+ )
206
+ if obj_payload is None:
207
+ raise RuntimeError(
208
+ _format_failed_entry(entry, reason="Failed entry missing obj payload")
209
+ )
210
+ updated_payload = cast(_TaskPayload, dict(payload))
211
+ for stale_field in (
212
+ "error",
213
+ "failure_kind",
214
+ "traceback",
215
+ "failed_at",
216
+ "claimed_at",
217
+ "worker_id",
218
+ ):
219
+ updated_payload.pop(stale_field, None)
220
+ updated_payload["attempt"] = next_attempt
221
+ updated_payload["spec_key"] = spec_key
222
+ updated_payload["obj"] = obj_payload
223
+ task_path = _todo_dir(run_dir, spec_key) / entry.path.name
224
+ task_path.parent.mkdir(parents=True, exist_ok=True)
225
+ if task_path.exists():
226
+ raise RuntimeError(
227
+ _format_failed_entry(
228
+ entry,
229
+ reason=f"Retry conflict: todo entry already exists at {task_path}",
230
+ )
231
+ )
232
+ _atomic_write_json(task_path, updated_payload)
233
+ entry.path.unlink(missing_ok=True)
234
+
235
+
236
+ def _handle_failed_tasks(
237
+ run_dir: Path,
238
+ entries: list[_FailedQueueEntry],
239
+ *,
240
+ retry_failed: bool,
241
+ max_compute_retries: int,
242
+ ) -> int:
243
+ requeued = 0
244
+ for entry in entries:
245
+ if entry.parse_error is not None:
246
+ raise RuntimeError(
247
+ _format_failed_entry(entry, reason="Invalid failed task payload")
248
+ )
249
+ if entry.payload is None:
250
+ raise RuntimeError(
251
+ _format_failed_entry(entry, reason="Missing failed task payload")
252
+ )
253
+ payload = entry.payload
254
+ failure_kind = payload.get("failure_kind")
255
+ if failure_kind != "compute":
256
+ raise RuntimeError(
257
+ _format_failed_entry(entry, reason="Protocol failure in failed queue")
258
+ )
259
+ attempt = payload.get("attempt")
260
+ if not isinstance(attempt, int):
261
+ raise RuntimeError(
262
+ _format_failed_entry(entry, reason="Failed entry missing attempt count")
263
+ )
264
+ if retry_failed and attempt <= max_compute_retries:
265
+ _requeue_failed_task(
266
+ run_dir,
267
+ entry,
268
+ payload,
269
+ next_attempt=attempt + 1,
270
+ )
271
+ requeued += 1
272
+ continue
273
+ if not retry_failed:
274
+ reason = "Compute failure with retry_failed disabled"
275
+ else:
276
+ retries_used = max(attempt - 1, 0)
277
+ reason = (
278
+ "Compute failure exhausted retries "
279
+ f"({retries_used}/{max_compute_retries})"
280
+ )
281
+ raise RuntimeError(_format_failed_entry(entry, reason=reason))
282
+ return requeued
283
+
284
+
285
+ def _done_hashes(run_dir: Path) -> set[str]:
286
+ done_dir = _done_dir(run_dir)
287
+ if not done_dir.exists():
288
+ return set()
289
+ return {path.stem for path in done_dir.iterdir() if path.is_file()}
290
+
291
+
292
+ def _missing_spec_keys(
293
+ plan: DependencyPlan, specs: dict[str, SlurmSpec]
294
+ ) -> dict[str, list[str]]:
295
+ missing: dict[str, list[str]] = {}
296
+ for node in plan.nodes.values():
297
+ if node.status != "TODO":
298
+ continue
299
+ if node.spec_key in specs:
300
+ continue
301
+ missing.setdefault(node.spec_key, []).append(
302
+ f"{node.obj.__class__.__name__}({node.obj._furu_hash})"
303
+ )
304
+ return missing
305
+
306
+
307
+ def _task_filename(digest: str) -> str:
308
+ return f"{digest}.json"
309
+
310
+
311
+ def _atomic_write_json(path: Path, payload: Mapping[str, JsonValue]) -> None:
312
+ """Write a JSON payload atomically.
313
+
314
+ Queue entries are consumed by other processes (workers/controllers). Using a
315
+ temp file + atomic rename avoids readers observing partially-written JSON on
316
+ shared/network filesystems.
317
+ """
318
+
319
+ path.parent.mkdir(parents=True, exist_ok=True)
320
+ tmp_path = path.with_name(f"{path.name}.tmp-{uuid.uuid4().hex}")
321
+ tmp_path.write_text(json.dumps(payload, indent=2))
322
+ tmp_path.replace(path)
323
+
324
+
325
+ def _ensure_queue_layout(run_dir: Path, specs: dict[str, SlurmSpec]) -> None:
326
+ queue_root = _queue_root(run_dir)
327
+ (queue_root / "todo").mkdir(parents=True, exist_ok=True)
328
+ (queue_root / "running").mkdir(parents=True, exist_ok=True)
329
+ _done_dir(run_dir).mkdir(parents=True, exist_ok=True)
330
+ _failed_dir(run_dir).mkdir(parents=True, exist_ok=True)
331
+ for spec_key in specs:
332
+ _todo_dir(run_dir, spec_key).mkdir(parents=True, exist_ok=True)
333
+ _running_dir(run_dir, spec_key).mkdir(parents=True, exist_ok=True)
334
+
335
+
336
+ def _task_known(run_dir: Path, spec_key: str, digest: str) -> bool:
337
+ filename = _task_filename(digest)
338
+ todo_path = _todo_dir(run_dir, spec_key) / filename
339
+ if todo_path.exists():
340
+ return True
341
+ if (_done_dir(run_dir) / filename).exists():
342
+ return True
343
+ if (_failed_dir(run_dir) / filename).exists():
344
+ return True
345
+ running_root = _running_dir(run_dir, spec_key)
346
+ if running_root.exists():
347
+ for path in running_root.glob(f"*/{filename}"):
348
+ if path.exists():
349
+ return True
350
+ return False
351
+
352
+
353
+ def _enqueue_task(run_dir: Path, node_hash: str, spec_key: str, obj: Furu) -> bool:
354
+ if _task_known(run_dir, spec_key, node_hash):
355
+ return False
356
+ payload: _TaskPayload = {
357
+ "hash": node_hash,
358
+ "spec_key": spec_key,
359
+ "obj": obj.to_dict(),
360
+ "attempt": 1,
361
+ }
362
+ path = _todo_dir(run_dir, spec_key) / _task_filename(node_hash)
363
+ _atomic_write_json(path, payload)
364
+ return True
365
+
366
+
367
+ def _claim_task(run_dir: Path, spec_key: str, worker_id: str) -> Path | None:
368
+ todo_root = _todo_dir(run_dir, spec_key)
369
+ if not todo_root.exists():
370
+ return None
371
+ running_root = _running_dir(run_dir, spec_key) / worker_id
372
+ running_root.mkdir(parents=True, exist_ok=True)
373
+
374
+ for path in sorted(todo_root.glob("*.json")):
375
+ if not path.is_file():
376
+ continue
377
+ target = running_root / path.name
378
+ try:
379
+ path.replace(target)
380
+ except FileNotFoundError:
381
+ continue
382
+ now = time.time()
383
+ with contextlib.suppress(OSError):
384
+ os.utime(target, (now, now))
385
+
386
+ # Best-effort: persist an explicit claim timestamp. This makes missing-heartbeat
387
+ # grace robust on filesystems with coarse mtimes or unexpected mtime behavior.
388
+ try:
389
+ raw = json.loads(target.read_text())
390
+ if isinstance(raw, dict):
391
+ payload = cast(_TaskPayload, raw)
392
+ payload["claimed_at"] = datetime.now(timezone.utc).isoformat(
393
+ timespec="seconds"
394
+ )
395
+ payload["worker_id"] = worker_id
396
+ tmp = target.with_suffix(f".tmp-{uuid.uuid4().hex}")
397
+ tmp.write_text(json.dumps(payload, indent=2))
398
+ tmp.replace(target)
399
+ except Exception as exc:
400
+ logger = get_logger()
401
+ logger.warning(
402
+ "pool claim: failed to stamp claimed_at/worker_id for %s: %s",
403
+ target,
404
+ exc,
405
+ )
406
+ return target
407
+ return None
408
+
409
+
410
+ def _heartbeat_path(task_path: Path) -> Path:
411
+ return task_path.with_suffix(".hb")
412
+
413
+
414
+ def _touch_heartbeat(path: Path) -> None:
415
+ now = time.time()
416
+ path.parent.mkdir(parents=True, exist_ok=True)
417
+ with contextlib.suppress(OSError):
418
+ if path.exists():
419
+ os.utime(path, (now, now))
420
+ return
421
+ path.touch()
422
+
423
+
424
+ def _heartbeat_loop(
425
+ path: Path, interval_sec: float, stop_event: threading.Event
426
+ ) -> None:
427
+ while not stop_event.wait(interval_sec):
428
+ _touch_heartbeat(path)
429
+
430
+
431
+ def _mark_done(run_dir: Path, task_path: Path) -> None:
432
+ target = _done_dir(run_dir) / task_path.name
433
+ target.parent.mkdir(parents=True, exist_ok=True)
434
+ hb_path = _heartbeat_path(task_path)
435
+ try:
436
+ task_path.replace(target)
437
+ except FileNotFoundError:
438
+ hb_path.unlink(missing_ok=True)
439
+ return
440
+ hb_path.unlink(missing_ok=True)
441
+
442
+
443
+ def _mark_failed(
444
+ run_dir: Path,
445
+ task_path: Path,
446
+ message: str,
447
+ *,
448
+ failure_kind: FailureKind,
449
+ ) -> None:
450
+ payload: _TaskPayload = {
451
+ "hash": task_path.stem,
452
+ "error": message,
453
+ "failure_kind": failure_kind,
454
+ "attempt": 1,
455
+ }
456
+ try:
457
+ raw_payload = json.loads(task_path.read_text())
458
+ except (json.JSONDecodeError, FileNotFoundError):
459
+ raw_payload = None
460
+ if isinstance(raw_payload, dict):
461
+ payload.update(cast(_TaskPayload, raw_payload))
462
+ payload["error"] = message
463
+ payload["failure_kind"] = failure_kind
464
+ if not isinstance(payload.get("attempt"), int):
465
+ payload["attempt"] = 1
466
+ if "worker_id" not in payload:
467
+ worker_id = _worker_id_from_path(task_path)
468
+ if worker_id is not None:
469
+ payload["worker_id"] = worker_id
470
+ target = _failed_dir(run_dir) / task_path.name
471
+ target.parent.mkdir(parents=True, exist_ok=True)
472
+ _atomic_write_json(target, payload)
473
+ task_path.unlink(missing_ok=True)
474
+ _heartbeat_path(task_path).unlink(missing_ok=True)
475
+
476
+
477
+ def pool_worker_main(
478
+ run_dir: Path,
479
+ spec_key: str,
480
+ idle_timeout_sec: float,
481
+ poll_interval_sec: float,
482
+ ) -> None:
483
+ worker_id = f"{socket.gethostname()}-{os.getpid()}"
484
+ last_task_time = time.time()
485
+
486
+ while True:
487
+ task_path = _claim_task(run_dir, spec_key, worker_id)
488
+ if task_path is None:
489
+ if time.time() - last_task_time > idle_timeout_sec:
490
+ return
491
+ time.sleep(poll_interval_sec)
492
+ continue
493
+
494
+ last_task_time = time.time()
495
+ try:
496
+ payload = json.loads(task_path.read_text())
497
+ except json.JSONDecodeError as exc:
498
+ _mark_failed(
499
+ run_dir,
500
+ task_path,
501
+ f"Invalid task payload JSON: {exc}",
502
+ failure_kind=classify_pool_exception(exc, phase="payload"),
503
+ )
504
+ raise
505
+ obj_payload = payload.get("obj") if isinstance(payload, dict) else None
506
+ if obj_payload is None:
507
+ _mark_failed(
508
+ run_dir,
509
+ task_path,
510
+ "Missing task payload",
511
+ failure_kind="protocol",
512
+ )
513
+ raise RuntimeError("Missing task payload")
514
+
515
+ try:
516
+ obj = Furu.from_dict(obj_payload)
517
+ except Exception as exc:
518
+ _mark_failed(
519
+ run_dir,
520
+ task_path,
521
+ f"Invalid task payload: {exc}",
522
+ failure_kind=classify_pool_exception(exc, phase="payload"),
523
+ )
524
+ raise
525
+ if not isinstance(obj, Furu):
526
+ message = f"Invalid task payload: expected Furu, got {type(obj).__name__}"
527
+ _mark_failed(run_dir, task_path, message, failure_kind="protocol")
528
+ raise RuntimeError(message)
529
+ if obj._executor_spec_key() != spec_key:
530
+ message = (
531
+ f"Spec mismatch: task {obj._executor_spec_key()} on worker {spec_key}"
532
+ )
533
+ _mark_failed(run_dir, task_path, message, failure_kind="protocol")
534
+ raise RuntimeError(message)
535
+
536
+ hb_path = _heartbeat_path(task_path)
537
+ _touch_heartbeat(hb_path)
538
+ heartbeat_stop = threading.Event()
539
+ heartbeat_thread = threading.Thread(
540
+ target=_heartbeat_loop,
541
+ args=(hb_path, max(0.5, poll_interval_sec), heartbeat_stop),
542
+ daemon=True,
543
+ )
544
+ heartbeat_thread.start()
545
+
546
+ try:
547
+ obj._worker_entry(allow_failed=FURU_CONFIG.retry_failed)
548
+ except Exception as exc:
549
+ heartbeat_stop.set()
550
+ heartbeat_thread.join()
551
+ state = obj.get_state()
552
+ failure_kind = classify_pool_exception(
553
+ exc,
554
+ phase="worker",
555
+ state=state,
556
+ )
557
+ _mark_failed(run_dir, task_path, str(exc), failure_kind=failure_kind)
558
+ raise
559
+ finally:
560
+ heartbeat_stop.set()
561
+ heartbeat_thread.join()
562
+
563
+ state = obj.get_state()
564
+ if isinstance(state.result, _StateResultSuccess):
565
+ _mark_done(run_dir, task_path)
566
+ continue
567
+
568
+ if isinstance(state.result, _StateResultFailed):
569
+ message = "Task failed; furu state is failed"
570
+ failure_kind: FailureKind = "compute"
571
+ else:
572
+ message = (
573
+ f"Task did not complete successfully (state={state.result.status})"
574
+ )
575
+ failure_kind = "protocol"
576
+ _mark_failed(run_dir, task_path, message, failure_kind=failure_kind)
577
+ raise RuntimeError(message)
578
+
579
+
580
+ def _backlog(run_dir: Path, spec_key: str) -> int:
581
+ todo_dir = _todo_dir(run_dir, spec_key)
582
+ if not todo_dir.exists():
583
+ return 0
584
+ return sum(1 for path in todo_dir.glob("*.json") if path.is_file())
585
+
586
+
587
+ def _requeue_stale_running(
588
+ run_dir: Path,
589
+ *,
590
+ stale_sec: float,
591
+ heartbeat_grace_sec: float,
592
+ max_compute_retries: int,
593
+ ) -> int:
594
+ running_root = _queue_root(run_dir) / "running"
595
+ if not running_root.exists():
596
+ return 0
597
+
598
+ now = time.time()
599
+ moved = 0
600
+ for path in sorted(running_root.rglob("*.json")):
601
+ hb_path = _heartbeat_path(path)
602
+ try:
603
+ hb_mtime = hb_path.stat().st_mtime
604
+ except FileNotFoundError:
605
+ hb_mtime = None
606
+ try:
607
+ mtime = path.stat().st_mtime
608
+ except FileNotFoundError:
609
+ continue
610
+ if hb_mtime is None:
611
+ if now - mtime <= heartbeat_grace_sec:
612
+ continue
613
+ try:
614
+ raw_payload = json.loads(path.read_text())
615
+ except json.JSONDecodeError:
616
+ raw_payload = None
617
+ if not isinstance(raw_payload, dict):
618
+ message = (
619
+ "Missing heartbeat file for running task beyond grace period; "
620
+ "invalid payload."
621
+ )
622
+ _mark_failed(run_dir, path, message, failure_kind="protocol")
623
+ continue
624
+ payload = cast(_TaskPayload, raw_payload)
625
+ claimed_at = payload.get("claimed_at")
626
+ if isinstance(claimed_at, str):
627
+ normalized = claimed_at.replace("Z", "+00:00")
628
+ try:
629
+ claimed_dt = datetime.fromisoformat(normalized)
630
+ except ValueError:
631
+ logger = get_logger()
632
+ logger.warning(
633
+ "pool controller: invalid claimed_at=%r for %s; falling back to mtime",
634
+ claimed_at,
635
+ path,
636
+ )
637
+ else:
638
+ if claimed_dt.tzinfo is None:
639
+ claimed_dt = claimed_dt.replace(tzinfo=timezone.utc)
640
+ if now - claimed_dt.timestamp() <= heartbeat_grace_sec:
641
+ continue
642
+ requeues = payload.get("missing_heartbeat_requeues")
643
+ requeues_count = requeues if isinstance(requeues, int) else 0
644
+ if requeues_count < MISSING_HEARTBEAT_REQUEUE_LIMIT:
645
+ if len(path.parents) < 3:
646
+ continue
647
+ spec_key = path.parent.parent.name
648
+ target = _todo_dir(run_dir, spec_key) / path.name
649
+ if target.exists():
650
+ logger = get_logger()
651
+ logger.warning(
652
+ "run_slurm_pool: missing-heartbeat requeue found existing todo %s; cleaning up stale running entry %s",
653
+ target,
654
+ path,
655
+ )
656
+ path.unlink(missing_ok=True)
657
+ hb_path.unlink(missing_ok=True)
658
+ continue
659
+ updated_payload = dict(payload)
660
+ updated_payload["missing_heartbeat_requeues"] = requeues_count + 1
661
+ updated_payload.pop("claimed_at", None)
662
+ updated_payload.pop("worker_id", None)
663
+ _atomic_write_json(target, updated_payload)
664
+ path.unlink(missing_ok=True)
665
+ hb_path.unlink(missing_ok=True)
666
+ moved += 1
667
+ continue
668
+ message = (
669
+ "Missing heartbeat file for running task beyond grace period; "
670
+ "missing-heartbeat requeues exhausted."
671
+ )
672
+ _mark_failed(run_dir, path, message, failure_kind="protocol")
673
+ continue
674
+ if now - hb_mtime <= stale_sec:
675
+ continue
676
+ if len(path.parents) < 3:
677
+ continue
678
+ try:
679
+ raw_payload = json.loads(path.read_text())
680
+ except json.JSONDecodeError:
681
+ raw_payload = None
682
+ if not isinstance(raw_payload, dict):
683
+ message = "Stale heartbeat beyond threshold; invalid payload."
684
+ _mark_failed(run_dir, path, message, failure_kind="protocol")
685
+ raise RuntimeError(message)
686
+ payload = cast(_TaskPayload, raw_payload)
687
+ attempt = payload.get("attempt")
688
+ if not isinstance(attempt, int):
689
+ message = "Stale heartbeat beyond threshold; missing attempt count."
690
+ _mark_failed(run_dir, path, message, failure_kind="protocol")
691
+ raise RuntimeError(message)
692
+ if attempt > max_compute_retries:
693
+ retries_used = max(attempt - 1, 0)
694
+ message = (
695
+ "Stale heartbeat exhausted retries "
696
+ f"({retries_used}/{max_compute_retries})."
697
+ )
698
+ _mark_failed(run_dir, path, message, failure_kind="protocol")
699
+ raise RuntimeError(message)
700
+ spec_key = path.parent.parent.name
701
+ target = _todo_dir(run_dir, spec_key) / path.name
702
+ if target.exists():
703
+ logger = get_logger()
704
+ logger.warning(
705
+ "run_slurm_pool: stale-heartbeat requeue found existing todo %s; cleaning up stale running entry %s",
706
+ target,
707
+ path,
708
+ )
709
+ path.unlink(missing_ok=True)
710
+ hb_path.unlink(missing_ok=True)
711
+ continue
712
+ requeues = payload.get("stale_heartbeat_requeues")
713
+ requeues_count = requeues if isinstance(requeues, int) else 0
714
+ updated_payload = dict(payload)
715
+ updated_payload["attempt"] = attempt + 1
716
+ updated_payload["stale_heartbeat_requeues"] = requeues_count + 1
717
+ updated_payload.pop("claimed_at", None)
718
+ updated_payload.pop("worker_id", None)
719
+ _atomic_write_json(target, updated_payload)
720
+ path.unlink(missing_ok=True)
721
+ hb_path.unlink(missing_ok=True)
722
+ moved += 1
723
+ return moved
724
+
725
+
726
+ def run_slurm_pool(
727
+ roots: list[Furu],
728
+ *,
729
+ specs: dict[str, SlurmSpec],
730
+ max_workers_total: int = 50,
731
+ window_size: str | int = "bfs",
732
+ idle_timeout_sec: float = 60.0,
733
+ poll_interval_sec: float = 2.0,
734
+ stale_running_sec: float = 900.0,
735
+ heartbeat_grace_sec: float = 30.0,
736
+ submitit_root: Path | None = None,
737
+ run_root: Path | None = None,
738
+ ) -> SlurmPoolRun:
739
+ if "default" not in specs:
740
+ raise KeyError("Missing slurm spec for key 'default'.")
741
+ if max_workers_total < 1:
742
+ raise ValueError("max_workers_total must be >= 1")
743
+
744
+ run_dir = _run_dir(run_root)
745
+ run_id = run_dir.name
746
+ submitit_root_effective = submitit_root_dir(submitit_root)
747
+ _ensure_queue_layout(run_dir, specs)
748
+
749
+ window = _normalize_window_size(window_size, len(roots))
750
+ active_indices = list(range(min(window, len(roots))))
751
+ next_index = len(active_indices)
752
+ jobs_by_spec: dict[str, list[SubmititJob]] = {spec_key: [] for spec_key in specs}
753
+ job_adapter = SubmititAdapter(executor=None)
754
+
755
+ plan = build_plan([roots[index] for index in active_indices])
756
+
757
+ while True:
758
+ active_roots = [roots[index] for index in active_indices]
759
+ plan = build_plan(active_roots, completed_hashes=_done_hashes(run_dir))
760
+
761
+ failed_entries = _scan_failed_tasks(run_dir)
762
+ if failed_entries:
763
+ _handle_failed_tasks(
764
+ run_dir,
765
+ failed_entries,
766
+ retry_failed=FURU_CONFIG.retry_failed,
767
+ max_compute_retries=FURU_CONFIG.max_compute_retries,
768
+ )
769
+ _requeue_stale_running(
770
+ run_dir,
771
+ stale_sec=stale_running_sec,
772
+ heartbeat_grace_sec=heartbeat_grace_sec,
773
+ max_compute_retries=FURU_CONFIG.max_compute_retries,
774
+ )
775
+
776
+ if not FURU_CONFIG.retry_failed:
777
+ failed = [node for node in plan.nodes.values() if node.status == "FAILED"]
778
+ if failed:
779
+ names = ", ".join(
780
+ f"{node.obj.__class__.__name__}({node.obj._furu_hash})"
781
+ for node in failed
782
+ )
783
+ raise RuntimeError(
784
+ f"Cannot run slurm pool with failed dependencies: {names}"
785
+ )
786
+
787
+ missing_specs = _missing_spec_keys(plan, specs)
788
+ if missing_specs:
789
+ details = "; ".join(
790
+ f"{key} (e.g., {', '.join(nodes[:2])})"
791
+ for key, nodes in sorted(missing_specs.items())
792
+ )
793
+ raise KeyError(f"Missing slurm spec for keys: {details}")
794
+
795
+ ready = ready_todo(plan)
796
+ for digest in ready:
797
+ node = plan.nodes[digest]
798
+ _enqueue_task(run_dir, digest, node.spec_key, node.obj)
799
+
800
+ for spec_key, jobs in jobs_by_spec.items():
801
+ jobs_by_spec[spec_key] = [
802
+ job for job in jobs if not job_adapter.is_done(job)
803
+ ]
804
+
805
+ total_workers = sum(len(jobs) for jobs in jobs_by_spec.values())
806
+ backlog_by_spec = {spec_key: _backlog(run_dir, spec_key) for spec_key in specs}
807
+
808
+ while total_workers < max_workers_total and any(
809
+ count > 0 for count in backlog_by_spec.values()
810
+ ):
811
+ spec_key = max(backlog_by_spec, key=lambda key: backlog_by_spec[key])
812
+ if backlog_by_spec[spec_key] <= 0:
813
+ break
814
+ executor = make_executor_for_spec(
815
+ spec_key,
816
+ specs[spec_key],
817
+ kind="workers",
818
+ submitit_root=submitit_root_effective,
819
+ run_id=run_id,
820
+ )
821
+ adapter = SubmititAdapter(executor)
822
+ job = adapter.submit(
823
+ lambda: pool_worker_main(
824
+ run_dir,
825
+ spec_key,
826
+ idle_timeout_sec=idle_timeout_sec,
827
+ poll_interval_sec=poll_interval_sec,
828
+ )
829
+ )
830
+ jobs_by_spec[spec_key].append(job)
831
+ total_workers += 1
832
+ backlog_by_spec[spec_key] -= 1
833
+
834
+ finished_indices = [
835
+ index
836
+ for index in active_indices
837
+ if plan.nodes.get(roots[index]._furu_hash) is not None
838
+ and plan.nodes[roots[index]._furu_hash].status == "DONE"
839
+ ]
840
+ for index in finished_indices:
841
+ active_indices.remove(index)
842
+
843
+ while len(active_indices) < window and next_index < len(roots):
844
+ active_indices.append(next_index)
845
+ next_index += 1
846
+
847
+ if not active_indices and next_index >= len(roots):
848
+ return SlurmPoolRun(
849
+ run_dir=run_dir,
850
+ submitit_root=submitit_root_effective,
851
+ plan=plan,
852
+ )
853
+ if (
854
+ not ready
855
+ and total_workers == 0
856
+ and not any(count > 0 for count in backlog_by_spec.values())
857
+ and not any(node.status == "IN_PROGRESS" for node in plan.nodes.values())
858
+ ):
859
+ todo_nodes = [node for node in plan.nodes.values() if node.status == "TODO"]
860
+ if todo_nodes:
861
+ sample = ", ".join(
862
+ f"{node.obj.__class__.__name__}({node.obj._furu_hash})"
863
+ for node in todo_nodes[:3]
864
+ )
865
+ raise RuntimeError(
866
+ "run_slurm_pool stalled with no progress; "
867
+ f"remaining TODO nodes: {sample}"
868
+ )
869
+
870
+ if any(node.status == "IN_PROGRESS" for node in plan.nodes.values()):
871
+ stale_detected = reconcile_or_timeout_in_progress(
872
+ plan,
873
+ stale_timeout_sec=FURU_CONFIG.stale_timeout,
874
+ )
875
+ if stale_detected:
876
+ continue
877
+
878
+ time.sleep(poll_interval_sec)