RRAEsTorch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,6 @@ import os
13
13
  import time
14
14
  import dill
15
15
  import shutil
16
- from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
17
16
  from functools import partial
18
17
  from RRAEsTorch.trackers import (
19
18
  Null_Tracker,
@@ -26,6 +25,8 @@ from prettytable import PrettyTable
26
25
  import torch
27
26
  from torch.utils.data import TensorDataset, DataLoader
28
27
 
28
+ from RRAEsTorch.utilities import get_basis
29
+
29
30
  class Circular_list:
30
31
  """
31
32
  Creates a list of fixed size.
@@ -157,40 +158,17 @@ class Trainor_class:
157
158
  model_cls=None,
158
159
  folder="",
159
160
  file=None,
160
- out_train=None,
161
- norm_in="None",
162
- norm_out="None",
163
- methods_map=["__call__"],
164
- methods_norm_in=["__call__"],
165
- methods_norm_out=["__call__"],
166
- call_map_count=1,
167
- call_map_axis=-1,
168
161
  **kwargs,
169
162
  ):
170
163
  if model_cls is not None:
171
164
  orig_model_cls = model_cls
172
- model_cls = vmap_wrap(orig_model_cls, call_map_axis, call_map_count, methods_map)
173
- model_cls = norm_wrap(model_cls, in_train, norm_in, None, out_train, norm_out, None, methods_norm_in, methods_norm_out)
174
165
  self.model = model_cls(**kwargs)
175
- params_in = self.model.params_in
176
- params_out = self.model.params_out
177
166
  else:
178
167
  orig_model_cls = None
179
- params_in = None
180
- params_out = None
181
168
 
182
169
  self.all_kwargs = {
183
170
  "kwargs": kwargs,
184
- "params_in": params_in,
185
- "params_out": params_out,
186
- "norm_in": norm_in,
187
- "norm_out": norm_out,
188
- "call_map_axis": call_map_axis,
189
- "call_map_count": call_map_count,
190
- "orig_model_cls": orig_model_cls,
191
- "methods_map": methods_map,
192
- "methods_norm_in": methods_norm_in,
193
- "methods_norm_out": methods_norm_out,
171
+ "orig_model_cls": orig_model_cls
194
172
  }
195
173
 
196
174
  self.folder = folder
@@ -229,7 +207,8 @@ class Trainor_class:
229
207
  save_losses=False,
230
208
  input_val=None,
231
209
  output_val=None,
232
- latent_size=0
210
+ latent_size=0,
211
+ device="cpu"
233
212
  ):
234
213
  assert isinstance(input, torch.Tensor), "Input should be a torch tensor"
235
214
  assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
@@ -301,7 +280,12 @@ class Trainor_class:
301
280
  all_losses = []
302
281
 
303
282
 
304
-
283
+ model = model.to(device)
284
+
285
+ if input_val is not None:
286
+ dataset_val = TensorDataset(input_val, output_val, torch.arange(0, input_val.shape[0], 1))
287
+ dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[0], shuffle=False)
288
+
305
289
  # Outler Loop
306
290
  for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
307
291
  try:
@@ -310,15 +294,13 @@ class Trainor_class:
310
294
  filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
311
295
  )
312
296
 
313
- if (batch_size > input.shape[-1]) or batch_size == -1:
314
- print(f"Setting batch size to: {input.shape[-1]}")
315
- batch_size = input.shape[-1]
297
+ if (batch_size > input.shape[0]) or batch_size == -1:
298
+ print(f"Setting batch size to: {input.shape[0]}")
299
+ batch_size = input.shape[0]
316
300
 
317
301
  # Inner loop (batch)
318
- inputT = input.permute(*range(input.ndim - 1, -1, -1))
319
- outputT = output.permute(*range(output.ndim - 1, -1, -1))
320
302
 
321
- dataset = TensorDataset(inputT, outputT, torch.arange(0, input.shape[-1], 1))
303
+ dataset = TensorDataset(input, output, torch.arange(0, input.shape[0], 1))
322
304
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
323
305
  data_iter = iter(dataloader)
324
306
 
@@ -331,15 +313,16 @@ class Trainor_class:
331
313
  input_b, out_b, idx_b = next(data_iter)
332
314
 
333
315
  start_time = time.perf_counter() # Start time
334
- input_b = input_b.permute(*range(input_b.ndim - 1, -1, -1))
335
- out_b = self.model.norm_out.default(None, pre_func_out(out_b)) # Pre-process batch out values
336
- out_b = out_b.permute(*range(out_b.ndim - 1, -1, -1))
316
+ out_b = pre_func_out(out_b) # Pre-process batch out values
337
317
  input_b = pre_func_inp(input_b) # Pre-process batch input values
338
- epsilon = eps_fn(latent_size, input_b.shape[-1])
318
+ epsilon = eps_fn(latent_size, input_b.shape[0])
339
319
 
340
320
  step_kwargs = merge_dicts(loss_kwargs, track_params)
341
321
 
342
322
  # Compute loss
323
+ input_b = input_b.to(device)
324
+ out_b = out_b.to(device)
325
+
343
326
  loss, model, optimizer_tr, (aux, extra_track) = make_step(
344
327
  model,
345
328
  input_b,
@@ -351,12 +334,14 @@ class Trainor_class:
351
334
  )
352
335
 
353
336
  if input_val is not None:
354
-
355
-
356
- idx = np.arange(input_val.shape[-1])
357
- val_loss, _ = loss_fun(
358
- model, input_val, output_val, idx=idx, epsilon=None, **step_kwargs
359
- )
337
+ val_loss = []
338
+ for input_vb, out_vb, idx_b in dataloader_val:
339
+ out_vb = pre_func_out(out_vb)
340
+ val_loss_batch = loss_fun(
341
+ model, input_vb.to(device), out_vb.to(device), idx=idx_b, epsilon=None, **step_kwargs
342
+ )[0]
343
+ val_loss.append(val_loss_batch.item())
344
+ val_loss = sum(val_loss) / len(val_loss)
360
345
  aux["val_loss"] = val_loss
361
346
  else:
362
347
  aux["val_loss"] = None
@@ -496,18 +481,16 @@ class Trainor_class:
496
481
  hasattr(self, "batch_size") or batch_size is not None
497
482
  ), "You should either provide a batch_size or fit the model first."
498
483
 
499
- x_train_oT = x_train_o.permute(*range(x_train_o.ndim - 1, -1, -1))
500
- dataset = TensorDataset(x_train_oT)
484
+ dataset = TensorDataset(x_train_o)
501
485
  batch_size = self.batch_size if batch_size is None else batch_size
502
486
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
503
487
 
504
488
  pred = []
505
489
  for x_b in dataloader:
506
- x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
507
- pred_batch = call_func(x_bT)
490
+ pred_batch = call_func(x_b[0])
508
491
  pred.append(pred_batch)
509
492
 
510
- y_pred_train_o = torch.concatenate(pred, axis=-1)
493
+ y_pred_train_o = torch.concatenate(pred)
511
494
 
512
495
  self.error_train_o = (
513
496
  torch.linalg.norm(y_pred_train_o - y_train_o)
@@ -516,8 +499,8 @@ class Trainor_class:
516
499
  )
517
500
  print("Train error on original output: ", self.error_train_o)
518
501
 
519
- y_pred_train = self.model.norm_out.default(None, y_pred_train_o)
520
- y_train = self.model.norm_out.default(None, y_train_o)
502
+ y_pred_train = pre_func_out(y_pred_train_o)
503
+ y_train = pre_func_out(y_train_o)
521
504
  self.error_train = (
522
505
  torch.linalg.norm(y_pred_train - y_train) / torch.linalg.norm(y_train) * 100
523
506
  )
@@ -525,16 +508,14 @@ class Trainor_class:
525
508
 
526
509
  if x_test_o is not None:
527
510
  y_test_o = pre_func_out(y_test_o)
528
- x_test_oT = x_test_o.permute(*range(x_test_o.ndim - 1, -1, -1))
529
- dataset = TensorDataset(x_test_oT)
511
+ dataset = TensorDataset(x_test_o)
530
512
  batch_size = self.batch_size if batch_size is None else batch_size
531
513
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
532
514
  pred = []
533
515
  for x_b in dataloader:
534
- x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
535
- pred_batch = call_func(x_bT)
516
+ pred_batch = call_func(x_b[0])
536
517
  pred.append(pred_batch)
537
- y_pred_test_o = torch.concatenate(pred, axis=-1)
518
+ y_pred_test_o = torch.concatenate(pred)
538
519
  self.error_test_o = (
539
520
  torch.linalg.norm(y_pred_test_o - y_test_o)
540
521
  / torch.linalg.norm(y_test_o)
@@ -543,8 +524,8 @@ class Trainor_class:
543
524
 
544
525
  print("Test error on original output: ", self.error_test_o)
545
526
 
546
- y_test = self.model.norm_out.default(None, y_test_o)
547
- y_pred_test = self.model.norm_out.default(None, y_pred_test_o)
527
+ y_test = pre_func_out(y_test_o)
528
+ y_pred_test = pre_func_out(y_pred_test_o)
548
529
  self.error_test = (
549
530
  torch.linalg.norm(y_pred_test - y_test) / torch.linalg.norm(y_test) * 100
550
531
  )
@@ -618,7 +599,7 @@ class Trainor_class:
618
599
  dill.dump(obj, f)
619
600
  print(f"Object saved in {filename}")
620
601
 
621
- def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, **fn_kwargs):
602
+ def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, device="cpu",**fn_kwargs):
622
603
  """NOTE: fn_kwargs defines the functions of the model
623
604
  (e.g. final_activation, inner activation), if
624
605
  needed to be saved/loaded on different devices/OS.
@@ -638,29 +619,12 @@ class Trainor_class:
638
619
  else:
639
620
  orig_model_cls = orig_model_cls
640
621
  kwargs = self.all_kwargs["kwargs"]
641
- self.call_map_axis = self.all_kwargs["call_map_axis"]
642
- self.call_map_count = self.all_kwargs["call_map_count"]
643
- self.params_in = self.all_kwargs["params_in"]
644
- self.params_out = self.all_kwargs["params_out"]
645
- self.norm_in = self.all_kwargs["norm_in"]
646
- self.norm_out = self.all_kwargs["norm_out"]
647
- try:
648
- self.methods_map = self.all_kwargs["methods_map"]
649
- self.methods_norm_in = self.all_kwargs["methods_norm_in"]
650
- self.methods_norm_out = self.all_kwargs["methods_norm_out"]
651
- except:
652
- self.methods_map = ["encode", "decode"]
653
- self.methods_norm_in = ["encode"]
654
- self.methods_norm_out = ["decode"]
655
622
 
656
623
  kwargs.update(fn_kwargs)
657
-
658
- model_cls = vmap_wrap(orig_model_cls, self.call_map_axis, self.call_map_count, self.methods_map)
659
- model_cls = norm_wrap(model_cls, None, self.norm_in, self.params_in, None, self.norm_out, self.params_out, self.methods_norm_in, self.methods_norm_out)
660
624
 
661
- model = model_cls(**kwargs)
625
+ model = orig_model_cls(**kwargs)
662
626
  model.load_state_dict(save_dict["model_state_dict"])
663
- self.model = model
627
+ self.model = model.to(device)
664
628
  attributes = save_dict["attr"]
665
629
 
666
630
  for key in attributes:
@@ -671,7 +635,7 @@ class Trainor_class:
671
635
 
672
636
  class AE_Trainor_class(Trainor_class):
673
637
  def __init__(self, *args, **kwargs):
674
- super().__init__(*args, methods_map=["encode", "decode"], methods_norm_in=["encode"], methods_norm_out=["decode"], **kwargs)
638
+ super().__init__(*args, **kwargs)
675
639
 
676
640
  def fit(self, *args, training_kwargs, **kwargs):
677
641
  if "pre_func_inp" not in kwargs:
@@ -830,11 +794,11 @@ class RRAE_Trainor_class(AE_Trainor_class):
830
794
  self.batch_size = 16 # default value
831
795
 
832
796
  if ft_kwargs:
833
- if "get_basis" in ft_kwargs:
834
- get_basis = ft_kwargs["get_basis"]
835
- ft_kwargs.pop("get_basis")
797
+ if "get_basis_bool" in ft_kwargs:
798
+ get_basis_bool = ft_kwargs["get_basis_bool"]
799
+ ft_kwargs.pop("get_basis_bool")
836
800
  else:
837
- get_basis = True
801
+ get_basis_bool = True
838
802
 
839
803
  if "ft_end_type" in ft_kwargs:
840
804
  ft_end_type = ft_kwargs["ft_end_type"]
@@ -849,8 +813,16 @@ class RRAE_Trainor_class(AE_Trainor_class):
849
813
  ft_end_type = "concat"
850
814
  basis_call_kwargs = {}
851
815
 
816
+ device = ft_kwargs.get("device", "cpu")
817
+
818
+ if "AE_func" in ft_kwargs:
819
+ AE_func = ft_kwargs["AE_func"]
820
+ ft_kwargs.pop("AE_func")
821
+ else:
822
+ AE_func = lambda m:m
823
+
852
824
  ft_model, ft_track_params = self.fine_tune_basis(
853
- None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs
825
+ None, args=args, kwargs=ft_kwargs, get_basis_bool=get_basis_bool, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device, AE_func=AE_func
854
826
  ) # fine tune basis
855
827
  self.ft_track_params = ft_track_params
856
828
  else:
@@ -858,7 +830,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
858
830
  ft_track_params = {}
859
831
  return model, track_params, ft_model, ft_track_params
860
832
 
861
- def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, *, args, kwargs):
833
+ def fine_tune_basis(self, basis=None, get_basis_bool=True, end_type="concat", basis_call_kwargs={}, device="cpu", AE_func=lambda m:m, *, args, kwargs):
862
834
 
863
835
  if "loss" in kwargs:
864
836
  norm_loss_ = kwargs["loss"]
@@ -869,45 +841,29 @@ class RRAE_Trainor_class(AE_Trainor_class):
869
841
  )
870
842
 
871
843
  if (basis is None):
872
- with torch.no_grad():
873
- if get_basis:
874
- inp = args[0] if len(args) > 0 else kwargs["input"]
875
-
876
- if "basis_batch_size" in kwargs:
877
- basis_batch_size = kwargs["basis_batch_size"]
878
- kwargs.pop("basis_batch_size")
879
- else:
880
- basis_batch_size = self.batch_size
844
+ inp = args[0] if len(args) > 0 else kwargs["input"]
881
845
 
882
- basis_kwargs = basis_call_kwargs | self.track_params
883
-
884
- inpT = inp.permute(*range(inp.ndim - 1, -1, -1))
885
- dataset = TensorDataset(inpT)
886
- dataloader = DataLoader(dataset, batch_size=basis_batch_size, shuffle=False)
887
-
888
- all_bases = []
889
-
890
- for inp_b in dataloader:
891
- inp_bT = inp_b[0].permute(*range(inp_b[0].ndim - 1, -1, -1))
892
- all_bases.append(self.model.latent(
893
- self.pre_func_inp(inp_bT), get_basis_coeffs=True, **basis_kwargs
894
- )[0]
895
- )
896
- if end_type == "concat":
897
- all_bases = torch.concatenate(all_bases, axis=1)
898
- print(all_bases.shape)
899
- basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
900
- self.basis = basis[:, : self.track_params["k_max"]]
901
- else:
902
- self.basis = all_bases
903
- else:
904
- bas = self.model.latent(self.pre_func_inp(inp[..., 0:1]), get_basis_coeffs=True, **self.track_params)[0]
905
- self.basis = torch.eye(bas.shape[0])
846
+ if "basis_batch_size" in kwargs:
847
+ basis_batch_size = kwargs["basis_batch_size"]
848
+ kwargs.pop("basis_batch_size")
849
+ else:
850
+ basis_batch_size = self.batch_size
851
+
852
+ model = self.model.to(device)
853
+ k_max = self.track_params["k_max"]
854
+ if isinstance(AE_func, list):
855
+ bases = []
856
+ for func in AE_func:
857
+ bases.append(get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, func))
858
+ self.basis = bases
859
+ else:
860
+ self.basis = get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, AE_func)
861
+
906
862
  else:
907
863
  self.basis = basis
908
864
 
909
865
  def loss_fun(model, input, out, idx, epsilon, basis):
910
- pred = model(input, epsilon=epsilon, apply_basis=basis, keep_normalized=True)
866
+ pred = model(input, epsilon=epsilon, apply_basis=basis)
911
867
  aux = {"loss": norm_loss_(pred, out)}
912
868
  return norm_loss_(pred, out), (aux, {})
913
869
 
@@ -919,7 +875,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
919
875
 
920
876
  kwargs.setdefault("loss_kwargs", {}).update({"basis": self.basis})
921
877
 
922
- fix_comp = lambda model: model._encode.parameters()
878
+ fix_comp = lambda model: AE_func(model)._encode.parameters()
923
879
  print("Fine tuning the basis ...")
924
880
  return super().fit(*args, fix_comp=fix_comp, training_kwargs=kwargs)
925
881
 
@@ -932,10 +888,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
932
888
  batch_size=None,
933
889
  pre_func_inp=lambda x: x,
934
890
  pre_func_out=lambda x: x,
935
- call_func=None,
891
+ device="cpu",
936
892
  ):
937
893
 
938
- call_func = lambda x: self.model(pre_func_inp(x), apply_basis=self.basis, epsilon=None)
894
+ call_func = lambda x: self.model(pre_func_inp(x.to(device)), apply_basis=self.basis.to(device), epsilon=None).to("cpu")
939
895
  res = super().evaluate(
940
896
  x_train_o,
941
897
  y_train_o,
@@ -945,6 +901,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
945
901
  call_func=call_func,
946
902
  pre_func_inp=pre_func_inp,
947
903
  pre_func_out=pre_func_out,
904
+ device=device,
948
905
  )
949
906
  return res
950
907
 
@@ -1,11 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: RRAEsTorch
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: A repo for RRAEs in PyTorch.
5
5
  Author-email: Jad Mounayer <jad.mounayer@outlook.com>
6
6
  License: MIT
7
7
  License-File: LICENSE
8
- License-File: LICENSE copy
9
8
  Requires-Python: >=3.10
10
9
  Requires-Dist: dill
11
10
  Requires-Dist: jaxtyping
@@ -0,0 +1,22 @@
1
+ RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
2
+ RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
3
+ RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
4
+ RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
5
+ RRAEsTorch/AE_classes/AE_classes.py,sha256=iRIcA7iTTDO9Ji90ZKDUmizBswqKRIbnLcwm1hTbVnY,18096
6
+ RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
7
+ RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=nxWdxWhpoVbYxMsr-U9mmd7vZ8s0pBaLsDvCsti2r-I,2913
8
+ RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Aa8PhhsWblw8tJATVEfw9DLFA5vK-0lSzG6BySwmTiM,2413
9
+ RRAEsTorch/tests/test_fitting_CNN.py,sha256=VlWXxFMDul1MTMt1TciezYZpKifLTl1YsIFUAlAxc1w,2796
10
+ RRAEsTorch/tests/test_fitting_MLP.py,sha256=FVw6ObzOZJG5CntHoW45A1me5QFgIBI64rYPxjh1ujM,3288
11
+ RRAEsTorch/tests/test_mains.py,sha256=ivTXP7NypSlgmB9VR5g0yq5VEuPZJGOibDqBMjOxHow,1021
12
+ RRAEsTorch/tests/test_save.py,sha256=KdXwoF3ao7UVr8tEJ7sEHqNtUqIYmuoVEAvt4FigXIs,1982
13
+ RRAEsTorch/tests/test_stable_SVD.py,sha256=OimHPqw4f22qndyRzwJfNvTzzjP2CM-yHtfXCqkMBuA,1230
14
+ RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
15
+ RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
16
+ RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
17
+ RRAEsTorch/training_classes/training_classes.py,sha256=tZedyKvi3f-phiByIG820Z_oMsG5vk9xQK_OS_04jIM,34014
18
+ RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
19
+ rraestorch-0.1.7.dist-info/METADATA,sha256=fdYJFLO7CHpG2IxfaC1zTFwiA6Lf2cyPqHiueU0lvs0,3028
20
+ rraestorch-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
+ rraestorch-0.1.7.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
22
+ rraestorch-0.1.7.dist-info/RECORD,,
@@ -1,56 +0,0 @@
1
- from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
2
- from torchvision.ops import MLP
3
- import pytest
4
- import numpy as np
5
- import math
6
- import numpy.random as random
7
- import torch
8
-
9
- def test_vmap_wrapper():
10
- # Usually MLP only accepts a vector, here we give
11
- # a tensor and vectorize over the last axis twice
12
- data = random.normal(size=(50, 60, 600))
13
- data = torch.tensor(data, dtype=torch.float32)
14
-
15
- model_cls = vmap_wrap(MLP, -1, 2)
16
- model = model_cls(50, [64, 100])
17
- try:
18
- model(data)
19
- except ValueError:
20
- pytest.fail("Vmap wrapper is not working properly.")
21
-
22
- def test_norm_wrapper():
23
- # Testing the keep_normalized kwarg
24
- data = random.normal(size=(50,))
25
- data = torch.tensor(data, dtype=torch.float32)
26
- model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
27
- model = model_cls(50, [64, 100])
28
- try:
29
- assert not torch.allclose(model(data), model(data, keep_normalized=True))
30
- except AssertionError:
31
- pytest.fail("The keep_normalized kwarg for norm wrapper is not behaving as expected.")
32
-
33
- # Testing minmax with knwon mins and maxs
34
- data = np.linspace(-1, 1, 100)
35
- data = torch.tensor(data, dtype=torch.float32)
36
- model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
37
- model = model_cls(50, [64, 100])
38
- try:
39
- assert 0.55 == model.norm_in.default(None, 0.1)
40
- assert -0.8 == model.inv_norm_out.default(None, 0.1)
41
- except AssertionError:
42
- pytest.fail("Something wrong with minmax wrapper.")
43
-
44
- # Testing meanstd with knwon mean and std
45
- data = random.normal(size=(50,))
46
- data = (data-np.mean(data))/np.std(data)
47
- data = data*2.0 + 1.0 # mean of 1 and std of 2
48
- data = torch.tensor(data, dtype=torch.float32)
49
-
50
- model_cls = norm_wrap(MLP, data, "meanstd", None, data, "meanstd", None)
51
- model = model_cls(50, [64, 100])
52
- try:
53
- assert math.isclose(2, model.norm_in.default(None, 5), rel_tol=1e-1, abs_tol=1e-1)
54
- assert math.isclose(7, model.inv_norm_out.default(None, 3), rel_tol=1e-1, abs_tol=1e-1)
55
- except AssertionError:
56
- pytest.fail("Something wrong with norm wrapper.")