torch-einops-utils 0.1.2__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/PKG-INFO +1 -1
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/pyproject.toml +1 -1
- torch_einops_utils-0.1.5/tests/test_nn.py +75 -0
- torch_einops_utils-0.1.5/torch_einops_utils/nn.py +52 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/torch_einops_utils/save_load.py +12 -6
- torch_einops_utils-0.1.2/tests/test_nn.py +0 -28
- torch_einops_utils-0.1.2/torch_einops_utils/nn.py +0 -20
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/.gitignore +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/LICENSE +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/README.md +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/tests/test_device.py +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/tests/test_save_load.py +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/tests/test_utils.py +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/torch_einops_utils/__init__.py +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/torch_einops_utils/device.py +0 -0
- {torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/torch_einops_utils/torch_einops_utils.py +0 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch_einops_utils.nn import Sequential, Lambda, Identity, Residual, count_parameters
|
|
4
|
+
|
|
5
|
+
def test_sequential():
|
|
6
|
+
# Test that it filters out None
|
|
7
|
+
seq = Sequential(
|
|
8
|
+
nn.Linear(10, 10),
|
|
9
|
+
None,
|
|
10
|
+
nn.ReLU()
|
|
11
|
+
)
|
|
12
|
+
assert len(seq) == 2
|
|
13
|
+
|
|
14
|
+
# Test forward pass
|
|
15
|
+
x = torch.randn(2, 10)
|
|
16
|
+
out = seq(x)
|
|
17
|
+
assert out.shape == (2, 10)
|
|
18
|
+
|
|
19
|
+
def test_lambda():
|
|
20
|
+
fn = lambda x: x * 2
|
|
21
|
+
lam = Lambda(fn)
|
|
22
|
+
x = torch.tensor([1., 2., 3.])
|
|
23
|
+
assert torch.allclose(lam(x), torch.tensor([2., 4., 6.]))
|
|
24
|
+
|
|
25
|
+
def test_identity():
|
|
26
|
+
ident = Identity()
|
|
27
|
+
x = torch.tensor([1., 2., 3.])
|
|
28
|
+
assert torch.allclose(ident(x), x)
|
|
29
|
+
|
|
30
|
+
def test_residual():
|
|
31
|
+
fn = lambda x: x * 2
|
|
32
|
+
res = Residual(fn)
|
|
33
|
+
x = torch.tensor([1., 2., 3.])
|
|
34
|
+
assert torch.allclose(res(x), torch.tensor([3., 6., 9.]))
|
|
35
|
+
|
|
36
|
+
def fn_tuple(x):
|
|
37
|
+
return x * 2, x * 3, dict(a = x * 4)
|
|
38
|
+
|
|
39
|
+
res_tuple = Residual(fn_tuple)
|
|
40
|
+
out1, out2, out3 = res_tuple(x)
|
|
41
|
+
assert torch.allclose(out1, torch.tensor([3., 6., 9.]))
|
|
42
|
+
assert torch.allclose(out2, torch.tensor([3., 6., 9.]))
|
|
43
|
+
assert torch.allclose(out3['a'], torch.tensor([4., 8., 12.]))
|
|
44
|
+
|
|
45
|
+
def test_count_parameters():
|
|
46
|
+
model = nn.Linear(10, 10)
|
|
47
|
+
assert count_parameters(model) == 110
|
|
48
|
+
|
|
49
|
+
# Test requires_grad filter
|
|
50
|
+
model.bias.requires_grad_(False)
|
|
51
|
+
assert count_parameters(model) == 110
|
|
52
|
+
assert count_parameters(model, requires_grad=True) == 100
|
|
53
|
+
assert count_parameters(model, requires_grad=False) == 10
|
|
54
|
+
|
|
55
|
+
# Test as a decorator
|
|
56
|
+
@count_parameters
|
|
57
|
+
class MyModel(nn.Module):
|
|
58
|
+
def __init__(self):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.linear = nn.Linear(10, 10)
|
|
61
|
+
self.linear.bias.requires_grad_(False)
|
|
62
|
+
|
|
63
|
+
my_model = MyModel()
|
|
64
|
+
assert my_model.num_parameters == 110
|
|
65
|
+
|
|
66
|
+
# Test as a decorator with kwargs
|
|
67
|
+
@count_parameters(requires_grad=True)
|
|
68
|
+
class MyModelTrainable(nn.Module):
|
|
69
|
+
def __init__(self):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.linear = nn.Linear(10, 10)
|
|
72
|
+
self.linear.bias.requires_grad_(False)
|
|
73
|
+
|
|
74
|
+
my_model_trainable = MyModelTrainable()
|
|
75
|
+
assert my_model_trainable.num_parameters == 100
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
|
|
6
|
+
from torch.nn import Module
|
|
7
|
+
from torch.nn import Sequential as PyTorchSequential
|
|
8
|
+
|
|
9
|
+
from torch_einops_utils.torch_einops_utils import exists, tree_flatten_with_inverse
|
|
10
|
+
|
|
11
|
+
# helpers / functions
|
|
12
|
+
|
|
13
|
+
def count_parameters(model_or_class = None, *, requires_grad: bool | None = None):
|
|
14
|
+
if not exists(model_or_class):
|
|
15
|
+
return partial(count_parameters, requires_grad = requires_grad)
|
|
16
|
+
|
|
17
|
+
def _count(model):
|
|
18
|
+
return sum(p.numel() for p in model.parameters() if not exists(requires_grad) or p.requires_grad == requires_grad)
|
|
19
|
+
|
|
20
|
+
if isinstance(model_or_class, type) and issubclass(model_or_class, Module):
|
|
21
|
+
model_or_class.num_parameters = property(_count)
|
|
22
|
+
return model_or_class
|
|
23
|
+
|
|
24
|
+
return _count(model_or_class)
|
|
25
|
+
|
|
26
|
+
def Sequential(*modules):
|
|
27
|
+
return PyTorchSequential(*filter(exists, modules))
|
|
28
|
+
|
|
29
|
+
# classes
|
|
30
|
+
|
|
31
|
+
class Identity(Module):
|
|
32
|
+
def forward(self, t, *args, **kwargs):
|
|
33
|
+
return t
|
|
34
|
+
|
|
35
|
+
class Lambda(Module):
|
|
36
|
+
def __init__(self, fn: Callable):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.fn = fn
|
|
39
|
+
|
|
40
|
+
def forward(self, t, *args, **kwargs):
|
|
41
|
+
return self.fn(t, *args, **kwargs)
|
|
42
|
+
|
|
43
|
+
class Residual(Module):
|
|
44
|
+
def __init__(self, fn: Callable):
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.fn = fn
|
|
47
|
+
|
|
48
|
+
def forward(self, x, *args, **kwargs):
|
|
49
|
+
out = self.fn(x, *args, **kwargs)
|
|
50
|
+
|
|
51
|
+
(first, *rest), inverse = tree_flatten_with_inverse(out)
|
|
52
|
+
return inverse((first + x, *rest))
|
|
@@ -67,8 +67,10 @@ def save_load(
|
|
|
67
67
|
setattr(self, config_instance_var_name, (args, kwargs))
|
|
68
68
|
_orig_init(self, *args, **kwargs)
|
|
69
69
|
|
|
70
|
-
def _save(self, path, overwrite = True):
|
|
71
|
-
|
|
70
|
+
def _save(self, path: str | Path, overwrite = True):
|
|
71
|
+
if isinstance(path, str):
|
|
72
|
+
path = Path(path)
|
|
73
|
+
|
|
72
74
|
assert overwrite or not path.exists()
|
|
73
75
|
|
|
74
76
|
config = getattr(self, config_instance_var_name)
|
|
@@ -80,8 +82,10 @@ def save_load(
|
|
|
80
82
|
|
|
81
83
|
torch.save(pkg, str(path))
|
|
82
84
|
|
|
83
|
-
def _load(self, path, strict = True):
|
|
84
|
-
|
|
85
|
+
def _load(self, path: str | Path, strict = True):
|
|
86
|
+
if isinstance(path, str):
|
|
87
|
+
path = Path(path)
|
|
88
|
+
|
|
85
89
|
assert path.exists()
|
|
86
90
|
|
|
87
91
|
pkg = torch.load(str(path), map_location = 'cpu')
|
|
@@ -95,8 +99,10 @@ def save_load(
|
|
|
95
99
|
# looks for a `config` key in the stored checkpoint, instantiating the model as well as loading the state dict
|
|
96
100
|
|
|
97
101
|
@classmethod
|
|
98
|
-
def _init_and_load_from(cls, path, strict = True):
|
|
99
|
-
|
|
102
|
+
def _init_and_load_from(cls, path: str | Path, strict = True):
|
|
103
|
+
if isinstance(path, str):
|
|
104
|
+
path = Path(path)
|
|
105
|
+
|
|
100
106
|
assert path.exists()
|
|
101
107
|
pkg = torch.load(str(path), map_location = 'cpu')
|
|
102
108
|
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import nn
|
|
3
|
-
from torch_einops_utils.nn import Sequential, Lambda, Identity
|
|
4
|
-
|
|
5
|
-
def test_sequential():
|
|
6
|
-
# Test that it filters out None
|
|
7
|
-
seq = Sequential(
|
|
8
|
-
nn.Linear(10, 10),
|
|
9
|
-
None,
|
|
10
|
-
nn.ReLU()
|
|
11
|
-
)
|
|
12
|
-
assert len(seq) == 2
|
|
13
|
-
|
|
14
|
-
# Test forward pass
|
|
15
|
-
x = torch.randn(2, 10)
|
|
16
|
-
out = seq(x)
|
|
17
|
-
assert out.shape == (2, 10)
|
|
18
|
-
|
|
19
|
-
def test_lambda():
|
|
20
|
-
fn = lambda x: x * 2
|
|
21
|
-
lam = Lambda(fn)
|
|
22
|
-
x = torch.tensor([1., 2., 3.])
|
|
23
|
-
assert torch.allclose(lam(x), torch.tensor([2., 4., 6.]))
|
|
24
|
-
|
|
25
|
-
def test_identity():
|
|
26
|
-
ident = Identity()
|
|
27
|
-
x = torch.tensor([1., 2., 3.])
|
|
28
|
-
assert torch.allclose(ident(x), x)
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from typing import Callable
|
|
3
|
-
|
|
4
|
-
from torch.nn import Module, Sequential as PyTorchSequential
|
|
5
|
-
from torch_einops_utils.torch_einops_utils import exists
|
|
6
|
-
|
|
7
|
-
def Sequential(*modules):
|
|
8
|
-
return PyTorchSequential(*filter(exists, modules))
|
|
9
|
-
|
|
10
|
-
class Identity(Module):
|
|
11
|
-
def forward(self, t, *args, **kwargs):
|
|
12
|
-
return t
|
|
13
|
-
|
|
14
|
-
class Lambda(Module):
|
|
15
|
-
def __init__(self, fn: Callable):
|
|
16
|
-
super().__init__()
|
|
17
|
-
self.fn = fn
|
|
18
|
-
|
|
19
|
-
def forward(self, t, *args, **kwargs):
|
|
20
|
-
return self.fn(t, *args, **kwargs)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torch_einops_utils-0.1.2 → torch_einops_utils-0.1.5}/torch_einops_utils/torch_einops_utils.py
RENAMED
|
File without changes
|