checkpointer 2.12.0__py3-none-any.whl → 2.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import gc
2
2
  import tempfile
3
+ from datetime import timedelta
3
4
  from typing import Callable
4
5
  from .checkpoint import CachedFunction, Checkpointer, CheckpointError, FunctionIdent
5
6
  from .object_hash import ObjectHash
@@ -8,8 +9,8 @@ from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeO
8
9
 
9
10
  checkpoint = Checkpointer()
10
11
  capture_checkpoint = Checkpointer(capture=True)
11
- memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
- tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
12
+ memory_checkpoint = Checkpointer(storage="memory", verbosity=0)
13
+ tmp_checkpoint = Checkpointer(directory=f"{tempfile.gettempdir()}/checkpoints")
13
14
  static_checkpoint = Checkpointer(fn_hash_from=())
14
15
 
15
16
  def cleanup_all(invalidated=True, expired=True):
@@ -1,19 +1,19 @@
1
1
  from __future__ import annotations
2
2
  import re
3
- from datetime import datetime
3
+ from datetime import datetime, timedelta
4
4
  from functools import cached_property, update_wrapper
5
5
  from inspect import Parameter, iscoroutine, signature, unwrap
6
- from itertools import chain
7
6
  from pathlib import Path
8
7
  from typing import (
9
8
  Callable, Concatenate, Coroutine, Generic, Iterable,
10
- Literal, Self, Type, TypedDict, Unpack, cast, overload,
9
+ Literal, Self, Type, TypedDict, Unpack, overload,
11
10
  )
12
11
  from .fn_ident import Capturable, RawFunctionIdent, get_fn_ident
13
12
  from .object_hash import ObjectHash
14
13
  from .print_checkpoint import print_checkpoint
15
14
  from .storages import STORAGE_MAP, Storage, StorageType
16
15
  from .types import AwaitableValue, C, Coro, Fn, P, R, hash_by_from_annotation
16
+ from .utils import flatten
17
17
 
18
18
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
19
19
 
@@ -21,21 +21,21 @@ class CheckpointError(Exception):
21
21
  pass
22
22
 
23
23
  class CheckpointerOpts(TypedDict, total=False):
24
- format: Type[Storage] | StorageType
25
- root_path: Path | str | None
24
+ storage: Type[Storage] | StorageType
25
+ directory: Path | str | None
26
26
  when: bool
27
27
  verbosity: Literal[0, 1, 2]
28
- should_expire: Callable[[datetime], bool] | None
28
+ expiry: Callable[[datetime], bool] | timedelta | None
29
29
  capture: bool
30
30
  fn_hash_from: object
31
31
 
32
32
  class Checkpointer:
33
33
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
34
- self.format = opts.get("format", "pickle")
35
- self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
34
+ self.storage = opts.get("storage", "pickle")
35
+ self.directory = Path(opts.get("directory", DEFAULT_DIR) or ".")
36
36
  self.when = opts.get("when", True)
37
37
  self.verbosity = opts.get("verbosity", 1)
38
- self.should_expire = opts.get("should_expire")
38
+ self.expiry = opts.get("expiry")
39
39
  self.capture = opts.get("capture", False)
40
40
  self.fn_hash_from = opts.get("fn_hash_from")
41
41
 
@@ -56,24 +56,46 @@ class FunctionIdent:
56
56
  Separated from CachedFunction to prevent hash desynchronization
57
57
  among bound instances when `.reinit()` is called.
58
58
  """
59
- def __init__(self, cached_fn: CachedFunction):
60
- self.__dict__.clear()
59
+ __slots__ = (
60
+ "checkpointer", "cached_fn", "fn", "fn_dir", "pos_names",
61
+ "arg_names", "default_args", "hash_by_map", "__dict__",
62
+ )
63
+
64
+ def __init__(self, cached_fn: CachedFunction, checkpointer: Checkpointer, fn: Callable):
65
+ wrapped = unwrap(fn)
66
+ fn_file = Path(wrapped.__code__.co_filename).name
67
+ fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
68
+ params = list(signature(wrapped).parameters.values())
69
+ pos_param_types = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
70
+ named_param_types = (Parameter.KEYWORD_ONLY,) + pos_param_types
71
+ name_by_kind = {Parameter.VAR_POSITIONAL: b"*", Parameter.VAR_KEYWORD: b"**"}
72
+ self.checkpointer = checkpointer
61
73
  self.cached_fn = cached_fn
74
+ self.fn = fn
75
+ self.fn_dir = f"{fn_file}/{fn_name}"
76
+ self.pos_names = [param.name for param in params if param.kind in pos_param_types]
77
+ self.arg_names = {param.name for param in params if param.kind in named_param_types}
78
+ self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
79
+ self.hash_by_map = {
80
+ name_by_kind.get(param.kind, param.name): hash_by
81
+ for param in params
82
+ if (hash_by := hash_by_from_annotation(param.annotation))
83
+ }
62
84
 
63
85
  def reset(self):
64
- self.__init__(self.cached_fn)
86
+ self.__dict__.clear()
65
87
 
66
88
  def is_static(self) -> bool:
67
- return self.cached_fn.checkpointer.fn_hash_from is not None
89
+ return self.checkpointer.fn_hash_from is not None
68
90
 
69
91
  @cached_property
70
92
  def raw_ident(self) -> RawFunctionIdent:
71
- return get_fn_ident(unwrap(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
93
+ return get_fn_ident(unwrap(self.fn), self.checkpointer.capture)
72
94
 
73
95
  @cached_property
74
96
  def fn_hash(self) -> str:
75
97
  if self.is_static():
76
- return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
98
+ return str(ObjectHash(self.checkpointer.fn_hash_from, digest_size=16))
77
99
  depends = self.deep_idents(past_static=False)
78
100
  deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
79
101
  return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
@@ -105,29 +127,21 @@ class FunctionIdent:
105
127
 
106
128
  class CachedFunction(Generic[Fn]):
107
129
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
108
- wrapped = unwrap(fn)
109
- fn_file = Path(wrapped.__code__.co_filename).name
110
- fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
111
- Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
112
- update_wrapper(cast(Callable, self), wrapped)
113
- self.checkpointer = checkpointer
114
- self.fn = fn
115
- self.fn_dir = f"{fn_file}/{fn_name}"
130
+ store_format = checkpointer.storage
131
+ Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
132
+ update_wrapper(self, unwrap(fn)) # type: ignore
133
+ self.ident = FunctionIdent(self, checkpointer, fn)
116
134
  self.storage = Storage(self)
117
- self.cleanup = self.storage.cleanup
118
135
  self.bound = ()
119
136
 
120
- params = list(signature(wrapped).parameters.values())
121
- pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
122
- self.arg_names = [param.name for param in params if param.kind in pos_params]
123
- self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
124
- self.hash_by_map = get_hash_by_map(params)
125
- self.ident = FunctionIdent(self)
126
-
127
137
  @overload
128
138
  def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
129
139
  @overload
130
- def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
140
+ def __get__(
141
+ self: CachedFunction[Callable[Concatenate[C, P], R]],
142
+ instance: C,
143
+ owner: Type[C],
144
+ ) -> CachedFunction[Callable[P, R]]: ...
131
145
  def __get__(self, instance, owner):
132
146
  if instance is None:
133
147
  return self
@@ -137,62 +151,65 @@ class CachedFunction(Generic[Fn]):
137
151
  return bound_fn
138
152
 
139
153
  @property
140
- def depends(self) -> list[Callable]:
141
- return self.ident.raw_ident.depends
154
+ def fn(self) -> Fn:
155
+ return self.ident.fn # type: ignore
142
156
 
143
- def reinit(self, recursive=False) -> CachedFunction[Fn]:
157
+ @property
158
+ def cleanup(self):
159
+ return self.storage.cleanup
160
+
161
+ def reinit(self, recursive=True) -> CachedFunction[Fn]:
144
162
  depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
145
163
  for ident in depend_idents: ident.reset()
146
164
  for ident in depend_idents: ident.fn_hash
147
165
  return self
148
166
 
149
167
  def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
168
+ ident = self.ident
150
169
  args = self.bound + args
151
- pos_args = args[len(self.arg_names):]
152
- named_pos_args = dict(zip(self.arg_names, args))
153
- named_args = {**self.default_args, **named_pos_args, **kw}
154
- if hash_by_map := self.hash_by_map:
155
- rest_hash_by = hash_by_map.get(b"**")
156
- for key, value in named_args.items():
157
- if hash_by := hash_by_map.get(key, rest_hash_by):
158
- named_args[key] = hash_by(value)
159
- if pos_hash_by := hash_by_map.get(b"*"):
160
- pos_args = map(pos_hash_by, pos_args)
161
- named_args_iter = chain.from_iterable(sorted(named_args.items()))
162
- captured = chain.from_iterable(capturable.capture() for capturable in self.ident.capturables)
163
- obj_hash = ObjectHash(digest_size=16) \
164
- .update(iter=named_args_iter, header="NAMED") \
165
- .update(iter=pos_args, header="POS") \
166
- .update(iter=captured, header="CAPTURED")
167
- return str(obj_hash)
170
+ pos_args = args[len(ident.pos_names):]
171
+ named_pos_args = dict(zip(ident.pos_names, args))
172
+ named_args = {**ident.default_args, **named_pos_args, **kw}
173
+ for key, hash_by in ident.hash_by_map.items():
174
+ if isinstance(key, str):
175
+ named_args[key] = hash_by(named_args[key])
176
+ elif key == b"*":
177
+ pos_args = map(hash_by, pos_args)
178
+ elif key == b"**":
179
+ for key in kw.keys() - ident.arg_names:
180
+ named_args[key] = hash_by(named_args[key])
181
+ call_hash = ObjectHash(digest_size=16) \
182
+ .update(header="NAMED", iter=flatten(sorted(named_args.items()))) \
183
+ .update(header="POS", iter=pos_args) \
184
+ .update(header="CAPTURED", iter=flatten(c.capture() for c in ident.capturables))
185
+ return str(call_hash)
168
186
 
169
187
  def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
170
188
  return self._get_call_hash(args, kw)
171
189
 
172
- async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
190
+ async def _store_coroutine(self, call_hash: str, coroutine: Coroutine):
173
191
  return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
174
192
 
175
193
  def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
176
194
  full_args = self.bound + args
177
- params = self.checkpointer
195
+ params = self.ident.checkpointer
196
+ storage = self.storage
178
197
  if not params.when:
179
198
  return self.fn(*full_args, **kw)
180
199
 
181
200
  call_hash = self._get_call_hash(args, kw)
182
- call_id = f"{self.storage.fn_id()}/{call_hash}"
183
- refresh = rerun \
184
- or not self.storage.exists(call_hash) \
185
- or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_hash)))
201
+ call_id = f"{storage.fn_id()}/{call_hash}"
202
+ refresh = rerun or not storage.exists(call_hash) or storage.expired(call_hash)
186
203
 
187
204
  if refresh:
188
205
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id, "blue")
189
206
  data = self.fn(*full_args, **kw)
190
207
  if iscoroutine(data):
191
- return self._resolve_coroutine(call_hash, data)
192
- return self.storage.store(call_hash, data)
208
+ return self._store_coroutine(call_hash, data)
209
+ return storage.store(call_hash, data)
193
210
 
194
211
  try:
195
- data = self.storage.load(call_hash)
212
+ data = storage.load(call_hash)
196
213
  print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id, "green")
197
214
  return data
198
215
  except (EOFError, FileNotFoundError):
@@ -232,15 +249,6 @@ class CachedFunction(Generic[Fn]):
232
249
  self.storage.store(self._get_call_hash(args, kw), value)
233
250
 
234
251
  def __repr__(self) -> str:
235
- return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
236
-
237
- def get_hash_by_map(params: list[Parameter]) -> dict[str | bytes, Callable[[object], object]]:
238
- hash_by_map = {}
239
- for param in params:
240
- name = param.name
241
- if param.kind == Parameter.VAR_POSITIONAL:
242
- name = b"*"
243
- elif param.kind == Parameter.VAR_KEYWORD:
244
- name = b"**"
245
- hash_by_map[name] = hash_by_from_annotation(param.annotation)
246
- return hash_by_map if any(hash_by_map.values()) else {}
252
+ initialized = "fn_hash" in self.ident.__dict__
253
+ fn_hash = self.ident.fn_hash[:6] if initialized else "- uninitialized"
254
+ return f"<CachedFunction {self.fn.__name__} {fn_hash}>"
checkpointer/fn_ident.py CHANGED
@@ -2,11 +2,12 @@ import dis
2
2
  from inspect import Parameter, getmodule, signature, unwrap
3
3
  from types import CodeType, MethodType, ModuleType
4
4
  from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
5
+ from .fn_string import get_fn_aststr
5
6
  from .import_mappings import resolve_annotation
6
7
  from .object_hash import ObjectHash
7
8
  from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
8
9
  from .utils import (
9
- AttrDict, cwd, distinct, get_cell_contents,
10
+ cwd, distinct, get_at, get_cell_contents,
10
11
  get_file, is_class, is_user_fn, seekable, takewhile,
11
12
  )
12
13
 
@@ -28,77 +29,77 @@ class Capturable(NamedTuple):
28
29
  def capture(self) -> tuple[str, object]:
29
30
  if obj := self.hash:
30
31
  return self.key, obj
31
- obj = AttrDict.get_at(self.module, *self.attr_path)
32
+ obj = get_at(self.module, *self.attr_path)
32
33
  obj = self.hash_by(obj) if self.hash_by else obj
33
34
  return self.key, obj
34
35
 
35
36
  @staticmethod
36
37
  def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
37
38
  file = str(get_file(module).relative_to(cwd))
38
- key = "-".join((file, *attr_path))
39
+ key = file + "/" + ".".join(attr_path)
39
40
  cap = Capturable(key, module, attr_path, hash_by)
40
41
  if not capture_once:
41
42
  return cap
42
43
  obj_hash = str(ObjectHash(cap.capture()[1]))
43
44
  return Capturable(key, module, attr_path, None, obj_hash)
44
45
 
45
- def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
46
+ def extract_classvars(code: CodeType, scope_vars: dict) -> dict[str, dict[str, Type]]:
46
47
  attr_path = AttrPath(())
47
48
  scope_obj = None
48
49
  classvars: dict[str, dict[str, Type]] = {}
49
50
  instructs = seekable(dis.get_instructions(code))
50
- for instr in instructs:
51
- if instr.opname in scope_vars and not attr_path:
51
+ for instruct in instructs:
52
+ if instruct.opname in scope_vars and not attr_path:
52
53
  attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
53
- attr_path = AttrPath((instr.opname, instr.argval, *attrs))
54
+ attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
54
55
  instructs.step(-1)
55
- elif instr.opname == "CALL":
56
- obj = scope_vars.get_at(*attr_path)
56
+ elif instruct.opname == "CALL":
57
+ obj = get_at(scope_vars, *attr_path)
57
58
  attr_path = AttrPath(())
58
59
  if is_class(obj):
59
60
  scope_obj = obj
60
- elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
61
- load_key = instr.opname.replace("STORE", "LOAD")
62
- classvars.setdefault(load_key, {})[instr.argval] = scope_obj
61
+ elif instruct.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
62
+ load_key = instruct.opname.replace("STORE", "LOAD")
63
+ classvars.setdefault(load_key, {})[instruct.argval] = scope_obj
63
64
  scope_obj = None
64
65
  return classvars
65
66
 
66
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[AttrPath, object]]:
67
+ def extract_scope_values(code: CodeType, scope_vars: dict) -> Iterable[tuple[AttrPath, object]]:
67
68
  classvars = extract_classvars(code, scope_vars)
68
- scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
69
+ scope_vars = {**scope_vars, **{k: {**scope_vars[k], **v} for k, v in classvars.items()}}
69
70
  instructs = seekable(dis.get_instructions(code))
70
- for instr in instructs:
71
- if instr.opname in scope_vars:
71
+ for instruct in instructs:
72
+ if instruct.opname in scope_vars:
72
73
  attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
73
- attr_path = AttrPath((instr.opname, instr.argval, *attrs))
74
+ attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
74
75
  parent_path = attr_path[:-1]
75
76
  instructs.step(-1)
76
- obj = scope_vars.get_at(*attr_path)
77
+ obj = get_at(scope_vars, *attr_path)
77
78
  if obj is not None:
78
79
  yield attr_path, obj
79
80
  if callable(obj) and parent_path[1:]:
80
- parent_obj = scope_vars.get_at(*parent_path)
81
+ parent_obj = get_at(scope_vars, *parent_path)
81
82
  yield parent_path, parent_obj
82
83
  for const in code.co_consts:
83
84
  if isinstance(const, CodeType):
84
- next_deref = scope_vars.LOAD_DEREF.set(scope_vars.LOAD_FAST)
85
- next_scope_vars = AttrDict({**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref})
85
+ next_deref = {**scope_vars["LOAD_DEREF"], **scope_vars["LOAD_FAST"]}
86
+ next_scope_vars = {**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref}
86
87
  yield from extract_scope_values(const, next_scope_vars)
87
88
 
88
- def resolve_class_annotations(anno: object) -> Type | None:
89
+ def class_from_annotation(anno: object) -> Type | None:
89
90
  if anno in (None, Annotated):
90
91
  return None
91
- elif is_class(anno):
92
+ if is_class(anno):
92
93
  return anno
93
- elif get_origin(anno) is Annotated:
94
- return resolve_class_annotations(next(iter(get_args(anno)), None))
95
- return resolve_class_annotations(get_origin(anno))
94
+ if get_origin(anno) is Annotated:
95
+ return class_from_annotation(next(iter(get_args(anno)), None))
96
+ return class_from_annotation(get_origin(anno))
96
97
 
97
- def get_self_value(fn: Callable) -> type | object | None:
98
+ def get_self_value(fn: Callable) -> Type | object | None:
98
99
  if isinstance(fn, MethodType):
99
100
  return fn.__self__
100
101
  parts = fn.__qualname__.split(".")[:-1]
101
- cls = parts and AttrDict(fn.__globals__).get_at(*parts)
102
+ cls = parts and get_at(fn.__globals__, *parts)
102
103
  if is_class(cls):
103
104
  return cls
104
105
 
@@ -106,9 +107,9 @@ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, o
106
107
  module = getmodule(fn)
107
108
  if not module or not is_user_fn(fn):
108
109
  return
109
- for (instruct, *attr_path), obj in captured_vars.items():
110
+ for (instruct_type, *attr_path), obj in captured_vars.items():
110
111
  attr_path = AttrPath(attr_path)
111
- if instruct == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
112
+ if instruct_type == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
112
113
  anno = resolve_annotation(module, ".".join(attr_path))
113
114
  if capture or is_capture_me(anno) or is_capture_me_once(anno):
114
115
  hash_by = hash_by_from_annotation(anno)
@@ -116,19 +117,20 @@ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, o
116
117
  yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
117
118
 
118
119
  def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
119
- sig_scope = {
120
+ scope_vars_signature: dict[str, Type | object] = {
120
121
  param.name: class_anno
121
122
  for param in signature(fn).parameters.values()
122
123
  if param.annotation is not Parameter.empty
123
124
  if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
124
- if (class_anno := resolve_class_annotations(param.annotation))
125
+ if (class_anno := class_from_annotation(param.annotation))
126
+ }
127
+ if self_obj := get_self_value(fn):
128
+ scope_vars_signature["self"] = self_obj
129
+ scope_vars = {
130
+ "LOAD_FAST": scope_vars_signature,
131
+ "LOAD_DEREF": dict(get_cell_contents(fn)),
132
+ "LOAD_GLOBAL": fn.__globals__,
125
133
  }
126
- self_value = get_self_value(fn)
127
- scope_vars = AttrDict({
128
- "LOAD_FAST": AttrDict({**sig_scope, "self": self_value} if self_value else sig_scope),
129
- "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
130
- "LOAD_GLOBAL": AttrDict(fn.__globals__),
131
- })
132
134
  captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
133
135
  captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
134
136
  capturables = list(get_capturables(fn, capture, captured_vars))
@@ -142,7 +144,7 @@ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn
142
144
  for depend_fn in captured_callables:
143
145
  depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
144
146
  if isinstance(depend_fn, CachedFunction):
145
- capturable_by_fn[depend_fn] = []
147
+ capturable_by_fn[depend_fn.ident.cached_fn] = []
146
148
  elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
147
149
  get_depend_fns(depend_fn, capture, capturable_by_fn)
148
150
  return capturable_by_fn
@@ -150,10 +152,10 @@ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn
150
152
  def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
151
153
  from .checkpoint import CachedFunction
152
154
  capturable_by_fn = get_depend_fns(fn, capture)
153
- capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
154
155
  depends = capturable_by_fn.keys()
155
156
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
156
- unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
157
- assert fn == unwrapped_depends[0]
158
- fn_hash = str(ObjectHash(iter=unwrapped_depends))
157
+ depend_callables = [fn for fn in depends if not isinstance(fn, CachedFunction)]
158
+ fn_hash = str(ObjectHash(iter=map(get_fn_aststr, depend_callables)))
159
+ capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
160
+ assert fn == depend_callables[0]
159
161
  return RawFunctionIdent(fn_hash, depends, capturables)
@@ -0,0 +1,75 @@
1
+ import ast
2
+ import sys
3
+ from inspect import getsource
4
+ from textwrap import dedent
5
+ from typing import Callable
6
+ from .utils import drop_none, get_at
7
+
8
+ def get_decorator_path(node: ast.AST) -> tuple[str, ...]:
9
+ if isinstance(node, ast.Call):
10
+ return get_decorator_path(node.func)
11
+ elif isinstance(node, ast.Attribute):
12
+ return get_decorator_path(node.value) + (node.attr,)
13
+ elif isinstance(node, ast.Name):
14
+ return (node.id,)
15
+ else:
16
+ return ()
17
+
18
+ def is_lone_expression(node: ast.AST) -> bool:
19
+ # Filter out docstrings
20
+ return isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant)
21
+
22
+ class CleanFunctionTransform(ast.NodeTransformer):
23
+ def __init__(self, fn_globals: dict):
24
+ self.is_root = True
25
+ self.fn_globals = fn_globals
26
+
27
+ def is_checkpointer(self, node: ast.AST) -> bool:
28
+ from .checkpoint import Checkpointer
29
+ decorator = get_at(self.fn_globals, *get_decorator_path(node))
30
+ return isinstance(decorator, Checkpointer) or decorator is Checkpointer
31
+
32
+ def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
33
+ fn_type = type(node).__name__
34
+ fn_name = None if self.is_root else node.name
35
+ args_by_type = [
36
+ node.args.posonlyargs + node.args.args,
37
+ drop_none([node.args.vararg]),
38
+ sorted(node.args.kwonlyargs, key=lambda x: x.arg),
39
+ drop_none([node.args.kwarg]),
40
+ ]
41
+ arg_kind_names = ",".join(f"{i}:{arg.arg}" for i, args in enumerate(args_by_type) for arg in args)
42
+ header = " ".join(drop_none((fn_type, fn_name, arg_kind_names or None)))
43
+
44
+ self.is_root = False
45
+
46
+ return ast.List([
47
+ ast.Constant(header),
48
+ ast.List([child for child in node.decorator_list if not self.is_checkpointer(child)], ast.Load()),
49
+ ast.List([self.visit(child) for child in node.body if not is_lone_expression(child)], ast.Load()),
50
+ ], ast.Load())
51
+
52
+ def visit_AsyncFunctionDef(self, node):
53
+ return self.visit_FunctionDef(node)
54
+
55
+ def get_fn_aststr(fn: Callable) -> str:
56
+ try:
57
+ source = getsource(fn)
58
+ except OSError:
59
+ return ""
60
+ try:
61
+ tree = ast.parse(dedent(source), mode="exec")
62
+ tree = tree.body[0]
63
+ except SyntaxError:
64
+ # lambda functions can cause SyntaxError in ast.parse
65
+ return source.strip()
66
+
67
+ if fn.__name__ != "<lambda>":
68
+ tree = CleanFunctionTransform(fn.__globals__).visit(tree)
69
+ else:
70
+ tree = ast.List([node for node in ast.walk(tree) if isinstance(node, ast.Lambda)], ast.Load())
71
+
72
+ if sys.version_info >= (3, 13):
73
+ return ast.dump(tree, annotate_fields=False, show_empty=True)
74
+ else:
75
+ return ast.dump(tree, annotate_fields=False)
@@ -5,32 +5,38 @@ import io
5
5
  import re
6
6
  import sys
7
7
  import tokenize
8
+ import sysconfig
9
+ from collections import OrderedDict
8
10
  from collections.abc import Iterable
9
11
  from contextlib import nullcontext, suppress
10
12
  from decimal import Decimal
11
13
  from io import StringIO
14
+ from inspect import getfile
12
15
  from itertools import chain
16
+ from pathlib import Path
13
17
  from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
14
- from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
18
+ from types import BuiltinFunctionType, FunctionType, GeneratorType, MappingProxyType, MethodType, ModuleType, UnionType
15
19
  from typing import Callable, Self, TypeVar
16
- from .utils import ContextVar
20
+ from .utils import ContextVar, flatten
17
21
 
18
22
  np, torch = None, None
19
23
 
20
- with suppress(Exception):
21
- import numpy as np
22
- with suppress(Exception):
23
- import torch
24
-
25
24
  class _Never:
26
25
  def __getattribute__(self, _: str):
27
26
  pass
28
27
 
28
+ with suppress(Exception):
29
+ import numpy as np
30
+ with suppress(Exception):
31
+ import torch
29
32
  if sys.version_info >= (3, 12):
30
33
  from typing import TypeAliasType
31
34
  else:
32
35
  TypeAliasType = _Never
33
36
 
37
+ nc = nullcontext()
38
+ stdlib = Path(sysconfig.get_paths()["stdlib"]).resolve()
39
+
34
40
  def encode_type(t: type | FunctionType) -> str:
35
41
  return f"{t.__module__}:{t.__qualname__}"
36
42
 
@@ -77,7 +83,7 @@ class ObjectHash:
77
83
  return self.write_bytes(":".join(map(str, args)).encode())
78
84
 
79
85
  def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
80
- with nullcontext() if tolerable is None else self.tolerable.set(tolerable):
86
+ with nc if tolerable is None else self.tolerable.set(tolerable):
81
87
  for obj in chain(objs, iter):
82
88
  if header is not None:
83
89
  self.write_bytes(header.encode())
@@ -105,11 +111,11 @@ class ObjectHash:
105
111
 
106
112
  case set() | frozenset():
107
113
  try:
108
- items = sorted(obj)
109
114
  header = "set"
115
+ items = sorted(obj)
110
116
  except:
111
- items = sorted(map(self.nested_hash, obj))
112
117
  header = "set-unsortable"
118
+ items = sorted(map(self.nested_hash, obj))
113
119
  self.header(header, encode_type_of(obj), len(obj)).update(iter=items)
114
120
 
115
121
  case TypeVar():
@@ -125,7 +131,11 @@ class ObjectHash:
125
131
  self.header("builtin", obj.__qualname__)
126
132
 
127
133
  case FunctionType():
128
- self.header("function", encode_type(obj)).update(get_fn_body(obj), obj.__defaults__, obj.__kwdefaults__, obj.__annotations__)
134
+ fn_file = Path(getfile(obj)).resolve()
135
+ if fn_file.is_relative_to(stdlib):
136
+ self.header("function-std", obj.__qualname__)
137
+ else:
138
+ self.header("function", encode_type(obj)).update(get_fn_body(obj), obj.__defaults__, obj.__kwdefaults__, obj.__annotations__)
129
139
 
130
140
  case MethodType():
131
141
  self.header("method").update(obj.__func__, obj.__self__.__class__)
@@ -170,14 +180,16 @@ class ObjectHash:
170
180
  match obj:
171
181
  case list() | tuple():
172
182
  self.header("list", encode_type_of(obj), len(obj)).update(iter=obj)
173
- case dict():
174
- try:
175
- items = sorted(obj.items())
176
- header = "dict"
177
- except:
178
- items = sorted((self.nested_hash(key), val) for key, val in obj.items())
179
- header = "dict-unsortable"
180
- self.header(header, encode_type_of(obj), len(obj)).update(iter=chain.from_iterable(items))
183
+ case dict() | MappingProxyType():
184
+ header = "dict"
185
+ items = obj.items()
186
+ if not isinstance(obj, OrderedDict):
187
+ try:
188
+ items = sorted(items)
189
+ except:
190
+ header = "dict-unsortable"
191
+ items = sorted((self.nested_hash(key), val) for key, val in items)
192
+ self.header(header, encode_type_of(obj), len(obj)).update(iter=flatten(items))
181
193
  case _:
182
194
  self._update_object(obj)
183
195
  finally:
@@ -31,7 +31,7 @@ class MemoryStorage(Storage):
31
31
  if key.parent == curr_key.parent:
32
32
  if invalidated and key != curr_key:
33
33
  del item_map[key]
34
- elif expired and self.checkpointer.should_expire:
34
+ elif expired and self.checkpointer.expiry:
35
35
  for call_hash, (date, _) in list(calldict.items()):
36
- if self.checkpointer.should_expire(date):
36
+ if self.expired_dt(date):
37
37
  del calldict[call_hash]
@@ -9,7 +9,7 @@ def filedate(path: Path) -> datetime:
9
9
 
10
10
  class PickleStorage(Storage):
11
11
  def get_path(self, call_hash: str):
12
- return self.fn_dir() / f"{call_hash}.pkl"
12
+ return self.fn_dir() / f"{call_hash[:2]}/{call_hash[2:]}.pkl"
13
13
 
14
14
  def store(self, call_hash, data):
15
15
  path = self.get_path(call_hash)
@@ -40,10 +40,10 @@ class PickleStorage(Storage):
40
40
  for path in old_dirs:
41
41
  shutil.rmtree(path)
42
42
  print(f"Removed {len(old_dirs)} invalidated directories for {self.cached_fn.__qualname__}")
43
- if expired and self.checkpointer.should_expire:
43
+ if expired and self.checkpointer.expiry:
44
44
  count = 0
45
45
  for pkl_path in fn_path.glob("**/*.pkl"):
46
- if self.checkpointer.should_expire(filedate(pkl_path)):
46
+ if self.expired_dt(filedate(pkl_path)):
47
47
  count += 1
48
48
  pkl_path.unlink(missing_ok=True)
49
49
  print(f"Removed {count} expired checkpoints for {self.cached_fn.__qualname__}")
@@ -1,24 +1,38 @@
1
1
  from __future__ import annotations
2
2
  from typing import Any, TYPE_CHECKING
3
3
  from pathlib import Path
4
- from datetime import datetime
4
+ from datetime import datetime, timedelta
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from ..checkpoint import Checkpointer, CachedFunction
8
8
 
9
9
  class Storage:
10
10
  checkpointer: Checkpointer
11
- cached_fn: CachedFunction
11
+ ident: CachedFunction
12
12
 
13
13
  def __init__(self, cached_fn: CachedFunction):
14
- self.checkpointer = cached_fn.checkpointer
14
+ self.checkpointer = cached_fn.ident.checkpointer
15
15
  self.cached_fn = cached_fn
16
16
 
17
17
  def fn_id(self) -> str:
18
- return f"{self.cached_fn.fn_dir}/{self.cached_fn.ident.fn_hash}"
18
+ ident = self.cached_fn.ident
19
+ return f"{ident.fn_dir}/{ident.fn_hash}"
19
20
 
20
21
  def fn_dir(self) -> Path:
21
- return self.checkpointer.root_path / self.fn_id()
22
+ return self.checkpointer.directory / self.fn_id()
23
+
24
+ def expired(self, call_hash: str) -> bool:
25
+ if not self.checkpointer.expiry:
26
+ return False
27
+ return self.expired_dt(self.checkpoint_date(call_hash))
28
+
29
+ def expired_dt(self, dt: datetime) -> bool:
30
+ expiry = self.checkpointer.expiry
31
+ if isinstance(expiry, timedelta):
32
+ return dt < datetime.now() - expiry
33
+ else:
34
+ if TYPE_CHECKING: assert expiry
35
+ return expiry(dt)
22
36
 
23
37
  def store(self, call_hash: str, data: Any) -> Any: ...
24
38
 
checkpointer/utils.py CHANGED
@@ -1,13 +1,14 @@
1
1
  from __future__ import annotations
2
2
  import inspect
3
- from contextlib import contextmanager
4
- from itertools import islice
3
+ from contextlib import contextmanager, suppress
4
+ from itertools import chain, islice
5
5
  from pathlib import Path
6
6
  from types import FunctionType, MethodType, ModuleType
7
7
  from typing import Callable, Generic, Iterable, Self, Type, TypeGuard
8
8
  from .types import T
9
9
 
10
10
  cwd = Path.cwd().resolve()
11
+ flatten = chain.from_iterable
11
12
 
12
13
  def is_class(obj) -> TypeGuard[Type]:
13
14
  return isinstance(obj, type)
@@ -22,11 +23,12 @@ def is_user_fn(obj) -> TypeGuard[Callable]:
22
23
  return isinstance(obj, (FunctionType, MethodType)) and is_user_file(get_file(obj))
23
24
 
24
25
  def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
25
- for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
26
- try:
26
+ for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or ()):
27
+ with suppress(ValueError):
27
28
  yield (key, cell.cell_contents)
28
- except ValueError:
29
- pass
29
+
30
+ def drop_none(iterable: Iterable[T | None]) -> list[T]:
31
+ return [x for x in iterable if x is not None]
30
32
 
31
33
  def distinct(seq: Iterable[T]) -> list[T]:
32
34
  return list(dict.fromkeys(seq))
@@ -80,6 +82,14 @@ class seekable(Generic[T]):
80
82
  with self.freeze():
81
83
  return list(islice(self, count))
82
84
 
85
+ def get_at(obj: object, *attrs: str) -> object:
86
+ for attr in attrs:
87
+ if type(obj) is dict:
88
+ obj = obj.get(attr, None)
89
+ else:
90
+ obj = getattr(obj, attr, None)
91
+ return obj
92
+
83
93
  class AttrDict(dict):
84
94
  def __init__(self, *args, **kwargs):
85
95
  super().__init__(*args, **kwargs)
@@ -91,17 +101,6 @@ class AttrDict(dict):
91
101
  def __setattr__(self, name: str, value: object):
92
102
  super().__setattr__(name, value)
93
103
 
94
- def set(self, d: dict) -> AttrDict:
95
- if not d:
96
- return self
97
- return AttrDict({**self, **d})
98
-
99
- def get_at(self: object, *attrs: str) -> object:
100
- obj = self
101
- for attr in attrs:
102
- obj = getattr(obj, attr, None)
103
- return obj
104
-
105
104
  class ContextVar(Generic[T]):
106
105
  def __init__(self, value: T):
107
106
  self.value = value
@@ -0,0 +1,243 @@
1
+ Metadata-Version: 2.4
2
+ Name: checkpointer
3
+ Version: 2.14.0
4
+ Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
+ Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
+ Author: Hampus Hallman
7
+ License-Expression: MIT
8
+ License-File: ATTRIBUTION.md
9
+ License-File: LICENSE
10
+ Keywords: async,cache,caching,data analysis,data processing,fast,hashing,invalidation,memoization,optimization,performance,workflow
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Requires-Python: >=3.11
15
+ Description-Content-Type: text/markdown
16
+
17
+ # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![pypi](https://img.shields.io/pypi/pyversions/checkpointer)](https://pypi.org/project/checkpointer/)
18
+
19
+ `checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and invalidates caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
20
+
21
+ ## 📦 Installation
22
+
23
+ ```bash
24
+ pip install checkpointer
25
+ ```
26
+
27
+ ## 🚀 Quick Start
28
+
29
+ Apply the `@checkpoint` decorator to any function:
30
+
31
+ ```python
32
+ from checkpointer import checkpoint
33
+
34
+ @checkpoint
35
+ def expensive_function(x: int) -> int:
36
+ print("Computing...")
37
+ return x ** 2
38
+
39
+ result = expensive_function(4) # Computes and stores the result
40
+ result = expensive_function(4) # Loads from the cache
41
+ ```
42
+
43
+ ## 🧠 How It Works
44
+
45
+ When you decorate a function with `@checkpoint` and call it, `checkpointer` computes a unique identifier that represents that specific call. This identifier is based on:
46
+
47
+ * The function's source code and all its user-defined dependencies,
48
+ * Global variables used by the function (if capturing is enabled or explicitly annotated),
49
+ * The actual arguments passed to the function.
50
+
51
+ `checkpointer` then looks up this identifier in its cache. If a valid cached result exists, it returns that immediately. Otherwise, it runs the original function, stores the result, and returns it.
52
+
53
+ `checkpointer` is designed to be flexible through features like:
54
+
55
+ * **Support for decorated methods**, correctly caching results bound to instances.
56
+ * **Support for decorated async functions**, compatible with any async runtime.
57
+ * **Robust hashing**, covering complex Python objects and large **NumPy**/**PyTorch** arrays via its internal `ObjectHash`.
58
+ * **Targeted hashing**, allowing you to optimize how arguments and captured variables are hashed.
59
+ * **Multi-layered caching**, letting you stack decorators for layered caching strategies without losing cache consistency.
60
+
61
+ ### 🚨 What Causes Cache Invalidation?
62
+
63
+ To ensure cache correctness, `checkpointer` tracks two types of hashes:
64
+
65
+ #### 1. Function Identity Hash (Computed Once Per Function)
66
+
67
+ This hash represents the decorated function itself and is computed once (usually on first invocation). It covers:
68
+
69
+ * **Function Code and Signature:**\
70
+ The actual logic and parameters of the function are hashed - but *not* parameter type annotations or formatting details like whitespace, newlines, comments, or trailing commas, which do **not** trigger invalidation.
71
+
72
+ * **Dependencies:**\
73
+ All user-defined functions and methods that the decorated function calls or relies on, including indirect dependencies, are included recursively. Dependencies are identified by:
74
+ * Inspecting the function's global scope for referenced functions and objects.
75
+ * Inferring from the function's argument type annotations.
76
+ * Analyzing object constructions and method calls to identify classes and methods used.
77
+
78
+ * **Exclusions:**\
79
+ Changes elsewhere in the module unrelated to the function or its dependencies do **not** cause invalidation.
80
+
81
+ #### 2. Call Hash (Computed on Every Function Call)
82
+
83
+ Every function call produces a call hash, combining:
84
+
85
+ * **Passed Arguments:**\
86
+ Includes positional and keyword arguments, combined with default values. Changing defaults alone doesn't necessarily trigger invalidation unless it affects actual call values.
87
+
88
+ * **Captured Global Variables:**\
89
+ When `capture=True` or explicit capture annotations are used, `checkpointer` includes referenced global variables in the call hash. Variables annotated with `CaptureMe` are hashed on every call, causing immediate cache invalidation if they change. Variables annotated with `CaptureMeOnce` are hashed only once per Python session, improving performance by avoiding repeated hashing.
90
+
91
+ * **Custom Argument Hashing:**\
92
+ Using `HashBy` annotations, arguments or captured variables can be transformed before hashing (e.g., sorting lists to ignore order), allowing more precise or efficient call hashes.
93
+
94
+ ## 💡 Usage
95
+
96
+ Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
97
+
98
+ * **`expensive_function(...)`**:\
99
+ Call the function normally. This will compute and cache the result or load it from cache.
100
+
101
+ * **`expensive_function.rerun(...)`**:\
102
+ Force the original function to execute and overwrite any existing cached result.
103
+
104
+ * **`expensive_function.fn(...)`**:\
105
+ Call the undecorated function directly, bypassing the cache (useful in recursion to prevent caching intermediate steps).
106
+
107
+ * **`expensive_function.get(...)`**:\
108
+ Retrieve the cached result without executing the function. Raises `CheckpointError` if no valid cache exists.
109
+
110
+ * **`expensive_function.exists(...)`**:\
111
+ Check if a cached result exists without computing or loading it.
112
+
113
+ * **`expensive_function.delete(...)`**:\
114
+ Remove the cached entry for given arguments.
115
+
116
+ * **`expensive_function.reinit(recursive: bool = True)`**:\
117
+ Recalculate the function identity hash and recapture `CaptureMeOnce` variables, updating the cached function state within the same Python session.
118
+
119
+ ## ⚙️ Configuration & Customization
120
+
121
+ The `@checkpoint` decorator accepts the following parameters:
122
+
123
+ * **`storage`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
124
+ Storage backend to use: `"pickle"` (disk-based, persistent), `"memory"` (in-memory, non-persistent), or a custom `Storage` class.
125
+
126
+ * **`directory`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
127
+ Base directory for disk-based checkpoints (only for `"pickle"` storage).
128
+
129
+ * **`capture`** (Type: `bool`, Default: `False`)\
130
+ If `True`, includes global variables referenced by the function in call hashes (except those excluded via `NoHash`).
131
+
132
+ * **`expiry`** (Type: `Callable[[datetime.datetime], bool]` or `datetime.timedelta`, Default: `None`)\
133
+ A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
134
+
135
+ * **`fn_hash_from`** (Type: `Any`, Default: `None`)\
136
+ Override the computed function identity hash with any hashable object you provide (e.g., version strings, config IDs). This gives you explicit control over the function's version and when its cache should be invalidated.
137
+
138
+ * **`when`** (Type: `bool`, Default: `True`)\
139
+ Enable or disable checkpointing dynamically, useful for environment-based toggling.
140
+
141
+ * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
142
+ Controls the level of logging output from `checkpointer`.
143
+ * `0`: No output.
144
+ * `1`: Shows when functions are computed and cached.
145
+ * `2`: Also shows when cached results are remembered (loaded from cache).
146
+
147
+ ## 🔬 Customize Argument Hashing
148
+
149
+ You can customize how arguments are hashed without modifying the actual argument values to improve cache hit rates or speed up hashing.
150
+
151
+ * **`Annotated[T, HashBy[fn]]`**:\
152
+ Transform the argument via `fn(argument)` before hashing. Useful for normalization (e.g., sorting lists) or optimized hashing for complex inputs.
153
+
154
+ * **`NoHash[T]`**:\
155
+ Exclude the argument from hashing completely, so changes to it won't trigger cache invalidation.
156
+
157
+ **Example:**
158
+
159
+ ```python
160
+ from typing import Annotated
161
+ from checkpointer import checkpoint, HashBy, NoHash
162
+ from pathlib import Path
163
+ import logging
164
+
165
+ def file_bytes(path: Path) -> bytes:
166
+ return path.read_bytes()
167
+
168
+ @checkpoint
169
+ def process(
170
+ numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
171
+ data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
172
+ log: NoHash[logging.Logger], # Exclude logger from hashing
173
+ ):
174
+ ...
175
+ ```
176
+
177
+ In this example, the hash for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don't affect caching.
178
+
179
+ ## 🎯 Capturing Global Variables
180
+
181
+ `checkpointer` can include **captured global variables** in call hashes - these are globals your function reads during execution that may affect results.
182
+
183
+ Use `capture=True` on `@checkpoint` to capture **all** referenced globals (except those explicitly excluded with `NoHash`).
184
+
185
+ Alternatively, you can **opt-in selectively** by annotating globals with:
186
+
187
+ * **`CaptureMe[T]`**:\
188
+ Capture the variable on every call (triggers invalidation on changes).
189
+
190
+ * **`CaptureMeOnce[T]`**:\
191
+ Capture once per Python session (for expensive, immutable globals).
192
+
193
+ You can also combine these with `HashBy` to customize how captured variables are hashed (e.g., hash by subset of attributes).
194
+
195
+ **Example:**
196
+
197
+ ```python
198
+ from typing import Annotated
199
+ from checkpointer import checkpoint, CaptureMe, CaptureMeOnce, HashBy
200
+ from pathlib import Path
201
+
202
+ def file_bytes(path: Path) -> bytes:
203
+ return path.read_bytes()
204
+
205
+ captured_data: CaptureMe[Annotated[Path, HashBy[file_bytes]]] = Path("data.txt")
206
+ session_config: CaptureMeOnce[dict] = {"mode": "prod"}
207
+
208
+ @checkpoint
209
+ def process():
210
+ # `captured_data` is included in the call hash on every call, hashed by file content
211
+ # `session_config` is hashed once per session
212
+ ...
213
+ ```
214
+
215
+ ## 🗄️ Custom Storage Backends
216
+
217
+ Implement your own storage backend by subclassing `checkpointer.Storage` and overriding required methods.
218
+
219
+ Within storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get function identity (name + hash/version), important for organizing checkpoints.
220
+
221
+ **Example:**
222
+
223
+ ```python
224
+ from checkpointer import checkpoint, Storage
225
+ from datetime import datetime
226
+
227
+ class MyCustomStorage(Storage):
228
+ def exists(self, call_hash):
229
+ fn_dir = self.checkpointer.directory / self.fn_id()
230
+ return (fn_dir / call_hash).exists()
231
+
232
+ def store(self, call_hash, data):
233
+ ... # Store serialized data
234
+ return data # Must return data to checkpointer
235
+
236
+ def checkpoint_date(self, call_hash): ...
237
+ def load(self, call_hash): ...
238
+ def delete(self, call_hash): ...
239
+
240
+ @checkpoint(storage=MyCustomStorage)
241
+ def custom_cached_function(x: int):
242
+ return x ** 2
243
+ ```
@@ -0,0 +1,18 @@
1
+ checkpointer/__init__.py,sha256=mT706hd-BoTpQXYH_MiqreplONVDvjCWjwwGQ7oRe_8,921
2
+ checkpointer/checkpoint.py,sha256=SUHu5ADRtSrapspk2DpBLaYnAQ7yuTc05lUELig_Tpc,9972
3
+ checkpointer/fn_ident.py,sha256=Z2HtkoG8n0ddqNzjejnKW6Lo-ID2unKQZa1tOpRGujs,6911
4
+ checkpointer/fn_string.py,sha256=YldjAU91cwiZNwuE_AQN_DMiQCPAhGiiC3lPi0uvM0g,2579
5
+ checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
6
+ checkpointer/object_hash.py,sha256=_VgyTEoe3H6268KBAOFZxSgWHb6otgG-3NzhHyehxe4,8595
7
+ checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
8
+ checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
9
+ checkpointer/utils.py,sha256=CF28pGytCWV4URhRDFaIDawX8qmCAAByQp0ZkEwrA6c,3014
10
+ checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
11
+ checkpointer/storages/memory_storage.py,sha256=f242kLEsB4TcqEkEBmJGIewJE9mvv9aCAZ0V9pPB2EQ,1110
12
+ checkpointer/storages/pickle_storage.py,sha256=burbgX-ICPLNMAF9iU3HRZroxJibKfmgJLqTCm6YbbI,1649
13
+ checkpointer/storages/storage.py,sha256=PgpwW_VLkiwD361Kjsz8aoq5FHZLBTqF43yD2mw9dTI,1326
14
+ checkpointer-2.14.0.dist-info/METADATA,sha256=MSJ--2X9N_Ug8sd9VhffONSzhIPA_njmzJhCe3fa_l8,11215
15
+ checkpointer-2.14.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ checkpointer-2.14.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
17
+ checkpointer-2.14.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
18
+ checkpointer-2.14.0.dist-info/RECORD,,
@@ -1,236 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: checkpointer
3
- Version: 2.12.0
4
- Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
- Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
- Author: Hampus Hallman
7
- License-Expression: MIT
8
- License-File: ATTRIBUTION.md
9
- License-File: LICENSE
10
- Keywords: async,cache,caching,code-aware,decorator,fast,hashing,invalidation,memoization,memoize,memory,optimization,performance,workflow
11
- Classifier: Programming Language :: Python :: 3.11
12
- Classifier: Programming Language :: Python :: 3.12
13
- Classifier: Programming Language :: Python :: 3.13
14
- Requires-Python: >=3.11
15
- Description-Content-Type: text/markdown
16
-
17
- # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![pypi](https://img.shields.io/pypi/pyversions/checkpointer)](https://pypi.org/project/checkpointer/)
18
-
19
- `checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and refreshes caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
20
-
21
- ## 📦 Installation
22
-
23
- ```bash
24
- pip install checkpointer
25
- ```
26
-
27
- ## 🚀 Quick Start
28
-
29
- Apply the `@checkpoint` decorator to any function:
30
-
31
- ```python
32
- from checkpointer import checkpoint
33
-
34
- @checkpoint
35
- def expensive_function(x: int) -> int:
36
- print("Computing...")
37
- return x ** 2
38
-
39
- result = expensive_function(4) # Computes and stores the result
40
- result = expensive_function(4) # Loads from the cache
41
- ```
42
-
43
- ## 🧠 How It Works
44
-
45
- When a `@checkpoint`-decorated function is called, `checkpointer` first computes a unique identifier (hash) for the call. This hash is derived from the function's source code, its dependencies, and the arguments passed.
46
-
47
- It then tries to retrieve a cached result using this ID. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
48
-
49
- Cache validity is determined by this function's hash, which automatically updates if:
50
-
51
- * **Function Code Changes**: The decorated function's source code is modified.
52
- * **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated) is modified.
53
- * **Captured Variables Change** (with `capture=True`): Global or closure-based variables used within the function are altered.
54
-
55
- **Example: Dependency Invalidation**
56
-
57
- ```python
58
- def multiply(a, b):
59
- return a * b
60
-
61
- @checkpoint
62
- def helper(x):
63
- # Depends on `multiply`
64
- return multiply(x + 1, 2)
65
-
66
- @checkpoint
67
- def compute(a, b):
68
- # Depends on `helper` and `multiply`
69
- return helper(a) + helper(b)
70
- ```
71
-
72
- If `multiply` is modified, caches for both `helper` and `compute` are automatically invalidated and recomputed.
73
-
74
- ## 💡 Usage
75
-
76
- Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
77
-
78
- * **`expensive_function(...)`**:\
79
- Call the function normally. This will either compute and cache the result or load it from the cache if available.
80
-
81
- * **`expensive_function.rerun(...)`**:\
82
- Forces the original function to execute, compute a new result, and overwrite any existing cached value for the given arguments.
83
-
84
- * **`expensive_function.fn(...)`**:\
85
- Calls the original, undecorated function directly, bypassing the cache entirely. This is particularly useful within recursive functions to prevent caching intermediate steps.
86
-
87
- * **`expensive_function.get(...)`**:\
88
- Attempts to retrieve the cached result for the given arguments without executing the original function. Raises `CheckpointError` if no valid cached result exists.
89
-
90
- * **`expensive_function.exists(...)`**:\
91
- Checks if a cached result exists for the given arguments without attempting to compute or load it. Returns `True` if a valid checkpoint exists, `False` otherwise.
92
-
93
- * **`expensive_function.delete(...)`**:\
94
- Removes the cached entry for the specified arguments.
95
-
96
- * **`expensive_function.reinit()`**:\
97
- Recalculates the function's internal hash. This is primarily used when `capture=True` and you need to update the cache based on changes to external variables within the same Python session.
98
-
99
- ## ⚙️ Configuration & Customization
100
-
101
- The `@checkpoint` decorator accepts the following parameters to customize its behavior:
102
-
103
- * **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
104
- Defines the storage backend to use. Built-in options are `"pickle"` (disk-based, persistent) and `"memory"` (in-memory, non-persistent). You can also provide a custom `Storage` class.
105
-
106
- * **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
107
- The base directory for storing disk-based checkpoints. This parameter is only relevant when `format` is set to `"pickle"`.
108
-
109
- * **`when`** (Type: `bool`, Default: `True`)\
110
- A boolean flag to enable or disable checkpointing for the decorated function. This is particularly useful for toggling caching based on environment variables (e.g., `when=os.environ.get("ENABLE_CACHING", "false").lower() == "true"`).
111
-
112
- * **`capture`** (Type: `bool`, Default: `False`)\
113
- If set to `True`, `checkpointer` includes global or closure-based variables used by the function in its hash calculation. This ensures that changes to these external variables also trigger cache invalidation and recomputation.
114
-
115
- * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
116
- A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
117
-
118
- * **`fn_hash_from`** (Type: `Any`, Default: `None`)\
119
- This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
120
-
121
- * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
122
- Controls the level of logging output from `checkpointer`.
123
- * `0`: No output.
124
- * `1`: Shows when functions are computed and cached.
125
- * `2`: Also shows when cached results are remembered (loaded from cache).
126
-
127
- ## 🔬 Customize Argument Hashing
128
-
129
- You can customize how individual function arguments are hashed without changing their actual values when passed in.
130
-
131
- * **`Annotated[T, HashBy[fn]]`**:\
132
- Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
133
-
134
- * **`NoHash[T]`**:\
135
- Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
136
-
137
- **Example:**
138
-
139
- ```python
140
- from typing import Annotated
141
- from checkpointer import checkpoint, HashBy, NoHash
142
- from pathlib import Path
143
- import logging
144
-
145
- def file_bytes(path: Path) -> bytes:
146
- return path.read_bytes()
147
-
148
- @checkpoint
149
- def process(
150
- numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
151
- data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
152
- log: NoHash[logging.Logger], # Exclude logger from hashing
153
- ):
154
- ...
155
- ```
156
-
157
- In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
158
-
159
- ## 🗄️ Custom Storage Backends
160
-
161
- For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
162
-
163
- Within custom storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
164
-
165
- **Example: Custom Storage Backend**
166
-
167
- ```python
168
- from checkpointer import checkpoint, Storage
169
- from datetime import datetime
170
-
171
- class MyCustomStorage(Storage):
172
- def exists(self, call_hash):
173
- # Example: Constructing a path based on function ID and call ID
174
- fn_dir = self.checkpointer.root_path / self.fn_id()
175
- return (fn_dir / call_hash).exists()
176
-
177
- def store(self, call_hash, data):
178
- ... # Store the serialized data for `call_hash`
179
- return data # This method must return the data back to checkpointer
180
-
181
- def checkpoint_date(self, call_hash): ...
182
- def load(self, call_hash): ...
183
- def delete(self, call_hash): ...
184
-
185
- @checkpoint(format=MyCustomStorage)
186
- def custom_cached_function(x: int):
187
- return x ** 2
188
- ```
189
-
190
- ## 🧱 Layered Caching
191
-
192
- You can apply multiple `@checkpoint` decorators to a single function to create layered caching strategies. `checkpointer` processes these decorators from bottom to top, meaning the decorator closest to the function definition is evaluated first.
193
-
194
- This is useful for scenarios like combining a fast, ephemeral cache (e.g., in-memory) with a persistent, slower cache (e.g., disk-based).
195
-
196
- **Example: Memory Cache over Disk Cache**
197
-
198
- ```python
199
- from checkpointer import checkpoint
200
-
201
- @checkpoint(format="memory") # Layer 2: Fast, ephemeral in-memory cache
202
- @checkpoint(format="pickle") # Layer 1: Persistent disk cache
203
- def some_expensive_operation():
204
- print("Performing a time-consuming operation...")
205
- return sum(i for i in range(10**7))
206
- ```
207
-
208
- ## ⚡ Async Support
209
-
210
- `checkpointer` works seamlessly with Python's `asyncio` and other async runtimes.
211
-
212
- ```python
213
- import asyncio
214
- from checkpointer import checkpoint
215
-
216
- @checkpoint
217
- async def async_compute_sum(a: int, b: int) -> int:
218
- print(f"Asynchronously computing {a} + {b}...")
219
- await asyncio.sleep(1)
220
- return a + b
221
-
222
- async def main():
223
- # First call computes and caches
224
- result1 = await async_compute_sum(3, 7)
225
- print(f"Result 1: {result1}")
226
-
227
- # Second call loads from cache
228
- result2 = await async_compute_sum(3, 7)
229
- print(f"Result 2: {result2}")
230
-
231
- # Retrieve from cache without re-running the async function
232
- result3 = async_compute_sum.get(3, 7)
233
- print(f"Result 3 (from cache): {result3}")
234
-
235
- asyncio.run(main())
236
- ```
@@ -1,17 +0,0 @@
1
- checkpointer/__init__.py,sha256=x8LbL-URPg-afn0O-HMxPsJkV9cmW-Cw5epKSCGClOM,889
2
- checkpointer/checkpoint.py,sha256=2K2D_aYHpI1XBcNrWQxta06wRdUj81TeCXshBVxl7cA,9869
3
- checkpointer/fn_ident.py,sha256=4qg4NIvcWRV5GduO70an_iu12dn8LLIlw3Q8F3gm0Mo,6826
4
- checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
5
- checkpointer/object_hash.py,sha256=MkrwSJJYVlWOjIDpGfYpAeYFPgCJhXHYBdlKVcUy2kw,8145
6
- checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
7
- checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
8
- checkpointer/utils.py,sha256=rq7AjR0WJql1o8clKMdXj4p3wrPYMLyS6G11nvNNjFE,2934
9
- checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
10
- checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
11
- checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
12
- checkpointer/storages/storage.py,sha256=5Jel7VlmCG9gnIbZFKT1NrEiAePz8ZD8hfsD-tYEiP4,905
13
- checkpointer-2.12.0.dist-info/METADATA,sha256=B-R8Aq2caTM7fclANY7MJVZJQaU0nB1Mdf7wMr6HSTI,10705
14
- checkpointer-2.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- checkpointer-2.12.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
16
- checkpointer-2.12.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
17
- checkpointer-2.12.0.dist-info/RECORD,,