torch-einops-utils 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  from torch_einops_utils.torch_einops_utils import (
2
2
  maybe,
3
3
  masked_mean,
4
+ shape_with_replace,
4
5
  slice_at_dim,
5
6
  slice_left_at_dim,
6
7
  slice_right_at_dim
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+ from pathlib import Path
3
+ from packaging import version as packaging_version
4
+
5
+ import pickle
6
+ from functools import wraps
7
+
8
+ import torch
9
+ from torch.nn import Module
10
+
11
+ # helpers
12
+
13
+ def exists(v):
14
+ return v is not None
15
+
16
+ def save_load(
17
+ save_method_name = 'save',
18
+ load_method_name = 'load',
19
+ config_instance_var_name = '_config',
20
+ init_and_load_classmethod_name = 'init_and_load',
21
+ version: str | None = None
22
+ ):
23
+ def _save_load(klass):
24
+ assert issubclass(klass, Module), 'save_load should decorate a subclass of torch.nn.Module'
25
+
26
+ _orig_init = klass.__init__
27
+
28
+ @wraps(_orig_init)
29
+ def __init__(self, *args, **kwargs):
30
+ _config = pickle.dumps((args, kwargs))
31
+
32
+ setattr(self, config_instance_var_name, _config)
33
+ _orig_init(self, *args, **kwargs)
34
+
35
+ def _save(self, path, overwrite = True):
36
+ path = Path(path)
37
+ assert overwrite or not path.exists()
38
+
39
+ pkg = dict(
40
+ model = self.state_dict(),
41
+ config = getattr(self, config_instance_var_name),
42
+ version = version,
43
+ )
44
+
45
+ torch.save(pkg, str(path))
46
+
47
+ def _load(self, path, strict = True):
48
+ path = Path(path)
49
+ assert path.exists()
50
+
51
+ pkg = torch.load(str(path), map_location = 'cpu')
52
+
53
+ if exists(version) and exists(pkg['version']) and packaging_version.parse(version) != packaging_version.parse(pkg['version']):
54
+ print(f'loading saved model at version {pkg["version"]}, but current package version is {version}')
55
+
56
+ self.load_state_dict(pkg['model'], strict = strict)
57
+
58
+ # init and load from
59
+ # looks for a `config` key in the stored checkpoint, instantiating the model as well as loading the state dict
60
+
61
+ @classmethod
62
+ def _init_and_load_from(cls, path, strict = True):
63
+ path = Path(path)
64
+ assert path.exists()
65
+ pkg = torch.load(str(path), map_location = 'cpu')
66
+
67
+ assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
68
+
69
+ config = pickle.loads(pkg['config'])
70
+ args, kwargs = config
71
+ model = cls(*args, **kwargs)
72
+
73
+ _load(model, path, strict = strict)
74
+ return model
75
+
76
+ # set decorated init as well as save, load, and init_and_load
77
+
78
+ klass.__init__ = __init__
79
+ setattr(klass, save_method_name, _save)
80
+ setattr(klass, load_method_name, _load)
81
+ setattr(klass, init_and_load_classmethod_name, _init_and_load_from)
82
+
83
+ return klass
84
+
85
+ return _save_load
@@ -61,6 +61,25 @@ def masked_mean(
61
61
 
62
62
  return num / den.clamp(min = eps)
63
63
 
64
+ # shapes
65
+
66
+ def shape_with_replace(
67
+ t,
68
+ replace_dict: dict[int, int] | None = None
69
+ ):
70
+ shape = t.shape
71
+
72
+ if not exists(replace_dict):
73
+ return shape
74
+
75
+ shape_list = list(shape)
76
+
77
+ for index, value in replace_dict.items():
78
+ assert index < len(shape_list)
79
+ shape_list[index] = value
80
+
81
+ return torch.Size(shape_list)
82
+
64
83
  # slicing
65
84
 
66
85
  def slice_at_dim(t, slc, dim = -1):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-utils
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: Personal utility functions
5
5
  Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
6
6
  Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
@@ -0,0 +1,7 @@
1
+ torch_einops_utils/__init__.py,sha256=F029-BB58UkzkTnXs8odeQWfjkyKnvBBV-7fIg3Drj0,899
2
+ torch_einops_utils/save_load.py,sha256=K-i7nmLyXBHdAfBLN3rGQzI3NVf6RRwF_GcrKQnfQsc,2669
3
+ torch_einops_utils/torch_einops_utils.py,sha256=UpcNBm4XkYF8NCdGGCgD2H1nlikTdtU3JEFWBzMoJjw,5855
4
+ torch_einops_utils-0.0.17.dist-info/METADATA,sha256=OMNDwWMl8oFEH1K1RvL91U-NtjSk5yLSW7XlLIg2iZU,2139
5
+ torch_einops_utils-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ torch_einops_utils-0.0.17.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
7
+ torch_einops_utils-0.0.17.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- torch_einops_utils/__init__.py,sha256=STbYPcW6LEF-ggmMVP_xkwKhyVCN4aHQC-xXGsDN6KA,875
2
- torch_einops_utils/torch_einops_utils.py,sha256=PZ37JN6Nd7FPYxB_hBgdX6dNZzLpilvtymVtfcF9pFc,5503
3
- torch_einops_utils-0.0.15.dist-info/METADATA,sha256=XiUbDajKh646ipm8cns-VY0GZVECVMWt3-yoVNhPZpU,2139
4
- torch_einops_utils-0.0.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- torch_einops_utils-0.0.15.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
6
- torch_einops_utils-0.0.15.dist-info/RECORD,,