torch-einops-utils 0.1.2__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-utils
3
- Version: 0.1.2
3
+ Version: 0.1.5
4
4
  Summary: Personal utility functions
5
5
  Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
6
6
  Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torch-einops-utils"
3
- version = "0.1.2"
3
+ version = "0.1.5"
4
4
  description = "Personal utility functions"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,75 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch_einops_utils.nn import Sequential, Lambda, Identity, Residual, count_parameters
4
+
5
+ def test_sequential():
6
+ # Test that it filters out None
7
+ seq = Sequential(
8
+ nn.Linear(10, 10),
9
+ None,
10
+ nn.ReLU()
11
+ )
12
+ assert len(seq) == 2
13
+
14
+ # Test forward pass
15
+ x = torch.randn(2, 10)
16
+ out = seq(x)
17
+ assert out.shape == (2, 10)
18
+
19
+ def test_lambda():
20
+ fn = lambda x: x * 2
21
+ lam = Lambda(fn)
22
+ x = torch.tensor([1., 2., 3.])
23
+ assert torch.allclose(lam(x), torch.tensor([2., 4., 6.]))
24
+
25
+ def test_identity():
26
+ ident = Identity()
27
+ x = torch.tensor([1., 2., 3.])
28
+ assert torch.allclose(ident(x), x)
29
+
30
+ def test_residual():
31
+ fn = lambda x: x * 2
32
+ res = Residual(fn)
33
+ x = torch.tensor([1., 2., 3.])
34
+ assert torch.allclose(res(x), torch.tensor([3., 6., 9.]))
35
+
36
+ def fn_tuple(x):
37
+ return x * 2, x * 3, dict(a = x * 4)
38
+
39
+ res_tuple = Residual(fn_tuple)
40
+ out1, out2, out3 = res_tuple(x)
41
+ assert torch.allclose(out1, torch.tensor([3., 6., 9.]))
42
+ assert torch.allclose(out2, torch.tensor([3., 6., 9.]))
43
+ assert torch.allclose(out3['a'], torch.tensor([4., 8., 12.]))
44
+
45
+ def test_count_parameters():
46
+ model = nn.Linear(10, 10)
47
+ assert count_parameters(model) == 110
48
+
49
+ # Test requires_grad filter
50
+ model.bias.requires_grad_(False)
51
+ assert count_parameters(model) == 110
52
+ assert count_parameters(model, requires_grad=True) == 100
53
+ assert count_parameters(model, requires_grad=False) == 10
54
+
55
+ # Test as a decorator
56
+ @count_parameters
57
+ class MyModel(nn.Module):
58
+ def __init__(self):
59
+ super().__init__()
60
+ self.linear = nn.Linear(10, 10)
61
+ self.linear.bias.requires_grad_(False)
62
+
63
+ my_model = MyModel()
64
+ assert my_model.num_parameters == 110
65
+
66
+ # Test as a decorator with kwargs
67
+ @count_parameters(requires_grad=True)
68
+ class MyModelTrainable(nn.Module):
69
+ def __init__(self):
70
+ super().__init__()
71
+ self.linear = nn.Linear(10, 10)
72
+ self.linear.bias.requires_grad_(False)
73
+
74
+ my_model_trainable = MyModelTrainable()
75
+ assert my_model_trainable.num_parameters == 100
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+ from functools import partial
5
+
6
+ from torch.nn import Module
7
+ from torch.nn import Sequential as PyTorchSequential
8
+
9
+ from torch_einops_utils.torch_einops_utils import exists, tree_flatten_with_inverse
10
+
11
+ # helpers / functions
12
+
13
+ def count_parameters(model_or_class = None, *, requires_grad: bool | None = None):
14
+ if not exists(model_or_class):
15
+ return partial(count_parameters, requires_grad = requires_grad)
16
+
17
+ def _count(model):
18
+ return sum(p.numel() for p in model.parameters() if not exists(requires_grad) or p.requires_grad == requires_grad)
19
+
20
+ if isinstance(model_or_class, type) and issubclass(model_or_class, Module):
21
+ model_or_class.num_parameters = property(_count)
22
+ return model_or_class
23
+
24
+ return _count(model_or_class)
25
+
26
+ def Sequential(*modules):
27
+ return PyTorchSequential(*filter(exists, modules))
28
+
29
+ # classes
30
+
31
+ class Identity(Module):
32
+ def forward(self, t, *args, **kwargs):
33
+ return t
34
+
35
+ class Lambda(Module):
36
+ def __init__(self, fn: Callable):
37
+ super().__init__()
38
+ self.fn = fn
39
+
40
+ def forward(self, t, *args, **kwargs):
41
+ return self.fn(t, *args, **kwargs)
42
+
43
+ class Residual(Module):
44
+ def __init__(self, fn: Callable):
45
+ super().__init__()
46
+ self.fn = fn
47
+
48
+ def forward(self, x, *args, **kwargs):
49
+ out = self.fn(x, *args, **kwargs)
50
+
51
+ (first, *rest), inverse = tree_flatten_with_inverse(out)
52
+ return inverse((first + x, *rest))
@@ -67,8 +67,10 @@ def save_load(
67
67
  setattr(self, config_instance_var_name, (args, kwargs))
68
68
  _orig_init(self, *args, **kwargs)
69
69
 
70
- def _save(self, path, overwrite = True):
71
- path = Path(path)
70
+ def _save(self, path: str | Path, overwrite = True):
71
+ if isinstance(path, str):
72
+ path = Path(path)
73
+
72
74
  assert overwrite or not path.exists()
73
75
 
74
76
  config = getattr(self, config_instance_var_name)
@@ -80,8 +82,10 @@ def save_load(
80
82
 
81
83
  torch.save(pkg, str(path))
82
84
 
83
- def _load(self, path, strict = True):
84
- path = Path(path)
85
+ def _load(self, path: str | Path, strict = True):
86
+ if isinstance(path, str):
87
+ path = Path(path)
88
+
85
89
  assert path.exists()
86
90
 
87
91
  pkg = torch.load(str(path), map_location = 'cpu')
@@ -95,8 +99,10 @@ def save_load(
95
99
  # looks for a `config` key in the stored checkpoint, instantiating the model as well as loading the state dict
96
100
 
97
101
  @classmethod
98
- def _init_and_load_from(cls, path, strict = True):
99
- path = Path(path)
102
+ def _init_and_load_from(cls, path: str | Path, strict = True):
103
+ if isinstance(path, str):
104
+ path = Path(path)
105
+
100
106
  assert path.exists()
101
107
  pkg = torch.load(str(path), map_location = 'cpu')
102
108
 
@@ -1,28 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch_einops_utils.nn import Sequential, Lambda, Identity
4
-
5
- def test_sequential():
6
- # Test that it filters out None
7
- seq = Sequential(
8
- nn.Linear(10, 10),
9
- None,
10
- nn.ReLU()
11
- )
12
- assert len(seq) == 2
13
-
14
- # Test forward pass
15
- x = torch.randn(2, 10)
16
- out = seq(x)
17
- assert out.shape == (2, 10)
18
-
19
- def test_lambda():
20
- fn = lambda x: x * 2
21
- lam = Lambda(fn)
22
- x = torch.tensor([1., 2., 3.])
23
- assert torch.allclose(lam(x), torch.tensor([2., 4., 6.]))
24
-
25
- def test_identity():
26
- ident = Identity()
27
- x = torch.tensor([1., 2., 3.])
28
- assert torch.allclose(ident(x), x)
@@ -1,20 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Callable
3
-
4
- from torch.nn import Module, Sequential as PyTorchSequential
5
- from torch_einops_utils.torch_einops_utils import exists
6
-
7
- def Sequential(*modules):
8
- return PyTorchSequential(*filter(exists, modules))
9
-
10
- class Identity(Module):
11
- def forward(self, t, *args, **kwargs):
12
- return t
13
-
14
- class Lambda(Module):
15
- def __init__(self, fn: Callable):
16
- super().__init__()
17
- self.fn = fn
18
-
19
- def forward(self, t, *args, **kwargs):
20
- return self.fn(t, *args, **kwargs)