checkpointer 2.0.2__tar.gz → 2.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpointer-2.0.2 → checkpointer-2.1.0}/PKG-INFO +32 -16
- {checkpointer-2.0.2 → checkpointer-2.1.0}/README.md +31 -15
- {checkpointer-2.0.2 → checkpointer-2.1.0}/checkpointer/__init__.py +1 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/checkpointer/checkpoint.py +18 -12
- checkpointer-2.1.0/checkpointer/function_body.py +80 -0
- checkpointer-2.1.0/checkpointer/storages/__init__.py +11 -0
- checkpointer-2.1.0/checkpointer/utils.py +52 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/pyproject.toml +1 -1
- {checkpointer-2.0.2 → checkpointer-2.1.0}/uv.lock +1 -1
- checkpointer-2.0.2/checkpointer/function_body.py +0 -46
- checkpointer-2.0.2/checkpointer/utils.py +0 -17
- {checkpointer-2.0.2 → checkpointer-2.1.0}/.gitignore +0 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/LICENSE +0 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/checkpointer/print_checkpoint.py +0 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/checkpointer/storages/bcolz_storage.py +0 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/checkpointer/storages/memory_storage.py +0 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/checkpointer/storages/pickle_storage.py +0 -0
- {checkpointer-2.0.2 → checkpointer-2.1.0}/checkpointer/types.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.0
|
3
|
+
Version: 2.1.0
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
|
|
19
19
|
|
20
20
|
`checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
|
21
21
|
|
22
|
-
Adding or removing `@checkpoint` doesn't change how your code works
|
22
|
+
Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
|
23
23
|
|
24
24
|
### Key Features:
|
25
25
|
- 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
|
@@ -27,6 +27,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
|
|
27
27
|
- 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
|
28
28
|
- ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
|
29
29
|
- 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
|
30
|
+
- 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
|
30
31
|
|
31
32
|
---
|
32
33
|
|
@@ -59,8 +60,10 @@ result = expensive_function(4) # Loads from the cache
|
|
59
60
|
When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
|
60
61
|
|
61
62
|
Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
|
63
|
+
|
62
64
|
1. **Its source code**: Changes to the function's code update its hash.
|
63
65
|
2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
|
66
|
+
3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
|
64
67
|
|
65
68
|
### Example: Cache Invalidation
|
66
69
|
|
@@ -115,7 +118,17 @@ def some_expensive_function():
|
|
115
118
|
|
116
119
|
## Usage
|
117
120
|
|
121
|
+
### Basic Invocation and Caching
|
122
|
+
|
123
|
+
Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
|
124
|
+
|
125
|
+
```python
|
126
|
+
result = expensive_function(4) # Computes and stores the result
|
127
|
+
result = expensive_function(4) # Loads the result from the cache
|
128
|
+
```
|
129
|
+
|
118
130
|
### Force Recalculation
|
131
|
+
|
119
132
|
Force a recalculation and overwrite the stored checkpoint:
|
120
133
|
|
121
134
|
```python
|
@@ -123,6 +136,7 @@ result = expensive_function.rerun(4)
|
|
123
136
|
```
|
124
137
|
|
125
138
|
### Call the Original Function
|
139
|
+
|
126
140
|
Use `fn` to directly call the original, undecorated function:
|
127
141
|
|
128
142
|
```python
|
@@ -132,6 +146,7 @@ result = expensive_function.fn(4)
|
|
132
146
|
This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
|
133
147
|
|
134
148
|
### Retrieve Stored Checkpoints
|
149
|
+
|
135
150
|
Access cached results without recalculating:
|
136
151
|
|
137
152
|
```python
|
@@ -154,11 +169,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
|
|
154
169
|
```python
|
155
170
|
from checkpointer import checkpoint, PickleStorage, MemoryStorage
|
156
171
|
|
157
|
-
@checkpoint(format="pickle") #
|
172
|
+
@checkpoint(format="pickle") # Short for format=PickleStorage
|
158
173
|
def disk_cached(x: int) -> int:
|
159
174
|
return x ** 2
|
160
175
|
|
161
|
-
@checkpoint(format="memory") #
|
176
|
+
@checkpoint(format="memory") # Short for format=MemoryStorage
|
162
177
|
def memory_cached(x: int) -> int:
|
163
178
|
return x * 10
|
164
179
|
```
|
@@ -191,14 +206,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
|
|
191
206
|
|
192
207
|
## Configuration Options ⚙️
|
193
208
|
|
194
|
-
| Option
|
195
|
-
|
196
|
-
| `
|
197
|
-
| `
|
198
|
-
| `
|
199
|
-
| `
|
200
|
-
| `
|
201
|
-
| `
|
209
|
+
| Option | Type | Default | Description |
|
210
|
+
|-----------------|-----------------------------------|----------------------|------------------------------------------------|
|
211
|
+
| `capture` | `bool` | `False` | Include captured variables in function hashes. |
|
212
|
+
| `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
|
213
|
+
| `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
|
214
|
+
| `when` | `bool` | `True` | Enable or disable checkpointing. |
|
215
|
+
| `verbosity` | `0` or `1` | `1` | Logging verbosity. |
|
216
|
+
| `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
|
217
|
+
| `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
|
202
218
|
|
203
219
|
---
|
204
220
|
|
@@ -220,13 +236,13 @@ async def async_compute_sum(a: int, b: int) -> int:
|
|
220
236
|
|
221
237
|
async def main():
|
222
238
|
result1 = compute_square(5)
|
223
|
-
print(result1)
|
239
|
+
print(result1) # Outputs 25
|
224
240
|
|
225
241
|
result2 = await async_compute_sum(3, 7)
|
226
|
-
print(result2)
|
242
|
+
print(result2) # Outputs 10
|
227
243
|
|
228
|
-
result3 = async_compute_sum.get(3, 7)
|
229
|
-
print(result3)
|
244
|
+
result3 = await async_compute_sum.get(3, 7)
|
245
|
+
print(result3) # Outputs 10
|
230
246
|
|
231
247
|
asyncio.run(main())
|
232
248
|
```
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
`checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
|
4
4
|
|
5
|
-
Adding or removing `@checkpoint` doesn't change how your code works
|
5
|
+
Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
|
6
6
|
|
7
7
|
### Key Features:
|
8
8
|
- 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
|
@@ -10,6 +10,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
|
|
10
10
|
- 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
|
11
11
|
- ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
|
12
12
|
- 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
|
13
|
+
- 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
|
13
14
|
|
14
15
|
---
|
15
16
|
|
@@ -42,8 +43,10 @@ result = expensive_function(4) # Loads from the cache
|
|
42
43
|
When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
|
43
44
|
|
44
45
|
Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
|
46
|
+
|
45
47
|
1. **Its source code**: Changes to the function's code update its hash.
|
46
48
|
2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
|
49
|
+
3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
|
47
50
|
|
48
51
|
### Example: Cache Invalidation
|
49
52
|
|
@@ -98,7 +101,17 @@ def some_expensive_function():
|
|
98
101
|
|
99
102
|
## Usage
|
100
103
|
|
104
|
+
### Basic Invocation and Caching
|
105
|
+
|
106
|
+
Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
|
107
|
+
|
108
|
+
```python
|
109
|
+
result = expensive_function(4) # Computes and stores the result
|
110
|
+
result = expensive_function(4) # Loads the result from the cache
|
111
|
+
```
|
112
|
+
|
101
113
|
### Force Recalculation
|
114
|
+
|
102
115
|
Force a recalculation and overwrite the stored checkpoint:
|
103
116
|
|
104
117
|
```python
|
@@ -106,6 +119,7 @@ result = expensive_function.rerun(4)
|
|
106
119
|
```
|
107
120
|
|
108
121
|
### Call the Original Function
|
122
|
+
|
109
123
|
Use `fn` to directly call the original, undecorated function:
|
110
124
|
|
111
125
|
```python
|
@@ -115,6 +129,7 @@ result = expensive_function.fn(4)
|
|
115
129
|
This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
|
116
130
|
|
117
131
|
### Retrieve Stored Checkpoints
|
132
|
+
|
118
133
|
Access cached results without recalculating:
|
119
134
|
|
120
135
|
```python
|
@@ -137,11 +152,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
|
|
137
152
|
```python
|
138
153
|
from checkpointer import checkpoint, PickleStorage, MemoryStorage
|
139
154
|
|
140
|
-
@checkpoint(format="pickle") #
|
155
|
+
@checkpoint(format="pickle") # Short for format=PickleStorage
|
141
156
|
def disk_cached(x: int) -> int:
|
142
157
|
return x ** 2
|
143
158
|
|
144
|
-
@checkpoint(format="memory") #
|
159
|
+
@checkpoint(format="memory") # Short for format=MemoryStorage
|
145
160
|
def memory_cached(x: int) -> int:
|
146
161
|
return x * 10
|
147
162
|
```
|
@@ -174,14 +189,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
|
|
174
189
|
|
175
190
|
## Configuration Options ⚙️
|
176
191
|
|
177
|
-
| Option
|
178
|
-
|
179
|
-
| `
|
180
|
-
| `
|
181
|
-
| `
|
182
|
-
| `
|
183
|
-
| `
|
184
|
-
| `
|
192
|
+
| Option | Type | Default | Description |
|
193
|
+
|-----------------|-----------------------------------|----------------------|------------------------------------------------|
|
194
|
+
| `capture` | `bool` | `False` | Include captured variables in function hashes. |
|
195
|
+
| `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
|
196
|
+
| `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
|
197
|
+
| `when` | `bool` | `True` | Enable or disable checkpointing. |
|
198
|
+
| `verbosity` | `0` or `1` | `1` | Logging verbosity. |
|
199
|
+
| `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
|
200
|
+
| `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
|
185
201
|
|
186
202
|
---
|
187
203
|
|
@@ -203,13 +219,13 @@ async def async_compute_sum(a: int, b: int) -> int:
|
|
203
219
|
|
204
220
|
async def main():
|
205
221
|
result1 = compute_square(5)
|
206
|
-
print(result1)
|
222
|
+
print(result1) # Outputs 25
|
207
223
|
|
208
224
|
result2 = await async_compute_sum(3, 7)
|
209
|
-
print(result2)
|
225
|
+
print(result2) # Outputs 10
|
210
226
|
|
211
|
-
result3 = async_compute_sum.get(3, 7)
|
212
|
-
print(result3)
|
227
|
+
result3 = await async_compute_sum.get(3, 7)
|
228
|
+
print(result3) # Outputs 10
|
213
229
|
|
214
230
|
asyncio.run(main())
|
215
231
|
```
|
@@ -5,5 +5,6 @@ import tempfile
|
|
5
5
|
|
6
6
|
create_checkpointer = Checkpointer
|
7
7
|
checkpoint = Checkpointer()
|
8
|
+
capture_checkpoint = Checkpointer(capture=True)
|
8
9
|
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
9
10
|
tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
|
@@ -7,16 +7,13 @@ from datetime import datetime
|
|
7
7
|
from functools import update_wrapper
|
8
8
|
from .types import Storage
|
9
9
|
from .function_body import get_function_hash
|
10
|
-
from .utils import unwrap_fn, sync_resolve_coroutine
|
11
|
-
from .storages
|
12
|
-
from .storages.memory_storage import MemoryStorage
|
13
|
-
from .storages.bcolz_storage import BcolzStorage
|
10
|
+
from .utils import unwrap_fn, sync_resolve_coroutine, resolved_awaitable
|
11
|
+
from .storages import STORAGE_MAP
|
14
12
|
from .print_checkpoint import print_checkpoint
|
15
13
|
|
16
14
|
Fn = TypeVar("Fn", bound=Callable)
|
17
15
|
|
18
16
|
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
19
|
-
STORAGE_MAP: dict[str, Type[Storage]] = {"memory": MemoryStorage, "pickle": PickleStorage, "bcolz": BcolzStorage}
|
20
17
|
|
21
18
|
class CheckpointError(Exception):
|
22
19
|
pass
|
@@ -28,6 +25,7 @@ class CheckpointerOpts(TypedDict, total=False):
|
|
28
25
|
verbosity: Literal[0, 1]
|
29
26
|
path: Callable[..., str] | None
|
30
27
|
should_expire: Callable[[datetime], bool] | None
|
28
|
+
capture: bool
|
31
29
|
|
32
30
|
class Checkpointer:
|
33
31
|
def __init__(self, **opts: Unpack[CheckpointerOpts]):
|
@@ -37,6 +35,7 @@ class Checkpointer:
|
|
37
35
|
self.verbosity = opts.get("verbosity", 1)
|
38
36
|
self.path = opts.get("path")
|
39
37
|
self.should_expire = opts.get("should_expire")
|
38
|
+
self.capture = opts.get("capture", False)
|
40
39
|
|
41
40
|
@overload
|
42
41
|
def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CheckpointFn[Fn]: ...
|
@@ -57,14 +56,16 @@ class CheckpointFn(Generic[Fn]):
|
|
57
56
|
storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
|
58
57
|
self.checkpointer = checkpointer
|
59
58
|
self.fn = fn
|
60
|
-
self.fn_hash = get_function_hash(wrapped)
|
59
|
+
self.fn_hash, self.depends = get_function_hash(wrapped, self.checkpointer.capture)
|
61
60
|
self.fn_id = f"{file_name}/{wrapped.__name__}"
|
62
61
|
self.is_async = inspect.iscoroutinefunction(wrapped)
|
63
62
|
self.storage = storage(checkpointer)
|
64
63
|
|
65
64
|
def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
|
66
65
|
if not callable(self.checkpointer.path):
|
67
|
-
|
66
|
+
# TODO: use digest size before digesting instead of truncating the hash
|
67
|
+
call_hash = hashing.hash((self.fn_hash, args, kw), "blake2b")[:32]
|
68
|
+
return f"{self.fn_id}/{call_hash}"
|
68
69
|
checkpoint_id = self.checkpointer.path(*args, **kw)
|
69
70
|
if not isinstance(checkpoint_id, str):
|
70
71
|
raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
|
@@ -101,12 +102,17 @@ class CheckpointFn(Generic[Fn]):
|
|
101
102
|
coroutine = self._store_on_demand(args, kw, rerun)
|
102
103
|
return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
|
103
104
|
|
104
|
-
|
105
|
-
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
106
|
-
|
107
|
-
def get(self, *args, **kw) -> Any:
|
105
|
+
def _get(self, args, kw) -> Any:
|
108
106
|
checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
|
109
107
|
try:
|
110
|
-
|
108
|
+
val = self.storage.load(checkpoint_path)
|
109
|
+
return resolved_awaitable(val) if self.is_async else val
|
111
110
|
except:
|
112
111
|
raise CheckpointError("Could not load checkpoint")
|
112
|
+
|
113
|
+
def exists(self, *args: tuple, **kw: dict) -> bool:
|
114
|
+
return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
|
115
|
+
|
116
|
+
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
117
|
+
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
118
|
+
get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import dis
|
3
|
+
import inspect
|
4
|
+
import tokenize
|
5
|
+
from io import StringIO
|
6
|
+
from collections.abc import Callable
|
7
|
+
from itertools import chain, takewhile
|
8
|
+
from operator import itemgetter
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Any, TypeGuard, TYPE_CHECKING
|
11
|
+
from types import CodeType, FunctionType
|
12
|
+
from relib import transpose, hashing, merge_dicts, drop_none
|
13
|
+
from .utils import unwrap_fn, iterate_and_upcoming, get_cell_contents, AttrDict, get_at_attr
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from .checkpoint import CheckpointFn
|
17
|
+
|
18
|
+
cwd = Path.cwd()
|
19
|
+
|
20
|
+
def extract_scope_values(code: CodeType, scope_vars: dict[str, Any], closure = False) -> dict[tuple[str, ...], Any]:
|
21
|
+
opname = "LOAD_GLOBAL" if not closure else "LOAD_DEREF"
|
22
|
+
scope_values_by_path: dict[tuple[str, ...], Any] = {}
|
23
|
+
instructions = list(dis.get_instructions(code))
|
24
|
+
|
25
|
+
for instr, upcoming_instrs in iterate_and_upcoming(instructions):
|
26
|
+
if instr.opname == opname:
|
27
|
+
name = instr.argval
|
28
|
+
attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
|
29
|
+
attr_path = (name, *(instr.argval for instr in attrs))
|
30
|
+
scope_values_by_path[attr_path] = get_at_attr(scope_vars, attr_path)
|
31
|
+
|
32
|
+
children = (extract_scope_values(const, scope_vars, closure) for const in code.co_consts if isinstance(const, CodeType))
|
33
|
+
return merge_dicts(scope_values_by_path, *children)
|
34
|
+
|
35
|
+
def get_fn_captured_vals(fn: Callable) -> list[Any]:
|
36
|
+
closure_scope = {k: get_cell_contents(v) for k, v in zip(fn.__code__.co_freevars, fn.__closure__ or [])}
|
37
|
+
global_vals = extract_scope_values(fn.__code__, AttrDict(fn.__globals__), closure=False)
|
38
|
+
closure_vals = extract_scope_values(fn.__code__, AttrDict(closure_scope), closure=True)
|
39
|
+
sorted_items = chain(sorted(global_vals.items()), sorted(closure_vals.items()))
|
40
|
+
return drop_none(map(itemgetter(1), sorted_items))
|
41
|
+
|
42
|
+
def get_fn_body(fn: Callable) -> str:
|
43
|
+
source = "".join(inspect.getsourcelines(fn)[0])
|
44
|
+
tokens = tokenize.generate_tokens(StringIO(source).readline)
|
45
|
+
ignore_types = (tokenize.COMMENT, tokenize.NL)
|
46
|
+
return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
|
47
|
+
|
48
|
+
def get_fn_path(fn: Callable) -> Path:
|
49
|
+
return Path(inspect.getfile(fn)).resolve()
|
50
|
+
|
51
|
+
def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
52
|
+
return isinstance(candidate_fn, FunctionType) \
|
53
|
+
and cwd in get_fn_path(candidate_fn).parents
|
54
|
+
|
55
|
+
def append_fn_depends(checkpoint_fns: set[CheckpointFn], captured_vals_by_fn: dict[Callable, list[Any]], fn: Callable, capture: bool) -> None:
|
56
|
+
from .checkpoint import CheckpointFn
|
57
|
+
captured_vals = get_fn_captured_vals(fn)
|
58
|
+
captured_vals_by_fn[fn] = [v for v in captured_vals if capture and not callable(v)]
|
59
|
+
callables = [unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val)]
|
60
|
+
depends = {val for val in callables if is_user_fn(val)}
|
61
|
+
checkpoint_fns.update({val for val in callables if isinstance(val, CheckpointFn)})
|
62
|
+
not_appended = depends - captured_vals_by_fn.keys()
|
63
|
+
captured_vals_by_fn.update({fn: [] for fn in not_appended})
|
64
|
+
for child_fn in not_appended:
|
65
|
+
append_fn_depends(checkpoint_fns, captured_vals_by_fn, child_fn, capture)
|
66
|
+
|
67
|
+
def get_depend_fns(fn: Callable, capture: bool) -> tuple[set[CheckpointFn], dict[Callable, list[Any]]]:
|
68
|
+
checkpoint_fns: set[CheckpointFn] = set()
|
69
|
+
captured_vals_by_fn: dict[Callable, list[Any]] = {}
|
70
|
+
append_fn_depends(checkpoint_fns, captured_vals_by_fn, fn, capture)
|
71
|
+
return checkpoint_fns, captured_vals_by_fn
|
72
|
+
|
73
|
+
def get_function_hash(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
|
74
|
+
checkpoint_fns, captured_vals_by_fn = get_depend_fns(fn, capture)
|
75
|
+
checkpoint_fns = sorted(checkpoint_fns, key=lambda fn: unwrap_fn(fn).__qualname__)
|
76
|
+
checkpoint_hashes = [check.fn_hash for check in checkpoint_fns]
|
77
|
+
depend_fns, depend_captured_vals = transpose(sorted(captured_vals_by_fn.items(), key=lambda x: x[0].__qualname__), 2)
|
78
|
+
fn_bodies = list(map(get_fn_body, [fn] + depend_fns))
|
79
|
+
fn_hash = hashing.hash((fn_bodies, depend_captured_vals, checkpoint_hashes), "blake2b")
|
80
|
+
return fn_hash, checkpoint_fns + depend_fns
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from typing import Type
|
2
|
+
from ..types import Storage
|
3
|
+
from .pickle_storage import PickleStorage
|
4
|
+
from .memory_storage import MemoryStorage
|
5
|
+
from .bcolz_storage import BcolzStorage
|
6
|
+
|
7
|
+
STORAGE_MAP: dict[str, Type[Storage]] = {
|
8
|
+
"pickle": PickleStorage,
|
9
|
+
"memory": MemoryStorage,
|
10
|
+
"bcolz": BcolzStorage,
|
11
|
+
}
|
@@ -0,0 +1,52 @@
|
|
1
|
+
from typing import Generator, Coroutine, Iterable, Any, cast
|
2
|
+
from types import CellType, coroutine
|
3
|
+
from itertools import islice
|
4
|
+
|
5
|
+
class AttrDict(dict):
|
6
|
+
def __init__(self, *args, **kwargs):
|
7
|
+
super(AttrDict, self).__init__(*args, **kwargs)
|
8
|
+
self.__dict__ = self
|
9
|
+
|
10
|
+
def __getattribute__(self, name: str) -> Any:
|
11
|
+
return super().__getattribute__(name)
|
12
|
+
|
13
|
+
def unwrap_fn[T](fn: T, checkpoint_fn=False) -> T:
|
14
|
+
from .checkpoint import CheckpointFn
|
15
|
+
while hasattr(fn, "__wrapped__"):
|
16
|
+
if checkpoint_fn and isinstance(fn, CheckpointFn):
|
17
|
+
return fn
|
18
|
+
fn = getattr(fn, "__wrapped__")
|
19
|
+
return fn
|
20
|
+
|
21
|
+
@coroutine
|
22
|
+
def coroutine_as_generator[T](coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
|
23
|
+
val = yield from coroutine
|
24
|
+
return val
|
25
|
+
|
26
|
+
def sync_resolve_coroutine[T](coroutine: Coroutine[None, None, T]) -> T:
|
27
|
+
gen = cast(Generator, coroutine_as_generator(coroutine))
|
28
|
+
try:
|
29
|
+
while True: next(gen)
|
30
|
+
except StopIteration as ex:
|
31
|
+
return ex.value
|
32
|
+
|
33
|
+
async def resolved_awaitable[T](value: T) -> T:
|
34
|
+
return value
|
35
|
+
|
36
|
+
def iterate_and_upcoming[T](l: list[T]) -> Iterable[tuple[T, Iterable[T]]]:
|
37
|
+
for i, item in enumerate(l):
|
38
|
+
yield item, islice(l, i + 1, None)
|
39
|
+
|
40
|
+
def get_at_attr(d: dict, keys: tuple[str, ...]) -> Any:
|
41
|
+
try:
|
42
|
+
for key in keys:
|
43
|
+
d = getattr(d, key)
|
44
|
+
except AttributeError:
|
45
|
+
return None
|
46
|
+
return d
|
47
|
+
|
48
|
+
def get_cell_contents(cell: CellType) -> Any:
|
49
|
+
try:
|
50
|
+
return cell.cell_contents
|
51
|
+
except ValueError:
|
52
|
+
return None
|
@@ -1,46 +0,0 @@
|
|
1
|
-
import inspect
|
2
|
-
import relib.hashing as hashing
|
3
|
-
from collections.abc import Callable
|
4
|
-
from types import FunctionType, CodeType
|
5
|
-
from pathlib import Path
|
6
|
-
from .utils import unwrap_fn
|
7
|
-
|
8
|
-
cwd = Path.cwd()
|
9
|
-
|
10
|
-
def get_fn_path(fn: Callable) -> Path:
|
11
|
-
return Path(inspect.getfile(fn)).resolve()
|
12
|
-
|
13
|
-
def get_function_body(fn: Callable) -> str:
|
14
|
-
# TODO: Strip comments
|
15
|
-
lines = inspect.getsourcelines(fn)[0]
|
16
|
-
lines = [line.rstrip() for line in lines]
|
17
|
-
lines = [line for line in lines if line]
|
18
|
-
return "\n".join(lines)
|
19
|
-
|
20
|
-
def get_code_children(code: CodeType) -> list[str]:
|
21
|
-
consts = [const for const in code.co_consts if isinstance(const, CodeType)]
|
22
|
-
children = [child for const in consts for child in get_code_children(const)]
|
23
|
-
return list(code.co_names) + children
|
24
|
-
|
25
|
-
def is_user_fn(candidate_fn, cleared_fns: set[Callable]) -> bool:
|
26
|
-
return isinstance(candidate_fn, FunctionType) \
|
27
|
-
and candidate_fn not in cleared_fns \
|
28
|
-
and cwd in get_fn_path(candidate_fn).parents
|
29
|
-
|
30
|
-
def append_fn_children(cleared_fns: set[Callable], fn: Callable) -> None:
|
31
|
-
code_children = get_code_children(fn.__code__)
|
32
|
-
fn_children = [unwrap_fn(fn.__globals__.get(co_name, None)) for co_name in code_children]
|
33
|
-
fn_children = [child for child in fn_children if is_user_fn(child, cleared_fns)]
|
34
|
-
cleared_fns.update(fn_children)
|
35
|
-
for child_fn in fn_children:
|
36
|
-
append_fn_children(cleared_fns, child_fn)
|
37
|
-
|
38
|
-
def get_fn_children(fn: Callable) -> list[Callable]:
|
39
|
-
cleared_fns: set[Callable] = set()
|
40
|
-
append_fn_children(cleared_fns, fn)
|
41
|
-
return sorted(cleared_fns, key=lambda fn: fn.__name__)
|
42
|
-
|
43
|
-
def get_function_hash(fn: Callable) -> str:
|
44
|
-
fns = [fn] + get_fn_children(fn)
|
45
|
-
fn_bodies = list(map(get_function_body, fns))
|
46
|
-
return hashing.hash(fn_bodies)
|
@@ -1,17 +0,0 @@
|
|
1
|
-
import types
|
2
|
-
|
3
|
-
def unwrap_fn[T](fn: T) -> T:
|
4
|
-
while hasattr(fn, "__wrapped__"):
|
5
|
-
fn = getattr(fn, "__wrapped__")
|
6
|
-
return fn
|
7
|
-
|
8
|
-
@types.coroutine
|
9
|
-
def coroutine_as_generator(coroutine):
|
10
|
-
val = yield from coroutine
|
11
|
-
return val
|
12
|
-
|
13
|
-
def sync_resolve_coroutine(coroutine):
|
14
|
-
try:
|
15
|
-
next(coroutine_as_generator(coroutine))
|
16
|
-
except StopIteration as ex:
|
17
|
-
return ex.value
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|