checkpointer 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -4,18 +4,18 @@ from typing import Callable
4
4
  from .checkpoint import CachedFunction, Checkpointer, CheckpointError
5
5
  from .object_hash import ObjectHash
6
6
  from .storages import MemoryStorage, PickleStorage, Storage
7
- from .utils import AwaitableValue
7
+ from .types import AwaitableValue, HashBy, NoHash
8
8
 
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
11
11
  memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
12
  tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
13
- static_checkpoint = Checkpointer(fn_hash=ObjectHash())
13
+ static_checkpoint = Checkpointer(fn_hash_from=())
14
14
 
15
15
  def cleanup_all(invalidated=True, expired=True):
16
16
  for obj in gc.get_objects():
17
17
  if isinstance(obj, CachedFunction):
18
18
  obj.cleanup(invalidated=invalidated, expired=expired)
19
19
 
20
- def get_function_hash(fn: Callable, capture=False) -> str:
21
- return CachedFunction(Checkpointer(capture=capture), fn).fn_hash
20
+ def get_function_hash(fn: Callable) -> str:
21
+ return CachedFunction(Checkpointer(), fn).ident.fn_hash
@@ -1,18 +1,20 @@
1
1
  from __future__ import annotations
2
- import inspect
3
2
  import re
4
3
  from datetime import datetime
5
4
  from functools import cached_property, update_wrapper
5
+ from inspect import Parameter, Signature, iscoroutine, signature
6
6
  from pathlib import Path
7
7
  from typing import (
8
- Awaitable, Callable, Concatenate, Generic, Iterable, Literal,
9
- ParamSpec, Self, Type, TypedDict, TypeVar, Unpack, cast, overload,
8
+ Annotated, Callable, Concatenate, Coroutine, Generic,
9
+ Iterable, Literal, ParamSpec, Self, Type, TypedDict,
10
+ TypeVar, Unpack, cast, get_args, get_origin, overload,
10
11
  )
11
- from .fn_ident import get_fn_ident
12
+ from .fn_ident import RawFunctionIdent, get_fn_ident
12
13
  from .object_hash import ObjectHash
13
14
  from .print_checkpoint import print_checkpoint
14
15
  from .storages import STORAGE_MAP, Storage
15
- from .utils import AwaitableValue, unwrap_fn
16
+ from .types import AwaitableValue, HashBy
17
+ from .utils import unwrap_fn
16
18
 
17
19
  Fn = TypeVar("Fn", bound=Callable)
18
20
  P = ParamSpec("P")
@@ -29,10 +31,9 @@ class CheckpointerOpts(TypedDict, total=False):
29
31
  root_path: Path | str | None
30
32
  when: bool
31
33
  verbosity: Literal[0, 1, 2]
32
- hash_by: Callable | None
33
34
  should_expire: Callable[[datetime], bool] | None
34
35
  capture: bool
35
- fn_hash: ObjectHash | None
36
+ fn_hash_from: object
36
37
 
37
38
  class Checkpointer:
38
39
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
@@ -40,10 +41,9 @@ class Checkpointer:
40
41
  self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
41
42
  self.when = opts.get("when", True)
42
43
  self.verbosity = opts.get("verbosity", 1)
43
- self.hash_by = opts.get("hash_by")
44
44
  self.should_expire = opts.get("should_expire")
45
45
  self.capture = opts.get("capture", False)
46
- self.fn_hash = opts.get("fn_hash")
46
+ self.fn_hash_from = opts.get("fn_hash_from")
47
47
 
48
48
  @overload
49
49
  def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
@@ -56,6 +56,35 @@ class Checkpointer:
56
56
 
57
57
  return CachedFunction(self, fn) if callable(fn) else self
58
58
 
59
+ class FunctionIdent:
60
+ """
61
+ Represents the identity and hash state of a cached function.
62
+ Separated from CachedFunction to prevent hash desynchronization
63
+ among bound instances when `.reinit()` is called.
64
+ """
65
+ def __init__(self, cached_fn: CachedFunction):
66
+ self.__dict__.clear()
67
+ self.cached_fn = cached_fn
68
+
69
+ @cached_property
70
+ def raw_ident(self) -> RawFunctionIdent:
71
+ return get_fn_ident(unwrap_fn(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
72
+
73
+ @cached_property
74
+ def fn_hash(self) -> str:
75
+ if self.cached_fn.checkpointer.fn_hash_from is not None:
76
+ return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
77
+ deep_hashes = [depend.ident.raw_ident.fn_hash for depend in self.cached_fn.deep_depends()]
78
+ return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
79
+
80
+ @cached_property
81
+ def captured_hash(self) -> str:
82
+ deep_hashes = [depend.ident.raw_ident.captured_hash for depend in self.cached_fn.deep_depends()]
83
+ return str(ObjectHash().write_text(iter=deep_hashes))
84
+
85
+ def clear(self):
86
+ self.__init__(self.cached_fn)
87
+
59
88
  class CachedFunction(Generic[Fn]):
60
89
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
61
90
  wrapped = unwrap_fn(fn)
@@ -70,6 +99,14 @@ class CachedFunction(Generic[Fn]):
70
99
  self.cleanup = self.storage.cleanup
71
100
  self.bound = ()
72
101
 
102
+ sig = signature(wrapped)
103
+ params = list(sig.parameters.items())
104
+ pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
105
+ self.arg_names = [name for name, param in params if param.kind in pos_params]
106
+ self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
107
+ self.hash_by_map = get_hash_by_map(sig)
108
+ self.ident = FunctionIdent(self)
109
+
73
110
  @overload
74
111
  def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
75
112
  @overload
@@ -82,41 +119,32 @@ class CachedFunction(Generic[Fn]):
82
119
  bound_fn.bound = (instance,)
83
120
  return bound_fn
84
121
 
85
- @cached_property
86
- def ident_tuple(self) -> tuple[str, list[Callable]]:
87
- return get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
88
-
89
- @property
90
- def fn_hash_raw(self) -> str:
91
- return self.ident_tuple[0]
92
-
93
122
  @property
94
123
  def depends(self) -> list[Callable]:
95
- return self.ident_tuple[1]
96
-
97
- @cached_property
98
- def fn_hash(self) -> str:
99
- deep_hashes = [depend.fn_hash_raw for depend in self.deep_depends()]
100
- fn_hash = ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes)
101
- return str(self.checkpointer.fn_hash or fn_hash)[:32]
124
+ return self.ident.raw_ident.depends
102
125
 
103
126
  def reinit(self, recursive=False) -> CachedFunction[Fn]:
104
- depends = list(self.deep_depends()) if recursive else [self]
105
- for depend in depends:
106
- self.__dict__.pop("fn_hash", None)
107
- self.__dict__.pop("ident_tuple", None)
108
- for depend in depends:
109
- depend.fn_hash
127
+ depend_idents = [depend.ident for depend in self.deep_depends()] if recursive else [self.ident]
128
+ for ident in depend_idents: ident.clear()
129
+ for ident in depend_idents: ident.fn_hash
110
130
  return self
111
131
 
112
- def get_call_id(self, args: tuple, kw: dict) -> str:
132
+ def get_call_id(self, args: tuple, kw: dict[str, object]) -> str:
113
133
  args = self.bound + args
114
- hash_by = self.checkpointer.hash_by
115
- hash_params = hash_by(*args, **kw) if hash_by else (args, kw)
116
- return str(ObjectHash(hash_params, digest_size=16))
117
-
118
- async def _resolve_awaitable(self, call_id: str, awaitable: Awaitable):
119
- return self.storage.store(call_id, AwaitableValue(await awaitable)).value
134
+ pos_args = args[len(self.arg_names):]
135
+ named_pos_args = dict(zip(self.arg_names, args))
136
+ named_args = {**self.default_args, **named_pos_args, **kw}
137
+ if hash_by_map := self.hash_by_map:
138
+ rest_hash_by = hash_by_map.get(b"**")
139
+ for key, value in named_args.items():
140
+ if hash_by := hash_by_map.get(key, rest_hash_by):
141
+ named_args[key] = hash_by(value)
142
+ if pos_hash_by := hash_by_map.get(b"*"):
143
+ pos_args = tuple(map(pos_hash_by, pos_args))
144
+ return str(ObjectHash(named_args, pos_args, self.ident.captured_hash, digest_size=16))
145
+
146
+ async def _resolve_coroutine(self, call_id: str, coroutine: Coroutine):
147
+ return self.storage.store(call_id, AwaitableValue(await coroutine)).value
120
148
 
121
149
  def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
122
150
  full_args = self.bound + args
@@ -125,7 +153,7 @@ class CachedFunction(Generic[Fn]):
125
153
  return self.fn(*full_args, **kw)
126
154
 
127
155
  call_id = self.get_call_id(args, kw)
128
- call_id_long = f"{self.fn_dir}/{self.fn_hash}/{call_id}"
156
+ call_id_long = f"{self.fn_dir}/{self.ident.fn_hash}/{call_id}"
129
157
 
130
158
  refresh = rerun \
131
159
  or not self.storage.exists(call_id) \
@@ -134,8 +162,8 @@ class CachedFunction(Generic[Fn]):
134
162
  if refresh:
135
163
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id_long, "blue")
136
164
  data = self.fn(*full_args, **kw)
137
- if inspect.isawaitable(data):
138
- return self._resolve_awaitable(call_id, data)
165
+ if iscoroutine(data):
166
+ return self._resolve_coroutine(call_id, data)
139
167
  return self.storage.store(call_id, data)
140
168
 
141
169
  try:
@@ -154,7 +182,7 @@ class CachedFunction(Generic[Fn]):
154
182
  return self._call(args, kw, True)
155
183
 
156
184
  @overload
157
- def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
185
+ def get(self: Callable[P, Coroutine[object, object, R]], *args: P.args, **kw: P.kwargs) -> R: ...
158
186
  @overload
159
187
  def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
160
188
  def get(self, *args, **kw):
@@ -172,7 +200,7 @@ class CachedFunction(Generic[Fn]):
172
200
  self.storage.delete(self.get_call_id(args, kw))
173
201
 
174
202
  def __repr__(self) -> str:
175
- return f"<CachedFunction {self.fn.__name__} {self.fn_hash[:6]}>"
203
+ return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
176
204
 
177
205
  def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
178
206
  if self not in visited:
@@ -182,3 +210,20 @@ class CachedFunction(Generic[Fn]):
182
210
  for depend in self.depends:
183
211
  if isinstance(depend, CachedFunction):
184
212
  yield from depend.deep_depends(visited)
213
+
214
+ def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
215
+ if get_origin(annotation) is Annotated:
216
+ args = get_args(annotation)
217
+ metadata = args[1] if len(args) > 1 else None
218
+ if get_origin(metadata) is HashBy:
219
+ return get_args(metadata)[0]
220
+
221
+ def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
222
+ hash_by_map = {}
223
+ for name, param in sig.parameters.items():
224
+ if param.kind == Parameter.VAR_POSITIONAL:
225
+ name = b"*"
226
+ elif param.kind == Parameter.VAR_KEYWORD:
227
+ name = b"**"
228
+ hash_by_map[name] = hash_by_from_annotation(param.annotation)
229
+ return hash_by_map if any(hash_by_map.values()) else {}
checkpointer/fn_ident.py CHANGED
@@ -1,15 +1,19 @@
1
1
  import dis
2
2
  import inspect
3
- from collections.abc import Callable
4
3
  from itertools import takewhile
5
4
  from pathlib import Path
6
5
  from types import CodeType, FunctionType, MethodType
7
- from typing import Any, Iterable, Type, TypeGuard
6
+ from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
8
7
  from .object_hash import ObjectHash
9
- from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
8
+ from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
10
9
 
11
10
  cwd = Path.cwd().resolve()
12
11
 
12
+ class RawFunctionIdent(NamedTuple):
13
+ fn_hash: str
14
+ captured_hash: str
15
+ depends: list[Callable]
16
+
13
17
  def is_class(obj) -> TypeGuard[Type]:
14
18
  # isinstance works too, but needlessly triggers _lazyinit()
15
19
  return issubclass(type(obj), type)
@@ -33,7 +37,7 @@ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[st
33
37
  scope_obj = None
34
38
  return classvars
35
39
 
36
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], Any]]:
40
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
37
41
  classvars = extract_classvars(code, scope_vars)
38
42
  scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
39
43
  for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
@@ -55,7 +59,7 @@ def get_self_value(fn: Callable) -> type | object | None:
55
59
  if is_class(cls):
56
60
  return cls
57
61
 
58
- def get_fn_captured_vals(fn: Callable) -> list[Any]:
62
+ def get_fn_captured_vals(fn: Callable) -> list[object]:
59
63
  self_value = get_self_value(fn)
60
64
  scope_vars = AttrDict({
61
65
  "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
@@ -71,24 +75,29 @@ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
71
75
  fn_path = Path(inspect.getfile(candidate_fn)).resolve()
72
76
  return cwd in fn_path.parents and ".venv" not in fn_path.parts
73
77
 
74
- def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
78
+ def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
75
79
  from .checkpoint import CachedFunction
76
- captured_vals_by_fn = captured_vals_by_fn or {}
77
80
  captured_vals = get_fn_captured_vals(fn)
78
- captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
79
- child_fns = (unwrap_fn(val, cached_fn=True) for val in captured_vals if callable(val))
80
- for child_fn in child_fns:
81
+ captured_vals_by_fn = captured_vals_by_fn or {}
82
+ captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
83
+ for val in captured_vals:
84
+ if not callable(val):
85
+ continue
86
+ child_fn = unwrap_fn(val, cached_fn=True)
81
87
  if isinstance(child_fn, CachedFunction):
82
88
  captured_vals_by_fn[child_fn] = []
83
89
  elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
84
- get_depend_fns(child_fn, capture, captured_vals_by_fn)
90
+ get_depend_fns(child_fn, captured_vals_by_fn)
85
91
  return captured_vals_by_fn
86
92
 
87
- def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
93
+ def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
88
94
  from .checkpoint import CachedFunction
89
- captured_vals_by_fn = get_depend_fns(fn, capture)
90
- depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
95
+ captured_vals_by_fn = get_depend_fns(fn)
96
+ depend_captured_vals = list(captured_vals_by_fn.values()) * capture
97
+ depends = captured_vals_by_fn.keys()
91
98
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
92
99
  unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
93
- fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
94
- return fn_hash, depends
100
+ assert fn == unwrapped_depends[0]
101
+ fn_hash = str(ObjectHash(iter=unwrapped_depends))
102
+ captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
103
+ return RawFunctionIdent(fn_hash, captured_hash, depends)
@@ -1,16 +1,19 @@
1
1
  import ctypes
2
2
  import hashlib
3
+ import inspect
3
4
  import io
4
5
  import re
5
6
  import sys
7
+ import tokenize
6
8
  from collections.abc import Iterable
7
9
  from contextlib import nullcontext, suppress
8
10
  from decimal import Decimal
11
+ from io import StringIO
9
12
  from itertools import chain
10
- from pickle import HIGHEST_PROTOCOL as PROTOCOL
13
+ from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
11
14
  from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
12
- from typing import Any, TypeVar
13
- from .utils import ContextVar, get_fn_body
15
+ from typing import Callable, TypeVar
16
+ from .utils import ContextVar
14
17
 
15
18
  np, torch = None, None
16
19
 
@@ -31,16 +34,16 @@ else:
31
34
  def encode_type(t: type | FunctionType) -> str:
32
35
  return f"{t.__module__}:{t.__qualname__}"
33
36
 
34
- def encode_type_of(v: Any) -> str:
37
+ def encode_type_of(v: object) -> str:
35
38
  return encode_type(type(v))
36
39
 
37
40
  class ObjectHashError(Exception):
38
- def __init__(self, obj: Any, cause: Exception):
41
+ def __init__(self, obj: object, cause: Exception):
39
42
  super().__init__(f"{type(cause).__name__} error when hashing {obj}")
40
43
  self.obj = obj
41
44
 
42
45
  class ObjectHash:
43
- def __init__(self, *objs: Any, iter: Iterable[Any] = (), digest_size=64, tolerate_errors=False) -> None:
46
+ def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerate_errors=False) -> None:
44
47
  self.hash = hashlib.blake2b(digest_size=digest_size)
45
48
  self.current: dict[int, int] = {}
46
49
  self.tolerate_errors = ContextVar(tolerate_errors)
@@ -59,7 +62,7 @@ class ObjectHash:
59
62
  def __eq__(self, value: object) -> bool:
60
63
  return isinstance(value, ObjectHash) and str(self) == str(value)
61
64
 
62
- def nested_hash(self, *objs: Any) -> str:
65
+ def nested_hash(self, *objs: object) -> str:
63
66
  return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
64
67
 
65
68
  def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
@@ -70,10 +73,10 @@ class ObjectHash:
70
73
  def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
71
74
  return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
72
75
 
73
- def header(self, *args: Any) -> "ObjectHash":
76
+ def header(self, *args: object) -> "ObjectHash":
74
77
  return self.write_bytes(":".join(map(str, args)).encode())
75
78
 
76
- def update(self, *objs: Any, iter: Iterable[Any] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
79
+ def update(self, *objs: object, iter: Iterable[object] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
77
80
  with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
78
81
  for obj in chain(objs, iter):
79
82
  try:
@@ -81,11 +84,11 @@ class ObjectHash:
81
84
  except Exception as ex:
82
85
  if self.tolerate_errors.value:
83
86
  self.header("error").update(type(ex))
84
- continue
85
- raise ObjectHashError(obj, ex) from ex
87
+ else:
88
+ raise ObjectHashError(obj, ex) from ex
86
89
  return self
87
90
 
88
- def _update_one(self, obj: Any) -> None:
91
+ def _update_one(self, obj: object) -> None:
89
92
  match obj:
90
93
  case None:
91
94
  self.header("null")
@@ -142,7 +145,7 @@ class ObjectHash:
142
145
  case _ if np and isinstance(obj, np.ndarray):
143
146
  self.header("ndarray", encode_type_of(obj), obj.shape, obj.strides).update(obj.dtype)
144
147
  if obj.dtype.hasobject:
145
- self.update(obj.__reduce_ex__(PROTOCOL))
148
+ self.update(obj.__reduce_ex__(PICKLE_PROTOCOL))
146
149
  else:
147
150
  array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
148
151
  self.write_bytes(array.data)
@@ -180,13 +183,14 @@ class ObjectHash:
180
183
  def _update_iterator(self, obj: Iterable) -> "ObjectHash":
181
184
  return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
182
185
 
183
- def _update_object(self, obj: Any) -> "ObjectHash":
186
+ def _update_object(self, obj: object) -> "ObjectHash":
184
187
  self.header("instance", encode_type_of(obj))
185
- if hasattr(obj, "__objecthash__") and callable(obj.__objecthash__):
186
- return self.header("objecthash").update(obj.__objecthash__())
188
+ get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
189
+ if callable(get_hash):
190
+ return self.header("objecthash").update(get_hash())
187
191
  reduced = None
188
192
  with suppress(Exception):
189
- reduced = obj.__reduce_ex__(PROTOCOL)
193
+ reduced = obj.__reduce_ex__(PICKLE_PROTOCOL)
190
194
  with suppress(Exception):
191
195
  reduced = reduced or obj.__reduce__()
192
196
  if isinstance(reduced, str):
@@ -206,3 +210,12 @@ class ObjectHash:
206
210
  return self._update_iterator(obj)
207
211
  repr_str = re.sub(r"\s+(at\s+0x[0-9a-fA-F]+)(>)$", r"\2", repr(obj))
208
212
  return self.header("repr").update(repr_str)
213
+
214
+ def get_fn_body(fn: Callable) -> str:
215
+ try:
216
+ source = inspect.getsource(fn)
217
+ except OSError:
218
+ return ""
219
+ tokens = tokenize.generate_tokens(StringIO(source).readline)
220
+ ignore_types = (tokenize.COMMENT, tokenize.NL)
221
+ return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
@@ -15,7 +15,7 @@ class Storage:
15
15
  self.cached_fn = cached_fn
16
16
 
17
17
  def fn_id(self) -> str:
18
- return f"{self.cached_fn.fn_dir}/{self.cached_fn.fn_hash}"
18
+ return f"{self.cached_fn.fn_dir}/{self.cached_fn.ident.fn_hash}"
19
19
 
20
20
  def fn_dir(self) -> Path:
21
21
  return self.checkpointer.root_path / self.fn_id()
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import pytest
3
- from . import CheckpointError, checkpoint
3
+ from checkpointer import CheckpointError, checkpoint
4
4
  from .utils import AttrDict
5
5
 
6
6
  def global_multiply(a: int, b: int) -> int:
@@ -110,16 +110,16 @@ def test_capture():
110
110
  def test_a():
111
111
  return item_dict.a + 1
112
112
 
113
- init_hash_a = test_a.fn_hash
114
- init_hash_whole = test_whole.fn_hash
113
+ init_hash_a = test_a.ident.captured_hash
114
+ init_hash_whole = test_whole.ident.captured_hash
115
115
  item_dict.b += 1
116
116
  test_whole.reinit()
117
117
  test_a.reinit()
118
- assert test_whole.fn_hash != init_hash_whole
119
- assert test_a.fn_hash == init_hash_a
118
+ assert test_whole.ident.captured_hash != init_hash_whole
119
+ assert test_a.ident.captured_hash == init_hash_a
120
120
  item_dict.a += 1
121
121
  test_a.reinit()
122
- assert test_a.fn_hash != init_hash_a
122
+ assert test_a.ident.captured_hash != init_hash_a
123
123
 
124
124
  def test_depends():
125
125
  def multiply_wrapper(a: int, b: int) -> int:
checkpointer/types.py ADDED
@@ -0,0 +1,17 @@
1
+ from typing import Annotated, Callable, Generic, TypeVar
2
+
3
+ T = TypeVar("T")
4
+ Fn = TypeVar("Fn", bound=Callable)
5
+
6
+ class HashBy(Generic[Fn]):
7
+ pass
8
+
9
+ NoHash = Annotated[T, HashBy[lambda _: None]]
10
+
11
+ class AwaitableValue:
12
+ def __init__(self, value):
13
+ self.value = value
14
+
15
+ def __await__(self):
16
+ yield
17
+ return self.value
checkpointer/utils.py CHANGED
@@ -1,7 +1,4 @@
1
- import inspect
2
- import tokenize
3
1
  from contextlib import contextmanager
4
- from io import StringIO
5
2
  from typing import Any, Callable, Generic, Iterable, TypeVar, cast
6
3
 
7
4
  T = TypeVar("T")
@@ -10,21 +7,6 @@ Fn = TypeVar("Fn", bound=Callable)
10
7
  def distinct(seq: Iterable[T]) -> list[T]:
11
8
  return list(dict.fromkeys(seq))
12
9
 
13
- def transpose(tuples, default_num_returns=0):
14
- output = tuple(zip(*tuples))
15
- if not output:
16
- return ([],) * default_num_returns
17
- return tuple(map(list, output))
18
-
19
- def get_fn_body(fn: Callable) -> str:
20
- try:
21
- source = inspect.getsource(fn)
22
- except OSError:
23
- return ""
24
- tokens = tokenize.generate_tokens(StringIO(source).readline)
25
- ignore_types = (tokenize.COMMENT, tokenize.NL)
26
- return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
27
-
28
10
  def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
29
11
  for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
30
12
  try:
@@ -39,14 +21,6 @@ def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
39
21
  return cast(Fn, fn)
40
22
  fn = getattr(fn, "__wrapped__")
41
23
 
42
- class AwaitableValue:
43
- def __init__(self, value):
44
- self.value = value
45
-
46
- def __await__(self):
47
- yield
48
- return self.value
49
-
50
24
  class AttrDict(dict):
51
25
  def __init__(self, *args, **kwargs):
52
26
  super().__init__(*args, **kwargs)
@@ -91,14 +65,19 @@ class iterate_and_upcoming(Generic[T]):
91
65
  def __init__(self, it: Iterable[T]) -> None:
92
66
  self.it = iter(it)
93
67
  self.previous: tuple[()] | tuple[T] = ()
68
+ self.tracked = self._tracked_iter()
94
69
 
95
70
  def __iter__(self):
96
71
  return self
97
72
 
98
73
  def __next__(self) -> tuple[T, Iterable[T]]:
99
- item = self.previous[0] if self.previous else next(self.it)
100
- self.previous = ()
101
- return item, self._tracked_iter()
74
+ try:
75
+ item = self.previous[0] if self.previous else next(self.it)
76
+ self.previous = ()
77
+ return item, self.tracked
78
+ except StopIteration:
79
+ self.tracked.close()
80
+ raise
102
81
 
103
82
  def _tracked_iter(self):
104
83
  for x in self.it:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.10.1
3
+ Version: 2.11.0
4
4
  Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -121,11 +121,8 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
121
121
  * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
122
122
  A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
123
123
 
124
- * **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)\
125
- A custom callable that takes the function's arguments (`*args`, `**kwargs`) and returns a hashable object (or tuple of objects). This allows for custom argument normalization (e.g., sorting lists before hashing) or optimized hashing for complex input types, which can improve cache hit rates or speed up the hashing process.
126
-
127
- * **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
128
- An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
124
+ * **`fn_hash_from`** (Type: `Any`, Default: `None`)\
125
+ This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
129
126
 
130
127
  * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
131
128
  Controls the level of logging output from `checkpointer`.
@@ -133,13 +130,45 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
133
130
  * `1`: Shows when functions are computed and cached.
134
131
  * `2`: Also shows when cached results are remembered (loaded from cache).
135
132
 
136
- ### 🗄️ Custom Storage Backends
133
+ ## 🔬 Customize Argument Hashing
134
+
135
+ You can customize how individual function arguments are hashed without changing their actual values when passed in.
136
+
137
+ * **`Annotated[T, HashBy[fn]]`**:\
138
+ Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
139
+
140
+ * **`NoHash[T]`**:\
141
+ Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
142
+
143
+ **Example:**
144
+
145
+ ```python
146
+ from typing import Annotated
147
+ from checkpointer import checkpoint, HashBy, NoHash
148
+ from pathlib import Path
149
+ import logging
150
+
151
+ def file_bytes(path: Path) -> bytes:
152
+ return path.read_bytes()
153
+
154
+ @checkpoint
155
+ def process(
156
+ numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
157
+ data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
158
+ log: NoHash[logging.Logger], # Exclude logger from hashing
159
+ ):
160
+ ...
161
+ ```
162
+
163
+ In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
164
+
165
+ ## 🗄️ Custom Storage Backends
137
166
 
138
167
  For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
139
168
 
140
169
  Within custom storage methods, `call_id` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
141
170
 
142
- #### Example: Custom Storage Backend
171
+ **Example: Custom Storage Backend**
143
172
 
144
173
  ```python
145
174
  from checkpointer import checkpoint, Storage
@@ -0,0 +1,16 @@
1
+ checkpointer/__init__.py,sha256=ayjFyHwvl_HRHwocY-hOJvAx0Ko5X9IMZrNT4CwfoMU,824
2
+ checkpointer/checkpoint.py,sha256=zU67_PGrVCMP90clPTR7AA4vfxECqmeFr0jJElUL5iQ,9051
3
+ checkpointer/fn_ident.py,sha256=-5XbovQowVyYCFc7JdT9z1NoIEiL8h9fi7alF_34Ils,4470
4
+ checkpointer/object_hash.py,sha256=YlyFupQrg3V2mpzTLfOqpqlZWhoSCHliScQ4cKd36T0,8133
5
+ checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
6
+ checkpointer/test_checkpointer.py,sha256=-EvsMMNOOiIxhTcG97LLX0jUMWp534ko7qCKDSFWiA0,3802
7
+ checkpointer/types.py,sha256=rAdjZNn1-jk35df7UVtby_qlp-8_18ucXfVCtS4RI_M,323
8
+ checkpointer/utils.py,sha256=0cGVSlTnABgs3jI1uHoTfz353kkGa-qtTfe7jG4NCr0,2192
9
+ checkpointer/storages/__init__.py,sha256=en32nTUltpCSgz8RVGS_leIHC1Y1G89IqG1ZqAb6qUo,236
10
+ checkpointer/storages/memory_storage.py,sha256=Br30b1AyNOcNjjAaDui1mBjDKfhDbu--jV4WmJenzaE,1109
11
+ checkpointer/storages/pickle_storage.py,sha256=xS96q8TvwxH_TOZsiKmrMrhFwQKZCcaH7XOiECmuwl8,1628
12
+ checkpointer/storages/storage.py,sha256=5NWmnfsU2QY24NwKC5MZNe4h7pcYhADw-y800v17aYE,895
13
+ checkpointer-2.11.0.dist-info/METADATA,sha256=GFRKdyTOIQwvmzIIrkV2owT8NO-X2ONRXdnadY9j5xY,11617
14
+ checkpointer-2.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ checkpointer-2.11.0.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
16
+ checkpointer-2.11.0.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- checkpointer/__init__.py,sha256=FBdEwdyQS2tc4eQMwrb9qsFgjPu9QOyLOqgDK2c1rAI,837
2
- checkpointer/checkpoint.py,sha256=yysni7CKzQ2XRX0nvEd34ktNjUcHVZ0CnKMaqQX-f_8,6827
3
- checkpointer/fn_ident.py,sha256=hiPvm1lw9Aol6v-d5bjteWdg0VhrPQH4r1Wiow-Wzpo,4311
4
- checkpointer/object_hash.py,sha256=kfyhFFxvcjwkexw5TKkAtCyWU3PZXgsWGJ5dsOaQXmY,7695
5
- checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
6
- checkpointer/test_checkpointer.py,sha256=mvJu5EItLM95S6IdawgINVgqVCPBNx1bnMHJmwF_cwo,3731
7
- checkpointer/utils.py,sha256=_zvbkno7221_4l_FbROJMRrs-enx8BRS4FGxn8et9Ns,2751
8
- checkpointer/storages/__init__.py,sha256=en32nTUltpCSgz8RVGS_leIHC1Y1G89IqG1ZqAb6qUo,236
9
- checkpointer/storages/memory_storage.py,sha256=Br30b1AyNOcNjjAaDui1mBjDKfhDbu--jV4WmJenzaE,1109
10
- checkpointer/storages/pickle_storage.py,sha256=xS96q8TvwxH_TOZsiKmrMrhFwQKZCcaH7XOiECmuwl8,1628
11
- checkpointer/storages/storage.py,sha256=dR84h-IN644XLGe7invSlzLuXlFlU8QIbsDZa1xdtPM,889
12
- checkpointer-2.10.1.dist-info/METADATA,sha256=n7eiGDEoqRH-bEncwy27EM108HMwqVLHpDbYGY_PFZk,10908
13
- checkpointer-2.10.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- checkpointer-2.10.1.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
15
- checkpointer-2.10.1.dist-info/RECORD,,