checkpointer 2.10.1__tar.gz → 2.11.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
1
  __pycache__/
2
2
  /dist/
3
3
  .DS_Store
4
+ /*.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.10.1
3
+ Version: 2.11.1
4
4
  Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -121,11 +121,8 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
121
121
  * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
122
122
  A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
123
123
 
124
- * **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)\
125
- A custom callable that takes the function's arguments (`*args`, `**kwargs`) and returns a hashable object (or tuple of objects). This allows for custom argument normalization (e.g., sorting lists before hashing) or optimized hashing for complex input types, which can improve cache hit rates or speed up the hashing process.
126
-
127
- * **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
128
- An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
124
+ * **`fn_hash_from`** (Type: `Any`, Default: `None`)\
125
+ This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
129
126
 
130
127
  * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
131
128
  Controls the level of logging output from `checkpointer`.
@@ -133,31 +130,63 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
133
130
  * `1`: Shows when functions are computed and cached.
134
131
  * `2`: Also shows when cached results are remembered (loaded from cache).
135
132
 
136
- ### 🗄️ Custom Storage Backends
133
+ ## 🔬 Customize Argument Hashing
134
+
135
+ You can customize how individual function arguments are hashed without changing their actual values when passed in.
136
+
137
+ * **`Annotated[T, HashBy[fn]]`**:\
138
+ Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
139
+
140
+ * **`NoHash[T]`**:\
141
+ Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
142
+
143
+ **Example:**
144
+
145
+ ```python
146
+ from typing import Annotated
147
+ from checkpointer import checkpoint, HashBy, NoHash
148
+ from pathlib import Path
149
+ import logging
150
+
151
+ def file_bytes(path: Path) -> bytes:
152
+ return path.read_bytes()
153
+
154
+ @checkpoint
155
+ def process(
156
+ numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
157
+ data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
158
+ log: NoHash[logging.Logger], # Exclude logger from hashing
159
+ ):
160
+ ...
161
+ ```
162
+
163
+ In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
164
+
165
+ ## 🗄️ Custom Storage Backends
137
166
 
138
167
  For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
139
168
 
140
- Within custom storage methods, `call_id` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
169
+ Within custom storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
141
170
 
142
- #### Example: Custom Storage Backend
171
+ **Example: Custom Storage Backend**
143
172
 
144
173
  ```python
145
174
  from checkpointer import checkpoint, Storage
146
175
  from datetime import datetime
147
176
 
148
177
  class MyCustomStorage(Storage):
149
- def exists(self, call_id):
178
+ def exists(self, call_hash):
150
179
  # Example: Constructing a path based on function ID and call ID
151
180
  fn_dir = self.checkpointer.root_path / self.fn_id()
152
- return (fn_dir / call_id).exists()
181
+ return (fn_dir / call_hash).exists()
153
182
 
154
- def store(self, call_id, data):
155
- ... # Store the serialized data for `call_id`
183
+ def store(self, call_hash, data):
184
+ ... # Store the serialized data for `call_hash`
156
185
  return data # This method must return the data back to checkpointer
157
186
 
158
- def checkpoint_date(self, call_id): ...
159
- def load(self, call_id): ...
160
- def delete(self, call_id): ...
187
+ def checkpoint_date(self, call_hash): ...
188
+ def load(self, call_hash): ...
189
+ def delete(self, call_hash): ...
161
190
 
162
191
  @checkpoint(format=MyCustomStorage)
163
192
  def custom_cached_function(x: int):
@@ -101,11 +101,8 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
101
101
  * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
102
102
  A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
103
103
 
104
- * **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)\
105
- A custom callable that takes the function's arguments (`*args`, `**kwargs`) and returns a hashable object (or tuple of objects). This allows for custom argument normalization (e.g., sorting lists before hashing) or optimized hashing for complex input types, which can improve cache hit rates or speed up the hashing process.
106
-
107
- * **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
108
- An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
104
+ * **`fn_hash_from`** (Type: `Any`, Default: `None`)\
105
+ This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
109
106
 
110
107
  * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
111
108
  Controls the level of logging output from `checkpointer`.
@@ -113,31 +110,63 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
113
110
  * `1`: Shows when functions are computed and cached.
114
111
  * `2`: Also shows when cached results are remembered (loaded from cache).
115
112
 
116
- ### 🗄️ Custom Storage Backends
113
+ ## 🔬 Customize Argument Hashing
114
+
115
+ You can customize how individual function arguments are hashed without changing their actual values when passed in.
116
+
117
+ * **`Annotated[T, HashBy[fn]]`**:\
118
+ Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
119
+
120
+ * **`NoHash[T]`**:\
121
+ Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
122
+
123
+ **Example:**
124
+
125
+ ```python
126
+ from typing import Annotated
127
+ from checkpointer import checkpoint, HashBy, NoHash
128
+ from pathlib import Path
129
+ import logging
130
+
131
+ def file_bytes(path: Path) -> bytes:
132
+ return path.read_bytes()
133
+
134
+ @checkpoint
135
+ def process(
136
+ numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
137
+ data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
138
+ log: NoHash[logging.Logger], # Exclude logger from hashing
139
+ ):
140
+ ...
141
+ ```
142
+
143
+ In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
144
+
145
+ ## 🗄️ Custom Storage Backends
117
146
 
118
147
  For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
119
148
 
120
- Within custom storage methods, `call_id` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
149
+ Within custom storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
121
150
 
122
- #### Example: Custom Storage Backend
151
+ **Example: Custom Storage Backend**
123
152
 
124
153
  ```python
125
154
  from checkpointer import checkpoint, Storage
126
155
  from datetime import datetime
127
156
 
128
157
  class MyCustomStorage(Storage):
129
- def exists(self, call_id):
158
+ def exists(self, call_hash):
130
159
  # Example: Constructing a path based on function ID and call ID
131
160
  fn_dir = self.checkpointer.root_path / self.fn_id()
132
- return (fn_dir / call_id).exists()
161
+ return (fn_dir / call_hash).exists()
133
162
 
134
- def store(self, call_id, data):
135
- ... # Store the serialized data for `call_id`
163
+ def store(self, call_hash, data):
164
+ ... # Store the serialized data for `call_hash`
136
165
  return data # This method must return the data back to checkpointer
137
166
 
138
- def checkpoint_date(self, call_id): ...
139
- def load(self, call_id): ...
140
- def delete(self, call_id): ...
167
+ def checkpoint_date(self, call_hash): ...
168
+ def load(self, call_hash): ...
169
+ def delete(self, call_hash): ...
141
170
 
142
171
  @checkpoint(format=MyCustomStorage)
143
172
  def custom_cached_function(x: int):
@@ -4,18 +4,18 @@ from typing import Callable
4
4
  from .checkpoint import CachedFunction, Checkpointer, CheckpointError
5
5
  from .object_hash import ObjectHash
6
6
  from .storages import MemoryStorage, PickleStorage, Storage
7
- from .utils import AwaitableValue
7
+ from .types import AwaitableValue, HashBy, NoHash
8
8
 
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
11
11
  memory_checkpoint = Checkpointer(format="memory", verbosity=0)
12
12
  tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
13
- static_checkpoint = Checkpointer(fn_hash=ObjectHash())
13
+ static_checkpoint = Checkpointer(fn_hash_from=())
14
14
 
15
15
  def cleanup_all(invalidated=True, expired=True):
16
16
  for obj in gc.get_objects():
17
17
  if isinstance(obj, CachedFunction):
18
18
  obj.cleanup(invalidated=invalidated, expired=expired)
19
19
 
20
- def get_function_hash(fn: Callable, capture=False) -> str:
21
- return CachedFunction(Checkpointer(capture=capture), fn).fn_hash
20
+ def get_function_hash(fn: Callable) -> str:
21
+ return CachedFunction(Checkpointer(), fn).ident.fn_hash
@@ -0,0 +1,236 @@
1
+ from __future__ import annotations
2
+ import re
3
+ from datetime import datetime
4
+ from functools import cached_property, update_wrapper
5
+ from inspect import Parameter, Signature, iscoroutine, signature
6
+ from pathlib import Path
7
+ from typing import (
8
+ Annotated, Callable, Concatenate, Coroutine, Generic,
9
+ Iterable, Literal, ParamSpec, Self, Type, TypedDict,
10
+ TypeVar, Unpack, cast, get_args, get_origin, overload,
11
+ )
12
+ from .fn_ident import RawFunctionIdent, get_fn_ident
13
+ from .object_hash import ObjectHash
14
+ from .print_checkpoint import print_checkpoint
15
+ from .storages import STORAGE_MAP, Storage
16
+ from .types import AwaitableValue, HashBy
17
+ from .utils import unwrap_fn
18
+
19
+ Fn = TypeVar("Fn", bound=Callable)
20
+ P = ParamSpec("P")
21
+ R = TypeVar("R")
22
+ C = TypeVar("C")
23
+
24
+ DEFAULT_DIR = Path.home() / ".cache/checkpoints"
25
+
26
+ class CheckpointError(Exception):
27
+ pass
28
+
29
+ class CheckpointerOpts(TypedDict, total=False):
30
+ format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
31
+ root_path: Path | str | None
32
+ when: bool
33
+ verbosity: Literal[0, 1, 2]
34
+ should_expire: Callable[[datetime], bool] | None
35
+ capture: bool
36
+ fn_hash_from: object
37
+
38
+ class Checkpointer:
39
+ def __init__(self, **opts: Unpack[CheckpointerOpts]):
40
+ self.format = opts.get("format", "pickle")
41
+ self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
42
+ self.when = opts.get("when", True)
43
+ self.verbosity = opts.get("verbosity", 1)
44
+ self.should_expire = opts.get("should_expire")
45
+ self.capture = opts.get("capture", False)
46
+ self.fn_hash_from = opts.get("fn_hash_from")
47
+
48
+ @overload
49
+ def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
50
+ @overload
51
+ def __call__(self, fn: None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer: ...
52
+ def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer | CachedFunction[Fn]:
53
+ if override_opts:
54
+ opts = CheckpointerOpts(**{**self.__dict__, **override_opts})
55
+ return Checkpointer(**opts)(fn)
56
+
57
+ return CachedFunction(self, fn) if callable(fn) else self
58
+
59
+ class FunctionIdent:
60
+ """
61
+ Represents the identity and hash state of a cached function.
62
+ Separated from CachedFunction to prevent hash desynchronization
63
+ among bound instances when `.reinit()` is called.
64
+ """
65
+ def __init__(self, cached_fn: CachedFunction):
66
+ self.__dict__.clear()
67
+ self.cached_fn = cached_fn
68
+
69
+ @cached_property
70
+ def raw_ident(self) -> RawFunctionIdent:
71
+ return get_fn_ident(unwrap_fn(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
72
+
73
+ @cached_property
74
+ def fn_hash(self) -> str:
75
+ if self.cached_fn.checkpointer.fn_hash_from is not None:
76
+ return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
77
+ deep_hashes = [depend.ident.raw_ident.fn_hash for depend in self.cached_fn.deep_depends()]
78
+ return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
79
+
80
+ @cached_property
81
+ def captured_hash(self) -> str:
82
+ deep_hashes = [depend.ident.raw_ident.captured_hash for depend in self.cached_fn.deep_depends()]
83
+ return str(ObjectHash().write_text(iter=deep_hashes))
84
+
85
+ def clear(self):
86
+ self.__init__(self.cached_fn)
87
+
88
+ class CachedFunction(Generic[Fn]):
89
+ def __init__(self, checkpointer: Checkpointer, fn: Fn):
90
+ wrapped = unwrap_fn(fn)
91
+ fn_file = Path(wrapped.__code__.co_filename).name
92
+ fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
93
+ Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
94
+ update_wrapper(cast(Callable, self), wrapped)
95
+ self.checkpointer = checkpointer
96
+ self.fn = fn
97
+ self.fn_dir = f"{fn_file}/{fn_name}"
98
+ self.storage = Storage(self)
99
+ self.cleanup = self.storage.cleanup
100
+ self.bound = ()
101
+
102
+ sig = signature(wrapped)
103
+ params = list(sig.parameters.items())
104
+ pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
105
+ self.arg_names = [name for name, param in params if param.kind in pos_params]
106
+ self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
107
+ self.hash_by_map = get_hash_by_map(sig)
108
+ self.ident = FunctionIdent(self)
109
+
110
+ @overload
111
+ def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
112
+ @overload
113
+ def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
114
+ def __get__(self, instance, owner):
115
+ if instance is None:
116
+ return self
117
+ bound_fn = object.__new__(CachedFunction)
118
+ bound_fn.__dict__ |= self.__dict__
119
+ bound_fn.bound = (instance,)
120
+ return bound_fn
121
+
122
+ @property
123
+ def depends(self) -> list[Callable]:
124
+ return self.ident.raw_ident.depends
125
+
126
+ def reinit(self, recursive=False) -> CachedFunction[Fn]:
127
+ depend_idents = [depend.ident for depend in self.deep_depends()] if recursive else [self.ident]
128
+ for ident in depend_idents: ident.clear()
129
+ for ident in depend_idents: ident.fn_hash
130
+ return self
131
+
132
+ def get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
133
+ args = self.bound + args
134
+ pos_args = args[len(self.arg_names):]
135
+ named_pos_args = dict(zip(self.arg_names, args))
136
+ named_args = {**self.default_args, **named_pos_args, **kw}
137
+ if hash_by_map := self.hash_by_map:
138
+ rest_hash_by = hash_by_map.get(b"**")
139
+ for key, value in named_args.items():
140
+ if hash_by := hash_by_map.get(key, rest_hash_by):
141
+ named_args[key] = hash_by(value)
142
+ if pos_hash_by := hash_by_map.get(b"*"):
143
+ pos_args = tuple(map(pos_hash_by, pos_args))
144
+ return str(ObjectHash(named_args, pos_args, self.ident.captured_hash, digest_size=16))
145
+
146
+ async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
147
+ return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
148
+
149
+ def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
150
+ full_args = self.bound + args
151
+ params = self.checkpointer
152
+ if not params.when:
153
+ return self.fn(*full_args, **kw)
154
+
155
+ call_hash = self.get_call_hash(args, kw)
156
+ call_hash_long = f"{self.fn_dir}/{self.ident.fn_hash}/{call_hash}"
157
+
158
+ refresh = rerun \
159
+ or not self.storage.exists(call_hash) \
160
+ or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_hash)))
161
+
162
+ if refresh:
163
+ print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_hash_long, "blue")
164
+ data = self.fn(*full_args, **kw)
165
+ if iscoroutine(data):
166
+ return self._resolve_coroutine(call_hash, data)
167
+ return self.storage.store(call_hash, data)
168
+
169
+ try:
170
+ data = self.storage.load(call_hash)
171
+ print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_hash_long, "green")
172
+ return data
173
+ except (EOFError, FileNotFoundError):
174
+ pass
175
+ print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_hash_long, "yellow")
176
+ return self._call(args, kw, True)
177
+
178
+ def __call__(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
179
+ return self._call(args, kw)
180
+
181
+ def rerun(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
182
+ return self._call(args, kw, True)
183
+
184
+ def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
185
+ return self.storage.exists(self.get_call_hash(args, kw))
186
+
187
+ def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
188
+ self.storage.delete(self.get_call_hash(args, kw))
189
+
190
+ @overload
191
+ def get(self: Callable[P, Coroutine[object, object, R]], *args: P.args, **kw: P.kwargs) -> R: ...
192
+ @overload
193
+ def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
194
+ def get(self, *args, **kw):
195
+ call_hash = self.get_call_hash(args, kw)
196
+ try:
197
+ data = self.storage.load(call_hash)
198
+ return data.value if isinstance(data, AwaitableValue) else data
199
+ except Exception as ex:
200
+ raise CheckpointError("Could not load checkpoint") from ex
201
+
202
+ @overload
203
+ def _set(self: Callable[P, Coroutine[object, object, R]], value: AwaitableValue[R], *args: P.args, **kw: P.kwargs): ...
204
+ @overload
205
+ def _set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
206
+ def _set(self, value, *args, **kw):
207
+ self.storage.store(self.get_call_hash(args, kw), value)
208
+
209
+ def __repr__(self) -> str:
210
+ return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
211
+
212
+ def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
213
+ if self not in visited:
214
+ yield self
215
+ visited = visited or set()
216
+ visited.add(self)
217
+ for depend in self.depends:
218
+ if isinstance(depend, CachedFunction):
219
+ yield from depend.deep_depends(visited)
220
+
221
+ def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
222
+ if get_origin(annotation) is Annotated:
223
+ args = get_args(annotation)
224
+ metadata = args[1] if len(args) > 1 else None
225
+ if get_origin(metadata) is HashBy:
226
+ return get_args(metadata)[0]
227
+
228
+ def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
229
+ hash_by_map = {}
230
+ for name, param in sig.parameters.items():
231
+ if param.kind == Parameter.VAR_POSITIONAL:
232
+ name = b"*"
233
+ elif param.kind == Parameter.VAR_KEYWORD:
234
+ name = b"**"
235
+ hash_by_map[name] = hash_by_from_annotation(param.annotation)
236
+ return hash_by_map if any(hash_by_map.values()) else {}
@@ -1,15 +1,19 @@
1
1
  import dis
2
2
  import inspect
3
- from collections.abc import Callable
4
3
  from itertools import takewhile
5
4
  from pathlib import Path
6
5
  from types import CodeType, FunctionType, MethodType
7
- from typing import Any, Iterable, Type, TypeGuard
6
+ from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
8
7
  from .object_hash import ObjectHash
9
- from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
8
+ from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
10
9
 
11
10
  cwd = Path.cwd().resolve()
12
11
 
12
+ class RawFunctionIdent(NamedTuple):
13
+ fn_hash: str
14
+ captured_hash: str
15
+ depends: list[Callable]
16
+
13
17
  def is_class(obj) -> TypeGuard[Type]:
14
18
  # isinstance works too, but needlessly triggers _lazyinit()
15
19
  return issubclass(type(obj), type)
@@ -33,7 +37,7 @@ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[st
33
37
  scope_obj = None
34
38
  return classvars
35
39
 
36
- def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], Any]]:
40
+ def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
37
41
  classvars = extract_classvars(code, scope_vars)
38
42
  scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
39
43
  for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
@@ -55,7 +59,7 @@ def get_self_value(fn: Callable) -> type | object | None:
55
59
  if is_class(cls):
56
60
  return cls
57
61
 
58
- def get_fn_captured_vals(fn: Callable) -> list[Any]:
62
+ def get_fn_captured_vals(fn: Callable) -> list[object]:
59
63
  self_value = get_self_value(fn)
60
64
  scope_vars = AttrDict({
61
65
  "LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
@@ -71,24 +75,29 @@ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
71
75
  fn_path = Path(inspect.getfile(candidate_fn)).resolve()
72
76
  return cwd in fn_path.parents and ".venv" not in fn_path.parts
73
77
 
74
- def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
78
+ def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
75
79
  from .checkpoint import CachedFunction
76
- captured_vals_by_fn = captured_vals_by_fn or {}
77
80
  captured_vals = get_fn_captured_vals(fn)
78
- captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
79
- child_fns = (unwrap_fn(val, cached_fn=True) for val in captured_vals if callable(val))
80
- for child_fn in child_fns:
81
+ captured_vals_by_fn = captured_vals_by_fn or {}
82
+ captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
83
+ for val in captured_vals:
84
+ if not callable(val):
85
+ continue
86
+ child_fn = unwrap_fn(val, cached_fn=True)
81
87
  if isinstance(child_fn, CachedFunction):
82
88
  captured_vals_by_fn[child_fn] = []
83
89
  elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
84
- get_depend_fns(child_fn, capture, captured_vals_by_fn)
90
+ get_depend_fns(child_fn, captured_vals_by_fn)
85
91
  return captured_vals_by_fn
86
92
 
87
- def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
93
+ def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
88
94
  from .checkpoint import CachedFunction
89
- captured_vals_by_fn = get_depend_fns(fn, capture)
90
- depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
95
+ captured_vals_by_fn = get_depend_fns(fn)
96
+ depend_captured_vals = list(captured_vals_by_fn.values()) * capture
97
+ depends = captured_vals_by_fn.keys()
91
98
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
92
99
  unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
93
- fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
94
- return fn_hash, depends
100
+ assert fn == unwrapped_depends[0]
101
+ fn_hash = str(ObjectHash(iter=unwrapped_depends))
102
+ captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
103
+ return RawFunctionIdent(fn_hash, captured_hash, depends)
@@ -1,16 +1,19 @@
1
1
  import ctypes
2
2
  import hashlib
3
+ import inspect
3
4
  import io
4
5
  import re
5
6
  import sys
7
+ import tokenize
6
8
  from collections.abc import Iterable
7
9
  from contextlib import nullcontext, suppress
8
10
  from decimal import Decimal
11
+ from io import StringIO
9
12
  from itertools import chain
10
- from pickle import HIGHEST_PROTOCOL as PROTOCOL
13
+ from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
11
14
  from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
12
- from typing import Any, TypeVar
13
- from .utils import ContextVar, get_fn_body
15
+ from typing import Callable, TypeVar
16
+ from .utils import ContextVar
14
17
 
15
18
  np, torch = None, None
16
19
 
@@ -31,16 +34,16 @@ else:
31
34
  def encode_type(t: type | FunctionType) -> str:
32
35
  return f"{t.__module__}:{t.__qualname__}"
33
36
 
34
- def encode_type_of(v: Any) -> str:
37
+ def encode_type_of(v: object) -> str:
35
38
  return encode_type(type(v))
36
39
 
37
40
  class ObjectHashError(Exception):
38
- def __init__(self, obj: Any, cause: Exception):
41
+ def __init__(self, obj: object, cause: Exception):
39
42
  super().__init__(f"{type(cause).__name__} error when hashing {obj}")
40
43
  self.obj = obj
41
44
 
42
45
  class ObjectHash:
43
- def __init__(self, *objs: Any, iter: Iterable[Any] = (), digest_size=64, tolerate_errors=False) -> None:
46
+ def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerate_errors=False) -> None:
44
47
  self.hash = hashlib.blake2b(digest_size=digest_size)
45
48
  self.current: dict[int, int] = {}
46
49
  self.tolerate_errors = ContextVar(tolerate_errors)
@@ -59,7 +62,7 @@ class ObjectHash:
59
62
  def __eq__(self, value: object) -> bool:
60
63
  return isinstance(value, ObjectHash) and str(self) == str(value)
61
64
 
62
- def nested_hash(self, *objs: Any) -> str:
65
+ def nested_hash(self, *objs: object) -> str:
63
66
  return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
64
67
 
65
68
  def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
@@ -70,10 +73,10 @@ class ObjectHash:
70
73
  def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
71
74
  return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
72
75
 
73
- def header(self, *args: Any) -> "ObjectHash":
76
+ def header(self, *args: object) -> "ObjectHash":
74
77
  return self.write_bytes(":".join(map(str, args)).encode())
75
78
 
76
- def update(self, *objs: Any, iter: Iterable[Any] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
79
+ def update(self, *objs: object, iter: Iterable[object] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
77
80
  with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
78
81
  for obj in chain(objs, iter):
79
82
  try:
@@ -81,11 +84,11 @@ class ObjectHash:
81
84
  except Exception as ex:
82
85
  if self.tolerate_errors.value:
83
86
  self.header("error").update(type(ex))
84
- continue
85
- raise ObjectHashError(obj, ex) from ex
87
+ else:
88
+ raise ObjectHashError(obj, ex) from ex
86
89
  return self
87
90
 
88
- def _update_one(self, obj: Any) -> None:
91
+ def _update_one(self, obj: object) -> None:
89
92
  match obj:
90
93
  case None:
91
94
  self.header("null")
@@ -142,7 +145,7 @@ class ObjectHash:
142
145
  case _ if np and isinstance(obj, np.ndarray):
143
146
  self.header("ndarray", encode_type_of(obj), obj.shape, obj.strides).update(obj.dtype)
144
147
  if obj.dtype.hasobject:
145
- self.update(obj.__reduce_ex__(PROTOCOL))
148
+ self.update(obj.__reduce_ex__(PICKLE_PROTOCOL))
146
149
  else:
147
150
  array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
148
151
  self.write_bytes(array.data)
@@ -180,13 +183,14 @@ class ObjectHash:
180
183
  def _update_iterator(self, obj: Iterable) -> "ObjectHash":
181
184
  return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
182
185
 
183
- def _update_object(self, obj: Any) -> "ObjectHash":
186
+ def _update_object(self, obj: object) -> "ObjectHash":
184
187
  self.header("instance", encode_type_of(obj))
185
- if hasattr(obj, "__objecthash__") and callable(obj.__objecthash__):
186
- return self.header("objecthash").update(obj.__objecthash__())
188
+ get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
189
+ if callable(get_hash):
190
+ return self.header("objecthash").update(get_hash())
187
191
  reduced = None
188
192
  with suppress(Exception):
189
- reduced = obj.__reduce_ex__(PROTOCOL)
193
+ reduced = obj.__reduce_ex__(PICKLE_PROTOCOL)
190
194
  with suppress(Exception):
191
195
  reduced = reduced or obj.__reduce__()
192
196
  if isinstance(reduced, str):
@@ -206,3 +210,12 @@ class ObjectHash:
206
210
  return self._update_iterator(obj)
207
211
  repr_str = re.sub(r"\s+(at\s+0x[0-9a-fA-F]+)(>)$", r"\2", repr(obj))
208
212
  return self.header("repr").update(repr_str)
213
+
214
+ def get_fn_body(fn: Callable) -> str:
215
+ try:
216
+ source = inspect.getsource(fn)
217
+ except OSError:
218
+ return ""
219
+ tokens = tokenize.generate_tokens(StringIO(source).readline)
220
+ ignore_types = (tokenize.COMMENT, tokenize.NL)
221
+ return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
@@ -9,21 +9,21 @@ class MemoryStorage(Storage):
9
9
  def get_dict(self):
10
10
  return item_map.setdefault(self.fn_dir(), {})
11
11
 
12
- def store(self, call_id, data):
13
- self.get_dict()[call_id] = (datetime.now(), data)
12
+ def store(self, call_hash, data):
13
+ self.get_dict()[call_hash] = (datetime.now(), data)
14
14
  return data
15
15
 
16
- def exists(self, call_id):
17
- return call_id in self.get_dict()
16
+ def exists(self, call_hash):
17
+ return call_hash in self.get_dict()
18
18
 
19
- def checkpoint_date(self, call_id):
20
- return self.get_dict()[call_id][0]
19
+ def checkpoint_date(self, call_hash):
20
+ return self.get_dict()[call_hash][0]
21
21
 
22
- def load(self, call_id):
23
- return self.get_dict()[call_id][1]
22
+ def load(self, call_hash):
23
+ return self.get_dict()[call_hash][1]
24
24
 
25
- def delete(self, call_id):
26
- self.get_dict().pop(call_id, None)
25
+ def delete(self, call_hash):
26
+ self.get_dict().pop(call_hash, None)
27
27
 
28
28
  def cleanup(self, invalidated=True, expired=True):
29
29
  curr_key = self.fn_dir()
@@ -32,6 +32,6 @@ class MemoryStorage(Storage):
32
32
  if invalidated and key != curr_key:
33
33
  del item_map[key]
34
34
  elif expired and self.checkpointer.should_expire:
35
- for call_id, (date, _) in list(calldict.items()):
35
+ for call_hash, (date, _) in list(calldict.items()):
36
36
  if self.checkpointer.should_expire(date):
37
- del calldict[call_id]
37
+ del calldict[call_hash]
@@ -8,29 +8,29 @@ def filedate(path: Path) -> datetime:
8
8
  return datetime.fromtimestamp(path.stat().st_mtime)
9
9
 
10
10
  class PickleStorage(Storage):
11
- def get_path(self, call_id: str):
12
- return self.fn_dir() / f"{call_id}.pkl"
11
+ def get_path(self, call_hash: str):
12
+ return self.fn_dir() / f"{call_hash}.pkl"
13
13
 
14
- def store(self, call_id, data):
15
- path = self.get_path(call_id)
14
+ def store(self, call_hash, data):
15
+ path = self.get_path(call_hash)
16
16
  path.parent.mkdir(parents=True, exist_ok=True)
17
17
  with path.open("wb") as file:
18
18
  pickle.dump(data, file, -1)
19
19
  return data
20
20
 
21
- def exists(self, call_id):
22
- return self.get_path(call_id).exists()
21
+ def exists(self, call_hash):
22
+ return self.get_path(call_hash).exists()
23
23
 
24
- def checkpoint_date(self, call_id):
24
+ def checkpoint_date(self, call_hash):
25
25
  # Should use st_atime/access time?
26
- return filedate(self.get_path(call_id))
26
+ return filedate(self.get_path(call_hash))
27
27
 
28
- def load(self, call_id):
29
- with self.get_path(call_id).open("rb") as file:
28
+ def load(self, call_hash):
29
+ with self.get_path(call_hash).open("rb") as file:
30
30
  return pickle.load(file)
31
31
 
32
- def delete(self, call_id):
33
- self.get_path(call_id).unlink(missing_ok=True)
32
+ def delete(self, call_hash):
33
+ self.get_path(call_hash).unlink(missing_ok=True)
34
34
 
35
35
  def cleanup(self, invalidated=True, expired=True):
36
36
  version_path = self.fn_dir()
@@ -15,19 +15,19 @@ class Storage:
15
15
  self.cached_fn = cached_fn
16
16
 
17
17
  def fn_id(self) -> str:
18
- return f"{self.cached_fn.fn_dir}/{self.cached_fn.fn_hash}"
18
+ return f"{self.cached_fn.fn_dir}/{self.cached_fn.ident.fn_hash}"
19
19
 
20
20
  def fn_dir(self) -> Path:
21
21
  return self.checkpointer.root_path / self.fn_id()
22
22
 
23
- def store(self, call_id: str, data: Any) -> Any: ...
23
+ def store(self, call_hash: str, data: Any) -> Any: ...
24
24
 
25
- def exists(self, call_id: str) -> bool: ...
25
+ def exists(self, call_hash: str) -> bool: ...
26
26
 
27
- def checkpoint_date(self, call_id: str) -> datetime: ...
27
+ def checkpoint_date(self, call_hash: str) -> datetime: ...
28
28
 
29
- def load(self, call_id: str) -> Any: ...
29
+ def load(self, call_hash: str) -> Any: ...
30
30
 
31
- def delete(self, call_id: str) -> None: ...
31
+ def delete(self, call_hash: str) -> None: ...
32
32
 
33
33
  def cleanup(self, invalidated=True, expired=True): ...
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import pytest
3
- from . import CheckpointError, checkpoint
3
+ from checkpointer import CheckpointError, checkpoint
4
4
  from .utils import AttrDict
5
5
 
6
6
  def global_multiply(a: int, b: int) -> int:
@@ -110,16 +110,16 @@ def test_capture():
110
110
  def test_a():
111
111
  return item_dict.a + 1
112
112
 
113
- init_hash_a = test_a.fn_hash
114
- init_hash_whole = test_whole.fn_hash
113
+ init_hash_a = test_a.ident.captured_hash
114
+ init_hash_whole = test_whole.ident.captured_hash
115
115
  item_dict.b += 1
116
116
  test_whole.reinit()
117
117
  test_a.reinit()
118
- assert test_whole.fn_hash != init_hash_whole
119
- assert test_a.fn_hash == init_hash_a
118
+ assert test_whole.ident.captured_hash != init_hash_whole
119
+ assert test_a.ident.captured_hash == init_hash_a
120
120
  item_dict.a += 1
121
121
  test_a.reinit()
122
- assert test_a.fn_hash != init_hash_a
122
+ assert test_a.ident.captured_hash != init_hash_a
123
123
 
124
124
  def test_depends():
125
125
  def multiply_wrapper(a: int, b: int) -> int:
@@ -0,0 +1,17 @@
1
+ from typing import Annotated, Callable, Generic, TypeVar
2
+
3
+ T = TypeVar("T")
4
+ Fn = TypeVar("Fn", bound=Callable)
5
+
6
+ class HashBy(Generic[Fn]):
7
+ pass
8
+
9
+ NoHash = Annotated[T, HashBy[lambda _: None]]
10
+
11
+ class AwaitableValue(Generic[T]):
12
+ def __init__(self, value: T):
13
+ self.value = value
14
+
15
+ def __await__(self):
16
+ yield
17
+ return self.value
@@ -1,7 +1,4 @@
1
- import inspect
2
- import tokenize
3
1
  from contextlib import contextmanager
4
- from io import StringIO
5
2
  from typing import Any, Callable, Generic, Iterable, TypeVar, cast
6
3
 
7
4
  T = TypeVar("T")
@@ -10,21 +7,6 @@ Fn = TypeVar("Fn", bound=Callable)
10
7
  def distinct(seq: Iterable[T]) -> list[T]:
11
8
  return list(dict.fromkeys(seq))
12
9
 
13
- def transpose(tuples, default_num_returns=0):
14
- output = tuple(zip(*tuples))
15
- if not output:
16
- return ([],) * default_num_returns
17
- return tuple(map(list, output))
18
-
19
- def get_fn_body(fn: Callable) -> str:
20
- try:
21
- source = inspect.getsource(fn)
22
- except OSError:
23
- return ""
24
- tokens = tokenize.generate_tokens(StringIO(source).readline)
25
- ignore_types = (tokenize.COMMENT, tokenize.NL)
26
- return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
27
-
28
10
  def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
29
11
  for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
30
12
  try:
@@ -39,14 +21,6 @@ def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
39
21
  return cast(Fn, fn)
40
22
  fn = getattr(fn, "__wrapped__")
41
23
 
42
- class AwaitableValue:
43
- def __init__(self, value):
44
- self.value = value
45
-
46
- def __await__(self):
47
- yield
48
- return self.value
49
-
50
24
  class AttrDict(dict):
51
25
  def __init__(self, *args, **kwargs):
52
26
  super().__init__(*args, **kwargs)
@@ -91,14 +65,19 @@ class iterate_and_upcoming(Generic[T]):
91
65
  def __init__(self, it: Iterable[T]) -> None:
92
66
  self.it = iter(it)
93
67
  self.previous: tuple[()] | tuple[T] = ()
68
+ self.tracked = self._tracked_iter()
94
69
 
95
70
  def __iter__(self):
96
71
  return self
97
72
 
98
73
  def __next__(self) -> tuple[T, Iterable[T]]:
99
- item = self.previous[0] if self.previous else next(self.it)
100
- self.previous = ()
101
- return item, self._tracked_iter()
74
+ try:
75
+ item = self.previous[0] if self.previous else next(self.it)
76
+ self.previous = ()
77
+ return item, self.tracked
78
+ except StopIteration:
79
+ self.tracked.close()
80
+ raise
102
81
 
103
82
  def _tracked_iter(self):
104
83
  for x in self.it:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.10.1"
3
+ version = "2.11.1"
4
4
  requires-python = ">=3.11"
5
5
  dependencies = []
6
6
  authors = [
@@ -21,10 +21,11 @@ Repository = "https://github.com/Reddan/checkpointer.git"
21
21
  [dependency-groups]
22
22
  dev = [
23
23
  "numpy>=2.2.1",
24
- "omg>=1.3.4",
24
+ "omg>=1.3.6",
25
25
  "poethepoet>=0.30.0",
26
26
  "pytest>=8.3.5",
27
27
  "pytest-asyncio>=0.26.0",
28
+ "rich>=14.0.0",
28
29
  "torch>=2.5.1",
29
30
  ]
30
31
 
@@ -8,7 +8,7 @@ resolution-markers = [
8
8
 
9
9
  [[package]]
10
10
  name = "checkpointer"
11
- version = "2.10.1"
11
+ version = "2.11.1"
12
12
  source = { editable = "." }
13
13
 
14
14
  [package.dev-dependencies]
@@ -18,6 +18,7 @@ dev = [
18
18
  { name = "poethepoet" },
19
19
  { name = "pytest" },
20
20
  { name = "pytest-asyncio" },
21
+ { name = "rich" },
21
22
  { name = "torch" },
22
23
  ]
23
24
 
@@ -26,10 +27,11 @@ dev = [
26
27
  [package.metadata.requires-dev]
27
28
  dev = [
28
29
  { name = "numpy", specifier = ">=2.2.1" },
29
- { name = "omg", specifier = ">=1.3.4" },
30
+ { name = "omg", specifier = ">=1.3.6" },
30
31
  { name = "poethepoet", specifier = ">=0.30.0" },
31
32
  { name = "pytest", specifier = ">=8.3.5" },
32
33
  { name = "pytest-asyncio", specifier = ">=0.26.0" },
34
+ { name = "rich", specifier = ">=14.0.0" },
33
35
  { name = "torch", specifier = ">=2.5.1" },
34
36
  ]
35
37
 
@@ -81,6 +83,18 @@ wheels = [
81
83
  { url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 },
82
84
  ]
83
85
 
86
+ [[package]]
87
+ name = "markdown-it-py"
88
+ version = "3.0.0"
89
+ source = { registry = "https://pypi.org/simple" }
90
+ dependencies = [
91
+ { name = "mdurl" },
92
+ ]
93
+ sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 }
94
+ wheels = [
95
+ { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 },
96
+ ]
97
+
84
98
  [[package]]
85
99
  name = "markupsafe"
86
100
  version = "3.0.2"
@@ -129,6 +143,15 @@ wheels = [
129
143
  { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739 },
130
144
  ]
131
145
 
146
+ [[package]]
147
+ name = "mdurl"
148
+ version = "0.1.2"
149
+ source = { registry = "https://pypi.org/simple" }
150
+ sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 }
151
+ wheels = [
152
+ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
153
+ ]
154
+
132
155
  [[package]]
133
156
  name = "mpmath"
134
157
  version = "1.3.0"
@@ -307,14 +330,14 @@ wheels = [
307
330
 
308
331
  [[package]]
309
332
  name = "omg"
310
- version = "1.3.4"
333
+ version = "1.3.6"
311
334
  source = { registry = "https://pypi.org/simple" }
312
335
  dependencies = [
313
336
  { name = "watchdog" },
314
337
  ]
315
- sdist = { url = "https://files.pythonhosted.org/packages/0b/12/9037613ad41ad71f6342e18a7fdf8f19826dfe61555c45541d87045cd084/omg-1.3.4.tar.gz", hash = "sha256:31e5fd82c2d2fdc53e944b1cb9fb7f630919f9441b8822fb5ae108962fabb49b", size = 14193 }
338
+ sdist = { url = "https://files.pythonhosted.org/packages/65/06/da0a3778b7ff8f1333ed7ddc0931ffff3c86ab5cb8bc4a96a1d0edb8671b/omg-1.3.6.tar.gz", hash = "sha256:465a51b7576fa31ef313e2b9a77d57f5d4816fb0a14dca0fc5c09ff471074fe6", size = 14268 }
316
339
  wheels = [
317
- { url = "https://files.pythonhosted.org/packages/fd/88/4a4e1cc054f141744a60e17d76c8ea1aea7af7e125cae0713d9b0d5ec12f/omg-1.3.4-py3-none-any.whl", hash = "sha256:449e00d341b63afa23633f47cd751a1874ad2545162e5f9abbdd115525fd7a71", size = 7926 },
340
+ { url = "https://files.pythonhosted.org/packages/dd/d2/87346e94dbecd3a65a09e2156c1adf30c162f31e69d0936343c3eff53e7a/omg-1.3.6-py3-none-any.whl", hash = "sha256:8e3ac99a18d5284ceef2ed98492d288d5f22ee2bb417591654a7d2433e196607", size = 7988 },
318
341
  ]
319
342
 
320
343
  [[package]]
@@ -357,6 +380,15 @@ wheels = [
357
380
  { url = "https://files.pythonhosted.org/packages/06/e1/04f56c9d848d6135ca3328c5a2ca84d3303c358ad7828db290385e36a8cc/poethepoet-0.31.1-py3-none-any.whl", hash = "sha256:7fdfa0ac6074be9936723e7231b5bfaad2923e96c674a9857e81d326cf8ccdc2", size = 80238 },
358
381
  ]
359
382
 
383
+ [[package]]
384
+ name = "pygments"
385
+ version = "2.19.1"
386
+ source = { registry = "https://pypi.org/simple" }
387
+ sdist = { url = "https://files.pythonhosted.org/packages/7c/2d/c3338d48ea6cc0feb8446d8e6937e1408088a72a39937982cc6111d17f84/pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f", size = 4968581 }
388
+ wheels = [
389
+ { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
390
+ ]
391
+
360
392
  [[package]]
361
393
  name = "pytest"
362
394
  version = "8.3.5"
@@ -419,6 +451,19 @@ wheels = [
419
451
  { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
420
452
  ]
421
453
 
454
+ [[package]]
455
+ name = "rich"
456
+ version = "14.0.0"
457
+ source = { registry = "https://pypi.org/simple" }
458
+ dependencies = [
459
+ { name = "markdown-it-py" },
460
+ { name = "pygments" },
461
+ ]
462
+ sdist = { url = "https://files.pythonhosted.org/packages/a1/53/830aa4c3066a8ab0ae9a9955976fb770fe9c6102117c8ec4ab3ea62d89e8/rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725", size = 224078 }
463
+ wheels = [
464
+ { url = "https://files.pythonhosted.org/packages/0d/9b/63f4c7ebc259242c89b3acafdb37b41d1185c07ff0011164674e9076b491/rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0", size = 243229 },
465
+ ]
466
+
422
467
  [[package]]
423
468
  name = "setuptools"
424
469
  version = "75.6.0"
@@ -1,184 +0,0 @@
1
- from __future__ import annotations
2
- import inspect
3
- import re
4
- from datetime import datetime
5
- from functools import cached_property, update_wrapper
6
- from pathlib import Path
7
- from typing import (
8
- Awaitable, Callable, Concatenate, Generic, Iterable, Literal,
9
- ParamSpec, Self, Type, TypedDict, TypeVar, Unpack, cast, overload,
10
- )
11
- from .fn_ident import get_fn_ident
12
- from .object_hash import ObjectHash
13
- from .print_checkpoint import print_checkpoint
14
- from .storages import STORAGE_MAP, Storage
15
- from .utils import AwaitableValue, unwrap_fn
16
-
17
- Fn = TypeVar("Fn", bound=Callable)
18
- P = ParamSpec("P")
19
- R = TypeVar("R")
20
- C = TypeVar("C")
21
-
22
- DEFAULT_DIR = Path.home() / ".cache/checkpoints"
23
-
24
- class CheckpointError(Exception):
25
- pass
26
-
27
- class CheckpointerOpts(TypedDict, total=False):
28
- format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
29
- root_path: Path | str | None
30
- when: bool
31
- verbosity: Literal[0, 1, 2]
32
- hash_by: Callable | None
33
- should_expire: Callable[[datetime], bool] | None
34
- capture: bool
35
- fn_hash: ObjectHash | None
36
-
37
- class Checkpointer:
38
- def __init__(self, **opts: Unpack[CheckpointerOpts]):
39
- self.format = opts.get("format", "pickle")
40
- self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
41
- self.when = opts.get("when", True)
42
- self.verbosity = opts.get("verbosity", 1)
43
- self.hash_by = opts.get("hash_by")
44
- self.should_expire = opts.get("should_expire")
45
- self.capture = opts.get("capture", False)
46
- self.fn_hash = opts.get("fn_hash")
47
-
48
- @overload
49
- def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
50
- @overload
51
- def __call__(self, fn: None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer: ...
52
- def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer | CachedFunction[Fn]:
53
- if override_opts:
54
- opts = CheckpointerOpts(**{**self.__dict__, **override_opts})
55
- return Checkpointer(**opts)(fn)
56
-
57
- return CachedFunction(self, fn) if callable(fn) else self
58
-
59
- class CachedFunction(Generic[Fn]):
60
- def __init__(self, checkpointer: Checkpointer, fn: Fn):
61
- wrapped = unwrap_fn(fn)
62
- fn_file = Path(wrapped.__code__.co_filename).name
63
- fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
64
- Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
65
- update_wrapper(cast(Callable, self), wrapped)
66
- self.checkpointer = checkpointer
67
- self.fn = fn
68
- self.fn_dir = f"{fn_file}/{fn_name}"
69
- self.storage = Storage(self)
70
- self.cleanup = self.storage.cleanup
71
- self.bound = ()
72
-
73
- @overload
74
- def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
75
- @overload
76
- def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
77
- def __get__(self, instance, owner):
78
- if instance is None:
79
- return self
80
- bound_fn = object.__new__(CachedFunction)
81
- bound_fn.__dict__ |= self.__dict__
82
- bound_fn.bound = (instance,)
83
- return bound_fn
84
-
85
- @cached_property
86
- def ident_tuple(self) -> tuple[str, list[Callable]]:
87
- return get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
88
-
89
- @property
90
- def fn_hash_raw(self) -> str:
91
- return self.ident_tuple[0]
92
-
93
- @property
94
- def depends(self) -> list[Callable]:
95
- return self.ident_tuple[1]
96
-
97
- @cached_property
98
- def fn_hash(self) -> str:
99
- deep_hashes = [depend.fn_hash_raw for depend in self.deep_depends()]
100
- fn_hash = ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes)
101
- return str(self.checkpointer.fn_hash or fn_hash)[:32]
102
-
103
- def reinit(self, recursive=False) -> CachedFunction[Fn]:
104
- depends = list(self.deep_depends()) if recursive else [self]
105
- for depend in depends:
106
- self.__dict__.pop("fn_hash", None)
107
- self.__dict__.pop("ident_tuple", None)
108
- for depend in depends:
109
- depend.fn_hash
110
- return self
111
-
112
- def get_call_id(self, args: tuple, kw: dict) -> str:
113
- args = self.bound + args
114
- hash_by = self.checkpointer.hash_by
115
- hash_params = hash_by(*args, **kw) if hash_by else (args, kw)
116
- return str(ObjectHash(hash_params, digest_size=16))
117
-
118
- async def _resolve_awaitable(self, call_id: str, awaitable: Awaitable):
119
- return self.storage.store(call_id, AwaitableValue(await awaitable)).value
120
-
121
- def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
122
- full_args = self.bound + args
123
- params = self.checkpointer
124
- if not params.when:
125
- return self.fn(*full_args, **kw)
126
-
127
- call_id = self.get_call_id(args, kw)
128
- call_id_long = f"{self.fn_dir}/{self.fn_hash}/{call_id}"
129
-
130
- refresh = rerun \
131
- or not self.storage.exists(call_id) \
132
- or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_id)))
133
-
134
- if refresh:
135
- print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id_long, "blue")
136
- data = self.fn(*full_args, **kw)
137
- if inspect.isawaitable(data):
138
- return self._resolve_awaitable(call_id, data)
139
- return self.storage.store(call_id, data)
140
-
141
- try:
142
- data = self.storage.load(call_id)
143
- print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id_long, "green")
144
- return data
145
- except (EOFError, FileNotFoundError):
146
- pass
147
- print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_id_long, "yellow")
148
- return self._call(args, kw, True)
149
-
150
- def __call__(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
151
- return self._call(args, kw)
152
-
153
- def rerun(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
154
- return self._call(args, kw, True)
155
-
156
- @overload
157
- def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
158
- @overload
159
- def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
160
- def get(self, *args, **kw):
161
- call_id = self.get_call_id(args, kw)
162
- try:
163
- data = self.storage.load(call_id)
164
- return data.value if isinstance(data, AwaitableValue) else data
165
- except Exception as ex:
166
- raise CheckpointError("Could not load checkpoint") from ex
167
-
168
- def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
169
- return self.storage.exists(self.get_call_id(args, kw))
170
-
171
- def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
172
- self.storage.delete(self.get_call_id(args, kw))
173
-
174
- def __repr__(self) -> str:
175
- return f"<CachedFunction {self.fn.__name__} {self.fn_hash[:6]}>"
176
-
177
- def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
178
- if self not in visited:
179
- yield self
180
- visited = visited or set()
181
- visited.add(self)
182
- for depend in self.depends:
183
- if isinstance(depend, CachedFunction):
184
- yield from depend.deep_depends(visited)
File without changes