checkpointer 2.0.2__tar.gz → 2.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpointer-2.0.2 → checkpointer-2.5.0}/LICENSE +1 -1
- {checkpointer-2.0.2 → checkpointer-2.5.0}/PKG-INFO +49 -21
- {checkpointer-2.0.2 → checkpointer-2.5.0}/README.md +45 -17
- checkpointer-2.5.0/checkpointer/__init__.py +20 -0
- {checkpointer-2.0.2 → checkpointer-2.5.0}/checkpointer/checkpoint.py +73 -30
- checkpointer-2.5.0/checkpointer/fn_ident.py +94 -0
- checkpointer-2.5.0/checkpointer/object_hash.py +186 -0
- checkpointer-2.5.0/checkpointer/storages/__init__.py +11 -0
- {checkpointer-2.0.2 → checkpointer-2.5.0}/checkpointer/storages/bcolz_storage.py +6 -7
- checkpointer-2.5.0/checkpointer/storages/memory_storage.py +39 -0
- checkpointer-2.5.0/checkpointer/storages/pickle_storage.py +45 -0
- checkpointer-2.0.2/checkpointer/types.py → checkpointer-2.5.0/checkpointer/storages/storage.py +9 -5
- checkpointer-2.5.0/checkpointer/test_checkpointer.py +170 -0
- checkpointer-2.5.0/checkpointer/utils.py +112 -0
- {checkpointer-2.0.2 → checkpointer-2.5.0}/pyproject.toml +17 -4
- checkpointer-2.5.0/uv.lock +529 -0
- checkpointer-2.0.2/checkpointer/__init__.py +0 -9
- checkpointer-2.0.2/checkpointer/function_body.py +0 -46
- checkpointer-2.0.2/checkpointer/storages/memory_storage.py +0 -25
- checkpointer-2.0.2/checkpointer/storages/pickle_storage.py +0 -31
- checkpointer-2.0.2/checkpointer/utils.py +0 -17
- checkpointer-2.0.2/uv.lock +0 -22
- {checkpointer-2.0.2 → checkpointer-2.5.0}/.gitignore +0 -0
- {checkpointer-2.0.2 → checkpointer-2.5.0}/checkpointer/print_checkpoint.py +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
Copyright
|
1
|
+
Copyright 2018-2025 Hampus Hallman
|
2
2
|
|
3
3
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4
4
|
|
@@ -1,25 +1,25 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.0
|
3
|
+
Version: 2.5.0
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
7
|
-
License: Copyright
|
7
|
+
License: Copyright 2018-2025 Hampus Hallman
|
8
8
|
|
9
9
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
10
10
|
|
11
11
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
12
12
|
|
13
13
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
14
|
+
License-File: LICENSE
|
14
15
|
Requires-Python: >=3.12
|
15
|
-
Requires-Dist: relib
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
|
18
18
|
# checkpointer · [](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [](https://pypi.org/project/checkpointer/) [](https://pypi.org/project/checkpointer/)
|
19
19
|
|
20
20
|
`checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
|
21
21
|
|
22
|
-
Adding or removing `@checkpoint` doesn't change how your code works
|
22
|
+
Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
|
23
23
|
|
24
24
|
### Key Features:
|
25
25
|
- 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
|
@@ -27,6 +27,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
|
|
27
27
|
- 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
|
28
28
|
- ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
|
29
29
|
- 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
|
30
|
+
- 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
|
30
31
|
|
31
32
|
---
|
32
33
|
|
@@ -59,8 +60,10 @@ result = expensive_function(4) # Loads from the cache
|
|
59
60
|
When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
|
60
61
|
|
61
62
|
Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
|
63
|
+
|
62
64
|
1. **Its source code**: Changes to the function's code update its hash.
|
63
65
|
2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
|
66
|
+
3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
|
64
67
|
|
65
68
|
### Example: Cache Invalidation
|
66
69
|
|
@@ -105,7 +108,7 @@ Layer caches by stacking checkpoints:
|
|
105
108
|
@dev_checkpoint # Adds caching during development
|
106
109
|
def some_expensive_function():
|
107
110
|
print("Performing a time-consuming operation...")
|
108
|
-
return sum(i * i for i in range(10**
|
111
|
+
return sum(i * i for i in range(10**8))
|
109
112
|
```
|
110
113
|
|
111
114
|
- **In development**: Both `dev_checkpoint` and `memory` caches are active.
|
@@ -115,7 +118,17 @@ def some_expensive_function():
|
|
115
118
|
|
116
119
|
## Usage
|
117
120
|
|
121
|
+
### Basic Invocation and Caching
|
122
|
+
|
123
|
+
Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
|
124
|
+
|
125
|
+
```python
|
126
|
+
result = expensive_function(4) # Computes and stores the result
|
127
|
+
result = expensive_function(4) # Loads the result from the cache
|
128
|
+
```
|
129
|
+
|
118
130
|
### Force Recalculation
|
131
|
+
|
119
132
|
Force a recalculation and overwrite the stored checkpoint:
|
120
133
|
|
121
134
|
```python
|
@@ -123,6 +136,7 @@ result = expensive_function.rerun(4)
|
|
123
136
|
```
|
124
137
|
|
125
138
|
### Call the Original Function
|
139
|
+
|
126
140
|
Use `fn` to directly call the original, undecorated function:
|
127
141
|
|
128
142
|
```python
|
@@ -132,12 +146,25 @@ result = expensive_function.fn(4)
|
|
132
146
|
This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
|
133
147
|
|
134
148
|
### Retrieve Stored Checkpoints
|
149
|
+
|
135
150
|
Access cached results without recalculating:
|
136
151
|
|
137
152
|
```python
|
138
153
|
stored_result = expensive_function.get(4)
|
139
154
|
```
|
140
155
|
|
156
|
+
### Refresh Function Hash
|
157
|
+
|
158
|
+
When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
|
159
|
+
|
160
|
+
Use the `reinit` method to manually refresh the function's hash within the same session:
|
161
|
+
|
162
|
+
```python
|
163
|
+
expensive_function.reinit()
|
164
|
+
```
|
165
|
+
|
166
|
+
This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
|
167
|
+
|
141
168
|
---
|
142
169
|
|
143
170
|
## Storage Backends
|
@@ -154,11 +181,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
|
|
154
181
|
```python
|
155
182
|
from checkpointer import checkpoint, PickleStorage, MemoryStorage
|
156
183
|
|
157
|
-
@checkpoint(format="pickle") #
|
184
|
+
@checkpoint(format="pickle") # Short for format=PickleStorage
|
158
185
|
def disk_cached(x: int) -> int:
|
159
186
|
return x ** 2
|
160
187
|
|
161
|
-
@checkpoint(format="memory") #
|
188
|
+
@checkpoint(format="memory") # Short for format=MemoryStorage
|
162
189
|
def memory_cached(x: int) -> int:
|
163
190
|
return x * 10
|
164
191
|
```
|
@@ -174,9 +201,9 @@ from checkpointer import checkpoint, Storage
|
|
174
201
|
from datetime import datetime
|
175
202
|
|
176
203
|
class CustomStorage(Storage):
|
204
|
+
def store(self, path, data): ... # Save the checkpoint data
|
177
205
|
def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
|
178
206
|
def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
|
179
|
-
def store(self, path, data): ... # Save the checkpoint data
|
180
207
|
def load(self, path): ... # Return the checkpoint data
|
181
208
|
def delete(self, path): ... # Delete the checkpoint
|
182
209
|
|
@@ -191,14 +218,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
|
|
191
218
|
|
192
219
|
## Configuration Options ⚙️
|
193
220
|
|
194
|
-
| Option
|
195
|
-
|
196
|
-
| `
|
197
|
-
| `
|
198
|
-
| `
|
199
|
-
| `
|
200
|
-
| `
|
201
|
-
| `
|
221
|
+
| Option | Type | Default | Description |
|
222
|
+
|-----------------|-----------------------------------|----------------------|------------------------------------------------|
|
223
|
+
| `capture` | `bool` | `False` | Include captured variables in function hashes. |
|
224
|
+
| `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
|
225
|
+
| `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
|
226
|
+
| `when` | `bool` | `True` | Enable or disable checkpointing. |
|
227
|
+
| `verbosity` | `0` or `1` | `1` | Logging verbosity. |
|
228
|
+
| `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
|
229
|
+
| `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
|
202
230
|
|
203
231
|
---
|
204
232
|
|
@@ -220,13 +248,13 @@ async def async_compute_sum(a: int, b: int) -> int:
|
|
220
248
|
|
221
249
|
async def main():
|
222
250
|
result1 = compute_square(5)
|
223
|
-
print(result1)
|
251
|
+
print(result1) # Outputs 25
|
224
252
|
|
225
253
|
result2 = await async_compute_sum(3, 7)
|
226
|
-
print(result2)
|
254
|
+
print(result2) # Outputs 10
|
227
255
|
|
228
|
-
result3 = async_compute_sum.get(3, 7)
|
229
|
-
print(result3)
|
256
|
+
result3 = await async_compute_sum.get(3, 7)
|
257
|
+
print(result3) # Outputs 10
|
230
258
|
|
231
259
|
asyncio.run(main())
|
232
260
|
```
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
`checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
|
4
4
|
|
5
|
-
Adding or removing `@checkpoint` doesn't change how your code works
|
5
|
+
Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
|
6
6
|
|
7
7
|
### Key Features:
|
8
8
|
- 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
|
@@ -10,6 +10,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
|
|
10
10
|
- 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
|
11
11
|
- ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
|
12
12
|
- 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
|
13
|
+
- 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
|
13
14
|
|
14
15
|
---
|
15
16
|
|
@@ -42,8 +43,10 @@ result = expensive_function(4) # Loads from the cache
|
|
42
43
|
When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
|
43
44
|
|
44
45
|
Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
|
46
|
+
|
45
47
|
1. **Its source code**: Changes to the function's code update its hash.
|
46
48
|
2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
|
49
|
+
3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
|
47
50
|
|
48
51
|
### Example: Cache Invalidation
|
49
52
|
|
@@ -88,7 +91,7 @@ Layer caches by stacking checkpoints:
|
|
88
91
|
@dev_checkpoint # Adds caching during development
|
89
92
|
def some_expensive_function():
|
90
93
|
print("Performing a time-consuming operation...")
|
91
|
-
return sum(i * i for i in range(10**
|
94
|
+
return sum(i * i for i in range(10**8))
|
92
95
|
```
|
93
96
|
|
94
97
|
- **In development**: Both `dev_checkpoint` and `memory` caches are active.
|
@@ -98,7 +101,17 @@ def some_expensive_function():
|
|
98
101
|
|
99
102
|
## Usage
|
100
103
|
|
104
|
+
### Basic Invocation and Caching
|
105
|
+
|
106
|
+
Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
|
107
|
+
|
108
|
+
```python
|
109
|
+
result = expensive_function(4) # Computes and stores the result
|
110
|
+
result = expensive_function(4) # Loads the result from the cache
|
111
|
+
```
|
112
|
+
|
101
113
|
### Force Recalculation
|
114
|
+
|
102
115
|
Force a recalculation and overwrite the stored checkpoint:
|
103
116
|
|
104
117
|
```python
|
@@ -106,6 +119,7 @@ result = expensive_function.rerun(4)
|
|
106
119
|
```
|
107
120
|
|
108
121
|
### Call the Original Function
|
122
|
+
|
109
123
|
Use `fn` to directly call the original, undecorated function:
|
110
124
|
|
111
125
|
```python
|
@@ -115,12 +129,25 @@ result = expensive_function.fn(4)
|
|
115
129
|
This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
|
116
130
|
|
117
131
|
### Retrieve Stored Checkpoints
|
132
|
+
|
118
133
|
Access cached results without recalculating:
|
119
134
|
|
120
135
|
```python
|
121
136
|
stored_result = expensive_function.get(4)
|
122
137
|
```
|
123
138
|
|
139
|
+
### Refresh Function Hash
|
140
|
+
|
141
|
+
When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
|
142
|
+
|
143
|
+
Use the `reinit` method to manually refresh the function's hash within the same session:
|
144
|
+
|
145
|
+
```python
|
146
|
+
expensive_function.reinit()
|
147
|
+
```
|
148
|
+
|
149
|
+
This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
|
150
|
+
|
124
151
|
---
|
125
152
|
|
126
153
|
## Storage Backends
|
@@ -137,11 +164,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
|
|
137
164
|
```python
|
138
165
|
from checkpointer import checkpoint, PickleStorage, MemoryStorage
|
139
166
|
|
140
|
-
@checkpoint(format="pickle") #
|
167
|
+
@checkpoint(format="pickle") # Short for format=PickleStorage
|
141
168
|
def disk_cached(x: int) -> int:
|
142
169
|
return x ** 2
|
143
170
|
|
144
|
-
@checkpoint(format="memory") #
|
171
|
+
@checkpoint(format="memory") # Short for format=MemoryStorage
|
145
172
|
def memory_cached(x: int) -> int:
|
146
173
|
return x * 10
|
147
174
|
```
|
@@ -157,9 +184,9 @@ from checkpointer import checkpoint, Storage
|
|
157
184
|
from datetime import datetime
|
158
185
|
|
159
186
|
class CustomStorage(Storage):
|
187
|
+
def store(self, path, data): ... # Save the checkpoint data
|
160
188
|
def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
|
161
189
|
def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
|
162
|
-
def store(self, path, data): ... # Save the checkpoint data
|
163
190
|
def load(self, path): ... # Return the checkpoint data
|
164
191
|
def delete(self, path): ... # Delete the checkpoint
|
165
192
|
|
@@ -174,14 +201,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
|
|
174
201
|
|
175
202
|
## Configuration Options ⚙️
|
176
203
|
|
177
|
-
| Option
|
178
|
-
|
179
|
-
| `
|
180
|
-
| `
|
181
|
-
| `
|
182
|
-
| `
|
183
|
-
| `
|
184
|
-
| `
|
204
|
+
| Option | Type | Default | Description |
|
205
|
+
|-----------------|-----------------------------------|----------------------|------------------------------------------------|
|
206
|
+
| `capture` | `bool` | `False` | Include captured variables in function hashes. |
|
207
|
+
| `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
|
208
|
+
| `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
|
209
|
+
| `when` | `bool` | `True` | Enable or disable checkpointing. |
|
210
|
+
| `verbosity` | `0` or `1` | `1` | Logging verbosity. |
|
211
|
+
| `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
|
212
|
+
| `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
|
185
213
|
|
186
214
|
---
|
187
215
|
|
@@ -203,13 +231,13 @@ async def async_compute_sum(a: int, b: int) -> int:
|
|
203
231
|
|
204
232
|
async def main():
|
205
233
|
result1 = compute_square(5)
|
206
|
-
print(result1)
|
234
|
+
print(result1) # Outputs 25
|
207
235
|
|
208
236
|
result2 = await async_compute_sum(3, 7)
|
209
|
-
print(result2)
|
237
|
+
print(result2) # Outputs 10
|
210
238
|
|
211
|
-
result3 = async_compute_sum.get(3, 7)
|
212
|
-
print(result3)
|
239
|
+
result3 = await async_compute_sum.get(3, 7)
|
240
|
+
print(result3) # Outputs 10
|
213
241
|
|
214
242
|
asyncio.run(main())
|
215
243
|
```
|
@@ -0,0 +1,20 @@
|
|
1
|
+
import gc
|
2
|
+
import tempfile
|
3
|
+
from typing import Callable
|
4
|
+
from .checkpoint import Checkpointer, CheckpointError, CheckpointFn
|
5
|
+
from .object_hash import ObjectHash
|
6
|
+
from .storages import MemoryStorage, PickleStorage, Storage
|
7
|
+
|
8
|
+
create_checkpointer = Checkpointer
|
9
|
+
checkpoint = Checkpointer()
|
10
|
+
capture_checkpoint = Checkpointer(capture=True)
|
11
|
+
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
12
|
+
tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
|
13
|
+
|
14
|
+
def cleanup_all(invalidated=True, expired=True):
|
15
|
+
for obj in gc.get_objects():
|
16
|
+
if isinstance(obj, CheckpointFn):
|
17
|
+
obj.cleanup(invalidated=invalidated, expired=expired)
|
18
|
+
|
19
|
+
def get_function_hash(fn: Callable, capture=False) -> str:
|
20
|
+
return CheckpointFn(Checkpointer(capture=capture), fn).fn_hash
|
@@ -1,22 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import inspect
|
3
|
-
import
|
4
|
-
from typing import Generic, TypeVar, Type, TypedDict, Callable, Unpack, Literal, Any, cast, overload
|
5
|
-
from pathlib import Path
|
3
|
+
import re
|
6
4
|
from datetime import datetime
|
7
5
|
from functools import update_wrapper
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from .
|
11
|
-
from .
|
12
|
-
from .storages.memory_storage import MemoryStorage
|
13
|
-
from .storages.bcolz_storage import BcolzStorage
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
|
8
|
+
from .fn_ident import get_fn_ident
|
9
|
+
from .object_hash import ObjectHash
|
14
10
|
from .print_checkpoint import print_checkpoint
|
11
|
+
from .storages import STORAGE_MAP, Storage
|
12
|
+
from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
|
15
13
|
|
16
14
|
Fn = TypeVar("Fn", bound=Callable)
|
17
15
|
|
18
16
|
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
19
|
-
STORAGE_MAP: dict[str, Type[Storage]] = {"memory": MemoryStorage, "pickle": PickleStorage, "bcolz": BcolzStorage}
|
20
17
|
|
21
18
|
class CheckpointError(Exception):
|
22
19
|
pass
|
@@ -28,6 +25,7 @@ class CheckpointerOpts(TypedDict, total=False):
|
|
28
25
|
verbosity: Literal[0, 1]
|
29
26
|
path: Callable[..., str] | None
|
30
27
|
should_expire: Callable[[datetime], bool] | None
|
28
|
+
capture: bool
|
31
29
|
|
32
30
|
class Checkpointer:
|
33
31
|
def __init__(self, **opts: Unpack[CheckpointerOpts]):
|
@@ -37,6 +35,7 @@ class Checkpointer:
|
|
37
35
|
self.verbosity = opts.get("verbosity", 1)
|
38
36
|
self.path = opts.get("path")
|
39
37
|
self.should_expire = opts.get("should_expire")
|
38
|
+
self.capture = opts.get("capture", False)
|
40
39
|
|
41
40
|
@overload
|
42
41
|
def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CheckpointFn[Fn]: ...
|
@@ -51,20 +50,47 @@ class Checkpointer:
|
|
51
50
|
|
52
51
|
class CheckpointFn(Generic[Fn]):
|
53
52
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
54
|
-
wrapped = unwrap_fn(fn)
|
55
|
-
file_name = Path(wrapped.__code__.co_filename).name
|
56
|
-
update_wrapper(cast(Callable, self), wrapped)
|
57
|
-
storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
|
58
53
|
self.checkpointer = checkpointer
|
59
54
|
self.fn = fn
|
60
|
-
|
61
|
-
|
55
|
+
|
56
|
+
def _set_ident(self, force=False):
|
57
|
+
if not hasattr(self, "fn_hash_raw") or force:
|
58
|
+
self.fn_hash_raw, self.depends = get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
|
59
|
+
return self
|
60
|
+
|
61
|
+
def _lazyinit(self):
|
62
|
+
wrapped = unwrap_fn(self.fn)
|
63
|
+
fn_file = Path(wrapped.__code__.co_filename).name
|
64
|
+
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
65
|
+
update_wrapper(cast(Callable, self), wrapped)
|
66
|
+
store_format = self.checkpointer.format
|
67
|
+
Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
|
68
|
+
deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
|
69
|
+
self.fn_hash = str(ObjectHash().update_hash(self.fn_hash_raw, iter=deep_hashes))
|
70
|
+
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
|
62
71
|
self.is_async = inspect.iscoroutinefunction(wrapped)
|
63
|
-
self.storage =
|
72
|
+
self.storage = Storage(self)
|
73
|
+
self.cleanup = self.storage.cleanup
|
74
|
+
|
75
|
+
def __getattribute__(self, name: str) -> Any:
|
76
|
+
return object.__getattribute__(self, "_getattribute")(name)
|
77
|
+
|
78
|
+
def _getattribute(self, name: str) -> Any:
|
79
|
+
setattr(self, "_getattribute", super().__getattribute__)
|
80
|
+
self._lazyinit()
|
81
|
+
return self._getattribute(name)
|
82
|
+
|
83
|
+
def reinit(self, recursive=False):
|
84
|
+
pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
|
85
|
+
for pointfn in pointfns:
|
86
|
+
pointfn._set_ident(True)
|
87
|
+
for pointfn in pointfns:
|
88
|
+
pointfn._lazyinit()
|
64
89
|
|
65
90
|
def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
|
66
91
|
if not callable(self.checkpointer.path):
|
67
|
-
|
92
|
+
call_hash = ObjectHash(self.fn_hash, args, kw, digest_size=16)
|
93
|
+
return f"{self.fn_subdir}/{call_hash}"
|
68
94
|
checkpoint_id = self.checkpointer.path(*args, **kw)
|
69
95
|
if not isinstance(checkpoint_id, str):
|
70
96
|
raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
|
@@ -73,13 +99,13 @@ class CheckpointFn(Generic[Fn]):
|
|
73
99
|
async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
|
74
100
|
checkpoint_id = self.get_checkpoint_id(args, kw)
|
75
101
|
checkpoint_path = self.checkpointer.root_path / checkpoint_id
|
76
|
-
|
102
|
+
verbose = self.checkpointer.verbosity > 0
|
77
103
|
refresh = rerun \
|
78
104
|
or not self.storage.exists(checkpoint_path) \
|
79
105
|
or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
|
80
106
|
|
81
107
|
if refresh:
|
82
|
-
print_checkpoint(
|
108
|
+
print_checkpoint(verbose, "MEMORIZING", checkpoint_id, "blue")
|
83
109
|
data = self.fn(*args, **kw)
|
84
110
|
if inspect.iscoroutine(data):
|
85
111
|
data = await data
|
@@ -88,12 +114,12 @@ class CheckpointFn(Generic[Fn]):
|
|
88
114
|
|
89
115
|
try:
|
90
116
|
data = self.storage.load(checkpoint_path)
|
91
|
-
print_checkpoint(
|
117
|
+
print_checkpoint(verbose, "REMEMBERED", checkpoint_id, "green")
|
92
118
|
return data
|
93
119
|
except (EOFError, FileNotFoundError):
|
94
|
-
|
95
|
-
|
96
|
-
|
120
|
+
pass
|
121
|
+
print_checkpoint(verbose, "CORRUPTED", checkpoint_id, "yellow")
|
122
|
+
return await self._store_on_demand(args, kw, True)
|
97
123
|
|
98
124
|
def _call(self, args: tuple, kw: dict, rerun=False):
|
99
125
|
if not self.checkpointer.when:
|
@@ -101,12 +127,29 @@ class CheckpointFn(Generic[Fn]):
|
|
101
127
|
coroutine = self._store_on_demand(args, kw, rerun)
|
102
128
|
return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
|
103
129
|
|
130
|
+
def _get(self, args, kw) -> Any:
|
131
|
+
checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
|
132
|
+
try:
|
133
|
+
val = self.storage.load(checkpoint_path)
|
134
|
+
return resolved_awaitable(val) if self.is_async else val
|
135
|
+
except Exception as ex:
|
136
|
+
raise CheckpointError("Could not load checkpoint") from ex
|
137
|
+
|
138
|
+
def exists(self, *args: tuple, **kw: dict) -> bool:
|
139
|
+
return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
|
140
|
+
|
104
141
|
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
105
142
|
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
143
|
+
get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
|
106
144
|
|
107
|
-
def
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
145
|
+
def __repr__(self) -> str:
|
146
|
+
return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
|
147
|
+
|
148
|
+
def iterate_checkpoint_fns(pointfn: CheckpointFn, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
|
149
|
+
visited = visited or set()
|
150
|
+
if pointfn not in visited:
|
151
|
+
yield pointfn
|
152
|
+
visited.add(pointfn)
|
153
|
+
for depend in pointfn.depends:
|
154
|
+
if isinstance(depend, CheckpointFn):
|
155
|
+
yield from iterate_checkpoint_fns(depend, visited)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import dis
|
2
|
+
import inspect
|
3
|
+
from collections.abc import Callable
|
4
|
+
from itertools import takewhile
|
5
|
+
from pathlib import Path
|
6
|
+
from types import CodeType, FunctionType, MethodType
|
7
|
+
from typing import Any, Generator, Type, TypeGuard
|
8
|
+
from .object_hash import ObjectHash
|
9
|
+
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
|
10
|
+
|
11
|
+
cwd = Path.cwd()
|
12
|
+
|
13
|
+
def is_class(obj) -> TypeGuard[Type]:
|
14
|
+
# isinstance works too, but needlessly triggers __getattribute__
|
15
|
+
return issubclass(type(obj), type)
|
16
|
+
|
17
|
+
def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
|
18
|
+
attr_path: tuple[str, ...] = ()
|
19
|
+
scope_obj = None
|
20
|
+
classvars: dict[str, dict[str, Type]] = {}
|
21
|
+
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
22
|
+
if instr.opname in scope_vars and not attr_path:
|
23
|
+
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
24
|
+
attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
25
|
+
elif instr.opname == "CALL":
|
26
|
+
obj = scope_vars.get_at(attr_path)
|
27
|
+
attr_path = ()
|
28
|
+
if is_class(obj):
|
29
|
+
scope_obj = obj
|
30
|
+
elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
|
31
|
+
load_key = instr.opname.replace("STORE", "LOAD")
|
32
|
+
classvars.setdefault(load_key, {})[instr.argval] = scope_obj
|
33
|
+
scope_obj = None
|
34
|
+
return classvars
|
35
|
+
|
36
|
+
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Generator[tuple[tuple[str, ...], Any], None, None]:
|
37
|
+
classvars = extract_classvars(code, scope_vars)
|
38
|
+
scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
|
39
|
+
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
40
|
+
if instr.opname in scope_vars:
|
41
|
+
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
42
|
+
attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
|
43
|
+
val = scope_vars.get_at(attr_path)
|
44
|
+
if val is not None:
|
45
|
+
yield attr_path, val
|
46
|
+
for const in code.co_consts:
|
47
|
+
if isinstance(const, CodeType):
|
48
|
+
yield from extract_scope_values(const, scope_vars)
|
49
|
+
|
50
|
+
def get_self_value(fn: Callable) -> type | object | None:
|
51
|
+
if isinstance(fn, MethodType):
|
52
|
+
return fn.__self__
|
53
|
+
parts = tuple(fn.__qualname__.split(".")[:-1])
|
54
|
+
cls = parts and AttrDict(fn.__globals__).get_at(parts)
|
55
|
+
if is_class(cls):
|
56
|
+
return cls
|
57
|
+
|
58
|
+
def get_fn_captured_vals(fn: Callable) -> list[Any]:
|
59
|
+
self_value = get_self_value(fn)
|
60
|
+
scope_vars = AttrDict({
|
61
|
+
"LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
|
62
|
+
"LOAD_DEREF": AttrDict(get_cell_contents(fn)),
|
63
|
+
"LOAD_GLOBAL": AttrDict(fn.__globals__),
|
64
|
+
})
|
65
|
+
vals = dict(extract_scope_values(fn.__code__, scope_vars))
|
66
|
+
return list(vals.values())
|
67
|
+
|
68
|
+
def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
69
|
+
if not isinstance(candidate_fn, (FunctionType, MethodType)):
|
70
|
+
return False
|
71
|
+
fn_path = Path(inspect.getfile(candidate_fn)).resolve()
|
72
|
+
return cwd in fn_path.parents and ".venv" not in fn_path.parts
|
73
|
+
|
74
|
+
def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
|
75
|
+
from .checkpoint import CheckpointFn
|
76
|
+
captured_vals_by_fn = captured_vals_by_fn or {}
|
77
|
+
captured_vals = get_fn_captured_vals(fn)
|
78
|
+
captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
|
79
|
+
child_fns = (unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val))
|
80
|
+
for child_fn in child_fns:
|
81
|
+
if isinstance(child_fn, CheckpointFn):
|
82
|
+
captured_vals_by_fn[child_fn] = []
|
83
|
+
elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
|
84
|
+
get_depend_fns(child_fn, capture, captured_vals_by_fn)
|
85
|
+
return captured_vals_by_fn
|
86
|
+
|
87
|
+
def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
|
88
|
+
from .checkpoint import CheckpointFn
|
89
|
+
captured_vals_by_fn = get_depend_fns(fn, capture)
|
90
|
+
depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
|
91
|
+
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
92
|
+
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CheckpointFn)]
|
93
|
+
fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
|
94
|
+
return fn_hash, depends
|