furu 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
furu/core/furu.py ADDED
@@ -0,0 +1,999 @@
1
+ import datetime
2
+ import getpass
3
+ import inspect
4
+ import os
5
+ import signal
6
+ import socket
7
+ import sys
8
+ import time
9
+ import traceback
10
+ from abc import ABC, abstractmethod
11
+ from pathlib import Path
12
+ from types import FrameType
13
+ from typing import Any, Callable, ClassVar, Self, TypedDict, TypeVar, cast, overload
14
+
15
+ import chz
16
+ import submitit
17
+ from typing_extensions import dataclass_transform
18
+
19
+ from ..adapters import SubmititAdapter
20
+ from ..adapters.submitit import SubmititJob
21
+ from ..config import FURU_CONFIG
22
+ from ..errors import (
23
+ MISSING,
24
+ FuruComputeError,
25
+ FuruLockNotAcquired,
26
+ FuruWaitTimeout,
27
+ )
28
+ from ..runtime import current_holder
29
+ from ..runtime.logging import enter_holder, get_logger, log, write_separator
30
+ from ..runtime.tracebacks import format_traceback
31
+ from ..serialization import FuruSerializer
32
+ from ..serialization.serializer import JsonValue
33
+ from ..storage import (
34
+ FuruMetadata,
35
+ MetadataManager,
36
+ MigrationManager,
37
+ MigrationRecord,
38
+ StateManager,
39
+ StateOwner,
40
+ )
41
+ from ..storage.state import (
42
+ _FuruState,
43
+ _OwnerDict,
44
+ _StateAttemptFailed,
45
+ _StateAttemptQueued,
46
+ _StateAttemptRunning,
47
+ _StateResultAbsent,
48
+ _StateResultFailed,
49
+ _StateResultMigrated,
50
+ _StateResultSuccess,
51
+ compute_lock,
52
+ )
53
+
54
+
55
+ class _SubmititEnvInfo(TypedDict, total=False):
56
+ """Environment info collected for submitit jobs."""
57
+
58
+ backend: str
59
+ slurm_job_id: str | None
60
+ pid: int
61
+ host: str
62
+ user: str
63
+ started_at: str
64
+ command: str
65
+
66
+
67
+ class _CallerInfo(TypedDict, total=False):
68
+ """Caller location info for logging."""
69
+
70
+ furu_caller_file: str
71
+ furu_caller_line: int
72
+
73
+
74
+ @dataclass_transform(
75
+ field_specifiers=(chz.field,), kw_only_default=True, frozen_default=True
76
+ )
77
+ class Furu[T](ABC):
78
+ """
79
+ Base class for cached computations with provenance tracking.
80
+
81
+ Subclasses must implement:
82
+ - _create(self) -> T
83
+ - _load(self) -> T
84
+ """
85
+
86
+ MISSING = MISSING
87
+
88
+ # Configuration (can be overridden in subclasses)
89
+ version_controlled: ClassVar[bool] = False
90
+
91
+ # Maximum time to wait for result (seconds). Default: 10 minutes.
92
+ _max_wait_time_sec: float = 600.0
93
+
94
+ def __init_subclass__(
95
+ cls,
96
+ *,
97
+ version_controlled: bool | None = None,
98
+ version: str | None = None,
99
+ typecheck: bool | None = None,
100
+ **kwargs: object,
101
+ ) -> None:
102
+ super().__init_subclass__(**kwargs)
103
+ if cls.__name__ == "Furu" and cls.__module__ == __name__:
104
+ return
105
+
106
+ # Python 3.14+ may not populate `__annotations__` in `cls.__dict__` (PEP 649).
107
+ # `chz` expects annotations to exist for every `chz.field()` attribute, so we
108
+ # materialize them and (as a last resort) fill missing ones with `Any`.
109
+ try:
110
+ annotations = dict(getattr(cls, "__annotations__", {}) or {})
111
+ except Exception:
112
+ annotations = {}
113
+
114
+ try:
115
+ materialized = inspect.get_annotations(cls, eval_str=False)
116
+ except TypeError: # pragma: no cover
117
+ materialized = inspect.get_annotations(cls)
118
+ except Exception:
119
+ materialized = {}
120
+
121
+ if materialized:
122
+ annotations.update(materialized)
123
+
124
+ FieldType: type[object] | None
125
+ try:
126
+ from chz.field import Field as _ChzField
127
+ except Exception: # pragma: no cover
128
+ FieldType = None
129
+ else:
130
+ FieldType = _ChzField
131
+
132
+ if FieldType is not None:
133
+ for field_name, value in cls.__dict__.items():
134
+ if isinstance(value, FieldType) and field_name not in annotations:
135
+ annotations[field_name] = Any
136
+
137
+ if annotations:
138
+ type.__setattr__(cls, "__annotations__", annotations)
139
+
140
+ chz_kwargs: dict[str, str | bool] = {}
141
+ if version is not None:
142
+ chz_kwargs["version"] = version
143
+ if typecheck is not None:
144
+ chz_kwargs["typecheck"] = typecheck
145
+ chz.chz(cls, **chz_kwargs)
146
+
147
+ if version_controlled is not None:
148
+ setattr(cls, "version_controlled", version_controlled)
149
+
150
+ @classmethod
151
+ def _namespace(cls) -> Path:
152
+ module = getattr(cls, "__module__", None)
153
+ qualname = getattr(cls, "__qualname__", cls.__name__)
154
+ if not module or module == "__main__":
155
+ raise ValueError(
156
+ "Cannot derive Furu namespace from __main__; define the class in an importable module."
157
+ )
158
+ if "<locals>" in qualname:
159
+ raise ValueError(
160
+ "Cannot derive Furu namespace for a local class; define it at module scope."
161
+ )
162
+ return Path(*module.split("."), *qualname.split("."))
163
+
164
+ @abstractmethod
165
+ def _create(self: Self) -> T:
166
+ """Compute and save the result (implement in subclass)."""
167
+ raise NotImplementedError(
168
+ f"{self.__class__.__name__}._create() not implemented"
169
+ )
170
+
171
+ @abstractmethod
172
+ def _load(self: Self) -> T:
173
+ """Load the result from disk (implement in subclass)."""
174
+ raise NotImplementedError(f"{self.__class__.__name__}._load() not implemented")
175
+
176
+ def _validate(self: Self) -> bool:
177
+ """Validate that result is complete and correct (override if needed)."""
178
+ return True
179
+
180
+ def _invalidate_cached_success(self: Self, directory: Path, *, reason: str) -> None:
181
+ logger = get_logger()
182
+ logger.warning(
183
+ "invalidate %s %s %s (%s)",
184
+ self.__class__.__name__,
185
+ self._furu_hash,
186
+ directory,
187
+ reason,
188
+ )
189
+
190
+ StateManager.get_success_marker_path(directory).unlink(missing_ok=True)
191
+
192
+ now = datetime.datetime.now(datetime.timezone.utc).isoformat(timespec="seconds")
193
+
194
+ def mutate(state: _FuruState) -> None:
195
+ state.result = _StateResultAbsent(status="absent")
196
+
197
+ StateManager.update_state(directory, mutate)
198
+ StateManager.append_event(
199
+ directory, {"type": "result_invalidated", "reason": reason, "at": now}
200
+ )
201
+
202
+ @property
203
+ def _furu_hash(self: Self) -> str:
204
+ """Compute hash of this object's content for storage identification."""
205
+ return FuruSerializer.compute_hash(self)
206
+
207
+ def _force_recompute(self: Self) -> bool:
208
+ if not FURU_CONFIG.force_recompute:
209
+ return False
210
+ qualname = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
211
+ return qualname in FURU_CONFIG.force_recompute
212
+
213
+ def _base_furu_dir(self: Self) -> Path:
214
+ root = FURU_CONFIG.get_root(self.version_controlled)
215
+ return root / self.__class__._namespace() / self._furu_hash
216
+
217
+ @property
218
+ def furu_dir(self: Self) -> Path:
219
+ """Get the directory for this Furu object."""
220
+ directory = self._base_furu_dir()
221
+ migration = self._alias_record(directory)
222
+ if migration is not None and self._alias_is_active(directory, migration):
223
+ return MigrationManager.resolve_dir(migration, target="from")
224
+ return directory
225
+
226
+ @property
227
+ def raw_dir(self: Self) -> Path:
228
+ """
229
+ Get the raw directory for Furu.
230
+
231
+ This is intended for large, non-versioned byproducts or inputs.
232
+ """
233
+ return FURU_CONFIG.raw_dir
234
+
235
+ def to_dict(self: Self) -> JsonValue:
236
+ """Convert to dictionary."""
237
+ return FuruSerializer.to_dict(self)
238
+
239
+ @classmethod
240
+ def from_dict(cls, data: JsonValue) -> "Furu":
241
+ """Reconstruct from dictionary."""
242
+ return FuruSerializer.from_dict(data)
243
+
244
+ def to_python(self: Self, multiline: bool = True) -> str:
245
+ """Convert to Python code."""
246
+ return FuruSerializer.to_python(self, multiline=multiline)
247
+
248
+ def log(self: Self, message: str, *, level: str = "INFO") -> Path:
249
+ """Log a message to the current holder's `furu.log`."""
250
+ return log(message, level=level)
251
+
252
+ def exists(self: Self) -> bool:
253
+ """Check if result exists and is valid."""
254
+ logger = get_logger()
255
+ directory = self._base_furu_dir()
256
+ state = self.get_state(directory)
257
+
258
+ if not isinstance(state.result, _StateResultSuccess):
259
+ logger.info("exists %s -> false", directory)
260
+ return False
261
+
262
+ ok = self._validate()
263
+ logger.info("exists %s -> %s", directory, "true" if ok else "false")
264
+ return ok
265
+
266
+ def get_metadata(self: Self) -> "FuruMetadata":
267
+ """Get metadata for this object."""
268
+ directory = self._base_furu_dir()
269
+ return MetadataManager.read_metadata(directory)
270
+
271
+ def get_migration_record(self: Self) -> MigrationRecord | None:
272
+ """Get migration record for this object."""
273
+ return MigrationManager.read_migration(self._base_furu_dir())
274
+
275
+ @overload
276
+ def load_or_create(self, executor: submitit.Executor) -> T | submitit.Job[T]: ...
277
+
278
+ @overload
279
+ def load_or_create(self, executor: None = None) -> T: ...
280
+
281
+ def load_or_create(
282
+ self: Self,
283
+ executor: submitit.Executor | None = None,
284
+ ) -> T | submitit.Job[T]:
285
+ """
286
+ Load result if it exists, computing if necessary.
287
+
288
+ Args:
289
+ executor: Optional executor for batch submission (e.g., submitit.Executor)
290
+
291
+ Returns:
292
+ Result if wait=True, job handle if wait=False, or None if already exists
293
+
294
+ Raises:
295
+ FuruComputeError: If computation fails with detailed error information
296
+ """
297
+ logger = get_logger()
298
+ parent_holder = current_holder()
299
+ has_parent = parent_holder is not None and parent_holder is not self
300
+ if has_parent:
301
+ logger.debug(
302
+ "dep: begin %s %s %s",
303
+ self.__class__.__name__,
304
+ self._furu_hash,
305
+ self._base_furu_dir(),
306
+ )
307
+
308
+ ok = False
309
+ try:
310
+ with enter_holder(self):
311
+ start_time = time.time()
312
+ base_dir = self._base_furu_dir()
313
+ base_dir.mkdir(parents=True, exist_ok=True)
314
+ directory = base_dir
315
+ migration = self._alias_record(base_dir)
316
+ alias_active = False
317
+
318
+ if (
319
+ migration is not None
320
+ and migration.kind == "alias"
321
+ and migration.overwritten_at is None
322
+ ):
323
+ target_dir = MigrationManager.resolve_dir(migration, target="from")
324
+ target_state = StateManager.read_state(target_dir)
325
+ if isinstance(target_state.result, _StateResultSuccess):
326
+ alias_active = True
327
+ directory = target_dir
328
+ else:
329
+ self._maybe_detach_alias(
330
+ directory=base_dir,
331
+ record=migration,
332
+ reason="original_not_success",
333
+ )
334
+ migration = MigrationManager.read_migration(base_dir)
335
+
336
+ if alias_active and self._force_recompute():
337
+ if migration is not None:
338
+ self._maybe_detach_alias(
339
+ directory=base_dir,
340
+ record=migration,
341
+ reason="force_recompute",
342
+ )
343
+ migration = MigrationManager.read_migration(base_dir)
344
+ alias_active = False
345
+ directory = base_dir
346
+
347
+ # Optimistic read: if state is already good, we don't need to reconcile (write lock)
348
+ # Optimization: Check for success marker first to avoid reading state.json
349
+ # This is much faster for cache hits (11x speedup on check).
350
+ success_marker = StateManager.get_success_marker_path(directory)
351
+ if success_marker.is_file():
352
+ # We have a success marker. Check if we can use it.
353
+ if self._force_recompute():
354
+ self._invalidate_cached_success(
355
+ directory, reason="force_recompute enabled"
356
+ )
357
+ # Fall through to normal load
358
+ else:
359
+ try:
360
+ if not self._validate():
361
+ self._invalidate_cached_success(
362
+ directory, reason="_validate returned false"
363
+ )
364
+ # Fall through
365
+ else:
366
+ # Valid success! Return immediately.
367
+ # Since we didn't read state, we skip the logging below for speed
368
+ # or we can log a minimal message if needed.
369
+ ok = True
370
+ self._log_console_start(action_color="green")
371
+ return self._load()
372
+ except Exception as e:
373
+ self._invalidate_cached_success(
374
+ directory,
375
+ reason=f"_validate raised {type(e).__name__}: {e}",
376
+ )
377
+ # Fall through
378
+
379
+ state0 = StateManager.read_state(directory)
380
+
381
+ needs_reconcile = True
382
+ if isinstance(state0.result, _StateResultSuccess):
383
+ # Double check logic if we fell through to here (e.g. race condition or invalidation above)
384
+ if self._force_recompute():
385
+ self._invalidate_cached_success(
386
+ directory, reason="force_recompute enabled"
387
+ )
388
+ state0 = StateManager.read_state(directory)
389
+ else:
390
+ try:
391
+ if not self._validate():
392
+ self._invalidate_cached_success(
393
+ directory, reason="_validate returned false"
394
+ )
395
+ state0 = StateManager.read_state(directory)
396
+ else:
397
+ # Valid success found, skip reconcile
398
+ needs_reconcile = False
399
+ except Exception as e:
400
+ self._invalidate_cached_success(
401
+ directory,
402
+ reason=f"_validate raised {type(e).__name__}: {e}",
403
+ )
404
+ state0 = StateManager.read_state(directory)
405
+
406
+ if needs_reconcile and executor is not None:
407
+ adapter0 = SubmititAdapter(executor)
408
+ self._reconcile(directory, adapter=adapter0)
409
+ state0 = StateManager.read_state(directory)
410
+
411
+ attempt0 = state0.attempt
412
+ if isinstance(state0.result, _StateResultSuccess):
413
+ decision = "success->load"
414
+ action_color = "green"
415
+ elif isinstance(attempt0, (_StateAttemptQueued, _StateAttemptRunning)):
416
+ decision = f"{attempt0.status}->wait"
417
+ action_color = "yellow"
418
+ else:
419
+ decision = "create"
420
+ action_color = "blue"
421
+
422
+ # Cache hits can be extremely noisy in pipelines; keep logs for state
423
+ # transitions (create/wait) and error cases, but suppress repeated
424
+ # "success->load" lines and the raw separator on successful loads.
425
+ self._log_console_start(action_color=action_color)
426
+
427
+ if decision != "success->load":
428
+ write_separator()
429
+ logger.debug(
430
+ "load_or_create %s %s %s (%s)",
431
+ self.__class__.__name__,
432
+ self._furu_hash,
433
+ directory,
434
+ decision,
435
+ extra={"furu_action_color": action_color},
436
+ )
437
+
438
+ # Fast path: already successful
439
+ state_now = StateManager.read_state(directory)
440
+ if isinstance(state_now.result, _StateResultSuccess):
441
+ try:
442
+ result = self._load()
443
+ ok = True
444
+ return result
445
+ except Exception as e:
446
+ # Ensure there is still a clear marker in logs for unexpected
447
+ # failures even when we suppressed the cache-hit header line.
448
+ write_separator()
449
+ logger.error(
450
+ "load_or_create %s %s (load failed)",
451
+ self.__class__.__name__,
452
+ self._furu_hash,
453
+ )
454
+ raise FuruComputeError(
455
+ f"Failed to load result from {directory}",
456
+ StateManager.get_state_path(directory),
457
+ e,
458
+ ) from e
459
+
460
+ # Synchronous execution
461
+ if executor is None:
462
+ status, created_here, result = self._run_locally(
463
+ start_time=start_time
464
+ )
465
+ if status == "success":
466
+ ok = True
467
+ if created_here:
468
+ logger.debug(
469
+ "load_or_create: %s created -> return",
470
+ self.__class__.__name__,
471
+ )
472
+ return cast(T, result)
473
+ logger.debug(
474
+ "load_or_create: %s success -> _load()",
475
+ self.__class__.__name__,
476
+ )
477
+ return self._load()
478
+
479
+ state = StateManager.read_state(directory)
480
+ attempt = state.attempt
481
+ message = (
482
+ attempt.error.message
483
+ if isinstance(attempt, _StateAttemptFailed)
484
+ else None
485
+ )
486
+ suffix = (
487
+ f": {message}" if isinstance(message, str) and message else ""
488
+ )
489
+ raise FuruComputeError(
490
+ f"Computation {status}{suffix}",
491
+ StateManager.get_state_path(directory),
492
+ )
493
+
494
+ # Asynchronous execution with submitit
495
+ (submitit_folder := self._base_furu_dir() / "submitit").mkdir(
496
+ exist_ok=True, parents=True
497
+ )
498
+ executor.folder = submitit_folder
499
+ adapter = SubmititAdapter(executor)
500
+
501
+ logger.debug(
502
+ "load_or_create: %s -> submitit submit_once()",
503
+ self.__class__.__name__,
504
+ )
505
+ job = self._submit_once(adapter, directory, None)
506
+ ok = True
507
+ return cast(submitit.Job[T], job)
508
+ finally:
509
+ if has_parent:
510
+ logger.debug(
511
+ "dep: end %s %s (%s)",
512
+ self.__class__.__name__,
513
+ self._furu_hash,
514
+ "ok" if ok else "error",
515
+ )
516
+
517
+ def _log_console_start(self, action_color: str) -> None:
518
+ """Log the start of load_or_create to console with caller info."""
519
+ logger = get_logger()
520
+ frame = sys._getframe(1)
521
+
522
+ caller_info: _CallerInfo = {}
523
+ if frame is not None:
524
+ # Walk up the stack to find the caller outside of furu package
525
+ furu_pkg_dir = str(Path(__file__).parent.parent)
526
+ while frame is not None:
527
+ filename = frame.f_code.co_filename
528
+ # Skip frames from within the furu package
529
+ if not filename.startswith(furu_pkg_dir):
530
+ caller_info = {
531
+ "furu_caller_file": filename,
532
+ "furu_caller_line": frame.f_lineno,
533
+ }
534
+ break
535
+ frame = frame.f_back
536
+
537
+ logger.info(
538
+ "load_or_create %s %s",
539
+ self.__class__.__name__,
540
+ self._furu_hash,
541
+ extra={
542
+ "furu_console_only": True,
543
+ "furu_action_color": action_color,
544
+ **caller_info,
545
+ },
546
+ )
547
+
548
+ def _check_timeout(self, start_time: float) -> None:
549
+ """Check if operation has timed out."""
550
+ if self._max_wait_time_sec is not None:
551
+ if time.time() - start_time > self._max_wait_time_sec:
552
+ raise FuruWaitTimeout(
553
+ f"Furu operation timed out after {self._max_wait_time_sec} seconds."
554
+ )
555
+
556
+ def _is_migrated_state(self, directory: Path) -> bool:
557
+ record = self._alias_record(directory)
558
+ return record is not None and self._alias_is_active(directory, record)
559
+
560
+ def _migration_target_dir(self, directory: Path) -> Path | None:
561
+ record = self._alias_record(directory)
562
+ if record is None:
563
+ return None
564
+ return MigrationManager.resolve_dir(record, target="from")
565
+
566
+ def _resolve_effective_dir(self) -> Path:
567
+ return self._base_furu_dir()
568
+
569
+ def get_state(self, directory: Path | None = None) -> _FuruState:
570
+ """Return the alias-aware state for this Furu directory."""
571
+ base_dir = directory or self._base_furu_dir()
572
+ record = self._alias_record(base_dir)
573
+ if record is None or not self._alias_is_active(base_dir, record):
574
+ return StateManager.read_state(base_dir)
575
+ target_dir = MigrationManager.resolve_dir(record, target="from")
576
+ return StateManager.read_state(target_dir)
577
+
578
+ def _alias_record(self, directory: Path) -> MigrationRecord | None:
579
+ record = MigrationManager.read_migration(directory)
580
+ if record is None or record.kind != "alias":
581
+ return None
582
+ return record
583
+
584
+ def _alias_is_active(self, directory: Path, record: MigrationRecord) -> bool:
585
+ if record.overwritten_at is not None:
586
+ return False
587
+ state = StateManager.read_state(directory)
588
+ if not isinstance(state.result, _StateResultMigrated):
589
+ return False
590
+ target = MigrationManager.resolve_dir(record, target="from")
591
+ target_state = StateManager.read_state(target)
592
+ return isinstance(target_state.result, _StateResultSuccess)
593
+
594
+ def _maybe_detach_alias(
595
+ self: Self,
596
+ *,
597
+ directory: Path,
598
+ record: MigrationRecord,
599
+ reason: str,
600
+ ) -> None:
601
+ if record.overwritten_at is not None:
602
+ return
603
+ now = datetime.datetime.now(datetime.timezone.utc).isoformat(timespec="seconds")
604
+ record.overwritten_at = now
605
+ MigrationManager.write_migration(record, directory)
606
+ target = MigrationManager.resolve_dir(record, target="from")
607
+ target_record = MigrationManager.read_migration(target)
608
+ if target_record is not None:
609
+ target_record.overwritten_at = now
610
+ MigrationManager.write_migration(target_record, target)
611
+ event: dict[str, str | int] = {
612
+ "type": "migration_overwrite",
613
+ "policy": record.policy,
614
+ "from": f"{record.from_namespace}:{record.from_hash}",
615
+ "to": f"{record.to_namespace}:{record.to_hash}",
616
+ "reason": reason,
617
+ }
618
+ StateManager.append_event(directory, event.copy())
619
+ StateManager.append_event(target, event.copy())
620
+
621
+ def _submit_once(
622
+ self,
623
+ adapter: SubmititAdapter,
624
+ directory: Path,
625
+ on_job_id: Callable[[str], None] | None,
626
+ ) -> SubmititJob | None:
627
+ """Submit job once without waiting (fire-and-forget mode)."""
628
+ logger = get_logger()
629
+ self._reconcile(directory, adapter=adapter)
630
+ state = StateManager.read_state(directory)
631
+ attempt = state.attempt
632
+ if (
633
+ isinstance(attempt, (_StateAttemptQueued, _StateAttemptRunning))
634
+ and attempt.backend == "submitit"
635
+ ):
636
+ job = adapter.load_job(directory)
637
+ if job is not None:
638
+ return job
639
+
640
+ # Try to acquire submit lock
641
+ lock_path = StateManager.get_lock_path(directory, StateManager.SUBMIT_LOCK)
642
+ lock_fd = StateManager.try_lock(lock_path)
643
+
644
+ if lock_fd is None:
645
+ # Someone else is submitting, wait briefly and return their job
646
+ logger.debug(
647
+ "submit: waiting for submit lock %s %s %s",
648
+ self.__class__.__name__,
649
+ self._furu_hash,
650
+ directory,
651
+ )
652
+ time.sleep(0.5)
653
+ return adapter.load_job(directory)
654
+
655
+ attempt_id: str | None = None
656
+ try:
657
+ # Create metadata
658
+ metadata = MetadataManager.create_metadata(
659
+ self, directory, ignore_diff=FURU_CONFIG.ignore_git_diff
660
+ )
661
+ MetadataManager.write_metadata(metadata, directory)
662
+
663
+ env_info = MetadataManager.collect_environment_info()
664
+ owner_state = StateOwner(
665
+ pid=env_info.pid,
666
+ host=env_info.hostname,
667
+ hostname=env_info.hostname,
668
+ user=env_info.user,
669
+ command=env_info.command,
670
+ timestamp=env_info.timestamp,
671
+ python_version=env_info.python_version,
672
+ executable=env_info.executable,
673
+ platform=env_info.platform,
674
+ )
675
+ owner_payload: _OwnerDict = {
676
+ "pid": owner_state.pid,
677
+ "host": owner_state.host,
678
+ "hostname": owner_state.hostname,
679
+ "user": owner_state.user,
680
+ "command": owner_state.command,
681
+ "timestamp": owner_state.timestamp,
682
+ "python_version": owner_state.python_version,
683
+ "executable": owner_state.executable,
684
+ "platform": owner_state.platform,
685
+ }
686
+ attempt_id = StateManager.start_attempt_queued(
687
+ directory,
688
+ backend="submitit",
689
+ lease_duration_sec=FURU_CONFIG.lease_duration_sec,
690
+ owner=owner_payload,
691
+ scheduler={},
692
+ )
693
+
694
+ job = adapter.submit(lambda: self._worker_entry())
695
+
696
+ # Save job handle and watch for job ID
697
+ adapter.pickle_job(job, directory)
698
+ adapter.watch_job_id(
699
+ job,
700
+ directory,
701
+ attempt_id=attempt_id,
702
+ callback=on_job_id,
703
+ )
704
+
705
+ return job
706
+ except Exception as e:
707
+ if attempt_id is not None:
708
+ StateManager.finish_attempt_failed(
709
+ directory,
710
+ attempt_id=attempt_id, # type: ignore[arg-type]
711
+ error={
712
+ "type": type(e).__name__,
713
+ "message": f"Failed to submit: {e}",
714
+ },
715
+ )
716
+ else:
717
+
718
+ def mutate(state: _FuruState) -> None:
719
+ state.result = _StateResultFailed(status="failed")
720
+
721
+ StateManager.update_state(directory, mutate)
722
+ raise FuruComputeError(
723
+ "Failed to submit job",
724
+ StateManager.get_state_path(directory),
725
+ e,
726
+ ) from e
727
+ finally:
728
+ StateManager.release_lock(lock_fd, lock_path)
729
+
730
+ def _worker_entry(self: Self) -> None:
731
+ """Entry point for worker process (called by submitit or locally)."""
732
+ with enter_holder(self):
733
+ logger = get_logger()
734
+ directory = self._base_furu_dir()
735
+ directory.mkdir(parents=True, exist_ok=True)
736
+
737
+ env_info = self._collect_submitit_env()
738
+
739
+ try:
740
+ with compute_lock(
741
+ directory,
742
+ backend="submitit",
743
+ lease_duration_sec=FURU_CONFIG.lease_duration_sec,
744
+ heartbeat_interval_sec=FURU_CONFIG.heartbeat_interval_sec,
745
+ owner={
746
+ "pid": os.getpid(),
747
+ "host": socket.gethostname(),
748
+ "user": getpass.getuser(),
749
+ "command": " ".join(sys.argv) if sys.argv else "<unknown>",
750
+ },
751
+ scheduler={
752
+ "backend": env_info.get("backend"),
753
+ "job_id": env_info.get("slurm_job_id"),
754
+ },
755
+ max_wait_time_sec=None, # Workers wait indefinitely
756
+ poll_interval_sec=FURU_CONFIG.poll_interval,
757
+ wait_log_every_sec=FURU_CONFIG.wait_log_every_sec,
758
+ reconcile_fn=lambda d: self._reconcile(d),
759
+ ) as ctx:
760
+ # Refresh metadata (now safe - attempt is already recorded)
761
+ metadata = MetadataManager.create_metadata(
762
+ self, directory, ignore_diff=FURU_CONFIG.ignore_git_diff
763
+ )
764
+ MetadataManager.write_metadata(metadata, directory)
765
+
766
+ # Set up signal handlers
767
+ self._setup_signal_handlers(
768
+ directory, ctx.stop_heartbeat, attempt_id=ctx.attempt_id
769
+ )
770
+
771
+ try:
772
+ # Run computation
773
+ logger.debug(
774
+ "_create: begin %s %s %s",
775
+ self.__class__.__name__,
776
+ self._furu_hash,
777
+ directory,
778
+ )
779
+ self._create()
780
+ logger.debug(
781
+ "_create: ok %s %s %s",
782
+ self.__class__.__name__,
783
+ self._furu_hash,
784
+ directory,
785
+ )
786
+ StateManager.write_success_marker(
787
+ directory, attempt_id=ctx.attempt_id
788
+ )
789
+ StateManager.finish_attempt_success(
790
+ directory, attempt_id=ctx.attempt_id
791
+ )
792
+ logger.info(
793
+ "_create ok %s %s",
794
+ self.__class__.__name__,
795
+ self._furu_hash,
796
+ extra={"furu_console_only": True},
797
+ )
798
+ except Exception as e:
799
+ logger.error(
800
+ "_create failed %s %s %s",
801
+ self.__class__.__name__,
802
+ self._furu_hash,
803
+ directory,
804
+ extra={"furu_file_only": True},
805
+ )
806
+ logger.error(
807
+ "%s", format_traceback(e), extra={"furu_file_only": True}
808
+ )
809
+
810
+ tb = "".join(
811
+ traceback.format_exception(type(e), e, e.__traceback__)
812
+ )
813
+ StateManager.finish_attempt_failed(
814
+ directory,
815
+ attempt_id=ctx.attempt_id,
816
+ error={
817
+ "type": type(e).__name__,
818
+ "message": str(e),
819
+ "traceback": tb,
820
+ },
821
+ )
822
+ raise
823
+ except FuruLockNotAcquired:
824
+ # Experiment already completed (success or failed), nothing to do
825
+ return
826
+
827
+ def _collect_submitit_env(self: Self) -> _SubmititEnvInfo:
828
+ """Collect submitit/slurm environment information."""
829
+ slurm_id = os.getenv("SLURM_JOB_ID")
830
+
831
+ info: _SubmititEnvInfo = {
832
+ "backend": "slurm" if slurm_id else "local",
833
+ "slurm_job_id": slurm_id,
834
+ "pid": os.getpid(),
835
+ "host": socket.gethostname(),
836
+ "user": getpass.getuser(),
837
+ "started_at": datetime.datetime.now(datetime.timezone.utc).isoformat(
838
+ timespec="seconds"
839
+ ),
840
+ "command": " ".join(sys.argv) if sys.argv else "<unknown>",
841
+ }
842
+
843
+ # Only call submitit.JobEnvironment() when actually in a submitit job
844
+ if slurm_id:
845
+ env = submitit.JobEnvironment()
846
+ info["backend"] = "submitit"
847
+ info["slurm_job_id"] = str(getattr(env, "job_id", slurm_id))
848
+
849
+ return info
850
+
851
+ def _run_locally(self: Self, start_time: float) -> tuple[str, bool, T | None]:
852
+ """Run computation locally, returning (status, created_here, result)."""
853
+ logger = get_logger()
854
+ directory = self._base_furu_dir()
855
+
856
+ # Calculate remaining time for the lock wait
857
+ max_wait: float | None = None
858
+ if self._max_wait_time_sec is not None:
859
+ elapsed = time.time() - start_time
860
+ max_wait = max(0.0, self._max_wait_time_sec - elapsed)
861
+
862
+ try:
863
+ with compute_lock(
864
+ directory,
865
+ backend="local",
866
+ lease_duration_sec=FURU_CONFIG.lease_duration_sec,
867
+ heartbeat_interval_sec=FURU_CONFIG.heartbeat_interval_sec,
868
+ owner={
869
+ "pid": os.getpid(),
870
+ "host": socket.gethostname(),
871
+ "user": getpass.getuser(),
872
+ "command": " ".join(sys.argv) if sys.argv else "<unknown>",
873
+ },
874
+ scheduler={},
875
+ max_wait_time_sec=max_wait,
876
+ poll_interval_sec=FURU_CONFIG.poll_interval,
877
+ wait_log_every_sec=FURU_CONFIG.wait_log_every_sec,
878
+ reconcile_fn=lambda d: self._reconcile(d),
879
+ ) as ctx:
880
+ # Create metadata (now safe - attempt is already recorded)
881
+ try:
882
+ metadata = MetadataManager.create_metadata(
883
+ self, directory, ignore_diff=FURU_CONFIG.ignore_git_diff
884
+ )
885
+ MetadataManager.write_metadata(metadata, directory)
886
+ except Exception as e:
887
+ raise FuruComputeError(
888
+ "Failed to create metadata",
889
+ StateManager.get_state_path(directory),
890
+ e,
891
+ ) from e
892
+
893
+ # Set up preemption handler
894
+ self._setup_signal_handlers(
895
+ directory, ctx.stop_heartbeat, attempt_id=ctx.attempt_id
896
+ )
897
+
898
+ try:
899
+ # Run the computation
900
+ logger.debug(
901
+ "_create: begin %s %s %s",
902
+ self.__class__.__name__,
903
+ self._furu_hash,
904
+ directory,
905
+ )
906
+ result = self._create()
907
+ logger.debug(
908
+ "_create: ok %s %s %s",
909
+ self.__class__.__name__,
910
+ self._furu_hash,
911
+ directory,
912
+ )
913
+ StateManager.write_success_marker(
914
+ directory, attempt_id=ctx.attempt_id
915
+ )
916
+ StateManager.finish_attempt_success(
917
+ directory, attempt_id=ctx.attempt_id
918
+ )
919
+ logger.info(
920
+ "_create ok %s %s",
921
+ self.__class__.__name__,
922
+ self._furu_hash,
923
+ extra={"furu_console_only": True},
924
+ )
925
+ return "success", True, result
926
+ except Exception as e:
927
+ logger.error(
928
+ "_create failed %s %s %s",
929
+ self.__class__.__name__,
930
+ self._furu_hash,
931
+ directory,
932
+ extra={"furu_file_only": True},
933
+ )
934
+ logger.error(
935
+ "%s", format_traceback(e), extra={"furu_file_only": True}
936
+ )
937
+
938
+ # Record failure (plain text in file)
939
+ tb = "".join(
940
+ traceback.format_exception(type(e), e, e.__traceback__)
941
+ )
942
+ StateManager.finish_attempt_failed(
943
+ directory,
944
+ attempt_id=ctx.attempt_id,
945
+ error={
946
+ "type": type(e).__name__,
947
+ "message": str(e),
948
+ "traceback": tb,
949
+ },
950
+ )
951
+ raise
952
+ except FuruLockNotAcquired:
953
+ # Lock couldn't be acquired because experiment already completed
954
+ state = StateManager.read_state(directory)
955
+ if isinstance(state.result, _StateResultSuccess):
956
+ return "success", False, None
957
+ if isinstance(state.result, _StateResultFailed):
958
+ return "failed", False, None
959
+ # Shouldn't happen, but re-raise if state is unexpected
960
+ raise
961
+
962
+ def _reconcile(
963
+ self: Self, directory: Path, *, adapter: SubmititAdapter | None = None
964
+ ) -> None:
965
+ if adapter is None:
966
+ StateManager.reconcile(directory)
967
+ return
968
+
969
+ StateManager.reconcile(
970
+ directory,
971
+ submitit_probe=lambda state: adapter.probe(directory, state),
972
+ )
973
+
974
+ def _setup_signal_handlers(
975
+ self,
976
+ directory: Path,
977
+ stop_heartbeat: Callable[[], None],
978
+ *,
979
+ attempt_id: str,
980
+ ) -> None:
981
+ """Set up signal handlers for graceful preemption."""
982
+
983
+ def handle_signal(signum: int, frame: FrameType | None) -> None:
984
+ try:
985
+ StateManager.finish_attempt_preempted(
986
+ directory,
987
+ attempt_id=attempt_id,
988
+ error={"type": "signal", "message": f"signal:{signum}"},
989
+ )
990
+ finally:
991
+ stop_heartbeat()
992
+ exit_code = 143 if signum == signal.SIGTERM else 130
993
+ os._exit(exit_code)
994
+
995
+ for sig in (signal.SIGTERM, signal.SIGINT):
996
+ signal.signal(sig, handle_signal)
997
+
998
+
999
+ _H = TypeVar("_H", bound=Furu, covariant=True)