checkpointer 1.2.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpointer/__init__.py CHANGED
@@ -1,10 +1,9 @@
1
- import os
2
-
3
- from .checkpoint import create_checkpointer, read_only, default_dir
4
- from .storage import store_on_demand, read_from_store
1
+ from .checkpoint import Checkpointer, CheckpointFn, CheckpointError
2
+ from .types import Storage
5
3
  from .function_body import get_function_hash
4
+ import tempfile
6
5
 
7
- storage_dir = os.environ.get('CHECKPOINTS_DIR', default_dir)
8
- verbosity = int(os.environ.get('CHECKPOINTS_VERBOSITY', '1'))
9
-
10
- checkpoint = create_checkpointer(root_path=storage_dir, verbosity=verbosity)
6
+ create_checkpointer = Checkpointer
7
+ checkpoint = Checkpointer()
8
+ memory_checkpoint = Checkpointer(format="memory")
9
+ tmp_checkpoint = Checkpointer(root_path=tempfile.gettempdir() + "/checkpoints")
@@ -1,68 +1,114 @@
1
1
  import inspect
2
- from collections import namedtuple
3
- from pathlib import Path
4
- from functools import wraps
5
2
  import relib.hashing as hashing
6
- from . import storage
3
+ from typing import Generic, TypeVar, TypedDict, Callable, Unpack, Literal, Union, Any, cast, overload
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from functools import update_wrapper
7
+ from .types import Storage
7
8
  from .function_body import get_function_hash
8
- from .utils import unwrap_func, sync_resolve_coroutine
9
-
10
- default_dir = Path.home() / '.checkpoints'
11
-
12
- def get_invoke_path(func, function_hash, args, kwargs, path):
13
- if type(path) == str:
14
- return path
15
- elif callable(path):
16
- return path(*args, **kwargs)
17
- else:
18
- hash = hashing.hash([function_hash, args, kwargs or 0])
19
- file_name = Path(func.__code__.co_filename).name
20
- name = func.__name__
21
- return file_name + '/' + name + '/' + hash
22
-
23
- def create_checkpointer_from_config(config):
24
- def checkpoint(opt_func=None, format=config.format, path=None, should_expire=None, when=True):
25
- def receive_func(func):
26
- if not (config.when and when):
27
- return func
28
-
29
- config_ = config._replace(format=format)
30
- is_async = inspect.iscoroutinefunction(func)
31
- unwrapped_func = unwrap_func(func)
32
- function_hash = get_function_hash(unwrapped_func)
33
-
34
- @wraps(unwrapped_func)
35
- def wrapper(*args, **kwargs):
36
- compute = lambda: func(*args, **kwargs)
37
- recheck = kwargs.pop('recheck', False)
38
- invoke_path = get_invoke_path(unwrapped_func, function_hash, args, kwargs, path)
39
- coroutine = storage.store_on_demand(compute, invoke_path, config_, recheck, should_expire)
40
- if is_async:
41
- return coroutine
42
- else:
43
- return sync_resolve_coroutine(coroutine)
44
-
45
- wrapper.checkpoint_config = config_
46
-
47
- return wrapper
48
-
49
- return receive_func(opt_func) if callable(opt_func) else receive_func
50
-
51
- return checkpoint
52
-
53
- def create_checkpointer(format='pickle', root_path=default_dir, when=True, verbosity=1):
54
- root_path = None if root_path is None else Path(root_path)
55
- opts = locals()
56
- CheckpointerConfig = namedtuple('CheckpointerConfig', sorted(opts))
57
- config = CheckpointerConfig(**opts)
58
- return create_checkpointer_from_config(config)
59
-
60
- def read_only(wrapper_func, config, format='pickle', path=None):
61
- func = unwrap_func(wrapper_func)
62
- function_hash = get_function_hash(func)
63
-
64
- def wrapper(*args, **kwargs):
65
- invoke_path = get_invoke_path(func, function_hash, args, kwargs, path)
66
- return storage.read_from_store(invoke_path, config, storage=format)
67
-
68
- return wrapper
9
+ from .utils import unwrap_fn, sync_resolve_coroutine
10
+ from .storages.pickle_storage import PickleStorage
11
+ from .storages.memory_storage import MemoryStorage
12
+ from .storages.bcolz_storage import BcolzStorage
13
+ from .print_checkpoint import print_checkpoint
14
+
15
+ Fn = TypeVar("Fn", bound=Callable)
16
+
17
+ DEFAULT_DIR = Path.home() / ".cache/checkpoints"
18
+ STORAGE_MAP = {"memory": MemoryStorage, "pickle": PickleStorage, "bcolz": BcolzStorage}
19
+
20
+ class CheckpointError(Exception):
21
+ pass
22
+
23
+ class CheckpointerOpts(TypedDict, total=False):
24
+ format: Storage | Literal["pickle", "memory", "bcolz"]
25
+ root_path: Path | str | None
26
+ when: bool
27
+ verbosity: Literal[0, 1]
28
+ path: Callable[..., str] | None
29
+ should_expire: Callable[[datetime], bool] | None
30
+
31
+ class Checkpointer:
32
+ def __init__(self, **opts: Unpack[CheckpointerOpts]):
33
+ self.format = opts.get("format", "pickle")
34
+ self.root_path = Path(opts.get("root_path", DEFAULT_DIR) or ".")
35
+ self.when = opts.get("when", True)
36
+ self.verbosity = opts.get("verbosity", 1)
37
+ self.path = opts.get("path")
38
+ self.should_expire = opts.get("should_expire")
39
+
40
+ def get_storage(self) -> Storage:
41
+ return STORAGE_MAP[self.format] if isinstance(self.format, str) else self.format
42
+
43
+ @overload
44
+ def __call__(self, fn: Fn, **override_opts: Unpack[CheckpointerOpts]) -> "CheckpointFn[Fn]": ...
45
+ @overload
46
+ def __call__(self, fn: None=None, **override_opts: Unpack[CheckpointerOpts]) -> "Checkpointer": ...
47
+ def __call__(self, fn: Fn | None=None, **override_opts: Unpack[CheckpointerOpts]) -> Union["Checkpointer", "CheckpointFn[Fn]"]:
48
+ if override_opts:
49
+ opts = CheckpointerOpts(**{**self.__dict__, **override_opts})
50
+ return Checkpointer(**opts)(fn)
51
+
52
+ return CheckpointFn(self, fn) if callable(fn) else self
53
+
54
+ class CheckpointFn(Generic[Fn]):
55
+ def __init__(self, checkpointer: Checkpointer, fn: Fn):
56
+ wrapped = unwrap_fn(fn)
57
+ file_name = Path(wrapped.__code__.co_filename).name
58
+ update_wrapper(cast(Callable, self), wrapped)
59
+ self.checkpointer = checkpointer
60
+ self.fn = fn
61
+ self.fn_hash = get_function_hash(wrapped)
62
+ self.fn_id = f"{file_name}/{wrapped.__name__}"
63
+ self.is_async = inspect.iscoroutinefunction(wrapped)
64
+
65
+ def get_checkpoint_id(self, args: tuple, kw: dict) -> str:
66
+ if not callable(self.checkpointer.path):
67
+ return f"{self.fn_id}/{hashing.hash([self.fn_hash, args, kw or 0])}"
68
+ checkpoint_id = self.checkpointer.path(*args, **kw)
69
+ if not isinstance(checkpoint_id, str):
70
+ raise CheckpointError(f"path function must return a string, got {type(checkpoint_id)}")
71
+ return checkpoint_id
72
+
73
+ async def _store_on_demand(self, args: tuple, kw: dict, rerun: bool):
74
+ checkpoint_id = self.get_checkpoint_id(args, kw)
75
+ checkpoint_path = self.checkpointer.root_path / checkpoint_id
76
+ storage = self.checkpointer.get_storage()
77
+ should_log = storage is not MemoryStorage and self.checkpointer.verbosity > 0
78
+ refresh = rerun \
79
+ or not storage.exists(checkpoint_path) \
80
+ or (self.checkpointer.should_expire and self.checkpointer.should_expire(storage.checkpoint_date(checkpoint_path)))
81
+
82
+ if refresh:
83
+ print_checkpoint(should_log, "MEMORIZING", checkpoint_id, "blue")
84
+ data = self.fn(*args, **kw)
85
+ if inspect.iscoroutine(data):
86
+ data = await data
87
+ storage.store(checkpoint_path, data)
88
+ return data
89
+
90
+ try:
91
+ data = storage.load(checkpoint_path)
92
+ print_checkpoint(should_log, "REMEMBERED", checkpoint_id, "green")
93
+ return data
94
+ except (EOFError, FileNotFoundError):
95
+ print_checkpoint(should_log, "CORRUPTED", checkpoint_id, "yellow")
96
+ storage.delete(checkpoint_path)
97
+ return await self._store_on_demand(args, kw, rerun)
98
+
99
+ def _call(self, args: tuple, kw: dict, rerun=False):
100
+ if not self.checkpointer.when:
101
+ return self.fn(*args, **kw)
102
+ coroutine = self._store_on_demand(args, kw, rerun)
103
+ return coroutine if self.is_async else sync_resolve_coroutine(coroutine)
104
+
105
+ __call__: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw))
106
+ rerun: Fn = cast(Fn, lambda self, *args, **kw: self._call(args, kw, True))
107
+
108
+ def get(self, *args, **kw) -> Any:
109
+ checkpoint_path = self.checkpointer.root_path / self.get_checkpoint_id(args, kw)
110
+ storage = self.checkpointer.get_storage()
111
+ try:
112
+ return storage.load(checkpoint_path)
113
+ except:
114
+ raise CheckpointError("Could not load checkpoint")
@@ -1,49 +1,46 @@
1
1
  import inspect
2
- from types import FunctionType, CodeType
3
2
  import relib.hashing as hashing
3
+ from collections.abc import Callable
4
+ from types import FunctionType, CodeType
4
5
  from pathlib import Path
5
- from .utils import unwrap_func
6
+ from .utils import unwrap_fn
6
7
 
7
8
  cwd = Path.cwd()
8
9
 
9
- def get_fn_path(fn):
10
- return Path(inspect.getfile(fn)).absolute()
10
+ def get_fn_path(fn: Callable) -> Path:
11
+ return Path(inspect.getfile(fn)).resolve()
11
12
 
12
- def get_function_body(fn):
13
+ def get_function_body(fn: Callable) -> str:
13
14
  # TODO: Strip comments
14
15
  lines = inspect.getsourcelines(fn)[0]
15
16
  lines = [line.rstrip() for line in lines]
16
17
  lines = [line for line in lines if line]
17
- return '\n'.join(lines)
18
+ return "\n".join(lines)
18
19
 
19
- def get_code_children(__code__):
20
- consts = [const for const in __code__.co_consts if isinstance(const, CodeType)]
20
+ def get_code_children(code: CodeType) -> list[str]:
21
+ consts = [const for const in code.co_consts if isinstance(const, CodeType)]
21
22
  children = [child for const in consts for child in get_code_children(const)]
22
- return list(__code__.co_names) + children
23
+ return list(code.co_names) + children
23
24
 
24
- def is_user_fn(candidate_fn, cleared_fns):
25
+ def is_user_fn(candidate_fn, cleared_fns: set[Callable]) -> bool:
25
26
  return isinstance(candidate_fn, FunctionType) \
26
27
  and candidate_fn not in cleared_fns \
27
28
  and cwd in get_fn_path(candidate_fn).parents
28
29
 
29
- def append_fn_children(fn, cleared_fns):
30
+ def append_fn_children(cleared_fns: set[Callable], fn: Callable) -> None:
30
31
  code_children = get_code_children(fn.__code__)
31
- fn_children = [unwrap_func(fn.__globals__.get(co_name, None)) for co_name in code_children]
32
+ fn_children = [unwrap_fn(fn.__globals__.get(co_name, None)) for co_name in code_children]
32
33
  fn_children = [child for child in fn_children if is_user_fn(child, cleared_fns)]
33
-
34
- for fn in fn_children:
35
- cleared_fns.add(fn)
36
-
34
+ cleared_fns.update(fn_children)
37
35
  for child_fn in fn_children:
38
- append_fn_children(child_fn, cleared_fns)
36
+ append_fn_children(cleared_fns, child_fn)
39
37
 
40
- def get_fn_children(fn):
41
- cleared_fns = set()
42
- append_fn_children(fn, cleared_fns)
38
+ def get_fn_children(fn: Callable) -> list[Callable]:
39
+ cleared_fns: set[Callable] = set()
40
+ append_fn_children(cleared_fns, fn)
43
41
  return sorted(cleared_fns, key=lambda fn: fn.__name__)
44
42
 
45
- def get_function_hash(fn):
43
+ def get_function_hash(fn: Callable) -> str:
46
44
  fns = [fn] + get_fn_children(fn)
47
45
  fn_bodies = list(map(get_function_body, fns))
48
- fn_bodies_hash = hashing.hash(fn_bodies)
49
- return fn_bodies_hash
46
+ return hashing.hash(fn_bodies)
@@ -0,0 +1,52 @@
1
+ import io
2
+ import os
3
+ import sys
4
+ from typing import Literal
5
+
6
+ Color = Literal[
7
+ "black", "grey", "red", "green", "yellow", "blue", "magenta",
8
+ "cyan", "light_grey", "dark_grey", "light_red", "light_green",
9
+ "light_yellow", "light_blue", "light_magenta", "light_cyan", "white",
10
+ ]
11
+
12
+ COLOR_MAP: dict[Color, int] = {
13
+ "black": 30,
14
+ "grey": 30,
15
+ "red": 31,
16
+ "green": 32,
17
+ "yellow": 33,
18
+ "blue": 34,
19
+ "magenta": 35,
20
+ "cyan": 36,
21
+ "light_grey": 37,
22
+ "dark_grey": 90,
23
+ "light_red": 91,
24
+ "light_green": 92,
25
+ "light_yellow": 93,
26
+ "light_blue": 94,
27
+ "light_magenta": 95,
28
+ "light_cyan": 96,
29
+ "white": 97,
30
+ }
31
+
32
+ def allow_color() -> bool:
33
+ if "NO_COLOR" in os.environ or os.environ.get("TERM") == "dumb" or not hasattr(sys.stdout, "fileno"):
34
+ return False
35
+ try:
36
+ return os.isatty(sys.stdout.fileno())
37
+ except io.UnsupportedOperation:
38
+ return sys.stdout.isatty()
39
+
40
+ def colored_(text: str, color: Color | None = None, on_color: Color | None = None) -> str:
41
+ if color:
42
+ text = f"\033[{COLOR_MAP[color]}m{text}"
43
+ if on_color:
44
+ text = f"\033[{COLOR_MAP[on_color] + 10}m{text}"
45
+ return text + "\033[0m"
46
+
47
+ noop = lambda text, *a, **k: text
48
+ colored = colored_ if allow_color() else noop
49
+
50
+ def print_checkpoint(should_log: bool, title: str, text: str, color: Color):
51
+ if should_log:
52
+ print(f"{colored(f" {title} ", "grey", color)} {colored(text, color)}")
@@ -1,92 +1,86 @@
1
1
  import shutil
2
2
  from pathlib import Path
3
3
  from datetime import datetime
4
+ from ..types import Storage
4
5
 
5
6
  def get_data_type_str(x):
6
7
  if isinstance(x, tuple):
7
- return 'tuple'
8
+ return "tuple"
8
9
  elif isinstance(x, dict):
9
- return 'dict'
10
+ return "dict"
10
11
  elif isinstance(x, list):
11
- return 'list'
12
- elif isinstance(x, str) or not hasattr(x, '__len__'):
13
- return 'other'
12
+ return "list"
13
+ elif isinstance(x, str) or not hasattr(x, "__len__"):
14
+ return "other"
14
15
  else:
15
- return 'ndarray'
16
+ return "ndarray"
16
17
 
17
- def get_paths(root_path, invoke_path):
18
- full_path = Path(invoke_path) if root_path is None else root_path / invoke_path
19
- meta_full_path = full_path.with_name(full_path.name + '_meta')
20
- return full_path, meta_full_path
18
+ def get_metapath(path: Path):
19
+ return path.with_name(f"{path.name}_meta")
21
20
 
22
- def get_collection_timestamp(config, path):
21
+ def insert_data(path: Path, data):
23
22
  import bcolz
24
- _, meta_full_path = get_paths(config.root_path, path)
25
- meta_data = bcolz.open(meta_full_path)[:][0]
26
- return meta_data['created']
27
-
28
- def get_is_expired(config, path):
29
- try:
30
- get_collection_timestamp(config, path)
31
- return False
32
- except (FileNotFoundError, EOFError):
33
- return True
34
-
35
- def should_expire(config, path, expire_fn):
36
- return expire_fn(get_collection_timestamp(config, path))
37
-
38
- def insert_data(path, data):
39
- import bcolz
40
- c = bcolz.carray(data, rootdir=path, mode='w')
23
+ c = bcolz.carray(data, rootdir=path, mode="w")
41
24
  c.flush()
42
25
 
43
- def store_data(config, path, data, expire_in=None):
44
- full_path, meta_full_path = get_paths(config.root_path, path)
45
- full_path.parent.mkdir(parents=True, exist_ok=True)
46
- created = datetime.now()
47
- data_type_str = get_data_type_str(data)
48
- if data_type_str == 'tuple':
49
- fields = list(range(len(data)))
50
- elif data_type_str == 'dict':
51
- fields = sorted(data.keys())
52
- else:
53
- fields = []
54
- meta_data = {'created': created, 'data_type_str': data_type_str, 'fields': fields}
55
- insert_data(meta_full_path, meta_data)
56
- if data_type_str in ['tuple', 'dict']:
57
- for i in range(len(fields)):
58
- sub_path = f"{path} ({i})"
59
- store_data(config, sub_path, data[fields[i]])
60
- else:
61
- insert_data(full_path, data)
62
- return data
26
+ class BcolzStorage(Storage):
27
+ @staticmethod
28
+ def exists(path):
29
+ return path.exists()
63
30
 
64
- def load_data(config, path):
65
- import bcolz
66
- full_path, meta_full_path = get_paths(config.root_path, path)
67
- meta_data = bcolz.open(meta_full_path)[:][0]
68
- data_type_str = meta_data['data_type_str']
69
- if data_type_str in ['tuple', 'dict']:
70
- fields = meta_data['fields']
71
- partitions = range(len(fields))
72
- data = [load_data(config, f"{path} ({i})") for i in partitions]
73
- if data_type_str == 'tuple':
74
- return tuple(data)
31
+ @staticmethod
32
+ def checkpoint_date(path):
33
+ return datetime.fromtimestamp(path.stat().st_mtime)
34
+
35
+ @staticmethod
36
+ def store(path, data):
37
+ metapath = get_metapath(path)
38
+ path.parent.mkdir(parents=True, exist_ok=True)
39
+ data_type_str = get_data_type_str(data)
40
+ if data_type_str == "tuple":
41
+ fields = list(range(len(data)))
42
+ elif data_type_str == "dict":
43
+ fields = sorted(data.keys())
75
44
  else:
76
- return dict(zip(fields, data))
77
- else:
78
- data = bcolz.open(full_path)
79
- if data_type_str == 'list':
80
- return list(data)
81
- elif data_type_str == 'other':
82
- return data[0]
45
+ fields = []
46
+ meta_data = {"data_type_str": data_type_str, "fields": fields}
47
+ insert_data(metapath, meta_data)
48
+ if data_type_str in ["tuple", "dict"]:
49
+ for i in range(len(fields)):
50
+ child_path = Path(f"{path} ({i})")
51
+ BcolzStorage.store(child_path, data[fields[i]])
52
+ else:
53
+ insert_data(path, data)
54
+
55
+ @staticmethod
56
+ def load(path):
57
+ import bcolz
58
+ metapath = get_metapath(path)
59
+ meta_data = bcolz.open(metapath)[:][0]
60
+ data_type_str = meta_data["data_type_str"]
61
+ if data_type_str in ["tuple", "dict"]:
62
+ fields = meta_data["fields"]
63
+ partitions = range(len(fields))
64
+ data = [BcolzStorage.load(Path(f"{path} ({i})")) for i in partitions]
65
+ if data_type_str == "tuple":
66
+ return tuple(data)
67
+ else:
68
+ return dict(zip(fields, data))
83
69
  else:
84
- return data[:]
70
+ data = bcolz.open(path)
71
+ if data_type_str == "list":
72
+ return list(data)
73
+ elif data_type_str == "other":
74
+ return data[0]
75
+ else:
76
+ return data[:]
85
77
 
86
- def delete_data(config, path):
87
- full_path, meta_full_path = get_paths(config.root_path, path)
88
- try:
89
- shutil.rmtree(meta_full_path)
90
- shutil.rmtree(full_path)
91
- except FileNotFoundError:
92
- pass
78
+ @staticmethod
79
+ def delete(path):
80
+ # NOTE: Not recursive
81
+ metapath = get_metapath(path)
82
+ try:
83
+ shutil.rmtree(metapath)
84
+ shutil.rmtree(path)
85
+ except FileNotFoundError:
86
+ pass
@@ -1,18 +1,28 @@
1
1
  from datetime import datetime
2
+ from ..types import Storage
2
3
 
3
4
  store = {}
4
5
  date_stored = {}
5
6
 
6
- def get_is_expired(config, path):
7
- return path not in store
7
+ class MemoryStorage(Storage):
8
+ @staticmethod
9
+ def exists(path):
10
+ return str(path) in store
8
11
 
9
- def should_expire(config, path, expire_fn):
10
- return expire_fn(date_stored[path])
12
+ @staticmethod
13
+ def checkpoint_date(path):
14
+ return date_stored[str(path)]
11
15
 
12
- def store_data(config, path, data):
13
- store[path] = data
14
- date_stored[path] = datetime.now()
15
- return data
16
+ @staticmethod
17
+ def store(path, data):
18
+ store[str(path)] = data
19
+ date_stored[str(path)] = datetime.now()
16
20
 
17
- def load_data(config, path):
18
- return store[path]
21
+ @staticmethod
22
+ def load(path):
23
+ return store[str(path)]
24
+
25
+ @staticmethod
26
+ def delete(path):
27
+ del store[str(path)]
28
+ del date_stored[str(path)]
@@ -1,49 +1,36 @@
1
1
  import pickle
2
2
  from pathlib import Path
3
3
  from datetime import datetime
4
+ from ..types import Storage
4
5
 
5
- def get_paths(root_path, invoke_path):
6
- p = Path(invoke_path) if root_path is None else root_path / invoke_path
7
- meta_full_path = p.with_name(p.name + '_meta.pkl')
8
- pkl_full_path = p.with_name(p.name + '.pkl')
9
- return meta_full_path, pkl_full_path
6
+ def get_path(path: Path):
7
+ return path.with_name(f"{path.name}.pkl")
10
8
 
11
- def get_collection_timestamp(config, path):
12
- meta_full_path, pkl_full_path = get_paths(config.root_path, path)
13
- with meta_full_path.open('rb') as file:
14
- meta_data = pickle.load(file)
15
- return meta_data['created']
9
+ class PickleStorage(Storage):
10
+ @staticmethod
11
+ def exists(path):
12
+ return get_path(path).exists()
16
13
 
17
- def get_is_expired(config, path):
18
- try:
19
- get_collection_timestamp(config, path)
20
- return False
21
- except (FileNotFoundError, EOFError):
22
- return True
14
+ @staticmethod
15
+ def checkpoint_date(path):
16
+ return datetime.fromtimestamp(get_path(path).stat().st_mtime)
23
17
 
24
- def should_expire(config, path, expire_fn):
25
- return expire_fn(get_collection_timestamp(config, path))
18
+ @staticmethod
19
+ def store(path, data):
20
+ full_path = get_path(path)
21
+ full_path.parent.mkdir(parents=True, exist_ok=True)
22
+ with full_path.open("wb") as file:
23
+ pickle.dump(data, file, -1)
26
24
 
27
- def store_data(config, path, data):
28
- created = datetime.now()
29
- meta_data = {'created': created}
30
- meta_full_path, pkl_full_path = get_paths(config.root_path, path)
31
- pkl_full_path.parent.mkdir(parents=True, exist_ok=True)
32
- with pkl_full_path.open('wb') as file:
33
- pickle.dump(data, file, -1)
34
- with meta_full_path.open('wb') as file:
35
- pickle.dump(meta_data, file, -1)
36
- return data
25
+ @staticmethod
26
+ def load(path):
27
+ full_path = get_path(path)
28
+ with full_path.open("rb") as file:
29
+ return pickle.load(file)
37
30
 
38
- def load_data(config, path):
39
- _, full_path = get_paths(config.root_path, path)
40
- with full_path.open('rb') as file:
41
- return pickle.load(file)
42
-
43
- def delete_data(config, path):
44
- meta_full_path, pkl_full_path = get_paths(config.root_path, path)
45
- try:
46
- meta_full_path.unlink()
47
- pkl_full_path.unlink()
48
- except FileNotFoundError:
49
- pass
31
+ @staticmethod
32
+ def delete(path):
33
+ try:
34
+ get_path(path).unlink()
35
+ except FileNotFoundError:
36
+ pass
checkpointer/types.py ADDED
@@ -0,0 +1,19 @@
1
+ from typing import Protocol, Any
2
+ from pathlib import Path
3
+ from datetime import datetime
4
+
5
+ class Storage(Protocol):
6
+ @staticmethod
7
+ def exists(path: Path) -> bool: ...
8
+
9
+ @staticmethod
10
+ def checkpoint_date(path: Path) -> datetime: ...
11
+
12
+ @staticmethod
13
+ def store(path: Path, data: Any) -> None: ...
14
+
15
+ @staticmethod
16
+ def load(path: Path) -> Any: ...
17
+
18
+ @staticmethod
19
+ def delete(path: Path) -> None: ...
checkpointer/utils.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import types
2
2
 
3
- def unwrap_func(func):
4
- while hasattr(func, '__wrapped__'):
5
- func = func.__wrapped__
6
- return func
3
+ def unwrap_fn[T](fn: T) -> T:
4
+ while hasattr(fn, "__wrapped__"):
5
+ fn = getattr(fn, "__wrapped__")
6
+ return fn
7
7
 
8
8
  @types.coroutine
9
9
  def coroutine_as_generator(coroutine):
@@ -0,0 +1,270 @@
1
+ Metadata-Version: 2.3
2
+ Name: checkpointer
3
+ Version: 2.0.1
4
+ Summary: A Python library for memoizing function results with support for multiple storage backends, async runtimes, and automatic cache invalidation
5
+ Project-URL: Repository, https://github.com/Reddan/checkpointer.git
6
+ Author: Hampus Hallman
7
+ License: Copyright 2024 Hampus Hallman
8
+
9
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
14
+ Requires-Python: >=3.12
15
+ Requires-Dist: relib
16
+ Description-Content-Type: text/markdown
17
+
18
+ # checkpointer · [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/Reddan/checkpointer/blob/master/LICENSE) [![pypi](https://img.shields.io/pypi/v/checkpointer)](https://pypi.org/project/checkpointer/) [![Python 3.12](https://img.shields.io/badge/python-3.12-blue)](https://pypi.org/project/checkpointer/)
19
+
20
+ `checkpointer` is a Python library for memoizing function results. It simplifies caching by providing a decorator-based API and supports various storage backends. It's designed for computationally expensive operations where caching can save time, or during development to avoid waiting for redundant computations. ⚡️
21
+
22
+ Adding or removing `@checkpoint` doesn't change how your code works, and it can be applied to any function, including ones you've already written, without altering their behavior or introducing side effects. The original function remains unchanged and can still be called directly when needed.
23
+
24
+ ### Key Features:
25
+ - 🗂️ **Multiple Storage Backends**: Supports in-memory, pickle, or your own custom storage.
26
+ - 🎯 **Simple Decorator API**: Apply `@checkpoint` to functions.
27
+ - 🔄 **Async and Sync Compatibility**: Works with synchronous functions and any Python async runtime (e.g., `asyncio`, `Trio`, `Curio`).
28
+ - ⏲️ **Custom Expiration Logic**: Automatically invalidate old checkpoints.
29
+ - 📂 **Flexible Path Configuration**: Control where checkpoints are stored.
30
+
31
+ ---
32
+
33
+ ## Installation
34
+
35
+ ```bash
36
+ pip install checkpointer
37
+ ```
38
+
39
+ ---
40
+
41
+ ## Quick Start 🚀
42
+
43
+ ```python
44
+ from checkpointer import checkpoint
45
+
46
+ @checkpoint
47
+ def expensive_function(x: int) -> int:
48
+ print("Computing...")
49
+ return x ** 2
50
+
51
+ result = expensive_function(4) # Computes and stores result
52
+ result = expensive_function(4) # Loads from checkpoint
53
+ ```
54
+
55
+ ---
56
+
57
+ ## How It Works
58
+
59
+ When you use `@checkpoint`, the function's **arguments** (`args`, `kwargs`) are hashed to create a unique identifier for each call. This identifier is used to store and retrieve cached results. If the same arguments are passed again, `checkpointer` will return the cached result instead of recomputing.
60
+
61
+ Additionally, `checkpointer` ensures that caches are invalidated when a function’s implementation or any of its dependencies change. Each function is assigned a hash based on:
62
+ 1. **Its source code**: Changes to the function’s code update its hash.
63
+ 2. **Dependent functions**: If a function calls others, changes to those will also update the hash.
64
+
65
+ ### Example: Cache Invalidation by Function Dependencies
66
+
67
+ ```python
68
+ def multiply(a, b):
69
+ return a * b
70
+
71
+ @checkpoint
72
+ def helper(x):
73
+ return multiply(x + 1, 2)
74
+
75
+ @checkpoint
76
+ def compute(a, b):
77
+ return helper(a) + helper(b)
78
+ ```
79
+
80
+ If you change `multiply`, the checkpoints for both `helper` and `compute` will be invalidated and recomputed.
81
+
82
+ ---
83
+
84
+ ## Parameterization
85
+
86
+ ### Global Configuration
87
+
88
+ You can configure a custom `Checkpointer`:
89
+
90
+ ```python
91
+ from checkpointer import checkpoint
92
+
93
+ checkpoint = checkpoint(format="memory", root_path="/tmp/checkpoints")
94
+ ```
95
+
96
+ Extend this configuration by calling itself again:
97
+
98
+ ```python
99
+ extended_checkpoint = checkpoint(format="pickle", verbosity=0)
100
+ ```
101
+
102
+ ### Per-Function Customization
103
+
104
+ ```python
105
+ @checkpoint(format="pickle", verbosity=0)
106
+ def my_function(x, y):
107
+ return x + y
108
+ ```
109
+
110
+ ### Combining Configurations
111
+
112
+ ```python
113
+ checkpoint = checkpoint(format="memory", verbosity=1)
114
+ quiet_checkpoint = checkpoint(verbosity=0)
115
+ pickle_checkpoint = checkpoint(format="pickle", root_path="/tmp/pickle_checkpoints")
116
+
117
+ @checkpoint
118
+ def compute_square(n: int) -> int:
119
+ return n ** 2
120
+
121
+ @quiet_checkpoint
122
+ def compute_quietly(n: int) -> int:
123
+ return n ** 3
124
+
125
+ @pickle_checkpoint
126
+ def compute_sum(a: int, b: int) -> int:
127
+ return a + b
128
+ ```
129
+
130
+ ### Layered Caching
131
+
132
+ ```python
133
+ IS_DEVELOPMENT = True # Toggle based on environment
134
+
135
+ dev_checkpoint = checkpoint(when=IS_DEVELOPMENT)
136
+
137
+ @checkpoint(format="memory")
138
+ @dev_checkpoint
139
+ def some_expensive_function():
140
+ print("Performing a time-consuming operation...")
141
+ return sum(i * i for i in range(10**6))
142
+ ```
143
+
144
+ - In development: Both `dev_checkpoint` and `memory` caches are active.
145
+ - In production: Only the `memory` cache is active.
146
+
147
+ ---
148
+
149
+ ## Usage
150
+
151
+ ### Force Recalculation
152
+ Use `rerun` to force a recalculation and overwrite the stored checkpoint:
153
+
154
+ ```python
155
+ result = expensive_function.rerun(4)
156
+ ```
157
+
158
+ ### Bypass Checkpointer
159
+ Use `fn` to directly call the original, undecorated function:
160
+
161
+ ```python
162
+ result = expensive_function.fn(4)
163
+ ```
164
+
165
+ This is especially useful **inside recursive functions**. By using `.fn` within the function itself, you avoid redundant caching of intermediate recursive calls while still caching the final result at the top level.
166
+
167
+ ### Retrieve Stored Checkpoints
168
+ Access stored results without recalculating:
169
+
170
+ ```python
171
+ stored_result = expensive_function.get(4)
172
+ ```
173
+
174
+ ---
175
+
176
+ ## Storage Backends
177
+
178
+ `checkpointer` supports flexible storage backends, including built-in options and custom implementations.
179
+
180
+ ### Built-In Backends
181
+
182
+ 1. **PickleStorage**: Saves checkpoints to disk using Python's `pickle` module.
183
+ 2. **MemoryStorage**: Caches checkpoints in memory for fast, non-persistent use.
184
+
185
+ To use these backends, pass either `"pickle"` or `PickleStorage` (and similarly for `"memory"` or `MemoryStorage`) to the `format` parameter:
186
+ ```python
187
+ from checkpointer import checkpoint, PickleStorage, MemoryStorage
188
+
189
+ @checkpoint(format="pickle") # Equivalent to format=PickleStorage
190
+ def disk_cached(x: int) -> int:
191
+ return x ** 2
192
+
193
+ @checkpoint(format="memory") # Equivalent to format=MemoryStorage
194
+ def memory_cached(x: int) -> int:
195
+ return x * 10
196
+ ```
197
+
198
+ ### Custom Storage Backends
199
+
200
+ Create custom storage backends by implementing methods for storing, loading, and managing checkpoints. For example, a custom storage backend might use a database, cloud storage, or a specialized format.
201
+
202
+ Example usage:
203
+ ```python
204
+ from checkpointer import checkpoint, Storage
205
+ from typing import Any
206
+ from pathlib import Path
207
+ from datetime import datetime
208
+
209
+ class CustomStorage(Storage): # Optional for type hinting
210
+ @staticmethod
211
+ def exists(path: Path) -> bool: ...
212
+ @staticmethod
213
+ def checkpoint_date(path: Path) -> datetime: ...
214
+ @staticmethod
215
+ def store(path: Path, data: Any) -> None: ...
216
+ @staticmethod
217
+ def load(path: Path) -> Any: ...
218
+ @staticmethod
219
+ def delete(path: Path) -> None: ...
220
+
221
+ @checkpoint(format=CustomStorage)
222
+ def custom_cached(x: int):
223
+ return x ** 2
224
+ ```
225
+
226
+ This flexibility allows you to adapt `checkpointer` to meet any storage requirement, whether persistent or in-memory.
227
+
228
+ ---
229
+
230
+ ## Configuration Options ⚙️
231
+
232
+ | Option | Type | Default | Description |
233
+ |----------------|-------------------------------------|-------------|---------------------------------------------|
234
+ | `format` | `"pickle"`, `"memory"`, `Storage` | `"pickle"` | Storage backend format. |
235
+ | `root_path` | `Path`, `str`, or `None` | User Cache | Root directory for storing checkpoints. |
236
+ | `when` | `bool` | `True` | Enable or disable checkpointing. |
237
+ | `verbosity` | `0` or `1` | `1` | Logging verbosity. |
238
+ | `path` | `Callable[..., str]` | `None` | Custom path for checkpoint storage. |
239
+ | `should_expire`| `Callable[[datetime], bool]` | `None` | Custom expiration logic. |
240
+
241
+ ---
242
+
243
+ ## Full Example 🛠️
244
+
245
+ ```python
246
+ import asyncio
247
+ from checkpointer import checkpoint
248
+
249
+ @checkpoint
250
+ def compute_square(n: int) -> int:
251
+ print(f"Computing {n}^2...")
252
+ return n ** 2
253
+
254
+ @checkpoint(format="memory")
255
+ async def async_compute_sum(a: int, b: int) -> int:
256
+ await asyncio.sleep(1)
257
+ return a + b
258
+
259
+ async def main():
260
+ result1 = compute_square(5)
261
+ print(result1)
262
+
263
+ result2 = await async_compute_sum(3, 7)
264
+ print(result2)
265
+
266
+ result3 = async_compute_sum.get(3, 7)
267
+ print(result3)
268
+
269
+ asyncio.run(main())
270
+ ```
@@ -0,0 +1,13 @@
1
+ checkpointer/__init__.py,sha256=2o-pOMXC_wVcjDtyjyapAdeTh6jyYwKYE0--C5XsKdc,350
2
+ checkpointer/checkpoint.py,sha256=-09sz8sZdYFwxfb8_O3L2PmdCN_lDXdcwKKTkFlOAtw,4715
3
+ checkpointer/function_body.py,sha256=92mnTY9d_JhKnKugeySYRP6qhU4fH6F6zesb7h2pEi0,1720
4
+ checkpointer/print_checkpoint.py,sha256=21aeqgM9CMjNAJyScqFmXCWWfh3jBIn7o7i5zJkZGaA,1369
5
+ checkpointer/types.py,sha256=n1AspKywTQhurCy7V_3t1HKIxYm0T6qOwuoDYfamO0E,408
6
+ checkpointer/utils.py,sha256=UrQt689UHUjl7kXpTbUCGkHUgQZllByX2rbuvZdt9vk,368
7
+ checkpointer/storages/bcolz_storage.py,sha256=F1JahTAgYmSpeE5mL1kPcANWTVxDgvb2YY8fgWRxt2U,2286
8
+ checkpointer/storages/memory_storage.py,sha256=EmXwscJ2D31Sekr4n0ONNaeiQWMf7SHfpHoVwRb1Ec8,534
9
+ checkpointer/storages/pickle_storage.py,sha256=YOndlnUdCaRUDWkzvQrU79j6FkGyp44WrSjl4kIs8RA,837
10
+ checkpointer-2.0.1.dist-info/METADATA,sha256=yHkb_PR1Js26cqT5UW1g2rmar_RnRabNjw4cE59tSlA,9568
11
+ checkpointer-2.0.1.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
12
+ checkpointer-2.0.1.dist-info/licenses/LICENSE,sha256=0cmUKqBotzbBcysIexd52AhjwbphhlGYiWbvg5l2QAU,1054
13
+ checkpointer-2.0.1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.26.3
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
checkpointer/storage.py DELETED
@@ -1,60 +0,0 @@
1
- import inspect
2
- from termcolor import colored
3
- from .storages import memory_storage, pickle_storage, bcolz_storage
4
-
5
- storages = {
6
- 'memory': memory_storage,
7
- 'pickle': pickle_storage,
8
- 'bcolz': bcolz_storage,
9
- }
10
-
11
- initialized_storages = set()
12
-
13
- def create_logger(should_log):
14
- def log(color, title, text):
15
- if should_log:
16
- title_log = colored(f' {title} ', 'grey', 'on_' + color)
17
- rest_log = colored(text, color)
18
- print(title_log + ' ' + rest_log)
19
- return log
20
-
21
- def get_storage(storage):
22
- if type(storage) == str:
23
- storage = storages[storage]
24
- if storage not in initialized_storages:
25
- if hasattr(storage, 'initialize'):
26
- storage.initialize()
27
- initialized_storages.add(storage)
28
- return storage
29
-
30
- async def store_on_demand(get_data, name, config, force=False, should_expire=None):
31
- storage = get_storage(config.format)
32
- should_log = storage != memory_storage and config.verbosity != 0
33
- log = create_logger(should_log)
34
- refresh = force \
35
- or storage.get_is_expired(config, name) \
36
- or (should_expire and storage.should_expire(config, name, should_expire))
37
-
38
- if refresh:
39
- log('blue', 'MEMORIZING', name)
40
- data = get_data()
41
- if inspect.iscoroutine(data):
42
- data = await data
43
- return storage.store_data(config, name, data)
44
- else:
45
- try:
46
- data = storage.load_data(config, name)
47
- log('green', 'REMEMBERED', name)
48
- return data
49
- except (EOFError, FileNotFoundError):
50
- log('yellow', 'CORRUPTED', name)
51
- storage.delete_data(config, name)
52
- result = await store_on_demand(get_data, name, config, force, should_expire)
53
- return result
54
-
55
- def read_from_store(name, config, storage='pickle'):
56
- storage = get_storage(storage)
57
- try:
58
- return storage.load_data(config, name)
59
- except:
60
- return None
@@ -1,16 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: checkpointer
3
- Version: 1.2.0
4
- Project-URL: Repository, https://github.com/Reddan/checkpointer.git
5
- Author: Hampus Hallman
6
- License: Copyright 2024 Hampus Hallman
7
-
8
- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
9
-
10
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
11
-
12
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
13
- License-File: LICENSE
14
- Requires-Python: >=3.8
15
- Requires-Dist: relib
16
- Requires-Dist: termcolor
@@ -1,12 +0,0 @@
1
- checkpointer/__init__.py,sha256=_RYcKsZbeUf08KZ-DcXlNn4eAMWh2LN9o-KvchYVmmk,380
2
- checkpointer/checkpoint.py,sha256=FyL78HvAvPtgl-esiAkt-CdekT18J2Sh0SriMtX4QLc,2367
3
- checkpointer/function_body.py,sha256=vBZNdPuF8gp11Z_NvjoClmMToHR8ynY7tPykE2u25oE,1577
4
- checkpointer/storage.py,sha256=Ofuh0dKF5vk4_B4djt3Q6qyZhIO5f59uCNCZjMrto0U,1782
5
- checkpointer/utils.py,sha256=CC3-W0RgHEP92zt9atAM2gP0fDrxANOdyMH8F5xaIdU,357
6
- checkpointer/storages/bcolz_storage.py,sha256=Yk7FI75noe9hZBWVFIRetiFSR7tkzbryYlBmxX-lVlw,2728
7
- checkpointer/storages/memory_storage.py,sha256=S4SgKSApbQE-pxxKRWLNJqyZMRQwaw5-N0DOIsZM7mE,364
8
- checkpointer/storages/pickle_storage.py,sha256=zcnX1GG6XPHvVxi7gCab5oFxKoz5E7LZHYH74VL1hkY,1542
9
- checkpointer-1.2.0.dist-info/METADATA,sha256=9oGYe0jvzZfuMtleqqpGOBlJspbxxJTynXv7vuA2G6o,1349
10
- checkpointer-1.2.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
11
- checkpointer-1.2.0.dist-info/licenses/LICENSE,sha256=0cmUKqBotzbBcysIexd52AhjwbphhlGYiWbvg5l2QAU,1054
12
- checkpointer-1.2.0.dist-info/RECORD,,