checkpointer 2.7.1__py3-none-any.whl → 2.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer/checkpoint.py +28 -20
- checkpointer/test_checkpointer.py +3 -2
- checkpointer/utils.py +10 -19
- {checkpointer-2.7.1.dist-info → checkpointer-2.8.1.dist-info}/METADATA +2 -2
- {checkpointer-2.7.1.dist-info → checkpointer-2.8.1.dist-info}/RECORD +7 -7
- {checkpointer-2.7.1.dist-info → checkpointer-2.8.1.dist-info}/WHEEL +0 -0
- {checkpointer-2.7.1.dist-info → checkpointer-2.8.1.dist-info}/licenses/LICENSE +0 -0
checkpointer/checkpoint.py
CHANGED
@@ -4,12 +4,12 @@ import re
|
|
4
4
|
from datetime import datetime
|
5
5
|
from functools import update_wrapper
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Any, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
|
7
|
+
from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
|
8
8
|
from .fn_ident import get_fn_ident
|
9
9
|
from .object_hash import ObjectHash
|
10
10
|
from .print_checkpoint import print_checkpoint
|
11
11
|
from .storages import STORAGE_MAP, Storage
|
12
|
-
from .utils import
|
12
|
+
from .utils import AwaitableValue, unwrap_fn
|
13
13
|
|
14
14
|
Fn = TypeVar("Fn", bound=Callable)
|
15
15
|
P = ParamSpec("P")
|
@@ -70,9 +70,8 @@ class CheckpointFn(Generic[Fn]):
|
|
70
70
|
update_wrapper(cast(Callable, self), wrapped)
|
71
71
|
Storage = STORAGE_MAP[params.format] if isinstance(params.format, str) else params.format
|
72
72
|
deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
|
73
|
-
self.fn_hash = str(params.fn_hash or ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
|
74
|
-
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:
|
75
|
-
self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
|
73
|
+
self.fn_hash = str(params.fn_hash or ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes))
|
74
|
+
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:32]}"
|
76
75
|
self.storage = Storage(self)
|
77
76
|
self.cleanup = self.storage.cleanup
|
78
77
|
|
@@ -98,7 +97,12 @@ class CheckpointFn(Generic[Fn]):
|
|
98
97
|
call_hash = ObjectHash(hash_params, digest_size=16)
|
99
98
|
return f"{self.fn_subdir}/{call_hash}"
|
100
99
|
|
101
|
-
async def
|
100
|
+
async def _resolve_awaitable(self, checkpoint_path: Path, awaitable: Awaitable):
|
101
|
+
data = await awaitable
|
102
|
+
self.storage.store(checkpoint_path, AwaitableValue(data))
|
103
|
+
return data
|
104
|
+
|
105
|
+
def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
|
102
106
|
params = self.checkpointer
|
103
107
|
checkpoint_id = self.get_checkpoint_id(args, kw)
|
104
108
|
checkpoint_path = params.root_path / checkpoint_id
|
@@ -109,10 +113,11 @@ class CheckpointFn(Generic[Fn]):
|
|
109
113
|
if refresh:
|
110
114
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
|
111
115
|
data = self.fn(*args, **kw)
|
112
|
-
if inspect.
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
+
if inspect.isawaitable(data):
|
117
|
+
return self._resolve_awaitable(checkpoint_path, data)
|
118
|
+
else:
|
119
|
+
self.storage.store(checkpoint_path, data)
|
120
|
+
return data
|
116
121
|
|
117
122
|
try:
|
118
123
|
data = self.storage.load(checkpoint_path)
|
@@ -121,20 +126,26 @@ class CheckpointFn(Generic[Fn]):
|
|
121
126
|
except (EOFError, FileNotFoundError):
|
122
127
|
pass
|
123
128
|
print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
|
124
|
-
return
|
129
|
+
return self._store_on_demand(args, kw, True)
|
125
130
|
|
126
131
|
def _call(self, args: tuple, kw: dict, rerun=False):
|
127
132
|
if not self.checkpointer.when:
|
128
133
|
return self.fn(*args, **kw)
|
129
|
-
|
130
|
-
return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
|
134
|
+
return self._store_on_demand(args, kw, rerun)
|
131
135
|
|
132
|
-
|
133
|
-
|
136
|
+
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
137
|
+
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
138
|
+
|
139
|
+
@overload
|
140
|
+
def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
141
|
+
@overload
|
142
|
+
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
143
|
+
|
144
|
+
def get(self, *args, **kw):
|
134
145
|
checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
|
135
146
|
try:
|
136
|
-
|
137
|
-
return
|
147
|
+
data = self.storage.load(checkpoint_path)
|
148
|
+
return data.value if isinstance(data, AwaitableValue) else data
|
138
149
|
except Exception as ex:
|
139
150
|
raise CheckpointError("Could not load checkpoint") from ex
|
140
151
|
|
@@ -142,9 +153,6 @@ class CheckpointFn(Generic[Fn]):
|
|
142
153
|
self = cast(CheckpointFn, self)
|
143
154
|
return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
|
144
155
|
|
145
|
-
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
146
|
-
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
147
|
-
|
148
156
|
def __repr__(self) -> str:
|
149
157
|
return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
|
150
158
|
|
@@ -79,9 +79,10 @@ async def test_async_caching():
|
|
79
79
|
return x ** 2
|
80
80
|
|
81
81
|
result1 = await async_square(3)
|
82
|
-
result2 = await async_square
|
82
|
+
result2 = await async_square(3)
|
83
|
+
result3 = async_square.get(3)
|
83
84
|
|
84
|
-
assert result1 == result2 == 9
|
85
|
+
assert result1 == result2 == result3 == 9
|
85
86
|
|
86
87
|
def test_force_recalculation():
|
87
88
|
@checkpoint
|
checkpointer/utils.py
CHANGED
@@ -2,11 +2,10 @@ import inspect
|
|
2
2
|
import tokenize
|
3
3
|
from contextlib import contextmanager
|
4
4
|
from io import StringIO
|
5
|
-
from
|
6
|
-
from typing import Any, Callable, Coroutine, Generator, Generic, Iterable, TypeVar, cast
|
5
|
+
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
7
6
|
|
8
7
|
T = TypeVar("T")
|
9
|
-
|
8
|
+
Fn = TypeVar("Fn", bound=Callable)
|
10
9
|
|
11
10
|
def distinct(seq: Iterable[T]) -> list[T]:
|
12
11
|
return list(dict.fromkeys(seq))
|
@@ -33,28 +32,20 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
|
|
33
32
|
except ValueError:
|
34
33
|
pass
|
35
34
|
|
36
|
-
def unwrap_fn(fn:
|
35
|
+
def unwrap_fn(fn: Fn, checkpoint_fn=False) -> Fn:
|
37
36
|
from .checkpoint import CheckpointFn
|
38
37
|
while True:
|
39
38
|
if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
|
40
|
-
return cast(
|
39
|
+
return cast(Fn, fn)
|
41
40
|
fn = getattr(fn, "__wrapped__")
|
42
41
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
@coroutine
|
47
|
-
def coroutine_as_generator(coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
|
48
|
-
val = yield from coroutine
|
49
|
-
return val
|
42
|
+
class AwaitableValue:
|
43
|
+
def __init__(self, value):
|
44
|
+
self.value = value
|
50
45
|
|
51
|
-
def
|
52
|
-
|
53
|
-
|
54
|
-
while True:
|
55
|
-
next(gen)
|
56
|
-
except StopIteration as ex:
|
57
|
-
return ex.value
|
46
|
+
def __await__(self):
|
47
|
+
yield
|
48
|
+
return self.value
|
58
49
|
|
59
50
|
class AttrDict(dict):
|
60
51
|
def __init__(self, *args, **kwargs):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.8.1
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
@@ -255,7 +255,7 @@ async def main():
|
|
255
255
|
result2 = await async_compute_sum(3, 7)
|
256
256
|
print(result2) # Outputs 10
|
257
257
|
|
258
|
-
result3 =
|
258
|
+
result3 = async_compute_sum.get(3, 7)
|
259
259
|
print(result3) # Outputs 10
|
260
260
|
|
261
261
|
asyncio.run(main())
|
@@ -1,16 +1,16 @@
|
|
1
1
|
checkpointer/__init__.py,sha256=HRLsQ24ZhxgmDcHchZ-hX6wA0NMCSedGA0NmCnUdS_c,832
|
2
|
-
checkpointer/checkpoint.py,sha256=
|
2
|
+
checkpointer/checkpoint.py,sha256=FZSKQVIXj_Enja0253WDUauO8siUHwlzcr4frHMJzB0,6538
|
3
3
|
checkpointer/fn_ident.py,sha256=SWaksNCTlskMom0ztqjECSRjZYPWXUA1p1ZCb-9tWo0,4297
|
4
4
|
checkpointer/object_hash.py,sha256=cxuWRDrg4F9wC18aC12zOZYOPv3bk2Qf6tZ0_WgAb6Y,7484
|
5
5
|
checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
|
6
|
-
checkpointer/test_checkpointer.py,sha256=
|
7
|
-
checkpointer/utils.py,sha256=
|
6
|
+
checkpointer/test_checkpointer.py,sha256=DM5f1Ci4z7MhDnxK-2Z-V3a3ntzBJqORQvCFROAf2SI,3807
|
7
|
+
checkpointer/utils.py,sha256=2yk1ksKszXofwnSqBrNCFRSR4C3YLPEAdHUFf-cviRU,2755
|
8
8
|
checkpointer/storages/__init__.py,sha256=Kl4Og5jhYxn6m3tB_kTMsabf4_eWVLmFVAoC-pikNQE,301
|
9
9
|
checkpointer/storages/bcolz_storage.py,sha256=3QkSUSeG5s2kFuVV_LZpzMn1A5E7kqC7jk7w35c0NyQ,2314
|
10
10
|
checkpointer/storages/memory_storage.py,sha256=S5ayOZE_CyaFQJ-vSgObTanldPzG3gh3NksjNAc7vsk,1282
|
11
11
|
checkpointer/storages/pickle_storage.py,sha256=idh9sBMdWuyvS220oa_7bAUpc9Xo9v6Ud9aYKGWasUs,1593
|
12
12
|
checkpointer/storages/storage.py,sha256=_m18Z8TKrdAbi6YYYQmuNOnhna4RB2sJDn1v3liaU3U,721
|
13
|
-
checkpointer-2.
|
14
|
-
checkpointer-2.
|
15
|
-
checkpointer-2.
|
16
|
-
checkpointer-2.
|
13
|
+
checkpointer-2.8.1.dist-info/METADATA,sha256=z0SbrcyKvHTOgJ6oq0dshV25fBAsqGjMCl_uWmxSFRM,10600
|
14
|
+
checkpointer-2.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
checkpointer-2.8.1.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
|
16
|
+
checkpointer-2.8.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|