checkpointer 2.0.2__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: checkpointer
3
- Version: 2.0.2
3
+ Version: 2.1.0
4
4
  Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
5
  Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
6
  Author: Hampus Hallman
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
19
19
 
20
20
  `checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
21
21
 
22
- Adding or removing `@checkpoint` doesn't change how your code works, and it can be applied to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
22
+ Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
23
23
 
24
24
  ### Key Features:
25
25
  - 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
@@ -27,6 +27,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
27
27
  - 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
28
28
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
29
29
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
30
+ - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
30
31
 
31
32
  ---
32
33
 
@@ -59,8 +60,10 @@ result = expensive_function(4) # Loads from the cache
59
60
  When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
60
61
 
61
62
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
63
+
62
64
  1. **Its source code**: Changes to the function's code update its hash.
63
65
  2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
66
+ 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
64
67
 
65
68
  ### Example: Cache Invalidation
66
69
 
@@ -115,7 +118,17 @@ def some_expensive_function():
115
118
 
116
119
  ## Usage
117
120
 
121
+ ### Basic Invocation and Caching
122
+
123
+ Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
124
+
125
+ ```python
126
+ result = expensive_function(4) # Computes and stores the result
127
+ result = expensive_function(4) # Loads the result from the cache
128
+ ```
129
+
118
130
  ### Force Recalculation
131
+
119
132
  Force a recalculation and overwrite the stored checkpoint:
120
133
 
121
134
  ```python
@@ -123,6 +136,7 @@ result = expensive_function.rerun(4)
123
136
  ```
124
137
 
125
138
  ### Call the Original Function
139
+
126
140
  Use `fn` to directly call the original, undecorated function:
127
141
 
128
142
  ```python
@@ -132,6 +146,7 @@ result = expensive_function.fn(4)
132
146
  This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
133
147
 
134
148
  ### Retrieve Stored Checkpoints
149
+
135
150
  Access cached results without recalculating:
136
151
 
137
152
  ```python
@@ -154,11 +169,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
154
169
  ```python
155
170
  from checkpointer import checkpoint, PickleStorage, MemoryStorage
156
171
 
157
- @checkpoint(format="pickle") # Equivalent to format=PickleStorage
172
+ @checkpoint(format="pickle") # Short for format=PickleStorage
158
173
  def disk_cached(x: int) -> int:
159
174
  return x ** 2
160
175
 
161
- @checkpoint(format="memory") # Equivalent to format=MemoryStorage
176
+ @checkpoint(format="memory") # Short for format=MemoryStorage
162
177
  def memory_cached(x: int) -> int:
163
178
  return x * 10
164
179
  ```
@@ -191,14 +206,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
191
206
 
192
207
  ## Configuration Options ⚙️
193
208
 
194
- | Option | Type | Default | Description |
195
- |----------------|-------------------------------------|-------------|---------------------------------------------|
196
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
197
- | `root_path` | `Path`, `str`, or `None` | User Cache | Root directory for storing checkpoints. |
198
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
199
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
200
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
201
- | `should_expire`| `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
209
+ | Option | Type | Default | Description |
210
+ |-----------------|-----------------------------------|----------------------|------------------------------------------------|
211
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
212
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
213
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
214
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
215
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
216
+ | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
217
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
202
218
 
203
219
  ---
204
220
 
@@ -220,13 +236,13 @@ async def async_compute_sum(a: int, b: int) -> int:
220
236
 
221
237
  async def main():
222
238
  result1 = compute_square(5)
223
- print(result1)
239
+ print(result1) # Outputs 25
224
240
 
225
241
  result2 = await async_compute_sum(3, 7)
226
- print(result2)
242
+ print(result2) # Outputs 10
227
243
 
228
- result3 = async_compute_sum.get(3, 7)
229
- print(result3)
244
+ result3 = await async_compute_sum.get(3, 7)
245
+ print(result3) # Outputs 10
230
246
 
231
247
  asyncio.run(main())
232
248
  ```
@@ -2,7 +2,7 @@
2
2
 
3
3
  `checkpointer` is a Python library for memoizing function results. It provides a decorator-based API with support for multiple storage backends. Use it for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations.
4
4
 
5
- Adding or removing `@checkpoint` doesn't change how your code works, and it can be applied to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
5
+ Adding or removing `@checkpoint` doesn't change how your code works. You can apply it to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
6
6
 
7
7
  ### Key Features:
8
8
  - 🗂️ **Multiple Storage Backends**: Built-in support for in-memory and pickle-based storage, or create your own.
@@ -10,6 +10,7 @@ Adding or removing `@checkpoint` doesn't change how your code works, and it can
10
10
  - 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
11
11
  - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
12
12
  - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
13
+ - 📦 **Captured Variables Handling**: Optionally include captured variables in cache invalidation.
13
14
 
14
15
  ---
15
16
 
@@ -42,8 +43,10 @@ result = expensive_function(4) # Loads from the cache
42
43
  When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` loads the cached result instead of recomputing.
43
44
 
44
45
  Additionally, `checkpointer` ensures that caches are invalidated when a function's implementation or any of its dependencies change. Each function is assigned a hash based on:
46
+
45
47
  1. **Its source code**: Changes to the function's code update its hash.
46
48
  2. **Dependent functions**: If a function calls others, changes in those dependencies will also update the hash.
49
+ 3. **Captured variables**: (Optional) If `capture=True`, changes to captured variables and global variables will also update the hash.
47
50
 
48
51
  ### Example: Cache Invalidation
49
52
 
@@ -98,7 +101,17 @@ def some_expensive_function():
98
101
 
99
102
  ## Usage
100
103
 
104
+ ### Basic Invocation and Caching
105
+
106
+ Call the decorated function as usual. On the first call, the result is computed and stored in the cache. Subsequent calls with the same arguments load the result from the cache:
107
+
108
+ ```python
109
+ result = expensive_function(4) # Computes and stores the result
110
+ result = expensive_function(4) # Loads the result from the cache
111
+ ```
112
+
101
113
  ### Force Recalculation
114
+
102
115
  Force a recalculation and overwrite the stored checkpoint:
103
116
 
104
117
  ```python
@@ -106,6 +119,7 @@ result = expensive_function.rerun(4)
106
119
  ```
107
120
 
108
121
  ### Call the Original Function
122
+
109
123
  Use `fn` to directly call the original, undecorated function:
110
124
 
111
125
  ```python
@@ -115,6 +129,7 @@ result = expensive_function.fn(4)
115
129
  This is especially useful **inside recursive functions** to avoid redundant caching of intermediate steps while still caching the final result.
116
130
 
117
131
  ### Retrieve Stored Checkpoints
132
+
118
133
  Access cached results without recalculating:
119
134
 
120
135
  ```python
@@ -137,11 +152,11 @@ You can specify a storage backend using either its name (`"pickle"` or `"memory"
137
152
  ```python
138
153
  from checkpointer import checkpoint, PickleStorage, MemoryStorage
139
154
 
140
- @checkpoint(format="pickle") # Equivalent to format=PickleStorage
155
+ @checkpoint(format="pickle") # Short for format=PickleStorage
141
156
  def disk_cached(x: int) -> int:
142
157
  return x ** 2
143
158
 
144
- @checkpoint(format="memory") # Equivalent to format=MemoryStorage
159
+ @checkpoint(format="memory") # Short for format=MemoryStorage
145
160
  def memory_cached(x: int) -> int:
146
161
  return x * 10
147
162
  ```
@@ -174,14 +189,15 @@ Using a custom backend lets you tailor storage to your application, whether it i
174
189
 
175
190
  ## Configuration Options ⚙️
176
191
 
177
- | Option | Type | Default | Description |
178
- |----------------|-------------------------------------|-------------|---------------------------------------------|
179
- | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
180
- | `root_path` | `Path`, `str`, or `None` | User Cache | Root directory for storing checkpoints. |
181
- | `when` | `bool` | `True` | Enable or disable checkpointing. |
182
- | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
183
- | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
184
- | `should_expire`| `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
192
+ | Option | Type | Default | Description |
193
+ |-----------------|-----------------------------------|----------------------|------------------------------------------------|
194
+ | `capture` | `bool` | `False` | Include captured variables in function hashes. |
195
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
196
+ | `root_path` | `Path`, `str`, or `None` | ~/.cache/checkpoints | Root directory for storing checkpoints. |
197
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
198
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
199
+ | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
200
+ | `should_expire` | `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
185
201
 
186
202
  ---
187
203
 
@@ -203,13 +219,13 @@ async def async_compute_sum(a: int, b: int) -> int:
203
219
 
204
220
  async def main():
205
221
  result1 = compute_square(5)
206
- print(result1)
222
+ print(result1) # Outputs 25
207
223
 
208
224
  result2 = await async_compute_sum(3, 7)
209
- print(result2)
225
+ print(result2) # Outputs 10
210
226
 
211
- result3 = async_compute_sum.get(3, 7)
212
- print(result3)
227
+ result3 = await async_compute_sum.get(3, 7)
228
+ print(result3) # Outputs 10
213
229
 
214
230
  asyncio.run(main())
215
231
  ```
@@ -5,5 +5,6 @@ import tempfile
5
5
 
6
6
  create_checkpointer = Checkpointer
7
7
  checkpoint = Checkpointer()
8
+ capture_checkpoint = Checkpointer(capture=True)
8
9
  memory_checkpoint = Checkpointer(format="memory", verbosity=0)
9
10
  tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
@@ -7,16 +7,13 @@ from datetime import datetime
7
7
  from functools import update_wrapper
8
8
  from .types import Storage
9
9
  from .function_body import get_function_hash
10
- from .utils import unwrap_fn, sync_resolve_coroutine
11
- from .storages.pickle_storage import PickleStorage
12
- from .storages.memory_storage import MemoryStorage
13
- from .storages.bcolz_storage import BcolzStorage
10
+ from .utils import unwrap_fn, sync_resolve_coroutine, resolved_awaitable
11
+ from .storages import STORAGE_MAP
14
12
  from .print_checkpoint import print_checkpoint
15
13
 
16
14
  Fn = TypeVar("Fn", bound=Callable)
17
15
 
18
16
  DEFAULT_DIR = Path.home() / ".cache/checkpoints"
19
- STORAGE_MAP: dict[str, Type[Storage]] = {"memory": MemoryStorage, "pickle": PickleStorage, "bcolz": BcolzStorage}
20
17
 
21
18
  class CheckpointError(Exception):
22
19
  pass
@@ -28,6 +25,7 @@ class CheckpointerOpts(TypedDict, total=False):
28
25
  verbosity: Literal[0, 1]
29
26
  path: Callable[..., str] | None
30
27
  should_expire: Callable[[datetime], bool] | None
28
+ capture: bool
31
29
 
32
30
  class Checkpointer:
33
31
  def __init__(self, **opts: Unpack[CheckpointerOpts]):
@@ -37,6 +35,7 @@ class Checkpointer:
37
35
  self.verbosity = opts.get("verbosity", 1)
38
36
  self.path = opts.get("path")
39
37
  self.should_expire = opts.get("should_expire")
38
+ self.capture = opts.get("capture", False)
40
39
 
41
40
  @overload
42
41
  def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> CheckpointFn[Fn]: ...
@@ -57,14 +56,16 @@ class CheckpointFn(Generic[Fn]):
57
56
  storage = STORAGE_MAP[checkpointer.format] if isinstance(checkpointer.format, str) else checkpointer.format
58
57
  self.checkpointer = checkpointer
59
58
  self.fn = fn
60
- self.fn_hash = get_function_hash(wrapped)
59
+ self.fn_hash, self.depends = get_function_hash(wrapped, self.checkpointer.capture)
61
60
  self.fn_id = f"{file_name}/{wrapped.__name__}"
62
61
  self.is_async = inspect.iscoroutinefunction(wrapped)
63
62
  self.storage = storage(checkpointer)
64
63
 
65
64
  def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
66
65
  if not callable(self.checkpointer.path):
67
- return f"{self.fn_id}/{hashing.hash([self.fn_hash, args, kw or 0])}"
66
+ # TODO: use digest size before digesting instead of truncating the hash
67
+ call_hash = hashing.hash((self.fn_hash, args, kw), "blake2b")[:32]
68
+ return f"{self.fn_id}/{call_hash}"
68
69
  checkpoint_id = self.checkpointer.path(*args, **kw)
69
70
  if not isinstance(checkpoint_id, str):
70
71
  raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
@@ -101,12 +102,17 @@ class CheckpointFn(Generic[Fn]):
101
102
  coroutine = self._store_on_demand(args, kw, rerun)
102
103
  return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
103
104
 
104
- __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
105
- rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
106
-
107
- def get(self, *args, **kw) -> Any:
105
+ def _get(self, args, kw) -> Any:
108
106
  checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
109
107
  try:
110
- return self.storage.load(checkpoint_path)
108
+ val = self.storage.load(checkpoint_path)
109
+ return resolved_awaitable(val) if self.is_async else val
111
110
  except:
112
111
  raise CheckpointError("Could not load checkpoint")
112
+
113
+ def exists(self, *args: tuple, **kw: dict) -> bool:
114
+ return self.storage.exists(self.checkpointer.root_path / self.get_checkpoint_id(args, kw))
115
+
116
+ __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
117
+ rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
118
+ get: Fn = cast(Fn, lambda self, *args, **kw: self._get(args, kw))
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+ import dis
3
+ import inspect
4
+ import tokenize
5
+ from io import StringIO
6
+ from collections.abc import Callable
7
+ from itertools import chain, takewhile
8
+ from operator import itemgetter
9
+ from pathlib import Path
10
+ from typing import Any, TypeGuard, TYPE_CHECKING
11
+ from types import CodeType, FunctionType
12
+ from relib import transpose, hashing, merge_dicts, drop_none
13
+ from .utils import unwrap_fn, iterate_and_upcoming, get_cell_contents, AttrDict, get_at_attr
14
+
15
+ if TYPE_CHECKING:
16
+ from .checkpoint import CheckpointFn
17
+
18
+ cwd = Path.cwd()
19
+
20
+ def extract_scope_values(code: CodeType, scope_vars: dict[str, Any], closure = False) -> dict[tuple[str, ...], Any]:
21
+ opname = "LOAD_GLOBAL" if not closure else "LOAD_DEREF"
22
+ scope_values_by_path: dict[tuple[str, ...], Any] = {}
23
+ instructions = list(dis.get_instructions(code))
24
+
25
+ for instr, upcoming_instrs in iterate_and_upcoming(instructions):
26
+ if instr.opname == opname:
27
+ name = instr.argval
28
+ attrs = takewhile(lambda instr: instr.opname == "LOAD_ATTR", upcoming_instrs)
29
+ attr_path = (name, *(instr.argval for instr in attrs))
30
+ scope_values_by_path[attr_path] = get_at_attr(scope_vars, attr_path)
31
+
32
+ children = (extract_scope_values(const, scope_vars, closure) for const in code.co_consts if isinstance(const, CodeType))
33
+ return merge_dicts(scope_values_by_path, *children)
34
+
35
+ def get_fn_captured_vals(fn: Callable) -> list[Any]:
36
+ closure_scope = {k: get_cell_contents(v) for k, v in zip(fn.__code__.co_freevars, fn.__closure__ or [])}
37
+ global_vals = extract_scope_values(fn.__code__, AttrDict(fn.__globals__), closure=False)
38
+ closure_vals = extract_scope_values(fn.__code__, AttrDict(closure_scope), closure=True)
39
+ sorted_items = chain(sorted(global_vals.items()), sorted(closure_vals.items()))
40
+ return drop_none(map(itemgetter(1), sorted_items))
41
+
42
+ def get_fn_body(fn: Callable) -> str:
43
+ source = "".join(inspect.getsourcelines(fn)[0])
44
+ tokens = tokenize.generate_tokens(StringIO(source).readline)
45
+ ignore_types = (tokenize.COMMENT, tokenize.NL)
46
+ return "".join("\0" + token.string for token in tokens if token.type not in ignore_types)
47
+
48
+ def get_fn_path(fn: Callable) -> Path:
49
+ return Path(inspect.getfile(fn)).resolve()
50
+
51
+ def is_user_fn(candidate_fn) -> TypeGuard[Callable]:
52
+ return isinstance(candidate_fn, FunctionType) \
53
+ and cwd in get_fn_path(candidate_fn).parents
54
+
55
+ def append_fn_depends(checkpoint_fns: set[CheckpointFn], captured_vals_by_fn: dict[Callable, list[Any]], fn: Callable, capture: bool) -> None:
56
+ from .checkpoint import CheckpointFn
57
+ captured_vals = get_fn_captured_vals(fn)
58
+ captured_vals_by_fn[fn] = [v for v in captured_vals if capture and not callable(v)]
59
+ callables = [unwrap_fn(val, checkpoint_fn=True) for val in captured_vals if callable(val)]
60
+ depends = {val for val in callables if is_user_fn(val)}
61
+ checkpoint_fns.update({val for val in callables if isinstance(val, CheckpointFn)})
62
+ not_appended = depends - captured_vals_by_fn.keys()
63
+ captured_vals_by_fn.update({fn: [] for fn in not_appended})
64
+ for child_fn in not_appended:
65
+ append_fn_depends(checkpoint_fns, captured_vals_by_fn, child_fn, capture)
66
+
67
+ def get_depend_fns(fn: Callable, capture: bool) -> tuple[set[CheckpointFn], dict[Callable, list[Any]]]:
68
+ checkpoint_fns: set[CheckpointFn] = set()
69
+ captured_vals_by_fn: dict[Callable, list[Any]] = {}
70
+ append_fn_depends(checkpoint_fns, captured_vals_by_fn, fn, capture)
71
+ return checkpoint_fns, captured_vals_by_fn
72
+
73
+ def get_function_hash(fn: Callable, capture: bool) -> tuple[str, list[Callable]]:
74
+ checkpoint_fns, captured_vals_by_fn = get_depend_fns(fn, capture)
75
+ checkpoint_fns = sorted(checkpoint_fns, key=lambda fn: unwrap_fn(fn).__qualname__)
76
+ checkpoint_hashes = [check.fn_hash for check in checkpoint_fns]
77
+ depend_fns, depend_captured_vals = transpose(sorted(captured_vals_by_fn.items(), key=lambda x: x[0].__qualname__), 2)
78
+ fn_bodies = list(map(get_fn_body, [fn] + depend_fns))
79
+ fn_hash = hashing.hash((fn_bodies, depend_captured_vals, checkpoint_hashes), "blake2b")
80
+ return fn_hash, checkpoint_fns + depend_fns
@@ -0,0 +1,11 @@
1
+ from typing import Type
2
+ from ..types import Storage
3
+ from .pickle_storage import PickleStorage
4
+ from .memory_storage import MemoryStorage
5
+ from .bcolz_storage import BcolzStorage
6
+
7
+ STORAGE_MAP: dict[str, Type[Storage]] = {
8
+ "pickle": PickleStorage,
9
+ "memory": MemoryStorage,
10
+ "bcolz": BcolzStorage,
11
+ }
@@ -0,0 +1,52 @@
1
+ from typing import Generator, Coroutine, Iterable, Any, cast
2
+ from types import CellType, coroutine
3
+ from itertools import islice
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+ def __getattribute__(self, name: str) -> Any:
11
+ return super().__getattribute__(name)
12
+
13
+ def unwrap_fn[T](fn: T, checkpoint_fn=False) -> T:
14
+ from .checkpoint import CheckpointFn
15
+ while hasattr(fn, "__wrapped__"):
16
+ if checkpoint_fn and isinstance(fn, CheckpointFn):
17
+ return fn
18
+ fn = getattr(fn, "__wrapped__")
19
+ return fn
20
+
21
+ @coroutine
22
+ def coroutine_as_generator[T](coroutine: Coroutine[None, None, T]) -> Generator[None, None, T]:
23
+ val = yield from coroutine
24
+ return val
25
+
26
+ def sync_resolve_coroutine[T](coroutine: Coroutine[None, None, T]) -> T:
27
+ gen = cast(Generator, coroutine_as_generator(coroutine))
28
+ try:
29
+ while True: next(gen)
30
+ except StopIteration as ex:
31
+ return ex.value
32
+
33
+ async def resolved_awaitable[T](value: T) -> T:
34
+ return value
35
+
36
+ def iterate_and_upcoming[T](l: list[T]) -> Iterable[tuple[T, Iterable[T]]]:
37
+ for i, item in enumerate(l):
38
+ yield item, islice(l, i + 1, None)
39
+
40
+ def get_at_attr(d: dict, keys: tuple[str, ...]) -> Any:
41
+ try:
42
+ for key in keys:
43
+ d = getattr(d, key)
44
+ except AttributeError:
45
+ return None
46
+ return d
47
+
48
+ def get_cell_contents(cell: CellType) -> Any:
49
+ try:
50
+ return cell.cell_contents
51
+ except ValueError:
52
+ return None
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "checkpointer"
3
- version = "2.0.2"
3
+ version = "2.1.0"
4
4
  requires-python = ">=3.12"
5
5
  dependencies = [
6
6
  "relib",
@@ -3,7 +3,7 @@ requires-python = ">=3.12"
3
3
 
4
4
  [[package]]
5
5
  name = "checkpointer"
6
- version = "2.0.2"
6
+ version = "2.1.0"
7
7
  source = { editable = "." }
8
8
  dependencies = [
9
9
  { name = "relib" },
@@ -1,46 +0,0 @@
1
- import inspect
2
- import relib.hashing as hashing
3
- from collections.abc import Callable
4
- from types import FunctionType, CodeType
5
- from pathlib import Path
6
- from .utils import unwrap_fn
7
-
8
- cwd = Path.cwd()
9
-
10
- def get_fn_path(fn: Callable) -> Path:
11
- return Path(inspect.getfile(fn)).resolve()
12
-
13
- def get_function_body(fn: Callable) -> str:
14
- # TODO: Strip comments
15
- lines = inspect.getsourcelines(fn)[0]
16
- lines = [line.rstrip() for line in lines]
17
- lines = [line for line in lines if line]
18
- return "\n".join(lines)
19
-
20
- def get_code_children(code: CodeType) -> list[str]:
21
- consts = [const for const in code.co_consts if isinstance(const, CodeType)]
22
- children = [child for const in consts for child in get_code_children(const)]
23
- return list(code.co_names) + children
24
-
25
- def is_user_fn(candidate_fn, cleared_fns: set[Callable]) -> bool:
26
- return isinstance(candidate_fn, FunctionType) \
27
- and candidate_fn not in cleared_fns \
28
- and cwd in get_fn_path(candidate_fn).parents
29
-
30
- def append_fn_children(cleared_fns: set[Callable], fn: Callable) -> None:
31
- code_children = get_code_children(fn.__code__)
32
- fn_children = [unwrap_fn(fn.__globals__.get(co_name, None)) for co_name in code_children]
33
- fn_children = [child for child in fn_children if is_user_fn(child, cleared_fns)]
34
- cleared_fns.update(fn_children)
35
- for child_fn in fn_children:
36
- append_fn_children(cleared_fns, child_fn)
37
-
38
- def get_fn_children(fn: Callable) -> list[Callable]:
39
- cleared_fns: set[Callable] = set()
40
- append_fn_children(cleared_fns, fn)
41
- return sorted(cleared_fns, key=lambda fn: fn.__name__)
42
-
43
- def get_function_hash(fn: Callable) -> str:
44
- fns = [fn] + get_fn_children(fn)
45
- fn_bodies = list(map(get_function_body, fns))
46
- return hashing.hash(fn_bodies)
@@ -1,17 +0,0 @@
1
- import types
2
-
3
- def unwrap_fn[T](fn: T) -> T:
4
- while hasattr(fn, "__wrapped__"):
5
- fn = getattr(fn, "__wrapped__")
6
- return fn
7
-
8
- @types.coroutine
9
- def coroutine_as_generator(coroutine):
10
- val = yield from coroutine
11
- return val
12
-
13
- def sync_resolve_coroutine(coroutine):
14
- try:
15
- next(coroutine_as_generator(coroutine))
16
- except StopIteration as ex:
17
- return ex.value
File without changes
File without changes