checkpointer 2.0.2__tar.gz → 2.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {checkpointer-2.0.2 → checkpointer-2.5.0}/LICENSE +1 -1
  2. {checkpointer-2.0.2 → checkpointer-2.5.0}/PKG-INFO +49 -21
  3. {checkpointer-2.0.2 → checkpointer-2.5.0}/README.md +45 -17
  4. checkpointer-2.5.0/checkpointer/__init__.py +20 -0
  5. {checkpointer-2.0.2 → checkpointer-2.5.0}/checkpointer/checkpoint.py +73 -30
  6. checkpointer-2.5.0/checkpointer/fn_ident.py +94 -0
  7. checkpointer-2.5.0/checkpointer/object_hash.py +186 -0
  8. checkpointer-2.5.0/checkpointer/storages/__init__.py +11 -0
  9. {checkpointer-2.0.2 → checkpointer-2.5.0}/checkpointer/storages/bcolz_storage.py +6 -7
  10. checkpointer-2.5.0/checkpointer/storages/memory_storage.py +39 -0
  11. checkpointer-2.5.0/checkpointer/storages/pickle_storage.py +45 -0
  12. checkpointer-2.0.2/checkpointer/types.py → checkpointer-2.5.0/checkpointer/storages/storage.py +9 -5
  13. checkpointer-2.5.0/checkpointer/test_checkpointer.py +170 -0
  14. checkpointer-2.5.0/checkpointer/utils.py +112 -0
  15. {checkpointer-2.0.2 → checkpointer-2.5.0}/pyproject.toml +17 -4
  16. checkpointer-2.5.0/uv.lock +529 -0
  17. checkpointer-2.0.2/checkpointer/__init__.py +0 -9
  18. checkpointer-2.0.2/checkpointer/function_body.py +0 -46
  19. checkpointer-2.0.2/checkpointer/storages/memory_storage.py +0 -25
  20. checkpointer-2.0.2/checkpointer/storages/pickle_storage.py +0 -31
  21. checkpointer-2.0.2/checkpointer/utils.py +0 -17
  22. checkpointer-2.0.2/uv.lock +0 -22
  23. {checkpointer-2.0.2 → checkpointer-2.5.0}/.gitignore +0 -0
  24. {checkpointer-2.0.2 → checkpointer-2.5.0}/checkpointer/print_checkpoint.py +0 -0
@@ -1,4 +1,4 @@
1
- Copyright 2024 Hampus Hallman
1
+ Copyright 2018-2025 Hampus Hallman
2
2
 
3
3
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
4
 
@@ -1,25 +1,25 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.0.2
3
+ Version: 2.5.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
7
- License: Copyright 2024 Hampus Hallman
7
+ License: Copyright 2018-2025 Hampus Hallman
8
8
 
9
9
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
10
 
11
11
  The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
12
 
13
13
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
14
+ License-File: LICENSE
14
15
  Requires-Python: >=3.12
15
- Requires-Dist: relib
16
16
  Description-Content-Type: text/markdown
17
17
 
18
18
  # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![Python 3.12](https://img.shields.io/badge/python-3.12-blue)](https://pypi.org/project/checkpointer/)
19
19
 
20
20
  `checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
21
21
 
22
- Adding or removing `@checkpoint` doesn't change how your code works, and it can be applied to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
22
+ Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
23
23
 
24
24
  ### Key Features:
25
25
  - 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
@@ -27,6 +27,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
27
27
  - 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
28
28
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
29
29
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
30
+ - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
30
31
 
31
32
  ---
32
33
 
@@ -59,8 +60,10 @@ result = expensive_function(4) # Loads from the cache
59
60
  When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
60
61
 
61
62
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
63
+
62
64
  1. **Its source code**: Changes to the function's code update its hash.
63
65
  2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
66
+ 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
64
67
 
65
68
  ### Example: Cache Invalidation
66
69
 
@@ -105,7 +108,7 @@ Layer caches by stacking checkpoints:
105
108
  @dev_checkpoint # Adds caching during development
106
109
  def some_expensive_function():
107
110
  print("Performing a time-consuming operation...")
108
- return sum(i * i for i in range(10**6))
111
+ return sum(i * i for i in range(10**8))
109
112
  ```
110
113
 
111
114
  - **In development**: Both `dev_checkpoint` and `memory` caches are active.
@@ -115,7 +118,17 @@ def some_expensive_function():
115
118
 
116
119
  ## Usage
117
120
 
121
+ ### Basic Invocation and Caching
122
+
123
+ Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
124
+
125
+ ```python
126
+ result = expensive_function(4) # Computes and stores the result
127
+ result = expensive_function(4) # Loads the result from the cache
128
+ ```
129
+
118
130
  ### Force Recalculation
131
+
119
132
  Force a recalculation and overwrite the stored checkpoint:
120
133
 
121
134
  ```python
@@ -123,6 +136,7 @@ result = expensive_function.rerun(4)
123
136
  ```
124
137
 
125
138
  ### Call the Original Function
139
+
126
140
  Use `fn` to directly call the original, undecorated function:
127
141
 
128
142
  ```python
@@ -132,12 +146,25 @@ result = expensive_function.fn(4)
132
146
  This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
133
147
 
134
148
  ### Retrieve Stored Checkpoints
149
+
135
150
  Access cached results without recalculating:
136
151
 
137
152
  ```python
138
153
  stored_result = expensive_function.get(4)
139
154
  ```
140
155
 
156
+ ### Refresh Function Hash
157
+
158
+ When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
159
+
160
+ Use the `reinit` method to manually refresh the function's hash within the same session:
161
+
162
+ ```python
163
+ expensive_function.reinit()
164
+ ```
165
+
166
+ This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
167
+
141
168
  ---
142
169
 
143
170
  ## Storage Backends
@@ -154,11 +181,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
154
181
  ```python
155
182
  from checkpointer import checkpoint, PickleStorage, MemoryStorage
156
183
 
157
- @checkpoint(format="pickle") # Equivalent to format=PickleStorage
184
+ @checkpoint(format="pickle") # Short for format=PickleStorage
158
185
  def disk_cached(x: int) -> int:
159
186
  return x ** 2
160
187
 
161
- @checkpoint(format="memory") # Equivalent to format=MemoryStorage
188
+ @checkpoint(format="memory") # Short for format=MemoryStorage
162
189
  def memory_cached(x: int) -> int:
163
190
  return x * 10
164
191
  ```
@@ -174,9 +201,9 @@ from checkpointer import checkpoint, Storage
174
201
  from datetime import datetime
175
202
 
176
203
  class CustomStorage(Storage):
204
+ def store(self, path, data): ... # Save the checkpoint data
177
205
  def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
178
206
  def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
179
- def store(self, path, data): ... # Save the checkpoint data
180
207
  def load(self, path): ... # Return the checkpoint data
181
208
  def delete(self, path): ... # Delete the checkpoint
182
209
 
@@ -191,14 +218,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
191
218
 
192
219
  ## Configuration Options ⚙️
193
220
 
194
- | Option | Type | Default | Description |
195
- |----------------|-------------------------------------|-------------|---------------------------------------------|
196
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
197
- | `root_path` | `Path`, `str`, or `None` | User Cache | Root directory for storing checkpoints. |
198
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
199
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
200
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
201
- | `should_expire`| `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
221
+ | Option | Type | Default | Description |
222
+ |-----------------|-----------------------------------|----------------------|------------------------------------------------|
223
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
224
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
225
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
226
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
227
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
228
+ | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
229
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
202
230
 
203
231
  ---
204
232
 
@@ -220,13 +248,13 @@ async def async_compute_sum(a: int, b: int) -> int:
220
248
 
221
249
  async def main():
222
250
  result1 = compute_square(5)
223
- print(result1)
251
+ print(result1) # Outputs 25
224
252
 
225
253
  result2 = await async_compute_sum(3, 7)
226
- print(result2)
254
+ print(result2) # Outputs 10
227
255
 
228
- result3 = async_compute_sum.get(3, 7)
229
- print(result3)
256
+ result3 = await async_compute_sum.get(3, 7)
257
+ print(result3) # Outputs 10
230
258
 
231
259
  asyncio.run(main())
232
260
  ```
@@ -2,7 +2,7 @@
2
2
 
3
3
  `checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
4
4
 
5
- Adding or removing `@checkpoint` doesn't change how your code works, and it can be applied to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
5
+ Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
6
6
 
7
7
  ### Key Features:
8
8
  - 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
@@ -10,6 +10,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
10
10
  - 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
11
11
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
12
12
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
13
+ - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
13
14
 
14
15
  ---
15
16
 
@@ -42,8 +43,10 @@ result = expensive_function(4) # Loads from the cache
42
43
  When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
43
44
 
44
45
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
46
+
45
47
  1. **Its source code**: Changes to the function's code update its hash.
46
48
  2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
49
+ 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
47
50
 
48
51
  ### Example: Cache Invalidation
49
52
 
@@ -88,7 +91,7 @@ Layer caches by stacking checkpoints:
88
91
  @dev_checkpoint # Adds caching during development
89
92
  def some_expensive_function():
90
93
  print("Performing a time-consuming operation...")
91
- return sum(i * i for i in range(10**6))
94
+ return sum(i * i for i in range(10**8))
92
95
  ```
93
96
 
94
97
  - **In development**: Both `dev_checkpoint` and `memory` caches are active.
@@ -98,7 +101,17 @@ def some_expensive_function():
98
101
 
99
102
  ## Usage
100
103
 
104
+ ### Basic Invocation and Caching
105
+
106
+ Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
107
+
108
+ ```python
109
+ result = expensive_function(4) # Computes and stores the result
110
+ result = expensive_function(4) # Loads the result from the cache
111
+ ```
112
+
101
113
  ### Force Recalculation
114
+
102
115
  Force a recalculation and overwrite the stored checkpoint:
103
116
 
104
117
  ```python
@@ -106,6 +119,7 @@ result = expensive_function.rerun(4)
106
119
  ```
107
120
 
108
121
  ### Call the Original Function
122
+
109
123
  Use `fn` to directly call the original, undecorated function:
110
124
 
111
125
  ```python
@@ -115,12 +129,25 @@ result = expensive_function.fn(4)
115
129
  This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
116
130
 
117
131
  ### Retrieve Stored Checkpoints
132
+
118
133
  Access cached results without recalculating:
119
134
 
120
135
  ```python
121
136
  stored_result = expensive_function.get(4)
122
137
  ```
123
138
 
139
+ ### Refresh Function Hash
140
+
141
+ When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
142
+
143
+ Use the `reinit` method to manually refresh the function's hash within the same session:
144
+
145
+ ```python
146
+ expensive_function.reinit()
147
+ ```
148
+
149
+ This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
150
+
124
151
  ---
125
152
 
126
153
  ## Storage Backends
@@ -137,11 +164,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
137
164
  ```python
138
165
  from checkpointer import checkpoint, PickleStorage, MemoryStorage
139
166
 
140
- @checkpoint(format="pickle") # Equivalent to format=PickleStorage
167
+ @checkpoint(format="pickle") # Short for format=PickleStorage
141
168
  def disk_cached(x: int) -> int:
142
169
  return x ** 2
143
170
 
144
- @checkpoint(format="memory") # Equivalent to format=MemoryStorage
171
+ @checkpoint(format="memory") # Short for format=MemoryStorage
145
172
  def memory_cached(x: int) -> int:
146
173
  return x * 10
147
174
  ```
@@ -157,9 +184,9 @@ from checkpointer import checkpoint, Storage
157
184
  from datetime import datetime
158
185
 
159
186
  class CustomStorage(Storage):
187
+ def store(self, path, data): ... # Save the checkpoint data
160
188
  def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
161
189
  def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
162
- def store(self, path, data): ... # Save the checkpoint data
163
190
  def load(self, path): ... # Return the checkpoint data
164
191
  def delete(self, path): ... # Delete the checkpoint
165
192
 
@@ -174,14 +201,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
174
201
 
175
202
  ## Configuration Options ⚙️
176
203
 
177
- | Option | Type | Default | Description |
178
- |----------------|-------------------------------------|-------------|---------------------------------------------|
179
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
180
- | `root_path` | `Path`, `str`, or `None` | User Cache | Root directory for storing checkpoints. |
181
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
182
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
183
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
184
- | `should_expire`| `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
204
+ | Option | Type | Default | Description |
205
+ |-----------------|-----------------------------------|----------------------|------------------------------------------------|
206
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
207
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
208
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
209
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
210
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
211
+ | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
212
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
185
213
 
186
214
  ---
187
215
 
@@ -203,13 +231,13 @@ async def async_compute_sum(a: int, b: int) -> int:
203
231
 
204
232
  async def main():
205
233
  result1 = compute_square(5)
206
- print(result1)
234
+ print(result1) # Outputs 25
207
235
 
208
236
  result2 = await async_compute_sum(3, 7)
209
- print(result2)
237
+ print(result2) # Outputs 10
210
238
 
211
- result3 = async_compute_sum.get(3, 7)
212
- print(result3)
239
+ result3 = await async_compute_sum.get(3, 7)
240
+ print(result3) # Outputs 10
213
241
 
214
242
  asyncio.run(main())
215
243
  ```
@@ -0,0 +1,20 @@
1
+ import gc
2
+ import tempfile
3
+ from typing import Callable
4
+ from .checkpoint import Checkpointer, CheckpointError, CheckpointFn
5
+ from .object_hash import ObjectHash
6
+ from .storages import MemoryStorage, PickleStorage, Storage
7
+
8
+ create_checkpointer = Checkpointer
9
+ checkpoint = Checkpointer()
10
+ capture_checkpoint = Checkpointer(capture=True)
11
+ memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
+ tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
13
+
14
+ def cleanup_all(invalidated=True, expired=True):
15
+ for obj in gc.get_objects():
16
+ if isinstance(obj, CheckpointFn):
17
+ obj.cleanup(invalidated=invalidated, expired=expired)
18
+
19
+ def get_function_hash(fn: Callable, capture=False) -> str:
20
+ return CheckpointFn(Checkpointer(capture=capture), fn).fn_hash
@@ -1,22 +1,19 @@
1
1
  from __future__ import annotations
2
2
  import inspect
3
- import relib.hashing as hashing
4
- from typing import Generic, TypeVar, Type, TypedDict, Callable, Unpack, Literal, Any, cast, overload
5
- from pathlib import Path
3
+ import re
6
4
  from datetime import datetime
7
5
  from functools import update_wrapper
8
- from .types import Storage
9
- from .function_body import get_function_hash
10
- from .utils import unwrap_fn, sync_resolve_coroutine
11
- from .storages.pickle_storage import PickleStorage
12
- from .storages.memory_storage import MemoryStorage
13
- from .storages.bcolz_storage import BcolzStorage
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
8
+ from .fn_ident import get_fn_ident
9
+ from .object_hash import ObjectHash
14
10
  from .print_checkpoint import print_checkpoint
11
+ from .storages import STORAGE_MAP, Storage
12
+ from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
15
13
 
16
14
  Fn = TypeVar("Fn", bound=Callable)
17
15
 
18
16
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
19
- STORAGE_MAP: dict[str, Type[Storage]] = {"memory": MemoryStorage, "pickle": PickleStorage, "bcolz": BcolzStorage}
20
17
 
21
18
  class CheckpointError(Exception):
22
19
  pass
@@ -28,6 +25,7 @@ class CheckpointerOpts(TypedDict, total=False):
28
25
  verbosity: Literal[0, 1]
29
26
  path: Callable[..., str] | None
30
27
  should_expire: Callable[[datetime], bool] | None
28
+ capture: bool
31
29
 
32
30
  class Checkpointer:
33
31
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
@@ -37,6 +35,7 @@ class Checkpointer:
37
35
  self.verbosity = opts.get("verbosity", 1)
38
36
  self.path = opts.get("path")
39
37
  self.should_expire = opts.get("should_expire")
38
+ self.capture = opts.get("capture", False)
40
39
 
41
40
  @overload
42
41
  def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CheckpointFn[Fn]: ...
@@ -51,20 +50,47 @@ class Checkpointer:
51
50
 
52
51
  class CheckpointFn(Generic[Fn]):
53
52
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
54
- wrapped = unwrap_fn(fn)
55
- file_name = Path(wrapped.__code__.co_filename).name
56
- update_wrapper(cast(Callable, self), wrapped)
57
- storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
58
53
  self.checkpointer = checkpointer
59
54
  self.fn = fn
60
- self.fn_hash = get_function_hash(wrapped)
61
- self.fn_id = f"{file_name}/{wrapped.__name__}"
55
+
56
+ def _set_ident(self, force=False):
57
+ if not hasattr(self, "fn_hash_raw") or force:
58
+ self.fn_hash_raw, self.depends = get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
59
+ return self
60
+
61
+ def _lazyinit(self):
62
+ wrapped = unwrap_fn(self.fn)
63
+ fn_file = Path(wrapped.__code__.co_filename).name
64
+ fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
65
+ update_wrapper(cast(Callable, self), wrapped)
66
+ store_format = self.checkpointer.format
67
+ Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
68
+ deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
69
+ self.fn_hash = str(ObjectHash().update_hash(self.fn_hash_raw, iter=deep_hashes))
70
+ self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
62
71
  self.is_async = inspect.iscoroutinefunction(wrapped)
63
- self.storage = storage(checkpointer)
72
+ self.storage = Storage(self)
73
+ self.cleanup = self.storage.cleanup
74
+
75
+ def __getattribute__(self, name: str) -> Any:
76
+ return object.__getattribute__(self, "_getattribute")(name)
77
+
78
+ def _getattribute(self, name: str) -> Any:
79
+ setattr(self, "_getattribute", super().__getattribute__)
80
+ self._lazyinit()
81
+ return self._getattribute(name)
82
+
83
+ def reinit(self, recursive=False):
84
+ pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
85
+ for pointfn in pointfns:
86
+ pointfn._set_ident(True)
87
+ for pointfn in pointfns:
88
+ pointfn._lazyinit()
64
89
 
65
90
  def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
66
91
  if not callable(self.checkpointer.path):
67
- return f"{self.fn_id}/{hashing.hash([self.fn_hash, args, kw or 0])}"
92
+ call_hash = ObjectHash(self.fn_hash, args, kw, digest_size=16)
93
+ return f"{self.fn_subdir}/{call_hash}"
68
94
  checkpoint_id = self.checkpointer.path(*args, **kw)
69
95
  if not isinstance(checkpoint_id, str):
70
96
  raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
@@ -73,13 +99,13 @@ class CheckpointFn(Generic[Fn]):
73
99
  async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
74
100
  checkpoint_id = self.get_checkpoint_id(args, kw)
75
101
  checkpoint_path = self.checkpointer.root_path / checkpoint_id
76
- should_log = self.checkpointer.verbosity > 0
102
+ verbose = self.checkpointer.verbosity > 0
77
103
  refresh = rerun \
78
104
  or not self.storage.exists(checkpoint_path) \
79
105
  or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
80
106
 
81
107
  if refresh:
82
- print_checkpoint(should_log, "MEMORIZING", checkpoint_id, "blue")
108
+ print_checkpoint(verbose, "MEMORIZING", checkpoint_id, "blue")
83
109
  data = self.fn(*args, **kw)
84
110
  if inspect.iscoroutine(data):
85
111
  data = await data
@@ -88,12 +114,12 @@ class CheckpointFn(Generic[Fn]):
88
114
 
89
115
  try:
90
116
  data = self.storage.load(checkpoint_path)
91
- print_checkpoint(should_log, "REMEMBERED", checkpoint_id, "green")
117
+ print_checkpoint(verbose, "REMEMBERED", checkpoint_id, "green")
92
118
  return data
93
119
  except (EOFError, FileNotFoundError):
94
- print_checkpoint(should_log, "CORRUPTED", checkpoint_id, "yellow")
95
- self.storage.delete(checkpoint_path)
96
- return await self._store_on_demand(args, kw, rerun)
120
+ pass
121
+ print_checkpoint(verbose, "CORRUPTED", checkpoint_id, "yellow")
122
+ return await self._store_on_demand(args, kw, True)
97
123
 
98
124
  def _call(self, args: tuple, kw: dict, rerun=False):
99
125
  if not self.checkpointer.when:
@@ -101,12 +127,29 @@ class CheckpointFn(Generic[Fn]):
101
127
  coroutine = self._store_on_demand(args, kw, rerun)
102
128
  return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
103
129
 
130
+ def _get(self, args, kw) -> Any:
131
+ checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
132
+ try:
133
+ val = self.storage.load(checkpoint_path)
134
+ return resolved_awaitable(val) if self.is_async else val
135
+ except Exception as ex:
136
+ raise CheckpointError("Could not load checkpoint") from ex
137
+
138
+ def exists(self, *args: tuple, **kw: dict) -> bool:
139
+ return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
140
+
104
141
  __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
105
142
  rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
143
+ get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
106
144
 
107
- def get(self, *args, **kw) -> Any:
108
- checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
109
- try:
110
- return self.storage.load(checkpoint_path)
111
- except:
112
- raise CheckpointError("Could not load checkpoint")
145
+ def __repr__(self) -> str:
146
+ return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
147
+
148
+ def iterate_checkpoint_fns(pointfn: CheckpointFn, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
149
+ visited = visited or set()
150
+ if pointfn not in visited:
151
+ yield pointfn
152
+ visited.add(pointfn)
153
+ for depend in pointfn.depends:
154
+ if isinstance(depend, CheckpointFn):
155
+ yield from iterate_checkpoint_fns(depend, visited)
@@ -0,0 +1,94 @@
1
+ import dis
2
+ import inspect
3
+ from collections.abc import Callable
4
+ from itertools import takewhile
5
+ from pathlib import Path
6
+ from types import CodeType, FunctionType, MethodType
7
+ from typing import Any, Generator, Type, TypeGuard
8
+ from .object_hash import ObjectHash
9
+ from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
10
+
11
+ cwd = Path.cwd()
12
+
13
+ def is_class(obj) -> TypeGuard[Type]:
14
+ # isinstance works too, but needlessly triggers __getattribute__
15
+ return issubclass(type(obj), type)
16
+
17
+ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
18
+ attr_path: tuple[str, ...] = ()
19
+ scope_obj = None
20
+ classvars: dict[str, dict[str, Type]] = {}
21
+ for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
22
+ if instr.opname in scope_vars and not attr_path:
23
+ attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
24
+ attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
25
+ elif instr.opname == "CALL":
26
+ obj = scope_vars.get_at(attr_path)
27
+ attr_path = ()
28
+ if is_class(obj):
29
+ scope_obj = obj
30
+ elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
31
+ load_key = instr.opname.replace("STORE", "LOAD")
32
+ classvars.setdefault(load_key, {})[instr.argval] = scope_obj
33
+ scope_obj = None
34
+ return classvars
35
+
36
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Generator[tuple[tuple[str, ...], Any], None, None]:
37
+ classvars = extract_classvars(code, scope_vars)
38
+ scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
39
+ for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
40
+ if instr.opname in scope_vars:
41
+ attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
42
+ attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
43
+ val = scope_vars.get_at(attr_path)
44
+ if val is not None:
45
+ yield attr_path, val
46
+ for const in code.co_consts:
47
+ if isinstance(const, CodeType):
48
+ yield from extract_scope_values(const, scope_vars)
49
+
50
+ def get_self_value(fn: Callable) -> type | object | None:
51
+ if isinstance(fn, MethodType):
52
+ return fn.__self__
53
+ parts = tuple(fn.__qualname__.split(".")[:-1])
54
+ cls = parts and AttrDict(fn.__globals__).get_at(parts)
55
+ if is_class(cls):
56
+ return cls
57
+
58
+ def get_fn_captured_vals(fn: Callable) -> list[Any]:
59
+ self_value = get_self_value(fn)
60
+ scope_vars = AttrDict({
61
+ "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
62
+ "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
63
+ "LOAD_GLOBAL": AttrDict(fn.__globals__),
64
+ })
65
+ vals = dict(extract_scope_values(fn.__code__, scope_vars))
66
+ return list(vals.values())
67
+
68
+ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
69
+ if not isinstance(candidate_fn, (FunctionType, MethodType)):
70
+ return False
71
+ fn_path = Path(inspect.getfile(candidate_fn)).resolve()
72
+ return cwd in fn_path.parents and ".venv" not in fn_path.parts
73
+
74
+ def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
75
+ from .checkpoint import CheckpointFn
76
+ captured_vals_by_fn = captured_vals_by_fn or {}
77
+ captured_vals = get_fn_captured_vals(fn)
78
+ captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
79
+ child_fns = (unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val))
80
+ for child_fn in child_fns:
81
+ if isinstance(child_fn, CheckpointFn):
82
+ captured_vals_by_fn[child_fn] = []
83
+ elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
84
+ get_depend_fns(child_fn, capture, captured_vals_by_fn)
85
+ return captured_vals_by_fn
86
+
87
+ def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
88
+ from .checkpoint import CheckpointFn
89
+ captured_vals_by_fn = get_depend_fns(fn, capture)
90
+ depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
91
+ depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
92
+ unwrapped_depends = [fn for fn in depends if not isinstance(fn, CheckpointFn)]
93
+ fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
94
+ return fn_hash, depends