checkpointer 2.11.2__py3-none-any.whl → 2.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer/__init__.py +4 -4
- checkpointer/checkpoint.py +125 -107
- checkpointer/fn_ident.py +129 -71
- checkpointer/fn_string.py +77 -0
- checkpointer/import_mappings.py +47 -0
- checkpointer/object_hash.py +37 -29
- checkpointer/storages/__init__.py +6 -4
- checkpointer/storages/storage.py +5 -4
- checkpointer/types.py +30 -2
- checkpointer/utils.py +84 -54
- checkpointer-2.13.0.dist-info/METADATA +260 -0
- checkpointer-2.13.0.dist-info/RECORD +18 -0
- checkpointer-2.13.0.dist-info/licenses/ATTRIBUTION.md +33 -0
- {checkpointer-2.11.2.dist-info → checkpointer-2.13.0.dist-info}/licenses/LICENSE +2 -0
- checkpointer/test_checkpointer.py +0 -168
- checkpointer-2.11.2.dist-info/METADATA +0 -240
- checkpointer-2.11.2.dist-info/RECORD +0 -16
- {checkpointer-2.11.2.dist-info → checkpointer-2.13.0.dist-info}/WHEEL +0 -0
checkpointer/__init__.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
import gc
|
2
2
|
import tempfile
|
3
3
|
from typing import Callable
|
4
|
-
from .checkpoint import CachedFunction, Checkpointer, CheckpointError
|
4
|
+
from .checkpoint import CachedFunction, Checkpointer, CheckpointError, FunctionIdent
|
5
5
|
from .object_hash import ObjectHash
|
6
6
|
from .storages import MemoryStorage, PickleStorage, Storage
|
7
|
-
from .types import AwaitableValue, HashBy, NoHash
|
7
|
+
from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeOnce, HashBy, NoHash
|
8
8
|
|
9
9
|
checkpoint = Checkpointer()
|
10
10
|
capture_checkpoint = Checkpointer(capture=True)
|
11
|
-
memory_checkpoint = Checkpointer(
|
12
|
-
tmp_checkpoint = Checkpointer(
|
11
|
+
memory_checkpoint = Checkpointer(storage="memory", verbosity=0)
|
12
|
+
tmp_checkpoint = Checkpointer(directory=f"{tempfile.gettempdir()}/checkpoints")
|
13
13
|
static_checkpoint = Checkpointer(fn_hash_from=())
|
14
14
|
|
15
15
|
def cleanup_all(invalidated=True, expired=True):
|
checkpointer/checkpoint.py
CHANGED
@@ -2,30 +2,27 @@ from __future__ import annotations
|
|
2
2
|
import re
|
3
3
|
from datetime import datetime
|
4
4
|
from functools import cached_property, update_wrapper
|
5
|
-
from inspect import Parameter,
|
5
|
+
from inspect import Parameter, iscoroutine, signature, unwrap
|
6
|
+
from itertools import chain
|
6
7
|
from pathlib import Path
|
7
8
|
from typing import (
|
8
|
-
|
9
|
-
|
10
|
-
Unpack, cast, get_args, get_origin, overload,
|
9
|
+
Callable, Concatenate, Coroutine, Generic, Iterable,
|
10
|
+
Literal, Self, Type, TypedDict, Unpack, cast, overload,
|
11
11
|
)
|
12
|
-
from .fn_ident import RawFunctionIdent, get_fn_ident
|
12
|
+
from .fn_ident import Capturable, RawFunctionIdent, get_fn_ident
|
13
13
|
from .object_hash import ObjectHash
|
14
14
|
from .print_checkpoint import print_checkpoint
|
15
|
-
from .storages import STORAGE_MAP, Storage
|
16
|
-
from .types import AwaitableValue, C, Coro, Fn,
|
17
|
-
from .utils import unwrap_fn
|
15
|
+
from .storages import STORAGE_MAP, Storage, StorageType
|
16
|
+
from .types import AwaitableValue, C, Coro, Fn, P, R, hash_by_from_annotation
|
18
17
|
|
19
18
|
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
20
19
|
|
21
|
-
empty_set = cast(set, frozenset())
|
22
|
-
|
23
20
|
class CheckpointError(Exception):
|
24
21
|
pass
|
25
22
|
|
26
23
|
class CheckpointerOpts(TypedDict, total=False):
|
27
|
-
|
28
|
-
|
24
|
+
storage: Type[Storage] | StorageType
|
25
|
+
directory: Path | str | None
|
29
26
|
when: bool
|
30
27
|
verbosity: Literal[0, 1, 2]
|
31
28
|
should_expire: Callable[[datetime], bool] | None
|
@@ -34,8 +31,8 @@ class CheckpointerOpts(TypedDict, total=False):
|
|
34
31
|
|
35
32
|
class Checkpointer:
|
36
33
|
def __init__(self, **opts: Unpack[CheckpointerOpts]):
|
37
|
-
self.
|
38
|
-
self.
|
34
|
+
self.storage = opts.get("storage", "pickle")
|
35
|
+
self.directory = Path(opts.get("directory", DEFAULT_DIR) or ".")
|
39
36
|
self.when = opts.get("when", True)
|
40
37
|
self.verbosity = opts.get("verbosity", 1)
|
41
38
|
self.should_expire = opts.get("should_expire")
|
@@ -59,119 +56,164 @@ class FunctionIdent:
|
|
59
56
|
Separated from CachedFunction to prevent hash desynchronization
|
60
57
|
among bound instances when `.reinit()` is called.
|
61
58
|
"""
|
62
|
-
|
63
|
-
|
59
|
+
__slots__ = (
|
60
|
+
"checkpointer", "cached_fn", "fn", "fn_dir", "pos_names",
|
61
|
+
"arg_names", "default_args", "hash_by_map", "__dict__",
|
62
|
+
)
|
63
|
+
|
64
|
+
def __init__(self, cached_fn: CachedFunction, checkpointer: Checkpointer, fn: Callable):
|
65
|
+
wrapped = unwrap(fn)
|
66
|
+
fn_file = Path(wrapped.__code__.co_filename).name
|
67
|
+
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
68
|
+
params = list(signature(wrapped).parameters.values())
|
69
|
+
pos_param_types = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
70
|
+
named_param_types = (Parameter.KEYWORD_ONLY,) + pos_param_types
|
71
|
+
name_by_kind = {Parameter.VAR_POSITIONAL: b"*", Parameter.VAR_KEYWORD: b"**"}
|
72
|
+
self.checkpointer = checkpointer
|
64
73
|
self.cached_fn = cached_fn
|
74
|
+
self.fn = fn
|
75
|
+
self.fn_dir = f"{fn_file}/{fn_name}"
|
76
|
+
self.pos_names = [param.name for param in params if param.kind in pos_param_types]
|
77
|
+
self.arg_names = {param.name for param in params if param.kind in named_param_types}
|
78
|
+
self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
|
79
|
+
self.hash_by_map = {
|
80
|
+
name_by_kind.get(param.kind, param.name): hash_by
|
81
|
+
for param in params
|
82
|
+
if (hash_by := hash_by_from_annotation(param.annotation))
|
83
|
+
}
|
84
|
+
|
85
|
+
def reset(self):
|
86
|
+
self.__dict__.clear()
|
87
|
+
|
88
|
+
def is_static(self) -> bool:
|
89
|
+
return self.checkpointer.fn_hash_from is not None
|
65
90
|
|
66
91
|
@cached_property
|
67
92
|
def raw_ident(self) -> RawFunctionIdent:
|
68
|
-
return get_fn_ident(
|
93
|
+
return get_fn_ident(unwrap(self.fn), self.checkpointer.capture)
|
69
94
|
|
70
95
|
@cached_property
|
71
96
|
def fn_hash(self) -> str:
|
72
|
-
if
|
73
|
-
return str(ObjectHash(
|
74
|
-
|
97
|
+
if self.is_static():
|
98
|
+
return str(ObjectHash(self.checkpointer.fn_hash_from, digest_size=16))
|
99
|
+
depends = self.deep_idents(past_static=False)
|
100
|
+
deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
|
75
101
|
return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
|
76
102
|
|
77
103
|
@cached_property
|
78
|
-
def
|
79
|
-
|
80
|
-
|
104
|
+
def capturables(self) -> list[Capturable]:
|
105
|
+
return sorted({
|
106
|
+
capturable.key: capturable
|
107
|
+
for depend in self.deep_idents()
|
108
|
+
for capturable in depend.raw_ident.capturables
|
109
|
+
}.values())
|
110
|
+
|
111
|
+
def deep_depends(self, past_static=True, visited: set[Callable] = set()) -> Iterable[Callable]:
|
112
|
+
if self.cached_fn not in visited:
|
113
|
+
yield self.cached_fn
|
114
|
+
visited = visited or set()
|
115
|
+
visited.add(self.cached_fn)
|
116
|
+
stop = not past_static and self.is_static()
|
117
|
+
depends = [] if stop else self.raw_ident.depends
|
118
|
+
for depend in depends:
|
119
|
+
if isinstance(depend, CachedFunction):
|
120
|
+
yield from depend.ident.deep_depends(past_static, visited)
|
121
|
+
elif depend not in visited:
|
122
|
+
yield depend
|
123
|
+
visited.add(depend)
|
81
124
|
|
82
|
-
def
|
83
|
-
self.
|
125
|
+
def deep_idents(self, past_static=True) -> Iterable[FunctionIdent]:
|
126
|
+
return (fn.ident for fn in self.deep_depends(past_static) if isinstance(fn, CachedFunction))
|
84
127
|
|
85
128
|
class CachedFunction(Generic[Fn]):
|
86
129
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
update_wrapper(cast(Callable, self), wrapped)
|
92
|
-
self.checkpointer = checkpointer
|
93
|
-
self.fn = fn
|
94
|
-
self.fn_dir = f"{fn_file}/{fn_name}"
|
130
|
+
store_format = checkpointer.storage
|
131
|
+
Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
|
132
|
+
update_wrapper(cast(Callable, self), unwrap(fn))
|
133
|
+
self.ident = FunctionIdent(self, checkpointer, fn)
|
95
134
|
self.storage = Storage(self)
|
96
|
-
self.cleanup = self.storage.cleanup
|
97
135
|
self.bound = ()
|
98
|
-
self.attrname: str | None = None
|
99
|
-
|
100
|
-
sig = signature(wrapped)
|
101
|
-
params = list(sig.parameters.items())
|
102
|
-
pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
103
|
-
self.arg_names = [name for name, param in params if param.kind in pos_params]
|
104
|
-
self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
|
105
|
-
self.hash_by_map = get_hash_by_map(sig)
|
106
|
-
self.ident = FunctionIdent(self)
|
107
|
-
|
108
|
-
def __set_name__(self, _, name: str):
|
109
|
-
assert self.attrname is None
|
110
|
-
self.attrname = name
|
111
136
|
|
112
137
|
@overload
|
113
138
|
def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
|
114
139
|
@overload
|
115
|
-
def __get__(
|
140
|
+
def __get__(
|
141
|
+
self: CachedFunction[Callable[Concatenate[C, P], R]],
|
142
|
+
instance: C,
|
143
|
+
owner: Type[C],
|
144
|
+
) -> CachedFunction[Callable[P, R]]: ...
|
116
145
|
def __get__(self, instance, owner):
|
117
146
|
if instance is None:
|
118
147
|
return self
|
119
|
-
assert self.attrname is not None
|
120
148
|
bound_fn = object.__new__(CachedFunction)
|
121
149
|
bound_fn.__dict__ |= self.__dict__
|
122
150
|
bound_fn.bound = (instance,)
|
123
|
-
if hasattr(instance, "__dict__"):
|
124
|
-
setattr(instance, self.attrname, bound_fn)
|
125
151
|
return bound_fn
|
126
152
|
|
127
153
|
@property
|
128
|
-
def
|
129
|
-
return self.ident.
|
154
|
+
def fn(self) -> Fn:
|
155
|
+
return cast(Fn, self.ident.fn)
|
156
|
+
|
157
|
+
@property
|
158
|
+
def cleanup(self):
|
159
|
+
return self.storage.cleanup
|
130
160
|
|
131
161
|
def reinit(self, recursive=False) -> CachedFunction[Fn]:
|
132
|
-
depend_idents =
|
162
|
+
depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
|
133
163
|
for ident in depend_idents: ident.reset()
|
134
164
|
for ident in depend_idents: ident.fn_hash
|
135
165
|
return self
|
136
166
|
|
137
|
-
def
|
167
|
+
def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
|
168
|
+
ident = self.ident
|
138
169
|
args = self.bound + args
|
139
|
-
pos_args = args[len(
|
140
|
-
named_pos_args = dict(zip(
|
141
|
-
named_args = {**
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
170
|
+
pos_args = args[len(ident.pos_names):]
|
171
|
+
named_pos_args = dict(zip(ident.pos_names, args))
|
172
|
+
named_args = {**ident.default_args, **named_pos_args, **kw}
|
173
|
+
for key, hash_by in ident.hash_by_map.items():
|
174
|
+
if isinstance(key, str):
|
175
|
+
named_args[key] = hash_by(named_args[key])
|
176
|
+
elif key == b"*":
|
177
|
+
pos_args = map(hash_by, pos_args)
|
178
|
+
elif key == b"**":
|
179
|
+
for key in kw.keys() - ident.arg_names:
|
180
|
+
named_args[key] = hash_by(named_args[key])
|
181
|
+
named_args_iter = chain.from_iterable(sorted(named_args.items()))
|
182
|
+
captured = chain.from_iterable(capturable.capture() for capturable in ident.capturables)
|
183
|
+
call_hash = ObjectHash(digest_size=16) \
|
184
|
+
.update(iter=named_args_iter, header="NAMED") \
|
185
|
+
.update(iter=pos_args, header="POS") \
|
186
|
+
.update(iter=captured, header="CAPTURED")
|
187
|
+
return str(call_hash)
|
188
|
+
|
189
|
+
def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
|
190
|
+
return self._get_call_hash(args, kw)
|
191
|
+
|
192
|
+
async def _store_coroutine(self, call_hash: str, coroutine: Coroutine):
|
152
193
|
return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
|
153
194
|
|
154
195
|
def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
|
155
196
|
full_args = self.bound + args
|
156
|
-
params = self.checkpointer
|
197
|
+
params = self.ident.checkpointer
|
198
|
+
storage = self.storage
|
157
199
|
if not params.when:
|
158
200
|
return self.fn(*full_args, **kw)
|
159
201
|
|
160
|
-
call_hash = self.
|
161
|
-
call_id = f"{
|
202
|
+
call_hash = self._get_call_hash(args, kw)
|
203
|
+
call_id = f"{storage.fn_id()}/{call_hash}"
|
162
204
|
refresh = rerun \
|
163
|
-
or not
|
164
|
-
or (params.should_expire and params.should_expire(
|
205
|
+
or not storage.exists(call_hash) \
|
206
|
+
or (params.should_expire and params.should_expire(storage.checkpoint_date(call_hash)))
|
165
207
|
|
166
208
|
if refresh:
|
167
209
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id, "blue")
|
168
210
|
data = self.fn(*full_args, **kw)
|
169
211
|
if iscoroutine(data):
|
170
|
-
return self.
|
171
|
-
return
|
212
|
+
return self._store_coroutine(call_hash, data)
|
213
|
+
return storage.store(call_hash, data)
|
172
214
|
|
173
215
|
try:
|
174
|
-
data =
|
216
|
+
data = storage.load(call_hash)
|
175
217
|
print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id, "green")
|
176
218
|
return data
|
177
219
|
except (EOFError, FileNotFoundError):
|
@@ -186,17 +228,17 @@ class CachedFunction(Generic[Fn]):
|
|
186
228
|
return self._call(args, kw, True)
|
187
229
|
|
188
230
|
def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
|
189
|
-
return self.storage.exists(self.
|
231
|
+
return self.storage.exists(self._get_call_hash(args, kw))
|
190
232
|
|
191
233
|
def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
|
192
|
-
self.storage.delete(self.
|
234
|
+
self.storage.delete(self._get_call_hash(args, kw))
|
193
235
|
|
194
236
|
@overload
|
195
237
|
def get(self: Callable[P, Coro[R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
196
238
|
@overload
|
197
239
|
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
198
240
|
def get(self, *args, **kw):
|
199
|
-
call_hash = self.
|
241
|
+
call_hash = self._get_call_hash(args, kw)
|
200
242
|
try:
|
201
243
|
data = self.storage.load(call_hash)
|
202
244
|
return data.value if isinstance(data, AwaitableValue) else data
|
@@ -208,33 +250,9 @@ class CachedFunction(Generic[Fn]):
|
|
208
250
|
@overload
|
209
251
|
def set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
|
210
252
|
def set(self, value, *args, **kw):
|
211
|
-
self.storage.store(self.
|
253
|
+
self.storage.store(self._get_call_hash(args, kw), value)
|
212
254
|
|
213
255
|
def __repr__(self) -> str:
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
if self not in visited:
|
218
|
-
yield self
|
219
|
-
visited = visited or set()
|
220
|
-
visited.add(self)
|
221
|
-
for depend in self.depends:
|
222
|
-
if isinstance(depend, CachedFunction):
|
223
|
-
yield from depend.deep_depends(visited)
|
224
|
-
|
225
|
-
def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
|
226
|
-
if get_origin(annotation) is Annotated:
|
227
|
-
args = get_args(annotation)
|
228
|
-
metadata = args[1] if len(args) > 1 else None
|
229
|
-
if get_origin(metadata) is HashBy:
|
230
|
-
return get_args(metadata)[0]
|
231
|
-
|
232
|
-
def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
|
233
|
-
hash_by_map = {}
|
234
|
-
for name, param in sig.parameters.items():
|
235
|
-
if param.kind == Parameter.VAR_POSITIONAL:
|
236
|
-
name = b"*"
|
237
|
-
elif param.kind == Parameter.VAR_KEYWORD:
|
238
|
-
name = b"**"
|
239
|
-
hash_by_map[name] = hash_by_from_annotation(param.annotation)
|
240
|
-
return hash_by_map if any(hash_by_map.values()) else {}
|
256
|
+
initialized = "fn_hash" in self.ident.__dict__
|
257
|
+
fn_hash = self.ident.fn_hash[:6] if initialized else "- uninitialized"
|
258
|
+
return f"<CachedFunction {self.fn.__name__} {fn_hash}>"
|
checkpointer/fn_ident.py
CHANGED
@@ -1,103 +1,161 @@
|
|
1
1
|
import dis
|
2
|
-
import
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from
|
2
|
+
from inspect import Parameter, getmodule, signature, unwrap
|
3
|
+
from types import CodeType, MethodType, ModuleType
|
4
|
+
from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
|
5
|
+
from .fn_string import get_fn_aststr
|
6
|
+
from .import_mappings import resolve_annotation
|
7
7
|
from .object_hash import ObjectHash
|
8
|
-
from .
|
8
|
+
from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
|
9
|
+
from .utils import (
|
10
|
+
cwd, distinct, get_at, get_cell_contents,
|
11
|
+
get_file, is_class, is_user_fn, seekable, takewhile,
|
12
|
+
)
|
9
13
|
|
10
|
-
|
14
|
+
AttrPath = tuple[str, ...]
|
15
|
+
CapturableByFn = dict[Callable, list["Capturable"]]
|
11
16
|
|
12
17
|
class RawFunctionIdent(NamedTuple):
|
13
18
|
fn_hash: str
|
14
|
-
captured_hash: str
|
15
19
|
depends: list[Callable]
|
20
|
+
capturables: set["Capturable"]
|
16
21
|
|
17
|
-
|
18
|
-
|
19
|
-
|
22
|
+
class Capturable(NamedTuple):
|
23
|
+
key: str
|
24
|
+
module: ModuleType
|
25
|
+
attr_path: AttrPath
|
26
|
+
hash_by: Callable | None
|
27
|
+
hash: str | None = None
|
20
28
|
|
21
|
-
def
|
22
|
-
|
29
|
+
def capture(self) -> tuple[str, object]:
|
30
|
+
if obj := self.hash:
|
31
|
+
return self.key, obj
|
32
|
+
obj = get_at(self.module, *self.attr_path)
|
33
|
+
obj = self.hash_by(obj) if self.hash_by else obj
|
34
|
+
return self.key, obj
|
35
|
+
|
36
|
+
@staticmethod
|
37
|
+
def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
|
38
|
+
file = str(get_file(module).relative_to(cwd))
|
39
|
+
key = file + "/" + ".".join(attr_path)
|
40
|
+
cap = Capturable(key, module, attr_path, hash_by)
|
41
|
+
if not capture_once:
|
42
|
+
return cap
|
43
|
+
obj_hash = str(ObjectHash(cap.capture()[1]))
|
44
|
+
return Capturable(key, module, attr_path, None, obj_hash)
|
45
|
+
|
46
|
+
def extract_classvars(code: CodeType, scope_vars: dict) -> dict[str, dict[str, Type]]:
|
47
|
+
attr_path = AttrPath(())
|
23
48
|
scope_obj = None
|
24
49
|
classvars: dict[str, dict[str, Type]] = {}
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
50
|
+
instructs = seekable(dis.get_instructions(code))
|
51
|
+
for instruct in instructs:
|
52
|
+
if instruct.opname in scope_vars and not attr_path:
|
53
|
+
attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
|
54
|
+
attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
|
55
|
+
instructs.step(-1)
|
56
|
+
elif instruct.opname == "CALL":
|
57
|
+
obj = get_at(scope_vars, *attr_path)
|
58
|
+
attr_path = AttrPath(())
|
32
59
|
if is_class(obj):
|
33
60
|
scope_obj = obj
|
34
|
-
elif
|
35
|
-
load_key =
|
36
|
-
classvars.setdefault(load_key, {})[
|
61
|
+
elif instruct.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
|
62
|
+
load_key = instruct.opname.replace("STORE", "LOAD")
|
63
|
+
classvars.setdefault(load_key, {})[instruct.argval] = scope_obj
|
37
64
|
scope_obj = None
|
38
65
|
return classvars
|
39
66
|
|
40
|
-
def extract_scope_values(code: CodeType, scope_vars:
|
67
|
+
def extract_scope_values(code: CodeType, scope_vars: dict) -> Iterable[tuple[AttrPath, object]]:
|
41
68
|
classvars = extract_classvars(code, scope_vars)
|
42
|
-
scope_vars = scope_vars
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
69
|
+
scope_vars = {**scope_vars, **{k: {**scope_vars[k], **v} for k, v in classvars.items()}}
|
70
|
+
instructs = seekable(dis.get_instructions(code))
|
71
|
+
for instruct in instructs:
|
72
|
+
if instruct.opname in scope_vars:
|
73
|
+
attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
|
74
|
+
attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
|
75
|
+
parent_path = attr_path[:-1]
|
76
|
+
instructs.step(-1)
|
77
|
+
obj = get_at(scope_vars, *attr_path)
|
78
|
+
if obj is not None:
|
79
|
+
yield attr_path, obj
|
80
|
+
if callable(obj) and parent_path[1:]:
|
81
|
+
parent_obj = get_at(scope_vars, *parent_path)
|
82
|
+
yield parent_path, parent_obj
|
50
83
|
for const in code.co_consts:
|
51
84
|
if isinstance(const, CodeType):
|
52
|
-
|
85
|
+
next_deref = {**scope_vars["LOAD_DEREF"], **scope_vars["LOAD_FAST"]}
|
86
|
+
next_scope_vars = {**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref}
|
87
|
+
yield from extract_scope_values(const, next_scope_vars)
|
88
|
+
|
89
|
+
def resolve_class_annotations(anno: object) -> Type | None:
|
90
|
+
if anno in (None, Annotated):
|
91
|
+
return None
|
92
|
+
elif is_class(anno):
|
93
|
+
return anno
|
94
|
+
elif get_origin(anno) is Annotated:
|
95
|
+
return resolve_class_annotations(next(iter(get_args(anno)), None))
|
96
|
+
return resolve_class_annotations(get_origin(anno))
|
53
97
|
|
54
|
-
def get_self_value(fn: Callable) ->
|
98
|
+
def get_self_value(fn: Callable) -> Type | object | None:
|
55
99
|
if isinstance(fn, MethodType):
|
56
100
|
return fn.__self__
|
57
|
-
parts =
|
58
|
-
cls = parts and
|
101
|
+
parts = fn.__qualname__.split(".")[:-1]
|
102
|
+
cls = parts and get_at(fn.__globals__, *parts)
|
59
103
|
if is_class(cls):
|
60
104
|
return cls
|
61
105
|
|
62
|
-
def
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
106
|
+
def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, object]) -> Iterable[Capturable]:
|
107
|
+
module = getmodule(fn)
|
108
|
+
if not module or not is_user_fn(fn):
|
109
|
+
return
|
110
|
+
for (instruct, *attr_path), obj in captured_vars.items():
|
111
|
+
attr_path = AttrPath(attr_path)
|
112
|
+
if instruct == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
|
113
|
+
anno = resolve_annotation(module, ".".join(attr_path))
|
114
|
+
if capture or is_capture_me(anno) or is_capture_me_once(anno):
|
115
|
+
hash_by = hash_by_from_annotation(anno)
|
116
|
+
if hash_by is not to_none:
|
117
|
+
yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
|
71
118
|
|
72
|
-
def
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
119
|
+
def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
|
120
|
+
scope_vars_signature: dict[str, Type | object] = {
|
121
|
+
param.name: class_anno
|
122
|
+
for param in signature(fn).parameters.values()
|
123
|
+
if param.annotation is not Parameter.empty
|
124
|
+
if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
|
125
|
+
if (class_anno := resolve_class_annotations(param.annotation))
|
126
|
+
}
|
127
|
+
if self_obj := get_self_value(fn):
|
128
|
+
scope_vars_signature["self"] = self_obj
|
129
|
+
scope_vars = {
|
130
|
+
"LOAD_FAST": scope_vars_signature,
|
131
|
+
"LOAD_DEREF": dict(get_cell_contents(fn)),
|
132
|
+
"LOAD_GLOBAL": fn.__globals__,
|
133
|
+
}
|
134
|
+
captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
|
135
|
+
captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
|
136
|
+
capturables = list(get_capturables(fn, capture, captured_vars))
|
137
|
+
return captured_callables, capturables
|
77
138
|
|
78
|
-
def get_depend_fns(fn: Callable,
|
139
|
+
def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn = {}) -> CapturableByFn:
|
79
140
|
from .checkpoint import CachedFunction
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
for
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
get_depend_fns(child_fn, captured_vals_by_fn)
|
91
|
-
return captured_vals_by_fn
|
141
|
+
captured_callables, capturables = get_fn_captures(fn, capture)
|
142
|
+
capturable_by_fn = capturable_by_fn or {}
|
143
|
+
capturable_by_fn[fn] = capturables
|
144
|
+
for depend_fn in captured_callables:
|
145
|
+
depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
|
146
|
+
if isinstance(depend_fn, CachedFunction):
|
147
|
+
capturable_by_fn[depend_fn.ident.cached_fn] = []
|
148
|
+
elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
|
149
|
+
get_depend_fns(depend_fn, capture, capturable_by_fn)
|
150
|
+
return capturable_by_fn
|
92
151
|
|
93
152
|
def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
|
94
153
|
from .checkpoint import CachedFunction
|
95
|
-
|
96
|
-
|
97
|
-
depends =
|
154
|
+
capturable_by_fn = get_depend_fns(fn, capture)
|
155
|
+
capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
|
156
|
+
depends = capturable_by_fn.keys()
|
98
157
|
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
99
|
-
|
100
|
-
assert fn ==
|
101
|
-
fn_hash = str(ObjectHash(iter=
|
102
|
-
|
103
|
-
return RawFunctionIdent(fn_hash, captured_hash, depends)
|
158
|
+
depend_callables = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
159
|
+
assert fn == depend_callables[0]
|
160
|
+
fn_hash = str(ObjectHash(iter=map(get_fn_aststr, depend_callables)))
|
161
|
+
return RawFunctionIdent(fn_hash, depends, capturables)
|