RRAEsTorch 0.1.6__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rraestorch-0.1.6 → rraestorch-0.1.7}/.gitignore +4 -0
  2. {rraestorch-0.1.6 → rraestorch-0.1.7}/PKG-INFO +1 -1
  3. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/AE_classes/AE_classes.py +14 -12
  4. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/tests/test_AE_classes_CNN.py +20 -26
  5. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/tests/test_AE_classes_MLP.py +20 -28
  6. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/tests/test_fitting_CNN.py +14 -14
  7. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/tests/test_fitting_MLP.py +11 -13
  8. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/tests/test_save.py +11 -11
  9. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/training_classes/training_classes.py +55 -115
  10. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-CNN.py +4 -7
  11. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-CNN1D.py +7 -7
  12. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-CNN3D.py +10 -15
  13. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-MLP.py +2 -2
  14. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-adap-CNN.py +4 -4
  15. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-adap-MLP.py +2 -2
  16. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-var-CNN.py +6 -6
  17. {rraestorch-0.1.6 → rraestorch-0.1.7}/main-var-CNN1D.py +9 -9
  18. {rraestorch-0.1.6 → rraestorch-0.1.7}/pyproject.toml +1 -1
  19. rraestorch-0.1.6/RRAEsTorch/tests/test_wrappers.py +0 -56
  20. rraestorch-0.1.6/RRAEsTorch/utilities/utilities.py +0 -1561
  21. rraestorch-0.1.6/RRAEsTorch/wrappers/__init__.py +0 -1
  22. rraestorch-0.1.6/RRAEsTorch/wrappers/wrappers.py +0 -237
  23. {rraestorch-0.1.6 → rraestorch-0.1.7}/.github/workflows/python-app.yml +0 -0
  24. {rraestorch-0.1.6 → rraestorch-0.1.7}/LICENSE +0 -0
  25. {rraestorch-0.1.6 → rraestorch-0.1.7}/README copy.md +0 -0
  26. {rraestorch-0.1.6 → rraestorch-0.1.7}/README.md +0 -0
  27. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/AE_base/AE_base.py +0 -0
  28. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/AE_base/__init__.py +0 -0
  29. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/AE_classes/__init__.py +0 -0
  30. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/__init__.py +0 -0
  31. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/config.py +0 -0
  32. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/tests/test_mains.py +0 -0
  33. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/tests/test_stable_SVD.py +0 -0
  34. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/trackers/__init__.py +0 -0
  35. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/trackers/trackers.py +0 -0
  36. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/training_classes/__init__.py +0 -0
  37. {rraestorch-0.1.6 → rraestorch-0.1.7}/RRAEsTorch/utilities/__init__.py +0 -0
  38. {rraestorch-0.1.6 → rraestorch-0.1.7}/general-MLP.py +0 -0
  39. {rraestorch-0.1.6 → rraestorch-0.1.7}/setup.cfg +0 -0
@@ -5,6 +5,10 @@ __pycache__/
5
5
 
6
6
  # C extensions
7
7
  *.so
8
+ *.mat
9
+ utilities.py
10
+ main_rugosity.py
11
+ nomacro
8
12
 
9
13
  # Distribution / packaging
10
14
  .Python
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: RRAEsTorch
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: A repo for RRAEs in PyTorch.
5
5
  Author-email: Jad Mounayer <jad.mounayer@outlook.com>
6
6
  License: MIT
@@ -6,7 +6,6 @@ from RRAEsTorch.utilities import (
6
6
  MLP_with_CNN3D_trans,
7
7
  stable_SVD,
8
8
  )
9
- from RRAEsTorch.wrappers import vmap_wrap
10
9
  import warnings
11
10
  from torch.nn import Linear
12
11
  from RRAEsTorch.AE_base import get_autoencoder_base
@@ -42,6 +41,8 @@ def latent_func_strong_RRAE(
42
41
  y_approx : jnp.array
43
42
  The latent space after the truncation.
44
43
  """
44
+ y = y.T # to get the number of samples in the last dimension, as expected by the SVD function
45
+
45
46
  if apply_basis is not None:
46
47
  if get_basis_coeffs:
47
48
  return apply_basis, apply_basis.T @ y
@@ -51,7 +52,7 @@ def latent_func_strong_RRAE(
51
52
  if get_right_sing:
52
53
  raise ValueError("Can not find right singular vector when projecting on basis")
53
54
  return apply_basis.T @ y
54
- return apply_basis @ apply_basis.T @ y
55
+ return (apply_basis @ apply_basis.T @ y).T
55
56
 
56
57
  k_max = -1 if k_max is None else k_max
57
58
 
@@ -97,7 +98,7 @@ def latent_func_strong_RRAE(
97
98
  sigs = None
98
99
  if ret:
99
100
  return u_now, coeffs, sigs
100
- return y_approx
101
+ return y_approx.T
101
102
 
102
103
  def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
103
104
  apply_basis = kwargs.get("apply_basis")
@@ -111,6 +112,9 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
111
112
  return latent_func_strong_RRAE(self, y, k_max, apply_basis=apply_basis, **kwargs)
112
113
 
113
114
  basis, coeffs = latent_func_strong_RRAE(self, y, k_max=k_max, get_basis_coeffs=True, apply_basis=apply_basis)
115
+
116
+ coeffs = coeffs.T # to get the number of samples as first dim
117
+
114
118
  if self.typ == "eye":
115
119
  mean = coeffs
116
120
  elif self.typ == "trainable":
@@ -121,7 +125,7 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
121
125
  logvar = self.lin_logvar(coeffs)
122
126
 
123
127
  if return_dist:
124
- return mean, logvar
128
+ return mean.T, logvar.T
125
129
 
126
130
  std = torch.exp(0.5 * logvar)
127
131
  if epsilon is not None:
@@ -132,8 +136,8 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
132
136
  z = mean
133
137
 
134
138
  if return_lat_dist:
135
- return basis @ z, mean, logvar
136
- return basis @ z
139
+ return z @ basis.T, mean.T, logvar.T
140
+ return z @ basis.T
137
141
 
138
142
 
139
143
  class RRAE_MLP(get_autoencoder_base()):
@@ -506,9 +510,8 @@ class VRRAE_CNN(CNN_Autoencoder):
506
510
  count=count,
507
511
  **kwargs,
508
512
  )
509
- v_Linear = vmap_wrap(Linear, -1, count=count)
510
- self.lin_mean = v_Linear(k_max, k_max)
511
- self.lin_logvar = v_Linear(k_max, k_max)
513
+ self.lin_mean = Linear(k_max, k_max)
514
+ self.lin_logvar = Linear(k_max, k_max)
512
515
  self.typ = typ
513
516
 
514
517
  def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
@@ -620,9 +623,8 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
620
623
  **kwargs,
621
624
  )
622
625
 
623
- v_Linear = vmap_wrap(Linear, -1, count=count)
624
- self.lin_mean = v_Linear(k_max, k_max,)
625
- self.lin_logvar = v_Linear(k_max, k_max)
626
+ self.lin_mean = Linear(k_max, k_max,)
627
+ self.lin_logvar = Linear(k_max, k_max)
626
628
  self.typ = typ
627
629
 
628
630
 
@@ -6,18 +6,12 @@ from RRAEsTorch.AE_classes import (
6
6
  IRMAE_CNN,
7
7
  LoRAE_CNN,
8
8
  )
9
- from RRAEsTorch.wrappers import vmap_wrap
10
9
  import numpy.random as random
11
10
  import numpy as np
12
11
  import torch
13
12
 
14
13
  methods = ["encode", "decode"]
15
14
 
16
- v_RRAE_CNN = vmap_wrap(RRAE_CNN, -1, 1, methods)
17
- v_Vanilla_AE_CNN = vmap_wrap(Vanilla_AE_CNN, -1, 1, methods)
18
- v_IRMAE_CNN = vmap_wrap(IRMAE_CNN, -1, 1, methods)
19
- v_LoRAE_CNN = vmap_wrap(LoRAE_CNN, -1, 1, methods)
20
-
21
15
  @pytest.mark.parametrize("width", (10, 17, 149))
22
16
  @pytest.mark.parametrize("height", (20,))
23
17
  @pytest.mark.parametrize("latent", (200,))
@@ -26,51 +20,51 @@ v_LoRAE_CNN = vmap_wrap(LoRAE_CNN, -1, 1, methods)
26
20
  @pytest.mark.parametrize("num_samples", (10, 100))
27
21
  class Test_AEs_shapes:
28
22
  def test_RRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
29
- x = random.normal(size=(channels, width, height, num_samples))
23
+ x = random.normal(size=(num_samples, channels, width, height))
30
24
  x = torch.tensor(x, dtype=torch.float32)
31
25
  kwargs = {"kwargs_dec": {"stride": 2}}
32
- model = v_RRAE_CNN(
33
- x.shape[0], x.shape[1], x.shape[2], latent, num_modes, **kwargs
26
+ model = RRAE_CNN(
27
+ x.shape[1], x.shape[2], x.shape[3], latent, num_modes, **kwargs
34
28
  )
35
29
  y = model.encode(x)
36
- assert y.shape == (latent, num_samples)
30
+ assert y.shape == (num_samples, latent)
37
31
  y = model.latent(x, k_max=num_modes)
38
32
  _, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
39
33
  assert sing_vals[num_modes + 1] < 1e-5
40
- assert y.shape == (latent, num_samples)
41
- assert model.decode(y).shape == (channels, width, height, num_samples)
34
+ assert y.shape == (num_samples, latent)
35
+ assert model.decode(y).shape == (num_samples, channels, width, height)
42
36
 
43
37
  def test_Vanilla_CNN(self, latent, num_modes, width, height, channels, num_samples):
44
- x = random.normal(size=(channels, width, height, num_samples))
38
+ x = random.normal(size=(num_samples, channels, width, height))
45
39
  x = torch.tensor(x, dtype=torch.float32)
46
40
  kwargs = {"kwargs_dec": {"stride": 2}}
47
- model = v_Vanilla_AE_CNN(
48
- x.shape[0], x.shape[1], x.shape[2], latent, **kwargs
41
+ model = Vanilla_AE_CNN(
42
+ x.shape[1], x.shape[2], x.shape[3], latent, **kwargs
49
43
  )
50
44
  y = model.encode(x)
51
- assert y.shape == (latent, num_samples)
45
+ assert y.shape == (num_samples, latent)
52
46
  x = model.decode(y)
53
- assert x.shape == (channels, width, height, num_samples)
47
+ assert x.shape == (num_samples, channels, width, height)
54
48
 
55
49
 
56
50
  def test_IRMAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
57
- x = random.normal(size=(channels, width, height, num_samples))
51
+ x = random.normal(size=(num_samples, channels, width, height))
58
52
  x = torch.tensor(x, dtype=torch.float32)
59
- model = v_IRMAE_CNN(
60
- x.shape[0], x.shape[1], x.shape[2], latent, linear_l=2
53
+ model = IRMAE_CNN(
54
+ x.shape[1], x.shape[2], x.shape[3], latent, linear_l=2
61
55
  )
62
56
  y = model.encode(x)
63
- assert y.shape == (latent, num_samples)
57
+ assert y.shape == (num_samples, latent)
64
58
  assert len(model._encode.layers[-1].layers_l) == 2
65
59
  x = model.decode(y)
66
- assert x.shape == (channels, width, height, num_samples)
60
+ assert x.shape == (num_samples, channels, width, height)
67
61
 
68
62
  def test_LoRAE_CNN(self, latent, num_modes, width, height, channels, num_samples):
69
- x = random.normal(size=(channels, width, height, num_samples))
63
+ x = random.normal(size=(num_samples, channels, width, height))
70
64
  x = torch.tensor(x, dtype=torch.float32)
71
- model = v_LoRAE_CNN(x.shape[0], x.shape[1], x.shape[2], latent)
65
+ model = LoRAE_CNN(x.shape[1], x.shape[2], x.shape[3], latent)
72
66
  y = model.encode(x)
73
- assert y.shape == (latent, num_samples)
67
+ assert y.shape == (num_samples, latent)
74
68
  assert len(model._encode.layers[-1].layers_l) == 1
75
69
  x = model.decode(y)
76
- assert x.shape == (channels, width, height, num_samples)
70
+ assert x.shape == (num_samples, channels, width, height)
@@ -6,67 +6,59 @@ from RRAEsTorch.AE_classes import (
6
6
  IRMAE_MLP,
7
7
  LoRAE_MLP,
8
8
  )
9
- from RRAEsTorch.wrappers import vmap_wrap
10
9
  import numpy.random as random
11
10
  import numpy as np
12
11
  import torch
13
12
 
14
- methods = ["encode", "decode"]
15
-
16
- v_RRAE_MLP = vmap_wrap(RRAE_MLP, -1, 1, methods)
17
- v_Vanilla_AE_MLP = vmap_wrap(Vanilla_AE_MLP, -1, 1, methods)
18
- v_IRMAE_MLP = vmap_wrap(IRMAE_MLP, -1, 1, methods)
19
- v_LoRAE_MLP = vmap_wrap(LoRAE_MLP, -1, 1, methods)
20
-
21
13
  @pytest.mark.parametrize("dim_D", (10, 15, 50))
22
14
  @pytest.mark.parametrize("latent", (200, 400, 800))
23
15
  @pytest.mark.parametrize("num_modes", (1, 2, 6))
24
16
  class Test_AEs_shapes:
25
17
  def test_RRAE_MLP(self, latent, num_modes, dim_D):
26
- x = random.normal(size=(500, dim_D))
18
+ x = random.normal(size=(dim_D, 500))
27
19
  x = torch.tensor(x, dtype=torch.float32)
28
- model = v_RRAE_MLP(x.shape[0], latent, num_modes)
20
+ model = RRAE_MLP(x.shape[1], latent, num_modes)
29
21
  y = model.encode(x)
30
- assert y.shape == (latent, dim_D)
22
+ assert y.shape == (dim_D, latent)
31
23
  y = model.perform_in_latent(y, k_max=num_modes)
32
24
  _, sing_vals, _ = torch.linalg.svd(y, full_matrices=False)
33
- assert sing_vals[num_modes + 1] < 1e-5
34
- assert y.shape == (latent, dim_D)
35
- assert model.decode(y).shape == (500, dim_D)
25
+ assert sing_vals[num_modes + 1] < 1e-4
26
+ assert y.shape == (dim_D, latent)
27
+ assert model.decode(y).shape == (dim_D, 500)
36
28
 
37
29
  def test_Vanilla_MLP(self, latent, num_modes, dim_D):
38
- x = random.normal(size=(500, dim_D))
30
+ x = random.normal(size=(dim_D, 500))
39
31
  x = torch.tensor(x, dtype=torch.float32)
40
- model = v_Vanilla_AE_MLP(x.shape[0], latent)
32
+ model = Vanilla_AE_MLP(x.shape[1], latent)
41
33
  y = model.encode(x)
42
- assert y.shape == (latent, dim_D)
34
+ assert y.shape == (dim_D, latent)
43
35
  x = model.decode(y)
44
- assert x.shape == (500, dim_D)
36
+ assert x.shape == (dim_D, 500)
45
37
 
46
38
  def test_IRMAE_MLP(self, latent, num_modes, dim_D):
47
- x = random.normal(size=(500, dim_D))
39
+ x = random.normal(size=(dim_D, 500))
48
40
  x = torch.tensor(x, dtype=torch.float32)
49
- model = v_IRMAE_MLP(x.shape[0], latent, linear_l=2)
41
+ model = IRMAE_MLP(x.shape[1], latent, linear_l=2)
50
42
  y = model.encode(x)
51
- assert y.shape == (latent, dim_D)
43
+ assert y.shape == (dim_D, latent)
52
44
  assert len(model._encode.layers_l) == 2
53
45
  x = model.decode(y)
54
- assert x.shape == (500, dim_D)
46
+ assert x.shape == (dim_D, 500)
55
47
 
56
48
  def test_LoRAE_MLP(self, latent, num_modes, dim_D):
57
- x = random.normal(size=(500, dim_D))
49
+ x = random.normal(size=(dim_D, 500))
58
50
  x = torch.tensor(x, dtype=torch.float32)
59
- model = v_LoRAE_MLP(x.shape[0], latent)
51
+ model = LoRAE_MLP(x.shape[1], latent)
60
52
  y = model.encode(x)
61
- assert y.shape == (latent, dim_D)
53
+ assert y.shape == (dim_D, latent)
62
54
  assert len(model._encode.layers_l) == 1
63
55
  x = model.decode(y)
64
- assert x.shape == (500, dim_D)
56
+ assert x.shape == (dim_D, 500)
65
57
 
66
58
  def test_getting_SVD_coeffs():
67
- data = random.uniform(size=(500, 15))
59
+ data = random.uniform(size=(15, 500))
68
60
  data = torch.tensor(data, dtype=torch.float32)
69
- model_s = v_RRAE_MLP(data.shape[0], 200, 3)
61
+ model_s = RRAE_MLP(data.shape[1], 200, 3)
70
62
  basis, coeffs = model_s.latent(data, k_max=3, get_basis_coeffs=True)
71
63
  assert basis.shape == (200, 3)
72
64
  assert coeffs.shape == (3, 15)
@@ -13,8 +13,8 @@ import torch
13
13
  @pytest.mark.parametrize(
14
14
  "model_cls, sh, lf",
15
15
  [
16
- (Vanilla_AE_CNN, (1, 2, 2, 10), "default"),
17
- (LoRAE_CNN, (6, 16, 16, 10), "nuc"),
16
+ (Vanilla_AE_CNN, (10, 1, 2, 2), "default"),
17
+ (LoRAE_CNN, (10, 6, 16, 16), "nuc"),
18
18
  ],
19
19
  )
20
20
  def test_AE_fitting(model_cls, sh, lf):
@@ -24,10 +24,10 @@ def test_AE_fitting(model_cls, sh, lf):
24
24
  x,
25
25
  model_cls,
26
26
  latent_size=100,
27
- channels=x.shape[0],
28
- width=x.shape[1],
29
- height=x.shape[2],
30
- samples=x.shape[-1], # Only for weak
27
+ channels=x.shape[1],
28
+ width=x.shape[2],
29
+ height=x.shape[3],
30
+ samples=x.shape[0], # Only for weak
31
31
  k_max=2,
32
32
  )
33
33
  kwargs = {
@@ -52,16 +52,16 @@ def test_AE_fitting(model_cls, sh, lf):
52
52
  def test_IRMAE_fitting():
53
53
  model_cls = IRMAE_CNN
54
54
  lf = "default"
55
- sh = (3, 12, 12, 10)
55
+ sh = (10, 3, 12, 12)
56
56
  x = random.normal(size=sh)
57
57
  x = torch.tensor(x, dtype=torch.float32)
58
58
  trainor = AE_Trainor_class(
59
59
  x,
60
60
  model_cls,
61
61
  latent_size=100,
62
- channels=x.shape[0],
63
- width=x.shape[1],
64
- height=x.shape[2],
62
+ channels=x.shape[1],
63
+ width=x.shape[2],
64
+ height=x.shape[3],
65
65
  k_max=2,
66
66
  linear_l=4,
67
67
  )
@@ -77,7 +77,7 @@ def test_IRMAE_fitting():
77
77
  assert False, f"Fitting failed with the following exception {repr(e)}"
78
78
 
79
79
  def test_RRAE_fitting():
80
- sh = (1, 20, 20, 10)
80
+ sh = (10, 1, 20, 20)
81
81
  model_cls = RRAE_CNN
82
82
  x = random.normal(size=sh)
83
83
  x = torch.tensor(x, dtype=torch.float32)
@@ -85,9 +85,9 @@ def test_RRAE_fitting():
85
85
  x,
86
86
  model_cls,
87
87
  latent_size=100,
88
- channels=x.shape[0],
89
- width=x.shape[1],
90
- height=x.shape[2],
88
+ channels=x.shape[1],
89
+ width=x.shape[2],
90
+ height=x.shape[3],
91
91
  k_max=2,
92
92
  )
93
93
  training_kwargs = {
@@ -14,8 +14,8 @@ import torch
14
14
  @pytest.mark.parametrize(
15
15
  "model_cls, sh, lf",
16
16
  [
17
- (Vanilla_AE_MLP, (500, 10), "default"),
18
- (LoRAE_MLP, (500, 10), "nuc"),
17
+ (Vanilla_AE_MLP, (10, 500), "default"),
18
+ (LoRAE_MLP, (10, 500), "nuc"),
19
19
  ],
20
20
  )
21
21
  def test_fitting(model_cls, sh, lf):
@@ -24,9 +24,8 @@ def test_fitting(model_cls, sh, lf):
24
24
  trainor = AE_Trainor_class(
25
25
  x,
26
26
  model_cls,
27
- in_size=x.shape[0],
28
- data_size=x.shape[-1],
29
- samples=x.shape[-1], # Only for weak
27
+ in_size=x.shape[1],
28
+ samples=x.shape[0], # Only for weak
30
29
  norm_in="meanstd",
31
30
  norm_out="minmax",
32
31
  out_train=x,
@@ -51,14 +50,14 @@ def test_fitting(model_cls, sh, lf):
51
50
 
52
51
 
53
52
  def test_RRAE_fitting():
54
- sh = (500, 10)
53
+ sh = (10, 500)
55
54
  model_cls = RRAE_MLP
56
55
  x = random.normal(size=sh)
57
56
  x = torch.tensor(x, dtype=torch.float32)
58
57
  trainor = RRAE_Trainor_class(
59
58
  x,
60
59
  model_cls,
61
- in_size=x.shape[0],
60
+ in_size=x.shape[1],
62
61
  latent_size=2000,
63
62
  k_max=2,
64
63
  )
@@ -83,14 +82,13 @@ def test_RRAE_fitting():
83
82
  def test_IRMAE_fitting():
84
83
  model_cls = IRMAE_MLP
85
84
  lf = "default"
86
- sh = (500, 10)
85
+ sh = (10, 500)
87
86
  x = random.normal(size=sh)
88
87
  x = torch.tensor(x, dtype=torch.float32)
89
88
  trainor = AE_Trainor_class(
90
89
  x,
91
90
  model_cls,
92
- in_size=x.shape[0],
93
- data_size=x.shape[-1],
91
+ in_size=x.shape[1],
94
92
  latent_size=2000,
95
93
  k_max=2,
96
94
  linear_l=4,
@@ -107,15 +105,15 @@ def test_IRMAE_fitting():
107
105
  assert False, f"Fitting failed with the following exception {repr(e)}"
108
106
 
109
107
  def test_fitting():
110
- sh = (50, 100)
108
+ sh = (100, 50)
111
109
  model_cls = MLP
112
110
  x = random.normal(size=sh)
113
111
  x = torch.tensor(x, dtype=torch.float32)
114
112
  trainor = Trainor_class(
115
113
  x,
116
114
  model_cls,
117
- in_channels=x.shape[0],
118
- hidden_channels=[100, x.shape[0]]
115
+ in_channels=x.shape[1],
116
+ hidden_channels=[100, x.shape[1]]
119
117
  )
120
118
  training_kwargs = {
121
119
  "step_st": [2],
@@ -5,7 +5,7 @@ import numpy.random as random
5
5
  import torch
6
6
 
7
7
  def test_save(): # Only to test if saving/loading is causing a problem
8
- data = random.normal(size=(1, 28, 28, 1))
8
+ data = random.normal(size=(1, 1, 28, 28))
9
9
  data = torch.tensor(data, dtype=torch.float32)
10
10
  model_cls = RRAE_CNN
11
11
 
@@ -13,9 +13,9 @@ def test_save(): # Only to test if saving/loading is causing a problem
13
13
  data,
14
14
  model_cls,
15
15
  latent_size=100,
16
- channels=data.shape[0],
17
- width=data.shape[1],
18
- height=data.shape[2],
16
+ channels=data.shape[1],
17
+ width=data.shape[2],
18
+ height=data.shape[3],
19
19
  pre_func_inp=lambda x: x * 2 / 17,
20
20
  pre_func_out=lambda x: x / 2,
21
21
  k_max=2,
@@ -25,17 +25,17 @@ def test_save(): # Only to test if saving/loading is causing a problem
25
25
  new_trainor = RRAE_Trainor_class()
26
26
  new_trainor.load_model("test_", erase=True)
27
27
  try:
28
- pr = trainor.model(data[..., 0:1], k_max=2)
28
+ pr = trainor.model(data[0:1], k_max=2)
29
29
  except Exception as e:
30
30
  raise ValueError(f"Original trainor failed with following exception {e}")
31
31
  try:
32
- pr = new_trainor.model(data[..., 0:1], k_max=2)
32
+ pr = new_trainor.model(data[0:1], k_max=2)
33
33
  except Exception as e:
34
34
  raise ValueError(f"Failed with following exception {e}")
35
35
 
36
36
 
37
37
  def test_save_with_final_act():
38
- data = random.normal(size=(1, 28, 28, 1))
38
+ data = random.normal(size=(1, 1, 28, 28))
39
39
  data = torch.tensor(data, dtype=torch.float32)
40
40
 
41
41
  model_cls = RRAE_CNN
@@ -44,9 +44,9 @@ def test_save_with_final_act():
44
44
  data,
45
45
  model_cls,
46
46
  latent_size=100,
47
- channels=data.shape[0],
48
- width=data.shape[1],
49
- height=data.shape[2],
47
+ channels=data.shape[1],
48
+ width=data.shape[2],
49
+ height=data.shape[3],
50
50
  kwargs_dec={"final_activation": torch.sigmoid},
51
51
  k_max=2,
52
52
  )
@@ -55,7 +55,7 @@ def test_save_with_final_act():
55
55
  new_trainor = RRAE_Trainor_class()
56
56
  new_trainor.load_model("test_", erase=True)
57
57
  try:
58
- pr = new_trainor.model(data[..., 0:1], k_max=2)
58
+ pr = new_trainor.model(data[0:1], k_max=2)
59
59
  assert torch.max(pr) <= 1.0, "Final activation not working"
60
60
  assert torch.min(pr) >= 0.0, "Final activation not working"
61
61
  except Exception as e: