checkpointer 2.7.1__tar.gz → 2.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.7.1
3
+ Version: 2.8.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -255,7 +255,7 @@ async def main():
255
255
  result2 = await async_compute_sum(3, 7)
256
256
  print(result2) # Outputs 10
257
257
 
258
- result3 = await async_compute_sum.get(3, 7)
258
+ result3 = async_compute_sum.get(3, 7)
259
259
  print(result3) # Outputs 10
260
260
 
261
261
  asyncio.run(main())
@@ -235,7 +235,7 @@ async def main():
235
235
  result2 = await async_compute_sum(3, 7)
236
236
  print(result2) # Outputs 10
237
237
 
238
- result3 = await async_compute_sum.get(3, 7)
238
+ result3 = async_compute_sum.get(3, 7)
239
239
  print(result3) # Outputs 10
240
240
 
241
241
  asyncio.run(main())
@@ -4,12 +4,12 @@ import re
4
4
  from datetime import datetime
5
5
  from functools import update_wrapper
6
6
  from pathlib import Path
7
- from typing import Any, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
7
+ from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
8
8
  from .fn_ident import get_fn_ident
9
9
  from .object_hash import ObjectHash
10
10
  from .print_checkpoint import print_checkpoint
11
11
  from .storages import STORAGE_MAP, Storage
12
- from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
12
+ from .utils import AwaitableValue, unwrap_fn
13
13
 
14
14
  Fn = TypeVar("Fn", bound=Callable)
15
15
  P = ParamSpec("P")
@@ -72,7 +72,6 @@ class CheckpointFn(Generic[Fn]):
72
72
  deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
73
73
  self.fn_hash = str(params.fn_hash or ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
74
74
  self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
75
- self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
76
75
  self.storage = Storage(self)
77
76
  self.cleanup = self.storage.cleanup
78
77
 
@@ -98,7 +97,12 @@ class CheckpointFn(Generic[Fn]):
98
97
  call_hash = ObjectHash(hash_params, digest_size=16)
99
98
  return f"{self.fn_subdir}/{call_hash}"
100
99
 
101
- async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
100
+ async def _resolve_awaitable(self, checkpoint_path: Path, awaitable: Awaitable):
101
+ data = await awaitable
102
+ self.storage.store(checkpoint_path, AwaitableValue(data))
103
+ return data
104
+
105
+ def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
102
106
  params = self.checkpointer
103
107
  checkpoint_id = self.get_checkpoint_id(args, kw)
104
108
  checkpoint_path = params.root_path / checkpoint_id
@@ -109,10 +113,11 @@ class CheckpointFn(Generic[Fn]):
109
113
  if refresh:
110
114
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
111
115
  data = self.fn(*args, **kw)
112
- if inspect.iscoroutine(data):
113
- data = await data
114
- self.storage.store(checkpoint_path, data)
115
- return data
116
+ if inspect.isawaitable(data):
117
+ return self._resolve_awaitable(checkpoint_path, data)
118
+ else:
119
+ self.storage.store(checkpoint_path, data)
120
+ return data
116
121
 
117
122
  try:
118
123
  data = self.storage.load(checkpoint_path)
@@ -121,20 +126,26 @@ class CheckpointFn(Generic[Fn]):
121
126
  except (EOFError, FileNotFoundError):
122
127
  pass
123
128
  print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
124
- return await self._store_on_demand(args, kw, True)
129
+ return self._store_on_demand(args, kw, True)
125
130
 
126
131
  def _call(self, args: tuple, kw: dict, rerun=False):
127
132
  if not self.checkpointer.when:
128
133
  return self.fn(*args, **kw)
129
- coroutine = self._store_on_demand(args, kw, rerun)
130
- return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
134
+ return self._store_on_demand(args, kw, rerun)
131
135
 
132
- def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: # type: ignore
133
- self = cast(CheckpointFn, self)
136
+ __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
137
+ rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
138
+
139
+ @overload
140
+ def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
141
+ @overload
142
+ def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
143
+
144
+ def get(self, *args, **kw):
134
145
  checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
135
146
  try:
136
- val = self.storage.load(checkpoint_path)
137
- return cast(R, resolved_awaitable(val) if self.is_async else val)
147
+ data = self.storage.load(checkpoint_path)
148
+ return data.value if isinstance(data, AwaitableValue) else data
138
149
  except Exception as ex:
139
150
  raise CheckpointError("Could not load checkpoint") from ex
140
151
 
@@ -142,9 +153,6 @@ class CheckpointFn(Generic[Fn]):
142
153
  self = cast(CheckpointFn, self)
143
154
  return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
144
155
 
145
- __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
146
- rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
147
-
148
156
  def __repr__(self) -> str:
149
157
  return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
150
158
 
@@ -79,9 +79,10 @@ async def test_async_caching():
79
79
  return x ** 2
80
80
 
81
81
  result1 = await async_square(3)
82
- result2 = await async_square.get(3)
82
+ result2 = await async_square(3)
83
+ result3 = async_square.get(3)
83
84
 
84
- assert result1 == result2 == 9
85
+ assert result1 == result2 == result3 == 9
85
86
 
86
87
  def test_force_recalculation():
87
88
  @checkpoint
@@ -2,11 +2,10 @@ import inspect
2
2
  import tokenize
3
3
  from contextlib import contextmanager
4
4
  from io import StringIO
5
- from types import coroutine
6
- from typing import Any, Callable, Coroutine, Generator, Generic, Iterable, TypeVar, cast
5
+ from typing import Any, Callable, Generic, Iterable, TypeVar, cast
7
6
 
8
7
  T = TypeVar("T")
9
- T_Callable = TypeVar("T_Callable", bound=Callable)
8
+ Fn = TypeVar("Fn", bound=Callable)
10
9
 
11
10
  def distinct(seq: Iterable[T]) -> list[T]:
12
11
  return list(dict.fromkeys(seq))
@@ -33,28 +32,20 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
33
32
  except ValueError:
34
33
  pass
35
34
 
36
- def unwrap_fn(fn: T_Callable, checkpoint_fn=False) -> T_Callable:
35
+ def unwrap_fn(fn: Fn, checkpoint_fn=False) -> Fn:
37
36
  from .checkpoint import CheckpointFn
38
37
  while True:
39
38
  if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
40
- return cast(T_Callable, fn)
39
+ return cast(Fn, fn)
41
40
  fn = getattr(fn, "__wrapped__")
42
41
 
43
- async def resolved_awaitable(value: T) -> T:
44
- return value
45
-
46
- @coroutine
47
- def coroutine_as_generator(coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
48
- val = yield from coroutine
49
- return val
42
+ class AwaitableValue:
43
+ def __init__(self, value):
44
+ self.value = value
50
45
 
51
- def sync_resolve_coroutine(coroutine: Coroutine[None, None, T]) -> T:
52
- gen = cast(Generator, coroutine_as_generator(coroutine))
53
- try:
54
- while True:
55
- next(gen)
56
- except StopIteration as ex:
57
- return ex.value
46
+ def __await__(self):
47
+ yield
48
+ return self.value
58
49
 
59
50
  class AttrDict(dict):
60
51
  def __init__(self, *args, **kwargs):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.7.1"
3
+ version = "2.8.0"
4
4
  requires-python = ">=3.11"
5
5
  dependencies = []
6
6
  authors = [
@@ -23,8 +23,8 @@ dev = [
23
23
  "numpy>=2.2.1",
24
24
  "omg>=1.2.3",
25
25
  "poethepoet>=0.30.0",
26
- "pytest>=8.3.3",
27
- "pytest-asyncio>=0.24.0",
26
+ "pytest>=8.3.5",
27
+ "pytest-asyncio>=0.26.0",
28
28
  "relib>=1.0.8",
29
29
  "torch>=2.5.1",
30
30
  ]
@@ -39,3 +39,6 @@ build-backend = "hatchling.build"
39
39
 
40
40
  [tool.hatch.build.targets.wheel]
41
41
  packages = ["checkpointer", "checkpointer.storages"]
42
+
43
+ [tool.pytest.ini_options]
44
+ asyncio_default_fixture_loop_scope = "session"
@@ -8,7 +8,7 @@ resolution-markers = [
8
8
 
9
9
  [[package]]
10
10
  name = "checkpointer"
11
- version = "2.7.1"
11
+ version = "2.8.0"
12
12
  source = { editable = "." }
13
13
 
14
14
  [package.dev-dependencies]
@@ -30,8 +30,8 @@ dev = [
30
30
  { name = "numpy", specifier = ">=2.2.1" },
31
31
  { name = "omg", specifier = ">=1.2.3" },
32
32
  { name = "poethepoet", specifier = ">=0.30.0" },
33
- { name = "pytest", specifier = ">=8.3.3" },
34
- { name = "pytest-asyncio", specifier = ">=0.24.0" },
33
+ { name = "pytest", specifier = ">=8.3.5" },
34
+ { name = "pytest-asyncio", specifier = ">=0.26.0" },
35
35
  { name = "relib", specifier = ">=1.0.8" },
36
36
  { name = "torch", specifier = ">=2.5.1" },
37
37
  ]
@@ -364,7 +364,7 @@ wheels = [
364
364
 
365
365
  [[package]]
366
366
  name = "pytest"
367
- version = "8.3.4"
367
+ version = "8.3.5"
368
368
  source = { registry = "https://pypi.org/simple" }
369
369
  dependencies = [
370
370
  { name = "colorama", marker = "sys_platform == 'win32'" },
@@ -372,21 +372,21 @@ dependencies = [
372
372
  { name = "packaging" },
373
373
  { name = "pluggy" },
374
374
  ]
375
- sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 }
375
+ sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
376
376
  wheels = [
377
- { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 },
377
+ { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
378
378
  ]
379
379
 
380
380
  [[package]]
381
381
  name = "pytest-asyncio"
382
- version = "0.25.0"
382
+ version = "0.26.0"
383
383
  source = { registry = "https://pypi.org/simple" }
384
384
  dependencies = [
385
385
  { name = "pytest" },
386
386
  ]
387
- sdist = { url = "https://files.pythonhosted.org/packages/94/18/82fcb4ee47d66d99f6cd1efc0b11b2a25029f303c599a5afda7c1bca4254/pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609", size = 53298 }
387
+ sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 }
388
388
  wheels = [
389
- { url = "https://files.pythonhosted.org/packages/88/56/2ee0cab25c11d4e38738a2a98c645a8f002e2ecf7b5ed774c70d53b92bb1/pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3", size = 19245 },
389
+ { url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
390
390
  ]
391
391
 
392
392
  [[package]]
File without changes
File without changes