checkpointer 2.7.1__py3-none-any.whl → 2.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,12 +4,12 @@ import re
4
4
  from datetime import datetime
5
5
  from functools import update_wrapper
6
6
  from pathlib import Path
7
- from typing import Any, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
7
+ from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
8
8
  from .fn_ident import get_fn_ident
9
9
  from .object_hash import ObjectHash
10
10
  from .print_checkpoint import print_checkpoint
11
11
  from .storages import STORAGE_MAP, Storage
12
- from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
12
+ from .utils import AwaitableValue, unwrap_fn
13
13
 
14
14
  Fn = TypeVar("Fn", bound=Callable)
15
15
  P = ParamSpec("P")
@@ -70,9 +70,8 @@ class CheckpointFn(Generic[Fn]):
70
70
  update_wrapper(cast(Callable, self), wrapped)
71
71
  Storage = STORAGE_MAP[params.format] if isinstance(params.format, str) else params.format
72
72
  deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
73
- self.fn_hash = str(params.fn_hash or ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
74
- self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
75
- self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
73
+ self.fn_hash = str(params.fn_hash or ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes))
74
+ self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:32]}"
76
75
  self.storage = Storage(self)
77
76
  self.cleanup = self.storage.cleanup
78
77
 
@@ -98,7 +97,12 @@ class CheckpointFn(Generic[Fn]):
98
97
  call_hash = ObjectHash(hash_params, digest_size=16)
99
98
  return f"{self.fn_subdir}/{call_hash}"
100
99
 
101
- async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
100
+ async def _resolve_awaitable(self, checkpoint_path: Path, awaitable: Awaitable):
101
+ data = await awaitable
102
+ self.storage.store(checkpoint_path, AwaitableValue(data))
103
+ return data
104
+
105
+ def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
102
106
  params = self.checkpointer
103
107
  checkpoint_id = self.get_checkpoint_id(args, kw)
104
108
  checkpoint_path = params.root_path / checkpoint_id
@@ -109,10 +113,11 @@ class CheckpointFn(Generic[Fn]):
109
113
  if refresh:
110
114
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
111
115
  data = self.fn(*args, **kw)
112
- if inspect.iscoroutine(data):
113
- data = await data
114
- self.storage.store(checkpoint_path, data)
115
- return data
116
+ if inspect.isawaitable(data):
117
+ return self._resolve_awaitable(checkpoint_path, data)
118
+ else:
119
+ self.storage.store(checkpoint_path, data)
120
+ return data
116
121
 
117
122
  try:
118
123
  data = self.storage.load(checkpoint_path)
@@ -121,20 +126,26 @@ class CheckpointFn(Generic[Fn]):
121
126
  except (EOFError, FileNotFoundError):
122
127
  pass
123
128
  print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
124
- return await self._store_on_demand(args, kw, True)
129
+ return self._store_on_demand(args, kw, True)
125
130
 
126
131
  def _call(self, args: tuple, kw: dict, rerun=False):
127
132
  if not self.checkpointer.when:
128
133
  return self.fn(*args, **kw)
129
- coroutine = self._store_on_demand(args, kw, rerun)
130
- return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
134
+ return self._store_on_demand(args, kw, rerun)
131
135
 
132
- def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: # type: ignore
133
- self = cast(CheckpointFn, self)
136
+ __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
137
+ rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
138
+
139
+ @overload
140
+ def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
141
+ @overload
142
+ def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
143
+
144
+ def get(self, *args, **kw):
134
145
  checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
135
146
  try:
136
- val = self.storage.load(checkpoint_path)
137
- return cast(R, resolved_awaitable(val) if self.is_async else val)
147
+ data = self.storage.load(checkpoint_path)
148
+ return data.value if isinstance(data, AwaitableValue) else data
138
149
  except Exception as ex:
139
150
  raise CheckpointError("Could not load checkpoint") from ex
140
151
 
@@ -142,9 +153,6 @@ class CheckpointFn(Generic[Fn]):
142
153
  self = cast(CheckpointFn, self)
143
154
  return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
144
155
 
145
- __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
146
- rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
147
-
148
156
  def __repr__(self) -> str:
149
157
  return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
150
158
 
@@ -79,9 +79,10 @@ async def test_async_caching():
79
79
  return x ** 2
80
80
 
81
81
  result1 = await async_square(3)
82
- result2 = await async_square.get(3)
82
+ result2 = await async_square(3)
83
+ result3 = async_square.get(3)
83
84
 
84
- assert result1 == result2 == 9
85
+ assert result1 == result2 == result3 == 9
85
86
 
86
87
  def test_force_recalculation():
87
88
  @checkpoint
checkpointer/utils.py CHANGED
@@ -2,11 +2,10 @@ import inspect
2
2
  import tokenize
3
3
  from contextlib import contextmanager
4
4
  from io import StringIO
5
- from types import coroutine
6
- from typing import Any, Callable, Coroutine, Generator, Generic, Iterable, TypeVar, cast
5
+ from typing import Any, Callable, Generic, Iterable, TypeVar, cast
7
6
 
8
7
  T = TypeVar("T")
9
- T_Callable = TypeVar("T_Callable", bound=Callable)
8
+ Fn = TypeVar("Fn", bound=Callable)
10
9
 
11
10
  def distinct(seq: Iterable[T]) -> list[T]:
12
11
  return list(dict.fromkeys(seq))
@@ -33,28 +32,20 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
33
32
  except ValueError:
34
33
  pass
35
34
 
36
- def unwrap_fn(fn: T_Callable, checkpoint_fn=False) -> T_Callable:
35
+ def unwrap_fn(fn: Fn, checkpoint_fn=False) -> Fn:
37
36
  from .checkpoint import CheckpointFn
38
37
  while True:
39
38
  if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
40
- return cast(T_Callable, fn)
39
+ return cast(Fn, fn)
41
40
  fn = getattr(fn, "__wrapped__")
42
41
 
43
- async def resolved_awaitable(value: T) -> T:
44
- return value
45
-
46
- @coroutine
47
- def coroutine_as_generator(coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
48
- val = yield from coroutine
49
- return val
42
+ class AwaitableValue:
43
+ def __init__(self, value):
44
+ self.value = value
50
45
 
51
- def sync_resolve_coroutine(coroutine: Coroutine[None, None, T]) -> T:
52
- gen = cast(Generator, coroutine_as_generator(coroutine))
53
- try:
54
- while True:
55
- next(gen)
56
- except StopIteration as ex:
57
- return ex.value
46
+ def __await__(self):
47
+ yield
48
+ return self.value
58
49
 
59
50
  class AttrDict(dict):
60
51
  def __init__(self, *args, **kwargs):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.7.1
3
+ Version: 2.8.1
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -255,7 +255,7 @@ async def main():
255
255
  result2 = await async_compute_sum(3, 7)
256
256
  print(result2) # Outputs 10
257
257
 
258
- result3 = await async_compute_sum.get(3, 7)
258
+ result3 = async_compute_sum.get(3, 7)
259
259
  print(result3) # Outputs 10
260
260
 
261
261
  asyncio.run(main())
@@ -1,16 +1,16 @@
1
1
  checkpointer/__init__.py,sha256=HRLsQ24ZhxgmDcHchZ-hX6wA0NMCSedGA0NmCnUdS_c,832
2
- checkpointer/checkpoint.py,sha256=H1RQJIIuEmRcXd42Y-qUfyU3z0OGp3smY0LoKQN0IYU,6425
2
+ checkpointer/checkpoint.py,sha256=FZSKQVIXj_Enja0253WDUauO8siUHwlzcr4frHMJzB0,6538
3
3
  checkpointer/fn_ident.py,sha256=SWaksNCTlskMom0ztqjECSRjZYPWXUA1p1ZCb-9tWo0,4297
4
4
  checkpointer/object_hash.py,sha256=cxuWRDrg4F9wC18aC12zOZYOPv3bk2Qf6tZ0_WgAb6Y,7484
5
5
  checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
6
- checkpointer/test_checkpointer.py,sha256=VdINXiaA_BoDdVYEB73ctfQ42fw3EDoYa9vYacoB13A,3768
7
- checkpointer/utils.py,sha256=E1AV96NTh3tuVmgPrr0JSKZaokw-Jely5Y6-NjlMCp8,3141
6
+ checkpointer/test_checkpointer.py,sha256=DM5f1Ci4z7MhDnxK-2Z-V3a3ntzBJqORQvCFROAf2SI,3807
7
+ checkpointer/utils.py,sha256=2yk1ksKszXofwnSqBrNCFRSR4C3YLPEAdHUFf-cviRU,2755
8
8
  checkpointer/storages/__init__.py,sha256=Kl4Og5jhYxn6m3tB_kTMsabf4_eWVLmFVAoC-pikNQE,301
9
9
  checkpointer/storages/bcolz_storage.py,sha256=3QkSUSeG5s2kFuVV_LZpzMn1A5E7kqC7jk7w35c0NyQ,2314
10
10
  checkpointer/storages/memory_storage.py,sha256=S5ayOZE_CyaFQJ-vSgObTanldPzG3gh3NksjNAc7vsk,1282
11
11
  checkpointer/storages/pickle_storage.py,sha256=idh9sBMdWuyvS220oa_7bAUpc9Xo9v6Ud9aYKGWasUs,1593
12
12
  checkpointer/storages/storage.py,sha256=_m18Z8TKrdAbi6YYYQmuNOnhna4RB2sJDn1v3liaU3U,721
13
- checkpointer-2.7.1.dist-info/METADATA,sha256=wEdB7ZEnYVW3-bcwFBpZAQXbv76MNU1PvUpKMhW1Ids,10606
14
- checkpointer-2.7.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- checkpointer-2.7.1.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
16
- checkpointer-2.7.1.dist-info/RECORD,,
13
+ checkpointer-2.8.1.dist-info/METADATA,sha256=z0SbrcyKvHTOgJ6oq0dshV25fBAsqGjMCl_uWmxSFRM,10600
14
+ checkpointer-2.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ checkpointer-2.8.1.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
16
+ checkpointer-2.8.1.dist-info/RECORD,,