checkpointer 2.9.1__tar.gz → 2.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpointer-2.9.1 → checkpointer-2.10.0}/PKG-INFO +27 -29
- {checkpointer-2.9.1 → checkpointer-2.10.0}/README.md +26 -28
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/__init__.py +4 -4
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/checkpoint.py +45 -31
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/fn_ident.py +6 -6
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/object_hash.py +3 -1
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/storages/__init__.py +0 -2
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/storages/memory_storage.py +1 -0
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/storages/pickle_storage.py +11 -6
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/storages/storage.py +7 -7
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/test_checkpointer.py +0 -1
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/utils.py +3 -3
- {checkpointer-2.9.1 → checkpointer-2.10.0}/pyproject.toml +1 -1
- {checkpointer-2.9.1 → checkpointer-2.10.0}/uv.lock +1 -1
- checkpointer-2.9.1/checkpointer/storages/bcolz_storage.py +0 -80
- {checkpointer-2.9.1 → checkpointer-2.10.0}/.gitignore +0 -0
- {checkpointer-2.9.1 → checkpointer-2.10.0}/.python-version +0 -0
- {checkpointer-2.9.1 → checkpointer-2.10.0}/LICENSE +0 -0
- {checkpointer-2.9.1 → checkpointer-2.10.0}/checkpointer/print_checkpoint.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.10.0
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
@@ -48,19 +48,14 @@ result = expensive_function(4) # Loads from the cache
|
|
48
48
|
|
49
49
|
## 🧠 How It Works
|
50
50
|
|
51
|
-
When a function
|
51
|
+
When a `@checkpoint`-decorated function is called, `checkpointer` first computes a unique identifier (hash) for the call. This hash is derived from the function's source code, its dependencies, and the arguments passed.
|
52
52
|
|
53
|
-
|
54
|
-
2. It attempts to retrieve a cached result using this identifier.
|
55
|
-
3. If a cached result is found, it's returned immediately.
|
56
|
-
4. If no cached result exists or the cache has expired, the original function is executed, its result is stored, and then returned.
|
53
|
+
It then tries to retrieve a cached result using this ID. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
|
57
54
|
|
58
|
-
|
55
|
+
Cache validity is determined by this function's hash, which automatically updates if:
|
59
56
|
|
60
|
-
|
61
|
-
|
62
|
-
* **Function Code Changes**: The source code of the decorated function itself is modified.
|
63
|
-
* **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated with `@checkpoint`) is modified.
|
57
|
+
* **Function Code Changes**: The decorated function's source code is modified.
|
58
|
+
* **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated) is modified.
|
64
59
|
* **Captured Variables Change** (with `capture=True`): Global or closure-based variables used within the function are altered.
|
65
60
|
|
66
61
|
**Example: Dependency Invalidation**
|
@@ -76,63 +71,63 @@ def helper(x):
|
|
76
71
|
|
77
72
|
@checkpoint
|
78
73
|
def compute(a, b):
|
79
|
-
# Depends on `helper`
|
74
|
+
# Depends on `helper` and `multiply`
|
80
75
|
return helper(a) + helper(b)
|
81
76
|
```
|
82
77
|
|
83
|
-
If `multiply` is modified, caches for both `helper` and `compute`
|
78
|
+
If `multiply` is modified, caches for both `helper` and `compute` are automatically invalidated and recomputed.
|
84
79
|
|
85
80
|
## 💡 Usage
|
86
81
|
|
87
82
|
Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
|
88
83
|
|
89
|
-
* **`expensive_function(...)
|
84
|
+
* **`expensive_function(...)`**:\
|
90
85
|
Call the function normally. This will either compute and cache the result or load it from the cache if available.
|
91
86
|
|
92
|
-
* **`expensive_function.rerun(...)
|
87
|
+
* **`expensive_function.rerun(...)`**:\
|
93
88
|
Forces the original function to execute, compute a new result, and overwrite any existing cached value for the given arguments.
|
94
89
|
|
95
|
-
* **`expensive_function.fn(...)
|
90
|
+
* **`expensive_function.fn(...)`**:\
|
96
91
|
Calls the original, undecorated function directly, bypassing the cache entirely. This is particularly useful within recursive functions to prevent caching intermediate steps.
|
97
92
|
|
98
|
-
* **`expensive_function.get(...)
|
93
|
+
* **`expensive_function.get(...)`**:\
|
99
94
|
Attempts to retrieve the cached result for the given arguments without executing the original function. Raises `CheckpointError` if no valid cached result exists.
|
100
95
|
|
101
|
-
* **`expensive_function.exists(...)
|
96
|
+
* **`expensive_function.exists(...)`**:\
|
102
97
|
Checks if a cached result exists for the given arguments without attempting to compute or load it. Returns `True` if a valid checkpoint exists, `False` otherwise.
|
103
98
|
|
104
|
-
* **`expensive_function.delete(...)
|
99
|
+
* **`expensive_function.delete(...)`**:\
|
105
100
|
Removes the cached entry for the specified arguments.
|
106
101
|
|
107
|
-
* **`expensive_function.reinit()
|
102
|
+
* **`expensive_function.reinit()`**:\
|
108
103
|
Recalculates the function's internal hash. This is primarily used when `capture=True` and you need to update the cache based on changes to external variables within the same Python session.
|
109
104
|
|
110
105
|
## ⚙️ Configuration & Customization
|
111
106
|
|
112
107
|
The `@checkpoint` decorator accepts the following parameters to customize its behavior:
|
113
108
|
|
114
|
-
* **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)
|
109
|
+
* **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
|
115
110
|
Defines the storage backend to use. Built-in options are `"pickle"` (disk-based, persistent) and `"memory"` (in-memory, non-persistent). You can also provide a custom `Storage` class.
|
116
111
|
|
117
|
-
* **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)
|
112
|
+
* **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
|
118
113
|
The base directory for storing disk-based checkpoints. This parameter is only relevant when `format` is set to `"pickle"`.
|
119
114
|
|
120
|
-
* **`when`** (Type: `bool`, Default: `True`)
|
115
|
+
* **`when`** (Type: `bool`, Default: `True`)\
|
121
116
|
A boolean flag to enable or disable checkpointing for the decorated function. This is particularly useful for toggling caching based on environment variables (e.g., `when=os.environ.get("ENABLE_CACHING", "false").lower() == "true"`).
|
122
117
|
|
123
|
-
* **`capture`** (Type: `bool`, Default: `False`)
|
118
|
+
* **`capture`** (Type: `bool`, Default: `False`)\
|
124
119
|
If set to `True`, `checkpointer` includes global or closure-based variables used by the function in its hash calculation. This ensures that changes to these external variables also trigger cache invalidation and recomputation.
|
125
120
|
|
126
|
-
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)
|
121
|
+
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
|
127
122
|
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
128
123
|
|
129
|
-
* **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)
|
124
|
+
* **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)\
|
130
125
|
A custom callable that takes the function's arguments (`*args`, `**kwargs`) and returns a hashable object (or tuple of objects). This allows for custom argument normalization (e.g., sorting lists before hashing) or optimized hashing for complex input types, which can improve cache hit rates or speed up the hashing process.
|
131
126
|
|
132
|
-
* **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)
|
127
|
+
* **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
|
133
128
|
An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
|
134
129
|
|
135
|
-
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)
|
130
|
+
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
136
131
|
Controls the level of logging output from `checkpointer`.
|
137
132
|
* `0`: No output.
|
138
133
|
* `1`: Shows when functions are computed and cached.
|
@@ -156,8 +151,11 @@ class MyCustomStorage(Storage):
|
|
156
151
|
fn_dir = self.checkpointer.root_path / self.fn_id()
|
157
152
|
return (fn_dir / call_id).exists()
|
158
153
|
|
154
|
+
def store(self, call_id, data):
|
155
|
+
... # Store the serialized data for `call_id`
|
156
|
+
return data # This method must return the data back to checkpointer
|
157
|
+
|
159
158
|
def checkpoint_date(self, call_id): ...
|
160
|
-
def store(self, call_id, data): ...
|
161
159
|
def load(self, call_id): ...
|
162
160
|
def delete(self, call_id): ...
|
163
161
|
|
@@ -28,19 +28,14 @@ result = expensive_function(4) # Loads from the cache
|
|
28
28
|
|
29
29
|
## 🧠 How It Works
|
30
30
|
|
31
|
-
When a function
|
31
|
+
When a `@checkpoint`-decorated function is called, `checkpointer` first computes a unique identifier (hash) for the call. This hash is derived from the function's source code, its dependencies, and the arguments passed.
|
32
32
|
|
33
|
-
|
34
|
-
2. It attempts to retrieve a cached result using this identifier.
|
35
|
-
3. If a cached result is found, it's returned immediately.
|
36
|
-
4. If no cached result exists or the cache has expired, the original function is executed, its result is stored, and then returned.
|
33
|
+
It then tries to retrieve a cached result using this ID. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
|
37
34
|
|
38
|
-
|
35
|
+
Cache validity is determined by this function's hash, which automatically updates if:
|
39
36
|
|
40
|
-
|
41
|
-
|
42
|
-
* **Function Code Changes**: The source code of the decorated function itself is modified.
|
43
|
-
* **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated with `@checkpoint`) is modified.
|
37
|
+
* **Function Code Changes**: The decorated function's source code is modified.
|
38
|
+
* **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated) is modified.
|
44
39
|
* **Captured Variables Change** (with `capture=True`): Global or closure-based variables used within the function are altered.
|
45
40
|
|
46
41
|
**Example: Dependency Invalidation**
|
@@ -56,63 +51,63 @@ def helper(x):
|
|
56
51
|
|
57
52
|
@checkpoint
|
58
53
|
def compute(a, b):
|
59
|
-
# Depends on `helper`
|
54
|
+
# Depends on `helper` and `multiply`
|
60
55
|
return helper(a) + helper(b)
|
61
56
|
```
|
62
57
|
|
63
|
-
If `multiply` is modified, caches for both `helper` and `compute`
|
58
|
+
If `multiply` is modified, caches for both `helper` and `compute` are automatically invalidated and recomputed.
|
64
59
|
|
65
60
|
## 💡 Usage
|
66
61
|
|
67
62
|
Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
|
68
63
|
|
69
|
-
* **`expensive_function(...)
|
64
|
+
* **`expensive_function(...)`**:\
|
70
65
|
Call the function normally. This will either compute and cache the result or load it from the cache if available.
|
71
66
|
|
72
|
-
* **`expensive_function.rerun(...)
|
67
|
+
* **`expensive_function.rerun(...)`**:\
|
73
68
|
Forces the original function to execute, compute a new result, and overwrite any existing cached value for the given arguments.
|
74
69
|
|
75
|
-
* **`expensive_function.fn(...)
|
70
|
+
* **`expensive_function.fn(...)`**:\
|
76
71
|
Calls the original, undecorated function directly, bypassing the cache entirely. This is particularly useful within recursive functions to prevent caching intermediate steps.
|
77
72
|
|
78
|
-
* **`expensive_function.get(...)
|
73
|
+
* **`expensive_function.get(...)`**:\
|
79
74
|
Attempts to retrieve the cached result for the given arguments without executing the original function. Raises `CheckpointError` if no valid cached result exists.
|
80
75
|
|
81
|
-
* **`expensive_function.exists(...)
|
76
|
+
* **`expensive_function.exists(...)`**:\
|
82
77
|
Checks if a cached result exists for the given arguments without attempting to compute or load it. Returns `True` if a valid checkpoint exists, `False` otherwise.
|
83
78
|
|
84
|
-
* **`expensive_function.delete(...)
|
79
|
+
* **`expensive_function.delete(...)`**:\
|
85
80
|
Removes the cached entry for the specified arguments.
|
86
81
|
|
87
|
-
* **`expensive_function.reinit()
|
82
|
+
* **`expensive_function.reinit()`**:\
|
88
83
|
Recalculates the function's internal hash. This is primarily used when `capture=True` and you need to update the cache based on changes to external variables within the same Python session.
|
89
84
|
|
90
85
|
## ⚙️ Configuration & Customization
|
91
86
|
|
92
87
|
The `@checkpoint` decorator accepts the following parameters to customize its behavior:
|
93
88
|
|
94
|
-
* **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)
|
89
|
+
* **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
|
95
90
|
Defines the storage backend to use. Built-in options are `"pickle"` (disk-based, persistent) and `"memory"` (in-memory, non-persistent). You can also provide a custom `Storage` class.
|
96
91
|
|
97
|
-
* **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)
|
92
|
+
* **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
|
98
93
|
The base directory for storing disk-based checkpoints. This parameter is only relevant when `format` is set to `"pickle"`.
|
99
94
|
|
100
|
-
* **`when`** (Type: `bool`, Default: `True`)
|
95
|
+
* **`when`** (Type: `bool`, Default: `True`)\
|
101
96
|
A boolean flag to enable or disable checkpointing for the decorated function. This is particularly useful for toggling caching based on environment variables (e.g., `when=os.environ.get("ENABLE_CACHING", "false").lower() == "true"`).
|
102
97
|
|
103
|
-
* **`capture`** (Type: `bool`, Default: `False`)
|
98
|
+
* **`capture`** (Type: `bool`, Default: `False`)\
|
104
99
|
If set to `True`, `checkpointer` includes global or closure-based variables used by the function in its hash calculation. This ensures that changes to these external variables also trigger cache invalidation and recomputation.
|
105
100
|
|
106
|
-
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)
|
101
|
+
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
|
107
102
|
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
108
103
|
|
109
|
-
* **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)
|
104
|
+
* **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)\
|
110
105
|
A custom callable that takes the function's arguments (`*args`, `**kwargs`) and returns a hashable object (or tuple of objects). This allows for custom argument normalization (e.g., sorting lists before hashing) or optimized hashing for complex input types, which can improve cache hit rates or speed up the hashing process.
|
111
106
|
|
112
|
-
* **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)
|
107
|
+
* **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
|
113
108
|
An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
|
114
109
|
|
115
|
-
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)
|
110
|
+
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
116
111
|
Controls the level of logging output from `checkpointer`.
|
117
112
|
* `0`: No output.
|
118
113
|
* `1`: Shows when functions are computed and cached.
|
@@ -136,8 +131,11 @@ class MyCustomStorage(Storage):
|
|
136
131
|
fn_dir = self.checkpointer.root_path / self.fn_id()
|
137
132
|
return (fn_dir / call_id).exists()
|
138
133
|
|
134
|
+
def store(self, call_id, data):
|
135
|
+
... # Store the serialized data for `call_id`
|
136
|
+
return data # This method must return the data back to checkpointer
|
137
|
+
|
139
138
|
def checkpoint_date(self, call_id): ...
|
140
|
-
def store(self, call_id, data): ...
|
141
139
|
def load(self, call_id): ...
|
142
140
|
def delete(self, call_id): ...
|
143
141
|
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import gc
|
2
2
|
import tempfile
|
3
3
|
from typing import Callable
|
4
|
-
from .checkpoint import Checkpointer, CheckpointError
|
4
|
+
from .checkpoint import CachedFunction, Checkpointer, CheckpointError
|
5
5
|
from .object_hash import ObjectHash
|
6
6
|
from .storages import MemoryStorage, PickleStorage, Storage
|
7
|
+
from .utils import AwaitableValue
|
7
8
|
|
8
|
-
create_checkpointer = Checkpointer
|
9
9
|
checkpoint = Checkpointer()
|
10
10
|
capture_checkpoint = Checkpointer(capture=True)
|
11
11
|
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
@@ -14,8 +14,8 @@ static_checkpoint = Checkpointer(fn_hash=ObjectHash())
|
|
14
14
|
|
15
15
|
def cleanup_all(invalidated=True, expired=True):
|
16
16
|
for obj in gc.get_objects():
|
17
|
-
if isinstance(obj,
|
17
|
+
if isinstance(obj, CachedFunction):
|
18
18
|
obj.cleanup(invalidated=invalidated, expired=expired)
|
19
19
|
|
20
20
|
def get_function_hash(fn: Callable, capture=False) -> str:
|
21
|
-
return
|
21
|
+
return CachedFunction(Checkpointer(capture=capture), fn).fn_hash
|
@@ -1,11 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import inspect
|
3
3
|
import re
|
4
|
-
from contextlib import suppress
|
5
4
|
from datetime import datetime
|
6
5
|
from functools import cached_property, update_wrapper
|
7
6
|
from pathlib import Path
|
8
|
-
from typing import
|
7
|
+
from typing import (
|
8
|
+
Awaitable, Callable, Concatenate, Generic, Iterable, Literal,
|
9
|
+
ParamSpec, Self, Type, TypedDict, TypeVar, Unpack, cast, overload,
|
10
|
+
)
|
9
11
|
from .fn_ident import get_fn_ident
|
10
12
|
from .object_hash import ObjectHash
|
11
13
|
from .print_checkpoint import print_checkpoint
|
@@ -15,6 +17,7 @@ from .utils import AwaitableValue, unwrap_fn
|
|
15
17
|
Fn = TypeVar("Fn", bound=Callable)
|
16
18
|
P = ParamSpec("P")
|
17
19
|
R = TypeVar("R")
|
20
|
+
C = TypeVar("C")
|
18
21
|
|
19
22
|
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
20
23
|
|
@@ -43,17 +46,17 @@ class Checkpointer:
|
|
43
46
|
self.fn_hash = opts.get("fn_hash")
|
44
47
|
|
45
48
|
@overload
|
46
|
-
def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) ->
|
49
|
+
def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
|
47
50
|
@overload
|
48
51
|
def __call__(self, fn: None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer: ...
|
49
|
-
def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer |
|
52
|
+
def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer | CachedFunction[Fn]:
|
50
53
|
if override_opts:
|
51
54
|
opts = CheckpointerOpts(**{**self.__dict__, **override_opts})
|
52
55
|
return Checkpointer(**opts)(fn)
|
53
56
|
|
54
|
-
return
|
57
|
+
return CachedFunction(self, fn) if callable(fn) else self
|
55
58
|
|
56
|
-
class
|
59
|
+
class CachedFunction(Generic[Fn]):
|
57
60
|
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
58
61
|
wrapped = unwrap_fn(fn)
|
59
62
|
fn_file = Path(wrapped.__code__.co_filename).name
|
@@ -65,6 +68,19 @@ class CheckpointFn(Generic[Fn]):
|
|
65
68
|
self.fn_dir = f"{fn_file}/{fn_name}"
|
66
69
|
self.storage = Storage(self)
|
67
70
|
self.cleanup = self.storage.cleanup
|
71
|
+
self.bound = ()
|
72
|
+
|
73
|
+
@overload
|
74
|
+
def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
|
75
|
+
@overload
|
76
|
+
def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
|
77
|
+
def __get__(self, instance, owner):
|
78
|
+
if instance is None:
|
79
|
+
return self
|
80
|
+
bound_fn = object.__new__(CachedFunction)
|
81
|
+
bound_fn.__dict__ |= self.__dict__
|
82
|
+
bound_fn.bound = (instance,)
|
83
|
+
return bound_fn
|
68
84
|
|
69
85
|
@cached_property
|
70
86
|
def ident_tuple(self) -> tuple[str, list[Callable]]:
|
@@ -80,33 +96,33 @@ class CheckpointFn(Generic[Fn]):
|
|
80
96
|
|
81
97
|
@cached_property
|
82
98
|
def fn_hash(self) -> str:
|
83
|
-
fn_hash = self.checkpointer.fn_hash
|
84
99
|
deep_hashes = [depend.fn_hash_raw for depend in self.deep_depends()]
|
85
|
-
|
100
|
+
fn_hash = ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes)
|
101
|
+
return str(self.checkpointer.fn_hash or fn_hash)[:32]
|
86
102
|
|
87
|
-
def reinit(self, recursive=False) ->
|
103
|
+
def reinit(self, recursive=False) -> CachedFunction[Fn]:
|
88
104
|
depends = list(self.deep_depends()) if recursive else [self]
|
89
105
|
for depend in depends:
|
90
|
-
|
91
|
-
|
106
|
+
self.__dict__.pop("fn_hash", None)
|
107
|
+
self.__dict__.pop("ident_tuple", None)
|
92
108
|
for depend in depends:
|
93
109
|
depend.fn_hash
|
94
110
|
return self
|
95
111
|
|
96
112
|
def get_call_id(self, args: tuple, kw: dict) -> str:
|
113
|
+
args = self.bound + args
|
97
114
|
hash_by = self.checkpointer.hash_by
|
98
115
|
hash_params = hash_by(*args, **kw) if hash_by else (args, kw)
|
99
116
|
return str(ObjectHash(hash_params, digest_size=16))
|
100
117
|
|
101
|
-
async def _resolve_awaitable(self,
|
102
|
-
|
103
|
-
self.storage.store(checkpoint_id, AwaitableValue(data))
|
104
|
-
return data
|
118
|
+
async def _resolve_awaitable(self, call_id: str, awaitable: Awaitable):
|
119
|
+
return await self.storage.store(call_id, AwaitableValue(await awaitable))
|
105
120
|
|
106
|
-
def _call(self, args: tuple, kw: dict, rerun=False):
|
121
|
+
def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
|
122
|
+
full_args = self.bound + args
|
107
123
|
params = self.checkpointer
|
108
124
|
if not params.when:
|
109
|
-
return self.fn(*
|
125
|
+
return self.fn(*full_args, **kw)
|
110
126
|
|
111
127
|
call_id = self.get_call_id(args, kw)
|
112
128
|
call_id_long = f"{self.fn_dir}/{self.fn_hash}/{call_id}"
|
@@ -117,12 +133,10 @@ class CheckpointFn(Generic[Fn]):
|
|
117
133
|
|
118
134
|
if refresh:
|
119
135
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id_long, "blue")
|
120
|
-
data = self.fn(*
|
136
|
+
data = self.fn(*full_args, **kw)
|
121
137
|
if inspect.isawaitable(data):
|
122
138
|
return self._resolve_awaitable(call_id, data)
|
123
|
-
|
124
|
-
self.storage.store(call_id, data)
|
125
|
-
return data
|
139
|
+
return self.storage.store(call_id, data)
|
126
140
|
|
127
141
|
try:
|
128
142
|
data = self.storage.load(call_id)
|
@@ -133,14 +147,16 @@ class CheckpointFn(Generic[Fn]):
|
|
133
147
|
print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_id_long, "yellow")
|
134
148
|
return self._call(args, kw, True)
|
135
149
|
|
136
|
-
__call__:
|
137
|
-
|
150
|
+
def __call__(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
|
151
|
+
return self._call(args, kw)
|
152
|
+
|
153
|
+
def rerun(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
|
154
|
+
return self._call(args, kw, True)
|
138
155
|
|
139
156
|
@overload
|
140
157
|
def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
141
158
|
@overload
|
142
159
|
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
143
|
-
|
144
160
|
def get(self, *args, **kw):
|
145
161
|
call_id = self.get_call_id(args, kw)
|
146
162
|
try:
|
@@ -149,22 +165,20 @@ class CheckpointFn(Generic[Fn]):
|
|
149
165
|
except Exception as ex:
|
150
166
|
raise CheckpointError("Could not load checkpoint") from ex
|
151
167
|
|
152
|
-
def exists(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> bool:
|
153
|
-
self = cast(CheckpointFn, self)
|
168
|
+
def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
|
154
169
|
return self.storage.exists(self.get_call_id(args, kw))
|
155
170
|
|
156
|
-
def delete(self: Callable[P, R], *args: P.args, **kw: P.kwargs):
|
157
|
-
self = cast(CheckpointFn, self)
|
171
|
+
def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
|
158
172
|
self.storage.delete(self.get_call_id(args, kw))
|
159
173
|
|
160
174
|
def __repr__(self) -> str:
|
161
|
-
return f"<
|
175
|
+
return f"<CachedFunction {self.fn.__name__} {self.fn_hash[:6]}>"
|
162
176
|
|
163
|
-
def deep_depends(self, visited: set[
|
177
|
+
def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
|
164
178
|
if self not in visited:
|
165
179
|
yield self
|
166
180
|
visited = visited or set()
|
167
181
|
visited.add(self)
|
168
182
|
for depend in self.depends:
|
169
|
-
if isinstance(depend,
|
183
|
+
if isinstance(depend, CachedFunction):
|
170
184
|
yield from depend.deep_depends(visited)
|
@@ -8,7 +8,7 @@ from typing import Any, Iterable, Type, TypeGuard
|
|
8
8
|
from .object_hash import ObjectHash
|
9
9
|
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
|
10
10
|
|
11
|
-
cwd = Path.cwd()
|
11
|
+
cwd = Path.cwd().resolve()
|
12
12
|
|
13
13
|
def is_class(obj) -> TypeGuard[Type]:
|
14
14
|
# isinstance works too, but needlessly triggers _lazyinit()
|
@@ -72,23 +72,23 @@ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
|
72
72
|
return cwd in fn_path.parents and ".venv" not in fn_path.parts
|
73
73
|
|
74
74
|
def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
|
75
|
-
from .checkpoint import
|
75
|
+
from .checkpoint import CachedFunction
|
76
76
|
captured_vals_by_fn = captured_vals_by_fn or {}
|
77
77
|
captured_vals = get_fn_captured_vals(fn)
|
78
78
|
captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
|
79
|
-
child_fns = (unwrap_fn(val,
|
79
|
+
child_fns = (unwrap_fn(val, cached_fn=True) for val in captured_vals if callable(val))
|
80
80
|
for child_fn in child_fns:
|
81
|
-
if isinstance(child_fn,
|
81
|
+
if isinstance(child_fn, CachedFunction):
|
82
82
|
captured_vals_by_fn[child_fn] = []
|
83
83
|
elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
|
84
84
|
get_depend_fns(child_fn, capture, captured_vals_by_fn)
|
85
85
|
return captured_vals_by_fn
|
86
86
|
|
87
87
|
def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
|
88
|
-
from .checkpoint import
|
88
|
+
from .checkpoint import CachedFunction
|
89
89
|
captured_vals_by_fn = get_depend_fns(fn, capture)
|
90
90
|
depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
|
91
91
|
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
92
|
-
unwrapped_depends = [fn for fn in depends if not isinstance(fn,
|
92
|
+
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
93
93
|
fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
|
94
94
|
return fn_hash, depends
|
@@ -180,8 +180,10 @@ class ObjectHash:
|
|
180
180
|
def _update_iterator(self, obj: Iterable) -> "ObjectHash":
|
181
181
|
return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
|
182
182
|
|
183
|
-
def _update_object(self, obj:
|
183
|
+
def _update_object(self, obj: Any) -> "ObjectHash":
|
184
184
|
self.header("instance", encode_type_of(obj))
|
185
|
+
if hasattr(obj, "__objecthash__") and callable(obj.__objecthash__):
|
186
|
+
return self.header("objecthash").update(obj.__objecthash__())
|
185
187
|
reduced = None
|
186
188
|
with suppress(Exception):
|
187
189
|
reduced = obj.__reduce_ex__(PROTOCOL)
|
@@ -2,10 +2,8 @@ from typing import Type
|
|
2
2
|
from .storage import Storage
|
3
3
|
from .pickle_storage import PickleStorage
|
4
4
|
from .memory_storage import MemoryStorage
|
5
|
-
from .bcolz_storage import BcolzStorage
|
6
5
|
|
7
6
|
STORAGE_MAP: dict[str, Type[Storage]] = {
|
8
7
|
"pickle": PickleStorage,
|
9
8
|
"memory": MemoryStorage,
|
10
|
-
"bcolz": BcolzStorage,
|
11
9
|
}
|
@@ -1,8 +1,12 @@
|
|
1
1
|
import pickle
|
2
2
|
import shutil
|
3
3
|
from datetime import datetime
|
4
|
+
from pathlib import Path
|
4
5
|
from .storage import Storage
|
5
6
|
|
7
|
+
def filedate(path: Path) -> datetime:
|
8
|
+
return datetime.fromtimestamp(path.stat().st_mtime)
|
9
|
+
|
6
10
|
class PickleStorage(Storage):
|
7
11
|
def get_path(self, call_id: str):
|
8
12
|
return self.fn_dir() / f"{call_id}.pkl"
|
@@ -12,13 +16,14 @@ class PickleStorage(Storage):
|
|
12
16
|
path.parent.mkdir(parents=True, exist_ok=True)
|
13
17
|
with path.open("wb") as file:
|
14
18
|
pickle.dump(data, file, -1)
|
19
|
+
return data
|
15
20
|
|
16
21
|
def exists(self, call_id):
|
17
22
|
return self.get_path(call_id).exists()
|
18
23
|
|
19
24
|
def checkpoint_date(self, call_id):
|
20
25
|
# Should use st_atime/access time?
|
21
|
-
return
|
26
|
+
return filedate(self.get_path(call_id))
|
22
27
|
|
23
28
|
def load(self, call_id):
|
24
29
|
with self.get_path(call_id).open("rb") as file:
|
@@ -34,11 +39,11 @@ class PickleStorage(Storage):
|
|
34
39
|
old_dirs = [path for path in fn_path.iterdir() if path.is_dir() and path != version_path]
|
35
40
|
for path in old_dirs:
|
36
41
|
shutil.rmtree(path)
|
37
|
-
print(f"Removed {len(old_dirs)} invalidated directories for {self.
|
42
|
+
print(f"Removed {len(old_dirs)} invalidated directories for {self.cached_fn.__qualname__}")
|
38
43
|
if expired and self.checkpointer.should_expire:
|
39
44
|
count = 0
|
40
|
-
for pkl_path in fn_path.
|
41
|
-
if self.checkpointer.should_expire(
|
45
|
+
for pkl_path in fn_path.glob("**/*.pkl"):
|
46
|
+
if self.checkpointer.should_expire(filedate(pkl_path)):
|
42
47
|
count += 1
|
43
|
-
|
44
|
-
print(f"Removed {count} expired checkpoints for {self.
|
48
|
+
pkl_path.unlink(missing_ok=True)
|
49
|
+
print(f"Removed {count} expired checkpoints for {self.cached_fn.__qualname__}")
|
@@ -4,23 +4,23 @@ from pathlib import Path
|
|
4
4
|
from datetime import datetime
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
|
-
from ..checkpoint import Checkpointer,
|
7
|
+
from ..checkpoint import Checkpointer, CachedFunction
|
8
8
|
|
9
9
|
class Storage:
|
10
10
|
checkpointer: Checkpointer
|
11
|
-
|
11
|
+
cached_fn: CachedFunction
|
12
12
|
|
13
|
-
def __init__(self,
|
14
|
-
self.checkpointer =
|
15
|
-
self.
|
13
|
+
def __init__(self, cached_fn: CachedFunction):
|
14
|
+
self.checkpointer = cached_fn.checkpointer
|
15
|
+
self.cached_fn = cached_fn
|
16
16
|
|
17
17
|
def fn_id(self) -> str:
|
18
|
-
return f"{self.
|
18
|
+
return f"{self.cached_fn.fn_dir}/{self.cached_fn.fn_hash}"
|
19
19
|
|
20
20
|
def fn_dir(self) -> Path:
|
21
21
|
return self.checkpointer.root_path / self.fn_id()
|
22
22
|
|
23
|
-
def store(self, call_id: str, data: Any) ->
|
23
|
+
def store(self, call_id: str, data: Any) -> Any: ...
|
24
24
|
|
25
25
|
def exists(self, call_id: str) -> bool: ...
|
26
26
|
|
@@ -32,10 +32,10 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
|
|
32
32
|
except ValueError:
|
33
33
|
pass
|
34
34
|
|
35
|
-
def unwrap_fn(fn: Fn,
|
36
|
-
from .checkpoint import
|
35
|
+
def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
|
36
|
+
from .checkpoint import CachedFunction
|
37
37
|
while True:
|
38
|
-
if (
|
38
|
+
if (cached_fn and isinstance(fn, CachedFunction)) or not hasattr(fn, "__wrapped__"):
|
39
39
|
return cast(Fn, fn)
|
40
40
|
fn = getattr(fn, "__wrapped__")
|
41
41
|
|
@@ -1,80 +0,0 @@
|
|
1
|
-
import shutil
|
2
|
-
from pathlib import Path
|
3
|
-
from datetime import datetime
|
4
|
-
from .storage import Storage
|
5
|
-
|
6
|
-
def get_data_type_str(x):
|
7
|
-
if isinstance(x, tuple):
|
8
|
-
return "tuple"
|
9
|
-
elif isinstance(x, dict):
|
10
|
-
return "dict"
|
11
|
-
elif isinstance(x, list):
|
12
|
-
return "list"
|
13
|
-
elif isinstance(x, str) or not hasattr(x, "__len__"):
|
14
|
-
return "other"
|
15
|
-
else:
|
16
|
-
return "ndarray"
|
17
|
-
|
18
|
-
def get_metapath(path: Path):
|
19
|
-
return path.with_name(f"{path.name}_meta")
|
20
|
-
|
21
|
-
def insert_data(path: Path, data):
|
22
|
-
import bcolz
|
23
|
-
c = bcolz.carray(data, rootdir=path, mode="w")
|
24
|
-
c.flush()
|
25
|
-
|
26
|
-
class BcolzStorage(Storage):
|
27
|
-
def exists(self, path):
|
28
|
-
return path.exists()
|
29
|
-
|
30
|
-
def checkpoint_date(self, path):
|
31
|
-
return datetime.fromtimestamp(path.stat().st_mtime)
|
32
|
-
|
33
|
-
def store(self, path, data):
|
34
|
-
metapath = get_metapath(path)
|
35
|
-
path.parent.mkdir(parents=True, exist_ok=True)
|
36
|
-
data_type_str = get_data_type_str(data)
|
37
|
-
if data_type_str == "tuple":
|
38
|
-
fields = list(range(len(data)))
|
39
|
-
elif data_type_str == "dict":
|
40
|
-
fields = sorted(data.keys())
|
41
|
-
else:
|
42
|
-
fields = []
|
43
|
-
meta_data = {"data_type_str": data_type_str, "fields": fields}
|
44
|
-
insert_data(metapath, meta_data)
|
45
|
-
if data_type_str in ["tuple", "dict"]:
|
46
|
-
for i in range(len(fields)):
|
47
|
-
child_path = Path(f"{path} ({i})")
|
48
|
-
self.store(child_path, data[fields[i]])
|
49
|
-
else:
|
50
|
-
insert_data(path, data)
|
51
|
-
|
52
|
-
def load(self, path):
|
53
|
-
import bcolz
|
54
|
-
metapath = get_metapath(path)
|
55
|
-
meta_data = bcolz.open(metapath)[:][0]
|
56
|
-
data_type_str = meta_data["data_type_str"]
|
57
|
-
if data_type_str in ["tuple", "dict"]:
|
58
|
-
fields = meta_data["fields"]
|
59
|
-
partitions = range(len(fields))
|
60
|
-
data = [self.load(Path(f"{path} ({i})")) for i in partitions]
|
61
|
-
if data_type_str == "tuple":
|
62
|
-
return tuple(data)
|
63
|
-
else:
|
64
|
-
return dict(zip(fields, data))
|
65
|
-
else:
|
66
|
-
data = bcolz.open(path)
|
67
|
-
if data_type_str == "list":
|
68
|
-
return list(data)
|
69
|
-
elif data_type_str == "other":
|
70
|
-
return data[0]
|
71
|
-
else:
|
72
|
-
return data[:]
|
73
|
-
|
74
|
-
def delete(self, path):
|
75
|
-
# NOTE: Not recursive
|
76
|
-
shutil.rmtree(get_metapath(path), ignore_errors=True)
|
77
|
-
shutil.rmtree(path, ignore_errors=True)
|
78
|
-
|
79
|
-
def cleanup(self, invalidated=True, expired=True):
|
80
|
-
raise NotImplementedError("cleanup() not implemented for bcolz storage")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|