checkpointer 2.7.1__tar.gz → 2.8.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpointer-2.7.1 → checkpointer-2.8.1}/PKG-INFO +2 -2
- {checkpointer-2.7.1 → checkpointer-2.8.1}/README.md +1 -1
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/checkpoint.py +28 -20
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/test_checkpointer.py +3 -2
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/utils.py +10 -19
- {checkpointer-2.7.1 → checkpointer-2.8.1}/pyproject.toml +6 -3
- {checkpointer-2.7.1 → checkpointer-2.8.1}/uv.lock +9 -9
- {checkpointer-2.7.1 → checkpointer-2.8.1}/.gitignore +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/.python-version +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/LICENSE +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/__init__.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/fn_ident.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/object_hash.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/print_checkpoint.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/storages/__init__.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/storages/bcolz_storage.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/storages/memory_storage.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/storages/pickle_storage.py +0 -0
- {checkpointer-2.7.1 → checkpointer-2.8.1}/checkpointer/storages/storage.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.8.1
|
4
4
|
Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
@@ -255,7 +255,7 @@ async def main():
|
|
255
255
|
result2 = await async_compute_sum(3, 7)
|
256
256
|
print(result2) # Outputs 10
|
257
257
|
|
258
|
-
result3 =
|
258
|
+
result3 = async_compute_sum.get(3, 7)
|
259
259
|
print(result3) # Outputs 10
|
260
260
|
|
261
261
|
asyncio.run(main())
|
@@ -4,12 +4,12 @@ import re
|
|
4
4
|
from datetime import datetime
|
5
5
|
from functools import update_wrapper
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Any, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
|
7
|
+
from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
|
8
8
|
from .fn_ident import get_fn_ident
|
9
9
|
from .object_hash import ObjectHash
|
10
10
|
from .print_checkpoint import print_checkpoint
|
11
11
|
from .storages import STORAGE_MAP, Storage
|
12
|
-
from .utils import
|
12
|
+
from .utils import AwaitableValue, unwrap_fn
|
13
13
|
|
14
14
|
Fn = TypeVar("Fn", bound=Callable)
|
15
15
|
P = ParamSpec("P")
|
@@ -70,9 +70,8 @@ class CheckpointFn(Generic[Fn]):
|
|
70
70
|
update_wrapper(cast(Callable, self), wrapped)
|
71
71
|
Storage = STORAGE_MAP[params.format] if isinstance(params.format, str) else params.format
|
72
72
|
deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
|
73
|
-
self.fn_hash = str(params.fn_hash or ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
|
74
|
-
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:
|
75
|
-
self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
|
73
|
+
self.fn_hash = str(params.fn_hash or ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes))
|
74
|
+
self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:32]}"
|
76
75
|
self.storage = Storage(self)
|
77
76
|
self.cleanup = self.storage.cleanup
|
78
77
|
|
@@ -98,7 +97,12 @@ class CheckpointFn(Generic[Fn]):
|
|
98
97
|
call_hash = ObjectHash(hash_params, digest_size=16)
|
99
98
|
return f"{self.fn_subdir}/{call_hash}"
|
100
99
|
|
101
|
-
async def
|
100
|
+
async def _resolve_awaitable(self, checkpoint_path: Path, awaitable: Awaitable):
|
101
|
+
data = await awaitable
|
102
|
+
self.storage.store(checkpoint_path, AwaitableValue(data))
|
103
|
+
return data
|
104
|
+
|
105
|
+
def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
|
102
106
|
params = self.checkpointer
|
103
107
|
checkpoint_id = self.get_checkpoint_id(args, kw)
|
104
108
|
checkpoint_path = params.root_path / checkpoint_id
|
@@ -109,10 +113,11 @@ class CheckpointFn(Generic[Fn]):
|
|
109
113
|
if refresh:
|
110
114
|
print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
|
111
115
|
data = self.fn(*args, **kw)
|
112
|
-
if inspect.
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
+
if inspect.isawaitable(data):
|
117
|
+
return self._resolve_awaitable(checkpoint_path, data)
|
118
|
+
else:
|
119
|
+
self.storage.store(checkpoint_path, data)
|
120
|
+
return data
|
116
121
|
|
117
122
|
try:
|
118
123
|
data = self.storage.load(checkpoint_path)
|
@@ -121,20 +126,26 @@ class CheckpointFn(Generic[Fn]):
|
|
121
126
|
except (EOFError, FileNotFoundError):
|
122
127
|
pass
|
123
128
|
print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
|
124
|
-
return
|
129
|
+
return self._store_on_demand(args, kw, True)
|
125
130
|
|
126
131
|
def _call(self, args: tuple, kw: dict, rerun=False):
|
127
132
|
if not self.checkpointer.when:
|
128
133
|
return self.fn(*args, **kw)
|
129
|
-
|
130
|
-
return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
|
134
|
+
return self._store_on_demand(args, kw, rerun)
|
131
135
|
|
132
|
-
|
133
|
-
|
136
|
+
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
137
|
+
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
138
|
+
|
139
|
+
@overload
|
140
|
+
def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
141
|
+
@overload
|
142
|
+
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
143
|
+
|
144
|
+
def get(self, *args, **kw):
|
134
145
|
checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
|
135
146
|
try:
|
136
|
-
|
137
|
-
return
|
147
|
+
data = self.storage.load(checkpoint_path)
|
148
|
+
return data.value if isinstance(data, AwaitableValue) else data
|
138
149
|
except Exception as ex:
|
139
150
|
raise CheckpointError("Could not load checkpoint") from ex
|
140
151
|
|
@@ -142,9 +153,6 @@ class CheckpointFn(Generic[Fn]):
|
|
142
153
|
self = cast(CheckpointFn, self)
|
143
154
|
return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
|
144
155
|
|
145
|
-
__call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
|
146
|
-
rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
|
147
|
-
|
148
156
|
def __repr__(self) -> str:
|
149
157
|
return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
|
150
158
|
|
@@ -79,9 +79,10 @@ async def test_async_caching():
|
|
79
79
|
return x ** 2
|
80
80
|
|
81
81
|
result1 = await async_square(3)
|
82
|
-
result2 = await async_square
|
82
|
+
result2 = await async_square(3)
|
83
|
+
result3 = async_square.get(3)
|
83
84
|
|
84
|
-
assert result1 == result2 == 9
|
85
|
+
assert result1 == result2 == result3 == 9
|
85
86
|
|
86
87
|
def test_force_recalculation():
|
87
88
|
@checkpoint
|
@@ -2,11 +2,10 @@ import inspect
|
|
2
2
|
import tokenize
|
3
3
|
from contextlib import contextmanager
|
4
4
|
from io import StringIO
|
5
|
-
from
|
6
|
-
from typing import Any, Callable, Coroutine, Generator, Generic, Iterable, TypeVar, cast
|
5
|
+
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
7
6
|
|
8
7
|
T = TypeVar("T")
|
9
|
-
|
8
|
+
Fn = TypeVar("Fn", bound=Callable)
|
10
9
|
|
11
10
|
def distinct(seq: Iterable[T]) -> list[T]:
|
12
11
|
return list(dict.fromkeys(seq))
|
@@ -33,28 +32,20 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
|
|
33
32
|
except ValueError:
|
34
33
|
pass
|
35
34
|
|
36
|
-
def unwrap_fn(fn:
|
35
|
+
def unwrap_fn(fn: Fn, checkpoint_fn=False) -> Fn:
|
37
36
|
from .checkpoint import CheckpointFn
|
38
37
|
while True:
|
39
38
|
if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
|
40
|
-
return cast(
|
39
|
+
return cast(Fn, fn)
|
41
40
|
fn = getattr(fn, "__wrapped__")
|
42
41
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
@coroutine
|
47
|
-
def coroutine_as_generator(coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
|
48
|
-
val = yield from coroutine
|
49
|
-
return val
|
42
|
+
class AwaitableValue:
|
43
|
+
def __init__(self, value):
|
44
|
+
self.value = value
|
50
45
|
|
51
|
-
def
|
52
|
-
|
53
|
-
|
54
|
-
while True:
|
55
|
-
next(gen)
|
56
|
-
except StopIteration as ex:
|
57
|
-
return ex.value
|
46
|
+
def __await__(self):
|
47
|
+
yield
|
48
|
+
return self.value
|
58
49
|
|
59
50
|
class AttrDict(dict):
|
60
51
|
def __init__(self, *args, **kwargs):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "checkpointer"
|
3
|
-
version = "2.
|
3
|
+
version = "2.8.1"
|
4
4
|
requires-python = ">=3.11"
|
5
5
|
dependencies = []
|
6
6
|
authors = [
|
@@ -23,8 +23,8 @@ dev = [
|
|
23
23
|
"numpy>=2.2.1",
|
24
24
|
"omg>=1.2.3",
|
25
25
|
"poethepoet>=0.30.0",
|
26
|
-
"pytest>=8.3.
|
27
|
-
"pytest-asyncio>=0.
|
26
|
+
"pytest>=8.3.5",
|
27
|
+
"pytest-asyncio>=0.26.0",
|
28
28
|
"relib>=1.0.8",
|
29
29
|
"torch>=2.5.1",
|
30
30
|
]
|
@@ -39,3 +39,6 @@ build-backend = "hatchling.build"
|
|
39
39
|
|
40
40
|
[tool.hatch.build.targets.wheel]
|
41
41
|
packages = ["checkpointer", "checkpointer.storages"]
|
42
|
+
|
43
|
+
[tool.pytest.ini_options]
|
44
|
+
asyncio_default_fixture_loop_scope = "session"
|
@@ -8,7 +8,7 @@ resolution-markers = [
|
|
8
8
|
|
9
9
|
[[package]]
|
10
10
|
name = "checkpointer"
|
11
|
-
version = "2.
|
11
|
+
version = "2.8.1"
|
12
12
|
source = { editable = "." }
|
13
13
|
|
14
14
|
[package.dev-dependencies]
|
@@ -30,8 +30,8 @@ dev = [
|
|
30
30
|
{ name = "numpy", specifier = ">=2.2.1" },
|
31
31
|
{ name = "omg", specifier = ">=1.2.3" },
|
32
32
|
{ name = "poethepoet", specifier = ">=0.30.0" },
|
33
|
-
{ name = "pytest", specifier = ">=8.3.
|
34
|
-
{ name = "pytest-asyncio", specifier = ">=0.
|
33
|
+
{ name = "pytest", specifier = ">=8.3.5" },
|
34
|
+
{ name = "pytest-asyncio", specifier = ">=0.26.0" },
|
35
35
|
{ name = "relib", specifier = ">=1.0.8" },
|
36
36
|
{ name = "torch", specifier = ">=2.5.1" },
|
37
37
|
]
|
@@ -364,7 +364,7 @@ wheels = [
|
|
364
364
|
|
365
365
|
[[package]]
|
366
366
|
name = "pytest"
|
367
|
-
version = "8.3.
|
367
|
+
version = "8.3.5"
|
368
368
|
source = { registry = "https://pypi.org/simple" }
|
369
369
|
dependencies = [
|
370
370
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
@@ -372,21 +372,21 @@ dependencies = [
|
|
372
372
|
{ name = "packaging" },
|
373
373
|
{ name = "pluggy" },
|
374
374
|
]
|
375
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
375
|
+
sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
|
376
376
|
wheels = [
|
377
|
-
{ url = "https://files.pythonhosted.org/packages/
|
377
|
+
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
378
378
|
]
|
379
379
|
|
380
380
|
[[package]]
|
381
381
|
name = "pytest-asyncio"
|
382
|
-
version = "0.
|
382
|
+
version = "0.26.0"
|
383
383
|
source = { registry = "https://pypi.org/simple" }
|
384
384
|
dependencies = [
|
385
385
|
{ name = "pytest" },
|
386
386
|
]
|
387
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
387
|
+
sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 }
|
388
388
|
wheels = [
|
389
|
-
{ url = "https://files.pythonhosted.org/packages/
|
389
|
+
{ url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
|
390
390
|
]
|
391
391
|
|
392
392
|
[[package]]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|