checkpointer 2.13.0__tar.gz → 2.14.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpointer-2.13.0 → checkpointer-2.14.0}/PKG-INFO +33 -50
- {checkpointer-2.13.0 → checkpointer-2.14.0}/README.md +32 -49
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/__init__.py +1 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/checkpoint.py +12 -16
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/fn_ident.py +10 -10
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/fn_string.py +5 -7
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/object_hash.py +10 -3
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/storages/memory_storage.py +2 -2
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/storages/pickle_storage.py +3 -3
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/storages/storage.py +14 -1
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/utils.py +3 -2
- {checkpointer-2.13.0 → checkpointer-2.14.0}/pyproject.toml +1 -1
- {checkpointer-2.13.0 → checkpointer-2.14.0}/uv.lock +1 -1
- {checkpointer-2.13.0 → checkpointer-2.14.0}/.gitignore +0 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/.python-version +0 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/ATTRIBUTION.md +0 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/LICENSE +0 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/import_mappings.py +0 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/print_checkpoint.py +0 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/storages/__init__.py +0 -0
- {checkpointer-2.13.0 → checkpointer-2.14.0}/checkpointer/types.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.14.0
|
4
4
|
Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
@@ -16,7 +16,7 @@ Description-Content-Type: text/markdown
|
|
16
16
|
|
17
17
|
# checkpointer · [](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [](https://pypi.org/project/checkpointer/) [](https://pypi.org/project/checkpointer/)
|
18
18
|
|
19
|
-
`checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and
|
19
|
+
`checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and invalidates caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
|
20
20
|
|
21
21
|
## 📦 Installation
|
22
22
|
|
@@ -42,41 +42,51 @@ result = expensive_function(4) # Loads from the cache
|
|
42
42
|
|
43
43
|
## 🧠 How It Works
|
44
44
|
|
45
|
-
When a `@checkpoint
|
45
|
+
When you decorate a function with `@checkpoint` and call it, `checkpointer` computes a unique identifier that represents that specific call. This identifier is based on:
|
46
46
|
|
47
|
-
|
47
|
+
* The function's source code and all its user-defined dependencies,
|
48
|
+
* Global variables used by the function (if capturing is enabled or explicitly annotated),
|
49
|
+
* The actual arguments passed to the function.
|
48
50
|
|
49
|
-
|
51
|
+
`checkpointer` then looks up this identifier in its cache. If a valid cached result exists, it returns that immediately. Otherwise, it runs the original function, stores the result, and returns it.
|
50
52
|
|
51
|
-
`checkpointer`
|
53
|
+
`checkpointer` is designed to be flexible through features like:
|
52
54
|
|
53
|
-
|
55
|
+
* **Support for decorated methods**, correctly caching results bound to instances.
|
56
|
+
* **Support for decorated async functions**, compatible with any async runtime.
|
57
|
+
* **Robust hashing**, covering complex Python objects and large **NumPy**/**PyTorch** arrays via its internal `ObjectHash`.
|
58
|
+
* **Targeted hashing**, allowing you to optimize how arguments and captured variables are hashed.
|
59
|
+
* **Multi-layered caching**, letting you stack decorators for layered caching strategies without losing cache consistency.
|
60
|
+
|
61
|
+
### 🚨 What Causes Cache Invalidation?
|
62
|
+
|
63
|
+
To ensure cache correctness, `checkpointer` tracks two types of hashes:
|
64
|
+
|
65
|
+
#### 1. Function Identity Hash (Computed Once Per Function)
|
54
66
|
|
55
67
|
This hash represents the decorated function itself and is computed once (usually on first invocation). It covers:
|
56
68
|
|
57
|
-
* **
|
58
|
-
The
|
69
|
+
* **Function Code and Signature:**\
|
70
|
+
The actual logic and parameters of the function are hashed - but *not* parameter type annotations or formatting details like whitespace, newlines, comments, or trailing commas, which do **not** trigger invalidation.
|
59
71
|
|
60
72
|
* **Dependencies:**\
|
61
|
-
All user-defined functions and methods the function calls or
|
62
|
-
* Inspecting the function's global scope for referenced functions
|
63
|
-
* Inferring from argument type annotations.
|
73
|
+
All user-defined functions and methods that the decorated function calls or relies on, including indirect dependencies, are included recursively. Dependencies are identified by:
|
74
|
+
* Inspecting the function's global scope for referenced functions and objects.
|
75
|
+
* Inferring from the function's argument type annotations.
|
64
76
|
* Analyzing object constructions and method calls to identify classes and methods used.
|
65
77
|
|
66
|
-
* **
|
67
|
-
Changes unrelated to the function or its dependencies
|
78
|
+
* **Exclusions:**\
|
79
|
+
Changes elsewhere in the module unrelated to the function or its dependencies do **not** cause invalidation.
|
68
80
|
|
69
81
|
#### 2. Call Hash (Computed on Every Function Call)
|
70
82
|
|
71
|
-
|
83
|
+
Every function call produces a call hash, combining:
|
72
84
|
|
73
85
|
* **Passed Arguments:**\
|
74
86
|
Includes positional and keyword arguments, combined with default values. Changing defaults alone doesn't necessarily trigger invalidation unless it affects actual call values.
|
75
87
|
|
76
88
|
* **Captured Global Variables:**\
|
77
|
-
When `capture=True` or explicit capture annotations are used, `checkpointer`
|
78
|
-
* `CaptureMe` variables are hashed on every call, so changes trigger invalidation.
|
79
|
-
* `CaptureMeOnce` variables are hashed once per session for performance optimization.
|
89
|
+
When `capture=True` or explicit capture annotations are used, `checkpointer` includes referenced global variables in the call hash. Variables annotated with `CaptureMe` are hashed on every call, causing immediate cache invalidation if they change. Variables annotated with `CaptureMeOnce` are hashed only once per Python session, improving performance by avoiding repeated hashing.
|
80
90
|
|
81
91
|
* **Custom Argument Hashing:**\
|
82
92
|
Using `HashBy` annotations, arguments or captured variables can be transformed before hashing (e.g., sorting lists to ignore order), allowing more precise or efficient call hashes.
|
@@ -103,7 +113,7 @@ Once a function is decorated with `@checkpoint`, you can interact with its cachi
|
|
103
113
|
* **`expensive_function.delete(...)`**:\
|
104
114
|
Remove the cached entry for given arguments.
|
105
115
|
|
106
|
-
* **`expensive_function.reinit(recursive: bool =
|
116
|
+
* **`expensive_function.reinit(recursive: bool = True)`**:\
|
107
117
|
Recalculate the function identity hash and recapture `CaptureMeOnce` variables, updating the cached function state within the same Python session.
|
108
118
|
|
109
119
|
## ⚙️ Configuration & Customization
|
@@ -116,18 +126,18 @@ The `@checkpoint` decorator accepts the following parameters:
|
|
116
126
|
* **`directory`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
|
117
127
|
Base directory for disk-based checkpoints (only for `"pickle"` storage).
|
118
128
|
|
119
|
-
* **`when`** (Type: `bool`, Default: `True`)\
|
120
|
-
Enable or disable checkpointing dynamically, useful for environment-based toggling.
|
121
|
-
|
122
129
|
* **`capture`** (Type: `bool`, Default: `False`)\
|
123
130
|
If `True`, includes global variables referenced by the function in call hashes (except those excluded via `NoHash`).
|
124
131
|
|
125
|
-
* **`
|
132
|
+
* **`expiry`** (Type: `Callable[[datetime.datetime], bool]` or `datetime.timedelta`, Default: `None`)\
|
126
133
|
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
127
134
|
|
128
135
|
* **`fn_hash_from`** (Type: `Any`, Default: `None`)\
|
129
136
|
Override the computed function identity hash with any hashable object you provide (e.g., version strings, config IDs). This gives you explicit control over the function's version and when its cache should be invalidated.
|
130
137
|
|
138
|
+
* **`when`** (Type: `bool`, Default: `True`)\
|
139
|
+
Enable or disable checkpointing dynamically, useful for environment-based toggling.
|
140
|
+
|
131
141
|
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
132
142
|
Controls the level of logging output from `checkpointer`.
|
133
143
|
* `0`: No output.
|
@@ -231,30 +241,3 @@ class MyCustomStorage(Storage):
|
|
231
241
|
def custom_cached_function(x: int):
|
232
242
|
return x ** 2
|
233
243
|
```
|
234
|
-
|
235
|
-
## ⚡ Async Support
|
236
|
-
|
237
|
-
`checkpointer` works with Python's `asyncio` and other async runtimes.
|
238
|
-
|
239
|
-
```python
|
240
|
-
import asyncio
|
241
|
-
from checkpointer import checkpoint
|
242
|
-
|
243
|
-
@checkpoint
|
244
|
-
async def async_compute_sum(a: int, b: int) -> int:
|
245
|
-
print(f"Asynchronously computing {a} + {b}...")
|
246
|
-
await asyncio.sleep(1)
|
247
|
-
return a + b
|
248
|
-
|
249
|
-
async def main():
|
250
|
-
result1 = await async_compute_sum(3, 7)
|
251
|
-
print(f"Result 1: {result1}")
|
252
|
-
|
253
|
-
result2 = await async_compute_sum(3, 7)
|
254
|
-
print(f"Result 2: {result2}")
|
255
|
-
|
256
|
-
result3 = async_compute_sum.get(3, 7)
|
257
|
-
print(f"Result 3 (from cache): {result3}")
|
258
|
-
|
259
|
-
asyncio.run(main())
|
260
|
-
```
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# checkpointer · [](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [](https://pypi.org/project/checkpointer/) [](https://pypi.org/project/checkpointer/)
|
2
2
|
|
3
|
-
`checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and
|
3
|
+
`checkpointer` is a Python library offering a decorator-based API for memoizing (caching) function results with code-aware cache invalidation. It works with sync and async functions, supports multiple storage backends, and invalidates caches automatically when your code or dependencies change - helping you maintain correctness, speed up execution, and smooth out your workflows by skipping redundant, costly operations.
|
4
4
|
|
5
5
|
## 📦 Installation
|
6
6
|
|
@@ -26,41 +26,51 @@ result = expensive_function(4) # Loads from the cache
|
|
26
26
|
|
27
27
|
## 🧠 How It Works
|
28
28
|
|
29
|
-
When a `@checkpoint
|
29
|
+
When you decorate a function with `@checkpoint` and call it, `checkpointer` computes a unique identifier that represents that specific call. This identifier is based on:
|
30
30
|
|
31
|
-
|
31
|
+
* The function's source code and all its user-defined dependencies,
|
32
|
+
* Global variables used by the function (if capturing is enabled or explicitly annotated),
|
33
|
+
* The actual arguments passed to the function.
|
32
34
|
|
33
|
-
|
35
|
+
`checkpointer` then looks up this identifier in its cache. If a valid cached result exists, it returns that immediately. Otherwise, it runs the original function, stores the result, and returns it.
|
34
36
|
|
35
|
-
`checkpointer`
|
37
|
+
`checkpointer` is designed to be flexible through features like:
|
36
38
|
|
37
|
-
|
39
|
+
* **Support for decorated methods**, correctly caching results bound to instances.
|
40
|
+
* **Support for decorated async functions**, compatible with any async runtime.
|
41
|
+
* **Robust hashing**, covering complex Python objects and large **NumPy**/**PyTorch** arrays via its internal `ObjectHash`.
|
42
|
+
* **Targeted hashing**, allowing you to optimize how arguments and captured variables are hashed.
|
43
|
+
* **Multi-layered caching**, letting you stack decorators for layered caching strategies without losing cache consistency.
|
44
|
+
|
45
|
+
### 🚨 What Causes Cache Invalidation?
|
46
|
+
|
47
|
+
To ensure cache correctness, `checkpointer` tracks two types of hashes:
|
48
|
+
|
49
|
+
#### 1. Function Identity Hash (Computed Once Per Function)
|
38
50
|
|
39
51
|
This hash represents the decorated function itself and is computed once (usually on first invocation). It covers:
|
40
52
|
|
41
|
-
* **
|
42
|
-
The
|
53
|
+
* **Function Code and Signature:**\
|
54
|
+
The actual logic and parameters of the function are hashed - but *not* parameter type annotations or formatting details like whitespace, newlines, comments, or trailing commas, which do **not** trigger invalidation.
|
43
55
|
|
44
56
|
* **Dependencies:**\
|
45
|
-
All user-defined functions and methods the function calls or
|
46
|
-
* Inspecting the function's global scope for referenced functions
|
47
|
-
* Inferring from argument type annotations.
|
57
|
+
All user-defined functions and methods that the decorated function calls or relies on, including indirect dependencies, are included recursively. Dependencies are identified by:
|
58
|
+
* Inspecting the function's global scope for referenced functions and objects.
|
59
|
+
* Inferring from the function's argument type annotations.
|
48
60
|
* Analyzing object constructions and method calls to identify classes and methods used.
|
49
61
|
|
50
|
-
* **
|
51
|
-
Changes unrelated to the function or its dependencies
|
62
|
+
* **Exclusions:**\
|
63
|
+
Changes elsewhere in the module unrelated to the function or its dependencies do **not** cause invalidation.
|
52
64
|
|
53
65
|
#### 2. Call Hash (Computed on Every Function Call)
|
54
66
|
|
55
|
-
|
67
|
+
Every function call produces a call hash, combining:
|
56
68
|
|
57
69
|
* **Passed Arguments:**\
|
58
70
|
Includes positional and keyword arguments, combined with default values. Changing defaults alone doesn't necessarily trigger invalidation unless it affects actual call values.
|
59
71
|
|
60
72
|
* **Captured Global Variables:**\
|
61
|
-
When `capture=True` or explicit capture annotations are used, `checkpointer`
|
62
|
-
* `CaptureMe` variables are hashed on every call, so changes trigger invalidation.
|
63
|
-
* `CaptureMeOnce` variables are hashed once per session for performance optimization.
|
73
|
+
When `capture=True` or explicit capture annotations are used, `checkpointer` includes referenced global variables in the call hash. Variables annotated with `CaptureMe` are hashed on every call, causing immediate cache invalidation if they change. Variables annotated with `CaptureMeOnce` are hashed only once per Python session, improving performance by avoiding repeated hashing.
|
64
74
|
|
65
75
|
* **Custom Argument Hashing:**\
|
66
76
|
Using `HashBy` annotations, arguments or captured variables can be transformed before hashing (e.g., sorting lists to ignore order), allowing more precise or efficient call hashes.
|
@@ -87,7 +97,7 @@ Once a function is decorated with `@checkpoint`, you can interact with its cachi
|
|
87
97
|
* **`expensive_function.delete(...)`**:\
|
88
98
|
Remove the cached entry for given arguments.
|
89
99
|
|
90
|
-
* **`expensive_function.reinit(recursive: bool =
|
100
|
+
* **`expensive_function.reinit(recursive: bool = True)`**:\
|
91
101
|
Recalculate the function identity hash and recapture `CaptureMeOnce` variables, updating the cached function state within the same Python session.
|
92
102
|
|
93
103
|
## ⚙️ Configuration & Customization
|
@@ -100,18 +110,18 @@ The `@checkpoint` decorator accepts the following parameters:
|
|
100
110
|
* **`directory`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
|
101
111
|
Base directory for disk-based checkpoints (only for `"pickle"` storage).
|
102
112
|
|
103
|
-
* **`when`** (Type: `bool`, Default: `True`)\
|
104
|
-
Enable or disable checkpointing dynamically, useful for environment-based toggling.
|
105
|
-
|
106
113
|
* **`capture`** (Type: `bool`, Default: `False`)\
|
107
114
|
If `True`, includes global variables referenced by the function in call hashes (except those excluded via `NoHash`).
|
108
115
|
|
109
|
-
* **`
|
116
|
+
* **`expiry`** (Type: `Callable[[datetime.datetime], bool]` or `datetime.timedelta`, Default: `None`)\
|
110
117
|
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
111
118
|
|
112
119
|
* **`fn_hash_from`** (Type: `Any`, Default: `None`)\
|
113
120
|
Override the computed function identity hash with any hashable object you provide (e.g., version strings, config IDs). This gives you explicit control over the function's version and when its cache should be invalidated.
|
114
121
|
|
122
|
+
* **`when`** (Type: `bool`, Default: `True`)\
|
123
|
+
Enable or disable checkpointing dynamically, useful for environment-based toggling.
|
124
|
+
|
115
125
|
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
116
126
|
Controls the level of logging output from `checkpointer`.
|
117
127
|
* `0`: No output.
|
@@ -215,30 +225,3 @@ class MyCustomStorage(Storage):
|
|
215
225
|
def custom_cached_function(x: int):
|
216
226
|
return x ** 2
|
217
227
|
```
|
218
|
-
|
219
|
-
## ⚡ Async Support
|
220
|
-
|
221
|
-
`checkpointer` works with Python's `asyncio` and other async runtimes.
|
222
|
-
|
223
|
-
```python
|
224
|
-
import asyncio
|
225
|
-
from checkpointer import checkpoint
|
226
|
-
|
227
|
-
@checkpoint
|
228
|
-
async def async_compute_sum(a: int, b: int) -> int:
|
229
|
-
print(f"Asynchronously computing {a} + {b}...")
|
230
|
-
await asyncio.sleep(1)
|
231
|
-
return a + b
|
232
|
-
|
233
|
-
async def main():
|
234
|
-
result1 = await async_compute_sum(3, 7)
|
235
|
-
print(f"Result 1: {result1}")
|
236
|
-
|
237
|
-
result2 = await async_compute_sum(3, 7)
|
238
|
-
print(f"Result 2: {result2}")
|
239
|
-
|
240
|
-
result3 = async_compute_sum.get(3, 7)
|
241
|
-
print(f"Result 3 (from cache): {result3}")
|
242
|
-
|
243
|
-
asyncio.run(main())
|
244
|
-
```
|
@@ -1,19 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import re
|
3
|
-
from datetime import datetime
|
3
|
+
from datetime import datetime, timedelta
|
4
4
|
from functools import cached_property, update_wrapper
|
5
5
|
from inspect import Parameter, iscoroutine, signature, unwrap
|
6
|
-
from itertools import chain
|
7
6
|
from pathlib import Path
|
8
7
|
from typing import (
|
9
8
|
Callable, Concatenate, Coroutine, Generic, Iterable,
|
10
|
-
Literal, Self, Type, TypedDict, Unpack,
|
9
|
+
Literal, Self, Type, TypedDict, Unpack, overload,
|
11
10
|
)
|
12
11
|
from .fn_ident import Capturable, RawFunctionIdent, get_fn_ident
|
13
12
|
from .object_hash import ObjectHash
|
14
13
|
from .print_checkpoint import print_checkpoint
|
15
14
|
from .storages import STORAGE_MAP, Storage, StorageType
|
16
15
|
from .types import AwaitableValue, C, Coro, Fn, P, R, hash_by_from_annotation
|
16
|
+
from .utils import flatten
|
17
17
|
|
18
18
|
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
19
19
|
|
@@ -25,7 +25,7 @@ class CheckpointerOpts(TypedDict, total=False):
|
|
25
25
|
directory: Path | str | None
|
26
26
|
when: bool
|
27
27
|
verbosity: Literal[0, 1, 2]
|
28
|
-
|
28
|
+
expiry: Callable[[datetime], bool] | timedelta | None
|
29
29
|
capture: bool
|
30
30
|
fn_hash_from: object
|
31
31
|
|
@@ -35,7 +35,7 @@ class Checkpointer:
|
|
35
35
|
self.directory = Path(opts.get("directory", DEFAULT_DIR) or ".")
|
36
36
|
self.when = opts.get("when", True)
|
37
37
|
self.verbosity = opts.get("verbosity", 1)
|
38
|
-
self.
|
38
|
+
self.expiry = opts.get("expiry")
|
39
39
|
self.capture = opts.get("capture", False)
|
40
40
|
self.fn_hash_from = opts.get("fn_hash_from")
|
41
41
|
|
@@ -129,7 +129,7 @@ class CachedFunction(Generic[Fn]):
|
|
129
129
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
130
130
|
store_format = checkpointer.storage
|
131
131
|
Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
|
132
|
-
update_wrapper(
|
132
|
+
update_wrapper(self, unwrap(fn)) # type: ignore
|
133
133
|
self.ident = FunctionIdent(self, checkpointer, fn)
|
134
134
|
self.storage = Storage(self)
|
135
135
|
self.bound = ()
|
@@ -152,13 +152,13 @@ class CachedFunction(Generic[Fn]):
|
|
152
152
|
|
153
153
|
@property
|
154
154
|
def fn(self) -> Fn:
|
155
|
-
return
|
155
|
+
return self.ident.fn # type: ignore
|
156
156
|
|
157
157
|
@property
|
158
158
|
def cleanup(self):
|
159
159
|
return self.storage.cleanup
|
160
160
|
|
161
|
-
def reinit(self, recursive=
|
161
|
+
def reinit(self, recursive=True) -> CachedFunction[Fn]:
|
162
162
|
depend_idents = list(self.ident.deep_idents()) if recursive else [self.ident]
|
163
163
|
for ident in depend_idents: ident.reset()
|
164
164
|
for ident in depend_idents: ident.fn_hash
|
@@ -178,12 +178,10 @@ class CachedFunction(Generic[Fn]):
|
|
178
178
|
elif key == b"**":
|
179
179
|
for key in kw.keys() - ident.arg_names:
|
180
180
|
named_args[key] = hash_by(named_args[key])
|
181
|
-
named_args_iter = chain.from_iterable(sorted(named_args.items()))
|
182
|
-
captured = chain.from_iterable(capturable.capture() for capturable in ident.capturables)
|
183
181
|
call_hash = ObjectHash(digest_size=16) \
|
184
|
-
.update(
|
185
|
-
.update(
|
186
|
-
.update(
|
182
|
+
.update(header="NAMED", iter=flatten(sorted(named_args.items()))) \
|
183
|
+
.update(header="POS", iter=pos_args) \
|
184
|
+
.update(header="CAPTURED", iter=flatten(c.capture() for c in ident.capturables))
|
187
185
|
return str(call_hash)
|
188
186
|
|
189
187
|
def get_call_hash(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> str:
|
@@ -201,9 +199,7 @@ class CachedFunction(Generic[Fn]):
|
|
201
199
|
|
202
200
|
call_hash = self._get_call_hash(args, kw)
|
203
201
|
call_id = f"{storage.fn_id()}/{call_hash}"
|
204
|
-
refresh = rerun
|
205
|
-
or not storage.exists(call_hash) \
|
206
|
-
or (params.should_expire and params.should_expire(storage.checkpoint_date(call_hash)))
|
202
|
+
refresh = rerun or not storage.exists(call_hash) or storage.expired(call_hash)
|
207
203
|
|
208
204
|
if refresh:
|
209
205
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id, "blue")
|
@@ -86,14 +86,14 @@ def extract_scope_values(code: CodeType, scope_vars: dict) -> Iterable[tuple[Att
|
|
86
86
|
next_scope_vars = {**scope_vars, "LOAD_FAST": {}, "LOAD_DEREF": next_deref}
|
87
87
|
yield from extract_scope_values(const, next_scope_vars)
|
88
88
|
|
89
|
-
def
|
89
|
+
def class_from_annotation(anno: object) -> Type | None:
|
90
90
|
if anno in (None, Annotated):
|
91
91
|
return None
|
92
|
-
|
92
|
+
if is_class(anno):
|
93
93
|
return anno
|
94
|
-
|
95
|
-
return
|
96
|
-
return
|
94
|
+
if get_origin(anno) is Annotated:
|
95
|
+
return class_from_annotation(next(iter(get_args(anno)), None))
|
96
|
+
return class_from_annotation(get_origin(anno))
|
97
97
|
|
98
98
|
def get_self_value(fn: Callable) -> Type | object | None:
|
99
99
|
if isinstance(fn, MethodType):
|
@@ -107,9 +107,9 @@ def get_capturables(fn: Callable, capture: bool, captured_vars: dict[AttrPath, o
|
|
107
107
|
module = getmodule(fn)
|
108
108
|
if not module or not is_user_fn(fn):
|
109
109
|
return
|
110
|
-
for (
|
110
|
+
for (instruct_type, *attr_path), obj in captured_vars.items():
|
111
111
|
attr_path = AttrPath(attr_path)
|
112
|
-
if
|
112
|
+
if instruct_type == "LOAD_GLOBAL" and not callable(obj) and not isinstance(obj, ModuleType):
|
113
113
|
anno = resolve_annotation(module, ".".join(attr_path))
|
114
114
|
if capture or is_capture_me(anno) or is_capture_me_once(anno):
|
115
115
|
hash_by = hash_by_from_annotation(anno)
|
@@ -122,7 +122,7 @@ def get_fn_captures(fn: Callable, capture: bool) -> tuple[list[Callable], list[C
|
|
122
122
|
for param in signature(fn).parameters.values()
|
123
123
|
if param.annotation is not Parameter.empty
|
124
124
|
if param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
|
125
|
-
if (class_anno :=
|
125
|
+
if (class_anno := class_from_annotation(param.annotation))
|
126
126
|
}
|
127
127
|
if self_obj := get_self_value(fn):
|
128
128
|
scope_vars_signature["self"] = self_obj
|
@@ -152,10 +152,10 @@ def get_depend_fns(fn: Callable, capture: bool, capturable_by_fn: CapturableByFn
|
|
152
152
|
def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
|
153
153
|
from .checkpoint import CachedFunction
|
154
154
|
capturable_by_fn = get_depend_fns(fn, capture)
|
155
|
-
capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
|
156
155
|
depends = capturable_by_fn.keys()
|
157
156
|
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
158
157
|
depend_callables = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
159
|
-
assert fn == depend_callables[0]
|
160
158
|
fn_hash = str(ObjectHash(iter=map(get_fn_aststr, depend_callables)))
|
159
|
+
capturables = {capt for capts in capturable_by_fn.values() for capt in capts}
|
160
|
+
assert fn == depend_callables[0]
|
161
161
|
return RawFunctionIdent(fn_hash, depends, capturables)
|
@@ -15,7 +15,7 @@ def get_decorator_path(node: ast.AST) -> tuple[str, ...]:
|
|
15
15
|
else:
|
16
16
|
return ()
|
17
17
|
|
18
|
-
def
|
18
|
+
def is_lone_expression(node: ast.AST) -> bool:
|
19
19
|
# Filter out docstrings
|
20
20
|
return isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant)
|
21
21
|
|
@@ -26,7 +26,8 @@ class CleanFunctionTransform(ast.NodeTransformer):
|
|
26
26
|
|
27
27
|
def is_checkpointer(self, node: ast.AST) -> bool:
|
28
28
|
from .checkpoint import Checkpointer
|
29
|
-
|
29
|
+
decorator = get_at(self.fn_globals, *get_decorator_path(node))
|
30
|
+
return isinstance(decorator, Checkpointer) or decorator is Checkpointer
|
30
31
|
|
31
32
|
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
|
32
33
|
fn_type = type(node).__name__
|
@@ -45,7 +46,7 @@ class CleanFunctionTransform(ast.NodeTransformer):
|
|
45
46
|
return ast.List([
|
46
47
|
ast.Constant(header),
|
47
48
|
ast.List([child for child in node.decorator_list if not self.is_checkpointer(child)], ast.Load()),
|
48
|
-
ast.List([self.visit(child) for child in node.body if not
|
49
|
+
ast.List([self.visit(child) for child in node.body if not is_lone_expression(child)], ast.Load()),
|
49
50
|
], ast.Load())
|
50
51
|
|
51
52
|
def visit_AsyncFunctionDef(self, node):
|
@@ -66,10 +67,7 @@ def get_fn_aststr(fn: Callable) -> str:
|
|
66
67
|
if fn.__name__ != "<lambda>":
|
67
68
|
tree = CleanFunctionTransform(fn.__globals__).visit(tree)
|
68
69
|
else:
|
69
|
-
for node in ast.walk(tree)
|
70
|
-
if isinstance(node, ast.Lambda):
|
71
|
-
tree = node
|
72
|
-
break
|
70
|
+
tree = ast.List([node for node in ast.walk(tree) if isinstance(node, ast.Lambda)], ast.Load())
|
73
71
|
|
74
72
|
if sys.version_info >= (3, 13):
|
75
73
|
return ast.dump(tree, annotate_fields=False, show_empty=True)
|
@@ -5,16 +5,19 @@ import io
|
|
5
5
|
import re
|
6
6
|
import sys
|
7
7
|
import tokenize
|
8
|
+
import sysconfig
|
8
9
|
from collections import OrderedDict
|
9
10
|
from collections.abc import Iterable
|
10
11
|
from contextlib import nullcontext, suppress
|
11
12
|
from decimal import Decimal
|
12
13
|
from io import StringIO
|
14
|
+
from inspect import getfile
|
13
15
|
from itertools import chain
|
16
|
+
from pathlib import Path
|
14
17
|
from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
|
15
18
|
from types import BuiltinFunctionType, FunctionType, GeneratorType, MappingProxyType, MethodType, ModuleType, UnionType
|
16
19
|
from typing import Callable, Self, TypeVar
|
17
|
-
from .utils import ContextVar
|
20
|
+
from .utils import ContextVar, flatten
|
18
21
|
|
19
22
|
np, torch = None, None
|
20
23
|
|
@@ -31,8 +34,8 @@ if sys.version_info >= (3, 12):
|
|
31
34
|
else:
|
32
35
|
TypeAliasType = _Never
|
33
36
|
|
34
|
-
flatten = chain.from_iterable
|
35
37
|
nc = nullcontext()
|
38
|
+
stdlib = Path(sysconfig.get_paths()["stdlib"]).resolve()
|
36
39
|
|
37
40
|
def encode_type(t: type | FunctionType) -> str:
|
38
41
|
return f"{t.__module__}:{t.__qualname__}"
|
@@ -128,7 +131,11 @@ class ObjectHash:
|
|
128
131
|
self.header("builtin", obj.__qualname__)
|
129
132
|
|
130
133
|
case FunctionType():
|
131
|
-
|
134
|
+
fn_file = Path(getfile(obj)).resolve()
|
135
|
+
if fn_file.is_relative_to(stdlib):
|
136
|
+
self.header("function-std", obj.__qualname__)
|
137
|
+
else:
|
138
|
+
self.header("function", encode_type(obj)).update(get_fn_body(obj), obj.__defaults__, obj.__kwdefaults__, obj.__annotations__)
|
132
139
|
|
133
140
|
case MethodType():
|
134
141
|
self.header("method").update(obj.__func__, obj.__self__.__class__)
|
@@ -31,7 +31,7 @@ class MemoryStorage(Storage):
|
|
31
31
|
if key.parent == curr_key.parent:
|
32
32
|
if invalidated and key != curr_key:
|
33
33
|
del item_map[key]
|
34
|
-
elif expired and self.checkpointer.
|
34
|
+
elif expired and self.checkpointer.expiry:
|
35
35
|
for call_hash, (date, _) in list(calldict.items()):
|
36
|
-
if self.
|
36
|
+
if self.expired_dt(date):
|
37
37
|
del calldict[call_hash]
|
@@ -9,7 +9,7 @@ def filedate(path: Path) -> datetime:
|
|
9
9
|
|
10
10
|
class PickleStorage(Storage):
|
11
11
|
def get_path(self, call_hash: str):
|
12
|
-
return self.fn_dir() / f"{call_hash}.pkl"
|
12
|
+
return self.fn_dir() / f"{call_hash[:2]}/{call_hash[2:]}.pkl"
|
13
13
|
|
14
14
|
def store(self, call_hash, data):
|
15
15
|
path = self.get_path(call_hash)
|
@@ -40,10 +40,10 @@ class PickleStorage(Storage):
|
|
40
40
|
for path in old_dirs:
|
41
41
|
shutil.rmtree(path)
|
42
42
|
print(f"Removed {len(old_dirs)} invalidated directories for {self.cached_fn.__qualname__}")
|
43
|
-
if expired and self.checkpointer.
|
43
|
+
if expired and self.checkpointer.expiry:
|
44
44
|
count = 0
|
45
45
|
for pkl_path in fn_path.glob("**/*.pkl"):
|
46
|
-
if self.
|
46
|
+
if self.expired_dt(filedate(pkl_path)):
|
47
47
|
count += 1
|
48
48
|
pkl_path.unlink(missing_ok=True)
|
49
49
|
print(f"Removed {count} expired checkpoints for {self.cached_fn.__qualname__}")
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from typing import Any, TYPE_CHECKING
|
3
3
|
from pathlib import Path
|
4
|
-
from datetime import datetime
|
4
|
+
from datetime import datetime, timedelta
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from ..checkpoint import Checkpointer, CachedFunction
|
@@ -21,6 +21,19 @@ class Storage:
|
|
21
21
|
def fn_dir(self) -> Path:
|
22
22
|
return self.checkpointer.directory / self.fn_id()
|
23
23
|
|
24
|
+
def expired(self, call_hash: str) -> bool:
|
25
|
+
if not self.checkpointer.expiry:
|
26
|
+
return False
|
27
|
+
return self.expired_dt(self.checkpoint_date(call_hash))
|
28
|
+
|
29
|
+
def expired_dt(self, dt: datetime) -> bool:
|
30
|
+
expiry = self.checkpointer.expiry
|
31
|
+
if isinstance(expiry, timedelta):
|
32
|
+
return dt < datetime.now() - expiry
|
33
|
+
else:
|
34
|
+
if TYPE_CHECKING: assert expiry
|
35
|
+
return expiry(dt)
|
36
|
+
|
24
37
|
def store(self, call_hash: str, data: Any) -> Any: ...
|
25
38
|
|
26
39
|
def exists(self, call_hash: str) -> bool: ...
|
@@ -1,13 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import inspect
|
3
3
|
from contextlib import contextmanager, suppress
|
4
|
-
from itertools import islice
|
4
|
+
from itertools import chain, islice
|
5
5
|
from pathlib import Path
|
6
6
|
from types import FunctionType, MethodType, ModuleType
|
7
7
|
from typing import Callable, Generic, Iterable, Self, Type, TypeGuard
|
8
8
|
from .types import T
|
9
9
|
|
10
10
|
cwd = Path.cwd().resolve()
|
11
|
+
flatten = chain.from_iterable
|
11
12
|
|
12
13
|
def is_class(obj) -> TypeGuard[Type]:
|
13
14
|
return isinstance(obj, type)
|
@@ -22,7 +23,7 @@ def is_user_fn(obj) -> TypeGuard[Callable]:
|
|
22
23
|
return isinstance(obj, (FunctionType, MethodType)) and is_user_file(get_file(obj))
|
23
24
|
|
24
25
|
def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
|
25
|
-
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or
|
26
|
+
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or ()):
|
26
27
|
with suppress(ValueError):
|
27
28
|
yield (key, cell.cell_contents)
|
28
29
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|