checkpointer 2.11.2__py3-none-any.whl → 2.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,77 @@
1
+ import ast
2
+ import sys
3
+ from inspect import getsource
4
+ from textwrap import dedent
5
+ from typing import Callable
6
+ from .utils import drop_none, get_at
7
+
8
+ def get_decorator_path(node: ast.AST) -> tuple[str, ...]:
9
+ if isinstance(node, ast.Call):
10
+ return get_decorator_path(node.func)
11
+ elif isinstance(node, ast.Attribute):
12
+ return get_decorator_path(node.value) + (node.attr,)
13
+ elif isinstance(node, ast.Name):
14
+ return (node.id,)
15
+ else:
16
+ return ()
17
+
18
+ def is_empty_expression(node: ast.AST) -> bool:
19
+ # Filter out docstrings
20
+ return isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant)
21
+
22
+ class CleanFunctionTransform(ast.NodeTransformer):
23
+ def __init__(self, fn_globals: dict):
24
+ self.is_root = True
25
+ self.fn_globals = fn_globals
26
+
27
+ def is_checkpointer(self, node: ast.AST) -> bool:
28
+ from .checkpoint import Checkpointer
29
+ return isinstance(get_at(self.fn_globals, *get_decorator_path(node)), Checkpointer)
30
+
31
+ def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
32
+ fn_type = type(node).__name__
33
+ fn_name = None if self.is_root else node.name
34
+ args_by_type = [
35
+ node.args.posonlyargs + node.args.args,
36
+ drop_none([node.args.vararg]),
37
+ sorted(node.args.kwonlyargs, key=lambda x: x.arg),
38
+ drop_none([node.args.kwarg]),
39
+ ]
40
+ arg_kind_names = ",".join(f"{i}:{arg.arg}" for i, args in enumerate(args_by_type) for arg in args)
41
+ header = " ".join(drop_none((fn_type, fn_name, arg_kind_names or None)))
42
+
43
+ self.is_root = False
44
+
45
+ return ast.List([
46
+ ast.Constant(header),
47
+ ast.List([child for child in node.decorator_list if not self.is_checkpointer(child)], ast.Load()),
48
+ ast.List([self.visit(child) for child in node.body if not is_empty_expression(child)], ast.Load()),
49
+ ], ast.Load())
50
+
51
+ def visit_AsyncFunctionDef(self, node):
52
+ return self.visit_FunctionDef(node)
53
+
54
+ def get_fn_aststr(fn: Callable) -> str:
55
+ try:
56
+ source = getsource(fn)
57
+ except OSError:
58
+ return ""
59
+ try:
60
+ tree = ast.parse(dedent(source), mode="exec")
61
+ tree = tree.body[0]
62
+ except SyntaxError:
63
+ # lambda functions can cause SyntaxError in ast.parse
64
+ return source.strip()
65
+
66
+ if fn.__name__ != "<lambda>":
67
+ tree = CleanFunctionTransform(fn.__globals__).visit(tree)
68
+ else:
69
+ for node in ast.walk(tree):
70
+ if isinstance(node, ast.Lambda):
71
+ tree = node
72
+ break
73
+
74
+ if sys.version_info >= (3, 13):
75
+ return ast.dump(tree, annotate_fields=False, show_empty=True)
76
+ else:
77
+ return ast.dump(tree, annotate_fields=False)
@@ -0,0 +1,47 @@
1
+ import ast
2
+ import inspect
3
+ import sys
4
+ from types import ModuleType
5
+ from typing import Iterable, Type
6
+ from .utils import cwd, get_file, is_user_file
7
+
8
+ ImportTarget = tuple[str, str | None]
9
+
10
+ cache: dict[tuple[str, int], dict[str, ImportTarget]] = {}
11
+
12
+ def generate_import_mappings(module: ModuleType) -> Iterable[tuple[str, ImportTarget]]:
13
+ mod_path = get_file(module)
14
+ if not is_user_file(mod_path):
15
+ return
16
+ mod_parts = list(mod_path.with_suffix("").relative_to(cwd).parts)
17
+ source = inspect.getsource(module)
18
+ tree = ast.parse(source)
19
+ for node in ast.walk(tree):
20
+ if isinstance(node, ast.Import):
21
+ for alias in node.names:
22
+ yield (alias.asname or alias.name, (alias.name, None))
23
+ elif isinstance(node, ast.ImportFrom):
24
+ target_mod = node.module or ""
25
+ if node.level > 0:
26
+ target_mod_parts = target_mod.split(".") * bool(target_mod)
27
+ target_mod_parts = mod_parts[:-node.level] + target_mod_parts
28
+ target_mod = ".".join(target_mod_parts)
29
+ for alias in node.names:
30
+ yield (alias.asname or alias.name, (target_mod, alias.name))
31
+
32
+ def get_import_mappings(module: ModuleType) -> dict[str, ImportTarget]:
33
+ cache_key = (module.__name__, id(module))
34
+ if cached := cache.get(cache_key):
35
+ return cached
36
+ import_mappings = dict(generate_import_mappings(module))
37
+ return cache.setdefault(cache_key, import_mappings)
38
+
39
+ def resolve_annotation(module: ModuleType, attr_name: str | None) -> Type | None:
40
+ if not attr_name:
41
+ return None
42
+ if anno := module.__annotations__.get(attr_name):
43
+ return anno
44
+ if next_pair := get_import_mappings(module).get(attr_name):
45
+ next_module_name, next_attr_name = next_pair
46
+ if next_module := sys.modules.get(next_module_name):
47
+ return resolve_annotation(next_module, next_attr_name)
@@ -5,32 +5,35 @@ import io
5
5
  import re
6
6
  import sys
7
7
  import tokenize
8
+ from collections import OrderedDict
8
9
  from collections.abc import Iterable
9
10
  from contextlib import nullcontext, suppress
10
11
  from decimal import Decimal
11
12
  from io import StringIO
12
13
  from itertools import chain
13
14
  from pickle import HIGHEST_PROTOCOL as PICKLE_PROTOCOL
14
- from types import BuiltinFunctionType, FunctionType, GeneratorType, MethodType, ModuleType, UnionType
15
- from typing import Callable, TypeVar
15
+ from types import BuiltinFunctionType, FunctionType, GeneratorType, MappingProxyType, MethodType, ModuleType, UnionType
16
+ from typing import Callable, Self, TypeVar
16
17
  from .utils import ContextVar
17
18
 
18
19
  np, torch = None, None
19
20
 
20
- with suppress(Exception):
21
- import numpy as np
22
- with suppress(Exception):
23
- import torch
24
-
25
21
  class _Never:
26
22
  def __getattribute__(self, _: str):
27
23
  pass
28
24
 
25
+ with suppress(Exception):
26
+ import numpy as np
27
+ with suppress(Exception):
28
+ import torch
29
29
  if sys.version_info >= (3, 12):
30
30
  from typing import TypeAliasType
31
31
  else:
32
32
  TypeAliasType = _Never
33
33
 
34
+ flatten = chain.from_iterable
35
+ nc = nullcontext()
36
+
34
37
  def encode_type(t: type | FunctionType) -> str:
35
38
  return f"{t.__module__}:{t.__qualname__}"
36
39
 
@@ -43,14 +46,14 @@ class ObjectHashError(Exception):
43
46
  self.obj = obj
44
47
 
45
48
  class ObjectHash:
46
- def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerate_errors=False) -> None:
49
+ def __init__(self, *objs: object, iter: Iterable[object] = (), digest_size=64, tolerable=False) -> None:
47
50
  self.hash = hashlib.blake2b(digest_size=digest_size)
48
51
  self.current: dict[int, int] = {}
49
- self.tolerate_errors = ContextVar(tolerate_errors)
52
+ self.tolerable = ContextVar(tolerable)
50
53
  self.update(iter=chain(objs, iter))
51
54
 
52
55
  def copy(self) -> "ObjectHash":
53
- new = ObjectHash(tolerate_errors=self.tolerate_errors.value)
56
+ new = ObjectHash(tolerable=self.tolerable.value)
54
57
  new.hash = self.hash.copy()
55
58
  return new
56
59
 
@@ -63,26 +66,29 @@ class ObjectHash:
63
66
  return isinstance(value, ObjectHash) and str(self) == str(value)
64
67
 
65
68
  def nested_hash(self, *objs: object) -> str:
66
- return ObjectHash(iter=objs, tolerate_errors=self.tolerate_errors.value).hexdigest()
69
+ return ObjectHash(iter=objs, tolerable=self.tolerable.value).hexdigest()
67
70
 
68
- def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> "ObjectHash":
71
+ def write_bytes(self, *data: bytes, iter: Iterable[bytes] = ()) -> Self:
69
72
  for d in chain(data, iter):
70
73
  self.hash.update(d)
71
74
  return self
72
75
 
73
- def write_text(self, *data: str, iter: Iterable[str] = ()) -> "ObjectHash":
76
+ def write_text(self, *data: str, iter: Iterable[str] = ()) -> Self:
74
77
  return self.write_bytes(iter=(d.encode() for d in chain(data, iter)))
75
78
 
76
- def header(self, *args: object) -> "ObjectHash":
79
+ def header(self, *args: object) -> Self:
77
80
  return self.write_bytes(":".join(map(str, args)).encode())
78
81
 
79
- def update(self, *objs: object, iter: Iterable[object] = (), tolerate_errors: bool | None=None) -> "ObjectHash":
80
- with nullcontext() if tolerate_errors is None else self.tolerate_errors.set(tolerate_errors):
82
+ def update(self, *objs: object, iter: Iterable[object] = (), tolerable: bool | None=None, header: str | None = None) -> Self:
83
+ with nc if tolerable is None else self.tolerable.set(tolerable):
81
84
  for obj in chain(objs, iter):
85
+ if header is not None:
86
+ self.write_bytes(header.encode())
87
+ header = None
82
88
  try:
83
89
  self._update_one(obj)
84
90
  except Exception as ex:
85
- if self.tolerate_errors.value:
91
+ if self.tolerable.value:
86
92
  self.header("error").update(type(ex))
87
93
  else:
88
94
  raise ObjectHashError(obj, ex) from ex
@@ -102,11 +108,11 @@ class ObjectHash:
102
108
 
103
109
  case set() | frozenset():
104
110
  try:
105
- items = sorted(obj)
106
111
  header = "set"
112
+ items = sorted(obj)
107
113
  except:
108
- items = sorted(map(self.nested_hash, obj))
109
114
  header = "set-unsortable"
115
+ items = sorted(map(self.nested_hash, obj))
110
116
  self.header(header, encode_type_of(obj), len(obj)).update(iter=items)
111
117
 
112
118
  case TypeVar():
@@ -167,23 +173,25 @@ class ObjectHash:
167
173
  match obj:
168
174
  case list() | tuple():
169
175
  self.header("list", encode_type_of(obj), len(obj)).update(iter=obj)
170
- case dict():
171
- try:
172
- items = sorted(obj.items())
173
- header = "dict"
174
- except:
175
- items = sorted((self.nested_hash(key), val) for key, val in obj.items())
176
- header = "dict-unsortable"
177
- self.header(header, encode_type_of(obj), len(obj)).update(iter=chain.from_iterable(items))
176
+ case dict() | MappingProxyType():
177
+ header = "dict"
178
+ items = obj.items()
179
+ if not isinstance(obj, OrderedDict):
180
+ try:
181
+ items = sorted(items)
182
+ except:
183
+ header = "dict-unsortable"
184
+ items = sorted((self.nested_hash(key), val) for key, val in items)
185
+ self.header(header, encode_type_of(obj), len(obj)).update(iter=flatten(items))
178
186
  case _:
179
187
  self._update_object(obj)
180
188
  finally:
181
189
  del self.current[id(obj)]
182
190
 
183
- def _update_iterator(self, obj: Iterable) -> "ObjectHash":
191
+ def _update_iterator(self, obj: Iterable) -> Self:
184
192
  return self.header("iterator", encode_type_of(obj)).update(iter=obj).header("iterator-end")
185
193
 
186
- def _update_object(self, obj: object) -> "ObjectHash":
194
+ def _update_object(self, obj: object) -> Self:
187
195
  self.header("instance", encode_type_of(obj))
188
196
  get_hash = hasattr(obj, "__objecthash__") and getattr(obj, "__objecthash__")
189
197
  if callable(get_hash):
@@ -1,9 +1,11 @@
1
- from typing import Type
2
- from .storage import Storage
3
- from .pickle_storage import PickleStorage
1
+ from typing import Literal, Type
4
2
  from .memory_storage import MemoryStorage
3
+ from .pickle_storage import PickleStorage
4
+ from .storage import Storage
5
+
6
+ StorageType = Literal["pickle", "memory"]
5
7
 
6
- STORAGE_MAP: dict[str, Type[Storage]] = {
8
+ STORAGE_MAP: dict[StorageType, Type[Storage]] = {
7
9
  "pickle": PickleStorage,
8
10
  "memory": MemoryStorage,
9
11
  }
@@ -8,17 +8,18 @@ if TYPE_CHECKING:
8
8
 
9
9
  class Storage:
10
10
  checkpointer: Checkpointer
11
- cached_fn: CachedFunction
11
+ ident: CachedFunction
12
12
 
13
13
  def __init__(self, cached_fn: CachedFunction):
14
- self.checkpointer = cached_fn.checkpointer
14
+ self.checkpointer = cached_fn.ident.checkpointer
15
15
  self.cached_fn = cached_fn
16
16
 
17
17
  def fn_id(self) -> str:
18
- return f"{self.cached_fn.fn_dir}/{self.cached_fn.ident.fn_hash}"
18
+ ident = self.cached_fn.ident
19
+ return f"{ident.fn_dir}/{ident.fn_hash}"
19
20
 
20
21
  def fn_dir(self) -> Path:
21
- return self.checkpointer.root_path / self.fn_id()
22
+ return self.checkpointer.directory / self.fn_id()
22
23
 
23
24
  def store(self, call_hash: str, data: Any) -> Any: ...
24
25
 
checkpointer/types.py CHANGED
@@ -1,4 +1,7 @@
1
- from typing import Annotated, Callable, Coroutine, Generic, ParamSpec, TypeVar
1
+ from typing import (
2
+ Annotated, Callable, Coroutine, Generic,
3
+ ParamSpec, TypeVar, get_args, get_origin,
4
+ )
2
5
 
3
6
  Fn = TypeVar("Fn", bound=Callable)
4
7
  P = ParamSpec("P")
@@ -9,7 +12,32 @@ T = TypeVar("T")
9
12
  class HashBy(Generic[Fn]):
10
13
  pass
11
14
 
12
- NoHash = Annotated[T, HashBy[lambda _: None]]
15
+ class Captured:
16
+ pass
17
+
18
+ class CapturedOnce:
19
+ pass
20
+
21
+ def to_none(_):
22
+ return None
23
+
24
+ def get_annotated_args(anno: object) -> tuple[object, ...]:
25
+ return get_args(anno) if get_origin(anno) is Annotated else ()
26
+
27
+ def hash_by_from_annotation(anno: object) -> Callable[[object], object] | None:
28
+ for arg in get_annotated_args(anno):
29
+ if get_origin(arg) is HashBy:
30
+ return get_args(arg)[0]
31
+
32
+ def is_capture_me(anno: object) -> bool:
33
+ return Captured in get_annotated_args(anno)
34
+
35
+ def is_capture_me_once(anno: object) -> bool:
36
+ return CapturedOnce in get_annotated_args(anno)
37
+
38
+ NoHash = Annotated[T, HashBy[to_none]]
39
+ CaptureMe = Annotated[T, Captured]
40
+ CaptureMeOnce = Annotated[T, CapturedOnce]
13
41
  Coro = Coroutine[object, object, R]
14
42
 
15
43
  class AwaitableValue(Generic[T]):
checkpointer/utils.py CHANGED
@@ -1,23 +1,93 @@
1
- from contextlib import contextmanager
2
- from typing import Callable, Generic, Iterable, cast
3
- from .types import Fn, T
1
+ from __future__ import annotations
2
+ import inspect
3
+ from contextlib import contextmanager, suppress
4
+ from itertools import islice
5
+ from pathlib import Path
6
+ from types import FunctionType, MethodType, ModuleType
7
+ from typing import Callable, Generic, Iterable, Self, Type, TypeGuard
8
+ from .types import T
4
9
 
5
- def distinct(seq: Iterable[T]) -> list[T]:
6
- return list(dict.fromkeys(seq))
10
+ cwd = Path.cwd().resolve()
11
+
12
+ def is_class(obj) -> TypeGuard[Type]:
13
+ return isinstance(obj, type)
14
+
15
+ def get_file(obj: Callable | ModuleType) -> Path:
16
+ return Path(inspect.getfile(obj)).resolve()
17
+
18
+ def is_user_file(path: Path) -> bool:
19
+ return cwd in path.parents and ".venv" not in path.parts
20
+
21
+ def is_user_fn(obj) -> TypeGuard[Callable]:
22
+ return isinstance(obj, (FunctionType, MethodType)) and is_user_file(get_file(obj))
7
23
 
8
24
  def get_cell_contents(fn: Callable) -> Iterable[tuple[str, object]]:
9
25
  for key, cell in zip(fn.__code__.co_freevars, fn.__closure__ or []):
10
- try:
26
+ with suppress(ValueError):
11
27
  yield (key, cell.cell_contents)
12
- except ValueError:
13
- pass
14
28
 
15
- def unwrap_fn(fn: Fn, cached_fn=False) -> Fn:
16
- from .checkpoint import CachedFunction
17
- while True:
18
- if (cached_fn and isinstance(fn, CachedFunction)) or not hasattr(fn, "__wrapped__"):
19
- return cast(Fn, fn)
20
- fn = getattr(fn, "__wrapped__")
29
+ def drop_none(iterable: Iterable[T | None]) -> list[T]:
30
+ return [x for x in iterable if x is not None]
31
+
32
+ def distinct(seq: Iterable[T]) -> list[T]:
33
+ return list(dict.fromkeys(seq))
34
+
35
+ def takewhile(iter: Iterable[tuple[bool, T]]) -> Iterable[T]:
36
+ for condition, value in iter:
37
+ if not condition:
38
+ return
39
+ yield value
40
+
41
+ class seekable(Generic[T]):
42
+ def __init__(self, iterable: Iterable[T]):
43
+ self.index = 0
44
+ self.source = iter(iterable)
45
+ self.sink: list[T] = []
46
+
47
+ def __iter__(self):
48
+ return self
49
+
50
+ def __next__(self) -> T:
51
+ if len(self.sink) > self.index:
52
+ item = self.sink[self.index]
53
+ else:
54
+ item = next(self.source)
55
+ self.sink.append(item)
56
+ self.index += 1
57
+ return item
58
+
59
+ def __bool__(self):
60
+ return bool(self.lookahead(1))
61
+
62
+ def seek(self, index: int) -> Self:
63
+ remainder = index - len(self.sink)
64
+ if remainder > 0:
65
+ next(islice(self, remainder, remainder), None)
66
+ self.index = max(0, min(index, len(self.sink)))
67
+ return self
68
+
69
+ def step(self, count: int) -> Self:
70
+ return self.seek(self.index + count)
71
+
72
+ @contextmanager
73
+ def freeze(self):
74
+ initial_index = self.index
75
+ try:
76
+ yield
77
+ finally:
78
+ self.seek(initial_index)
79
+
80
+ def lookahead(self, count: int) -> list[T]:
81
+ with self.freeze():
82
+ return list(islice(self, count))
83
+
84
+ def get_at(obj: object, *attrs: str) -> object:
85
+ for attr in attrs:
86
+ if type(obj) is dict:
87
+ obj = obj.get(attr, None)
88
+ else:
89
+ obj = getattr(obj, attr, None)
90
+ return obj
21
91
 
22
92
  class AttrDict(dict):
23
93
  def __init__(self, *args, **kwargs):
@@ -30,23 +100,6 @@ class AttrDict(dict):
30
100
  def __setattr__(self, name: str, value: object):
31
101
  super().__setattr__(name, value)
32
102
 
33
- def set(self, d: dict) -> "AttrDict":
34
- if not d:
35
- return self
36
- return AttrDict({**self, **d})
37
-
38
- def delete(self, *attrs: str) -> "AttrDict":
39
- d = AttrDict(self)
40
- for attr in attrs:
41
- del d[attr]
42
- return d
43
-
44
- def get_at(self, attrs: tuple[str, ...]) -> object:
45
- d = self
46
- for attr in attrs:
47
- d = getattr(d, attr, None)
48
- return d
49
-
50
103
  class ContextVar(Generic[T]):
51
104
  def __init__(self, value: T):
52
105
  self.value = value
@@ -58,26 +111,3 @@ class ContextVar(Generic[T]):
58
111
  yield
59
112
  finally:
60
113
  self.value = old
61
-
62
- class iterate_and_upcoming(Generic[T]):
63
- def __init__(self, it: Iterable[T]) -> None:
64
- self.it = iter(it)
65
- self.previous: tuple[()] | tuple[T] = ()
66
- self.tracked = self._tracked_iter()
67
-
68
- def __iter__(self):
69
- return self
70
-
71
- def __next__(self) -> tuple[T, Iterable[T]]:
72
- try:
73
- item = self.previous[0] if self.previous else next(self.it)
74
- self.previous = ()
75
- return item, self.tracked
76
- except StopIteration:
77
- self.tracked.close()
78
- raise
79
-
80
- def _tracked_iter(self):
81
- for x in self.it:
82
- self.previous = (x,)
83
- yield x