checkpointer 2.9.1__tar.gz → 2.10.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpointer
3
- Version: 2.9.1
3
+ Version: 2.10.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -48,19 +48,14 @@ result = expensive_function(4) # Loads from the cache
48
48
 
49
49
  ## 🧠 How It Works
50
50
 
51
- When a function decorated with `@checkpoint` is called:
51
+ When a `@checkpoint`-decorated function is called, `checkpointer` first computes a unique identifier (hash) for the call. This hash is derived from the function's source code, its dependencies, and the arguments passed.
52
52
 
53
- 1. `checkpointer` computes a unique identifier (hash) for the function call based on its source code, its dependencies, and the arguments passed.
54
- 2. It attempts to retrieve a cached result using this identifier.
55
- 3. If a cached result is found, it's returned immediately.
56
- 4. If no cached result exists or the cache has expired, the original function is executed, its result is stored, and then returned.
53
+ It then tries to retrieve a cached result using this ID. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
57
54
 
58
- ### ♻️ Automatic Cache Invalidation
55
+ Cache validity is determined by this function's hash, which automatically updates if:
59
56
 
60
- `checkpointer` ensures caches are invalidated automatically when the underlying computation changes. A function's hash, which determines cache validity, updates if:
61
-
62
- * **Function Code Changes**: The source code of the decorated function itself is modified.
63
- * **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated with `@checkpoint`) is modified.
57
+ * **Function Code Changes**: The decorated function's source code is modified.
58
+ * **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated) is modified.
64
59
  * **Captured Variables Change** (with `capture=True`): Global or closure-based variables used within the function are altered.
65
60
 
66
61
  **Example: Dependency Invalidation**
@@ -76,63 +71,63 @@ def helper(x):
76
71
 
77
72
  @checkpoint
78
73
  def compute(a, b):
79
- # Depends on `helper`
74
+ # Depends on `helper` and `multiply`
80
75
  return helper(a) + helper(b)
81
76
  ```
82
77
 
83
- If `multiply` is modified, caches for both `helper` and `compute` will automatically be invalidated and recomputed upon their next call.
78
+ If `multiply` is modified, caches for both `helper` and `compute` are automatically invalidated and recomputed.
84
79
 
85
80
  ## 💡 Usage
86
81
 
87
82
  Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
88
83
 
89
- * **`expensive_function(...)`**:
84
+ * **`expensive_function(...)`**:\
90
85
  Call the function normally. This will either compute and cache the result or load it from the cache if available.
91
86
 
92
- * **`expensive_function.rerun(...)`**:
87
+ * **`expensive_function.rerun(...)`**:\
93
88
  Forces the original function to execute, compute a new result, and overwrite any existing cached value for the given arguments.
94
89
 
95
- * **`expensive_function.fn(...)`**:
90
+ * **`expensive_function.fn(...)`**:\
96
91
  Calls the original, undecorated function directly, bypassing the cache entirely. This is particularly useful within recursive functions to prevent caching intermediate steps.
97
92
 
98
- * **`expensive_function.get(...)`**:
93
+ * **`expensive_function.get(...)`**:\
99
94
  Attempts to retrieve the cached result for the given arguments without executing the original function. Raises `CheckpointError` if no valid cached result exists.
100
95
 
101
- * **`expensive_function.exists(...)`**:
96
+ * **`expensive_function.exists(...)`**:\
102
97
  Checks if a cached result exists for the given arguments without attempting to compute or load it. Returns `True` if a valid checkpoint exists, `False` otherwise.
103
98
 
104
- * **`expensive_function.delete(...)`**:
99
+ * **`expensive_function.delete(...)`**:\
105
100
  Removes the cached entry for the specified arguments.
106
101
 
107
- * **`expensive_function.reinit()`**:
102
+ * **`expensive_function.reinit()`**:\
108
103
  Recalculates the function's internal hash. This is primarily used when `capture=True` and you need to update the cache based on changes to external variables within the same Python session.
109
104
 
110
105
  ## ⚙️ Configuration & Customization
111
106
 
112
107
  The `@checkpoint` decorator accepts the following parameters to customize its behavior:
113
108
 
114
- * **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)
109
+ * **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
115
110
  Defines the storage backend to use. Built-in options are `"pickle"` (disk-based, persistent) and `"memory"` (in-memory, non-persistent). You can also provide a custom `Storage` class.
116
111
 
117
- * **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)
112
+ * **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
118
113
  The base directory for storing disk-based checkpoints. This parameter is only relevant when `format` is set to `"pickle"`.
119
114
 
120
- * **`when`** (Type: `bool`, Default: `True`)
115
+ * **`when`** (Type: `bool`, Default: `True`)\
121
116
  A boolean flag to enable or disable checkpointing for the decorated function. This is particularly useful for toggling caching based on environment variables (e.g., `when=os.environ.get("ENABLE_CACHING", "false").lower() == "true"`).
122
117
 
123
- * **`capture`** (Type: `bool`, Default: `False`)
118
+ * **`capture`** (Type: `bool`, Default: `False`)\
124
119
  If set to `True`, `checkpointer` includes global or closure-based variables used by the function in its hash calculation. This ensures that changes to these external variables also trigger cache invalidation and recomputation.
125
120
 
126
- * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)
121
+ * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
127
122
  A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
128
123
 
129
- * **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)
124
+ * **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)\
130
125
  A custom callable that takes the function's arguments (`*args`, `**kwargs`) and returns a hashable object (or tuple of objects). This allows for custom argument normalization (e.g., sorting lists before hashing) or optimized hashing for complex input types, which can improve cache hit rates or speed up the hashing process.
131
126
 
132
- * **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)
127
+ * **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
133
128
  An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
134
129
 
135
- * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)
130
+ * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
136
131
  Controls the level of logging output from `checkpointer`.
137
132
  * `0`: No output.
138
133
  * `1`: Shows when functions are computed and cached.
@@ -156,8 +151,11 @@ class MyCustomStorage(Storage):
156
151
  fn_dir = self.checkpointer.root_path / self.fn_id()
157
152
  return (fn_dir / call_id).exists()
158
153
 
154
+ def store(self, call_id, data):
155
+ ... # Store the serialized data for `call_id`
156
+ return data # This method must return the data back to checkpointer
157
+
159
158
  def checkpoint_date(self, call_id): ...
160
- def store(self, call_id, data): ...
161
159
  def load(self, call_id): ...
162
160
  def delete(self, call_id): ...
163
161
 
@@ -28,19 +28,14 @@ result = expensive_function(4) # Loads from the cache
28
28
 
29
29
  ## 🧠 How It Works
30
30
 
31
- When a function decorated with `@checkpoint` is called:
31
+ When a `@checkpoint`-decorated function is called, `checkpointer` first computes a unique identifier (hash) for the call. This hash is derived from the function's source code, its dependencies, and the arguments passed.
32
32
 
33
- 1. `checkpointer` computes a unique identifier (hash) for the function call based on its source code, its dependencies, and the arguments passed.
34
- 2. It attempts to retrieve a cached result using this identifier.
35
- 3. If a cached result is found, it's returned immediately.
36
- 4. If no cached result exists or the cache has expired, the original function is executed, its result is stored, and then returned.
33
+ It then tries to retrieve a cached result using this ID. If a valid cached result is found, it's returned immediately. Otherwise, the original function executes, its result is stored, and then returned.
37
34
 
38
- ### ♻️ Automatic Cache Invalidation
35
+ Cache validity is determined by this function's hash, which automatically updates if:
39
36
 
40
- `checkpointer` ensures caches are invalidated automatically when the underlying computation changes. A function's hash, which determines cache validity, updates if:
41
-
42
- * **Function Code Changes**: The source code of the decorated function itself is modified.
43
- * **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated with `@checkpoint`) is modified.
37
+ * **Function Code Changes**: The decorated function's source code is modified.
38
+ * **Dependencies Change**: Any user-defined function in its dependency tree (direct or indirect, even across modules or not decorated) is modified.
44
39
  * **Captured Variables Change** (with `capture=True`): Global or closure-based variables used within the function are altered.
45
40
 
46
41
  **Example: Dependency Invalidation**
@@ -56,63 +51,63 @@ def helper(x):
56
51
 
57
52
  @checkpoint
58
53
  def compute(a, b):
59
- # Depends on `helper`
54
+ # Depends on `helper` and `multiply`
60
55
  return helper(a) + helper(b)
61
56
  ```
62
57
 
63
- If `multiply` is modified, caches for both `helper` and `compute` will automatically be invalidated and recomputed upon their next call.
58
+ If `multiply` is modified, caches for both `helper` and `compute` are automatically invalidated and recomputed.
64
59
 
65
60
  ## 💡 Usage
66
61
 
67
62
  Once a function is decorated with `@checkpoint`, you can interact with its caching behavior using the following methods:
68
63
 
69
- * **`expensive_function(...)`**:
64
+ * **`expensive_function(...)`**:\
70
65
  Call the function normally. This will either compute and cache the result or load it from the cache if available.
71
66
 
72
- * **`expensive_function.rerun(...)`**:
67
+ * **`expensive_function.rerun(...)`**:\
73
68
  Forces the original function to execute, compute a new result, and overwrite any existing cached value for the given arguments.
74
69
 
75
- * **`expensive_function.fn(...)`**:
70
+ * **`expensive_function.fn(...)`**:\
76
71
  Calls the original, undecorated function directly, bypassing the cache entirely. This is particularly useful within recursive functions to prevent caching intermediate steps.
77
72
 
78
- * **`expensive_function.get(...)`**:
73
+ * **`expensive_function.get(...)`**:\
79
74
  Attempts to retrieve the cached result for the given arguments without executing the original function. Raises `CheckpointError` if no valid cached result exists.
80
75
 
81
- * **`expensive_function.exists(...)`**:
76
+ * **`expensive_function.exists(...)`**:\
82
77
  Checks if a cached result exists for the given arguments without attempting to compute or load it. Returns `True` if a valid checkpoint exists, `False` otherwise.
83
78
 
84
- * **`expensive_function.delete(...)`**:
79
+ * **`expensive_function.delete(...)`**:\
85
80
  Removes the cached entry for the specified arguments.
86
81
 
87
- * **`expensive_function.reinit()`**:
82
+ * **`expensive_function.reinit()`**:\
88
83
  Recalculates the function's internal hash. This is primarily used when `capture=True` and you need to update the cache based on changes to external variables within the same Python session.
89
84
 
90
85
  ## ⚙️ Configuration & Customization
91
86
 
92
87
  The `@checkpoint` decorator accepts the following parameters to customize its behavior:
93
88
 
94
- * **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)
89
+ * **`format`** (Type: `str` or `checkpointer.Storage`, Default: `"pickle"`)\
95
90
  Defines the storage backend to use. Built-in options are `"pickle"` (disk-based, persistent) and `"memory"` (in-memory, non-persistent). You can also provide a custom `Storage` class.
96
91
 
97
- * **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)
92
+ * **`root_path`** (Type: `str` or `pathlib.Path` or `None`, Default: `~/.cache/checkpoints`)\
98
93
  The base directory for storing disk-based checkpoints. This parameter is only relevant when `format` is set to `"pickle"`.
99
94
 
100
- * **`when`** (Type: `bool`, Default: `True`)
95
+ * **`when`** (Type: `bool`, Default: `True`)\
101
96
  A boolean flag to enable or disable checkpointing for the decorated function. This is particularly useful for toggling caching based on environment variables (e.g., `when=os.environ.get("ENABLE_CACHING", "false").lower() == "true"`).
102
97
 
103
- * **`capture`** (Type: `bool`, Default: `False`)
98
+ * **`capture`** (Type: `bool`, Default: `False`)\
104
99
  If set to `True`, `checkpointer` includes global or closure-based variables used by the function in its hash calculation. This ensures that changes to these external variables also trigger cache invalidation and recomputation.
105
100
 
106
- * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)
101
+ * **`should_expire`** (Type: `Callable[[datetime.datetime], bool]`, Default: `None`)\
107
102
  A custom callable that receives the `datetime` timestamp of a cached result. It should return `True` if the cached result is considered expired and needs recomputation, or `False` otherwise.
108
103
 
109
- * **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)
104
+ * **`hash_by`** (Type: `Callable[..., Any]`, Default: `None`)\
110
105
  A custom callable that takes the function's arguments (`*args`, `**kwargs`) and returns a hashable object (or tuple of objects). This allows for custom argument normalization (e.g., sorting lists before hashing) or optimized hashing for complex input types, which can improve cache hit rates or speed up the hashing process.
111
106
 
112
- * **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)
107
+ * **`fn_hash`** (Type: `checkpointer.ObjectHash`, Default: `None`)\
113
108
  An optional parameter that takes an instance of `checkpointer.ObjectHash`. This allows you to override the automatically computed function hash, giving you explicit control over when the function's cache should be invalidated. You can pass any values relevant to your invalidation logic to `ObjectHash` (e.g., `ObjectHash(version_string, config_id, ...)`, as it can consistently hash most Python values.
114
109
 
115
- * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)
110
+ * **`verbosity`** (Type: `int` (`0`, `1`, or `2`), Default: `1`)\
116
111
  Controls the level of logging output from `checkpointer`.
117
112
  * `0`: No output.
118
113
  * `1`: Shows when functions are computed and cached.
@@ -136,8 +131,11 @@ class MyCustomStorage(Storage):
136
131
  fn_dir = self.checkpointer.root_path / self.fn_id()
137
132
  return (fn_dir / call_id).exists()
138
133
 
134
+ def store(self, call_id, data):
135
+ ... # Store the serialized data for `call_id`
136
+ return data # This method must return the data back to checkpointer
137
+
139
138
  def checkpoint_date(self, call_id): ...
140
- def store(self, call_id, data): ...
141
139
  def load(self, call_id): ...
142
140
  def delete(self, call_id): ...
143
141
 
@@ -1,11 +1,11 @@
1
1
  import gc
2
2
  import tempfile
3
3
  from typing import Callable
4
- from .checkpoint import Checkpointer, CheckpointError, CheckpointFn
4
+ from .checkpoint import CachedFunction, Checkpointer, CheckpointError
5
5
  from .object_hash import ObjectHash
6
6
  from .storages import MemoryStorage, PickleStorage, Storage
7
+ from .utils import AwaitableValue
7
8
 
8
- create_checkpointer = Checkpointer
9
9
  checkpoint = Checkpointer()
10
10
  capture_checkpoint = Checkpointer(capture=True)
11
11
  memory_checkpoint = Checkpointer(format="memory", verbosity=0)
@@ -14,8 +14,8 @@ static_checkpoint = Checkpointer(fn_hash=ObjectHash())
14
14
 
15
15
  def cleanup_all(invalidated=True, expired=True):
16
16
  for obj in gc.get_objects():
17
- if isinstance(obj, CheckpointFn):
17
+ if isinstance(obj, CachedFunction):
18
18
  obj.cleanup(invalidated=invalidated, expired=expired)
19
19
 
20
20
  def get_function_hash(fn: Callable, capture=False) -> str:
21
- return CheckpointFn(Checkpointer(capture=capture), fn).fn_hash
21
+ return CachedFunction(Checkpointer(capture=capture), fn).fn_hash
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
  import inspect
3
3
  import re
4
- from contextlib import suppress
5
4
  from datetime import datetime
6
5
  from functools import cached_property, update_wrapper
7
6
  from pathlib import Path
8
- from typing import Awaitable, Callable, Generic, Iterable, Literal, ParamSpec, Type, TypedDict, TypeVar, Unpack, cast, overload
7
+ from typing import (
8
+ Awaitable, Callable, Concatenate, Generic, Iterable, Literal,
9
+ ParamSpec, Self, Type, TypedDict, TypeVar, Unpack, cast, overload,
10
+ )
9
11
  from .fn_ident import get_fn_ident
10
12
  from .object_hash import ObjectHash
11
13
  from .print_checkpoint import print_checkpoint
@@ -15,6 +17,7 @@ from .utils import AwaitableValue, unwrap_fn
15
17
  Fn = TypeVar("Fn", bound=Callable)
16
18
  P = ParamSpec("P")
17
19
  R = TypeVar("R")
20
+ C = TypeVar("C")
18
21
 
19
22
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
20
23
 
@@ -43,17 +46,17 @@ class Checkpointer:
43
46
  self.fn_hash = opts.get("fn_hash")
44
47
 
45
48
  @overload
46
- def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CheckpointFn[Fn]: ...
49
+ def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CachedFunction[Fn]: ...
47
50
  @overload
48
51
  def __call__(self, fn: None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer: ...
49
- def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer | CheckpointFn[Fn]:
52
+ def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Checkpointer | CachedFunction[Fn]:
50
53
  if override_opts:
51
54
  opts = CheckpointerOpts(**{**self.__dict__, **override_opts})
52
55
  return Checkpointer(**opts)(fn)
53
56
 
54
- return CheckpointFn(self, fn) if callable(fn) else self
57
+ return CachedFunction(self, fn) if callable(fn) else self
55
58
 
56
- class CheckpointFn(Generic[Fn]):
59
+ class CachedFunction(Generic[Fn]):
57
60
  def __init__(self, checkpointer: Checkpointer, fn: Fn):
58
61
  wrapped = unwrap_fn(fn)
59
62
  fn_file = Path(wrapped.__code__.co_filename).name
@@ -65,6 +68,19 @@ class CheckpointFn(Generic[Fn]):
65
68
  self.fn_dir = f"{fn_file}/{fn_name}"
66
69
  self.storage = Storage(self)
67
70
  self.cleanup = self.storage.cleanup
71
+ self.bound = ()
72
+
73
+ @overload
74
+ def __get__(self: Self, instance: None, owner: Type[C]) -> Self: ...
75
+ @overload
76
+ def __get__(self: CachedFunction[Callable[Concatenate[C, P], R]], instance: C, owner: Type[C]) -> CachedFunction[Callable[P, R]]: ...
77
+ def __get__(self, instance, owner):
78
+ if instance is None:
79
+ return self
80
+ bound_fn = object.__new__(CachedFunction)
81
+ bound_fn.__dict__ |= self.__dict__
82
+ bound_fn.bound = (instance,)
83
+ return bound_fn
68
84
 
69
85
  @cached_property
70
86
  def ident_tuple(self) -> tuple[str, list[Callable]]:
@@ -80,33 +96,33 @@ class CheckpointFn(Generic[Fn]):
80
96
 
81
97
  @cached_property
82
98
  def fn_hash(self) -> str:
83
- fn_hash = self.checkpointer.fn_hash
84
99
  deep_hashes = [depend.fn_hash_raw for depend in self.deep_depends()]
85
- return str(fn_hash or ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes))[:32]
100
+ fn_hash = ObjectHash(digest_size=16).write_text(self.fn_hash_raw, *deep_hashes)
101
+ return str(self.checkpointer.fn_hash or fn_hash)[:32]
86
102
 
87
- def reinit(self, recursive=False) -> CheckpointFn[Fn]:
103
+ def reinit(self, recursive=False) -> CachedFunction[Fn]:
88
104
  depends = list(self.deep_depends()) if recursive else [self]
89
105
  for depend in depends:
90
- with suppress(AttributeError):
91
- del depend.ident_tuple, depend.fn_hash
106
+ self.__dict__.pop("fn_hash", None)
107
+ self.__dict__.pop("ident_tuple", None)
92
108
  for depend in depends:
93
109
  depend.fn_hash
94
110
  return self
95
111
 
96
112
  def get_call_id(self, args: tuple, kw: dict) -> str:
113
+ args = self.bound + args
97
114
  hash_by = self.checkpointer.hash_by
98
115
  hash_params = hash_by(*args, **kw) if hash_by else (args, kw)
99
116
  return str(ObjectHash(hash_params, digest_size=16))
100
117
 
101
- async def _resolve_awaitable(self, checkpoint_id: str, awaitable: Awaitable):
102
- data = await awaitable
103
- self.storage.store(checkpoint_id, AwaitableValue(data))
104
- return data
118
+ async def _resolve_awaitable(self, call_id: str, awaitable: Awaitable):
119
+ return await self.storage.store(call_id, AwaitableValue(await awaitable))
105
120
 
106
- def _call(self, args: tuple, kw: dict, rerun=False):
121
+ def _call(self: CachedFunction[Callable[P, R]], args: tuple, kw: dict, rerun=False) -> R:
122
+ full_args = self.bound + args
107
123
  params = self.checkpointer
108
124
  if not params.when:
109
- return self.fn(*args, **kw)
125
+ return self.fn(*full_args, **kw)
110
126
 
111
127
  call_id = self.get_call_id(args, kw)
112
128
  call_id_long = f"{self.fn_dir}/{self.fn_hash}/{call_id}"
@@ -117,12 +133,10 @@ class CheckpointFn(Generic[Fn]):
117
133
 
118
134
  if refresh:
119
135
  print_checkpoint(params.verbosity >= 1, "MEMORIZING", call_id_long, "blue")
120
- data = self.fn(*args, **kw)
136
+ data = self.fn(*full_args, **kw)
121
137
  if inspect.isawaitable(data):
122
138
  return self._resolve_awaitable(call_id, data)
123
- else:
124
- self.storage.store(call_id, data)
125
- return data
139
+ return self.storage.store(call_id, data)
126
140
 
127
141
  try:
128
142
  data = self.storage.load(call_id)
@@ -133,14 +147,16 @@ class CheckpointFn(Generic[Fn]):
133
147
  print_checkpoint(params.verbosity >= 1, "CORRUPTED", call_id_long, "yellow")
134
148
  return self._call(args, kw, True)
135
149
 
136
- __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
137
- rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
150
+ def __call__(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
151
+ return self._call(args, kw)
152
+
153
+ def rerun(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> R:
154
+ return self._call(args, kw, True)
138
155
 
139
156
  @overload
140
157
  def get(self: Callable[P, Awaitable[R]], *args: P.args, **kw: P.kwargs) -> R: ...
141
158
  @overload
142
159
  def get(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> R: ...
143
-
144
160
  def get(self, *args, **kw):
145
161
  call_id = self.get_call_id(args, kw)
146
162
  try:
@@ -149,22 +165,20 @@ class CheckpointFn(Generic[Fn]):
149
165
  except Exception as ex:
150
166
  raise CheckpointError("Could not load checkpoint") from ex
151
167
 
152
- def exists(self: Callable[P, R], *args: P.args, **kw: P.kwargs) -> bool: # type: ignore
153
- self = cast(CheckpointFn, self)
168
+ def exists(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs) -> bool:
154
169
  return self.storage.exists(self.get_call_id(args, kw))
155
170
 
156
- def delete(self: Callable[P, R], *args: P.args, **kw: P.kwargs): # type: ignore
157
- self = cast(CheckpointFn, self)
171
+ def delete(self: CachedFunction[Callable[P, R]], *args: P.args, **kw: P.kwargs):
158
172
  self.storage.delete(self.get_call_id(args, kw))
159
173
 
160
174
  def __repr__(self) -> str:
161
- return f"<CheckpointFn {self.fn.__name__} {self.fn_hash[:6]}>"
175
+ return f"<CachedFunction {self.fn.__name__} {self.fn_hash[:6]}>"
162
176
 
163
- def deep_depends(self, visited: set[CheckpointFn] = set()) -> Iterable[CheckpointFn]:
177
+ def deep_depends(self, visited: set[CachedFunction] = set()) -> Iterable[CachedFunction]:
164
178
  if self not in visited:
165
179
  yield self
166
180
  visited = visited or set()
167
181
  visited.add(self)
168
182
  for depend in self.depends:
169
- if isinstance(depend, CheckpointFn):
183
+ if isinstance(depend, CachedFunction):
170
184
  yield from depend.deep_depends(visited)
@@ -8,7 +8,7 @@ from typing import Any, Iterable, Type, TypeGuard
8
8
  from .object_hash import ObjectHash
9
9
  from .utils import AttrDict, distinct, get_cell_contents, iterate_and_upcoming, transpose, unwrap_fn
10
10
 
11
- cwd = Path.cwd()
11
+ cwd = Path.cwd().resolve()
12
12
 
13
13
  def is_class(obj) -> TypeGuard[Type]:
14
14
  # isinstance works too, but needlessly triggers _lazyinit()
@@ -72,23 +72,23 @@ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
72
72
  return cwd in fn_path.parents and ".venv" not in fn_path.parts
73
73
 
74
74
  def get_depend_fns(fn: Callable, capture: bool, captured_vals_by_fn: dict[Callable, list[Any]] = {}) -> dict[Callable, list[Any]]:
75
- from .checkpoint import CheckpointFn
75
+ from .checkpoint import CachedFunction
76
76
  captured_vals_by_fn = captured_vals_by_fn or {}
77
77
  captured_vals = get_fn_captured_vals(fn)
78
78
  captured_vals_by_fn[fn] = [val for val in captured_vals if not callable(val)] * capture
79
- child_fns = (unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val))
79
+ child_fns = (unwrap_fn(val, cached_fn=True) for val in captured_vals if callable(val))
80
80
  for child_fn in child_fns:
81
- if isinstance(child_fn, CheckpointFn):
81
+ if isinstance(child_fn, CachedFunction):
82
82
  captured_vals_by_fn[child_fn] = []
83
83
  elif child_fn not in captured_vals_by_fn and is_user_fn(child_fn):
84
84
  get_depend_fns(child_fn, capture, captured_vals_by_fn)
85
85
  return captured_vals_by_fn
86
86
 
87
87
  def get_fn_ident(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
88
- from .checkpoint import CheckpointFn
88
+ from .checkpoint import CachedFunction
89
89
  captured_vals_by_fn = get_depend_fns(fn, capture)
90
90
  depends, depend_captured_vals = transpose(captured_vals_by_fn.items(), 2)
91
91
  depends = distinct(fn.__func__ if isinstance(fn, MethodType) else fn for fn in depends)
92
- unwrapped_depends = [fn for fn in depends if not isinstance(fn, CheckpointFn)]
92
+ unwrapped_depends = [fn for fn in depends if not isinstance(fn, CachedFunction)]
93
93
  fn_hash = str(ObjectHash(fn, unwrapped_depends).update(depend_captured_vals, tolerate_errors=True))
94
94
  return fn_hash, depends
@@ -180,8 +180,10 @@ class ObjectHash:
180
180
  def _update_iterator(self, obj: Iterable) -> "ObjectHash":
181
181
  return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
182
182
 
183
- def _update_object(self, obj: object) -> "ObjectHash":
183
+ def _update_object(self, obj: Any) -> "ObjectHash":
184
184
  self.header("instance", encode_type_of(obj))
185
+ if hasattr(obj, "__objecthash__") and callable(obj.__objecthash__):
186
+ return self.header("objecthash").update(obj.__objecthash__())
185
187
  reduced = None
186
188
  with suppress(Exception):
187
189
  reduced = obj.__reduce_ex__(PROTOCOL)
@@ -2,10 +2,8 @@ from typing import Type
2
2
  from .storage import Storage
3
3
  from .pickle_storage import PickleStorage
4
4
  from .memory_storage import MemoryStorage
5
- from .bcolz_storage import BcolzStorage
6
5
 
7
6
  STORAGE_MAP: dict[str, Type[Storage]] = {
8
7
  "pickle": PickleStorage,
9
8
  "memory": MemoryStorage,
10
- "bcolz": BcolzStorage,
11
9
  }
@@ -11,6 +11,7 @@ class MemoryStorage(Storage):
11
11
 
12
12
  def store(self, call_id, data):
13
13
  self.get_dict()[call_id] = (datetime.now(), data)
14
+ return data
14
15
 
15
16
  def exists(self, call_id):
16
17
  return call_id in self.get_dict()
@@ -1,8 +1,12 @@
1
1
  import pickle
2
2
  import shutil
3
3
  from datetime import datetime
4
+ from pathlib import Path
4
5
  from .storage import Storage
5
6
 
7
+ def filedate(path: Path) -> datetime:
8
+ return datetime.fromtimestamp(path.stat().st_mtime)
9
+
6
10
  class PickleStorage(Storage):
7
11
  def get_path(self, call_id: str):
8
12
  return self.fn_dir() / f"{call_id}.pkl"
@@ -12,13 +16,14 @@ class PickleStorage(Storage):
12
16
  path.parent.mkdir(parents=True, exist_ok=True)
13
17
  with path.open("wb") as file:
14
18
  pickle.dump(data, file, -1)
19
+ return data
15
20
 
16
21
  def exists(self, call_id):
17
22
  return self.get_path(call_id).exists()
18
23
 
19
24
  def checkpoint_date(self, call_id):
20
25
  # Should use st_atime/access time?
21
- return datetime.fromtimestamp(self.get_path(call_id).stat().st_mtime)
26
+ return filedate(self.get_path(call_id))
22
27
 
23
28
  def load(self, call_id):
24
29
  with self.get_path(call_id).open("rb") as file:
@@ -34,11 +39,11 @@ class PickleStorage(Storage):
34
39
  old_dirs = [path for path in fn_path.iterdir() if path.is_dir() and path != version_path]
35
40
  for path in old_dirs:
36
41
  shutil.rmtree(path)
37
- print(f"Removed {len(old_dirs)} invalidated directories for {self.checkpoint_fn.__qualname__}")
42
+ print(f"Removed {len(old_dirs)} invalidated directories for {self.cached_fn.__qualname__}")
38
43
  if expired and self.checkpointer.should_expire:
39
44
  count = 0
40
- for pkl_path in fn_path.rglob("*.pkl"):
41
- if self.checkpointer.should_expire(self.checkpoint_date(pkl_path.stem)):
45
+ for pkl_path in fn_path.glob("**/*.pkl"):
46
+ if self.checkpointer.should_expire(filedate(pkl_path)):
42
47
  count += 1
43
- self.delete(pkl_path.stem)
44
- print(f"Removed {count} expired checkpoints for {self.checkpoint_fn.__qualname__}")
48
+ pkl_path.unlink(missing_ok=True)
49
+ print(f"Removed {count} expired checkpoints for {self.cached_fn.__qualname__}")
@@ -4,23 +4,23 @@ from pathlib import Path
4
4
  from datetime import datetime
5
5
 
6
6
  if TYPE_CHECKING:
7
- from ..checkpoint import Checkpointer, CheckpointFn
7
+ from ..checkpoint import Checkpointer, CachedFunction
8
8
 
9
9
  class Storage:
10
10
  checkpointer: Checkpointer
11
- checkpoint_fn: CheckpointFn
11
+ cached_fn: CachedFunction
12
12
 
13
- def __init__(self, checkpoint_fn: CheckpointFn):
14
- self.checkpointer = checkpoint_fn.checkpointer
15
- self.checkpoint_fn = checkpoint_fn
13
+ def __init__(self, cached_fn: CachedFunction):
14
+ self.checkpointer = cached_fn.checkpointer
15
+ self.cached_fn = cached_fn
16
16
 
17
17
  def fn_id(self) -> str:
18
- return f"{self.checkpoint_fn.fn_dir}/{self.checkpoint_fn.fn_hash}"
18
+ return f"{self.cached_fn.fn_dir}/{self.cached_fn.fn_hash}"
19
19
 
20
20
  def fn_dir(self) -> Path:
21
21
  return self.checkpointer.root_path / self.fn_id()
22
22
 
23
- def store(self, call_id: str, data: Any) -> None: ...
23
+ def store(self, call_id: str, data: Any) -> Any: ...
24
24
 
25
25
  def exists(self, call_id: str) -> bool: ...
26
26
 
@@ -1,7 +1,6 @@
1
1
  import asyncio
2
2
  import pytest
3
3
  from riprint import riprint as print
4
- from types import MethodType, MethodWrapperType
5
4
  from . import checkpoint
6
5
  from .checkpoint import CheckpointError
7
6
  from .utils import AttrDict
@@ -32,10 +32,10 @@ def get_cell_contents(fn: Callable) -> Iterable[tuple[str, Any]]:
32
32
  except ValueError:
33
33
  pass
34
34
 
35
- def unwrap_fn(fn: Fn, checkpoint_fn=False) -> Fn:
36
- from .checkpoint import CheckpointFn
35
+ def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
36
+ from .checkpoint import CachedFunction
37
37
  while True:
38
- if (checkpoint_fn and isinstance(fn, CheckpointFn)) or not hasattr(fn, "__wrapped__"):
38
+ if (cached_fn and isinstance(fn, CachedFunction)) or not hasattr(fn, "__wrapped__"):
39
39
  return cast(Fn, fn)
40
40
  fn = getattr(fn, "__wrapped__")
41
41
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.9.1"
3
+ version = "2.10.0"
4
4
  requires-python = ">=3.11"
5
5
  dependencies = []
6
6
  authors = [
@@ -8,7 +8,7 @@ resolution-markers = [
8
8
 
9
9
  [[package]]
10
10
  name = "checkpointer"
11
- version = "2.9.1"
11
+ version = "2.10.0"
12
12
  source = { editable = "." }
13
13
 
14
14
  [package.dev-dependencies]
@@ -1,80 +0,0 @@
1
- import shutil
2
- from pathlib import Path
3
- from datetime import datetime
4
- from .storage import Storage
5
-
6
- def get_data_type_str(x):
7
- if isinstance(x, tuple):
8
- return "tuple"
9
- elif isinstance(x, dict):
10
- return "dict"
11
- elif isinstance(x, list):
12
- return "list"
13
- elif isinstance(x, str) or not hasattr(x, "__len__"):
14
- return "other"
15
- else:
16
- return "ndarray"
17
-
18
- def get_metapath(path: Path):
19
- return path.with_name(f"{path.name}_meta")
20
-
21
- def insert_data(path: Path, data):
22
- import bcolz
23
- c = bcolz.carray(data, rootdir=path, mode="w")
24
- c.flush()
25
-
26
- class BcolzStorage(Storage):
27
- def exists(self, path):
28
- return path.exists()
29
-
30
- def checkpoint_date(self, path):
31
- return datetime.fromtimestamp(path.stat().st_mtime)
32
-
33
- def store(self, path, data):
34
- metapath = get_metapath(path)
35
- path.parent.mkdir(parents=True, exist_ok=True)
36
- data_type_str = get_data_type_str(data)
37
- if data_type_str == "tuple":
38
- fields = list(range(len(data)))
39
- elif data_type_str == "dict":
40
- fields = sorted(data.keys())
41
- else:
42
- fields = []
43
- meta_data = {"data_type_str": data_type_str, "fields": fields}
44
- insert_data(metapath, meta_data)
45
- if data_type_str in ["tuple", "dict"]:
46
- for i in range(len(fields)):
47
- child_path = Path(f"{path} ({i})")
48
- self.store(child_path, data[fields[i]])
49
- else:
50
- insert_data(path, data)
51
-
52
- def load(self, path):
53
- import bcolz
54
- metapath = get_metapath(path)
55
- meta_data = bcolz.open(metapath)[:][0]
56
- data_type_str = meta_data["data_type_str"]
57
- if data_type_str in ["tuple", "dict"]:
58
- fields = meta_data["fields"]
59
- partitions = range(len(fields))
60
- data = [self.load(Path(f"{path} ({i})")) for i in partitions]
61
- if data_type_str == "tuple":
62
- return tuple(data)
63
- else:
64
- return dict(zip(fields, data))
65
- else:
66
- data = bcolz.open(path)
67
- if data_type_str == "list":
68
- return list(data)
69
- elif data_type_str == "other":
70
- return data[0]
71
- else:
72
- return data[:]
73
-
74
- def delete(self, path):
75
- # NOTE: Not recursive
76
- shutil.rmtree(get_metapath(path), ignore_errors=True)
77
- shutil.rmtree(path, ignore_errors=True)
78
-
79
- def cleanup(self, invalidated=True, expired=True):
80
- raise NotImplementedError("cleanup() not implemented for bcolz storage")
File without changes
File without changes