RRAEsTorch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_base/AE_base.py +104 -0
- RRAEsTorch/AE_base/__init__.py +1 -0
- RRAEsTorch/AE_classes/AE_classes.py +636 -0
- RRAEsTorch/AE_classes/__init__.py +1 -0
- RRAEsTorch/__init__.py +1 -0
- RRAEsTorch/config.py +95 -0
- RRAEsTorch/tests/test_AE_classes_CNN.py +76 -0
- RRAEsTorch/tests/test_AE_classes_MLP.py +73 -0
- RRAEsTorch/tests/test_fitting_CNN.py +109 -0
- RRAEsTorch/tests/test_fitting_MLP.py +133 -0
- RRAEsTorch/tests/test_mains.py +34 -0
- RRAEsTorch/tests/test_save.py +62 -0
- RRAEsTorch/tests/test_stable_SVD.py +37 -0
- RRAEsTorch/tests/test_wrappers.py +56 -0
- RRAEsTorch/trackers/__init__.py +1 -0
- RRAEsTorch/trackers/trackers.py +245 -0
- RRAEsTorch/training_classes/__init__.py +5 -0
- RRAEsTorch/training_classes/training_classes.py +977 -0
- RRAEsTorch/utilities/__init__.py +1 -0
- RRAEsTorch/utilities/utilities.py +1562 -0
- RRAEsTorch/wrappers/__init__.py +1 -0
- RRAEsTorch/wrappers/wrappers.py +237 -0
- rraestorch-0.1.0.dist-info/METADATA +90 -0
- rraestorch-0.1.0.dist-info/RECORD +27 -0
- rraestorch-0.1.0.dist-info/WHEEL +4 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE +21 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE copy +21 -0
|
@@ -0,0 +1,636 @@
|
|
|
1
|
+
from RRAEsTorch.utilities import (
|
|
2
|
+
Sample,
|
|
3
|
+
CNNs_with_MLP,
|
|
4
|
+
MLP_with_CNNs_trans,
|
|
5
|
+
CNN3D_with_MLP,
|
|
6
|
+
MLP_with_CNN3D_trans,
|
|
7
|
+
stable_SVD,
|
|
8
|
+
)
|
|
9
|
+
from RRAEsTorch.wrappers import vmap_wrap
|
|
10
|
+
import jax.random as jrandom
|
|
11
|
+
import warnings
|
|
12
|
+
from torch.nn import Linear
|
|
13
|
+
from RRAEsTorch.AE_base import get_autoencoder_base
|
|
14
|
+
import torch
|
|
15
|
+
from torch.func import vmap
|
|
16
|
+
|
|
17
|
+
_identity = lambda x, *args, **kwargs: x
|
|
18
|
+
|
|
19
|
+
def latent_func_strong_RRAE(
|
|
20
|
+
self,
|
|
21
|
+
y,
|
|
22
|
+
k_max=None,
|
|
23
|
+
apply_basis=None,
|
|
24
|
+
get_basis_coeffs=False,
|
|
25
|
+
get_coeffs=False,
|
|
26
|
+
get_right_sing=False,
|
|
27
|
+
ret=False,
|
|
28
|
+
*args,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
"""Performing the truncated SVD in the latent space.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
y : jnp.array
|
|
36
|
+
The latent space.
|
|
37
|
+
k_max : int
|
|
38
|
+
The maximum number of modes to keep. If this is -1,
|
|
39
|
+
the function will return y (i.e. all the modes).
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
y_approx : jnp.array
|
|
44
|
+
The latent space after the truncation.
|
|
45
|
+
"""
|
|
46
|
+
if apply_basis is not None:
|
|
47
|
+
if get_basis_coeffs:
|
|
48
|
+
return apply_basis, apply_basis.T @ y
|
|
49
|
+
if get_coeffs:
|
|
50
|
+
if get_right_sing:
|
|
51
|
+
raise ValueError("Can not find right singular vector when projecting on basis")
|
|
52
|
+
if get_right_sing:
|
|
53
|
+
raise ValueError("Can not find right singular vector when projecting on basis")
|
|
54
|
+
return apply_basis.T @ y
|
|
55
|
+
return apply_basis @ apply_basis.T @ y
|
|
56
|
+
|
|
57
|
+
k_max = -1 if k_max is None else k_max
|
|
58
|
+
|
|
59
|
+
if get_basis_coeffs or get_coeffs:
|
|
60
|
+
u, s, v = stable_SVD(y)
|
|
61
|
+
|
|
62
|
+
if isinstance(k_max, int):
|
|
63
|
+
k_max = [k_max]
|
|
64
|
+
|
|
65
|
+
u_now = [u[:, :k] for k in k_max]
|
|
66
|
+
coeffs = [torch.multiply(v[:k, :], torch.unsqueeze(s[:k], -1)) for k in k_max]
|
|
67
|
+
|
|
68
|
+
if len(k_max) == 1:
|
|
69
|
+
u_now = u_now[0]
|
|
70
|
+
coeffs = coeffs[0]
|
|
71
|
+
|
|
72
|
+
if get_coeffs:
|
|
73
|
+
if get_right_sing:
|
|
74
|
+
return v[:k_max, :]
|
|
75
|
+
if get_right_sing:
|
|
76
|
+
return v[:k_max, :]
|
|
77
|
+
return coeffs
|
|
78
|
+
return u_now, coeffs
|
|
79
|
+
|
|
80
|
+
if k_max != -1:
|
|
81
|
+
u, s, v = stable_SVD(y)
|
|
82
|
+
|
|
83
|
+
if k_max is None:
|
|
84
|
+
raise ValueError("k_max was not given when truncation is required.")
|
|
85
|
+
|
|
86
|
+
if isinstance(k_max, int):
|
|
87
|
+
k_max = [k_max]
|
|
88
|
+
|
|
89
|
+
y_approx = [(u[..., :k] * s[:k]) @ v[:k] for k in k_max]
|
|
90
|
+
|
|
91
|
+
if len(k_max) == 1:
|
|
92
|
+
y_approx = y_approx[0]
|
|
93
|
+
|
|
94
|
+
else:
|
|
95
|
+
y_approx = y
|
|
96
|
+
u_now = None
|
|
97
|
+
coeffs = None
|
|
98
|
+
sigs = None
|
|
99
|
+
if ret:
|
|
100
|
+
return u_now, coeffs, sigs
|
|
101
|
+
return y_approx
|
|
102
|
+
|
|
103
|
+
def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
104
|
+
apply_basis = kwargs.get("apply_basis")
|
|
105
|
+
|
|
106
|
+
if "apply_basis" in kwargs:
|
|
107
|
+
kwargs.pop("apply_basis")
|
|
108
|
+
|
|
109
|
+
if kwargs.get("get_coeffs") or kwargs.get("get_basis_coeffs"):
|
|
110
|
+
if return_dist or return_lat_dist:
|
|
111
|
+
raise ValueError
|
|
112
|
+
return latent_func_strong_RRAE(self, y, k_max, apply_basis=apply_basis, **kwargs)
|
|
113
|
+
|
|
114
|
+
basis, coeffs = latent_func_strong_RRAE(self, y, k_max=k_max, get_basis_coeffs=True, apply_basis=apply_basis)
|
|
115
|
+
if self.typ == "eye":
|
|
116
|
+
mean = coeffs
|
|
117
|
+
elif self.typ == "trainable":
|
|
118
|
+
mean = self.lin_mean(coeffs)
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError("typ must be either 'eye' or 'trainable'")
|
|
121
|
+
|
|
122
|
+
logvar = self.lin_logvar(coeffs)
|
|
123
|
+
|
|
124
|
+
if return_dist:
|
|
125
|
+
return mean, logvar
|
|
126
|
+
|
|
127
|
+
std = torch.exp(0.5 * logvar)
|
|
128
|
+
if epsilon is not None:
|
|
129
|
+
if len(epsilon.shape) == 4:
|
|
130
|
+
epsilon = epsilon[0, 0] # to allow tpu sharding
|
|
131
|
+
z = mean + torch.tensor(epsilon, dtype=torch.float32) * std
|
|
132
|
+
else:
|
|
133
|
+
z = mean
|
|
134
|
+
|
|
135
|
+
if return_lat_dist:
|
|
136
|
+
return basis @ z, mean, logvar
|
|
137
|
+
return basis @ z
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class RRAE_MLP(get_autoencoder_base()):
|
|
141
|
+
"""Subclass of RRAEs with the strong formulation when the input
|
|
142
|
+
is of dimension (data_size, batch_size).
|
|
143
|
+
|
|
144
|
+
Attributes
|
|
145
|
+
----------
|
|
146
|
+
encode : MLP_with_linear
|
|
147
|
+
An MLP as the encoding function.
|
|
148
|
+
decode : MLP_with_linear
|
|
149
|
+
An MLP as the decoding function.
|
|
150
|
+
perform_in_latent : function
|
|
151
|
+
The function that performs operations in the latent space.
|
|
152
|
+
k_max : int
|
|
153
|
+
The maximum number of modes to keep in the latent space.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
in_size,
|
|
159
|
+
latent_size,
|
|
160
|
+
k_max,
|
|
161
|
+
post_proc_func=None,
|
|
162
|
+
*,
|
|
163
|
+
kwargs_enc={},
|
|
164
|
+
kwargs_dec={},
|
|
165
|
+
**kwargs,
|
|
166
|
+
):
|
|
167
|
+
|
|
168
|
+
if "linear_l" in kwargs.keys():
|
|
169
|
+
warnings.warn("linear_l can not be specified for Strong")
|
|
170
|
+
kwargs.pop("linear_l")
|
|
171
|
+
|
|
172
|
+
super().__init__(
|
|
173
|
+
in_size,
|
|
174
|
+
latent_size,
|
|
175
|
+
map_latent=False,
|
|
176
|
+
post_proc_func=post_proc_func,
|
|
177
|
+
kwargs_enc=kwargs_enc,
|
|
178
|
+
kwargs_dec=kwargs_dec,
|
|
179
|
+
**kwargs,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def _perform_in_latent(self, y, *args, **kwargs):
|
|
183
|
+
return latent_func_strong_RRAE(self, y, *args, **kwargs)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class Vanilla_AE_MLP(get_autoencoder_base()):
|
|
187
|
+
"""Vanilla Autoencoder.
|
|
188
|
+
|
|
189
|
+
Subclass for the Vanilla AE, basically the strong RRAE with
|
|
190
|
+
k_max = -1, hence returning all the modes with no truncation.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
in_size,
|
|
196
|
+
latent_size,
|
|
197
|
+
*,
|
|
198
|
+
kwargs_enc={},
|
|
199
|
+
kwargs_dec={},
|
|
200
|
+
**kwargs,
|
|
201
|
+
):
|
|
202
|
+
if "k_max" in kwargs.keys():
|
|
203
|
+
if kwargs["k_max"] != -1:
|
|
204
|
+
warnings.warn(
|
|
205
|
+
"k_max can not be specified for Vanilla_AE_MLP, switching to -1 (all modes)"
|
|
206
|
+
)
|
|
207
|
+
kwargs.pop("k_max")
|
|
208
|
+
|
|
209
|
+
latent_size_after = latent_size
|
|
210
|
+
|
|
211
|
+
super().__init__(
|
|
212
|
+
in_size,
|
|
213
|
+
latent_size,
|
|
214
|
+
latent_size_after,
|
|
215
|
+
kwargs_enc=kwargs_enc,
|
|
216
|
+
kwargs_dec=kwargs_dec,
|
|
217
|
+
**kwargs,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def sample(y, sample_cls, k_max=None, epsilon=None, *args, **kwargs):
|
|
222
|
+
if epsilon is None:
|
|
223
|
+
new_perform_sample = lambda m, lv: sample_cls(m, lv, *args, **kwargs)
|
|
224
|
+
return vmap(new_perform_sample, in_dims=[-1, -1], out_dims=-1)(*y)
|
|
225
|
+
else:
|
|
226
|
+
new_perform_sample = lambda m, lv, s: sample_cls(m, lv, s, *args, **kwargs)
|
|
227
|
+
return vmap(new_perform_sample, in_dims=[-1, -1, -1], out_axes=-1)(
|
|
228
|
+
*y, epsilon
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class VAE_MLP(get_autoencoder_base()):
|
|
233
|
+
_sample: Sample
|
|
234
|
+
lin_mean: Linear
|
|
235
|
+
lin_logvar: Linear
|
|
236
|
+
latent_size: int
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
in_size,
|
|
241
|
+
latent_size,
|
|
242
|
+
*,
|
|
243
|
+
kwargs_enc={},
|
|
244
|
+
kwargs_dec={},
|
|
245
|
+
**kwargs,
|
|
246
|
+
):
|
|
247
|
+
self.latent_size = latent_size
|
|
248
|
+
self._sample = Sample(sample_dim=latent_size)
|
|
249
|
+
self.lin_mean = Linear(latent_size, latent_size)
|
|
250
|
+
self.lin_logvar = Linear(latent_size, latent_size)
|
|
251
|
+
|
|
252
|
+
super().__init__(
|
|
253
|
+
in_size,
|
|
254
|
+
latent_size,
|
|
255
|
+
map_latent=False,
|
|
256
|
+
kwargs_enc=kwargs_enc,
|
|
257
|
+
kwargs_dec=kwargs_dec,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def _perform_in_latent(self, y, *args, return_dist=False, **kwargs):
|
|
261
|
+
y = vmap(self.lin_mean, in_dims=-1, out_axes=-1)(y), vmap(
|
|
262
|
+
self.lin_logvar, in_dims=-1, out_axes=-1
|
|
263
|
+
)(y)
|
|
264
|
+
if return_dist:
|
|
265
|
+
return y[0], y[1]
|
|
266
|
+
return sample(y, self._sample, *args, **kwargs)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class IRMAE_MLP(get_autoencoder_base()):
|
|
270
|
+
def __init__(
|
|
271
|
+
self,
|
|
272
|
+
in_size,
|
|
273
|
+
latent_size,
|
|
274
|
+
linear_l=None,
|
|
275
|
+
*,
|
|
276
|
+
kwargs_enc={},
|
|
277
|
+
kwargs_dec={},
|
|
278
|
+
**kwargs,
|
|
279
|
+
):
|
|
280
|
+
|
|
281
|
+
assert linear_l is not None, "linear_l must be specified for IRMAE_MLP"
|
|
282
|
+
|
|
283
|
+
if "k_max" in kwargs.keys():
|
|
284
|
+
if kwargs["k_max"] != -1:
|
|
285
|
+
warnings.warn(
|
|
286
|
+
"k_max can not be specified for the model proposed, switching to -1 (all modes)"
|
|
287
|
+
)
|
|
288
|
+
kwargs.pop("k_max")
|
|
289
|
+
|
|
290
|
+
if "linear_l" in kwargs.keys():
|
|
291
|
+
raise ValueError("Specify linear_l in the constructor, not in kwargs")
|
|
292
|
+
|
|
293
|
+
kwargs_enc = {**kwargs_enc, "linear_l": linear_l}
|
|
294
|
+
|
|
295
|
+
super().__init__(
|
|
296
|
+
in_size,
|
|
297
|
+
latent_size,
|
|
298
|
+
kwargs_enc=kwargs_enc,
|
|
299
|
+
kwargs_dec=kwargs_dec,
|
|
300
|
+
**kwargs,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class LoRAE_MLP(IRMAE_MLP):
|
|
305
|
+
def __init__(
|
|
306
|
+
self, in_size, latent_size, *, kwargs_enc={}, kwargs_dec={}, **kwargs
|
|
307
|
+
):
|
|
308
|
+
if "linear_l" in kwargs.keys():
|
|
309
|
+
if kwargs["linear_l"] != 1:
|
|
310
|
+
raise ValueError("linear_l can not be specified for LoRAE_CNN")
|
|
311
|
+
|
|
312
|
+
super().__init__(
|
|
313
|
+
in_size,
|
|
314
|
+
latent_size,
|
|
315
|
+
linear_l=1,
|
|
316
|
+
kwargs_enc=kwargs_enc,
|
|
317
|
+
kwargs_dec=kwargs_dec,
|
|
318
|
+
**kwargs,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class CNN_Autoencoder(get_autoencoder_base()):
|
|
324
|
+
def __init__(
|
|
325
|
+
self,
|
|
326
|
+
channels,
|
|
327
|
+
height,
|
|
328
|
+
width,
|
|
329
|
+
latent_size,
|
|
330
|
+
latent_size_after=None,
|
|
331
|
+
*,
|
|
332
|
+
count=1,
|
|
333
|
+
dimension=2,
|
|
334
|
+
kwargs_enc={},
|
|
335
|
+
kwargs_dec={},
|
|
336
|
+
**kwargs,
|
|
337
|
+
):
|
|
338
|
+
latent_size_after = (
|
|
339
|
+
latent_size if latent_size_after is None else latent_size_after
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
_encode = CNNs_with_MLP(
|
|
343
|
+
width=width,
|
|
344
|
+
height=height,
|
|
345
|
+
channels=channels,
|
|
346
|
+
out=latent_size,
|
|
347
|
+
dimension=dimension,
|
|
348
|
+
**kwargs_enc,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
_decode = MLP_with_CNNs_trans(
|
|
352
|
+
width=width,
|
|
353
|
+
height=height,
|
|
354
|
+
inp=latent_size_after,
|
|
355
|
+
channels=channels,
|
|
356
|
+
dimension=dimension,
|
|
357
|
+
**kwargs_dec,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
super().__init__(
|
|
361
|
+
None,
|
|
362
|
+
latent_size,
|
|
363
|
+
_encode=_encode,
|
|
364
|
+
map_latent=False,
|
|
365
|
+
_decode=_decode,
|
|
366
|
+
count=count,
|
|
367
|
+
**kwargs,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
class CNN3D_Autoencoder(get_autoencoder_base()):
|
|
371
|
+
def __init__(
|
|
372
|
+
self,
|
|
373
|
+
depth, # I add the depth
|
|
374
|
+
width,
|
|
375
|
+
height,
|
|
376
|
+
channels,
|
|
377
|
+
latent_size,
|
|
378
|
+
k_max=-1,
|
|
379
|
+
latent_size_after=None,
|
|
380
|
+
_perform_in_latent=_identity,
|
|
381
|
+
*,
|
|
382
|
+
kwargs_enc={},
|
|
383
|
+
kwargs_dec={},
|
|
384
|
+
**kwargs,
|
|
385
|
+
):
|
|
386
|
+
latent_size_after = (
|
|
387
|
+
latent_size if latent_size_after is None else latent_size_after
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
_encode = CNN3D_with_MLP(
|
|
391
|
+
depth=depth,
|
|
392
|
+
width=width,
|
|
393
|
+
height=height,
|
|
394
|
+
channels=channels,
|
|
395
|
+
out=latent_size,
|
|
396
|
+
**kwargs_enc,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
_decode = MLP_with_CNN3D_trans(
|
|
400
|
+
depth=depth,
|
|
401
|
+
width=width,
|
|
402
|
+
height=height,
|
|
403
|
+
inp=latent_size_after,
|
|
404
|
+
channels=channels,
|
|
405
|
+
**kwargs_dec,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
super().__init__(
|
|
409
|
+
None,
|
|
410
|
+
latent_size,
|
|
411
|
+
k_max=k_max,
|
|
412
|
+
_encode=_encode,
|
|
413
|
+
_perform_in_latent=_perform_in_latent,
|
|
414
|
+
map_latent=False,
|
|
415
|
+
_decode=_decode,
|
|
416
|
+
**kwargs,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
class VAE_CNN(CNN_Autoencoder):
|
|
421
|
+
lin_mean: Linear
|
|
422
|
+
lin_logvar: Linear
|
|
423
|
+
latent_size: int
|
|
424
|
+
|
|
425
|
+
def __init__(self, channels, height, width, latent_size, *, count=1, **kwargs):
|
|
426
|
+
v_Linear = vmap_wrap(Linear, -1, count=count)
|
|
427
|
+
self.lin_mean = v_Linear(latent_size, latent_size)
|
|
428
|
+
self.lin_logvar = v_Linear(latent_size, latent_size)
|
|
429
|
+
self.latent_size = latent_size
|
|
430
|
+
super().__init__(
|
|
431
|
+
channels,
|
|
432
|
+
height,
|
|
433
|
+
width,
|
|
434
|
+
latent_size,
|
|
435
|
+
count=count,
|
|
436
|
+
**kwargs,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
def _perform_in_latent(self, y, *args, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
440
|
+
mean = self.lin_mean(y)
|
|
441
|
+
logvar = self.lin_logvar(y)
|
|
442
|
+
|
|
443
|
+
if return_dist:
|
|
444
|
+
return mean, logvar
|
|
445
|
+
|
|
446
|
+
std = torch.exp(0.5 * logvar)
|
|
447
|
+
if epsilon is not None:
|
|
448
|
+
if len(epsilon.shape) == 4:
|
|
449
|
+
epsilon = epsilon[0, 0] # to allow tpu sharding
|
|
450
|
+
z = mean + epsilon * std
|
|
451
|
+
else:
|
|
452
|
+
z = mean
|
|
453
|
+
|
|
454
|
+
if return_lat_dist:
|
|
455
|
+
return z, mean, logvar
|
|
456
|
+
return z
|
|
457
|
+
|
|
458
|
+
class RRAE_CNN(CNN_Autoencoder):
|
|
459
|
+
"""Subclass of RRAEs with the strong formulation for inputs of
|
|
460
|
+
dimension (channels, width, height).
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
def __init__(self, channels, height, width, latent_size, k_max, **kwargs):
|
|
464
|
+
|
|
465
|
+
super().__init__(
|
|
466
|
+
channels,
|
|
467
|
+
height,
|
|
468
|
+
width,
|
|
469
|
+
latent_size,
|
|
470
|
+
**kwargs,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
def _perform_in_latent(self, y, *args, **kwargs):
|
|
474
|
+
return latent_func_strong_RRAE(self, y, *args, **kwargs)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
class RRAE_CNN3D(CNN3D_Autoencoder):
|
|
478
|
+
|
|
479
|
+
def __init__(self, depth, width, height, channels, latent_size, k_max, **kwargs):
|
|
480
|
+
|
|
481
|
+
super().__init__(
|
|
482
|
+
depth,
|
|
483
|
+
width,
|
|
484
|
+
height,
|
|
485
|
+
channels,
|
|
486
|
+
latent_size,
|
|
487
|
+
k_max,
|
|
488
|
+
_perform_in_latent=latent_func_strong_RRAE,
|
|
489
|
+
**kwargs,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
def _perform_in_latent(self, y, *args, **kwargs):
|
|
493
|
+
return latent_func_strong_RRAE(self, y, *args, **kwargs)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class VRRAE_CNN(CNN_Autoencoder):
|
|
497
|
+
lin_mean: Linear
|
|
498
|
+
lin_logvar: Linear
|
|
499
|
+
typ: int
|
|
500
|
+
|
|
501
|
+
def __init__(self, channels, height, width, latent_size, k_max, typ="eye", *, count=1, **kwargs):
|
|
502
|
+
super().__init__(
|
|
503
|
+
channels,
|
|
504
|
+
height,
|
|
505
|
+
width,
|
|
506
|
+
latent_size,
|
|
507
|
+
count=count,
|
|
508
|
+
**kwargs,
|
|
509
|
+
)
|
|
510
|
+
v_Linear = vmap_wrap(Linear, -1, count=count)
|
|
511
|
+
self.lin_mean = v_Linear(k_max, k_max)
|
|
512
|
+
self.lin_logvar = v_Linear(k_max, k_max)
|
|
513
|
+
self.typ = typ
|
|
514
|
+
|
|
515
|
+
def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
516
|
+
return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class Vanilla_AE_CNN(CNN_Autoencoder):
|
|
520
|
+
"""Vanilla Autoencoder.
|
|
521
|
+
|
|
522
|
+
Subclass for the Vanilla AE, basically the strong RRAE with
|
|
523
|
+
k_max = -1, hence returning all the modes with no truncation.
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
def __init__(self, channels, height, width, latent_size, **kwargs):
|
|
527
|
+
if "linear_l" in kwargs.keys():
|
|
528
|
+
warnings.warn("linear_l can not be specified for Vanilla_CNN")
|
|
529
|
+
kwargs.pop("linear_l")
|
|
530
|
+
|
|
531
|
+
super().__init__(channels, height, width, latent_size, **kwargs)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class IRMAE_CNN(CNN_Autoencoder):
|
|
535
|
+
def __init__(
|
|
536
|
+
self, channels, height, width, latent_size, linear_l=None, **kwargs
|
|
537
|
+
):
|
|
538
|
+
|
|
539
|
+
assert linear_l is not None, "linear_l must be specified for IRMAE_CNN"
|
|
540
|
+
|
|
541
|
+
if "kwargs_enc" in kwargs:
|
|
542
|
+
kwargs_enc = kwargs["kwargs_enc"]
|
|
543
|
+
kwargs_enc["kwargs_mlp"] = {"linear_l": linear_l}
|
|
544
|
+
kwargs["kwargs_enc"] = kwargs_enc
|
|
545
|
+
else:
|
|
546
|
+
kwargs["kwargs_enc"] = {"kwargs_mlp": {"linear_l": linear_l}}
|
|
547
|
+
super().__init__(
|
|
548
|
+
channels,
|
|
549
|
+
height,
|
|
550
|
+
width,
|
|
551
|
+
latent_size,
|
|
552
|
+
**kwargs,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
class LoRAE_CNN(IRMAE_CNN):
|
|
557
|
+
def __init__(self, channels, height, width, latent_size, **kwargs):
|
|
558
|
+
|
|
559
|
+
if "linear_l" in kwargs.keys():
|
|
560
|
+
if kwargs["linear_l"] != 1:
|
|
561
|
+
raise ValueError("linear_l can not be specified for LoRAE_CNN")
|
|
562
|
+
|
|
563
|
+
super().__init__(
|
|
564
|
+
channels, height, width, latent_size, linear_l=1, **kwargs
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class CNN1D_Autoencoder(CNN_Autoencoder):
|
|
569
|
+
def __init__(
|
|
570
|
+
self,
|
|
571
|
+
channels,
|
|
572
|
+
input_dim,
|
|
573
|
+
latent_size,
|
|
574
|
+
latent_size_after=None,
|
|
575
|
+
*,
|
|
576
|
+
count=1,
|
|
577
|
+
kwargs_enc={},
|
|
578
|
+
kwargs_dec={},
|
|
579
|
+
**kwargs,
|
|
580
|
+
):
|
|
581
|
+
super().__init__(channels,
|
|
582
|
+
input_dim,
|
|
583
|
+
None,
|
|
584
|
+
latent_size,
|
|
585
|
+
latent_size_after=latent_size_after,
|
|
586
|
+
count=count,
|
|
587
|
+
dimension=1,
|
|
588
|
+
kwargs_enc=kwargs_enc,
|
|
589
|
+
kwargs_dec=kwargs_dec,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
class RRAE_CNN1D(CNN1D_Autoencoder):
|
|
593
|
+
"""Subclass of RRAEs with the strong formulation for inputs of
|
|
594
|
+
dimension (channels, width, height).
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
def __init__(self, channels, input_dim, latent_size, k_max, **kwargs):
|
|
598
|
+
|
|
599
|
+
super().__init__(
|
|
600
|
+
channels,
|
|
601
|
+
input_dim,
|
|
602
|
+
latent_size,
|
|
603
|
+
**kwargs,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
def _perform_in_latent(self, y, *args, **kwargs):
|
|
607
|
+
return latent_func_strong_RRAE(self, y, *args, **kwargs)
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class VRRAE_CNN1D(CNN1D_Autoencoder):
|
|
611
|
+
lin_mean: Linear
|
|
612
|
+
lin_logvar: Linear
|
|
613
|
+
typ: int
|
|
614
|
+
|
|
615
|
+
def __init__(self, channels, input_dim, latent_size, k_max, typ="eye", *, count=1, **kwargs):
|
|
616
|
+
v_Linear = vmap_wrap(Linear, -1, count=count)
|
|
617
|
+
self.lin_mean = v_Linear(k_max, k_max,)
|
|
618
|
+
self.lin_logvar = v_Linear(k_max, k_max)
|
|
619
|
+
self.typ = typ
|
|
620
|
+
super().__init__(
|
|
621
|
+
channels,
|
|
622
|
+
input_dim,
|
|
623
|
+
latent_size,
|
|
624
|
+
count=count,
|
|
625
|
+
**kwargs,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
629
|
+
return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
|
|
630
|
+
|
|
631
|
+
def get_basis_coeffs(self, x, *args, **kwargs):
|
|
632
|
+
return self.perform_in_latent(self.encode(x), *args, get_basis_coeffs=True, **kwargs)
|
|
633
|
+
|
|
634
|
+
def decode_coeffs(self, c, basis, *args, **kwargs):
|
|
635
|
+
return self.decode(basis @ c, *args, **kwargs)
|
|
636
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .AE_classes import *
|
RRAEsTorch/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .training_classes import *
|