RRAEsTorch 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,6 @@ import os
13
13
  import time
14
14
  import dill
15
15
  import shutil
16
- from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
17
16
  from functools import partial
18
17
  from RRAEsTorch.trackers import (
19
18
  Null_Tracker,
@@ -26,6 +25,8 @@ from prettytable import PrettyTable
26
25
  import torch
27
26
  from torch.utils.data import TensorDataset, DataLoader
28
27
 
28
+ from RRAEsTorch.utilities import get_basis
29
+
29
30
  class Circular_list:
30
31
  """
31
32
  Creates a list of fixed size.
@@ -157,40 +158,17 @@ class Trainor_class:
157
158
  model_cls=None,
158
159
  folder="",
159
160
  file=None,
160
- out_train=None,
161
- norm_in="None",
162
- norm_out="None",
163
- methods_map=["__call__"],
164
- methods_norm_in=["__call__"],
165
- methods_norm_out=["__call__"],
166
- call_map_count=1,
167
- call_map_axis=-1,
168
161
  **kwargs,
169
162
  ):
170
163
  if model_cls is not None:
171
164
  orig_model_cls = model_cls
172
- model_cls = vmap_wrap(orig_model_cls, call_map_axis, call_map_count, methods_map)
173
- model_cls = norm_wrap(model_cls, in_train, norm_in, None, out_train, norm_out, None, methods_norm_in, methods_norm_out)
174
165
  self.model = model_cls(**kwargs)
175
- params_in = self.model.params_in
176
- params_out = self.model.params_out
177
166
  else:
178
167
  orig_model_cls = None
179
- params_in = None
180
- params_out = None
181
168
 
182
169
  self.all_kwargs = {
183
170
  "kwargs": kwargs,
184
- "params_in": params_in,
185
- "params_out": params_out,
186
- "norm_in": norm_in,
187
- "norm_out": norm_out,
188
- "call_map_axis": call_map_axis,
189
- "call_map_count": call_map_count,
190
- "orig_model_cls": orig_model_cls,
191
- "methods_map": methods_map,
192
- "methods_norm_in": methods_norm_in,
193
- "methods_norm_out": methods_norm_out,
171
+ "orig_model_cls": orig_model_cls
194
172
  }
195
173
 
196
174
  self.folder = folder
@@ -305,8 +283,8 @@ class Trainor_class:
305
283
  model = model.to(device)
306
284
 
307
285
  if input_val is not None:
308
- dataset_val = TensorDataset(input_val.permute(*range(input_val.ndim - 1, -1, -1)), output_val.permute(*range(output_val.ndim - 1, -1, -1)), torch.arange(0, input_val.shape[-1], 1))
309
- dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[-1], shuffle=False)
286
+ dataset_val = TensorDataset(input_val, output_val, torch.arange(0, input_val.shape[0], 1))
287
+ dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[0], shuffle=False)
310
288
 
311
289
  # Outler Loop
312
290
  for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
@@ -316,15 +294,13 @@ class Trainor_class:
316
294
  filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
317
295
  )
318
296
 
319
- if (batch_size > input.shape[-1]) or batch_size == -1:
320
- print(f"Setting batch size to: {input.shape[-1]}")
321
- batch_size = input.shape[-1]
297
+ if (batch_size > input.shape[0]) or batch_size == -1:
298
+ print(f"Setting batch size to: {input.shape[0]}")
299
+ batch_size = input.shape[0]
322
300
 
323
301
  # Inner loop (batch)
324
- inputT = input.permute(*range(input.ndim - 1, -1, -1))
325
- outputT = output.permute(*range(output.ndim - 1, -1, -1))
326
302
 
327
- dataset = TensorDataset(inputT, outputT, torch.arange(0, input.shape[-1], 1))
303
+ dataset = TensorDataset(input, output, torch.arange(0, input.shape[0], 1))
328
304
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
329
305
  data_iter = iter(dataloader)
330
306
 
@@ -337,11 +313,9 @@ class Trainor_class:
337
313
  input_b, out_b, idx_b = next(data_iter)
338
314
 
339
315
  start_time = time.perf_counter() # Start time
340
- input_b = input_b.permute(*range(input_b.ndim - 1, -1, -1))
341
- out_b = self.model.norm_out.default(None, pre_func_out(out_b)) # Pre-process batch out values
342
- out_b = out_b.permute(*range(out_b.ndim - 1, -1, -1))
316
+ out_b = pre_func_out(out_b) # Pre-process batch out values
343
317
  input_b = pre_func_inp(input_b) # Pre-process batch input values
344
- epsilon = eps_fn(latent_size, input_b.shape[-1])
318
+ epsilon = eps_fn(latent_size, input_b.shape[0])
345
319
 
346
320
  step_kwargs = merge_dicts(loss_kwargs, track_params)
347
321
 
@@ -362,10 +336,7 @@ class Trainor_class:
362
336
  if input_val is not None:
363
337
  val_loss = []
364
338
  for input_vb, out_vb, idx_b in dataloader_val:
365
- input_vb = input_vb.permute(*range(input_vb.ndim - 1, -1, -1))
366
- out_vb = self.model.norm_out.default(None, pre_func_out(out_vb)) # Pre-process batch out values
367
- out_vb = out_vb.permute(*range(out_vb.ndim - 1, -1, -1))
368
-
339
+ out_vb = pre_func_out(out_vb)
369
340
  val_loss_batch = loss_fun(
370
341
  model, input_vb.to(device), out_vb.to(device), idx=idx_b, epsilon=None, **step_kwargs
371
342
  )[0]
@@ -510,18 +481,16 @@ class Trainor_class:
510
481
  hasattr(self, "batch_size") or batch_size is not None
511
482
  ), "You should either provide a batch_size or fit the model first."
512
483
 
513
- x_train_oT = x_train_o.permute(*range(x_train_o.ndim - 1, -1, -1))
514
- dataset = TensorDataset(x_train_oT)
484
+ dataset = TensorDataset(x_train_o)
515
485
  batch_size = self.batch_size if batch_size is None else batch_size
516
486
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
517
487
 
518
488
  pred = []
519
489
  for x_b in dataloader:
520
- x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
521
- pred_batch = call_func(x_bT)
490
+ pred_batch = call_func(x_b[0])
522
491
  pred.append(pred_batch)
523
492
 
524
- y_pred_train_o = torch.concatenate(pred, axis=-1)
493
+ y_pred_train_o = torch.concatenate(pred)
525
494
 
526
495
  self.error_train_o = (
527
496
  torch.linalg.norm(y_pred_train_o - y_train_o)
@@ -530,8 +499,8 @@ class Trainor_class:
530
499
  )
531
500
  print("Train error on original output: ", self.error_train_o)
532
501
 
533
- y_pred_train = self.model.norm_out.default(None, y_pred_train_o)
534
- y_train = self.model.norm_out.default(None, y_train_o)
502
+ y_pred_train = pre_func_out(y_pred_train_o)
503
+ y_train = pre_func_out(y_train_o)
535
504
  self.error_train = (
536
505
  torch.linalg.norm(y_pred_train - y_train) / torch.linalg.norm(y_train) * 100
537
506
  )
@@ -539,16 +508,14 @@ class Trainor_class:
539
508
 
540
509
  if x_test_o is not None:
541
510
  y_test_o = pre_func_out(y_test_o)
542
- x_test_oT = x_test_o.permute(*range(x_test_o.ndim - 1, -1, -1))
543
- dataset = TensorDataset(x_test_oT)
511
+ dataset = TensorDataset(x_test_o)
544
512
  batch_size = self.batch_size if batch_size is None else batch_size
545
513
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
546
514
  pred = []
547
515
  for x_b in dataloader:
548
- x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
549
- pred_batch = call_func(x_bT)
516
+ pred_batch = call_func(x_b[0])
550
517
  pred.append(pred_batch)
551
- y_pred_test_o = torch.concatenate(pred, axis=-1)
518
+ y_pred_test_o = torch.concatenate(pred)
552
519
  self.error_test_o = (
553
520
  torch.linalg.norm(y_pred_test_o - y_test_o)
554
521
  / torch.linalg.norm(y_test_o)
@@ -557,8 +524,8 @@ class Trainor_class:
557
524
 
558
525
  print("Test error on original output: ", self.error_test_o)
559
526
 
560
- y_test = self.model.norm_out.default(None, y_test_o)
561
- y_pred_test = self.model.norm_out.default(None, y_pred_test_o)
527
+ y_test = pre_func_out(y_test_o)
528
+ y_pred_test = pre_func_out(y_pred_test_o)
562
529
  self.error_test = (
563
530
  torch.linalg.norm(y_pred_test - y_test) / torch.linalg.norm(y_test) * 100
564
531
  )
@@ -652,27 +619,10 @@ class Trainor_class:
652
619
  else:
653
620
  orig_model_cls = orig_model_cls
654
621
  kwargs = self.all_kwargs["kwargs"]
655
- self.call_map_axis = self.all_kwargs["call_map_axis"]
656
- self.call_map_count = self.all_kwargs["call_map_count"]
657
- self.params_in = self.all_kwargs["params_in"]
658
- self.params_out = self.all_kwargs["params_out"]
659
- self.norm_in = self.all_kwargs["norm_in"]
660
- self.norm_out = self.all_kwargs["norm_out"]
661
- try:
662
- self.methods_map = self.all_kwargs["methods_map"]
663
- self.methods_norm_in = self.all_kwargs["methods_norm_in"]
664
- self.methods_norm_out = self.all_kwargs["methods_norm_out"]
665
- except:
666
- self.methods_map = ["encode", "decode"]
667
- self.methods_norm_in = ["encode"]
668
- self.methods_norm_out = ["decode"]
669
622
 
670
623
  kwargs.update(fn_kwargs)
671
-
672
- model_cls = vmap_wrap(orig_model_cls, self.call_map_axis, self.call_map_count, self.methods_map)
673
- model_cls = norm_wrap(model_cls, None, self.norm_in, self.params_in, None, self.norm_out, self.params_out, self.methods_norm_in, self.methods_norm_out)
674
624
 
675
- model = model_cls(**kwargs)
625
+ model = orig_model_cls(**kwargs)
676
626
  model.load_state_dict(save_dict["model_state_dict"])
677
627
  self.model = model.to(device)
678
628
  attributes = save_dict["attr"]
@@ -685,7 +635,7 @@ class Trainor_class:
685
635
 
686
636
  class AE_Trainor_class(Trainor_class):
687
637
  def __init__(self, *args, **kwargs):
688
- super().__init__(*args, methods_map=["encode", "decode"], methods_norm_in=["encode"], methods_norm_out=["decode"], **kwargs)
638
+ super().__init__(*args, **kwargs)
689
639
 
690
640
  def fit(self, *args, training_kwargs, **kwargs):
691
641
  if "pre_func_inp" not in kwargs:
@@ -844,11 +794,11 @@ class RRAE_Trainor_class(AE_Trainor_class):
844
794
  self.batch_size = 16 # default value
845
795
 
846
796
  if ft_kwargs:
847
- if "get_basis" in ft_kwargs:
848
- get_basis = ft_kwargs["get_basis"]
849
- ft_kwargs.pop("get_basis")
797
+ if "get_basis_bool" in ft_kwargs:
798
+ get_basis_bool = ft_kwargs["get_basis_bool"]
799
+ ft_kwargs.pop("get_basis_bool")
850
800
  else:
851
- get_basis = True
801
+ get_basis_bool = True
852
802
 
853
803
  if "ft_end_type" in ft_kwargs:
854
804
  ft_end_type = ft_kwargs["ft_end_type"]
@@ -865,8 +815,14 @@ class RRAE_Trainor_class(AE_Trainor_class):
865
815
 
866
816
  device = ft_kwargs.get("device", "cpu")
867
817
 
818
+ if "AE_func" in ft_kwargs:
819
+ AE_func = ft_kwargs["AE_func"]
820
+ ft_kwargs.pop("AE_func")
821
+ else:
822
+ AE_func = lambda m:m
823
+
868
824
  ft_model, ft_track_params = self.fine_tune_basis(
869
- None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device
825
+ None, args=args, kwargs=ft_kwargs, get_basis_bool=get_basis_bool, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device, AE_func=AE_func
870
826
  ) # fine tune basis
871
827
  self.ft_track_params = ft_track_params
872
828
  else:
@@ -874,7 +830,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
874
830
  ft_track_params = {}
875
831
  return model, track_params, ft_model, ft_track_params
876
832
 
877
- def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, device="cpu", *, args, kwargs):
833
+ def fine_tune_basis(self, basis=None, get_basis_bool=True, end_type="concat", basis_call_kwargs={}, device="cpu", AE_func=lambda m:m, *, args, kwargs):
878
834
 
879
835
  if "loss" in kwargs:
880
836
  norm_loss_ = kwargs["loss"]
@@ -885,45 +841,29 @@ class RRAE_Trainor_class(AE_Trainor_class):
885
841
  )
886
842
 
887
843
  if (basis is None):
888
- with torch.no_grad():
889
- if get_basis:
890
- inp = args[0] if len(args) > 0 else kwargs["input"]
891
-
892
- if "basis_batch_size" in kwargs:
893
- basis_batch_size = kwargs["basis_batch_size"]
894
- kwargs.pop("basis_batch_size")
895
- else:
896
- basis_batch_size = self.batch_size
844
+ inp = args[0] if len(args) > 0 else kwargs["input"]
897
845
 
898
- basis_kwargs = basis_call_kwargs | self.track_params
899
-
900
- inpT = inp.permute(*range(inp.ndim - 1, -1, -1))
901
- dataset = TensorDataset(inpT)
902
- dataloader = DataLoader(dataset, batch_size=basis_batch_size, shuffle=False)
903
- model = self.model.to(device)
904
-
905
- all_bases = []
906
-
907
- for inp_b in dataloader:
908
- inp_bT = inp_b[0].permute(*range(inp_b[0].ndim - 1, -1, -1))
909
- all_bases.append(model.latent(
910
- self.pre_func_inp(inp_bT.to(device)), get_basis_coeffs=True, **basis_kwargs
911
- )[0].to("cpu")
912
- )
913
- if end_type == "concat":
914
- all_bases = torch.concatenate(all_bases, axis=1)
915
- basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
916
- self.basis = basis[:, : self.track_params["k_max"]]
917
- else:
918
- self.basis = all_bases
919
- else:
920
- bas = model.latent(self.pre_func_inp(inp[..., 0:1].to(device)), get_basis_coeffs=True, **self.track_params)[0].to("cpu")
921
- self.basis = torch.eye(bas.shape[0])
846
+ if "basis_batch_size" in kwargs:
847
+ basis_batch_size = kwargs["basis_batch_size"]
848
+ kwargs.pop("basis_batch_size")
849
+ else:
850
+ basis_batch_size = self.batch_size
851
+
852
+ model = self.model.to(device)
853
+ k_max = self.track_params["k_max"]
854
+ if isinstance(AE_func, list):
855
+ bases = []
856
+ for func in AE_func:
857
+ bases.append(get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, func))
858
+ self.basis = bases
859
+ else:
860
+ self.basis = get_basis(get_basis_bool, model, k_max, basis_batch_size, inp, end_type, device, basis_call_kwargs, self.pre_func_inp, AE_func)
861
+
922
862
  else:
923
863
  self.basis = basis
924
864
 
925
865
  def loss_fun(model, input, out, idx, epsilon, basis):
926
- pred = model(input, epsilon=epsilon, apply_basis=basis, keep_normalized=True)
866
+ pred = model(input, epsilon=epsilon, apply_basis=basis)
927
867
  aux = {"loss": norm_loss_(pred, out)}
928
868
  return norm_loss_(pred, out), (aux, {})
929
869
 
@@ -935,7 +875,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
935
875
 
936
876
  kwargs.setdefault("loss_kwargs", {}).update({"basis": self.basis})
937
877
 
938
- fix_comp = lambda model: model._encode.parameters()
878
+ fix_comp = lambda model: AE_func(model)._encode.parameters()
939
879
  print("Fine tuning the basis ...")
940
880
  return super().fit(*args, fix_comp=fix_comp, training_kwargs=kwargs)
941
881
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: RRAEsTorch
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: A repo for RRAEs in PyTorch.
5
5
  Author-email: Jad Mounayer <jad.mounayer@outlook.com>
6
6
  License: MIT
@@ -0,0 +1,22 @@
1
+ RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
2
+ RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
3
+ RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
4
+ RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
5
+ RRAEsTorch/AE_classes/AE_classes.py,sha256=iRIcA7iTTDO9Ji90ZKDUmizBswqKRIbnLcwm1hTbVnY,18096
6
+ RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
7
+ RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=nxWdxWhpoVbYxMsr-U9mmd7vZ8s0pBaLsDvCsti2r-I,2913
8
+ RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Aa8PhhsWblw8tJATVEfw9DLFA5vK-0lSzG6BySwmTiM,2413
9
+ RRAEsTorch/tests/test_fitting_CNN.py,sha256=VlWXxFMDul1MTMt1TciezYZpKifLTl1YsIFUAlAxc1w,2796
10
+ RRAEsTorch/tests/test_fitting_MLP.py,sha256=FVw6ObzOZJG5CntHoW45A1me5QFgIBI64rYPxjh1ujM,3288
11
+ RRAEsTorch/tests/test_mains.py,sha256=ivTXP7NypSlgmB9VR5g0yq5VEuPZJGOibDqBMjOxHow,1021
12
+ RRAEsTorch/tests/test_save.py,sha256=KdXwoF3ao7UVr8tEJ7sEHqNtUqIYmuoVEAvt4FigXIs,1982
13
+ RRAEsTorch/tests/test_stable_SVD.py,sha256=OimHPqw4f22qndyRzwJfNvTzzjP2CM-yHtfXCqkMBuA,1230
14
+ RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
15
+ RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
16
+ RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
17
+ RRAEsTorch/training_classes/training_classes.py,sha256=tZedyKvi3f-phiByIG820Z_oMsG5vk9xQK_OS_04jIM,34014
18
+ RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
19
+ rraestorch-0.1.7.dist-info/METADATA,sha256=fdYJFLO7CHpG2IxfaC1zTFwiA6Lf2cyPqHiueU0lvs0,3028
20
+ rraestorch-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
+ rraestorch-0.1.7.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
22
+ rraestorch-0.1.7.dist-info/RECORD,,
@@ -1,56 +0,0 @@
1
- from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
2
- from torchvision.ops import MLP
3
- import pytest
4
- import numpy as np
5
- import math
6
- import numpy.random as random
7
- import torch
8
-
9
- def test_vmap_wrapper():
10
- # Usually MLP only accepts a vector, here we give
11
- # a tensor and vectorize over the last axis twice
12
- data = random.normal(size=(50, 60, 600))
13
- data = torch.tensor(data, dtype=torch.float32)
14
-
15
- model_cls = vmap_wrap(MLP, -1, 2)
16
- model = model_cls(50, [64, 100])
17
- try:
18
- model(data)
19
- except ValueError:
20
- pytest.fail("Vmap wrapper is not working properly.")
21
-
22
- def test_norm_wrapper():
23
- # Testing the keep_normalized kwarg
24
- data = random.normal(size=(50,))
25
- data = torch.tensor(data, dtype=torch.float32)
26
- model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
27
- model = model_cls(50, [64, 100])
28
- try:
29
- assert not torch.allclose(model(data), model(data, keep_normalized=True))
30
- except AssertionError:
31
- pytest.fail("The keep_normalized kwarg for norm wrapper is not behaving as expected.")
32
-
33
- # Testing minmax with knwon mins and maxs
34
- data = np.linspace(-1, 1, 100)
35
- data = torch.tensor(data, dtype=torch.float32)
36
- model_cls = norm_wrap(MLP, data, "minmax", None, data, "minmax", None)
37
- model = model_cls(50, [64, 100])
38
- try:
39
- assert 0.55 == model.norm_in.default(None, 0.1)
40
- assert -0.8 == model.inv_norm_out.default(None, 0.1)
41
- except AssertionError:
42
- pytest.fail("Something wrong with minmax wrapper.")
43
-
44
- # Testing meanstd with knwon mean and std
45
- data = random.normal(size=(50,))
46
- data = (data-np.mean(data))/np.std(data)
47
- data = data*2.0 + 1.0 # mean of 1 and std of 2
48
- data = torch.tensor(data, dtype=torch.float32)
49
-
50
- model_cls = norm_wrap(MLP, data, "meanstd", None, data, "meanstd", None)
51
- model = model_cls(50, [64, 100])
52
- try:
53
- assert math.isclose(2, model.norm_in.default(None, 5), rel_tol=1e-1, abs_tol=1e-1)
54
- assert math.isclose(7, model.inv_norm_out.default(None, 3), rel_tol=1e-1, abs_tol=1e-1)
55
- except AssertionError:
56
- pytest.fail("Something wrong with norm wrapper.")