checkpointer 2.8.1__tar.gz → 2.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.8.1
3
+ Version: 2.9.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -203,11 +203,11 @@ from checkpointer import checkpoint, Storage
203
203
  from datetime import datetime
204
204
 
205
205
  class CustomStorage(Storage):
206
- def exists(self, path) -> bool: ... # Check if a checkpoint exists
207
- def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
208
- def store(self, path, data): ... # Save data to the checkpoint
209
- def load(self, path): ... # Load data from the checkpoint
210
- def delete(self, path): ... # Delete the checkpoint
206
+ def exists(self, call_id) -> bool: ... # Check if a checkpoint exists
207
+ def checkpoint_date(self, call_id) -> datetime: ... # Get the checkpoint's timestamp
208
+ def store(self, call_id, data): ... # Save data to the checkpoint
209
+ def load(self, call_id): ... # Load data from the checkpoint
210
+ def delete(self, call_id): ... # Delete the checkpoint
211
211
 
212
212
  @checkpoint(format=CustomStorage)
213
213
  def custom_cached(x: int):
@@ -183,11 +183,11 @@ from checkpointer import checkpoint, Storage
183
183
  from datetime import datetime
184
184
 
185
185
  class CustomStorage(Storage):
186
- def exists(self, path) -> bool: ... # Check if a checkpoint exists
187
- def checkpoint_date(self, path) -> datetime: ... # Get the checkpoint's timestamp
188
- def store(self, path, data): ... # Save data to the checkpoint
189
- def load(self, path): ... # Load data from the checkpoint
190
- def delete(self, path): ... # Delete the checkpoint
186
+ def exists(self, call_id) -> bool: ... # Check if a checkpoint exists
187
+ def checkpoint_date(self, call_id) -> datetime: ... # Get the checkpoint's timestamp
188
+ def store(self, call_id, data): ... # Save data to the checkpoint
189
+ def load(self, call_id): ... # Load data from the checkpoint
190
+ def delete(self, call_id): ... # Delete the checkpoint
191
191
 
192
192
  @checkpoint(format=CustomStorage)
193
193
  def custom_cached(x: int):
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
  import inspect
3
3
  import re
4
+ from contextlib import suppress
4
5
  from datetime import datetime
5
- from functools import update_wrapper
6
+ from functools import cached_property, update_wrapper
6
7
  from pathlib import Path
7
- from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
8
+ from typing import Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
8
9
  from .fn_ident import get_fn_ident
9
10
  from .object_hash import ObjectHash
10
11
  from .print_checkpoint import print_checkpoint
@@ -54,84 +55,83 @@ class Checkpointer:
54
55
 
55
56
  class CheckpointFn(Generic[Fn]):
56
57
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
57
- self.checkpointer = checkpointer
58
- self.fn = fn
59
-
60
- def _set_ident(self, force=False):
61
- if not hasattr(self, "fn_hash_raw") or force:
62
- self.fn_hash_raw, self.depends = get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
63
- return self
64
-
65
- def _lazyinit(self):
66
- params = self.checkpointer
67
- wrapped = unwrap_fn(self.fn)
58
+ wrapped = unwrap_fn(fn)
68
59
  fn_file = Path(wrapped.__code__.co_filename).name
69
60
  fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
61
+ Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
70
62
  update_wrapper(cast(Callable, self), wrapped)
71
- Storage = STORAGE_MAP[params.format] if isinstance(params.format, str) else params.format
72
- deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
73
- self.fn_hash = str(params.fn_hash or ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes))
74
- self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:32]}"
63
+ self.checkpointer = checkpointer
64
+ self.fn = fn
75
65
  self.storage = Storage(self)
76
66
  self.cleanup = self.storage.cleanup
67
+ self.fn_dir = f"{fn_file}/{fn_name}"
68
+
69
+ @cached_property
70
+ def ident_tuple(self) -> tuple[str, list[Callable]]:
71
+ return get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
72
+
73
+ @property
74
+ def fn_hash_raw(self) -> str:
75
+ return self.ident_tuple[0]
77
76
 
78
- def __getattribute__(self, name: str) -> Any:
79
- return object.__getattribute__(self, "_getattribute")(name)
77
+ @property
78
+ def depends(self) -> list[Callable]:
79
+ return self.ident_tuple[1]
80
80
 
81
- def _getattribute(self, name: str) -> Any:
82
- setattr(self, "_getattribute", super().__getattribute__)
83
- self._lazyinit()
84
- return self._getattribute(name)
81
+ @cached_property
82
+ def fn_hash(self) -> str:
83
+ fn_hash = self.checkpointer.fn_hash
84
+ deep_hashes = [depend.fn_hash_raw for depend in self.deep_depends()]
85
+ return str(fn_hash or ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes))[:32]
85
86
 
86
87
  def reinit(self, recursive=False) -> CheckpointFn[Fn]:
87
- pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
88
- for pointfn in pointfns:
89
- pointfn._set_ident(True)
90
- for pointfn in pointfns:
91
- pointfn._lazyinit()
88
+ depends = list(self.deep_depends()) if recursive else [self]
89
+ for depend in depends:
90
+ with suppress(AttributeError):
91
+ del depend.ident_tuple, depend.fn_hash
92
+ for depend in depends:
93
+ depend.fn_hash
92
94
  return self
93
95
 
94
- def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
96
+ def get_call_id(self, args: tuple, kw: dict) -> str:
95
97
  hash_by = self.checkpointer.hash_by
96
98
  hash_params = hash_by(*args, **kw) if hash_by else (args, kw)
97
- call_hash = ObjectHash(hash_params, digest_size=16)
98
- return f"{self.fn_subdir}/{call_hash}"
99
+ return str(ObjectHash(hash_params, digest_size=16))
99
100
 
100
- async def _resolve_awaitable(self, checkpoint_path: Path, awaitable: Awaitable):
101
+ async def _resolve_awaitable(self, checkpoint_id: str, awaitable: Awaitable):
101
102
  data = await awaitable
102
- self.storage.store(checkpoint_path, AwaitableValue(data))
103
+ self.storage.store(checkpoint_id, AwaitableValue(data))
103
104
  return data
104
105
 
105
- def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
106
+ def _call(self, args: tuple, kw: dict, rerun=False):
106
107
  params = self.checkpointer
107
- checkpoint_id = self.get_checkpoint_id(args, kw)
108
- checkpoint_path = params.root_path / checkpoint_id
108
+ if not params.when:
109
+ return self.fn(*args, **kw)
110
+
111
+ call_id = self.get_call_id(args, kw)
112
+ call_id_long = f"{self.fn_dir}/{self.fn_hash}/{call_id}"
113
+
109
114
  refresh = rerun \
110
- or not self.storage.exists(checkpoint_path) \
111
- or (params.should_expire and params.should_expire(self.storage.checkpoint_date(checkpoint_path)))
115
+ or not self.storage.exists(call_id) \
116
+ or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_id)))
112
117
 
113
118
  if refresh:
114
- print_checkpoint(params.verbosity >= 1, "MEMORIZING", checkpoint_id, "blue")
119
+ print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id_long, "blue")
115
120
  data = self.fn(*args, **kw)
116
121
  if inspect.isawaitable(data):
117
- return self._resolve_awaitable(checkpoint_path, data)
122
+ return self._resolve_awaitable(call_id, data)
118
123
  else:
119
- self.storage.store(checkpoint_path, data)
124
+ self.storage.store(call_id, data)
120
125
  return data
121
126
 
122
127
  try:
123
- data = self.storage.load(checkpoint_path)
124
- print_checkpoint(params.verbosity >= 2, "REMEMBERED", checkpoint_id, "green")
128
+ data = self.storage.load(call_id)
129
+ print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id_long, "green")
125
130
  return data
126
131
  except (EOFError, FileNotFoundError):
127
132
  pass
128
- print_checkpoint(params.verbosity >= 1, "CORRUPTED", checkpoint_id, "yellow")
129
- return self._store_on_demand(args, kw, True)
130
-
131
- def _call(self, args: tuple, kw: dict, rerun=False):
132
- if not self.checkpointer.when:
133
- return self.fn(*args, **kw)
134
- return self._store_on_demand(args, kw, rerun)
133
+ print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_id_long, "yellow")
134
+ return self._call(args, kw, True)
135
135
 
136
136
  __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
137
137
  rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
@@ -142,25 +142,29 @@ class CheckpointFn(Generic[Fn]):
142
142
  def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
143
143
 
144
144
  def get(self, *args, **kw):
145
- checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
145
+ call_id = self.get_call_id(args, kw)
146
146
  try:
147
- data = self.storage.load(checkpoint_path)
147
+ data = self.storage.load(call_id)
148
148
  return data.value if isinstance(data, AwaitableValue) else data
149
149
  except Exception as ex:
150
150
  raise CheckpointError("Could not load checkpoint") from ex
151
151
 
152
152
  def exists(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> bool: # type: ignore
153
153
  self = cast(CheckpointFn, self)
154
- return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
154
+ return self.storage.exists(self.get_call_id(args, kw))
155
+
156
+ def delete(self: Callable[P, R], *args: P.args, **kw: P.kwargs): # type: ignore
157
+ self = cast(CheckpointFn, self)
158
+ self.storage.delete(self.get_call_id(args, kw))
155
159
 
156
160
  def __repr__(self) -> str:
157
161
  return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
158
162
 
159
- def iterate_checkpoint_fns(pointfn: CheckpointFn, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
160
- visited = visited or set()
161
- if pointfn not in visited:
162
- yield pointfn
163
- visited.add(pointfn)
164
- for depend in pointfn.depends:
165
- if isinstance(depend, CheckpointFn):
166
- yield from iterate_checkpoint_fns(depend, visited)
163
+ def deep_depends(self, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
164
+ if self not in visited:
165
+ yield self
166
+ visited = visited or set()
167
+ visited.add(self)
168
+ for depend in self.depends:
169
+ if isinstance(depend, CheckpointFn):
170
+ yield from depend.deep_depends(visited)
@@ -0,0 +1,36 @@
1
+ from typing import Any
2
+ from pathlib import Path
3
+ from datetime import datetime
4
+ from .storage import Storage
5
+
6
+ item_map: dict[Path, dict[str, tuple[datetime, Any]]] = {}
7
+
8
+ class MemoryStorage(Storage):
9
+ def get_dict(self):
10
+ return item_map.setdefault(self.dir(), {})
11
+
12
+ def store(self, call_id, data):
13
+ self.get_dict()[call_id] = (datetime.now(), data)
14
+
15
+ def exists(self, call_id):
16
+ return call_id in self.get_dict()
17
+
18
+ def checkpoint_date(self, call_id):
19
+ return self.get_dict()[call_id][0]
20
+
21
+ def load(self, call_id):
22
+ return self.get_dict()[call_id][1]
23
+
24
+ def delete(self, call_id):
25
+ self.get_dict().pop(call_id, None)
26
+
27
+ def cleanup(self, invalidated=True, expired=True):
28
+ curr_key = self.dir()
29
+ for key, calldict in list(item_map.items()):
30
+ if key.parent == curr_key.parent:
31
+ if invalidated and key != curr_key:
32
+ del item_map[key]
33
+ elif expired and self.checkpointer.should_expire:
34
+ for call_id, (date, _) in list(calldict.items()):
35
+ if self.checkpointer.should_expire(date):
36
+ del calldict[call_id]
@@ -1,35 +1,34 @@
1
1
  import pickle
2
2
  import shutil
3
- from pathlib import Path
4
3
  from datetime import datetime
5
4
  from .storage import Storage
6
5
 
7
- def get_path(path: Path):
8
- return path.with_name(f"{path.name}.pkl")
9
-
10
6
  class PickleStorage(Storage):
11
- def store(self, path, data):
12
- full_path = get_path(path)
13
- full_path.parent.mkdir(parents=True, exist_ok=True)
14
- with full_path.open("wb") as file:
7
+ def get_path(self, call_id: str):
8
+ return self.dir() / f"{call_id}.pkl"
9
+
10
+ def store(self, call_id, data):
11
+ path = self.get_path(call_id)
12
+ path.parent.mkdir(parents=True, exist_ok=True)
13
+ with path.open("wb") as file:
15
14
  pickle.dump(data, file, -1)
16
15
 
17
- def exists(self, path):
18
- return get_path(path).exists()
16
+ def exists(self, call_id):
17
+ return self.get_path(call_id).exists()
19
18
 
20
- def checkpoint_date(self, path):
19
+ def checkpoint_date(self, call_id):
21
20
  # Should use st_atime/access time?
22
- return datetime.fromtimestamp(get_path(path).stat().st_mtime)
21
+ return datetime.fromtimestamp(self.get_path(call_id).stat().st_mtime)
23
22
 
24
- def load(self, path):
25
- with get_path(path).open("rb") as file:
23
+ def load(self, call_id):
24
+ with self.get_path(call_id).open("rb") as file:
26
25
  return pickle.load(file)
27
26
 
28
- def delete(self, path):
29
- get_path(path).unlink(missing_ok=True)
27
+ def delete(self, call_id):
28
+ self.get_path(call_id).unlink(missing_ok=True)
30
29
 
31
30
  def cleanup(self, invalidated=True, expired=True):
32
- version_path = self.checkpointer.root_path.resolve() / self.checkpoint_fn.fn_subdir
31
+ version_path = self.dir()
33
32
  fn_path = version_path.parent
34
33
  if invalidated:
35
34
  old_dirs = [path for path in fn_path.iterdir() if path.is_dir() and path != version_path]
@@ -39,8 +38,7 @@ class PickleStorage(Storage):
39
38
  if expired and self.checkpointer.should_expire:
40
39
  count = 0
41
40
  for pkl_path in fn_path.rglob("*.pkl"):
42
- path = pkl_path.with_suffix("")
43
- if self.checkpointer.should_expire(self.checkpoint_date(path)):
41
+ if self.checkpointer.should_expire(self.checkpoint_date(pkl_path.stem)):
44
42
  count += 1
45
- self.delete(path)
43
+ self.delete(pkl_path.stem)
46
44
  print(f"Removed {count} expired checkpoints for {self.checkpoint_fn.__qualname__}")
@@ -14,14 +14,17 @@ class Storage:
14
14
  self.checkpointer = checkpoint_fn.checkpointer
15
15
  self.checkpoint_fn = checkpoint_fn
16
16
 
17
- def store(self, path: Path, data: Any) -> None: ...
17
+ def dir(self) -> Path:
18
+ return self.checkpointer.root_path / self.checkpoint_fn.fn_dir / self.checkpoint_fn.fn_hash
18
19
 
19
- def exists(self, path: Path) -> bool: ...
20
+ def store(self, call_id: str, data: Any) -> None: ...
20
21
 
21
- def checkpoint_date(self, path: Path) -> datetime: ...
22
+ def exists(self, call_id: str) -> bool: ...
22
23
 
23
- def load(self, path: Path) -> Any: ...
24
+ def checkpoint_date(self, call_id: str) -> datetime: ...
24
25
 
25
- def delete(self, path: Path) -> None: ...
26
+ def load(self, call_id: str) -> Any: ...
27
+
28
+ def delete(self, call_id: str) -> None: ...
26
29
 
27
30
  def cleanup(self, invalidated=True, expired=True): ...
@@ -1,3 +1,10 @@
1
+ """
2
+ TODO: Add tests for:
3
+ - Checkpointing with different formats (pickle, memory, etc.)
4
+ - Classes and methods - instances and classes
5
+ - reinit deep depends
6
+ """
7
+
1
8
  import asyncio
2
9
  import pytest
3
10
  from riprint import riprint as print
@@ -142,7 +149,7 @@ def test_depends():
142
149
  assert set(test_a.depends) == {test_a.fn, helper, multiply_wrapper, global_multiply}
143
150
  assert set(test_b.depends) == {test_b.fn, test_a, multiply_wrapper, global_multiply}
144
151
 
145
- def test_lazy_init():
152
+ def test_lazy_init_1():
146
153
  @checkpoint
147
154
  def fn1(x: object) -> object:
148
155
  return fn2(x)
@@ -151,10 +158,21 @@ def test_lazy_init():
151
158
  def fn2(x: object) -> object:
152
159
  return fn1(x)
153
160
 
154
- assert type(object.__getattribute__(fn1, "_getattribute")) is MethodType
155
- with pytest.raises(AttributeError):
156
- object.__getattribute__(fn1, "fn_hash")
157
- assert fn1.fn_hash == object.__getattribute__(fn1, "fn_hash")
158
- assert type(object.__getattribute__(fn1, "_getattribute")) is MethodWrapperType
161
+ assert set(fn1.depends) == {fn1.fn, fn2}
162
+ assert set(fn2.depends) == {fn1, fn2.fn}
163
+
164
+ def test_lazy_init_2():
165
+ @checkpoint
166
+ def fn1(x: object) -> object:
167
+ return fn2(x)
168
+
169
+ assert set(fn1.depends) == {fn1.fn}
170
+
171
+ @checkpoint
172
+ def fn2(x: object) -> object:
173
+ return fn1(x)
174
+
175
+ assert set(fn1.depends) == {fn1.fn}
176
+ fn1.reinit()
159
177
  assert set(fn1.depends) == {fn1.fn, fn2}
160
178
  assert set(fn2.depends) == {fn1, fn2.fn}
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.8.1"
3
+ version = "2.9.0"
4
4
  requires-python = ">=3.11"
5
5
  dependencies = []
6
6
  authors = [
@@ -8,7 +8,7 @@ resolution-markers = [
8
8
 
9
9
  [[package]]
10
10
  name = "checkpointer"
11
- version = "2.8.1"
11
+ version = "2.9.0"
12
12
  source = { editable = "." }
13
13
 
14
14
  [package.dev-dependencies]
@@ -1,39 +0,0 @@
1
- from typing import Any
2
- from pathlib import Path
3
- from datetime import datetime
4
- from .storage import Storage
5
-
6
- item_map: dict[Path, dict[str, tuple[datetime, Any]]] = {}
7
-
8
- def get_short_path(path: Path):
9
- return path.parts[-1]
10
-
11
- class MemoryStorage(Storage):
12
- def get_dict(self):
13
- return item_map.setdefault(self.checkpointer.root_path / self.checkpoint_fn.fn_subdir, {})
14
-
15
- def store(self, path, data):
16
- self.get_dict()[get_short_path(path)] = (datetime.now(), data)
17
-
18
- def exists(self, path):
19
- return get_short_path(path) in self.get_dict()
20
-
21
- def checkpoint_date(self, path):
22
- return self.get_dict()[get_short_path(path)][0]
23
-
24
- def load(self, path):
25
- return self.get_dict()[get_short_path(path)][1]
26
-
27
- def delete(self, path):
28
- del self.get_dict()[get_short_path(path)]
29
-
30
- def cleanup(self, invalidated=True, expired=True):
31
- curr_key = self.checkpointer.root_path / self.checkpoint_fn.fn_subdir
32
- for key, calldict in list(item_map.items()):
33
- if key.parent == curr_key.parent:
34
- if invalidated and key != curr_key:
35
- del item_map[key]
36
- elif expired and self.checkpointer.should_expire:
37
- for callid, (date, _) in list(calldict.items()):
38
- if self.checkpointer.should_expire(date):
39
- del calldict[callid]
File without changes
File without changes