RRAEsTorch 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_classes/AE_classes.py +14 -12
- RRAEsTorch/tests/test_AE_classes_CNN.py +20 -26
- RRAEsTorch/tests/test_AE_classes_MLP.py +20 -28
- RRAEsTorch/tests/test_fitting_CNN.py +14 -14
- RRAEsTorch/tests/test_fitting_MLP.py +11 -13
- RRAEsTorch/tests/test_save.py +11 -11
- RRAEsTorch/training_classes/training_classes.py +55 -115
- {rraestorch-0.1.6.dist-info → rraestorch-0.1.7.dist-info}/METADATA +1 -1
- rraestorch-0.1.7.dist-info/RECORD +22 -0
- RRAEsTorch/tests/test_wrappers.py +0 -56
- RRAEsTorch/utilities/utilities.py +0 -1561
- RRAEsTorch/wrappers/__init__.py +0 -1
- RRAEsTorch/wrappers/wrappers.py +0 -237
- rraestorch-0.1.6.dist-info/RECORD +0 -26
- {rraestorch-0.1.6.dist-info → rraestorch-0.1.7.dist-info}/WHEEL +0 -0
- {rraestorch-0.1.6.dist-info → rraestorch-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,6 @@ import os
|
|
|
13
13
|
import time
|
|
14
14
|
import dill
|
|
15
15
|
import shutil
|
|
16
|
-
from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
|
|
17
16
|
from functools import partial
|
|
18
17
|
from RRAEsTorch.trackers import (
|
|
19
18
|
Null_Tracker,
|
|
@@ -26,6 +25,8 @@ from prettytable import PrettyTable
|
|
|
26
25
|
import torch
|
|
27
26
|
from torch.utils.data import TensorDataset, DataLoader
|
|
28
27
|
|
|
28
|
+
from RRAEsTorch.utilities import get_basis
|
|
29
|
+
|
|
29
30
|
class Circular_list:
|
|
30
31
|
"""
|
|
31
32
|
Creates a list of fixed size.
|
|
@@ -157,40 +158,17 @@ class Trainor_class:
|
|
|
157
158
|
model_cls=None,
|
|
158
159
|
folder="",
|
|
159
160
|
file=None,
|
|
160
|
-
out_train=None,
|
|
161
|
-
norm_in="None",
|
|
162
|
-
norm_out="None",
|
|
163
|
-
methods_map=["__call__"],
|
|
164
|
-
methods_norm_in=["__call__"],
|
|
165
|
-
methods_norm_out=["__call__"],
|
|
166
|
-
call_map_count=1,
|
|
167
|
-
call_map_axis=-1,
|
|
168
161
|
**kwargs,
|
|
169
162
|
):
|
|
170
163
|
if model_cls is not None:
|
|
171
164
|
orig_model_cls = model_cls
|
|
172
|
-
model_cls = vmap_wrap(orig_model_cls, call_map_axis, call_map_count, methods_map)
|
|
173
|
-
model_cls = norm_wrap(model_cls, in_train, norm_in, None, out_train, norm_out, None, methods_norm_in, methods_norm_out)
|
|
174
165
|
self.model = model_cls(**kwargs)
|
|
175
|
-
params_in = self.model.params_in
|
|
176
|
-
params_out = self.model.params_out
|
|
177
166
|
else:
|
|
178
167
|
orig_model_cls = None
|
|
179
|
-
params_in = None
|
|
180
|
-
params_out = None
|
|
181
168
|
|
|
182
169
|
self.all_kwargs = {
|
|
183
170
|
"kwargs": kwargs,
|
|
184
|
-
"
|
|
185
|
-
"params_out": params_out,
|
|
186
|
-
"norm_in": norm_in,
|
|
187
|
-
"norm_out": norm_out,
|
|
188
|
-
"call_map_axis": call_map_axis,
|
|
189
|
-
"call_map_count": call_map_count,
|
|
190
|
-
"orig_model_cls": orig_model_cls,
|
|
191
|
-
"methods_map": methods_map,
|
|
192
|
-
"methods_norm_in": methods_norm_in,
|
|
193
|
-
"methods_norm_out": methods_norm_out,
|
|
171
|
+
"orig_model_cls": orig_model_cls
|
|
194
172
|
}
|
|
195
173
|
|
|
196
174
|
self.folder = folder
|
|
@@ -305,8 +283,8 @@ class Trainor_class:
|
|
|
305
283
|
model = model.to(device)
|
|
306
284
|
|
|
307
285
|
if input_val is not None:
|
|
308
|
-
dataset_val = TensorDataset(input_val
|
|
309
|
-
dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[
|
|
286
|
+
dataset_val = TensorDataset(input_val, output_val, torch.arange(0, input_val.shape[0], 1))
|
|
287
|
+
dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[0], shuffle=False)
|
|
310
288
|
|
|
311
289
|
# Outler Loop
|
|
312
290
|
for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
|
|
@@ -316,15 +294,13 @@ class Trainor_class:
|
|
|
316
294
|
filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
|
|
317
295
|
)
|
|
318
296
|
|
|
319
|
-
if (batch_size > input.shape[
|
|
320
|
-
print(f"Setting batch size to: {input.shape[
|
|
321
|
-
batch_size = input.shape[
|
|
297
|
+
if (batch_size > input.shape[0]) or batch_size == -1:
|
|
298
|
+
print(f"Setting batch size to: {input.shape[0]}")
|
|
299
|
+
batch_size = input.shape[0]
|
|
322
300
|
|
|
323
301
|
# Inner loop (batch)
|
|
324
|
-
inputT = input.permute(*range(input.ndim - 1, -1, -1))
|
|
325
|
-
outputT = output.permute(*range(output.ndim - 1, -1, -1))
|
|
326
302
|
|
|
327
|
-
dataset = TensorDataset(
|
|
303
|
+
dataset = TensorDataset(input, output, torch.arange(0, input.shape[0], 1))
|
|
328
304
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
329
305
|
data_iter = iter(dataloader)
|
|
330
306
|
|
|
@@ -337,11 +313,9 @@ class Trainor_class:
|
|
|
337
313
|
input_b, out_b, idx_b = next(data_iter)
|
|
338
314
|
|
|
339
315
|
start_time = time.perf_counter() # Start time
|
|
340
|
-
|
|
341
|
-
out_b = self.model.norm_out.default(None, pre_func_out(out_b)) # Pre-process batch out values
|
|
342
|
-
out_b = out_b.permute(*range(out_b.ndim - 1, -1, -1))
|
|
316
|
+
out_b = pre_func_out(out_b) # Pre-process batch out values
|
|
343
317
|
input_b = pre_func_inp(input_b) # Pre-process batch input values
|
|
344
|
-
epsilon = eps_fn(latent_size, input_b.shape[
|
|
318
|
+
epsilon = eps_fn(latent_size, input_b.shape[0])
|
|
345
319
|
|
|
346
320
|
step_kwargs = merge_dicts(loss_kwargs, track_params)
|
|
347
321
|
|
|
@@ -362,10 +336,7 @@ class Trainor_class:
|
|
|
362
336
|
if input_val is not None:
|
|
363
337
|
val_loss = []
|
|
364
338
|
for input_vb, out_vb, idx_b in dataloader_val:
|
|
365
|
-
|
|
366
|
-
out_vb = self.model.norm_out.default(None, pre_func_out(out_vb)) # Pre-process batch out values
|
|
367
|
-
out_vb = out_vb.permute(*range(out_vb.ndim - 1, -1, -1))
|
|
368
|
-
|
|
339
|
+
out_vb = pre_func_out(out_vb)
|
|
369
340
|
val_loss_batch = loss_fun(
|
|
370
341
|
model, input_vb.to(device), out_vb.to(device), idx=idx_b, epsilon=None, **step_kwargs
|
|
371
342
|
)[0]
|
|
@@ -510,18 +481,16 @@ class Trainor_class:
|
|
|
510
481
|
hasattr(self, "batch_size") or batch_size is not None
|
|
511
482
|
), "You should either provide a batch_size or fit the model first."
|
|
512
483
|
|
|
513
|
-
|
|
514
|
-
dataset = TensorDataset(x_train_oT)
|
|
484
|
+
dataset = TensorDataset(x_train_o)
|
|
515
485
|
batch_size = self.batch_size if batch_size is None else batch_size
|
|
516
486
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
517
487
|
|
|
518
488
|
pred = []
|
|
519
489
|
for x_b in dataloader:
|
|
520
|
-
|
|
521
|
-
pred_batch = call_func(x_bT)
|
|
490
|
+
pred_batch = call_func(x_b[0])
|
|
522
491
|
pred.append(pred_batch)
|
|
523
492
|
|
|
524
|
-
y_pred_train_o = torch.concatenate(pred
|
|
493
|
+
y_pred_train_o = torch.concatenate(pred)
|
|
525
494
|
|
|
526
495
|
self.error_train_o = (
|
|
527
496
|
torch.linalg.norm(y_pred_train_o - y_train_o)
|
|
@@ -530,8 +499,8 @@ class Trainor_class:
|
|
|
530
499
|
)
|
|
531
500
|
print("Train error on original output: ", self.error_train_o)
|
|
532
501
|
|
|
533
|
-
y_pred_train =
|
|
534
|
-
y_train =
|
|
502
|
+
y_pred_train = pre_func_out(y_pred_train_o)
|
|
503
|
+
y_train = pre_func_out(y_train_o)
|
|
535
504
|
self.error_train = (
|
|
536
505
|
torch.linalg.norm(y_pred_train - y_train) / torch.linalg.norm(y_train) * 100
|
|
537
506
|
)
|
|
@@ -539,16 +508,14 @@ class Trainor_class:
|
|
|
539
508
|
|
|
540
509
|
if x_test_o is not None:
|
|
541
510
|
y_test_o = pre_func_out(y_test_o)
|
|
542
|
-
|
|
543
|
-
dataset = TensorDataset(x_test_oT)
|
|
511
|
+
dataset = TensorDataset(x_test_o)
|
|
544
512
|
batch_size = self.batch_size if batch_size is None else batch_size
|
|
545
513
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
546
514
|
pred = []
|
|
547
515
|
for x_b in dataloader:
|
|
548
|
-
|
|
549
|
-
pred_batch = call_func(x_bT)
|
|
516
|
+
pred_batch = call_func(x_b[0])
|
|
550
517
|
pred.append(pred_batch)
|
|
551
|
-
y_pred_test_o = torch.concatenate(pred
|
|
518
|
+
y_pred_test_o = torch.concatenate(pred)
|
|
552
519
|
self.error_test_o = (
|
|
553
520
|
torch.linalg.norm(y_pred_test_o - y_test_o)
|
|
554
521
|
/ torch.linalg.norm(y_test_o)
|
|
@@ -557,8 +524,8 @@ class Trainor_class:
|
|
|
557
524
|
|
|
558
525
|
print("Test error on original output: ", self.error_test_o)
|
|
559
526
|
|
|
560
|
-
y_test =
|
|
561
|
-
y_pred_test =
|
|
527
|
+
y_test = pre_func_out(y_test_o)
|
|
528
|
+
y_pred_test = pre_func_out(y_pred_test_o)
|
|
562
529
|
self.error_test = (
|
|
563
530
|
torch.linalg.norm(y_pred_test - y_test) / torch.linalg.norm(y_test) * 100
|
|
564
531
|
)
|
|
@@ -652,27 +619,10 @@ class Trainor_class:
|
|
|
652
619
|
else:
|
|
653
620
|
orig_model_cls = orig_model_cls
|
|
654
621
|
kwargs = self.all_kwargs["kwargs"]
|
|
655
|
-
self.call_map_axis = self.all_kwargs["call_map_axis"]
|
|
656
|
-
self.call_map_count = self.all_kwargs["call_map_count"]
|
|
657
|
-
self.params_in = self.all_kwargs["params_in"]
|
|
658
|
-
self.params_out = self.all_kwargs["params_out"]
|
|
659
|
-
self.norm_in = self.all_kwargs["norm_in"]
|
|
660
|
-
self.norm_out = self.all_kwargs["norm_out"]
|
|
661
|
-
try:
|
|
662
|
-
self.methods_map = self.all_kwargs["methods_map"]
|
|
663
|
-
self.methods_norm_in = self.all_kwargs["methods_norm_in"]
|
|
664
|
-
self.methods_norm_out = self.all_kwargs["methods_norm_out"]
|
|
665
|
-
except:
|
|
666
|
-
self.methods_map = ["encode", "decode"]
|
|
667
|
-
self.methods_norm_in = ["encode"]
|
|
668
|
-
self.methods_norm_out = ["decode"]
|
|
669
622
|
|
|
670
623
|
kwargs.update(fn_kwargs)
|
|
671
|
-
|
|
672
|
-
model_cls = vmap_wrap(orig_model_cls, self.call_map_axis, self.call_map_count, self.methods_map)
|
|
673
|
-
model_cls = norm_wrap(model_cls, None, self.norm_in, self.params_in, None, self.norm_out, self.params_out, self.methods_norm_in, self.methods_norm_out)
|
|
674
624
|
|
|
675
|
-
model =
|
|
625
|
+
model = orig_model_cls(**kwargs)
|
|
676
626
|
model.load_state_dict(save_dict["model_state_dict"])
|
|
677
627
|
self.model = model.to(device)
|
|
678
628
|
attributes = save_dict["attr"]
|
|
@@ -685,7 +635,7 @@ class Trainor_class:
|
|
|
685
635
|
|
|
686
636
|
class AE_Trainor_class(Trainor_class):
|
|
687
637
|
def __init__(self, *args, **kwargs):
|
|
688
|
-
super().__init__(*args,
|
|
638
|
+
super().__init__(*args, **kwargs)
|
|
689
639
|
|
|
690
640
|
def fit(self, *args, training_kwargs, **kwargs):
|
|
691
641
|
if "pre_func_inp" not in kwargs:
|
|
@@ -844,11 +794,11 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
844
794
|
self.batch_size = 16 # default value
|
|
845
795
|
|
|
846
796
|
if ft_kwargs:
|
|
847
|
-
if "
|
|
848
|
-
|
|
849
|
-
ft_kwargs.pop("
|
|
797
|
+
if "get_basis_bool" in ft_kwargs:
|
|
798
|
+
get_basis_bool = ft_kwargs["get_basis_bool"]
|
|
799
|
+
ft_kwargs.pop("get_basis_bool")
|
|
850
800
|
else:
|
|
851
|
-
|
|
801
|
+
get_basis_bool = True
|
|
852
802
|
|
|
853
803
|
if "ft_end_type" in ft_kwargs:
|
|
854
804
|
ft_end_type = ft_kwargs["ft_end_type"]
|
|
@@ -865,8 +815,14 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
865
815
|
|
|
866
816
|
device = ft_kwargs.get("device", "cpu")
|
|
867
817
|
|
|
818
|
+
if "AE_func" in ft_kwargs:
|
|
819
|
+
AE_func = ft_kwargs["AE_func"]
|
|
820
|
+
ft_kwargs.pop("AE_func")
|
|
821
|
+
else:
|
|
822
|
+
AE_func = lambda m:m
|
|
823
|
+
|
|
868
824
|
ft_model, ft_track_params = self.fine_tune_basis(
|
|
869
|
-
None, args=args, kwargs=ft_kwargs,
|
|
825
|
+
None, args=args, kwargs=ft_kwargs, get_basis_bool=get_basis_bool, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device, AE_func=AE_func
|
|
870
826
|
) # fine tune basis
|
|
871
827
|
self.ft_track_params = ft_track_params
|
|
872
828
|
else:
|
|
@@ -874,7 +830,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
874
830
|
ft_track_params = {}
|
|
875
831
|
return model, track_params, ft_model, ft_track_params
|
|
876
832
|
|
|
877
|
-
def fine_tune_basis(self, basis=None,
|
|
833
|
+
def fine_tune_basis(self, basis=None, get_basis_bool=True, end_type="concat", basis_call_kwargs={}, device="cpu", AE_func=lambda m:m, *, args, kwargs):
|
|
878
834
|
|
|
879
835
|
if "loss" in kwargs:
|
|
880
836
|
norm_loss_ = kwargs["loss"]
|
|
@@ -885,45 +841,29 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
885
841
|
)
|
|
886
842
|
|
|
887
843
|
if (basis is None):
|
|
888
|
-
|
|
889
|
-
if get_basis:
|
|
890
|
-
inp = args[0] if len(args) > 0 else kwargs["input"]
|
|
891
|
-
|
|
892
|
-
if "basis_batch_size" in kwargs:
|
|
893
|
-
basis_batch_size = kwargs["basis_batch_size"]
|
|
894
|
-
kwargs.pop("basis_batch_size")
|
|
895
|
-
else:
|
|
896
|
-
basis_batch_size = self.batch_size
|
|
844
|
+
inp = args[0] if len(args) > 0 else kwargs["input"]
|
|
897
845
|
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
all_bases = torch.concatenate(all_bases, axis=1)
|
|
915
|
-
basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
|
|
916
|
-
self.basis = basis[:, : self.track_params["k_max"]]
|
|
917
|
-
else:
|
|
918
|
-
self.basis = all_bases
|
|
919
|
-
else:
|
|
920
|
-
bas = model.latent(self.pre_func_inp(inp[..., 0:1].to(device)), get_basis_coeffs=True, **self.track_params)[0].to("cpu")
|
|
921
|
-
self.basis = torch.eye(bas.shape[0])
|
|
846
|
+
if "basis_batch_size" in kwargs:
|
|
847
|
+
basis_batch_size = kwargs["basis_batch_size"]
|
|
848
|
+
kwargs.pop("basis_batch_size")
|
|
849
|
+
else:
|
|
850
|
+
basis_batch_size = self.batch_size
|
|
851
|
+
|
|
852
|
+
model = self.model.to(device)
|
|
853
|
+
k_max = self.track_params["k_max"]
|
|
854
|
+
if isinstance(AE_func, list):
|
|
855
|
+
bases = []
|
|
856
|
+
for func in AE_func:
|
|
857
|
+
bases.append(get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, func))
|
|
858
|
+
self.basis = bases
|
|
859
|
+
else:
|
|
860
|
+
self.basis = get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, AE_func)
|
|
861
|
+
|
|
922
862
|
else:
|
|
923
863
|
self.basis = basis
|
|
924
864
|
|
|
925
865
|
def loss_fun(model, input, out, idx, epsilon, basis):
|
|
926
|
-
pred = model(input, epsilon=epsilon, apply_basis=basis
|
|
866
|
+
pred = model(input, epsilon=epsilon, apply_basis=basis)
|
|
927
867
|
aux = {"loss": norm_loss_(pred, out)}
|
|
928
868
|
return norm_loss_(pred, out), (aux, {})
|
|
929
869
|
|
|
@@ -935,7 +875,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
935
875
|
|
|
936
876
|
kwargs.setdefault("loss_kwargs", {}).update({"basis": self.basis})
|
|
937
877
|
|
|
938
|
-
fix_comp = lambda model: model._encode.parameters()
|
|
878
|
+
fix_comp = lambda model: AE_func(model)._encode.parameters()
|
|
939
879
|
print("Fine tuning the basis ...")
|
|
940
880
|
return super().fit(*args, fix_comp=fix_comp, training_kwargs=kwargs)
|
|
941
881
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
|
|
2
|
+
RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
|
|
3
|
+
RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
|
|
4
|
+
RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
|
|
5
|
+
RRAEsTorch/AE_classes/AE_classes.py,sha256=iRIcA7iTTDO9Ji90ZKDUmizBswqKRIbnLcwm1hTbVnY,18096
|
|
6
|
+
RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
|
|
7
|
+
RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=nxWdxWhpoVbYxMsr-U9mmd7vZ8s0pBaLsDvCsti2r-I,2913
|
|
8
|
+
RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Aa8PhhsWblw8tJATVEfw9DLFA5vK-0lSzG6BySwmTiM,2413
|
|
9
|
+
RRAEsTorch/tests/test_fitting_CNN.py,sha256=VlWXxFMDul1MTMt1TciezYZpKifLTl1YsIFUAlAxc1w,2796
|
|
10
|
+
RRAEsTorch/tests/test_fitting_MLP.py,sha256=FVw6ObzOZJG5CntHoW45A1me5QFgIBI64rYPxjh1ujM,3288
|
|
11
|
+
RRAEsTorch/tests/test_mains.py,sha256=ivTXP7NypSlgmB9VR5g0yq5VEuPZJGOibDqBMjOxHow,1021
|
|
12
|
+
RRAEsTorch/tests/test_save.py,sha256=KdXwoF3ao7UVr8tEJ7sEHqNtUqIYmuoVEAvt4FigXIs,1982
|
|
13
|
+
RRAEsTorch/tests/test_stable_SVD.py,sha256=OimHPqw4f22qndyRzwJfNvTzzjP2CM-yHtfXCqkMBuA,1230
|
|
14
|
+
RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
|
|
15
|
+
RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
|
|
16
|
+
RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
|
|
17
|
+
RRAEsTorch/training_classes/training_classes.py,sha256=tZedyKvi3f-phiByIG820Z_oMsG5vk9xQK_OS_04jIM,34014
|
|
18
|
+
RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
|
|
19
|
+
rraestorch-0.1.7.dist-info/METADATA,sha256=fdYJFLO7CHpG2IxfaC1zTFwiA6Lf2cyPqHiueU0lvs0,3028
|
|
20
|
+
rraestorch-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
+
rraestorch-0.1.7.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
|
|
22
|
+
rraestorch-0.1.7.dist-info/RECORD,,
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
|
|
2
|
-
from torchvision.ops import MLP
|
|
3
|
-
import pytest
|
|
4
|
-
import numpy as np
|
|
5
|
-
import math
|
|
6
|
-
import numpy.random as random
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
def test_vmap_wrapper():
|
|
10
|
-
# Usually MLP only accepts a vector, here we give
|
|
11
|
-
# a tensor and vectorize over the last axis twice
|
|
12
|
-
data = random.normal(size=(50, 60, 600))
|
|
13
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
14
|
-
|
|
15
|
-
model_cls = vmap_wrap(MLP, -1, 2)
|
|
16
|
-
model = model_cls(50, [64, 100])
|
|
17
|
-
try:
|
|
18
|
-
model(data)
|
|
19
|
-
except ValueError:
|
|
20
|
-
pytest.fail("Vmap wrapper is not working properly.")
|
|
21
|
-
|
|
22
|
-
def test_norm_wrapper():
|
|
23
|
-
# Testing the keep_normalized kwarg
|
|
24
|
-
data = random.normal(size=(50,))
|
|
25
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
26
|
-
model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
|
|
27
|
-
model = model_cls(50, [64, 100])
|
|
28
|
-
try:
|
|
29
|
-
assert not torch.allclose(model(data), model(data, keep_normalized=True))
|
|
30
|
-
except AssertionError:
|
|
31
|
-
pytest.fail("The keep_normalized kwarg for norm wrapper is not behaving as expected.")
|
|
32
|
-
|
|
33
|
-
# Testing minmax with knwon mins and maxs
|
|
34
|
-
data = np.linspace(-1, 1, 100)
|
|
35
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
36
|
-
model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
|
|
37
|
-
model = model_cls(50, [64, 100])
|
|
38
|
-
try:
|
|
39
|
-
assert 0.55 == model.norm_in.default(None, 0.1)
|
|
40
|
-
assert -0.8 == model.inv_norm_out.default(None, 0.1)
|
|
41
|
-
except AssertionError:
|
|
42
|
-
pytest.fail("Something wrong with minmax wrapper.")
|
|
43
|
-
|
|
44
|
-
# Testing meanstd with knwon mean and std
|
|
45
|
-
data = random.normal(size=(50,))
|
|
46
|
-
data = (data-np.mean(data))/np.std(data)
|
|
47
|
-
data = data*2.0 + 1.0 # mean of 1 and std of 2
|
|
48
|
-
data = torch.tensor(data, dtype=torch.float32)
|
|
49
|
-
|
|
50
|
-
model_cls = norm_wrap(MLP, data, "meanstd", None, data, "meanstd", None)
|
|
51
|
-
model = model_cls(50, [64, 100])
|
|
52
|
-
try:
|
|
53
|
-
assert math.isclose(2, model.norm_in.default(None, 5), rel_tol=1e-1, abs_tol=1e-1)
|
|
54
|
-
assert math.isclose(7, model.inv_norm_out.default(None, 3), rel_tol=1e-1, abs_tol=1e-1)
|
|
55
|
-
except AssertionError:
|
|
56
|
-
pytest.fail("Something wrong with norm wrapper.")
|