torch-einops-utils 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_einops_utils/__init__.py +1 -0
- torch_einops_utils/save_load.py +85 -0
- torch_einops_utils/torch_einops_utils.py +19 -0
- {torch_einops_utils-0.0.15.dist-info → torch_einops_utils-0.0.17.dist-info}/METADATA +1 -1
- torch_einops_utils-0.0.17.dist-info/RECORD +7 -0
- torch_einops_utils-0.0.15.dist-info/RECORD +0 -6
- {torch_einops_utils-0.0.15.dist-info → torch_einops_utils-0.0.17.dist-info}/WHEEL +0 -0
- {torch_einops_utils-0.0.15.dist-info → torch_einops_utils-0.0.17.dist-info}/licenses/LICENSE +0 -0
torch_einops_utils/__init__.py
CHANGED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from packaging import version as packaging_version
|
|
4
|
+
|
|
5
|
+
import pickle
|
|
6
|
+
from functools import wraps
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.nn import Module
|
|
10
|
+
|
|
11
|
+
# helpers
|
|
12
|
+
|
|
13
|
+
def exists(v):
|
|
14
|
+
return v is not None
|
|
15
|
+
|
|
16
|
+
def save_load(
|
|
17
|
+
save_method_name = 'save',
|
|
18
|
+
load_method_name = 'load',
|
|
19
|
+
config_instance_var_name = '_config',
|
|
20
|
+
init_and_load_classmethod_name = 'init_and_load',
|
|
21
|
+
version: str | None = None
|
|
22
|
+
):
|
|
23
|
+
def _save_load(klass):
|
|
24
|
+
assert issubclass(klass, Module), 'save_load should decorate a subclass of torch.nn.Module'
|
|
25
|
+
|
|
26
|
+
_orig_init = klass.__init__
|
|
27
|
+
|
|
28
|
+
@wraps(_orig_init)
|
|
29
|
+
def __init__(self, *args, **kwargs):
|
|
30
|
+
_config = pickle.dumps((args, kwargs))
|
|
31
|
+
|
|
32
|
+
setattr(self, config_instance_var_name, _config)
|
|
33
|
+
_orig_init(self, *args, **kwargs)
|
|
34
|
+
|
|
35
|
+
def _save(self, path, overwrite = True):
|
|
36
|
+
path = Path(path)
|
|
37
|
+
assert overwrite or not path.exists()
|
|
38
|
+
|
|
39
|
+
pkg = dict(
|
|
40
|
+
model = self.state_dict(),
|
|
41
|
+
config = getattr(self, config_instance_var_name),
|
|
42
|
+
version = version,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
torch.save(pkg, str(path))
|
|
46
|
+
|
|
47
|
+
def _load(self, path, strict = True):
|
|
48
|
+
path = Path(path)
|
|
49
|
+
assert path.exists()
|
|
50
|
+
|
|
51
|
+
pkg = torch.load(str(path), map_location = 'cpu')
|
|
52
|
+
|
|
53
|
+
if exists(version) and exists(pkg['version']) and packaging_version.parse(version) != packaging_version.parse(pkg['version']):
|
|
54
|
+
print(f'loading saved model at version {pkg["version"]}, but current package version is {version}')
|
|
55
|
+
|
|
56
|
+
self.load_state_dict(pkg['model'], strict = strict)
|
|
57
|
+
|
|
58
|
+
# init and load from
|
|
59
|
+
# looks for a `config` key in the stored checkpoint, instantiating the model as well as loading the state dict
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def _init_and_load_from(cls, path, strict = True):
|
|
63
|
+
path = Path(path)
|
|
64
|
+
assert path.exists()
|
|
65
|
+
pkg = torch.load(str(path), map_location = 'cpu')
|
|
66
|
+
|
|
67
|
+
assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
|
|
68
|
+
|
|
69
|
+
config = pickle.loads(pkg['config'])
|
|
70
|
+
args, kwargs = config
|
|
71
|
+
model = cls(*args, **kwargs)
|
|
72
|
+
|
|
73
|
+
_load(model, path, strict = strict)
|
|
74
|
+
return model
|
|
75
|
+
|
|
76
|
+
# set decorated init as well as save, load, and init_and_load
|
|
77
|
+
|
|
78
|
+
klass.__init__ = __init__
|
|
79
|
+
setattr(klass, save_method_name, _save)
|
|
80
|
+
setattr(klass, load_method_name, _load)
|
|
81
|
+
setattr(klass, init_and_load_classmethod_name, _init_and_load_from)
|
|
82
|
+
|
|
83
|
+
return klass
|
|
84
|
+
|
|
85
|
+
return _save_load
|
|
@@ -61,6 +61,25 @@ def masked_mean(
|
|
|
61
61
|
|
|
62
62
|
return num / den.clamp(min = eps)
|
|
63
63
|
|
|
64
|
+
# shapes
|
|
65
|
+
|
|
66
|
+
def shape_with_replace(
|
|
67
|
+
t,
|
|
68
|
+
replace_dict: dict[int, int] | None = None
|
|
69
|
+
):
|
|
70
|
+
shape = t.shape
|
|
71
|
+
|
|
72
|
+
if not exists(replace_dict):
|
|
73
|
+
return shape
|
|
74
|
+
|
|
75
|
+
shape_list = list(shape)
|
|
76
|
+
|
|
77
|
+
for index, value in replace_dict.items():
|
|
78
|
+
assert index < len(shape_list)
|
|
79
|
+
shape_list[index] = value
|
|
80
|
+
|
|
81
|
+
return torch.Size(shape_list)
|
|
82
|
+
|
|
64
83
|
# slicing
|
|
65
84
|
|
|
66
85
|
def slice_at_dim(t, slc, dim = -1):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-einops-utils
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.17
|
|
4
4
|
Summary: Personal utility functions
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
torch_einops_utils/__init__.py,sha256=F029-BB58UkzkTnXs8odeQWfjkyKnvBBV-7fIg3Drj0,899
|
|
2
|
+
torch_einops_utils/save_load.py,sha256=K-i7nmLyXBHdAfBLN3rGQzI3NVf6RRwF_GcrKQnfQsc,2669
|
|
3
|
+
torch_einops_utils/torch_einops_utils.py,sha256=UpcNBm4XkYF8NCdGGCgD2H1nlikTdtU3JEFWBzMoJjw,5855
|
|
4
|
+
torch_einops_utils-0.0.17.dist-info/METADATA,sha256=OMNDwWMl8oFEH1K1RvL91U-NtjSk5yLSW7XlLIg2iZU,2139
|
|
5
|
+
torch_einops_utils-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
+
torch_einops_utils-0.0.17.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
|
|
7
|
+
torch_einops_utils-0.0.17.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
torch_einops_utils/__init__.py,sha256=STbYPcW6LEF-ggmMVP_xkwKhyVCN4aHQC-xXGsDN6KA,875
|
|
2
|
-
torch_einops_utils/torch_einops_utils.py,sha256=PZ37JN6Nd7FPYxB_hBgdX6dNZzLpilvtymVtfcF9pFc,5503
|
|
3
|
-
torch_einops_utils-0.0.15.dist-info/METADATA,sha256=XiUbDajKh646ipm8cns-VY0GZVECVMWt3-yoVNhPZpU,2139
|
|
4
|
-
torch_einops_utils-0.0.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
torch_einops_utils-0.0.15.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
|
|
6
|
-
torch_einops_utils-0.0.15.dist-info/RECORD,,
|
|
File without changes
|
{torch_einops_utils-0.0.15.dist-info → torch_einops_utils-0.0.17.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|