safe-state 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- safe_state/__init__.py +64 -0
- safe_state/checkpoint.py +168 -0
- safe_state/core.py +358 -0
- safe_state/exceptions.py +21 -0
- safe_state/reconnect.py +267 -0
- safe_state/serialization.py +112 -0
- safe_state-0.1.0.dist-info/METADATA +366 -0
- safe_state-0.1.0.dist-info/RECORD +11 -0
- safe_state-0.1.0.dist-info/WHEEL +5 -0
- safe_state-0.1.0.dist-info/licenses/LICENSE +21 -0
- safe_state-0.1.0.dist-info/top_level.txt +1 -0
safe_state/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""
|
|
2
|
+
safe-state: resumable execution for Python.
|
|
3
|
+
|
|
4
|
+
Drop ``@safe_state`` on top of any function that iterates through work. If the
|
|
5
|
+
function crashes — network blip, rate limit, power loss — the checkpoint is on
|
|
6
|
+
disk. Run the script again and it picks up where it left off. Live objects
|
|
7
|
+
(database connections, sockets, file handles) are restored via a reconnect
|
|
8
|
+
registry instead of being naively pickled.
|
|
9
|
+
|
|
10
|
+
Quick start:
|
|
11
|
+
|
|
12
|
+
from safe_state import safe_state
|
|
13
|
+
|
|
14
|
+
@safe_state
|
|
15
|
+
def remediate(clients, db):
|
|
16
|
+
for client in clients:
|
|
17
|
+
secure_endpoint(client, db)
|
|
18
|
+
|
|
19
|
+
remediate(load_clients(), open_db()) # crashes at client 15? just rerun.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from .core import safe_state, checkpoint, SafeIterator
|
|
23
|
+
from .checkpoint import Checkpoint, CheckpointStore
|
|
24
|
+
from .reconnect import (
|
|
25
|
+
ReconnectRegistry,
|
|
26
|
+
reconnect_handler,
|
|
27
|
+
register_reconnector,
|
|
28
|
+
get_default_registry,
|
|
29
|
+
)
|
|
30
|
+
from .serialization import freeze_state, thaw_state, can_freeze
|
|
31
|
+
from .exceptions import (
|
|
32
|
+
SafeStateError,
|
|
33
|
+
StateNotFoundError,
|
|
34
|
+
StateCorruptedError,
|
|
35
|
+
ReconnectError,
|
|
36
|
+
NoReconnectorError,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
__version__ = "0.1.0"
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
# decorator + helpers
|
|
43
|
+
"safe_state",
|
|
44
|
+
"checkpoint",
|
|
45
|
+
"SafeIterator",
|
|
46
|
+
# checkpoint store
|
|
47
|
+
"Checkpoint",
|
|
48
|
+
"CheckpointStore",
|
|
49
|
+
# reconnect machinery
|
|
50
|
+
"ReconnectRegistry",
|
|
51
|
+
"reconnect_handler",
|
|
52
|
+
"register_reconnector",
|
|
53
|
+
"get_default_registry",
|
|
54
|
+
# serialization
|
|
55
|
+
"freeze_state",
|
|
56
|
+
"thaw_state",
|
|
57
|
+
"can_freeze",
|
|
58
|
+
# exceptions
|
|
59
|
+
"SafeStateError",
|
|
60
|
+
"StateNotFoundError",
|
|
61
|
+
"StateCorruptedError",
|
|
62
|
+
"ReconnectError",
|
|
63
|
+
"NoReconnectorError",
|
|
64
|
+
]
|
safe_state/checkpoint.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Checkpoint store: on-disk persistence of execution progress.
|
|
3
|
+
|
|
4
|
+
A Checkpoint records:
|
|
5
|
+
- which items in an iteration have completed (by index)
|
|
6
|
+
- results returned for each completed item (optional)
|
|
7
|
+
- the last failure (exception type, message, traceback) if any
|
|
8
|
+
- frozen local state at the point of failure
|
|
9
|
+
- a timestamp and run identifier
|
|
10
|
+
|
|
11
|
+
Checkpoints are stored as a single dill blob per job name, atomically replaced
|
|
12
|
+
via write-temp-then-rename to avoid corruption on crash.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import time
|
|
20
|
+
import traceback
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Dict, List, Optional, Set
|
|
24
|
+
|
|
25
|
+
from .exceptions import StateCorruptedError, StateNotFoundError
|
|
26
|
+
from .reconnect import ReconnectRegistry, get_default_registry
|
|
27
|
+
from .serialization import freeze_state, thaw_state
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class Checkpoint:
|
|
32
|
+
"""A snapshot of progress through a safe-state-decorated function."""
|
|
33
|
+
|
|
34
|
+
job_name: str
|
|
35
|
+
completed_indices: Set[int] = field(default_factory=set)
|
|
36
|
+
results: Dict[int, Any] = field(default_factory=dict)
|
|
37
|
+
total_items: Optional[int] = None
|
|
38
|
+
last_failure: Optional[Dict[str, Any]] = None
|
|
39
|
+
frozen_state: Optional[bytes] = None # raw bytes from freeze_state()
|
|
40
|
+
started_at: float = field(default_factory=time.time)
|
|
41
|
+
updated_at: float = field(default_factory=time.time)
|
|
42
|
+
run_count: int = 1
|
|
43
|
+
|
|
44
|
+
def mark_complete(self, index: int, result: Any = None, store_result: bool = False) -> None:
|
|
45
|
+
self.completed_indices.add(index)
|
|
46
|
+
if store_result:
|
|
47
|
+
self.results[index] = result
|
|
48
|
+
self.updated_at = time.time()
|
|
49
|
+
|
|
50
|
+
def record_failure(self, index: int, exc: BaseException, locals_snapshot: Dict[str, Any]) -> None:
|
|
51
|
+
self.last_failure = {
|
|
52
|
+
"index": index,
|
|
53
|
+
"exception_type": type(exc).__name__,
|
|
54
|
+
"exception_module": type(exc).__module__,
|
|
55
|
+
"message": str(exc),
|
|
56
|
+
"traceback": "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
|
|
57
|
+
"at": time.time(),
|
|
58
|
+
}
|
|
59
|
+
# Try to freeze the whole snapshot first; if that fails, freeze per-key,
|
|
60
|
+
# dropping values that can't be serialized.
|
|
61
|
+
try:
|
|
62
|
+
self.frozen_state = freeze_state(locals_snapshot)
|
|
63
|
+
except Exception:
|
|
64
|
+
partial: Dict[str, Any] = {}
|
|
65
|
+
dropped: List[str] = []
|
|
66
|
+
for k, v in locals_snapshot.items():
|
|
67
|
+
try:
|
|
68
|
+
freeze_state({k: v})
|
|
69
|
+
partial[k] = v
|
|
70
|
+
except Exception:
|
|
71
|
+
dropped.append(k)
|
|
72
|
+
try:
|
|
73
|
+
self.frozen_state = freeze_state(partial)
|
|
74
|
+
if dropped:
|
|
75
|
+
self.last_failure["dropped_locals"] = dropped
|
|
76
|
+
except Exception as freeze_err:
|
|
77
|
+
self.frozen_state = None
|
|
78
|
+
self.last_failure["freeze_error"] = (
|
|
79
|
+
f"{type(freeze_err).__name__}: {freeze_err}"
|
|
80
|
+
)
|
|
81
|
+
self.updated_at = time.time()
|
|
82
|
+
|
|
83
|
+
def progress(self) -> Dict[str, Any]:
|
|
84
|
+
"""Human-readable progress summary."""
|
|
85
|
+
done = len(self.completed_indices)
|
|
86
|
+
total = self.total_items if self.total_items is not None else "?"
|
|
87
|
+
return {
|
|
88
|
+
"job": self.job_name,
|
|
89
|
+
"completed": done,
|
|
90
|
+
"total": total,
|
|
91
|
+
"remaining": (self.total_items - done) if self.total_items is not None else "?",
|
|
92
|
+
"has_failure": self.last_failure is not None,
|
|
93
|
+
"run_count": self.run_count,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class CheckpointStore:
|
|
98
|
+
"""File-backed persistence for a single Checkpoint."""
|
|
99
|
+
|
|
100
|
+
def __init__(self, path: str | os.PathLike, registry: Optional[ReconnectRegistry] = None):
|
|
101
|
+
self.path = Path(path)
|
|
102
|
+
self.registry = registry or get_default_registry()
|
|
103
|
+
|
|
104
|
+
def exists(self) -> bool:
|
|
105
|
+
return self.path.exists()
|
|
106
|
+
|
|
107
|
+
def load(self) -> Checkpoint:
|
|
108
|
+
if not self.path.exists():
|
|
109
|
+
raise StateNotFoundError(f"No checkpoint at {self.path}")
|
|
110
|
+
try:
|
|
111
|
+
with open(self.path, "rb") as f:
|
|
112
|
+
blob = f.read()
|
|
113
|
+
payload = thaw_state(blob, registry=self.registry)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
raise StateCorruptedError(f"Could not load checkpoint {self.path}: {e}") from e
|
|
116
|
+
|
|
117
|
+
cp = Checkpoint(
|
|
118
|
+
job_name=payload["job_name"],
|
|
119
|
+
completed_indices=set(payload.get("completed_indices", [])),
|
|
120
|
+
results=payload.get("results", {}),
|
|
121
|
+
total_items=payload.get("total_items"),
|
|
122
|
+
last_failure=payload.get("last_failure"),
|
|
123
|
+
frozen_state=payload.get("frozen_state"),
|
|
124
|
+
started_at=payload.get("started_at", time.time()),
|
|
125
|
+
updated_at=payload.get("updated_at", time.time()),
|
|
126
|
+
run_count=payload.get("run_count", 1),
|
|
127
|
+
)
|
|
128
|
+
return cp
|
|
129
|
+
|
|
130
|
+
def save(self, cp: Checkpoint) -> None:
|
|
131
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
132
|
+
payload = {
|
|
133
|
+
"job_name": cp.job_name,
|
|
134
|
+
"completed_indices": list(cp.completed_indices),
|
|
135
|
+
"results": cp.results,
|
|
136
|
+
"total_items": cp.total_items,
|
|
137
|
+
"last_failure": cp.last_failure,
|
|
138
|
+
"frozen_state": cp.frozen_state, # already bytes
|
|
139
|
+
"started_at": cp.started_at,
|
|
140
|
+
"updated_at": cp.updated_at,
|
|
141
|
+
"run_count": cp.run_count,
|
|
142
|
+
}
|
|
143
|
+
blob = freeze_state(payload, registry=self.registry)
|
|
144
|
+
|
|
145
|
+
# Atomic write: temp file in same directory, then rename.
|
|
146
|
+
tmp_fd, tmp_path = tempfile.mkstemp(
|
|
147
|
+
prefix=f".{self.path.name}.", suffix=".tmp", dir=str(self.path.parent)
|
|
148
|
+
)
|
|
149
|
+
try:
|
|
150
|
+
with os.fdopen(tmp_fd, "wb") as f:
|
|
151
|
+
f.write(blob)
|
|
152
|
+
os.replace(tmp_path, self.path)
|
|
153
|
+
except Exception:
|
|
154
|
+
try:
|
|
155
|
+
os.unlink(tmp_path)
|
|
156
|
+
except OSError:
|
|
157
|
+
pass
|
|
158
|
+
raise
|
|
159
|
+
|
|
160
|
+
def clear(self) -> None:
|
|
161
|
+
if self.path.exists():
|
|
162
|
+
self.path.unlink()
|
|
163
|
+
|
|
164
|
+
def thaw_locals(self, cp: Checkpoint) -> Optional[Dict[str, Any]]:
|
|
165
|
+
"""Rebuild the frozen locals snapshot, reconnecting any live objects."""
|
|
166
|
+
if cp.frozen_state is None:
|
|
167
|
+
return None
|
|
168
|
+
return thaw_state(cp.frozen_state, registry=self.registry)
|
safe_state/core.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The `safe_state` decorator and its supporting iterator.
|
|
3
|
+
|
|
4
|
+
The decorator wraps a function that processes an iterable. The first positional
|
|
5
|
+
argument (by default) is intercepted and replaced with a SafeIterator that:
|
|
6
|
+
|
|
7
|
+
1. Tracks progress through the iteration to disk after every item.
|
|
8
|
+
2. On exception, saves a checkpoint of completed indices + frozen locals.
|
|
9
|
+
3. On the next invocation with the same job name, loads the checkpoint
|
|
10
|
+
and skips items that were already completed.
|
|
11
|
+
|
|
12
|
+
A `safe_state.checkpoint()` helper is also provided for use inside the function
|
|
13
|
+
body (manual checkpoint mode).
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import functools
|
|
19
|
+
import inspect
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
|
24
|
+
|
|
25
|
+
from .checkpoint import Checkpoint, CheckpointStore
|
|
26
|
+
from .exceptions import SafeStateError, StateNotFoundError
|
|
27
|
+
from .reconnect import ReconnectRegistry, get_default_registry
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Default directory for checkpoints. Override via env var or decorator arg.
|
|
31
|
+
DEFAULT_STATE_DIR = Path(os.environ.get("SAFE_STATE_DIR", ".safe_state"))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Thread-local-ish current context, exposed so user code can call manual checkpoint().
|
|
35
|
+
# We use a module-level dict keyed by function id — simple and avoids threading concerns
|
|
36
|
+
# for the common single-threaded use case. For multithreaded use, users should pass
|
|
37
|
+
# explicit CheckpointStore objects.
|
|
38
|
+
_active_contexts: Dict[int, "_RunContext"] = {}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class _RunContext:
|
|
42
|
+
"""Per-call context used by the SafeIterator and manual checkpoint()."""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
store: CheckpointStore,
|
|
47
|
+
checkpoint: Checkpoint,
|
|
48
|
+
verbose: bool,
|
|
49
|
+
):
|
|
50
|
+
self.store = store
|
|
51
|
+
self.checkpoint = checkpoint
|
|
52
|
+
self.verbose = verbose
|
|
53
|
+
self.current_index: Optional[int] = None
|
|
54
|
+
self.current_locals: Dict[str, Any] = {}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SafeIterator:
|
|
58
|
+
"""
|
|
59
|
+
Wraps an iterable so iteration progress is checkpointed automatically.
|
|
60
|
+
|
|
61
|
+
Yields items whose index is not already in checkpoint.completed_indices.
|
|
62
|
+
After each yielded item is consumed (i.e., the next iteration starts),
|
|
63
|
+
the previous item's index is marked complete and the checkpoint is saved.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
iterable: Iterable[Any],
|
|
69
|
+
context: _RunContext,
|
|
70
|
+
save_every: int = 1,
|
|
71
|
+
store_results: bool = False,
|
|
72
|
+
):
|
|
73
|
+
self._items = list(iterable) # materialize to get a stable length and indexing
|
|
74
|
+
self._context = context
|
|
75
|
+
self._save_every = max(1, int(save_every))
|
|
76
|
+
self._store_results = store_results
|
|
77
|
+
self._since_save = 0
|
|
78
|
+
|
|
79
|
+
# Record total on first run.
|
|
80
|
+
if self._context.checkpoint.total_items is None:
|
|
81
|
+
self._context.checkpoint.total_items = len(self._items)
|
|
82
|
+
|
|
83
|
+
def __iter__(self) -> Iterator[Tuple[int, Any]]:
|
|
84
|
+
cp = self._context.checkpoint
|
|
85
|
+
for idx, item in enumerate(self._items):
|
|
86
|
+
if idx in cp.completed_indices:
|
|
87
|
+
if self._context.verbose:
|
|
88
|
+
print(f"[safe_state] skip {idx} (already done)", file=sys.stderr)
|
|
89
|
+
continue
|
|
90
|
+
self._context.current_index = idx
|
|
91
|
+
yield item
|
|
92
|
+
# Mark the just-yielded item complete.
|
|
93
|
+
cp.mark_complete(idx, store_result=self._store_results)
|
|
94
|
+
self._since_save += 1
|
|
95
|
+
if self._since_save >= self._save_every:
|
|
96
|
+
self._context.store.save(cp)
|
|
97
|
+
self._since_save = 0
|
|
98
|
+
# Final flush.
|
|
99
|
+
if self._since_save > 0:
|
|
100
|
+
self._context.store.save(cp)
|
|
101
|
+
|
|
102
|
+
def __len__(self) -> int:
|
|
103
|
+
return len(self._items)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def checkpoint(**locals_to_save: Any) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Manually trigger a checkpoint save from inside a decorated function.
|
|
109
|
+
|
|
110
|
+
Pass any local variables you want frozen into the checkpoint:
|
|
111
|
+
|
|
112
|
+
@safe_state
|
|
113
|
+
def long_job(tasks, db):
|
|
114
|
+
results = []
|
|
115
|
+
for t in tasks:
|
|
116
|
+
results.append(do(t))
|
|
117
|
+
checkpoint(results=results) # snapshot progress
|
|
118
|
+
|
|
119
|
+
Has no effect if called outside a safe-state-decorated function.
|
|
120
|
+
"""
|
|
121
|
+
# Walk up the call stack to find a frame whose function is decorated.
|
|
122
|
+
frame = inspect.currentframe()
|
|
123
|
+
try:
|
|
124
|
+
while frame is not None:
|
|
125
|
+
ctx = _active_contexts.get(id(frame.f_code))
|
|
126
|
+
if ctx is not None:
|
|
127
|
+
ctx.current_locals = dict(locals_to_save)
|
|
128
|
+
# Freeze locals into the checkpoint snapshot.
|
|
129
|
+
cp = ctx.checkpoint
|
|
130
|
+
from .serialization import freeze_state
|
|
131
|
+
try:
|
|
132
|
+
cp.frozen_state = freeze_state(locals_to_save)
|
|
133
|
+
except Exception:
|
|
134
|
+
cp.frozen_state = None
|
|
135
|
+
ctx.store.save(cp)
|
|
136
|
+
return
|
|
137
|
+
frame = frame.f_back
|
|
138
|
+
finally:
|
|
139
|
+
del frame
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def safe_state(
|
|
143
|
+
_fn: Optional[Callable] = None,
|
|
144
|
+
*,
|
|
145
|
+
name: Optional[str] = None,
|
|
146
|
+
state_dir: Union[str, os.PathLike, None] = None,
|
|
147
|
+
iterable_arg: Union[str, int] = 0,
|
|
148
|
+
save_every: int = 1,
|
|
149
|
+
store_results: bool = False,
|
|
150
|
+
keep_on_success: bool = False,
|
|
151
|
+
verbose: bool = False,
|
|
152
|
+
registry: Optional[ReconnectRegistry] = None,
|
|
153
|
+
auto_iterate: bool = True,
|
|
154
|
+
) -> Callable:
|
|
155
|
+
"""
|
|
156
|
+
Decorator that makes a function resumable.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
name: Job name (used as the checkpoint filename). Defaults to the function's
|
|
160
|
+
qualified name.
|
|
161
|
+
state_dir: Directory for checkpoint files. Defaults to ``.safe_state/`` or
|
|
162
|
+
``$SAFE_STATE_DIR``.
|
|
163
|
+
iterable_arg: Position (int) or name (str) of the iterable argument that
|
|
164
|
+
should be intercepted and checkpointed. Default: first positional arg.
|
|
165
|
+
save_every: Persist progress every N completed items. Default: 1 (every item).
|
|
166
|
+
store_results: If True, store the iterable's items in the checkpoint so
|
|
167
|
+
results from already-completed items survive across restarts. Note: items
|
|
168
|
+
must be serializable.
|
|
169
|
+
keep_on_success: If True, leave the checkpoint file in place after the
|
|
170
|
+
function completes successfully. Default: delete it.
|
|
171
|
+
verbose: Print progress messages to stderr.
|
|
172
|
+
registry: Optional custom ReconnectRegistry for live-object handling.
|
|
173
|
+
auto_iterate: If True (default), the decorated function receives a
|
|
174
|
+
SafeIterator in place of its iterable argument. If False, the function
|
|
175
|
+
is left untouched — use this with manual checkpoint() calls.
|
|
176
|
+
|
|
177
|
+
Usage:
|
|
178
|
+
|
|
179
|
+
@safe_state
|
|
180
|
+
def remediate(clients, db):
|
|
181
|
+
for client in clients:
|
|
182
|
+
secure_endpoint(client, db)
|
|
183
|
+
|
|
184
|
+
# Or with config:
|
|
185
|
+
@safe_state(name="newsletter-blast", state_dir="state/", save_every=5)
|
|
186
|
+
def remediate(clients, db):
|
|
187
|
+
...
|
|
188
|
+
|
|
189
|
+
On the first run, the checkpoint file is created. On any subsequent run,
|
|
190
|
+
if the previous run crashed, the function resumes from where it stopped.
|
|
191
|
+
"""
|
|
192
|
+
state_dir_path = Path(state_dir) if state_dir is not None else DEFAULT_STATE_DIR
|
|
193
|
+
reg = registry or get_default_registry()
|
|
194
|
+
|
|
195
|
+
def decorate(fn: Callable) -> Callable:
|
|
196
|
+
job_name = name or f"{fn.__module__}.{fn.__qualname__}"
|
|
197
|
+
# Sanitize for filename use.
|
|
198
|
+
safe_name = job_name.replace("/", "_").replace(os.sep, "_")
|
|
199
|
+
path = state_dir_path / f"{safe_name}.safestate"
|
|
200
|
+
|
|
201
|
+
sig = inspect.signature(fn)
|
|
202
|
+
|
|
203
|
+
@functools.wraps(fn)
|
|
204
|
+
def wrapper(*args, **kwargs):
|
|
205
|
+
store = CheckpointStore(path, registry=reg)
|
|
206
|
+
|
|
207
|
+
# Load or create the checkpoint.
|
|
208
|
+
if store.exists():
|
|
209
|
+
try:
|
|
210
|
+
cp = store.load()
|
|
211
|
+
cp.run_count += 1
|
|
212
|
+
if verbose:
|
|
213
|
+
prog = cp.progress()
|
|
214
|
+
print(
|
|
215
|
+
f"[safe_state] resuming {job_name!r}: "
|
|
216
|
+
f"{prog['completed']}/{prog['total']} done "
|
|
217
|
+
f"(run #{cp.run_count})",
|
|
218
|
+
file=sys.stderr,
|
|
219
|
+
)
|
|
220
|
+
except SafeStateError as e:
|
|
221
|
+
if verbose:
|
|
222
|
+
print(
|
|
223
|
+
f"[safe_state] could not load checkpoint ({e}); starting fresh",
|
|
224
|
+
file=sys.stderr,
|
|
225
|
+
)
|
|
226
|
+
cp = Checkpoint(job_name=job_name)
|
|
227
|
+
else:
|
|
228
|
+
cp = Checkpoint(job_name=job_name)
|
|
229
|
+
if verbose:
|
|
230
|
+
print(f"[safe_state] starting fresh job {job_name!r}", file=sys.stderr)
|
|
231
|
+
|
|
232
|
+
ctx = _RunContext(store=store, checkpoint=cp, verbose=verbose)
|
|
233
|
+
_active_contexts[id(fn.__code__)] = ctx
|
|
234
|
+
|
|
235
|
+
# Bind args and intercept the iterable if auto_iterate.
|
|
236
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
237
|
+
bound.apply_defaults()
|
|
238
|
+
|
|
239
|
+
if auto_iterate:
|
|
240
|
+
arg_name: Optional[str]
|
|
241
|
+
if isinstance(iterable_arg, int):
|
|
242
|
+
params = list(sig.parameters)
|
|
243
|
+
if iterable_arg >= len(params):
|
|
244
|
+
raise SafeStateError(
|
|
245
|
+
f"iterable_arg={iterable_arg} but {fn.__name__} only "
|
|
246
|
+
f"has {len(params)} parameters"
|
|
247
|
+
)
|
|
248
|
+
arg_name = params[iterable_arg]
|
|
249
|
+
else:
|
|
250
|
+
arg_name = iterable_arg
|
|
251
|
+
if arg_name not in bound.arguments:
|
|
252
|
+
raise SafeStateError(
|
|
253
|
+
f"iterable arg {arg_name!r} not provided to {fn.__name__}"
|
|
254
|
+
)
|
|
255
|
+
original_iter = bound.arguments[arg_name]
|
|
256
|
+
bound.arguments[arg_name] = (
|
|
257
|
+
item for _, item in _enumerate_safe(
|
|
258
|
+
original_iter, ctx, save_every, store_results
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
result = fn(*bound.args, **bound.kwargs)
|
|
264
|
+
# Success: optionally clear.
|
|
265
|
+
if not keep_on_success:
|
|
266
|
+
store.clear()
|
|
267
|
+
else:
|
|
268
|
+
store.save(cp)
|
|
269
|
+
return result
|
|
270
|
+
except BaseException as exc:
|
|
271
|
+
# Capture the failure with locals from the failing frame.
|
|
272
|
+
locals_snapshot = _grab_failure_locals(exc, fn)
|
|
273
|
+
cp.record_failure(
|
|
274
|
+
index=ctx.current_index if ctx.current_index is not None else -1,
|
|
275
|
+
exc=exc,
|
|
276
|
+
locals_snapshot=locals_snapshot,
|
|
277
|
+
)
|
|
278
|
+
store.save(cp)
|
|
279
|
+
if verbose:
|
|
280
|
+
prog = cp.progress()
|
|
281
|
+
print(
|
|
282
|
+
f"[safe_state] {job_name!r} failed at item "
|
|
283
|
+
f"{ctx.current_index}: {type(exc).__name__}: {exc}. "
|
|
284
|
+
f"Progress {prog['completed']}/{prog['total']} saved to {path}",
|
|
285
|
+
file=sys.stderr,
|
|
286
|
+
)
|
|
287
|
+
raise
|
|
288
|
+
finally:
|
|
289
|
+
_active_contexts.pop(id(fn.__code__), None)
|
|
290
|
+
|
|
291
|
+
wrapper.checkpoint_path = path # type: ignore[attr-defined]
|
|
292
|
+
wrapper.job_name = job_name # type: ignore[attr-defined]
|
|
293
|
+
wrapper.peek_checkpoint = lambda: CheckpointStore(path, registry=reg).load() if path.exists() else None # type: ignore[attr-defined]
|
|
294
|
+
wrapper.clear_checkpoint = lambda: CheckpointStore(path, registry=reg).clear() # type: ignore[attr-defined]
|
|
295
|
+
return wrapper
|
|
296
|
+
|
|
297
|
+
if _fn is not None and callable(_fn):
|
|
298
|
+
return decorate(_fn)
|
|
299
|
+
return decorate
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _enumerate_safe(
|
|
303
|
+
iterable: Iterable[Any],
|
|
304
|
+
ctx: _RunContext,
|
|
305
|
+
save_every: int,
|
|
306
|
+
store_results: bool,
|
|
307
|
+
) -> Iterator[Tuple[int, Any]]:
|
|
308
|
+
"""Internal: yield (idx, item) pairs, skipping completed indices and checkpointing."""
|
|
309
|
+
cp = ctx.checkpoint
|
|
310
|
+
items = list(iterable)
|
|
311
|
+
if cp.total_items is None:
|
|
312
|
+
cp.total_items = len(items)
|
|
313
|
+
since_save = 0
|
|
314
|
+
for idx, item in enumerate(items):
|
|
315
|
+
if idx in cp.completed_indices:
|
|
316
|
+
if ctx.verbose:
|
|
317
|
+
print(f"[safe_state] skip index {idx} (done)", file=sys.stderr)
|
|
318
|
+
continue
|
|
319
|
+
ctx.current_index = idx
|
|
320
|
+
yield idx, item
|
|
321
|
+
cp.mark_complete(idx, store_result=store_results)
|
|
322
|
+
since_save += 1
|
|
323
|
+
if since_save >= save_every:
|
|
324
|
+
ctx.store.save(cp)
|
|
325
|
+
since_save = 0
|
|
326
|
+
if since_save > 0:
|
|
327
|
+
ctx.store.save(cp)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _grab_failure_locals(exc: BaseException, decorated_fn: Callable) -> Dict[str, Any]:
|
|
331
|
+
"""
|
|
332
|
+
Pull locals from the frame where the exception was raised, filtered to
|
|
333
|
+
serializable-looking values. Falls back to the decorated function's frame
|
|
334
|
+
if available.
|
|
335
|
+
"""
|
|
336
|
+
tb = exc.__traceback__
|
|
337
|
+
target_frame = None
|
|
338
|
+
while tb is not None:
|
|
339
|
+
if tb.tb_frame.f_code is decorated_fn.__code__:
|
|
340
|
+
target_frame = tb.tb_frame
|
|
341
|
+
tb = tb.tb_next
|
|
342
|
+
|
|
343
|
+
if target_frame is None:
|
|
344
|
+
# Use the deepest frame as fallback.
|
|
345
|
+
tb = exc.__traceback__
|
|
346
|
+
while tb is not None:
|
|
347
|
+
target_frame = tb.tb_frame
|
|
348
|
+
tb = tb.tb_next
|
|
349
|
+
|
|
350
|
+
if target_frame is None:
|
|
351
|
+
return {}
|
|
352
|
+
|
|
353
|
+
out: Dict[str, Any] = {}
|
|
354
|
+
for k, v in target_frame.f_locals.items():
|
|
355
|
+
if k.startswith("__"):
|
|
356
|
+
continue
|
|
357
|
+
out[k] = v
|
|
358
|
+
return out
|
safe_state/exceptions.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Custom exceptions for safe-state."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SafeStateError(Exception):
|
|
5
|
+
"""Base exception for all safe-state errors."""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StateNotFoundError(SafeStateError):
|
|
9
|
+
"""Raised when attempting to resume from a checkpoint that doesn't exist."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StateCorruptedError(SafeStateError):
|
|
13
|
+
"""Raised when a checkpoint file cannot be deserialized."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ReconnectError(SafeStateError):
|
|
17
|
+
"""Raised when a live object cannot be re-established on resume."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NoReconnectorError(SafeStateError):
|
|
21
|
+
"""Raised when a live object is encountered with no registered reconnect handler."""
|