safe-state 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
safe_state/__init__.py ADDED
@@ -0,0 +1,64 @@
1
+ """
2
+ safe-state: resumable execution for Python.
3
+
4
+ Drop ``@safe_state`` on top of any function that iterates through work. If the
5
+ function crashes — network blip, rate limit, power loss — the checkpoint is on
6
+ disk. Run the script again and it picks up where it left off. Live objects
7
+ (database connections, sockets, file handles) are restored via a reconnect
8
+ registry instead of being naively pickled.
9
+
10
+ Quick start:
11
+
12
+ from safe_state import safe_state
13
+
14
+ @safe_state
15
+ def remediate(clients, db):
16
+ for client in clients:
17
+ secure_endpoint(client, db)
18
+
19
+ remediate(load_clients(), open_db()) # crashes at client 15? just rerun.
20
+ """
21
+
22
+ from .core import safe_state, checkpoint, SafeIterator
23
+ from .checkpoint import Checkpoint, CheckpointStore
24
+ from .reconnect import (
25
+ ReconnectRegistry,
26
+ reconnect_handler,
27
+ register_reconnector,
28
+ get_default_registry,
29
+ )
30
+ from .serialization import freeze_state, thaw_state, can_freeze
31
+ from .exceptions import (
32
+ SafeStateError,
33
+ StateNotFoundError,
34
+ StateCorruptedError,
35
+ ReconnectError,
36
+ NoReconnectorError,
37
+ )
38
+
39
+ __version__ = "0.1.0"
40
+
41
+ __all__ = [
42
+ # decorator + helpers
43
+ "safe_state",
44
+ "checkpoint",
45
+ "SafeIterator",
46
+ # checkpoint store
47
+ "Checkpoint",
48
+ "CheckpointStore",
49
+ # reconnect machinery
50
+ "ReconnectRegistry",
51
+ "reconnect_handler",
52
+ "register_reconnector",
53
+ "get_default_registry",
54
+ # serialization
55
+ "freeze_state",
56
+ "thaw_state",
57
+ "can_freeze",
58
+ # exceptions
59
+ "SafeStateError",
60
+ "StateNotFoundError",
61
+ "StateCorruptedError",
62
+ "ReconnectError",
63
+ "NoReconnectorError",
64
+ ]
@@ -0,0 +1,168 @@
1
+ """
2
+ Checkpoint store: on-disk persistence of execution progress.
3
+
4
+ A Checkpoint records:
5
+ - which items in an iteration have completed (by index)
6
+ - results returned for each completed item (optional)
7
+ - the last failure (exception type, message, traceback) if any
8
+ - frozen local state at the point of failure
9
+ - a timestamp and run identifier
10
+
11
+ Checkpoints are stored as a single dill blob per job name, atomically replaced
12
+ via write-temp-then-rename to avoid corruption on crash.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import tempfile
19
+ import time
20
+ import traceback
21
+ from dataclasses import dataclass, field
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Set
24
+
25
+ from .exceptions import StateCorruptedError, StateNotFoundError
26
+ from .reconnect import ReconnectRegistry, get_default_registry
27
+ from .serialization import freeze_state, thaw_state
28
+
29
+
30
+ @dataclass
31
+ class Checkpoint:
32
+ """A snapshot of progress through a safe-state-decorated function."""
33
+
34
+ job_name: str
35
+ completed_indices: Set[int] = field(default_factory=set)
36
+ results: Dict[int, Any] = field(default_factory=dict)
37
+ total_items: Optional[int] = None
38
+ last_failure: Optional[Dict[str, Any]] = None
39
+ frozen_state: Optional[bytes] = None # raw bytes from freeze_state()
40
+ started_at: float = field(default_factory=time.time)
41
+ updated_at: float = field(default_factory=time.time)
42
+ run_count: int = 1
43
+
44
+ def mark_complete(self, index: int, result: Any = None, store_result: bool = False) -> None:
45
+ self.completed_indices.add(index)
46
+ if store_result:
47
+ self.results[index] = result
48
+ self.updated_at = time.time()
49
+
50
+ def record_failure(self, index: int, exc: BaseException, locals_snapshot: Dict[str, Any]) -> None:
51
+ self.last_failure = {
52
+ "index": index,
53
+ "exception_type": type(exc).__name__,
54
+ "exception_module": type(exc).__module__,
55
+ "message": str(exc),
56
+ "traceback": "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
57
+ "at": time.time(),
58
+ }
59
+ # Try to freeze the whole snapshot first; if that fails, freeze per-key,
60
+ # dropping values that can't be serialized.
61
+ try:
62
+ self.frozen_state = freeze_state(locals_snapshot)
63
+ except Exception:
64
+ partial: Dict[str, Any] = {}
65
+ dropped: List[str] = []
66
+ for k, v in locals_snapshot.items():
67
+ try:
68
+ freeze_state({k: v})
69
+ partial[k] = v
70
+ except Exception:
71
+ dropped.append(k)
72
+ try:
73
+ self.frozen_state = freeze_state(partial)
74
+ if dropped:
75
+ self.last_failure["dropped_locals"] = dropped
76
+ except Exception as freeze_err:
77
+ self.frozen_state = None
78
+ self.last_failure["freeze_error"] = (
79
+ f"{type(freeze_err).__name__}: {freeze_err}"
80
+ )
81
+ self.updated_at = time.time()
82
+
83
+ def progress(self) -> Dict[str, Any]:
84
+ """Human-readable progress summary."""
85
+ done = len(self.completed_indices)
86
+ total = self.total_items if self.total_items is not None else "?"
87
+ return {
88
+ "job": self.job_name,
89
+ "completed": done,
90
+ "total": total,
91
+ "remaining": (self.total_items - done) if self.total_items is not None else "?",
92
+ "has_failure": self.last_failure is not None,
93
+ "run_count": self.run_count,
94
+ }
95
+
96
+
97
+ class CheckpointStore:
98
+ """File-backed persistence for a single Checkpoint."""
99
+
100
+ def __init__(self, path: str | os.PathLike, registry: Optional[ReconnectRegistry] = None):
101
+ self.path = Path(path)
102
+ self.registry = registry or get_default_registry()
103
+
104
+ def exists(self) -> bool:
105
+ return self.path.exists()
106
+
107
+ def load(self) -> Checkpoint:
108
+ if not self.path.exists():
109
+ raise StateNotFoundError(f"No checkpoint at {self.path}")
110
+ try:
111
+ with open(self.path, "rb") as f:
112
+ blob = f.read()
113
+ payload = thaw_state(blob, registry=self.registry)
114
+ except Exception as e:
115
+ raise StateCorruptedError(f"Could not load checkpoint {self.path}: {e}") from e
116
+
117
+ cp = Checkpoint(
118
+ job_name=payload["job_name"],
119
+ completed_indices=set(payload.get("completed_indices", [])),
120
+ results=payload.get("results", {}),
121
+ total_items=payload.get("total_items"),
122
+ last_failure=payload.get("last_failure"),
123
+ frozen_state=payload.get("frozen_state"),
124
+ started_at=payload.get("started_at", time.time()),
125
+ updated_at=payload.get("updated_at", time.time()),
126
+ run_count=payload.get("run_count", 1),
127
+ )
128
+ return cp
129
+
130
+ def save(self, cp: Checkpoint) -> None:
131
+ self.path.parent.mkdir(parents=True, exist_ok=True)
132
+ payload = {
133
+ "job_name": cp.job_name,
134
+ "completed_indices": list(cp.completed_indices),
135
+ "results": cp.results,
136
+ "total_items": cp.total_items,
137
+ "last_failure": cp.last_failure,
138
+ "frozen_state": cp.frozen_state, # already bytes
139
+ "started_at": cp.started_at,
140
+ "updated_at": cp.updated_at,
141
+ "run_count": cp.run_count,
142
+ }
143
+ blob = freeze_state(payload, registry=self.registry)
144
+
145
+ # Atomic write: temp file in same directory, then rename.
146
+ tmp_fd, tmp_path = tempfile.mkstemp(
147
+ prefix=f".{self.path.name}.", suffix=".tmp", dir=str(self.path.parent)
148
+ )
149
+ try:
150
+ with os.fdopen(tmp_fd, "wb") as f:
151
+ f.write(blob)
152
+ os.replace(tmp_path, self.path)
153
+ except Exception:
154
+ try:
155
+ os.unlink(tmp_path)
156
+ except OSError:
157
+ pass
158
+ raise
159
+
160
+ def clear(self) -> None:
161
+ if self.path.exists():
162
+ self.path.unlink()
163
+
164
+ def thaw_locals(self, cp: Checkpoint) -> Optional[Dict[str, Any]]:
165
+ """Rebuild the frozen locals snapshot, reconnecting any live objects."""
166
+ if cp.frozen_state is None:
167
+ return None
168
+ return thaw_state(cp.frozen_state, registry=self.registry)
safe_state/core.py ADDED
@@ -0,0 +1,358 @@
1
+ """
2
+ The `safe_state` decorator and its supporting iterator.
3
+
4
+ The decorator wraps a function that processes an iterable. The first positional
5
+ argument (by default) is intercepted and replaced with a SafeIterator that:
6
+
7
+ 1. Tracks progress through the iteration to disk after every item.
8
+ 2. On exception, saves a checkpoint of completed indices + frozen locals.
9
+ 3. On the next invocation with the same job name, loads the checkpoint
10
+ and skips items that were already completed.
11
+
12
+ A `safe_state.checkpoint()` helper is also provided for use inside the function
13
+ body (manual checkpoint mode).
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ import inspect
20
+ import os
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
24
+
25
+ from .checkpoint import Checkpoint, CheckpointStore
26
+ from .exceptions import SafeStateError, StateNotFoundError
27
+ from .reconnect import ReconnectRegistry, get_default_registry
28
+
29
+
30
+ # Default directory for checkpoints. Override via env var or decorator arg.
31
+ DEFAULT_STATE_DIR = Path(os.environ.get("SAFE_STATE_DIR", ".safe_state"))
32
+
33
+
34
+ # Thread-local-ish current context, exposed so user code can call manual checkpoint().
35
+ # We use a module-level dict keyed by function id — simple and avoids threading concerns
36
+ # for the common single-threaded use case. For multithreaded use, users should pass
37
+ # explicit CheckpointStore objects.
38
+ _active_contexts: Dict[int, "_RunContext"] = {}
39
+
40
+
41
+ class _RunContext:
42
+ """Per-call context used by the SafeIterator and manual checkpoint()."""
43
+
44
+ def __init__(
45
+ self,
46
+ store: CheckpointStore,
47
+ checkpoint: Checkpoint,
48
+ verbose: bool,
49
+ ):
50
+ self.store = store
51
+ self.checkpoint = checkpoint
52
+ self.verbose = verbose
53
+ self.current_index: Optional[int] = None
54
+ self.current_locals: Dict[str, Any] = {}
55
+
56
+
57
+ class SafeIterator:
58
+ """
59
+ Wraps an iterable so iteration progress is checkpointed automatically.
60
+
61
+ Yields items whose index is not already in checkpoint.completed_indices.
62
+ After each yielded item is consumed (i.e., the next iteration starts),
63
+ the previous item's index is marked complete and the checkpoint is saved.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ iterable: Iterable[Any],
69
+ context: _RunContext,
70
+ save_every: int = 1,
71
+ store_results: bool = False,
72
+ ):
73
+ self._items = list(iterable) # materialize to get a stable length and indexing
74
+ self._context = context
75
+ self._save_every = max(1, int(save_every))
76
+ self._store_results = store_results
77
+ self._since_save = 0
78
+
79
+ # Record total on first run.
80
+ if self._context.checkpoint.total_items is None:
81
+ self._context.checkpoint.total_items = len(self._items)
82
+
83
+ def __iter__(self) -> Iterator[Tuple[int, Any]]:
84
+ cp = self._context.checkpoint
85
+ for idx, item in enumerate(self._items):
86
+ if idx in cp.completed_indices:
87
+ if self._context.verbose:
88
+ print(f"[safe_state] skip {idx} (already done)", file=sys.stderr)
89
+ continue
90
+ self._context.current_index = idx
91
+ yield item
92
+ # Mark the just-yielded item complete.
93
+ cp.mark_complete(idx, store_result=self._store_results)
94
+ self._since_save += 1
95
+ if self._since_save >= self._save_every:
96
+ self._context.store.save(cp)
97
+ self._since_save = 0
98
+ # Final flush.
99
+ if self._since_save > 0:
100
+ self._context.store.save(cp)
101
+
102
+ def __len__(self) -> int:
103
+ return len(self._items)
104
+
105
+
106
+ def checkpoint(**locals_to_save: Any) -> None:
107
+ """
108
+ Manually trigger a checkpoint save from inside a decorated function.
109
+
110
+ Pass any local variables you want frozen into the checkpoint:
111
+
112
+ @safe_state
113
+ def long_job(tasks, db):
114
+ results = []
115
+ for t in tasks:
116
+ results.append(do(t))
117
+ checkpoint(results=results) # snapshot progress
118
+
119
+ Has no effect if called outside a safe-state-decorated function.
120
+ """
121
+ # Walk up the call stack to find a frame whose function is decorated.
122
+ frame = inspect.currentframe()
123
+ try:
124
+ while frame is not None:
125
+ ctx = _active_contexts.get(id(frame.f_code))
126
+ if ctx is not None:
127
+ ctx.current_locals = dict(locals_to_save)
128
+ # Freeze locals into the checkpoint snapshot.
129
+ cp = ctx.checkpoint
130
+ from .serialization import freeze_state
131
+ try:
132
+ cp.frozen_state = freeze_state(locals_to_save)
133
+ except Exception:
134
+ cp.frozen_state = None
135
+ ctx.store.save(cp)
136
+ return
137
+ frame = frame.f_back
138
+ finally:
139
+ del frame
140
+
141
+
142
+ def safe_state(
143
+ _fn: Optional[Callable] = None,
144
+ *,
145
+ name: Optional[str] = None,
146
+ state_dir: Union[str, os.PathLike, None] = None,
147
+ iterable_arg: Union[str, int] = 0,
148
+ save_every: int = 1,
149
+ store_results: bool = False,
150
+ keep_on_success: bool = False,
151
+ verbose: bool = False,
152
+ registry: Optional[ReconnectRegistry] = None,
153
+ auto_iterate: bool = True,
154
+ ) -> Callable:
155
+ """
156
+ Decorator that makes a function resumable.
157
+
158
+ Args:
159
+ name: Job name (used as the checkpoint filename). Defaults to the function's
160
+ qualified name.
161
+ state_dir: Directory for checkpoint files. Defaults to ``.safe_state/`` or
162
+ ``$SAFE_STATE_DIR``.
163
+ iterable_arg: Position (int) or name (str) of the iterable argument that
164
+ should be intercepted and checkpointed. Default: first positional arg.
165
+ save_every: Persist progress every N completed items. Default: 1 (every item).
166
+ store_results: If True, store the iterable's items in the checkpoint so
167
+ results from already-completed items survive across restarts. Note: items
168
+ must be serializable.
169
+ keep_on_success: If True, leave the checkpoint file in place after the
170
+ function completes successfully. Default: delete it.
171
+ verbose: Print progress messages to stderr.
172
+ registry: Optional custom ReconnectRegistry for live-object handling.
173
+ auto_iterate: If True (default), the decorated function receives a
174
+ SafeIterator in place of its iterable argument. If False, the function
175
+ is left untouched — use this with manual checkpoint() calls.
176
+
177
+ Usage:
178
+
179
+ @safe_state
180
+ def remediate(clients, db):
181
+ for client in clients:
182
+ secure_endpoint(client, db)
183
+
184
+ # Or with config:
185
+ @safe_state(name="newsletter-blast", state_dir="state/", save_every=5)
186
+ def remediate(clients, db):
187
+ ...
188
+
189
+ On the first run, the checkpoint file is created. On any subsequent run,
190
+ if the previous run crashed, the function resumes from where it stopped.
191
+ """
192
+ state_dir_path = Path(state_dir) if state_dir is not None else DEFAULT_STATE_DIR
193
+ reg = registry or get_default_registry()
194
+
195
+ def decorate(fn: Callable) -> Callable:
196
+ job_name = name or f"{fn.__module__}.{fn.__qualname__}"
197
+ # Sanitize for filename use.
198
+ safe_name = job_name.replace("/", "_").replace(os.sep, "_")
199
+ path = state_dir_path / f"{safe_name}.safestate"
200
+
201
+ sig = inspect.signature(fn)
202
+
203
+ @functools.wraps(fn)
204
+ def wrapper(*args, **kwargs):
205
+ store = CheckpointStore(path, registry=reg)
206
+
207
+ # Load or create the checkpoint.
208
+ if store.exists():
209
+ try:
210
+ cp = store.load()
211
+ cp.run_count += 1
212
+ if verbose:
213
+ prog = cp.progress()
214
+ print(
215
+ f"[safe_state] resuming {job_name!r}: "
216
+ f"{prog['completed']}/{prog['total']} done "
217
+ f"(run #{cp.run_count})",
218
+ file=sys.stderr,
219
+ )
220
+ except SafeStateError as e:
221
+ if verbose:
222
+ print(
223
+ f"[safe_state] could not load checkpoint ({e}); starting fresh",
224
+ file=sys.stderr,
225
+ )
226
+ cp = Checkpoint(job_name=job_name)
227
+ else:
228
+ cp = Checkpoint(job_name=job_name)
229
+ if verbose:
230
+ print(f"[safe_state] starting fresh job {job_name!r}", file=sys.stderr)
231
+
232
+ ctx = _RunContext(store=store, checkpoint=cp, verbose=verbose)
233
+ _active_contexts[id(fn.__code__)] = ctx
234
+
235
+ # Bind args and intercept the iterable if auto_iterate.
236
+ bound = sig.bind_partial(*args, **kwargs)
237
+ bound.apply_defaults()
238
+
239
+ if auto_iterate:
240
+ arg_name: Optional[str]
241
+ if isinstance(iterable_arg, int):
242
+ params = list(sig.parameters)
243
+ if iterable_arg >= len(params):
244
+ raise SafeStateError(
245
+ f"iterable_arg={iterable_arg} but {fn.__name__} only "
246
+ f"has {len(params)} parameters"
247
+ )
248
+ arg_name = params[iterable_arg]
249
+ else:
250
+ arg_name = iterable_arg
251
+ if arg_name not in bound.arguments:
252
+ raise SafeStateError(
253
+ f"iterable arg {arg_name!r} not provided to {fn.__name__}"
254
+ )
255
+ original_iter = bound.arguments[arg_name]
256
+ bound.arguments[arg_name] = (
257
+ item for _, item in _enumerate_safe(
258
+ original_iter, ctx, save_every, store_results
259
+ )
260
+ )
261
+
262
+ try:
263
+ result = fn(*bound.args, **bound.kwargs)
264
+ # Success: optionally clear.
265
+ if not keep_on_success:
266
+ store.clear()
267
+ else:
268
+ store.save(cp)
269
+ return result
270
+ except BaseException as exc:
271
+ # Capture the failure with locals from the failing frame.
272
+ locals_snapshot = _grab_failure_locals(exc, fn)
273
+ cp.record_failure(
274
+ index=ctx.current_index if ctx.current_index is not None else -1,
275
+ exc=exc,
276
+ locals_snapshot=locals_snapshot,
277
+ )
278
+ store.save(cp)
279
+ if verbose:
280
+ prog = cp.progress()
281
+ print(
282
+ f"[safe_state] {job_name!r} failed at item "
283
+ f"{ctx.current_index}: {type(exc).__name__}: {exc}. "
284
+ f"Progress {prog['completed']}/{prog['total']} saved to {path}",
285
+ file=sys.stderr,
286
+ )
287
+ raise
288
+ finally:
289
+ _active_contexts.pop(id(fn.__code__), None)
290
+
291
+ wrapper.checkpoint_path = path # type: ignore[attr-defined]
292
+ wrapper.job_name = job_name # type: ignore[attr-defined]
293
+ wrapper.peek_checkpoint = lambda: CheckpointStore(path, registry=reg).load() if path.exists() else None # type: ignore[attr-defined]
294
+ wrapper.clear_checkpoint = lambda: CheckpointStore(path, registry=reg).clear() # type: ignore[attr-defined]
295
+ return wrapper
296
+
297
+ if _fn is not None and callable(_fn):
298
+ return decorate(_fn)
299
+ return decorate
300
+
301
+
302
+ def _enumerate_safe(
303
+ iterable: Iterable[Any],
304
+ ctx: _RunContext,
305
+ save_every: int,
306
+ store_results: bool,
307
+ ) -> Iterator[Tuple[int, Any]]:
308
+ """Internal: yield (idx, item) pairs, skipping completed indices and checkpointing."""
309
+ cp = ctx.checkpoint
310
+ items = list(iterable)
311
+ if cp.total_items is None:
312
+ cp.total_items = len(items)
313
+ since_save = 0
314
+ for idx, item in enumerate(items):
315
+ if idx in cp.completed_indices:
316
+ if ctx.verbose:
317
+ print(f"[safe_state] skip index {idx} (done)", file=sys.stderr)
318
+ continue
319
+ ctx.current_index = idx
320
+ yield idx, item
321
+ cp.mark_complete(idx, store_result=store_results)
322
+ since_save += 1
323
+ if since_save >= save_every:
324
+ ctx.store.save(cp)
325
+ since_save = 0
326
+ if since_save > 0:
327
+ ctx.store.save(cp)
328
+
329
+
330
+ def _grab_failure_locals(exc: BaseException, decorated_fn: Callable) -> Dict[str, Any]:
331
+ """
332
+ Pull locals from the frame where the exception was raised, filtered to
333
+ serializable-looking values. Falls back to the decorated function's frame
334
+ if available.
335
+ """
336
+ tb = exc.__traceback__
337
+ target_frame = None
338
+ while tb is not None:
339
+ if tb.tb_frame.f_code is decorated_fn.__code__:
340
+ target_frame = tb.tb_frame
341
+ tb = tb.tb_next
342
+
343
+ if target_frame is None:
344
+ # Use the deepest frame as fallback.
345
+ tb = exc.__traceback__
346
+ while tb is not None:
347
+ target_frame = tb.tb_frame
348
+ tb = tb.tb_next
349
+
350
+ if target_frame is None:
351
+ return {}
352
+
353
+ out: Dict[str, Any] = {}
354
+ for k, v in target_frame.f_locals.items():
355
+ if k.startswith("__"):
356
+ continue
357
+ out[k] = v
358
+ return out
@@ -0,0 +1,21 @@
1
+ """Custom exceptions for safe-state."""
2
+
3
+
4
+ class SafeStateError(Exception):
5
+ """Base exception for all safe-state errors."""
6
+
7
+
8
+ class StateNotFoundError(SafeStateError):
9
+ """Raised when attempting to resume from a checkpoint that doesn't exist."""
10
+
11
+
12
+ class StateCorruptedError(SafeStateError):
13
+ """Raised when a checkpoint file cannot be deserialized."""
14
+
15
+
16
+ class ReconnectError(SafeStateError):
17
+ """Raised when a live object cannot be re-established on resume."""
18
+
19
+
20
+ class NoReconnectorError(SafeStateError):
21
+ """Raised when a live object is encountered with no registered reconnect handler."""