RRAEsTorch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
RRAEsTorch/config.py ADDED
@@ -0,0 +1,95 @@
1
+ from RRAEsTorch.utilities import MLP_with_linear
2
+ import numpy.random as random
3
+ from RRAEsTorch.AE_base import AE_base as rraes
4
+ import torch
5
+ from torch.func import vmap
6
+
7
+ class Autoencoder(torch.nn.Module):
8
+ _encode: MLP_with_linear
9
+ _decode: MLP_with_linear
10
+ _perform_in_latent: callable
11
+ _perform_in_latent: callable
12
+ map_latent: bool
13
+ norm_funcs: list
14
+ inv_norm_funcs: list
15
+ count: int
16
+
17
+ def __init__(
18
+ self,
19
+ in_size,
20
+ latent_size,
21
+ latent_size_after=None,
22
+ _encode=None,
23
+ _decode=None,
24
+ map_latent=True,
25
+ *,
26
+ count=1,
27
+ kwargs_enc={},
28
+ kwargs_dec={},
29
+ **kwargs,
30
+ ):
31
+ super().__init__()
32
+ if latent_size_after is None:
33
+ latent_size_after = latent_size
34
+
35
+ if _encode is None:
36
+ if "width_size" not in kwargs_enc.keys():
37
+ kwargs_enc["width_size"] = 64
38
+
39
+ if "depth" not in kwargs_enc.keys():
40
+ kwargs_enc["depth"] = 1
41
+
42
+ self._encode = MLP_with_linear(
43
+ in_size=in_size,
44
+ out_size=latent_size,
45
+ **kwargs_enc,
46
+ )
47
+
48
+ else:
49
+ self._encode = _encode
50
+
51
+ if not hasattr(self, "_perform_in_latent"):
52
+ self._perform_in_latent = lambda x, *args, **kwargs: x
53
+
54
+ if _decode is None:
55
+ if "width_size" not in kwargs_dec.keys():
56
+ kwargs_dec["width_size"] = 64
57
+ if "depth" not in kwargs_dec.keys():
58
+ kwargs_dec["depth"] = 6
59
+
60
+ self._decode = MLP_with_linear(
61
+ in_size=latent_size_after,
62
+ out_size=in_size,
63
+ **kwargs_dec,
64
+ )
65
+ else:
66
+ self._decode = _decode
67
+
68
+ self.count = count
69
+ self.map_latent = map_latent
70
+ self.inv_norm_funcs = ["decode"]
71
+ self.norm_funcs = ["encode", "latent"]
72
+
73
+ def encode(self, x, *args, **kwargs):
74
+ return self._encode(x, *args, **kwargs)
75
+
76
+ def decode(self, x, *args, **kwargs):
77
+ return self._decode(x, *args, **kwargs)
78
+
79
+ def perform_in_latent(self, y, *args, **kwargs):
80
+ if self.map_latent:
81
+ new_perform_in_latent = lambda x: self._perform_in_latent(
82
+ x, *args, **kwargs
83
+ )
84
+ for _ in range(self.count):
85
+ new_perform_in_latent = vmap(new_perform_in_latent, in_dims=-1, out_dims=-1)
86
+ return new_perform_in_latent(y)
87
+ return self._perform_in_latent(y, *args, **kwargs)
88
+
89
+ def forward(self, x, *args, **kwargs):
90
+ return self.decode(self.perform_in_latent(self.encode(x), *args, **kwargs))
91
+
92
+ def latent(self, x, *args, **kwargs):
93
+ return self.perform_in_latent(self.encode(x), *args, **kwargs)
94
+
95
+ rraes.set_autoencoder_base(Autoencoder)
@@ -0,0 +1,76 @@
1
+ import RRAEsTorch.config
2
+ import pytest
3
+ from RRAEsTorch.AE_classes import (
4
+ RRAE_CNN,
5
+ Vanilla_AE_CNN,
6
+ IRMAE_CNN,
7
+ LoRAE_CNN,
8
+ )
9
+ from RRAEsTorch.wrappers import vmap_wrap
10
+ import numpy.random as random
11
+ import numpy as np
12
+ import torch
13
+
14
+ methods = ["encode", "decode"]
15
+
16
+ v_RRAE_CNN = vmap_wrap(RRAE_CNN, -1, 1, methods)
17
+ v_Vanilla_AE_CNN = vmap_wrap(Vanilla_AE_CNN, -1, 1, methods)
18
+ v_IRMAE_CNN = vmap_wrap(IRMAE_CNN, -1, 1, methods)
19
+ v_LoRAE_CNN = vmap_wrap(LoRAE_CNN, -1, 1, methods)
20
+
21
+ @pytest.mark.parametrize("width", (10, 17, 149))
22
+ @pytest.mark.parametrize("height", (20,))
23
+ @pytest.mark.parametrize("latent", (200,))
24
+ @pytest.mark.parametrize("num_modes", (1,))
25
+ @pytest.mark.parametrize("channels", (1, 3, 5))
26
+ @pytest.mark.parametrize("num_samples", (10, 100))
27
+ class Test_AEs_shapes:
28
+ def test_RRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
29
+ x = random.normal(size=(channels, width, height, num_samples))
30
+ x = torch.tensor(x, dtype=torch.float32)
31
+ kwargs = {"kwargs_dec": {"stride": 2}}
32
+ model = v_RRAE_CNN(
33
+ x.shape[0], x.shape[1], x.shape[2], latent, num_modes, **kwargs
34
+ )
35
+ y = model.encode(x)
36
+ assert y.shape == (latent, num_samples)
37
+ y = model.latent(x, k_max=num_modes)
38
+ _, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
39
+ assert sing_vals[num_modes + 1] < 1e-5
40
+ assert y.shape == (latent, num_samples)
41
+ assert model.decode(y).shape == (channels, width, height, num_samples)
42
+
43
+ def test_Vanilla_CNN(self, latent, num_modes, width, height, channels, num_samples):
44
+ x = random.normal(size=(channels, width, height, num_samples))
45
+ x = torch.tensor(x, dtype=torch.float32)
46
+ kwargs = {"kwargs_dec": {"stride": 2}}
47
+ model = v_Vanilla_AE_CNN(
48
+ x.shape[0], x.shape[1], x.shape[2], latent, **kwargs
49
+ )
50
+ y = model.encode(x)
51
+ assert y.shape == (latent, num_samples)
52
+ x = model.decode(y)
53
+ assert x.shape == (channels, width, height, num_samples)
54
+
55
+
56
+ def test_IRMAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
57
+ x = random.normal(size=(channels, width, height, num_samples))
58
+ x = torch.tensor(x, dtype=torch.float32)
59
+ model = v_IRMAE_CNN(
60
+ x.shape[0], x.shape[1], x.shape[2], latent, linear_l=2
61
+ )
62
+ y = model.encode(x)
63
+ assert y.shape == (latent, num_samples)
64
+ assert len(model._encode.layers[-1].layers_l) == 2
65
+ x = model.decode(y)
66
+ assert x.shape == (channels, width, height, num_samples)
67
+
68
+ def test_LoRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
69
+ x = random.normal(size=(channels, width, height, num_samples))
70
+ x = torch.tensor(x, dtype=torch.float32)
71
+ model = v_LoRAE_CNN(x.shape[0], x.shape[1], x.shape[2], latent)
72
+ y = model.encode(x)
73
+ assert y.shape == (latent, num_samples)
74
+ assert len(model._encode.layers[-1].layers_l) == 1
75
+ x = model.decode(y)
76
+ assert x.shape == (channels, width, height, num_samples)
@@ -0,0 +1,73 @@
1
+ import RRAEsTorch.config
2
+ import pytest
3
+ from RRAEsTorch.AE_classes import (
4
+ RRAE_MLP,
5
+ Vanilla_AE_MLP,
6
+ IRMAE_MLP,
7
+ LoRAE_MLP,
8
+ )
9
+ from RRAEsTorch.wrappers import vmap_wrap
10
+ import numpy.random as random
11
+ import numpy as np
12
+ import torch
13
+
14
+ methods = ["encode", "decode"]
15
+
16
+ v_RRAE_MLP = vmap_wrap(RRAE_MLP, -1, 1, methods)
17
+ v_Vanilla_AE_MLP = vmap_wrap(Vanilla_AE_MLP, -1, 1, methods)
18
+ v_IRMAE_MLP = vmap_wrap(IRMAE_MLP, -1, 1, methods)
19
+ v_LoRAE_MLP = vmap_wrap(LoRAE_MLP, -1, 1, methods)
20
+
21
+ @pytest.mark.parametrize("dim_D", (10, 15, 50))
22
+ @pytest.mark.parametrize("latent", (200, 400, 800))
23
+ @pytest.mark.parametrize("num_modes", (1, 2, 6))
24
+ class Test_AEs_shapes:
25
+ def test_RRAE_MLP(self, latent, num_modes, dim_D):
26
+ x = random.normal(size=(500, dim_D))
27
+ x = torch.tensor(x, dtype=torch.float32)
28
+ model = v_RRAE_MLP(x.shape[0], latent, num_modes)
29
+ y = model.encode(x)
30
+ assert y.shape == (latent, dim_D)
31
+ y = model.perform_in_latent(y, k_max=num_modes)
32
+ _, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
33
+ assert sing_vals[num_modes + 1] < 1e-5
34
+ assert y.shape == (latent, dim_D)
35
+ assert model.decode(y).shape == (500, dim_D)
36
+
37
+ def test_Vanilla_MLP(self, latent, num_modes, dim_D):
38
+ x = random.normal(size=(500, dim_D))
39
+ x = torch.tensor(x, dtype=torch.float32)
40
+ model = v_Vanilla_AE_MLP(x.shape[0], latent)
41
+ y = model.encode(x)
42
+ assert y.shape == (latent, dim_D)
43
+ x = model.decode(y)
44
+ assert x.shape == (500, dim_D)
45
+
46
+ def test_IRMAE_MLP(self, latent, num_modes, dim_D):
47
+ x = random.normal(size=(500, dim_D))
48
+ x = torch.tensor(x, dtype=torch.float32)
49
+ model = v_IRMAE_MLP(x.shape[0], latent, linear_l=2)
50
+ y = model.encode(x)
51
+ assert y.shape == (latent, dim_D)
52
+ assert len(model._encode.layers_l) == 2
53
+ x = model.decode(y)
54
+ assert x.shape == (500, dim_D)
55
+
56
+ def test_LoRAE_MLP(self, latent, num_modes, dim_D):
57
+ x = random.normal(size=(500, dim_D))
58
+ x = torch.tensor(x, dtype=torch.float32)
59
+ model = v_LoRAE_MLP(x.shape[0], latent)
60
+ y = model.encode(x)
61
+ assert y.shape == (latent, dim_D)
62
+ assert len(model._encode.layers_l) == 1
63
+ x = model.decode(y)
64
+ assert x.shape == (500, dim_D)
65
+
66
+ def test_getting_SVD_coeffs():
67
+ data = random.uniform(size=(500, 15))
68
+ data = torch.tensor(data, dtype=torch.float32)
69
+ model_s = v_RRAE_MLP(data.shape[0], 200, 3)
70
+ basis, coeffs = model_s.latent(data, k_max=3, get_basis_coeffs=True)
71
+ assert basis.shape == (200, 3)
72
+ assert coeffs.shape == (3, 15)
73
+
@@ -0,0 +1,109 @@
1
+ import RRAEsTorch.config
2
+ import pytest
3
+ from RRAEsTorch.AE_classes import (
4
+ RRAE_CNN,
5
+ Vanilla_AE_CNN,
6
+ IRMAE_CNN,
7
+ LoRAE_CNN,
8
+ )
9
+ from RRAEsTorch.training_classes import RRAE_Trainor_class, Trainor_class, AE_Trainor_class
10
+ import numpy.random as random
11
+ import torch
12
+
13
+ @pytest.mark.parametrize(
14
+ "model_cls, sh, lf",
15
+ [
16
+ (Vanilla_AE_CNN, (1, 2, 2, 10), "default"),
17
+ (LoRAE_CNN, (6, 16, 16, 10), "nuc"),
18
+ ],
19
+ )
20
+ def test_AE_fitting(model_cls, sh, lf):
21
+ x = random.normal(size=sh)
22
+ x = torch.tensor(x, dtype=torch.float32)
23
+ trainor = AE_Trainor_class(
24
+ x,
25
+ model_cls,
26
+ latent_size=100,
27
+ channels=x.shape[0],
28
+ width=x.shape[1],
29
+ height=x.shape[2],
30
+ samples=x.shape[-1], # Only for weak
31
+ k_max=2,
32
+ )
33
+ kwargs = {
34
+ "step_st": [2],
35
+ "loss_kwargs": {
36
+ "lambda_nuc": 0.001,
37
+ "find_layer": lambda model: model.encode.layers[-2].layers[-1].weight,
38
+ "loss_type": lf
39
+ },
40
+ }
41
+ try:
42
+ trainor.fit(
43
+ x,
44
+ x,
45
+ verbose=False,
46
+ training_kwargs=kwargs,
47
+ )
48
+ except Exception as e:
49
+ assert False, f"Fitting failed with the following exception {repr(e)}"
50
+
51
+
52
+ def test_IRMAE_fitting():
53
+ model_cls = IRMAE_CNN
54
+ lf = "default"
55
+ sh = (3, 12, 12, 10)
56
+ x = random.normal(size=sh)
57
+ x = torch.tensor(x, dtype=torch.float32)
58
+ trainor = AE_Trainor_class(
59
+ x,
60
+ model_cls,
61
+ latent_size=100,
62
+ channels=x.shape[0],
63
+ width=x.shape[1],
64
+ height=x.shape[2],
65
+ k_max=2,
66
+ linear_l=4,
67
+ )
68
+ kwargs = {"step_st": [2], "loss_type": lf}
69
+ try:
70
+ trainor.fit(
71
+ x,
72
+ x,
73
+ verbose=False,
74
+ training_kwargs=kwargs,
75
+ )
76
+ except Exception as e:
77
+ assert False, f"Fitting failed with the following exception {repr(e)}"
78
+
79
+ def test_RRAE_fitting():
80
+ sh = (1, 20, 20, 10)
81
+ model_cls = RRAE_CNN
82
+ x = random.normal(size=sh)
83
+ x = torch.tensor(x, dtype=torch.float32)
84
+ trainor = RRAE_Trainor_class(
85
+ x,
86
+ model_cls,
87
+ latent_size=100,
88
+ channels=x.shape[0],
89
+ width=x.shape[1],
90
+ height=x.shape[2],
91
+ k_max=2,
92
+ )
93
+ training_kwargs = {
94
+ "step_st": [2],
95
+ "loss_type": "default"
96
+ }
97
+ ft_kwargs = {
98
+ "step_st": [2],
99
+ }
100
+ try:
101
+ trainor.fit(
102
+ x,
103
+ x,
104
+ verbose=False,
105
+ training_kwargs=training_kwargs,
106
+ ft_kwargs=ft_kwargs,
107
+ )
108
+ except Exception as e:
109
+ assert False, f"Fitting failed with the following exception {repr(e)}"
@@ -0,0 +1,133 @@
1
+ import RRAEsTorch.config
2
+ import numpy.random as random
3
+ import pytest
4
+ from RRAEsTorch.AE_classes import (
5
+ RRAE_MLP,
6
+ Vanilla_AE_MLP,
7
+ IRMAE_MLP,
8
+ LoRAE_MLP,
9
+ )
10
+ from torchvision.ops import MLP
11
+ from RRAEsTorch.training_classes import RRAE_Trainor_class, Trainor_class, AE_Trainor_class
12
+ import torch
13
+
14
+ @pytest.mark.parametrize(
15
+ "model_cls, sh, lf",
16
+ [
17
+ (Vanilla_AE_MLP, (500, 10), "default"),
18
+ (LoRAE_MLP, (500, 10), "nuc"),
19
+ ],
20
+ )
21
+ def test_fitting(model_cls, sh, lf):
22
+ x = random.normal(size=sh)
23
+ x = torch.tensor(x, dtype=torch.float32)
24
+ trainor = AE_Trainor_class(
25
+ x,
26
+ model_cls,
27
+ in_size=x.shape[0],
28
+ data_size=x.shape[-1],
29
+ samples=x.shape[-1], # Only for weak
30
+ norm_in="meanstd",
31
+ norm_out="minmax",
32
+ out_train=x,
33
+ latent_size=2000,
34
+ k_max=2,
35
+ )
36
+ kwargs = {
37
+ "step_st": [2],
38
+ "loss_kwargs": {"lambda_nuc": 0.001, "find_layer": lambda model: model._encode.layers_l[0].weight},
39
+ "loss_type": lf
40
+ }
41
+ try:
42
+ trainor.fit(
43
+ x,
44
+ x,
45
+ verbose=False,
46
+ training_kwargs=kwargs,
47
+ )
48
+ except Exception as e:
49
+ assert False, f"Fitting failed with the following exception {repr(e)}"
50
+
51
+
52
+
53
+ def test_RRAE_fitting():
54
+ sh = (500, 10)
55
+ model_cls = RRAE_MLP
56
+ x = random.normal(size=sh)
57
+ x = torch.tensor(x, dtype=torch.float32)
58
+ trainor = RRAE_Trainor_class(
59
+ x,
60
+ model_cls,
61
+ in_size=x.shape[0],
62
+ latent_size=2000,
63
+ k_max=2,
64
+ )
65
+ training_kwargs = {
66
+ "step_st": [2],
67
+ "loss_type":"RRAE"
68
+ }
69
+ ft_kwargs = {
70
+ "step_st": [2],
71
+ }
72
+ try:
73
+ trainor.fit(
74
+ x,
75
+ x,
76
+ verbose=False,
77
+ training_kwargs=training_kwargs,
78
+ ft_kwargs=ft_kwargs,
79
+ )
80
+ except Exception as e:
81
+ assert False, f"Fitting failed with the following exception {repr(e)}"
82
+
83
+ def test_IRMAE_fitting():
84
+ model_cls = IRMAE_MLP
85
+ lf = "default"
86
+ sh = (500, 10)
87
+ x = random.normal(size=sh)
88
+ x = torch.tensor(x, dtype=torch.float32)
89
+ trainor = AE_Trainor_class(
90
+ x,
91
+ model_cls,
92
+ in_size=x.shape[0],
93
+ data_size=x.shape[-1],
94
+ latent_size=2000,
95
+ k_max=2,
96
+ linear_l=4,
97
+ )
98
+ kwargs = {"step_st": [2], "loss_type":lf}
99
+ try:
100
+ trainor.fit(
101
+ x,
102
+ x,
103
+ verbose=False,
104
+ training_kwargs=kwargs,
105
+ )
106
+ except Exception as e:
107
+ assert False, f"Fitting failed with the following exception {repr(e)}"
108
+
109
+ def test_fitting():
110
+ sh = (50, 100)
111
+ model_cls = MLP
112
+ x = random.normal(size=sh)
113
+ x = torch.tensor(x, dtype=torch.float32)
114
+ trainor = Trainor_class(
115
+ x,
116
+ model_cls,
117
+ in_channels=x.shape[0],
118
+ hidden_channels=[100, x.shape[0]]
119
+ )
120
+ training_kwargs = {
121
+ "step_st": [2],
122
+ "loss_type": "default"
123
+ }
124
+
125
+ try:
126
+ trainor.fit(
127
+ x,
128
+ x,
129
+ verbose=False,
130
+ **training_kwargs,
131
+ )
132
+ except Exception as e:
133
+ assert False, f"Fitting failed with the following exception {repr(e)}"
@@ -0,0 +1,34 @@
1
+ import subprocess
2
+ import os
3
+ import pytest
4
+ import shutil
5
+
6
+ def try_remove(name):
7
+ try:
8
+ shutil.rmtree(name)
9
+ except FileNotFoundError:
10
+ pass
11
+
12
+ def run_script(script_name):
13
+ try:
14
+ result = subprocess.run(
15
+ ["python", script_name], check=True, capture_output=True, text=True
16
+ )
17
+ try_remove("shift")
18
+ try_remove("folder_name")
19
+ try_remove("2d_gaussian_shift_scale")
20
+ try_remove("gaussian_shift")
21
+ return result.stdout
22
+ except subprocess.CalledProcessError as e:
23
+ pytest.fail(f"Error running {script_name}:\n{e.stderr}")
24
+
25
+
26
+ @pytest.mark.parametrize(
27
+ "script_name", ["main-MLP.py", "main-CNN.py", "general-MLP.py", "main-adap-CNN.py", "main-adap-MLP.py", "main-var-CNN.py", "main-CNN1D.py"]
28
+ )
29
+ def test_scripts(script_name):
30
+ if os.path.exists(script_name):
31
+ output = run_script(script_name)
32
+ assert output is not None
33
+ else:
34
+ pytest.fail(f"Script {script_name} not found")
@@ -0,0 +1,62 @@
1
+ import RRAEsTorch.config
2
+ from RRAEsTorch.AE_classes import RRAE_CNN
3
+ from RRAEsTorch.training_classes import RRAE_Trainor_class
4
+ import numpy.random as random
5
+ import torch
6
+
7
+ def test_save(): # Only to test if saving/loading is causing a problem
8
+ data = random.normal(size=(1, 28, 28, 1))
9
+ data = torch.tensor(data, dtype=torch.float32)
10
+ model_cls = RRAE_CNN
11
+
12
+ trainor = RRAE_Trainor_class(
13
+ data,
14
+ model_cls,
15
+ latent_size=100,
16
+ channels=data.shape[0],
17
+ width=data.shape[1],
18
+ height=data.shape[2],
19
+ pre_func_inp=lambda x: x * 2 / 17,
20
+ pre_func_out=lambda x: x / 2,
21
+ k_max=2,
22
+ )
23
+
24
+ trainor.save_model("test_")
25
+ new_trainor = RRAE_Trainor_class()
26
+ new_trainor.load_model("test_", erase=True)
27
+ try:
28
+ pr = trainor.model(data[..., 0:1], k_max=2)
29
+ except Exception as e:
30
+ raise ValueError(f"Original trainor failed with following exception {e}")
31
+ try:
32
+ pr = new_trainor.model(data[..., 0:1], k_max=2)
33
+ except Exception as e:
34
+ raise ValueError(f"Failed with following exception {e}")
35
+
36
+
37
+ def test_save_with_final_act():
38
+ data = random.normal(size=(1, 28, 28, 1))
39
+ data = torch.tensor(data, dtype=torch.float32)
40
+
41
+ model_cls = RRAE_CNN
42
+
43
+ trainor = RRAE_Trainor_class(
44
+ data,
45
+ model_cls,
46
+ latent_size=100,
47
+ channels=data.shape[0],
48
+ width=data.shape[1],
49
+ height=data.shape[2],
50
+ kwargs_dec={"final_activation": torch.sigmoid},
51
+ k_max=2,
52
+ )
53
+
54
+ trainor.save_model("test_")
55
+ new_trainor = RRAE_Trainor_class()
56
+ new_trainor.load_model("test_", erase=True)
57
+ try:
58
+ pr = new_trainor.model(data[..., 0:1], k_max=2)
59
+ assert torch.max(pr) <= 1.0, "Final activation not working"
60
+ assert torch.min(pr) >= 0.0, "Final activation not working"
61
+ except Exception as e:
62
+ raise ValueError(f"Failed with following exception {e}")
@@ -0,0 +1,37 @@
1
+ from RRAEsTorch.utilities import stable_SVD
2
+ from torch.linalg import svd as normal_svd
3
+ import numpy.random as random
4
+ import pytest
5
+ import numpy as np
6
+ import torch
7
+
8
+ def stable_SVD_to_scalar(A):
9
+ U, s, Vt = stable_SVD(A)
10
+ return torch.linalg.norm((U * s) @ Vt) # Any scalar depending on U, s, and Vt.
11
+
12
+ def normal_svd_to_scalar(A):
13
+ U, s, Vt = normal_svd(A, full_matrices=False)
14
+ return torch.linalg.norm((U * s) @ Vt) # Any scalar depending on U, s, and Vt.
15
+
16
+ @pytest.mark.parametrize(
17
+ "length, width",
18
+ [(10, 10), (100, 10), (10, 100), (50000, 100), (1000, 1000), (100, 50000)],
19
+ )
20
+ def test_random_normal(length, width):
21
+ A = random.uniform(low=0.0, high=1.0, size=(length, width))
22
+ A = torch.tensor(A, dtype=torch.float32)
23
+
24
+ A = A.clone().detach().requires_grad_(True)
25
+
26
+ stable_value = stable_SVD_to_scalar(A)
27
+ stable_value.backward()
28
+ stable_grad = A.grad
29
+
30
+ A = A.clone().detach().requires_grad_(True)
31
+
32
+ normal_value = normal_svd_to_scalar(A)
33
+ normal_value.backward()
34
+ normal_grad = A.grad
35
+
36
+ assert torch.allclose(stable_value, normal_value, atol=1e-5, rtol=1e-5)
37
+ assert torch.allclose(stable_grad, normal_grad, atol=1e-5, rtol=1e-5)
@@ -0,0 +1,56 @@
1
+ from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
2
+ from torchvision.ops import MLP
3
+ import pytest
4
+ import numpy as np
5
+ import math
6
+ import numpy.random as random
7
+ import torch
8
+
9
+ def test_vmap_wrapper():
10
+ # Usually MLP only accepts a vector, here we give
11
+ # a tensor and vectorize over the last axis twice
12
+ data = random.normal(size=(50, 60, 600))
13
+ data = torch.tensor(data, dtype=torch.float32)
14
+
15
+ model_cls = vmap_wrap(MLP, -1, 2)
16
+ model = model_cls(50, [64, 100])
17
+ try:
18
+ model(data)
19
+ except ValueError:
20
+ pytest.fail("Vmap wrapper is not working properly.")
21
+
22
+ def test_norm_wrapper():
23
+ # Testing the keep_normalized kwarg
24
+ data = random.normal(size=(50,))
25
+ data = torch.tensor(data, dtype=torch.float32)
26
+ model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
27
+ model = model_cls(50, [64, 100])
28
+ try:
29
+ assert not torch.allclose(model(data), model(data, keep_normalized=True))
30
+ except AssertionError:
31
+ pytest.fail("The keep_normalized kwarg for norm wrapper is not behaving as expected.")
32
+
33
+ # Testing minmax with knwon mins and maxs
34
+ data = np.linspace(-1, 1, 100)
35
+ data = torch.tensor(data, dtype=torch.float32)
36
+ model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
37
+ model = model_cls(50, [64, 100])
38
+ try:
39
+ assert 0.55 == model.norm_in.default(None, 0.1)
40
+ assert -0.8 == model.inv_norm_out.default(None, 0.1)
41
+ except AssertionError:
42
+ pytest.fail("Something wrong with minmax wrapper.")
43
+
44
+ # Testing meanstd with knwon mean and std
45
+ data = random.normal(size=(50,))
46
+ data = (data-np.mean(data))/np.std(data)
47
+ data = data*2.0 + 1.0 # mean of 1 and std of 2
48
+ data = torch.tensor(data, dtype=torch.float32)
49
+
50
+ model_cls = norm_wrap(MLP, data, "meanstd", None, data, "meanstd", None)
51
+ model = model_cls(50, [64, 100])
52
+ try:
53
+ assert math.isclose(2, model.norm_in.default(None, 5), rel_tol=1e-1, abs_tol=1e-1)
54
+ assert math.isclose(7, model.inv_norm_out.default(None, 3), rel_tol=1e-1, abs_tol=1e-1)
55
+ except AssertionError:
56
+ pytest.fail("Something wrong with norm wrapper.")
@@ -0,0 +1 @@
1
+ from .trackers import *