checkpointer 2.11.2__py3-none-any.whl → 2.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import gc
2
2
  import tempfile
3
3
  from typing import Callable
4
- from .checkpoint import CachedFunction, Checkpointer, CheckpointError
4
+ from .checkpoint import CachedFunction, Checkpointer, CheckpointError, FunctionIdent
5
5
  from .object_hash import ObjectHash
6
6
  from .storages import MemoryStorage, PickleStorage, Storage
7
- from .types import AwaitableValue, HashBy, NoHash
7
+ from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeOnce, HashBy, NoHash
8
8
 
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
@@ -2,29 +2,26 @@ from __future__ import annotations
2
2
  import re
3
3
  from datetime import datetime
4
4
  from functools import cached_property, update_wrapper
5
- from inspect import Parameter, Signature, iscoroutine, signature
5
+ from inspect import Parameter, iscoroutine, signature, unwrap
6
+ from itertools import chain
6
7
  from pathlib import Path
7
8
  from typing import (
8
- Annotated, Callable, Concatenate, Coroutine, Generic,
9
- Iterable, Literal, Self, Type, TypedDict,
10
- Unpack, cast, get_args, get_origin, overload,
9
+ Callable, Concatenate, Coroutine, Generic, Iterable,
10
+ Literal, Self, Type, TypedDict, Unpack, cast, overload,
11
11
  )
12
- from .fn_ident import RawFunctionIdent, get_fn_ident
12
+ from .fn_ident import Capturable, RawFunctionIdent, get_fn_ident
13
13
  from .object_hash import ObjectHash
14
14
  from .print_checkpoint import print_checkpoint
15
- from .storages import STORAGE_MAP, Storage
16
- from .types import AwaitableValue, C, Coro, Fn, HashBy, P, R
17
- from .utils import unwrap_fn
15
+ from .storages import STORAGE_MAP, Storage, StorageType
16
+ from .types import AwaitableValue, C, Coro, Fn, P, R, hash_by_from_annotation
18
17
 
19
18
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
20
19
 
21
- empty_set = cast(set, frozenset())
22
-
23
20
  class CheckpointError(Exception):
24
21
  pass
25
22
 
26
23
  class CheckpointerOpts(TypedDict, total=False):
27
- format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
24
+ format: Type[Storage] | StorageType
28
25
  root_path: Path | str | None
29
26
  when: bool
30
27
  verbosity: Literal[0, 1, 2]
@@ -63,28 +60,52 @@ class FunctionIdent:
63
60
  self.__dict__.clear()
64
61
  self.cached_fn = cached_fn
65
62
 
63
+ def reset(self):
64
+ self.__init__(self.cached_fn)
65
+
66
+ def is_static(self) -> bool:
67
+ return self.cached_fn.checkpointer.fn_hash_from is not None
68
+
66
69
  @cached_property
67
70
  def raw_ident(self) -> RawFunctionIdent:
68
- return get_fn_ident(unwrap_fn(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
71
+ return get_fn_ident(unwrap(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
69
72
 
70
73
  @cached_property
71
74
  def fn_hash(self) -> str:
72
- if (hash_from := self.cached_fn.checkpointer.fn_hash_from) is not None:
73
- return str(ObjectHash(hash_from, digest_size=16))
74
- deep_hashes = [depend.ident.raw_ident.fn_hash for depend in self.cached_fn.deep_depends()]
75
+ if self.is_static():
76
+ return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
77
+ depends = self.deep_idents(past_static=False)
78
+ deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
75
79
  return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
76
80
 
77
81
  @cached_property
78
- def captured_hash(self) -> str:
79
- deep_hashes = [depend.ident.raw_ident.captured_hash for depend in self.cached_fn.deep_depends()]
80
- return str(ObjectHash().write_text(iter=deep_hashes))
82
+ def capturables(self) -> list[Capturable]:
83
+ return sorted({
84
+ capturable.key: capturable
85
+ for depend in self.deep_idents()
86
+ for capturable in depend.raw_ident.capturables
87
+ }.values())
88
+
89
+ def deep_depends(self, past_static=True, visited: set[Callable] = set()) -> Iterable[Callable]:
90
+ if self.cached_fn not in visited:
91
+ yield self.cached_fn
92
+ visited = visited or set()
93
+ visited.add(self.cached_fn)
94
+ stop = not past_static and self.is_static()
95
+ depends = [] if stop else self.raw_ident.depends
96
+ for depend in depends:
97
+ if isinstance(depend, CachedFunction):
98
+ yield from depend.ident.deep_depends(past_static, visited)
99
+ elif depend not in visited:
100
+ yield depend
101
+ visited.add(depend)
81
102
 
82
- def reset(self):
83
- self.__init__(self.cached_fn)
103
+ def deep_idents(self, past_static=True) -> Iterable[FunctionIdent]:
104
+ return (fn.ident for fn in self.deep_depends(past_static) if isinstance(fn, CachedFunction))
84
105
 
85
106
  class CachedFunction(Generic[Fn]):
86
107
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
87
- wrapped = unwrap_fn(fn)
108
+ wrapped = unwrap(fn)
88
109
  fn_file = Path(wrapped.__code__.co_filename).name
89
110
  fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
90
111
  Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
@@ -95,20 +116,14 @@ class CachedFunction(Generic[Fn]):
95
116
  self.storage = Storage(self)
96
117
  self.cleanup = self.storage.cleanup
97
118
  self.bound = ()
98
- self.attrname: str | None = None
99
119
 
100
- sig = signature(wrapped)
101
- params = list(sig.parameters.items())
120
+ params = list(signature(wrapped).parameters.values())
102
121
  pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
103
- self.arg_names = [name for name, param in params if param.kind in pos_params]
104
- self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
105
- self.hash_by_map = get_hash_by_map(sig)
122
+ self.arg_names = [param.name for param in params if param.kind in pos_params]
123
+ self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
124
+ self.hash_by_map = get_hash_by_map(params)
106
125
  self.ident = FunctionIdent(self)
107
126
 
108
- def __set_name__(self, _, name: str):
109
- assert self.attrname is None
110
- self.attrname = name
111
-
112
127
  @overload
113
128
  def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
114
129
  @overload
@@ -116,12 +131,9 @@ class CachedFunction(Generic[Fn]):
116
131
  def __get__(self, instance, owner):
117
132
  if instance is None:
118
133
  return self
119
- assert self.attrname is not None
120
134
  bound_fn = object.__new__(CachedFunction)
121
135
  bound_fn.__dict__ |= self.__dict__
122
136
  bound_fn.bound = (instance,)
123
- if hasattr(instance, "__dict__"):
124
- setattr(instance, self.attrname, bound_fn)
125
137
  return bound_fn
126
138
 
127
139
  @property
@@ -129,12 +141,12 @@ class CachedFunction(Generic[Fn]):
129
141
  return self.ident.raw_ident.depends
130
142
 
131
143
  def reinit(self, recursive=False) -> CachedFunction[Fn]:
132
- depend_idents = [depend.ident for depend in self.deep_depends()] if recursive else [self.ident]
144
+ depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
133
145
  for ident in depend_idents: ident.reset()
134
146
  for ident in depend_idents: ident.fn_hash
135
147
  return self
136
148
 
137
- def get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
149
+ def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
138
150
  args = self.bound + args
139
151
  pos_args = args[len(self.arg_names):]
140
152
  named_pos_args = dict(zip(self.arg_names, args))
@@ -145,8 +157,17 @@ class CachedFunction(Generic[Fn]):
145
157
  if hash_by := hash_by_map.get(key, rest_hash_by):
146
158
  named_args[key] = hash_by(value)
147
159
  if pos_hash_by := hash_by_map.get(b"*"):
148
- pos_args = tuple(map(pos_hash_by, pos_args))
149
- return str(ObjectHash(named_args, pos_args, self.ident.captured_hash, digest_size=16))
160
+ pos_args = map(pos_hash_by, pos_args)
161
+ named_args_iter = chain.from_iterable(sorted(named_args.items()))
162
+ captured = chain.from_iterable(capturable.capture() for capturable in self.ident.capturables)
163
+ obj_hash = ObjectHash(digest_size=16) \
164
+ .update(iter=named_args_iter, header="NAMED") \
165
+ .update(iter=pos_args, header="POS") \
166
+ .update(iter=captured, header="CAPTURED")
167
+ return str(obj_hash)
168
+
169
+ def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
170
+ return self._get_call_hash(args, kw)
150
171
 
151
172
  async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
152
173
  return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
@@ -157,7 +178,7 @@ class CachedFunction(Generic[Fn]):
157
178
  if not params.when:
158
179
  return self.fn(*full_args, **kw)
159
180
 
160
- call_hash = self.get_call_hash(args, kw)
181
+ call_hash = self._get_call_hash(args, kw)
161
182
  call_id = f"{self.storage.fn_id()}/{call_hash}"
162
183
  refresh = rerun \
163
184
  or not self.storage.exists(call_hash) \
@@ -186,17 +207,17 @@ class CachedFunction(Generic[Fn]):
186
207
  return self._call(args, kw, True)
187
208
 
188
209
  def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
189
- return self.storage.exists(self.get_call_hash(args, kw))
210
+ return self.storage.exists(self._get_call_hash(args, kw))
190
211
 
191
212
  def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
192
- self.storage.delete(self.get_call_hash(args, kw))
213
+ self.storage.delete(self._get_call_hash(args, kw))
193
214
 
194
215
  @overload
195
216
  def get(self: Callable[P, Coro[R]], *args: P.args, **kw: P.kwargs) -> R: ...
196
217
  @overload
197
218
  def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
198
219
  def get(self, *args, **kw):
199
- call_hash = self.get_call_hash(args, kw)
220
+ call_hash = self._get_call_hash(args, kw)
200
221
  try:
201
222
  data = self.storage.load(call_hash)
202
223
  return data.value if isinstance(data, AwaitableValue) else data
@@ -208,30 +229,15 @@ class CachedFunction(Generic[Fn]):
208
229
  @overload
209
230
  def set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
210
231
  def set(self, value, *args, **kw):
211
- self.storage.store(self.get_call_hash(args, kw), value)
232
+ self.storage.store(self._get_call_hash(args, kw), value)
212
233
 
213
234
  def __repr__(self) -> str:
214
235
  return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
215
236
 
216
- def deep_depends(self, visited: set[CachedFunction] = empty_set) -> Iterable[CachedFunction]:
217
- if self not in visited:
218
- yield self
219
- visited = visited or set()
220
- visited.add(self)
221
- for depend in self.depends:
222
- if isinstance(depend, CachedFunction):
223
- yield from depend.deep_depends(visited)
224
-
225
- def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
226
- if get_origin(annotation) is Annotated:
227
- args = get_args(annotation)
228
- metadata = args[1] if len(args) > 1 else None
229
- if get_origin(metadata) is HashBy:
230
- return get_args(metadata)[0]
231
-
232
- def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
237
+ def get_hash_by_map(params: list[Parameter]) -> dict[str | bytes, Callable[[object], object]]:
233
238
  hash_by_map = {}
234
- for name, param in sig.parameters.items():
239
+ for param in params:
240
+ name = param.name
235
241
  if param.kind == Parameter.VAR_POSITIONAL:
236
242
  name = b"*"
237
243
  elif param.kind == Parameter.VAR_KEYWORD:
checkpointer/fn_ident.py CHANGED
@@ -1,34 +1,60 @@
1
1
  import dis
2
- import inspect
3
- from itertools import takewhile
4
- from pathlib import Path
5
- from types import CodeType, FunctionType, MethodType
6
- from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
2
+ from inspect import Parameter, getmodule, signature, unwrap
3
+ from types import CodeType, MethodType, ModuleType
4
+ from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
5
+ from .import_mappings import resolve_annotation
7
6
  from .object_hash import ObjectHash
8
- from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
7
+ from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
8
+ from .utils import (
9
+ AttrDict, cwd, distinct, get_cell_contents,
10
+ get_file, is_class, is_user_fn, seekable, takewhile,
11
+ )
9
12
 
10
- cwd = Path.cwd().resolve()
13
+ AttrPath = tuple[str, ...]
14
+ CapturableByFn = dict[Callable, list["Capturable"]]
11
15
 
12
16
  class RawFunctionIdent(NamedTuple):
13
17
  fn_hash: str
14
- captured_hash: str
15
18
  depends: list[Callable]
19
+ capturables: set["Capturable"]
16
20
 
17
- def is_class(obj) -> TypeGuard[Type]:
18
- # isinstance works too, but needlessly triggers _lazyinit()
19
- return issubclass(type(obj), type)
21
+ class Capturable(NamedTuple):
22
+ key: str
23
+ module: ModuleType
24
+ attr_path: AttrPath
25
+ hash_by: Callable | None
26
+ hash: str | None = None
27
+
28
+ def capture(self) -> tuple[str, object]:
29
+ if obj := self.hash:
30
+ return self.key, obj
31
+ obj = AttrDict.get_at(self.module, *self.attr_path)
32
+ obj = self.hash_by(obj) if self.hash_by else obj
33
+ return self.key, obj
34
+
35
+ @staticmethod
36
+ def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
37
+ file = str(get_file(module).relative_to(cwd))
38
+ key = "-".join((file, *attr_path))
39
+ cap = Capturable(key, module, attr_path, hash_by)
40
+ if not capture_once:
41
+ return cap
42
+ obj_hash = str(ObjectHash(cap.capture()[1]))
43
+ return Capturable(key, module, attr_path, None, obj_hash)
20
44
 
21
45
  def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
22
- attr_path: tuple[str, ...] = ()
46
+ attr_path = AttrPath(())
23
47
  scope_obj = None
24
48
  classvars: dict[str, dict[str, Type]] = {}
25
- for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
49
+ instructs = seekable(dis.get_instructions(code))
50
+ for instr in instructs:
26
51
  if instr.opname in scope_vars and not attr_path:
27
- attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
28
- attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
52
+ attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
53
+ attr_path = AttrPath((instr.opname, instr.argval, *attrs))
54
+ instructs.step(-1)
29
55
  elif instr.opname == "CALL":
30
- obj = scope_vars.get_at(attr_path)
31
- attr_path = ()
56
+ obj = scope_vars.get_at(*attr_path)
57
+ attr_path = AttrPath(())
32
58
  if is_class(obj):
33
59
  scope_obj = obj
34
60
  elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
@@ -37,67 +63,97 @@ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[st
37
63
  scope_obj = None
38
64
  return classvars
39
65
 
40
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
66
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[AttrPath, object]]:
41
67
  classvars = extract_classvars(code, scope_vars)
42
68
  scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
43
- for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
69
+ instructs = seekable(dis.get_instructions(code))
70
+ for instr in instructs:
44
71
  if instr.opname in scope_vars:
45
- attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
46
- attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
47
- val = scope_vars.get_at(attr_path)
48
- if val is not None:
49
- yield attr_path, val
72
+ attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
73
+ attr_path = AttrPath((instr.opname, instr.argval, *attrs))
74
+ parent_path = attr_path[:-1]
75
+ instructs.step(-1)
76
+ obj = scope_vars.get_at(*attr_path)
77
+ if obj is not None:
78
+ yield attr_path, obj
79
+ if callable(obj) and parent_path[1:]:
80
+ parent_obj = scope_vars.get_at(*parent_path)
81
+ yield parent_path, parent_obj
50
82
  for const in code.co_consts:
51
83
  if isinstance(const, CodeType):
52
- yield from extract_scope_values(const, scope_vars)
84
+ next_deref = scope_vars.LOAD_DEREF.set(scope_vars.LOAD_FAST)
85
+ next_scope_vars = AttrDict({**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref})
86
+ yield from extract_scope_values(const, next_scope_vars)
87
+
88
+ def resolve_class_annotations(anno: object) -> Type | None:
89
+ if anno in (None, Annotated):
90
+ return None
91
+ elif is_class(anno):
92
+ return anno
93
+ elif get_origin(anno) is Annotated:
94
+ return resolve_class_annotations(next(iter(get_args(anno)), None))
95
+ return resolve_class_annotations(get_origin(anno))
53
96
 
54
97
  def get_self_value(fn: Callable) -> type | object | None:
55
98
  if isinstance(fn, MethodType):
56
99
  return fn.__self__
57
- parts = tuple(fn.__qualname__.split(".")[:-1])
58
- cls = parts and AttrDict(fn.__globals__).get_at(parts)
100
+ parts = fn.__qualname__.split(".")[:-1]
101
+ cls = parts and AttrDict(fn.__globals__).get_at(*parts)
59
102
  if is_class(cls):
60
103
  return cls
61
104
 
62
- def get_fn_captured_vals(fn: Callable) -> list[object]:
105
+ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, object]) -> Iterable[Capturable]:
106
+ module = getmodule(fn)
107
+ if not module or not is_user_fn(fn):
108
+ return
109
+ for (instruct, *attr_path), obj in captured_vars.items():
110
+ attr_path = AttrPath(attr_path)
111
+ if instruct == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
112
+ anno = resolve_annotation(module, ".".join(attr_path))
113
+ if capture or is_capture_me(anno) or is_capture_me_once(anno):
114
+ hash_by = hash_by_from_annotation(anno)
115
+ if hash_by is not to_none:
116
+ yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
117
+
118
+ def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
119
+ sig_scope = {
120
+ param.name: class_anno
121
+ for param in signature(fn).parameters.values()
122
+ if param.annotation is not Parameter.empty
123
+ if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
124
+ if (class_anno := resolve_class_annotations(param.annotation))
125
+ }
63
126
  self_value = get_self_value(fn)
64
127
  scope_vars = AttrDict({
65
- "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
128
+ "LOAD_FAST": AttrDict({**sig_scope, "self": self_value} if self_value else sig_scope),
66
129
  "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
67
130
  "LOAD_GLOBAL": AttrDict(fn.__globals__),
68
131
  })
69
- vals = dict(extract_scope_values(fn.__code__, scope_vars))
70
- return list(vals.values())
71
-
72
- def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
73
- if not isinstance(candidate_fn, (FunctionType, MethodType)):
74
- return False
75
- fn_path = Path(inspect.getfile(candidate_fn)).resolve()
76
- return cwd in fn_path.parents and ".venv" not in fn_path.parts
132
+ captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
133
+ captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
134
+ capturables = list(get_capturables(fn, capture, captured_vars))
135
+ return captured_callables, capturables
77
136
 
78
- def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
137
+ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn = {}) -> CapturableByFn:
79
138
  from .checkpoint import CachedFunction
80
- captured_vals = get_fn_captured_vals(fn)
81
- captured_vals_by_fn = captured_vals_by_fn or {}
82
- captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
83
- for val in captured_vals:
84
- if not callable(val):
85
- continue
86
- child_fn = unwrap_fn(val, cached_fn=True)
87
- if isinstance(child_fn, CachedFunction):
88
- captured_vals_by_fn[child_fn] = []
89
- elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
90
- get_depend_fns(child_fn, captured_vals_by_fn)
91
- return captured_vals_by_fn
139
+ captured_callables, capturables = get_fn_captures(fn, capture)
140
+ capturable_by_fn = capturable_by_fn or {}
141
+ capturable_by_fn[fn] = capturables
142
+ for depend_fn in captured_callables:
143
+ depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
144
+ if isinstance(depend_fn, CachedFunction):
145
+ capturable_by_fn[depend_fn] = []
146
+ elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
147
+ get_depend_fns(depend_fn, capture, capturable_by_fn)
148
+ return capturable_by_fn
92
149
 
93
150
  def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
94
151
  from .checkpoint import CachedFunction
95
- captured_vals_by_fn = get_depend_fns(fn)
96
- depend_captured_vals = list(captured_vals_by_fn.values()) * capture
97
- depends = captured_vals_by_fn.keys()
152
+ capturable_by_fn = get_depend_fns(fn, capture)
153
+ capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
154
+ depends = capturable_by_fn.keys()
98
155
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
99
156
  unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
100
157
  assert fn == unwrapped_depends[0]
101
158
  fn_hash = str(ObjectHash(iter=unwrapped_depends))
102
- captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
103
- return RawFunctionIdent(fn_hash, captured_hash, depends)
159
+ return RawFunctionIdent(fn_hash, depends, capturables)
@@ -0,0 +1,47 @@
1
+ import ast
2
+ import inspect
3
+ import sys
4
+ from types import ModuleType
5
+ from typing import Iterable, Type
6
+ from .utils import cwd, get_file, is_user_file
7
+
8
+ ImportTarget = tuple[str, str | None]
9
+
10
+ cache: dict[tuple[str, int], dict[str, ImportTarget]] = {}
11
+
12
+ def generate_import_mappings(module: ModuleType) -> Iterable[tuple[str, ImportTarget]]:
13
+ mod_path = get_file(module)
14
+ if not is_user_file(mod_path):
15
+ return
16
+ mod_parts = list(mod_path.with_suffix("").relative_to(cwd).parts)
17
+ source = inspect.getsource(module)
18
+ tree = ast.parse(source)
19
+ for node in ast.walk(tree):
20
+ if isinstance(node, ast.Import):
21
+ for alias in node.names:
22
+ yield (alias.asname or alias.name, (alias.name, None))
23
+ elif isinstance(node, ast.ImportFrom):
24
+ target_mod = node.module or ""
25
+ if node.level > 0:
26
+ target_mod_parts = target_mod.split(".") * bool(target_mod)
27
+ target_mod_parts = mod_parts[:-node.level] + target_mod_parts
28
+ target_mod = ".".join(target_mod_parts)
29
+ for alias in node.names:
30
+ yield (alias.asname or alias.name, (target_mod, alias.name))
31
+
32
+ def get_import_mappings(module: ModuleType) -> dict[str, ImportTarget]:
33
+ cache_key = (module.__name__, id(module))
34
+ if cached := cache.get(cache_key):
35
+ return cached
36
+ import_mappings = dict(generate_import_mappings(module))
37
+ return cache.setdefault(cache_key, import_mappings)
38
+
39
+ def resolve_annotation(module: ModuleType, attr_name: str | None) -> Type | None:
40
+ if not attr_name:
41
+ return None
42
+ if anno := module.__annotations__.get(attr_name):
43
+ return anno
44
+ if next_pair := get_import_mappings(module).get(attr_name):
45
+ next_module_name, next_attr_name = next_pair
46
+ if next_module := sys.modules.get(next_module_name):
47
+ return resolve_annotation(next_module, next_attr_name)
@@ -12,7 +12,7 @@ from io import StringIO
12
12
  from itertools import chain
13
13
  from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
14
14
  from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
15
- from typing import Callable, TypeVar
15
+ from typing import Callable, Self, TypeVar
16
16
  from .utils import ContextVar
17
17
 
18
18
  np, torch = None, None
@@ -43,14 +43,14 @@ class ObjectHashError(Exception):
43
43
  self.obj = obj
44
44
 
45
45
  class ObjectHash:
46
- def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerate_errors=False) -> None:
46
+ def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerable=False) -> None:
47
47
  self.hash = hashlib.blake2b(digest_size=digest_size)
48
48
  self.current: dict[int, int] = {}
49
- self.tolerate_errors = ContextVar(tolerate_errors)
49
+ self.tolerable = ContextVar(tolerable)
50
50
  self.update(iter=chain(objs, iter))
51
51
 
52
52
  def copy(self) -> "ObjectHash":
53
- new = ObjectHash(tolerate_errors=self.tolerate_errors.value)
53
+ new = ObjectHash(tolerable=self.tolerable.value)
54
54
  new.hash = self.hash.copy()
55
55
  return new
56
56
 
@@ -63,26 +63,29 @@ class ObjectHash:
63
63
  return isinstance(value, ObjectHash) and str(self) == str(value)
64
64
 
65
65
  def nested_hash(self, *objs: object) -> str:
66
- return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
66
+ return ObjectHash(iter=objs, tolerable=self.tolerable.value).hexdigest()
67
67
 
68
- def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
68
+ def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> Self:
69
69
  for d in chain(data, iter):
70
70
  self.hash.update(d)
71
71
  return self
72
72
 
73
- def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
73
+ def write_text(self, *data: str, iter: Iterable[str] = ()) -> Self:
74
74
  return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
75
75
 
76
- def header(self, *args: object) -> "ObjectHash":
76
+ def header(self, *args: object) -> Self:
77
77
  return self.write_bytes(":".join(map(str, args)).encode())
78
78
 
79
- def update(self, *objs: object, iter: Iterable[object] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
80
- with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
79
+ def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
80
+ with nullcontext() if tolerable is None else self.tolerable.set(tolerable):
81
81
  for obj in chain(objs, iter):
82
+ if header is not None:
83
+ self.write_bytes(header.encode())
84
+ header = None
82
85
  try:
83
86
  self._update_one(obj)
84
87
  except Exception as ex:
85
- if self.tolerate_errors.value:
88
+ if self.tolerable.value:
86
89
  self.header("error").update(type(ex))
87
90
  else:
88
91
  raise ObjectHashError(obj, ex) from ex
@@ -180,10 +183,10 @@ class ObjectHash:
180
183
  finally:
181
184
  del self.current[id(obj)]
182
185
 
183
- def _update_iterator(self, obj: Iterable) -> "ObjectHash":
186
+ def _update_iterator(self, obj: Iterable) -> Self:
184
187
  return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
185
188
 
186
- def _update_object(self, obj: object) -> "ObjectHash":
189
+ def _update_object(self, obj: object) -> Self:
187
190
  self.header("instance", encode_type_of(obj))
188
191
  get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
189
192
  if callable(get_hash):
@@ -1,9 +1,11 @@
1
- from typing import Type
2
- from .storage import Storage
3
- from .pickle_storage import PickleStorage
1
+ from typing import Literal, Type
4
2
  from .memory_storage import MemoryStorage
3
+ from .pickle_storage import PickleStorage
4
+ from .storage import Storage
5
+
6
+ StorageType = Literal["pickle", "memory"]
5
7
 
6
- STORAGE_MAP: dict[str, Type[Storage]] = {
8
+ STORAGE_MAP: dict[StorageType, Type[Storage]] = {
7
9
  "pickle": PickleStorage,
8
10
  "memory": MemoryStorage,
9
11
  }
checkpointer/types.py CHANGED
@@ -1,4 +1,7 @@
1
- from typing import Annotated, Callable, Coroutine, Generic, ParamSpec, TypeVar
1
+ from typing import (
2
+ Annotated, Callable, Coroutine, Generic,
3
+ ParamSpec, TypeVar, get_args, get_origin,
4
+ )
2
5
 
3
6
  Fn = TypeVar("Fn", bound=Callable)
4
7
  P = ParamSpec("P")
@@ -9,7 +12,32 @@ T = TypeVar("T")
9
12
  class HashBy(Generic[Fn]):
10
13
  pass
11
14
 
12
- NoHash = Annotated[T, HashBy[lambda _: None]]
15
+ class Captured:
16
+ pass
17
+
18
+ class CapturedOnce:
19
+ pass
20
+
21
+ def to_none(_):
22
+ return None
23
+
24
+ def get_annotated_args(anno: object) -> tuple[object, ...]:
25
+ return get_args(anno) if get_origin(anno) is Annotated else ()
26
+
27
+ def hash_by_from_annotation(anno: object) -> Callable[[object], object] | None:
28
+ for arg in get_annotated_args(anno):
29
+ if get_origin(arg) is HashBy:
30
+ return get_args(arg)[0]
31
+
32
+ def is_capture_me(anno: object) -> bool:
33
+ return Captured in get_annotated_args(anno)
34
+
35
+ def is_capture_me_once(anno: object) -> bool:
36
+ return CapturedOnce in get_annotated_args(anno)
37
+
38
+ NoHash = Annotated[T, HashBy[to_none]]
39
+ CaptureMe = Annotated[T, Captured]
40
+ CaptureMeOnce = Annotated[T, CapturedOnce]
13
41
  Coro = Coroutine[object, object, R]
14
42
 
15
43
  class AwaitableValue(Generic[T]):
checkpointer/utils.py CHANGED
@@ -1,9 +1,25 @@
1
+ from __future__ import annotations
2
+ import inspect
1
3
  from contextlib import contextmanager
2
- from typing import Callable, Generic, Iterable, cast
3
- from .types import Fn, T
4
+ from itertools import islice
5
+ from pathlib import Path
6
+ from types import FunctionType, MethodType, ModuleType
7
+ from typing import Callable, Generic, Iterable, Self, Type, TypeGuard
8
+ from .types import T
4
9
 
5
- def distinct(seq: Iterable[T]) -> list[T]:
6
- return list(dict.fromkeys(seq))
10
+ cwd = Path.cwd().resolve()
11
+
12
+ def is_class(obj) -> TypeGuard[Type]:
13
+ return isinstance(obj, type)
14
+
15
+ def get_file(obj: Callable | ModuleType) -> Path:
16
+ return Path(inspect.getfile(obj)).resolve()
17
+
18
+ def is_user_file(path: Path) -> bool:
19
+ return cwd in path.parents and ".venv" not in path.parts
20
+
21
+ def is_user_fn(obj) -> TypeGuard[Callable]:
22
+ return isinstance(obj, (FunctionType, MethodType)) and is_user_file(get_file(obj))
7
23
 
8
24
  def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
9
25
  for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
@@ -12,12 +28,57 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
12
28
  except ValueError:
13
29
  pass
14
30
 
15
- def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
16
- from .checkpoint import CachedFunction
17
- while True:
18
- if (cached_fn and isinstance(fn, CachedFunction)) or not hasattr(fn, "__wrapped__"):
19
- return cast(Fn, fn)
20
- fn = getattr(fn, "__wrapped__")
31
+ def distinct(seq: Iterable[T]) -> list[T]:
32
+ return list(dict.fromkeys(seq))
33
+
34
+ def takewhile(iter: Iterable[tuple[bool, T]]) -> Iterable[T]:
35
+ for condition, value in iter:
36
+ if not condition:
37
+ return
38
+ yield value
39
+
40
+ class seekable(Generic[T]):
41
+ def __init__(self, iterable: Iterable[T]):
42
+ self.index = 0
43
+ self.source = iter(iterable)
44
+ self.sink: list[T] = []
45
+
46
+ def __iter__(self):
47
+ return self
48
+
49
+ def __next__(self) -> T:
50
+ if len(self.sink) > self.index:
51
+ item = self.sink[self.index]
52
+ else:
53
+ item = next(self.source)
54
+ self.sink.append(item)
55
+ self.index += 1
56
+ return item
57
+
58
+ def __bool__(self):
59
+ return bool(self.lookahead(1))
60
+
61
+ def seek(self, index: int) -> Self:
62
+ remainder = index - len(self.sink)
63
+ if remainder > 0:
64
+ next(islice(self, remainder, remainder), None)
65
+ self.index = max(0, min(index, len(self.sink)))
66
+ return self
67
+
68
+ def step(self, count: int) -> Self:
69
+ return self.seek(self.index + count)
70
+
71
+ @contextmanager
72
+ def freeze(self):
73
+ initial_index = self.index
74
+ try:
75
+ yield
76
+ finally:
77
+ self.seek(initial_index)
78
+
79
+ def lookahead(self, count: int) -> list[T]:
80
+ with self.freeze():
81
+ return list(islice(self, count))
21
82
 
22
83
  class AttrDict(dict):
23
84
  def __init__(self, *args, **kwargs):
@@ -30,22 +91,16 @@ class AttrDict(dict):
30
91
  def __setattr__(self, name: str, value: object):
31
92
  super().__setattr__(name, value)
32
93
 
33
- def set(self, d: dict) -> "AttrDict":
94
+ def set(self, d: dict) -> AttrDict:
34
95
  if not d:
35
96
  return self
36
97
  return AttrDict({**self, **d})
37
98
 
38
- def delete(self, *attrs: str) -> "AttrDict":
39
- d = AttrDict(self)
99
+ def get_at(self: object, *attrs: str) -> object:
100
+ obj = self
40
101
  for attr in attrs:
41
- del d[attr]
42
- return d
43
-
44
- def get_at(self, attrs: tuple[str, ...]) -> object:
45
- d = self
46
- for attr in attrs:
47
- d = getattr(d, attr, None)
48
- return d
102
+ obj = getattr(obj, attr, None)
103
+ return obj
49
104
 
50
105
  class ContextVar(Generic[T]):
51
106
  def __init__(self, value: T):
@@ -58,26 +113,3 @@ class ContextVar(Generic[T]):
58
113
  yield
59
114
  finally:
60
115
  self.value = old
61
-
62
- class iterate_and_upcoming(Generic[T]):
63
- def __init__(self, it: Iterable[T]) -> None:
64
- self.it = iter(it)
65
- self.previous: tuple[()] | tuple[T] = ()
66
- self.tracked = self._tracked_iter()
67
-
68
- def __iter__(self):
69
- return self
70
-
71
- def __next__(self) -> tuple[T, Iterable[T]]:
72
- try:
73
- item = self.previous[0] if self.previous else next(self.it)
74
- self.previous = ()
75
- return item, self.tracked
76
- except StopIteration:
77
- self.tracked.close()
78
- raise
79
-
80
- def _tracked_iter(self):
81
- for x in self.it:
82
- self.previous = (x,)
83
- yield x
@@ -1,17 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.11.2
3
+ Version: 2.12.0
4
4
  Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
7
- License: Copyright 2018-2025 Hampus Hallman
8
-
9
- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
-
11
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
-
13
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
7
+ License-Expression: MIT
8
+ License-File: ATTRIBUTION.md
14
9
  License-File: LICENSE
10
+ Keywords: async,cache,caching,code-aware,decorator,fast,hashing,invalidation,memoization,memoize,memory,optimization,performance,workflow
15
11
  Classifier: Programming Language :: Python :: 3.11
16
12
  Classifier: Programming Language :: Python :: 3.12
17
13
  Classifier: Programming Language :: Python :: 3.13
@@ -0,0 +1,17 @@
1
+ checkpointer/__init__.py,sha256=x8LbL-URPg-afn0O-HMxPsJkV9cmW-Cw5epKSCGClOM,889
2
+ checkpointer/checkpoint.py,sha256=2K2D_aYHpI1XBcNrWQxta06wRdUj81TeCXshBVxl7cA,9869
3
+ checkpointer/fn_ident.py,sha256=4qg4NIvcWRV5GduO70an_iu12dn8LLIlw3Q8F3gm0Mo,6826
4
+ checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
5
+ checkpointer/object_hash.py,sha256=MkrwSJJYVlWOjIDpGfYpAeYFPgCJhXHYBdlKVcUy2kw,8145
6
+ checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
7
+ checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
8
+ checkpointer/utils.py,sha256=rq7AjR0WJql1o8clKMdXj4p3wrPYMLyS6G11nvNNjFE,2934
9
+ checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
10
+ checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
11
+ checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
12
+ checkpointer/storages/storage.py,sha256=5Jel7VlmCG9gnIbZFKT1NrEiAePz8ZD8hfsD-tYEiP4,905
13
+ checkpointer-2.12.0.dist-info/METADATA,sha256=B-R8Aq2caTM7fclANY7MJVZJQaU0nB1Mdf7wMr6HSTI,10705
14
+ checkpointer-2.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ checkpointer-2.12.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
16
+ checkpointer-2.12.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
17
+ checkpointer-2.12.0.dist-info/RECORD,,
@@ -0,0 +1,33 @@
1
+ # Attribution and License Notices
2
+
3
+ This project includes code copied or adapted from third-party open-source projects. The following acknowledges the original sources and complies with their licensing requirements.
4
+
5
+ ---
6
+
7
+ ## Third-Party Code
8
+
9
+ ### more-itertools
10
+ - **Source:** https://github.com/more-itertools/more-itertools
11
+ - **Author:** Erik Rose
12
+ - **Copyright:** (c) 2012 Erik Rose
13
+ - **License:** MIT (https://github.com/more-itertools/more-itertools/blob/master/LICENSE)
14
+
15
+ ### colored
16
+ - **Source:** https://gitlab.com/dslackw/colored
17
+ - **Author:** Dimitris Zlatanidis
18
+ - **Copyright:** (c) 2014-2025 Dimitris Zlatanidis
19
+ - **License:** MIT (https://gitlab.com/dslackw/colored/-/blob/master/LICENSE.txt)
20
+
21
+ ---
22
+
23
+ ## License
24
+
25
+ This project is licensed under the MIT License. See the `LICENSE` file for details.
26
+
27
+ ---
28
+
29
+ ## Notes
30
+
31
+ - Third-party code is included under their original MIT licenses.
32
+ - This file documents those license notices, fulfilling attribution obligations.
33
+ - Source files with copied code may omit individual license headers in favor of this centralized attribution.
@@ -1,3 +1,5 @@
1
+ MIT License
2
+
1
3
  Copyright 2018-2025 Hampus Hallman
2
4
 
3
5
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
@@ -1,168 +0,0 @@
1
- import asyncio
2
- import pytest
3
- from checkpointer import CheckpointError, checkpoint
4
- from .utils import AttrDict
5
-
6
- def global_multiply(a: int, b: int) -> int:
7
- return a * b
8
-
9
- @pytest.fixture(autouse=True)
10
- def run_before_and_after_tests(tmpdir):
11
- global checkpoint
12
- checkpoint = checkpoint(root_path=tmpdir)
13
- yield
14
-
15
- def test_basic_caching():
16
- @checkpoint
17
- def square(x: int) -> int:
18
- return x ** 2
19
-
20
- result1 = square(4)
21
- result2 = square(4)
22
-
23
- assert result1 == result2 == 16
24
-
25
- def test_cache_invalidation():
26
- @checkpoint
27
- def multiply(a: int, b: int):
28
- return a * b
29
-
30
- @checkpoint
31
- def helper(x: int):
32
- return multiply(x + 1, 2)
33
-
34
- @checkpoint
35
- def compute(a: int, b: int):
36
- return helper(a) + helper(b)
37
-
38
- result1 = compute(3, 4)
39
- assert result1 == 18
40
-
41
- def test_layered_caching():
42
- dev_checkpoint = checkpoint(when=True)
43
-
44
- @checkpoint(format="memory")
45
- @dev_checkpoint
46
- def expensive_function(x: int):
47
- return x ** 2
48
-
49
- assert expensive_function(4) == 16
50
- assert expensive_function(4) == 16
51
-
52
- def test_recursive_caching1():
53
- @checkpoint
54
- def fib(n: int) -> int:
55
- return fib(n - 1) + fib(n - 2) if n > 1 else n
56
-
57
- assert fib(10) == 55
58
- assert fib.get(10) == 55
59
- assert fib.get(5) == 5
60
-
61
- def test_recursive_caching2():
62
- @checkpoint
63
- def fib(n: int) -> int:
64
- return fib.fn(n - 1) + fib.fn(n - 2) if n > 1 else n
65
-
66
- assert fib(10) == 55
67
- assert fib.get(10) == 55
68
- with pytest.raises(CheckpointError):
69
- fib.get(5)
70
-
71
- @pytest.mark.asyncio
72
- async def test_async_caching():
73
- @checkpoint(format="memory")
74
- async def async_square(x: int) -> int:
75
- await asyncio.sleep(0.1)
76
- return x ** 2
77
-
78
- result1 = await async_square(3)
79
- result2 = await async_square(3)
80
- result3 = async_square.get(3)
81
-
82
- assert result1 == result2 == result3 == 9
83
-
84
- def test_force_recalculation():
85
- @checkpoint
86
- def square(x: int) -> int:
87
- return x ** 2
88
-
89
- assert square(5) == 25
90
- square.rerun(5)
91
- assert square.get(5) == 25
92
-
93
- def test_multi_layer_decorator():
94
- @checkpoint(format="memory")
95
- @checkpoint(format="pickle")
96
- def add(a: int, b: int) -> int:
97
- return a + b
98
-
99
- assert add(2, 3) == 5
100
- assert add.get(2, 3) == 5
101
-
102
- def test_capture():
103
- item_dict = AttrDict({"a": 1, "b": 1})
104
-
105
- @checkpoint(capture=True)
106
- def test_whole():
107
- return item_dict
108
-
109
- @checkpoint(capture=True)
110
- def test_a():
111
- return item_dict.a + 1
112
-
113
- init_hash_a = test_a.ident.captured_hash
114
- init_hash_whole = test_whole.ident.captured_hash
115
- item_dict.b += 1
116
- test_whole.reinit()
117
- test_a.reinit()
118
- assert test_whole.ident.captured_hash != init_hash_whole
119
- assert test_a.ident.captured_hash == init_hash_a
120
- item_dict.a += 1
121
- test_a.reinit()
122
- assert test_a.ident.captured_hash != init_hash_a
123
-
124
- def test_depends():
125
- def multiply_wrapper(a: int, b: int) -> int:
126
- return global_multiply(a, b)
127
-
128
- def helper(a: int, b: int) -> int:
129
- return multiply_wrapper(a + 1, b + 1)
130
-
131
- @checkpoint
132
- def test_a(a: int, b: int) -> int:
133
- return helper(a, b)
134
-
135
- @checkpoint
136
- def test_b(a: int, b: int) -> int:
137
- return test_a(a, b) + multiply_wrapper(a, b)
138
-
139
- assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
140
- assert set(test_b.depends) == {test_b.fn, test_a, multiply_wrapper, global_multiply}
141
-
142
- def test_lazy_init_1():
143
- @checkpoint
144
- def fn1(x: object) -> object:
145
- return fn2(x)
146
-
147
- @checkpoint
148
- def fn2(x: object) -> object:
149
- return fn1(x)
150
-
151
- assert set(fn1.depends) == {fn1.fn, fn2}
152
- assert set(fn2.depends) == {fn1, fn2.fn}
153
-
154
- def test_lazy_init_2():
155
- @checkpoint
156
- def fn1(x: object) -> object:
157
- return fn2(x)
158
-
159
- assert set(fn1.depends) == {fn1.fn}
160
-
161
- @checkpoint
162
- def fn2(x: object) -> object:
163
- return fn1(x)
164
-
165
- assert set(fn1.depends) == {fn1.fn}
166
- fn1.reinit()
167
- assert set(fn1.depends) == {fn1.fn, fn2}
168
- assert set(fn2.depends) == {fn1, fn2.fn}
@@ -1,16 +0,0 @@
1
- checkpointer/__init__.py,sha256=ayjFyHwvl_HRHwocY-hOJvAx0Ko5X9IMZrNT4CwfoMU,824
2
- checkpointer/checkpoint.py,sha256=jUTImaeAMde2skReH8DxmlTaUe8XzL1uKRSkbS1-N80,9523
3
- checkpointer/fn_ident.py,sha256=-5XbovQowVyYCFc7JdT9z1NoIEiL8h9fi7alF_34Ils,4470
4
- checkpointer/object_hash.py,sha256=YlyFupQrg3V2mpzTLfOqpqlZWhoSCHliScQ4cKd36T0,8133
5
- checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
6
- checkpointer/test_checkpointer.py,sha256=-EvsMMNOOiIxhTcG97LLX0jUMWp534ko7qCKDSFWiA0,3802
7
- checkpointer/types.py,sha256=n1CxJsh7c_o72pFEfSbZ8cZgMeSNfAbLWUwrca8-zNo,449
8
- checkpointer/utils.py,sha256=CXoofmu6nxY5uW7oPUqdH31qffi3jQn_sc1qXbzxCsU,2137
9
- checkpointer/storages/__init__.py,sha256=en32nTUltpCSgz8RVGS_leIHC1Y1G89IqG1ZqAb6qUo,236
10
- checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
11
- checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
12
- checkpointer/storages/storage.py,sha256=5Jel7VlmCG9gnIbZFKT1NrEiAePz8ZD8hfsD-tYEiP4,905
13
- checkpointer-2.11.2.dist-info/METADATA,sha256=5nwfJ3M_F-RjJLCr49fMaXdjbl5yk1dKZt6Fk7WCy_k,11630
14
- checkpointer-2.11.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- checkpointer-2.11.2.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
16
- checkpointer-2.11.2.dist-info/RECORD,,