RRAEsTorch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_classes/AE_classes.py +18 -14
- RRAEsTorch/tests/test_AE_classes_CNN.py +20 -26
- RRAEsTorch/tests/test_AE_classes_MLP.py +20 -28
- RRAEsTorch/tests/test_fitting_CNN.py +14 -14
- RRAEsTorch/tests/test_fitting_MLP.py +11 -13
- RRAEsTorch/tests/test_save.py +11 -11
- RRAEsTorch/training_classes/training_classes.py +78 -121
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/METADATA +1 -2
- rraestorch-0.1.7.dist-info/RECORD +22 -0
- RRAEsTorch/tests/test_wrappers.py +0 -56
- RRAEsTorch/utilities/utilities.py +0 -1562
- RRAEsTorch/wrappers/__init__.py +0 -1
- RRAEsTorch/wrappers/wrappers.py +0 -237
- rraestorch-0.1.5.dist-info/RECORD +0 -27
- rraestorch-0.1.5.dist-info/licenses/LICENSE copy +0 -21
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/WHEEL +0 -0
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,6 @@ from RRAEsTorch.utilities import (
|
|
|
6
6
|
MLP_with_CNN3D_trans,
|
|
7
7
|
stable_SVD,
|
|
8
8
|
)
|
|
9
|
-
from RRAEsTorch.wrappers import vmap_wrap
|
|
10
9
|
import warnings
|
|
11
10
|
from torch.nn import Linear
|
|
12
11
|
from RRAEsTorch.AE_base import get_autoencoder_base
|
|
@@ -42,6 +41,8 @@ def latent_func_strong_RRAE(
|
|
|
42
41
|
y_approx : jnp.array
|
|
43
42
|
The latent space after the truncation.
|
|
44
43
|
"""
|
|
44
|
+
y = y.T # to get the number of samples in the last dimension, as expected by the SVD function
|
|
45
|
+
|
|
45
46
|
if apply_basis is not None:
|
|
46
47
|
if get_basis_coeffs:
|
|
47
48
|
return apply_basis, apply_basis.T @ y
|
|
@@ -51,7 +52,7 @@ def latent_func_strong_RRAE(
|
|
|
51
52
|
if get_right_sing:
|
|
52
53
|
raise ValueError("Can not find right singular vector when projecting on basis")
|
|
53
54
|
return apply_basis.T @ y
|
|
54
|
-
return apply_basis @ apply_basis.T @ y
|
|
55
|
+
return (apply_basis @ apply_basis.T @ y).T
|
|
55
56
|
|
|
56
57
|
k_max = -1 if k_max is None else k_max
|
|
57
58
|
|
|
@@ -97,7 +98,7 @@ def latent_func_strong_RRAE(
|
|
|
97
98
|
sigs = None
|
|
98
99
|
if ret:
|
|
99
100
|
return u_now, coeffs, sigs
|
|
100
|
-
return y_approx
|
|
101
|
+
return y_approx.T
|
|
101
102
|
|
|
102
103
|
def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
103
104
|
apply_basis = kwargs.get("apply_basis")
|
|
@@ -111,6 +112,9 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
|
|
|
111
112
|
return latent_func_strong_RRAE(self, y, k_max, apply_basis=apply_basis, **kwargs)
|
|
112
113
|
|
|
113
114
|
basis, coeffs = latent_func_strong_RRAE(self, y, k_max=k_max, get_basis_coeffs=True, apply_basis=apply_basis)
|
|
115
|
+
|
|
116
|
+
coeffs = coeffs.T # to get the number of samples as first dim
|
|
117
|
+
|
|
114
118
|
if self.typ == "eye":
|
|
115
119
|
mean = coeffs
|
|
116
120
|
elif self.typ == "trainable":
|
|
@@ -121,19 +125,19 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
|
|
|
121
125
|
logvar = self.lin_logvar(coeffs)
|
|
122
126
|
|
|
123
127
|
if return_dist:
|
|
124
|
-
return mean, logvar
|
|
128
|
+
return mean.T, logvar.T
|
|
125
129
|
|
|
126
130
|
std = torch.exp(0.5 * logvar)
|
|
127
131
|
if epsilon is not None:
|
|
128
132
|
if len(epsilon.shape) == 4:
|
|
129
133
|
epsilon = epsilon[0, 0] # to allow tpu sharding
|
|
130
|
-
z = mean +
|
|
134
|
+
z = mean + epsilon * std
|
|
131
135
|
else:
|
|
132
136
|
z = mean
|
|
133
137
|
|
|
134
138
|
if return_lat_dist:
|
|
135
|
-
return
|
|
136
|
-
return
|
|
139
|
+
return z @ basis.T, mean.T, logvar.T
|
|
140
|
+
return z @ basis.T
|
|
137
141
|
|
|
138
142
|
|
|
139
143
|
class RRAE_MLP(get_autoencoder_base()):
|
|
@@ -506,9 +510,8 @@ class VRRAE_CNN(CNN_Autoencoder):
|
|
|
506
510
|
count=count,
|
|
507
511
|
**kwargs,
|
|
508
512
|
)
|
|
509
|
-
|
|
510
|
-
self.
|
|
511
|
-
self.lin_logvar = v_Linear(k_max, k_max)
|
|
513
|
+
self.lin_mean = Linear(k_max, k_max)
|
|
514
|
+
self.lin_logvar = Linear(k_max, k_max)
|
|
512
515
|
self.typ = typ
|
|
513
516
|
|
|
514
517
|
def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
@@ -612,10 +615,6 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
|
|
|
612
615
|
typ: int
|
|
613
616
|
|
|
614
617
|
def __init__(self, channels, input_dim, latent_size, k_max, typ="eye", *, count=1, **kwargs):
|
|
615
|
-
v_Linear = vmap_wrap(Linear, -1, count=count)
|
|
616
|
-
self.lin_mean = v_Linear(k_max, k_max,)
|
|
617
|
-
self.lin_logvar = v_Linear(k_max, k_max)
|
|
618
|
-
self.typ = typ
|
|
619
618
|
super().__init__(
|
|
620
619
|
channels,
|
|
621
620
|
input_dim,
|
|
@@ -623,6 +622,11 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
|
|
|
623
622
|
count=count,
|
|
624
623
|
**kwargs,
|
|
625
624
|
)
|
|
625
|
+
|
|
626
|
+
self.lin_mean = Linear(k_max, k_max,)
|
|
627
|
+
self.lin_logvar = Linear(k_max, k_max)
|
|
628
|
+
self.typ = typ
|
|
629
|
+
|
|
626
630
|
|
|
627
631
|
def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
628
632
|
return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
|
|
@@ -6,18 +6,12 @@ from RRAEsTorch.AE_classes import (
|
|
|
6
6
|
IRMAE_CNN,
|
|
7
7
|
LoRAE_CNN,
|
|
8
8
|
)
|
|
9
|
-
from RRAEsTorch.wrappers import vmap_wrap
|
|
10
9
|
import numpy.random as random
|
|
11
10
|
import numpy as np
|
|
12
11
|
import torch
|
|
13
12
|
|
|
14
13
|
methods = ["encode", "decode"]
|
|
15
14
|
|
|
16
|
-
v_RRAE_CNN = vmap_wrap(RRAE_CNN, -1, 1, methods)
|
|
17
|
-
v_Vanilla_AE_CNN = vmap_wrap(Vanilla_AE_CNN, -1, 1, methods)
|
|
18
|
-
v_IRMAE_CNN = vmap_wrap(IRMAE_CNN, -1, 1, methods)
|
|
19
|
-
v_LoRAE_CNN = vmap_wrap(LoRAE_CNN, -1, 1, methods)
|
|
20
|
-
|
|
21
15
|
@pytest.mark.parametrize("width", (10, 17, 149))
|
|
22
16
|
@pytest.mark.parametrize("height", (20,))
|
|
23
17
|
@pytest.mark.parametrize("latent", (200,))
|
|
@@ -26,51 +20,51 @@ v_LoRAE_CNN = vmap_wrap(LoRAE_CNN, -1, 1, methods)
|
|
|
26
20
|
@pytest.mark.parametrize("num_samples", (10, 100))
|
|
27
21
|
class Test_AEs_shapes:
|
|
28
22
|
def test_RRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
29
|
-
x = random.normal(size=(channels, width, height
|
|
23
|
+
x = random.normal(size=(num_samples, channels, width, height))
|
|
30
24
|
x = torch.tensor(x, dtype=torch.float32)
|
|
31
25
|
kwargs = {"kwargs_dec": {"stride": 2}}
|
|
32
|
-
model =
|
|
33
|
-
x.shape[
|
|
26
|
+
model = RRAE_CNN(
|
|
27
|
+
x.shape[1], x.shape[2], x.shape[3], latent, num_modes, **kwargs
|
|
34
28
|
)
|
|
35
29
|
y = model.encode(x)
|
|
36
|
-
assert y.shape == (
|
|
30
|
+
assert y.shape == (num_samples, latent)
|
|
37
31
|
y = model.latent(x, k_max=num_modes)
|
|
38
32
|
_, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
|
|
39
33
|
assert sing_vals[num_modes + 1] < 1e-5
|
|
40
|
-
assert y.shape == (
|
|
41
|
-
assert model.decode(y).shape == (channels, width, height
|
|
34
|
+
assert y.shape == (num_samples, latent)
|
|
35
|
+
assert model.decode(y).shape == (num_samples, channels, width, height)
|
|
42
36
|
|
|
43
37
|
def test_Vanilla_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
44
|
-
x = random.normal(size=(channels, width, height
|
|
38
|
+
x = random.normal(size=(num_samples, channels, width, height))
|
|
45
39
|
x = torch.tensor(x, dtype=torch.float32)
|
|
46
40
|
kwargs = {"kwargs_dec": {"stride": 2}}
|
|
47
|
-
model =
|
|
48
|
-
x.shape[
|
|
41
|
+
model = Vanilla_AE_CNN(
|
|
42
|
+
x.shape[1], x.shape[2], x.shape[3], latent, **kwargs
|
|
49
43
|
)
|
|
50
44
|
y = model.encode(x)
|
|
51
|
-
assert y.shape == (
|
|
45
|
+
assert y.shape == (num_samples, latent)
|
|
52
46
|
x = model.decode(y)
|
|
53
|
-
assert x.shape == (channels, width, height
|
|
47
|
+
assert x.shape == (num_samples, channels, width, height)
|
|
54
48
|
|
|
55
49
|
|
|
56
50
|
def test_IRMAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
57
|
-
x = random.normal(size=(channels, width, height
|
|
51
|
+
x = random.normal(size=(num_samples, channels, width, height))
|
|
58
52
|
x = torch.tensor(x, dtype=torch.float32)
|
|
59
|
-
model =
|
|
60
|
-
x.shape[
|
|
53
|
+
model = IRMAE_CNN(
|
|
54
|
+
x.shape[1], x.shape[2], x.shape[3], latent, linear_l=2
|
|
61
55
|
)
|
|
62
56
|
y = model.encode(x)
|
|
63
|
-
assert y.shape == (
|
|
57
|
+
assert y.shape == (num_samples, latent)
|
|
64
58
|
assert len(model._encode.layers[-1].layers_l) == 2
|
|
65
59
|
x = model.decode(y)
|
|
66
|
-
assert x.shape == (channels, width, height
|
|
60
|
+
assert x.shape == (num_samples, channels, width, height)
|
|
67
61
|
|
|
68
62
|
def test_LoRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
|
|
69
|
-
x = random.normal(size=(channels, width, height
|
|
63
|
+
x = random.normal(size=(num_samples, channels, width, height))
|
|
70
64
|
x = torch.tensor(x, dtype=torch.float32)
|
|
71
|
-
model =
|
|
65
|
+
model = LoRAE_CNN(x.shape[1], x.shape[2], x.shape[3], latent)
|
|
72
66
|
y = model.encode(x)
|
|
73
|
-
assert y.shape == (
|
|
67
|
+
assert y.shape == (num_samples, latent)
|
|
74
68
|
assert len(model._encode.layers[-1].layers_l) == 1
|
|
75
69
|
x = model.decode(y)
|
|
76
|
-
assert x.shape == (channels, width, height
|
|
70
|
+
assert x.shape == (num_samples, channels, width, height)
|
|
@@ -6,67 +6,59 @@ from RRAEsTorch.AE_classes import (
|
|
|
6
6
|
IRMAE_MLP,
|
|
7
7
|
LoRAE_MLP,
|
|
8
8
|
)
|
|
9
|
-
from RRAEsTorch.wrappers import vmap_wrap
|
|
10
9
|
import numpy.random as random
|
|
11
10
|
import numpy as np
|
|
12
11
|
import torch
|
|
13
12
|
|
|
14
|
-
methods = ["encode", "decode"]
|
|
15
|
-
|
|
16
|
-
v_RRAE_MLP = vmap_wrap(RRAE_MLP, -1, 1, methods)
|
|
17
|
-
v_Vanilla_AE_MLP = vmap_wrap(Vanilla_AE_MLP, -1, 1, methods)
|
|
18
|
-
v_IRMAE_MLP = vmap_wrap(IRMAE_MLP, -1, 1, methods)
|
|
19
|
-
v_LoRAE_MLP = vmap_wrap(LoRAE_MLP, -1, 1, methods)
|
|
20
|
-
|
|
21
13
|
@pytest.mark.parametrize("dim_D", (10, 15, 50))
|
|
22
14
|
@pytest.mark.parametrize("latent", (200, 400, 800))
|
|
23
15
|
@pytest.mark.parametrize("num_modes", (1, 2, 6))
|
|
24
16
|
class Test_AEs_shapes:
|
|
25
17
|
def test_RRAE_MLP(self, latent, num_modes, dim_D):
|
|
26
|
-
x = random.normal(size=(
|
|
18
|
+
x = random.normal(size=(dim_D, 500))
|
|
27
19
|
x = torch.tensor(x, dtype=torch.float32)
|
|
28
|
-
model =
|
|
20
|
+
model = RRAE_MLP(x.shape[1], latent, num_modes)
|
|
29
21
|
y = model.encode(x)
|
|
30
|
-
assert y.shape == (
|
|
22
|
+
assert y.shape == (dim_D, latent)
|
|
31
23
|
y = model.perform_in_latent(y, k_max=num_modes)
|
|
32
24
|
_, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
|
|
33
|
-
assert sing_vals[num_modes + 1] < 1e-
|
|
34
|
-
assert y.shape == (
|
|
35
|
-
assert model.decode(y).shape == (
|
|
25
|
+
assert sing_vals[num_modes + 1] < 1e-4
|
|
26
|
+
assert y.shape == (dim_D, latent)
|
|
27
|
+
assert model.decode(y).shape == (dim_D, 500)
|
|
36
28
|
|
|
37
29
|
def test_Vanilla_MLP(self, latent, num_modes, dim_D):
|
|
38
|
-
x = random.normal(size=(
|
|
30
|
+
x = random.normal(size=(dim_D, 500))
|
|
39
31
|
x = torch.tensor(x, dtype=torch.float32)
|
|
40
|
-
model =
|
|
32
|
+
model = Vanilla_AE_MLP(x.shape[1], latent)
|
|
41
33
|
y = model.encode(x)
|
|
42
|
-
assert y.shape == (
|
|
34
|
+
assert y.shape == (dim_D, latent)
|
|
43
35
|
x = model.decode(y)
|
|
44
|
-
assert x.shape == (
|
|
36
|
+
assert x.shape == (dim_D, 500)
|
|
45
37
|
|
|
46
38
|
def test_IRMAE_MLP(self, latent, num_modes, dim_D):
|
|
47
|
-
x = random.normal(size=(
|
|
39
|
+
x = random.normal(size=(dim_D, 500))
|
|
48
40
|
x = torch.tensor(x, dtype=torch.float32)
|
|
49
|
-
model =
|
|
41
|
+
model = IRMAE_MLP(x.shape[1], latent, linear_l=2)
|
|
50
42
|
y = model.encode(x)
|
|
51
|
-
assert y.shape == (
|
|
43
|
+
assert y.shape == (dim_D, latent)
|
|
52
44
|
assert len(model._encode.layers_l) == 2
|
|
53
45
|
x = model.decode(y)
|
|
54
|
-
assert x.shape == (
|
|
46
|
+
assert x.shape == (dim_D, 500)
|
|
55
47
|
|
|
56
48
|
def test_LoRAE_MLP(self, latent, num_modes, dim_D):
|
|
57
|
-
x = random.normal(size=(
|
|
49
|
+
x = random.normal(size=(dim_D, 500))
|
|
58
50
|
x = torch.tensor(x, dtype=torch.float32)
|
|
59
|
-
model =
|
|
51
|
+
model = LoRAE_MLP(x.shape[1], latent)
|
|
60
52
|
y = model.encode(x)
|
|
61
|
-
assert y.shape == (
|
|
53
|
+
assert y.shape == (dim_D, latent)
|
|
62
54
|
assert len(model._encode.layers_l) == 1
|
|
63
55
|
x = model.decode(y)
|
|
64
|
-
assert x.shape == (
|
|
56
|
+
assert x.shape == (dim_D, 500)
|
|
65
57
|
|
|
66
58
|
def test_getting_SVD_coeffs():
|
|
67
|
-
data = random.uniform(size=(
|
|
59
|
+
data = random.uniform(size=(15, 500))
|
|
68
60
|
data = torch.tensor(data, dtype=torch.float32)
|
|
69
|
-
model_s =
|
|
61
|
+
model_s = RRAE_MLP(data.shape[1], 200, 3)
|
|
70
62
|
basis, coeffs = model_s.latent(data, k_max=3, get_basis_coeffs=True)
|
|
71
63
|
assert basis.shape == (200, 3)
|
|
72
64
|
assert coeffs.shape == (3, 15)
|
|
@@ -13,8 +13,8 @@ import torch
|
|
|
13
13
|
@pytest.mark.parametrize(
|
|
14
14
|
"model_cls, sh, lf",
|
|
15
15
|
[
|
|
16
|
-
(Vanilla_AE_CNN, (1, 2, 2
|
|
17
|
-
(LoRAE_CNN, (6, 16, 16
|
|
16
|
+
(Vanilla_AE_CNN, (10, 1, 2, 2), "default"),
|
|
17
|
+
(LoRAE_CNN, (10, 6, 16, 16), "nuc"),
|
|
18
18
|
],
|
|
19
19
|
)
|
|
20
20
|
def test_AE_fitting(model_cls, sh, lf):
|
|
@@ -24,10 +24,10 @@ def test_AE_fitting(model_cls, sh, lf):
|
|
|
24
24
|
x,
|
|
25
25
|
model_cls,
|
|
26
26
|
latent_size=100,
|
|
27
|
-
channels=x.shape[
|
|
28
|
-
width=x.shape[
|
|
29
|
-
height=x.shape[
|
|
30
|
-
samples=x.shape[
|
|
27
|
+
channels=x.shape[1],
|
|
28
|
+
width=x.shape[2],
|
|
29
|
+
height=x.shape[3],
|
|
30
|
+
samples=x.shape[0], # Only for weak
|
|
31
31
|
k_max=2,
|
|
32
32
|
)
|
|
33
33
|
kwargs = {
|
|
@@ -52,16 +52,16 @@ def test_AE_fitting(model_cls, sh, lf):
|
|
|
52
52
|
def test_IRMAE_fitting():
|
|
53
53
|
model_cls = IRMAE_CNN
|
|
54
54
|
lf = "default"
|
|
55
|
-
sh = (3, 12, 12
|
|
55
|
+
sh = (10, 3, 12, 12)
|
|
56
56
|
x = random.normal(size=sh)
|
|
57
57
|
x = torch.tensor(x, dtype=torch.float32)
|
|
58
58
|
trainor = AE_Trainor_class(
|
|
59
59
|
x,
|
|
60
60
|
model_cls,
|
|
61
61
|
latent_size=100,
|
|
62
|
-
channels=x.shape[
|
|
63
|
-
width=x.shape[
|
|
64
|
-
height=x.shape[
|
|
62
|
+
channels=x.shape[1],
|
|
63
|
+
width=x.shape[2],
|
|
64
|
+
height=x.shape[3],
|
|
65
65
|
k_max=2,
|
|
66
66
|
linear_l=4,
|
|
67
67
|
)
|
|
@@ -77,7 +77,7 @@ def test_IRMAE_fitting():
|
|
|
77
77
|
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
78
78
|
|
|
79
79
|
def test_RRAE_fitting():
|
|
80
|
-
sh = (1, 20, 20
|
|
80
|
+
sh = (10, 1, 20, 20)
|
|
81
81
|
model_cls = RRAE_CNN
|
|
82
82
|
x = random.normal(size=sh)
|
|
83
83
|
x = torch.tensor(x, dtype=torch.float32)
|
|
@@ -85,9 +85,9 @@ def test_RRAE_fitting():
|
|
|
85
85
|
x,
|
|
86
86
|
model_cls,
|
|
87
87
|
latent_size=100,
|
|
88
|
-
channels=x.shape[
|
|
89
|
-
width=x.shape[
|
|
90
|
-
height=x.shape[
|
|
88
|
+
channels=x.shape[1],
|
|
89
|
+
width=x.shape[2],
|
|
90
|
+
height=x.shape[3],
|
|
91
91
|
k_max=2,
|
|
92
92
|
)
|
|
93
93
|
training_kwargs = {
|
|
@@ -14,8 +14,8 @@ import torch
|
|
|
14
14
|
@pytest.mark.parametrize(
|
|
15
15
|
"model_cls, sh, lf",
|
|
16
16
|
[
|
|
17
|
-
(Vanilla_AE_MLP, (
|
|
18
|
-
(LoRAE_MLP, (
|
|
17
|
+
(Vanilla_AE_MLP, (10, 500), "default"),
|
|
18
|
+
(LoRAE_MLP, (10, 500), "nuc"),
|
|
19
19
|
],
|
|
20
20
|
)
|
|
21
21
|
def test_fitting(model_cls, sh, lf):
|
|
@@ -24,9 +24,8 @@ def test_fitting(model_cls, sh, lf):
|
|
|
24
24
|
trainor = AE_Trainor_class(
|
|
25
25
|
x,
|
|
26
26
|
model_cls,
|
|
27
|
-
in_size=x.shape[
|
|
28
|
-
|
|
29
|
-
samples=x.shape[-1], # Only for weak
|
|
27
|
+
in_size=x.shape[1],
|
|
28
|
+
samples=x.shape[0], # Only for weak
|
|
30
29
|
norm_in="meanstd",
|
|
31
30
|
norm_out="minmax",
|
|
32
31
|
out_train=x,
|
|
@@ -51,14 +50,14 @@ def test_fitting(model_cls, sh, lf):
|
|
|
51
50
|
|
|
52
51
|
|
|
53
52
|
def test_RRAE_fitting():
|
|
54
|
-
sh = (
|
|
53
|
+
sh = (10, 500)
|
|
55
54
|
model_cls = RRAE_MLP
|
|
56
55
|
x = random.normal(size=sh)
|
|
57
56
|
x = torch.tensor(x, dtype=torch.float32)
|
|
58
57
|
trainor = RRAE_Trainor_class(
|
|
59
58
|
x,
|
|
60
59
|
model_cls,
|
|
61
|
-
in_size=x.shape[
|
|
60
|
+
in_size=x.shape[1],
|
|
62
61
|
latent_size=2000,
|
|
63
62
|
k_max=2,
|
|
64
63
|
)
|
|
@@ -83,14 +82,13 @@ def test_RRAE_fitting():
|
|
|
83
82
|
def test_IRMAE_fitting():
|
|
84
83
|
model_cls = IRMAE_MLP
|
|
85
84
|
lf = "default"
|
|
86
|
-
sh = (
|
|
85
|
+
sh = (10, 500)
|
|
87
86
|
x = random.normal(size=sh)
|
|
88
87
|
x = torch.tensor(x, dtype=torch.float32)
|
|
89
88
|
trainor = AE_Trainor_class(
|
|
90
89
|
x,
|
|
91
90
|
model_cls,
|
|
92
|
-
in_size=x.shape[
|
|
93
|
-
data_size=x.shape[-1],
|
|
91
|
+
in_size=x.shape[1],
|
|
94
92
|
latent_size=2000,
|
|
95
93
|
k_max=2,
|
|
96
94
|
linear_l=4,
|
|
@@ -107,15 +105,15 @@ def test_IRMAE_fitting():
|
|
|
107
105
|
assert False, f"Fitting failed with the following exception {repr(e)}"
|
|
108
106
|
|
|
109
107
|
def test_fitting():
|
|
110
|
-
sh = (
|
|
108
|
+
sh = (100, 50)
|
|
111
109
|
model_cls = MLP
|
|
112
110
|
x = random.normal(size=sh)
|
|
113
111
|
x = torch.tensor(x, dtype=torch.float32)
|
|
114
112
|
trainor = Trainor_class(
|
|
115
113
|
x,
|
|
116
114
|
model_cls,
|
|
117
|
-
in_channels=x.shape[
|
|
118
|
-
hidden_channels=[100, x.shape[
|
|
115
|
+
in_channels=x.shape[1],
|
|
116
|
+
hidden_channels=[100, x.shape[1]]
|
|
119
117
|
)
|
|
120
118
|
training_kwargs = {
|
|
121
119
|
"step_st": [2],
|
RRAEsTorch/tests/test_save.py
CHANGED
|
@@ -5,7 +5,7 @@ import numpy.random as random
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
def test_save(): # Only to test if saving/loading is causing a problem
|
|
8
|
-
data = random.normal(size=(1,
|
|
8
|
+
data = random.normal(size=(1, 1, 28, 28))
|
|
9
9
|
data = torch.tensor(data, dtype=torch.float32)
|
|
10
10
|
model_cls = RRAE_CNN
|
|
11
11
|
|
|
@@ -13,9 +13,9 @@ def test_save(): # Only to test if saving/loading is causing a problem
|
|
|
13
13
|
data,
|
|
14
14
|
model_cls,
|
|
15
15
|
latent_size=100,
|
|
16
|
-
channels=data.shape[
|
|
17
|
-
width=data.shape[
|
|
18
|
-
height=data.shape[
|
|
16
|
+
channels=data.shape[1],
|
|
17
|
+
width=data.shape[2],
|
|
18
|
+
height=data.shape[3],
|
|
19
19
|
pre_func_inp=lambda x: x * 2 / 17,
|
|
20
20
|
pre_func_out=lambda x: x / 2,
|
|
21
21
|
k_max=2,
|
|
@@ -25,17 +25,17 @@ def test_save(): # Only to test if saving/loading is causing a problem
|
|
|
25
25
|
new_trainor = RRAE_Trainor_class()
|
|
26
26
|
new_trainor.load_model("test_", erase=True)
|
|
27
27
|
try:
|
|
28
|
-
pr = trainor.model(data[
|
|
28
|
+
pr = trainor.model(data[0:1], k_max=2)
|
|
29
29
|
except Exception as e:
|
|
30
30
|
raise ValueError(f"Original trainor failed with following exception {e}")
|
|
31
31
|
try:
|
|
32
|
-
pr = new_trainor.model(data[
|
|
32
|
+
pr = new_trainor.model(data[0:1], k_max=2)
|
|
33
33
|
except Exception as e:
|
|
34
34
|
raise ValueError(f"Failed with following exception {e}")
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def test_save_with_final_act():
|
|
38
|
-
data = random.normal(size=(1,
|
|
38
|
+
data = random.normal(size=(1, 1, 28, 28))
|
|
39
39
|
data = torch.tensor(data, dtype=torch.float32)
|
|
40
40
|
|
|
41
41
|
model_cls = RRAE_CNN
|
|
@@ -44,9 +44,9 @@ def test_save_with_final_act():
|
|
|
44
44
|
data,
|
|
45
45
|
model_cls,
|
|
46
46
|
latent_size=100,
|
|
47
|
-
channels=data.shape[
|
|
48
|
-
width=data.shape[
|
|
49
|
-
height=data.shape[
|
|
47
|
+
channels=data.shape[1],
|
|
48
|
+
width=data.shape[2],
|
|
49
|
+
height=data.shape[3],
|
|
50
50
|
kwargs_dec={"final_activation": torch.sigmoid},
|
|
51
51
|
k_max=2,
|
|
52
52
|
)
|
|
@@ -55,7 +55,7 @@ def test_save_with_final_act():
|
|
|
55
55
|
new_trainor = RRAE_Trainor_class()
|
|
56
56
|
new_trainor.load_model("test_", erase=True)
|
|
57
57
|
try:
|
|
58
|
-
pr = new_trainor.model(data[
|
|
58
|
+
pr = new_trainor.model(data[0:1], k_max=2)
|
|
59
59
|
assert torch.max(pr) <= 1.0, "Final activation not working"
|
|
60
60
|
assert torch.min(pr) >= 0.0, "Final activation not working"
|
|
61
61
|
except Exception as e:
|