checkpointer 2.10.0__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer/__init__.py +4 -4
- checkpointer/checkpoint.py +87 -42
- checkpointer/fn_ident.py +25 -16
- checkpointer/object_hash.py +32 -17
- checkpointer/storages/storage.py +1 -1
- checkpointer/test_checkpointer.py +6 -8
- checkpointer/types.py +17 -0
- checkpointer/utils.py +8 -29
- {checkpointer-2.10.0.dist-info → checkpointer-2.11.0.dist-info}/METADATA +38 -9
- checkpointer-2.11.0.dist-info/RECORD +16 -0
- checkpointer-2.10.0.dist-info/RECORD +0 -15
- {checkpointer-2.10.0.dist-info → checkpointer-2.11.0.dist-info}/WHEEL +0 -0
- {checkpointer-2.10.0.dist-info → checkpointer-2.11.0.dist-info}/licenses/LICENSE +0 -0
checkpointer/__init__.py
CHANGED
@@ -4,18 +4,18 @@ from typing import Callable
|
|
4
4
|
from .checkpoint import CachedFunction, Checkpointer, CheckpointError
|
5
5
|
from .object_hash import ObjectHash
|
6
6
|
from .storages import MemoryStorage, PickleStorage, Storage
|
7
|
-
from .
|
7
|
+
from .types import AwaitableValue, HashBy, NoHash
|
8
8
|
|
9
9
|
checkpoint = Checkpointer()
|
10
10
|
capture_checkpoint = Checkpointer(capture=True)
|
11
11
|
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
12
12
|
tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
|
13
|
-
static_checkpoint = Checkpointer(
|
13
|
+
static_checkpoint = Checkpointer(fn_hash_from=())
|
14
14
|
|
15
15
|
def cleanup_all(invalidated=True, expired=True):
|
16
16
|
for obj in gc.get_objects():
|
17
17
|
if isinstance(obj, CachedFunction):
|
18
18
|
obj.cleanup(invalidated=invalidated, expired=expired)
|
19
19
|
|
20
|
-
def get_function_hash(fn: Callable
|
21
|
-
return CachedFunction(Checkpointer(
|
20
|
+
def get_function_hash(fn: Callable) -> str:
|
21
|
+
return CachedFunction(Checkpointer(), fn).ident.fn_hash
|
checkpointer/checkpoint.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import inspect
|
3
2
|
import re
|
4
3
|
from datetime import datetime
|
5
4
|
from functools import cached_property, update_wrapper
|
5
|
+
from inspect import Parameter, Signature, iscoroutine, signature
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import (
|
8
|
-
|
9
|
-
ParamSpec, Self, Type, TypedDict,
|
8
|
+
Annotated, Callable, Concatenate, Coroutine, Generic,
|
9
|
+
Iterable, Literal, ParamSpec, Self, Type, TypedDict,
|
10
|
+
TypeVar, Unpack, cast, get_args, get_origin, overload,
|
10
11
|
)
|
11
|
-
from .fn_ident import get_fn_ident
|
12
|
+
from .fn_ident import RawFunctionIdent, get_fn_ident
|
12
13
|
from .object_hash import ObjectHash
|
13
14
|
from .print_checkpoint import print_checkpoint
|
14
15
|
from .storages import STORAGE_MAP, Storage
|
15
|
-
from .
|
16
|
+
from .types import AwaitableValue, HashBy
|
17
|
+
from .utils import unwrap_fn
|
16
18
|
|
17
19
|
Fn = TypeVar("Fn", bound=Callable)
|
18
20
|
P = ParamSpec("P")
|
@@ -29,10 +31,9 @@ class CheckpointerOpts(TypedDict, total=False):
|
|
29
31
|
root_path: Path | str | None
|
30
32
|
when: bool
|
31
33
|
verbosity: Literal[0, 1, 2]
|
32
|
-
hash_by: Callable | None
|
33
34
|
should_expire: Callable[[datetime], bool] | None
|
34
35
|
capture: bool
|
35
|
-
|
36
|
+
fn_hash_from: object
|
36
37
|
|
37
38
|
class Checkpointer:
|
38
39
|
def __init__(self, **opts: Unpack[CheckpointerOpts]):
|
@@ -40,10 +41,9 @@ class Checkpointer:
|
|
40
41
|
self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
|
41
42
|
self.when = opts.get("when", True)
|
42
43
|
self.verbosity = opts.get("verbosity", 1)
|
43
|
-
self.hash_by = opts.get("hash_by")
|
44
44
|
self.should_expire = opts.get("should_expire")
|
45
45
|
self.capture = opts.get("capture", False)
|
46
|
-
self.
|
46
|
+
self.fn_hash_from = opts.get("fn_hash_from")
|
47
47
|
|
48
48
|
@overload
|
49
49
|
def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
|
@@ -56,6 +56,35 @@ class Checkpointer:
|
|
56
56
|
|
57
57
|
return CachedFunction(self, fn) if callable(fn) else self
|
58
58
|
|
59
|
+
class FunctionIdent:
|
60
|
+
"""
|
61
|
+
Represents the identity and hash state of a cached function.
|
62
|
+
Separated from CachedFunction to prevent hash desynchronization
|
63
|
+
among bound instances when `.reinit()` is called.
|
64
|
+
"""
|
65
|
+
def __init__(self, cached_fn: CachedFunction):
|
66
|
+
self.__dict__.clear()
|
67
|
+
self.cached_fn = cached_fn
|
68
|
+
|
69
|
+
@cached_property
|
70
|
+
def raw_ident(self) -> RawFunctionIdent:
|
71
|
+
return get_fn_ident(unwrap_fn(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
|
72
|
+
|
73
|
+
@cached_property
|
74
|
+
def fn_hash(self) -> str:
|
75
|
+
if self.cached_fn.checkpointer.fn_hash_from is not None:
|
76
|
+
return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
|
77
|
+
deep_hashes = [depend.ident.raw_ident.fn_hash for depend in self.cached_fn.deep_depends()]
|
78
|
+
return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
|
79
|
+
|
80
|
+
@cached_property
|
81
|
+
def captured_hash(self) -> str:
|
82
|
+
deep_hashes = [depend.ident.raw_ident.captured_hash for depend in self.cached_fn.deep_depends()]
|
83
|
+
return str(ObjectHash().write_text(iter=deep_hashes))
|
84
|
+
|
85
|
+
def clear(self):
|
86
|
+
self.__init__(self.cached_fn)
|
87
|
+
|
59
88
|
class CachedFunction(Generic[Fn]):
|
60
89
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
61
90
|
wrapped = unwrap_fn(fn)
|
@@ -70,6 +99,14 @@ class CachedFunction(Generic[Fn]):
|
|
70
99
|
self.cleanup = self.storage.cleanup
|
71
100
|
self.bound = ()
|
72
101
|
|
102
|
+
sig = signature(wrapped)
|
103
|
+
params = list(sig.parameters.items())
|
104
|
+
pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
105
|
+
self.arg_names = [name for name, param in params if param.kind in pos_params]
|
106
|
+
self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
|
107
|
+
self.hash_by_map = get_hash_by_map(sig)
|
108
|
+
self.ident = FunctionIdent(self)
|
109
|
+
|
73
110
|
@overload
|
74
111
|
def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
|
75
112
|
@overload
|
@@ -82,41 +119,32 @@ class CachedFunction(Generic[Fn]):
|
|
82
119
|
bound_fn.bound = (instance,)
|
83
120
|
return bound_fn
|
84
121
|
|
85
|
-
@cached_property
|
86
|
-
def ident_tuple(self) -> tuple[str, list[Callable]]:
|
87
|
-
return get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
|
88
|
-
|
89
|
-
@property
|
90
|
-
def fn_hash_raw(self) -> str:
|
91
|
-
return self.ident_tuple[0]
|
92
|
-
|
93
122
|
@property
|
94
123
|
def depends(self) -> list[Callable]:
|
95
|
-
return self.
|
96
|
-
|
97
|
-
@cached_property
|
98
|
-
def fn_hash(self) -> str:
|
99
|
-
deep_hashes = [depend.fn_hash_raw for depend in self.deep_depends()]
|
100
|
-
fn_hash = ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes)
|
101
|
-
return str(self.checkpointer.fn_hash or fn_hash)[:32]
|
124
|
+
return self.ident.raw_ident.depends
|
102
125
|
|
103
126
|
def reinit(self, recursive=False) -> CachedFunction[Fn]:
|
104
|
-
|
105
|
-
for
|
106
|
-
|
107
|
-
self.__dict__.pop("ident_tuple", None)
|
108
|
-
for depend in depends:
|
109
|
-
depend.fn_hash
|
127
|
+
depend_idents = [depend.ident for depend in self.deep_depends()] if recursive else [self.ident]
|
128
|
+
for ident in depend_idents: ident.clear()
|
129
|
+
for ident in depend_idents: ident.fn_hash
|
110
130
|
return self
|
111
131
|
|
112
|
-
def get_call_id(self, args: tuple, kw: dict) -> str:
|
132
|
+
def get_call_id(self, args: tuple, kw: dict[str, object]) -> str:
|
113
133
|
args = self.bound + args
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
134
|
+
pos_args = args[len(self.arg_names):]
|
135
|
+
named_pos_args = dict(zip(self.arg_names, args))
|
136
|
+
named_args = {**self.default_args, **named_pos_args, **kw}
|
137
|
+
if hash_by_map := self.hash_by_map:
|
138
|
+
rest_hash_by = hash_by_map.get(b"**")
|
139
|
+
for key, value in named_args.items():
|
140
|
+
if hash_by := hash_by_map.get(key, rest_hash_by):
|
141
|
+
named_args[key] = hash_by(value)
|
142
|
+
if pos_hash_by := hash_by_map.get(b"*"):
|
143
|
+
pos_args = tuple(map(pos_hash_by, pos_args))
|
144
|
+
return str(ObjectHash(named_args, pos_args, self.ident.captured_hash, digest_size=16))
|
145
|
+
|
146
|
+
async def _resolve_coroutine(self, call_id: str, coroutine: Coroutine):
|
147
|
+
return self.storage.store(call_id, AwaitableValue(await coroutine)).value
|
120
148
|
|
121
149
|
def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
|
122
150
|
full_args = self.bound + args
|
@@ -125,7 +153,7 @@ class CachedFunction(Generic[Fn]):
|
|
125
153
|
return self.fn(*full_args, **kw)
|
126
154
|
|
127
155
|
call_id = self.get_call_id(args, kw)
|
128
|
-
call_id_long = f"{self.fn_dir}/{self.fn_hash}/{call_id}"
|
156
|
+
call_id_long = f"{self.fn_dir}/{self.ident.fn_hash}/{call_id}"
|
129
157
|
|
130
158
|
refresh = rerun \
|
131
159
|
or not self.storage.exists(call_id) \
|
@@ -134,8 +162,8 @@ class CachedFunction(Generic[Fn]):
|
|
134
162
|
if refresh:
|
135
163
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id_long, "blue")
|
136
164
|
data = self.fn(*full_args, **kw)
|
137
|
-
if
|
138
|
-
return self.
|
165
|
+
if iscoroutine(data):
|
166
|
+
return self._resolve_coroutine(call_id, data)
|
139
167
|
return self.storage.store(call_id, data)
|
140
168
|
|
141
169
|
try:
|
@@ -154,7 +182,7 @@ class CachedFunction(Generic[Fn]):
|
|
154
182
|
return self._call(args, kw, True)
|
155
183
|
|
156
184
|
@overload
|
157
|
-
def get(self: Callable[P,
|
185
|
+
def get(self: Callable[P, Coroutine[object, object, R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
158
186
|
@overload
|
159
187
|
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
160
188
|
def get(self, *args, **kw):
|
@@ -172,7 +200,7 @@ class CachedFunction(Generic[Fn]):
|
|
172
200
|
self.storage.delete(self.get_call_id(args, kw))
|
173
201
|
|
174
202
|
def __repr__(self) -> str:
|
175
|
-
return f"<CachedFunction {self.fn.__name__} {self.fn_hash[:6]}>"
|
203
|
+
return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
|
176
204
|
|
177
205
|
def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
|
178
206
|
if self not in visited:
|
@@ -182,3 +210,20 @@ class CachedFunction(Generic[Fn]):
|
|
182
210
|
for depend in self.depends:
|
183
211
|
if isinstance(depend, CachedFunction):
|
184
212
|
yield from depend.deep_depends(visited)
|
213
|
+
|
214
|
+
def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
|
215
|
+
if get_origin(annotation) is Annotated:
|
216
|
+
args = get_args(annotation)
|
217
|
+
metadata = args[1] if len(args) > 1 else None
|
218
|
+
if get_origin(metadata) is HashBy:
|
219
|
+
return get_args(metadata)[0]
|
220
|
+
|
221
|
+
def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
|
222
|
+
hash_by_map = {}
|
223
|
+
for name, param in sig.parameters.items():
|
224
|
+
if param.kind == Parameter.VAR_POSITIONAL:
|
225
|
+
name = b"*"
|
226
|
+
elif param.kind == Parameter.VAR_KEYWORD:
|
227
|
+
name = b"**"
|
228
|
+
hash_by_map[name] = hash_by_from_annotation(param.annotation)
|
229
|
+
return hash_by_map if any(hash_by_map.values()) else {}
|
checkpointer/fn_ident.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1
1
|
import dis
|
2
2
|
import inspect
|
3
|
-
from collections.abc import Callable
|
4
3
|
from itertools import takewhile
|
5
4
|
from pathlib import Path
|
6
5
|
from types import CodeType, FunctionType, MethodType
|
7
|
-
from typing import
|
6
|
+
from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
|
8
7
|
from .object_hash import ObjectHash
|
9
|
-
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming,
|
8
|
+
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
|
10
9
|
|
11
10
|
cwd = Path.cwd().resolve()
|
12
11
|
|
12
|
+
class RawFunctionIdent(NamedTuple):
|
13
|
+
fn_hash: str
|
14
|
+
captured_hash: str
|
15
|
+
depends: list[Callable]
|
16
|
+
|
13
17
|
def is_class(obj) -> TypeGuard[Type]:
|
14
18
|
# isinstance works too, but needlessly triggers _lazyinit()
|
15
19
|
return issubclass(type(obj), type)
|
@@ -33,7 +37,7 @@ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[st
|
|
33
37
|
scope_obj = None
|
34
38
|
return classvars
|
35
39
|
|
36
|
-
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...],
|
40
|
+
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
|
37
41
|
classvars = extract_classvars(code, scope_vars)
|
38
42
|
scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
|
39
43
|
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
@@ -55,7 +59,7 @@ def get_self_value(fn: Callable) -> type | object | None:
|
|
55
59
|
if is_class(cls):
|
56
60
|
return cls
|
57
61
|
|
58
|
-
def get_fn_captured_vals(fn: Callable) -> list[
|
62
|
+
def get_fn_captured_vals(fn: Callable) -> list[object]:
|
59
63
|
self_value = get_self_value(fn)
|
60
64
|
scope_vars = AttrDict({
|
61
65
|
"LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
|
@@ -71,24 +75,29 @@ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
|
71
75
|
fn_path = Path(inspect.getfile(candidate_fn)).resolve()
|
72
76
|
return cwd in fn_path.parents and ".venv" not in fn_path.parts
|
73
77
|
|
74
|
-
def get_depend_fns(fn: Callable,
|
78
|
+
def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
|
75
79
|
from .checkpoint import CachedFunction
|
76
|
-
captured_vals_by_fn = captured_vals_by_fn or {}
|
77
80
|
captured_vals = get_fn_captured_vals(fn)
|
78
|
-
captured_vals_by_fn
|
79
|
-
|
80
|
-
for
|
81
|
+
captured_vals_by_fn = captured_vals_by_fn or {}
|
82
|
+
captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
|
83
|
+
for val in captured_vals:
|
84
|
+
if not callable(val):
|
85
|
+
continue
|
86
|
+
child_fn = unwrap_fn(val, cached_fn=True)
|
81
87
|
if isinstance(child_fn, CachedFunction):
|
82
88
|
captured_vals_by_fn[child_fn] = []
|
83
89
|
elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
|
84
|
-
get_depend_fns(child_fn,
|
90
|
+
get_depend_fns(child_fn, captured_vals_by_fn)
|
85
91
|
return captured_vals_by_fn
|
86
92
|
|
87
|
-
def get_fn_ident(fn: Callable, capture: bool) ->
|
93
|
+
def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
|
88
94
|
from .checkpoint import CachedFunction
|
89
|
-
captured_vals_by_fn = get_depend_fns(fn
|
90
|
-
|
95
|
+
captured_vals_by_fn = get_depend_fns(fn)
|
96
|
+
depend_captured_vals = list(captured_vals_by_fn.values()) * capture
|
97
|
+
depends = captured_vals_by_fn.keys()
|
91
98
|
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
92
99
|
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
93
|
-
|
94
|
-
|
100
|
+
assert fn == unwrapped_depends[0]
|
101
|
+
fn_hash = str(ObjectHash(iter=unwrapped_depends))
|
102
|
+
captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
|
103
|
+
return RawFunctionIdent(fn_hash, captured_hash, depends)
|
checkpointer/object_hash.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1
1
|
import ctypes
|
2
2
|
import hashlib
|
3
|
+
import inspect
|
3
4
|
import io
|
4
5
|
import re
|
5
6
|
import sys
|
7
|
+
import tokenize
|
6
8
|
from collections.abc import Iterable
|
7
9
|
from contextlib import nullcontext, suppress
|
8
10
|
from decimal import Decimal
|
11
|
+
from io import StringIO
|
9
12
|
from itertools import chain
|
10
|
-
from pickle import HIGHEST_PROTOCOL as
|
13
|
+
from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
|
11
14
|
from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
|
12
|
-
from typing import
|
13
|
-
from .utils import ContextVar
|
15
|
+
from typing import Callable, TypeVar
|
16
|
+
from .utils import ContextVar
|
14
17
|
|
15
18
|
np, torch = None, None
|
16
19
|
|
@@ -31,16 +34,16 @@ else:
|
|
31
34
|
def encode_type(t: type | FunctionType) -> str:
|
32
35
|
return f"{t.__module__}:{t.__qualname__}"
|
33
36
|
|
34
|
-
def encode_type_of(v:
|
37
|
+
def encode_type_of(v: object) -> str:
|
35
38
|
return encode_type(type(v))
|
36
39
|
|
37
40
|
class ObjectHashError(Exception):
|
38
|
-
def __init__(self, obj:
|
41
|
+
def __init__(self, obj: object, cause: Exception):
|
39
42
|
super().__init__(f"{type(cause).__name__} error when hashing {obj}")
|
40
43
|
self.obj = obj
|
41
44
|
|
42
45
|
class ObjectHash:
|
43
|
-
def __init__(self, *objs:
|
46
|
+
def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerate_errors=False) -> None:
|
44
47
|
self.hash = hashlib.blake2b(digest_size=digest_size)
|
45
48
|
self.current: dict[int, int] = {}
|
46
49
|
self.tolerate_errors = ContextVar(tolerate_errors)
|
@@ -59,7 +62,7 @@ class ObjectHash:
|
|
59
62
|
def __eq__(self, value: object) -> bool:
|
60
63
|
return isinstance(value, ObjectHash) and str(self) == str(value)
|
61
64
|
|
62
|
-
def nested_hash(self, *objs:
|
65
|
+
def nested_hash(self, *objs: object) -> str:
|
63
66
|
return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
|
64
67
|
|
65
68
|
def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
|
@@ -70,10 +73,10 @@ class ObjectHash:
|
|
70
73
|
def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
|
71
74
|
return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
|
72
75
|
|
73
|
-
def header(self, *args:
|
76
|
+
def header(self, *args: object) -> "ObjectHash":
|
74
77
|
return self.write_bytes(":".join(map(str, args)).encode())
|
75
78
|
|
76
|
-
def update(self, *objs:
|
79
|
+
def update(self, *objs: object, iter: Iterable[object] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
|
77
80
|
with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
|
78
81
|
for obj in chain(objs, iter):
|
79
82
|
try:
|
@@ -81,11 +84,11 @@ class ObjectHash:
|
|
81
84
|
except Exception as ex:
|
82
85
|
if self.tolerate_errors.value:
|
83
86
|
self.header("error").update(type(ex))
|
84
|
-
|
85
|
-
|
87
|
+
else:
|
88
|
+
raise ObjectHashError(obj, ex) from ex
|
86
89
|
return self
|
87
90
|
|
88
|
-
def _update_one(self, obj:
|
91
|
+
def _update_one(self, obj: object) -> None:
|
89
92
|
match obj:
|
90
93
|
case None:
|
91
94
|
self.header("null")
|
@@ -142,7 +145,7 @@ class ObjectHash:
|
|
142
145
|
case _ if np and isinstance(obj, np.ndarray):
|
143
146
|
self.header("ndarray", encode_type_of(obj), obj.shape, obj.strides).update(obj.dtype)
|
144
147
|
if obj.dtype.hasobject:
|
145
|
-
self.update(obj.__reduce_ex__(
|
148
|
+
self.update(obj.__reduce_ex__(PICKLE_PROTOCOL))
|
146
149
|
else:
|
147
150
|
array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
|
148
151
|
self.write_bytes(array.data)
|
@@ -180,13 +183,14 @@ class ObjectHash:
|
|
180
183
|
def _update_iterator(self, obj: Iterable) -> "ObjectHash":
|
181
184
|
return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
|
182
185
|
|
183
|
-
def _update_object(self, obj:
|
186
|
+
def _update_object(self, obj: object) -> "ObjectHash":
|
184
187
|
self.header("instance", encode_type_of(obj))
|
185
|
-
|
186
|
-
|
188
|
+
get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
|
189
|
+
if callable(get_hash):
|
190
|
+
return self.header("objecthash").update(get_hash())
|
187
191
|
reduced = None
|
188
192
|
with suppress(Exception):
|
189
|
-
reduced = obj.__reduce_ex__(
|
193
|
+
reduced = obj.__reduce_ex__(PICKLE_PROTOCOL)
|
190
194
|
with suppress(Exception):
|
191
195
|
reduced = reduced or obj.__reduce__()
|
192
196
|
if isinstance(reduced, str):
|
@@ -202,5 +206,16 @@ class ObjectHash:
|
|
202
206
|
return self.header("slots").update(slots)
|
203
207
|
if d := getattr(obj, "__dict__", {}):
|
204
208
|
return self.header("dict").update(d)
|
209
|
+
if isinstance(obj, Iterable):
|
210
|
+
return self._update_iterator(obj)
|
205
211
|
repr_str = re.sub(r"\s+(at\s+0x[0-9a-fA-F]+)(>)$", r"\2", repr(obj))
|
206
212
|
return self.header("repr").update(repr_str)
|
213
|
+
|
214
|
+
def get_fn_body(fn: Callable) -> str:
|
215
|
+
try:
|
216
|
+
source = inspect.getsource(fn)
|
217
|
+
except OSError:
|
218
|
+
return ""
|
219
|
+
tokens = tokenize.generate_tokens(StringIO(source).readline)
|
220
|
+
ignore_types = (tokenize.COMMENT, tokenize.NL)
|
221
|
+
return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
|
checkpointer/storages/storage.py
CHANGED
@@ -15,7 +15,7 @@ class Storage:
|
|
15
15
|
self.cached_fn = cached_fn
|
16
16
|
|
17
17
|
def fn_id(self) -> str:
|
18
|
-
return f"{self.cached_fn.fn_dir}/{self.cached_fn.fn_hash}"
|
18
|
+
return f"{self.cached_fn.fn_dir}/{self.cached_fn.ident.fn_hash}"
|
19
19
|
|
20
20
|
def fn_dir(self) -> Path:
|
21
21
|
return self.checkpointer.root_path / self.fn_id()
|
@@ -1,8 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import pytest
|
3
|
-
from
|
4
|
-
from . import checkpoint
|
5
|
-
from .checkpoint import CheckpointError
|
3
|
+
from checkpointer import CheckpointError, checkpoint
|
6
4
|
from .utils import AttrDict
|
7
5
|
|
8
6
|
def global_multiply(a: int, b: int) -> int:
|
@@ -112,16 +110,16 @@ def test_capture():
|
|
112
110
|
def test_a():
|
113
111
|
return item_dict.a + 1
|
114
112
|
|
115
|
-
init_hash_a = test_a.
|
116
|
-
init_hash_whole = test_whole.
|
113
|
+
init_hash_a = test_a.ident.captured_hash
|
114
|
+
init_hash_whole = test_whole.ident.captured_hash
|
117
115
|
item_dict.b += 1
|
118
116
|
test_whole.reinit()
|
119
117
|
test_a.reinit()
|
120
|
-
assert test_whole.
|
121
|
-
assert test_a.
|
118
|
+
assert test_whole.ident.captured_hash != init_hash_whole
|
119
|
+
assert test_a.ident.captured_hash == init_hash_a
|
122
120
|
item_dict.a += 1
|
123
121
|
test_a.reinit()
|
124
|
-
assert test_a.
|
122
|
+
assert test_a.ident.captured_hash != init_hash_a
|
125
123
|
|
126
124
|
def test_depends():
|
127
125
|
def multiply_wrapper(a: int, b: int) -> int:
|
checkpointer/types.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
from typing import Annotated, Callable, Generic, TypeVar
|
2
|
+
|
3
|
+
T = TypeVar("T")
|
4
|
+
Fn = TypeVar("Fn", bound=Callable)
|
5
|
+
|
6
|
+
class HashBy(Generic[Fn]):
|
7
|
+
pass
|
8
|
+
|
9
|
+
NoHash = Annotated[T, HashBy[lambda _: None]]
|
10
|
+
|
11
|
+
class AwaitableValue:
|
12
|
+
def __init__(self, value):
|
13
|
+
self.value = value
|
14
|
+
|
15
|
+
def __await__(self):
|
16
|
+
yield
|
17
|
+
return self.value
|
checkpointer/utils.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1
|
-
import inspect
|
2
|
-
import tokenize
|
3
1
|
from contextlib import contextmanager
|
4
|
-
from io import StringIO
|
5
2
|
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
6
3
|
|
7
4
|
T = TypeVar("T")
|
@@ -10,21 +7,6 @@ Fn = TypeVar("Fn", bound=Callable)
|
|
10
7
|
def distinct(seq: Iterable[T]) -> list[T]:
|
11
8
|
return list(dict.fromkeys(seq))
|
12
9
|
|
13
|
-
def transpose(tuples, default_num_returns=0):
|
14
|
-
output = tuple(zip(*tuples))
|
15
|
-
if not output:
|
16
|
-
return ([],) * default_num_returns
|
17
|
-
return tuple(map(list, output))
|
18
|
-
|
19
|
-
def get_fn_body(fn: Callable) -> str:
|
20
|
-
try:
|
21
|
-
source = inspect.getsource(fn)
|
22
|
-
except OSError:
|
23
|
-
return ""
|
24
|
-
tokens = tokenize.generate_tokens(StringIO(source).readline)
|
25
|
-
ignore_types = (tokenize.COMMENT, tokenize.NL)
|
26
|
-
return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
|
27
|
-
|
28
10
|
def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
|
29
11
|
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
|
30
12
|
try:
|
@@ -39,14 +21,6 @@ def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
|
|
39
21
|
return cast(Fn, fn)
|
40
22
|
fn = getattr(fn, "__wrapped__")
|
41
23
|
|
42
|
-
class AwaitableValue:
|
43
|
-
def __init__(self, value):
|
44
|
-
self.value = value
|
45
|
-
|
46
|
-
def __await__(self):
|
47
|
-
yield
|
48
|
-
return self.value
|
49
|
-
|
50
24
|
class AttrDict(dict):
|
51
25
|
def __init__(self, *args, **kwargs):
|
52
26
|
super().__init__(*args, **kwargs)
|
@@ -91,14 +65,19 @@ class iterate_and_upcoming(Generic[T]):
|
|
91
65
|
def __init__(self, it: Iterable[T]) -> None:
|
92
66
|
self.it = iter(it)
|
93
67
|
self.previous: tuple[()] | tuple[T] = ()
|
68
|
+
self.tracked = self._tracked_iter()
|
94
69
|
|
95
70
|
def __iter__(self):
|
96
71
|
return self
|
97
72
|
|
98
73
|
def __next__(self) -> tuple[T, Iterable[T]]:
|
99
|
-
|
100
|
-
|
101
|
-
|
74
|
+
try:
|
75
|
+
item = self.previous[0] if self.previous else next(self.it)
|
76
|
+
self.previous = ()
|
77
|
+
return item, self.tracked
|
78
|
+
except StopIteration:
|
79
|
+
self.tracked.close()
|
80
|
+
raise
|
102
81
|
|
103
82
|
def _tracked_iter(self):
|
104
83
|
for x in self.it:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
4
|
-
Summary:
|
3
|
+
Version: 2.11.0
|
4
|
+
Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
7
7
|
License: Copyright 2018-2025 Hampus Hallman
|
@@ -121,11 +121,8 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
|
|
121
121
|
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
|
122
122
|
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
123
123
|
|
124
|
-
* **`
|
125
|
-
|
126
|
-
|
127
|
-
* **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
|
128
|
-
An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
|
124
|
+
* **`fn_hash_from`** (Type: `Any`, Default: `None`)\
|
125
|
+
This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
|
129
126
|
|
130
127
|
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
131
128
|
Controls the level of logging output from `checkpointer`.
|
@@ -133,13 +130,45 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
|
|
133
130
|
* `1`: Shows when functions are computed and cached.
|
134
131
|
* `2`: Also shows when cached results are remembered (loaded from cache).
|
135
132
|
|
136
|
-
|
133
|
+
## 🔬 Customize Argument Hashing
|
134
|
+
|
135
|
+
You can customize how individual function arguments are hashed without changing their actual values when passed in.
|
136
|
+
|
137
|
+
* **`Annotated[T, HashBy[fn]]`**:\
|
138
|
+
Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
|
139
|
+
|
140
|
+
* **`NoHash[T]`**:\
|
141
|
+
Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
|
142
|
+
|
143
|
+
**Example:**
|
144
|
+
|
145
|
+
```python
|
146
|
+
from typing import Annotated
|
147
|
+
from checkpointer import checkpoint, HashBy, NoHash
|
148
|
+
from pathlib import Path
|
149
|
+
import logging
|
150
|
+
|
151
|
+
def file_bytes(path: Path) -> bytes:
|
152
|
+
return path.read_bytes()
|
153
|
+
|
154
|
+
@checkpoint
|
155
|
+
def process(
|
156
|
+
numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
|
157
|
+
data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
|
158
|
+
log: NoHash[logging.Logger], # Exclude logger from hashing
|
159
|
+
):
|
160
|
+
...
|
161
|
+
```
|
162
|
+
|
163
|
+
In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
|
164
|
+
|
165
|
+
## 🗄️ Custom Storage Backends
|
137
166
|
|
138
167
|
For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
|
139
168
|
|
140
169
|
Within custom storage methods, `call_id` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
|
141
170
|
|
142
|
-
|
171
|
+
**Example: Custom Storage Backend**
|
143
172
|
|
144
173
|
```python
|
145
174
|
from checkpointer import checkpoint, Storage
|
@@ -0,0 +1,16 @@
|
|
1
|
+
checkpointer/__init__.py,sha256=ayjFyHwvl_HRHwocY-hOJvAx0Ko5X9IMZrNT4CwfoMU,824
|
2
|
+
checkpointer/checkpoint.py,sha256=zU67_PGrVCMP90clPTR7AA4vfxECqmeFr0jJElUL5iQ,9051
|
3
|
+
checkpointer/fn_ident.py,sha256=-5XbovQowVyYCFc7JdT9z1NoIEiL8h9fi7alF_34Ils,4470
|
4
|
+
checkpointer/object_hash.py,sha256=YlyFupQrg3V2mpzTLfOqpqlZWhoSCHliScQ4cKd36T0,8133
|
5
|
+
checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
|
6
|
+
checkpointer/test_checkpointer.py,sha256=-EvsMMNOOiIxhTcG97LLX0jUMWp534ko7qCKDSFWiA0,3802
|
7
|
+
checkpointer/types.py,sha256=rAdjZNn1-jk35df7UVtby_qlp-8_18ucXfVCtS4RI_M,323
|
8
|
+
checkpointer/utils.py,sha256=0cGVSlTnABgs3jI1uHoTfz353kkGa-qtTfe7jG4NCr0,2192
|
9
|
+
checkpointer/storages/__init__.py,sha256=en32nTUltpCSgz8RVGS_leIHC1Y1G89IqG1ZqAb6qUo,236
|
10
|
+
checkpointer/storages/memory_storage.py,sha256=Br30b1AyNOcNjjAaDui1mBjDKfhDbu--jV4WmJenzaE,1109
|
11
|
+
checkpointer/storages/pickle_storage.py,sha256=xS96q8TvwxH_TOZsiKmrMrhFwQKZCcaH7XOiECmuwl8,1628
|
12
|
+
checkpointer/storages/storage.py,sha256=5NWmnfsU2QY24NwKC5MZNe4h7pcYhADw-y800v17aYE,895
|
13
|
+
checkpointer-2.11.0.dist-info/METADATA,sha256=GFRKdyTOIQwvmzIIrkV2owT8NO-X2ONRXdnadY9j5xY,11617
|
14
|
+
checkpointer-2.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
checkpointer-2.11.0.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
|
16
|
+
checkpointer-2.11.0.dist-info/RECORD,,
|
@@ -1,15 +0,0 @@
|
|
1
|
-
checkpointer/__init__.py,sha256=FBdEwdyQS2tc4eQMwrb9qsFgjPu9QOyLOqgDK2c1rAI,837
|
2
|
-
checkpointer/checkpoint.py,sha256=SThQh4cdODw6waf1s3wYhcuHwjAPEmrsUzzsTMpPtow,6827
|
3
|
-
checkpointer/fn_ident.py,sha256=hiPvm1lw9Aol6v-d5bjteWdg0VhrPQH4r1Wiow-Wzpo,4311
|
4
|
-
checkpointer/object_hash.py,sha256=m-GH-hJ7kQ5wOeX8_4eBGdX6gWe8vwuxDt5RP1evoxY,7621
|
5
|
-
checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
|
6
|
-
checkpointer/test_checkpointer.py,sha256=GJgM0Cp7V-otr6GDzinmq7lBbm9rngH9CNSqlG1m8VM,3791
|
7
|
-
checkpointer/utils.py,sha256=_zvbkno7221_4l_FbROJMRrs-enx8BRS4FGxn8et9Ns,2751
|
8
|
-
checkpointer/storages/__init__.py,sha256=en32nTUltpCSgz8RVGS_leIHC1Y1G89IqG1ZqAb6qUo,236
|
9
|
-
checkpointer/storages/memory_storage.py,sha256=Br30b1AyNOcNjjAaDui1mBjDKfhDbu--jV4WmJenzaE,1109
|
10
|
-
checkpointer/storages/pickle_storage.py,sha256=xS96q8TvwxH_TOZsiKmrMrhFwQKZCcaH7XOiECmuwl8,1628
|
11
|
-
checkpointer/storages/storage.py,sha256=dR84h-IN644XLGe7invSlzLuXlFlU8QIbsDZa1xdtPM,889
|
12
|
-
checkpointer-2.10.0.dist-info/METADATA,sha256=99lN6dxixf5EMhgnN_9Me5iC3BHh9VnzwbwGGzcx1rk,10919
|
13
|
-
checkpointer-2.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
14
|
-
checkpointer-2.10.0.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
|
15
|
-
checkpointer-2.10.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|