checkpointer 2.7.0__tar.gz → 2.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.7.0
3
+ Version: 2.8.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -255,7 +255,7 @@ async def main():
255
255
  result2 = await async_compute_sum(3, 7)
256
256
  print(result2) # Outputs 10
257
257
 
258
- result3 = await async_compute_sum.get(3, 7)
258
+ result3 = async_compute_sum.get(3, 7)
259
259
  print(result3) # Outputs 10
260
260
 
261
261
  asyncio.run(main())
@@ -235,7 +235,7 @@ async def main():
235
235
  result2 = await async_compute_sum(3, 7)
236
236
  print(result2) # Outputs 10
237
237
 
238
- result3 = await async_compute_sum.get(3, 7)
238
+ result3 = async_compute_sum.get(3, 7)
239
239
  print(result3) # Outputs 10
240
240
 
241
241
  asyncio.run(main())
@@ -9,7 +9,8 @@ create_checkpointer = Checkpointer
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
11
11
  memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
- tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
12
+ tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
13
+ static_checkpoint = Checkpointer(fn_hash=ObjectHash())
13
14
 
14
15
  def cleanup_all(invalidated=True, expired=True):
15
16
  for obj in gc.get_objects():
@@ -4,14 +4,16 @@ import re
4
4
  from datetime import datetime
5
5
  from functools import update_wrapper
6
6
  from pathlib import Path
7
- from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
7
+ from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
8
8
  from .fn_ident import get_fn_ident
9
9
  from .object_hash import ObjectHash
10
10
  from .print_checkpoint import print_checkpoint
11
11
  from .storages import STORAGE_MAP, Storage
12
- from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
12
+ from .utils import AwaitableValue, unwrap_fn
13
13
 
14
14
  Fn = TypeVar("Fn", bound=Callable)
15
+ P = ParamSpec("P")
16
+ R = TypeVar("R")
15
17
 
16
18
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
17
19
 
@@ -70,7 +72,6 @@ class CheckpointFn(Generic[Fn]):
70
72
  deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
71
73
  self.fn_hash = str(params.fn_hash or ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
72
74
  self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
73
- self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
74
75
  self.storage = Storage(self)
75
76
  self.cleanup = self.storage.cleanup
76
77
 
@@ -96,7 +97,12 @@ class CheckpointFn(Generic[Fn]):
96
97
  call_hash = ObjectHash(hash_params, digest_size=16)
97
98
  return f"{self.fn_subdir}/{call_hash}"
98
99
 
99
- async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
100
+ async def _resolve_awaitable(self, checkpoint_path: Path, awaitable: Awaitable):
101
+ data = await awaitable
102
+ self.storage.store(checkpoint_path, AwaitableValue(data))
103
+ return data
104
+
105
+ def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
100
106
  params = self.checkpointer
101
107
  checkpoint_id = self.get_checkpoint_id(args, kw)
102
108
  checkpoint_path = params.root_path / checkpoint_id
@@ -107,10 +113,11 @@ class CheckpointFn(Generic[Fn]):
107
113
  if refresh:
108
114
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
109
115
  data = self.fn(*args, **kw)
110
- if inspect.iscoroutine(data):
111
- data = await data
112
- self.storage.store(checkpoint_path, data)
113
- return data
116
+ if inspect.isawaitable(data):
117
+ return self._resolve_awaitable(checkpoint_path, data)
118
+ else:
119
+ self.storage.store(checkpoint_path, data)
120
+ return data
114
121
 
115
122
  try:
116
123
  data = self.storage.load(checkpoint_path)
@@ -119,29 +126,33 @@ class CheckpointFn(Generic[Fn]):
119
126
  except (EOFError, FileNotFoundError):
120
127
  pass
121
128
  print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
122
- return await self._store_on_demand(args, kw, True)
129
+ return self._store_on_demand(args, kw, True)
123
130
 
124
131
  def _call(self, args: tuple, kw: dict, rerun=False):
125
132
  if not self.checkpointer.when:
126
133
  return self.fn(*args, **kw)
127
- coroutine = self._store_on_demand(args, kw, rerun)
128
- return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
134
+ return self._store_on_demand(args, kw, rerun)
135
+
136
+ __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
137
+ rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
138
+
139
+ @overload
140
+ def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
141
+ @overload
142
+ def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
129
143
 
130
- def _get(self, args, kw) -> Any:
144
+ def get(self, *args, **kw):
131
145
  checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
132
146
  try:
133
- val = self.storage.load(checkpoint_path)
134
- return resolved_awaitable(val) if self.is_async else val
147
+ data = self.storage.load(checkpoint_path)
148
+ return data.value if isinstance(data, AwaitableValue) else data
135
149
  except Exception as ex:
136
150
  raise CheckpointError("Could not load checkpoint") from ex
137
151
 
138
- def exists(self, *args: tuple, **kw: dict) -> bool:
152
+ def exists(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> bool: # type: ignore
153
+ self = cast(CheckpointFn, self)
139
154
  return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
140
155
 
141
- __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
142
- rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
143
- get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
144
-
145
156
  def __repr__(self) -> str:
146
157
  return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
147
158
 
@@ -56,6 +56,9 @@ class ObjectHash:
56
56
 
57
57
  __str__ = hexdigest
58
58
 
59
+ def __eq__(self, value: object) -> bool:
60
+ return isinstance(value, ObjectHash) and str(self) == str(value)
61
+
59
62
  def nested_hash(self, *objs: Any) -> str:
60
63
  return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
61
64
 
@@ -6,7 +6,7 @@ from . import checkpoint
6
6
  from .checkpoint import CheckpointError
7
7
  from .utils import AttrDict
8
8
 
9
- def global_multiply(a, b):
9
+ def global_multiply(a: int, b: int) -> int:
10
10
  return a * b
11
11
 
12
12
  @pytest.fixture(autouse=True)
@@ -27,15 +27,15 @@ def test_basic_caching():
27
27
 
28
28
  def test_cache_invalidation():
29
29
  @checkpoint
30
- def multiply(a, b):
30
+ def multiply(a: int, b: int):
31
31
  return a * b
32
32
 
33
33
  @checkpoint
34
- def helper(x):
34
+ def helper(x: int):
35
35
  return multiply(x + 1, 2)
36
36
 
37
37
  @checkpoint
38
- def compute(a, b):
38
+ def compute(a: int, b: int):
39
39
  return helper(a) + helper(b)
40
40
 
41
41
  result1 = compute(3, 4)
@@ -46,7 +46,7 @@ def test_layered_caching():
46
46
 
47
47
  @checkpoint(format="memory")
48
48
  @dev_checkpoint
49
- def expensive_function(x):
49
+ def expensive_function(x: int):
50
50
  return x ** 2
51
51
 
52
52
  assert expensive_function(4) == 16
@@ -79,9 +79,10 @@ async def test_async_caching():
79
79
  return x ** 2
80
80
 
81
81
  result1 = await async_square(3)
82
- result2 = await async_square.get(3)
82
+ result2 = await async_square(3)
83
+ result3 = async_square.get(3)
83
84
 
84
- assert result1 == result2 == 9
85
+ assert result1 == result2 == result3 == 9
85
86
 
86
87
  def test_force_recalculation():
87
88
  @checkpoint
@@ -95,7 +96,7 @@ def test_force_recalculation():
95
96
  def test_multi_layer_decorator():
96
97
  @checkpoint(format="memory")
97
98
  @checkpoint(format="pickle")
98
- def add(a, b):
99
+ def add(a: int, b: int) -> int:
99
100
  return a + b
100
101
 
101
102
  assert add(2, 3) == 5
@@ -124,18 +125,18 @@ def test_capture():
124
125
  assert test_a.fn_hash != init_hash_a
125
126
 
126
127
  def test_depends():
127
- def multiply_wrapper(a, b):
128
+ def multiply_wrapper(a: int, b: int) -> int:
128
129
  return global_multiply(a, b)
129
130
 
130
- def helper(a, b):
131
+ def helper(a: int, b: int) -> int:
131
132
  return multiply_wrapper(a + 1, b + 1)
132
133
 
133
134
  @checkpoint
134
- def test_a(a, b):
135
+ def test_a(a: int, b: int) -> int:
135
136
  return helper(a, b)
136
137
 
137
138
  @checkpoint
138
- def test_b(a, b):
139
+ def test_b(a: int, b: int) -> int:
139
140
  return test_a(a, b) + multiply_wrapper(a, b)
140
141
 
141
142
  assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
@@ -143,17 +144,17 @@ def test_depends():
143
144
 
144
145
  def test_lazy_init():
145
146
  @checkpoint
146
- def fn1(x):
147
+ def fn1(x: object) -> object:
147
148
  return fn2(x)
148
149
 
149
150
  @checkpoint
150
- def fn2(x):
151
+ def fn2(x: object) -> object:
151
152
  return fn1(x)
152
153
 
153
- assert type(object.__getattribute__(fn1, "_getattribute")) == MethodType
154
+ assert type(object.__getattribute__(fn1, "_getattribute")) is MethodType
154
155
  with pytest.raises(AttributeError):
155
156
  object.__getattribute__(fn1, "fn_hash")
156
157
  assert fn1.fn_hash == object.__getattribute__(fn1, "fn_hash")
157
- assert type(object.__getattribute__(fn1, "_getattribute")) == MethodWrapperType
158
+ assert type(object.__getattribute__(fn1, "_getattribute")) is MethodWrapperType
158
159
  assert set(fn1.depends) == {fn1.fn, fn2}
159
160
  assert set(fn2.depends) == {fn1, fn2.fn}
@@ -2,11 +2,10 @@ import inspect
2
2
  import tokenize
3
3
  from contextlib import contextmanager
4
4
  from io import StringIO
5
- from types import coroutine
6
- from typing import Any, Callable, Coroutine, Generator, Generic, Iterable, TypeVar, cast
5
+ from typing import Any, Callable, Generic, Iterable, TypeVar, cast
7
6
 
8
7
  T = TypeVar("T")
9
- T_Callable = TypeVar("T_Callable", bound=Callable)
8
+ Fn = TypeVar("Fn", bound=Callable)
10
9
 
11
10
  def distinct(seq: Iterable[T]) -> list[T]:
12
11
  return list(dict.fromkeys(seq))
@@ -33,28 +32,20 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
33
32
  except ValueError:
34
33
  pass
35
34
 
36
- def unwrap_fn(fn: T_Callable, checkpoint_fn=False) -> T_Callable:
35
+ def unwrap_fn(fn: Fn, checkpoint_fn=False) -> Fn:
37
36
  from .checkpoint import CheckpointFn
38
37
  while True:
39
38
  if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
40
- return cast(T_Callable, fn)
39
+ return cast(Fn, fn)
41
40
  fn = getattr(fn, "__wrapped__")
42
41
 
43
- async def resolved_awaitable(value: T) -> T:
44
- return value
45
-
46
- @coroutine
47
- def coroutine_as_generator(coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
48
- val = yield from coroutine
49
- return val
42
+ class AwaitableValue:
43
+ def __init__(self, value):
44
+ self.value = value
50
45
 
51
- def sync_resolve_coroutine(coroutine: Coroutine[None, None, T]) -> T:
52
- gen = cast(Generator, coroutine_as_generator(coroutine))
53
- try:
54
- while True:
55
- next(gen)
56
- except StopIteration as ex:
57
- return ex.value
46
+ def __await__(self):
47
+ yield
48
+ return self.value
58
49
 
59
50
  class AttrDict(dict):
60
51
  def __init__(self, *args, **kwargs):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.7.0"
3
+ version = "2.8.0"
4
4
  requires-python = ">=3.11"
5
5
  dependencies = []
6
6
  authors = [
@@ -23,8 +23,8 @@ dev = [
23
23
  "numpy>=2.2.1",
24
24
  "omg>=1.2.3",
25
25
  "poethepoet>=0.30.0",
26
- "pytest>=8.3.3",
27
- "pytest-asyncio>=0.24.0",
26
+ "pytest>=8.3.5",
27
+ "pytest-asyncio>=0.26.0",
28
28
  "relib>=1.0.8",
29
29
  "torch>=2.5.1",
30
30
  ]
@@ -39,3 +39,6 @@ build-backend = "hatchling.build"
39
39
 
40
40
  [tool.hatch.build.targets.wheel]
41
41
  packages = ["checkpointer", "checkpointer.storages"]
42
+
43
+ [tool.pytest.ini_options]
44
+ asyncio_default_fixture_loop_scope = "session"
@@ -8,7 +8,7 @@ resolution-markers = [
8
8
 
9
9
  [[package]]
10
10
  name = "checkpointer"
11
- version = "2.7.0"
11
+ version = "2.8.0"
12
12
  source = { editable = "." }
13
13
 
14
14
  [package.dev-dependencies]
@@ -30,8 +30,8 @@ dev = [
30
30
  { name = "numpy", specifier = ">=2.2.1" },
31
31
  { name = "omg", specifier = ">=1.2.3" },
32
32
  { name = "poethepoet", specifier = ">=0.30.0" },
33
- { name = "pytest", specifier = ">=8.3.3" },
34
- { name = "pytest-asyncio", specifier = ">=0.24.0" },
33
+ { name = "pytest", specifier = ">=8.3.5" },
34
+ { name = "pytest-asyncio", specifier = ">=0.26.0" },
35
35
  { name = "relib", specifier = ">=1.0.8" },
36
36
  { name = "torch", specifier = ">=2.5.1" },
37
37
  ]
@@ -364,7 +364,7 @@ wheels = [
364
364
 
365
365
  [[package]]
366
366
  name = "pytest"
367
- version = "8.3.4"
367
+ version = "8.3.5"
368
368
  source = { registry = "https://pypi.org/simple" }
369
369
  dependencies = [
370
370
  { name = "colorama", marker = "sys_platform == 'win32'" },
@@ -372,21 +372,21 @@ dependencies = [
372
372
  { name = "packaging" },
373
373
  { name = "pluggy" },
374
374
  ]
375
- sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 }
375
+ sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
376
376
  wheels = [
377
- { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 },
377
+ { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
378
378
  ]
379
379
 
380
380
  [[package]]
381
381
  name = "pytest-asyncio"
382
- version = "0.25.0"
382
+ version = "0.26.0"
383
383
  source = { registry = "https://pypi.org/simple" }
384
384
  dependencies = [
385
385
  { name = "pytest" },
386
386
  ]
387
- sdist = { url = "https://files.pythonhosted.org/packages/94/18/82fcb4ee47d66d99f6cd1efc0b11b2a25029f303c599a5afda7c1bca4254/pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609", size = 53298 }
387
+ sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 }
388
388
  wheels = [
389
- { url = "https://files.pythonhosted.org/packages/88/56/2ee0cab25c11d4e38738a2a98c645a8f002e2ecf7b5ed774c70d53b92bb1/pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3", size = 19245 },
389
+ { url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
390
390
  ]
391
391
 
392
392
  [[package]]
File without changes
File without changes