checkpointer 2.1.0__tar.gz → 2.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. checkpointer-2.6.0/.python-version +1 -0
  2. {checkpointer-2.1.0 → checkpointer-2.6.0}/LICENSE +1 -1
  3. {checkpointer-2.1.0 → checkpointer-2.6.0}/PKG-INFO +36 -23
  4. {checkpointer-2.1.0 → checkpointer-2.6.0}/README.md +30 -19
  5. checkpointer-2.6.0/checkpointer/__init__.py +20 -0
  6. {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/checkpoint.py +65 -32
  7. checkpointer-2.6.0/checkpointer/fn_ident.py +94 -0
  8. checkpointer-2.6.0/checkpointer/object_hash.py +192 -0
  9. {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/storages/__init__.py +1 -1
  10. {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/storages/bcolz_storage.py +6 -7
  11. checkpointer-2.6.0/checkpointer/storages/memory_storage.py +39 -0
  12. checkpointer-2.6.0/checkpointer/storages/pickle_storage.py +45 -0
  13. checkpointer-2.1.0/checkpointer/types.py → checkpointer-2.6.0/checkpointer/storages/storage.py +9 -5
  14. checkpointer-2.6.0/checkpointer/test_checkpointer.py +159 -0
  15. checkpointer-2.6.0/checkpointer/utils.py +112 -0
  16. {checkpointer-2.1.0 → checkpointer-2.6.0}/pyproject.toml +21 -4
  17. checkpointer-2.6.0/uv.lock +519 -0
  18. checkpointer-2.1.0/checkpointer/__init__.py +0 -10
  19. checkpointer-2.1.0/checkpointer/function_body.py +0 -80
  20. checkpointer-2.1.0/checkpointer/storages/memory_storage.py +0 -25
  21. checkpointer-2.1.0/checkpointer/storages/pickle_storage.py +0 -31
  22. checkpointer-2.1.0/checkpointer/utils.py +0 -52
  23. checkpointer-2.1.0/uv.lock +0 -22
  24. {checkpointer-2.1.0 → checkpointer-2.6.0}/.gitignore +0 -0
  25. {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/print_checkpoint.py +0 -0
@@ -0,0 +1 @@
1
+ 3.12.7
@@ -1,4 +1,4 @@
1
- Copyright 2024 Hampus Hallman
1
+ Copyright 2018-2025 Hampus Hallman
2
2
 
3
3
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
4
 
@@ -1,18 +1,20 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.1.0
3
+ Version: 2.6.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
7
- License: Copyright 2024 Hampus Hallman
7
+ License: Copyright 2018-2025 Hampus Hallman
8
8
 
9
9
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
10
 
11
11
  The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
12
 
13
13
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
14
+ License-File: LICENSE
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
14
17
  Requires-Python: >=3.12
15
- Requires-Dist: relib
16
18
  Description-Content-Type: text/markdown
17
19
 
18
20
  # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![Python 3.12](https://img.shields.io/badge/python-3.12-blue)](https://pypi.org/project/checkpointer/)
@@ -28,6 +30,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
28
30
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
29
31
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
30
32
  - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
33
+ - ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
31
34
 
32
35
  ---
33
36
 
@@ -61,9 +64,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
61
64
 
62
65
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
63
66
 
64
- 1. **Its source code**: Changes to the function's code update its hash.
65
- 2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
66
- 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
67
+ 1. **Function Code**: The hash updates when the functions own source code changes.
68
+ 2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
69
+ 3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
67
70
 
68
71
  ### Example: Cache Invalidation
69
72
 
@@ -108,7 +111,7 @@ Layer caches by stacking checkpoints:
108
111
  @dev_checkpoint # Adds caching during development
109
112
  def some_expensive_function():
110
113
  print("Performing a time-consuming operation...")
111
- return sum(i * i for i in range(10**6))
114
+ return sum(i * i for i in range(10**8))
112
115
  ```
113
116
 
114
117
  - **In development**: Both `dev_checkpoint` and `memory` caches are active.
@@ -153,11 +156,21 @@ Access cached results without recalculating:
153
156
  stored_result = expensive_function.get(4)
154
157
  ```
155
158
 
159
+ ### Refresh Function Hash
160
+
161
+ If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
162
+
163
+ ```python
164
+ expensive_function.reinit()
165
+ ```
166
+
167
+ This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
168
+
156
169
  ---
157
170
 
158
171
  ## Storage Backends
159
172
 
160
- `checkpointer` works with both built-in and custom storage backends, so you can use what's provided or roll your own as needed.
173
+ `checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
161
174
 
162
175
  ### Built-In Backends
163
176
 
@@ -189,10 +202,10 @@ from checkpointer import checkpoint, Storage
189
202
  from datetime import datetime
190
203
 
191
204
  class CustomStorage(Storage):
192
- def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
193
- def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
194
- def store(self, path, data): ... # Save the checkpoint data
195
- def load(self, path): ... # Return the checkpoint data
205
+ def exists(self, path) -> bool: ... # Check if a checkpoint exists
206
+ def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
207
+ def store(self, path, data): ... # Save data to the checkpoint
208
+ def load(self, path): ... # Load data from the checkpoint
196
209
  def delete(self, path): ... # Delete the checkpoint
197
210
 
198
211
  @checkpoint(format=CustomStorage)
@@ -200,21 +213,21 @@ def custom_cached(x: int):
200
213
  return x ** 2
201
214
  ```
202
215
 
203
- Using a custom backend lets you tailor storage to your application, whether it involves databases, cloud storage, or custom file formats.
216
+ Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
204
217
 
205
218
  ---
206
219
 
207
220
  ## Configuration Options ⚙️
208
221
 
209
- | Option | Type | Default | Description |
210
- |-----------------|-----------------------------------|----------------------|------------------------------------------------|
211
- | `capture` | `bool` | `False` | Include captured variables in function hashes. |
212
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
213
- | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
214
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
215
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
216
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
217
- | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
222
+ | Option | Type | Default | Description |
223
+ |-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
224
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
225
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
226
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
227
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
228
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
229
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
230
+ | `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
218
231
 
219
232
  ---
220
233
 
@@ -11,6 +11,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
11
11
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
12
12
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
13
13
  - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
14
+ - ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
14
15
 
15
16
  ---
16
17
 
@@ -44,9 +45,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
44
45
 
45
46
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
46
47
 
47
- 1. **Its source code**: Changes to the function's code update its hash.
48
- 2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
49
- 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
48
+ 1. **Function Code**: The hash updates when the functions own source code changes.
49
+ 2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
50
+ 3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
50
51
 
51
52
  ### Example: Cache Invalidation
52
53
 
@@ -91,7 +92,7 @@ Layer caches by stacking checkpoints:
91
92
  @dev_checkpoint # Adds caching during development
92
93
  def some_expensive_function():
93
94
  print("Performing a time-consuming operation...")
94
- return sum(i * i for i in range(10**6))
95
+ return sum(i * i for i in range(10**8))
95
96
  ```
96
97
 
97
98
  - **In development**: Both `dev_checkpoint` and `memory` caches are active.
@@ -136,11 +137,21 @@ Access cached results without recalculating:
136
137
  stored_result = expensive_function.get(4)
137
138
  ```
138
139
 
140
+ ### Refresh Function Hash
141
+
142
+ If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
143
+
144
+ ```python
145
+ expensive_function.reinit()
146
+ ```
147
+
148
+ This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
149
+
139
150
  ---
140
151
 
141
152
  ## Storage Backends
142
153
 
143
- `checkpointer` works with both built-in and custom storage backends, so you can use what's provided or roll your own as needed.
154
+ `checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
144
155
 
145
156
  ### Built-In Backends
146
157
 
@@ -172,10 +183,10 @@ from checkpointer import checkpoint, Storage
172
183
  from datetime import datetime
173
184
 
174
185
  class CustomStorage(Storage):
175
- def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
176
- def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
177
- def store(self, path, data): ... # Save the checkpoint data
178
- def load(self, path): ... # Return the checkpoint data
186
+ def exists(self, path) -> bool: ... # Check if a checkpoint exists
187
+ def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
188
+ def store(self, path, data): ... # Save data to the checkpoint
189
+ def load(self, path): ... # Load data from the checkpoint
179
190
  def delete(self, path): ... # Delete the checkpoint
180
191
 
181
192
  @checkpoint(format=CustomStorage)
@@ -183,21 +194,21 @@ def custom_cached(x: int):
183
194
  return x ** 2
184
195
  ```
185
196
 
186
- Using a custom backend lets you tailor storage to your application, whether it involves databases, cloud storage, or custom file formats.
197
+ Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
187
198
 
188
199
  ---
189
200
 
190
201
  ## Configuration Options ⚙️
191
202
 
192
- | Option | Type | Default | Description |
193
- |-----------------|-----------------------------------|----------------------|------------------------------------------------|
194
- | `capture` | `bool` | `False` | Include captured variables in function hashes. |
195
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
196
- | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
197
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
198
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
199
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
200
- | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
203
+ | Option | Type | Default | Description |
204
+ |-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
205
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
206
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
207
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
208
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
209
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
210
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
211
+ | `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
201
212
 
202
213
  ---
203
214
 
@@ -0,0 +1,20 @@
1
+ import gc
2
+ import tempfile
3
+ from typing import Callable
4
+ from .checkpoint import Checkpointer, CheckpointError, CheckpointFn
5
+ from .object_hash import ObjectHash
6
+ from .storages import MemoryStorage, PickleStorage, Storage
7
+
8
+ create_checkpointer = Checkpointer
9
+ checkpoint = Checkpointer()
10
+ capture_checkpoint = Checkpointer(capture=True)
11
+ memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
+ tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
13
+
14
+ def cleanup_all(invalidated=True, expired=True):
15
+ for obj in gc.get_objects():
16
+ if isinstance(obj, CheckpointFn):
17
+ obj.cleanup(invalidated=invalidated, expired=expired)
18
+
19
+ def get_function_hash(fn: Callable, capture=False) -> str:
20
+ return CheckpointFn(Checkpointer(capture=capture), fn).fn_hash
@@ -1,15 +1,15 @@
1
1
  from __future__ import annotations
2
2
  import inspect
3
- import relib.hashing as hashing
4
- from typing import Generic, TypeVar, Type, TypedDict, Callable, Unpack, Literal, Any, cast, overload
5
- from pathlib import Path
3
+ import re
6
4
  from datetime import datetime
7
5
  from functools import update_wrapper
8
- from .types import Storage
9
- from .function_body import get_function_hash
10
- from .utils import unwrap_fn, sync_resolve_coroutine, resolved_awaitable
11
- from .storages import STORAGE_MAP
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
8
+ from .fn_ident import get_fn_ident
9
+ from .object_hash import ObjectHash
12
10
  from .print_checkpoint import print_checkpoint
11
+ from .storages import STORAGE_MAP, Storage
12
+ from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
13
13
 
14
14
  Fn = TypeVar("Fn", bound=Callable)
15
15
 
@@ -23,7 +23,7 @@ class CheckpointerOpts(TypedDict, total=False):
23
23
  root_path: Path | str | None
24
24
  when: bool
25
25
  verbosity: Literal[0, 1]
26
- path: Callable[..., str] | None
26
+ hash_by: Callable | None
27
27
  should_expire: Callable[[datetime], bool] | None
28
28
  capture: bool
29
29
 
@@ -33,7 +33,7 @@ class Checkpointer:
33
33
  self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
34
34
  self.when = opts.get("when", True)
35
35
  self.verbosity = opts.get("verbosity", 1)
36
- self.path = opts.get("path")
36
+ self.hash_by = opts.get("hash_by")
37
37
  self.should_expire = opts.get("should_expire")
38
38
  self.capture = opts.get("capture", False)
39
39
 
@@ -50,37 +50,58 @@ class Checkpointer:
50
50
 
51
51
  class CheckpointFn(Generic[Fn]):
52
52
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
53
- wrapped = unwrap_fn(fn)
54
- file_name = Path(wrapped.__code__.co_filename).name
55
- update_wrapper(cast(Callable, self), wrapped)
56
- storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
57
53
  self.checkpointer = checkpointer
58
54
  self.fn = fn
59
- self.fn_hash, self.depends = get_function_hash(wrapped, self.checkpointer.capture)
60
- self.fn_id = f"{file_name}/{wrapped.__name__}"
55
+
56
+ def _set_ident(self, force=False):
57
+ if not hasattr(self, "fn_hash_raw") or force:
58
+ self.fn_hash_raw, self.depends = get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
59
+ return self
60
+
61
+ def _lazyinit(self):
62
+ wrapped = unwrap_fn(self.fn)
63
+ fn_file = Path(wrapped.__code__.co_filename).name
64
+ fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
65
+ update_wrapper(cast(Callable, self), wrapped)
66
+ store_format = self.checkpointer.format
67
+ Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
68
+ deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
69
+ self.fn_hash = str(ObjectHash().write_text(self.fn_hash_raw, iter=deep_hashes))
70
+ self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
61
71
  self.is_async = inspect.iscoroutinefunction(wrapped)
62
- self.storage = storage(checkpointer)
72
+ self.storage = Storage(self)
73
+ self.cleanup = self.storage.cleanup
74
+
75
+ def __getattribute__(self, name: str) -> Any:
76
+ return object.__getattribute__(self, "_getattribute")(name)
77
+
78
+ def _getattribute(self, name: str) -> Any:
79
+ setattr(self, "_getattribute", super().__getattribute__)
80
+ self._lazyinit()
81
+ return self._getattribute(name)
82
+
83
+ def reinit(self, recursive=False):
84
+ pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
85
+ for pointfn in pointfns:
86
+ pointfn._set_ident(True)
87
+ for pointfn in pointfns:
88
+ pointfn._lazyinit()
63
89
 
64
90
  def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
65
- if not callable(self.checkpointer.path):
66
- # TODO: use digest size before digesting instead of truncating the hash
67
- call_hash = hashing.hash((self.fn_hash, args, kw), "blake2b")[:32]
68
- return f"{self.fn_id}/{call_hash}"
69
- checkpoint_id = self.checkpointer.path(*args, **kw)
70
- if not isinstance(checkpoint_id, str):
71
- raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
72
- return checkpoint_id
91
+ hash_params = [self.checkpointer.hash_by(*args, **kw)] if self.checkpointer.hash_by else (args, kw)
92
+ call_hash = ObjectHash(self.fn_hash, *hash_params, digest_size=16)
93
+ return f"{self.fn_subdir}/{call_hash}"
73
94
 
74
95
  async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
75
96
  checkpoint_id = self.get_checkpoint_id(args, kw)
76
97
  checkpoint_path = self.checkpointer.root_path / checkpoint_id
77
- should_log = self.checkpointer.verbosity > 0
98
+ verbose = self.checkpointer.verbosity > 0
78
99
  refresh = rerun \
79
100
  or not self.storage.exists(checkpoint_path) \
80
101
  or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
81
102
 
82
103
  if refresh:
83
- print_checkpoint(should_log, "MEMORIZING", checkpoint_id, "blue")
104
+ print_checkpoint(verbose, "MEMORIZING", checkpoint_id, "blue")
84
105
  data = self.fn(*args, **kw)
85
106
  if inspect.iscoroutine(data):
86
107
  data = await data
@@ -89,12 +110,12 @@ class CheckpointFn(Generic[Fn]):
89
110
 
90
111
  try:
91
112
  data = self.storage.load(checkpoint_path)
92
- print_checkpoint(should_log, "REMEMBERED", checkpoint_id, "green")
113
+ print_checkpoint(verbose, "REMEMBERED", checkpoint_id, "green")
93
114
  return data
94
115
  except (EOFError, FileNotFoundError):
95
- print_checkpoint(should_log, "CORRUPTED", checkpoint_id, "yellow")
96
- self.storage.delete(checkpoint_path)
97
- return await self._store_on_demand(args, kw, rerun)
116
+ pass
117
+ print_checkpoint(verbose, "CORRUPTED", checkpoint_id, "yellow")
118
+ return await self._store_on_demand(args, kw, True)
98
119
 
99
120
  def _call(self, args: tuple, kw: dict, rerun=False):
100
121
  if not self.checkpointer.when:
@@ -107,8 +128,8 @@ class CheckpointFn(Generic[Fn]):
107
128
  try:
108
129
  val = self.storage.load(checkpoint_path)
109
130
  return resolved_awaitable(val) if self.is_async else val
110
- except:
111
- raise CheckpointError("Could not load checkpoint")
131
+ except Exception as ex:
132
+ raise CheckpointError("Could not load checkpoint") from ex
112
133
 
113
134
  def exists(self, *args: tuple, **kw: dict) -> bool:
114
135
  return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
@@ -116,3 +137,15 @@ class CheckpointFn(Generic[Fn]):
116
137
  __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
117
138
  rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
118
139
  get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
140
+
141
+ def __repr__(self) -> str:
142
+ return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
143
+
144
+ def iterate_checkpoint_fns(pointfn: CheckpointFn, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
145
+ visited = visited or set()
146
+ if pointfn not in visited:
147
+ yield pointfn
148
+ visited.add(pointfn)
149
+ for depend in pointfn.depends:
150
+ if isinstance(depend, CheckpointFn):
151
+ yield from iterate_checkpoint_fns(depend, visited)
@@ -0,0 +1,94 @@
1
+ import dis
2
+ import inspect
3
+ from collections.abc import Callable
4
+ from itertools import takewhile
5
+ from pathlib import Path
6
+ from types import CodeType, FunctionType, MethodType
7
+ from typing import Any, Generator, Type, TypeGuard
8
+ from .object_hash import ObjectHash
9
+ from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
10
+
11
+ cwd = Path.cwd()
12
+
13
+ def is_class(obj) -> TypeGuard[Type]:
14
+ # isinstance works too, but needlessly triggers _lazyinit()
15
+ return issubclass(type(obj), type)
16
+
17
+ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
18
+ attr_path: tuple[str, ...] = ()
19
+ scope_obj = None
20
+ classvars: dict[str, dict[str, Type]] = {}
21
+ for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
22
+ if instr.opname in scope_vars and not attr_path:
23
+ attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
24
+ attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
25
+ elif instr.opname == "CALL":
26
+ obj = scope_vars.get_at(attr_path)
27
+ attr_path = ()
28
+ if is_class(obj):
29
+ scope_obj = obj
30
+ elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
31
+ load_key = instr.opname.replace("STORE", "LOAD")
32
+ classvars.setdefault(load_key, {})[instr.argval] = scope_obj
33
+ scope_obj = None
34
+ return classvars
35
+
36
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Generator[tuple[tuple[str, ...], Any], None, None]:
37
+ classvars = extract_classvars(code, scope_vars)
38
+ scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
39
+ for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
40
+ if instr.opname in scope_vars:
41
+ attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
42
+ attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
43
+ val = scope_vars.get_at(attr_path)
44
+ if val is not None:
45
+ yield attr_path, val
46
+ for const in code.co_consts:
47
+ if isinstance(const, CodeType):
48
+ yield from extract_scope_values(const, scope_vars)
49
+
50
+ def get_self_value(fn: Callable) -> type | object | None:
51
+ if isinstance(fn, MethodType):
52
+ return fn.__self__
53
+ parts = tuple(fn.__qualname__.split(".")[:-1])
54
+ cls = parts and AttrDict(fn.__globals__).get_at(parts)
55
+ if is_class(cls):
56
+ return cls
57
+
58
+ def get_fn_captured_vals(fn: Callable) -> list[Any]:
59
+ self_value = get_self_value(fn)
60
+ scope_vars = AttrDict({
61
+ "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
62
+ "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
63
+ "LOAD_GLOBAL": AttrDict(fn.__globals__),
64
+ })
65
+ vals = dict(extract_scope_values(fn.__code__, scope_vars))
66
+ return list(vals.values())
67
+
68
+ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
69
+ if not isinstance(candidate_fn, (FunctionType, MethodType)):
70
+ return False
71
+ fn_path = Path(inspect.getfile(candidate_fn)).resolve()
72
+ return cwd in fn_path.parents and ".venv" not in fn_path.parts
73
+
74
+ def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
75
+ from .checkpoint import CheckpointFn
76
+ captured_vals_by_fn = captured_vals_by_fn or {}
77
+ captured_vals = get_fn_captured_vals(fn)
78
+ captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
79
+ child_fns = (unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val))
80
+ for child_fn in child_fns:
81
+ if isinstance(child_fn, CheckpointFn):
82
+ captured_vals_by_fn[child_fn] = []
83
+ elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
84
+ get_depend_fns(child_fn, capture, captured_vals_by_fn)
85
+ return captured_vals_by_fn
86
+
87
+ def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
88
+ from .checkpoint import CheckpointFn
89
+ captured_vals_by_fn = get_depend_fns(fn, capture)
90
+ depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
91
+ depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
92
+ unwrapped_depends = [fn for fn in depends if not isinstance(fn, CheckpointFn)]
93
+ fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
94
+ return fn_hash, depends