RRAEsTorch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_classes/AE_classes.py +18 -14
- RRAEsTorch/tests/test_AE_classes_CNN.py +20 -26
- RRAEsTorch/tests/test_AE_classes_MLP.py +20 -28
- RRAEsTorch/tests/test_fitting_CNN.py +14 -14
- RRAEsTorch/tests/test_fitting_MLP.py +11 -13
- RRAEsTorch/tests/test_save.py +11 -11
- RRAEsTorch/training_classes/training_classes.py +78 -121
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/METADATA +1 -2
- rraestorch-0.1.7.dist-info/RECORD +22 -0
- RRAEsTorch/tests/test_wrappers.py +0 -56
- RRAEsTorch/utilities/utilities.py +0 -1562
- RRAEsTorch/wrappers/__init__.py +0 -1
- RRAEsTorch/wrappers/wrappers.py +0 -237
- rraestorch-0.1.5.dist-info/RECORD +0 -27
- rraestorch-0.1.5.dist-info/licenses/LICENSE copy +0 -21
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/WHEEL +0 -0
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,6 @@ import os
|
|
|
13
13
|
import time
|
|
14
14
|
import dill
|
|
15
15
|
import shutil
|
|
16
|
-
from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
|
|
17
16
|
from functools import partial
|
|
18
17
|
from RRAEsTorch.trackers import (
|
|
19
18
|
Null_Tracker,
|
|
@@ -26,6 +25,8 @@ from prettytable import PrettyTable
|
|
|
26
25
|
import torch
|
|
27
26
|
from torch.utils.data import TensorDataset, DataLoader
|
|
28
27
|
|
|
28
|
+
from RRAEsTorch.utilities import get_basis
|
|
29
|
+
|
|
29
30
|
class Circular_list:
|
|
30
31
|
"""
|
|
31
32
|
Creates a list of fixed size.
|
|
@@ -157,40 +158,17 @@ class Trainor_class:
|
|
|
157
158
|
model_cls=None,
|
|
158
159
|
folder="",
|
|
159
160
|
file=None,
|
|
160
|
-
out_train=None,
|
|
161
|
-
norm_in="None",
|
|
162
|
-
norm_out="None",
|
|
163
|
-
methods_map=["__call__"],
|
|
164
|
-
methods_norm_in=["__call__"],
|
|
165
|
-
methods_norm_out=["__call__"],
|
|
166
|
-
call_map_count=1,
|
|
167
|
-
call_map_axis=-1,
|
|
168
161
|
**kwargs,
|
|
169
162
|
):
|
|
170
163
|
if model_cls is not None:
|
|
171
164
|
orig_model_cls = model_cls
|
|
172
|
-
model_cls = vmap_wrap(orig_model_cls, call_map_axis, call_map_count, methods_map)
|
|
173
|
-
model_cls = norm_wrap(model_cls, in_train, norm_in, None, out_train, norm_out, None, methods_norm_in, methods_norm_out)
|
|
174
165
|
self.model = model_cls(**kwargs)
|
|
175
|
-
params_in = self.model.params_in
|
|
176
|
-
params_out = self.model.params_out
|
|
177
166
|
else:
|
|
178
167
|
orig_model_cls = None
|
|
179
|
-
params_in = None
|
|
180
|
-
params_out = None
|
|
181
168
|
|
|
182
169
|
self.all_kwargs = {
|
|
183
170
|
"kwargs": kwargs,
|
|
184
|
-
"
|
|
185
|
-
"params_out": params_out,
|
|
186
|
-
"norm_in": norm_in,
|
|
187
|
-
"norm_out": norm_out,
|
|
188
|
-
"call_map_axis": call_map_axis,
|
|
189
|
-
"call_map_count": call_map_count,
|
|
190
|
-
"orig_model_cls": orig_model_cls,
|
|
191
|
-
"methods_map": methods_map,
|
|
192
|
-
"methods_norm_in": methods_norm_in,
|
|
193
|
-
"methods_norm_out": methods_norm_out,
|
|
171
|
+
"orig_model_cls": orig_model_cls
|
|
194
172
|
}
|
|
195
173
|
|
|
196
174
|
self.folder = folder
|
|
@@ -229,7 +207,8 @@ class Trainor_class:
|
|
|
229
207
|
save_losses=False,
|
|
230
208
|
input_val=None,
|
|
231
209
|
output_val=None,
|
|
232
|
-
latent_size=0
|
|
210
|
+
latent_size=0,
|
|
211
|
+
device="cpu"
|
|
233
212
|
):
|
|
234
213
|
assert isinstance(input, torch.Tensor), "Input should be a torch tensor"
|
|
235
214
|
assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
|
|
@@ -301,7 +280,12 @@ class Trainor_class:
|
|
|
301
280
|
all_losses = []
|
|
302
281
|
|
|
303
282
|
|
|
304
|
-
|
|
283
|
+
model = model.to(device)
|
|
284
|
+
|
|
285
|
+
if input_val is not None:
|
|
286
|
+
dataset_val = TensorDataset(input_val, output_val, torch.arange(0, input_val.shape[0], 1))
|
|
287
|
+
dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[0], shuffle=False)
|
|
288
|
+
|
|
305
289
|
# Outler Loop
|
|
306
290
|
for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
|
|
307
291
|
try:
|
|
@@ -310,15 +294,13 @@ class Trainor_class:
|
|
|
310
294
|
filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
|
|
311
295
|
)
|
|
312
296
|
|
|
313
|
-
if (batch_size > input.shape[
|
|
314
|
-
print(f"Setting batch size to: {input.shape[
|
|
315
|
-
batch_size = input.shape[
|
|
297
|
+
if (batch_size > input.shape[0]) or batch_size == -1:
|
|
298
|
+
print(f"Setting batch size to: {input.shape[0]}")
|
|
299
|
+
batch_size = input.shape[0]
|
|
316
300
|
|
|
317
301
|
# Inner loop (batch)
|
|
318
|
-
inputT = input.permute(*range(input.ndim - 1, -1, -1))
|
|
319
|
-
outputT = output.permute(*range(output.ndim - 1, -1, -1))
|
|
320
302
|
|
|
321
|
-
dataset = TensorDataset(
|
|
303
|
+
dataset = TensorDataset(input, output, torch.arange(0, input.shape[0], 1))
|
|
322
304
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
323
305
|
data_iter = iter(dataloader)
|
|
324
306
|
|
|
@@ -331,15 +313,16 @@ class Trainor_class:
|
|
|
331
313
|
input_b, out_b, idx_b = next(data_iter)
|
|
332
314
|
|
|
333
315
|
start_time = time.perf_counter() # Start time
|
|
334
|
-
|
|
335
|
-
out_b = self.model.norm_out.default(None, pre_func_out(out_b)) # Pre-process batch out values
|
|
336
|
-
out_b = out_b.permute(*range(out_b.ndim - 1, -1, -1))
|
|
316
|
+
out_b = pre_func_out(out_b) # Pre-process batch out values
|
|
337
317
|
input_b = pre_func_inp(input_b) # Pre-process batch input values
|
|
338
|
-
epsilon = eps_fn(latent_size, input_b.shape[
|
|
318
|
+
epsilon = eps_fn(latent_size, input_b.shape[0])
|
|
339
319
|
|
|
340
320
|
step_kwargs = merge_dicts(loss_kwargs, track_params)
|
|
341
321
|
|
|
342
322
|
# Compute loss
|
|
323
|
+
input_b = input_b.to(device)
|
|
324
|
+
out_b = out_b.to(device)
|
|
325
|
+
|
|
343
326
|
loss, model, optimizer_tr, (aux, extra_track) = make_step(
|
|
344
327
|
model,
|
|
345
328
|
input_b,
|
|
@@ -351,12 +334,14 @@ class Trainor_class:
|
|
|
351
334
|
)
|
|
352
335
|
|
|
353
336
|
if input_val is not None:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
337
|
+
val_loss = []
|
|
338
|
+
for input_vb, out_vb, idx_b in dataloader_val:
|
|
339
|
+
out_vb = pre_func_out(out_vb)
|
|
340
|
+
val_loss_batch = loss_fun(
|
|
341
|
+
model, input_vb.to(device), out_vb.to(device), idx=idx_b, epsilon=None, **step_kwargs
|
|
342
|
+
)[0]
|
|
343
|
+
val_loss.append(val_loss_batch.item())
|
|
344
|
+
val_loss = sum(val_loss) / len(val_loss)
|
|
360
345
|
aux["val_loss"] = val_loss
|
|
361
346
|
else:
|
|
362
347
|
aux["val_loss"] = None
|
|
@@ -496,18 +481,16 @@ class Trainor_class:
|
|
|
496
481
|
hasattr(self, "batch_size") or batch_size is not None
|
|
497
482
|
), "You should either provide a batch_size or fit the model first."
|
|
498
483
|
|
|
499
|
-
|
|
500
|
-
dataset = TensorDataset(x_train_oT)
|
|
484
|
+
dataset = TensorDataset(x_train_o)
|
|
501
485
|
batch_size = self.batch_size if batch_size is None else batch_size
|
|
502
486
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
503
487
|
|
|
504
488
|
pred = []
|
|
505
489
|
for x_b in dataloader:
|
|
506
|
-
|
|
507
|
-
pred_batch = call_func(x_bT)
|
|
490
|
+
pred_batch = call_func(x_b[0])
|
|
508
491
|
pred.append(pred_batch)
|
|
509
492
|
|
|
510
|
-
y_pred_train_o = torch.concatenate(pred
|
|
493
|
+
y_pred_train_o = torch.concatenate(pred)
|
|
511
494
|
|
|
512
495
|
self.error_train_o = (
|
|
513
496
|
torch.linalg.norm(y_pred_train_o - y_train_o)
|
|
@@ -516,8 +499,8 @@ class Trainor_class:
|
|
|
516
499
|
)
|
|
517
500
|
print("Train error on original output: ", self.error_train_o)
|
|
518
501
|
|
|
519
|
-
y_pred_train =
|
|
520
|
-
y_train =
|
|
502
|
+
y_pred_train = pre_func_out(y_pred_train_o)
|
|
503
|
+
y_train = pre_func_out(y_train_o)
|
|
521
504
|
self.error_train = (
|
|
522
505
|
torch.linalg.norm(y_pred_train - y_train) / torch.linalg.norm(y_train) * 100
|
|
523
506
|
)
|
|
@@ -525,16 +508,14 @@ class Trainor_class:
|
|
|
525
508
|
|
|
526
509
|
if x_test_o is not None:
|
|
527
510
|
y_test_o = pre_func_out(y_test_o)
|
|
528
|
-
|
|
529
|
-
dataset = TensorDataset(x_test_oT)
|
|
511
|
+
dataset = TensorDataset(x_test_o)
|
|
530
512
|
batch_size = self.batch_size if batch_size is None else batch_size
|
|
531
513
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
532
514
|
pred = []
|
|
533
515
|
for x_b in dataloader:
|
|
534
|
-
|
|
535
|
-
pred_batch = call_func(x_bT)
|
|
516
|
+
pred_batch = call_func(x_b[0])
|
|
536
517
|
pred.append(pred_batch)
|
|
537
|
-
y_pred_test_o = torch.concatenate(pred
|
|
518
|
+
y_pred_test_o = torch.concatenate(pred)
|
|
538
519
|
self.error_test_o = (
|
|
539
520
|
torch.linalg.norm(y_pred_test_o - y_test_o)
|
|
540
521
|
/ torch.linalg.norm(y_test_o)
|
|
@@ -543,8 +524,8 @@ class Trainor_class:
|
|
|
543
524
|
|
|
544
525
|
print("Test error on original output: ", self.error_test_o)
|
|
545
526
|
|
|
546
|
-
y_test =
|
|
547
|
-
y_pred_test =
|
|
527
|
+
y_test = pre_func_out(y_test_o)
|
|
528
|
+
y_pred_test = pre_func_out(y_pred_test_o)
|
|
548
529
|
self.error_test = (
|
|
549
530
|
torch.linalg.norm(y_pred_test - y_test) / torch.linalg.norm(y_test) * 100
|
|
550
531
|
)
|
|
@@ -618,7 +599,7 @@ class Trainor_class:
|
|
|
618
599
|
dill.dump(obj, f)
|
|
619
600
|
print(f"Object saved in {filename}")
|
|
620
601
|
|
|
621
|
-
def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None,
|
|
602
|
+
def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, device="cpu",**fn_kwargs):
|
|
622
603
|
"""NOTE: fn_kwargs defines the functions of the model
|
|
623
604
|
(e.g. final_activation, inner activation), if
|
|
624
605
|
needed to be saved/loaded on different devices/OS.
|
|
@@ -638,29 +619,12 @@ class Trainor_class:
|
|
|
638
619
|
else:
|
|
639
620
|
orig_model_cls = orig_model_cls
|
|
640
621
|
kwargs = self.all_kwargs["kwargs"]
|
|
641
|
-
self.call_map_axis = self.all_kwargs["call_map_axis"]
|
|
642
|
-
self.call_map_count = self.all_kwargs["call_map_count"]
|
|
643
|
-
self.params_in = self.all_kwargs["params_in"]
|
|
644
|
-
self.params_out = self.all_kwargs["params_out"]
|
|
645
|
-
self.norm_in = self.all_kwargs["norm_in"]
|
|
646
|
-
self.norm_out = self.all_kwargs["norm_out"]
|
|
647
|
-
try:
|
|
648
|
-
self.methods_map = self.all_kwargs["methods_map"]
|
|
649
|
-
self.methods_norm_in = self.all_kwargs["methods_norm_in"]
|
|
650
|
-
self.methods_norm_out = self.all_kwargs["methods_norm_out"]
|
|
651
|
-
except:
|
|
652
|
-
self.methods_map = ["encode", "decode"]
|
|
653
|
-
self.methods_norm_in = ["encode"]
|
|
654
|
-
self.methods_norm_out = ["decode"]
|
|
655
622
|
|
|
656
623
|
kwargs.update(fn_kwargs)
|
|
657
|
-
|
|
658
|
-
model_cls = vmap_wrap(orig_model_cls, self.call_map_axis, self.call_map_count, self.methods_map)
|
|
659
|
-
model_cls = norm_wrap(model_cls, None, self.norm_in, self.params_in, None, self.norm_out, self.params_out, self.methods_norm_in, self.methods_norm_out)
|
|
660
624
|
|
|
661
|
-
model =
|
|
625
|
+
model = orig_model_cls(**kwargs)
|
|
662
626
|
model.load_state_dict(save_dict["model_state_dict"])
|
|
663
|
-
self.model = model
|
|
627
|
+
self.model = model.to(device)
|
|
664
628
|
attributes = save_dict["attr"]
|
|
665
629
|
|
|
666
630
|
for key in attributes:
|
|
@@ -671,7 +635,7 @@ class Trainor_class:
|
|
|
671
635
|
|
|
672
636
|
class AE_Trainor_class(Trainor_class):
|
|
673
637
|
def __init__(self, *args, **kwargs):
|
|
674
|
-
super().__init__(*args,
|
|
638
|
+
super().__init__(*args, **kwargs)
|
|
675
639
|
|
|
676
640
|
def fit(self, *args, training_kwargs, **kwargs):
|
|
677
641
|
if "pre_func_inp" not in kwargs:
|
|
@@ -830,11 +794,11 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
830
794
|
self.batch_size = 16 # default value
|
|
831
795
|
|
|
832
796
|
if ft_kwargs:
|
|
833
|
-
if "
|
|
834
|
-
|
|
835
|
-
ft_kwargs.pop("
|
|
797
|
+
if "get_basis_bool" in ft_kwargs:
|
|
798
|
+
get_basis_bool = ft_kwargs["get_basis_bool"]
|
|
799
|
+
ft_kwargs.pop("get_basis_bool")
|
|
836
800
|
else:
|
|
837
|
-
|
|
801
|
+
get_basis_bool = True
|
|
838
802
|
|
|
839
803
|
if "ft_end_type" in ft_kwargs:
|
|
840
804
|
ft_end_type = ft_kwargs["ft_end_type"]
|
|
@@ -849,8 +813,16 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
849
813
|
ft_end_type = "concat"
|
|
850
814
|
basis_call_kwargs = {}
|
|
851
815
|
|
|
816
|
+
device = ft_kwargs.get("device", "cpu")
|
|
817
|
+
|
|
818
|
+
if "AE_func" in ft_kwargs:
|
|
819
|
+
AE_func = ft_kwargs["AE_func"]
|
|
820
|
+
ft_kwargs.pop("AE_func")
|
|
821
|
+
else:
|
|
822
|
+
AE_func = lambda m:m
|
|
823
|
+
|
|
852
824
|
ft_model, ft_track_params = self.fine_tune_basis(
|
|
853
|
-
None, args=args, kwargs=ft_kwargs,
|
|
825
|
+
None, args=args, kwargs=ft_kwargs, get_basis_bool=get_basis_bool, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device, AE_func=AE_func
|
|
854
826
|
) # fine tune basis
|
|
855
827
|
self.ft_track_params = ft_track_params
|
|
856
828
|
else:
|
|
@@ -858,7 +830,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
858
830
|
ft_track_params = {}
|
|
859
831
|
return model, track_params, ft_model, ft_track_params
|
|
860
832
|
|
|
861
|
-
def fine_tune_basis(self, basis=None,
|
|
833
|
+
def fine_tune_basis(self, basis=None, get_basis_bool=True, end_type="concat", basis_call_kwargs={}, device="cpu", AE_func=lambda m:m, *, args, kwargs):
|
|
862
834
|
|
|
863
835
|
if "loss" in kwargs:
|
|
864
836
|
norm_loss_ = kwargs["loss"]
|
|
@@ -869,45 +841,29 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
869
841
|
)
|
|
870
842
|
|
|
871
843
|
if (basis is None):
|
|
872
|
-
|
|
873
|
-
if get_basis:
|
|
874
|
-
inp = args[0] if len(args) > 0 else kwargs["input"]
|
|
875
|
-
|
|
876
|
-
if "basis_batch_size" in kwargs:
|
|
877
|
-
basis_batch_size = kwargs["basis_batch_size"]
|
|
878
|
-
kwargs.pop("basis_batch_size")
|
|
879
|
-
else:
|
|
880
|
-
basis_batch_size = self.batch_size
|
|
844
|
+
inp = args[0] if len(args) > 0 else kwargs["input"]
|
|
881
845
|
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
print(all_bases.shape)
|
|
899
|
-
basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
|
|
900
|
-
self.basis = basis[:, : self.track_params["k_max"]]
|
|
901
|
-
else:
|
|
902
|
-
self.basis = all_bases
|
|
903
|
-
else:
|
|
904
|
-
bas = self.model.latent(self.pre_func_inp(inp[..., 0:1]), get_basis_coeffs=True, **self.track_params)[0]
|
|
905
|
-
self.basis = torch.eye(bas.shape[0])
|
|
846
|
+
if "basis_batch_size" in kwargs:
|
|
847
|
+
basis_batch_size = kwargs["basis_batch_size"]
|
|
848
|
+
kwargs.pop("basis_batch_size")
|
|
849
|
+
else:
|
|
850
|
+
basis_batch_size = self.batch_size
|
|
851
|
+
|
|
852
|
+
model = self.model.to(device)
|
|
853
|
+
k_max = self.track_params["k_max"]
|
|
854
|
+
if isinstance(AE_func, list):
|
|
855
|
+
bases = []
|
|
856
|
+
for func in AE_func:
|
|
857
|
+
bases.append(get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, func))
|
|
858
|
+
self.basis = bases
|
|
859
|
+
else:
|
|
860
|
+
self.basis = get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, AE_func)
|
|
861
|
+
|
|
906
862
|
else:
|
|
907
863
|
self.basis = basis
|
|
908
864
|
|
|
909
865
|
def loss_fun(model, input, out, idx, epsilon, basis):
|
|
910
|
-
pred = model(input, epsilon=epsilon, apply_basis=basis
|
|
866
|
+
pred = model(input, epsilon=epsilon, apply_basis=basis)
|
|
911
867
|
aux = {"loss": norm_loss_(pred, out)}
|
|
912
868
|
return norm_loss_(pred, out), (aux, {})
|
|
913
869
|
|
|
@@ -919,7 +875,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
919
875
|
|
|
920
876
|
kwargs.setdefault("loss_kwargs", {}).update({"basis": self.basis})
|
|
921
877
|
|
|
922
|
-
fix_comp = lambda model: model._encode.parameters()
|
|
878
|
+
fix_comp = lambda model: AE_func(model)._encode.parameters()
|
|
923
879
|
print("Fine tuning the basis ...")
|
|
924
880
|
return super().fit(*args, fix_comp=fix_comp, training_kwargs=kwargs)
|
|
925
881
|
|
|
@@ -932,10 +888,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
932
888
|
batch_size=None,
|
|
933
889
|
pre_func_inp=lambda x: x,
|
|
934
890
|
pre_func_out=lambda x: x,
|
|
935
|
-
|
|
891
|
+
device="cpu",
|
|
936
892
|
):
|
|
937
893
|
|
|
938
|
-
call_func = lambda x: self.model(pre_func_inp(x), apply_basis=self.basis, epsilon=None)
|
|
894
|
+
call_func = lambda x: self.model(pre_func_inp(x.to(device)), apply_basis=self.basis.to(device), epsilon=None).to("cpu")
|
|
939
895
|
res = super().evaluate(
|
|
940
896
|
x_train_o,
|
|
941
897
|
y_train_o,
|
|
@@ -945,6 +901,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
945
901
|
call_func=call_func,
|
|
946
902
|
pre_func_inp=pre_func_inp,
|
|
947
903
|
pre_func_out=pre_func_out,
|
|
904
|
+
device=device,
|
|
948
905
|
)
|
|
949
906
|
return res
|
|
950
907
|
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: RRAEsTorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.7
|
|
4
4
|
Summary: A repo for RRAEs in PyTorch.
|
|
5
5
|
Author-email: Jad Mounayer <jad.mounayer@outlook.com>
|
|
6
6
|
License: MIT
|
|
7
7
|
License-File: LICENSE
|
|
8
|
-
License-File: LICENSE copy
|
|
9
8
|
Requires-Python: >=3.10
|
|
10
9
|
Requires-Dist: dill
|
|
11
10
|
Requires-Dist: jaxtyping
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
|
|
2
|
+
RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
|
|
3
|
+
RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
|
|
4
|
+
RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
|
|
5
|
+
RRAEsTorch/AE_classes/AE_classes.py,sha256=iRIcA7iTTDO9Ji90ZKDUmizBswqKRIbnLcwm1hTbVnY,18096
|
|
6
|
+
RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
|
|
7
|
+
RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=nxWdxWhpoVbYxMsr-U9mmd7vZ8s0pBaLsDvCsti2r-I,2913
|
|
8
|
+
RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Aa8PhhsWblw8tJATVEfw9DLFA5vK-0lSzG6BySwmTiM,2413
|
|
9
|
+
RRAEsTorch/tests/test_fitting_CNN.py,sha256=VlWXxFMDul1MTMt1TciezYZpKifLTl1YsIFUAlAxc1w,2796
|
|
10
|
+
RRAEsTorch/tests/test_fitting_MLP.py,sha256=FVw6ObzOZJG5CntHoW45A1me5QFgIBI64rYPxjh1ujM,3288
|
|
11
|
+
RRAEsTorch/tests/test_mains.py,sha256=ivTXP7NypSlgmB9VR5g0yq5VEuPZJGOibDqBMjOxHow,1021
|
|
12
|
+
RRAEsTorch/tests/test_save.py,sha256=KdXwoF3ao7UVr8tEJ7sEHqNtUqIYmuoVEAvt4FigXIs,1982
|
|
13
|
+
RRAEsTorch/tests/test_stable_SVD.py,sha256=OimHPqw4f22qndyRzwJfNvTzzjP2CM-yHtfXCqkMBuA,1230
|
|
14
|
+
RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
|
|
15
|
+
RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
|
|
16
|
+
RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
|
|
17
|
+
RRAEsTorch/training_classes/training_classes.py,sha256=tZedyKvi3f-phiByIG820Z_oMsG5vk9xQK_OS_04jIM,34014
|
|
18
|
+
RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
|
|
19
|
+
rraestorch-0.1.7.dist-info/METADATA,sha256=fdYJFLO7CHpG2IxfaC1zTFwiA6Lf2cyPqHiueU0lvs0,3028
|
|
20
|
+
rraestorch-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
+
rraestorch-0.1.7.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
|
|
22
|
+
rraestorch-0.1.7.dist-info/RECORD,,
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
|
|
2
|
-
from torchvision.ops import MLP
|
|
3
|
-
import pytest
|
|
4
|
-
import numpy as np
|
|
5
|
-
import math
|
|
6
|
-
import numpy.random as random
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
def test_vmap_wrapper():
|
|
10
|
-
# Usually MLP only accepts a vector, here we give
|
|
11
|
-
# a tensor and vectorize over the last axis twice
|
|
12
|
-
data = random.normal(size=(50, 60, 600))
|
|
13
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
14
|
-
|
|
15
|
-
model_cls = vmap_wrap(MLP, -1, 2)
|
|
16
|
-
model = model_cls(50, [64, 100])
|
|
17
|
-
try:
|
|
18
|
-
model(data)
|
|
19
|
-
except ValueError:
|
|
20
|
-
pytest.fail("Vmap wrapper is not working properly.")
|
|
21
|
-
|
|
22
|
-
def test_norm_wrapper():
|
|
23
|
-
# Testing the keep_normalized kwarg
|
|
24
|
-
data = random.normal(size=(50,))
|
|
25
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
26
|
-
model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
|
|
27
|
-
model = model_cls(50, [64, 100])
|
|
28
|
-
try:
|
|
29
|
-
assert not torch.allclose(model(data), model(data, keep_normalized=True))
|
|
30
|
-
except AssertionError:
|
|
31
|
-
pytest.fail("The keep_normalized kwarg for norm wrapper is not behaving as expected.")
|
|
32
|
-
|
|
33
|
-
# Testing minmax with knwon mins and maxs
|
|
34
|
-
data = np.linspace(-1, 1, 100)
|
|
35
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
36
|
-
model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
|
|
37
|
-
model = model_cls(50, [64, 100])
|
|
38
|
-
try:
|
|
39
|
-
assert 0.55 == model.norm_in.default(None, 0.1)
|
|
40
|
-
assert -0.8 == model.inv_norm_out.default(None, 0.1)
|
|
41
|
-
except AssertionError:
|
|
42
|
-
pytest.fail("Something wrong with minmax wrapper.")
|
|
43
|
-
|
|
44
|
-
# Testing meanstd with knwon mean and std
|
|
45
|
-
data = random.normal(size=(50,))
|
|
46
|
-
data = (data-np.mean(data))/np.std(data)
|
|
47
|
-
data = data*2.0 + 1.0 # mean of 1 and std of 2
|
|
48
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
49
|
-
|
|
50
|
-
model_cls = norm_wrap(MLP, data, "meanstd", None, data, "meanstd", None)
|
|
51
|
-
model = model_cls(50, [64, 100])
|
|
52
|
-
try:
|
|
53
|
-
assert math.isclose(2, model.norm_in.default(None, 5), rel_tol=1e-1, abs_tol=1e-1)
|
|
54
|
-
assert math.isclose(7, model.inv_norm_out.default(None, 3), rel_tol=1e-1, abs_tol=1e-1)
|
|
55
|
-
except AssertionError:
|
|
56
|
-
pytest.fail("Something wrong with norm wrapper.")
|