checkpointer 2.11.2__tar.gz → 2.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer-2.12.0/ATTRIBUTION.md +33 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/LICENSE +2 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/PKG-INFO +4 -8
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/__init__.py +2 -2
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/checkpoint.py +67 -61
- checkpointer-2.12.0/checkpointer/fn_ident.py +159 -0
- checkpointer-2.12.0/checkpointer/import_mappings.py +47 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/object_hash.py +16 -13
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/storages/__init__.py +6 -4
- checkpointer-2.12.0/checkpointer/types.py +49 -0
- checkpointer-2.12.0/checkpointer/utils.py +115 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/pyproject.toml +23 -3
- {checkpointer-2.11.2 → checkpointer-2.12.0}/uv.lock +1 -1
- checkpointer-2.11.2/checkpointer/fn_ident.py +0 -103
- checkpointer-2.11.2/checkpointer/test_checkpointer.py +0 -168
- checkpointer-2.11.2/checkpointer/types.py +0 -21
- checkpointer-2.11.2/checkpointer/utils.py +0 -83
- {checkpointer-2.11.2 → checkpointer-2.12.0}/.gitignore +0 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/.python-version +0 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/README.md +0 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/print_checkpoint.py +0 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/storages/memory_storage.py +0 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/storages/pickle_storage.py +0 -0
- {checkpointer-2.11.2 → checkpointer-2.12.0}/checkpointer/storages/storage.py +0 -0
@@ -0,0 +1,33 @@
|
|
1
|
+
# Attribution and License Notices
|
2
|
+
|
3
|
+
This project includes code copied or adapted from third-party open-source projects. The following acknowledges the original sources and complies with their licensing requirements.
|
4
|
+
|
5
|
+
---
|
6
|
+
|
7
|
+
## Third-Party Code
|
8
|
+
|
9
|
+
### more-itertools
|
10
|
+
- **Source:** https://github.com/more-itertools/more-itertools
|
11
|
+
- **Author:** Erik Rose
|
12
|
+
- **Copyright:** (c) 2012 Erik Rose
|
13
|
+
- **License:** MIT (https://github.com/more-itertools/more-itertools/blob/master/LICENSE)
|
14
|
+
|
15
|
+
### colored
|
16
|
+
- **Source:** https://gitlab.com/dslackw/colored
|
17
|
+
- **Author:** Dimitris Zlatanidis
|
18
|
+
- **Copyright:** (c) 2014-2025 Dimitris Zlatanidis
|
19
|
+
- **License:** MIT (https://gitlab.com/dslackw/colored/-/blob/master/LICENSE.txt)
|
20
|
+
|
21
|
+
---
|
22
|
+
|
23
|
+
## License
|
24
|
+
|
25
|
+
This project is licensed under the MIT License. See the `LICENSE` file for details.
|
26
|
+
|
27
|
+
---
|
28
|
+
|
29
|
+
## Notes
|
30
|
+
|
31
|
+
- Third-party code is included under their original MIT licenses.
|
32
|
+
- This file documents those license notices, fulfilling attribution obligations.
|
33
|
+
- Source files with copied code may omit individual license headers in favor of this centralized attribution.
|
@@ -1,3 +1,5 @@
|
|
1
|
+
MIT License
|
2
|
+
|
1
3
|
Copyright 2018-2025 Hampus Hallman
|
2
4
|
|
3
5
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
@@ -1,17 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.12.0
|
4
4
|
Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
7
|
-
License:
|
8
|
-
|
9
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
10
|
-
|
11
|
-
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
12
|
-
|
13
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
7
|
+
License-Expression: MIT
|
8
|
+
License-File: ATTRIBUTION.md
|
14
9
|
License-File: LICENSE
|
10
|
+
Keywords: async,cache,caching,code-aware,decorator,fast,hashing,invalidation,memoization,memoize,memory,optimization,performance,workflow
|
15
11
|
Classifier: Programming Language :: Python :: 3.11
|
16
12
|
Classifier: Programming Language :: Python :: 3.12
|
17
13
|
Classifier: Programming Language :: Python :: 3.13
|
@@ -1,10 +1,10 @@
|
|
1
1
|
import gc
|
2
2
|
import tempfile
|
3
3
|
from typing import Callable
|
4
|
-
from .checkpoint import CachedFunction, Checkpointer, CheckpointError
|
4
|
+
from .checkpoint import CachedFunction, Checkpointer, CheckpointError, FunctionIdent
|
5
5
|
from .object_hash import ObjectHash
|
6
6
|
from .storages import MemoryStorage, PickleStorage, Storage
|
7
|
-
from .types import AwaitableValue, HashBy, NoHash
|
7
|
+
from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeOnce, HashBy, NoHash
|
8
8
|
|
9
9
|
checkpoint = Checkpointer()
|
10
10
|
capture_checkpoint = Checkpointer(capture=True)
|
@@ -2,29 +2,26 @@ from __future__ import annotations
|
|
2
2
|
import re
|
3
3
|
from datetime import datetime
|
4
4
|
from functools import cached_property, update_wrapper
|
5
|
-
from inspect import Parameter,
|
5
|
+
from inspect import Parameter, iscoroutine, signature, unwrap
|
6
|
+
from itertools import chain
|
6
7
|
from pathlib import Path
|
7
8
|
from typing import (
|
8
|
-
|
9
|
-
|
10
|
-
Unpack, cast, get_args, get_origin, overload,
|
9
|
+
Callable, Concatenate, Coroutine, Generic, Iterable,
|
10
|
+
Literal, Self, Type, TypedDict, Unpack, cast, overload,
|
11
11
|
)
|
12
|
-
from .fn_ident import RawFunctionIdent, get_fn_ident
|
12
|
+
from .fn_ident import Capturable, RawFunctionIdent, get_fn_ident
|
13
13
|
from .object_hash import ObjectHash
|
14
14
|
from .print_checkpoint import print_checkpoint
|
15
|
-
from .storages import STORAGE_MAP, Storage
|
16
|
-
from .types import AwaitableValue, C, Coro, Fn,
|
17
|
-
from .utils import unwrap_fn
|
15
|
+
from .storages import STORAGE_MAP, Storage, StorageType
|
16
|
+
from .types import AwaitableValue, C, Coro, Fn, P, R, hash_by_from_annotation
|
18
17
|
|
19
18
|
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
20
19
|
|
21
|
-
empty_set = cast(set, frozenset())
|
22
|
-
|
23
20
|
class CheckpointError(Exception):
|
24
21
|
pass
|
25
22
|
|
26
23
|
class CheckpointerOpts(TypedDict, total=False):
|
27
|
-
format: Type[Storage] |
|
24
|
+
format: Type[Storage] | StorageType
|
28
25
|
root_path: Path | str | None
|
29
26
|
when: bool
|
30
27
|
verbosity: Literal[0, 1, 2]
|
@@ -63,28 +60,52 @@ class FunctionIdent:
|
|
63
60
|
self.__dict__.clear()
|
64
61
|
self.cached_fn = cached_fn
|
65
62
|
|
63
|
+
def reset(self):
|
64
|
+
self.__init__(self.cached_fn)
|
65
|
+
|
66
|
+
def is_static(self) -> bool:
|
67
|
+
return self.cached_fn.checkpointer.fn_hash_from is not None
|
68
|
+
|
66
69
|
@cached_property
|
67
70
|
def raw_ident(self) -> RawFunctionIdent:
|
68
|
-
return get_fn_ident(
|
71
|
+
return get_fn_ident(unwrap(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
|
69
72
|
|
70
73
|
@cached_property
|
71
74
|
def fn_hash(self) -> str:
|
72
|
-
if
|
73
|
-
return str(ObjectHash(
|
74
|
-
|
75
|
+
if self.is_static():
|
76
|
+
return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
|
77
|
+
depends = self.deep_idents(past_static=False)
|
78
|
+
deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
|
75
79
|
return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
|
76
80
|
|
77
81
|
@cached_property
|
78
|
-
def
|
79
|
-
|
80
|
-
|
82
|
+
def capturables(self) -> list[Capturable]:
|
83
|
+
return sorted({
|
84
|
+
capturable.key: capturable
|
85
|
+
for depend in self.deep_idents()
|
86
|
+
for capturable in depend.raw_ident.capturables
|
87
|
+
}.values())
|
88
|
+
|
89
|
+
def deep_depends(self, past_static=True, visited: set[Callable] = set()) -> Iterable[Callable]:
|
90
|
+
if self.cached_fn not in visited:
|
91
|
+
yield self.cached_fn
|
92
|
+
visited = visited or set()
|
93
|
+
visited.add(self.cached_fn)
|
94
|
+
stop = not past_static and self.is_static()
|
95
|
+
depends = [] if stop else self.raw_ident.depends
|
96
|
+
for depend in depends:
|
97
|
+
if isinstance(depend, CachedFunction):
|
98
|
+
yield from depend.ident.deep_depends(past_static, visited)
|
99
|
+
elif depend not in visited:
|
100
|
+
yield depend
|
101
|
+
visited.add(depend)
|
81
102
|
|
82
|
-
def
|
83
|
-
self.
|
103
|
+
def deep_idents(self, past_static=True) -> Iterable[FunctionIdent]:
|
104
|
+
return (fn.ident for fn in self.deep_depends(past_static) if isinstance(fn, CachedFunction))
|
84
105
|
|
85
106
|
class CachedFunction(Generic[Fn]):
|
86
107
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
87
|
-
wrapped =
|
108
|
+
wrapped = unwrap(fn)
|
88
109
|
fn_file = Path(wrapped.__code__.co_filename).name
|
89
110
|
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
90
111
|
Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
|
@@ -95,20 +116,14 @@ class CachedFunction(Generic[Fn]):
|
|
95
116
|
self.storage = Storage(self)
|
96
117
|
self.cleanup = self.storage.cleanup
|
97
118
|
self.bound = ()
|
98
|
-
self.attrname: str | None = None
|
99
119
|
|
100
|
-
|
101
|
-
params = list(sig.parameters.items())
|
120
|
+
params = list(signature(wrapped).parameters.values())
|
102
121
|
pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
103
|
-
self.arg_names = [name for
|
104
|
-
self.default_args = {name: param.default for
|
105
|
-
self.hash_by_map = get_hash_by_map(
|
122
|
+
self.arg_names = [param.name for param in params if param.kind in pos_params]
|
123
|
+
self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
|
124
|
+
self.hash_by_map = get_hash_by_map(params)
|
106
125
|
self.ident = FunctionIdent(self)
|
107
126
|
|
108
|
-
def __set_name__(self, _, name: str):
|
109
|
-
assert self.attrname is None
|
110
|
-
self.attrname = name
|
111
|
-
|
112
127
|
@overload
|
113
128
|
def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
|
114
129
|
@overload
|
@@ -116,12 +131,9 @@ class CachedFunction(Generic[Fn]):
|
|
116
131
|
def __get__(self, instance, owner):
|
117
132
|
if instance is None:
|
118
133
|
return self
|
119
|
-
assert self.attrname is not None
|
120
134
|
bound_fn = object.__new__(CachedFunction)
|
121
135
|
bound_fn.__dict__ |= self.__dict__
|
122
136
|
bound_fn.bound = (instance,)
|
123
|
-
if hasattr(instance, "__dict__"):
|
124
|
-
setattr(instance, self.attrname, bound_fn)
|
125
137
|
return bound_fn
|
126
138
|
|
127
139
|
@property
|
@@ -129,12 +141,12 @@ class CachedFunction(Generic[Fn]):
|
|
129
141
|
return self.ident.raw_ident.depends
|
130
142
|
|
131
143
|
def reinit(self, recursive=False) -> CachedFunction[Fn]:
|
132
|
-
depend_idents =
|
144
|
+
depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
|
133
145
|
for ident in depend_idents: ident.reset()
|
134
146
|
for ident in depend_idents: ident.fn_hash
|
135
147
|
return self
|
136
148
|
|
137
|
-
def
|
149
|
+
def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
|
138
150
|
args = self.bound + args
|
139
151
|
pos_args = args[len(self.arg_names):]
|
140
152
|
named_pos_args = dict(zip(self.arg_names, args))
|
@@ -145,8 +157,17 @@ class CachedFunction(Generic[Fn]):
|
|
145
157
|
if hash_by := hash_by_map.get(key, rest_hash_by):
|
146
158
|
named_args[key] = hash_by(value)
|
147
159
|
if pos_hash_by := hash_by_map.get(b"*"):
|
148
|
-
pos_args =
|
149
|
-
|
160
|
+
pos_args = map(pos_hash_by, pos_args)
|
161
|
+
named_args_iter = chain.from_iterable(sorted(named_args.items()))
|
162
|
+
captured = chain.from_iterable(capturable.capture() for capturable in self.ident.capturables)
|
163
|
+
obj_hash = ObjectHash(digest_size=16) \
|
164
|
+
.update(iter=named_args_iter, header="NAMED") \
|
165
|
+
.update(iter=pos_args, header="POS") \
|
166
|
+
.update(iter=captured, header="CAPTURED")
|
167
|
+
return str(obj_hash)
|
168
|
+
|
169
|
+
def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
|
170
|
+
return self._get_call_hash(args, kw)
|
150
171
|
|
151
172
|
async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
|
152
173
|
return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
|
@@ -157,7 +178,7 @@ class CachedFunction(Generic[Fn]):
|
|
157
178
|
if not params.when:
|
158
179
|
return self.fn(*full_args, **kw)
|
159
180
|
|
160
|
-
call_hash = self.
|
181
|
+
call_hash = self._get_call_hash(args, kw)
|
161
182
|
call_id = f"{self.storage.fn_id()}/{call_hash}"
|
162
183
|
refresh = rerun \
|
163
184
|
or not self.storage.exists(call_hash) \
|
@@ -186,17 +207,17 @@ class CachedFunction(Generic[Fn]):
|
|
186
207
|
return self._call(args, kw, True)
|
187
208
|
|
188
209
|
def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
|
189
|
-
return self.storage.exists(self.
|
210
|
+
return self.storage.exists(self._get_call_hash(args, kw))
|
190
211
|
|
191
212
|
def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
|
192
|
-
self.storage.delete(self.
|
213
|
+
self.storage.delete(self._get_call_hash(args, kw))
|
193
214
|
|
194
215
|
@overload
|
195
216
|
def get(self: Callable[P, Coro[R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
196
217
|
@overload
|
197
218
|
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
198
219
|
def get(self, *args, **kw):
|
199
|
-
call_hash = self.
|
220
|
+
call_hash = self._get_call_hash(args, kw)
|
200
221
|
try:
|
201
222
|
data = self.storage.load(call_hash)
|
202
223
|
return data.value if isinstance(data, AwaitableValue) else data
|
@@ -208,30 +229,15 @@ class CachedFunction(Generic[Fn]):
|
|
208
229
|
@overload
|
209
230
|
def set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
|
210
231
|
def set(self, value, *args, **kw):
|
211
|
-
self.storage.store(self.
|
232
|
+
self.storage.store(self._get_call_hash(args, kw), value)
|
212
233
|
|
213
234
|
def __repr__(self) -> str:
|
214
235
|
return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
|
215
236
|
|
216
|
-
|
217
|
-
if self not in visited:
|
218
|
-
yield self
|
219
|
-
visited = visited or set()
|
220
|
-
visited.add(self)
|
221
|
-
for depend in self.depends:
|
222
|
-
if isinstance(depend, CachedFunction):
|
223
|
-
yield from depend.deep_depends(visited)
|
224
|
-
|
225
|
-
def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
|
226
|
-
if get_origin(annotation) is Annotated:
|
227
|
-
args = get_args(annotation)
|
228
|
-
metadata = args[1] if len(args) > 1 else None
|
229
|
-
if get_origin(metadata) is HashBy:
|
230
|
-
return get_args(metadata)[0]
|
231
|
-
|
232
|
-
def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
|
237
|
+
def get_hash_by_map(params: list[Parameter]) -> dict[str | bytes, Callable[[object], object]]:
|
233
238
|
hash_by_map = {}
|
234
|
-
for
|
239
|
+
for param in params:
|
240
|
+
name = param.name
|
235
241
|
if param.kind == Parameter.VAR_POSITIONAL:
|
236
242
|
name = b"*"
|
237
243
|
elif param.kind == Parameter.VAR_KEYWORD:
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import dis
|
2
|
+
from inspect import Parameter, getmodule, signature, unwrap
|
3
|
+
from types import CodeType, MethodType, ModuleType
|
4
|
+
from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
|
5
|
+
from .import_mappings import resolve_annotation
|
6
|
+
from .object_hash import ObjectHash
|
7
|
+
from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
|
8
|
+
from .utils import (
|
9
|
+
AttrDict, cwd, distinct, get_cell_contents,
|
10
|
+
get_file, is_class, is_user_fn, seekable, takewhile,
|
11
|
+
)
|
12
|
+
|
13
|
+
AttrPath = tuple[str, ...]
|
14
|
+
CapturableByFn = dict[Callable, list["Capturable"]]
|
15
|
+
|
16
|
+
class RawFunctionIdent(NamedTuple):
|
17
|
+
fn_hash: str
|
18
|
+
depends: list[Callable]
|
19
|
+
capturables: set["Capturable"]
|
20
|
+
|
21
|
+
class Capturable(NamedTuple):
|
22
|
+
key: str
|
23
|
+
module: ModuleType
|
24
|
+
attr_path: AttrPath
|
25
|
+
hash_by: Callable | None
|
26
|
+
hash: str | None = None
|
27
|
+
|
28
|
+
def capture(self) -> tuple[str, object]:
|
29
|
+
if obj := self.hash:
|
30
|
+
return self.key, obj
|
31
|
+
obj = AttrDict.get_at(self.module, *self.attr_path)
|
32
|
+
obj = self.hash_by(obj) if self.hash_by else obj
|
33
|
+
return self.key, obj
|
34
|
+
|
35
|
+
@staticmethod
|
36
|
+
def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
|
37
|
+
file = str(get_file(module).relative_to(cwd))
|
38
|
+
key = "-".join((file, *attr_path))
|
39
|
+
cap = Capturable(key, module, attr_path, hash_by)
|
40
|
+
if not capture_once:
|
41
|
+
return cap
|
42
|
+
obj_hash = str(ObjectHash(cap.capture()[1]))
|
43
|
+
return Capturable(key, module, attr_path, None, obj_hash)
|
44
|
+
|
45
|
+
def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
|
46
|
+
attr_path = AttrPath(())
|
47
|
+
scope_obj = None
|
48
|
+
classvars: dict[str, dict[str, Type]] = {}
|
49
|
+
instructs = seekable(dis.get_instructions(code))
|
50
|
+
for instr in instructs:
|
51
|
+
if instr.opname in scope_vars and not attr_path:
|
52
|
+
attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
|
53
|
+
attr_path = AttrPath((instr.opname, instr.argval, *attrs))
|
54
|
+
instructs.step(-1)
|
55
|
+
elif instr.opname == "CALL":
|
56
|
+
obj = scope_vars.get_at(*attr_path)
|
57
|
+
attr_path = AttrPath(())
|
58
|
+
if is_class(obj):
|
59
|
+
scope_obj = obj
|
60
|
+
elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
|
61
|
+
load_key = instr.opname.replace("STORE", "LOAD")
|
62
|
+
classvars.setdefault(load_key, {})[instr.argval] = scope_obj
|
63
|
+
scope_obj = None
|
64
|
+
return classvars
|
65
|
+
|
66
|
+
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[AttrPath, object]]:
|
67
|
+
classvars = extract_classvars(code, scope_vars)
|
68
|
+
scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
|
69
|
+
instructs = seekable(dis.get_instructions(code))
|
70
|
+
for instr in instructs:
|
71
|
+
if instr.opname in scope_vars:
|
72
|
+
attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
|
73
|
+
attr_path = AttrPath((instr.opname, instr.argval, *attrs))
|
74
|
+
parent_path = attr_path[:-1]
|
75
|
+
instructs.step(-1)
|
76
|
+
obj = scope_vars.get_at(*attr_path)
|
77
|
+
if obj is not None:
|
78
|
+
yield attr_path, obj
|
79
|
+
if callable(obj) and parent_path[1:]:
|
80
|
+
parent_obj = scope_vars.get_at(*parent_path)
|
81
|
+
yield parent_path, parent_obj
|
82
|
+
for const in code.co_consts:
|
83
|
+
if isinstance(const, CodeType):
|
84
|
+
next_deref = scope_vars.LOAD_DEREF.set(scope_vars.LOAD_FAST)
|
85
|
+
next_scope_vars = AttrDict({**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref})
|
86
|
+
yield from extract_scope_values(const, next_scope_vars)
|
87
|
+
|
88
|
+
def resolve_class_annotations(anno: object) -> Type | None:
|
89
|
+
if anno in (None, Annotated):
|
90
|
+
return None
|
91
|
+
elif is_class(anno):
|
92
|
+
return anno
|
93
|
+
elif get_origin(anno) is Annotated:
|
94
|
+
return resolve_class_annotations(next(iter(get_args(anno)), None))
|
95
|
+
return resolve_class_annotations(get_origin(anno))
|
96
|
+
|
97
|
+
def get_self_value(fn: Callable) -> type | object | None:
|
98
|
+
if isinstance(fn, MethodType):
|
99
|
+
return fn.__self__
|
100
|
+
parts = fn.__qualname__.split(".")[:-1]
|
101
|
+
cls = parts and AttrDict(fn.__globals__).get_at(*parts)
|
102
|
+
if is_class(cls):
|
103
|
+
return cls
|
104
|
+
|
105
|
+
def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, object]) -> Iterable[Capturable]:
|
106
|
+
module = getmodule(fn)
|
107
|
+
if not module or not is_user_fn(fn):
|
108
|
+
return
|
109
|
+
for (instruct, *attr_path), obj in captured_vars.items():
|
110
|
+
attr_path = AttrPath(attr_path)
|
111
|
+
if instruct == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
|
112
|
+
anno = resolve_annotation(module, ".".join(attr_path))
|
113
|
+
if capture or is_capture_me(anno) or is_capture_me_once(anno):
|
114
|
+
hash_by = hash_by_from_annotation(anno)
|
115
|
+
if hash_by is not to_none:
|
116
|
+
yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
|
117
|
+
|
118
|
+
def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
|
119
|
+
sig_scope = {
|
120
|
+
param.name: class_anno
|
121
|
+
for param in signature(fn).parameters.values()
|
122
|
+
if param.annotation is not Parameter.empty
|
123
|
+
if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
|
124
|
+
if (class_anno := resolve_class_annotations(param.annotation))
|
125
|
+
}
|
126
|
+
self_value = get_self_value(fn)
|
127
|
+
scope_vars = AttrDict({
|
128
|
+
"LOAD_FAST": AttrDict({**sig_scope, "self": self_value} if self_value else sig_scope),
|
129
|
+
"LOAD_DEREF": AttrDict(get_cell_contents(fn)),
|
130
|
+
"LOAD_GLOBAL": AttrDict(fn.__globals__),
|
131
|
+
})
|
132
|
+
captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
|
133
|
+
captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
|
134
|
+
capturables = list(get_capturables(fn, capture, captured_vars))
|
135
|
+
return captured_callables, capturables
|
136
|
+
|
137
|
+
def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn = {}) -> CapturableByFn:
|
138
|
+
from .checkpoint import CachedFunction
|
139
|
+
captured_callables, capturables = get_fn_captures(fn, capture)
|
140
|
+
capturable_by_fn = capturable_by_fn or {}
|
141
|
+
capturable_by_fn[fn] = capturables
|
142
|
+
for depend_fn in captured_callables:
|
143
|
+
depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
|
144
|
+
if isinstance(depend_fn, CachedFunction):
|
145
|
+
capturable_by_fn[depend_fn] = []
|
146
|
+
elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
|
147
|
+
get_depend_fns(depend_fn, capture, capturable_by_fn)
|
148
|
+
return capturable_by_fn
|
149
|
+
|
150
|
+
def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
|
151
|
+
from .checkpoint import CachedFunction
|
152
|
+
capturable_by_fn = get_depend_fns(fn, capture)
|
153
|
+
capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
|
154
|
+
depends = capturable_by_fn.keys()
|
155
|
+
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
156
|
+
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
157
|
+
assert fn == unwrapped_depends[0]
|
158
|
+
fn_hash = str(ObjectHash(iter=unwrapped_depends))
|
159
|
+
return RawFunctionIdent(fn_hash, depends, capturables)
|
@@ -0,0 +1,47 @@
|
|
1
|
+
import ast
|
2
|
+
import inspect
|
3
|
+
import sys
|
4
|
+
from types import ModuleType
|
5
|
+
from typing import Iterable, Type
|
6
|
+
from .utils import cwd, get_file, is_user_file
|
7
|
+
|
8
|
+
ImportTarget = tuple[str, str | None]
|
9
|
+
|
10
|
+
cache: dict[tuple[str, int], dict[str, ImportTarget]] = {}
|
11
|
+
|
12
|
+
def generate_import_mappings(module: ModuleType) -> Iterable[tuple[str, ImportTarget]]:
|
13
|
+
mod_path = get_file(module)
|
14
|
+
if not is_user_file(mod_path):
|
15
|
+
return
|
16
|
+
mod_parts = list(mod_path.with_suffix("").relative_to(cwd).parts)
|
17
|
+
source = inspect.getsource(module)
|
18
|
+
tree = ast.parse(source)
|
19
|
+
for node in ast.walk(tree):
|
20
|
+
if isinstance(node, ast.Import):
|
21
|
+
for alias in node.names:
|
22
|
+
yield (alias.asname or alias.name, (alias.name, None))
|
23
|
+
elif isinstance(node, ast.ImportFrom):
|
24
|
+
target_mod = node.module or ""
|
25
|
+
if node.level > 0:
|
26
|
+
target_mod_parts = target_mod.split(".") * bool(target_mod)
|
27
|
+
target_mod_parts = mod_parts[:-node.level] + target_mod_parts
|
28
|
+
target_mod = ".".join(target_mod_parts)
|
29
|
+
for alias in node.names:
|
30
|
+
yield (alias.asname or alias.name, (target_mod, alias.name))
|
31
|
+
|
32
|
+
def get_import_mappings(module: ModuleType) -> dict[str, ImportTarget]:
|
33
|
+
cache_key = (module.__name__, id(module))
|
34
|
+
if cached := cache.get(cache_key):
|
35
|
+
return cached
|
36
|
+
import_mappings = dict(generate_import_mappings(module))
|
37
|
+
return cache.setdefault(cache_key, import_mappings)
|
38
|
+
|
39
|
+
def resolve_annotation(module: ModuleType, attr_name: str | None) -> Type | None:
|
40
|
+
if not attr_name:
|
41
|
+
return None
|
42
|
+
if anno := module.__annotations__.get(attr_name):
|
43
|
+
return anno
|
44
|
+
if next_pair := get_import_mappings(module).get(attr_name):
|
45
|
+
next_module_name, next_attr_name = next_pair
|
46
|
+
if next_module := sys.modules.get(next_module_name):
|
47
|
+
return resolve_annotation(next_module, next_attr_name)
|
@@ -12,7 +12,7 @@ from io import StringIO
|
|
12
12
|
from itertools import chain
|
13
13
|
from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
|
14
14
|
from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
|
15
|
-
from typing import Callable, TypeVar
|
15
|
+
from typing import Callable, Self, TypeVar
|
16
16
|
from .utils import ContextVar
|
17
17
|
|
18
18
|
np, torch = None, None
|
@@ -43,14 +43,14 @@ class ObjectHashError(Exception):
|
|
43
43
|
self.obj = obj
|
44
44
|
|
45
45
|
class ObjectHash:
|
46
|
-
def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64,
|
46
|
+
def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerable=False) -> None:
|
47
47
|
self.hash = hashlib.blake2b(digest_size=digest_size)
|
48
48
|
self.current: dict[int, int] = {}
|
49
|
-
self.
|
49
|
+
self.tolerable = ContextVar(tolerable)
|
50
50
|
self.update(iter=chain(objs, iter))
|
51
51
|
|
52
52
|
def copy(self) -> "ObjectHash":
|
53
|
-
new = ObjectHash(
|
53
|
+
new = ObjectHash(tolerable=self.tolerable.value)
|
54
54
|
new.hash = self.hash.copy()
|
55
55
|
return new
|
56
56
|
|
@@ -63,26 +63,29 @@ class ObjectHash:
|
|
63
63
|
return isinstance(value, ObjectHash) and str(self) == str(value)
|
64
64
|
|
65
65
|
def nested_hash(self, *objs: object) -> str:
|
66
|
-
return ObjectHash(iter=objs,
|
66
|
+
return ObjectHash(iter=objs, tolerable=self.tolerable.value).hexdigest()
|
67
67
|
|
68
|
-
def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) ->
|
68
|
+
def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> Self:
|
69
69
|
for d in chain(data, iter):
|
70
70
|
self.hash.update(d)
|
71
71
|
return self
|
72
72
|
|
73
|
-
def write_text(self, *data: str, iter: Iterable[str] = ()) ->
|
73
|
+
def write_text(self, *data: str, iter: Iterable[str] = ()) -> Self:
|
74
74
|
return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
|
75
75
|
|
76
|
-
def header(self, *args: object) ->
|
76
|
+
def header(self, *args: object) -> Self:
|
77
77
|
return self.write_bytes(":".join(map(str, args)).encode())
|
78
78
|
|
79
|
-
def update(self, *objs: object, iter: Iterable[object] = (),
|
80
|
-
with nullcontext() if
|
79
|
+
def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
|
80
|
+
with nullcontext() if tolerable is None else self.tolerable.set(tolerable):
|
81
81
|
for obj in chain(objs, iter):
|
82
|
+
if header is not None:
|
83
|
+
self.write_bytes(header.encode())
|
84
|
+
header = None
|
82
85
|
try:
|
83
86
|
self._update_one(obj)
|
84
87
|
except Exception as ex:
|
85
|
-
if self.
|
88
|
+
if self.tolerable.value:
|
86
89
|
self.header("error").update(type(ex))
|
87
90
|
else:
|
88
91
|
raise ObjectHashError(obj, ex) from ex
|
@@ -180,10 +183,10 @@ class ObjectHash:
|
|
180
183
|
finally:
|
181
184
|
del self.current[id(obj)]
|
182
185
|
|
183
|
-
def _update_iterator(self, obj: Iterable) ->
|
186
|
+
def _update_iterator(self, obj: Iterable) -> Self:
|
184
187
|
return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
|
185
188
|
|
186
|
-
def _update_object(self, obj: object) ->
|
189
|
+
def _update_object(self, obj: object) -> Self:
|
187
190
|
self.header("instance", encode_type_of(obj))
|
188
191
|
get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
|
189
192
|
if callable(get_hash):
|
@@ -1,9 +1,11 @@
|
|
1
|
-
from typing import Type
|
2
|
-
from .storage import Storage
|
3
|
-
from .pickle_storage import PickleStorage
|
1
|
+
from typing import Literal, Type
|
4
2
|
from .memory_storage import MemoryStorage
|
3
|
+
from .pickle_storage import PickleStorage
|
4
|
+
from .storage import Storage
|
5
|
+
|
6
|
+
StorageType = Literal["pickle", "memory"]
|
5
7
|
|
6
|
-
STORAGE_MAP: dict[
|
8
|
+
STORAGE_MAP: dict[StorageType, Type[Storage]] = {
|
7
9
|
"pickle": PickleStorage,
|
8
10
|
"memory": MemoryStorage,
|
9
11
|
}
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from typing import (
|
2
|
+
Annotated, Callable, Coroutine, Generic,
|
3
|
+
ParamSpec, TypeVar, get_args, get_origin,
|
4
|
+
)
|
5
|
+
|
6
|
+
Fn = TypeVar("Fn", bound=Callable)
|
7
|
+
P = ParamSpec("P")
|
8
|
+
R = TypeVar("R")
|
9
|
+
C = TypeVar("C")
|
10
|
+
T = TypeVar("T")
|
11
|
+
|
12
|
+
class HashBy(Generic[Fn]):
|
13
|
+
pass
|
14
|
+
|
15
|
+
class Captured:
|
16
|
+
pass
|
17
|
+
|
18
|
+
class CapturedOnce:
|
19
|
+
pass
|
20
|
+
|
21
|
+
def to_none(_):
|
22
|
+
return None
|
23
|
+
|
24
|
+
def get_annotated_args(anno: object) -> tuple[object, ...]:
|
25
|
+
return get_args(anno) if get_origin(anno) is Annotated else ()
|
26
|
+
|
27
|
+
def hash_by_from_annotation(anno: object) -> Callable[[object], object] | None:
|
28
|
+
for arg in get_annotated_args(anno):
|
29
|
+
if get_origin(arg) is HashBy:
|
30
|
+
return get_args(arg)[0]
|
31
|
+
|
32
|
+
def is_capture_me(anno: object) -> bool:
|
33
|
+
return Captured in get_annotated_args(anno)
|
34
|
+
|
35
|
+
def is_capture_me_once(anno: object) -> bool:
|
36
|
+
return CapturedOnce in get_annotated_args(anno)
|
37
|
+
|
38
|
+
NoHash = Annotated[T, HashBy[to_none]]
|
39
|
+
CaptureMe = Annotated[T, Captured]
|
40
|
+
CaptureMeOnce = Annotated[T, CapturedOnce]
|
41
|
+
Coro = Coroutine[object, object, R]
|
42
|
+
|
43
|
+
class AwaitableValue(Generic[T]):
|
44
|
+
def __init__(self, value: T):
|
45
|
+
self.value = value
|
46
|
+
|
47
|
+
def __await__(self):
|
48
|
+
yield
|
49
|
+
return self.value
|
@@ -0,0 +1,115 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import inspect
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from itertools import islice
|
5
|
+
from pathlib import Path
|
6
|
+
from types import FunctionType, MethodType, ModuleType
|
7
|
+
from typing import Callable, Generic, Iterable, Self, Type, TypeGuard
|
8
|
+
from .types import T
|
9
|
+
|
10
|
+
cwd = Path.cwd().resolve()
|
11
|
+
|
12
|
+
def is_class(obj) -> TypeGuard[Type]:
|
13
|
+
return isinstance(obj, type)
|
14
|
+
|
15
|
+
def get_file(obj: Callable | ModuleType) -> Path:
|
16
|
+
return Path(inspect.getfile(obj)).resolve()
|
17
|
+
|
18
|
+
def is_user_file(path: Path) -> bool:
|
19
|
+
return cwd in path.parents and ".venv" not in path.parts
|
20
|
+
|
21
|
+
def is_user_fn(obj) -> TypeGuard[Callable]:
|
22
|
+
return isinstance(obj, (FunctionType, MethodType)) and is_user_file(get_file(obj))
|
23
|
+
|
24
|
+
def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
|
25
|
+
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
|
26
|
+
try:
|
27
|
+
yield (key, cell.cell_contents)
|
28
|
+
except ValueError:
|
29
|
+
pass
|
30
|
+
|
31
|
+
def distinct(seq: Iterable[T]) -> list[T]:
|
32
|
+
return list(dict.fromkeys(seq))
|
33
|
+
|
34
|
+
def takewhile(iter: Iterable[tuple[bool, T]]) -> Iterable[T]:
|
35
|
+
for condition, value in iter:
|
36
|
+
if not condition:
|
37
|
+
return
|
38
|
+
yield value
|
39
|
+
|
40
|
+
class seekable(Generic[T]):
|
41
|
+
def __init__(self, iterable: Iterable[T]):
|
42
|
+
self.index = 0
|
43
|
+
self.source = iter(iterable)
|
44
|
+
self.sink: list[T] = []
|
45
|
+
|
46
|
+
def __iter__(self):
|
47
|
+
return self
|
48
|
+
|
49
|
+
def __next__(self) -> T:
|
50
|
+
if len(self.sink) > self.index:
|
51
|
+
item = self.sink[self.index]
|
52
|
+
else:
|
53
|
+
item = next(self.source)
|
54
|
+
self.sink.append(item)
|
55
|
+
self.index += 1
|
56
|
+
return item
|
57
|
+
|
58
|
+
def __bool__(self):
|
59
|
+
return bool(self.lookahead(1))
|
60
|
+
|
61
|
+
def seek(self, index: int) -> Self:
|
62
|
+
remainder = index - len(self.sink)
|
63
|
+
if remainder > 0:
|
64
|
+
next(islice(self, remainder, remainder), None)
|
65
|
+
self.index = max(0, min(index, len(self.sink)))
|
66
|
+
return self
|
67
|
+
|
68
|
+
def step(self, count: int) -> Self:
|
69
|
+
return self.seek(self.index + count)
|
70
|
+
|
71
|
+
@contextmanager
|
72
|
+
def freeze(self):
|
73
|
+
initial_index = self.index
|
74
|
+
try:
|
75
|
+
yield
|
76
|
+
finally:
|
77
|
+
self.seek(initial_index)
|
78
|
+
|
79
|
+
def lookahead(self, count: int) -> list[T]:
|
80
|
+
with self.freeze():
|
81
|
+
return list(islice(self, count))
|
82
|
+
|
83
|
+
class AttrDict(dict):
|
84
|
+
def __init__(self, *args, **kwargs):
|
85
|
+
super().__init__(*args, **kwargs)
|
86
|
+
self.__dict__ = self
|
87
|
+
|
88
|
+
def __getattribute__(self, name: str):
|
89
|
+
return super().__getattribute__(name)
|
90
|
+
|
91
|
+
def __setattr__(self, name: str, value: object):
|
92
|
+
super().__setattr__(name, value)
|
93
|
+
|
94
|
+
def set(self, d: dict) -> AttrDict:
|
95
|
+
if not d:
|
96
|
+
return self
|
97
|
+
return AttrDict({**self, **d})
|
98
|
+
|
99
|
+
def get_at(self: object, *attrs: str) -> object:
|
100
|
+
obj = self
|
101
|
+
for attr in attrs:
|
102
|
+
obj = getattr(obj, attr, None)
|
103
|
+
return obj
|
104
|
+
|
105
|
+
class ContextVar(Generic[T]):
|
106
|
+
def __init__(self, value: T):
|
107
|
+
self.value = value
|
108
|
+
|
109
|
+
@contextmanager
|
110
|
+
def set(self, value: T):
|
111
|
+
self.value, old = value, self.value
|
112
|
+
try:
|
113
|
+
yield
|
114
|
+
finally:
|
115
|
+
self.value = old
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "checkpointer"
|
3
|
-
version = "2.
|
3
|
+
version = "2.12.0"
|
4
4
|
requires-python = ">=3.11"
|
5
5
|
dependencies = []
|
6
6
|
authors = [
|
@@ -8,12 +8,29 @@ authors = [
|
|
8
8
|
]
|
9
9
|
description = "checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes."
|
10
10
|
readme = "README.md"
|
11
|
-
license =
|
11
|
+
license = "MIT"
|
12
|
+
license-files = ["LICENSE", "ATTRIBUTION.md"]
|
12
13
|
classifiers = [
|
13
14
|
"Programming Language :: Python :: 3.11",
|
14
15
|
"Programming Language :: Python :: 3.12",
|
15
16
|
"Programming Language :: Python :: 3.13",
|
16
17
|
]
|
18
|
+
keywords = [
|
19
|
+
"cache",
|
20
|
+
"caching",
|
21
|
+
"memoize",
|
22
|
+
"memoization",
|
23
|
+
"performance",
|
24
|
+
"fast",
|
25
|
+
"memory",
|
26
|
+
"invalidation",
|
27
|
+
"code-aware",
|
28
|
+
"optimization",
|
29
|
+
"hashing",
|
30
|
+
"async",
|
31
|
+
"workflow",
|
32
|
+
"decorator",
|
33
|
+
]
|
17
34
|
|
18
35
|
[project.urls]
|
19
36
|
Repository = "https://github.com/Reddan/checkpointer.git"
|
@@ -30,7 +47,7 @@ dev = [
|
|
30
47
|
]
|
31
48
|
|
32
49
|
[tool.poe.tasks]
|
33
|
-
tests = "pytest checkpointer
|
50
|
+
tests = "pytest checkpointer/**/test_*.py"
|
34
51
|
tests-debug = "poe tests -s"
|
35
52
|
|
36
53
|
[build-system]
|
@@ -40,5 +57,8 @@ build-backend = "hatchling.build"
|
|
40
57
|
[tool.hatch.build.targets.wheel]
|
41
58
|
packages = ["checkpointer", "checkpointer.storages"]
|
42
59
|
|
60
|
+
[tool.hatch.build]
|
61
|
+
exclude = ["test_*.py"]
|
62
|
+
|
43
63
|
[tool.pytest.ini_options]
|
44
64
|
asyncio_default_fixture_loop_scope = "session"
|
@@ -1,103 +0,0 @@
|
|
1
|
-
import dis
|
2
|
-
import inspect
|
3
|
-
from itertools import takewhile
|
4
|
-
from pathlib import Path
|
5
|
-
from types import CodeType, FunctionType, MethodType
|
6
|
-
from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
|
7
|
-
from .object_hash import ObjectHash
|
8
|
-
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
|
9
|
-
|
10
|
-
cwd = Path.cwd().resolve()
|
11
|
-
|
12
|
-
class RawFunctionIdent(NamedTuple):
|
13
|
-
fn_hash: str
|
14
|
-
captured_hash: str
|
15
|
-
depends: list[Callable]
|
16
|
-
|
17
|
-
def is_class(obj) -> TypeGuard[Type]:
|
18
|
-
# isinstance works too, but needlessly triggers _lazyinit()
|
19
|
-
return issubclass(type(obj), type)
|
20
|
-
|
21
|
-
def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
|
22
|
-
attr_path: tuple[str, ...] = ()
|
23
|
-
scope_obj = None
|
24
|
-
classvars: dict[str, dict[str, Type]] = {}
|
25
|
-
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
26
|
-
if instr.opname in scope_vars and not attr_path:
|
27
|
-
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
28
|
-
attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
29
|
-
elif instr.opname == "CALL":
|
30
|
-
obj = scope_vars.get_at(attr_path)
|
31
|
-
attr_path = ()
|
32
|
-
if is_class(obj):
|
33
|
-
scope_obj = obj
|
34
|
-
elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
|
35
|
-
load_key = instr.opname.replace("STORE", "LOAD")
|
36
|
-
classvars.setdefault(load_key, {})[instr.argval] = scope_obj
|
37
|
-
scope_obj = None
|
38
|
-
return classvars
|
39
|
-
|
40
|
-
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
|
41
|
-
classvars = extract_classvars(code, scope_vars)
|
42
|
-
scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
|
43
|
-
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
44
|
-
if instr.opname in scope_vars:
|
45
|
-
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
46
|
-
attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
47
|
-
val = scope_vars.get_at(attr_path)
|
48
|
-
if val is not None:
|
49
|
-
yield attr_path, val
|
50
|
-
for const in code.co_consts:
|
51
|
-
if isinstance(const, CodeType):
|
52
|
-
yield from extract_scope_values(const, scope_vars)
|
53
|
-
|
54
|
-
def get_self_value(fn: Callable) -> type | object | None:
|
55
|
-
if isinstance(fn, MethodType):
|
56
|
-
return fn.__self__
|
57
|
-
parts = tuple(fn.__qualname__.split(".")[:-1])
|
58
|
-
cls = parts and AttrDict(fn.__globals__).get_at(parts)
|
59
|
-
if is_class(cls):
|
60
|
-
return cls
|
61
|
-
|
62
|
-
def get_fn_captured_vals(fn: Callable) -> list[object]:
|
63
|
-
self_value = get_self_value(fn)
|
64
|
-
scope_vars = AttrDict({
|
65
|
-
"LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
|
66
|
-
"LOAD_DEREF": AttrDict(get_cell_contents(fn)),
|
67
|
-
"LOAD_GLOBAL": AttrDict(fn.__globals__),
|
68
|
-
})
|
69
|
-
vals = dict(extract_scope_values(fn.__code__, scope_vars))
|
70
|
-
return list(vals.values())
|
71
|
-
|
72
|
-
def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
73
|
-
if not isinstance(candidate_fn, (FunctionType, MethodType)):
|
74
|
-
return False
|
75
|
-
fn_path = Path(inspect.getfile(candidate_fn)).resolve()
|
76
|
-
return cwd in fn_path.parents and ".venv" not in fn_path.parts
|
77
|
-
|
78
|
-
def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
|
79
|
-
from .checkpoint import CachedFunction
|
80
|
-
captured_vals = get_fn_captured_vals(fn)
|
81
|
-
captured_vals_by_fn = captured_vals_by_fn or {}
|
82
|
-
captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
|
83
|
-
for val in captured_vals:
|
84
|
-
if not callable(val):
|
85
|
-
continue
|
86
|
-
child_fn = unwrap_fn(val, cached_fn=True)
|
87
|
-
if isinstance(child_fn, CachedFunction):
|
88
|
-
captured_vals_by_fn[child_fn] = []
|
89
|
-
elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
|
90
|
-
get_depend_fns(child_fn, captured_vals_by_fn)
|
91
|
-
return captured_vals_by_fn
|
92
|
-
|
93
|
-
def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
|
94
|
-
from .checkpoint import CachedFunction
|
95
|
-
captured_vals_by_fn = get_depend_fns(fn)
|
96
|
-
depend_captured_vals = list(captured_vals_by_fn.values()) * capture
|
97
|
-
depends = captured_vals_by_fn.keys()
|
98
|
-
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
99
|
-
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
100
|
-
assert fn == unwrapped_depends[0]
|
101
|
-
fn_hash = str(ObjectHash(iter=unwrapped_depends))
|
102
|
-
captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
|
103
|
-
return RawFunctionIdent(fn_hash, captured_hash, depends)
|
@@ -1,168 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import pytest
|
3
|
-
from checkpointer import CheckpointError, checkpoint
|
4
|
-
from .utils import AttrDict
|
5
|
-
|
6
|
-
def global_multiply(a: int, b: int) -> int:
|
7
|
-
return a * b
|
8
|
-
|
9
|
-
@pytest.fixture(autouse=True)
|
10
|
-
def run_before_and_after_tests(tmpdir):
|
11
|
-
global checkpoint
|
12
|
-
checkpoint = checkpoint(root_path=tmpdir)
|
13
|
-
yield
|
14
|
-
|
15
|
-
def test_basic_caching():
|
16
|
-
@checkpoint
|
17
|
-
def square(x: int) -> int:
|
18
|
-
return x ** 2
|
19
|
-
|
20
|
-
result1 = square(4)
|
21
|
-
result2 = square(4)
|
22
|
-
|
23
|
-
assert result1 == result2 == 16
|
24
|
-
|
25
|
-
def test_cache_invalidation():
|
26
|
-
@checkpoint
|
27
|
-
def multiply(a: int, b: int):
|
28
|
-
return a * b
|
29
|
-
|
30
|
-
@checkpoint
|
31
|
-
def helper(x: int):
|
32
|
-
return multiply(x + 1, 2)
|
33
|
-
|
34
|
-
@checkpoint
|
35
|
-
def compute(a: int, b: int):
|
36
|
-
return helper(a) + helper(b)
|
37
|
-
|
38
|
-
result1 = compute(3, 4)
|
39
|
-
assert result1 == 18
|
40
|
-
|
41
|
-
def test_layered_caching():
|
42
|
-
dev_checkpoint = checkpoint(when=True)
|
43
|
-
|
44
|
-
@checkpoint(format="memory")
|
45
|
-
@dev_checkpoint
|
46
|
-
def expensive_function(x: int):
|
47
|
-
return x ** 2
|
48
|
-
|
49
|
-
assert expensive_function(4) == 16
|
50
|
-
assert expensive_function(4) == 16
|
51
|
-
|
52
|
-
def test_recursive_caching1():
|
53
|
-
@checkpoint
|
54
|
-
def fib(n: int) -> int:
|
55
|
-
return fib(n - 1) + fib(n - 2) if n > 1 else n
|
56
|
-
|
57
|
-
assert fib(10) == 55
|
58
|
-
assert fib.get(10) == 55
|
59
|
-
assert fib.get(5) == 5
|
60
|
-
|
61
|
-
def test_recursive_caching2():
|
62
|
-
@checkpoint
|
63
|
-
def fib(n: int) -> int:
|
64
|
-
return fib.fn(n - 1) + fib.fn(n - 2) if n > 1 else n
|
65
|
-
|
66
|
-
assert fib(10) == 55
|
67
|
-
assert fib.get(10) == 55
|
68
|
-
with pytest.raises(CheckpointError):
|
69
|
-
fib.get(5)
|
70
|
-
|
71
|
-
@pytest.mark.asyncio
|
72
|
-
async def test_async_caching():
|
73
|
-
@checkpoint(format="memory")
|
74
|
-
async def async_square(x: int) -> int:
|
75
|
-
await asyncio.sleep(0.1)
|
76
|
-
return x ** 2
|
77
|
-
|
78
|
-
result1 = await async_square(3)
|
79
|
-
result2 = await async_square(3)
|
80
|
-
result3 = async_square.get(3)
|
81
|
-
|
82
|
-
assert result1 == result2 == result3 == 9
|
83
|
-
|
84
|
-
def test_force_recalculation():
|
85
|
-
@checkpoint
|
86
|
-
def square(x: int) -> int:
|
87
|
-
return x ** 2
|
88
|
-
|
89
|
-
assert square(5) == 25
|
90
|
-
square.rerun(5)
|
91
|
-
assert square.get(5) == 25
|
92
|
-
|
93
|
-
def test_multi_layer_decorator():
|
94
|
-
@checkpoint(format="memory")
|
95
|
-
@checkpoint(format="pickle")
|
96
|
-
def add(a: int, b: int) -> int:
|
97
|
-
return a + b
|
98
|
-
|
99
|
-
assert add(2, 3) == 5
|
100
|
-
assert add.get(2, 3) == 5
|
101
|
-
|
102
|
-
def test_capture():
|
103
|
-
item_dict = AttrDict({"a": 1, "b": 1})
|
104
|
-
|
105
|
-
@checkpoint(capture=True)
|
106
|
-
def test_whole():
|
107
|
-
return item_dict
|
108
|
-
|
109
|
-
@checkpoint(capture=True)
|
110
|
-
def test_a():
|
111
|
-
return item_dict.a + 1
|
112
|
-
|
113
|
-
init_hash_a = test_a.ident.captured_hash
|
114
|
-
init_hash_whole = test_whole.ident.captured_hash
|
115
|
-
item_dict.b += 1
|
116
|
-
test_whole.reinit()
|
117
|
-
test_a.reinit()
|
118
|
-
assert test_whole.ident.captured_hash != init_hash_whole
|
119
|
-
assert test_a.ident.captured_hash == init_hash_a
|
120
|
-
item_dict.a += 1
|
121
|
-
test_a.reinit()
|
122
|
-
assert test_a.ident.captured_hash != init_hash_a
|
123
|
-
|
124
|
-
def test_depends():
|
125
|
-
def multiply_wrapper(a: int, b: int) -> int:
|
126
|
-
return global_multiply(a, b)
|
127
|
-
|
128
|
-
def helper(a: int, b: int) -> int:
|
129
|
-
return multiply_wrapper(a + 1, b + 1)
|
130
|
-
|
131
|
-
@checkpoint
|
132
|
-
def test_a(a: int, b: int) -> int:
|
133
|
-
return helper(a, b)
|
134
|
-
|
135
|
-
@checkpoint
|
136
|
-
def test_b(a: int, b: int) -> int:
|
137
|
-
return test_a(a, b) + multiply_wrapper(a, b)
|
138
|
-
|
139
|
-
assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
|
140
|
-
assert set(test_b.depends) == {test_b.fn, test_a, multiply_wrapper, global_multiply}
|
141
|
-
|
142
|
-
def test_lazy_init_1():
|
143
|
-
@checkpoint
|
144
|
-
def fn1(x: object) -> object:
|
145
|
-
return fn2(x)
|
146
|
-
|
147
|
-
@checkpoint
|
148
|
-
def fn2(x: object) -> object:
|
149
|
-
return fn1(x)
|
150
|
-
|
151
|
-
assert set(fn1.depends) == {fn1.fn, fn2}
|
152
|
-
assert set(fn2.depends) == {fn1, fn2.fn}
|
153
|
-
|
154
|
-
def test_lazy_init_2():
|
155
|
-
@checkpoint
|
156
|
-
def fn1(x: object) -> object:
|
157
|
-
return fn2(x)
|
158
|
-
|
159
|
-
assert set(fn1.depends) == {fn1.fn}
|
160
|
-
|
161
|
-
@checkpoint
|
162
|
-
def fn2(x: object) -> object:
|
163
|
-
return fn1(x)
|
164
|
-
|
165
|
-
assert set(fn1.depends) == {fn1.fn}
|
166
|
-
fn1.reinit()
|
167
|
-
assert set(fn1.depends) == {fn1.fn, fn2}
|
168
|
-
assert set(fn2.depends) == {fn1, fn2.fn}
|
@@ -1,21 +0,0 @@
|
|
1
|
-
from typing import Annotated, Callable, Coroutine, Generic, ParamSpec, TypeVar
|
2
|
-
|
3
|
-
Fn = TypeVar("Fn", bound=Callable)
|
4
|
-
P = ParamSpec("P")
|
5
|
-
R = TypeVar("R")
|
6
|
-
C = TypeVar("C")
|
7
|
-
T = TypeVar("T")
|
8
|
-
|
9
|
-
class HashBy(Generic[Fn]):
|
10
|
-
pass
|
11
|
-
|
12
|
-
NoHash = Annotated[T, HashBy[lambda _: None]]
|
13
|
-
Coro = Coroutine[object, object, R]
|
14
|
-
|
15
|
-
class AwaitableValue(Generic[T]):
|
16
|
-
def __init__(self, value: T):
|
17
|
-
self.value = value
|
18
|
-
|
19
|
-
def __await__(self):
|
20
|
-
yield
|
21
|
-
return self.value
|
@@ -1,83 +0,0 @@
|
|
1
|
-
from contextlib import contextmanager
|
2
|
-
from typing import Callable, Generic, Iterable, cast
|
3
|
-
from .types import Fn, T
|
4
|
-
|
5
|
-
def distinct(seq: Iterable[T]) -> list[T]:
|
6
|
-
return list(dict.fromkeys(seq))
|
7
|
-
|
8
|
-
def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
|
9
|
-
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
|
10
|
-
try:
|
11
|
-
yield (key, cell.cell_contents)
|
12
|
-
except ValueError:
|
13
|
-
pass
|
14
|
-
|
15
|
-
def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
|
16
|
-
from .checkpoint import CachedFunction
|
17
|
-
while True:
|
18
|
-
if (cached_fn and isinstance(fn, CachedFunction)) or not hasattr(fn, "__wrapped__"):
|
19
|
-
return cast(Fn, fn)
|
20
|
-
fn = getattr(fn, "__wrapped__")
|
21
|
-
|
22
|
-
class AttrDict(dict):
|
23
|
-
def __init__(self, *args, **kwargs):
|
24
|
-
super().__init__(*args, **kwargs)
|
25
|
-
self.__dict__ = self
|
26
|
-
|
27
|
-
def __getattribute__(self, name: str):
|
28
|
-
return super().__getattribute__(name)
|
29
|
-
|
30
|
-
def __setattr__(self, name: str, value: object):
|
31
|
-
super().__setattr__(name, value)
|
32
|
-
|
33
|
-
def set(self, d: dict) -> "AttrDict":
|
34
|
-
if not d:
|
35
|
-
return self
|
36
|
-
return AttrDict({**self, **d})
|
37
|
-
|
38
|
-
def delete(self, *attrs: str) -> "AttrDict":
|
39
|
-
d = AttrDict(self)
|
40
|
-
for attr in attrs:
|
41
|
-
del d[attr]
|
42
|
-
return d
|
43
|
-
|
44
|
-
def get_at(self, attrs: tuple[str, ...]) -> object:
|
45
|
-
d = self
|
46
|
-
for attr in attrs:
|
47
|
-
d = getattr(d, attr, None)
|
48
|
-
return d
|
49
|
-
|
50
|
-
class ContextVar(Generic[T]):
|
51
|
-
def __init__(self, value: T):
|
52
|
-
self.value = value
|
53
|
-
|
54
|
-
@contextmanager
|
55
|
-
def set(self, value: T):
|
56
|
-
self.value, old = value, self.value
|
57
|
-
try:
|
58
|
-
yield
|
59
|
-
finally:
|
60
|
-
self.value = old
|
61
|
-
|
62
|
-
class iterate_and_upcoming(Generic[T]):
|
63
|
-
def __init__(self, it: Iterable[T]) -> None:
|
64
|
-
self.it = iter(it)
|
65
|
-
self.previous: tuple[()] | tuple[T] = ()
|
66
|
-
self.tracked = self._tracked_iter()
|
67
|
-
|
68
|
-
def __iter__(self):
|
69
|
-
return self
|
70
|
-
|
71
|
-
def __next__(self) -> tuple[T, Iterable[T]]:
|
72
|
-
try:
|
73
|
-
item = self.previous[0] if self.previous else next(self.it)
|
74
|
-
self.previous = ()
|
75
|
-
return item, self.tracked
|
76
|
-
except StopIteration:
|
77
|
-
self.tracked.close()
|
78
|
-
raise
|
79
|
-
|
80
|
-
def _tracked_iter(self):
|
81
|
-
for x in self.it:
|
82
|
-
self.previous = (x,)
|
83
|
-
yield x
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|