RRAEsTorch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_base/AE_base.py +104 -0
- RRAEsTorch/AE_base/__init__.py +1 -0
- RRAEsTorch/AE_classes/AE_classes.py +636 -0
- RRAEsTorch/AE_classes/__init__.py +1 -0
- RRAEsTorch/__init__.py +1 -0
- RRAEsTorch/config.py +95 -0
- RRAEsTorch/tests/test_AE_classes_CNN.py +76 -0
- RRAEsTorch/tests/test_AE_classes_MLP.py +73 -0
- RRAEsTorch/tests/test_fitting_CNN.py +109 -0
- RRAEsTorch/tests/test_fitting_MLP.py +133 -0
- RRAEsTorch/tests/test_mains.py +34 -0
- RRAEsTorch/tests/test_save.py +62 -0
- RRAEsTorch/tests/test_stable_SVD.py +37 -0
- RRAEsTorch/tests/test_wrappers.py +56 -0
- RRAEsTorch/trackers/__init__.py +1 -0
- RRAEsTorch/trackers/trackers.py +245 -0
- RRAEsTorch/training_classes/__init__.py +5 -0
- RRAEsTorch/training_classes/training_classes.py +977 -0
- RRAEsTorch/utilities/__init__.py +1 -0
- RRAEsTorch/utilities/utilities.py +1562 -0
- RRAEsTorch/wrappers/__init__.py +1 -0
- RRAEsTorch/wrappers/wrappers.py +237 -0
- rraestorch-0.1.0.dist-info/METADATA +90 -0
- rraestorch-0.1.0.dist-info/RECORD +27 -0
- rraestorch-0.1.0.dist-info/WHEEL +4 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE +21 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE copy +21 -0
RRAEsTorch/config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from RRAEsTorch.utilities import MLP_with_linear
|
|
2
|
+
import numpy.random as random
|
|
3
|
+
from RRAEsTorch.AE_base import AE_base as rraes
|
|
4
|
+
import torch
|
|
5
|
+
from torch.func import vmap
|
|
6
|
+
|
|
7
|
+
class Autoencoder(torch.nn.Module):
|
|
8
|
+
_encode: MLP_with_linear
|
|
9
|
+
_decode: MLP_with_linear
|
|
10
|
+
_perform_in_latent: callable
|
|
11
|
+
_perform_in_latent: callable
|
|
12
|
+
map_latent: bool
|
|
13
|
+
norm_funcs: list
|
|
14
|
+
inv_norm_funcs: list
|
|
15
|
+
count: int
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
in_size,
|
|
20
|
+
latent_size,
|
|
21
|
+
latent_size_after=None,
|
|
22
|
+
_encode=None,
|
|
23
|
+
_decode=None,
|
|
24
|
+
map_latent=True,
|
|
25
|
+
*,
|
|
26
|
+
count=1,
|
|
27
|
+
kwargs_enc={},
|
|
28
|
+
kwargs_dec={},
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
if latent_size_after is None:
|
|
33
|
+
latent_size_after = latent_size
|
|
34
|
+
|
|
35
|
+
if _encode is None:
|
|
36
|
+
if "width_size" not in kwargs_enc.keys():
|
|
37
|
+
kwargs_enc["width_size"] = 64
|
|
38
|
+
|
|
39
|
+
if "depth" not in kwargs_enc.keys():
|
|
40
|
+
kwargs_enc["depth"] = 1
|
|
41
|
+
|
|
42
|
+
self._encode = MLP_with_linear(
|
|
43
|
+
in_size=in_size,
|
|
44
|
+
out_size=latent_size,
|
|
45
|
+
**kwargs_enc,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
else:
|
|
49
|
+
self._encode = _encode
|
|
50
|
+
|
|
51
|
+
if not hasattr(self, "_perform_in_latent"):
|
|
52
|
+
self._perform_in_latent = lambda x, *args, **kwargs: x
|
|
53
|
+
|
|
54
|
+
if _decode is None:
|
|
55
|
+
if "width_size" not in kwargs_dec.keys():
|
|
56
|
+
kwargs_dec["width_size"] = 64
|
|
57
|
+
if "depth" not in kwargs_dec.keys():
|
|
58
|
+
kwargs_dec["depth"] = 6
|
|
59
|
+
|
|
60
|
+
self._decode = MLP_with_linear(
|
|
61
|
+
in_size=latent_size_after,
|
|
62
|
+
out_size=in_size,
|
|
63
|
+
**kwargs_dec,
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
self._decode = _decode
|
|
67
|
+
|
|
68
|
+
self.count = count
|
|
69
|
+
self.map_latent = map_latent
|
|
70
|
+
self.inv_norm_funcs = ["decode"]
|
|
71
|
+
self.norm_funcs = ["encode", "latent"]
|
|
72
|
+
|
|
73
|
+
def encode(self, x, *args, **kwargs):
|
|
74
|
+
return self._encode(x, *args, **kwargs)
|
|
75
|
+
|
|
76
|
+
def decode(self, x, *args, **kwargs):
|
|
77
|
+
return self._decode(x, *args, **kwargs)
|
|
78
|
+
|
|
79
|
+
def perform_in_latent(self, y, *args, **kwargs):
|
|
80
|
+
if self.map_latent:
|
|
81
|
+
new_perform_in_latent = lambda x: self._perform_in_latent(
|
|
82
|
+
x, *args, **kwargs
|
|
83
|
+
)
|
|
84
|
+
for _ in range(self.count):
|
|
85
|
+
new_perform_in_latent = vmap(new_perform_in_latent, in_dims=-1, out_dims=-1)
|
|
86
|
+
return new_perform_in_latent(y)
|
|
87
|
+
return self._perform_in_latent(y, *args, **kwargs)
|
|
88
|
+
|
|
89
|
+
def forward(self, x, *args, **kwargs):
|
|
90
|
+
return self.decode(self.perform_in_latent(self.encode(x), *args, **kwargs))
|
|
91
|
+
|
|
92
|
+
def latent(self, x, *args, **kwargs):
|
|
93
|
+
return self.perform_in_latent(self.encode(x), *args, **kwargs)
|
|
94
|
+
|
|
95
|
+
rraes.set_autoencoder_base(Autoencoder)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import RRAEsTorch.config
|
|
2
|
+
import pytest
|
|
3
|
+
from RRAEsTorch.AE_classes import (
|
|
4
|
+
RRAE_CNN,
|
|
5
|
+
Vanilla_AE_CNN,
|
|
6
|
+
IRMAE_CNN,
|
|
7
|
+
LoRAE_CNN,
|
|
8
|
+
)
|
|
9
|
+
from RRAEsTorch.wrappers import vmap_wrap
|
|
10
|
+
import numpy.random as random
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
methods = ["encode", "decode"]
|
|
15
|
+
|
|
16
|
+
v_RRAE_CNN = vmap_wrap(RRAE_CNN, -1, 1, methods)
|
|
17
|
+
v_Vanilla_AE_CNN = vmap_wrap(Vanilla_AE_CNN, -1, 1, methods)
|
|
18
|
+
v_IRMAE_CNN = vmap_wrap(IRMAE_CNN, -1, 1, methods)
|
|
19
|
+
v_LoRAE_CNN = vmap_wrap(LoRAE_CNN, -1, 1, methods)
|
|
20
|
+
|
|
21
|
+
@pytest.mark.parametrize("width", (10, 17, 149))
|
|
22
|
+
@pytest.mark.parametrize("height", (20,))
|
|
23
|
+
@pytest.mark.parametrize("latent", (200,))
|
|
24
|
+
@pytest.mark.parametrize("num_modes", (1,))
|
|
25
|
+
@pytest.mark.parametrize("channels", (1, 3, 5))
|
|
26
|
+
@pytest.mark.parametrize("num_samples", (10, 100))
|
|
27
|
+
class Test_AEs_shapes:
|
|
28
|
+
def test_RRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
29
|
+
x = random.normal(size=(channels, width, height, num_samples))
|
|
30
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
31
|
+
kwargs = {"kwargs_dec": {"stride": 2}}
|
|
32
|
+
model = v_RRAE_CNN(
|
|
33
|
+
x.shape[0], x.shape[1], x.shape[2], latent, num_modes, **kwargs
|
|
34
|
+
)
|
|
35
|
+
y = model.encode(x)
|
|
36
|
+
assert y.shape == (latent, num_samples)
|
|
37
|
+
y = model.latent(x, k_max=num_modes)
|
|
38
|
+
_, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
|
|
39
|
+
assert sing_vals[num_modes + 1] < 1e-5
|
|
40
|
+
assert y.shape == (latent, num_samples)
|
|
41
|
+
assert model.decode(y).shape == (channels, width, height, num_samples)
|
|
42
|
+
|
|
43
|
+
def test_Vanilla_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
44
|
+
x = random.normal(size=(channels, width, height, num_samples))
|
|
45
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
46
|
+
kwargs = {"kwargs_dec": {"stride": 2}}
|
|
47
|
+
model = v_Vanilla_AE_CNN(
|
|
48
|
+
x.shape[0], x.shape[1], x.shape[2], latent, **kwargs
|
|
49
|
+
)
|
|
50
|
+
y = model.encode(x)
|
|
51
|
+
assert y.shape == (latent, num_samples)
|
|
52
|
+
x = model.decode(y)
|
|
53
|
+
assert x.shape == (channels, width, height, num_samples)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def test_IRMAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
57
|
+
x = random.normal(size=(channels, width, height, num_samples))
|
|
58
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
59
|
+
model = v_IRMAE_CNN(
|
|
60
|
+
x.shape[0], x.shape[1], x.shape[2], latent, linear_l=2
|
|
61
|
+
)
|
|
62
|
+
y = model.encode(x)
|
|
63
|
+
assert y.shape == (latent, num_samples)
|
|
64
|
+
assert len(model._encode.layers[-1].layers_l) == 2
|
|
65
|
+
x = model.decode(y)
|
|
66
|
+
assert x.shape == (channels, width, height, num_samples)
|
|
67
|
+
|
|
68
|
+
def test_LoRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
69
|
+
x = random.normal(size=(channels, width, height, num_samples))
|
|
70
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
71
|
+
model = v_LoRAE_CNN(x.shape[0], x.shape[1], x.shape[2], latent)
|
|
72
|
+
y = model.encode(x)
|
|
73
|
+
assert y.shape == (latent, num_samples)
|
|
74
|
+
assert len(model._encode.layers[-1].layers_l) == 1
|
|
75
|
+
x = model.decode(y)
|
|
76
|
+
assert x.shape == (channels, width, height, num_samples)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import RRAEsTorch.config
|
|
2
|
+
import pytest
|
|
3
|
+
from RRAEsTorch.AE_classes import (
|
|
4
|
+
RRAE_MLP,
|
|
5
|
+
Vanilla_AE_MLP,
|
|
6
|
+
IRMAE_MLP,
|
|
7
|
+
LoRAE_MLP,
|
|
8
|
+
)
|
|
9
|
+
from RRAEsTorch.wrappers import vmap_wrap
|
|
10
|
+
import numpy.random as random
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
methods = ["encode", "decode"]
|
|
15
|
+
|
|
16
|
+
v_RRAE_MLP = vmap_wrap(RRAE_MLP, -1, 1, methods)
|
|
17
|
+
v_Vanilla_AE_MLP = vmap_wrap(Vanilla_AE_MLP, -1, 1, methods)
|
|
18
|
+
v_IRMAE_MLP = vmap_wrap(IRMAE_MLP, -1, 1, methods)
|
|
19
|
+
v_LoRAE_MLP = vmap_wrap(LoRAE_MLP, -1, 1, methods)
|
|
20
|
+
|
|
21
|
+
@pytest.mark.parametrize("dim_D", (10, 15, 50))
|
|
22
|
+
@pytest.mark.parametrize("latent", (200, 400, 800))
|
|
23
|
+
@pytest.mark.parametrize("num_modes", (1, 2, 6))
|
|
24
|
+
class Test_AEs_shapes:
|
|
25
|
+
def test_RRAE_MLP(self, latent, num_modes, dim_D):
|
|
26
|
+
x = random.normal(size=(500, dim_D))
|
|
27
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
28
|
+
model = v_RRAE_MLP(x.shape[0], latent, num_modes)
|
|
29
|
+
y = model.encode(x)
|
|
30
|
+
assert y.shape == (latent, dim_D)
|
|
31
|
+
y = model.perform_in_latent(y, k_max=num_modes)
|
|
32
|
+
_, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
|
|
33
|
+
assert sing_vals[num_modes + 1] < 1e-5
|
|
34
|
+
assert y.shape == (latent, dim_D)
|
|
35
|
+
assert model.decode(y).shape == (500, dim_D)
|
|
36
|
+
|
|
37
|
+
def test_Vanilla_MLP(self, latent, num_modes, dim_D):
|
|
38
|
+
x = random.normal(size=(500, dim_D))
|
|
39
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
40
|
+
model = v_Vanilla_AE_MLP(x.shape[0], latent)
|
|
41
|
+
y = model.encode(x)
|
|
42
|
+
assert y.shape == (latent, dim_D)
|
|
43
|
+
x = model.decode(y)
|
|
44
|
+
assert x.shape == (500, dim_D)
|
|
45
|
+
|
|
46
|
+
def test_IRMAE_MLP(self, latent, num_modes, dim_D):
|
|
47
|
+
x = random.normal(size=(500, dim_D))
|
|
48
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
49
|
+
model = v_IRMAE_MLP(x.shape[0], latent, linear_l=2)
|
|
50
|
+
y = model.encode(x)
|
|
51
|
+
assert y.shape == (latent, dim_D)
|
|
52
|
+
assert len(model._encode.layers_l) == 2
|
|
53
|
+
x = model.decode(y)
|
|
54
|
+
assert x.shape == (500, dim_D)
|
|
55
|
+
|
|
56
|
+
def test_LoRAE_MLP(self, latent, num_modes, dim_D):
|
|
57
|
+
x = random.normal(size=(500, dim_D))
|
|
58
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
59
|
+
model = v_LoRAE_MLP(x.shape[0], latent)
|
|
60
|
+
y = model.encode(x)
|
|
61
|
+
assert y.shape == (latent, dim_D)
|
|
62
|
+
assert len(model._encode.layers_l) == 1
|
|
63
|
+
x = model.decode(y)
|
|
64
|
+
assert x.shape == (500, dim_D)
|
|
65
|
+
|
|
66
|
+
def test_getting_SVD_coeffs():
|
|
67
|
+
data = random.uniform(size=(500, 15))
|
|
68
|
+
data = torch.tensor(data, dtype=torch.float32)
|
|
69
|
+
model_s = v_RRAE_MLP(data.shape[0], 200, 3)
|
|
70
|
+
basis, coeffs = model_s.latent(data, k_max=3, get_basis_coeffs=True)
|
|
71
|
+
assert basis.shape == (200, 3)
|
|
72
|
+
assert coeffs.shape == (3, 15)
|
|
73
|
+
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import RRAEsTorch.config
|
|
2
|
+
import pytest
|
|
3
|
+
from RRAEsTorch.AE_classes import (
|
|
4
|
+
RRAE_CNN,
|
|
5
|
+
Vanilla_AE_CNN,
|
|
6
|
+
IRMAE_CNN,
|
|
7
|
+
LoRAE_CNN,
|
|
8
|
+
)
|
|
9
|
+
from RRAEsTorch.training_classes import RRAE_Trainor_class, Trainor_class, AE_Trainor_class
|
|
10
|
+
import numpy.random as random
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
@pytest.mark.parametrize(
|
|
14
|
+
"model_cls, sh, lf",
|
|
15
|
+
[
|
|
16
|
+
(Vanilla_AE_CNN, (1, 2, 2, 10), "default"),
|
|
17
|
+
(LoRAE_CNN, (6, 16, 16, 10), "nuc"),
|
|
18
|
+
],
|
|
19
|
+
)
|
|
20
|
+
def test_AE_fitting(model_cls, sh, lf):
|
|
21
|
+
x = random.normal(size=sh)
|
|
22
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
23
|
+
trainor = AE_Trainor_class(
|
|
24
|
+
x,
|
|
25
|
+
model_cls,
|
|
26
|
+
latent_size=100,
|
|
27
|
+
channels=x.shape[0],
|
|
28
|
+
width=x.shape[1],
|
|
29
|
+
height=x.shape[2],
|
|
30
|
+
samples=x.shape[-1], # Only for weak
|
|
31
|
+
k_max=2,
|
|
32
|
+
)
|
|
33
|
+
kwargs = {
|
|
34
|
+
"step_st": [2],
|
|
35
|
+
"loss_kwargs": {
|
|
36
|
+
"lambda_nuc": 0.001,
|
|
37
|
+
"find_layer": lambda model: model.encode.layers[-2].layers[-1].weight,
|
|
38
|
+
"loss_type": lf
|
|
39
|
+
},
|
|
40
|
+
}
|
|
41
|
+
try:
|
|
42
|
+
trainor.fit(
|
|
43
|
+
x,
|
|
44
|
+
x,
|
|
45
|
+
verbose=False,
|
|
46
|
+
training_kwargs=kwargs,
|
|
47
|
+
)
|
|
48
|
+
except Exception as e:
|
|
49
|
+
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_IRMAE_fitting():
|
|
53
|
+
model_cls = IRMAE_CNN
|
|
54
|
+
lf = "default"
|
|
55
|
+
sh = (3, 12, 12, 10)
|
|
56
|
+
x = random.normal(size=sh)
|
|
57
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
58
|
+
trainor = AE_Trainor_class(
|
|
59
|
+
x,
|
|
60
|
+
model_cls,
|
|
61
|
+
latent_size=100,
|
|
62
|
+
channels=x.shape[0],
|
|
63
|
+
width=x.shape[1],
|
|
64
|
+
height=x.shape[2],
|
|
65
|
+
k_max=2,
|
|
66
|
+
linear_l=4,
|
|
67
|
+
)
|
|
68
|
+
kwargs = {"step_st": [2], "loss_type": lf}
|
|
69
|
+
try:
|
|
70
|
+
trainor.fit(
|
|
71
|
+
x,
|
|
72
|
+
x,
|
|
73
|
+
verbose=False,
|
|
74
|
+
training_kwargs=kwargs,
|
|
75
|
+
)
|
|
76
|
+
except Exception as e:
|
|
77
|
+
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
78
|
+
|
|
79
|
+
def test_RRAE_fitting():
|
|
80
|
+
sh = (1, 20, 20, 10)
|
|
81
|
+
model_cls = RRAE_CNN
|
|
82
|
+
x = random.normal(size=sh)
|
|
83
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
84
|
+
trainor = RRAE_Trainor_class(
|
|
85
|
+
x,
|
|
86
|
+
model_cls,
|
|
87
|
+
latent_size=100,
|
|
88
|
+
channels=x.shape[0],
|
|
89
|
+
width=x.shape[1],
|
|
90
|
+
height=x.shape[2],
|
|
91
|
+
k_max=2,
|
|
92
|
+
)
|
|
93
|
+
training_kwargs = {
|
|
94
|
+
"step_st": [2],
|
|
95
|
+
"loss_type": "default"
|
|
96
|
+
}
|
|
97
|
+
ft_kwargs = {
|
|
98
|
+
"step_st": [2],
|
|
99
|
+
}
|
|
100
|
+
try:
|
|
101
|
+
trainor.fit(
|
|
102
|
+
x,
|
|
103
|
+
x,
|
|
104
|
+
verbose=False,
|
|
105
|
+
training_kwargs=training_kwargs,
|
|
106
|
+
ft_kwargs=ft_kwargs,
|
|
107
|
+
)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import RRAEsTorch.config
|
|
2
|
+
import numpy.random as random
|
|
3
|
+
import pytest
|
|
4
|
+
from RRAEsTorch.AE_classes import (
|
|
5
|
+
RRAE_MLP,
|
|
6
|
+
Vanilla_AE_MLP,
|
|
7
|
+
IRMAE_MLP,
|
|
8
|
+
LoRAE_MLP,
|
|
9
|
+
)
|
|
10
|
+
from torchvision.ops import MLP
|
|
11
|
+
from RRAEsTorch.training_classes import RRAE_Trainor_class, Trainor_class, AE_Trainor_class
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
@pytest.mark.parametrize(
|
|
15
|
+
"model_cls, sh, lf",
|
|
16
|
+
[
|
|
17
|
+
(Vanilla_AE_MLP, (500, 10), "default"),
|
|
18
|
+
(LoRAE_MLP, (500, 10), "nuc"),
|
|
19
|
+
],
|
|
20
|
+
)
|
|
21
|
+
def test_fitting(model_cls, sh, lf):
|
|
22
|
+
x = random.normal(size=sh)
|
|
23
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
24
|
+
trainor = AE_Trainor_class(
|
|
25
|
+
x,
|
|
26
|
+
model_cls,
|
|
27
|
+
in_size=x.shape[0],
|
|
28
|
+
data_size=x.shape[-1],
|
|
29
|
+
samples=x.shape[-1], # Only for weak
|
|
30
|
+
norm_in="meanstd",
|
|
31
|
+
norm_out="minmax",
|
|
32
|
+
out_train=x,
|
|
33
|
+
latent_size=2000,
|
|
34
|
+
k_max=2,
|
|
35
|
+
)
|
|
36
|
+
kwargs = {
|
|
37
|
+
"step_st": [2],
|
|
38
|
+
"loss_kwargs": {"lambda_nuc": 0.001, "find_layer": lambda model: model._encode.layers_l[0].weight},
|
|
39
|
+
"loss_type": lf
|
|
40
|
+
}
|
|
41
|
+
try:
|
|
42
|
+
trainor.fit(
|
|
43
|
+
x,
|
|
44
|
+
x,
|
|
45
|
+
verbose=False,
|
|
46
|
+
training_kwargs=kwargs,
|
|
47
|
+
)
|
|
48
|
+
except Exception as e:
|
|
49
|
+
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_RRAE_fitting():
|
|
54
|
+
sh = (500, 10)
|
|
55
|
+
model_cls = RRAE_MLP
|
|
56
|
+
x = random.normal(size=sh)
|
|
57
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
58
|
+
trainor = RRAE_Trainor_class(
|
|
59
|
+
x,
|
|
60
|
+
model_cls,
|
|
61
|
+
in_size=x.shape[0],
|
|
62
|
+
latent_size=2000,
|
|
63
|
+
k_max=2,
|
|
64
|
+
)
|
|
65
|
+
training_kwargs = {
|
|
66
|
+
"step_st": [2],
|
|
67
|
+
"loss_type":"RRAE"
|
|
68
|
+
}
|
|
69
|
+
ft_kwargs = {
|
|
70
|
+
"step_st": [2],
|
|
71
|
+
}
|
|
72
|
+
try:
|
|
73
|
+
trainor.fit(
|
|
74
|
+
x,
|
|
75
|
+
x,
|
|
76
|
+
verbose=False,
|
|
77
|
+
training_kwargs=training_kwargs,
|
|
78
|
+
ft_kwargs=ft_kwargs,
|
|
79
|
+
)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
82
|
+
|
|
83
|
+
def test_IRMAE_fitting():
|
|
84
|
+
model_cls = IRMAE_MLP
|
|
85
|
+
lf = "default"
|
|
86
|
+
sh = (500, 10)
|
|
87
|
+
x = random.normal(size=sh)
|
|
88
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
89
|
+
trainor = AE_Trainor_class(
|
|
90
|
+
x,
|
|
91
|
+
model_cls,
|
|
92
|
+
in_size=x.shape[0],
|
|
93
|
+
data_size=x.shape[-1],
|
|
94
|
+
latent_size=2000,
|
|
95
|
+
k_max=2,
|
|
96
|
+
linear_l=4,
|
|
97
|
+
)
|
|
98
|
+
kwargs = {"step_st": [2], "loss_type":lf}
|
|
99
|
+
try:
|
|
100
|
+
trainor.fit(
|
|
101
|
+
x,
|
|
102
|
+
x,
|
|
103
|
+
verbose=False,
|
|
104
|
+
training_kwargs=kwargs,
|
|
105
|
+
)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
108
|
+
|
|
109
|
+
def test_fitting():
|
|
110
|
+
sh = (50, 100)
|
|
111
|
+
model_cls = MLP
|
|
112
|
+
x = random.normal(size=sh)
|
|
113
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
114
|
+
trainor = Trainor_class(
|
|
115
|
+
x,
|
|
116
|
+
model_cls,
|
|
117
|
+
in_channels=x.shape[0],
|
|
118
|
+
hidden_channels=[100, x.shape[0]]
|
|
119
|
+
)
|
|
120
|
+
training_kwargs = {
|
|
121
|
+
"step_st": [2],
|
|
122
|
+
"loss_type": "default"
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
trainor.fit(
|
|
127
|
+
x,
|
|
128
|
+
x,
|
|
129
|
+
verbose=False,
|
|
130
|
+
**training_kwargs,
|
|
131
|
+
)
|
|
132
|
+
except Exception as e:
|
|
133
|
+
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import os
|
|
3
|
+
import pytest
|
|
4
|
+
import shutil
|
|
5
|
+
|
|
6
|
+
def try_remove(name):
|
|
7
|
+
try:
|
|
8
|
+
shutil.rmtree(name)
|
|
9
|
+
except FileNotFoundError:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
def run_script(script_name):
|
|
13
|
+
try:
|
|
14
|
+
result = subprocess.run(
|
|
15
|
+
["python", script_name], check=True, capture_output=True, text=True
|
|
16
|
+
)
|
|
17
|
+
try_remove("shift")
|
|
18
|
+
try_remove("folder_name")
|
|
19
|
+
try_remove("2d_gaussian_shift_scale")
|
|
20
|
+
try_remove("gaussian_shift")
|
|
21
|
+
return result.stdout
|
|
22
|
+
except subprocess.CalledProcessError as e:
|
|
23
|
+
pytest.fail(f"Error running {script_name}:\n{e.stderr}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.mark.parametrize(
|
|
27
|
+
"script_name", ["main-MLP.py", "main-CNN.py", "general-MLP.py", "main-adap-CNN.py", "main-adap-MLP.py", "main-var-CNN.py", "main-CNN1D.py"]
|
|
28
|
+
)
|
|
29
|
+
def test_scripts(script_name):
|
|
30
|
+
if os.path.exists(script_name):
|
|
31
|
+
output = run_script(script_name)
|
|
32
|
+
assert output is not None
|
|
33
|
+
else:
|
|
34
|
+
pytest.fail(f"Script {script_name} not found")
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import RRAEsTorch.config
|
|
2
|
+
from RRAEsTorch.AE_classes import RRAE_CNN
|
|
3
|
+
from RRAEsTorch.training_classes import RRAE_Trainor_class
|
|
4
|
+
import numpy.random as random
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
def test_save(): # Only to test if saving/loading is causing a problem
|
|
8
|
+
data = random.normal(size=(1, 28, 28, 1))
|
|
9
|
+
data = torch.tensor(data, dtype=torch.float32)
|
|
10
|
+
model_cls = RRAE_CNN
|
|
11
|
+
|
|
12
|
+
trainor = RRAE_Trainor_class(
|
|
13
|
+
data,
|
|
14
|
+
model_cls,
|
|
15
|
+
latent_size=100,
|
|
16
|
+
channels=data.shape[0],
|
|
17
|
+
width=data.shape[1],
|
|
18
|
+
height=data.shape[2],
|
|
19
|
+
pre_func_inp=lambda x: x * 2 / 17,
|
|
20
|
+
pre_func_out=lambda x: x / 2,
|
|
21
|
+
k_max=2,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
trainor.save_model("test_")
|
|
25
|
+
new_trainor = RRAE_Trainor_class()
|
|
26
|
+
new_trainor.load_model("test_", erase=True)
|
|
27
|
+
try:
|
|
28
|
+
pr = trainor.model(data[..., 0:1], k_max=2)
|
|
29
|
+
except Exception as e:
|
|
30
|
+
raise ValueError(f"Original trainor failed with following exception {e}")
|
|
31
|
+
try:
|
|
32
|
+
pr = new_trainor.model(data[..., 0:1], k_max=2)
|
|
33
|
+
except Exception as e:
|
|
34
|
+
raise ValueError(f"Failed with following exception {e}")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_save_with_final_act():
|
|
38
|
+
data = random.normal(size=(1, 28, 28, 1))
|
|
39
|
+
data = torch.tensor(data, dtype=torch.float32)
|
|
40
|
+
|
|
41
|
+
model_cls = RRAE_CNN
|
|
42
|
+
|
|
43
|
+
trainor = RRAE_Trainor_class(
|
|
44
|
+
data,
|
|
45
|
+
model_cls,
|
|
46
|
+
latent_size=100,
|
|
47
|
+
channels=data.shape[0],
|
|
48
|
+
width=data.shape[1],
|
|
49
|
+
height=data.shape[2],
|
|
50
|
+
kwargs_dec={"final_activation": torch.sigmoid},
|
|
51
|
+
k_max=2,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
trainor.save_model("test_")
|
|
55
|
+
new_trainor = RRAE_Trainor_class()
|
|
56
|
+
new_trainor.load_model("test_", erase=True)
|
|
57
|
+
try:
|
|
58
|
+
pr = new_trainor.model(data[..., 0:1], k_max=2)
|
|
59
|
+
assert torch.max(pr) <= 1.0, "Final activation not working"
|
|
60
|
+
assert torch.min(pr) >= 0.0, "Final activation not working"
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise ValueError(f"Failed with following exception {e}")
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from RRAEsTorch.utilities import stable_SVD
|
|
2
|
+
from torch.linalg import svd as normal_svd
|
|
3
|
+
import numpy.random as random
|
|
4
|
+
import pytest
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
def stable_SVD_to_scalar(A):
|
|
9
|
+
U, s, Vt = stable_SVD(A)
|
|
10
|
+
return torch.linalg.norm((U * s) @ Vt) # Any scalar depending on U, s, and Vt.
|
|
11
|
+
|
|
12
|
+
def normal_svd_to_scalar(A):
|
|
13
|
+
U, s, Vt = normal_svd(A, full_matrices=False)
|
|
14
|
+
return torch.linalg.norm((U * s) @ Vt) # Any scalar depending on U, s, and Vt.
|
|
15
|
+
|
|
16
|
+
@pytest.mark.parametrize(
|
|
17
|
+
"length, width",
|
|
18
|
+
[(10, 10), (100, 10), (10, 100), (50000, 100), (1000, 1000), (100, 50000)],
|
|
19
|
+
)
|
|
20
|
+
def test_random_normal(length, width):
|
|
21
|
+
A = random.uniform(low=0.0, high=1.0, size=(length, width))
|
|
22
|
+
A = torch.tensor(A, dtype=torch.float32)
|
|
23
|
+
|
|
24
|
+
A = A.clone().detach().requires_grad_(True)
|
|
25
|
+
|
|
26
|
+
stable_value = stable_SVD_to_scalar(A)
|
|
27
|
+
stable_value.backward()
|
|
28
|
+
stable_grad = A.grad
|
|
29
|
+
|
|
30
|
+
A = A.clone().detach().requires_grad_(True)
|
|
31
|
+
|
|
32
|
+
normal_value = normal_svd_to_scalar(A)
|
|
33
|
+
normal_value.backward()
|
|
34
|
+
normal_grad = A.grad
|
|
35
|
+
|
|
36
|
+
assert torch.allclose(stable_value, normal_value, atol=1e-5, rtol=1e-5)
|
|
37
|
+
assert torch.allclose(stable_grad, normal_grad, atol=1e-5, rtol=1e-5)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
|
|
2
|
+
from torchvision.ops import MLP
|
|
3
|
+
import pytest
|
|
4
|
+
import numpy as np
|
|
5
|
+
import math
|
|
6
|
+
import numpy.random as random
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
def test_vmap_wrapper():
|
|
10
|
+
# Usually MLP only accepts a vector, here we give
|
|
11
|
+
# a tensor and vectorize over the last axis twice
|
|
12
|
+
data = random.normal(size=(50, 60, 600))
|
|
13
|
+
data = torch.tensor(data, dtype=torch.float32)
|
|
14
|
+
|
|
15
|
+
model_cls = vmap_wrap(MLP, -1, 2)
|
|
16
|
+
model = model_cls(50, [64, 100])
|
|
17
|
+
try:
|
|
18
|
+
model(data)
|
|
19
|
+
except ValueError:
|
|
20
|
+
pytest.fail("Vmap wrapper is not working properly.")
|
|
21
|
+
|
|
22
|
+
def test_norm_wrapper():
|
|
23
|
+
# Testing the keep_normalized kwarg
|
|
24
|
+
data = random.normal(size=(50,))
|
|
25
|
+
data = torch.tensor(data, dtype=torch.float32)
|
|
26
|
+
model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
|
|
27
|
+
model = model_cls(50, [64, 100])
|
|
28
|
+
try:
|
|
29
|
+
assert not torch.allclose(model(data), model(data, keep_normalized=True))
|
|
30
|
+
except AssertionError:
|
|
31
|
+
pytest.fail("The keep_normalized kwarg for norm wrapper is not behaving as expected.")
|
|
32
|
+
|
|
33
|
+
# Testing minmax with knwon mins and maxs
|
|
34
|
+
data = np.linspace(-1, 1, 100)
|
|
35
|
+
data = torch.tensor(data, dtype=torch.float32)
|
|
36
|
+
model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
|
|
37
|
+
model = model_cls(50, [64, 100])
|
|
38
|
+
try:
|
|
39
|
+
assert 0.55 == model.norm_in.default(None, 0.1)
|
|
40
|
+
assert -0.8 == model.inv_norm_out.default(None, 0.1)
|
|
41
|
+
except AssertionError:
|
|
42
|
+
pytest.fail("Something wrong with minmax wrapper.")
|
|
43
|
+
|
|
44
|
+
# Testing meanstd with knwon mean and std
|
|
45
|
+
data = random.normal(size=(50,))
|
|
46
|
+
data = (data-np.mean(data))/np.std(data)
|
|
47
|
+
data = data*2.0 + 1.0 # mean of 1 and std of 2
|
|
48
|
+
data = torch.tensor(data, dtype=torch.float32)
|
|
49
|
+
|
|
50
|
+
model_cls = norm_wrap(MLP, data, "meanstd", None, data, "meanstd", None)
|
|
51
|
+
model = model_cls(50, [64, 100])
|
|
52
|
+
try:
|
|
53
|
+
assert math.isclose(2, model.norm_in.default(None, 5), rel_tol=1e-1, abs_tol=1e-1)
|
|
54
|
+
assert math.isclose(7, model.inv_norm_out.default(None, 3), rel_tol=1e-1, abs_tol=1e-1)
|
|
55
|
+
except AssertionError:
|
|
56
|
+
pytest.fail("Something wrong with norm wrapper.")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .trackers import *
|