checkpointer 2.11.2__py3-none-any.whl → 2.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpointer/__init__.py +4 -4
- checkpointer/checkpoint.py +125 -107
- checkpointer/fn_ident.py +129 -71
- checkpointer/fn_string.py +77 -0
- checkpointer/import_mappings.py +47 -0
- checkpointer/object_hash.py +37 -29
- checkpointer/storages/__init__.py +6 -4
- checkpointer/storages/storage.py +5 -4
- checkpointer/types.py +30 -2
- checkpointer/utils.py +84 -54
- checkpointer-2.13.0.dist-info/METADATA +260 -0
- checkpointer-2.13.0.dist-info/RECORD +18 -0
- checkpointer-2.13.0.dist-info/licenses/ATTRIBUTION.md +33 -0
- {checkpointer-2.11.2.dist-info → checkpointer-2.13.0.dist-info}/licenses/LICENSE +2 -0
- checkpointer/test_checkpointer.py +0 -168
- checkpointer-2.11.2.dist-info/METADATA +0 -240
- checkpointer-2.11.2.dist-info/RECORD +0 -16
- {checkpointer-2.11.2.dist-info → checkpointer-2.13.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,77 @@
|
|
1
|
+
import ast
|
2
|
+
import sys
|
3
|
+
from inspect import getsource
|
4
|
+
from textwrap import dedent
|
5
|
+
from typing import Callable
|
6
|
+
from .utils import drop_none, get_at
|
7
|
+
|
8
|
+
def get_decorator_path(node: ast.AST) -> tuple[str, ...]:
|
9
|
+
if isinstance(node, ast.Call):
|
10
|
+
return get_decorator_path(node.func)
|
11
|
+
elif isinstance(node, ast.Attribute):
|
12
|
+
return get_decorator_path(node.value) + (node.attr,)
|
13
|
+
elif isinstance(node, ast.Name):
|
14
|
+
return (node.id,)
|
15
|
+
else:
|
16
|
+
return ()
|
17
|
+
|
18
|
+
def is_empty_expression(node: ast.AST) -> bool:
|
19
|
+
# Filter out docstrings
|
20
|
+
return isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant)
|
21
|
+
|
22
|
+
class CleanFunctionTransform(ast.NodeTransformer):
|
23
|
+
def __init__(self, fn_globals: dict):
|
24
|
+
self.is_root = True
|
25
|
+
self.fn_globals = fn_globals
|
26
|
+
|
27
|
+
def is_checkpointer(self, node: ast.AST) -> bool:
|
28
|
+
from .checkpoint import Checkpointer
|
29
|
+
return isinstance(get_at(self.fn_globals, *get_decorator_path(node)), Checkpointer)
|
30
|
+
|
31
|
+
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
|
32
|
+
fn_type = type(node).__name__
|
33
|
+
fn_name = None if self.is_root else node.name
|
34
|
+
args_by_type = [
|
35
|
+
node.args.posonlyargs + node.args.args,
|
36
|
+
drop_none([node.args.vararg]),
|
37
|
+
sorted(node.args.kwonlyargs, key=lambda x: x.arg),
|
38
|
+
drop_none([node.args.kwarg]),
|
39
|
+
]
|
40
|
+
arg_kind_names = ",".join(f"{i}:{arg.arg}" for i, args in enumerate(args_by_type) for arg in args)
|
41
|
+
header = " ".join(drop_none((fn_type, fn_name, arg_kind_names or None)))
|
42
|
+
|
43
|
+
self.is_root = False
|
44
|
+
|
45
|
+
return ast.List([
|
46
|
+
ast.Constant(header),
|
47
|
+
ast.List([child for child in node.decorator_list if not self.is_checkpointer(child)], ast.Load()),
|
48
|
+
ast.List([self.visit(child) for child in node.body if not is_empty_expression(child)], ast.Load()),
|
49
|
+
], ast.Load())
|
50
|
+
|
51
|
+
def visit_AsyncFunctionDef(self, node):
|
52
|
+
return self.visit_FunctionDef(node)
|
53
|
+
|
54
|
+
def get_fn_aststr(fn: Callable) -> str:
|
55
|
+
try:
|
56
|
+
source = getsource(fn)
|
57
|
+
except OSError:
|
58
|
+
return ""
|
59
|
+
try:
|
60
|
+
tree = ast.parse(dedent(source), mode="exec")
|
61
|
+
tree = tree.body[0]
|
62
|
+
except SyntaxError:
|
63
|
+
# lambda functions can cause SyntaxError in ast.parse
|
64
|
+
return source.strip()
|
65
|
+
|
66
|
+
if fn.__name__ != "<lambda>":
|
67
|
+
tree = CleanFunctionTransform(fn.__globals__).visit(tree)
|
68
|
+
else:
|
69
|
+
for node in ast.walk(tree):
|
70
|
+
if isinstance(node, ast.Lambda):
|
71
|
+
tree = node
|
72
|
+
break
|
73
|
+
|
74
|
+
if sys.version_info >= (3, 13):
|
75
|
+
return ast.dump(tree, annotate_fields=False, show_empty=True)
|
76
|
+
else:
|
77
|
+
return ast.dump(tree, annotate_fields=False)
|
@@ -0,0 +1,47 @@
|
|
1
|
+
import ast
|
2
|
+
import inspect
|
3
|
+
import sys
|
4
|
+
from types import ModuleType
|
5
|
+
from typing import Iterable, Type
|
6
|
+
from .utils import cwd, get_file, is_user_file
|
7
|
+
|
8
|
+
ImportTarget = tuple[str, str | None]
|
9
|
+
|
10
|
+
cache: dict[tuple[str, int], dict[str, ImportTarget]] = {}
|
11
|
+
|
12
|
+
def generate_import_mappings(module: ModuleType) -> Iterable[tuple[str, ImportTarget]]:
|
13
|
+
mod_path = get_file(module)
|
14
|
+
if not is_user_file(mod_path):
|
15
|
+
return
|
16
|
+
mod_parts = list(mod_path.with_suffix("").relative_to(cwd).parts)
|
17
|
+
source = inspect.getsource(module)
|
18
|
+
tree = ast.parse(source)
|
19
|
+
for node in ast.walk(tree):
|
20
|
+
if isinstance(node, ast.Import):
|
21
|
+
for alias in node.names:
|
22
|
+
yield (alias.asname or alias.name, (alias.name, None))
|
23
|
+
elif isinstance(node, ast.ImportFrom):
|
24
|
+
target_mod = node.module or ""
|
25
|
+
if node.level > 0:
|
26
|
+
target_mod_parts = target_mod.split(".") * bool(target_mod)
|
27
|
+
target_mod_parts = mod_parts[:-node.level] + target_mod_parts
|
28
|
+
target_mod = ".".join(target_mod_parts)
|
29
|
+
for alias in node.names:
|
30
|
+
yield (alias.asname or alias.name, (target_mod, alias.name))
|
31
|
+
|
32
|
+
def get_import_mappings(module: ModuleType) -> dict[str, ImportTarget]:
|
33
|
+
cache_key = (module.__name__, id(module))
|
34
|
+
if cached := cache.get(cache_key):
|
35
|
+
return cached
|
36
|
+
import_mappings = dict(generate_import_mappings(module))
|
37
|
+
return cache.setdefault(cache_key, import_mappings)
|
38
|
+
|
39
|
+
def resolve_annotation(module: ModuleType, attr_name: str | None) -> Type | None:
|
40
|
+
if not attr_name:
|
41
|
+
return None
|
42
|
+
if anno := module.__annotations__.get(attr_name):
|
43
|
+
return anno
|
44
|
+
if next_pair := get_import_mappings(module).get(attr_name):
|
45
|
+
next_module_name, next_attr_name = next_pair
|
46
|
+
if next_module := sys.modules.get(next_module_name):
|
47
|
+
return resolve_annotation(next_module, next_attr_name)
|
checkpointer/object_hash.py
CHANGED
@@ -5,32 +5,35 @@ import io
|
|
5
5
|
import re
|
6
6
|
import sys
|
7
7
|
import tokenize
|
8
|
+
from collections import OrderedDict
|
8
9
|
from collections.abc import Iterable
|
9
10
|
from contextlib import nullcontext, suppress
|
10
11
|
from decimal import Decimal
|
11
12
|
from io import StringIO
|
12
13
|
from itertools import chain
|
13
14
|
from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
|
14
|
-
from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
|
15
|
-
from typing import Callable, TypeVar
|
15
|
+
from types import BuiltinFunctionType, FunctionType, GeneratorType, MappingProxyType, MethodType, ModuleType, UnionType
|
16
|
+
from typing import Callable, Self, TypeVar
|
16
17
|
from .utils import ContextVar
|
17
18
|
|
18
19
|
np, torch = None, None
|
19
20
|
|
20
|
-
with suppress(Exception):
|
21
|
-
import numpy as np
|
22
|
-
with suppress(Exception):
|
23
|
-
import torch
|
24
|
-
|
25
21
|
class _Never:
|
26
22
|
def __getattribute__(self, _: str):
|
27
23
|
pass
|
28
24
|
|
25
|
+
with suppress(Exception):
|
26
|
+
import numpy as np
|
27
|
+
with suppress(Exception):
|
28
|
+
import torch
|
29
29
|
if sys.version_info >= (3, 12):
|
30
30
|
from typing import TypeAliasType
|
31
31
|
else:
|
32
32
|
TypeAliasType = _Never
|
33
33
|
|
34
|
+
flatten = chain.from_iterable
|
35
|
+
nc = nullcontext()
|
36
|
+
|
34
37
|
def encode_type(t: type | FunctionType) -> str:
|
35
38
|
return f"{t.__module__}:{t.__qualname__}"
|
36
39
|
|
@@ -43,14 +46,14 @@ class ObjectHashError(Exception):
|
|
43
46
|
self.obj = obj
|
44
47
|
|
45
48
|
class ObjectHash:
|
46
|
-
def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64,
|
49
|
+
def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerable=False) -> None:
|
47
50
|
self.hash = hashlib.blake2b(digest_size=digest_size)
|
48
51
|
self.current: dict[int, int] = {}
|
49
|
-
self.
|
52
|
+
self.tolerable = ContextVar(tolerable)
|
50
53
|
self.update(iter=chain(objs, iter))
|
51
54
|
|
52
55
|
def copy(self) -> "ObjectHash":
|
53
|
-
new = ObjectHash(
|
56
|
+
new = ObjectHash(tolerable=self.tolerable.value)
|
54
57
|
new.hash = self.hash.copy()
|
55
58
|
return new
|
56
59
|
|
@@ -63,26 +66,29 @@ class ObjectHash:
|
|
63
66
|
return isinstance(value, ObjectHash) and str(self) == str(value)
|
64
67
|
|
65
68
|
def nested_hash(self, *objs: object) -> str:
|
66
|
-
return ObjectHash(iter=objs,
|
69
|
+
return ObjectHash(iter=objs, tolerable=self.tolerable.value).hexdigest()
|
67
70
|
|
68
|
-
def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) ->
|
71
|
+
def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> Self:
|
69
72
|
for d in chain(data, iter):
|
70
73
|
self.hash.update(d)
|
71
74
|
return self
|
72
75
|
|
73
|
-
def write_text(self, *data: str, iter: Iterable[str] = ()) ->
|
76
|
+
def write_text(self, *data: str, iter: Iterable[str] = ()) -> Self:
|
74
77
|
return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
|
75
78
|
|
76
|
-
def header(self, *args: object) ->
|
79
|
+
def header(self, *args: object) -> Self:
|
77
80
|
return self.write_bytes(":".join(map(str, args)).encode())
|
78
81
|
|
79
|
-
def update(self, *objs: object, iter: Iterable[object] = (),
|
80
|
-
with
|
82
|
+
def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
|
83
|
+
with nc if tolerable is None else self.tolerable.set(tolerable):
|
81
84
|
for obj in chain(objs, iter):
|
85
|
+
if header is not None:
|
86
|
+
self.write_bytes(header.encode())
|
87
|
+
header = None
|
82
88
|
try:
|
83
89
|
self._update_one(obj)
|
84
90
|
except Exception as ex:
|
85
|
-
if self.
|
91
|
+
if self.tolerable.value:
|
86
92
|
self.header("error").update(type(ex))
|
87
93
|
else:
|
88
94
|
raise ObjectHashError(obj, ex) from ex
|
@@ -102,11 +108,11 @@ class ObjectHash:
|
|
102
108
|
|
103
109
|
case set() | frozenset():
|
104
110
|
try:
|
105
|
-
items = sorted(obj)
|
106
111
|
header = "set"
|
112
|
+
items = sorted(obj)
|
107
113
|
except:
|
108
|
-
items = sorted(map(self.nested_hash, obj))
|
109
114
|
header = "set-unsortable"
|
115
|
+
items = sorted(map(self.nested_hash, obj))
|
110
116
|
self.header(header, encode_type_of(obj), len(obj)).update(iter=items)
|
111
117
|
|
112
118
|
case TypeVar():
|
@@ -167,23 +173,25 @@ class ObjectHash:
|
|
167
173
|
match obj:
|
168
174
|
case list() | tuple():
|
169
175
|
self.header("list", encode_type_of(obj), len(obj)).update(iter=obj)
|
170
|
-
case dict():
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
176
|
+
case dict() | MappingProxyType():
|
177
|
+
header = "dict"
|
178
|
+
items = obj.items()
|
179
|
+
if not isinstance(obj, OrderedDict):
|
180
|
+
try:
|
181
|
+
items = sorted(items)
|
182
|
+
except:
|
183
|
+
header = "dict-unsortable"
|
184
|
+
items = sorted((self.nested_hash(key), val) for key, val in items)
|
185
|
+
self.header(header, encode_type_of(obj), len(obj)).update(iter=flatten(items))
|
178
186
|
case _:
|
179
187
|
self._update_object(obj)
|
180
188
|
finally:
|
181
189
|
del self.current[id(obj)]
|
182
190
|
|
183
|
-
def _update_iterator(self, obj: Iterable) ->
|
191
|
+
def _update_iterator(self, obj: Iterable) -> Self:
|
184
192
|
return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
|
185
193
|
|
186
|
-
def _update_object(self, obj: object) ->
|
194
|
+
def _update_object(self, obj: object) -> Self:
|
187
195
|
self.header("instance", encode_type_of(obj))
|
188
196
|
get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
|
189
197
|
if callable(get_hash):
|
@@ -1,9 +1,11 @@
|
|
1
|
-
from typing import Type
|
2
|
-
from .storage import Storage
|
3
|
-
from .pickle_storage import PickleStorage
|
1
|
+
from typing import Literal, Type
|
4
2
|
from .memory_storage import MemoryStorage
|
3
|
+
from .pickle_storage import PickleStorage
|
4
|
+
from .storage import Storage
|
5
|
+
|
6
|
+
StorageType = Literal["pickle", "memory"]
|
5
7
|
|
6
|
-
STORAGE_MAP: dict[
|
8
|
+
STORAGE_MAP: dict[StorageType, Type[Storage]] = {
|
7
9
|
"pickle": PickleStorage,
|
8
10
|
"memory": MemoryStorage,
|
9
11
|
}
|
checkpointer/storages/storage.py
CHANGED
@@ -8,17 +8,18 @@ if TYPE_CHECKING:
|
|
8
8
|
|
9
9
|
class Storage:
|
10
10
|
checkpointer: Checkpointer
|
11
|
-
|
11
|
+
ident: CachedFunction
|
12
12
|
|
13
13
|
def __init__(self, cached_fn: CachedFunction):
|
14
|
-
self.checkpointer = cached_fn.checkpointer
|
14
|
+
self.checkpointer = cached_fn.ident.checkpointer
|
15
15
|
self.cached_fn = cached_fn
|
16
16
|
|
17
17
|
def fn_id(self) -> str:
|
18
|
-
|
18
|
+
ident = self.cached_fn.ident
|
19
|
+
return f"{ident.fn_dir}/{ident.fn_hash}"
|
19
20
|
|
20
21
|
def fn_dir(self) -> Path:
|
21
|
-
return self.checkpointer.
|
22
|
+
return self.checkpointer.directory / self.fn_id()
|
22
23
|
|
23
24
|
def store(self, call_hash: str, data: Any) -> Any: ...
|
24
25
|
|
checkpointer/types.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import (
|
2
|
+
Annotated, Callable, Coroutine, Generic,
|
3
|
+
ParamSpec, TypeVar, get_args, get_origin,
|
4
|
+
)
|
2
5
|
|
3
6
|
Fn = TypeVar("Fn", bound=Callable)
|
4
7
|
P = ParamSpec("P")
|
@@ -9,7 +12,32 @@ T = TypeVar("T")
|
|
9
12
|
class HashBy(Generic[Fn]):
|
10
13
|
pass
|
11
14
|
|
12
|
-
|
15
|
+
class Captured:
|
16
|
+
pass
|
17
|
+
|
18
|
+
class CapturedOnce:
|
19
|
+
pass
|
20
|
+
|
21
|
+
def to_none(_):
|
22
|
+
return None
|
23
|
+
|
24
|
+
def get_annotated_args(anno: object) -> tuple[object, ...]:
|
25
|
+
return get_args(anno) if get_origin(anno) is Annotated else ()
|
26
|
+
|
27
|
+
def hash_by_from_annotation(anno: object) -> Callable[[object], object] | None:
|
28
|
+
for arg in get_annotated_args(anno):
|
29
|
+
if get_origin(arg) is HashBy:
|
30
|
+
return get_args(arg)[0]
|
31
|
+
|
32
|
+
def is_capture_me(anno: object) -> bool:
|
33
|
+
return Captured in get_annotated_args(anno)
|
34
|
+
|
35
|
+
def is_capture_me_once(anno: object) -> bool:
|
36
|
+
return CapturedOnce in get_annotated_args(anno)
|
37
|
+
|
38
|
+
NoHash = Annotated[T, HashBy[to_none]]
|
39
|
+
CaptureMe = Annotated[T, Captured]
|
40
|
+
CaptureMeOnce = Annotated[T, CapturedOnce]
|
13
41
|
Coro = Coroutine[object, object, R]
|
14
42
|
|
15
43
|
class AwaitableValue(Generic[T]):
|
checkpointer/utils.py
CHANGED
@@ -1,23 +1,93 @@
|
|
1
|
-
from
|
2
|
-
|
3
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
import inspect
|
3
|
+
from contextlib import contextmanager, suppress
|
4
|
+
from itertools import islice
|
5
|
+
from pathlib import Path
|
6
|
+
from types import FunctionType, MethodType, ModuleType
|
7
|
+
from typing import Callable, Generic, Iterable, Self, Type, TypeGuard
|
8
|
+
from .types import T
|
4
9
|
|
5
|
-
|
6
|
-
|
10
|
+
cwd = Path.cwd().resolve()
|
11
|
+
|
12
|
+
def is_class(obj) -> TypeGuard[Type]:
|
13
|
+
return isinstance(obj, type)
|
14
|
+
|
15
|
+
def get_file(obj: Callable | ModuleType) -> Path:
|
16
|
+
return Path(inspect.getfile(obj)).resolve()
|
17
|
+
|
18
|
+
def is_user_file(path: Path) -> bool:
|
19
|
+
return cwd in path.parents and ".venv" not in path.parts
|
20
|
+
|
21
|
+
def is_user_fn(obj) -> TypeGuard[Callable]:
|
22
|
+
return isinstance(obj, (FunctionType, MethodType)) and is_user_file(get_file(obj))
|
7
23
|
|
8
24
|
def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
|
9
25
|
for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
|
10
|
-
|
26
|
+
with suppress(ValueError):
|
11
27
|
yield (key, cell.cell_contents)
|
12
|
-
except ValueError:
|
13
|
-
pass
|
14
28
|
|
15
|
-
def
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
29
|
+
def drop_none(iterable: Iterable[T | None]) -> list[T]:
|
30
|
+
return [x for x in iterable if x is not None]
|
31
|
+
|
32
|
+
def distinct(seq: Iterable[T]) -> list[T]:
|
33
|
+
return list(dict.fromkeys(seq))
|
34
|
+
|
35
|
+
def takewhile(iter: Iterable[tuple[bool, T]]) -> Iterable[T]:
|
36
|
+
for condition, value in iter:
|
37
|
+
if not condition:
|
38
|
+
return
|
39
|
+
yield value
|
40
|
+
|
41
|
+
class seekable(Generic[T]):
|
42
|
+
def __init__(self, iterable: Iterable[T]):
|
43
|
+
self.index = 0
|
44
|
+
self.source = iter(iterable)
|
45
|
+
self.sink: list[T] = []
|
46
|
+
|
47
|
+
def __iter__(self):
|
48
|
+
return self
|
49
|
+
|
50
|
+
def __next__(self) -> T:
|
51
|
+
if len(self.sink) > self.index:
|
52
|
+
item = self.sink[self.index]
|
53
|
+
else:
|
54
|
+
item = next(self.source)
|
55
|
+
self.sink.append(item)
|
56
|
+
self.index += 1
|
57
|
+
return item
|
58
|
+
|
59
|
+
def __bool__(self):
|
60
|
+
return bool(self.lookahead(1))
|
61
|
+
|
62
|
+
def seek(self, index: int) -> Self:
|
63
|
+
remainder = index - len(self.sink)
|
64
|
+
if remainder > 0:
|
65
|
+
next(islice(self, remainder, remainder), None)
|
66
|
+
self.index = max(0, min(index, len(self.sink)))
|
67
|
+
return self
|
68
|
+
|
69
|
+
def step(self, count: int) -> Self:
|
70
|
+
return self.seek(self.index + count)
|
71
|
+
|
72
|
+
@contextmanager
|
73
|
+
def freeze(self):
|
74
|
+
initial_index = self.index
|
75
|
+
try:
|
76
|
+
yield
|
77
|
+
finally:
|
78
|
+
self.seek(initial_index)
|
79
|
+
|
80
|
+
def lookahead(self, count: int) -> list[T]:
|
81
|
+
with self.freeze():
|
82
|
+
return list(islice(self, count))
|
83
|
+
|
84
|
+
def get_at(obj: object, *attrs: str) -> object:
|
85
|
+
for attr in attrs:
|
86
|
+
if type(obj) is dict:
|
87
|
+
obj = obj.get(attr, None)
|
88
|
+
else:
|
89
|
+
obj = getattr(obj, attr, None)
|
90
|
+
return obj
|
21
91
|
|
22
92
|
class AttrDict(dict):
|
23
93
|
def __init__(self, *args, **kwargs):
|
@@ -30,23 +100,6 @@ class AttrDict(dict):
|
|
30
100
|
def __setattr__(self, name: str, value: object):
|
31
101
|
super().__setattr__(name, value)
|
32
102
|
|
33
|
-
def set(self, d: dict) -> "AttrDict":
|
34
|
-
if not d:
|
35
|
-
return self
|
36
|
-
return AttrDict({**self, **d})
|
37
|
-
|
38
|
-
def delete(self, *attrs: str) -> "AttrDict":
|
39
|
-
d = AttrDict(self)
|
40
|
-
for attr in attrs:
|
41
|
-
del d[attr]
|
42
|
-
return d
|
43
|
-
|
44
|
-
def get_at(self, attrs: tuple[str, ...]) -> object:
|
45
|
-
d = self
|
46
|
-
for attr in attrs:
|
47
|
-
d = getattr(d, attr, None)
|
48
|
-
return d
|
49
|
-
|
50
103
|
class ContextVar(Generic[T]):
|
51
104
|
def __init__(self, value: T):
|
52
105
|
self.value = value
|
@@ -58,26 +111,3 @@ class ContextVar(Generic[T]):
|
|
58
111
|
yield
|
59
112
|
finally:
|
60
113
|
self.value = old
|
61
|
-
|
62
|
-
class iterate_and_upcoming(Generic[T]):
|
63
|
-
def __init__(self, it: Iterable[T]) -> None:
|
64
|
-
self.it = iter(it)
|
65
|
-
self.previous: tuple[()] | tuple[T] = ()
|
66
|
-
self.tracked = self._tracked_iter()
|
67
|
-
|
68
|
-
def __iter__(self):
|
69
|
-
return self
|
70
|
-
|
71
|
-
def __next__(self) -> tuple[T, Iterable[T]]:
|
72
|
-
try:
|
73
|
-
item = self.previous[0] if self.previous else next(self.it)
|
74
|
-
self.previous = ()
|
75
|
-
return item, self.tracked
|
76
|
-
except StopIteration:
|
77
|
-
self.tracked.close()
|
78
|
-
raise
|
79
|
-
|
80
|
-
def _tracked_iter(self):
|
81
|
-
for x in self.it:
|
82
|
-
self.previous = (x,)
|
83
|
-
yield x
|