RRAEsTorch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_base/AE_base.py +104 -0
- RRAEsTorch/AE_base/__init__.py +1 -0
- RRAEsTorch/AE_classes/AE_classes.py +636 -0
- RRAEsTorch/AE_classes/__init__.py +1 -0
- RRAEsTorch/__init__.py +1 -0
- RRAEsTorch/config.py +95 -0
- RRAEsTorch/tests/test_AE_classes_CNN.py +76 -0
- RRAEsTorch/tests/test_AE_classes_MLP.py +73 -0
- RRAEsTorch/tests/test_fitting_CNN.py +109 -0
- RRAEsTorch/tests/test_fitting_MLP.py +133 -0
- RRAEsTorch/tests/test_mains.py +34 -0
- RRAEsTorch/tests/test_save.py +62 -0
- RRAEsTorch/tests/test_stable_SVD.py +37 -0
- RRAEsTorch/tests/test_wrappers.py +56 -0
- RRAEsTorch/trackers/__init__.py +1 -0
- RRAEsTorch/trackers/trackers.py +245 -0
- RRAEsTorch/training_classes/__init__.py +5 -0
- RRAEsTorch/training_classes/training_classes.py +977 -0
- RRAEsTorch/utilities/__init__.py +1 -0
- RRAEsTorch/utilities/utilities.py +1562 -0
- RRAEsTorch/wrappers/__init__.py +1 -0
- RRAEsTorch/wrappers/wrappers.py +237 -0
- rraestorch-0.1.0.dist-info/METADATA +90 -0
- rraestorch-0.1.0.dist-info/RECORD +27 -0
- rraestorch-0.1.0.dist-info/WHEEL +4 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE +21 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE copy +21 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from RRAEsTorch.utilities import MLP_with_linear
|
|
2
|
+
import jax.random as jrandom
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from torch.func import vmap
|
|
5
|
+
|
|
6
|
+
def set_autoencoder_base(cls):
|
|
7
|
+
global _AutoencoderBase
|
|
8
|
+
_AutoencoderBase = cls
|
|
9
|
+
|
|
10
|
+
def get_autoencoder_base():
|
|
11
|
+
global _AutoencoderBase
|
|
12
|
+
return _AutoencoderBase or _default_autoencoder
|
|
13
|
+
|
|
14
|
+
def _default_autoencoder():
|
|
15
|
+
class Autoencoder(nn.Module):
|
|
16
|
+
_encode: MLP_with_linear
|
|
17
|
+
_decode: MLP_with_linear
|
|
18
|
+
_perform_in_latent: callable
|
|
19
|
+
_perform_in_latent: callable
|
|
20
|
+
map_latent: bool
|
|
21
|
+
norm_funcs: list
|
|
22
|
+
inv_norm_funcs: list
|
|
23
|
+
count: int
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
in_size,
|
|
28
|
+
latent_size,
|
|
29
|
+
latent_size_after=None,
|
|
30
|
+
_encode=None,
|
|
31
|
+
_decode=None,
|
|
32
|
+
map_latent=True,
|
|
33
|
+
*,
|
|
34
|
+
count=1,
|
|
35
|
+
kwargs_enc={},
|
|
36
|
+
kwargs_dec={},
|
|
37
|
+
**kwargs,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
if latent_size_after is None:
|
|
42
|
+
latent_size_after = latent_size
|
|
43
|
+
|
|
44
|
+
if _encode is None:
|
|
45
|
+
if "width_size" not in kwargs_enc.keys():
|
|
46
|
+
kwargs_enc["width_size"] = 64
|
|
47
|
+
|
|
48
|
+
if "depth" not in kwargs_enc.keys():
|
|
49
|
+
kwargs_enc["depth"] = 1
|
|
50
|
+
|
|
51
|
+
self._encode = MLP_with_linear(
|
|
52
|
+
in_size=in_size,
|
|
53
|
+
out_size=latent_size,
|
|
54
|
+
**kwargs_enc,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
else:
|
|
58
|
+
self._encode = _encode
|
|
59
|
+
|
|
60
|
+
if not hasattr(self, "_perform_in_latent"):
|
|
61
|
+
self._perform_in_latent = lambda x, *args, **kwargs: x
|
|
62
|
+
|
|
63
|
+
if _decode is None:
|
|
64
|
+
if "width_size" not in kwargs_dec.keys():
|
|
65
|
+
kwargs_dec["width_size"] = 64
|
|
66
|
+
if "depth" not in kwargs_dec.keys():
|
|
67
|
+
kwargs_dec["depth"] = 6
|
|
68
|
+
|
|
69
|
+
self._decode = MLP_with_linear(
|
|
70
|
+
in_size=latent_size_after,
|
|
71
|
+
out_size=in_size,
|
|
72
|
+
**kwargs_dec,
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
self._decode = _decode
|
|
76
|
+
|
|
77
|
+
self.count = count
|
|
78
|
+
self.map_latent = map_latent
|
|
79
|
+
self.inv_norm_funcs = ["decode"]
|
|
80
|
+
self.norm_funcs = ["encode", "latent"]
|
|
81
|
+
|
|
82
|
+
def encode(self, x, *args, **kwargs):
|
|
83
|
+
return self._encode(x, *args, **kwargs)
|
|
84
|
+
|
|
85
|
+
def decode(self, x, *args, **kwargs):
|
|
86
|
+
return self._decode(x, *args, **kwargs)
|
|
87
|
+
|
|
88
|
+
def perform_in_latent(self, y, *args, **kwargs):
|
|
89
|
+
if self.map_latent:
|
|
90
|
+
new_perform_in_latent = lambda x: self._perform_in_latent(
|
|
91
|
+
x, *args, **kwargs
|
|
92
|
+
)
|
|
93
|
+
for _ in range(self.count):
|
|
94
|
+
new_perform_in_latent = vmap(new_perform_in_latent, in_dims=-1, out_dims=-1)
|
|
95
|
+
return new_perform_in_latent(y)
|
|
96
|
+
return self._perform_in_latent(y, *args, **kwargs)
|
|
97
|
+
|
|
98
|
+
def forward(self, x, *args, **kwargs):
|
|
99
|
+
return self.decode(self.perform_in_latent(self.encode(x), *args, **kwargs))
|
|
100
|
+
|
|
101
|
+
def latent(self, x, *args, **kwargs):
|
|
102
|
+
return self.perform_in_latent(self.encode(x), *args, **kwargs)
|
|
103
|
+
|
|
104
|
+
return Autoencoder
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .AE_base import *
|