checkpointer 2.7.0__tar.gz → 2.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpointer-2.7.0 → checkpointer-2.8.0}/PKG-INFO +2 -2
- {checkpointer-2.7.0 → checkpointer-2.8.0}/README.md +1 -1
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/__init__.py +2 -1
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/checkpoint.py +30 -19
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/object_hash.py +3 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/test_checkpointer.py +17 -16
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/utils.py +10 -19
- {checkpointer-2.7.0 → checkpointer-2.8.0}/pyproject.toml +6 -3
- {checkpointer-2.7.0 → checkpointer-2.8.0}/uv.lock +9 -9
- {checkpointer-2.7.0 → checkpointer-2.8.0}/.gitignore +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/.python-version +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/LICENSE +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/fn_ident.py +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/print_checkpoint.py +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/storages/__init__.py +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/storages/bcolz_storage.py +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/storages/memory_storage.py +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/storages/pickle_storage.py +0 -0
- {checkpointer-2.7.0 → checkpointer-2.8.0}/checkpointer/storages/storage.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.8.0
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
@@ -255,7 +255,7 @@ async def main():
|
|
255
255
|
result2 = await async_compute_sum(3, 7)
|
256
256
|
print(result2) # Outputs 10
|
257
257
|
|
258
|
-
result3 =
|
258
|
+
result3 = async_compute_sum.get(3, 7)
|
259
259
|
print(result3) # Outputs 10
|
260
260
|
|
261
261
|
asyncio.run(main())
|
@@ -9,7 +9,8 @@ create_checkpointer = Checkpointer
|
|
9
9
|
checkpoint = Checkpointer()
|
10
10
|
capture_checkpoint = Checkpointer(capture=True)
|
11
11
|
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
12
|
-
tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir()
|
12
|
+
tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
|
13
|
+
static_checkpoint = Checkpointer(fn_hash=ObjectHash())
|
13
14
|
|
14
15
|
def cleanup_all(invalidated=True, expired=True):
|
15
16
|
for obj in gc.get_objects():
|
@@ -4,14 +4,16 @@ import re
|
|
4
4
|
from datetime import datetime
|
5
5
|
from functools import update_wrapper
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
|
7
|
+
from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
|
8
8
|
from .fn_ident import get_fn_ident
|
9
9
|
from .object_hash import ObjectHash
|
10
10
|
from .print_checkpoint import print_checkpoint
|
11
11
|
from .storages import STORAGE_MAP, Storage
|
12
|
-
from .utils import
|
12
|
+
from .utils import AwaitableValue, unwrap_fn
|
13
13
|
|
14
14
|
Fn = TypeVar("Fn", bound=Callable)
|
15
|
+
P = ParamSpec("P")
|
16
|
+
R = TypeVar("R")
|
15
17
|
|
16
18
|
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
17
19
|
|
@@ -70,7 +72,6 @@ class CheckpointFn(Generic[Fn]):
|
|
70
72
|
deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
|
71
73
|
self.fn_hash = str(params.fn_hash or ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
|
72
74
|
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
|
73
|
-
self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
|
74
75
|
self.storage = Storage(self)
|
75
76
|
self.cleanup = self.storage.cleanup
|
76
77
|
|
@@ -96,7 +97,12 @@ class CheckpointFn(Generic[Fn]):
|
|
96
97
|
call_hash = ObjectHash(hash_params, digest_size=16)
|
97
98
|
return f"{self.fn_subdir}/{call_hash}"
|
98
99
|
|
99
|
-
async def
|
100
|
+
async def _resolve_awaitable(self, checkpoint_path: Path, awaitable: Awaitable):
|
101
|
+
data = await awaitable
|
102
|
+
self.storage.store(checkpoint_path, AwaitableValue(data))
|
103
|
+
return data
|
104
|
+
|
105
|
+
def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
|
100
106
|
params = self.checkpointer
|
101
107
|
checkpoint_id = self.get_checkpoint_id(args, kw)
|
102
108
|
checkpoint_path = params.root_path / checkpoint_id
|
@@ -107,10 +113,11 @@ class CheckpointFn(Generic[Fn]):
|
|
107
113
|
if refresh:
|
108
114
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
|
109
115
|
data = self.fn(*args, **kw)
|
110
|
-
if inspect.
|
111
|
-
|
112
|
-
|
113
|
-
|
116
|
+
if inspect.isawaitable(data):
|
117
|
+
return self._resolve_awaitable(checkpoint_path, data)
|
118
|
+
else:
|
119
|
+
self.storage.store(checkpoint_path, data)
|
120
|
+
return data
|
114
121
|
|
115
122
|
try:
|
116
123
|
data = self.storage.load(checkpoint_path)
|
@@ -119,29 +126,33 @@ class CheckpointFn(Generic[Fn]):
|
|
119
126
|
except (EOFError, FileNotFoundError):
|
120
127
|
pass
|
121
128
|
print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
|
122
|
-
return
|
129
|
+
return self._store_on_demand(args, kw, True)
|
123
130
|
|
124
131
|
def _call(self, args: tuple, kw: dict, rerun=False):
|
125
132
|
if not self.checkpointer.when:
|
126
133
|
return self.fn(*args, **kw)
|
127
|
-
|
128
|
-
|
134
|
+
return self._store_on_demand(args, kw, rerun)
|
135
|
+
|
136
|
+
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
137
|
+
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
138
|
+
|
139
|
+
@overload
|
140
|
+
def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
141
|
+
@overload
|
142
|
+
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
129
143
|
|
130
|
-
def
|
144
|
+
def get(self, *args, **kw):
|
131
145
|
checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
|
132
146
|
try:
|
133
|
-
|
134
|
-
return
|
147
|
+
data = self.storage.load(checkpoint_path)
|
148
|
+
return data.value if isinstance(data, AwaitableValue) else data
|
135
149
|
except Exception as ex:
|
136
150
|
raise CheckpointError("Could not load checkpoint") from ex
|
137
151
|
|
138
|
-
def exists(self, *args:
|
152
|
+
def exists(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> bool: # type: ignore
|
153
|
+
self = cast(CheckpointFn, self)
|
139
154
|
return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
|
140
155
|
|
141
|
-
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
142
|
-
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
143
|
-
get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
|
144
|
-
|
145
156
|
def __repr__(self) -> str:
|
146
157
|
return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
|
147
158
|
|
@@ -56,6 +56,9 @@ class ObjectHash:
|
|
56
56
|
|
57
57
|
__str__ = hexdigest
|
58
58
|
|
59
|
+
def __eq__(self, value: object) -> bool:
|
60
|
+
return isinstance(value, ObjectHash) and str(self) == str(value)
|
61
|
+
|
59
62
|
def nested_hash(self, *objs: Any) -> str:
|
60
63
|
return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
|
61
64
|
|
@@ -6,7 +6,7 @@ from . import checkpoint
|
|
6
6
|
from .checkpoint import CheckpointError
|
7
7
|
from .utils import AttrDict
|
8
8
|
|
9
|
-
def global_multiply(a, b):
|
9
|
+
def global_multiply(a: int, b: int) -> int:
|
10
10
|
return a * b
|
11
11
|
|
12
12
|
@pytest.fixture(autouse=True)
|
@@ -27,15 +27,15 @@ def test_basic_caching():
|
|
27
27
|
|
28
28
|
def test_cache_invalidation():
|
29
29
|
@checkpoint
|
30
|
-
def multiply(a, b):
|
30
|
+
def multiply(a: int, b: int):
|
31
31
|
return a * b
|
32
32
|
|
33
33
|
@checkpoint
|
34
|
-
def helper(x):
|
34
|
+
def helper(x: int):
|
35
35
|
return multiply(x + 1, 2)
|
36
36
|
|
37
37
|
@checkpoint
|
38
|
-
def compute(a, b):
|
38
|
+
def compute(a: int, b: int):
|
39
39
|
return helper(a) + helper(b)
|
40
40
|
|
41
41
|
result1 = compute(3, 4)
|
@@ -46,7 +46,7 @@ def test_layered_caching():
|
|
46
46
|
|
47
47
|
@checkpoint(format="memory")
|
48
48
|
@dev_checkpoint
|
49
|
-
def expensive_function(x):
|
49
|
+
def expensive_function(x: int):
|
50
50
|
return x ** 2
|
51
51
|
|
52
52
|
assert expensive_function(4) == 16
|
@@ -79,9 +79,10 @@ async def test_async_caching():
|
|
79
79
|
return x ** 2
|
80
80
|
|
81
81
|
result1 = await async_square(3)
|
82
|
-
result2 = await async_square
|
82
|
+
result2 = await async_square(3)
|
83
|
+
result3 = async_square.get(3)
|
83
84
|
|
84
|
-
assert result1 == result2 == 9
|
85
|
+
assert result1 == result2 == result3 == 9
|
85
86
|
|
86
87
|
def test_force_recalculation():
|
87
88
|
@checkpoint
|
@@ -95,7 +96,7 @@ def test_force_recalculation():
|
|
95
96
|
def test_multi_layer_decorator():
|
96
97
|
@checkpoint(format="memory")
|
97
98
|
@checkpoint(format="pickle")
|
98
|
-
def add(a, b):
|
99
|
+
def add(a: int, b: int) -> int:
|
99
100
|
return a + b
|
100
101
|
|
101
102
|
assert add(2, 3) == 5
|
@@ -124,18 +125,18 @@ def test_capture():
|
|
124
125
|
assert test_a.fn_hash != init_hash_a
|
125
126
|
|
126
127
|
def test_depends():
|
127
|
-
def multiply_wrapper(a, b):
|
128
|
+
def multiply_wrapper(a: int, b: int) -> int:
|
128
129
|
return global_multiply(a, b)
|
129
130
|
|
130
|
-
def helper(a, b):
|
131
|
+
def helper(a: int, b: int) -> int:
|
131
132
|
return multiply_wrapper(a + 1, b + 1)
|
132
133
|
|
133
134
|
@checkpoint
|
134
|
-
def test_a(a, b):
|
135
|
+
def test_a(a: int, b: int) -> int:
|
135
136
|
return helper(a, b)
|
136
137
|
|
137
138
|
@checkpoint
|
138
|
-
def test_b(a, b):
|
139
|
+
def test_b(a: int, b: int) -> int:
|
139
140
|
return test_a(a, b) + multiply_wrapper(a, b)
|
140
141
|
|
141
142
|
assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
|
@@ -143,17 +144,17 @@ def test_depends():
|
|
143
144
|
|
144
145
|
def test_lazy_init():
|
145
146
|
@checkpoint
|
146
|
-
def fn1(x):
|
147
|
+
def fn1(x: object) -> object:
|
147
148
|
return fn2(x)
|
148
149
|
|
149
150
|
@checkpoint
|
150
|
-
def fn2(x):
|
151
|
+
def fn2(x: object) -> object:
|
151
152
|
return fn1(x)
|
152
153
|
|
153
|
-
assert type(object.__getattribute__(fn1, "_getattribute"))
|
154
|
+
assert type(object.__getattribute__(fn1, "_getattribute")) is MethodType
|
154
155
|
with pytest.raises(AttributeError):
|
155
156
|
object.__getattribute__(fn1, "fn_hash")
|
156
157
|
assert fn1.fn_hash == object.__getattribute__(fn1, "fn_hash")
|
157
|
-
assert type(object.__getattribute__(fn1, "_getattribute"))
|
158
|
+
assert type(object.__getattribute__(fn1, "_getattribute")) is MethodWrapperType
|
158
159
|
assert set(fn1.depends) == {fn1.fn, fn2}
|
159
160
|
assert set(fn2.depends) == {fn1, fn2.fn}
|
@@ -2,11 +2,10 @@ import inspect
|
|
2
2
|
import tokenize
|
3
3
|
from contextlib import contextmanager
|
4
4
|
from io import StringIO
|
5
|
-
from
|
6
|
-
from typing import Any, Callable, Coroutine, Generator, Generic, Iterable, TypeVar, cast
|
5
|
+
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
7
6
|
|
8
7
|
T = TypeVar("T")
|
9
|
-
|
8
|
+
Fn = TypeVar("Fn", bound=Callable)
|
10
9
|
|
11
10
|
def distinct(seq: Iterable[T]) -> list[T]:
|
12
11
|
return list(dict.fromkeys(seq))
|
@@ -33,28 +32,20 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
|
|
33
32
|
except ValueError:
|
34
33
|
pass
|
35
34
|
|
36
|
-
def unwrap_fn(fn:
|
35
|
+
def unwrap_fn(fn: Fn, checkpoint_fn=False) -> Fn:
|
37
36
|
from .checkpoint import CheckpointFn
|
38
37
|
while True:
|
39
38
|
if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
|
40
|
-
return cast(
|
39
|
+
return cast(Fn, fn)
|
41
40
|
fn = getattr(fn, "__wrapped__")
|
42
41
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
@coroutine
|
47
|
-
def coroutine_as_generator(coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
|
48
|
-
val = yield from coroutine
|
49
|
-
return val
|
42
|
+
class AwaitableValue:
|
43
|
+
def __init__(self, value):
|
44
|
+
self.value = value
|
50
45
|
|
51
|
-
def
|
52
|
-
|
53
|
-
|
54
|
-
while True:
|
55
|
-
next(gen)
|
56
|
-
except StopIteration as ex:
|
57
|
-
return ex.value
|
46
|
+
def __await__(self):
|
47
|
+
yield
|
48
|
+
return self.value
|
58
49
|
|
59
50
|
class AttrDict(dict):
|
60
51
|
def __init__(self, *args, **kwargs):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "checkpointer"
|
3
|
-
version = "2.
|
3
|
+
version = "2.8.0"
|
4
4
|
requires-python = ">=3.11"
|
5
5
|
dependencies = []
|
6
6
|
authors = [
|
@@ -23,8 +23,8 @@ dev = [
|
|
23
23
|
"numpy>=2.2.1",
|
24
24
|
"omg>=1.2.3",
|
25
25
|
"poethepoet>=0.30.0",
|
26
|
-
"pytest>=8.3.
|
27
|
-
"pytest-asyncio>=0.
|
26
|
+
"pytest>=8.3.5",
|
27
|
+
"pytest-asyncio>=0.26.0",
|
28
28
|
"relib>=1.0.8",
|
29
29
|
"torch>=2.5.1",
|
30
30
|
]
|
@@ -39,3 +39,6 @@ build-backend = "hatchling.build"
|
|
39
39
|
|
40
40
|
[tool.hatch.build.targets.wheel]
|
41
41
|
packages = ["checkpointer", "checkpointer.storages"]
|
42
|
+
|
43
|
+
[tool.pytest.ini_options]
|
44
|
+
asyncio_default_fixture_loop_scope = "session"
|
@@ -8,7 +8,7 @@ resolution-markers = [
|
|
8
8
|
|
9
9
|
[[package]]
|
10
10
|
name = "checkpointer"
|
11
|
-
version = "2.
|
11
|
+
version = "2.8.0"
|
12
12
|
source = { editable = "." }
|
13
13
|
|
14
14
|
[package.dev-dependencies]
|
@@ -30,8 +30,8 @@ dev = [
|
|
30
30
|
{ name = "numpy", specifier = ">=2.2.1" },
|
31
31
|
{ name = "omg", specifier = ">=1.2.3" },
|
32
32
|
{ name = "poethepoet", specifier = ">=0.30.0" },
|
33
|
-
{ name = "pytest", specifier = ">=8.3.
|
34
|
-
{ name = "pytest-asyncio", specifier = ">=0.
|
33
|
+
{ name = "pytest", specifier = ">=8.3.5" },
|
34
|
+
{ name = "pytest-asyncio", specifier = ">=0.26.0" },
|
35
35
|
{ name = "relib", specifier = ">=1.0.8" },
|
36
36
|
{ name = "torch", specifier = ">=2.5.1" },
|
37
37
|
]
|
@@ -364,7 +364,7 @@ wheels = [
|
|
364
364
|
|
365
365
|
[[package]]
|
366
366
|
name = "pytest"
|
367
|
-
version = "8.3.
|
367
|
+
version = "8.3.5"
|
368
368
|
source = { registry = "https://pypi.org/simple" }
|
369
369
|
dependencies = [
|
370
370
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
@@ -372,21 +372,21 @@ dependencies = [
|
|
372
372
|
{ name = "packaging" },
|
373
373
|
{ name = "pluggy" },
|
374
374
|
]
|
375
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
375
|
+
sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
|
376
376
|
wheels = [
|
377
|
-
{ url = "https://files.pythonhosted.org/packages/
|
377
|
+
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
378
378
|
]
|
379
379
|
|
380
380
|
[[package]]
|
381
381
|
name = "pytest-asyncio"
|
382
|
-
version = "0.
|
382
|
+
version = "0.26.0"
|
383
383
|
source = { registry = "https://pypi.org/simple" }
|
384
384
|
dependencies = [
|
385
385
|
{ name = "pytest" },
|
386
386
|
]
|
387
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
387
|
+
sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 }
|
388
388
|
wheels = [
|
389
|
-
{ url = "https://files.pythonhosted.org/packages/
|
389
|
+
{ url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
|
390
390
|
]
|
391
391
|
|
392
392
|
[[package]]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|