checkpointer 2.11.2__py3-none-any.whl → 2.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -1,15 +1,15 @@
1
1
  import gc
2
2
  import tempfile
3
3
  from typing import Callable
4
- from .checkpoint import CachedFunction, Checkpointer, CheckpointError
4
+ from .checkpoint import CachedFunction, Checkpointer, CheckpointError, FunctionIdent
5
5
  from .object_hash import ObjectHash
6
6
  from .storages import MemoryStorage, PickleStorage, Storage
7
- from .types import AwaitableValue, HashBy, NoHash
7
+ from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeOnce, HashBy, NoHash
8
8
 
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
11
- memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
- tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
11
+ memory_checkpoint = Checkpointer(storage="memory", verbosity=0)
12
+ tmp_checkpoint = Checkpointer(directory=f"{tempfile.gettempdir()}/checkpoints")
13
13
  static_checkpoint = Checkpointer(fn_hash_from=())
14
14
 
15
15
  def cleanup_all(invalidated=True, expired=True):
@@ -2,30 +2,27 @@ from __future__ import annotations
2
2
  import re
3
3
  from datetime import datetime
4
4
  from functools import cached_property, update_wrapper
5
- from inspect import Parameter, Signature, iscoroutine, signature
5
+ from inspect import Parameter, iscoroutine, signature, unwrap
6
+ from itertools import chain
6
7
  from pathlib import Path
7
8
  from typing import (
8
- Annotated, Callable, Concatenate, Coroutine, Generic,
9
- Iterable, Literal, Self, Type, TypedDict,
10
- Unpack, cast, get_args, get_origin, overload,
9
+ Callable, Concatenate, Coroutine, Generic, Iterable,
10
+ Literal, Self, Type, TypedDict, Unpack, cast, overload,
11
11
  )
12
- from .fn_ident import RawFunctionIdent, get_fn_ident
12
+ from .fn_ident import Capturable, RawFunctionIdent, get_fn_ident
13
13
  from .object_hash import ObjectHash
14
14
  from .print_checkpoint import print_checkpoint
15
- from .storages import STORAGE_MAP, Storage
16
- from .types import AwaitableValue, C, Coro, Fn, HashBy, P, R
17
- from .utils import unwrap_fn
15
+ from .storages import STORAGE_MAP, Storage, StorageType
16
+ from .types import AwaitableValue, C, Coro, Fn, P, R, hash_by_from_annotation
18
17
 
19
18
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
20
19
 
21
- empty_set = cast(set, frozenset())
22
-
23
20
  class CheckpointError(Exception):
24
21
  pass
25
22
 
26
23
  class CheckpointerOpts(TypedDict, total=False):
27
- format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
28
- root_path: Path | str | None
24
+ storage: Type[Storage] | StorageType
25
+ directory: Path | str | None
29
26
  when: bool
30
27
  verbosity: Literal[0, 1, 2]
31
28
  should_expire: Callable[[datetime], bool] | None
@@ -34,8 +31,8 @@ class CheckpointerOpts(TypedDict, total=False):
34
31
 
35
32
  class Checkpointer:
36
33
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
37
- self.format = opts.get("format", "pickle")
38
- self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
34
+ self.storage = opts.get("storage", "pickle")
35
+ self.directory = Path(opts.get("directory", DEFAULT_DIR) or ".")
39
36
  self.when = opts.get("when", True)
40
37
  self.verbosity = opts.get("verbosity", 1)
41
38
  self.should_expire = opts.get("should_expire")
@@ -59,119 +56,164 @@ class FunctionIdent:
59
56
  Separated from CachedFunction to prevent hash desynchronization
60
57
  among bound instances when `.reinit()` is called.
61
58
  """
62
- def __init__(self, cached_fn: CachedFunction):
63
- self.__dict__.clear()
59
+ __slots__ = (
60
+ "checkpointer", "cached_fn", "fn", "fn_dir", "pos_names",
61
+ "arg_names", "default_args", "hash_by_map", "__dict__",
62
+ )
63
+
64
+ def __init__(self, cached_fn: CachedFunction, checkpointer: Checkpointer, fn: Callable):
65
+ wrapped = unwrap(fn)
66
+ fn_file = Path(wrapped.__code__.co_filename).name
67
+ fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
68
+ params = list(signature(wrapped).parameters.values())
69
+ pos_param_types = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
70
+ named_param_types = (Parameter.KEYWORD_ONLY,) + pos_param_types
71
+ name_by_kind = {Parameter.VAR_POSITIONAL: b"*", Parameter.VAR_KEYWORD: b"**"}
72
+ self.checkpointer = checkpointer
64
73
  self.cached_fn = cached_fn
74
+ self.fn = fn
75
+ self.fn_dir = f"{fn_file}/{fn_name}"
76
+ self.pos_names = [param.name for param in params if param.kind in pos_param_types]
77
+ self.arg_names = {param.name for param in params if param.kind in named_param_types}
78
+ self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
79
+ self.hash_by_map = {
80
+ name_by_kind.get(param.kind, param.name): hash_by
81
+ for param in params
82
+ if (hash_by := hash_by_from_annotation(param.annotation))
83
+ }
84
+
85
+ def reset(self):
86
+ self.__dict__.clear()
87
+
88
+ def is_static(self) -> bool:
89
+ return self.checkpointer.fn_hash_from is not None
65
90
 
66
91
  @cached_property
67
92
  def raw_ident(self) -> RawFunctionIdent:
68
- return get_fn_ident(unwrap_fn(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
93
+ return get_fn_ident(unwrap(self.fn), self.checkpointer.capture)
69
94
 
70
95
  @cached_property
71
96
  def fn_hash(self) -> str:
72
- if (hash_from := self.cached_fn.checkpointer.fn_hash_from) is not None:
73
- return str(ObjectHash(hash_from, digest_size=16))
74
- deep_hashes = [depend.ident.raw_ident.fn_hash for depend in self.cached_fn.deep_depends()]
97
+ if self.is_static():
98
+ return str(ObjectHash(self.checkpointer.fn_hash_from, digest_size=16))
99
+ depends = self.deep_idents(past_static=False)
100
+ deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
75
101
  return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
76
102
 
77
103
  @cached_property
78
- def captured_hash(self) -> str:
79
- deep_hashes = [depend.ident.raw_ident.captured_hash for depend in self.cached_fn.deep_depends()]
80
- return str(ObjectHash().write_text(iter=deep_hashes))
104
+ def capturables(self) -> list[Capturable]:
105
+ return sorted({
106
+ capturable.key: capturable
107
+ for depend in self.deep_idents()
108
+ for capturable in depend.raw_ident.capturables
109
+ }.values())
110
+
111
+ def deep_depends(self, past_static=True, visited: set[Callable] = set()) -> Iterable[Callable]:
112
+ if self.cached_fn not in visited:
113
+ yield self.cached_fn
114
+ visited = visited or set()
115
+ visited.add(self.cached_fn)
116
+ stop = not past_static and self.is_static()
117
+ depends = [] if stop else self.raw_ident.depends
118
+ for depend in depends:
119
+ if isinstance(depend, CachedFunction):
120
+ yield from depend.ident.deep_depends(past_static, visited)
121
+ elif depend not in visited:
122
+ yield depend
123
+ visited.add(depend)
81
124
 
82
- def reset(self):
83
- self.__init__(self.cached_fn)
125
+ def deep_idents(self, past_static=True) -> Iterable[FunctionIdent]:
126
+ return (fn.ident for fn in self.deep_depends(past_static) if isinstance(fn, CachedFunction))
84
127
 
85
128
  class CachedFunction(Generic[Fn]):
86
129
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
87
- wrapped = unwrap_fn(fn)
88
- fn_file = Path(wrapped.__code__.co_filename).name
89
- fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
90
- Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
91
- update_wrapper(cast(Callable, self), wrapped)
92
- self.checkpointer = checkpointer
93
- self.fn = fn
94
- self.fn_dir = f"{fn_file}/{fn_name}"
130
+ store_format = checkpointer.storage
131
+ Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
132
+ update_wrapper(cast(Callable, self), unwrap(fn))
133
+ self.ident = FunctionIdent(self, checkpointer, fn)
95
134
  self.storage = Storage(self)
96
- self.cleanup = self.storage.cleanup
97
135
  self.bound = ()
98
- self.attrname: str | None = None
99
-
100
- sig = signature(wrapped)
101
- params = list(sig.parameters.items())
102
- pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
103
- self.arg_names = [name for name, param in params if param.kind in pos_params]
104
- self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
105
- self.hash_by_map = get_hash_by_map(sig)
106
- self.ident = FunctionIdent(self)
107
-
108
- def __set_name__(self, _, name: str):
109
- assert self.attrname is None
110
- self.attrname = name
111
136
 
112
137
  @overload
113
138
  def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
114
139
  @overload
115
- def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
140
+ def __get__(
141
+ self: CachedFunction[Callable[Concatenate[C, P], R]],
142
+ instance: C,
143
+ owner: Type[C],
144
+ ) -> CachedFunction[Callable[P, R]]: ...
116
145
  def __get__(self, instance, owner):
117
146
  if instance is None:
118
147
  return self
119
- assert self.attrname is not None
120
148
  bound_fn = object.__new__(CachedFunction)
121
149
  bound_fn.__dict__ |= self.__dict__
122
150
  bound_fn.bound = (instance,)
123
- if hasattr(instance, "__dict__"):
124
- setattr(instance, self.attrname, bound_fn)
125
151
  return bound_fn
126
152
 
127
153
  @property
128
- def depends(self) -> list[Callable]:
129
- return self.ident.raw_ident.depends
154
+ def fn(self) -> Fn:
155
+ return cast(Fn, self.ident.fn)
156
+
157
+ @property
158
+ def cleanup(self):
159
+ return self.storage.cleanup
130
160
 
131
161
  def reinit(self, recursive=False) -> CachedFunction[Fn]:
132
- depend_idents = [depend.ident for depend in self.deep_depends()] if recursive else [self.ident]
162
+ depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
133
163
  for ident in depend_idents: ident.reset()
134
164
  for ident in depend_idents: ident.fn_hash
135
165
  return self
136
166
 
137
- def get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
167
+ def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
168
+ ident = self.ident
138
169
  args = self.bound + args
139
- pos_args = args[len(self.arg_names):]
140
- named_pos_args = dict(zip(self.arg_names, args))
141
- named_args = {**self.default_args, **named_pos_args, **kw}
142
- if hash_by_map := self.hash_by_map:
143
- rest_hash_by = hash_by_map.get(b"**")
144
- for key, value in named_args.items():
145
- if hash_by := hash_by_map.get(key, rest_hash_by):
146
- named_args[key] = hash_by(value)
147
- if pos_hash_by := hash_by_map.get(b"*"):
148
- pos_args = tuple(map(pos_hash_by, pos_args))
149
- return str(ObjectHash(named_args, pos_args, self.ident.captured_hash, digest_size=16))
150
-
151
- async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
170
+ pos_args = args[len(ident.pos_names):]
171
+ named_pos_args = dict(zip(ident.pos_names, args))
172
+ named_args = {**ident.default_args, **named_pos_args, **kw}
173
+ for key, hash_by in ident.hash_by_map.items():
174
+ if isinstance(key, str):
175
+ named_args[key] = hash_by(named_args[key])
176
+ elif key == b"*":
177
+ pos_args = map(hash_by, pos_args)
178
+ elif key == b"**":
179
+ for key in kw.keys() - ident.arg_names:
180
+ named_args[key] = hash_by(named_args[key])
181
+ named_args_iter = chain.from_iterable(sorted(named_args.items()))
182
+ captured = chain.from_iterable(capturable.capture() for capturable in ident.capturables)
183
+ call_hash = ObjectHash(digest_size=16) \
184
+ .update(iter=named_args_iter, header="NAMED") \
185
+ .update(iter=pos_args, header="POS") \
186
+ .update(iter=captured, header="CAPTURED")
187
+ return str(call_hash)
188
+
189
+ def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
190
+ return self._get_call_hash(args, kw)
191
+
192
+ async def _store_coroutine(self, call_hash: str, coroutine: Coroutine):
152
193
  return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
153
194
 
154
195
  def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
155
196
  full_args = self.bound + args
156
- params = self.checkpointer
197
+ params = self.ident.checkpointer
198
+ storage = self.storage
157
199
  if not params.when:
158
200
  return self.fn(*full_args, **kw)
159
201
 
160
- call_hash = self.get_call_hash(args, kw)
161
- call_id = f"{self.storage.fn_id()}/{call_hash}"
202
+ call_hash = self._get_call_hash(args, kw)
203
+ call_id = f"{storage.fn_id()}/{call_hash}"
162
204
  refresh = rerun \
163
- or not self.storage.exists(call_hash) \
164
- or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_hash)))
205
+ or not storage.exists(call_hash) \
206
+ or (params.should_expire and params.should_expire(storage.checkpoint_date(call_hash)))
165
207
 
166
208
  if refresh:
167
209
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id, "blue")
168
210
  data = self.fn(*full_args, **kw)
169
211
  if iscoroutine(data):
170
- return self._resolve_coroutine(call_hash, data)
171
- return self.storage.store(call_hash, data)
212
+ return self._store_coroutine(call_hash, data)
213
+ return storage.store(call_hash, data)
172
214
 
173
215
  try:
174
- data = self.storage.load(call_hash)
216
+ data = storage.load(call_hash)
175
217
  print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id, "green")
176
218
  return data
177
219
  except (EOFError, FileNotFoundError):
@@ -186,17 +228,17 @@ class CachedFunction(Generic[Fn]):
186
228
  return self._call(args, kw, True)
187
229
 
188
230
  def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
189
- return self.storage.exists(self.get_call_hash(args, kw))
231
+ return self.storage.exists(self._get_call_hash(args, kw))
190
232
 
191
233
  def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
192
- self.storage.delete(self.get_call_hash(args, kw))
234
+ self.storage.delete(self._get_call_hash(args, kw))
193
235
 
194
236
  @overload
195
237
  def get(self: Callable[P, Coro[R]], *args: P.args, **kw: P.kwargs) -> R: ...
196
238
  @overload
197
239
  def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
198
240
  def get(self, *args, **kw):
199
- call_hash = self.get_call_hash(args, kw)
241
+ call_hash = self._get_call_hash(args, kw)
200
242
  try:
201
243
  data = self.storage.load(call_hash)
202
244
  return data.value if isinstance(data, AwaitableValue) else data
@@ -208,33 +250,9 @@ class CachedFunction(Generic[Fn]):
208
250
  @overload
209
251
  def set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
210
252
  def set(self, value, *args, **kw):
211
- self.storage.store(self.get_call_hash(args, kw), value)
253
+ self.storage.store(self._get_call_hash(args, kw), value)
212
254
 
213
255
  def __repr__(self) -> str:
214
- return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
215
-
216
- def deep_depends(self, visited: set[CachedFunction] = empty_set) -> Iterable[CachedFunction]:
217
- if self not in visited:
218
- yield self
219
- visited = visited or set()
220
- visited.add(self)
221
- for depend in self.depends:
222
- if isinstance(depend, CachedFunction):
223
- yield from depend.deep_depends(visited)
224
-
225
- def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
226
- if get_origin(annotation) is Annotated:
227
- args = get_args(annotation)
228
- metadata = args[1] if len(args) > 1 else None
229
- if get_origin(metadata) is HashBy:
230
- return get_args(metadata)[0]
231
-
232
- def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
233
- hash_by_map = {}
234
- for name, param in sig.parameters.items():
235
- if param.kind == Parameter.VAR_POSITIONAL:
236
- name = b"*"
237
- elif param.kind == Parameter.VAR_KEYWORD:
238
- name = b"**"
239
- hash_by_map[name] = hash_by_from_annotation(param.annotation)
240
- return hash_by_map if any(hash_by_map.values()) else {}
256
+ initialized = "fn_hash" in self.ident.__dict__
257
+ fn_hash = self.ident.fn_hash[:6] if initialized else "- uninitialized"
258
+ return f"<CachedFunction {self.fn.__name__} {fn_hash}>"
checkpointer/fn_ident.py CHANGED
@@ -1,103 +1,161 @@
1
1
  import dis
2
- import inspect
3
- from itertools import takewhile
4
- from pathlib import Path
5
- from types import CodeType, FunctionType, MethodType
6
- from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
2
+ from inspect import Parameter, getmodule, signature, unwrap
3
+ from types import CodeType, MethodType, ModuleType
4
+ from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
5
+ from .fn_string import get_fn_aststr
6
+ from .import_mappings import resolve_annotation
7
7
  from .object_hash import ObjectHash
8
- from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
8
+ from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
9
+ from .utils import (
10
+ cwd, distinct, get_at, get_cell_contents,
11
+ get_file, is_class, is_user_fn, seekable, takewhile,
12
+ )
9
13
 
10
- cwd = Path.cwd().resolve()
14
+ AttrPath = tuple[str, ...]
15
+ CapturableByFn = dict[Callable, list["Capturable"]]
11
16
 
12
17
  class RawFunctionIdent(NamedTuple):
13
18
  fn_hash: str
14
- captured_hash: str
15
19
  depends: list[Callable]
20
+ capturables: set["Capturable"]
16
21
 
17
- def is_class(obj) -> TypeGuard[Type]:
18
- # isinstance works too, but needlessly triggers _lazyinit()
19
- return issubclass(type(obj), type)
22
+ class Capturable(NamedTuple):
23
+ key: str
24
+ module: ModuleType
25
+ attr_path: AttrPath
26
+ hash_by: Callable | None
27
+ hash: str | None = None
20
28
 
21
- def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
22
- attr_path: tuple[str, ...] = ()
29
+ def capture(self) -> tuple[str, object]:
30
+ if obj := self.hash:
31
+ return self.key, obj
32
+ obj = get_at(self.module, *self.attr_path)
33
+ obj = self.hash_by(obj) if self.hash_by else obj
34
+ return self.key, obj
35
+
36
+ @staticmethod
37
+ def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
38
+ file = str(get_file(module).relative_to(cwd))
39
+ key = file + "/" + ".".join(attr_path)
40
+ cap = Capturable(key, module, attr_path, hash_by)
41
+ if not capture_once:
42
+ return cap
43
+ obj_hash = str(ObjectHash(cap.capture()[1]))
44
+ return Capturable(key, module, attr_path, None, obj_hash)
45
+
46
+ def extract_classvars(code: CodeType, scope_vars: dict) -> dict[str, dict[str, Type]]:
47
+ attr_path = AttrPath(())
23
48
  scope_obj = None
24
49
  classvars: dict[str, dict[str, Type]] = {}
25
- for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
26
- if instr.opname in scope_vars and not attr_path:
27
- attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
28
- attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
29
- elif instr.opname == "CALL":
30
- obj = scope_vars.get_at(attr_path)
31
- attr_path = ()
50
+ instructs = seekable(dis.get_instructions(code))
51
+ for instruct in instructs:
52
+ if instruct.opname in scope_vars and not attr_path:
53
+ attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
54
+ attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
55
+ instructs.step(-1)
56
+ elif instruct.opname == "CALL":
57
+ obj = get_at(scope_vars, *attr_path)
58
+ attr_path = AttrPath(())
32
59
  if is_class(obj):
33
60
  scope_obj = obj
34
- elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
35
- load_key = instr.opname.replace("STORE", "LOAD")
36
- classvars.setdefault(load_key, {})[instr.argval] = scope_obj
61
+ elif instruct.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
62
+ load_key = instruct.opname.replace("STORE", "LOAD")
63
+ classvars.setdefault(load_key, {})[instruct.argval] = scope_obj
37
64
  scope_obj = None
38
65
  return classvars
39
66
 
40
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
67
+ def extract_scope_values(code: CodeType, scope_vars: dict) -> Iterable[tuple[AttrPath, object]]:
41
68
  classvars = extract_classvars(code, scope_vars)
42
- scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
43
- for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
44
- if instr.opname in scope_vars:
45
- attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
46
- attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
47
- val = scope_vars.get_at(attr_path)
48
- if val is not None:
49
- yield attr_path, val
69
+ scope_vars = {**scope_vars, **{k: {**scope_vars[k], **v} for k, v in classvars.items()}}
70
+ instructs = seekable(dis.get_instructions(code))
71
+ for instruct in instructs:
72
+ if instruct.opname in scope_vars:
73
+ attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
74
+ attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
75
+ parent_path = attr_path[:-1]
76
+ instructs.step(-1)
77
+ obj = get_at(scope_vars, *attr_path)
78
+ if obj is not None:
79
+ yield attr_path, obj
80
+ if callable(obj) and parent_path[1:]:
81
+ parent_obj = get_at(scope_vars, *parent_path)
82
+ yield parent_path, parent_obj
50
83
  for const in code.co_consts:
51
84
  if isinstance(const, CodeType):
52
- yield from extract_scope_values(const, scope_vars)
85
+ next_deref = {**scope_vars["LOAD_DEREF"], **scope_vars["LOAD_FAST"]}
86
+ next_scope_vars = {**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref}
87
+ yield from extract_scope_values(const, next_scope_vars)
88
+
89
+ def resolve_class_annotations(anno: object) -> Type | None:
90
+ if anno in (None, Annotated):
91
+ return None
92
+ elif is_class(anno):
93
+ return anno
94
+ elif get_origin(anno) is Annotated:
95
+ return resolve_class_annotations(next(iter(get_args(anno)), None))
96
+ return resolve_class_annotations(get_origin(anno))
53
97
 
54
- def get_self_value(fn: Callable) -> type | object | None:
98
+ def get_self_value(fn: Callable) -> Type | object | None:
55
99
  if isinstance(fn, MethodType):
56
100
  return fn.__self__
57
- parts = tuple(fn.__qualname__.split(".")[:-1])
58
- cls = parts and AttrDict(fn.__globals__).get_at(parts)
101
+ parts = fn.__qualname__.split(".")[:-1]
102
+ cls = parts and get_at(fn.__globals__, *parts)
59
103
  if is_class(cls):
60
104
  return cls
61
105
 
62
- def get_fn_captured_vals(fn: Callable) -> list[object]:
63
- self_value = get_self_value(fn)
64
- scope_vars = AttrDict({
65
- "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
66
- "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
67
- "LOAD_GLOBAL": AttrDict(fn.__globals__),
68
- })
69
- vals = dict(extract_scope_values(fn.__code__, scope_vars))
70
- return list(vals.values())
106
+ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, object]) -> Iterable[Capturable]:
107
+ module = getmodule(fn)
108
+ if not module or not is_user_fn(fn):
109
+ return
110
+ for (instruct, *attr_path), obj in captured_vars.items():
111
+ attr_path = AttrPath(attr_path)
112
+ if instruct == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
113
+ anno = resolve_annotation(module, ".".join(attr_path))
114
+ if capture or is_capture_me(anno) or is_capture_me_once(anno):
115
+ hash_by = hash_by_from_annotation(anno)
116
+ if hash_by is not to_none:
117
+ yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
71
118
 
72
- def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
73
- if not isinstance(candidate_fn, (FunctionType, MethodType)):
74
- return False
75
- fn_path = Path(inspect.getfile(candidate_fn)).resolve()
76
- return cwd in fn_path.parents and ".venv" not in fn_path.parts
119
+ def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
120
+ scope_vars_signature: dict[str, Type | object] = {
121
+ param.name: class_anno
122
+ for param in signature(fn).parameters.values()
123
+ if param.annotation is not Parameter.empty
124
+ if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
125
+ if (class_anno := resolve_class_annotations(param.annotation))
126
+ }
127
+ if self_obj := get_self_value(fn):
128
+ scope_vars_signature["self"] = self_obj
129
+ scope_vars = {
130
+ "LOAD_FAST": scope_vars_signature,
131
+ "LOAD_DEREF": dict(get_cell_contents(fn)),
132
+ "LOAD_GLOBAL": fn.__globals__,
133
+ }
134
+ captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
135
+ captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
136
+ capturables = list(get_capturables(fn, capture, captured_vars))
137
+ return captured_callables, capturables
77
138
 
78
- def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
139
+ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn = {}) -> CapturableByFn:
79
140
  from .checkpoint import CachedFunction
80
- captured_vals = get_fn_captured_vals(fn)
81
- captured_vals_by_fn = captured_vals_by_fn or {}
82
- captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
83
- for val in captured_vals:
84
- if not callable(val):
85
- continue
86
- child_fn = unwrap_fn(val, cached_fn=True)
87
- if isinstance(child_fn, CachedFunction):
88
- captured_vals_by_fn[child_fn] = []
89
- elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
90
- get_depend_fns(child_fn, captured_vals_by_fn)
91
- return captured_vals_by_fn
141
+ captured_callables, capturables = get_fn_captures(fn, capture)
142
+ capturable_by_fn = capturable_by_fn or {}
143
+ capturable_by_fn[fn] = capturables
144
+ for depend_fn in captured_callables:
145
+ depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
146
+ if isinstance(depend_fn, CachedFunction):
147
+ capturable_by_fn[depend_fn.ident.cached_fn] = []
148
+ elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
149
+ get_depend_fns(depend_fn, capture, capturable_by_fn)
150
+ return capturable_by_fn
92
151
 
93
152
  def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
94
153
  from .checkpoint import CachedFunction
95
- captured_vals_by_fn = get_depend_fns(fn)
96
- depend_captured_vals = list(captured_vals_by_fn.values()) * capture
97
- depends = captured_vals_by_fn.keys()
154
+ capturable_by_fn = get_depend_fns(fn, capture)
155
+ capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
156
+ depends = capturable_by_fn.keys()
98
157
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
99
- unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
100
- assert fn == unwrapped_depends[0]
101
- fn_hash = str(ObjectHash(iter=unwrapped_depends))
102
- captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
103
- return RawFunctionIdent(fn_hash, captured_hash, depends)
158
+ depend_callables = [fn for fn in depends if not isinstance(fn, CachedFunction)]
159
+ assert fn == depend_callables[0]
160
+ fn_hash = str(ObjectHash(iter=map(get_fn_aststr, depend_callables)))
161
+ return RawFunctionIdent(fn_hash, depends, capturables)