checkpointer 2.12.0__py3-none-any.whl → 2.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -8,8 +8,8 @@ from .types import AwaitableValue, Captured, CapturedOnce, CaptureMe, CaptureMeO
8
8
 
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
11
- memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
- tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
11
+ memory_checkpoint = Checkpointer(storage="memory", verbosity=0)
12
+ tmp_checkpoint = Checkpointer(directory=f"{tempfile.gettempdir()}/checkpoints")
13
13
  static_checkpoint = Checkpointer(fn_hash_from=())
14
14
 
15
15
  def cleanup_all(invalidated=True, expired=True):
@@ -21,8 +21,8 @@ class CheckpointError(Exception):
21
21
  pass
22
22
 
23
23
  class CheckpointerOpts(TypedDict, total=False):
24
- format: Type[Storage] | StorageType
25
- root_path: Path | str | None
24
+ storage: Type[Storage] | StorageType
25
+ directory: Path | str | None
26
26
  when: bool
27
27
  verbosity: Literal[0, 1, 2]
28
28
  should_expire: Callable[[datetime], bool] | None
@@ -31,8 +31,8 @@ class CheckpointerOpts(TypedDict, total=False):
31
31
 
32
32
  class Checkpointer:
33
33
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
34
- self.format = opts.get("format", "pickle")
35
- self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
34
+ self.storage = opts.get("storage", "pickle")
35
+ self.directory = Path(opts.get("directory", DEFAULT_DIR) or ".")
36
36
  self.when = opts.get("when", True)
37
37
  self.verbosity = opts.get("verbosity", 1)
38
38
  self.should_expire = opts.get("should_expire")
@@ -56,24 +56,46 @@ class FunctionIdent:
56
56
  Separated from CachedFunction to prevent hash desynchronization
57
57
  among bound instances when `.reinit()` is called.
58
58
  """
59
- def __init__(self, cached_fn: CachedFunction):
60
- self.__dict__.clear()
59
+ __slots__ = (
60
+ "checkpointer", "cached_fn", "fn", "fn_dir", "pos_names",
61
+ "arg_names", "default_args", "hash_by_map", "__dict__",
62
+ )
63
+
64
+ def __init__(self, cached_fn: CachedFunction, checkpointer: Checkpointer, fn: Callable):
65
+ wrapped = unwrap(fn)
66
+ fn_file = Path(wrapped.__code__.co_filename).name
67
+ fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
68
+ params = list(signature(wrapped).parameters.values())
69
+ pos_param_types = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
70
+ named_param_types = (Parameter.KEYWORD_ONLY,) + pos_param_types
71
+ name_by_kind = {Parameter.VAR_POSITIONAL: b"*", Parameter.VAR_KEYWORD: b"**"}
72
+ self.checkpointer = checkpointer
61
73
  self.cached_fn = cached_fn
74
+ self.fn = fn
75
+ self.fn_dir = f"{fn_file}/{fn_name}"
76
+ self.pos_names = [param.name for param in params if param.kind in pos_param_types]
77
+ self.arg_names = {param.name for param in params if param.kind in named_param_types}
78
+ self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
79
+ self.hash_by_map = {
80
+ name_by_kind.get(param.kind, param.name): hash_by
81
+ for param in params
82
+ if (hash_by := hash_by_from_annotation(param.annotation))
83
+ }
62
84
 
63
85
  def reset(self):
64
- self.__init__(self.cached_fn)
86
+ self.__dict__.clear()
65
87
 
66
88
  def is_static(self) -> bool:
67
- return self.cached_fn.checkpointer.fn_hash_from is not None
89
+ return self.checkpointer.fn_hash_from is not None
68
90
 
69
91
  @cached_property
70
92
  def raw_ident(self) -> RawFunctionIdent:
71
- return get_fn_ident(unwrap(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
93
+ return get_fn_ident(unwrap(self.fn), self.checkpointer.capture)
72
94
 
73
95
  @cached_property
74
96
  def fn_hash(self) -> str:
75
97
  if self.is_static():
76
- return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
98
+ return str(ObjectHash(self.checkpointer.fn_hash_from, digest_size=16))
77
99
  depends = self.deep_idents(past_static=False)
78
100
  deep_hashes = [d.fn_hash if d.is_static() else d.raw_ident.fn_hash for d in depends]
79
101
  return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
@@ -105,29 +127,21 @@ class FunctionIdent:
105
127
 
106
128
  class CachedFunction(Generic[Fn]):
107
129
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
108
- wrapped = unwrap(fn)
109
- fn_file = Path(wrapped.__code__.co_filename).name
110
- fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
111
- Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
112
- update_wrapper(cast(Callable, self), wrapped)
113
- self.checkpointer = checkpointer
114
- self.fn = fn
115
- self.fn_dir = f"{fn_file}/{fn_name}"
130
+ store_format = checkpointer.storage
131
+ Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
132
+ update_wrapper(cast(Callable, self), unwrap(fn))
133
+ self.ident = FunctionIdent(self, checkpointer, fn)
116
134
  self.storage = Storage(self)
117
- self.cleanup = self.storage.cleanup
118
135
  self.bound = ()
119
136
 
120
- params = list(signature(wrapped).parameters.values())
121
- pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
122
- self.arg_names = [param.name for param in params if param.kind in pos_params]
123
- self.default_args = {param.name: param.default for param in params if param.default is not Parameter.empty}
124
- self.hash_by_map = get_hash_by_map(params)
125
- self.ident = FunctionIdent(self)
126
-
127
137
  @overload
128
138
  def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
129
139
  @overload
130
- def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
140
+ def __get__(
141
+ self: CachedFunction[Callable[Concatenate[C, P], R]],
142
+ instance: C,
143
+ owner: Type[C],
144
+ ) -> CachedFunction[Callable[P, R]]: ...
131
145
  def __get__(self, instance, owner):
132
146
  if instance is None:
133
147
  return self
@@ -137,8 +151,12 @@ class CachedFunction(Generic[Fn]):
137
151
  return bound_fn
138
152
 
139
153
  @property
140
- def depends(self) -> list[Callable]:
141
- return self.ident.raw_ident.depends
154
+ def fn(self) -> Fn:
155
+ return cast(Fn, self.ident.fn)
156
+
157
+ @property
158
+ def cleanup(self):
159
+ return self.storage.cleanup
142
160
 
143
161
  def reinit(self, recursive=False) -> CachedFunction[Fn]:
144
162
  depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
@@ -147,52 +165,55 @@ class CachedFunction(Generic[Fn]):
147
165
  return self
148
166
 
149
167
  def _get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
168
+ ident = self.ident
150
169
  args = self.bound + args
151
- pos_args = args[len(self.arg_names):]
152
- named_pos_args = dict(zip(self.arg_names, args))
153
- named_args = {**self.default_args, **named_pos_args, **kw}
154
- if hash_by_map := self.hash_by_map:
155
- rest_hash_by = hash_by_map.get(b"**")
156
- for key, value in named_args.items():
157
- if hash_by := hash_by_map.get(key, rest_hash_by):
158
- named_args[key] = hash_by(value)
159
- if pos_hash_by := hash_by_map.get(b"*"):
160
- pos_args = map(pos_hash_by, pos_args)
170
+ pos_args = args[len(ident.pos_names):]
171
+ named_pos_args = dict(zip(ident.pos_names, args))
172
+ named_args = {**ident.default_args, **named_pos_args, **kw}
173
+ for key, hash_by in ident.hash_by_map.items():
174
+ if isinstance(key, str):
175
+ named_args[key] = hash_by(named_args[key])
176
+ elif key == b"*":
177
+ pos_args = map(hash_by, pos_args)
178
+ elif key == b"**":
179
+ for key in kw.keys() - ident.arg_names:
180
+ named_args[key] = hash_by(named_args[key])
161
181
  named_args_iter = chain.from_iterable(sorted(named_args.items()))
162
- captured = chain.from_iterable(capturable.capture() for capturable in self.ident.capturables)
163
- obj_hash = ObjectHash(digest_size=16) \
182
+ captured = chain.from_iterable(capturable.capture() for capturable in ident.capturables)
183
+ call_hash = ObjectHash(digest_size=16) \
164
184
  .update(iter=named_args_iter, header="NAMED") \
165
185
  .update(iter=pos_args, header="POS") \
166
186
  .update(iter=captured, header="CAPTURED")
167
- return str(obj_hash)
187
+ return str(call_hash)
168
188
 
169
189
  def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
170
190
  return self._get_call_hash(args, kw)
171
191
 
172
- async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
192
+ async def _store_coroutine(self, call_hash: str, coroutine: Coroutine):
173
193
  return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
174
194
 
175
195
  def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
176
196
  full_args = self.bound + args
177
- params = self.checkpointer
197
+ params = self.ident.checkpointer
198
+ storage = self.storage
178
199
  if not params.when:
179
200
  return self.fn(*full_args, **kw)
180
201
 
181
202
  call_hash = self._get_call_hash(args, kw)
182
- call_id = f"{self.storage.fn_id()}/{call_hash}"
203
+ call_id = f"{storage.fn_id()}/{call_hash}"
183
204
  refresh = rerun \
184
- or not self.storage.exists(call_hash) \
185
- or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_hash)))
205
+ or not storage.exists(call_hash) \
206
+ or (params.should_expire and params.should_expire(storage.checkpoint_date(call_hash)))
186
207
 
187
208
  if refresh:
188
209
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id, "blue")
189
210
  data = self.fn(*full_args, **kw)
190
211
  if iscoroutine(data):
191
- return self._resolve_coroutine(call_hash, data)
192
- return self.storage.store(call_hash, data)
212
+ return self._store_coroutine(call_hash, data)
213
+ return storage.store(call_hash, data)
193
214
 
194
215
  try:
195
- data = self.storage.load(call_hash)
216
+ data = storage.load(call_hash)
196
217
  print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id, "green")
197
218
  return data
198
219
  except (EOFError, FileNotFoundError):
@@ -232,15 +253,6 @@ class CachedFunction(Generic[Fn]):
232
253
  self.storage.store(self._get_call_hash(args, kw), value)
233
254
 
234
255
  def __repr__(self) -> str:
235
- return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
236
-
237
- def get_hash_by_map(params: list[Parameter]) -> dict[str | bytes, Callable[[object], object]]:
238
- hash_by_map = {}
239
- for param in params:
240
- name = param.name
241
- if param.kind == Parameter.VAR_POSITIONAL:
242
- name = b"*"
243
- elif param.kind == Parameter.VAR_KEYWORD:
244
- name = b"**"
245
- hash_by_map[name] = hash_by_from_annotation(param.annotation)
246
- return hash_by_map if any(hash_by_map.values()) else {}
256
+ initialized = "fn_hash" in self.ident.__dict__
257
+ fn_hash = self.ident.fn_hash[:6] if initialized else "- uninitialized"
258
+ return f"<CachedFunction {self.fn.__name__} {fn_hash}>"
checkpointer/fn_ident.py CHANGED
@@ -2,11 +2,12 @@ import dis
2
2
  from inspect import Parameter, getmodule, signature, unwrap
3
3
  from types import CodeType, MethodType, ModuleType
4
4
  from typing import Annotated, Callable, Iterable, NamedTuple, Type, get_args, get_origin
5
+ from .fn_string import get_fn_aststr
5
6
  from .import_mappings import resolve_annotation
6
7
  from .object_hash import ObjectHash
7
8
  from .types import hash_by_from_annotation, is_capture_me, is_capture_me_once, to_none
8
9
  from .utils import (
9
- AttrDict, cwd, distinct, get_cell_contents,
10
+ cwd, distinct, get_at, get_cell_contents,
10
11
  get_file, is_class, is_user_fn, seekable, takewhile,
11
12
  )
12
13
 
@@ -28,61 +29,61 @@ class Capturable(NamedTuple):
28
29
  def capture(self) -> tuple[str, object]:
29
30
  if obj := self.hash:
30
31
  return self.key, obj
31
- obj = AttrDict.get_at(self.module, *self.attr_path)
32
+ obj = get_at(self.module, *self.attr_path)
32
33
  obj = self.hash_by(obj) if self.hash_by else obj
33
34
  return self.key, obj
34
35
 
35
36
  @staticmethod
36
37
  def new(module: ModuleType, attr_path: AttrPath, hash_by: Callable | None, capture_once: bool) -> "Capturable":
37
38
  file = str(get_file(module).relative_to(cwd))
38
- key = "-".join((file, *attr_path))
39
+ key = file + "/" + ".".join(attr_path)
39
40
  cap = Capturable(key, module, attr_path, hash_by)
40
41
  if not capture_once:
41
42
  return cap
42
43
  obj_hash = str(ObjectHash(cap.capture()[1]))
43
44
  return Capturable(key, module, attr_path, None, obj_hash)
44
45
 
45
- def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
46
+ def extract_classvars(code: CodeType, scope_vars: dict) -> dict[str, dict[str, Type]]:
46
47
  attr_path = AttrPath(())
47
48
  scope_obj = None
48
49
  classvars: dict[str, dict[str, Type]] = {}
49
50
  instructs = seekable(dis.get_instructions(code))
50
- for instr in instructs:
51
- if instr.opname in scope_vars and not attr_path:
51
+ for instruct in instructs:
52
+ if instruct.opname in scope_vars and not attr_path:
52
53
  attrs = takewhile((x.opname == "LOAD_ATTR", x.argval) for x in instructs)
53
- attr_path = AttrPath((instr.opname, instr.argval, *attrs))
54
+ attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
54
55
  instructs.step(-1)
55
- elif instr.opname == "CALL":
56
- obj = scope_vars.get_at(*attr_path)
56
+ elif instruct.opname == "CALL":
57
+ obj = get_at(scope_vars, *attr_path)
57
58
  attr_path = AttrPath(())
58
59
  if is_class(obj):
59
60
  scope_obj = obj
60
- elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
61
- load_key = instr.opname.replace("STORE", "LOAD")
62
- classvars.setdefault(load_key, {})[instr.argval] = scope_obj
61
+ elif instruct.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
62
+ load_key = instruct.opname.replace("STORE", "LOAD")
63
+ classvars.setdefault(load_key, {})[instruct.argval] = scope_obj
63
64
  scope_obj = None
64
65
  return classvars
65
66
 
66
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[AttrPath, object]]:
67
+ def extract_scope_values(code: CodeType, scope_vars: dict) -> Iterable[tuple[AttrPath, object]]:
67
68
  classvars = extract_classvars(code, scope_vars)
68
- scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
69
+ scope_vars = {**scope_vars, **{k: {**scope_vars[k], **v} for k, v in classvars.items()}}
69
70
  instructs = seekable(dis.get_instructions(code))
70
- for instr in instructs:
71
- if instr.opname in scope_vars:
71
+ for instruct in instructs:
72
+ if instruct.opname in scope_vars:
72
73
  attrs = takewhile((x.opname in ("LOAD_ATTR", "LOAD_METHOD"), x.argval) for x in instructs)
73
- attr_path = AttrPath((instr.opname, instr.argval, *attrs))
74
+ attr_path = AttrPath((instruct.opname, instruct.argval, *attrs))
74
75
  parent_path = attr_path[:-1]
75
76
  instructs.step(-1)
76
- obj = scope_vars.get_at(*attr_path)
77
+ obj = get_at(scope_vars, *attr_path)
77
78
  if obj is not None:
78
79
  yield attr_path, obj
79
80
  if callable(obj) and parent_path[1:]:
80
- parent_obj = scope_vars.get_at(*parent_path)
81
+ parent_obj = get_at(scope_vars, *parent_path)
81
82
  yield parent_path, parent_obj
82
83
  for const in code.co_consts:
83
84
  if isinstance(const, CodeType):
84
- next_deref = scope_vars.LOAD_DEREF.set(scope_vars.LOAD_FAST)
85
- next_scope_vars = AttrDict({**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref})
85
+ next_deref = {**scope_vars["LOAD_DEREF"], **scope_vars["LOAD_FAST"]}
86
+ next_scope_vars = {**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref}
86
87
  yield from extract_scope_values(const, next_scope_vars)
87
88
 
88
89
  def resolve_class_annotations(anno: object) -> Type | None:
@@ -94,11 +95,11 @@ def resolve_class_annotations(anno: object) -> Type | None:
94
95
  return resolve_class_annotations(next(iter(get_args(anno)), None))
95
96
  return resolve_class_annotations(get_origin(anno))
96
97
 
97
- def get_self_value(fn: Callable) -> type | object | None:
98
+ def get_self_value(fn: Callable) -> Type | object | None:
98
99
  if isinstance(fn, MethodType):
99
100
  return fn.__self__
100
101
  parts = fn.__qualname__.split(".")[:-1]
101
- cls = parts and AttrDict(fn.__globals__).get_at(*parts)
102
+ cls = parts and get_at(fn.__globals__, *parts)
102
103
  if is_class(cls):
103
104
  return cls
104
105
 
@@ -116,19 +117,20 @@ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, o
116
117
  yield Capturable.new(module, attr_path, hash_by, is_capture_me_once(anno))
117
118
 
118
119
  def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[Capturable]]:
119
- sig_scope = {
120
+ scope_vars_signature: dict[str, Type | object] = {
120
121
  param.name: class_anno
121
122
  for param in signature(fn).parameters.values()
122
123
  if param.annotation is not Parameter.empty
123
124
  if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
124
125
  if (class_anno := resolve_class_annotations(param.annotation))
125
126
  }
126
- self_value = get_self_value(fn)
127
- scope_vars = AttrDict({
128
- "LOAD_FAST": AttrDict({**sig_scope, "self": self_value} if self_value else sig_scope),
129
- "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
130
- "LOAD_GLOBAL": AttrDict(fn.__globals__),
131
- })
127
+ if self_obj := get_self_value(fn):
128
+ scope_vars_signature["self"] = self_obj
129
+ scope_vars = {
130
+ "LOAD_FAST": scope_vars_signature,
131
+ "LOAD_DEREF": dict(get_cell_contents(fn)),
132
+ "LOAD_GLOBAL": fn.__globals__,
133
+ }
132
134
  captured_vars = dict(extract_scope_values(fn.__code__, scope_vars))
133
135
  captured_callables = [obj for obj in captured_vars.values() if callable(obj)]
134
136
  capturables = list(get_capturables(fn, capture, captured_vars))
@@ -142,7 +144,7 @@ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn
142
144
  for depend_fn in captured_callables:
143
145
  depend_fn = unwrap(depend_fn, stop=lambda f: isinstance(f, CachedFunction))
144
146
  if isinstance(depend_fn, CachedFunction):
145
- capturable_by_fn[depend_fn] = []
147
+ capturable_by_fn[depend_fn.ident.cached_fn] = []
146
148
  elif depend_fn not in capturable_by_fn and is_user_fn(depend_fn):
147
149
  get_depend_fns(depend_fn, capture, capturable_by_fn)
148
150
  return capturable_by_fn
@@ -153,7 +155,7 @@ def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
153
155
  capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
154
156
  depends = capturable_by_fn.keys()
155
157
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
156
- unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
157
- assert fn == unwrapped_depends[0]
158
- fn_hash = str(ObjectHash(iter=unwrapped_depends))
158
+ depend_callables = [fn for fn in depends if not isinstance(fn, CachedFunction)]
159
+ assert fn == depend_callables[0]
160
+ fn_hash = str(ObjectHash(iter=map(get_fn_aststr, depend_callables)))
159
161
  return RawFunctionIdent(fn_hash, depends, capturables)
@@ -0,0 +1,77 @@
1
+ import ast
2
+ import sys
3
+ from inspect import getsource
4
+ from textwrap import dedent
5
+ from typing import Callable
6
+ from .utils import drop_none, get_at
7
+
8
+ def get_decorator_path(node: ast.AST) -> tuple[str, ...]:
9
+ if isinstance(node, ast.Call):
10
+ return get_decorator_path(node.func)
11
+ elif isinstance(node, ast.Attribute):
12
+ return get_decorator_path(node.value) + (node.attr,)
13
+ elif isinstance(node, ast.Name):
14
+ return (node.id,)
15
+ else:
16
+ return ()
17
+
18
+ def is_empty_expression(node: ast.AST) -> bool:
19
+ # Filter out docstrings
20
+ return isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant)
21
+
22
+ class CleanFunctionTransform(ast.NodeTransformer):
23
+ def __init__(self, fn_globals: dict):
24
+ self.is_root = True
25
+ self.fn_globals = fn_globals
26
+
27
+ def is_checkpointer(self, node: ast.AST) -> bool:
28
+ from .checkpoint import Checkpointer
29
+ return isinstance(get_at(self.fn_globals, *get_decorator_path(node)), Checkpointer)
30
+
31
+ def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
32
+ fn_type = type(node).__name__
33
+ fn_name = None if self.is_root else node.name
34
+ args_by_type = [
35
+ node.args.posonlyargs + node.args.args,
36
+ drop_none([node.args.vararg]),
37
+ sorted(node.args.kwonlyargs, key=lambda x: x.arg),
38
+ drop_none([node.args.kwarg]),
39
+ ]
40
+ arg_kind_names = ",".join(f"{i}:{arg.arg}" for i, args in enumerate(args_by_type) for arg in args)
41
+ header = " ".join(drop_none((fn_type, fn_name, arg_kind_names or None)))
42
+
43
+ self.is_root = False
44
+
45
+ return ast.List([
46
+ ast.Constant(header),
47
+ ast.List([child for child in node.decorator_list if not self.is_checkpointer(child)], ast.Load()),
48
+ ast.List([self.visit(child) for child in node.body if not is_empty_expression(child)], ast.Load()),
49
+ ], ast.Load())
50
+
51
+ def visit_AsyncFunctionDef(self, node):
52
+ return self.visit_FunctionDef(node)
53
+
54
+ def get_fn_aststr(fn: Callable) -> str:
55
+ try:
56
+ source = getsource(fn)
57
+ except OSError:
58
+ return ""
59
+ try:
60
+ tree = ast.parse(dedent(source), mode="exec")
61
+ tree = tree.body[0]
62
+ except SyntaxError:
63
+ # lambda functions can cause SyntaxError in ast.parse
64
+ return source.strip()
65
+
66
+ if fn.__name__ != "<lambda>":
67
+ tree = CleanFunctionTransform(fn.__globals__).visit(tree)
68
+ else:
69
+ for node in ast.walk(tree):
70
+ if isinstance(node, ast.Lambda):
71
+ tree = node
72
+ break
73
+
74
+ if sys.version_info >= (3, 13):
75
+ return ast.dump(tree, annotate_fields=False, show_empty=True)
76
+ else:
77
+ return ast.dump(tree, annotate_fields=False)
@@ -5,32 +5,35 @@ import io
5
5
  import re
6
6
  import sys
7
7
  import tokenize
8
+ from collections import OrderedDict
8
9
  from collections.abc import Iterable
9
10
  from contextlib import nullcontext, suppress
10
11
  from decimal import Decimal
11
12
  from io import StringIO
12
13
  from itertools import chain
13
14
  from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
14
- from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
15
+ from types import BuiltinFunctionType, FunctionType, GeneratorType, MappingProxyType, MethodType, ModuleType, UnionType
15
16
  from typing import Callable, Self, TypeVar
16
17
  from .utils import ContextVar
17
18
 
18
19
  np, torch = None, None
19
20
 
20
- with suppress(Exception):
21
- import numpy as np
22
- with suppress(Exception):
23
- import torch
24
-
25
21
  class _Never:
26
22
  def __getattribute__(self, _: str):
27
23
  pass
28
24
 
25
+ with suppress(Exception):
26
+ import numpy as np
27
+ with suppress(Exception):
28
+ import torch
29
29
  if sys.version_info >= (3, 12):
30
30
  from typing import TypeAliasType
31
31
  else:
32
32
  TypeAliasType = _Never
33
33
 
34
+ flatten = chain.from_iterable
35
+ nc = nullcontext()
36
+
34
37
  def encode_type(t: type | FunctionType) -> str:
35
38
  return f"{t.__module__}:{t.__qualname__}"
36
39
 
@@ -77,7 +80,7 @@ class ObjectHash:
77
80
  return self.write_bytes(":".join(map(str, args)).encode())
78
81
 
79
82
  def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
80
- with nullcontext() if tolerable is None else self.tolerable.set(tolerable):
83
+ with nc if tolerable is None else self.tolerable.set(tolerable):
81
84
  for obj in chain(objs, iter):
82
85
  if header is not None:
83
86
  self.write_bytes(header.encode())
@@ -105,11 +108,11 @@ class ObjectHash:
105
108
 
106
109
  case set() | frozenset():
107
110
  try:
108
- items = sorted(obj)
109
111
  header = "set"
112
+ items = sorted(obj)
110
113
  except:
111
- items = sorted(map(self.nested_hash, obj))
112
114
  header = "set-unsortable"
115
+ items = sorted(map(self.nested_hash, obj))
113
116
  self.header(header, encode_type_of(obj), len(obj)).update(iter=items)
114
117
 
115
118
  case TypeVar():
@@ -170,14 +173,16 @@ class ObjectHash:
170
173
  match obj:
171
174
  case list() | tuple():
172
175
  self.header("list", encode_type_of(obj), len(obj)).update(iter=obj)
173
- case dict():
174
- try:
175
- items = sorted(obj.items())
176
- header = "dict"
177
- except:
178
- items = sorted((self.nested_hash(key), val) for key, val in obj.items())
179
- header = "dict-unsortable"
180
- self.header(header, encode_type_of(obj), len(obj)).update(iter=chain.from_iterable(items))
176
+ case dict() | MappingProxyType():
177
+ header = "dict"
178
+ items = obj.items()
179
+ if not isinstance(obj, OrderedDict):
180
+ try:
181
+ items = sorted(items)
182
+ except:
183
+ header = "dict-unsortable"
184
+ items = sorted((self.nested_hash(key), val) for key, val in items)
185
+ self.header(header, encode_type_of(obj), len(obj)).update(iter=flatten(items))
181
186
  case _:
182
187
  self._update_object(obj)
183
188
  finally:
@@ -8,17 +8,18 @@ if TYPE_CHECKING:
8
8
 
9
9
  class Storage:
10
10
  checkpointer: Checkpointer
11
- cached_fn: CachedFunction
11
+ ident: CachedFunction
12
12
 
13
13
  def __init__(self, cached_fn: CachedFunction):
14
- self.checkpointer = cached_fn.checkpointer
14
+ self.checkpointer = cached_fn.ident.checkpointer
15
15
  self.cached_fn = cached_fn
16
16
 
17
17
  def fn_id(self) -> str:
18
- return f"{self.cached_fn.fn_dir}/{self.cached_fn.ident.fn_hash}"
18
+ ident = self.cached_fn.ident
19
+ return f"{ident.fn_dir}/{ident.fn_hash}"
19
20
 
20
21
  def fn_dir(self) -> Path:
21
- return self.checkpointer.root_path / self.fn_id()
22
+ return self.checkpointer.directory / self.fn_id()
22
23
 
23
24
  def store(self, call_hash: str, data: Any) -> Any: ...
24
25
 
checkpointer/utils.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
  import inspect
3
- from contextlib import contextmanager
3
+ from contextlib import contextmanager, suppress
4
4
  from itertools import islice
5
5
  from pathlib import Path
6
6
  from types import FunctionType, MethodType, ModuleType
@@ -23,10 +23,11 @@ def is_user_fn(obj) -> TypeGuard[Callable]:
23
23
 
24
24
  def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
25
25
  for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
26
- try:
26
+ with suppress(ValueError):
27
27
  yield (key, cell.cell_contents)
28
- except ValueError:
29
- pass
28
+
29
+ def drop_none(iterable: Iterable[T | None]) -> list[T]:
30
+ return [x for x in iterable if x is not None]
30
31
 
31
32
  def distinct(seq: Iterable[T]) -> list[T]:
32
33
  return list(dict.fromkeys(seq))
@@ -80,6 +81,14 @@ class seekable(Generic[T]):
80
81
  with self.freeze():
81
82
  return list(islice(self, count))
82
83
 
84
+ def get_at(obj: object, *attrs: str) -> object:
85
+ for attr in attrs:
86
+ if type(obj) is dict:
87
+ obj = obj.get(attr, None)
88
+ else:
89
+ obj = getattr(obj, attr, None)
90
+ return obj
91
+
83
92
  class AttrDict(dict):
84
93
  def __init__(self, *args, **kwargs):
85
94
  super().__init__(*args, **kwargs)
@@ -91,17 +100,6 @@ class AttrDict(dict):
91
100
  def __setattr__(self, name: str, value: object):
92
101
  super().__setattr__(name, value)
93
102
 
94
- def set(self, d: dict) -> AttrDict:
95
- if not d:
96
- return self
97
- return AttrDict({**self, **d})
98
-
99
- def get_at(self: object, *attrs: str) -> object:
100
- obj = self
101
- for attr in attrs:
102
- obj = getattr(obj, attr, None)
103
- return obj
104
-
105
103
  class ContextVar(Generic[T]):
106
104
  def __init__(self, value: T):
107
105
  self.value = value
@@ -0,0 +1,260 @@
1
+ Metadata-Version: 2.4
2
+ Name: checkpointer
3
+ Version: 2.13.0
4
+ Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
+ Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
+ Author: Hampus Hallman
7
+ License-Expression: MIT
8
+ License-File: ATTRIBUTION.md
9
+ License-File: LICENSE
10
+ Keywords: async,cache,caching,data analysis,data processing,fast,hashing,invalidation,memoization,optimization,performance,workflow
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Requires-Python: >=3.11
15
+ Description-Content-Type: text/markdown
16
+
17
+ # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![pypi](https://img.shields.io/pypi/pyversions/checkpointer)](https://pypi.org/project/checkpointer/)
18
+
19
+ `checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and refreshes caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
20
+
21
+ ## 📦 Installation
22
+
23
+ ```bash
24
+ pip install checkpointer
25
+ ```
26
+
27
+ ## 🚀 Quick Start
28
+
29
+ Apply the `@checkpoint` decorator to any function:
30
+
31
+ ```python
32
+ from checkpointer import checkpoint
33
+
34
+ @checkpoint
35
+ def expensive_function(x: int) -> int:
36
+ print("Computing...")
37
+ return x ** 2
38
+
39
+ result = expensive_function(4) # Computes and stores the result
40
+ result = expensive_function(4) # Loads from the cache
41
+ ```
42
+
43
+ ## 🧠 How It Works
44
+
45
+ When a `@checkpoint`-decorated function is called, `checkpointer` computes a unique identifier for the call. This identifier derives from the function's source code, its dependencies, captured variables, and the arguments passed.
46
+
47
+ It then tries to retrieve a cached result using this identifier. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
48
+
49
+ ### 🚨 What Triggers Cache Invalidation?
50
+
51
+ `checkpointer` maintains cache correctness using two types of hashes:
52
+
53
+ #### 1. Function Identity Hash (One-Time per Function)
54
+
55
+ This hash represents the decorated function itself and is computed once (usually on first invocation). It covers:
56
+
57
+ * **Decorated Function's Code:**\
58
+ The function's logic and signature (excluding parameter type annotations) are hashed. Formatting changes like whitespace, newlines, comments, or trailing commas do **not** cause invalidation.
59
+
60
+ * **Dependencies:**\
61
+ All user-defined functions and methods the function calls or uses are included recursively. Dependencies are detected by:
62
+ * Inspecting the function's global scope for referenced functions/objects.
63
+ * Inferring from argument type annotations.
64
+ * Analyzing object constructions and method calls to identify classes and methods used.
65
+
66
+ * **Top-Level Module Code:**\
67
+ Changes unrelated to the function or its dependencies in the module do **not** trigger invalidation.
68
+
69
+ #### 2. Call Hash (Computed on Every Function Call)
70
+
71
+ Each function call's cache key (the **call hash**) combines:
72
+
73
+ * **Passed Arguments:**\
74
+ Includes positional and keyword arguments, combined with default values. Changing defaults alone doesn't necessarily trigger invalidation unless it affects actual call values.
75
+
76
+ * **Captured Global Variables:**\
77
+ When `capture=True` or explicit capture annotations are used, `checkpointer` hashes global variables referenced by the function:
78
+ * `CaptureMe` variables are hashed on every call, so changes trigger invalidation.
79
+ * `CaptureMeOnce` variables are hashed once per session for performance optimization.
80
+
81
+ * **Custom Argument Hashing:**\
82
+ Using `HashBy` annotations, arguments or captured variables can be transformed before hashing (e.g., sorting lists to ignore order), allowing more precise or efficient call hashes.
83
+
84
+ ## 💡 Usage
85
+
86
+ Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
87
+
88
+ * **`expensive_function(...)`**:\
89
+ Call the function normally. This will compute and cache the result or load it from cache.
90
+
91
+ * **`expensive_function.rerun(...)`**:\
92
+ Force the original function to execute and overwrite any existing cached result.
93
+
94
+ * **`expensive_function.fn(...)`**:\
95
+ Call the undecorated function directly, bypassing the cache (useful in recursion to prevent caching intermediate steps).
96
+
97
+ * **`expensive_function.get(...)`**:\
98
+ Retrieve the cached result without executing the function. Raises `CheckpointError` if no valid cache exists.
99
+
100
+ * **`expensive_function.exists(...)`**:\
101
+ Check if a cached result exists without computing or loading it.
102
+
103
+ * **`expensive_function.delete(...)`**:\
104
+ Remove the cached entry for given arguments.
105
+
106
+ * **`expensive_function.reinit(recursive: bool = False)`**:\
107
+ Recalculate the function identity hash and recapture `CaptureMeOnce` variables, updating the cached function state within the same Python session.
108
+
109
+ ## ⚙️ Configuration & Customization
110
+
111
+ The `@checkpoint` decorator accepts the following parameters:
112
+
113
+ * **`storage`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
114
+ Storage backend to use: `"pickle"` (disk-based, persistent), `"memory"` (in-memory, non-persistent), or a custom `Storage` class.
115
+
116
+ * **`directory`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
117
+ Base directory for disk-based checkpoints (only for `"pickle"` storage).
118
+
119
+ * **`when`** (Type: `bool`, Default: `True`)\
120
+ Enable or disable checkpointing dynamically, useful for environment-based toggling.
121
+
122
+ * **`capture`** (Type: `bool`, Default: `False`)\
123
+ If `True`, includes global variables referenced by the function in call hashes (except those excluded via `NoHash`).
124
+
125
+ * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
126
+ A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
127
+
128
+ * **`fn_hash_from`** (Type: `Any`, Default: `None`)\
129
+ Override the computed function identity hash with any hashable object you provide (e.g., version strings, config IDs). This gives you explicit control over the function's version and when its cache should be invalidated.
130
+
131
+ * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
132
+ Controls the level of logging output from `checkpointer`.
133
+ * `0`: No output.
134
+ * `1`: Shows when functions are computed and cached.
135
+ * `2`: Also shows when cached results are remembered (loaded from cache).
136
+
137
+ ## 🔬 Customize Argument Hashing
138
+
139
+ You can customize how arguments are hashed without modifying the actual argument values to improve cache hit rates or speed up hashing.
140
+
141
+ * **`Annotated[T, HashBy[fn]]`**:\
142
+ Transform the argument via `fn(argument)` before hashing. Useful for normalization (e.g., sorting lists) or optimized hashing for complex inputs.
143
+
144
+ * **`NoHash[T]`**:\
145
+ Exclude the argument from hashing completely, so changes to it won't trigger cache invalidation.
146
+
147
+ **Example:**
148
+
149
+ ```python
150
+ from typing import Annotated
151
+ from checkpointer import checkpoint, HashBy, NoHash
152
+ from pathlib import Path
153
+ import logging
154
+
155
+ def file_bytes(path: Path) -> bytes:
156
+ return path.read_bytes()
157
+
158
+ @checkpoint
159
+ def process(
160
+ numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
161
+ data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
162
+ log: NoHash[logging.Logger], # Exclude logger from hashing
163
+ ):
164
+ ...
165
+ ```
166
+
167
+ In this example, the hash for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don't affect caching.
168
+
169
+ ## 🎯 Capturing Global Variables
170
+
171
+ `checkpointer` can include **captured global variables** in call hashes - these are globals your function reads during execution that may affect results.
172
+
173
+ Use `capture=True` on `@checkpoint` to capture **all** referenced globals (except those explicitly excluded with `NoHash`).
174
+
175
+ Alternatively, you can **opt-in selectively** by annotating globals with:
176
+
177
+ * **`CaptureMe[T]`**:\
178
+ Capture the variable on every call (triggers invalidation on changes).
179
+
180
+ * **`CaptureMeOnce[T]`**:\
181
+ Capture once per Python session (for expensive, immutable globals).
182
+
183
+ You can also combine these with `HashBy` to customize how captured variables are hashed (e.g., hash by subset of attributes).
184
+
185
+ **Example:**
186
+
187
+ ```python
188
+ from typing import Annotated
189
+ from checkpointer import checkpoint, CaptureMe, CaptureMeOnce, HashBy
190
+ from pathlib import Path
191
+
192
+ def file_bytes(path: Path) -> bytes:
193
+ return path.read_bytes()
194
+
195
+ captured_data: CaptureMe[Annotated[Path, HashBy[file_bytes]]] = Path("data.txt")
196
+ session_config: CaptureMeOnce[dict] = {"mode": "prod"}
197
+
198
+ @checkpoint
199
+ def process():
200
+ # `captured_data` is included in the call hash on every call, hashed by file content
201
+ # `session_config` is hashed once per session
202
+ ...
203
+ ```
204
+
205
+ ## 🗄️ Custom Storage Backends
206
+
207
+ Implement your own storage backend by subclassing `checkpointer.Storage` and overriding required methods.
208
+
209
+ Within storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get function identity (name + hash/version), important for organizing checkpoints.
210
+
211
+ **Example:**
212
+
213
+ ```python
214
+ from checkpointer import checkpoint, Storage
215
+ from datetime import datetime
216
+
217
+ class MyCustomStorage(Storage):
218
+ def exists(self, call_hash):
219
+ fn_dir = self.checkpointer.directory / self.fn_id()
220
+ return (fn_dir / call_hash).exists()
221
+
222
+ def store(self, call_hash, data):
223
+ ... # Store serialized data
224
+ return data # Must return data to checkpointer
225
+
226
+ def checkpoint_date(self, call_hash): ...
227
+ def load(self, call_hash): ...
228
+ def delete(self, call_hash): ...
229
+
230
+ @checkpoint(storage=MyCustomStorage)
231
+ def custom_cached_function(x: int):
232
+ return x ** 2
233
+ ```
234
+
235
+ ## ⚡ Async Support
236
+
237
+ `checkpointer` works with Python's `asyncio` and other async runtimes.
238
+
239
+ ```python
240
+ import asyncio
241
+ from checkpointer import checkpoint
242
+
243
+ @checkpoint
244
+ async def async_compute_sum(a: int, b: int) -> int:
245
+ print(f"Asynchronously computing {a} + {b}...")
246
+ await asyncio.sleep(1)
247
+ return a + b
248
+
249
+ async def main():
250
+ result1 = await async_compute_sum(3, 7)
251
+ print(f"Result 1: {result1}")
252
+
253
+ result2 = await async_compute_sum(3, 7)
254
+ print(f"Result 2: {result2}")
255
+
256
+ result3 = async_compute_sum.get(3, 7)
257
+ print(f"Result 3 (from cache): {result3}")
258
+
259
+ asyncio.run(main())
260
+ ```
@@ -0,0 +1,18 @@
1
+ checkpointer/__init__.py,sha256=l14EbRTgmkPlJUJc-5uAjWmB4dQr6kXXOPAjHcqbKK8,890
2
+ checkpointer/checkpoint.py,sha256=Ylf0Yel9WiyHg7EoP355b2huS_yh0hgBUx2l3bOyI_c,10149
3
+ checkpointer/fn_ident.py,sha256=mZfGPSIidlZVHsG3Kc6jk8rTNDoN_k2tCkJmbSKRhdg,6921
4
+ checkpointer/fn_string.py,sha256=R1evcaBKoVP9SSDd741O1FoaC9SiaA3-kfn6QLxd9No,2532
5
+ checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
6
+ checkpointer/object_hash.py,sha256=NXlhME87iA9rrMRjF_Au4SKXFVc14j-SiSEsGhH7M8s,8327
7
+ checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
8
+ checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
9
+ checkpointer/utils.py,sha256=FEabT0jp7Bx2KN-uco0QDAsBJ6hgauKwEjKR4B8IKzk,2977
10
+ checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
11
+ checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
12
+ checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
13
+ checkpointer/storages/storage.py,sha256=9CZquRjw9lWpcCVCiU7E6P-eJxGBUVbNKWncsTRwmoc,916
14
+ checkpointer-2.13.0.dist-info/METADATA,sha256=vyBK4qpjq4c3S2yCCxjHTF2C5QtuQ3WDTWhEj-etz6U,10925
15
+ checkpointer-2.13.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ checkpointer-2.13.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
17
+ checkpointer-2.13.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
18
+ checkpointer-2.13.0.dist-info/RECORD,,
@@ -1,236 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: checkpointer
3
- Version: 2.12.0
4
- Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
- Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
- Author: Hampus Hallman
7
- License-Expression: MIT
8
- License-File: ATTRIBUTION.md
9
- License-File: LICENSE
10
- Keywords: async,cache,caching,code-aware,decorator,fast,hashing,invalidation,memoization,memoize,memory,optimization,performance,workflow
11
- Classifier: Programming Language :: Python :: 3.11
12
- Classifier: Programming Language :: Python :: 3.12
13
- Classifier: Programming Language :: Python :: 3.13
14
- Requires-Python: >=3.11
15
- Description-Content-Type: text/markdown
16
-
17
- # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![pypi](https://img.shields.io/pypi/pyversions/checkpointer)](https://pypi.org/project/checkpointer/)
18
-
19
- `checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and refreshes caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
20
-
21
- ## 📦 Installation
22
-
23
- ```bash
24
- pip install checkpointer
25
- ```
26
-
27
- ## 🚀 Quick Start
28
-
29
- Apply the `@checkpoint` decorator to any function:
30
-
31
- ```python
32
- from checkpointer import checkpoint
33
-
34
- @checkpoint
35
- def expensive_function(x: int) -> int:
36
- print("Computing...")
37
- return x ** 2
38
-
39
- result = expensive_function(4) # Computes and stores the result
40
- result = expensive_function(4) # Loads from the cache
41
- ```
42
-
43
- ## 🧠 How It Works
44
-
45
- When a `@checkpoint`-decorated function is called, `checkpointer` first computes a unique identifier (hash) for the call. This hash is derived from the function's source code, its dependencies, and the arguments passed.
46
-
47
- It then tries to retrieve a cached result using this ID. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
48
-
49
- Cache validity is determined by this function's hash, which automatically updates if:
50
-
51
- * **Function Code Changes**: The decorated function's source code is modified.
52
- * **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated) is modified.
53
- * **Captured Variables Change** (with `capture=True`): Global or closure-based variables used within the function are altered.
54
-
55
- **Example: Dependency Invalidation**
56
-
57
- ```python
58
- def multiply(a, b):
59
- return a * b
60
-
61
- @checkpoint
62
- def helper(x):
63
- # Depends on `multiply`
64
- return multiply(x + 1, 2)
65
-
66
- @checkpoint
67
- def compute(a, b):
68
- # Depends on `helper` and `multiply`
69
- return helper(a) + helper(b)
70
- ```
71
-
72
- If `multiply` is modified, caches for both `helper` and `compute` are automatically invalidated and recomputed.
73
-
74
- ## 💡 Usage
75
-
76
- Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
77
-
78
- * **`expensive_function(...)`**:\
79
- Call the function normally. This will either compute and cache the result or load it from the cache if available.
80
-
81
- * **`expensive_function.rerun(...)`**:\
82
- Forces the original function to execute, compute a new result, and overwrite any existing cached value for the given arguments.
83
-
84
- * **`expensive_function.fn(...)`**:\
85
- Calls the original, undecorated function directly, bypassing the cache entirely. This is particularly useful within recursive functions to prevent caching intermediate steps.
86
-
87
- * **`expensive_function.get(...)`**:\
88
- Attempts to retrieve the cached result for the given arguments without executing the original function. Raises `CheckpointError` if no valid cached result exists.
89
-
90
- * **`expensive_function.exists(...)`**:\
91
- Checks if a cached result exists for the given arguments without attempting to compute or load it. Returns `True` if a valid checkpoint exists, `False` otherwise.
92
-
93
- * **`expensive_function.delete(...)`**:\
94
- Removes the cached entry for the specified arguments.
95
-
96
- * **`expensive_function.reinit()`**:\
97
- Recalculates the function's internal hash. This is primarily used when `capture=True` and you need to update the cache based on changes to external variables within the same Python session.
98
-
99
- ## ⚙️ Configuration & Customization
100
-
101
- The `@checkpoint` decorator accepts the following parameters to customize its behavior:
102
-
103
- * **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
104
- Defines the storage backend to use. Built-in options are `"pickle"` (disk-based, persistent) and `"memory"` (in-memory, non-persistent). You can also provide a custom `Storage` class.
105
-
106
- * **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
107
- The base directory for storing disk-based checkpoints. This parameter is only relevant when `format` is set to `"pickle"`.
108
-
109
- * **`when`** (Type: `bool`, Default: `True`)\
110
- A boolean flag to enable or disable checkpointing for the decorated function. This is particularly useful for toggling caching based on environment variables (e.g., `when=os.environ.get("ENABLE_CACHING", "false").lower() == "true"`).
111
-
112
- * **`capture`** (Type: `bool`, Default: `False`)\
113
- If set to `True`, `checkpointer` includes global or closure-based variables used by the function in its hash calculation. This ensures that changes to these external variables also trigger cache invalidation and recomputation.
114
-
115
- * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
116
- A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
117
-
118
- * **`fn_hash_from`** (Type: `Any`, Default: `None`)\
119
- This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
120
-
121
- * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
122
- Controls the level of logging output from `checkpointer`.
123
- * `0`: No output.
124
- * `1`: Shows when functions are computed and cached.
125
- * `2`: Also shows when cached results are remembered (loaded from cache).
126
-
127
- ## 🔬 Customize Argument Hashing
128
-
129
- You can customize how individual function arguments are hashed without changing their actual values when passed in.
130
-
131
- * **`Annotated[T, HashBy[fn]]`**:\
132
- Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
133
-
134
- * **`NoHash[T]`**:\
135
- Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
136
-
137
- **Example:**
138
-
139
- ```python
140
- from typing import Annotated
141
- from checkpointer import checkpoint, HashBy, NoHash
142
- from pathlib import Path
143
- import logging
144
-
145
- def file_bytes(path: Path) -> bytes:
146
- return path.read_bytes()
147
-
148
- @checkpoint
149
- def process(
150
- numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
151
- data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
152
- log: NoHash[logging.Logger], # Exclude logger from hashing
153
- ):
154
- ...
155
- ```
156
-
157
- In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
158
-
159
- ## 🗄️ Custom Storage Backends
160
-
161
- For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
162
-
163
- Within custom storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
164
-
165
- **Example: Custom Storage Backend**
166
-
167
- ```python
168
- from checkpointer import checkpoint, Storage
169
- from datetime import datetime
170
-
171
- class MyCustomStorage(Storage):
172
- def exists(self, call_hash):
173
- # Example: Constructing a path based on function ID and call ID
174
- fn_dir = self.checkpointer.root_path / self.fn_id()
175
- return (fn_dir / call_hash).exists()
176
-
177
- def store(self, call_hash, data):
178
- ... # Store the serialized data for `call_hash`
179
- return data # This method must return the data back to checkpointer
180
-
181
- def checkpoint_date(self, call_hash): ...
182
- def load(self, call_hash): ...
183
- def delete(self, call_hash): ...
184
-
185
- @checkpoint(format=MyCustomStorage)
186
- def custom_cached_function(x: int):
187
- return x ** 2
188
- ```
189
-
190
- ## 🧱 Layered Caching
191
-
192
- You can apply multiple `@checkpoint` decorators to a single function to create layered caching strategies. `checkpointer` processes these decorators from bottom to top, meaning the decorator closest to the function definition is evaluated first.
193
-
194
- This is useful for scenarios like combining a fast, ephemeral cache (e.g., in-memory) with a persistent, slower cache (e.g., disk-based).
195
-
196
- **Example: Memory Cache over Disk Cache**
197
-
198
- ```python
199
- from checkpointer import checkpoint
200
-
201
- @checkpoint(format="memory") # Layer 2: Fast, ephemeral in-memory cache
202
- @checkpoint(format="pickle") # Layer 1: Persistent disk cache
203
- def some_expensive_operation():
204
- print("Performing a time-consuming operation...")
205
- return sum(i for i in range(10**7))
206
- ```
207
-
208
- ## ⚡ Async Support
209
-
210
- `checkpointer` works seamlessly with Python's `asyncio` and other async runtimes.
211
-
212
- ```python
213
- import asyncio
214
- from checkpointer import checkpoint
215
-
216
- @checkpoint
217
- async def async_compute_sum(a: int, b: int) -> int:
218
- print(f"Asynchronously computing {a} + {b}...")
219
- await asyncio.sleep(1)
220
- return a + b
221
-
222
- async def main():
223
- # First call computes and caches
224
- result1 = await async_compute_sum(3, 7)
225
- print(f"Result 1: {result1}")
226
-
227
- # Second call loads from cache
228
- result2 = await async_compute_sum(3, 7)
229
- print(f"Result 2: {result2}")
230
-
231
- # Retrieve from cache without re-running the async function
232
- result3 = async_compute_sum.get(3, 7)
233
- print(f"Result 3 (from cache): {result3}")
234
-
235
- asyncio.run(main())
236
- ```
@@ -1,17 +0,0 @@
1
- checkpointer/__init__.py,sha256=x8LbL-URPg-afn0O-HMxPsJkV9cmW-Cw5epKSCGClOM,889
2
- checkpointer/checkpoint.py,sha256=2K2D_aYHpI1XBcNrWQxta06wRdUj81TeCXshBVxl7cA,9869
3
- checkpointer/fn_ident.py,sha256=4qg4NIvcWRV5GduO70an_iu12dn8LLIlw3Q8F3gm0Mo,6826
4
- checkpointer/import_mappings.py,sha256=ESqWvZTzYAmaVnJ6NulUvn3_8CInOOPmEKUXO2WD_WA,1794
5
- checkpointer/object_hash.py,sha256=MkrwSJJYVlWOjIDpGfYpAeYFPgCJhXHYBdlKVcUy2kw,8145
6
- checkpointer/print_checkpoint.py,sha256=uUQ493fJCaB4nhp4Ox60govSCiBTIPbBX15zt2QiRGo,1356
7
- checkpointer/types.py,sha256=GFqbGACdDxzQX3bb2LmF9UxQVWOEisGvdtobnqCBAOA,1129
8
- checkpointer/utils.py,sha256=rq7AjR0WJql1o8clKMdXj4p3wrPYMLyS6G11nvNNjFE,2934
9
- checkpointer/storages/__init__.py,sha256=p-r4YrPXn505_S3qLrSXHSlsEtb13w_DFnCt9IiUomk,296
10
- checkpointer/storages/memory_storage.py,sha256=aQRSOmAfS0UudubCpv8cdfu2ycM8mlsO9tFMcD2kmgo,1133
11
- checkpointer/storages/pickle_storage.py,sha256=je1LM2lTSs5yzm25Apg5tJ9jU9T6nXCgD9SlqQRIFaM,1652
12
- checkpointer/storages/storage.py,sha256=5Jel7VlmCG9gnIbZFKT1NrEiAePz8ZD8hfsD-tYEiP4,905
13
- checkpointer-2.12.0.dist-info/METADATA,sha256=B-R8Aq2caTM7fclANY7MJVZJQaU0nB1Mdf7wMr6HSTI,10705
14
- checkpointer-2.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- checkpointer-2.12.0.dist-info/licenses/ATTRIBUTION.md,sha256=WF6L7-sD4s9t9ytVJOhjhpDoZ6TrWpqE3_bMdDIeJxI,1078
16
- checkpointer-2.12.0.dist-info/licenses/LICENSE,sha256=drXs6vIb7uW49r70UuMz2A1VtOCl626kiTbcmrar1Xo,1072
17
- checkpointer-2.12.0.dist-info/RECORD,,