checkpointer 2.12.0__py3-none-any.whl → 2.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer/__init__.py +2 -2
- checkpointer/checkpoint.py +74 -62
- checkpointer/fn_ident.py +36 -34
- checkpointer/fn_string.py +77 -0
- checkpointer/object_hash.py +22 -17
- checkpointer/storages/storage.py +5 -4
- checkpointer/utils.py +13 -15
- checkpointer-2.13.0.dist-info/METADATA +260 -0
- checkpointer-2.13.0.dist-info/RECORD +18 -0
- checkpointer-2.12.0.dist-info/METADATA +0 -236
- checkpointer-2.12.0.dist-info/RECORD +0 -17
- {checkpointer-2.12.0.dist-info → checkpointer-2.13.0.dist-info}/WHEEL +0 -0
- {checkpointer-2.12.0.dist-info → checkpointer-2.13.0.dist-info}/licenses/ATTRIBUTION.md +0 -0
- {checkpointer-2.12.0.dist-info → checkpointer-2.13.0.dist-info}/licenses/LICENSE +0 -0
checkpointer/__init__.py
CHANGED
@@ -8,8 +8,8 @@ from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeO
|
|
8
8
|
|
9
9
|
checkpoint = Checkpointer()
|
10
10
|
capture_checkpoint = Checkpointer(capture=True)
|
11
|
-
memory_checkpoint = Checkpointer(
|
12
|
-
tmp_checkpoint = Checkpointer(
|
11
|
+
memory_checkpoint = Checkpointer(storage="memory", verbosity=0)
|
12
|
+
tmp_checkpoint = Checkpointer(directory=f"{tempfile.gettempdir()}/checkpoints")
|
13
13
|
static_checkpoint = Checkpointer(fn_hash_from=())
|
14
14
|
|
15
15
|
def cleanup_all(invalidated=True, expired=True):
|
checkpointer/checkpoint.py
CHANGED
@@ -21,8 +21,8 @@ class CheckpointError(Exception):
|
|
21
21
|
pass
|
22
22
|
|
23
23
|
class CheckpointerOpts(TypedDict, total=False):
|
24
|
-
|
25
|
-
|
24
|
+
storage: Type[Storage] | StorageType
|
25
|
+
directory: Path | str | None
|
26
26
|
when: bool
|
27
27
|
verbosity: Literal[0, 1, 2]
|
28
28
|
should_expire: Callable[[datetime], bool] | None
|
@@ -31,8 +31,8 @@ class CheckpointerOpts(TypedDict, total=False):
|
|
31
31
|
|
32
32
|
class Checkpointer:
|
33
33
|
def __init__(self, **opts: Unpack[CheckpointerOpts]):
|
34
|
-
self.
|
35
|
-
self.
|
34
|
+
self.storage = opts.get("storage", "pickle")
|
35
|
+
self.directory = Path(opts.get("directory", DEFAULT_DIR) or ".")
|
36
36
|
self.when = opts.get("when", True)
|
37
37
|
self.verbosity = opts.get("verbosity", 1)
|
38
38
|
self.should_expire = opts.get("should_expire")
|
@@ -56,24 +56,46 @@ class FunctionIdent:
|
|
56
56
|
Separated from CachedFunction to prevent hash desynchronization
|
57
57
|
among bound instances when `.reinit()` is called.
|
58
58
|
"""
|
59
|
-
|
60
|
-
|
59
|
+
__slots__ = (
|
60
|
+
"checkpointer", "cached_fn", "fn", "fn_dir", "pos_names",
|
61
|
+
"arg_names", "default_args", "hash_by_map", "__dict__",
|
62
|
+
)
|
63
|
+
|
64
|
+
def __init__(self, cached_fn: CachedFunction, checkpointer: Checkpointer, fn: Callable):
|
65
|
+
wrapped = unwrap(fn)
|
66
|
+
fn_file = Path(wrapped.__code__.co_filename).name
|
67
|
+
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
68
|
+
params = list(signature(wrapped).parameters.values())
|
69
|
+
pos_param_types = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
70
|
+
named_param_types = (Parameter.KEYWORD_ONLY,) + pos_param_types
|
71
|
+
name_by_kind = {Parameter.VAR_POSITIONAL: b"*", Parameter.VAR_KEYWORD: b"**"}
|
72
|
+
self.checkpointer = checkpointer
|
61
73
|
self.cached_fn = cached_fn
|
74
|
+
self.fn = fn
|
75
|
+
self.fn_dir = f"{fn_file}/{fn_name}"
|
76
|
+
self.pos_names = [param.name for param in params if param.kind in pos_param_types]
|
77
|
+
self.arg_names = {param.name for param in params if param.kind in named_param_types}
|
78
|
+
self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
|
79
|
+
self.hash_by_map = {
|
80
|
+
name_by_kind.get(param.kind, param.name): hash_by
|
81
|
+
for param in params
|
82
|
+
if (hash_by := hash_by_from_annotation(param.annotation))
|
83
|
+
}
|
62
84
|
|
63
85
|
def reset(self):
|
64
|
-
self.
|
86
|
+
self.__dict__.clear()
|
65
87
|
|
66
88
|
def is_static(self) -> bool:
|
67
|
-
return self.
|
89
|
+
return self.checkpointer.fn_hash_from is not None
|
68
90
|
|
69
91
|
@cached_property
|
70
92
|
def raw_ident(self) -> RawFunctionIdent:
|
71
|
-
return get_fn_ident(unwrap(self.
|
93
|
+
return get_fn_ident(unwrap(self.fn), self.checkpointer.capture)
|
72
94
|
|
73
95
|
@cached_property
|
74
96
|
def fn_hash(self) -> str:
|
75
97
|
if self.is_static():
|
76
|
-
return str(ObjectHash(self.
|
98
|
+
return str(ObjectHash(self.checkpointer.fn_hash_from, digest_size=16))
|
77
99
|
depends = self.deep_idents(past_static=False)
|
78
100
|
deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
|
79
101
|
return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
|
@@ -105,29 +127,21 @@ class FunctionIdent:
|
|
105
127
|
|
106
128
|
class CachedFunction(Generic[Fn]):
|
107
129
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
update_wrapper(cast(Callable, self), wrapped)
|
113
|
-
self.checkpointer = checkpointer
|
114
|
-
self.fn = fn
|
115
|
-
self.fn_dir = f"{fn_file}/{fn_name}"
|
130
|
+
store_format = checkpointer.storage
|
131
|
+
Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
|
132
|
+
update_wrapper(cast(Callable, self), unwrap(fn))
|
133
|
+
self.ident = FunctionIdent(self, checkpointer, fn)
|
116
134
|
self.storage = Storage(self)
|
117
|
-
self.cleanup = self.storage.cleanup
|
118
135
|
self.bound = ()
|
119
136
|
|
120
|
-
params = list(signature(wrapped).parameters.values())
|
121
|
-
pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
122
|
-
self.arg_names = [param.name for param in params if param.kind in pos_params]
|
123
|
-
self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
|
124
|
-
self.hash_by_map = get_hash_by_map(params)
|
125
|
-
self.ident = FunctionIdent(self)
|
126
|
-
|
127
137
|
@overload
|
128
138
|
def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
|
129
139
|
@overload
|
130
|
-
def __get__(
|
140
|
+
def __get__(
|
141
|
+
self: CachedFunction[Callable[Concatenate[C, P], R]],
|
142
|
+
instance: C,
|
143
|
+
owner: Type[C],
|
144
|
+
) -> CachedFunction[Callable[P, R]]: ...
|
131
145
|
def __get__(self, instance, owner):
|
132
146
|
if instance is None:
|
133
147
|
return self
|
@@ -137,8 +151,12 @@ class CachedFunction(Generic[Fn]):
|
|
137
151
|
return bound_fn
|
138
152
|
|
139
153
|
@property
|
140
|
-
def
|
141
|
-
return self.ident.
|
154
|
+
def fn(self) -> Fn:
|
155
|
+
return cast(Fn, self.ident.fn)
|
156
|
+
|
157
|
+
@property
|
158
|
+
def cleanup(self):
|
159
|
+
return self.storage.cleanup
|
142
160
|
|
143
161
|
def reinit(self, recursive=False) -> CachedFunction[Fn]:
|
144
162
|
depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
|
@@ -147,52 +165,55 @@ class CachedFunction(Generic[Fn]):
|
|
147
165
|
return self
|
148
166
|
|
149
167
|
def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
|
168
|
+
ident = self.ident
|
150
169
|
args = self.bound + args
|
151
|
-
pos_args = args[len(
|
152
|
-
named_pos_args = dict(zip(
|
153
|
-
named_args = {**
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
170
|
+
pos_args = args[len(ident.pos_names):]
|
171
|
+
named_pos_args = dict(zip(ident.pos_names, args))
|
172
|
+
named_args = {**ident.default_args, **named_pos_args, **kw}
|
173
|
+
for key, hash_by in ident.hash_by_map.items():
|
174
|
+
if isinstance(key, str):
|
175
|
+
named_args[key] = hash_by(named_args[key])
|
176
|
+
elif key == b"*":
|
177
|
+
pos_args = map(hash_by, pos_args)
|
178
|
+
elif key == b"**":
|
179
|
+
for key in kw.keys() - ident.arg_names:
|
180
|
+
named_args[key] = hash_by(named_args[key])
|
161
181
|
named_args_iter = chain.from_iterable(sorted(named_args.items()))
|
162
|
-
captured = chain.from_iterable(capturable.capture() for capturable in
|
163
|
-
|
182
|
+
captured = chain.from_iterable(capturable.capture() for capturable in ident.capturables)
|
183
|
+
call_hash = ObjectHash(digest_size=16) \
|
164
184
|
.update(iter=named_args_iter, header="NAMED") \
|
165
185
|
.update(iter=pos_args, header="POS") \
|
166
186
|
.update(iter=captured, header="CAPTURED")
|
167
|
-
return str(
|
187
|
+
return str(call_hash)
|
168
188
|
|
169
189
|
def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
|
170
190
|
return self._get_call_hash(args, kw)
|
171
191
|
|
172
|
-
async def
|
192
|
+
async def _store_coroutine(self, call_hash: str, coroutine: Coroutine):
|
173
193
|
return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
|
174
194
|
|
175
195
|
def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
|
176
196
|
full_args = self.bound + args
|
177
|
-
params = self.checkpointer
|
197
|
+
params = self.ident.checkpointer
|
198
|
+
storage = self.storage
|
178
199
|
if not params.when:
|
179
200
|
return self.fn(*full_args, **kw)
|
180
201
|
|
181
202
|
call_hash = self._get_call_hash(args, kw)
|
182
|
-
call_id = f"{
|
203
|
+
call_id = f"{storage.fn_id()}/{call_hash}"
|
183
204
|
refresh = rerun \
|
184
|
-
or not
|
185
|
-
or (params.should_expire and params.should_expire(
|
205
|
+
or not storage.exists(call_hash) \
|
206
|
+
or (params.should_expire and params.should_expire(storage.checkpoint_date(call_hash)))
|
186
207
|
|
187
208
|
if refresh:
|
188
209
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id, "blue")
|
189
210
|
data = self.fn(*full_args, **kw)
|
190
211
|
if iscoroutine(data):
|
191
|
-
return self.
|
192
|
-
return
|
212
|
+
return self._store_coroutine(call_hash, data)
|
213
|
+
return storage.store(call_hash, data)
|
193
214
|
|
194
215
|
try:
|
195
|
-
data =
|
216
|
+
data = storage.load(call_hash)
|
196
217
|
print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id, "green")
|
197
218
|
return data
|
198
219
|
except (EOFError, FileNotFoundError):
|
@@ -232,15 +253,6 @@ class CachedFunction(Generic[Fn]):
|
|
232
253
|
self.storage.store(self._get_call_hash(args, kw), value)
|
233
254
|
|
234
255
|
def __repr__(self) -> str:
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
hash_by_map = {}
|
239
|
-
for param in params:
|
240
|
-
name = param.name
|
241
|
-
if param.kind == Parameter.VAR_POSITIONAL:
|
242
|
-
name = b"*"
|
243
|
-
elif param.kind == Parameter.VAR_KEYWORD:
|
244
|
-
name = b"**"
|
245
|
-
hash_by_map[name] = hash_by_from_annotation(param.annotation)
|
246
|
-
return hash_by_map if any(hash_by_map.values()) else {}
|
256
|
+
initialized = "fn_hash" in self.ident.__dict__
|
257
|
+
fn_hash = self.ident.fn_hash[:6] if initialized else "- uninitialized"
|
258
|
+
return f"<CachedFunction {self.fn.__name__} {fn_hash}>"
|
checkpointer/fn_ident.py
CHANGED
@@ -2,11 +2,12 @@ import dis
|
|
2
2
|
from inspect import Parameter, getmodule, signature, unwrap
|
3
3
|
from types import CodeType, MethodType, ModuleType
|
4
4
|
from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
|
5
|
+
from .fn_string import get_fn_aststr
|
5
6
|
from .import_mappings import resolve_annotation
|
6
7
|
from .object_hash import ObjectHash
|
7
8
|
from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
|
8
9
|
from .utils import (
|
9
|
-
|
10
|
+
cwd, distinct, get_at, get_cell_contents,
|
10
11
|
get_file, is_class, is_user_fn, seekable, takewhile,
|
11
12
|
)
|
12
13
|
|
@@ -28,61 +29,61 @@ class Capturable(NamedTuple):
|
|
28
29
|
def capture(self) -> tuple[str, object]:
|
29
30
|
if obj := self.hash:
|
30
31
|
return self.key, obj
|
31
|
-
obj =
|
32
|
+
obj = get_at(self.module, *self.attr_path)
|
32
33
|
obj = self.hash_by(obj) if self.hash_by else obj
|
33
34
|
return self.key, obj
|
34
35
|
|
35
36
|
@staticmethod
|
36
37
|
def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
|
37
38
|
file = str(get_file(module).relative_to(cwd))
|
38
|
-
key = "
|
39
|
+
key = file + "/" + ".".join(attr_path)
|
39
40
|
cap = Capturable(key, module, attr_path, hash_by)
|
40
41
|
if not capture_once:
|
41
42
|
return cap
|
42
43
|
obj_hash = str(ObjectHash(cap.capture()[1]))
|
43
44
|
return Capturable(key, module, attr_path, None, obj_hash)
|
44
45
|
|
45
|
-
def extract_classvars(code: CodeType, scope_vars:
|
46
|
+
def extract_classvars(code: CodeType, scope_vars: dict) -> dict[str, dict[str, Type]]:
|
46
47
|
attr_path = AttrPath(())
|
47
48
|
scope_obj = None
|
48
49
|
classvars: dict[str, dict[str, Type]] = {}
|
49
50
|
instructs = seekable(dis.get_instructions(code))
|
50
|
-
for
|
51
|
-
if
|
51
|
+
for instruct in instructs:
|
52
|
+
if instruct.opname in scope_vars and not attr_path:
|
52
53
|
attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
|
53
|
-
attr_path = AttrPath((
|
54
|
+
attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
|
54
55
|
instructs.step(-1)
|
55
|
-
elif
|
56
|
-
obj =
|
56
|
+
elif instruct.opname == "CALL":
|
57
|
+
obj = get_at(scope_vars, *attr_path)
|
57
58
|
attr_path = AttrPath(())
|
58
59
|
if is_class(obj):
|
59
60
|
scope_obj = obj
|
60
|
-
elif
|
61
|
-
load_key =
|
62
|
-
classvars.setdefault(load_key, {})[
|
61
|
+
elif instruct.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
|
62
|
+
load_key = instruct.opname.replace("STORE", "LOAD")
|
63
|
+
classvars.setdefault(load_key, {})[instruct.argval] = scope_obj
|
63
64
|
scope_obj = None
|
64
65
|
return classvars
|
65
66
|
|
66
|
-
def extract_scope_values(code: CodeType, scope_vars:
|
67
|
+
def extract_scope_values(code: CodeType, scope_vars: dict) -> Iterable[tuple[AttrPath, object]]:
|
67
68
|
classvars = extract_classvars(code, scope_vars)
|
68
|
-
scope_vars = scope_vars
|
69
|
+
scope_vars = {**scope_vars, **{k: {**scope_vars[k], **v} for k, v in classvars.items()}}
|
69
70
|
instructs = seekable(dis.get_instructions(code))
|
70
|
-
for
|
71
|
-
if
|
71
|
+
for instruct in instructs:
|
72
|
+
if instruct.opname in scope_vars:
|
72
73
|
attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
|
73
|
-
attr_path = AttrPath((
|
74
|
+
attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
|
74
75
|
parent_path = attr_path[:-1]
|
75
76
|
instructs.step(-1)
|
76
|
-
obj =
|
77
|
+
obj = get_at(scope_vars, *attr_path)
|
77
78
|
if obj is not None:
|
78
79
|
yield attr_path, obj
|
79
80
|
if callable(obj) and parent_path[1:]:
|
80
|
-
parent_obj =
|
81
|
+
parent_obj = get_at(scope_vars, *parent_path)
|
81
82
|
yield parent_path, parent_obj
|
82
83
|
for const in code.co_consts:
|
83
84
|
if isinstance(const, CodeType):
|
84
|
-
next_deref = scope_vars
|
85
|
-
next_scope_vars =
|
85
|
+
next_deref = {**scope_vars["LOAD_DEREF"], **scope_vars["LOAD_FAST"]}
|
86
|
+
next_scope_vars = {**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref}
|
86
87
|
yield from extract_scope_values(const, next_scope_vars)
|
87
88
|
|
88
89
|
def resolve_class_annotations(anno: object) -> Type | None:
|
@@ -94,11 +95,11 @@ def resolve_class_annotations(anno: object) -> Type | None:
|
|
94
95
|
return resolve_class_annotations(next(iter(get_args(anno)), None))
|
95
96
|
return resolve_class_annotations(get_origin(anno))
|
96
97
|
|
97
|
-
def get_self_value(fn: Callable) ->
|
98
|
+
def get_self_value(fn: Callable) -> Type | object | None:
|
98
99
|
if isinstance(fn, MethodType):
|
99
100
|
return fn.__self__
|
100
101
|
parts = fn.__qualname__.split(".")[:-1]
|
101
|
-
cls = parts and
|
102
|
+
cls = parts and get_at(fn.__globals__, *parts)
|
102
103
|
if is_class(cls):
|
103
104
|
return cls
|
104
105
|
|
@@ -116,19 +117,20 @@ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, o
|
|
116
117
|
yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
|
117
118
|
|
118
119
|
def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
|
119
|
-
|
120
|
+
scope_vars_signature: dict[str, Type | object] = {
|
120
121
|
param.name: class_anno
|
121
122
|
for param in signature(fn).parameters.values()
|
122
123
|
if param.annotation is not Parameter.empty
|
123
124
|
if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
|
124
125
|
if (class_anno := resolve_class_annotations(param.annotation))
|
125
126
|
}
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
"
|
130
|
-
"
|
131
|
-
|
127
|
+
if self_obj := get_self_value(fn):
|
128
|
+
scope_vars_signature["self"] = self_obj
|
129
|
+
scope_vars = {
|
130
|
+
"LOAD_FAST": scope_vars_signature,
|
131
|
+
"LOAD_DEREF": dict(get_cell_contents(fn)),
|
132
|
+
"LOAD_GLOBAL": fn.__globals__,
|
133
|
+
}
|
132
134
|
captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
|
133
135
|
captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
|
134
136
|
capturables = list(get_capturables(fn, capture, captured_vars))
|
@@ -142,7 +144,7 @@ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn
|
|
142
144
|
for depend_fn in captured_callables:
|
143
145
|
depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
|
144
146
|
if isinstance(depend_fn, CachedFunction):
|
145
|
-
capturable_by_fn[depend_fn] = []
|
147
|
+
capturable_by_fn[depend_fn.ident.cached_fn] = []
|
146
148
|
elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
|
147
149
|
get_depend_fns(depend_fn, capture, capturable_by_fn)
|
148
150
|
return capturable_by_fn
|
@@ -153,7 +155,7 @@ def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
|
|
153
155
|
capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
|
154
156
|
depends = capturable_by_fn.keys()
|
155
157
|
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
156
|
-
|
157
|
-
assert fn ==
|
158
|
-
fn_hash = str(ObjectHash(iter=
|
158
|
+
depend_callables = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
159
|
+
assert fn == depend_callables[0]
|
160
|
+
fn_hash = str(ObjectHash(iter=map(get_fn_aststr, depend_callables)))
|
159
161
|
return RawFunctionIdent(fn_hash, depends, capturables)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
import ast
|
2
|
+
import sys
|
3
|
+
from inspect import getsource
|
4
|
+
from textwrap import dedent
|
5
|
+
from typing import Callable
|
6
|
+
from .utils import drop_none, get_at
|
7
|
+
|
8
|
+
def get_decorator_path(node: ast.AST) -> tuple[str, ...]:
|
9
|
+
if isinstance(node, ast.Call):
|
10
|
+
return get_decorator_path(node.func)
|
11
|
+
elif isinstance(node, ast.Attribute):
|
12
|
+
return get_decorator_path(node.value) + (node.attr,)
|
13
|
+
elif isinstance(node, ast.Name):
|
14
|
+
return (node.id,)
|
15
|
+
else:
|
16
|
+
return ()
|
17
|
+
|
18
|
+
def is_empty_expression(node: ast.AST) -> bool:
|
19
|
+
# Filter out docstrings
|
20
|
+
return isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant)
|
21
|
+
|
22
|
+
class CleanFunctionTransform(ast.NodeTransformer):
|
23
|
+
def __init__(self, fn_globals: dict):
|
24
|
+
self.is_root = True
|
25
|
+
self.fn_globals = fn_globals
|
26
|
+
|
27
|
+
def is_checkpointer(self, node: ast.AST) -> bool:
|
28
|
+
from .checkpoint import Checkpointer
|
29
|
+
return isinstance(get_at(self.fn_globals, *get_decorator_path(node)), Checkpointer)
|
30
|
+
|
31
|
+
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
|
32
|
+
fn_type = type(node).__name__
|
33
|
+
fn_name = None if self.is_root else node.name
|
34
|
+
args_by_type = [
|
35
|
+
node.args.posonlyargs + node.args.args,
|
36
|
+
drop_none([node.args.vararg]),
|
37
|
+
sorted(node.args.kwonlyargs, key=lambda x: x.arg),
|
38
|
+
drop_none([node.args.kwarg]),
|
39
|
+
]
|
40
|
+
arg_kind_names = ",".join(f"{i}:{arg.arg}" for i, args in enumerate(args_by_type) for arg in args)
|
41
|
+
header = " ".join(drop_none((fn_type, fn_name, arg_kind_names or None)))
|
42
|
+
|
43
|
+
self.is_root = False
|
44
|
+
|
45
|
+
return ast.List([
|
46
|
+
ast.Constant(header),
|
47
|
+
ast.List([child for child in node.decorator_list if not self.is_checkpointer(child)], ast.Load()),
|
48
|
+
ast.List([self.visit(child) for child in node.body if not is_empty_expression(child)], ast.Load()),
|
49
|
+
], ast.Load())
|
50
|
+
|
51
|
+
def visit_AsyncFunctionDef(self, node):
|
52
|
+
return self.visit_FunctionDef(node)
|
53
|
+
|
54
|
+
def get_fn_aststr(fn: Callable) -> str:
|
55
|
+
try:
|
56
|
+
source = getsource(fn)
|
57
|
+
except OSError:
|
58
|
+
return ""
|
59
|
+
try:
|
60
|
+
tree = ast.parse(dedent(source), mode="exec")
|
61
|
+
tree = tree.body[0]
|
62
|
+
except SyntaxError:
|
63
|
+
# lambda functions can cause SyntaxError in ast.parse
|
64
|
+
return source.strip()
|
65
|
+
|
66
|
+
if fn.__name__ != "<lambda>":
|
67
|
+
tree = CleanFunctionTransform(fn.__globals__).visit(tree)
|
68
|
+
else:
|
69
|
+
for node in ast.walk(tree):
|
70
|
+
if isinstance(node, ast.Lambda):
|
71
|
+
tree = node
|
72
|
+
break
|
73
|
+
|
74
|
+
if sys.version_info >= (3, 13):
|
75
|
+
return ast.dump(tree, annotate_fields=False, show_empty=True)
|
76
|
+
else:
|
77
|
+
return ast.dump(tree, annotate_fields=False)
|
checkpointer/object_hash.py
CHANGED
@@ -5,32 +5,35 @@ import io
|
|
5
5
|
import re
|
6
6
|
import sys
|
7
7
|
import tokenize
|
8
|
+
from collections import OrderedDict
|
8
9
|
from collections.abc import Iterable
|
9
10
|
from contextlib import nullcontext, suppress
|
10
11
|
from decimal import Decimal
|
11
12
|
from io import StringIO
|
12
13
|
from itertools import chain
|
13
14
|
from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
|
14
|
-
from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
|
15
|
+
from types import BuiltinFunctionType, FunctionType, GeneratorType, MappingProxyType, MethodType, ModuleType, UnionType
|
15
16
|
from typing import Callable, Self, TypeVar
|
16
17
|
from .utils import ContextVar
|
17
18
|
|
18
19
|
np, torch = None, None
|
19
20
|
|
20
|
-
with suppress(Exception):
|
21
|
-
import numpy as np
|
22
|
-
with suppress(Exception):
|
23
|
-
import torch
|
24
|
-
|
25
21
|
class _Never:
|
26
22
|
def __getattribute__(self, _: str):
|
27
23
|
pass
|
28
24
|
|
25
|
+
with suppress(Exception):
|
26
|
+
import numpy as np
|
27
|
+
with suppress(Exception):
|
28
|
+
import torch
|
29
29
|
if sys.version_info >= (3, 12):
|
30
30
|
from typing import TypeAliasType
|
31
31
|
else:
|
32
32
|
TypeAliasType = _Never
|
33
33
|
|
34
|
+
flatten = chain.from_iterable
|
35
|
+
nc = nullcontext()
|
36
|
+
|
34
37
|
def encode_type(t: type | FunctionType) -> str:
|
35
38
|
return f"{t.__module__}:{t.__qualname__}"
|
36
39
|
|
@@ -77,7 +80,7 @@ class ObjectHash:
|
|
77
80
|
return self.write_bytes(":".join(map(str, args)).encode())
|
78
81
|
|
79
82
|
def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
|
80
|
-
with
|
83
|
+
with nc if tolerable is None else self.tolerable.set(tolerable):
|
81
84
|
for obj in chain(objs, iter):
|
82
85
|
if header is not None:
|
83
86
|
self.write_bytes(header.encode())
|
@@ -105,11 +108,11 @@ class ObjectHash:
|
|
105
108
|
|
106
109
|
case set() | frozenset():
|
107
110
|
try:
|
108
|
-
items = sorted(obj)
|
109
111
|
header = "set"
|
112
|
+
items = sorted(obj)
|
110
113
|
except:
|
111
|
-
items = sorted(map(self.nested_hash, obj))
|
112
114
|
header = "set-unsortable"
|
115
|
+
items = sorted(map(self.nested_hash, obj))
|
113
116
|
self.header(header, encode_type_of(obj), len(obj)).update(iter=items)
|
114
117
|
|
115
118
|
case TypeVar():
|
@@ -170,14 +173,16 @@ class ObjectHash:
|
|
170
173
|
match obj:
|
171
174
|
case list() | tuple():
|
172
175
|
self.header("list", encode_type_of(obj), len(obj)).update(iter=obj)
|
173
|
-
case dict():
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
176
|
+
case dict() | MappingProxyType():
|
177
|
+
header = "dict"
|
178
|
+
items = obj.items()
|
179
|
+
if not isinstance(obj, OrderedDict):
|
180
|
+
try:
|
181
|
+
items = sorted(items)
|
182
|
+
except:
|
183
|
+
header = "dict-unsortable"
|
184
|
+
items = sorted((self.nested_hash(key), val) for key, val in items)
|
185
|
+
self.header(header, encode_type_of(obj), len(obj)).update(iter=flatten(items))
|
181
186
|
case _:
|
182
187
|
self._update_object(obj)
|
183
188
|
finally:
|
checkpointer/storages/storage.py
CHANGED
@@ -8,17 +8,18 @@ if TYPE_CHECKING:
|
|
8
8
|
|
9
9
|
class Storage:
|
10
10
|
checkpointer: Checkpointer
|
11
|
-
|
11
|
+
ident: CachedFunction
|
12
12
|
|
13
13
|
def __init__(self, cached_fn: CachedFunction):
|
14
|
-
self.checkpointer = cached_fn.checkpointer
|
14
|
+
self.checkpointer = cached_fn.ident.checkpointer
|
15
15
|
self.cached_fn = cached_fn
|
16
16
|
|
17
17
|
def fn_id(self) -> str:
|
18
|
-
|
18
|
+
ident = self.cached_fn.ident
|
19
|
+
return f"{ident.fn_dir}/{ident.fn_hash}"
|
19
20
|
|
20
21
|
def fn_dir(self) -> Path:
|
21
|
-
return self.checkpointer.
|
22
|
+
return self.checkpointer.directory / self.fn_id()
|
22
23
|
|
23
24
|
def store(self, call_hash: str, data: Any) -> Any: ...
|
24
25
|
|
checkpointer/utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import inspect
|
3
|
-
from contextlib import contextmanager
|
3
|
+
from contextlib import contextmanager, suppress
|
4
4
|
from itertools import islice
|
5
5
|
from pathlib import Path
|
6
6
|
from types import FunctionType, MethodType, ModuleType
|
@@ -23,10 +23,11 @@ def is_user_fn(obj) -> TypeGuard[Callable]:
|
|
23
23
|
|
24
24
|
def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
|
25
25
|
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
|
26
|
-
|
26
|
+
with suppress(ValueError):
|
27
27
|
yield (key, cell.cell_contents)
|
28
|
-
|
29
|
-
|
28
|
+
|
29
|
+
def drop_none(iterable: Iterable[T | None]) -> list[T]:
|
30
|
+
return [x for x in iterable if x is not None]
|
30
31
|
|
31
32
|
def distinct(seq: Iterable[T]) -> list[T]:
|
32
33
|
return list(dict.fromkeys(seq))
|
@@ -80,6 +81,14 @@ class seekable(Generic[T]):
|
|
80
81
|
with self.freeze():
|
81
82
|
return list(islice(self, count))
|
82
83
|
|
84
|
+
def get_at(obj: object, *attrs: str) -> object:
|
85
|
+
for attr in attrs:
|
86
|
+
if type(obj) is dict:
|
87
|
+
obj = obj.get(attr, None)
|
88
|
+
else:
|
89
|
+
obj = getattr(obj, attr, None)
|
90
|
+
return obj
|
91
|
+
|
83
92
|
class AttrDict(dict):
|
84
93
|
def __init__(self, *args, **kwargs):
|
85
94
|
super().__init__(*args, **kwargs)
|
@@ -91,17 +100,6 @@ class AttrDict(dict):
|
|
91
100
|
def __setattr__(self, name: str, value: object):
|
92
101
|
super().__setattr__(name, value)
|
93
102
|
|
94
|
-
def set(self, d: dict) -> AttrDict:
|
95
|
-
if not d:
|
96
|
-
return self
|
97
|
-
return AttrDict({**self, **d})
|
98
|
-
|
99
|
-
def get_at(self: object, *attrs: str) -> object:
|
100
|
-
obj = self
|
101
|
-
for attr in attrs:
|
102
|
-
obj = getattr(obj, attr, None)
|
103
|
-
return obj
|
104
|
-
|
105
103
|
class ContextVar(Generic[T]):
|
106
104
|
def __init__(self, value: T):
|
107
105
|
self.value = value
|
@@ -0,0 +1,260 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: checkpointer
|
3
|
+
Version: 2.13.0
|
4
|
+
Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
|
5
|
+
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
|
+
Author: Hampus Hallman
|
7
|
+
License-Expression: MIT
|
8
|
+
License-File: ATTRIBUTION.md
|
9
|
+
License-File: LICENSE
|
10
|
+
Keywords: async,cache,caching,data analysis,data processing,fast,hashing,invalidation,memoization,optimization,performance,workflow
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
14
|
+
Requires-Python: >=3.11
|
15
|
+
Description-Content-Type: text/markdown
|
16
|
+
|
17
|
+
# checkpointer · [](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [](https://pypi.org/project/checkpointer/) [](https://pypi.org/project/checkpointer/)
|
18
|
+
|
19
|
+
`checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and refreshes caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
|
20
|
+
|
21
|
+
## 📦 Installation
|
22
|
+
|
23
|
+
```bash
|
24
|
+
pip install checkpointer
|
25
|
+
```
|
26
|
+
|
27
|
+
## 🚀 Quick Start
|
28
|
+
|
29
|
+
Apply the `@checkpoint` decorator to any function:
|
30
|
+
|
31
|
+
```python
|
32
|
+
from checkpointer import checkpoint
|
33
|
+
|
34
|
+
@checkpoint
|
35
|
+
def expensive_function(x: int) -> int:
|
36
|
+
print("Computing...")
|
37
|
+
return x ** 2
|
38
|
+
|
39
|
+
result = expensive_function(4) # Computes and stores the result
|
40
|
+
result = expensive_function(4) # Loads from the cache
|
41
|
+
```
|
42
|
+
|
43
|
+
## 🧠 How It Works
|
44
|
+
|
45
|
+
When a `@checkpoint`-decorated function is called, `checkpointer` computes a unique identifier for the call. This identifier derives from the function's source code, its dependencies, captured variables, and the arguments passed.
|
46
|
+
|
47
|
+
It then tries to retrieve a cached result using this identifier. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
|
48
|
+
|
49
|
+
### 🚨 What Triggers Cache Invalidation?
|
50
|
+
|
51
|
+
`checkpointer` maintains cache correctness using two types of hashes:
|
52
|
+
|
53
|
+
#### 1. Function Identity Hash (One-Time per Function)
|
54
|
+
|
55
|
+
This hash represents the decorated function itself and is computed once (usually on first invocation). It covers:
|
56
|
+
|
57
|
+
* **Decorated Function's Code:**\
|
58
|
+
The function's logic and signature (excluding parameter type annotations) are hashed. Formatting changes like whitespace, newlines, comments, or trailing commas do **not** cause invalidation.
|
59
|
+
|
60
|
+
* **Dependencies:**\
|
61
|
+
All user-defined functions and methods the function calls or uses are included recursively. Dependencies are detected by:
|
62
|
+
* Inspecting the function's global scope for referenced functions/objects.
|
63
|
+
* Inferring from argument type annotations.
|
64
|
+
* Analyzing object constructions and method calls to identify classes and methods used.
|
65
|
+
|
66
|
+
* **Top-Level Module Code:**\
|
67
|
+
Changes unrelated to the function or its dependencies in the module do **not** trigger invalidation.
|
68
|
+
|
69
|
+
#### 2. Call Hash (Computed on Every Function Call)
|
70
|
+
|
71
|
+
Each function call's cache key (the **call hash**) combines:
|
72
|
+
|
73
|
+
* **Passed Arguments:**\
|
74
|
+
Includes positional and keyword arguments, combined with default values. Changing defaults alone doesn't necessarily trigger invalidation unless it affects actual call values.
|
75
|
+
|
76
|
+
* **Captured Global Variables:**\
|
77
|
+
When `capture=True` or explicit capture annotations are used, `checkpointer` hashes global variables referenced by the function:
|
78
|
+
* `CaptureMe` variables are hashed on every call, so changes trigger invalidation.
|
79
|
+
* `CaptureMeOnce` variables are hashed once per session for performance optimization.
|
80
|
+
|
81
|
+
* **Custom Argument Hashing:**\
|
82
|
+
Using `HashBy` annotations, arguments or captured variables can be transformed before hashing (e.g., sorting lists to ignore order), allowing more precise or efficient call hashes.
|
83
|
+
|
84
|
+
## 💡 Usage
|
85
|
+
|
86
|
+
Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
|
87
|
+
|
88
|
+
* **`expensive_function(...)`**:\
|
89
|
+
Call the function normally. This will compute and cache the result or load it from cache.
|
90
|
+
|
91
|
+
* **`expensive_function.rerun(...)`**:\
|
92
|
+
Force the original function to execute and overwrite any existing cached result.
|
93
|
+
|
94
|
+
* **`expensive_function.fn(...)`**:\
|
95
|
+
Call the undecorated function directly, bypassing the cache (useful in recursion to prevent caching intermediate steps).
|
96
|
+
|
97
|
+
* **`expensive_function.get(...)`**:\
|
98
|
+
Retrieve the cached result without executing the function. Raises `CheckpointError` if no valid cache exists.
|
99
|
+
|
100
|
+
* **`expensive_function.exists(...)`**:\
|
101
|
+
Check if a cached result exists without computing or loading it.
|
102
|
+
|
103
|
+
* **`expensive_function.delete(...)`**:\
|
104
|
+
Remove the cached entry for given arguments.
|
105
|
+
|
106
|
+
* **`expensive_function.reinit(recursive: bool = False)`**:\
|
107
|
+
Recalculate the function identity hash and recapture `CaptureMeOnce` variables, updating the cached function state within the same Python session.
|
108
|
+
|
109
|
+
## ⚙️ Configuration & Customization
|
110
|
+
|
111
|
+
The `@checkpoint` decorator accepts the following parameters:
|
112
|
+
|
113
|
+
* **`storage`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
|
114
|
+
Storage backend to use: `"pickle"` (disk-based, persistent), `"memory"` (in-memory, non-persistent), or a custom `Storage` class.
|
115
|
+
|
116
|
+
* **`directory`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
|
117
|
+
Base directory for disk-based checkpoints (only for `"pickle"` storage).
|
118
|
+
|
119
|
+
* **`when`** (Type: `bool`, Default: `True`)\
|
120
|
+
Enable or disable checkpointing dynamically, useful for environment-based toggling.
|
121
|
+
|
122
|
+
* **`capture`** (Type: `bool`, Default: `False`)\
|
123
|
+
If `True`, includes global variables referenced by the function in call hashes (except those excluded via `NoHash`).
|
124
|
+
|
125
|
+
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
|
126
|
+
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
127
|
+
|
128
|
+
* **`fn_hash_from`** (Type: `Any`, Default: `None`)\
|
129
|
+
Override the computed function identity hash with any hashable object you provide (e.g., version strings, config IDs). This gives you explicit control over the function's version and when its cache should be invalidated.
|
130
|
+
|
131
|
+
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
132
|
+
Controls the level of logging output from `checkpointer`.
|
133
|
+
* `0`: No output.
|
134
|
+
* `1`: Shows when functions are computed and cached.
|
135
|
+
* `2`: Also shows when cached results are remembered (loaded from cache).
|
136
|
+
|
137
|
+
## 🔬 Customize Argument Hashing
|
138
|
+
|
139
|
+
You can customize how arguments are hashed without modifying the actual argument values to improve cache hit rates or speed up hashing.
|
140
|
+
|
141
|
+
* **`Annotated[T, HashBy[fn]]`**:\
|
142
|
+
Transform the argument via `fn(argument)` before hashing. Useful for normalization (e.g., sorting lists) or optimized hashing for complex inputs.
|
143
|
+
|
144
|
+
* **`NoHash[T]`**:\
|
145
|
+
Exclude the argument from hashing completely, so changes to it won't trigger cache invalidation.
|
146
|
+
|
147
|
+
**Example:**
|
148
|
+
|
149
|
+
```python
|
150
|
+
from typing import Annotated
|
151
|
+
from checkpointer import checkpoint, HashBy, NoHash
|
152
|
+
from pathlib import Path
|
153
|
+
import logging
|
154
|
+
|
155
|
+
def file_bytes(path: Path) -> bytes:
|
156
|
+
return path.read_bytes()
|
157
|
+
|
158
|
+
@checkpoint
|
159
|
+
def process(
|
160
|
+
numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
|
161
|
+
data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
|
162
|
+
log: NoHash[logging.Logger], # Exclude logger from hashing
|
163
|
+
):
|
164
|
+
...
|
165
|
+
```
|
166
|
+
|
167
|
+
In this example, the hash for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don't affect caching.
|
168
|
+
|
169
|
+
## 🎯 Capturing Global Variables
|
170
|
+
|
171
|
+
`checkpointer` can include **captured global variables** in call hashes - these are globals your function reads during execution that may affect results.
|
172
|
+
|
173
|
+
Use `capture=True` on `@checkpoint` to capture **all** referenced globals (except those explicitly excluded with `NoHash`).
|
174
|
+
|
175
|
+
Alternatively, you can **opt-in selectively** by annotating globals with:
|
176
|
+
|
177
|
+
* **`CaptureMe[T]`**:\
|
178
|
+
Capture the variable on every call (triggers invalidation on changes).
|
179
|
+
|
180
|
+
* **`CaptureMeOnce[T]`**:\
|
181
|
+
Capture once per Python session (for expensive, immutable globals).
|
182
|
+
|
183
|
+
You can also combine these with `HashBy` to customize how captured variables are hashed (e.g., hash by subset of attributes).
|
184
|
+
|
185
|
+
**Example:**
|
186
|
+
|
187
|
+
```python
|
188
|
+
from typing import Annotated
|
189
|
+
from checkpointer import checkpoint, CaptureMe, CaptureMeOnce, HashBy
|
190
|
+
from pathlib import Path
|
191
|
+
|
192
|
+
def file_bytes(path: Path) -> bytes:
|
193
|
+
return path.read_bytes()
|
194
|
+
|
195
|
+
captured_data: CaptureMe[Annotated[Path, HashBy[file_bytes]]] = Path("data.txt")
|
196
|
+
session_config: CaptureMeOnce[dict] = {"mode": "prod"}
|
197
|
+
|
198
|
+
@checkpoint
|
199
|
+
def process():
|
200
|
+
# `captured_data` is included in the call hash on every call, hashed by file content
|
201
|
+
# `session_config` is hashed once per session
|
202
|
+
...
|
203
|
+
```
|
204
|
+
|
205
|
+
## 🗄️ Custom Storage Backends
|
206
|
+
|
207
|
+
Implement your own storage backend by subclassing `checkpointer.Storage` and overriding required methods.
|
208
|
+
|
209
|
+
Within storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get function identity (name + hash/version), important for organizing checkpoints.
|
210
|
+
|
211
|
+
**Example:**
|
212
|
+
|
213
|
+
```python
|
214
|
+
from checkpointer import checkpoint, Storage
|
215
|
+
from datetime import datetime
|
216
|
+
|
217
|
+
class MyCustomStorage(Storage):
|
218
|
+
def exists(self, call_hash):
|
219
|
+
fn_dir = self.checkpointer.directory / self.fn_id()
|
220
|
+
return (fn_dir / call_hash).exists()
|
221
|
+
|
222
|
+
def store(self, call_hash, data):
|
223
|
+
... # Store serialized data
|
224
|
+
return data # Must return data to checkpointer
|
225
|
+
|
226
|
+
def checkpoint_date(self, call_hash): ...
|
227
|
+
def load(self, call_hash): ...
|
228
|
+
def delete(self, call_hash): ...
|
229
|
+
|
230
|
+
@checkpoint(storage=MyCustomStorage)
|
231
|
+
def custom_cached_function(x: int):
|
232
|
+
return x ** 2
|
233
|
+
```
|
234
|
+
|
235
|
+
## ⚡ Async Support
|
236
|
+
|
237
|
+
`checkpointer` works with Python's `asyncio` and other async runtimes.
|
238
|
+
|
239
|
+
```python
|
240
|
+
import asyncio
|
241
|
+
from checkpointer import checkpoint
|
242
|
+
|
243
|
+
@checkpoint
|
244
|
+
async def async_compute_sum(a: int, b: int) -> int:
|
245
|
+
print(f"Asynchronously computing {a} + {b}...")
|
246
|
+
await asyncio.sleep(1)
|
247
|
+
return a + b
|
248
|
+
|
249
|
+
async def main():
|
250
|
+
result1 = await async_compute_sum(3, 7)
|
251
|
+
print(f"Result 1: {result1}")
|
252
|
+
|
253
|
+
result2 = await async_compute_sum(3, 7)
|
254
|
+
print(f"Result 2: {result2}")
|
255
|
+
|
256
|
+
result3 = async_compute_sum.get(3, 7)
|
257
|
+
print(f"Result 3 (from cache): {result3}")
|
258
|
+
|
259
|
+
asyncio.run(main())
|
260
|
+
```
|
@@ -0,0 +1,18 @@
|
|
1
|
+
checkpointer/__init__.py,sha256=l14EbRTgmkPlJUJc-5uAjWmB4dQr6kXXOPAjHcqbKK8,890
|
2
|
+
checkpointer/checkpoint.py,sha256=Ylf0Yel9WiyHg7EoP355b2huS_yh0hgBUx2l3bOyI_c,10149
|
3
|
+
checkpointer/fn_ident.py,sha256=mZfGPSIidlZVHsG3Kc6jk8rTNDoN_k2tCkJmbSKRhdg,6921
|
4
|
+
checkpointer/fn_string.py,sha256=R1evcaBKoVP9SSDd741O1FoaC9SiaA3-kfn6QLxd9No,2532
|
5
|
+
checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
|
6
|
+
checkpointer/object_hash.py,sha256=NXlhME87iA9rrMRjF_Au4SKXFVc14j-SiSEsGhH7M8s,8327
|
7
|
+
checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
|
8
|
+
checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
|
9
|
+
checkpointer/utils.py,sha256=FEabT0jp7Bx2KN-uco0QDAsBJ6hgauKwEjKR4B8IKzk,2977
|
10
|
+
checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
|
11
|
+
checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
|
12
|
+
checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
|
13
|
+
checkpointer/storages/storage.py,sha256=9CZquRjw9lWpcCVCiU7E6P-eJxGBUVbNKWncsTRwmoc,916
|
14
|
+
checkpointer-2.13.0.dist-info/METADATA,sha256=vyBK4qpjq4c3S2yCCxjHTF2C5QtuQ3WDTWhEj-etz6U,10925
|
15
|
+
checkpointer-2.13.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
checkpointer-2.13.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
|
17
|
+
checkpointer-2.13.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
|
18
|
+
checkpointer-2.13.0.dist-info/RECORD,,
|
@@ -1,236 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.4
|
2
|
-
Name: checkpointer
|
3
|
-
Version: 2.12.0
|
4
|
-
Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
|
5
|
-
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
|
-
Author: Hampus Hallman
|
7
|
-
License-Expression: MIT
|
8
|
-
License-File: ATTRIBUTION.md
|
9
|
-
License-File: LICENSE
|
10
|
-
Keywords: async,cache,caching,code-aware,decorator,fast,hashing,invalidation,memoization,memoize,memory,optimization,performance,workflow
|
11
|
-
Classifier: Programming Language :: Python :: 3.11
|
12
|
-
Classifier: Programming Language :: Python :: 3.12
|
13
|
-
Classifier: Programming Language :: Python :: 3.13
|
14
|
-
Requires-Python: >=3.11
|
15
|
-
Description-Content-Type: text/markdown
|
16
|
-
|
17
|
-
# checkpointer · [](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [](https://pypi.org/project/checkpointer/) [](https://pypi.org/project/checkpointer/)
|
18
|
-
|
19
|
-
`checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and refreshes caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
|
20
|
-
|
21
|
-
## 📦 Installation
|
22
|
-
|
23
|
-
```bash
|
24
|
-
pip install checkpointer
|
25
|
-
```
|
26
|
-
|
27
|
-
## 🚀 Quick Start
|
28
|
-
|
29
|
-
Apply the `@checkpoint` decorator to any function:
|
30
|
-
|
31
|
-
```python
|
32
|
-
from checkpointer import checkpoint
|
33
|
-
|
34
|
-
@checkpoint
|
35
|
-
def expensive_function(x: int) -> int:
|
36
|
-
print("Computing...")
|
37
|
-
return x ** 2
|
38
|
-
|
39
|
-
result = expensive_function(4) # Computes and stores the result
|
40
|
-
result = expensive_function(4) # Loads from the cache
|
41
|
-
```
|
42
|
-
|
43
|
-
## 🧠 How It Works
|
44
|
-
|
45
|
-
When a `@checkpoint`-decorated function is called, `checkpointer` first computes a unique identifier (hash) for the call. This hash is derived from the function's source code, its dependencies, and the arguments passed.
|
46
|
-
|
47
|
-
It then tries to retrieve a cached result using this ID. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
|
48
|
-
|
49
|
-
Cache validity is determined by this function's hash, which automatically updates if:
|
50
|
-
|
51
|
-
* **Function Code Changes**: The decorated function's source code is modified.
|
52
|
-
* **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated) is modified.
|
53
|
-
* **Captured Variables Change** (with `capture=True`): Global or closure-based variables used within the function are altered.
|
54
|
-
|
55
|
-
**Example: Dependency Invalidation**
|
56
|
-
|
57
|
-
```python
|
58
|
-
def multiply(a, b):
|
59
|
-
return a * b
|
60
|
-
|
61
|
-
@checkpoint
|
62
|
-
def helper(x):
|
63
|
-
# Depends on `multiply`
|
64
|
-
return multiply(x + 1, 2)
|
65
|
-
|
66
|
-
@checkpoint
|
67
|
-
def compute(a, b):
|
68
|
-
# Depends on `helper` and `multiply`
|
69
|
-
return helper(a) + helper(b)
|
70
|
-
```
|
71
|
-
|
72
|
-
If `multiply` is modified, caches for both `helper` and `compute` are automatically invalidated and recomputed.
|
73
|
-
|
74
|
-
## 💡 Usage
|
75
|
-
|
76
|
-
Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
|
77
|
-
|
78
|
-
* **`expensive_function(...)`**:\
|
79
|
-
Call the function normally. This will either compute and cache the result or load it from the cache if available.
|
80
|
-
|
81
|
-
* **`expensive_function.rerun(...)`**:\
|
82
|
-
Forces the original function to execute, compute a new result, and overwrite any existing cached value for the given arguments.
|
83
|
-
|
84
|
-
* **`expensive_function.fn(...)`**:\
|
85
|
-
Calls the original, undecorated function directly, bypassing the cache entirely. This is particularly useful within recursive functions to prevent caching intermediate steps.
|
86
|
-
|
87
|
-
* **`expensive_function.get(...)`**:\
|
88
|
-
Attempts to retrieve the cached result for the given arguments without executing the original function. Raises `CheckpointError` if no valid cached result exists.
|
89
|
-
|
90
|
-
* **`expensive_function.exists(...)`**:\
|
91
|
-
Checks if a cached result exists for the given arguments without attempting to compute or load it. Returns `True` if a valid checkpoint exists, `False` otherwise.
|
92
|
-
|
93
|
-
* **`expensive_function.delete(...)`**:\
|
94
|
-
Removes the cached entry for the specified arguments.
|
95
|
-
|
96
|
-
* **`expensive_function.reinit()`**:\
|
97
|
-
Recalculates the function's internal hash. This is primarily used when `capture=True` and you need to update the cache based on changes to external variables within the same Python session.
|
98
|
-
|
99
|
-
## ⚙️ Configuration & Customization
|
100
|
-
|
101
|
-
The `@checkpoint` decorator accepts the following parameters to customize its behavior:
|
102
|
-
|
103
|
-
* **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
|
104
|
-
Defines the storage backend to use. Built-in options are `"pickle"` (disk-based, persistent) and `"memory"` (in-memory, non-persistent). You can also provide a custom `Storage` class.
|
105
|
-
|
106
|
-
* **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
|
107
|
-
The base directory for storing disk-based checkpoints. This parameter is only relevant when `format` is set to `"pickle"`.
|
108
|
-
|
109
|
-
* **`when`** (Type: `bool`, Default: `True`)\
|
110
|
-
A boolean flag to enable or disable checkpointing for the decorated function. This is particularly useful for toggling caching based on environment variables (e.g., `when=os.environ.get("ENABLE_CACHING", "false").lower() == "true"`).
|
111
|
-
|
112
|
-
* **`capture`** (Type: `bool`, Default: `False`)\
|
113
|
-
If set to `True`, `checkpointer` includes global or closure-based variables used by the function in its hash calculation. This ensures that changes to these external variables also trigger cache invalidation and recomputation.
|
114
|
-
|
115
|
-
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
|
116
|
-
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
117
|
-
|
118
|
-
* **`fn_hash_from`** (Type: `Any`, Default: `None`)\
|
119
|
-
This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
|
120
|
-
|
121
|
-
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
122
|
-
Controls the level of logging output from `checkpointer`.
|
123
|
-
* `0`: No output.
|
124
|
-
* `1`: Shows when functions are computed and cached.
|
125
|
-
* `2`: Also shows when cached results are remembered (loaded from cache).
|
126
|
-
|
127
|
-
## 🔬 Customize Argument Hashing
|
128
|
-
|
129
|
-
You can customize how individual function arguments are hashed without changing their actual values when passed in.
|
130
|
-
|
131
|
-
* **`Annotated[T, HashBy[fn]]`**:\
|
132
|
-
Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
|
133
|
-
|
134
|
-
* **`NoHash[T]`**:\
|
135
|
-
Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
|
136
|
-
|
137
|
-
**Example:**
|
138
|
-
|
139
|
-
```python
|
140
|
-
from typing import Annotated
|
141
|
-
from checkpointer import checkpoint, HashBy, NoHash
|
142
|
-
from pathlib import Path
|
143
|
-
import logging
|
144
|
-
|
145
|
-
def file_bytes(path: Path) -> bytes:
|
146
|
-
return path.read_bytes()
|
147
|
-
|
148
|
-
@checkpoint
|
149
|
-
def process(
|
150
|
-
numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
|
151
|
-
data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
|
152
|
-
log: NoHash[logging.Logger], # Exclude logger from hashing
|
153
|
-
):
|
154
|
-
...
|
155
|
-
```
|
156
|
-
|
157
|
-
In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
|
158
|
-
|
159
|
-
## 🗄️ Custom Storage Backends
|
160
|
-
|
161
|
-
For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
|
162
|
-
|
163
|
-
Within custom storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
|
164
|
-
|
165
|
-
**Example: Custom Storage Backend**
|
166
|
-
|
167
|
-
```python
|
168
|
-
from checkpointer import checkpoint, Storage
|
169
|
-
from datetime import datetime
|
170
|
-
|
171
|
-
class MyCustomStorage(Storage):
|
172
|
-
def exists(self, call_hash):
|
173
|
-
# Example: Constructing a path based on function ID and call ID
|
174
|
-
fn_dir = self.checkpointer.root_path / self.fn_id()
|
175
|
-
return (fn_dir / call_hash).exists()
|
176
|
-
|
177
|
-
def store(self, call_hash, data):
|
178
|
-
... # Store the serialized data for `call_hash`
|
179
|
-
return data # This method must return the data back to checkpointer
|
180
|
-
|
181
|
-
def checkpoint_date(self, call_hash): ...
|
182
|
-
def load(self, call_hash): ...
|
183
|
-
def delete(self, call_hash): ...
|
184
|
-
|
185
|
-
@checkpoint(format=MyCustomStorage)
|
186
|
-
def custom_cached_function(x: int):
|
187
|
-
return x ** 2
|
188
|
-
```
|
189
|
-
|
190
|
-
## 🧱 Layered Caching
|
191
|
-
|
192
|
-
You can apply multiple `@checkpoint` decorators to a single function to create layered caching strategies. `checkpointer` processes these decorators from bottom to top, meaning the decorator closest to the function definition is evaluated first.
|
193
|
-
|
194
|
-
This is useful for scenarios like combining a fast, ephemeral cache (e.g., in-memory) with a persistent, slower cache (e.g., disk-based).
|
195
|
-
|
196
|
-
**Example: Memory Cache over Disk Cache**
|
197
|
-
|
198
|
-
```python
|
199
|
-
from checkpointer import checkpoint
|
200
|
-
|
201
|
-
@checkpoint(format="memory") # Layer 2: Fast, ephemeral in-memory cache
|
202
|
-
@checkpoint(format="pickle") # Layer 1: Persistent disk cache
|
203
|
-
def some_expensive_operation():
|
204
|
-
print("Performing a time-consuming operation...")
|
205
|
-
return sum(i for i in range(10**7))
|
206
|
-
```
|
207
|
-
|
208
|
-
## ⚡ Async Support
|
209
|
-
|
210
|
-
`checkpointer` works seamlessly with Python's `asyncio` and other async runtimes.
|
211
|
-
|
212
|
-
```python
|
213
|
-
import asyncio
|
214
|
-
from checkpointer import checkpoint
|
215
|
-
|
216
|
-
@checkpoint
|
217
|
-
async def async_compute_sum(a: int, b: int) -> int:
|
218
|
-
print(f"Asynchronously computing {a} + {b}...")
|
219
|
-
await asyncio.sleep(1)
|
220
|
-
return a + b
|
221
|
-
|
222
|
-
async def main():
|
223
|
-
# First call computes and caches
|
224
|
-
result1 = await async_compute_sum(3, 7)
|
225
|
-
print(f"Result 1: {result1}")
|
226
|
-
|
227
|
-
# Second call loads from cache
|
228
|
-
result2 = await async_compute_sum(3, 7)
|
229
|
-
print(f"Result 2: {result2}")
|
230
|
-
|
231
|
-
# Retrieve from cache without re-running the async function
|
232
|
-
result3 = async_compute_sum.get(3, 7)
|
233
|
-
print(f"Result 3 (from cache): {result3}")
|
234
|
-
|
235
|
-
asyncio.run(main())
|
236
|
-
```
|
@@ -1,17 +0,0 @@
|
|
1
|
-
checkpointer/__init__.py,sha256=x8LbL-URPg-afn0O-HMxPsJkV9cmW-Cw5epKSCGClOM,889
|
2
|
-
checkpointer/checkpoint.py,sha256=2K2D_aYHpI1XBcNrWQxta06wRdUj81TeCXshBVxl7cA,9869
|
3
|
-
checkpointer/fn_ident.py,sha256=4qg4NIvcWRV5GduO70an_iu12dn8LLIlw3Q8F3gm0Mo,6826
|
4
|
-
checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
|
5
|
-
checkpointer/object_hash.py,sha256=MkrwSJJYVlWOjIDpGfYpAeYFPgCJhXHYBdlKVcUy2kw,8145
|
6
|
-
checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
|
7
|
-
checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
|
8
|
-
checkpointer/utils.py,sha256=rq7AjR0WJql1o8clKMdXj4p3wrPYMLyS6G11nvNNjFE,2934
|
9
|
-
checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
|
10
|
-
checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
|
11
|
-
checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
|
12
|
-
checkpointer/storages/storage.py,sha256=5Jel7VlmCG9gnIbZFKT1NrEiAePz8ZD8hfsD-tYEiP4,905
|
13
|
-
checkpointer-2.12.0.dist-info/METADATA,sha256=B-R8Aq2caTM7fclANY7MJVZJQaU0nB1Mdf7wMr6HSTI,10705
|
14
|
-
checkpointer-2.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
-
checkpointer-2.12.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
|
16
|
-
checkpointer-2.12.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
|
17
|
-
checkpointer-2.12.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|