checkpointer 2.1.0__tar.gz → 2.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {checkpointer-2.1.0 → checkpointer-2.5.0}/LICENSE +1 -1
  2. {checkpointer-2.1.0 → checkpointer-2.5.0}/PKG-INFO +18 -6
  3. {checkpointer-2.1.0 → checkpointer-2.5.0}/README.md +14 -2
  4. checkpointer-2.5.0/checkpointer/__init__.py +20 -0
  5. {checkpointer-2.1.0 → checkpointer-2.5.0}/checkpointer/checkpoint.py +62 -25
  6. checkpointer-2.5.0/checkpointer/fn_ident.py +94 -0
  7. checkpointer-2.5.0/checkpointer/object_hash.py +186 -0
  8. {checkpointer-2.1.0 → checkpointer-2.5.0}/checkpointer/storages/__init__.py +1 -1
  9. {checkpointer-2.1.0 → checkpointer-2.5.0}/checkpointer/storages/bcolz_storage.py +6 -7
  10. checkpointer-2.5.0/checkpointer/storages/memory_storage.py +39 -0
  11. checkpointer-2.5.0/checkpointer/storages/pickle_storage.py +45 -0
  12. checkpointer-2.1.0/checkpointer/types.py → checkpointer-2.5.0/checkpointer/storages/storage.py +9 -5
  13. checkpointer-2.5.0/checkpointer/test_checkpointer.py +170 -0
  14. checkpointer-2.5.0/checkpointer/utils.py +112 -0
  15. {checkpointer-2.1.0 → checkpointer-2.5.0}/pyproject.toml +17 -4
  16. checkpointer-2.5.0/uv.lock +529 -0
  17. checkpointer-2.1.0/checkpointer/__init__.py +0 -10
  18. checkpointer-2.1.0/checkpointer/function_body.py +0 -80
  19. checkpointer-2.1.0/checkpointer/storages/memory_storage.py +0 -25
  20. checkpointer-2.1.0/checkpointer/storages/pickle_storage.py +0 -31
  21. checkpointer-2.1.0/checkpointer/utils.py +0 -52
  22. checkpointer-2.1.0/uv.lock +0 -22
  23. {checkpointer-2.1.0 → checkpointer-2.5.0}/.gitignore +0 -0
  24. {checkpointer-2.1.0 → checkpointer-2.5.0}/checkpointer/print_checkpoint.py +0 -0
@@ -1,4 +1,4 @@
1
- Copyright 2024 Hampus Hallman
1
+ Copyright 2018-2025 Hampus Hallman
2
2
 
3
3
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
4
 
@@ -1,18 +1,18 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.1.0
3
+ Version: 2.5.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
7
- License: Copyright 2024 Hampus Hallman
7
+ License: Copyright 2018-2025 Hampus Hallman
8
8
 
9
9
  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
10
 
11
11
  The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
12
 
13
13
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
14
+ License-File: LICENSE
14
15
  Requires-Python: >=3.12
15
- Requires-Dist: relib
16
16
  Description-Content-Type: text/markdown
17
17
 
18
18
  # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![Python 3.12](https://img.shields.io/badge/python-3.12-blue)](https://pypi.org/project/checkpointer/)
@@ -108,7 +108,7 @@ Layer caches by stacking checkpoints:
108
108
  @dev_checkpoint # Adds caching during development
109
109
  def some_expensive_function():
110
110
  print("Performing a time-consuming operation...")
111
- return sum(i * i for i in range(10**6))
111
+ return sum(i * i for i in range(10**8))
112
112
  ```
113
113
 
114
114
  - **In development**: Both `dev_checkpoint` and `memory` caches are active.
@@ -153,6 +153,18 @@ Access cached results without recalculating:
153
153
  stored_result = expensive_function.get(4)
154
154
  ```
155
155
 
156
+ ### Refresh Function Hash
157
+
158
+ When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
159
+
160
+ Use the `reinit` method to manually refresh the function's hash within the same session:
161
+
162
+ ```python
163
+ expensive_function.reinit()
164
+ ```
165
+
166
+ This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
167
+
156
168
  ---
157
169
 
158
170
  ## Storage Backends
@@ -189,9 +201,9 @@ from checkpointer import checkpoint, Storage
189
201
  from datetime import datetime
190
202
 
191
203
  class CustomStorage(Storage):
204
+ def store(self, path, data): ... # Save the checkpoint data
192
205
  def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
193
206
  def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
194
- def store(self, path, data): ... # Save the checkpoint data
195
207
  def load(self, path): ... # Return the checkpoint data
196
208
  def delete(self, path): ... # Delete the checkpoint
197
209
 
@@ -91,7 +91,7 @@ Layer caches by stacking checkpoints:
91
91
  @dev_checkpoint # Adds caching during development
92
92
  def some_expensive_function():
93
93
  print("Performing a time-consuming operation...")
94
- return sum(i * i for i in range(10**6))
94
+ return sum(i * i for i in range(10**8))
95
95
  ```
96
96
 
97
97
  - **In development**: Both `dev_checkpoint` and `memory` caches are active.
@@ -136,6 +136,18 @@ Access cached results without recalculating:
136
136
  stored_result = expensive_function.get(4)
137
137
  ```
138
138
 
139
+ ### Refresh Function Hash
140
+
141
+ When using `capture=True`, changes to captured variables are included in the function's hash to determine cache invalidation. However, `checkpointer` does not automatically update this hash during a running Python session—it recalculates between sessions or when you explicitly refresh it.
142
+
143
+ Use the `reinit` method to manually refresh the function's hash within the same session:
144
+
145
+ ```python
146
+ expensive_function.reinit()
147
+ ```
148
+
149
+ This forces `checkpointer` to recalculate the hash of `expensive_function`, considering any changes to captured variables. It's useful when you've modified external variables that the function depends on and want the cache to reflect these changes immediately.
150
+
139
151
  ---
140
152
 
141
153
  ## Storage Backends
@@ -172,9 +184,9 @@ from checkpointer import checkpoint, Storage
172
184
  from datetime import datetime
173
185
 
174
186
  class CustomStorage(Storage):
187
+ def store(self, path, data): ... # Save the checkpoint data
175
188
  def exists(self, path) -> bool: ... # Check if a checkpoint exists at the given path
176
189
  def checkpoint_date(self, path) -> datetime: ... # Return the date the checkpoint was created
177
- def store(self, path, data): ... # Save the checkpoint data
178
190
  def load(self, path): ... # Return the checkpoint data
179
191
  def delete(self, path): ... # Delete the checkpoint
180
192
 
@@ -0,0 +1,20 @@
1
+ import gc
2
+ import tempfile
3
+ from typing import Callable
4
+ from .checkpoint import Checkpointer, CheckpointError, CheckpointFn
5
+ from .object_hash import ObjectHash
6
+ from .storages import MemoryStorage, PickleStorage, Storage
7
+
8
+ create_checkpointer = Checkpointer
9
+ checkpoint = Checkpointer()
10
+ capture_checkpoint = Checkpointer(capture=True)
11
+ memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
+ tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
13
+
14
+ def cleanup_all(invalidated=True, expired=True):
15
+ for obj in gc.get_objects():
16
+ if isinstance(obj, CheckpointFn):
17
+ obj.cleanup(invalidated=invalidated, expired=expired)
18
+
19
+ def get_function_hash(fn: Callable, capture=False) -> str:
20
+ return CheckpointFn(Checkpointer(capture=capture), fn).fn_hash
@@ -1,15 +1,15 @@
1
1
  from __future__ import annotations
2
2
  import inspect
3
- import relib.hashing as hashing
4
- from typing import Generic, TypeVar, Type, TypedDict, Callable, Unpack, Literal, Any, cast, overload
5
- from pathlib import Path
3
+ import re
6
4
  from datetime import datetime
7
5
  from functools import update_wrapper
8
- from .types import Storage
9
- from .function_body import get_function_hash
10
- from .utils import unwrap_fn, sync_resolve_coroutine, resolved_awaitable
11
- from .storages import STORAGE_MAP
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Generic, Iterable, Literal, Type, TypedDict, TypeVar, Unpack, cast, overload
8
+ from .fn_ident import get_fn_ident
9
+ from .object_hash import ObjectHash
12
10
  from .print_checkpoint import print_checkpoint
11
+ from .storages import STORAGE_MAP, Storage
12
+ from .utils import resolved_awaitable, sync_resolve_coroutine, unwrap_fn
13
13
 
14
14
  Fn = TypeVar("Fn", bound=Callable)
15
15
 
@@ -50,22 +50,47 @@ class Checkpointer:
50
50
 
51
51
  class CheckpointFn(Generic[Fn]):
52
52
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
53
- wrapped = unwrap_fn(fn)
54
- file_name = Path(wrapped.__code__.co_filename).name
55
- update_wrapper(cast(Callable, self), wrapped)
56
- storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
57
53
  self.checkpointer = checkpointer
58
54
  self.fn = fn
59
- self.fn_hash, self.depends = get_function_hash(wrapped, self.checkpointer.capture)
60
- self.fn_id = f"{file_name}/{wrapped.__name__}"
55
+
56
+ def _set_ident(self, force=False):
57
+ if not hasattr(self, "fn_hash_raw") or force:
58
+ self.fn_hash_raw, self.depends = get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
59
+ return self
60
+
61
+ def _lazyinit(self):
62
+ wrapped = unwrap_fn(self.fn)
63
+ fn_file = Path(wrapped.__code__.co_filename).name
64
+ fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
65
+ update_wrapper(cast(Callable, self), wrapped)
66
+ store_format = self.checkpointer.format
67
+ Storage = STORAGE_MAP[store_format] if isinstance(store_format, str) else store_format
68
+ deep_hashes = [child._set_ident().fn_hash_raw for child in iterate_checkpoint_fns(self)]
69
+ self.fn_hash = str(ObjectHash().update_hash(self.fn_hash_raw, iter=deep_hashes))
70
+ self.fn_subdir = f"{fn_file}/{fn_name}/{self.fn_hash[:16]}"
61
71
  self.is_async = inspect.iscoroutinefunction(wrapped)
62
- self.storage = storage(checkpointer)
72
+ self.storage = Storage(self)
73
+ self.cleanup = self.storage.cleanup
74
+
75
+ def __getattribute__(self, name: str) -> Any:
76
+ return object.__getattribute__(self, "_getattribute")(name)
77
+
78
+ def _getattribute(self, name: str) -> Any:
79
+ setattr(self, "_getattribute", super().__getattribute__)
80
+ self._lazyinit()
81
+ return self._getattribute(name)
82
+
83
+ def reinit(self, recursive=False):
84
+ pointfns = list(iterate_checkpoint_fns(self)) if recursive else [self]
85
+ for pointfn in pointfns:
86
+ pointfn._set_ident(True)
87
+ for pointfn in pointfns:
88
+ pointfn._lazyinit()
63
89
 
64
90
  def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
65
91
  if not callable(self.checkpointer.path):
66
- # TODO: use digest size before digesting instead of truncating the hash
67
- call_hash = hashing.hash((self.fn_hash, args, kw), "blake2b")[:32]
68
- return f"{self.fn_id}/{call_hash}"
92
+ call_hash = ObjectHash(self.fn_hash, args, kw, digest_size=16)
93
+ return f"{self.fn_subdir}/{call_hash}"
69
94
  checkpoint_id = self.checkpointer.path(*args, **kw)
70
95
  if not isinstance(checkpoint_id, str):
71
96
  raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
@@ -74,13 +99,13 @@ class CheckpointFn(Generic[Fn]):
74
99
  async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
75
100
  checkpoint_id = self.get_checkpoint_id(args, kw)
76
101
  checkpoint_path = self.checkpointer.root_path / checkpoint_id
77
- should_log = self.checkpointer.verbosity > 0
102
+ verbose = self.checkpointer.verbosity > 0
78
103
  refresh = rerun \
79
104
  or not self.storage.exists(checkpoint_path) \
80
105
  or (self.checkpointer.should_expire and self.checkpointer.should_expire(self.storage.checkpoint_date(checkpoint_path)))
81
106
 
82
107
  if refresh:
83
- print_checkpoint(should_log, "MEMORIZING", checkpoint_id, "blue")
108
+ print_checkpoint(verbose, "MEMORIZING", checkpoint_id, "blue")
84
109
  data = self.fn(*args, **kw)
85
110
  if inspect.iscoroutine(data):
86
111
  data = await data
@@ -89,12 +114,12 @@ class CheckpointFn(Generic[Fn]):
89
114
 
90
115
  try:
91
116
  data = self.storage.load(checkpoint_path)
92
- print_checkpoint(should_log, "REMEMBERED", checkpoint_id, "green")
117
+ print_checkpoint(verbose, "REMEMBERED", checkpoint_id, "green")
93
118
  return data
94
119
  except (EOFError, FileNotFoundError):
95
- print_checkpoint(should_log, "CORRUPTED", checkpoint_id, "yellow")
96
- self.storage.delete(checkpoint_path)
97
- return await self._store_on_demand(args, kw, rerun)
120
+ pass
121
+ print_checkpoint(verbose, "CORRUPTED", checkpoint_id, "yellow")
122
+ return await self._store_on_demand(args, kw, True)
98
123
 
99
124
  def _call(self, args: tuple, kw: dict, rerun=False):
100
125
  if not self.checkpointer.when:
@@ -107,8 +132,8 @@ class CheckpointFn(Generic[Fn]):
107
132
  try:
108
133
  val = self.storage.load(checkpoint_path)
109
134
  return resolved_awaitable(val) if self.is_async else val
110
- except:
111
- raise CheckpointError("Could not load checkpoint")
135
+ except Exception as ex:
136
+ raise CheckpointError("Could not load checkpoint") from ex
112
137
 
113
138
  def exists(self, *args: tuple, **kw: dict) -> bool:
114
139
  return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
@@ -116,3 +141,15 @@ class CheckpointFn(Generic[Fn]):
116
141
  __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
117
142
  rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
118
143
  get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
144
+
145
+ def __repr__(self) -> str:
146
+ return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
147
+
148
+ def iterate_checkpoint_fns(pointfn: CheckpointFn, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
149
+ visited = visited or set()
150
+ if pointfn not in visited:
151
+ yield pointfn
152
+ visited.add(pointfn)
153
+ for depend in pointfn.depends:
154
+ if isinstance(depend, CheckpointFn):
155
+ yield from iterate_checkpoint_fns(depend, visited)
@@ -0,0 +1,94 @@
1
+ import dis
2
+ import inspect
3
+ from collections.abc import Callable
4
+ from itertools import takewhile
5
+ from pathlib import Path
6
+ from types import CodeType, FunctionType, MethodType
7
+ from typing import Any, Generator, Type, TypeGuard
8
+ from .object_hash import ObjectHash
9
+ from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
10
+
11
+ cwd = Path.cwd()
12
+
13
+ def is_class(obj) -> TypeGuard[Type]:
14
+ # isinstance works too, but needlessly triggers __getattribute__
15
+ return issubclass(type(obj), type)
16
+
17
+ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[str, Type]]:
18
+ attr_path: tuple[str, ...] = ()
19
+ scope_obj = None
20
+ classvars: dict[str, dict[str, Type]] = {}
21
+ for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
22
+ if instr.opname in scope_vars and not attr_path:
23
+ attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
24
+ attr_path = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
25
+ elif instr.opname == "CALL":
26
+ obj = scope_vars.get_at(attr_path)
27
+ attr_path = ()
28
+ if is_class(obj):
29
+ scope_obj = obj
30
+ elif instr.opname in ("STORE_FAST", "STORE_DEREF") and scope_obj:
31
+ load_key = instr.opname.replace("STORE", "LOAD")
32
+ classvars.setdefault(load_key, {})[instr.argval] = scope_obj
33
+ scope_obj = None
34
+ return classvars
35
+
36
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Generator[tuple[tuple[str, ...], Any], None, None]:
37
+ classvars = extract_classvars(code, scope_vars)
38
+ scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
39
+ for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
40
+ if instr.opname in scope_vars:
41
+ attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
42
+ attr_path: tuple[str, ...] = (instr.opname, instr.argval, *(str(x.argval) for x in attrs))
43
+ val = scope_vars.get_at(attr_path)
44
+ if val is not None:
45
+ yield attr_path, val
46
+ for const in code.co_consts:
47
+ if isinstance(const, CodeType):
48
+ yield from extract_scope_values(const, scope_vars)
49
+
50
+ def get_self_value(fn: Callable) -> type | object | None:
51
+ if isinstance(fn, MethodType):
52
+ return fn.__self__
53
+ parts = tuple(fn.__qualname__.split(".")[:-1])
54
+ cls = parts and AttrDict(fn.__globals__).get_at(parts)
55
+ if is_class(cls):
56
+ return cls
57
+
58
+ def get_fn_captured_vals(fn: Callable) -> list[Any]:
59
+ self_value = get_self_value(fn)
60
+ scope_vars = AttrDict({
61
+ "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
62
+ "LOAD_DEREF": AttrDict(get_cell_contents(fn)),
63
+ "LOAD_GLOBAL": AttrDict(fn.__globals__),
64
+ })
65
+ vals = dict(extract_scope_values(fn.__code__, scope_vars))
66
+ return list(vals.values())
67
+
68
+ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
69
+ if not isinstance(candidate_fn, (FunctionType, MethodType)):
70
+ return False
71
+ fn_path = Path(inspect.getfile(candidate_fn)).resolve()
72
+ return cwd in fn_path.parents and ".venv" not in fn_path.parts
73
+
74
+ def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
75
+ from .checkpoint import CheckpointFn
76
+ captured_vals_by_fn = captured_vals_by_fn or {}
77
+ captured_vals = get_fn_captured_vals(fn)
78
+ captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
79
+ child_fns = (unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val))
80
+ for child_fn in child_fns:
81
+ if isinstance(child_fn, CheckpointFn):
82
+ captured_vals_by_fn[child_fn] = []
83
+ elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
84
+ get_depend_fns(child_fn, capture, captured_vals_by_fn)
85
+ return captured_vals_by_fn
86
+
87
+ def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
88
+ from .checkpoint import CheckpointFn
89
+ captured_vals_by_fn = get_depend_fns(fn, capture)
90
+ depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
91
+ depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
92
+ unwrapped_depends = [fn for fn in depends if not isinstance(fn, CheckpointFn)]
93
+ fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
94
+ return fn_hash, depends
@@ -0,0 +1,186 @@
1
+ import ctypes
2
+ import hashlib
3
+ import io
4
+ import re
5
+ from collections.abc import Iterable
6
+ from contextlib import nullcontext
7
+ from decimal import Decimal
8
+ from itertools import chain
9
+ from pickle import HIGHEST_PROTOCOL as PROTOCOL
10
+ from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
11
+ from typing import Any, TypeAliasType, TypeVar
12
+ from .utils import ContextVar, get_fn_body
13
+
14
+ try:
15
+ import numpy as np
16
+ except:
17
+ np = None
18
+ try:
19
+ import torch
20
+ except:
21
+ torch = None
22
+
23
+ def encode_type(t: type | FunctionType) -> str:
24
+ return f"{t.__module__}:{t.__qualname__}"
25
+
26
+ def encode_val(v: Any) -> str:
27
+ return encode_type(type(v))
28
+
29
+ class ObjectHashError(Exception):
30
+ def __init__(self, obj: Any, cause: Exception):
31
+ super().__init__(f"{type(cause).__name__} error when hashing {obj}")
32
+ self.obj = obj
33
+
34
+ class ObjectHash:
35
+ def __init__(self, *obj: Any, iter: Iterable[Any] = [], digest_size=64, tolerate_errors=False) -> None:
36
+ self.hash = hashlib.blake2b(digest_size=digest_size)
37
+ self.current: dict[int, int] = {}
38
+ self.tolerate_errors = ContextVar(tolerate_errors)
39
+ self.update(iter=chain(obj, iter))
40
+
41
+ def copy(self) -> "ObjectHash":
42
+ new = ObjectHash(tolerate_errors=self.tolerate_errors.value)
43
+ new.hash = self.hash.copy()
44
+ return new
45
+
46
+ def hexdigest(self) -> str:
47
+ return self.hash.hexdigest()
48
+
49
+ __str__ = hexdigest
50
+
51
+ def update_hash(self, *data: bytes | str, iter: Iterable[bytes | str] = []) -> "ObjectHash":
52
+ for d in chain(data, iter):
53
+ self.hash.update(d.encode() if isinstance(d, str) else d)
54
+ return self
55
+
56
+ def header(self, *args: Any) -> "ObjectHash":
57
+ return self.update_hash(":".join(map(str, args)))
58
+
59
+ def update(self, *objs: Any, iter: Iterable[Any] = [], tolerate_errors: bool | None=None) -> "ObjectHash":
60
+ with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
61
+ for obj in chain(objs, iter):
62
+ try:
63
+ self._update_one(obj)
64
+ except Exception as ex:
65
+ if self.tolerate_errors.value:
66
+ self.header("error").update(type(ex))
67
+ continue
68
+ raise ObjectHashError(obj, ex) from ex
69
+ return self
70
+
71
+ def _update_one(self, obj: Any) -> None:
72
+ match obj:
73
+ case None:
74
+ self.header("null")
75
+
76
+ case bool() | int() | float() | complex() | Decimal() | ObjectHash():
77
+ self.header("number", encode_val(obj), obj)
78
+
79
+ case str() | bytes() | bytearray() | memoryview():
80
+ self.header("bytes", encode_val(obj), len(obj)).update_hash(obj)
81
+
82
+ case set() | frozenset():
83
+ self.header("set", encode_val(obj), len(obj))
84
+ try:
85
+ items = sorted(obj)
86
+ except:
87
+ self.header("unsortable")
88
+ items = sorted(str(ObjectHash(item, tolerate_errors=self.tolerate_errors.value)) for item in obj)
89
+ self.update(iter=items)
90
+
91
+ case TypeVar():
92
+ self.header("TypeVar").update(obj.__name__, obj.__bound__, obj.__constraints__, obj.__contravariant__, obj.__covariant__)
93
+
94
+ case TypeAliasType():
95
+ self.header("TypeAliasType").update(obj.__name__, obj.__value__)
96
+
97
+ case UnionType():
98
+ self.header("UnionType").update(obj.__args__)
99
+
100
+ case BuiltinFunctionType():
101
+ self.header("builtin", obj.__qualname__)
102
+
103
+ case FunctionType():
104
+ self.header("function", encode_type(obj)).update(get_fn_body(obj), obj.__defaults__, obj.__kwdefaults__, obj.__annotations__)
105
+
106
+ case MethodType():
107
+ self.header("method").update(obj.__func__, obj.__self__.__class__)
108
+
109
+ case ModuleType():
110
+ self.header("module", obj.__name__, obj.__file__)
111
+
112
+ case GeneratorType():
113
+ self.header("generator", obj.__qualname__)._update_iterator(obj)
114
+
115
+ case io.TextIOWrapper() | io.FileIO() | io.BufferedRandom() | io.BufferedWriter() | io.BufferedReader():
116
+ self.header("file", encode_val(obj)).update(obj.name, obj.mode, obj.tell())
117
+
118
+ case type():
119
+ self.header("type", encode_type(obj))
120
+
121
+ case _ if np and isinstance(obj, np.dtype):
122
+ self.header("dtype").update(obj.__class__, obj.descr)
123
+
124
+ case _ if np and isinstance(obj, np.ndarray):
125
+ self.header("ndarray", encode_val(obj), obj.shape, obj.strides).update(obj.dtype)
126
+ if obj.dtype.hasobject:
127
+ self.update(obj.__reduce_ex__(PROTOCOL))
128
+ else:
129
+ array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
130
+ self.update_hash(array.data)
131
+
132
+ case _ if torch and isinstance(obj, torch.Tensor):
133
+ self.header("tensor", encode_val(obj), str(obj.dtype), tuple(obj.shape), obj.stride(), str(obj.device))
134
+ if obj.device.type != "cpu":
135
+ obj = obj.cpu()
136
+ storage = obj.storage()
137
+ buffer = (ctypes.c_ubyte * (storage.nbytes())).from_address(storage.data_ptr())
138
+ self.update_hash(memoryview(buffer))
139
+
140
+ case _ if id(obj) in self.current:
141
+ self.header("circular", self.current[id(obj)])
142
+
143
+ case _:
144
+ try:
145
+ self.current[id(obj)] = len(self.current)
146
+ match obj:
147
+ case list() | tuple():
148
+ self.header("list", encode_val(obj), len(obj)).update(iter=obj)
149
+ case dict():
150
+ try:
151
+ items = sorted(obj.items())
152
+ except:
153
+ items = sorted((str(ObjectHash(key, tolerate_errors=self.tolerate_errors.value)), val) for key, val in obj.items())
154
+ self.header("dict", encode_val(obj), len(obj)).update(iter=chain.from_iterable(items))
155
+ case _:
156
+ self._update_object(obj)
157
+ finally:
158
+ del self.current[id(obj)]
159
+
160
+ def _update_iterator(self, obj: Iterable) -> None:
161
+ self.header("iterator", encode_val(obj)).update(iter=obj).header(b"iterator-end")
162
+
163
+ def _update_object(self, obj: object) -> "ObjectHash":
164
+ self.header("instance", encode_val(obj))
165
+ try:
166
+ reduced = obj.__reduce_ex__(PROTOCOL) if hasattr(obj, "__reduce_ex__") else obj.__reduce__()
167
+ except:
168
+ reduced = None
169
+ if isinstance(reduced, str):
170
+ return self.header("reduce-str").update(reduced)
171
+ if reduced:
172
+ reduced = list(reduced)
173
+ it = reduced.pop(3) if len(reduced) >= 4 else None
174
+ self.header("reduce").update(reduced)
175
+ if it is not None:
176
+ self._update_iterator(it)
177
+ return self
178
+ if state := hasattr(obj, "__getstate__") and obj.__getstate__():
179
+ return self.header("getstate").update(state)
180
+ if len(getattr(obj, "__slots__", [])):
181
+ slots = {slot: getattr(obj, slot, None) for slot in getattr(obj, "__slots__")}
182
+ return self.header("slots").update(slots)
183
+ if d := getattr(obj, "__dict__", {}):
184
+ return self.header("dict").update(d)
185
+ repr_str = re.sub(r"\s+(at\s+0x[0-9a-fA-F]+)(>)$", r"\2", repr(obj))
186
+ return self.header("repr").update(repr_str)
@@ -1,5 +1,5 @@
1
1
  from typing import Type
2
- from ..types import Storage
2
+ from .storage import Storage
3
3
  from .pickle_storage import PickleStorage
4
4
  from .memory_storage import MemoryStorage
5
5
  from .bcolz_storage import BcolzStorage
@@ -1,7 +1,7 @@
1
1
  import shutil
2
2
  from pathlib import Path
3
3
  from datetime import datetime
4
- from ..types import Storage
4
+ from .storage import Storage
5
5
 
6
6
  def get_data_type_str(x):
7
7
  if isinstance(x, tuple):
@@ -73,9 +73,8 @@ class BcolzStorage(Storage):
73
73
 
74
74
  def delete(self, path):
75
75
  # NOTE: Not recursive
76
- metapath = get_metapath(path)
77
- try:
78
- shutil.rmtree(metapath)
79
- shutil.rmtree(path)
80
- except FileNotFoundError:
81
- pass
76
+ shutil.rmtree(get_metapath(path), ignore_errors=True)
77
+ shutil.rmtree(path, ignore_errors=True)
78
+
79
+ def cleanup(self, invalidated=True, expired=True):
80
+ raise NotImplementedError("cleanup() not implemented for bcolz storage")
@@ -0,0 +1,39 @@
1
+ from typing import Any
2
+ from pathlib import Path
3
+ from datetime import datetime
4
+ from .storage import Storage
5
+
6
+ item_map: dict[Path, dict[str, tuple[datetime, Any]]] = {}
7
+
8
+ def get_short_path(path: Path):
9
+ return path.parts[-1]
10
+
11
+ class MemoryStorage(Storage):
12
+ def get_dict(self):
13
+ return item_map.setdefault(self.checkpointer.root_path / self.checkpoint_fn.fn_subdir, {})
14
+
15
+ def store(self, path, data):
16
+ self.get_dict()[get_short_path(path)] = (datetime.now(), data)
17
+
18
+ def exists(self, path):
19
+ return get_short_path(path) in self.get_dict()
20
+
21
+ def checkpoint_date(self, path):
22
+ return self.get_dict()[get_short_path(path)][0]
23
+
24
+ def load(self, path):
25
+ return self.get_dict()[get_short_path(path)][1]
26
+
27
+ def delete(self, path):
28
+ del self.get_dict()[get_short_path(path)]
29
+
30
+ def cleanup(self, invalidated=True, expired=True):
31
+ curr_key = self.checkpointer.root_path / self.checkpoint_fn.fn_subdir
32
+ for key, calldict in list(item_map.items()):
33
+ if key.parent == curr_key.parent:
34
+ if invalidated and key != curr_key:
35
+ del item_map[key]
36
+ elif expired and self.checkpointer.should_expire:
37
+ for callid, (date, _) in list(calldict.items()):
38
+ if self.checkpointer.should_expire(date):
39
+ del calldict[callid]
@@ -0,0 +1,45 @@
1
+ import pickle
2
+ import shutil
3
+ from pathlib import Path
4
+ from datetime import datetime
5
+ from .storage import Storage
6
+
7
+ def get_path(path: Path):
8
+ return path.with_name(f"{path.name}.pkl")
9
+
10
+ class PickleStorage(Storage):
11
+ def store(self, path, data):
12
+ full_path = get_path(path)
13
+ full_path.parent.mkdir(parents=True, exist_ok=True)
14
+ with full_path.open("wb") as file:
15
+ pickle.dump(data, file, -1)
16
+
17
+ def exists(self, path):
18
+ return get_path(path).exists()
19
+
20
+ def checkpoint_date(self, path):
21
+ return datetime.fromtimestamp(get_path(path).stat().st_mtime)
22
+
23
+ def load(self, path):
24
+ with get_path(path).open("rb") as file:
25
+ return pickle.load(file)
26
+
27
+ def delete(self, path):
28
+ get_path(path).unlink(missing_ok=True)
29
+
30
+ def cleanup(self, invalidated=True, expired=True):
31
+ version_path = self.checkpointer.root_path.resolve() / self.checkpoint_fn.fn_subdir
32
+ fn_path = version_path.parent
33
+ if invalidated:
34
+ old_dirs = [path for path in fn_path.iterdir() if path.is_dir() and path != version_path]
35
+ for path in old_dirs:
36
+ shutil.rmtree(path)
37
+ print(f"Removed {len(old_dirs)} invalidated directories for {self.checkpoint_fn.__qualname__}")
38
+ if expired and self.checkpointer.should_expire:
39
+ count = 0
40
+ for pkl_path in fn_path.rglob("*.pkl"):
41
+ path = pkl_path.with_suffix("")
42
+ if self.checkpointer.should_expire(self.checkpoint_date(path)):
43
+ count += 1
44
+ self.delete(path)
45
+ print(f"Removed {count} expired checkpoints for {self.checkpoint_fn.__qualname__}")