RRAEsTorch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,636 @@
1
+ from RRAEsTorch.utilities import (
2
+ Sample,
3
+ CNNs_with_MLP,
4
+ MLP_with_CNNs_trans,
5
+ CNN3D_with_MLP,
6
+ MLP_with_CNN3D_trans,
7
+ stable_SVD,
8
+ )
9
+ from RRAEsTorch.wrappers import vmap_wrap
10
+ import jax.random as jrandom
11
+ import warnings
12
+ from torch.nn import Linear
13
+ from RRAEsTorch.AE_base import get_autoencoder_base
14
+ import torch
15
+ from torch.func import vmap
16
+
17
+ _identity = lambda x, *args, **kwargs: x
18
+
19
+ def latent_func_strong_RRAE(
20
+ self,
21
+ y,
22
+ k_max=None,
23
+ apply_basis=None,
24
+ get_basis_coeffs=False,
25
+ get_coeffs=False,
26
+ get_right_sing=False,
27
+ ret=False,
28
+ *args,
29
+ **kwargs,
30
+ ):
31
+ """Performing the truncated SVD in the latent space.
32
+
33
+ Parameters
34
+ ----------
35
+ y : jnp.array
36
+ The latent space.
37
+ k_max : int
38
+ The maximum number of modes to keep. If this is -1,
39
+ the function will return y (i.e. all the modes).
40
+
41
+ Returns
42
+ -------
43
+ y_approx : jnp.array
44
+ The latent space after the truncation.
45
+ """
46
+ if apply_basis is not None:
47
+ if get_basis_coeffs:
48
+ return apply_basis, apply_basis.T @ y
49
+ if get_coeffs:
50
+ if get_right_sing:
51
+ raise ValueError("Can not find right singular vector when projecting on basis")
52
+ if get_right_sing:
53
+ raise ValueError("Can not find right singular vector when projecting on basis")
54
+ return apply_basis.T @ y
55
+ return apply_basis @ apply_basis.T @ y
56
+
57
+ k_max = -1 if k_max is None else k_max
58
+
59
+ if get_basis_coeffs or get_coeffs:
60
+ u, s, v = stable_SVD(y)
61
+
62
+ if isinstance(k_max, int):
63
+ k_max = [k_max]
64
+
65
+ u_now = [u[:, :k] for k in k_max]
66
+ coeffs = [torch.multiply(v[:k, :], torch.unsqueeze(s[:k], -1)) for k in k_max]
67
+
68
+ if len(k_max) == 1:
69
+ u_now = u_now[0]
70
+ coeffs = coeffs[0]
71
+
72
+ if get_coeffs:
73
+ if get_right_sing:
74
+ return v[:k_max, :]
75
+ if get_right_sing:
76
+ return v[:k_max, :]
77
+ return coeffs
78
+ return u_now, coeffs
79
+
80
+ if k_max != -1:
81
+ u, s, v = stable_SVD(y)
82
+
83
+ if k_max is None:
84
+ raise ValueError("k_max was not given when truncation is required.")
85
+
86
+ if isinstance(k_max, int):
87
+ k_max = [k_max]
88
+
89
+ y_approx = [(u[..., :k] * s[:k]) @ v[:k] for k in k_max]
90
+
91
+ if len(k_max) == 1:
92
+ y_approx = y_approx[0]
93
+
94
+ else:
95
+ y_approx = y
96
+ u_now = None
97
+ coeffs = None
98
+ sigs = None
99
+ if ret:
100
+ return u_now, coeffs, sigs
101
+ return y_approx
102
+
103
+ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
104
+ apply_basis = kwargs.get("apply_basis")
105
+
106
+ if "apply_basis" in kwargs:
107
+ kwargs.pop("apply_basis")
108
+
109
+ if kwargs.get("get_coeffs") or kwargs.get("get_basis_coeffs"):
110
+ if return_dist or return_lat_dist:
111
+ raise ValueError
112
+ return latent_func_strong_RRAE(self, y, k_max, apply_basis=apply_basis, **kwargs)
113
+
114
+ basis, coeffs = latent_func_strong_RRAE(self, y, k_max=k_max, get_basis_coeffs=True, apply_basis=apply_basis)
115
+ if self.typ == "eye":
116
+ mean = coeffs
117
+ elif self.typ == "trainable":
118
+ mean = self.lin_mean(coeffs)
119
+ else:
120
+ raise ValueError("typ must be either 'eye' or 'trainable'")
121
+
122
+ logvar = self.lin_logvar(coeffs)
123
+
124
+ if return_dist:
125
+ return mean, logvar
126
+
127
+ std = torch.exp(0.5 * logvar)
128
+ if epsilon is not None:
129
+ if len(epsilon.shape) == 4:
130
+ epsilon = epsilon[0, 0] # to allow tpu sharding
131
+ z = mean + torch.tensor(epsilon, dtype=torch.float32) * std
132
+ else:
133
+ z = mean
134
+
135
+ if return_lat_dist:
136
+ return basis @ z, mean, logvar
137
+ return basis @ z
138
+
139
+
140
+ class RRAE_MLP(get_autoencoder_base()):
141
+ """Subclass of RRAEs with the strong formulation when the input
142
+ is of dimension (data_size, batch_size).
143
+
144
+ Attributes
145
+ ----------
146
+ encode : MLP_with_linear
147
+ An MLP as the encoding function.
148
+ decode : MLP_with_linear
149
+ An MLP as the decoding function.
150
+ perform_in_latent : function
151
+ The function that performs operations in the latent space.
152
+ k_max : int
153
+ The maximum number of modes to keep in the latent space.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ in_size,
159
+ latent_size,
160
+ k_max,
161
+ post_proc_func=None,
162
+ *,
163
+ kwargs_enc={},
164
+ kwargs_dec={},
165
+ **kwargs,
166
+ ):
167
+
168
+ if "linear_l" in kwargs.keys():
169
+ warnings.warn("linear_l can not be specified for Strong")
170
+ kwargs.pop("linear_l")
171
+
172
+ super().__init__(
173
+ in_size,
174
+ latent_size,
175
+ map_latent=False,
176
+ post_proc_func=post_proc_func,
177
+ kwargs_enc=kwargs_enc,
178
+ kwargs_dec=kwargs_dec,
179
+ **kwargs,
180
+ )
181
+
182
+ def _perform_in_latent(self, y, *args, **kwargs):
183
+ return latent_func_strong_RRAE(self, y, *args, **kwargs)
184
+
185
+
186
+ class Vanilla_AE_MLP(get_autoencoder_base()):
187
+ """Vanilla Autoencoder.
188
+
189
+ Subclass for the Vanilla AE, basically the strong RRAE with
190
+ k_max = -1, hence returning all the modes with no truncation.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ in_size,
196
+ latent_size,
197
+ *,
198
+ kwargs_enc={},
199
+ kwargs_dec={},
200
+ **kwargs,
201
+ ):
202
+ if "k_max" in kwargs.keys():
203
+ if kwargs["k_max"] != -1:
204
+ warnings.warn(
205
+ "k_max can not be specified for Vanilla_AE_MLP, switching to -1 (all modes)"
206
+ )
207
+ kwargs.pop("k_max")
208
+
209
+ latent_size_after = latent_size
210
+
211
+ super().__init__(
212
+ in_size,
213
+ latent_size,
214
+ latent_size_after,
215
+ kwargs_enc=kwargs_enc,
216
+ kwargs_dec=kwargs_dec,
217
+ **kwargs,
218
+ )
219
+
220
+
221
+ def sample(y, sample_cls, k_max=None, epsilon=None, *args, **kwargs):
222
+ if epsilon is None:
223
+ new_perform_sample = lambda m, lv: sample_cls(m, lv, *args, **kwargs)
224
+ return vmap(new_perform_sample, in_dims=[-1, -1], out_dims=-1)(*y)
225
+ else:
226
+ new_perform_sample = lambda m, lv, s: sample_cls(m, lv, s, *args, **kwargs)
227
+ return vmap(new_perform_sample, in_dims=[-1, -1, -1], out_axes=-1)(
228
+ *y, epsilon
229
+ )
230
+
231
+
232
+ class VAE_MLP(get_autoencoder_base()):
233
+ _sample: Sample
234
+ lin_mean: Linear
235
+ lin_logvar: Linear
236
+ latent_size: int
237
+
238
+ def __init__(
239
+ self,
240
+ in_size,
241
+ latent_size,
242
+ *,
243
+ kwargs_enc={},
244
+ kwargs_dec={},
245
+ **kwargs,
246
+ ):
247
+ self.latent_size = latent_size
248
+ self._sample = Sample(sample_dim=latent_size)
249
+ self.lin_mean = Linear(latent_size, latent_size)
250
+ self.lin_logvar = Linear(latent_size, latent_size)
251
+
252
+ super().__init__(
253
+ in_size,
254
+ latent_size,
255
+ map_latent=False,
256
+ kwargs_enc=kwargs_enc,
257
+ kwargs_dec=kwargs_dec,
258
+ )
259
+
260
+ def _perform_in_latent(self, y, *args, return_dist=False, **kwargs):
261
+ y = vmap(self.lin_mean, in_dims=-1, out_axes=-1)(y), vmap(
262
+ self.lin_logvar, in_dims=-1, out_axes=-1
263
+ )(y)
264
+ if return_dist:
265
+ return y[0], y[1]
266
+ return sample(y, self._sample, *args, **kwargs)
267
+
268
+
269
+ class IRMAE_MLP(get_autoencoder_base()):
270
+ def __init__(
271
+ self,
272
+ in_size,
273
+ latent_size,
274
+ linear_l=None,
275
+ *,
276
+ kwargs_enc={},
277
+ kwargs_dec={},
278
+ **kwargs,
279
+ ):
280
+
281
+ assert linear_l is not None, "linear_l must be specified for IRMAE_MLP"
282
+
283
+ if "k_max" in kwargs.keys():
284
+ if kwargs["k_max"] != -1:
285
+ warnings.warn(
286
+ "k_max can not be specified for the model proposed, switching to -1 (all modes)"
287
+ )
288
+ kwargs.pop("k_max")
289
+
290
+ if "linear_l" in kwargs.keys():
291
+ raise ValueError("Specify linear_l in the constructor, not in kwargs")
292
+
293
+ kwargs_enc = {**kwargs_enc, "linear_l": linear_l}
294
+
295
+ super().__init__(
296
+ in_size,
297
+ latent_size,
298
+ kwargs_enc=kwargs_enc,
299
+ kwargs_dec=kwargs_dec,
300
+ **kwargs,
301
+ )
302
+
303
+
304
+ class LoRAE_MLP(IRMAE_MLP):
305
+ def __init__(
306
+ self, in_size, latent_size, *, kwargs_enc={}, kwargs_dec={}, **kwargs
307
+ ):
308
+ if "linear_l" in kwargs.keys():
309
+ if kwargs["linear_l"] != 1:
310
+ raise ValueError("linear_l can not be specified for LoRAE_CNN")
311
+
312
+ super().__init__(
313
+ in_size,
314
+ latent_size,
315
+ linear_l=1,
316
+ kwargs_enc=kwargs_enc,
317
+ kwargs_dec=kwargs_dec,
318
+ **kwargs,
319
+ )
320
+
321
+
322
+
323
+ class CNN_Autoencoder(get_autoencoder_base()):
324
+ def __init__(
325
+ self,
326
+ channels,
327
+ height,
328
+ width,
329
+ latent_size,
330
+ latent_size_after=None,
331
+ *,
332
+ count=1,
333
+ dimension=2,
334
+ kwargs_enc={},
335
+ kwargs_dec={},
336
+ **kwargs,
337
+ ):
338
+ latent_size_after = (
339
+ latent_size if latent_size_after is None else latent_size_after
340
+ )
341
+
342
+ _encode = CNNs_with_MLP(
343
+ width=width,
344
+ height=height,
345
+ channels=channels,
346
+ out=latent_size,
347
+ dimension=dimension,
348
+ **kwargs_enc,
349
+ )
350
+
351
+ _decode = MLP_with_CNNs_trans(
352
+ width=width,
353
+ height=height,
354
+ inp=latent_size_after,
355
+ channels=channels,
356
+ dimension=dimension,
357
+ **kwargs_dec,
358
+ )
359
+
360
+ super().__init__(
361
+ None,
362
+ latent_size,
363
+ _encode=_encode,
364
+ map_latent=False,
365
+ _decode=_decode,
366
+ count=count,
367
+ **kwargs,
368
+ )
369
+
370
+ class CNN3D_Autoencoder(get_autoencoder_base()):
371
+ def __init__(
372
+ self,
373
+ depth, # I add the depth
374
+ width,
375
+ height,
376
+ channels,
377
+ latent_size,
378
+ k_max=-1,
379
+ latent_size_after=None,
380
+ _perform_in_latent=_identity,
381
+ *,
382
+ kwargs_enc={},
383
+ kwargs_dec={},
384
+ **kwargs,
385
+ ):
386
+ latent_size_after = (
387
+ latent_size if latent_size_after is None else latent_size_after
388
+ )
389
+
390
+ _encode = CNN3D_with_MLP(
391
+ depth=depth,
392
+ width=width,
393
+ height=height,
394
+ channels=channels,
395
+ out=latent_size,
396
+ **kwargs_enc,
397
+ )
398
+
399
+ _decode = MLP_with_CNN3D_trans(
400
+ depth=depth,
401
+ width=width,
402
+ height=height,
403
+ inp=latent_size_after,
404
+ channels=channels,
405
+ **kwargs_dec,
406
+ )
407
+
408
+ super().__init__(
409
+ None,
410
+ latent_size,
411
+ k_max=k_max,
412
+ _encode=_encode,
413
+ _perform_in_latent=_perform_in_latent,
414
+ map_latent=False,
415
+ _decode=_decode,
416
+ **kwargs,
417
+ )
418
+
419
+
420
+ class VAE_CNN(CNN_Autoencoder):
421
+ lin_mean: Linear
422
+ lin_logvar: Linear
423
+ latent_size: int
424
+
425
+ def __init__(self, channels, height, width, latent_size, *, count=1, **kwargs):
426
+ v_Linear = vmap_wrap(Linear, -1, count=count)
427
+ self.lin_mean = v_Linear(latent_size, latent_size)
428
+ self.lin_logvar = v_Linear(latent_size, latent_size)
429
+ self.latent_size = latent_size
430
+ super().__init__(
431
+ channels,
432
+ height,
433
+ width,
434
+ latent_size,
435
+ count=count,
436
+ **kwargs,
437
+ )
438
+
439
+ def _perform_in_latent(self, y, *args, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
440
+ mean = self.lin_mean(y)
441
+ logvar = self.lin_logvar(y)
442
+
443
+ if return_dist:
444
+ return mean, logvar
445
+
446
+ std = torch.exp(0.5 * logvar)
447
+ if epsilon is not None:
448
+ if len(epsilon.shape) == 4:
449
+ epsilon = epsilon[0, 0] # to allow tpu sharding
450
+ z = mean + epsilon * std
451
+ else:
452
+ z = mean
453
+
454
+ if return_lat_dist:
455
+ return z, mean, logvar
456
+ return z
457
+
458
+ class RRAE_CNN(CNN_Autoencoder):
459
+ """Subclass of RRAEs with the strong formulation for inputs of
460
+ dimension (channels, width, height).
461
+ """
462
+
463
+ def __init__(self, channels, height, width, latent_size, k_max, **kwargs):
464
+
465
+ super().__init__(
466
+ channels,
467
+ height,
468
+ width,
469
+ latent_size,
470
+ **kwargs,
471
+ )
472
+
473
+ def _perform_in_latent(self, y, *args, **kwargs):
474
+ return latent_func_strong_RRAE(self, y, *args, **kwargs)
475
+
476
+
477
+ class RRAE_CNN3D(CNN3D_Autoencoder):
478
+
479
+ def __init__(self, depth, width, height, channels, latent_size, k_max, **kwargs):
480
+
481
+ super().__init__(
482
+ depth,
483
+ width,
484
+ height,
485
+ channels,
486
+ latent_size,
487
+ k_max,
488
+ _perform_in_latent=latent_func_strong_RRAE,
489
+ **kwargs,
490
+ )
491
+
492
+ def _perform_in_latent(self, y, *args, **kwargs):
493
+ return latent_func_strong_RRAE(self, y, *args, **kwargs)
494
+
495
+
496
+ class VRRAE_CNN(CNN_Autoencoder):
497
+ lin_mean: Linear
498
+ lin_logvar: Linear
499
+ typ: int
500
+
501
+ def __init__(self, channels, height, width, latent_size, k_max, typ="eye", *, count=1, **kwargs):
502
+ super().__init__(
503
+ channels,
504
+ height,
505
+ width,
506
+ latent_size,
507
+ count=count,
508
+ **kwargs,
509
+ )
510
+ v_Linear = vmap_wrap(Linear, -1, count=count)
511
+ self.lin_mean = v_Linear(k_max, k_max)
512
+ self.lin_logvar = v_Linear(k_max, k_max)
513
+ self.typ = typ
514
+
515
+ def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
516
+ return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
517
+
518
+
519
+ class Vanilla_AE_CNN(CNN_Autoencoder):
520
+ """Vanilla Autoencoder.
521
+
522
+ Subclass for the Vanilla AE, basically the strong RRAE with
523
+ k_max = -1, hence returning all the modes with no truncation.
524
+ """
525
+
526
+ def __init__(self, channels, height, width, latent_size, **kwargs):
527
+ if "linear_l" in kwargs.keys():
528
+ warnings.warn("linear_l can not be specified for Vanilla_CNN")
529
+ kwargs.pop("linear_l")
530
+
531
+ super().__init__(channels, height, width, latent_size, **kwargs)
532
+
533
+
534
+ class IRMAE_CNN(CNN_Autoencoder):
535
+ def __init__(
536
+ self, channels, height, width, latent_size, linear_l=None, **kwargs
537
+ ):
538
+
539
+ assert linear_l is not None, "linear_l must be specified for IRMAE_CNN"
540
+
541
+ if "kwargs_enc" in kwargs:
542
+ kwargs_enc = kwargs["kwargs_enc"]
543
+ kwargs_enc["kwargs_mlp"] = {"linear_l": linear_l}
544
+ kwargs["kwargs_enc"] = kwargs_enc
545
+ else:
546
+ kwargs["kwargs_enc"] = {"kwargs_mlp": {"linear_l": linear_l}}
547
+ super().__init__(
548
+ channels,
549
+ height,
550
+ width,
551
+ latent_size,
552
+ **kwargs,
553
+ )
554
+
555
+
556
+ class LoRAE_CNN(IRMAE_CNN):
557
+ def __init__(self, channels, height, width, latent_size, **kwargs):
558
+
559
+ if "linear_l" in kwargs.keys():
560
+ if kwargs["linear_l"] != 1:
561
+ raise ValueError("linear_l can not be specified for LoRAE_CNN")
562
+
563
+ super().__init__(
564
+ channels, height, width, latent_size, linear_l=1, **kwargs
565
+ )
566
+
567
+
568
+ class CNN1D_Autoencoder(CNN_Autoencoder):
569
+ def __init__(
570
+ self,
571
+ channels,
572
+ input_dim,
573
+ latent_size,
574
+ latent_size_after=None,
575
+ *,
576
+ count=1,
577
+ kwargs_enc={},
578
+ kwargs_dec={},
579
+ **kwargs,
580
+ ):
581
+ super().__init__(channels,
582
+ input_dim,
583
+ None,
584
+ latent_size,
585
+ latent_size_after=latent_size_after,
586
+ count=count,
587
+ dimension=1,
588
+ kwargs_enc=kwargs_enc,
589
+ kwargs_dec=kwargs_dec,
590
+ )
591
+
592
+ class RRAE_CNN1D(CNN1D_Autoencoder):
593
+ """Subclass of RRAEs with the strong formulation for inputs of
594
+ dimension (channels, width, height).
595
+ """
596
+
597
+ def __init__(self, channels, input_dim, latent_size, k_max, **kwargs):
598
+
599
+ super().__init__(
600
+ channels,
601
+ input_dim,
602
+ latent_size,
603
+ **kwargs,
604
+ )
605
+
606
+ def _perform_in_latent(self, y, *args, **kwargs):
607
+ return latent_func_strong_RRAE(self, y, *args, **kwargs)
608
+
609
+
610
+ class VRRAE_CNN1D(CNN1D_Autoencoder):
611
+ lin_mean: Linear
612
+ lin_logvar: Linear
613
+ typ: int
614
+
615
+ def __init__(self, channels, input_dim, latent_size, k_max, typ="eye", *, count=1, **kwargs):
616
+ v_Linear = vmap_wrap(Linear, -1, count=count)
617
+ self.lin_mean = v_Linear(k_max, k_max,)
618
+ self.lin_logvar = v_Linear(k_max, k_max)
619
+ self.typ = typ
620
+ super().__init__(
621
+ channels,
622
+ input_dim,
623
+ latent_size,
624
+ count=count,
625
+ **kwargs,
626
+ )
627
+
628
+ def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
629
+ return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
630
+
631
+ def get_basis_coeffs(self, x, *args, **kwargs):
632
+ return self.perform_in_latent(self.encode(x), *args, get_basis_coeffs=True, **kwargs)
633
+
634
+ def decode_coeffs(self, c, basis, *args, **kwargs):
635
+ return self.decode(basis @ c, *args, **kwargs)
636
+
@@ -0,0 +1 @@
1
+ from .AE_classes import *
RRAEsTorch/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from .training_classes import *