checkpointer 2.5.0__tar.gz → 2.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ 3.12.7
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.5.0
3
+ Version: 2.6.1
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -12,10 +12,12 @@ License: Copyright 2018-2025 Hampus Hallman
12
12
 
13
13
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
14
14
  License-File: LICENSE
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
15
17
  Requires-Python: >=3.12
16
18
  Description-Content-Type: text/markdown
17
19
 
18
- # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![Python 3.12](https://img.shields.io/badge/python-3.12-blue)](https://pypi.org/project/checkpointer/)
20
+ # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![pypi](https://img.shields.io/pypi/pyversions/checkpointer)](https://pypi.org/project/checkpointer/)
19
21
 
20
22
  `checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
21
23
 
@@ -28,6 +30,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
28
30
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
29
31
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
30
32
  - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
33
+ - ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
31
34
 
32
35
  ---
33
36
 
@@ -61,9 +64,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
61
64
 
62
65
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
63
66
 
64
- 1. **Its source code**: Changes to the function's code update its hash.
65
- 2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
66
- 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
67
+ 1. **Function Code**: The hash updates when the functions own source code changes.
68
+ 2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
69
+ 3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
67
70
 
68
71
  ### Example: Cache Invalidation
69
72
 
@@ -155,21 +158,19 @@ stored_result = expensive_function.get(4)
155
158
 
156
159
  ### Refresh Function Hash
157
160
 
158
- When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
159
-
160
- Use the `reinit` method to manually refresh the function's hash within the same session:
161
+ If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
161
162
 
162
163
  ```python
163
164
  expensive_function.reinit()
164
165
  ```
165
166
 
166
- This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
167
+ This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
167
168
 
168
169
  ---
169
170
 
170
171
  ## Storage Backends
171
172
 
172
- `checkpointer` works with both built-in and custom storage backends, so you can use what's provided or roll your own as needed.
173
+ `checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
173
174
 
174
175
  ### Built-In Backends
175
176
 
@@ -201,10 +202,10 @@ from checkpointer import checkpoint, Storage
201
202
  from datetime import datetime
202
203
 
203
204
  class CustomStorage(Storage):
204
- def store(self, path, data): ... # Save the checkpoint data
205
- def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
206
- def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
207
- def load(self, path): ... # Return the checkpoint data
205
+ def exists(self, path) -> bool: ... # Check if a checkpoint exists
206
+ def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
207
+ def store(self, path, data): ... # Save data to the checkpoint
208
+ def load(self, path): ... # Load data from the checkpoint
208
209
  def delete(self, path): ... # Delete the checkpoint
209
210
 
210
211
  @checkpoint(format=CustomStorage)
@@ -212,21 +213,21 @@ def custom_cached(x: int):
212
213
  return x ** 2
213
214
  ```
214
215
 
215
- Using a custom backend lets you tailor storage to your application, whether it involves databases, cloud storage, or custom file formats.
216
+ Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
216
217
 
217
218
  ---
218
219
 
219
220
  ## Configuration Options ⚙️
220
221
 
221
- | Option | Type | Default | Description |
222
- |-----------------|-----------------------------------|----------------------|------------------------------------------------|
223
- | `capture` | `bool` | `False` | Include captured variables in function hashes. |
224
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
225
- | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
226
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
227
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
228
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
229
- | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
222
+ | Option | Type | Default | Description |
223
+ |-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
224
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
225
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
226
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
227
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
228
+ | `verbosity` | `0`, `1` or `2` | `1` | Logging verbosity. |
229
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
230
+ | `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
230
231
 
231
232
  ---
232
233
 
@@ -1,4 +1,4 @@
1
- # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![Python 3.12](https://img.shields.io/badge/python-3.12-blue)](https://pypi.org/project/checkpointer/)
1
+ # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![pypi](https://img.shields.io/pypi/pyversions/checkpointer)](https://pypi.org/project/checkpointer/)
2
2
 
3
3
  `checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
4
4
 
@@ -11,6 +11,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
11
11
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
12
12
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
13
13
  - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
14
+ - ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
14
15
 
15
16
  ---
16
17
 
@@ -44,9 +45,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
44
45
 
45
46
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
46
47
 
47
- 1. **Its source code**: Changes to the function's code update its hash.
48
- 2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
49
- 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
48
+ 1. **Function Code**: The hash updates when the functions own source code changes.
49
+ 2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
50
+ 3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
50
51
 
51
52
  ### Example: Cache Invalidation
52
53
 
@@ -138,21 +139,19 @@ stored_result = expensive_function.get(4)
138
139
 
139
140
  ### Refresh Function Hash
140
141
 
141
- When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
142
-
143
- Use the `reinit` method to manually refresh the function's hash within the same session:
142
+ If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
144
143
 
145
144
  ```python
146
145
  expensive_function.reinit()
147
146
  ```
148
147
 
149
- This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
148
+ This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
150
149
 
151
150
  ---
152
151
 
153
152
  ## Storage Backends
154
153
 
155
- `checkpointer` works with both built-in and custom storage backends, so you can use what's provided or roll your own as needed.
154
+ `checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
156
155
 
157
156
  ### Built-In Backends
158
157
 
@@ -184,10 +183,10 @@ from checkpointer import checkpoint, Storage
184
183
  from datetime import datetime
185
184
 
186
185
  class CustomStorage(Storage):
187
- def store(self, path, data): ... # Save the checkpoint data
188
- def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
189
- def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
190
- def load(self, path): ... # Return the checkpoint data
186
+ def exists(self, path) -> bool: ... # Check if a checkpoint exists
187
+ def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
188
+ def store(self, path, data): ... # Save data to the checkpoint
189
+ def load(self, path): ... # Load data from the checkpoint
191
190
  def delete(self, path): ... # Delete the checkpoint
192
191
 
193
192
  @checkpoint(format=CustomStorage)
@@ -195,21 +194,21 @@ def custom_cached(x: int):
195
194
  return x ** 2
196
195
  ```
197
196
 
198
- Using a custom backend lets you tailor storage to your application, whether it involves databases, cloud storage, or custom file formats.
197
+ Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
199
198
 
200
199
  ---
201
200
 
202
201
  ## Configuration Options ⚙️
203
202
 
204
- | Option | Type | Default | Description |
205
- |-----------------|-----------------------------------|----------------------|------------------------------------------------|
206
- | `capture` | `bool` | `False` | Include captured variables in function hashes. |
207
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
208
- | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
209
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
210
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
211
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
212
- | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
203
+ | Option | Type | Default | Description |
204
+ |-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
205
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
206
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
207
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
208
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
209
+ | `verbosity` | `0`, `1` or `2` | `1` | Logging verbosity. |
210
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
211
+ | `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
213
212
 
214
213
  ---
215
214
 
@@ -22,10 +22,11 @@ class CheckpointerOpts(TypedDict, total=False):
22
22
  format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
23
23
  root_path: Path | str | None
24
24
  when: bool
25
- verbosity: Literal[0, 1]
26
- path: Callable[..., str] | None
25
+ verbosity: Literal[0, 1, 2]
26
+ hash_by: Callable | None
27
27
  should_expire: Callable[[datetime], bool] | None
28
28
  capture: bool
29
+ fn_hash: str | None
29
30
 
30
31
  class Checkpointer:
31
32
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
@@ -33,9 +34,10 @@ class Checkpointer:
33
34
  self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
34
35
  self.when = opts.get("when", True)
35
36
  self.verbosity = opts.get("verbosity", 1)
36
- self.path = opts.get("path")
37
+ self.hash_by = opts.get("hash_by")
37
38
  self.should_expire = opts.get("should_expire")
38
39
  self.capture = opts.get("capture", False)
40
+ self.fn_hash = opts.get("fn_hash")
39
41
 
40
42
  @overload
41
43
  def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CheckpointFn[Fn]: ...
@@ -66,9 +68,9 @@ class CheckpointFn(Generic[Fn]):
66
68
  store_format = self.checkpointer.format
67
69
  Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
68
70
  deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
69
- self.fn_hash = str(ObjectHash().update_hash(self.fn_hash_raw, iter=deep_hashes))
71
+ self.fn_hash = self.checkpointer.fn_hash or str(ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
70
72
  self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
71
- self.is_async = inspect.iscoroutinefunction(wrapped)
73
+ self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
72
74
  self.storage = Storage(self)
73
75
  self.cleanup = self.storage.cleanup
74
76
 
@@ -80,32 +82,29 @@ class CheckpointFn(Generic[Fn]):
80
82
  self._lazyinit()
81
83
  return self._getattribute(name)
82
84
 
83
- def reinit(self, recursive=False):
85
+ def reinit(self, recursive=False) -> CheckpointFn[Fn]:
84
86
  pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
85
87
  for pointfn in pointfns:
86
88
  pointfn._set_ident(True)
87
89
  for pointfn in pointfns:
88
90
  pointfn._lazyinit()
91
+ return self
89
92
 
90
93
  def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
91
- if not callable(self.checkpointer.path):
92
- call_hash = ObjectHash(self.fn_hash, args, kw, digest_size=16)
93
- return f"{self.fn_subdir}/{call_hash}"
94
- checkpoint_id = self.checkpointer.path(*args, **kw)
95
- if not isinstance(checkpoint_id, str):
96
- raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
97
- return checkpoint_id
94
+ hash_params = [self.checkpointer.hash_by(*args, **kw)] if self.checkpointer.hash_by else (args, kw)
95
+ call_hash = ObjectHash(self.fn_hash, *hash_params, digest_size=16)
96
+ return f"{self.fn_subdir}/{call_hash}"
98
97
 
99
98
  async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
100
99
  checkpoint_id = self.get_checkpoint_id(args, kw)
101
100
  checkpoint_path = self.checkpointer.root_path / checkpoint_id
102
- verbose = self.checkpointer.verbosity > 0
101
+ verbosity = self.checkpointer.verbosity
103
102
  refresh = rerun \
104
103
  or not self.storage.exists(checkpoint_path) \
105
104
  or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
106
105
 
107
106
  if refresh:
108
- print_checkpoint(verbose, "MEMORIZING", checkpoint_id, "blue")
107
+ print_checkpoint(verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
109
108
  data = self.fn(*args, **kw)
110
109
  if inspect.iscoroutine(data):
111
110
  data = await data
@@ -114,11 +113,11 @@ class CheckpointFn(Generic[Fn]):
114
113
 
115
114
  try:
116
115
  data = self.storage.load(checkpoint_path)
117
- print_checkpoint(verbose, "REMEMBERED", checkpoint_id, "green")
116
+ print_checkpoint(verbosity >= 2, "REMEMBERED", checkpoint_id, "green")
118
117
  return data
119
118
  except (EOFError, FileNotFoundError):
120
119
  pass
121
- print_checkpoint(verbose, "CORRUPTED", checkpoint_id, "yellow")
120
+ print_checkpoint(verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
122
121
  return await self._store_on_demand(args, kw, True)
123
122
 
124
123
  def _call(self, args: tuple, kw: dict, rerun=False):
@@ -4,14 +4,14 @@ from collections.abc import Callable
4
4
  from itertools import takewhile
5
5
  from pathlib import Path
6
6
  from types import CodeType, FunctionType, MethodType
7
- from typing import Any, Generator, Type, TypeGuard
7
+ from typing import Any, Iterable, Type, TypeGuard
8
8
  from .object_hash import ObjectHash
9
9
  from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
10
10
 
11
11
  cwd = Path.cwd()
12
12
 
13
13
  def is_class(obj) -> TypeGuard[Type]:
14
- # isinstance works too, but needlessly triggers __getattribute__
14
+ # isinstance works too, but needlessly triggers _lazyinit()
15
15
  return issubclass(type(obj), type)
16
16
 
17
17
  def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
@@ -33,7 +33,7 @@ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[st
33
33
  scope_obj = None
34
34
  return classvars
35
35
 
36
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Generator[tuple[tuple[str, ...], Any], None, None]:
36
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], Any]]:
37
37
  classvars = extract_classvars(code, scope_vars)
38
38
  scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
39
39
  for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
@@ -3,7 +3,7 @@ import hashlib
3
3
  import io
4
4
  import re
5
5
  from collections.abc import Iterable
6
- from contextlib import nullcontext
6
+ from contextlib import nullcontext, suppress
7
7
  from decimal import Decimal
8
8
  from itertools import chain
9
9
  from pickle import HIGHEST_PROTOCOL as PROTOCOL
@@ -11,19 +11,18 @@ from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType,
11
11
  from typing import Any, TypeAliasType, TypeVar
12
12
  from .utils import ContextVar, get_fn_body
13
13
 
14
- try:
14
+ np, torch = None, None
15
+
16
+ with suppress(Exception):
15
17
  import numpy as np
16
- except:
17
- np = None
18
- try:
18
+
19
+ with suppress(Exception):
19
20
  import torch
20
- except:
21
- torch = None
22
21
 
23
22
  def encode_type(t: type | FunctionType) -> str:
24
23
  return f"{t.__module__}:{t.__qualname__}"
25
24
 
26
- def encode_val(v: Any) -> str:
25
+ def encode_type_of(v: Any) -> str:
27
26
  return encode_type(type(v))
28
27
 
29
28
  class ObjectHashError(Exception):
@@ -32,11 +31,11 @@ class ObjectHashError(Exception):
32
31
  self.obj = obj
33
32
 
34
33
  class ObjectHash:
35
- def __init__(self, *obj: Any, iter: Iterable[Any] = [], digest_size=64, tolerate_errors=False) -> None:
34
+ def __init__(self, *objs: Any, iter: Iterable[Any] = (), digest_size=64, tolerate_errors=False) -> None:
36
35
  self.hash = hashlib.blake2b(digest_size=digest_size)
37
36
  self.current: dict[int, int] = {}
38
37
  self.tolerate_errors = ContextVar(tolerate_errors)
39
- self.update(iter=chain(obj, iter))
38
+ self.update(iter=chain(objs, iter))
40
39
 
41
40
  def copy(self) -> "ObjectHash":
42
41
  new = ObjectHash(tolerate_errors=self.tolerate_errors.value)
@@ -48,15 +47,21 @@ class ObjectHash:
48
47
 
49
48
  __str__ = hexdigest
50
49
 
51
- def update_hash(self, *data: bytes | str, iter: Iterable[bytes | str] = []) -> "ObjectHash":
50
+ def nested_hash(self, *objs: Any) -> str:
51
+ return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
52
+
53
+ def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
52
54
  for d in chain(data, iter):
53
- self.hash.update(d.encode() if isinstance(d, str) else d)
55
+ self.hash.update(d)
54
56
  return self
55
57
 
58
+ def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
59
+ return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
60
+
56
61
  def header(self, *args: Any) -> "ObjectHash":
57
- return self.update_hash(":".join(map(str, args)))
62
+ return self.write_bytes(":".join(map(str, args)).encode())
58
63
 
59
- def update(self, *objs: Any, iter: Iterable[Any] = [], tolerate_errors: bool | None=None) -> "ObjectHash":
64
+ def update(self, *objs: Any, iter: Iterable[Any] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
60
65
  with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
61
66
  for obj in chain(objs, iter):
62
67
  try:
@@ -74,19 +79,20 @@ class ObjectHash:
74
79
  self.header("null")
75
80
 
76
81
  case bool() | int() | float() | complex() | Decimal() | ObjectHash():
77
- self.header("number", encode_val(obj), obj)
82
+ self.header("number", encode_type_of(obj), obj)
78
83
 
79
84
  case str() | bytes() | bytearray() | memoryview():
80
- self.header("bytes", encode_val(obj), len(obj)).update_hash(obj)
85
+ b = obj.encode() if isinstance(obj, str) else obj
86
+ self.header("bytes", encode_type_of(obj), len(b)).write_bytes(b)
81
87
 
82
88
  case set() | frozenset():
83
- self.header("set", encode_val(obj), len(obj))
84
89
  try:
85
90
  items = sorted(obj)
91
+ header = "set"
86
92
  except:
87
- self.header("unsortable")
88
- items = sorted(str(ObjectHash(item, tolerate_errors=self.tolerate_errors.value)) for item in obj)
89
- self.update(iter=items)
93
+ items = sorted(map(self.nested_hash, obj))
94
+ header = "set-unsortable"
95
+ self.header(header, encode_type_of(obj), len(obj)).update(iter=items)
90
96
 
91
97
  case TypeVar():
92
98
  self.header("TypeVar").update(obj.__name__, obj.__bound__, obj.__constraints__, obj.__contravariant__, obj.__covariant__)
@@ -113,7 +119,7 @@ class ObjectHash:
113
119
  self.header("generator", obj.__qualname__)._update_iterator(obj)
114
120
 
115
121
  case io.TextIOWrapper() | io.FileIO() | io.BufferedRandom() | io.BufferedWriter() | io.BufferedReader():
116
- self.header("file", encode_val(obj)).update(obj.name, obj.mode, obj.tell())
122
+ self.header("file", encode_type_of(obj)).update(obj.name, obj.mode, obj.tell())
117
123
 
118
124
  case type():
119
125
  self.header("type", encode_type(obj))
@@ -122,20 +128,20 @@ class ObjectHash:
122
128
  self.header("dtype").update(obj.__class__, obj.descr)
123
129
 
124
130
  case _ if np and isinstance(obj, np.ndarray):
125
- self.header("ndarray", encode_val(obj), obj.shape, obj.strides).update(obj.dtype)
131
+ self.header("ndarray", encode_type_of(obj), obj.shape, obj.strides).update(obj.dtype)
126
132
  if obj.dtype.hasobject:
127
133
  self.update(obj.__reduce_ex__(PROTOCOL))
128
134
  else:
129
135
  array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
130
- self.update_hash(array.data)
136
+ self.write_bytes(array.data)
131
137
 
132
138
  case _ if torch and isinstance(obj, torch.Tensor):
133
- self.header("tensor", encode_val(obj), str(obj.dtype), tuple(obj.shape), obj.stride(), str(obj.device))
139
+ self.header("tensor", encode_type_of(obj), obj.dtype, tuple(obj.shape), obj.stride(), obj.device)
134
140
  if obj.device.type != "cpu":
135
141
  obj = obj.cpu()
136
142
  storage = obj.storage()
137
- buffer = (ctypes.c_ubyte * (storage.nbytes())).from_address(storage.data_ptr())
138
- self.update_hash(memoryview(buffer))
143
+ buffer = (ctypes.c_ubyte * storage.nbytes()).from_address(storage.data_ptr())
144
+ self.write_bytes(memoryview(buffer))
139
145
 
140
146
  case _ if id(obj) in self.current:
141
147
  self.header("circular", self.current[id(obj)])
@@ -145,36 +151,36 @@ class ObjectHash:
145
151
  self.current[id(obj)] = len(self.current)
146
152
  match obj:
147
153
  case list() | tuple():
148
- self.header("list", encode_val(obj), len(obj)).update(iter=obj)
154
+ self.header("list", encode_type_of(obj), len(obj)).update(iter=obj)
149
155
  case dict():
150
156
  try:
151
157
  items = sorted(obj.items())
158
+ header = "dict"
152
159
  except:
153
- items = sorted((str(ObjectHash(key, tolerate_errors=self.tolerate_errors.value)), val) for key, val in obj.items())
154
- self.header("dict", encode_val(obj), len(obj)).update(iter=chain.from_iterable(items))
160
+ items = sorted((self.nested_hash(key), val) for key, val in obj.items())
161
+ header = "dict-unsortable"
162
+ self.header(header, encode_type_of(obj), len(obj)).update(iter=chain.from_iterable(items))
155
163
  case _:
156
164
  self._update_object(obj)
157
165
  finally:
158
166
  del self.current[id(obj)]
159
167
 
160
- def _update_iterator(self, obj: Iterable) -> None:
161
- self.header("iterator", encode_val(obj)).update(iter=obj).header(b"iterator-end")
168
+ def _update_iterator(self, obj: Iterable) -> "ObjectHash":
169
+ return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
162
170
 
163
171
  def _update_object(self, obj: object) -> "ObjectHash":
164
- self.header("instance", encode_val(obj))
165
- try:
166
- reduced = obj.__reduce_ex__(PROTOCOL) if hasattr(obj, "__reduce_ex__") else obj.__reduce__()
167
- except:
168
- reduced = None
172
+ self.header("instance", encode_type_of(obj))
173
+ reduced = None
174
+ with suppress(Exception):
175
+ reduced = obj.__reduce_ex__(PROTOCOL)
176
+ with suppress(Exception):
177
+ reduced = reduced or obj.__reduce__()
169
178
  if isinstance(reduced, str):
170
179
  return self.header("reduce-str").update(reduced)
171
180
  if reduced:
172
181
  reduced = list(reduced)
173
182
  it = reduced.pop(3) if len(reduced) >= 4 else None
174
- self.header("reduce").update(reduced)
175
- if it is not None:
176
- self._update_iterator(it)
177
- return self
183
+ return self.header("reduce").update(reduced)._update_iterator(it or ())
178
184
  if state := hasattr(obj, "__getstate__") and obj.__getstate__():
179
185
  return self.header("getstate").update(state)
180
186
  if len(getattr(obj, "__slots__", [])):
@@ -18,6 +18,7 @@ class PickleStorage(Storage):
18
18
  return get_path(path).exists()
19
19
 
20
20
  def checkpoint_date(self, path):
21
+ # Should use st_atime/access time?
21
22
  return datetime.fromtimestamp(get_path(path).stat().st_mtime)
22
23
 
23
24
  def load(self, path):
@@ -83,17 +83,6 @@ async def test_async_caching():
83
83
 
84
84
  assert result1 == result2 == 9
85
85
 
86
- def test_custom_path_caching():
87
- def custom_path(a, b):
88
- return f"add/{a}-{b}"
89
-
90
- @checkpoint(path=custom_path)
91
- def add(a, b):
92
- return a + b
93
-
94
- add(3, 4)
95
- assert (checkpoint.root_path / "add/3-4.pkl").exists()
96
-
97
86
  def test_force_recalculation():
98
87
  @checkpoint
99
88
  def square(x: int) -> int:
@@ -23,7 +23,7 @@ def get_fn_body(fn: Callable) -> str:
23
23
  ignore_types = (tokenize.COMMENT, tokenize.NL)
24
24
  return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
25
25
 
26
- def get_cell_contents(fn: Callable) -> Generator[tuple[str, Any], None, None]:
26
+ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
27
27
  for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
28
28
  try:
29
29
  yield (key, cell.cell_contents)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.5.0"
3
+ version = "2.6.1"
4
4
  requires-python = ">=3.12"
5
5
  dependencies = []
6
6
  authors = [
@@ -9,6 +9,10 @@ authors = [
9
9
  description = "A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation"
10
10
  readme = "README.md"
11
11
  license = {file = "LICENSE"}
12
+ classifiers = [
13
+ "Programming Language :: Python :: 3.12",
14
+ "Programming Language :: Python :: 3.13",
15
+ ]
12
16
 
13
17
  [project.urls]
14
18
  Repository = "https://github.com/Reddan/checkpointer.git"
@@ -1,9 +1,10 @@
1
1
  version = 1
2
+ revision = 1
2
3
  requires-python = ">=3.12"
3
4
 
4
5
  [[package]]
5
6
  name = "checkpointer"
6
- version = "2.5.0"
7
+ version = "2.6.1"
7
8
  source = { editable = "." }
8
9
 
9
10
  [package.dev-dependencies]
@@ -177,7 +178,6 @@ name = "nvidia-cublas-cu12"
177
178
  version = "12.4.5.8"
178
179
  source = { registry = "https://pypi.org/simple" }
179
180
  wheels = [
180
- { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 },
181
181
  { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 },
182
182
  ]
183
183
 
@@ -186,7 +186,6 @@ name = "nvidia-cuda-cupti-cu12"
186
186
  version = "12.4.127"
187
187
  source = { registry = "https://pypi.org/simple" }
188
188
  wheels = [
189
- { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 },
190
189
  { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 },
191
190
  ]
192
191
 
@@ -195,7 +194,6 @@ name = "nvidia-cuda-nvrtc-cu12"
195
194
  version = "12.4.127"
196
195
  source = { registry = "https://pypi.org/simple" }
197
196
  wheels = [
198
- { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 },
199
197
  { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 },
200
198
  ]
201
199
 
@@ -204,7 +202,6 @@ name = "nvidia-cuda-runtime-cu12"
204
202
  version = "12.4.127"
205
203
  source = { registry = "https://pypi.org/simple" }
206
204
  wheels = [
207
- { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 },
208
205
  { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 },
209
206
  ]
210
207
 
@@ -227,7 +224,6 @@ dependencies = [
227
224
  { name = "nvidia-nvjitlink-cu12" },
228
225
  ]
229
226
  wheels = [
230
- { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 },
231
227
  { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
232
228
  ]
233
229
 
@@ -236,7 +232,6 @@ name = "nvidia-curand-cu12"
236
232
  version = "10.3.5.147"
237
233
  source = { registry = "https://pypi.org/simple" }
238
234
  wheels = [
239
- { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 },
240
235
  { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 },
241
236
  ]
242
237
 
@@ -250,7 +245,6 @@ dependencies = [
250
245
  { name = "nvidia-nvjitlink-cu12" },
251
246
  ]
252
247
  wheels = [
253
- { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 },
254
248
  { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
255
249
  ]
256
250
 
@@ -262,7 +256,6 @@ dependencies = [
262
256
  { name = "nvidia-nvjitlink-cu12" },
263
257
  ]
264
258
  wheels = [
265
- { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 },
266
259
  { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
267
260
  ]
268
261
 
@@ -279,7 +272,6 @@ name = "nvidia-nvjitlink-cu12"
279
272
  version = "12.4.127"
280
273
  source = { registry = "https://pypi.org/simple" }
281
274
  wheels = [
282
- { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 },
283
275
  { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 },
284
276
  ]
285
277
 
@@ -288,7 +280,6 @@ name = "nvidia-nvtx-cu12"
288
280
  version = "12.4.127"
289
281
  source = { registry = "https://pypi.org/simple" }
290
282
  wheels = [
291
- { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 },
292
283
  { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
293
284
  ]
294
285
 
@@ -459,21 +450,21 @@ dependencies = [
459
450
  { name = "fsspec" },
460
451
  { name = "jinja2" },
461
452
  { name = "networkx" },
462
- { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
463
- { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
464
- { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
465
- { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
466
- { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
467
- { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
468
- { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
469
- { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
470
- { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
471
- { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
472
- { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
473
- { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
453
+ { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
454
+ { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
455
+ { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
456
+ { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
457
+ { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
458
+ { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
459
+ { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
460
+ { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
461
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
462
+ { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
463
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
464
+ { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
474
465
  { name = "setuptools" },
475
466
  { name = "sympy" },
476
- { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" },
467
+ { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" },
477
468
  { name = "typing-extensions" },
478
469
  ]
479
470
  wheels = [
File without changes
File without changes