checkpointer 2.1.0__py3-none-any.whl → 2.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer/__init__.py +13 -3
- checkpointer/checkpoint.py +62 -25
- checkpointer/fn_ident.py +94 -0
- checkpointer/object_hash.py +186 -0
- checkpointer/storages/__init__.py +1 -1
- checkpointer/storages/bcolz_storage.py +6 -7
- checkpointer/storages/memory_storage.py +25 -11
- checkpointer/storages/pickle_storage.py +27 -13
- checkpointer/{types.py → storages/storage.py} +9 -5
- checkpointer/test_checkpointer.py +170 -0
- checkpointer/utils.py +92 -32
- {checkpointer-2.1.0.dist-info → checkpointer-2.5.0.dist-info}/METADATA +18 -6
- checkpointer-2.5.0.dist-info/RECORD +16 -0
- {checkpointer-2.1.0.dist-info → checkpointer-2.5.0.dist-info}/WHEEL +1 -1
- {checkpointer-2.1.0.dist-info → checkpointer-2.5.0.dist-info}/licenses/LICENSE +1 -1
- checkpointer/function_body.py +0 -80
- checkpointer-2.1.0.dist-info/RECORD +0 -14
checkpointer/__init__.py
CHANGED
@@ -1,10 +1,20 @@
|
|
1
|
-
|
2
|
-
from .types import Storage
|
3
|
-
from .function_body import get_function_hash
|
1
|
+
import gc
|
4
2
|
import tempfile
|
3
|
+
from typing import Callable
|
4
|
+
from .checkpoint import Checkpointer, CheckpointError, CheckpointFn
|
5
|
+
from .object_hash import ObjectHash
|
6
|
+
from .storages import MemoryStorage, PickleStorage, Storage
|
5
7
|
|
6
8
|
create_checkpointer = Checkpointer
|
7
9
|
checkpoint = Checkpointer()
|
8
10
|
capture_checkpoint = Checkpointer(capture=True)
|
9
11
|
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
10
12
|
tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
|
13
|
+
|
14
|
+
def cleanup_all(invalidated=True, expired=True):
|
15
|
+
for obj in gc.get_objects():
|
16
|
+
if isinstance(obj, CheckpointFn):
|
17
|
+
obj.cleanup(invalidated=invalidated, expired=expired)
|
18
|
+
|
19
|
+
def get_function_hash(fn: Callable, capture=False) -> str:
|
20
|
+
return CheckpointFn(Checkpointer(capture=capture), fn).fn_hash
|
checkpointer/checkpoint.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import inspect
|
3
|
-
import
|
4
|
-
from typing import Generic, TypeVar, Type, TypedDict, Callable, Unpack, Literal, Any, cast, overload
|
5
|
-
from pathlib import Path
|
3
|
+
import re
|
6
4
|
from datetime import datetime
|
7
5
|
from functools import update_wrapper
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from .
|
11
|
-
from .
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
|
8
|
+
from .fn_ident import get_fn_ident
|
9
|
+
from .object_hash import ObjectHash
|
12
10
|
from .print_checkpoint import print_checkpoint
|
11
|
+
from .storages import STORAGE_MAP, Storage
|
12
|
+
from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
|
13
13
|
|
14
14
|
Fn = TypeVar("Fn", bound=Callable)
|
15
15
|
|
@@ -50,22 +50,47 @@ class Checkpointer:
|
|
50
50
|
|
51
51
|
class CheckpointFn(Generic[Fn]):
|
52
52
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
53
|
-
wrapped = unwrap_fn(fn)
|
54
|
-
file_name = Path(wrapped.__code__.co_filename).name
|
55
|
-
update_wrapper(cast(Callable, self), wrapped)
|
56
|
-
storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
|
57
53
|
self.checkpointer = checkpointer
|
58
54
|
self.fn = fn
|
59
|
-
|
60
|
-
|
55
|
+
|
56
|
+
def _set_ident(self, force=False):
|
57
|
+
if not hasattr(self, "fn_hash_raw") or force:
|
58
|
+
self.fn_hash_raw, self.depends = get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
|
59
|
+
return self
|
60
|
+
|
61
|
+
def _lazyinit(self):
|
62
|
+
wrapped = unwrap_fn(self.fn)
|
63
|
+
fn_file = Path(wrapped.__code__.co_filename).name
|
64
|
+
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
65
|
+
update_wrapper(cast(Callable, self), wrapped)
|
66
|
+
store_format = self.checkpointer.format
|
67
|
+
Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
|
68
|
+
deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
|
69
|
+
self.fn_hash = str(ObjectHash().update_hash(self.fn_hash_raw, iter=deep_hashes))
|
70
|
+
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
|
61
71
|
self.is_async = inspect.iscoroutinefunction(wrapped)
|
62
|
-
self.storage =
|
72
|
+
self.storage = Storage(self)
|
73
|
+
self.cleanup = self.storage.cleanup
|
74
|
+
|
75
|
+
def __getattribute__(self, name: str) -> Any:
|
76
|
+
return object.__getattribute__(self, "_getattribute")(name)
|
77
|
+
|
78
|
+
def _getattribute(self, name: str) -> Any:
|
79
|
+
setattr(self, "_getattribute", super().__getattribute__)
|
80
|
+
self._lazyinit()
|
81
|
+
return self._getattribute(name)
|
82
|
+
|
83
|
+
def reinit(self, recursive=False):
|
84
|
+
pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
|
85
|
+
for pointfn in pointfns:
|
86
|
+
pointfn._set_ident(True)
|
87
|
+
for pointfn in pointfns:
|
88
|
+
pointfn._lazyinit()
|
63
89
|
|
64
90
|
def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
|
65
91
|
if not callable(self.checkpointer.path):
|
66
|
-
|
67
|
-
|
68
|
-
return f"{self.fn_id}/{call_hash}"
|
92
|
+
call_hash = ObjectHash(self.fn_hash, args, kw, digest_size=16)
|
93
|
+
return f"{self.fn_subdir}/{call_hash}"
|
69
94
|
checkpoint_id = self.checkpointer.path(*args, **kw)
|
70
95
|
if not isinstance(checkpoint_id, str):
|
71
96
|
raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
|
@@ -74,13 +99,13 @@ class CheckpointFn(Generic[Fn]):
|
|
74
99
|
async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
|
75
100
|
checkpoint_id = self.get_checkpoint_id(args, kw)
|
76
101
|
checkpoint_path = self.checkpointer.root_path / checkpoint_id
|
77
|
-
|
102
|
+
verbose = self.checkpointer.verbosity > 0
|
78
103
|
refresh = rerun \
|
79
104
|
or not self.storage.exists(checkpoint_path) \
|
80
105
|
or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
|
81
106
|
|
82
107
|
if refresh:
|
83
|
-
print_checkpoint(
|
108
|
+
print_checkpoint(verbose, "MEMORIZING", checkpoint_id, "blue")
|
84
109
|
data = self.fn(*args, **kw)
|
85
110
|
if inspect.iscoroutine(data):
|
86
111
|
data = await data
|
@@ -89,12 +114,12 @@ class CheckpointFn(Generic[Fn]):
|
|
89
114
|
|
90
115
|
try:
|
91
116
|
data = self.storage.load(checkpoint_path)
|
92
|
-
print_checkpoint(
|
117
|
+
print_checkpoint(verbose, "REMEMBERED", checkpoint_id, "green")
|
93
118
|
return data
|
94
119
|
except (EOFError, FileNotFoundError):
|
95
|
-
|
96
|
-
|
97
|
-
|
120
|
+
pass
|
121
|
+
print_checkpoint(verbose, "CORRUPTED", checkpoint_id, "yellow")
|
122
|
+
return await self._store_on_demand(args, kw, True)
|
98
123
|
|
99
124
|
def _call(self, args: tuple, kw: dict, rerun=False):
|
100
125
|
if not self.checkpointer.when:
|
@@ -107,8 +132,8 @@ class CheckpointFn(Generic[Fn]):
|
|
107
132
|
try:
|
108
133
|
val = self.storage.load(checkpoint_path)
|
109
134
|
return resolved_awaitable(val) if self.is_async else val
|
110
|
-
except:
|
111
|
-
raise CheckpointError("Could not load checkpoint")
|
135
|
+
except Exception as ex:
|
136
|
+
raise CheckpointError("Could not load checkpoint") from ex
|
112
137
|
|
113
138
|
def exists(self, *args: tuple, **kw: dict) -> bool:
|
114
139
|
return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
|
@@ -116,3 +141,15 @@ class CheckpointFn(Generic[Fn]):
|
|
116
141
|
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
117
142
|
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
118
143
|
get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
|
144
|
+
|
145
|
+
def __repr__(self) -> str:
|
146
|
+
return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
|
147
|
+
|
148
|
+
def iterate_checkpoint_fns(pointfn: CheckpointFn, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
|
149
|
+
visited = visited or set()
|
150
|
+
if pointfn not in visited:
|
151
|
+
yield pointfn
|
152
|
+
visited.add(pointfn)
|
153
|
+
for depend in pointfn.depends:
|
154
|
+
if isinstance(depend, CheckpointFn):
|
155
|
+
yield from iterate_checkpoint_fns(depend, visited)
|
checkpointer/fn_ident.py
ADDED
@@ -0,0 +1,94 @@
|
|
1
|
+
import dis
|
2
|
+
import inspect
|
3
|
+
from collections.abc import Callable
|
4
|
+
from itertools import takewhile
|
5
|
+
from pathlib import Path
|
6
|
+
from types import CodeType, FunctionType, MethodType
|
7
|
+
from typing import Any, Generator, Type, TypeGuard
|
8
|
+
from .object_hash import ObjectHash
|
9
|
+
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
|
10
|
+
|
11
|
+
cwd = Path.cwd()
|
12
|
+
|
13
|
+
def is_class(obj) -> TypeGuard[Type]:
|
14
|
+
# isinstance works too, but needlessly triggers __getattribute__
|
15
|
+
return issubclass(type(obj), type)
|
16
|
+
|
17
|
+
def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
|
18
|
+
attr_path: tuple[str, ...] = ()
|
19
|
+
scope_obj = None
|
20
|
+
classvars: dict[str, dict[str, Type]] = {}
|
21
|
+
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
22
|
+
if instr.opname in scope_vars and not attr_path:
|
23
|
+
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
24
|
+
attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
25
|
+
elif instr.opname == "CALL":
|
26
|
+
obj = scope_vars.get_at(attr_path)
|
27
|
+
attr_path = ()
|
28
|
+
if is_class(obj):
|
29
|
+
scope_obj = obj
|
30
|
+
elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
|
31
|
+
load_key = instr.opname.replace("STORE", "LOAD")
|
32
|
+
classvars.setdefault(load_key, {})[instr.argval] = scope_obj
|
33
|
+
scope_obj = None
|
34
|
+
return classvars
|
35
|
+
|
36
|
+
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Generator[tuple[tuple[str, ...], Any], None, None]:
|
37
|
+
classvars = extract_classvars(code, scope_vars)
|
38
|
+
scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
|
39
|
+
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
40
|
+
if instr.opname in scope_vars:
|
41
|
+
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
42
|
+
attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
43
|
+
val = scope_vars.get_at(attr_path)
|
44
|
+
if val is not None:
|
45
|
+
yield attr_path, val
|
46
|
+
for const in code.co_consts:
|
47
|
+
if isinstance(const, CodeType):
|
48
|
+
yield from extract_scope_values(const, scope_vars)
|
49
|
+
|
50
|
+
def get_self_value(fn: Callable) -> type | object | None:
|
51
|
+
if isinstance(fn, MethodType):
|
52
|
+
return fn.__self__
|
53
|
+
parts = tuple(fn.__qualname__.split(".")[:-1])
|
54
|
+
cls = parts and AttrDict(fn.__globals__).get_at(parts)
|
55
|
+
if is_class(cls):
|
56
|
+
return cls
|
57
|
+
|
58
|
+
def get_fn_captured_vals(fn: Callable) -> list[Any]:
|
59
|
+
self_value = get_self_value(fn)
|
60
|
+
scope_vars = AttrDict({
|
61
|
+
"LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
|
62
|
+
"LOAD_DEREF": AttrDict(get_cell_contents(fn)),
|
63
|
+
"LOAD_GLOBAL": AttrDict(fn.__globals__),
|
64
|
+
})
|
65
|
+
vals = dict(extract_scope_values(fn.__code__, scope_vars))
|
66
|
+
return list(vals.values())
|
67
|
+
|
68
|
+
def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
69
|
+
if not isinstance(candidate_fn, (FunctionType, MethodType)):
|
70
|
+
return False
|
71
|
+
fn_path = Path(inspect.getfile(candidate_fn)).resolve()
|
72
|
+
return cwd in fn_path.parents and ".venv" not in fn_path.parts
|
73
|
+
|
74
|
+
def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
|
75
|
+
from .checkpoint import CheckpointFn
|
76
|
+
captured_vals_by_fn = captured_vals_by_fn or {}
|
77
|
+
captured_vals = get_fn_captured_vals(fn)
|
78
|
+
captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
|
79
|
+
child_fns = (unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val))
|
80
|
+
for child_fn in child_fns:
|
81
|
+
if isinstance(child_fn, CheckpointFn):
|
82
|
+
captured_vals_by_fn[child_fn] = []
|
83
|
+
elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
|
84
|
+
get_depend_fns(child_fn, capture, captured_vals_by_fn)
|
85
|
+
return captured_vals_by_fn
|
86
|
+
|
87
|
+
def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
|
88
|
+
from .checkpoint import CheckpointFn
|
89
|
+
captured_vals_by_fn = get_depend_fns(fn, capture)
|
90
|
+
depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
|
91
|
+
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
92
|
+
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CheckpointFn)]
|
93
|
+
fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
|
94
|
+
return fn_hash, depends
|
@@ -0,0 +1,186 @@
|
|
1
|
+
import ctypes
|
2
|
+
import hashlib
|
3
|
+
import io
|
4
|
+
import re
|
5
|
+
from collections.abc import Iterable
|
6
|
+
from contextlib import nullcontext
|
7
|
+
from decimal import Decimal
|
8
|
+
from itertools import chain
|
9
|
+
from pickle import HIGHEST_PROTOCOL as PROTOCOL
|
10
|
+
from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
|
11
|
+
from typing import Any, TypeAliasType, TypeVar
|
12
|
+
from .utils import ContextVar, get_fn_body
|
13
|
+
|
14
|
+
try:
|
15
|
+
import numpy as np
|
16
|
+
except:
|
17
|
+
np = None
|
18
|
+
try:
|
19
|
+
import torch
|
20
|
+
except:
|
21
|
+
torch = None
|
22
|
+
|
23
|
+
def encode_type(t: type | FunctionType) -> str:
|
24
|
+
return f"{t.__module__}:{t.__qualname__}"
|
25
|
+
|
26
|
+
def encode_val(v: Any) -> str:
|
27
|
+
return encode_type(type(v))
|
28
|
+
|
29
|
+
class ObjectHashError(Exception):
|
30
|
+
def __init__(self, obj: Any, cause: Exception):
|
31
|
+
super().__init__(f"{type(cause).__name__} error when hashing {obj}")
|
32
|
+
self.obj = obj
|
33
|
+
|
34
|
+
class ObjectHash:
|
35
|
+
def __init__(self, *obj: Any, iter: Iterable[Any] = [], digest_size=64, tolerate_errors=False) -> None:
|
36
|
+
self.hash = hashlib.blake2b(digest_size=digest_size)
|
37
|
+
self.current: dict[int, int] = {}
|
38
|
+
self.tolerate_errors = ContextVar(tolerate_errors)
|
39
|
+
self.update(iter=chain(obj, iter))
|
40
|
+
|
41
|
+
def copy(self) -> "ObjectHash":
|
42
|
+
new = ObjectHash(tolerate_errors=self.tolerate_errors.value)
|
43
|
+
new.hash = self.hash.copy()
|
44
|
+
return new
|
45
|
+
|
46
|
+
def hexdigest(self) -> str:
|
47
|
+
return self.hash.hexdigest()
|
48
|
+
|
49
|
+
__str__ = hexdigest
|
50
|
+
|
51
|
+
def update_hash(self, *data: bytes | str, iter: Iterable[bytes | str] = []) -> "ObjectHash":
|
52
|
+
for d in chain(data, iter):
|
53
|
+
self.hash.update(d.encode() if isinstance(d, str) else d)
|
54
|
+
return self
|
55
|
+
|
56
|
+
def header(self, *args: Any) -> "ObjectHash":
|
57
|
+
return self.update_hash(":".join(map(str, args)))
|
58
|
+
|
59
|
+
def update(self, *objs: Any, iter: Iterable[Any] = [], tolerate_errors: bool | None=None) -> "ObjectHash":
|
60
|
+
with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
|
61
|
+
for obj in chain(objs, iter):
|
62
|
+
try:
|
63
|
+
self._update_one(obj)
|
64
|
+
except Exception as ex:
|
65
|
+
if self.tolerate_errors.value:
|
66
|
+
self.header("error").update(type(ex))
|
67
|
+
continue
|
68
|
+
raise ObjectHashError(obj, ex) from ex
|
69
|
+
return self
|
70
|
+
|
71
|
+
def _update_one(self, obj: Any) -> None:
|
72
|
+
match obj:
|
73
|
+
case None:
|
74
|
+
self.header("null")
|
75
|
+
|
76
|
+
case bool() | int() | float() | complex() | Decimal() | ObjectHash():
|
77
|
+
self.header("number", encode_val(obj), obj)
|
78
|
+
|
79
|
+
case str() | bytes() | bytearray() | memoryview():
|
80
|
+
self.header("bytes", encode_val(obj), len(obj)).update_hash(obj)
|
81
|
+
|
82
|
+
case set() | frozenset():
|
83
|
+
self.header("set", encode_val(obj), len(obj))
|
84
|
+
try:
|
85
|
+
items = sorted(obj)
|
86
|
+
except:
|
87
|
+
self.header("unsortable")
|
88
|
+
items = sorted(str(ObjectHash(item, tolerate_errors=self.tolerate_errors.value)) for item in obj)
|
89
|
+
self.update(iter=items)
|
90
|
+
|
91
|
+
case TypeVar():
|
92
|
+
self.header("TypeVar").update(obj.__name__, obj.__bound__, obj.__constraints__, obj.__contravariant__, obj.__covariant__)
|
93
|
+
|
94
|
+
case TypeAliasType():
|
95
|
+
self.header("TypeAliasType").update(obj.__name__, obj.__value__)
|
96
|
+
|
97
|
+
case UnionType():
|
98
|
+
self.header("UnionType").update(obj.__args__)
|
99
|
+
|
100
|
+
case BuiltinFunctionType():
|
101
|
+
self.header("builtin", obj.__qualname__)
|
102
|
+
|
103
|
+
case FunctionType():
|
104
|
+
self.header("function", encode_type(obj)).update(get_fn_body(obj), obj.__defaults__, obj.__kwdefaults__, obj.__annotations__)
|
105
|
+
|
106
|
+
case MethodType():
|
107
|
+
self.header("method").update(obj.__func__, obj.__self__.__class__)
|
108
|
+
|
109
|
+
case ModuleType():
|
110
|
+
self.header("module", obj.__name__, obj.__file__)
|
111
|
+
|
112
|
+
case GeneratorType():
|
113
|
+
self.header("generator", obj.__qualname__)._update_iterator(obj)
|
114
|
+
|
115
|
+
case io.TextIOWrapper() | io.FileIO() | io.BufferedRandom() | io.BufferedWriter() | io.BufferedReader():
|
116
|
+
self.header("file", encode_val(obj)).update(obj.name, obj.mode, obj.tell())
|
117
|
+
|
118
|
+
case type():
|
119
|
+
self.header("type", encode_type(obj))
|
120
|
+
|
121
|
+
case _ if np and isinstance(obj, np.dtype):
|
122
|
+
self.header("dtype").update(obj.__class__, obj.descr)
|
123
|
+
|
124
|
+
case _ if np and isinstance(obj, np.ndarray):
|
125
|
+
self.header("ndarray", encode_val(obj), obj.shape, obj.strides).update(obj.dtype)
|
126
|
+
if obj.dtype.hasobject:
|
127
|
+
self.update(obj.__reduce_ex__(PROTOCOL))
|
128
|
+
else:
|
129
|
+
array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
|
130
|
+
self.update_hash(array.data)
|
131
|
+
|
132
|
+
case _ if torch and isinstance(obj, torch.Tensor):
|
133
|
+
self.header("tensor", encode_val(obj), str(obj.dtype), tuple(obj.shape), obj.stride(), str(obj.device))
|
134
|
+
if obj.device.type != "cpu":
|
135
|
+
obj = obj.cpu()
|
136
|
+
storage = obj.storage()
|
137
|
+
buffer = (ctypes.c_ubyte * (storage.nbytes())).from_address(storage.data_ptr())
|
138
|
+
self.update_hash(memoryview(buffer))
|
139
|
+
|
140
|
+
case _ if id(obj) in self.current:
|
141
|
+
self.header("circular", self.current[id(obj)])
|
142
|
+
|
143
|
+
case _:
|
144
|
+
try:
|
145
|
+
self.current[id(obj)] = len(self.current)
|
146
|
+
match obj:
|
147
|
+
case list() | tuple():
|
148
|
+
self.header("list", encode_val(obj), len(obj)).update(iter=obj)
|
149
|
+
case dict():
|
150
|
+
try:
|
151
|
+
items = sorted(obj.items())
|
152
|
+
except:
|
153
|
+
items = sorted((str(ObjectHash(key, tolerate_errors=self.tolerate_errors.value)), val) for key, val in obj.items())
|
154
|
+
self.header("dict", encode_val(obj), len(obj)).update(iter=chain.from_iterable(items))
|
155
|
+
case _:
|
156
|
+
self._update_object(obj)
|
157
|
+
finally:
|
158
|
+
del self.current[id(obj)]
|
159
|
+
|
160
|
+
def _update_iterator(self, obj: Iterable) -> None:
|
161
|
+
self.header("iterator", encode_val(obj)).update(iter=obj).header(b"iterator-end")
|
162
|
+
|
163
|
+
def _update_object(self, obj: object) -> "ObjectHash":
|
164
|
+
self.header("instance", encode_val(obj))
|
165
|
+
try:
|
166
|
+
reduced = obj.__reduce_ex__(PROTOCOL) if hasattr(obj, "__reduce_ex__") else obj.__reduce__()
|
167
|
+
except:
|
168
|
+
reduced = None
|
169
|
+
if isinstance(reduced, str):
|
170
|
+
return self.header("reduce-str").update(reduced)
|
171
|
+
if reduced:
|
172
|
+
reduced = list(reduced)
|
173
|
+
it = reduced.pop(3) if len(reduced) >= 4 else None
|
174
|
+
self.header("reduce").update(reduced)
|
175
|
+
if it is not None:
|
176
|
+
self._update_iterator(it)
|
177
|
+
return self
|
178
|
+
if state := hasattr(obj, "__getstate__") and obj.__getstate__():
|
179
|
+
return self.header("getstate").update(state)
|
180
|
+
if len(getattr(obj, "__slots__", [])):
|
181
|
+
slots = {slot: getattr(obj, slot, None) for slot in getattr(obj, "__slots__")}
|
182
|
+
return self.header("slots").update(slots)
|
183
|
+
if d := getattr(obj, "__dict__", {}):
|
184
|
+
return self.header("dict").update(d)
|
185
|
+
repr_str = re.sub(r"\s+(at\s+0x[0-9a-fA-F]+)(>)$", r"\2", repr(obj))
|
186
|
+
return self.header("repr").update(repr_str)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import shutil
|
2
2
|
from pathlib import Path
|
3
3
|
from datetime import datetime
|
4
|
-
from
|
4
|
+
from .storage import Storage
|
5
5
|
|
6
6
|
def get_data_type_str(x):
|
7
7
|
if isinstance(x, tuple):
|
@@ -73,9 +73,8 @@ class BcolzStorage(Storage):
|
|
73
73
|
|
74
74
|
def delete(self, path):
|
75
75
|
# NOTE: Not recursive
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
pass
|
76
|
+
shutil.rmtree(get_metapath(path), ignore_errors=True)
|
77
|
+
shutil.rmtree(path, ignore_errors=True)
|
78
|
+
|
79
|
+
def cleanup(self, invalidated=True, expired=True):
|
80
|
+
raise NotImplementedError("cleanup() not implemented for bcolz storage")
|
@@ -1,25 +1,39 @@
|
|
1
1
|
from typing import Any
|
2
2
|
from pathlib import Path
|
3
3
|
from datetime import datetime
|
4
|
-
from
|
4
|
+
from .storage import Storage
|
5
5
|
|
6
|
-
item_map: dict[str, tuple[datetime, Any]] = {}
|
6
|
+
item_map: dict[Path, dict[str, tuple[datetime, Any]]] = {}
|
7
|
+
|
8
|
+
def get_short_path(path: Path):
|
9
|
+
return path.parts[-1]
|
7
10
|
|
8
11
|
class MemoryStorage(Storage):
|
9
|
-
def
|
10
|
-
return
|
12
|
+
def get_dict(self):
|
13
|
+
return item_map.setdefault(self.checkpointer.root_path / self.checkpoint_fn.fn_subdir, {})
|
14
|
+
|
15
|
+
def store(self, path, data):
|
16
|
+
self.get_dict()[get_short_path(path)] = (datetime.now(), data)
|
11
17
|
|
12
18
|
def exists(self, path):
|
13
|
-
return
|
19
|
+
return get_short_path(path) in self.get_dict()
|
14
20
|
|
15
21
|
def checkpoint_date(self, path):
|
16
|
-
return
|
17
|
-
|
18
|
-
def store(self, path, data):
|
19
|
-
item_map[self.get_short_path(path)] = (datetime.now(), data)
|
22
|
+
return self.get_dict()[get_short_path(path)][0]
|
20
23
|
|
21
24
|
def load(self, path):
|
22
|
-
return
|
25
|
+
return self.get_dict()[get_short_path(path)][1]
|
23
26
|
|
24
27
|
def delete(self, path):
|
25
|
-
del
|
28
|
+
del self.get_dict()[get_short_path(path)]
|
29
|
+
|
30
|
+
def cleanup(self, invalidated=True, expired=True):
|
31
|
+
curr_key = self.checkpointer.root_path / self.checkpoint_fn.fn_subdir
|
32
|
+
for key, calldict in list(item_map.items()):
|
33
|
+
if key.parent == curr_key.parent:
|
34
|
+
if invalidated and key != curr_key:
|
35
|
+
del item_map[key]
|
36
|
+
elif expired and self.checkpointer.should_expire:
|
37
|
+
for callid, (date, _) in list(calldict.items()):
|
38
|
+
if self.checkpointer.should_expire(date):
|
39
|
+
del calldict[callid]
|
@@ -1,31 +1,45 @@
|
|
1
1
|
import pickle
|
2
|
+
import shutil
|
2
3
|
from pathlib import Path
|
3
4
|
from datetime import datetime
|
4
|
-
from
|
5
|
+
from .storage import Storage
|
5
6
|
|
6
7
|
def get_path(path: Path):
|
7
8
|
return path.with_name(f"{path.name}.pkl")
|
8
9
|
|
9
10
|
class PickleStorage(Storage):
|
10
|
-
def exists(self, path):
|
11
|
-
return get_path(path).exists()
|
12
|
-
|
13
|
-
def checkpoint_date(self, path):
|
14
|
-
return datetime.fromtimestamp(get_path(path).stat().st_mtime)
|
15
|
-
|
16
11
|
def store(self, path, data):
|
17
12
|
full_path = get_path(path)
|
18
13
|
full_path.parent.mkdir(parents=True, exist_ok=True)
|
19
14
|
with full_path.open("wb") as file:
|
20
15
|
pickle.dump(data, file, -1)
|
21
16
|
|
17
|
+
def exists(self, path):
|
18
|
+
return get_path(path).exists()
|
19
|
+
|
20
|
+
def checkpoint_date(self, path):
|
21
|
+
return datetime.fromtimestamp(get_path(path).stat().st_mtime)
|
22
|
+
|
22
23
|
def load(self, path):
|
23
|
-
|
24
|
-
with full_path.open("rb") as file:
|
24
|
+
with get_path(path).open("rb") as file:
|
25
25
|
return pickle.load(file)
|
26
26
|
|
27
27
|
def delete(self, path):
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
28
|
+
get_path(path).unlink(missing_ok=True)
|
29
|
+
|
30
|
+
def cleanup(self, invalidated=True, expired=True):
|
31
|
+
version_path = self.checkpointer.root_path.resolve() / self.checkpoint_fn.fn_subdir
|
32
|
+
fn_path = version_path.parent
|
33
|
+
if invalidated:
|
34
|
+
old_dirs = [path for path in fn_path.iterdir() if path.is_dir() and path != version_path]
|
35
|
+
for path in old_dirs:
|
36
|
+
shutil.rmtree(path)
|
37
|
+
print(f"Removed {len(old_dirs)} invalidated directories for {self.checkpoint_fn.__qualname__}")
|
38
|
+
if expired and self.checkpointer.should_expire:
|
39
|
+
count = 0
|
40
|
+
for pkl_path in fn_path.rglob("*.pkl"):
|
41
|
+
path = pkl_path.with_suffix("")
|
42
|
+
if self.checkpointer.should_expire(self.checkpoint_date(path)):
|
43
|
+
count += 1
|
44
|
+
self.delete(path)
|
45
|
+
print(f"Removed {count} expired checkpoints for {self.checkpoint_fn.__qualname__}")
|
@@ -4,20 +4,24 @@ from pathlib import Path
|
|
4
4
|
from datetime import datetime
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
|
-
from
|
7
|
+
from ..checkpoint import Checkpointer, CheckpointFn
|
8
8
|
|
9
9
|
class Storage:
|
10
10
|
checkpointer: Checkpointer
|
11
|
+
checkpoint_fn: CheckpointFn
|
11
12
|
|
12
|
-
def __init__(self,
|
13
|
-
self.checkpointer = checkpointer
|
13
|
+
def __init__(self, checkpoint_fn: CheckpointFn):
|
14
|
+
self.checkpointer = checkpoint_fn.checkpointer
|
15
|
+
self.checkpoint_fn = checkpoint_fn
|
16
|
+
|
17
|
+
def store(self, path: Path, data: Any) -> None: ...
|
14
18
|
|
15
19
|
def exists(self, path: Path) -> bool: ...
|
16
20
|
|
17
21
|
def checkpoint_date(self, path: Path) -> datetime: ...
|
18
22
|
|
19
|
-
def store(self, path: Path, data: Any) -> None: ...
|
20
|
-
|
21
23
|
def load(self, path: Path) -> Any: ...
|
22
24
|
|
23
25
|
def delete(self, path: Path) -> None: ...
|
26
|
+
|
27
|
+
def cleanup(self, invalidated=True, expired=True): ...
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import asyncio
|
2
|
+
import pytest
|
3
|
+
from riprint import riprint as print
|
4
|
+
from types import MethodType, MethodWrapperType
|
5
|
+
from . import checkpoint
|
6
|
+
from .checkpoint import CheckpointError
|
7
|
+
from .utils import AttrDict
|
8
|
+
|
9
|
+
def global_multiply(a, b):
|
10
|
+
return a * b
|
11
|
+
|
12
|
+
@pytest.fixture(autouse=True)
|
13
|
+
def run_before_and_after_tests(tmpdir):
|
14
|
+
global checkpoint
|
15
|
+
checkpoint = checkpoint(root_path=tmpdir)
|
16
|
+
yield
|
17
|
+
|
18
|
+
def test_basic_caching():
|
19
|
+
@checkpoint
|
20
|
+
def square(x: int) -> int:
|
21
|
+
return x ** 2
|
22
|
+
|
23
|
+
result1 = square(4)
|
24
|
+
result2 = square(4)
|
25
|
+
|
26
|
+
assert result1 == result2 == 16
|
27
|
+
|
28
|
+
def test_cache_invalidation():
|
29
|
+
@checkpoint
|
30
|
+
def multiply(a, b):
|
31
|
+
return a * b
|
32
|
+
|
33
|
+
@checkpoint
|
34
|
+
def helper(x):
|
35
|
+
return multiply(x + 1, 2)
|
36
|
+
|
37
|
+
@checkpoint
|
38
|
+
def compute(a, b):
|
39
|
+
return helper(a) + helper(b)
|
40
|
+
|
41
|
+
result1 = compute(3, 4)
|
42
|
+
assert result1 == 18
|
43
|
+
|
44
|
+
def test_layered_caching():
|
45
|
+
dev_checkpoint = checkpoint(when=True)
|
46
|
+
|
47
|
+
@checkpoint(format="memory")
|
48
|
+
@dev_checkpoint
|
49
|
+
def expensive_function(x):
|
50
|
+
return x ** 2
|
51
|
+
|
52
|
+
assert expensive_function(4) == 16
|
53
|
+
assert expensive_function(4) == 16
|
54
|
+
|
55
|
+
def test_recursive_caching1():
|
56
|
+
@checkpoint
|
57
|
+
def fib(n: int) -> int:
|
58
|
+
return fib(n - 1) + fib(n - 2) if n > 1 else n
|
59
|
+
|
60
|
+
assert fib(10) == 55
|
61
|
+
assert fib.get(10) == 55
|
62
|
+
assert fib.get(5) == 5
|
63
|
+
|
64
|
+
def test_recursive_caching2():
|
65
|
+
@checkpoint
|
66
|
+
def fib(n: int) -> int:
|
67
|
+
return fib.fn(n - 1) + fib.fn(n - 2) if n > 1 else n
|
68
|
+
|
69
|
+
assert fib(10) == 55
|
70
|
+
assert fib.get(10) == 55
|
71
|
+
with pytest.raises(CheckpointError):
|
72
|
+
fib.get(5)
|
73
|
+
|
74
|
+
@pytest.mark.asyncio
|
75
|
+
async def test_async_caching():
|
76
|
+
@checkpoint(format="memory")
|
77
|
+
async def async_square(x: int) -> int:
|
78
|
+
await asyncio.sleep(0.1)
|
79
|
+
return x ** 2
|
80
|
+
|
81
|
+
result1 = await async_square(3)
|
82
|
+
result2 = await async_square.get(3)
|
83
|
+
|
84
|
+
assert result1 == result2 == 9
|
85
|
+
|
86
|
+
def test_custom_path_caching():
|
87
|
+
def custom_path(a, b):
|
88
|
+
return f"add/{a}-{b}"
|
89
|
+
|
90
|
+
@checkpoint(path=custom_path)
|
91
|
+
def add(a, b):
|
92
|
+
return a + b
|
93
|
+
|
94
|
+
add(3, 4)
|
95
|
+
assert (checkpoint.root_path / "add/3-4.pkl").exists()
|
96
|
+
|
97
|
+
def test_force_recalculation():
|
98
|
+
@checkpoint
|
99
|
+
def square(x: int) -> int:
|
100
|
+
return x ** 2
|
101
|
+
|
102
|
+
assert square(5) == 25
|
103
|
+
square.rerun(5)
|
104
|
+
assert square.get(5) == 25
|
105
|
+
|
106
|
+
def test_multi_layer_decorator():
|
107
|
+
@checkpoint(format="memory")
|
108
|
+
@checkpoint(format="pickle")
|
109
|
+
def add(a, b):
|
110
|
+
return a + b
|
111
|
+
|
112
|
+
assert add(2, 3) == 5
|
113
|
+
assert add.get(2, 3) == 5
|
114
|
+
|
115
|
+
def test_capture():
|
116
|
+
item_dict = AttrDict({"a": 1, "b": 1})
|
117
|
+
|
118
|
+
@checkpoint(capture=True)
|
119
|
+
def test_whole():
|
120
|
+
return item_dict
|
121
|
+
|
122
|
+
@checkpoint(capture=True)
|
123
|
+
def test_a():
|
124
|
+
return item_dict.a + 1
|
125
|
+
|
126
|
+
init_hash_a = test_a.fn_hash
|
127
|
+
init_hash_whole = test_whole.fn_hash
|
128
|
+
item_dict.b += 1
|
129
|
+
test_whole.reinit()
|
130
|
+
test_a.reinit()
|
131
|
+
assert test_whole.fn_hash != init_hash_whole
|
132
|
+
assert test_a.fn_hash == init_hash_a
|
133
|
+
item_dict.a += 1
|
134
|
+
test_a.reinit()
|
135
|
+
assert test_a.fn_hash != init_hash_a
|
136
|
+
|
137
|
+
def test_depends():
|
138
|
+
def multiply_wrapper(a, b):
|
139
|
+
return global_multiply(a, b)
|
140
|
+
|
141
|
+
def helper(a, b):
|
142
|
+
return multiply_wrapper(a + 1, b + 1)
|
143
|
+
|
144
|
+
@checkpoint
|
145
|
+
def test_a(a, b):
|
146
|
+
return helper(a, b)
|
147
|
+
|
148
|
+
@checkpoint
|
149
|
+
def test_b(a, b):
|
150
|
+
return test_a(a, b) + multiply_wrapper(a, b)
|
151
|
+
|
152
|
+
assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
|
153
|
+
assert set(test_b.depends) == {test_b.fn, test_a, multiply_wrapper, global_multiply}
|
154
|
+
|
155
|
+
def test_lazy_init():
|
156
|
+
@checkpoint
|
157
|
+
def fn1(x):
|
158
|
+
return fn2(x)
|
159
|
+
|
160
|
+
@checkpoint
|
161
|
+
def fn2(x):
|
162
|
+
return fn1(x)
|
163
|
+
|
164
|
+
assert type(object.__getattribute__(fn1, "_getattribute")) == MethodType
|
165
|
+
with pytest.raises(AttributeError):
|
166
|
+
object.__getattribute__(fn1, "fn_hash")
|
167
|
+
assert fn1.fn_hash == object.__getattribute__(fn1, "fn_hash")
|
168
|
+
assert type(object.__getattribute__(fn1, "_getattribute")) == MethodWrapperType
|
169
|
+
assert set(fn1.depends) == {fn1.fn, fn2}
|
170
|
+
assert set(fn2.depends) == {fn1, fn2.fn}
|
checkpointer/utils.py
CHANGED
@@ -1,22 +1,44 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
from
|
1
|
+
import inspect
|
2
|
+
import tokenize
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from io import StringIO
|
5
|
+
from types import coroutine
|
6
|
+
from typing import Any, Callable, Coroutine, Generator, Iterable, cast
|
4
7
|
|
5
|
-
|
6
|
-
|
7
|
-
super(AttrDict, self).__init__(*args, **kwargs)
|
8
|
-
self.__dict__ = self
|
8
|
+
def distinct[T](seq: Iterable[T]) -> list[T]:
|
9
|
+
return list(dict.fromkeys(seq))
|
9
10
|
|
10
|
-
|
11
|
-
|
11
|
+
def transpose(tuples, default_num_returns=0):
|
12
|
+
output = tuple(zip(*tuples))
|
13
|
+
if not output:
|
14
|
+
return ([],) * default_num_returns
|
15
|
+
return tuple(map(list, output))
|
16
|
+
|
17
|
+
def get_fn_body(fn: Callable) -> str:
|
18
|
+
try:
|
19
|
+
source = inspect.getsource(fn)
|
20
|
+
except OSError:
|
21
|
+
return ""
|
22
|
+
tokens = tokenize.generate_tokens(StringIO(source).readline)
|
23
|
+
ignore_types = (tokenize.COMMENT, tokenize.NL)
|
24
|
+
return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
|
12
25
|
|
13
|
-
def
|
26
|
+
def get_cell_contents(fn: Callable) -> Generator[tuple[str, Any], None, None]:
|
27
|
+
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
|
28
|
+
try:
|
29
|
+
yield (key, cell.cell_contents)
|
30
|
+
except ValueError:
|
31
|
+
pass
|
32
|
+
|
33
|
+
def unwrap_fn[T: Callable](fn: T, checkpoint_fn=False) -> T:
|
14
34
|
from .checkpoint import CheckpointFn
|
15
|
-
while
|
16
|
-
if checkpoint_fn and isinstance(fn, CheckpointFn):
|
17
|
-
return fn
|
35
|
+
while True:
|
36
|
+
if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
|
37
|
+
return cast(T, fn)
|
18
38
|
fn = getattr(fn, "__wrapped__")
|
19
|
-
|
39
|
+
|
40
|
+
async def resolved_awaitable[T](value: T) -> T:
|
41
|
+
return value
|
20
42
|
|
21
43
|
@coroutine
|
22
44
|
def coroutine_as_generator[T](coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
|
@@ -26,27 +48,65 @@ def coroutine_as_generator[T](coroutine: Coroutine[None, None, T]) -> Generator[
|
|
26
48
|
def sync_resolve_coroutine[T](coroutine: Coroutine[None, None, T]) -> T:
|
27
49
|
gen = cast(Generator, coroutine_as_generator(coroutine))
|
28
50
|
try:
|
29
|
-
while True:
|
51
|
+
while True:
|
52
|
+
next(gen)
|
30
53
|
except StopIteration as ex:
|
31
54
|
return ex.value
|
32
55
|
|
33
|
-
|
34
|
-
|
56
|
+
class AttrDict(dict):
|
57
|
+
def __init__(self, *args, **kwargs):
|
58
|
+
super().__init__(*args, **kwargs)
|
59
|
+
self.__dict__ = self
|
35
60
|
|
36
|
-
def
|
37
|
-
|
38
|
-
yield item, islice(l, i + 1, None)
|
61
|
+
def __getattribute__(self, name: str) -> Any:
|
62
|
+
return super().__getattribute__(name)
|
39
63
|
|
40
|
-
def
|
41
|
-
|
42
|
-
for key in keys:
|
43
|
-
d = getattr(d, key)
|
44
|
-
except AttributeError:
|
45
|
-
return None
|
46
|
-
return d
|
64
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
65
|
+
return super().__setattr__(name, value)
|
47
66
|
|
48
|
-
def
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
67
|
+
def set(self, d: dict) -> "AttrDict":
|
68
|
+
if not d:
|
69
|
+
return self
|
70
|
+
return AttrDict({**self, **d})
|
71
|
+
|
72
|
+
def delete(self, *attrs: str) -> "AttrDict":
|
73
|
+
d = AttrDict(self)
|
74
|
+
for attr in attrs:
|
75
|
+
del d[attr]
|
76
|
+
return d
|
77
|
+
|
78
|
+
def get_at(self, attrs: tuple[str, ...]) -> Any:
|
79
|
+
d = self
|
80
|
+
for attr in attrs:
|
81
|
+
d = getattr(d, attr, None)
|
82
|
+
return d
|
83
|
+
|
84
|
+
class ContextVar[T]:
|
85
|
+
def __init__(self, value: T):
|
86
|
+
self.value = value
|
87
|
+
|
88
|
+
@contextmanager
|
89
|
+
def set(self, value: T):
|
90
|
+
self.value, old = value, self.value
|
91
|
+
try:
|
92
|
+
yield
|
93
|
+
finally:
|
94
|
+
self.value = old
|
95
|
+
|
96
|
+
class iterate_and_upcoming[T]:
|
97
|
+
def __init__(self, it: Iterable[T]) -> None:
|
98
|
+
self.it = iter(it)
|
99
|
+
self.previous: tuple[()] | tuple[T] = ()
|
100
|
+
|
101
|
+
def __iter__(self):
|
102
|
+
return self
|
103
|
+
|
104
|
+
def __next__(self) -> tuple[T, Iterable[T]]:
|
105
|
+
item = self.previous[0] if self.previous else next(self.it)
|
106
|
+
self.previous = ()
|
107
|
+
return item, self._tracked_iter()
|
108
|
+
|
109
|
+
def _tracked_iter(self):
|
110
|
+
for x in self.it:
|
111
|
+
self.previous = (x,)
|
112
|
+
yield x
|
@@ -1,18 +1,18 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.5.0
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
7
|
-
License: Copyright
|
7
|
+
License: Copyright 2018-2025 Hampus Hallman
|
8
8
|
|
9
9
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
10
10
|
|
11
11
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
12
12
|
|
13
13
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
14
|
+
License-File: LICENSE
|
14
15
|
Requires-Python: >=3.12
|
15
|
-
Requires-Dist: relib
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
|
18
18
|
# checkpointer · [](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [](https://pypi.org/project/checkpointer/) [](https://pypi.org/project/checkpointer/)
|
@@ -108,7 +108,7 @@ Layer caches by stacking checkpoints:
|
|
108
108
|
@dev_checkpoint # Adds caching during development
|
109
109
|
def some_expensive_function():
|
110
110
|
print("Performing a time-consuming operation...")
|
111
|
-
return sum(i * i for i in range(10**
|
111
|
+
return sum(i * i for i in range(10**8))
|
112
112
|
```
|
113
113
|
|
114
114
|
- **In development**: Both `dev_checkpoint` and `memory` caches are active.
|
@@ -153,6 +153,18 @@ Access cached results without recalculating:
|
|
153
153
|
stored_result = expensive_function.get(4)
|
154
154
|
```
|
155
155
|
|
156
|
+
### Refresh Function Hash
|
157
|
+
|
158
|
+
When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
|
159
|
+
|
160
|
+
Use the `reinit` method to manually refresh the function's hash within the same session:
|
161
|
+
|
162
|
+
```python
|
163
|
+
expensive_function.reinit()
|
164
|
+
```
|
165
|
+
|
166
|
+
This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
|
167
|
+
|
156
168
|
---
|
157
169
|
|
158
170
|
## Storage Backends
|
@@ -189,9 +201,9 @@ from checkpointer import checkpoint, Storage
|
|
189
201
|
from datetime import datetime
|
190
202
|
|
191
203
|
class CustomStorage(Storage):
|
204
|
+
def store(self, path, data): ... # Save the checkpoint data
|
192
205
|
def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
|
193
206
|
def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
|
194
|
-
def store(self, path, data): ... # Save the checkpoint data
|
195
207
|
def load(self, path): ... # Return the checkpoint data
|
196
208
|
def delete(self, path): ... # Delete the checkpoint
|
197
209
|
|
@@ -0,0 +1,16 @@
|
|
1
|
+
checkpointer/__init__.py,sha256=ZJ6frUNgkklUi85b5uXTyTfRzMvZgQOJY-ZOnu7jh78,777
|
2
|
+
checkpointer/checkpoint.py,sha256=FeizwZf0r6j_xy8EOyDKXqcfCNNZnUYBziVbxPu9kwE,6284
|
3
|
+
checkpointer/fn_ident.py,sha256=_GyIfoUvEpjZ4dUAa04NK4pJdSmDAXuufjt7z2xQP8w,4316
|
4
|
+
checkpointer/object_hash.py,sha256=ekpKXtbKtHBl5e9s-uyng8qOHTFl9CCT6QHlQTZTQn8,6860
|
5
|
+
checkpointer/print_checkpoint.py,sha256=21aeqgM9CMjNAJyScqFmXCWWfh3jBIn7o7i5zJkZGaA,1369
|
6
|
+
checkpointer/test_checkpointer.py,sha256=qpG_p4DVMlOBxt71v93-GTKve08EQaExFAf6xSv0wUg,3821
|
7
|
+
checkpointer/utils.py,sha256=Rvm2NaJHtPTusM7fyHz_w9HUy_fqQfx8S1fr5CBWGL0,3047
|
8
|
+
checkpointer/storages/__init__.py,sha256=Kl4Og5jhYxn6m3tB_kTMsabf4_eWVLmFVAoC-pikNQE,301
|
9
|
+
checkpointer/storages/bcolz_storage.py,sha256=3QkSUSeG5s2kFuVV_LZpzMn1A5E7kqC7jk7w35c0NyQ,2314
|
10
|
+
checkpointer/storages/memory_storage.py,sha256=S5ayOZE_CyaFQJ-vSgObTanldPzG3gh3NksjNAc7vsk,1282
|
11
|
+
checkpointer/storages/pickle_storage.py,sha256=lJ0ton9ib3eifiny8XtPSNsx-w4Cm8oYUlbmKob34xU,1554
|
12
|
+
checkpointer/storages/storage.py,sha256=_m18Z8TKrdAbi6YYYQmuNOnhna4RB2sJDn1v3liaU3U,721
|
13
|
+
checkpointer-2.5.0.dist-info/METADATA,sha256=yd3D8zAWdCe4011MLBwknwDvSYx3LGGSmk6AzDXmMUg,10647
|
14
|
+
checkpointer-2.5.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
checkpointer-2.5.0.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
|
16
|
+
checkpointer-2.5.0.dist-info/RECORD,,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
Copyright
|
1
|
+
Copyright 2018-2025 Hampus Hallman
|
2
2
|
|
3
3
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4
4
|
|
checkpointer/function_body.py
DELETED
@@ -1,80 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
import dis
|
3
|
-
import inspect
|
4
|
-
import tokenize
|
5
|
-
from io import StringIO
|
6
|
-
from collections.abc import Callable
|
7
|
-
from itertools import chain, takewhile
|
8
|
-
from operator import itemgetter
|
9
|
-
from pathlib import Path
|
10
|
-
from typing import Any, TypeGuard, TYPE_CHECKING
|
11
|
-
from types import CodeType, FunctionType
|
12
|
-
from relib import transpose, hashing, merge_dicts, drop_none
|
13
|
-
from .utils import unwrap_fn, iterate_and_upcoming, get_cell_contents, AttrDict, get_at_attr
|
14
|
-
|
15
|
-
if TYPE_CHECKING:
|
16
|
-
from .checkpoint import CheckpointFn
|
17
|
-
|
18
|
-
cwd = Path.cwd()
|
19
|
-
|
20
|
-
def extract_scope_values(code: CodeType, scope_vars: dict[str, Any], closure = False) -> dict[tuple[str, ...], Any]:
|
21
|
-
opname = "LOAD_GLOBAL" if not closure else "LOAD_DEREF"
|
22
|
-
scope_values_by_path: dict[tuple[str, ...], Any] = {}
|
23
|
-
instructions = list(dis.get_instructions(code))
|
24
|
-
|
25
|
-
for instr, upcoming_instrs in iterate_and_upcoming(instructions):
|
26
|
-
if instr.opname == opname:
|
27
|
-
name = instr.argval
|
28
|
-
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
29
|
-
attr_path = (name, *(instr.argval for instr in attrs))
|
30
|
-
scope_values_by_path[attr_path] = get_at_attr(scope_vars, attr_path)
|
31
|
-
|
32
|
-
children = (extract_scope_values(const, scope_vars, closure) for const in code.co_consts if isinstance(const, CodeType))
|
33
|
-
return merge_dicts(scope_values_by_path, *children)
|
34
|
-
|
35
|
-
def get_fn_captured_vals(fn: Callable) -> list[Any]:
|
36
|
-
closure_scope = {k: get_cell_contents(v) for k, v in zip(fn.__code__.co_freevars, fn.__closure__ or [])}
|
37
|
-
global_vals = extract_scope_values(fn.__code__, AttrDict(fn.__globals__), closure=False)
|
38
|
-
closure_vals = extract_scope_values(fn.__code__, AttrDict(closure_scope), closure=True)
|
39
|
-
sorted_items = chain(sorted(global_vals.items()), sorted(closure_vals.items()))
|
40
|
-
return drop_none(map(itemgetter(1), sorted_items))
|
41
|
-
|
42
|
-
def get_fn_body(fn: Callable) -> str:
|
43
|
-
source = "".join(inspect.getsourcelines(fn)[0])
|
44
|
-
tokens = tokenize.generate_tokens(StringIO(source).readline)
|
45
|
-
ignore_types = (tokenize.COMMENT, tokenize.NL)
|
46
|
-
return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
|
47
|
-
|
48
|
-
def get_fn_path(fn: Callable) -> Path:
|
49
|
-
return Path(inspect.getfile(fn)).resolve()
|
50
|
-
|
51
|
-
def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
52
|
-
return isinstance(candidate_fn, FunctionType) \
|
53
|
-
and cwd in get_fn_path(candidate_fn).parents
|
54
|
-
|
55
|
-
def append_fn_depends(checkpoint_fns: set[CheckpointFn], captured_vals_by_fn: dict[Callable, list[Any]], fn: Callable, capture: bool) -> None:
|
56
|
-
from .checkpoint import CheckpointFn
|
57
|
-
captured_vals = get_fn_captured_vals(fn)
|
58
|
-
captured_vals_by_fn[fn] = [v for v in captured_vals if capture and not callable(v)]
|
59
|
-
callables = [unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val)]
|
60
|
-
depends = {val for val in callables if is_user_fn(val)}
|
61
|
-
checkpoint_fns.update({val for val in callables if isinstance(val, CheckpointFn)})
|
62
|
-
not_appended = depends - captured_vals_by_fn.keys()
|
63
|
-
captured_vals_by_fn.update({fn: [] for fn in not_appended})
|
64
|
-
for child_fn in not_appended:
|
65
|
-
append_fn_depends(checkpoint_fns, captured_vals_by_fn, child_fn, capture)
|
66
|
-
|
67
|
-
def get_depend_fns(fn: Callable, capture: bool) -> tuple[set[CheckpointFn], dict[Callable, list[Any]]]:
|
68
|
-
checkpoint_fns: set[CheckpointFn] = set()
|
69
|
-
captured_vals_by_fn: dict[Callable, list[Any]] = {}
|
70
|
-
append_fn_depends(checkpoint_fns, captured_vals_by_fn, fn, capture)
|
71
|
-
return checkpoint_fns, captured_vals_by_fn
|
72
|
-
|
73
|
-
def get_function_hash(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
|
74
|
-
checkpoint_fns, captured_vals_by_fn = get_depend_fns(fn, capture)
|
75
|
-
checkpoint_fns = sorted(checkpoint_fns, key=lambda fn: unwrap_fn(fn).__qualname__)
|
76
|
-
checkpoint_hashes = [check.fn_hash for check in checkpoint_fns]
|
77
|
-
depend_fns, depend_captured_vals = transpose(sorted(captured_vals_by_fn.items(), key=lambda x: x[0].__qualname__), 2)
|
78
|
-
fn_bodies = list(map(get_fn_body, [fn] + depend_fns))
|
79
|
-
fn_hash = hashing.hash((fn_bodies, depend_captured_vals, checkpoint_hashes), "blake2b")
|
80
|
-
return fn_hash, checkpoint_fns + depend_fns
|
@@ -1,14 +0,0 @@
|
|
1
|
-
checkpointer/__init__.py,sha256=t-dv0hIfgJHFx2M8tjCUMC9DlucPM8hvJOwGv86owUo,411
|
2
|
-
checkpointer/checkpoint.py,sha256=NHY_63EzlY3X6eqbOBE-dIprMZ_-X_GRC-nhy6cI1QQ,4990
|
3
|
-
checkpointer/function_body.py,sha256=DAq5fj1MMgb3az_Pfdxzqg6woJ6esFgvaqkkKqogJBY,4074
|
4
|
-
checkpointer/print_checkpoint.py,sha256=21aeqgM9CMjNAJyScqFmXCWWfh3jBIn7o7i5zJkZGaA,1369
|
5
|
-
checkpointer/types.py,sha256=SslunQTXxovFuGOR_VKfL7z5Vif9RD1PPx0J1FQdGLw,564
|
6
|
-
checkpointer/utils.py,sha256=qT7pk3o6GkX-1ylfi6I-DJO5fdVOIHMuaEK8dTEAoVw,1465
|
7
|
-
checkpointer/storages/__init__.py,sha256=G7JrOAyCGITd1wOz-u6_4RZVgxzxGLVLHPwBuW1sx1U,300
|
8
|
-
checkpointer/storages/bcolz_storage.py,sha256=UoeREc3oS8skFClu9sULpgpqbIVcp3tVd8CeYfAe5yM,2220
|
9
|
-
checkpointer/storages/memory_storage.py,sha256=RQ4WTVapxJGVPv1DNlb9VFTifxtyQy8YVo8fwaRLfdk,692
|
10
|
-
checkpointer/storages/pickle_storage.py,sha256=nyrBWLXKnyzXgZIMwrpWUOAGRozpX3jL9pCyCV29e4E,787
|
11
|
-
checkpointer-2.1.0.dist-info/METADATA,sha256=tdpLxisGi4Wx3gvs_W52t1SQqRwgDuke26pkfzHVuKY,9926
|
12
|
-
checkpointer-2.1.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
13
|
-
checkpointer-2.1.0.dist-info/licenses/LICENSE,sha256=0cmUKqBotzbBcysIexd52AhjwbphhlGYiWbvg5l2QAU,1054
|
14
|
-
checkpointer-2.1.0.dist-info/RECORD,,
|