checkpointer 2.5.0__tar.gz → 2.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ 3.12.7
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.5.0
3
+ Version: 2.6.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -12,6 +12,8 @@ License: Copyright 2018-2025 Hampus Hallman
12
12
 
13
13
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
14
14
  License-File: LICENSE
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
15
17
  Requires-Python: >=3.12
16
18
  Description-Content-Type: text/markdown
17
19
 
@@ -28,6 +30,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
28
30
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
29
31
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
30
32
  - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
33
+ - ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
31
34
 
32
35
  ---
33
36
 
@@ -61,9 +64,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
61
64
 
62
65
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
63
66
 
64
- 1. **Its source code**: Changes to the function's code update its hash.
65
- 2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
66
- 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
67
+ 1. **Function Code**: The hash updates when the functions own source code changes.
68
+ 2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
69
+ 3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
67
70
 
68
71
  ### Example: Cache Invalidation
69
72
 
@@ -155,21 +158,19 @@ stored_result = expensive_function.get(4)
155
158
 
156
159
  ### Refresh Function Hash
157
160
 
158
- When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
159
-
160
- Use the `reinit` method to manually refresh the function's hash within the same session:
161
+ If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
161
162
 
162
163
  ```python
163
164
  expensive_function.reinit()
164
165
  ```
165
166
 
166
- This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
167
+ This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
167
168
 
168
169
  ---
169
170
 
170
171
  ## Storage Backends
171
172
 
172
- `checkpointer` works with both built-in and custom storage backends, so you can use what's provided or roll your own as needed.
173
+ `checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
173
174
 
174
175
  ### Built-In Backends
175
176
 
@@ -201,10 +202,10 @@ from checkpointer import checkpoint, Storage
201
202
  from datetime import datetime
202
203
 
203
204
  class CustomStorage(Storage):
204
- def store(self, path, data): ... # Save the checkpoint data
205
- def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
206
- def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
207
- def load(self, path): ... # Return the checkpoint data
205
+ def exists(self, path) -> bool: ... # Check if a checkpoint exists
206
+ def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
207
+ def store(self, path, data): ... # Save data to the checkpoint
208
+ def load(self, path): ... # Load data from the checkpoint
208
209
  def delete(self, path): ... # Delete the checkpoint
209
210
 
210
211
  @checkpoint(format=CustomStorage)
@@ -212,21 +213,21 @@ def custom_cached(x: int):
212
213
  return x ** 2
213
214
  ```
214
215
 
215
- Using a custom backend lets you tailor storage to your application, whether it involves databases, cloud storage, or custom file formats.
216
+ Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
216
217
 
217
218
  ---
218
219
 
219
220
  ## Configuration Options ⚙️
220
221
 
221
- | Option | Type | Default | Description |
222
- |-----------------|-----------------------------------|----------------------|------------------------------------------------|
223
- | `capture` | `bool` | `False` | Include captured variables in function hashes. |
224
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
225
- | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
226
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
227
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
228
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
229
- | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
222
+ | Option | Type | Default | Description |
223
+ |-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
224
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
225
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
226
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
227
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
228
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
229
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
230
+ | `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
230
231
 
231
232
  ---
232
233
 
@@ -11,6 +11,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
11
11
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
12
12
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
13
13
  - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
14
+ - ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
14
15
 
15
16
  ---
16
17
 
@@ -44,9 +45,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
44
45
 
45
46
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
46
47
 
47
- 1. **Its source code**: Changes to the function's code update its hash.
48
- 2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
49
- 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
48
+ 1. **Function Code**: The hash updates when the functions own source code changes.
49
+ 2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
50
+ 3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
50
51
 
51
52
  ### Example: Cache Invalidation
52
53
 
@@ -138,21 +139,19 @@ stored_result = expensive_function.get(4)
138
139
 
139
140
  ### Refresh Function Hash
140
141
 
141
- When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
142
-
143
- Use the `reinit` method to manually refresh the function's hash within the same session:
142
+ If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
144
143
 
145
144
  ```python
146
145
  expensive_function.reinit()
147
146
  ```
148
147
 
149
- This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
148
+ This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
150
149
 
151
150
  ---
152
151
 
153
152
  ## Storage Backends
154
153
 
155
- `checkpointer` works with both built-in and custom storage backends, so you can use what's provided or roll your own as needed.
154
+ `checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
156
155
 
157
156
  ### Built-In Backends
158
157
 
@@ -184,10 +183,10 @@ from checkpointer import checkpoint, Storage
184
183
  from datetime import datetime
185
184
 
186
185
  class CustomStorage(Storage):
187
- def store(self, path, data): ... # Save the checkpoint data
188
- def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
189
- def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
190
- def load(self, path): ... # Return the checkpoint data
186
+ def exists(self, path) -> bool: ... # Check if a checkpoint exists
187
+ def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
188
+ def store(self, path, data): ... # Save data to the checkpoint
189
+ def load(self, path): ... # Load data from the checkpoint
191
190
  def delete(self, path): ... # Delete the checkpoint
192
191
 
193
192
  @checkpoint(format=CustomStorage)
@@ -195,21 +194,21 @@ def custom_cached(x: int):
195
194
  return x ** 2
196
195
  ```
197
196
 
198
- Using a custom backend lets you tailor storage to your application, whether it involves databases, cloud storage, or custom file formats.
197
+ Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
199
198
 
200
199
  ---
201
200
 
202
201
  ## Configuration Options ⚙️
203
202
 
204
- | Option | Type | Default | Description |
205
- |-----------------|-----------------------------------|----------------------|------------------------------------------------|
206
- | `capture` | `bool` | `False` | Include captured variables in function hashes. |
207
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
208
- | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
209
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
210
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
211
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
212
- | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
203
+ | Option | Type | Default | Description |
204
+ |-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
205
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
206
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
207
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
208
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
209
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
210
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
211
+ | `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
213
212
 
214
213
  ---
215
214
 
@@ -23,7 +23,7 @@ class CheckpointerOpts(TypedDict, total=False):
23
23
  root_path: Path | str | None
24
24
  when: bool
25
25
  verbosity: Literal[0, 1]
26
- path: Callable[..., str] | None
26
+ hash_by: Callable | None
27
27
  should_expire: Callable[[datetime], bool] | None
28
28
  capture: bool
29
29
 
@@ -33,7 +33,7 @@ class Checkpointer:
33
33
  self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
34
34
  self.when = opts.get("when", True)
35
35
  self.verbosity = opts.get("verbosity", 1)
36
- self.path = opts.get("path")
36
+ self.hash_by = opts.get("hash_by")
37
37
  self.should_expire = opts.get("should_expire")
38
38
  self.capture = opts.get("capture", False)
39
39
 
@@ -66,7 +66,7 @@ class CheckpointFn(Generic[Fn]):
66
66
  store_format = self.checkpointer.format
67
67
  Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
68
68
  deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
69
- self.fn_hash = str(ObjectHash().update_hash(self.fn_hash_raw, iter=deep_hashes))
69
+ self.fn_hash = str(ObjectHash().write_text(self.fn_hash_raw, iter=deep_hashes))
70
70
  self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
71
71
  self.is_async = inspect.iscoroutinefunction(wrapped)
72
72
  self.storage = Storage(self)
@@ -88,13 +88,9 @@ class CheckpointFn(Generic[Fn]):
88
88
  pointfn._lazyinit()
89
89
 
90
90
  def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
91
- if not callable(self.checkpointer.path):
92
- call_hash = ObjectHash(self.fn_hash, args, kw, digest_size=16)
93
- return f"{self.fn_subdir}/{call_hash}"
94
- checkpoint_id = self.checkpointer.path(*args, **kw)
95
- if not isinstance(checkpoint_id, str):
96
- raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
97
- return checkpoint_id
91
+ hash_params = [self.checkpointer.hash_by(*args, **kw)] if self.checkpointer.hash_by else (args, kw)
92
+ call_hash = ObjectHash(self.fn_hash, *hash_params, digest_size=16)
93
+ return f"{self.fn_subdir}/{call_hash}"
98
94
 
99
95
  async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
100
96
  checkpoint_id = self.get_checkpoint_id(args, kw)
@@ -11,7 +11,7 @@ from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming,
11
11
  cwd = Path.cwd()
12
12
 
13
13
  def is_class(obj) -> TypeGuard[Type]:
14
- # isinstance works too, but needlessly triggers __getattribute__
14
+ # isinstance works too, but needlessly triggers _lazyinit()
15
15
  return issubclass(type(obj), type)
16
16
 
17
17
  def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
@@ -3,7 +3,7 @@ import hashlib
3
3
  import io
4
4
  import re
5
5
  from collections.abc import Iterable
6
- from contextlib import nullcontext
6
+ from contextlib import nullcontext, suppress
7
7
  from decimal import Decimal
8
8
  from itertools import chain
9
9
  from pickle import HIGHEST_PROTOCOL as PROTOCOL
@@ -11,19 +11,18 @@ from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType,
11
11
  from typing import Any, TypeAliasType, TypeVar
12
12
  from .utils import ContextVar, get_fn_body
13
13
 
14
- try:
14
+ np, torch = None, None
15
+
16
+ with suppress(Exception):
15
17
  import numpy as np
16
- except:
17
- np = None
18
- try:
18
+
19
+ with suppress(Exception):
19
20
  import torch
20
- except:
21
- torch = None
22
21
 
23
22
  def encode_type(t: type | FunctionType) -> str:
24
23
  return f"{t.__module__}:{t.__qualname__}"
25
24
 
26
- def encode_val(v: Any) -> str:
25
+ def encode_type_of(v: Any) -> str:
27
26
  return encode_type(type(v))
28
27
 
29
28
  class ObjectHashError(Exception):
@@ -32,11 +31,11 @@ class ObjectHashError(Exception):
32
31
  self.obj = obj
33
32
 
34
33
  class ObjectHash:
35
- def __init__(self, *obj: Any, iter: Iterable[Any] = [], digest_size=64, tolerate_errors=False) -> None:
34
+ def __init__(self, *objs: Any, iter: Iterable[Any] = (), digest_size=64, tolerate_errors=False) -> None:
36
35
  self.hash = hashlib.blake2b(digest_size=digest_size)
37
36
  self.current: dict[int, int] = {}
38
37
  self.tolerate_errors = ContextVar(tolerate_errors)
39
- self.update(iter=chain(obj, iter))
38
+ self.update(iter=chain(objs, iter))
40
39
 
41
40
  def copy(self) -> "ObjectHash":
42
41
  new = ObjectHash(tolerate_errors=self.tolerate_errors.value)
@@ -48,15 +47,21 @@ class ObjectHash:
48
47
 
49
48
  __str__ = hexdigest
50
49
 
51
- def update_hash(self, *data: bytes | str, iter: Iterable[bytes | str] = []) -> "ObjectHash":
50
+ def nested_hash(self, *objs: Any) -> str:
51
+ return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
52
+
53
+ def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
52
54
  for d in chain(data, iter):
53
- self.hash.update(d.encode() if isinstance(d, str) else d)
55
+ self.hash.update(d)
54
56
  return self
55
57
 
58
+ def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
59
+ return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
60
+
56
61
  def header(self, *args: Any) -> "ObjectHash":
57
- return self.update_hash(":".join(map(str, args)))
62
+ return self.write_bytes(":".join(map(str, args)).encode())
58
63
 
59
- def update(self, *objs: Any, iter: Iterable[Any] = [], tolerate_errors: bool | None=None) -> "ObjectHash":
64
+ def update(self, *objs: Any, iter: Iterable[Any] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
60
65
  with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
61
66
  for obj in chain(objs, iter):
62
67
  try:
@@ -74,19 +79,20 @@ class ObjectHash:
74
79
  self.header("null")
75
80
 
76
81
  case bool() | int() | float() | complex() | Decimal() | ObjectHash():
77
- self.header("number", encode_val(obj), obj)
82
+ self.header("number", encode_type_of(obj), obj)
78
83
 
79
84
  case str() | bytes() | bytearray() | memoryview():
80
- self.header("bytes", encode_val(obj), len(obj)).update_hash(obj)
85
+ b = obj.encode() if isinstance(obj, str) else obj
86
+ self.header("bytes", encode_type_of(obj), len(b)).write_bytes(b)
81
87
 
82
88
  case set() | frozenset():
83
- self.header("set", encode_val(obj), len(obj))
84
89
  try:
85
90
  items = sorted(obj)
91
+ header = "set"
86
92
  except:
87
- self.header("unsortable")
88
- items = sorted(str(ObjectHash(item, tolerate_errors=self.tolerate_errors.value)) for item in obj)
89
- self.update(iter=items)
93
+ items = sorted(map(self.nested_hash, obj))
94
+ header = "set-unsortable"
95
+ self.header(header, encode_type_of(obj), len(obj)).update(iter=items)
90
96
 
91
97
  case TypeVar():
92
98
  self.header("TypeVar").update(obj.__name__, obj.__bound__, obj.__constraints__, obj.__contravariant__, obj.__covariant__)
@@ -113,7 +119,7 @@ class ObjectHash:
113
119
  self.header("generator", obj.__qualname__)._update_iterator(obj)
114
120
 
115
121
  case io.TextIOWrapper() | io.FileIO() | io.BufferedRandom() | io.BufferedWriter() | io.BufferedReader():
116
- self.header("file", encode_val(obj)).update(obj.name, obj.mode, obj.tell())
122
+ self.header("file", encode_type_of(obj)).update(obj.name, obj.mode, obj.tell())
117
123
 
118
124
  case type():
119
125
  self.header("type", encode_type(obj))
@@ -122,20 +128,20 @@ class ObjectHash:
122
128
  self.header("dtype").update(obj.__class__, obj.descr)
123
129
 
124
130
  case _ if np and isinstance(obj, np.ndarray):
125
- self.header("ndarray", encode_val(obj), obj.shape, obj.strides).update(obj.dtype)
131
+ self.header("ndarray", encode_type_of(obj), obj.shape, obj.strides).update(obj.dtype)
126
132
  if obj.dtype.hasobject:
127
133
  self.update(obj.__reduce_ex__(PROTOCOL))
128
134
  else:
129
135
  array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
130
- self.update_hash(array.data)
136
+ self.write_bytes(array.data)
131
137
 
132
138
  case _ if torch and isinstance(obj, torch.Tensor):
133
- self.header("tensor", encode_val(obj), str(obj.dtype), tuple(obj.shape), obj.stride(), str(obj.device))
139
+ self.header("tensor", encode_type_of(obj), obj.dtype, tuple(obj.shape), obj.stride(), obj.device)
134
140
  if obj.device.type != "cpu":
135
141
  obj = obj.cpu()
136
142
  storage = obj.storage()
137
- buffer = (ctypes.c_ubyte * (storage.nbytes())).from_address(storage.data_ptr())
138
- self.update_hash(memoryview(buffer))
143
+ buffer = (ctypes.c_ubyte * storage.nbytes()).from_address(storage.data_ptr())
144
+ self.write_bytes(memoryview(buffer))
139
145
 
140
146
  case _ if id(obj) in self.current:
141
147
  self.header("circular", self.current[id(obj)])
@@ -145,36 +151,36 @@ class ObjectHash:
145
151
  self.current[id(obj)] = len(self.current)
146
152
  match obj:
147
153
  case list() | tuple():
148
- self.header("list", encode_val(obj), len(obj)).update(iter=obj)
154
+ self.header("list", encode_type_of(obj), len(obj)).update(iter=obj)
149
155
  case dict():
150
156
  try:
151
157
  items = sorted(obj.items())
158
+ header = "dict"
152
159
  except:
153
- items = sorted((str(ObjectHash(key, tolerate_errors=self.tolerate_errors.value)), val) for key, val in obj.items())
154
- self.header("dict", encode_val(obj), len(obj)).update(iter=chain.from_iterable(items))
160
+ items = sorted((self.nested_hash(key), val) for key, val in obj.items())
161
+ header = "dict-unsortable"
162
+ self.header(header, encode_type_of(obj), len(obj)).update(iter=chain.from_iterable(items))
155
163
  case _:
156
164
  self._update_object(obj)
157
165
  finally:
158
166
  del self.current[id(obj)]
159
167
 
160
- def _update_iterator(self, obj: Iterable) -> None:
161
- self.header("iterator", encode_val(obj)).update(iter=obj).header(b"iterator-end")
168
+ def _update_iterator(self, obj: Iterable) -> "ObjectHash":
169
+ return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
162
170
 
163
171
  def _update_object(self, obj: object) -> "ObjectHash":
164
- self.header("instance", encode_val(obj))
165
- try:
166
- reduced = obj.__reduce_ex__(PROTOCOL) if hasattr(obj, "__reduce_ex__") else obj.__reduce__()
167
- except:
168
- reduced = None
172
+ self.header("instance", encode_type_of(obj))
173
+ reduced = None
174
+ with suppress(Exception):
175
+ reduced = obj.__reduce_ex__(PROTOCOL)
176
+ with suppress(Exception):
177
+ reduced = reduced or obj.__reduce__()
169
178
  if isinstance(reduced, str):
170
179
  return self.header("reduce-str").update(reduced)
171
180
  if reduced:
172
181
  reduced = list(reduced)
173
182
  it = reduced.pop(3) if len(reduced) >= 4 else None
174
- self.header("reduce").update(reduced)
175
- if it is not None:
176
- self._update_iterator(it)
177
- return self
183
+ return self.header("reduce").update(reduced)._update_iterator(it or ())
178
184
  if state := hasattr(obj, "__getstate__") and obj.__getstate__():
179
185
  return self.header("getstate").update(state)
180
186
  if len(getattr(obj, "__slots__", [])):
@@ -83,17 +83,6 @@ async def test_async_caching():
83
83
 
84
84
  assert result1 == result2 == 9
85
85
 
86
- def test_custom_path_caching():
87
- def custom_path(a, b):
88
- return f"add/{a}-{b}"
89
-
90
- @checkpoint(path=custom_path)
91
- def add(a, b):
92
- return a + b
93
-
94
- add(3, 4)
95
- assert (checkpoint.root_path / "add/3-4.pkl").exists()
96
-
97
86
  def test_force_recalculation():
98
87
  @checkpoint
99
88
  def square(x: int) -> int:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.5.0"
3
+ version = "2.6.0"
4
4
  requires-python = ">=3.12"
5
5
  dependencies = []
6
6
  authors = [
@@ -9,6 +9,10 @@ authors = [
9
9
  description = "A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation"
10
10
  readme = "README.md"
11
11
  license = {file = "LICENSE"}
12
+ classifiers = [
13
+ "Programming Language :: Python :: 3.12",
14
+ "Programming Language :: Python :: 3.13",
15
+ ]
12
16
 
13
17
  [project.urls]
14
18
  Repository = "https://github.com/Reddan/checkpointer.git"
@@ -3,7 +3,7 @@ requires-python = ">=3.12"
3
3
 
4
4
  [[package]]
5
5
  name = "checkpointer"
6
- version = "2.5.0"
6
+ version = "2.6.0"
7
7
  source = { editable = "." }
8
8
 
9
9
  [package.dev-dependencies]
@@ -177,7 +177,6 @@ name = "nvidia-cublas-cu12"
177
177
  version = "12.4.5.8"
178
178
  source = { registry = "https://pypi.org/simple" }
179
179
  wheels = [
180
- { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 },
181
180
  { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 },
182
181
  ]
183
182
 
@@ -186,7 +185,6 @@ name = "nvidia-cuda-cupti-cu12"
186
185
  version = "12.4.127"
187
186
  source = { registry = "https://pypi.org/simple" }
188
187
  wheels = [
189
- { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 },
190
188
  { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 },
191
189
  ]
192
190
 
@@ -195,7 +193,6 @@ name = "nvidia-cuda-nvrtc-cu12"
195
193
  version = "12.4.127"
196
194
  source = { registry = "https://pypi.org/simple" }
197
195
  wheels = [
198
- { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 },
199
196
  { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 },
200
197
  ]
201
198
 
@@ -204,7 +201,6 @@ name = "nvidia-cuda-runtime-cu12"
204
201
  version = "12.4.127"
205
202
  source = { registry = "https://pypi.org/simple" }
206
203
  wheels = [
207
- { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 },
208
204
  { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 },
209
205
  ]
210
206
 
@@ -227,7 +223,6 @@ dependencies = [
227
223
  { name = "nvidia-nvjitlink-cu12" },
228
224
  ]
229
225
  wheels = [
230
- { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 },
231
226
  { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
232
227
  ]
233
228
 
@@ -236,7 +231,6 @@ name = "nvidia-curand-cu12"
236
231
  version = "10.3.5.147"
237
232
  source = { registry = "https://pypi.org/simple" }
238
233
  wheels = [
239
- { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 },
240
234
  { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 },
241
235
  ]
242
236
 
@@ -250,7 +244,6 @@ dependencies = [
250
244
  { name = "nvidia-nvjitlink-cu12" },
251
245
  ]
252
246
  wheels = [
253
- { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 },
254
247
  { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
255
248
  ]
256
249
 
@@ -262,7 +255,6 @@ dependencies = [
262
255
  { name = "nvidia-nvjitlink-cu12" },
263
256
  ]
264
257
  wheels = [
265
- { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 },
266
258
  { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
267
259
  ]
268
260
 
@@ -279,7 +271,6 @@ name = "nvidia-nvjitlink-cu12"
279
271
  version = "12.4.127"
280
272
  source = { registry = "https://pypi.org/simple" }
281
273
  wheels = [
282
- { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 },
283
274
  { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 },
284
275
  ]
285
276
 
@@ -288,7 +279,6 @@ name = "nvidia-nvtx-cu12"
288
279
  version = "12.4.127"
289
280
  source = { registry = "https://pypi.org/simple" }
290
281
  wheels = [
291
- { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 },
292
282
  { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
293
283
  ]
294
284
 
@@ -459,21 +449,21 @@ dependencies = [
459
449
  { name = "fsspec" },
460
450
  { name = "jinja2" },
461
451
  { name = "networkx" },
462
- { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
463
- { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
464
- { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
465
- { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
466
- { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
467
- { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
468
- { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
469
- { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
470
- { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
471
- { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
472
- { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
473
- { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
452
+ { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
453
+ { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
454
+ { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
455
+ { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
456
+ { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
457
+ { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
458
+ { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
459
+ { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
460
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
461
+ { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
462
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
463
+ { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
474
464
  { name = "setuptools" },
475
465
  { name = "sympy" },
476
- { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" },
466
+ { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" },
477
467
  { name = "typing-extensions" },
478
468
  ]
479
469
  wheels = [
File without changes
File without changes