checkpointer 2.11.1__py3-none-any.whl → 2.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import gc
2
2
  import tempfile
3
3
  from typing import Callable
4
- from .checkpoint import CachedFunction, Checkpointer, CheckpointError
4
+ from .checkpoint import CachedFunction, Checkpointer, CheckpointError, FunctionIdent
5
5
  from .object_hash import ObjectHash
6
6
  from .storages import MemoryStorage, PickleStorage, Storage
7
- from .types import AwaitableValue, HashBy, NoHash
7
+ from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeOnce, HashBy, NoHash
8
8
 
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
@@ -2,24 +2,18 @@ from __future__ import annotations
2
2
  import re
3
3
  from datetime import datetime
4
4
  from functools import cached_property, update_wrapper
5
- from inspect import Parameter, Signature, iscoroutine, signature
5
+ from inspect import Parameter, iscoroutine, signature, unwrap
6
+ from itertools import chain
6
7
  from pathlib import Path
7
8
  from typing import (
8
- Annotated, Callable, Concatenate, Coroutine, Generic,
9
- Iterable, Literal, ParamSpec, Self, Type, TypedDict,
10
- TypeVar, Unpack, cast, get_args, get_origin, overload,
9
+ Callable, Concatenate, Coroutine, Generic, Iterable,
10
+ Literal, Self, Type, TypedDict, Unpack, cast, overload,
11
11
  )
12
- from .fn_ident import RawFunctionIdent, get_fn_ident
12
+ from .fn_ident import Capturable, RawFunctionIdent, get_fn_ident
13
13
  from .object_hash import ObjectHash
14
14
  from .print_checkpoint import print_checkpoint
15
- from .storages import STORAGE_MAP, Storage
16
- from .types import AwaitableValue, HashBy
17
- from .utils import unwrap_fn
18
-
19
- Fn = TypeVar("Fn", bound=Callable)
20
- P = ParamSpec("P")
21
- R = TypeVar("R")
22
- C = TypeVar("C")
15
+ from .storages import STORAGE_MAP, Storage, StorageType
16
+ from .types import AwaitableValue, C, Coro, Fn, P, R, hash_by_from_annotation
23
17
 
24
18
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
25
19
 
@@ -27,7 +21,7 @@ class CheckpointError(Exception):
27
21
  pass
28
22
 
29
23
  class CheckpointerOpts(TypedDict, total=False):
30
- format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
24
+ format: Type[Storage] | StorageType
31
25
  root_path: Path | str | None
32
26
  when: bool
33
27
  verbosity: Literal[0, 1, 2]
@@ -66,28 +60,52 @@ class FunctionIdent:
66
60
  self.__dict__.clear()
67
61
  self.cached_fn = cached_fn
68
62
 
63
+ def reset(self):
64
+ self.__init__(self.cached_fn)
65
+
66
+ def is_static(self) -> bool:
67
+ return self.cached_fn.checkpointer.fn_hash_from is not None
68
+
69
69
  @cached_property
70
70
  def raw_ident(self) -> RawFunctionIdent:
71
- return get_fn_ident(unwrap_fn(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
71
+ return get_fn_ident(unwrap(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
72
72
 
73
73
  @cached_property
74
74
  def fn_hash(self) -> str:
75
- if self.cached_fn.checkpointer.fn_hash_from is not None:
75
+ if self.is_static():
76
76
  return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
77
- deep_hashes = [depend.ident.raw_ident.fn_hash for depend in self.cached_fn.deep_depends()]
77
+ depends = self.deep_idents(past_static=False)
78
+ deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
78
79
  return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
79
80
 
80
81
  @cached_property
81
- def captured_hash(self) -> str:
82
- deep_hashes = [depend.ident.raw_ident.captured_hash for depend in self.cached_fn.deep_depends()]
83
- return str(ObjectHash().write_text(iter=deep_hashes))
82
+ def capturables(self) -> list[Capturable]:
83
+ return sorted({
84
+ capturable.key: capturable
85
+ for depend in self.deep_idents()
86
+ for capturable in depend.raw_ident.capturables
87
+ }.values())
88
+
89
+ def deep_depends(self, past_static=True, visited: set[Callable] = set()) -> Iterable[Callable]:
90
+ if self.cached_fn not in visited:
91
+ yield self.cached_fn
92
+ visited = visited or set()
93
+ visited.add(self.cached_fn)
94
+ stop = not past_static and self.is_static()
95
+ depends = [] if stop else self.raw_ident.depends
96
+ for depend in depends:
97
+ if isinstance(depend, CachedFunction):
98
+ yield from depend.ident.deep_depends(past_static, visited)
99
+ elif depend not in visited:
100
+ yield depend
101
+ visited.add(depend)
84
102
 
85
- def clear(self):
86
- self.__init__(self.cached_fn)
103
+ def deep_idents(self, past_static=True) -> Iterable[FunctionIdent]:
104
+ return (fn.ident for fn in self.deep_depends(past_static) if isinstance(fn, CachedFunction))
87
105
 
88
106
  class CachedFunction(Generic[Fn]):
89
107
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
90
- wrapped = unwrap_fn(fn)
108
+ wrapped = unwrap(fn)
91
109
  fn_file = Path(wrapped.__code__.co_filename).name
92
110
  fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
93
111
  Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
@@ -99,12 +117,11 @@ class CachedFunction(Generic[Fn]):
99
117
  self.cleanup = self.storage.cleanup
100
118
  self.bound = ()
101
119
 
102
- sig = signature(wrapped)
103
- params = list(sig.parameters.items())
120
+ params = list(signature(wrapped).parameters.values())
104
121
  pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
105
- self.arg_names = [name for name, param in params if param.kind in pos_params]
106
- self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
107
- self.hash_by_map = get_hash_by_map(sig)
122
+ self.arg_names = [param.name for param in params if param.kind in pos_params]
123
+ self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
124
+ self.hash_by_map = get_hash_by_map(params)
108
125
  self.ident = FunctionIdent(self)
109
126
 
110
127
  @overload
@@ -124,12 +141,12 @@ class CachedFunction(Generic[Fn]):
124
141
  return self.ident.raw_ident.depends
125
142
 
126
143
  def reinit(self, recursive=False) -> CachedFunction[Fn]:
127
- depend_idents = [depend.ident for depend in self.deep_depends()] if recursive else [self.ident]
128
- for ident in depend_idents: ident.clear()
144
+ depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
145
+ for ident in depend_idents: ident.reset()
129
146
  for ident in depend_idents: ident.fn_hash
130
147
  return self
131
148
 
132
- def get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
149
+ def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
133
150
  args = self.bound + args
134
151
  pos_args = args[len(self.arg_names):]
135
152
  named_pos_args = dict(zip(self.arg_names, args))
@@ -140,8 +157,17 @@ class CachedFunction(Generic[Fn]):
140
157
  if hash_by := hash_by_map.get(key, rest_hash_by):
141
158
  named_args[key] = hash_by(value)
142
159
  if pos_hash_by := hash_by_map.get(b"*"):
143
- pos_args = tuple(map(pos_hash_by, pos_args))
144
- return str(ObjectHash(named_args, pos_args, self.ident.captured_hash, digest_size=16))
160
+ pos_args = map(pos_hash_by, pos_args)
161
+ named_args_iter = chain.from_iterable(sorted(named_args.items()))
162
+ captured = chain.from_iterable(capturable.capture() for capturable in self.ident.capturables)
163
+ obj_hash = ObjectHash(digest_size=16) \
164
+ .update(iter=named_args_iter, header="NAMED") \
165
+ .update(iter=pos_args, header="POS") \
166
+ .update(iter=captured, header="CAPTURED")
167
+ return str(obj_hash)
168
+
169
+ def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
170
+ return self._get_call_hash(args, kw)
145
171
 
146
172
  async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
147
173
  return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
@@ -152,15 +178,14 @@ class CachedFunction(Generic[Fn]):
152
178
  if not params.when:
153
179
  return self.fn(*full_args, **kw)
154
180
 
155
- call_hash = self.get_call_hash(args, kw)
156
- call_hash_long = f"{self.fn_dir}/{self.ident.fn_hash}/{call_hash}"
157
-
181
+ call_hash = self._get_call_hash(args, kw)
182
+ call_id = f"{self.storage.fn_id()}/{call_hash}"
158
183
  refresh = rerun \
159
184
  or not self.storage.exists(call_hash) \
160
185
  or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_hash)))
161
186
 
162
187
  if refresh:
163
- print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_hash_long, "blue")
188
+ print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id, "blue")
164
189
  data = self.fn(*full_args, **kw)
165
190
  if iscoroutine(data):
166
191
  return self._resolve_coroutine(call_hash, data)
@@ -168,11 +193,11 @@ class CachedFunction(Generic[Fn]):
168
193
 
169
194
  try:
170
195
  data = self.storage.load(call_hash)
171
- print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_hash_long, "green")
196
+ print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id, "green")
172
197
  return data
173
198
  except (EOFError, FileNotFoundError):
174
199
  pass
175
- print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_hash_long, "yellow")
200
+ print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_id, "yellow")
176
201
  return self._call(args, kw, True)
177
202
 
178
203
  def __call__(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
@@ -182,17 +207,17 @@ class CachedFunction(Generic[Fn]):
182
207
  return self._call(args, kw, True)
183
208
 
184
209
  def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
185
- return self.storage.exists(self.get_call_hash(args, kw))
210
+ return self.storage.exists(self._get_call_hash(args, kw))
186
211
 
187
212
  def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
188
- self.storage.delete(self.get_call_hash(args, kw))
213
+ self.storage.delete(self._get_call_hash(args, kw))
189
214
 
190
215
  @overload
191
- def get(self: Callable[P, Coroutine[object, object, R]], *args: P.args, **kw: P.kwargs) -> R: ...
216
+ def get(self: Callable[P, Coro[R]], *args: P.args, **kw: P.kwargs) -> R: ...
192
217
  @overload
193
218
  def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
194
219
  def get(self, *args, **kw):
195
- call_hash = self.get_call_hash(args, kw)
220
+ call_hash = self._get_call_hash(args, kw)
196
221
  try:
197
222
  data = self.storage.load(call_hash)
198
223
  return data.value if isinstance(data, AwaitableValue) else data
@@ -200,34 +225,19 @@ class CachedFunction(Generic[Fn]):
200
225
  raise CheckpointError("Could not load checkpoint") from ex
201
226
 
202
227
  @overload
203
- def _set(self: Callable[P, Coroutine[object, object, R]], value: AwaitableValue[R], *args: P.args, **kw: P.kwargs): ...
228
+ def set(self: Callable[P, Coro[R]], value: AwaitableValue[R], *args: P.args, **kw: P.kwargs): ...
204
229
  @overload
205
- def _set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
206
- def _set(self, value, *args, **kw):
207
- self.storage.store(self.get_call_hash(args, kw), value)
230
+ def set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
231
+ def set(self, value, *args, **kw):
232
+ self.storage.store(self._get_call_hash(args, kw), value)
208
233
 
209
234
  def __repr__(self) -> str:
210
235
  return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
211
236
 
212
- def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
213
- if self not in visited:
214
- yield self
215
- visited = visited or set()
216
- visited.add(self)
217
- for depend in self.depends:
218
- if isinstance(depend, CachedFunction):
219
- yield from depend.deep_depends(visited)
220
-
221
- def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
222
- if get_origin(annotation) is Annotated:
223
- args = get_args(annotation)
224
- metadata = args[1] if len(args) > 1 else None
225
- if get_origin(metadata) is HashBy:
226
- return get_args(metadata)[0]
227
-
228
- def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
237
+ def get_hash_by_map(params: list[Parameter]) -> dict[str | bytes, Callable[[object], object]]:
229
238
  hash_by_map = {}
230
- for name, param in sig.parameters.items():
239
+ for param in params:
240
+ name = param.name
231
241
  if param.kind == Parameter.VAR_POSITIONAL:
232
242
  name = b"*"
233
243
  elif param.kind == Parameter.VAR_KEYWORD:
checkpointer/fn_ident.py CHANGED
@@ -1,34 +1,60 @@
1
1
  import dis
2
- import inspect
3
- from itertools import takewhile
4
- from pathlib import Path
5
- from types import CodeType, FunctionType, MethodType
6
- from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
2
+ from inspect import Parameter, getmodule, signature, unwrap
3
+ from types import CodeType, MethodType, ModuleType
4
+ from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
5
+ from .import_mappings import resolve_annotation
7
6
  from .object_hash import ObjectHash
8
- from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
7
+ from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
8
+ from .utils import (
9
+ AttrDict, cwd, distinct, get_cell_contents,
10
+ get_file, is_class, is_user_fn, seekable, takewhile,
11
+ )
9
12
 
10
- cwd = Path.cwd().resolve()
13
+ AttrPath = tuple[str, ...]
14
+ CapturableByFn = dict[Callable, list["Capturable"]]
11
15
 
12
16
  class RawFunctionIdent(NamedTuple):
13
17
  fn_hash: str
14
- captured_hash: str
15
18
  depends: list[Callable]
19
+ capturables: set["Capturable"]
16
20
 
17
- def is_class(obj) -> TypeGuard[Type]:
18
- # isinstance works too, but needlessly triggers _lazyinit()
19
- return issubclass(type(obj), type)
21
+ class Capturable(NamedTuple):
22
+ key: str
23
+ module: ModuleType
24
+ attr_path: AttrPath
25
+ hash_by: Callable | None
26
+ hash: str | None = None
27
+
28
+ def capture(self) -> tuple[str, object]:
29
+ if obj := self.hash:
30
+ return self.key, obj
31
+ obj = AttrDict.get_at(self.module, *self.attr_path)
32
+ obj = self.hash_by(obj) if self.hash_by else obj
33
+ return self.key, obj
34
+
35
+ @staticmethod
36
+ def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
37
+ file = str(get_file(module).relative_to(cwd))
38
+ key = "-".join((file, *attr_path))
39
+ cap = Capturable(key, module, attr_path, hash_by)
40
+ if not capture_once:
41
+ return cap
42
+ obj_hash = str(ObjectHash(cap.capture()[1]))
43
+ return Capturable(key, module, attr_path, None, obj_hash)
20
44
 
21
45
  def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
22
- attr_path: tuple[str, ...] = ()
46
+ attr_path = AttrPath(())
23
47
  scope_obj = None
24
48
  classvars: dict[str, dict[str, Type]] = {}
25
- for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
49
+ instructs = seekable(dis.get_instructions(code))
50
+ for instr in instructs:
26
51
  if instr.opname in scope_vars and not attr_path:
27
- attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
28
- attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
52
+ attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
53
+ attr_path = AttrPath((instr.opname, instr.argval, *attrs))
54
+ instructs.step(-1)
29
55
  elif instr.opname == "CALL":
30
- obj = scope_vars.get_at(attr_path)
31
- attr_path = ()
56
+ obj = scope_vars.get_at(*attr_path)
57
+ attr_path = AttrPath(())
32
58
  if is_class(obj):
33
59
  scope_obj = obj
34
60
  elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
@@ -37,67 +63,97 @@ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[st
37
63
  scope_obj = None
38
64
  return classvars
39
65
 
40
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
66
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[AttrPath, object]]:
41
67
  classvars = extract_classvars(code, scope_vars)
42
68
  scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
43
- for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
69
+ instructs = seekable(dis.get_instructions(code))
70
+ for instr in instructs:
44
71
  if instr.opname in scope_vars:
45
- attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
46
- attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
47
- val = scope_vars.get_at(attr_path)
48
- if val is not None:
49
- yield attr_path, val
72
+ attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
73
+ attr_path = AttrPath((instr.opname, instr.argval, *attrs))
74
+ parent_path = attr_path[:-1]
75
+ instructs.step(-1)
76
+ obj = scope_vars.get_at(*attr_path)
77
+ if obj is not None:
78
+ yield attr_path, obj
79
+ if callable(obj) and parent_path[1:]:
80
+ parent_obj = scope_vars.get_at(*parent_path)
81
+ yield parent_path, parent_obj
50
82
  for const in code.co_consts:
51
83
  if isinstance(const, CodeType):
52
- yield from extract_scope_values(const, scope_vars)
84
+ next_deref = scope_vars.LOAD_DEREF.set(scope_vars.LOAD_FAST)
85
+ next_scope_vars = AttrDict({**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref})
86
+ yield from extract_scope_values(const, next_scope_vars)
87
+
88
+ def resolve_class_annotations(anno: object) -> Type | None:
89
+ if anno in (None, Annotated):
90
+ return None
91
+ elif is_class(anno):
92
+ return anno
93
+ elif get_origin(anno) is Annotated:
94
+ return resolve_class_annotations(next(iter(get_args(anno)), None))
95
+ return resolve_class_annotations(get_origin(anno))
53
96
 
54
97
  def get_self_value(fn: Callable) -> type | object | None:
55
98
  if isinstance(fn, MethodType):
56
99
  return fn.__self__
57
- parts = tuple(fn.__qualname__.split(".")[:-1])
58
- cls = parts and AttrDict(fn.__globals__).get_at(parts)
100
+ parts = fn.__qualname__.split(".")[:-1]
101
+ cls = parts and AttrDict(fn.__globals__).get_at(*parts)
59
102
  if is_class(cls):
60
103
  return cls
61
104
 
62
- def get_fn_captured_vals(fn: Callable) -> list[object]:
105
+ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, object]) -> Iterable[Capturable]:
106
+ module = getmodule(fn)
107
+ if not module or not is_user_fn(fn):
108
+ return
109
+ for (instruct, *attr_path), obj in captured_vars.items():
110
+ attr_path = AttrPath(attr_path)
111
+ if instruct == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
112
+ anno = resolve_annotation(module, ".".join(attr_path))
113
+ if capture or is_capture_me(anno) or is_capture_me_once(anno):
114
+ hash_by = hash_by_from_annotation(anno)
115
+ if hash_by is not to_none:
116
+ yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
117
+
118
+ def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
119
+ sig_scope = {
120
+ param.name: class_anno
121
+ for param in signature(fn).parameters.values()
122
+ if param.annotation is not Parameter.empty
123
+ if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
124
+ if (class_anno := resolve_class_annotations(param.annotation))
125
+ }
63
126
  self_value = get_self_value(fn)
64
127
  scope_vars = AttrDict({
65
- "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
128
+ "LOAD_FAST": AttrDict({**sig_scope, "self": self_value} if self_value else sig_scope),
66
129
  "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
67
130
  "LOAD_GLOBAL": AttrDict(fn.__globals__),
68
131
  })
69
- vals = dict(extract_scope_values(fn.__code__, scope_vars))
70
- return list(vals.values())
71
-
72
- def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
73
- if not isinstance(candidate_fn, (FunctionType, MethodType)):
74
- return False
75
- fn_path = Path(inspect.getfile(candidate_fn)).resolve()
76
- return cwd in fn_path.parents and ".venv" not in fn_path.parts
132
+ captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
133
+ captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
134
+ capturables = list(get_capturables(fn, capture, captured_vars))
135
+ return captured_callables, capturables
77
136
 
78
- def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
137
+ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn = {}) -> CapturableByFn:
79
138
  from .checkpoint import CachedFunction
80
- captured_vals = get_fn_captured_vals(fn)
81
- captured_vals_by_fn = captured_vals_by_fn or {}
82
- captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
83
- for val in captured_vals:
84
- if not callable(val):
85
- continue
86
- child_fn = unwrap_fn(val, cached_fn=True)
87
- if isinstance(child_fn, CachedFunction):
88
- captured_vals_by_fn[child_fn] = []
89
- elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
90
- get_depend_fns(child_fn, captured_vals_by_fn)
91
- return captured_vals_by_fn
139
+ captured_callables, capturables = get_fn_captures(fn, capture)
140
+ capturable_by_fn = capturable_by_fn or {}
141
+ capturable_by_fn[fn] = capturables
142
+ for depend_fn in captured_callables:
143
+ depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
144
+ if isinstance(depend_fn, CachedFunction):
145
+ capturable_by_fn[depend_fn] = []
146
+ elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
147
+ get_depend_fns(depend_fn, capture, capturable_by_fn)
148
+ return capturable_by_fn
92
149
 
93
150
  def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
94
151
  from .checkpoint import CachedFunction
95
- captured_vals_by_fn = get_depend_fns(fn)
96
- depend_captured_vals = list(captured_vals_by_fn.values()) * capture
97
- depends = captured_vals_by_fn.keys()
152
+ capturable_by_fn = get_depend_fns(fn, capture)
153
+ capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
154
+ depends = capturable_by_fn.keys()
98
155
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
99
156
  unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
100
157
  assert fn == unwrapped_depends[0]
101
158
  fn_hash = str(ObjectHash(iter=unwrapped_depends))
102
- captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
103
- return RawFunctionIdent(fn_hash, captured_hash, depends)
159
+ return RawFunctionIdent(fn_hash, depends, capturables)
@@ -0,0 +1,47 @@
1
+ import ast
2
+ import inspect
3
+ import sys
4
+ from types import ModuleType
5
+ from typing import Iterable, Type
6
+ from .utils import cwd, get_file, is_user_file
7
+
8
+ ImportTarget = tuple[str, str | None]
9
+
10
+ cache: dict[tuple[str, int], dict[str, ImportTarget]] = {}
11
+
12
+ def generate_import_mappings(module: ModuleType) -> Iterable[tuple[str, ImportTarget]]:
13
+ mod_path = get_file(module)
14
+ if not is_user_file(mod_path):
15
+ return
16
+ mod_parts = list(mod_path.with_suffix("").relative_to(cwd).parts)
17
+ source = inspect.getsource(module)
18
+ tree = ast.parse(source)
19
+ for node in ast.walk(tree):
20
+ if isinstance(node, ast.Import):
21
+ for alias in node.names:
22
+ yield (alias.asname or alias.name, (alias.name, None))
23
+ elif isinstance(node, ast.ImportFrom):
24
+ target_mod = node.module or ""
25
+ if node.level > 0:
26
+ target_mod_parts = target_mod.split(".") * bool(target_mod)
27
+ target_mod_parts = mod_parts[:-node.level] + target_mod_parts
28
+ target_mod = ".".join(target_mod_parts)
29
+ for alias in node.names:
30
+ yield (alias.asname or alias.name, (target_mod, alias.name))
31
+
32
+ def get_import_mappings(module: ModuleType) -> dict[str, ImportTarget]:
33
+ cache_key = (module.__name__, id(module))
34
+ if cached := cache.get(cache_key):
35
+ return cached
36
+ import_mappings = dict(generate_import_mappings(module))
37
+ return cache.setdefault(cache_key, import_mappings)
38
+
39
+ def resolve_annotation(module: ModuleType, attr_name: str | None) -> Type | None:
40
+ if not attr_name:
41
+ return None
42
+ if anno := module.__annotations__.get(attr_name):
43
+ return anno
44
+ if next_pair := get_import_mappings(module).get(attr_name):
45
+ next_module_name, next_attr_name = next_pair
46
+ if next_module := sys.modules.get(next_module_name):
47
+ return resolve_annotation(next_module, next_attr_name)
@@ -12,7 +12,7 @@ from io import StringIO
12
12
  from itertools import chain
13
13
  from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
14
14
  from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
15
- from typing import Callable, TypeVar
15
+ from typing import Callable, Self, TypeVar
16
16
  from .utils import ContextVar
17
17
 
18
18
  np, torch = None, None
@@ -43,14 +43,14 @@ class ObjectHashError(Exception):
43
43
  self.obj = obj
44
44
 
45
45
  class ObjectHash:
46
- def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerate_errors=False) -> None:
46
+ def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerable=False) -> None:
47
47
  self.hash = hashlib.blake2b(digest_size=digest_size)
48
48
  self.current: dict[int, int] = {}
49
- self.tolerate_errors = ContextVar(tolerate_errors)
49
+ self.tolerable = ContextVar(tolerable)
50
50
  self.update(iter=chain(objs, iter))
51
51
 
52
52
  def copy(self) -> "ObjectHash":
53
- new = ObjectHash(tolerate_errors=self.tolerate_errors.value)
53
+ new = ObjectHash(tolerable=self.tolerable.value)
54
54
  new.hash = self.hash.copy()
55
55
  return new
56
56
 
@@ -63,26 +63,29 @@ class ObjectHash:
63
63
  return isinstance(value, ObjectHash) and str(self) == str(value)
64
64
 
65
65
  def nested_hash(self, *objs: object) -> str:
66
- return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
66
+ return ObjectHash(iter=objs, tolerable=self.tolerable.value).hexdigest()
67
67
 
68
- def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
68
+ def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> Self:
69
69
  for d in chain(data, iter):
70
70
  self.hash.update(d)
71
71
  return self
72
72
 
73
- def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
73
+ def write_text(self, *data: str, iter: Iterable[str] = ()) -> Self:
74
74
  return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
75
75
 
76
- def header(self, *args: object) -> "ObjectHash":
76
+ def header(self, *args: object) -> Self:
77
77
  return self.write_bytes(":".join(map(str, args)).encode())
78
78
 
79
- def update(self, *objs: object, iter: Iterable[object] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
80
- with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
79
+ def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
80
+ with nullcontext() if tolerable is None else self.tolerable.set(tolerable):
81
81
  for obj in chain(objs, iter):
82
+ if header is not None:
83
+ self.write_bytes(header.encode())
84
+ header = None
82
85
  try:
83
86
  self._update_one(obj)
84
87
  except Exception as ex:
85
- if self.tolerate_errors.value:
88
+ if self.tolerable.value:
86
89
  self.header("error").update(type(ex))
87
90
  else:
88
91
  raise ObjectHashError(obj, ex) from ex
@@ -180,10 +183,10 @@ class ObjectHash:
180
183
  finally:
181
184
  del self.current[id(obj)]
182
185
 
183
- def _update_iterator(self, obj: Iterable) -> "ObjectHash":
186
+ def _update_iterator(self, obj: Iterable) -> Self:
184
187
  return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
185
188
 
186
- def _update_object(self, obj: object) -> "ObjectHash":
189
+ def _update_object(self, obj: object) -> Self:
187
190
  self.header("instance", encode_type_of(obj))
188
191
  get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
189
192
  if callable(get_hash):
@@ -29,7 +29,7 @@ COLOR_MAP: dict[Color, int] = {
29
29
  "white": 97,
30
30
  }
31
31
 
32
- def allow_color() -> bool:
32
+ def _allow_color() -> bool:
33
33
  if "NO_COLOR" in os.environ or os.environ.get("TERM") == "dumb" or not hasattr(sys.stdout, "fileno"):
34
34
  return False
35
35
  try:
@@ -37,16 +37,17 @@ def allow_color() -> bool:
37
37
  except io.UnsupportedOperation:
38
38
  return sys.stdout.isatty()
39
39
 
40
- def colored_(text: str, color: Color | None = None, on_color: Color | None = None) -> str:
40
+ allow_color = _allow_color()
41
+
42
+ def colored(text: str, color: Color | None = None, on_color: Color | None = None) -> str:
43
+ if not allow_color:
44
+ return text
41
45
  if color:
42
46
  text = f"\033[{COLOR_MAP[color]}m{text}"
43
47
  if on_color:
44
48
  text = f"\033[{COLOR_MAP[on_color] + 10}m{text}"
45
49
  return text + "\033[0m"
46
50
 
47
- noop = lambda text, *a, **k: text
48
- colored = colored_ if allow_color() else noop
49
-
50
51
  def print_checkpoint(should_log: bool, title: str, text: str, color: Color):
51
52
  if should_log:
52
53
  print(f'{colored(f" {title} ", "grey", color)} {colored(text, color)}')
@@ -1,9 +1,11 @@
1
- from typing import Type
2
- from .storage import Storage
3
- from .pickle_storage import PickleStorage
1
+ from typing import Literal, Type
4
2
  from .memory_storage import MemoryStorage
3
+ from .pickle_storage import PickleStorage
4
+ from .storage import Storage
5
+
6
+ StorageType = Literal["pickle", "memory"]
5
7
 
6
- STORAGE_MAP: dict[str, Type[Storage]] = {
8
+ STORAGE_MAP: dict[StorageType, Type[Storage]] = {
7
9
  "pickle": PickleStorage,
8
10
  "memory": MemoryStorage,
9
11
  }
checkpointer/types.py CHANGED
@@ -1,12 +1,44 @@
1
- from typing import Annotated, Callable, Generic, TypeVar
1
+ from typing import (
2
+ Annotated, Callable, Coroutine, Generic,
3
+ ParamSpec, TypeVar, get_args, get_origin,
4
+ )
2
5
 
3
- T = TypeVar("T")
4
6
  Fn = TypeVar("Fn", bound=Callable)
7
+ P = ParamSpec("P")
8
+ R = TypeVar("R")
9
+ C = TypeVar("C")
10
+ T = TypeVar("T")
5
11
 
6
12
  class HashBy(Generic[Fn]):
7
13
  pass
8
14
 
9
- NoHash = Annotated[T, HashBy[lambda _: None]]
15
+ class Captured:
16
+ pass
17
+
18
+ class CapturedOnce:
19
+ pass
20
+
21
+ def to_none(_):
22
+ return None
23
+
24
+ def get_annotated_args(anno: object) -> tuple[object, ...]:
25
+ return get_args(anno) if get_origin(anno) is Annotated else ()
26
+
27
+ def hash_by_from_annotation(anno: object) -> Callable[[object], object] | None:
28
+ for arg in get_annotated_args(anno):
29
+ if get_origin(arg) is HashBy:
30
+ return get_args(arg)[0]
31
+
32
+ def is_capture_me(anno: object) -> bool:
33
+ return Captured in get_annotated_args(anno)
34
+
35
+ def is_capture_me_once(anno: object) -> bool:
36
+ return CapturedOnce in get_annotated_args(anno)
37
+
38
+ NoHash = Annotated[T, HashBy[to_none]]
39
+ CaptureMe = Annotated[T, Captured]
40
+ CaptureMeOnce = Annotated[T, CapturedOnce]
41
+ Coro = Coroutine[object, object, R]
10
42
 
11
43
  class AwaitableValue(Generic[T]):
12
44
  def __init__(self, value: T):
checkpointer/utils.py CHANGED
@@ -1,53 +1,106 @@
1
+ from __future__ import annotations
2
+ import inspect
1
3
  from contextlib import contextmanager
2
- from typing import Any, Callable, Generic, Iterable, TypeVar, cast
4
+ from itertools import islice
5
+ from pathlib import Path
6
+ from types import FunctionType, MethodType, ModuleType
7
+ from typing import Callable, Generic, Iterable, Self, Type, TypeGuard
8
+ from .types import T
3
9
 
4
- T = TypeVar("T")
5
- Fn = TypeVar("Fn", bound=Callable)
10
+ cwd = Path.cwd().resolve()
6
11
 
7
- def distinct(seq: Iterable[T]) -> list[T]:
8
- return list(dict.fromkeys(seq))
12
+ def is_class(obj) -> TypeGuard[Type]:
13
+ return isinstance(obj, type)
14
+
15
+ def get_file(obj: Callable | ModuleType) -> Path:
16
+ return Path(inspect.getfile(obj)).resolve()
17
+
18
+ def is_user_file(path: Path) -> bool:
19
+ return cwd in path.parents and ".venv" not in path.parts
9
20
 
10
- def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
21
+ def is_user_fn(obj) -> TypeGuard[Callable]:
22
+ return isinstance(obj, (FunctionType, MethodType)) and is_user_file(get_file(obj))
23
+
24
+ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
11
25
  for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
12
26
  try:
13
27
  yield (key, cell.cell_contents)
14
28
  except ValueError:
15
29
  pass
16
30
 
17
- def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
18
- from .checkpoint import CachedFunction
19
- while True:
20
- if (cached_fn and isinstance(fn, CachedFunction)) or not hasattr(fn, "__wrapped__"):
21
- return cast(Fn, fn)
22
- fn = getattr(fn, "__wrapped__")
31
+ def distinct(seq: Iterable[T]) -> list[T]:
32
+ return list(dict.fromkeys(seq))
33
+
34
+ def takewhile(iter: Iterable[tuple[bool, T]]) -> Iterable[T]:
35
+ for condition, value in iter:
36
+ if not condition:
37
+ return
38
+ yield value
39
+
40
+ class seekable(Generic[T]):
41
+ def __init__(self, iterable: Iterable[T]):
42
+ self.index = 0
43
+ self.source = iter(iterable)
44
+ self.sink: list[T] = []
45
+
46
+ def __iter__(self):
47
+ return self
48
+
49
+ def __next__(self) -> T:
50
+ if len(self.sink) > self.index:
51
+ item = self.sink[self.index]
52
+ else:
53
+ item = next(self.source)
54
+ self.sink.append(item)
55
+ self.index += 1
56
+ return item
57
+
58
+ def __bool__(self):
59
+ return bool(self.lookahead(1))
60
+
61
+ def seek(self, index: int) -> Self:
62
+ remainder = index - len(self.sink)
63
+ if remainder > 0:
64
+ next(islice(self, remainder, remainder), None)
65
+ self.index = max(0, min(index, len(self.sink)))
66
+ return self
67
+
68
+ def step(self, count: int) -> Self:
69
+ return self.seek(self.index + count)
70
+
71
+ @contextmanager
72
+ def freeze(self):
73
+ initial_index = self.index
74
+ try:
75
+ yield
76
+ finally:
77
+ self.seek(initial_index)
78
+
79
+ def lookahead(self, count: int) -> list[T]:
80
+ with self.freeze():
81
+ return list(islice(self, count))
23
82
 
24
83
  class AttrDict(dict):
25
84
  def __init__(self, *args, **kwargs):
26
85
  super().__init__(*args, **kwargs)
27
86
  self.__dict__ = self
28
87
 
29
- def __getattribute__(self, name: str) -> Any:
88
+ def __getattribute__(self, name: str):
30
89
  return super().__getattribute__(name)
31
90
 
32
- def __setattr__(self, name: str, value: Any) -> None:
33
- return super().__setattr__(name, value)
91
+ def __setattr__(self, name: str, value: object):
92
+ super().__setattr__(name, value)
34
93
 
35
- def set(self, d: dict) -> "AttrDict":
94
+ def set(self, d: dict) -> AttrDict:
36
95
  if not d:
37
96
  return self
38
97
  return AttrDict({**self, **d})
39
98
 
40
- def delete(self, *attrs: str) -> "AttrDict":
41
- d = AttrDict(self)
42
- for attr in attrs:
43
- del d[attr]
44
- return d
45
-
46
- def get_at(self, attrs: tuple[str, ...]) -> Any:
47
- d = self
99
+ def get_at(self: object, *attrs: str) -> object:
100
+ obj = self
48
101
  for attr in attrs:
49
- d = getattr(d, attr, None)
50
- return d
102
+ obj = getattr(obj, attr, None)
103
+ return obj
51
104
 
52
105
  class ContextVar(Generic[T]):
53
106
  def __init__(self, value: T):
@@ -60,26 +113,3 @@ class ContextVar(Generic[T]):
60
113
  yield
61
114
  finally:
62
115
  self.value = old
63
-
64
- class iterate_and_upcoming(Generic[T]):
65
- def __init__(self, it: Iterable[T]) -> None:
66
- self.it = iter(it)
67
- self.previous: tuple[()] | tuple[T] = ()
68
- self.tracked = self._tracked_iter()
69
-
70
- def __iter__(self):
71
- return self
72
-
73
- def __next__(self) -> tuple[T, Iterable[T]]:
74
- try:
75
- item = self.previous[0] if self.previous else next(self.it)
76
- self.previous = ()
77
- return item, self.tracked
78
- except StopIteration:
79
- self.tracked.close()
80
- raise
81
-
82
- def _tracked_iter(self):
83
- for x in self.it:
84
- self.previous = (x,)
85
- yield x
@@ -1,17 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.11.1
3
+ Version: 2.12.0
4
4
  Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
7
- License: Copyright 2018-2025 Hampus Hallman
8
-
9
- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
-
11
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
-
13
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
7
+ License-Expression: MIT
8
+ License-File: ATTRIBUTION.md
14
9
  License-File: LICENSE
10
+ Keywords: async,cache,caching,code-aware,decorator,fast,hashing,invalidation,memoization,memoize,memory,optimization,performance,workflow
15
11
  Classifier: Programming Language :: Python :: 3.11
16
12
  Classifier: Programming Language :: Python :: 3.12
17
13
  Classifier: Programming Language :: Python :: 3.13
@@ -20,9 +16,7 @@ Description-Content-Type: text/markdown
20
16
 
21
17
  # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![pypi](https://img.shields.io/pypi/pyversions/checkpointer)](https://pypi.org/project/checkpointer/)
22
18
 
23
- `checkpointer` is a Python library providing a decorator-based API for memoizing (caching) function results. It helps you skip redundant, computationally expensive operations, saving execution time and streamlining your workflows.
24
-
25
- It works with synchronous and asynchronous functions, supports multiple storage backends, and automatically invalidates caches when function code, dependencies, or captured variables change.
19
+ `checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and refreshes caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
26
20
 
27
21
  ## 📦 Installation
28
22
 
@@ -0,0 +1,17 @@
1
+ checkpointer/__init__.py,sha256=x8LbL-URPg-afn0O-HMxPsJkV9cmW-Cw5epKSCGClOM,889
2
+ checkpointer/checkpoint.py,sha256=2K2D_aYHpI1XBcNrWQxta06wRdUj81TeCXshBVxl7cA,9869
3
+ checkpointer/fn_ident.py,sha256=4qg4NIvcWRV5GduO70an_iu12dn8LLIlw3Q8F3gm0Mo,6826
4
+ checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
5
+ checkpointer/object_hash.py,sha256=MkrwSJJYVlWOjIDpGfYpAeYFPgCJhXHYBdlKVcUy2kw,8145
6
+ checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
7
+ checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
8
+ checkpointer/utils.py,sha256=rq7AjR0WJql1o8clKMdXj4p3wrPYMLyS6G11nvNNjFE,2934
9
+ checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
10
+ checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
11
+ checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
12
+ checkpointer/storages/storage.py,sha256=5Jel7VlmCG9gnIbZFKT1NrEiAePz8ZD8hfsD-tYEiP4,905
13
+ checkpointer-2.12.0.dist-info/METADATA,sha256=B-R8Aq2caTM7fclANY7MJVZJQaU0nB1Mdf7wMr6HSTI,10705
14
+ checkpointer-2.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ checkpointer-2.12.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
16
+ checkpointer-2.12.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
17
+ checkpointer-2.12.0.dist-info/RECORD,,
@@ -0,0 +1,33 @@
1
+ # Attribution and License Notices
2
+
3
+ This project includes code copied or adapted from third-party open-source projects. The following acknowledges the original sources and complies with their licensing requirements.
4
+
5
+ ---
6
+
7
+ ## Third-Party Code
8
+
9
+ ### more-itertools
10
+ - **Source:** https://github.com/more-itertools/more-itertools
11
+ - **Author:** Erik Rose
12
+ - **Copyright:** (c) 2012 Erik Rose
13
+ - **License:** MIT (https://github.com/more-itertools/more-itertools/blob/master/LICENSE)
14
+
15
+ ### colored
16
+ - **Source:** https://gitlab.com/dslackw/colored
17
+ - **Author:** Dimitris Zlatanidis
18
+ - **Copyright:** (c) 2014-2025 Dimitris Zlatanidis
19
+ - **License:** MIT (https://gitlab.com/dslackw/colored/-/blob/master/LICENSE.txt)
20
+
21
+ ---
22
+
23
+ ## License
24
+
25
+ This project is licensed under the MIT License. See the `LICENSE` file for details.
26
+
27
+ ---
28
+
29
+ ## Notes
30
+
31
+ - Third-party code is included under their original MIT licenses.
32
+ - This file documents those license notices, fulfilling attribution obligations.
33
+ - Source files with copied code may omit individual license headers in favor of this centralized attribution.
@@ -1,3 +1,5 @@
1
+ MIT License
2
+
1
3
  Copyright 2018-2025 Hampus Hallman
2
4
 
3
5
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
@@ -1,168 +0,0 @@
1
- import asyncio
2
- import pytest
3
- from checkpointer import CheckpointError, checkpoint
4
- from .utils import AttrDict
5
-
6
- def global_multiply(a: int, b: int) -> int:
7
- return a * b
8
-
9
- @pytest.fixture(autouse=True)
10
- def run_before_and_after_tests(tmpdir):
11
- global checkpoint
12
- checkpoint = checkpoint(root_path=tmpdir)
13
- yield
14
-
15
- def test_basic_caching():
16
- @checkpoint
17
- def square(x: int) -> int:
18
- return x ** 2
19
-
20
- result1 = square(4)
21
- result2 = square(4)
22
-
23
- assert result1 == result2 == 16
24
-
25
- def test_cache_invalidation():
26
- @checkpoint
27
- def multiply(a: int, b: int):
28
- return a * b
29
-
30
- @checkpoint
31
- def helper(x: int):
32
- return multiply(x + 1, 2)
33
-
34
- @checkpoint
35
- def compute(a: int, b: int):
36
- return helper(a) + helper(b)
37
-
38
- result1 = compute(3, 4)
39
- assert result1 == 18
40
-
41
- def test_layered_caching():
42
- dev_checkpoint = checkpoint(when=True)
43
-
44
- @checkpoint(format="memory")
45
- @dev_checkpoint
46
- def expensive_function(x: int):
47
- return x ** 2
48
-
49
- assert expensive_function(4) == 16
50
- assert expensive_function(4) == 16
51
-
52
- def test_recursive_caching1():
53
- @checkpoint
54
- def fib(n: int) -> int:
55
- return fib(n - 1) + fib(n - 2) if n > 1 else n
56
-
57
- assert fib(10) == 55
58
- assert fib.get(10) == 55
59
- assert fib.get(5) == 5
60
-
61
- def test_recursive_caching2():
62
- @checkpoint
63
- def fib(n: int) -> int:
64
- return fib.fn(n - 1) + fib.fn(n - 2) if n > 1 else n
65
-
66
- assert fib(10) == 55
67
- assert fib.get(10) == 55
68
- with pytest.raises(CheckpointError):
69
- fib.get(5)
70
-
71
- @pytest.mark.asyncio
72
- async def test_async_caching():
73
- @checkpoint(format="memory")
74
- async def async_square(x: int) -> int:
75
- await asyncio.sleep(0.1)
76
- return x ** 2
77
-
78
- result1 = await async_square(3)
79
- result2 = await async_square(3)
80
- result3 = async_square.get(3)
81
-
82
- assert result1 == result2 == result3 == 9
83
-
84
- def test_force_recalculation():
85
- @checkpoint
86
- def square(x: int) -> int:
87
- return x ** 2
88
-
89
- assert square(5) == 25
90
- square.rerun(5)
91
- assert square.get(5) == 25
92
-
93
- def test_multi_layer_decorator():
94
- @checkpoint(format="memory")
95
- @checkpoint(format="pickle")
96
- def add(a: int, b: int) -> int:
97
- return a + b
98
-
99
- assert add(2, 3) == 5
100
- assert add.get(2, 3) == 5
101
-
102
- def test_capture():
103
- item_dict = AttrDict({"a": 1, "b": 1})
104
-
105
- @checkpoint(capture=True)
106
- def test_whole():
107
- return item_dict
108
-
109
- @checkpoint(capture=True)
110
- def test_a():
111
- return item_dict.a + 1
112
-
113
- init_hash_a = test_a.ident.captured_hash
114
- init_hash_whole = test_whole.ident.captured_hash
115
- item_dict.b += 1
116
- test_whole.reinit()
117
- test_a.reinit()
118
- assert test_whole.ident.captured_hash != init_hash_whole
119
- assert test_a.ident.captured_hash == init_hash_a
120
- item_dict.a += 1
121
- test_a.reinit()
122
- assert test_a.ident.captured_hash != init_hash_a
123
-
124
- def test_depends():
125
- def multiply_wrapper(a: int, b: int) -> int:
126
- return global_multiply(a, b)
127
-
128
- def helper(a: int, b: int) -> int:
129
- return multiply_wrapper(a + 1, b + 1)
130
-
131
- @checkpoint
132
- def test_a(a: int, b: int) -> int:
133
- return helper(a, b)
134
-
135
- @checkpoint
136
- def test_b(a: int, b: int) -> int:
137
- return test_a(a, b) + multiply_wrapper(a, b)
138
-
139
- assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
140
- assert set(test_b.depends) == {test_b.fn, test_a, multiply_wrapper, global_multiply}
141
-
142
- def test_lazy_init_1():
143
- @checkpoint
144
- def fn1(x: object) -> object:
145
- return fn2(x)
146
-
147
- @checkpoint
148
- def fn2(x: object) -> object:
149
- return fn1(x)
150
-
151
- assert set(fn1.depends) == {fn1.fn, fn2}
152
- assert set(fn2.depends) == {fn1, fn2.fn}
153
-
154
- def test_lazy_init_2():
155
- @checkpoint
156
- def fn1(x: object) -> object:
157
- return fn2(x)
158
-
159
- assert set(fn1.depends) == {fn1.fn}
160
-
161
- @checkpoint
162
- def fn2(x: object) -> object:
163
- return fn1(x)
164
-
165
- assert set(fn1.depends) == {fn1.fn}
166
- fn1.reinit()
167
- assert set(fn1.depends) == {fn1.fn, fn2}
168
- assert set(fn2.depends) == {fn1, fn2.fn}
@@ -1,16 +0,0 @@
1
- checkpointer/__init__.py,sha256=ayjFyHwvl_HRHwocY-hOJvAx0Ko5X9IMZrNT4CwfoMU,824
2
- checkpointer/checkpoint.py,sha256=fSIgZATbTMBwZWLeWH3wIzHoPRQ5LySeb9Ygi30rCg8,9415
3
- checkpointer/fn_ident.py,sha256=-5XbovQowVyYCFc7JdT9z1NoIEiL8h9fi7alF_34Ils,4470
4
- checkpointer/object_hash.py,sha256=YlyFupQrg3V2mpzTLfOqpqlZWhoSCHliScQ4cKd36T0,8133
5
- checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
6
- checkpointer/test_checkpointer.py,sha256=-EvsMMNOOiIxhTcG97LLX0jUMWp534ko7qCKDSFWiA0,3802
7
- checkpointer/types.py,sha256=_dxYqzqzV8GB_g-MQlN_Voie32syKy8u7RHbc0i4upY,338
8
- checkpointer/utils.py,sha256=0cGVSlTnABgs3jI1uHoTfz353kkGa-qtTfe7jG4NCr0,2192
9
- checkpointer/storages/__init__.py,sha256=en32nTUltpCSgz8RVGS_leIHC1Y1G89IqG1ZqAb6qUo,236
10
- checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
11
- checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
12
- checkpointer/storages/storage.py,sha256=5Jel7VlmCG9gnIbZFKT1NrEiAePz8ZD8hfsD-tYEiP4,905
13
- checkpointer-2.11.1.dist-info/METADATA,sha256=vpXrjRj8-yid19Lh_tU45cNdVkYalFj_RP3pRCm3S8A,11633
14
- checkpointer-2.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- checkpointer-2.11.1.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
16
- checkpointer-2.11.1.dist-info/RECORD,,