checkpointer 2.1.0__tar.gz → 2.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer-2.6.0/.python-version +1 -0
- {checkpointer-2.1.0 → checkpointer-2.6.0}/LICENSE +1 -1
- {checkpointer-2.1.0 → checkpointer-2.6.0}/PKG-INFO +36 -23
- {checkpointer-2.1.0 → checkpointer-2.6.0}/README.md +30 -19
- checkpointer-2.6.0/checkpointer/__init__.py +20 -0
- {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/checkpoint.py +65 -32
- checkpointer-2.6.0/checkpointer/fn_ident.py +94 -0
- checkpointer-2.6.0/checkpointer/object_hash.py +192 -0
- {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/storages/__init__.py +1 -1
- {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/storages/bcolz_storage.py +6 -7
- checkpointer-2.6.0/checkpointer/storages/memory_storage.py +39 -0
- checkpointer-2.6.0/checkpointer/storages/pickle_storage.py +45 -0
- checkpointer-2.1.0/checkpointer/types.py → checkpointer-2.6.0/checkpointer/storages/storage.py +9 -5
- checkpointer-2.6.0/checkpointer/test_checkpointer.py +159 -0
- checkpointer-2.6.0/checkpointer/utils.py +112 -0
- {checkpointer-2.1.0 → checkpointer-2.6.0}/pyproject.toml +21 -4
- checkpointer-2.6.0/uv.lock +519 -0
- checkpointer-2.1.0/checkpointer/__init__.py +0 -10
- checkpointer-2.1.0/checkpointer/function_body.py +0 -80
- checkpointer-2.1.0/checkpointer/storages/memory_storage.py +0 -25
- checkpointer-2.1.0/checkpointer/storages/pickle_storage.py +0 -31
- checkpointer-2.1.0/checkpointer/utils.py +0 -52
- checkpointer-2.1.0/uv.lock +0 -22
- {checkpointer-2.1.0 → checkpointer-2.6.0}/.gitignore +0 -0
- {checkpointer-2.1.0 → checkpointer-2.6.0}/checkpointer/print_checkpoint.py +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
3.12.7
|
@@ -1,4 +1,4 @@
|
|
1
|
-
Copyright
|
1
|
+
Copyright 2018-2025 Hampus Hallman
|
2
2
|
|
3
3
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4
4
|
|
@@ -1,18 +1,20 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.6.0
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
7
|
-
License: Copyright
|
7
|
+
License: Copyright 2018-2025 Hampus Hallman
|
8
8
|
|
9
9
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
10
10
|
|
11
11
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
12
12
|
|
13
13
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
14
|
+
License-File: LICENSE
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
14
17
|
Requires-Python: >=3.12
|
15
|
-
Requires-Dist: relib
|
16
18
|
Description-Content-Type: text/markdown
|
17
19
|
|
18
20
|
# checkpointer · [](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [](https://pypi.org/project/checkpointer/) [](https://pypi.org/project/checkpointer/)
|
@@ -28,6 +30,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
|
|
28
30
|
- ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
|
29
31
|
- 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
|
30
32
|
- 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
|
33
|
+
- ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
|
31
34
|
|
32
35
|
---
|
33
36
|
|
@@ -61,9 +64,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
|
|
61
64
|
|
62
65
|
Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
|
63
66
|
|
64
|
-
1. **
|
65
|
-
2. **
|
66
|
-
3. **
|
67
|
+
1. **Function Code**: The hash updates when the function’s own source code changes.
|
68
|
+
2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
|
69
|
+
3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
|
67
70
|
|
68
71
|
### Example: Cache Invalidation
|
69
72
|
|
@@ -108,7 +111,7 @@ Layer caches by stacking checkpoints:
|
|
108
111
|
@dev_checkpoint # Adds caching during development
|
109
112
|
def some_expensive_function():
|
110
113
|
print("Performing a time-consuming operation...")
|
111
|
-
return sum(i * i for i in range(10**
|
114
|
+
return sum(i * i for i in range(10**8))
|
112
115
|
```
|
113
116
|
|
114
117
|
- **In development**: Both `dev_checkpoint` and `memory` caches are active.
|
@@ -153,11 +156,21 @@ Access cached results without recalculating:
|
|
153
156
|
stored_result = expensive_function.get(4)
|
154
157
|
```
|
155
158
|
|
159
|
+
### Refresh Function Hash
|
160
|
+
|
161
|
+
If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
|
162
|
+
|
163
|
+
```python
|
164
|
+
expensive_function.reinit()
|
165
|
+
```
|
166
|
+
|
167
|
+
This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
|
168
|
+
|
156
169
|
---
|
157
170
|
|
158
171
|
## Storage Backends
|
159
172
|
|
160
|
-
`checkpointer` works with
|
173
|
+
`checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
|
161
174
|
|
162
175
|
### Built-In Backends
|
163
176
|
|
@@ -189,10 +202,10 @@ from checkpointer import checkpoint, Storage
|
|
189
202
|
from datetime import datetime
|
190
203
|
|
191
204
|
class CustomStorage(Storage):
|
192
|
-
def exists(self, path) -> bool: ... # Check if a checkpoint exists
|
193
|
-
def checkpoint_date(self, path) -> datetime: ... #
|
194
|
-
def store(self, path, data): ... # Save the checkpoint
|
195
|
-
def load(self, path): ... #
|
205
|
+
def exists(self, path) -> bool: ... # Check if a checkpoint exists
|
206
|
+
def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
|
207
|
+
def store(self, path, data): ... # Save data to the checkpoint
|
208
|
+
def load(self, path): ... # Load data from the checkpoint
|
196
209
|
def delete(self, path): ... # Delete the checkpoint
|
197
210
|
|
198
211
|
@checkpoint(format=CustomStorage)
|
@@ -200,21 +213,21 @@ def custom_cached(x: int):
|
|
200
213
|
return x ** 2
|
201
214
|
```
|
202
215
|
|
203
|
-
|
216
|
+
Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
|
204
217
|
|
205
218
|
---
|
206
219
|
|
207
220
|
## Configuration Options ⚙️
|
208
221
|
|
209
|
-
| Option | Type
|
210
|
-
|
211
|
-
| `capture` | `bool`
|
212
|
-
| `format` | `"pickle"`, `"memory"`, `Storage`
|
213
|
-
| `root_path` | `Path`, `str`, or `None`
|
214
|
-
| `when` | `bool`
|
215
|
-
| `verbosity` | `0` or `1`
|
216
|
-
| `
|
217
|
-
| `
|
222
|
+
| Option | Type | Default | Description |
|
223
|
+
|-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
|
224
|
+
| `capture` | `bool` | `False` | Include captured variables in function hashes. |
|
225
|
+
| `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
|
226
|
+
| `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
|
227
|
+
| `when` | `bool` | `True` | Enable or disable checkpointing. |
|
228
|
+
| `verbosity` | `0` or `1` | `1` | Logging verbosity. |
|
229
|
+
| `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
|
230
|
+
| `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
|
218
231
|
|
219
232
|
---
|
220
233
|
|
@@ -11,6 +11,7 @@ Adding or removing `@checkpoint` doesn't change how your code works. You can app
|
|
11
11
|
- ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
|
12
12
|
- 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
|
13
13
|
- 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
|
14
|
+
- ⚡ **Custom Argument Hashing**: Override argument hashing for speed or specialized hashing logic.
|
14
15
|
|
15
16
|
---
|
16
17
|
|
@@ -44,9 +45,9 @@ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are
|
|
44
45
|
|
45
46
|
Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
|
46
47
|
|
47
|
-
1. **
|
48
|
-
2. **
|
49
|
-
3. **
|
48
|
+
1. **Function Code**: The hash updates when the function’s own source code changes.
|
49
|
+
2. **Dependencies**: If the function calls other user-defined functions, changes in those dependencies also update the hash.
|
50
|
+
3. **External Variables** *(with `capture=True`)*: Any global or closure-based variables used by the function are included in its hash, so changes to those variables also trigger cache invalidation.
|
50
51
|
|
51
52
|
### Example: Cache Invalidation
|
52
53
|
|
@@ -91,7 +92,7 @@ Layer caches by stacking checkpoints:
|
|
91
92
|
@dev_checkpoint # Adds caching during development
|
92
93
|
def some_expensive_function():
|
93
94
|
print("Performing a time-consuming operation...")
|
94
|
-
return sum(i * i for i in range(10**
|
95
|
+
return sum(i * i for i in range(10**8))
|
95
96
|
```
|
96
97
|
|
97
98
|
- **In development**: Both `dev_checkpoint` and `memory` caches are active.
|
@@ -136,11 +137,21 @@ Access cached results without recalculating:
|
|
136
137
|
stored_result = expensive_function.get(4)
|
137
138
|
```
|
138
139
|
|
140
|
+
### Refresh Function Hash
|
141
|
+
|
142
|
+
If `capture=True`, you might need to re-hash a function during the same Python session. For that, call `reinit`:
|
143
|
+
|
144
|
+
```python
|
145
|
+
expensive_function.reinit()
|
146
|
+
```
|
147
|
+
|
148
|
+
This tells `checkpointer` to recalculate the function hash, reflecting changes in captured variables.
|
149
|
+
|
139
150
|
---
|
140
151
|
|
141
152
|
## Storage Backends
|
142
153
|
|
143
|
-
`checkpointer` works with
|
154
|
+
`checkpointer` works with built-in and custom storage backends, so you can use what's provided or roll your own as needed.
|
144
155
|
|
145
156
|
### Built-In Backends
|
146
157
|
|
@@ -172,10 +183,10 @@ from checkpointer import checkpoint, Storage
|
|
172
183
|
from datetime import datetime
|
173
184
|
|
174
185
|
class CustomStorage(Storage):
|
175
|
-
def exists(self, path) -> bool: ... # Check if a checkpoint exists
|
176
|
-
def checkpoint_date(self, path) -> datetime: ... #
|
177
|
-
def store(self, path, data): ... # Save the checkpoint
|
178
|
-
def load(self, path): ... #
|
186
|
+
def exists(self, path) -> bool: ... # Check if a checkpoint exists
|
187
|
+
def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
|
188
|
+
def store(self, path, data): ... # Save data to the checkpoint
|
189
|
+
def load(self, path): ... # Load data from the checkpoint
|
179
190
|
def delete(self, path): ... # Delete the checkpoint
|
180
191
|
|
181
192
|
@checkpoint(format=CustomStorage)
|
@@ -183,21 +194,21 @@ def custom_cached(x: int):
|
|
183
194
|
return x ** 2
|
184
195
|
```
|
185
196
|
|
186
|
-
|
197
|
+
Use a custom backend to integrate with databases, cloud storage, or specialized file formats.
|
187
198
|
|
188
199
|
---
|
189
200
|
|
190
201
|
## Configuration Options ⚙️
|
191
202
|
|
192
|
-
| Option | Type
|
193
|
-
|
194
|
-
| `capture` | `bool`
|
195
|
-
| `format` | `"pickle"`, `"memory"`, `Storage`
|
196
|
-
| `root_path` | `Path`, `str`, or `None`
|
197
|
-
| `when` | `bool`
|
198
|
-
| `verbosity` | `0` or `1`
|
199
|
-
| `
|
200
|
-
| `
|
203
|
+
| Option | Type | Default | Description |
|
204
|
+
|-----------------|-------------------------------------|----------------------|-----------------------------------------------------------|
|
205
|
+
| `capture` | `bool` | `False` | Include captured variables in function hashes. |
|
206
|
+
| `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
|
207
|
+
| `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
|
208
|
+
| `when` | `bool` | `True` | Enable or disable checkpointing. |
|
209
|
+
| `verbosity` | `0` or `1` | `1` | Logging verbosity. |
|
210
|
+
| `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
|
211
|
+
| `hash_by` | `Callable[..., Any]` | `None` | Custom function that transforms arguments before hashing. |
|
201
212
|
|
202
213
|
---
|
203
214
|
|
@@ -0,0 +1,20 @@
|
|
1
|
+
import gc
|
2
|
+
import tempfile
|
3
|
+
from typing import Callable
|
4
|
+
from .checkpoint import Checkpointer, CheckpointError, CheckpointFn
|
5
|
+
from .object_hash import ObjectHash
|
6
|
+
from .storages import MemoryStorage, PickleStorage, Storage
|
7
|
+
|
8
|
+
create_checkpointer = Checkpointer
|
9
|
+
checkpoint = Checkpointer()
|
10
|
+
capture_checkpoint = Checkpointer(capture=True)
|
11
|
+
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
12
|
+
tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
|
13
|
+
|
14
|
+
def cleanup_all(invalidated=True, expired=True):
|
15
|
+
for obj in gc.get_objects():
|
16
|
+
if isinstance(obj, CheckpointFn):
|
17
|
+
obj.cleanup(invalidated=invalidated, expired=expired)
|
18
|
+
|
19
|
+
def get_function_hash(fn: Callable, capture=False) -> str:
|
20
|
+
return CheckpointFn(Checkpointer(capture=capture), fn).fn_hash
|
@@ -1,15 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import inspect
|
3
|
-
import
|
4
|
-
from typing import Generic, TypeVar, Type, TypedDict, Callable, Unpack, Literal, Any, cast, overload
|
5
|
-
from pathlib import Path
|
3
|
+
import re
|
6
4
|
from datetime import datetime
|
7
5
|
from functools import update_wrapper
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from .
|
11
|
-
from .
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
|
8
|
+
from .fn_ident import get_fn_ident
|
9
|
+
from .object_hash import ObjectHash
|
12
10
|
from .print_checkpoint import print_checkpoint
|
11
|
+
from .storages import STORAGE_MAP, Storage
|
12
|
+
from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
|
13
13
|
|
14
14
|
Fn = TypeVar("Fn", bound=Callable)
|
15
15
|
|
@@ -23,7 +23,7 @@ class CheckpointerOpts(TypedDict, total=False):
|
|
23
23
|
root_path: Path | str | None
|
24
24
|
when: bool
|
25
25
|
verbosity: Literal[0, 1]
|
26
|
-
|
26
|
+
hash_by: Callable | None
|
27
27
|
should_expire: Callable[[datetime], bool] | None
|
28
28
|
capture: bool
|
29
29
|
|
@@ -33,7 +33,7 @@ class Checkpointer:
|
|
33
33
|
self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
|
34
34
|
self.when = opts.get("when", True)
|
35
35
|
self.verbosity = opts.get("verbosity", 1)
|
36
|
-
self.
|
36
|
+
self.hash_by = opts.get("hash_by")
|
37
37
|
self.should_expire = opts.get("should_expire")
|
38
38
|
self.capture = opts.get("capture", False)
|
39
39
|
|
@@ -50,37 +50,58 @@ class Checkpointer:
|
|
50
50
|
|
51
51
|
class CheckpointFn(Generic[Fn]):
|
52
52
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
53
|
-
wrapped = unwrap_fn(fn)
|
54
|
-
file_name = Path(wrapped.__code__.co_filename).name
|
55
|
-
update_wrapper(cast(Callable, self), wrapped)
|
56
|
-
storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
|
57
53
|
self.checkpointer = checkpointer
|
58
54
|
self.fn = fn
|
59
|
-
|
60
|
-
|
55
|
+
|
56
|
+
def _set_ident(self, force=False):
|
57
|
+
if not hasattr(self, "fn_hash_raw") or force:
|
58
|
+
self.fn_hash_raw, self.depends = get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
|
59
|
+
return self
|
60
|
+
|
61
|
+
def _lazyinit(self):
|
62
|
+
wrapped = unwrap_fn(self.fn)
|
63
|
+
fn_file = Path(wrapped.__code__.co_filename).name
|
64
|
+
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
65
|
+
update_wrapper(cast(Callable, self), wrapped)
|
66
|
+
store_format = self.checkpointer.format
|
67
|
+
Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
|
68
|
+
deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
|
69
|
+
self.fn_hash = str(ObjectHash().write_text(self.fn_hash_raw, iter=deep_hashes))
|
70
|
+
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
|
61
71
|
self.is_async = inspect.iscoroutinefunction(wrapped)
|
62
|
-
self.storage =
|
72
|
+
self.storage = Storage(self)
|
73
|
+
self.cleanup = self.storage.cleanup
|
74
|
+
|
75
|
+
def __getattribute__(self, name: str) -> Any:
|
76
|
+
return object.__getattribute__(self, "_getattribute")(name)
|
77
|
+
|
78
|
+
def _getattribute(self, name: str) -> Any:
|
79
|
+
setattr(self, "_getattribute", super().__getattribute__)
|
80
|
+
self._lazyinit()
|
81
|
+
return self._getattribute(name)
|
82
|
+
|
83
|
+
def reinit(self, recursive=False):
|
84
|
+
pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
|
85
|
+
for pointfn in pointfns:
|
86
|
+
pointfn._set_ident(True)
|
87
|
+
for pointfn in pointfns:
|
88
|
+
pointfn._lazyinit()
|
63
89
|
|
64
90
|
def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
return f"{self.fn_id}/{call_hash}"
|
69
|
-
checkpoint_id = self.checkpointer.path(*args, **kw)
|
70
|
-
if not isinstance(checkpoint_id, str):
|
71
|
-
raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
|
72
|
-
return checkpoint_id
|
91
|
+
hash_params = [self.checkpointer.hash_by(*args, **kw)] if self.checkpointer.hash_by else (args, kw)
|
92
|
+
call_hash = ObjectHash(self.fn_hash, *hash_params, digest_size=16)
|
93
|
+
return f"{self.fn_subdir}/{call_hash}"
|
73
94
|
|
74
95
|
async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
|
75
96
|
checkpoint_id = self.get_checkpoint_id(args, kw)
|
76
97
|
checkpoint_path = self.checkpointer.root_path / checkpoint_id
|
77
|
-
|
98
|
+
verbose = self.checkpointer.verbosity > 0
|
78
99
|
refresh = rerun \
|
79
100
|
or not self.storage.exists(checkpoint_path) \
|
80
101
|
or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
|
81
102
|
|
82
103
|
if refresh:
|
83
|
-
print_checkpoint(
|
104
|
+
print_checkpoint(verbose, "MEMORIZING", checkpoint_id, "blue")
|
84
105
|
data = self.fn(*args, **kw)
|
85
106
|
if inspect.iscoroutine(data):
|
86
107
|
data = await data
|
@@ -89,12 +110,12 @@ class CheckpointFn(Generic[Fn]):
|
|
89
110
|
|
90
111
|
try:
|
91
112
|
data = self.storage.load(checkpoint_path)
|
92
|
-
print_checkpoint(
|
113
|
+
print_checkpoint(verbose, "REMEMBERED", checkpoint_id, "green")
|
93
114
|
return data
|
94
115
|
except (EOFError, FileNotFoundError):
|
95
|
-
|
96
|
-
|
97
|
-
|
116
|
+
pass
|
117
|
+
print_checkpoint(verbose, "CORRUPTED", checkpoint_id, "yellow")
|
118
|
+
return await self._store_on_demand(args, kw, True)
|
98
119
|
|
99
120
|
def _call(self, args: tuple, kw: dict, rerun=False):
|
100
121
|
if not self.checkpointer.when:
|
@@ -107,8 +128,8 @@ class CheckpointFn(Generic[Fn]):
|
|
107
128
|
try:
|
108
129
|
val = self.storage.load(checkpoint_path)
|
109
130
|
return resolved_awaitable(val) if self.is_async else val
|
110
|
-
except:
|
111
|
-
raise CheckpointError("Could not load checkpoint")
|
131
|
+
except Exception as ex:
|
132
|
+
raise CheckpointError("Could not load checkpoint") from ex
|
112
133
|
|
113
134
|
def exists(self, *args: tuple, **kw: dict) -> bool:
|
114
135
|
return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
|
@@ -116,3 +137,15 @@ class CheckpointFn(Generic[Fn]):
|
|
116
137
|
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
117
138
|
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
118
139
|
get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
|
140
|
+
|
141
|
+
def __repr__(self) -> str:
|
142
|
+
return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
|
143
|
+
|
144
|
+
def iterate_checkpoint_fns(pointfn: CheckpointFn, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
|
145
|
+
visited = visited or set()
|
146
|
+
if pointfn not in visited:
|
147
|
+
yield pointfn
|
148
|
+
visited.add(pointfn)
|
149
|
+
for depend in pointfn.depends:
|
150
|
+
if isinstance(depend, CheckpointFn):
|
151
|
+
yield from iterate_checkpoint_fns(depend, visited)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import dis
|
2
|
+
import inspect
|
3
|
+
from collections.abc import Callable
|
4
|
+
from itertools import takewhile
|
5
|
+
from pathlib import Path
|
6
|
+
from types import CodeType, FunctionType, MethodType
|
7
|
+
from typing import Any, Generator, Type, TypeGuard
|
8
|
+
from .object_hash import ObjectHash
|
9
|
+
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
|
10
|
+
|
11
|
+
cwd = Path.cwd()
|
12
|
+
|
13
|
+
def is_class(obj) -> TypeGuard[Type]:
|
14
|
+
# isinstance works too, but needlessly triggers _lazyinit()
|
15
|
+
return issubclass(type(obj), type)
|
16
|
+
|
17
|
+
def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
|
18
|
+
attr_path: tuple[str, ...] = ()
|
19
|
+
scope_obj = None
|
20
|
+
classvars: dict[str, dict[str, Type]] = {}
|
21
|
+
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
22
|
+
if instr.opname in scope_vars and not attr_path:
|
23
|
+
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
24
|
+
attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
25
|
+
elif instr.opname == "CALL":
|
26
|
+
obj = scope_vars.get_at(attr_path)
|
27
|
+
attr_path = ()
|
28
|
+
if is_class(obj):
|
29
|
+
scope_obj = obj
|
30
|
+
elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
|
31
|
+
load_key = instr.opname.replace("STORE", "LOAD")
|
32
|
+
classvars.setdefault(load_key, {})[instr.argval] = scope_obj
|
33
|
+
scope_obj = None
|
34
|
+
return classvars
|
35
|
+
|
36
|
+
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Generator[tuple[tuple[str, ...], Any], None, None]:
|
37
|
+
classvars = extract_classvars(code, scope_vars)
|
38
|
+
scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
|
39
|
+
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
40
|
+
if instr.opname in scope_vars:
|
41
|
+
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
42
|
+
attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
43
|
+
val = scope_vars.get_at(attr_path)
|
44
|
+
if val is not None:
|
45
|
+
yield attr_path, val
|
46
|
+
for const in code.co_consts:
|
47
|
+
if isinstance(const, CodeType):
|
48
|
+
yield from extract_scope_values(const, scope_vars)
|
49
|
+
|
50
|
+
def get_self_value(fn: Callable) -> type | object | None:
|
51
|
+
if isinstance(fn, MethodType):
|
52
|
+
return fn.__self__
|
53
|
+
parts = tuple(fn.__qualname__.split(".")[:-1])
|
54
|
+
cls = parts and AttrDict(fn.__globals__).get_at(parts)
|
55
|
+
if is_class(cls):
|
56
|
+
return cls
|
57
|
+
|
58
|
+
def get_fn_captured_vals(fn: Callable) -> list[Any]:
|
59
|
+
self_value = get_self_value(fn)
|
60
|
+
scope_vars = AttrDict({
|
61
|
+
"LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
|
62
|
+
"LOAD_DEREF": AttrDict(get_cell_contents(fn)),
|
63
|
+
"LOAD_GLOBAL": AttrDict(fn.__globals__),
|
64
|
+
})
|
65
|
+
vals = dict(extract_scope_values(fn.__code__, scope_vars))
|
66
|
+
return list(vals.values())
|
67
|
+
|
68
|
+
def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
69
|
+
if not isinstance(candidate_fn, (FunctionType, MethodType)):
|
70
|
+
return False
|
71
|
+
fn_path = Path(inspect.getfile(candidate_fn)).resolve()
|
72
|
+
return cwd in fn_path.parents and ".venv" not in fn_path.parts
|
73
|
+
|
74
|
+
def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
|
75
|
+
from .checkpoint import CheckpointFn
|
76
|
+
captured_vals_by_fn = captured_vals_by_fn or {}
|
77
|
+
captured_vals = get_fn_captured_vals(fn)
|
78
|
+
captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
|
79
|
+
child_fns = (unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val))
|
80
|
+
for child_fn in child_fns:
|
81
|
+
if isinstance(child_fn, CheckpointFn):
|
82
|
+
captured_vals_by_fn[child_fn] = []
|
83
|
+
elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
|
84
|
+
get_depend_fns(child_fn, capture, captured_vals_by_fn)
|
85
|
+
return captured_vals_by_fn
|
86
|
+
|
87
|
+
def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
|
88
|
+
from .checkpoint import CheckpointFn
|
89
|
+
captured_vals_by_fn = get_depend_fns(fn, capture)
|
90
|
+
depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
|
91
|
+
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
92
|
+
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CheckpointFn)]
|
93
|
+
fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
|
94
|
+
return fn_hash, depends
|