RRAEsTorch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,104 @@
1
+ from RRAEsTorch.utilities import MLP_with_linear
2
+ import jax.random as jrandom
3
+ import torch.nn as nn
4
+ from torch.func import vmap
5
+
6
+ def set_autoencoder_base(cls):
7
+ global _AutoencoderBase
8
+ _AutoencoderBase = cls
9
+
10
+ def get_autoencoder_base():
11
+ global _AutoencoderBase
12
+ return _AutoencoderBase or _default_autoencoder
13
+
14
+ def _default_autoencoder():
15
+ class Autoencoder(nn.Module):
16
+ _encode: MLP_with_linear
17
+ _decode: MLP_with_linear
18
+ _perform_in_latent: callable
19
+ _perform_in_latent: callable
20
+ map_latent: bool
21
+ norm_funcs: list
22
+ inv_norm_funcs: list
23
+ count: int
24
+
25
+ def __init__(
26
+ self,
27
+ in_size,
28
+ latent_size,
29
+ latent_size_after=None,
30
+ _encode=None,
31
+ _decode=None,
32
+ map_latent=True,
33
+ *,
34
+ count=1,
35
+ kwargs_enc={},
36
+ kwargs_dec={},
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+
41
+ if latent_size_after is None:
42
+ latent_size_after = latent_size
43
+
44
+ if _encode is None:
45
+ if "width_size" not in kwargs_enc.keys():
46
+ kwargs_enc["width_size"] = 64
47
+
48
+ if "depth" not in kwargs_enc.keys():
49
+ kwargs_enc["depth"] = 1
50
+
51
+ self._encode = MLP_with_linear(
52
+ in_size=in_size,
53
+ out_size=latent_size,
54
+ **kwargs_enc,
55
+ )
56
+
57
+ else:
58
+ self._encode = _encode
59
+
60
+ if not hasattr(self, "_perform_in_latent"):
61
+ self._perform_in_latent = lambda x, *args, **kwargs: x
62
+
63
+ if _decode is None:
64
+ if "width_size" not in kwargs_dec.keys():
65
+ kwargs_dec["width_size"] = 64
66
+ if "depth" not in kwargs_dec.keys():
67
+ kwargs_dec["depth"] = 6
68
+
69
+ self._decode = MLP_with_linear(
70
+ in_size=latent_size_after,
71
+ out_size=in_size,
72
+ **kwargs_dec,
73
+ )
74
+ else:
75
+ self._decode = _decode
76
+
77
+ self.count = count
78
+ self.map_latent = map_latent
79
+ self.inv_norm_funcs = ["decode"]
80
+ self.norm_funcs = ["encode", "latent"]
81
+
82
+ def encode(self, x, *args, **kwargs):
83
+ return self._encode(x, *args, **kwargs)
84
+
85
+ def decode(self, x, *args, **kwargs):
86
+ return self._decode(x, *args, **kwargs)
87
+
88
+ def perform_in_latent(self, y, *args, **kwargs):
89
+ if self.map_latent:
90
+ new_perform_in_latent = lambda x: self._perform_in_latent(
91
+ x, *args, **kwargs
92
+ )
93
+ for _ in range(self.count):
94
+ new_perform_in_latent = vmap(new_perform_in_latent, in_dims=-1, out_dims=-1)
95
+ return new_perform_in_latent(y)
96
+ return self._perform_in_latent(y, *args, **kwargs)
97
+
98
+ def forward(self, x, *args, **kwargs):
99
+ return self.decode(self.perform_in_latent(self.encode(x), *args, **kwargs))
100
+
101
+ def latent(self, x, *args, **kwargs):
102
+ return self.perform_in_latent(self.encode(x), *args, **kwargs)
103
+
104
+ return Autoencoder
@@ -0,0 +1 @@
1
+ from .AE_base import *