checkpointer 2.6.2__py3-none-any.whl → 2.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -9,7 +9,8 @@ create_checkpointer = Checkpointer
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
11
11
  memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
- tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
12
+ tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
13
+ static_checkpoint = Checkpointer(fn_hash=ObjectHash())
13
14
 
14
15
  def cleanup_all(invalidated=True, expired=True):
15
16
  for obj in gc.get_objects():
@@ -4,7 +4,7 @@ import re
4
4
  from datetime import datetime
5
5
  from functools import update_wrapper
6
6
  from pathlib import Path
7
- from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
7
+ from typing import Any, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
8
8
  from .fn_ident import get_fn_ident
9
9
  from .object_hash import ObjectHash
10
10
  from .print_checkpoint import print_checkpoint
@@ -12,6 +12,8 @@ from .storages import STORAGE_MAP, Storage
12
12
  from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
13
13
 
14
14
  Fn = TypeVar("Fn", bound=Callable)
15
+ P = ParamSpec("P")
16
+ R = TypeVar("R")
15
17
 
16
18
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
17
19
 
@@ -26,7 +28,7 @@ class CheckpointerOpts(TypedDict, total=False):
26
28
  hash_by: Callable | None
27
29
  should_expire: Callable[[datetime], bool] | None
28
30
  capture: bool
29
- fn_hash: str | None
31
+ fn_hash: ObjectHash | None
30
32
 
31
33
  class Checkpointer:
32
34
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
@@ -61,14 +63,14 @@ class CheckpointFn(Generic[Fn]):
61
63
  return self
62
64
 
63
65
  def _lazyinit(self):
66
+ params = self.checkpointer
64
67
  wrapped = unwrap_fn(self.fn)
65
68
  fn_file = Path(wrapped.__code__.co_filename).name
66
69
  fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
67
70
  update_wrapper(cast(Callable, self), wrapped)
68
- store_format = self.checkpointer.format
69
- Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
71
+ Storage = STORAGE_MAP[params.format] if isinstance(params.format, str) else params.format
70
72
  deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
71
- self.fn_hash = self.checkpointer.fn_hash or str(ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
73
+ self.fn_hash = str(params.fn_hash or ObjectHash().write_text(self.fn_hash_raw, *deep_hashes))
72
74
  self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
73
75
  self.is_async: bool = self.fn.is_async if isinstance(self.fn, CheckpointFn) else inspect.iscoroutinefunction(self.fn)
74
76
  self.storage = Storage(self)
@@ -91,20 +93,21 @@ class CheckpointFn(Generic[Fn]):
91
93
  return self
92
94
 
93
95
  def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
94
- hash_params = [self.checkpointer.hash_by(*args, **kw)] if self.checkpointer.hash_by else (args, kw)
95
- call_hash = ObjectHash(self.fn_hash, *hash_params, digest_size=16)
96
+ hash_by = self.checkpointer.hash_by
97
+ hash_params = hash_by(*args, **kw) if hash_by else (args, kw)
98
+ call_hash = ObjectHash(hash_params, digest_size=16)
96
99
  return f"{self.fn_subdir}/{call_hash}"
97
100
 
98
101
  async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
102
+ params = self.checkpointer
99
103
  checkpoint_id = self.get_checkpoint_id(args, kw)
100
- checkpoint_path = self.checkpointer.root_path / checkpoint_id
101
- verbosity = self.checkpointer.verbosity
104
+ checkpoint_path = params.root_path / checkpoint_id
102
105
  refresh = rerun \
103
106
  or not self.storage.exists(checkpoint_path) \
104
- or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
107
+ or (params.should_expire and params.should_expire(self.storage.checkpoint_date(checkpoint_path)))
105
108
 
106
109
  if refresh:
107
- print_checkpoint(verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
110
+ print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
108
111
  data = self.fn(*args, **kw)
109
112
  if inspect.iscoroutine(data):
110
113
  data = await data
@@ -113,11 +116,11 @@ class CheckpointFn(Generic[Fn]):
113
116
 
114
117
  try:
115
118
  data = self.storage.load(checkpoint_path)
116
- print_checkpoint(verbosity >= 2, "REMEMBERED", checkpoint_id, "green")
119
+ print_checkpoint(params.verbosity >= 2, "REMEMBERED", checkpoint_id, "green")
117
120
  return data
118
121
  except (EOFError, FileNotFoundError):
119
122
  pass
120
- print_checkpoint(verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
123
+ print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
121
124
  return await self._store_on_demand(args, kw, True)
122
125
 
123
126
  def _call(self, args: tuple, kw: dict, rerun=False):
@@ -126,20 +129,21 @@ class CheckpointFn(Generic[Fn]):
126
129
  coroutine = self._store_on_demand(args, kw, rerun)
127
130
  return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
128
131
 
129
- def _get(self, args, kw) -> Any:
132
+ def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: # type: ignore
133
+ self = cast(CheckpointFn, self)
130
134
  checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
131
135
  try:
132
136
  val = self.storage.load(checkpoint_path)
133
- return resolved_awaitable(val) if self.is_async else val
137
+ return cast(R, resolved_awaitable(val) if self.is_async else val)
134
138
  except Exception as ex:
135
139
  raise CheckpointError("Could not load checkpoint") from ex
136
140
 
137
- def exists(self, *args: tuple, **kw: dict) -> bool:
141
+ def exists(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> bool: # type: ignore
142
+ self = cast(CheckpointFn, self)
138
143
  return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
139
144
 
140
145
  __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
141
146
  rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
142
- get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
143
147
 
144
148
  def __repr__(self) -> str:
145
149
  return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
@@ -56,6 +56,9 @@ class ObjectHash:
56
56
 
57
57
  __str__ = hexdigest
58
58
 
59
+ def __eq__(self, value: object) -> bool:
60
+ return isinstance(value, ObjectHash) and str(self) == str(value)
61
+
59
62
  def nested_hash(self, *objs: Any) -> str:
60
63
  return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
61
64
 
@@ -6,7 +6,7 @@ from . import checkpoint
6
6
  from .checkpoint import CheckpointError
7
7
  from .utils import AttrDict
8
8
 
9
- def global_multiply(a, b):
9
+ def global_multiply(a: int, b: int) -> int:
10
10
  return a * b
11
11
 
12
12
  @pytest.fixture(autouse=True)
@@ -27,15 +27,15 @@ def test_basic_caching():
27
27
 
28
28
  def test_cache_invalidation():
29
29
  @checkpoint
30
- def multiply(a, b):
30
+ def multiply(a: int, b: int):
31
31
  return a * b
32
32
 
33
33
  @checkpoint
34
- def helper(x):
34
+ def helper(x: int):
35
35
  return multiply(x + 1, 2)
36
36
 
37
37
  @checkpoint
38
- def compute(a, b):
38
+ def compute(a: int, b: int):
39
39
  return helper(a) + helper(b)
40
40
 
41
41
  result1 = compute(3, 4)
@@ -46,7 +46,7 @@ def test_layered_caching():
46
46
 
47
47
  @checkpoint(format="memory")
48
48
  @dev_checkpoint
49
- def expensive_function(x):
49
+ def expensive_function(x: int):
50
50
  return x ** 2
51
51
 
52
52
  assert expensive_function(4) == 16
@@ -95,7 +95,7 @@ def test_force_recalculation():
95
95
  def test_multi_layer_decorator():
96
96
  @checkpoint(format="memory")
97
97
  @checkpoint(format="pickle")
98
- def add(a, b):
98
+ def add(a: int, b: int) -> int:
99
99
  return a + b
100
100
 
101
101
  assert add(2, 3) == 5
@@ -124,18 +124,18 @@ def test_capture():
124
124
  assert test_a.fn_hash != init_hash_a
125
125
 
126
126
  def test_depends():
127
- def multiply_wrapper(a, b):
127
+ def multiply_wrapper(a: int, b: int) -> int:
128
128
  return global_multiply(a, b)
129
129
 
130
- def helper(a, b):
130
+ def helper(a: int, b: int) -> int:
131
131
  return multiply_wrapper(a + 1, b + 1)
132
132
 
133
133
  @checkpoint
134
- def test_a(a, b):
134
+ def test_a(a: int, b: int) -> int:
135
135
  return helper(a, b)
136
136
 
137
137
  @checkpoint
138
- def test_b(a, b):
138
+ def test_b(a: int, b: int) -> int:
139
139
  return test_a(a, b) + multiply_wrapper(a, b)
140
140
 
141
141
  assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
@@ -143,17 +143,17 @@ def test_depends():
143
143
 
144
144
  def test_lazy_init():
145
145
  @checkpoint
146
- def fn1(x):
146
+ def fn1(x: object) -> object:
147
147
  return fn2(x)
148
148
 
149
149
  @checkpoint
150
- def fn2(x):
150
+ def fn2(x: object) -> object:
151
151
  return fn1(x)
152
152
 
153
- assert type(object.__getattribute__(fn1, "_getattribute")) == MethodType
153
+ assert type(object.__getattribute__(fn1, "_getattribute")) is MethodType
154
154
  with pytest.raises(AttributeError):
155
155
  object.__getattribute__(fn1, "fn_hash")
156
156
  assert fn1.fn_hash == object.__getattribute__(fn1, "fn_hash")
157
- assert type(object.__getattribute__(fn1, "_getattribute")) == MethodWrapperType
157
+ assert type(object.__getattribute__(fn1, "_getattribute")) is MethodWrapperType
158
158
  assert set(fn1.depends) == {fn1.fn, fn2}
159
159
  assert set(fn2.depends) == {fn1, fn2.fn}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.6.2
3
+ Version: 2.7.1
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -1,16 +1,16 @@
1
- checkpointer/__init__.py,sha256=ZJ6frUNgkklUi85b5uXTyTfRzMvZgQOJY-ZOnu7jh78,777
2
- checkpointer/checkpoint.py,sha256=3ohNeAqzZLCT9IHJ7nsDbtZ8rjI0rdApbqp8KowDRtc,6332
1
+ checkpointer/__init__.py,sha256=HRLsQ24ZhxgmDcHchZ-hX6wA0NMCSedGA0NmCnUdS_c,832
2
+ checkpointer/checkpoint.py,sha256=H1RQJIIuEmRcXd42Y-qUfyU3z0OGp3smY0LoKQN0IYU,6425
3
3
  checkpointer/fn_ident.py,sha256=SWaksNCTlskMom0ztqjECSRjZYPWXUA1p1ZCb-9tWo0,4297
4
- checkpointer/object_hash.py,sha256=rcHzVYZAeygLyqeKv1NODIDp0M_knLuDZefcBV_7ln4,7371
4
+ checkpointer/object_hash.py,sha256=cxuWRDrg4F9wC18aC12zOZYOPv3bk2Qf6tZ0_WgAb6Y,7484
5
5
  checkpointer/print_checkpoint.py,sha256=aJCeWMRJiIR3KpyPk_UOKTaD906kArGrmLGQ3LqcVgo,1369
6
- checkpointer/test_checkpointer.py,sha256=uJ2Pg9Miq1W0l28eNlRhMjuT_R8c-ygYwp3KP3VW8Os,3600
6
+ checkpointer/test_checkpointer.py,sha256=VdINXiaA_BoDdVYEB73ctfQ42fw3EDoYa9vYacoB13A,3768
7
7
  checkpointer/utils.py,sha256=E1AV96NTh3tuVmgPrr0JSKZaokw-Jely5Y6-NjlMCp8,3141
8
8
  checkpointer/storages/__init__.py,sha256=Kl4Og5jhYxn6m3tB_kTMsabf4_eWVLmFVAoC-pikNQE,301
9
9
  checkpointer/storages/bcolz_storage.py,sha256=3QkSUSeG5s2kFuVV_LZpzMn1A5E7kqC7jk7w35c0NyQ,2314
10
10
  checkpointer/storages/memory_storage.py,sha256=S5ayOZE_CyaFQJ-vSgObTanldPzG3gh3NksjNAc7vsk,1282
11
11
  checkpointer/storages/pickle_storage.py,sha256=idh9sBMdWuyvS220oa_7bAUpc9Xo9v6Ud9aYKGWasUs,1593
12
12
  checkpointer/storages/storage.py,sha256=_m18Z8TKrdAbi6YYYQmuNOnhna4RB2sJDn1v3liaU3U,721
13
- checkpointer-2.6.2.dist-info/METADATA,sha256=fpt1kQbTghSeSIbDranjL0-iGYaA3LbWFlyAfSTgJB0,10606
14
- checkpointer-2.6.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- checkpointer-2.6.2.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
16
- checkpointer-2.6.2.dist-info/RECORD,,
13
+ checkpointer-2.7.1.dist-info/METADATA,sha256=wEdB7ZEnYVW3-bcwFBpZAQXbv76MNU1PvUpKMhW1Ids,10606
14
+ checkpointer-2.7.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ checkpointer-2.7.1.dist-info/licenses/LICENSE,sha256=9xVsdtv_-uSyY9Xl9yujwAPm4-mjcCLeVy-ljwXEWbo,1059
16
+ checkpointer-2.7.1.dist-info/RECORD,,