checkpointer 2.10.1__tar.gz → 2.11.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpointer-2.10.1 → checkpointer-2.11.1}/.gitignore +1 -0
- {checkpointer-2.10.1 → checkpointer-2.11.1}/PKG-INFO +45 -16
- {checkpointer-2.10.1 → checkpointer-2.11.1}/README.md +44 -15
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/__init__.py +4 -4
- checkpointer-2.11.1/checkpointer/checkpoint.py +236 -0
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/fn_ident.py +25 -16
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/object_hash.py +30 -17
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/storages/memory_storage.py +12 -12
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/storages/pickle_storage.py +12 -12
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/storages/storage.py +6 -6
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/test_checkpointer.py +6 -6
- checkpointer-2.11.1/checkpointer/types.py +17 -0
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/utils.py +8 -29
- {checkpointer-2.10.1 → checkpointer-2.11.1}/pyproject.toml +3 -2
- {checkpointer-2.10.1 → checkpointer-2.11.1}/uv.lock +50 -5
- checkpointer-2.10.1/checkpointer/checkpoint.py +0 -184
- {checkpointer-2.10.1 → checkpointer-2.11.1}/.python-version +0 -0
- {checkpointer-2.10.1 → checkpointer-2.11.1}/LICENSE +0 -0
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/print_checkpoint.py +0 -0
- {checkpointer-2.10.1 → checkpointer-2.11.1}/checkpointer/storages/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: checkpointer
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.11.1
|
4
4
|
Summary: checkpointer adds code-aware caching to Python functions, maintaining correctness and speeding up execution as your code changes.
|
5
5
|
Project-URL: Repository, https://github.com/Reddan/checkpointer.git
|
6
6
|
Author: Hampus Hallman
|
@@ -121,11 +121,8 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
|
|
121
121
|
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
|
122
122
|
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
123
123
|
|
124
|
-
* **`
|
125
|
-
|
126
|
-
|
127
|
-
* **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
|
128
|
-
An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
|
124
|
+
* **`fn_hash_from`** (Type: `Any`, Default: `None`)\
|
125
|
+
This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
|
129
126
|
|
130
127
|
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
131
128
|
Controls the level of logging output from `checkpointer`.
|
@@ -133,31 +130,63 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
|
|
133
130
|
* `1`: Shows when functions are computed and cached.
|
134
131
|
* `2`: Also shows when cached results are remembered (loaded from cache).
|
135
132
|
|
136
|
-
|
133
|
+
## 🔬 Customize Argument Hashing
|
134
|
+
|
135
|
+
You can customize how individual function arguments are hashed without changing their actual values when passed in.
|
136
|
+
|
137
|
+
* **`Annotated[T, HashBy[fn]]`**:\
|
138
|
+
Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
|
139
|
+
|
140
|
+
* **`NoHash[T]`**:\
|
141
|
+
Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
|
142
|
+
|
143
|
+
**Example:**
|
144
|
+
|
145
|
+
```python
|
146
|
+
from typing import Annotated
|
147
|
+
from checkpointer import checkpoint, HashBy, NoHash
|
148
|
+
from pathlib import Path
|
149
|
+
import logging
|
150
|
+
|
151
|
+
def file_bytes(path: Path) -> bytes:
|
152
|
+
return path.read_bytes()
|
153
|
+
|
154
|
+
@checkpoint
|
155
|
+
def process(
|
156
|
+
numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
|
157
|
+
data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
|
158
|
+
log: NoHash[logging.Logger], # Exclude logger from hashing
|
159
|
+
):
|
160
|
+
...
|
161
|
+
```
|
162
|
+
|
163
|
+
In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
|
164
|
+
|
165
|
+
## 🗄️ Custom Storage Backends
|
137
166
|
|
138
167
|
For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
|
139
168
|
|
140
|
-
Within custom storage methods, `
|
169
|
+
Within custom storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
|
141
170
|
|
142
|
-
|
171
|
+
**Example: Custom Storage Backend**
|
143
172
|
|
144
173
|
```python
|
145
174
|
from checkpointer import checkpoint, Storage
|
146
175
|
from datetime import datetime
|
147
176
|
|
148
177
|
class MyCustomStorage(Storage):
|
149
|
-
def exists(self,
|
178
|
+
def exists(self, call_hash):
|
150
179
|
# Example: Constructing a path based on function ID and call ID
|
151
180
|
fn_dir = self.checkpointer.root_path / self.fn_id()
|
152
|
-
return (fn_dir /
|
181
|
+
return (fn_dir / call_hash).exists()
|
153
182
|
|
154
|
-
def store(self,
|
155
|
-
... # Store the serialized data for `
|
183
|
+
def store(self, call_hash, data):
|
184
|
+
... # Store the serialized data for `call_hash`
|
156
185
|
return data # This method must return the data back to checkpointer
|
157
186
|
|
158
|
-
def checkpoint_date(self,
|
159
|
-
def load(self,
|
160
|
-
def delete(self,
|
187
|
+
def checkpoint_date(self, call_hash): ...
|
188
|
+
def load(self, call_hash): ...
|
189
|
+
def delete(self, call_hash): ...
|
161
190
|
|
162
191
|
@checkpoint(format=MyCustomStorage)
|
163
192
|
def custom_cached_function(x: int):
|
@@ -101,11 +101,8 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
|
|
101
101
|
* **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
|
102
102
|
A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
|
103
103
|
|
104
|
-
* **`
|
105
|
-
|
106
|
-
|
107
|
-
* **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
|
108
|
-
An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
|
104
|
+
* **`fn_hash_from`** (Type: `Any`, Default: `None`)\
|
105
|
+
This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any object relevant to your invalidation logic (e.g., version strings, config parameters). The object you provide will be hashed internally by `checkpointer`.
|
109
106
|
|
110
107
|
* **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
|
111
108
|
Controls the level of logging output from `checkpointer`.
|
@@ -113,31 +110,63 @@ The `@checkpoint` decorator accepts the following parameters to customize its be
|
|
113
110
|
* `1`: Shows when functions are computed and cached.
|
114
111
|
* `2`: Also shows when cached results are remembered (loaded from cache).
|
115
112
|
|
116
|
-
|
113
|
+
## 🔬 Customize Argument Hashing
|
114
|
+
|
115
|
+
You can customize how individual function arguments are hashed without changing their actual values when passed in.
|
116
|
+
|
117
|
+
* **`Annotated[T, HashBy[fn]]`**:\
|
118
|
+
Hashes the argument by applying `fn(argument)` before hashing. This enables custom normalization (e.g., sorting lists to ignore order) or optimized hashing for complex types, improving cache hit rates or speeding up hashing.
|
119
|
+
|
120
|
+
* **`NoHash[T]`**:\
|
121
|
+
Completely excludes the argument from hashing, so changes to it won’t trigger cache invalidation.
|
122
|
+
|
123
|
+
**Example:**
|
124
|
+
|
125
|
+
```python
|
126
|
+
from typing import Annotated
|
127
|
+
from checkpointer import checkpoint, HashBy, NoHash
|
128
|
+
from pathlib import Path
|
129
|
+
import logging
|
130
|
+
|
131
|
+
def file_bytes(path: Path) -> bytes:
|
132
|
+
return path.read_bytes()
|
133
|
+
|
134
|
+
@checkpoint
|
135
|
+
def process(
|
136
|
+
numbers: Annotated[list[int], HashBy[sorted]], # Hash by sorted list
|
137
|
+
data_file: Annotated[Path, HashBy[file_bytes]], # Hash by file content
|
138
|
+
log: NoHash[logging.Logger], # Exclude logger from hashing
|
139
|
+
):
|
140
|
+
...
|
141
|
+
```
|
142
|
+
|
143
|
+
In this example, the cache key for `numbers` ignores order, `data_file` is hashed based on its contents rather than path, and changes to `log` don’t affect caching.
|
144
|
+
|
145
|
+
## 🗄️ Custom Storage Backends
|
117
146
|
|
118
147
|
For integration with databases, cloud storage, or custom serialization, implement your own storage backend by inheriting from `checkpointer.Storage` and implementing its abstract methods.
|
119
148
|
|
120
|
-
Within custom storage methods, `
|
149
|
+
Within custom storage methods, `call_hash` identifies calls by arguments. Use `self.fn_id()` to get the function's unique identity (name + hash/version), crucial for organizing stored checkpoints (e.g., by function version). Access global `Checkpointer` config via `self.checkpointer`.
|
121
150
|
|
122
|
-
|
151
|
+
**Example: Custom Storage Backend**
|
123
152
|
|
124
153
|
```python
|
125
154
|
from checkpointer import checkpoint, Storage
|
126
155
|
from datetime import datetime
|
127
156
|
|
128
157
|
class MyCustomStorage(Storage):
|
129
|
-
def exists(self,
|
158
|
+
def exists(self, call_hash):
|
130
159
|
# Example: Constructing a path based on function ID and call ID
|
131
160
|
fn_dir = self.checkpointer.root_path / self.fn_id()
|
132
|
-
return (fn_dir /
|
161
|
+
return (fn_dir / call_hash).exists()
|
133
162
|
|
134
|
-
def store(self,
|
135
|
-
... # Store the serialized data for `
|
163
|
+
def store(self, call_hash, data):
|
164
|
+
... # Store the serialized data for `call_hash`
|
136
165
|
return data # This method must return the data back to checkpointer
|
137
166
|
|
138
|
-
def checkpoint_date(self,
|
139
|
-
def load(self,
|
140
|
-
def delete(self,
|
167
|
+
def checkpoint_date(self, call_hash): ...
|
168
|
+
def load(self, call_hash): ...
|
169
|
+
def delete(self, call_hash): ...
|
141
170
|
|
142
171
|
@checkpoint(format=MyCustomStorage)
|
143
172
|
def custom_cached_function(x: int):
|
@@ -4,18 +4,18 @@ from typing import Callable
|
|
4
4
|
from .checkpoint import CachedFunction, Checkpointer, CheckpointError
|
5
5
|
from .object_hash import ObjectHash
|
6
6
|
from .storages import MemoryStorage, PickleStorage, Storage
|
7
|
-
from .
|
7
|
+
from .types import AwaitableValue, HashBy, NoHash
|
8
8
|
|
9
9
|
checkpoint = Checkpointer()
|
10
10
|
capture_checkpoint = Checkpointer(capture=True)
|
11
11
|
memory_checkpoint = Checkpointer(format="memory", verbosity=0)
|
12
12
|
tmp_checkpoint = Checkpointer(root_path=f"{tempfile.gettempdir()}/checkpoints")
|
13
|
-
static_checkpoint = Checkpointer(
|
13
|
+
static_checkpoint = Checkpointer(fn_hash_from=())
|
14
14
|
|
15
15
|
def cleanup_all(invalidated=True, expired=True):
|
16
16
|
for obj in gc.get_objects():
|
17
17
|
if isinstance(obj, CachedFunction):
|
18
18
|
obj.cleanup(invalidated=invalidated, expired=expired)
|
19
19
|
|
20
|
-
def get_function_hash(fn: Callable
|
21
|
-
return CachedFunction(Checkpointer(
|
20
|
+
def get_function_hash(fn: Callable) -> str:
|
21
|
+
return CachedFunction(Checkpointer(), fn).ident.fn_hash
|
@@ -0,0 +1,236 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import re
|
3
|
+
from datetime import datetime
|
4
|
+
from functools import cached_property, update_wrapper
|
5
|
+
from inspect import Parameter, Signature, iscoroutine, signature
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import (
|
8
|
+
Annotated, Callable, Concatenate, Coroutine, Generic,
|
9
|
+
Iterable, Literal, ParamSpec, Self, Type, TypedDict,
|
10
|
+
TypeVar, Unpack, cast, get_args, get_origin, overload,
|
11
|
+
)
|
12
|
+
from .fn_ident import RawFunctionIdent, get_fn_ident
|
13
|
+
from .object_hash import ObjectHash
|
14
|
+
from .print_checkpoint import print_checkpoint
|
15
|
+
from .storages import STORAGE_MAP, Storage
|
16
|
+
from .types import AwaitableValue, HashBy
|
17
|
+
from .utils import unwrap_fn
|
18
|
+
|
19
|
+
Fn = TypeVar("Fn", bound=Callable)
|
20
|
+
P = ParamSpec("P")
|
21
|
+
R = TypeVar("R")
|
22
|
+
C = TypeVar("C")
|
23
|
+
|
24
|
+
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
25
|
+
|
26
|
+
class CheckpointError(Exception):
|
27
|
+
pass
|
28
|
+
|
29
|
+
class CheckpointerOpts(TypedDict, total=False):
|
30
|
+
format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
|
31
|
+
root_path: Path | str | None
|
32
|
+
when: bool
|
33
|
+
verbosity: Literal[0, 1, 2]
|
34
|
+
should_expire: Callable[[datetime], bool] | None
|
35
|
+
capture: bool
|
36
|
+
fn_hash_from: object
|
37
|
+
|
38
|
+
class Checkpointer:
|
39
|
+
def __init__(self, **opts: Unpack[CheckpointerOpts]):
|
40
|
+
self.format = opts.get("format", "pickle")
|
41
|
+
self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
|
42
|
+
self.when = opts.get("when", True)
|
43
|
+
self.verbosity = opts.get("verbosity", 1)
|
44
|
+
self.should_expire = opts.get("should_expire")
|
45
|
+
self.capture = opts.get("capture", False)
|
46
|
+
self.fn_hash_from = opts.get("fn_hash_from")
|
47
|
+
|
48
|
+
@overload
|
49
|
+
def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
|
50
|
+
@overload
|
51
|
+
def __call__(self, fn: None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer: ...
|
52
|
+
def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer | CachedFunction[Fn]:
|
53
|
+
if override_opts:
|
54
|
+
opts = CheckpointerOpts(**{**self.__dict__, **override_opts})
|
55
|
+
return Checkpointer(**opts)(fn)
|
56
|
+
|
57
|
+
return CachedFunction(self, fn) if callable(fn) else self
|
58
|
+
|
59
|
+
class FunctionIdent:
|
60
|
+
"""
|
61
|
+
Represents the identity and hash state of a cached function.
|
62
|
+
Separated from CachedFunction to prevent hash desynchronization
|
63
|
+
among bound instances when `.reinit()` is called.
|
64
|
+
"""
|
65
|
+
def __init__(self, cached_fn: CachedFunction):
|
66
|
+
self.__dict__.clear()
|
67
|
+
self.cached_fn = cached_fn
|
68
|
+
|
69
|
+
@cached_property
|
70
|
+
def raw_ident(self) -> RawFunctionIdent:
|
71
|
+
return get_fn_ident(unwrap_fn(self.cached_fn.fn), self.cached_fn.checkpointer.capture)
|
72
|
+
|
73
|
+
@cached_property
|
74
|
+
def fn_hash(self) -> str:
|
75
|
+
if self.cached_fn.checkpointer.fn_hash_from is not None:
|
76
|
+
return str(ObjectHash(self.cached_fn.checkpointer.fn_hash_from, digest_size=16))
|
77
|
+
deep_hashes = [depend.ident.raw_ident.fn_hash for depend in self.cached_fn.deep_depends()]
|
78
|
+
return str(ObjectHash(digest_size=16).write_text(iter=deep_hashes))
|
79
|
+
|
80
|
+
@cached_property
|
81
|
+
def captured_hash(self) -> str:
|
82
|
+
deep_hashes = [depend.ident.raw_ident.captured_hash for depend in self.cached_fn.deep_depends()]
|
83
|
+
return str(ObjectHash().write_text(iter=deep_hashes))
|
84
|
+
|
85
|
+
def clear(self):
|
86
|
+
self.__init__(self.cached_fn)
|
87
|
+
|
88
|
+
class CachedFunction(Generic[Fn]):
|
89
|
+
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
90
|
+
wrapped = unwrap_fn(fn)
|
91
|
+
fn_file = Path(wrapped.__code__.co_filename).name
|
92
|
+
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
93
|
+
Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
|
94
|
+
update_wrapper(cast(Callable, self), wrapped)
|
95
|
+
self.checkpointer = checkpointer
|
96
|
+
self.fn = fn
|
97
|
+
self.fn_dir = f"{fn_file}/{fn_name}"
|
98
|
+
self.storage = Storage(self)
|
99
|
+
self.cleanup = self.storage.cleanup
|
100
|
+
self.bound = ()
|
101
|
+
|
102
|
+
sig = signature(wrapped)
|
103
|
+
params = list(sig.parameters.items())
|
104
|
+
pos_params = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
105
|
+
self.arg_names = [name for name, param in params if param.kind in pos_params]
|
106
|
+
self.default_args = {name: param.default for name, param in params if param.default is not Parameter.empty}
|
107
|
+
self.hash_by_map = get_hash_by_map(sig)
|
108
|
+
self.ident = FunctionIdent(self)
|
109
|
+
|
110
|
+
@overload
|
111
|
+
def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
|
112
|
+
@overload
|
113
|
+
def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
|
114
|
+
def __get__(self, instance, owner):
|
115
|
+
if instance is None:
|
116
|
+
return self
|
117
|
+
bound_fn = object.__new__(CachedFunction)
|
118
|
+
bound_fn.__dict__ |= self.__dict__
|
119
|
+
bound_fn.bound = (instance,)
|
120
|
+
return bound_fn
|
121
|
+
|
122
|
+
@property
|
123
|
+
def depends(self) -> list[Callable]:
|
124
|
+
return self.ident.raw_ident.depends
|
125
|
+
|
126
|
+
def reinit(self, recursive=False) -> CachedFunction[Fn]:
|
127
|
+
depend_idents = [depend.ident for depend in self.deep_depends()] if recursive else [self.ident]
|
128
|
+
for ident in depend_idents: ident.clear()
|
129
|
+
for ident in depend_idents: ident.fn_hash
|
130
|
+
return self
|
131
|
+
|
132
|
+
def get_call_hash(self, args: tuple, kw: dict[str, object]) -> str:
|
133
|
+
args = self.bound + args
|
134
|
+
pos_args = args[len(self.arg_names):]
|
135
|
+
named_pos_args = dict(zip(self.arg_names, args))
|
136
|
+
named_args = {**self.default_args, **named_pos_args, **kw}
|
137
|
+
if hash_by_map := self.hash_by_map:
|
138
|
+
rest_hash_by = hash_by_map.get(b"**")
|
139
|
+
for key, value in named_args.items():
|
140
|
+
if hash_by := hash_by_map.get(key, rest_hash_by):
|
141
|
+
named_args[key] = hash_by(value)
|
142
|
+
if pos_hash_by := hash_by_map.get(b"*"):
|
143
|
+
pos_args = tuple(map(pos_hash_by, pos_args))
|
144
|
+
return str(ObjectHash(named_args, pos_args, self.ident.captured_hash, digest_size=16))
|
145
|
+
|
146
|
+
async def _resolve_coroutine(self, call_hash: str, coroutine: Coroutine):
|
147
|
+
return self.storage.store(call_hash, AwaitableValue(await coroutine)).value
|
148
|
+
|
149
|
+
def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
|
150
|
+
full_args = self.bound + args
|
151
|
+
params = self.checkpointer
|
152
|
+
if not params.when:
|
153
|
+
return self.fn(*full_args, **kw)
|
154
|
+
|
155
|
+
call_hash = self.get_call_hash(args, kw)
|
156
|
+
call_hash_long = f"{self.fn_dir}/{self.ident.fn_hash}/{call_hash}"
|
157
|
+
|
158
|
+
refresh = rerun \
|
159
|
+
or not self.storage.exists(call_hash) \
|
160
|
+
or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_hash)))
|
161
|
+
|
162
|
+
if refresh:
|
163
|
+
print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_hash_long, "blue")
|
164
|
+
data = self.fn(*full_args, **kw)
|
165
|
+
if iscoroutine(data):
|
166
|
+
return self._resolve_coroutine(call_hash, data)
|
167
|
+
return self.storage.store(call_hash, data)
|
168
|
+
|
169
|
+
try:
|
170
|
+
data = self.storage.load(call_hash)
|
171
|
+
print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_hash_long, "green")
|
172
|
+
return data
|
173
|
+
except (EOFError, FileNotFoundError):
|
174
|
+
pass
|
175
|
+
print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_hash_long, "yellow")
|
176
|
+
return self._call(args, kw, True)
|
177
|
+
|
178
|
+
def __call__(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
|
179
|
+
return self._call(args, kw)
|
180
|
+
|
181
|
+
def rerun(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
|
182
|
+
return self._call(args, kw, True)
|
183
|
+
|
184
|
+
def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
|
185
|
+
return self.storage.exists(self.get_call_hash(args, kw))
|
186
|
+
|
187
|
+
def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
|
188
|
+
self.storage.delete(self.get_call_hash(args, kw))
|
189
|
+
|
190
|
+
@overload
|
191
|
+
def get(self: Callable[P, Coroutine[object, object, R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
192
|
+
@overload
|
193
|
+
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
194
|
+
def get(self, *args, **kw):
|
195
|
+
call_hash = self.get_call_hash(args, kw)
|
196
|
+
try:
|
197
|
+
data = self.storage.load(call_hash)
|
198
|
+
return data.value if isinstance(data, AwaitableValue) else data
|
199
|
+
except Exception as ex:
|
200
|
+
raise CheckpointError("Could not load checkpoint") from ex
|
201
|
+
|
202
|
+
@overload
|
203
|
+
def _set(self: Callable[P, Coroutine[object, object, R]], value: AwaitableValue[R], *args: P.args, **kw: P.kwargs): ...
|
204
|
+
@overload
|
205
|
+
def _set(self: Callable[P, R], value: R, *args: P.args, **kw: P.kwargs): ...
|
206
|
+
def _set(self, value, *args, **kw):
|
207
|
+
self.storage.store(self.get_call_hash(args, kw), value)
|
208
|
+
|
209
|
+
def __repr__(self) -> str:
|
210
|
+
return f"<CachedFunction {self.fn.__name__} {self.ident.fn_hash[:6]}>"
|
211
|
+
|
212
|
+
def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
|
213
|
+
if self not in visited:
|
214
|
+
yield self
|
215
|
+
visited = visited or set()
|
216
|
+
visited.add(self)
|
217
|
+
for depend in self.depends:
|
218
|
+
if isinstance(depend, CachedFunction):
|
219
|
+
yield from depend.deep_depends(visited)
|
220
|
+
|
221
|
+
def hash_by_from_annotation(annotation: type) -> Callable[[object], object] | None:
|
222
|
+
if get_origin(annotation) is Annotated:
|
223
|
+
args = get_args(annotation)
|
224
|
+
metadata = args[1] if len(args) > 1 else None
|
225
|
+
if get_origin(metadata) is HashBy:
|
226
|
+
return get_args(metadata)[0]
|
227
|
+
|
228
|
+
def get_hash_by_map(sig: Signature) -> dict[str | bytes, Callable[[object], object]]:
|
229
|
+
hash_by_map = {}
|
230
|
+
for name, param in sig.parameters.items():
|
231
|
+
if param.kind == Parameter.VAR_POSITIONAL:
|
232
|
+
name = b"*"
|
233
|
+
elif param.kind == Parameter.VAR_KEYWORD:
|
234
|
+
name = b"**"
|
235
|
+
hash_by_map[name] = hash_by_from_annotation(param.annotation)
|
236
|
+
return hash_by_map if any(hash_by_map.values()) else {}
|
@@ -1,15 +1,19 @@
|
|
1
1
|
import dis
|
2
2
|
import inspect
|
3
|
-
from collections.abc import Callable
|
4
3
|
from itertools import takewhile
|
5
4
|
from pathlib import Path
|
6
5
|
from types import CodeType, FunctionType, MethodType
|
7
|
-
from typing import
|
6
|
+
from typing import Callable, Iterable, NamedTuple, Type, TypeGuard
|
8
7
|
from .object_hash import ObjectHash
|
9
|
-
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming,
|
8
|
+
from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, unwrap_fn
|
10
9
|
|
11
10
|
cwd = Path.cwd().resolve()
|
12
11
|
|
12
|
+
class RawFunctionIdent(NamedTuple):
|
13
|
+
fn_hash: str
|
14
|
+
captured_hash: str
|
15
|
+
depends: list[Callable]
|
16
|
+
|
13
17
|
def is_class(obj) -> TypeGuard[Type]:
|
14
18
|
# isinstance works too, but needlessly triggers _lazyinit()
|
15
19
|
return issubclass(type(obj), type)
|
@@ -33,7 +37,7 @@ def extract_classvars(code: CodeType, scope_vars: AttrDict) -> dict[str, dict[st
|
|
33
37
|
scope_obj = None
|
34
38
|
return classvars
|
35
39
|
|
36
|
-
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...],
|
40
|
+
def extract_scope_values(code: CodeType, scope_vars: AttrDict) -> Iterable[tuple[tuple[str, ...], object]]:
|
37
41
|
classvars = extract_classvars(code, scope_vars)
|
38
42
|
scope_vars = scope_vars.set({k: scope_vars[k].set(v) for k, v in classvars.items()})
|
39
43
|
for instr, upcoming_instrs in iterate_and_upcoming(dis.get_instructions(code)):
|
@@ -55,7 +59,7 @@ def get_self_value(fn: Callable) -> type | object | None:
|
|
55
59
|
if is_class(cls):
|
56
60
|
return cls
|
57
61
|
|
58
|
-
def get_fn_captured_vals(fn: Callable) -> list[
|
62
|
+
def get_fn_captured_vals(fn: Callable) -> list[object]:
|
59
63
|
self_value = get_self_value(fn)
|
60
64
|
scope_vars = AttrDict({
|
61
65
|
"LOAD_FAST": AttrDict({"self": self_value} if self_value else {}),
|
@@ -71,24 +75,29 @@ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
|
|
71
75
|
fn_path = Path(inspect.getfile(candidate_fn)).resolve()
|
72
76
|
return cwd in fn_path.parents and ".venv" not in fn_path.parts
|
73
77
|
|
74
|
-
def get_depend_fns(fn: Callable,
|
78
|
+
def get_depend_fns(fn: Callable, captured_vals_by_fn: dict[Callable, list[object]] = {}) -> dict[Callable, list[object]]:
|
75
79
|
from .checkpoint import CachedFunction
|
76
|
-
captured_vals_by_fn = captured_vals_by_fn or {}
|
77
80
|
captured_vals = get_fn_captured_vals(fn)
|
78
|
-
captured_vals_by_fn
|
79
|
-
|
80
|
-
for
|
81
|
+
captured_vals_by_fn = captured_vals_by_fn or {}
|
82
|
+
captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)]
|
83
|
+
for val in captured_vals:
|
84
|
+
if not callable(val):
|
85
|
+
continue
|
86
|
+
child_fn = unwrap_fn(val, cached_fn=True)
|
81
87
|
if isinstance(child_fn, CachedFunction):
|
82
88
|
captured_vals_by_fn[child_fn] = []
|
83
89
|
elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
|
84
|
-
get_depend_fns(child_fn,
|
90
|
+
get_depend_fns(child_fn, captured_vals_by_fn)
|
85
91
|
return captured_vals_by_fn
|
86
92
|
|
87
|
-
def get_fn_ident(fn: Callable, capture: bool) ->
|
93
|
+
def get_fn_ident(fn: Callable, capture: bool) -> RawFunctionIdent:
|
88
94
|
from .checkpoint import CachedFunction
|
89
|
-
captured_vals_by_fn = get_depend_fns(fn
|
90
|
-
|
95
|
+
captured_vals_by_fn = get_depend_fns(fn)
|
96
|
+
depend_captured_vals = list(captured_vals_by_fn.values()) * capture
|
97
|
+
depends = captured_vals_by_fn.keys()
|
91
98
|
depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
|
92
99
|
unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
|
93
|
-
|
94
|
-
|
100
|
+
assert fn == unwrapped_depends[0]
|
101
|
+
fn_hash = str(ObjectHash(iter=unwrapped_depends))
|
102
|
+
captured_hash = str(ObjectHash(iter=depend_captured_vals, tolerate_errors=True))
|
103
|
+
return RawFunctionIdent(fn_hash, captured_hash, depends)
|
@@ -1,16 +1,19 @@
|
|
1
1
|
import ctypes
|
2
2
|
import hashlib
|
3
|
+
import inspect
|
3
4
|
import io
|
4
5
|
import re
|
5
6
|
import sys
|
7
|
+
import tokenize
|
6
8
|
from collections.abc import Iterable
|
7
9
|
from contextlib import nullcontext, suppress
|
8
10
|
from decimal import Decimal
|
11
|
+
from io import StringIO
|
9
12
|
from itertools import chain
|
10
|
-
from pickle import HIGHEST_PROTOCOL as
|
13
|
+
from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
|
11
14
|
from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
|
12
|
-
from typing import
|
13
|
-
from .utils import ContextVar
|
15
|
+
from typing import Callable, TypeVar
|
16
|
+
from .utils import ContextVar
|
14
17
|
|
15
18
|
np, torch = None, None
|
16
19
|
|
@@ -31,16 +34,16 @@ else:
|
|
31
34
|
def encode_type(t: type | FunctionType) -> str:
|
32
35
|
return f"{t.__module__}:{t.__qualname__}"
|
33
36
|
|
34
|
-
def encode_type_of(v:
|
37
|
+
def encode_type_of(v: object) -> str:
|
35
38
|
return encode_type(type(v))
|
36
39
|
|
37
40
|
class ObjectHashError(Exception):
|
38
|
-
def __init__(self, obj:
|
41
|
+
def __init__(self, obj: object, cause: Exception):
|
39
42
|
super().__init__(f"{type(cause).__name__} error when hashing {obj}")
|
40
43
|
self.obj = obj
|
41
44
|
|
42
45
|
class ObjectHash:
|
43
|
-
def __init__(self, *objs:
|
46
|
+
def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerate_errors=False) -> None:
|
44
47
|
self.hash = hashlib.blake2b(digest_size=digest_size)
|
45
48
|
self.current: dict[int, int] = {}
|
46
49
|
self.tolerate_errors = ContextVar(tolerate_errors)
|
@@ -59,7 +62,7 @@ class ObjectHash:
|
|
59
62
|
def __eq__(self, value: object) -> bool:
|
60
63
|
return isinstance(value, ObjectHash) and str(self) == str(value)
|
61
64
|
|
62
|
-
def nested_hash(self, *objs:
|
65
|
+
def nested_hash(self, *objs: object) -> str:
|
63
66
|
return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
|
64
67
|
|
65
68
|
def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
|
@@ -70,10 +73,10 @@ class ObjectHash:
|
|
70
73
|
def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
|
71
74
|
return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
|
72
75
|
|
73
|
-
def header(self, *args:
|
76
|
+
def header(self, *args: object) -> "ObjectHash":
|
74
77
|
return self.write_bytes(":".join(map(str, args)).encode())
|
75
78
|
|
76
|
-
def update(self, *objs:
|
79
|
+
def update(self, *objs: object, iter: Iterable[object] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
|
77
80
|
with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
|
78
81
|
for obj in chain(objs, iter):
|
79
82
|
try:
|
@@ -81,11 +84,11 @@ class ObjectHash:
|
|
81
84
|
except Exception as ex:
|
82
85
|
if self.tolerate_errors.value:
|
83
86
|
self.header("error").update(type(ex))
|
84
|
-
|
85
|
-
|
87
|
+
else:
|
88
|
+
raise ObjectHashError(obj, ex) from ex
|
86
89
|
return self
|
87
90
|
|
88
|
-
def _update_one(self, obj:
|
91
|
+
def _update_one(self, obj: object) -> None:
|
89
92
|
match obj:
|
90
93
|
case None:
|
91
94
|
self.header("null")
|
@@ -142,7 +145,7 @@ class ObjectHash:
|
|
142
145
|
case _ if np and isinstance(obj, np.ndarray):
|
143
146
|
self.header("ndarray", encode_type_of(obj), obj.shape, obj.strides).update(obj.dtype)
|
144
147
|
if obj.dtype.hasobject:
|
145
|
-
self.update(obj.__reduce_ex__(
|
148
|
+
self.update(obj.__reduce_ex__(PICKLE_PROTOCOL))
|
146
149
|
else:
|
147
150
|
array = np.ascontiguousarray(obj if obj.base is None else obj.base).view(np.uint8)
|
148
151
|
self.write_bytes(array.data)
|
@@ -180,13 +183,14 @@ class ObjectHash:
|
|
180
183
|
def _update_iterator(self, obj: Iterable) -> "ObjectHash":
|
181
184
|
return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
|
182
185
|
|
183
|
-
def _update_object(self, obj:
|
186
|
+
def _update_object(self, obj: object) -> "ObjectHash":
|
184
187
|
self.header("instance", encode_type_of(obj))
|
185
|
-
|
186
|
-
|
188
|
+
get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
|
189
|
+
if callable(get_hash):
|
190
|
+
return self.header("objecthash").update(get_hash())
|
187
191
|
reduced = None
|
188
192
|
with suppress(Exception):
|
189
|
-
reduced = obj.__reduce_ex__(
|
193
|
+
reduced = obj.__reduce_ex__(PICKLE_PROTOCOL)
|
190
194
|
with suppress(Exception):
|
191
195
|
reduced = reduced or obj.__reduce__()
|
192
196
|
if isinstance(reduced, str):
|
@@ -206,3 +210,12 @@ class ObjectHash:
|
|
206
210
|
return self._update_iterator(obj)
|
207
211
|
repr_str = re.sub(r"\s+(at\s+0x[0-9a-fA-F]+)(>)$", r"\2", repr(obj))
|
208
212
|
return self.header("repr").update(repr_str)
|
213
|
+
|
214
|
+
def get_fn_body(fn: Callable) -> str:
|
215
|
+
try:
|
216
|
+
source = inspect.getsource(fn)
|
217
|
+
except OSError:
|
218
|
+
return ""
|
219
|
+
tokens = tokenize.generate_tokens(StringIO(source).readline)
|
220
|
+
ignore_types = (tokenize.COMMENT, tokenize.NL)
|
221
|
+
return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
|
@@ -9,21 +9,21 @@ class MemoryStorage(Storage):
|
|
9
9
|
def get_dict(self):
|
10
10
|
return item_map.setdefault(self.fn_dir(), {})
|
11
11
|
|
12
|
-
def store(self,
|
13
|
-
self.get_dict()[
|
12
|
+
def store(self, call_hash, data):
|
13
|
+
self.get_dict()[call_hash] = (datetime.now(), data)
|
14
14
|
return data
|
15
15
|
|
16
|
-
def exists(self,
|
17
|
-
return
|
16
|
+
def exists(self, call_hash):
|
17
|
+
return call_hash in self.get_dict()
|
18
18
|
|
19
|
-
def checkpoint_date(self,
|
20
|
-
return self.get_dict()[
|
19
|
+
def checkpoint_date(self, call_hash):
|
20
|
+
return self.get_dict()[call_hash][0]
|
21
21
|
|
22
|
-
def load(self,
|
23
|
-
return self.get_dict()[
|
22
|
+
def load(self, call_hash):
|
23
|
+
return self.get_dict()[call_hash][1]
|
24
24
|
|
25
|
-
def delete(self,
|
26
|
-
self.get_dict().pop(
|
25
|
+
def delete(self, call_hash):
|
26
|
+
self.get_dict().pop(call_hash, None)
|
27
27
|
|
28
28
|
def cleanup(self, invalidated=True, expired=True):
|
29
29
|
curr_key = self.fn_dir()
|
@@ -32,6 +32,6 @@ class MemoryStorage(Storage):
|
|
32
32
|
if invalidated and key != curr_key:
|
33
33
|
del item_map[key]
|
34
34
|
elif expired and self.checkpointer.should_expire:
|
35
|
-
for
|
35
|
+
for call_hash, (date, _) in list(calldict.items()):
|
36
36
|
if self.checkpointer.should_expire(date):
|
37
|
-
del calldict[
|
37
|
+
del calldict[call_hash]
|
@@ -8,29 +8,29 @@ def filedate(path: Path) -> datetime:
|
|
8
8
|
return datetime.fromtimestamp(path.stat().st_mtime)
|
9
9
|
|
10
10
|
class PickleStorage(Storage):
|
11
|
-
def get_path(self,
|
12
|
-
return self.fn_dir() / f"{
|
11
|
+
def get_path(self, call_hash: str):
|
12
|
+
return self.fn_dir() / f"{call_hash}.pkl"
|
13
13
|
|
14
|
-
def store(self,
|
15
|
-
path = self.get_path(
|
14
|
+
def store(self, call_hash, data):
|
15
|
+
path = self.get_path(call_hash)
|
16
16
|
path.parent.mkdir(parents=True, exist_ok=True)
|
17
17
|
with path.open("wb") as file:
|
18
18
|
pickle.dump(data, file, -1)
|
19
19
|
return data
|
20
20
|
|
21
|
-
def exists(self,
|
22
|
-
return self.get_path(
|
21
|
+
def exists(self, call_hash):
|
22
|
+
return self.get_path(call_hash).exists()
|
23
23
|
|
24
|
-
def checkpoint_date(self,
|
24
|
+
def checkpoint_date(self, call_hash):
|
25
25
|
# Should use st_atime/access time?
|
26
|
-
return filedate(self.get_path(
|
26
|
+
return filedate(self.get_path(call_hash))
|
27
27
|
|
28
|
-
def load(self,
|
29
|
-
with self.get_path(
|
28
|
+
def load(self, call_hash):
|
29
|
+
with self.get_path(call_hash).open("rb") as file:
|
30
30
|
return pickle.load(file)
|
31
31
|
|
32
|
-
def delete(self,
|
33
|
-
self.get_path(
|
32
|
+
def delete(self, call_hash):
|
33
|
+
self.get_path(call_hash).unlink(missing_ok=True)
|
34
34
|
|
35
35
|
def cleanup(self, invalidated=True, expired=True):
|
36
36
|
version_path = self.fn_dir()
|
@@ -15,19 +15,19 @@ class Storage:
|
|
15
15
|
self.cached_fn = cached_fn
|
16
16
|
|
17
17
|
def fn_id(self) -> str:
|
18
|
-
return f"{self.cached_fn.fn_dir}/{self.cached_fn.fn_hash}"
|
18
|
+
return f"{self.cached_fn.fn_dir}/{self.cached_fn.ident.fn_hash}"
|
19
19
|
|
20
20
|
def fn_dir(self) -> Path:
|
21
21
|
return self.checkpointer.root_path / self.fn_id()
|
22
22
|
|
23
|
-
def store(self,
|
23
|
+
def store(self, call_hash: str, data: Any) -> Any: ...
|
24
24
|
|
25
|
-
def exists(self,
|
25
|
+
def exists(self, call_hash: str) -> bool: ...
|
26
26
|
|
27
|
-
def checkpoint_date(self,
|
27
|
+
def checkpoint_date(self, call_hash: str) -> datetime: ...
|
28
28
|
|
29
|
-
def load(self,
|
29
|
+
def load(self, call_hash: str) -> Any: ...
|
30
30
|
|
31
|
-
def delete(self,
|
31
|
+
def delete(self, call_hash: str) -> None: ...
|
32
32
|
|
33
33
|
def cleanup(self, invalidated=True, expired=True): ...
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import pytest
|
3
|
-
from
|
3
|
+
from checkpointer import CheckpointError, checkpoint
|
4
4
|
from .utils import AttrDict
|
5
5
|
|
6
6
|
def global_multiply(a: int, b: int) -> int:
|
@@ -110,16 +110,16 @@ def test_capture():
|
|
110
110
|
def test_a():
|
111
111
|
return item_dict.a + 1
|
112
112
|
|
113
|
-
init_hash_a = test_a.
|
114
|
-
init_hash_whole = test_whole.
|
113
|
+
init_hash_a = test_a.ident.captured_hash
|
114
|
+
init_hash_whole = test_whole.ident.captured_hash
|
115
115
|
item_dict.b += 1
|
116
116
|
test_whole.reinit()
|
117
117
|
test_a.reinit()
|
118
|
-
assert test_whole.
|
119
|
-
assert test_a.
|
118
|
+
assert test_whole.ident.captured_hash != init_hash_whole
|
119
|
+
assert test_a.ident.captured_hash == init_hash_a
|
120
120
|
item_dict.a += 1
|
121
121
|
test_a.reinit()
|
122
|
-
assert test_a.
|
122
|
+
assert test_a.ident.captured_hash != init_hash_a
|
123
123
|
|
124
124
|
def test_depends():
|
125
125
|
def multiply_wrapper(a: int, b: int) -> int:
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from typing import Annotated, Callable, Generic, TypeVar
|
2
|
+
|
3
|
+
T = TypeVar("T")
|
4
|
+
Fn = TypeVar("Fn", bound=Callable)
|
5
|
+
|
6
|
+
class HashBy(Generic[Fn]):
|
7
|
+
pass
|
8
|
+
|
9
|
+
NoHash = Annotated[T, HashBy[lambda _: None]]
|
10
|
+
|
11
|
+
class AwaitableValue(Generic[T]):
|
12
|
+
def __init__(self, value: T):
|
13
|
+
self.value = value
|
14
|
+
|
15
|
+
def __await__(self):
|
16
|
+
yield
|
17
|
+
return self.value
|
@@ -1,7 +1,4 @@
|
|
1
|
-
import inspect
|
2
|
-
import tokenize
|
3
1
|
from contextlib import contextmanager
|
4
|
-
from io import StringIO
|
5
2
|
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
6
3
|
|
7
4
|
T = TypeVar("T")
|
@@ -10,21 +7,6 @@ Fn = TypeVar("Fn", bound=Callable)
|
|
10
7
|
def distinct(seq: Iterable[T]) -> list[T]:
|
11
8
|
return list(dict.fromkeys(seq))
|
12
9
|
|
13
|
-
def transpose(tuples, default_num_returns=0):
|
14
|
-
output = tuple(zip(*tuples))
|
15
|
-
if not output:
|
16
|
-
return ([],) * default_num_returns
|
17
|
-
return tuple(map(list, output))
|
18
|
-
|
19
|
-
def get_fn_body(fn: Callable) -> str:
|
20
|
-
try:
|
21
|
-
source = inspect.getsource(fn)
|
22
|
-
except OSError:
|
23
|
-
return ""
|
24
|
-
tokens = tokenize.generate_tokens(StringIO(source).readline)
|
25
|
-
ignore_types = (tokenize.COMMENT, tokenize.NL)
|
26
|
-
return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
|
27
|
-
|
28
10
|
def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
|
29
11
|
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
|
30
12
|
try:
|
@@ -39,14 +21,6 @@ def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
|
|
39
21
|
return cast(Fn, fn)
|
40
22
|
fn = getattr(fn, "__wrapped__")
|
41
23
|
|
42
|
-
class AwaitableValue:
|
43
|
-
def __init__(self, value):
|
44
|
-
self.value = value
|
45
|
-
|
46
|
-
def __await__(self):
|
47
|
-
yield
|
48
|
-
return self.value
|
49
|
-
|
50
24
|
class AttrDict(dict):
|
51
25
|
def __init__(self, *args, **kwargs):
|
52
26
|
super().__init__(*args, **kwargs)
|
@@ -91,14 +65,19 @@ class iterate_and_upcoming(Generic[T]):
|
|
91
65
|
def __init__(self, it: Iterable[T]) -> None:
|
92
66
|
self.it = iter(it)
|
93
67
|
self.previous: tuple[()] | tuple[T] = ()
|
68
|
+
self.tracked = self._tracked_iter()
|
94
69
|
|
95
70
|
def __iter__(self):
|
96
71
|
return self
|
97
72
|
|
98
73
|
def __next__(self) -> tuple[T, Iterable[T]]:
|
99
|
-
|
100
|
-
|
101
|
-
|
74
|
+
try:
|
75
|
+
item = self.previous[0] if self.previous else next(self.it)
|
76
|
+
self.previous = ()
|
77
|
+
return item, self.tracked
|
78
|
+
except StopIteration:
|
79
|
+
self.tracked.close()
|
80
|
+
raise
|
102
81
|
|
103
82
|
def _tracked_iter(self):
|
104
83
|
for x in self.it:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "checkpointer"
|
3
|
-
version = "2.
|
3
|
+
version = "2.11.1"
|
4
4
|
requires-python = ">=3.11"
|
5
5
|
dependencies = []
|
6
6
|
authors = [
|
@@ -21,10 +21,11 @@ Repository = "https://github.com/Reddan/checkpointer.git"
|
|
21
21
|
[dependency-groups]
|
22
22
|
dev = [
|
23
23
|
"numpy>=2.2.1",
|
24
|
-
"omg>=1.3.
|
24
|
+
"omg>=1.3.6",
|
25
25
|
"poethepoet>=0.30.0",
|
26
26
|
"pytest>=8.3.5",
|
27
27
|
"pytest-asyncio>=0.26.0",
|
28
|
+
"rich>=14.0.0",
|
28
29
|
"torch>=2.5.1",
|
29
30
|
]
|
30
31
|
|
@@ -8,7 +8,7 @@ resolution-markers = [
|
|
8
8
|
|
9
9
|
[[package]]
|
10
10
|
name = "checkpointer"
|
11
|
-
version = "2.
|
11
|
+
version = "2.11.1"
|
12
12
|
source = { editable = "." }
|
13
13
|
|
14
14
|
[package.dev-dependencies]
|
@@ -18,6 +18,7 @@ dev = [
|
|
18
18
|
{ name = "poethepoet" },
|
19
19
|
{ name = "pytest" },
|
20
20
|
{ name = "pytest-asyncio" },
|
21
|
+
{ name = "rich" },
|
21
22
|
{ name = "torch" },
|
22
23
|
]
|
23
24
|
|
@@ -26,10 +27,11 @@ dev = [
|
|
26
27
|
[package.metadata.requires-dev]
|
27
28
|
dev = [
|
28
29
|
{ name = "numpy", specifier = ">=2.2.1" },
|
29
|
-
{ name = "omg", specifier = ">=1.3.
|
30
|
+
{ name = "omg", specifier = ">=1.3.6" },
|
30
31
|
{ name = "poethepoet", specifier = ">=0.30.0" },
|
31
32
|
{ name = "pytest", specifier = ">=8.3.5" },
|
32
33
|
{ name = "pytest-asyncio", specifier = ">=0.26.0" },
|
34
|
+
{ name = "rich", specifier = ">=14.0.0" },
|
33
35
|
{ name = "torch", specifier = ">=2.5.1" },
|
34
36
|
]
|
35
37
|
|
@@ -81,6 +83,18 @@ wheels = [
|
|
81
83
|
{ url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 },
|
82
84
|
]
|
83
85
|
|
86
|
+
[[package]]
|
87
|
+
name = "markdown-it-py"
|
88
|
+
version = "3.0.0"
|
89
|
+
source = { registry = "https://pypi.org/simple" }
|
90
|
+
dependencies = [
|
91
|
+
{ name = "mdurl" },
|
92
|
+
]
|
93
|
+
sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 }
|
94
|
+
wheels = [
|
95
|
+
{ url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 },
|
96
|
+
]
|
97
|
+
|
84
98
|
[[package]]
|
85
99
|
name = "markupsafe"
|
86
100
|
version = "3.0.2"
|
@@ -129,6 +143,15 @@ wheels = [
|
|
129
143
|
{ url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739 },
|
130
144
|
]
|
131
145
|
|
146
|
+
[[package]]
|
147
|
+
name = "mdurl"
|
148
|
+
version = "0.1.2"
|
149
|
+
source = { registry = "https://pypi.org/simple" }
|
150
|
+
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 }
|
151
|
+
wheels = [
|
152
|
+
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
|
153
|
+
]
|
154
|
+
|
132
155
|
[[package]]
|
133
156
|
name = "mpmath"
|
134
157
|
version = "1.3.0"
|
@@ -307,14 +330,14 @@ wheels = [
|
|
307
330
|
|
308
331
|
[[package]]
|
309
332
|
name = "omg"
|
310
|
-
version = "1.3.
|
333
|
+
version = "1.3.6"
|
311
334
|
source = { registry = "https://pypi.org/simple" }
|
312
335
|
dependencies = [
|
313
336
|
{ name = "watchdog" },
|
314
337
|
]
|
315
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
338
|
+
sdist = { url = "https://files.pythonhosted.org/packages/65/06/da0a3778b7ff8f1333ed7ddc0931ffff3c86ab5cb8bc4a96a1d0edb8671b/omg-1.3.6.tar.gz", hash = "sha256:465a51b7576fa31ef313e2b9a77d57f5d4816fb0a14dca0fc5c09ff471074fe6", size = 14268 }
|
316
339
|
wheels = [
|
317
|
-
{ url = "https://files.pythonhosted.org/packages/
|
340
|
+
{ url = "https://files.pythonhosted.org/packages/dd/d2/87346e94dbecd3a65a09e2156c1adf30c162f31e69d0936343c3eff53e7a/omg-1.3.6-py3-none-any.whl", hash = "sha256:8e3ac99a18d5284ceef2ed98492d288d5f22ee2bb417591654a7d2433e196607", size = 7988 },
|
318
341
|
]
|
319
342
|
|
320
343
|
[[package]]
|
@@ -357,6 +380,15 @@ wheels = [
|
|
357
380
|
{ url = "https://files.pythonhosted.org/packages/06/e1/04f56c9d848d6135ca3328c5a2ca84d3303c358ad7828db290385e36a8cc/poethepoet-0.31.1-py3-none-any.whl", hash = "sha256:7fdfa0ac6074be9936723e7231b5bfaad2923e96c674a9857e81d326cf8ccdc2", size = 80238 },
|
358
381
|
]
|
359
382
|
|
383
|
+
[[package]]
|
384
|
+
name = "pygments"
|
385
|
+
version = "2.19.1"
|
386
|
+
source = { registry = "https://pypi.org/simple" }
|
387
|
+
sdist = { url = "https://files.pythonhosted.org/packages/7c/2d/c3338d48ea6cc0feb8446d8e6937e1408088a72a39937982cc6111d17f84/pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f", size = 4968581 }
|
388
|
+
wheels = [
|
389
|
+
{ url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
|
390
|
+
]
|
391
|
+
|
360
392
|
[[package]]
|
361
393
|
name = "pytest"
|
362
394
|
version = "8.3.5"
|
@@ -419,6 +451,19 @@ wheels = [
|
|
419
451
|
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
|
420
452
|
]
|
421
453
|
|
454
|
+
[[package]]
|
455
|
+
name = "rich"
|
456
|
+
version = "14.0.0"
|
457
|
+
source = { registry = "https://pypi.org/simple" }
|
458
|
+
dependencies = [
|
459
|
+
{ name = "markdown-it-py" },
|
460
|
+
{ name = "pygments" },
|
461
|
+
]
|
462
|
+
sdist = { url = "https://files.pythonhosted.org/packages/a1/53/830aa4c3066a8ab0ae9a9955976fb770fe9c6102117c8ec4ab3ea62d89e8/rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725", size = 224078 }
|
463
|
+
wheels = [
|
464
|
+
{ url = "https://files.pythonhosted.org/packages/0d/9b/63f4c7ebc259242c89b3acafdb37b41d1185c07ff0011164674e9076b491/rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0", size = 243229 },
|
465
|
+
]
|
466
|
+
|
422
467
|
[[package]]
|
423
468
|
name = "setuptools"
|
424
469
|
version = "75.6.0"
|
@@ -1,184 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
import inspect
|
3
|
-
import re
|
4
|
-
from datetime import datetime
|
5
|
-
from functools import cached_property, update_wrapper
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import (
|
8
|
-
Awaitable, Callable, Concatenate, Generic, Iterable, Literal,
|
9
|
-
ParamSpec, Self, Type, TypedDict, TypeVar, Unpack, cast, overload,
|
10
|
-
)
|
11
|
-
from .fn_ident import get_fn_ident
|
12
|
-
from .object_hash import ObjectHash
|
13
|
-
from .print_checkpoint import print_checkpoint
|
14
|
-
from .storages import STORAGE_MAP, Storage
|
15
|
-
from .utils import AwaitableValue, unwrap_fn
|
16
|
-
|
17
|
-
Fn = TypeVar("Fn", bound=Callable)
|
18
|
-
P = ParamSpec("P")
|
19
|
-
R = TypeVar("R")
|
20
|
-
C = TypeVar("C")
|
21
|
-
|
22
|
-
DEFAULT_DIR = Path.home() / ".cache/checkpoints"
|
23
|
-
|
24
|
-
class CheckpointError(Exception):
|
25
|
-
pass
|
26
|
-
|
27
|
-
class CheckpointerOpts(TypedDict, total=False):
|
28
|
-
format: Type[Storage] | Literal["pickle", "memory", "bcolz"]
|
29
|
-
root_path: Path | str | None
|
30
|
-
when: bool
|
31
|
-
verbosity: Literal[0, 1, 2]
|
32
|
-
hash_by: Callable | None
|
33
|
-
should_expire: Callable[[datetime], bool] | None
|
34
|
-
capture: bool
|
35
|
-
fn_hash: ObjectHash | None
|
36
|
-
|
37
|
-
class Checkpointer:
|
38
|
-
def __init__(self, **opts: Unpack[CheckpointerOpts]):
|
39
|
-
self.format = opts.get("format", "pickle")
|
40
|
-
self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
|
41
|
-
self.when = opts.get("when", True)
|
42
|
-
self.verbosity = opts.get("verbosity", 1)
|
43
|
-
self.hash_by = opts.get("hash_by")
|
44
|
-
self.should_expire = opts.get("should_expire")
|
45
|
-
self.capture = opts.get("capture", False)
|
46
|
-
self.fn_hash = opts.get("fn_hash")
|
47
|
-
|
48
|
-
@overload
|
49
|
-
def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
|
50
|
-
@overload
|
51
|
-
def __call__(self, fn: None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer: ...
|
52
|
-
def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer | CachedFunction[Fn]:
|
53
|
-
if override_opts:
|
54
|
-
opts = CheckpointerOpts(**{**self.__dict__, **override_opts})
|
55
|
-
return Checkpointer(**opts)(fn)
|
56
|
-
|
57
|
-
return CachedFunction(self, fn) if callable(fn) else self
|
58
|
-
|
59
|
-
class CachedFunction(Generic[Fn]):
|
60
|
-
def __init__(self, checkpointer: Checkpointer, fn: Fn):
|
61
|
-
wrapped = unwrap_fn(fn)
|
62
|
-
fn_file = Path(wrapped.__code__.co_filename).name
|
63
|
-
fn_name = re.sub(r"[^\w.]", "", wrapped.__qualname__)
|
64
|
-
Storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
|
65
|
-
update_wrapper(cast(Callable, self), wrapped)
|
66
|
-
self.checkpointer = checkpointer
|
67
|
-
self.fn = fn
|
68
|
-
self.fn_dir = f"{fn_file}/{fn_name}"
|
69
|
-
self.storage = Storage(self)
|
70
|
-
self.cleanup = self.storage.cleanup
|
71
|
-
self.bound = ()
|
72
|
-
|
73
|
-
@overload
|
74
|
-
def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
|
75
|
-
@overload
|
76
|
-
def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
|
77
|
-
def __get__(self, instance, owner):
|
78
|
-
if instance is None:
|
79
|
-
return self
|
80
|
-
bound_fn = object.__new__(CachedFunction)
|
81
|
-
bound_fn.__dict__ |= self.__dict__
|
82
|
-
bound_fn.bound = (instance,)
|
83
|
-
return bound_fn
|
84
|
-
|
85
|
-
@cached_property
|
86
|
-
def ident_tuple(self) -> tuple[str, list[Callable]]:
|
87
|
-
return get_fn_ident(unwrap_fn(self.fn), self.checkpointer.capture)
|
88
|
-
|
89
|
-
@property
|
90
|
-
def fn_hash_raw(self) -> str:
|
91
|
-
return self.ident_tuple[0]
|
92
|
-
|
93
|
-
@property
|
94
|
-
def depends(self) -> list[Callable]:
|
95
|
-
return self.ident_tuple[1]
|
96
|
-
|
97
|
-
@cached_property
|
98
|
-
def fn_hash(self) -> str:
|
99
|
-
deep_hashes = [depend.fn_hash_raw for depend in self.deep_depends()]
|
100
|
-
fn_hash = ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes)
|
101
|
-
return str(self.checkpointer.fn_hash or fn_hash)[:32]
|
102
|
-
|
103
|
-
def reinit(self, recursive=False) -> CachedFunction[Fn]:
|
104
|
-
depends = list(self.deep_depends()) if recursive else [self]
|
105
|
-
for depend in depends:
|
106
|
-
self.__dict__.pop("fn_hash", None)
|
107
|
-
self.__dict__.pop("ident_tuple", None)
|
108
|
-
for depend in depends:
|
109
|
-
depend.fn_hash
|
110
|
-
return self
|
111
|
-
|
112
|
-
def get_call_id(self, args: tuple, kw: dict) -> str:
|
113
|
-
args = self.bound + args
|
114
|
-
hash_by = self.checkpointer.hash_by
|
115
|
-
hash_params = hash_by(*args, **kw) if hash_by else (args, kw)
|
116
|
-
return str(ObjectHash(hash_params, digest_size=16))
|
117
|
-
|
118
|
-
async def _resolve_awaitable(self, call_id: str, awaitable: Awaitable):
|
119
|
-
return self.storage.store(call_id, AwaitableValue(await awaitable)).value
|
120
|
-
|
121
|
-
def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
|
122
|
-
full_args = self.bound + args
|
123
|
-
params = self.checkpointer
|
124
|
-
if not params.when:
|
125
|
-
return self.fn(*full_args, **kw)
|
126
|
-
|
127
|
-
call_id = self.get_call_id(args, kw)
|
128
|
-
call_id_long = f"{self.fn_dir}/{self.fn_hash}/{call_id}"
|
129
|
-
|
130
|
-
refresh = rerun \
|
131
|
-
or not self.storage.exists(call_id) \
|
132
|
-
or (params.should_expire and params.should_expire(self.storage.checkpoint_date(call_id)))
|
133
|
-
|
134
|
-
if refresh:
|
135
|
-
print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id_long, "blue")
|
136
|
-
data = self.fn(*full_args, **kw)
|
137
|
-
if inspect.isawaitable(data):
|
138
|
-
return self._resolve_awaitable(call_id, data)
|
139
|
-
return self.storage.store(call_id, data)
|
140
|
-
|
141
|
-
try:
|
142
|
-
data = self.storage.load(call_id)
|
143
|
-
print_checkpoint(params.verbosity >= 2, "REMEMBERED", call_id_long, "green")
|
144
|
-
return data
|
145
|
-
except (EOFError, FileNotFoundError):
|
146
|
-
pass
|
147
|
-
print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_id_long, "yellow")
|
148
|
-
return self._call(args, kw, True)
|
149
|
-
|
150
|
-
def __call__(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
|
151
|
-
return self._call(args, kw)
|
152
|
-
|
153
|
-
def rerun(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
|
154
|
-
return self._call(args, kw, True)
|
155
|
-
|
156
|
-
@overload
|
157
|
-
def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
|
158
|
-
@overload
|
159
|
-
def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
|
160
|
-
def get(self, *args, **kw):
|
161
|
-
call_id = self.get_call_id(args, kw)
|
162
|
-
try:
|
163
|
-
data = self.storage.load(call_id)
|
164
|
-
return data.value if isinstance(data, AwaitableValue) else data
|
165
|
-
except Exception as ex:
|
166
|
-
raise CheckpointError("Could not load checkpoint") from ex
|
167
|
-
|
168
|
-
def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
|
169
|
-
return self.storage.exists(self.get_call_id(args, kw))
|
170
|
-
|
171
|
-
def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
|
172
|
-
self.storage.delete(self.get_call_id(args, kw))
|
173
|
-
|
174
|
-
def __repr__(self) -> str:
|
175
|
-
return f"<CachedFunction {self.fn.__name__} {self.fn_hash[:6]}>"
|
176
|
-
|
177
|
-
def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
|
178
|
-
if self not in visited:
|
179
|
-
yield self
|
180
|
-
visited = visited or set()
|
181
|
-
visited.add(self)
|
182
|
-
for depend in self.depends:
|
183
|
-
if isinstance(depend, CachedFunction):
|
184
|
-
yield from depend.deep_depends(visited)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|