RRAEsTorch 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {rraestorch-0.1.5 → rraestorch-0.1.6}/PKG-INFO +1 -2
  2. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/AE_classes/AE_classes.py +7 -5
  3. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/training_classes/training_classes.py +35 -18
  4. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/utilities/utilities.py +0 -1
  5. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-CNN.py +6 -2
  6. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-CNN1D.py +7 -3
  7. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-CNN3D.py +9 -1
  8. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-MLP.py +7 -2
  9. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-adap-CNN.py +7 -3
  10. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-adap-MLP.py +7 -3
  11. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-var-CNN.py +9 -4
  12. {rraestorch-0.1.5 → rraestorch-0.1.6}/main-var-CNN1D.py +9 -4
  13. {rraestorch-0.1.5 → rraestorch-0.1.6}/pyproject.toml +1 -1
  14. rraestorch-0.1.5/.gitignore copy +0 -207
  15. rraestorch-0.1.5/LICENSE copy +0 -21
  16. {rraestorch-0.1.5 → rraestorch-0.1.6}/.github/workflows/python-app.yml +0 -0
  17. {rraestorch-0.1.5 → rraestorch-0.1.6}/.gitignore +0 -0
  18. {rraestorch-0.1.5 → rraestorch-0.1.6}/LICENSE +0 -0
  19. {rraestorch-0.1.5 → rraestorch-0.1.6}/README copy.md +0 -0
  20. {rraestorch-0.1.5 → rraestorch-0.1.6}/README.md +0 -0
  21. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/AE_base/AE_base.py +0 -0
  22. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/AE_base/__init__.py +0 -0
  23. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/AE_classes/__init__.py +0 -0
  24. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/__init__.py +0 -0
  25. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/config.py +0 -0
  26. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_AE_classes_CNN.py +0 -0
  27. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_AE_classes_MLP.py +0 -0
  28. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_fitting_CNN.py +0 -0
  29. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_fitting_MLP.py +0 -0
  30. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_mains.py +0 -0
  31. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_save.py +0 -0
  32. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_stable_SVD.py +0 -0
  33. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/tests/test_wrappers.py +0 -0
  34. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/trackers/__init__.py +0 -0
  35. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/trackers/trackers.py +0 -0
  36. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/training_classes/__init__.py +0 -0
  37. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/utilities/__init__.py +0 -0
  38. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/wrappers/__init__.py +0 -0
  39. {rraestorch-0.1.5 → rraestorch-0.1.6}/RRAEsTorch/wrappers/wrappers.py +0 -0
  40. {rraestorch-0.1.5 → rraestorch-0.1.6}/general-MLP.py +0 -0
  41. {rraestorch-0.1.5 → rraestorch-0.1.6}/setup.cfg +0 -0
@@ -1,11 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: RRAEsTorch
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A repo for RRAEs in PyTorch.
5
5
  Author-email: Jad Mounayer <jad.mounayer@outlook.com>
6
6
  License: MIT
7
7
  License-File: LICENSE
8
- License-File: LICENSE copy
9
8
  Requires-Python: >=3.10
10
9
  Requires-Dist: dill
11
10
  Requires-Dist: jaxtyping
@@ -127,7 +127,7 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
127
127
  if epsilon is not None:
128
128
  if len(epsilon.shape) == 4:
129
129
  epsilon = epsilon[0, 0] # to allow tpu sharding
130
- z = mean + torch.tensor(epsilon, dtype=torch.float32) * std
130
+ z = mean + epsilon * std
131
131
  else:
132
132
  z = mean
133
133
 
@@ -612,10 +612,6 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
612
612
  typ: int
613
613
 
614
614
  def __init__(self, channels, input_dim, latent_size, k_max, typ="eye", *, count=1, **kwargs):
615
- v_Linear = vmap_wrap(Linear, -1, count=count)
616
- self.lin_mean = v_Linear(k_max, k_max,)
617
- self.lin_logvar = v_Linear(k_max, k_max)
618
- self.typ = typ
619
615
  super().__init__(
620
616
  channels,
621
617
  input_dim,
@@ -623,6 +619,12 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
623
619
  count=count,
624
620
  **kwargs,
625
621
  )
622
+
623
+ v_Linear = vmap_wrap(Linear, -1, count=count)
624
+ self.lin_mean = v_Linear(k_max, k_max,)
625
+ self.lin_logvar = v_Linear(k_max, k_max)
626
+ self.typ = typ
627
+
626
628
 
627
629
  def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
628
630
  return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
@@ -229,7 +229,8 @@ class Trainor_class:
229
229
  save_losses=False,
230
230
  input_val=None,
231
231
  output_val=None,
232
- latent_size=0
232
+ latent_size=0,
233
+ device="cpu"
233
234
  ):
234
235
  assert isinstance(input, torch.Tensor), "Input should be a torch tensor"
235
236
  assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
@@ -301,7 +302,12 @@ class Trainor_class:
301
302
  all_losses = []
302
303
 
303
304
 
304
-
305
+ model = model.to(device)
306
+
307
+ if input_val is not None:
308
+ dataset_val = TensorDataset(input_val.permute(*range(input_val.ndim - 1, -1, -1)), output_val.permute(*range(output_val.ndim - 1, -1, -1)), torch.arange(0, input_val.shape[-1], 1))
309
+ dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[-1], shuffle=False)
310
+
305
311
  # Outler Loop
306
312
  for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
307
313
  try:
@@ -340,6 +346,9 @@ class Trainor_class:
340
346
  step_kwargs = merge_dicts(loss_kwargs, track_params)
341
347
 
342
348
  # Compute loss
349
+ input_b = input_b.to(device)
350
+ out_b = out_b.to(device)
351
+
343
352
  loss, model, optimizer_tr, (aux, extra_track) = make_step(
344
353
  model,
345
354
  input_b,
@@ -351,12 +360,17 @@ class Trainor_class:
351
360
  )
352
361
 
353
362
  if input_val is not None:
363
+ val_loss = []
364
+ for input_vb, out_vb, idx_b in dataloader_val:
365
+ input_vb = input_vb.permute(*range(input_vb.ndim - 1, -1, -1))
366
+ out_vb = self.model.norm_out.default(None, pre_func_out(out_vb)) # Pre-process batch out values
367
+ out_vb = out_vb.permute(*range(out_vb.ndim - 1, -1, -1))
354
368
 
355
-
356
- idx = np.arange(input_val.shape[-1])
357
- val_loss, _ = loss_fun(
358
- model, input_val, output_val, idx=idx, epsilon=None, **step_kwargs
359
- )
369
+ val_loss_batch = loss_fun(
370
+ model, input_vb.to(device), out_vb.to(device), idx=idx_b, epsilon=None, **step_kwargs
371
+ )[0]
372
+ val_loss.append(val_loss_batch.item())
373
+ val_loss = sum(val_loss) / len(val_loss)
360
374
  aux["val_loss"] = val_loss
361
375
  else:
362
376
  aux["val_loss"] = None
@@ -618,7 +632,7 @@ class Trainor_class:
618
632
  dill.dump(obj, f)
619
633
  print(f"Object saved in {filename}")
620
634
 
621
- def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, **fn_kwargs):
635
+ def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, device="cpu",**fn_kwargs):
622
636
  """NOTE: fn_kwargs defines the functions of the model
623
637
  (e.g. final_activation, inner activation), if
624
638
  needed to be saved/loaded on different devices/OS.
@@ -660,7 +674,7 @@ class Trainor_class:
660
674
 
661
675
  model = model_cls(**kwargs)
662
676
  model.load_state_dict(save_dict["model_state_dict"])
663
- self.model = model
677
+ self.model = model.to(device)
664
678
  attributes = save_dict["attr"]
665
679
 
666
680
  for key in attributes:
@@ -849,8 +863,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
849
863
  ft_end_type = "concat"
850
864
  basis_call_kwargs = {}
851
865
 
866
+ device = ft_kwargs.get("device", "cpu")
867
+
852
868
  ft_model, ft_track_params = self.fine_tune_basis(
853
- None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs
869
+ None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device
854
870
  ) # fine tune basis
855
871
  self.ft_track_params = ft_track_params
856
872
  else:
@@ -858,7 +874,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
858
874
  ft_track_params = {}
859
875
  return model, track_params, ft_model, ft_track_params
860
876
 
861
- def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, *, args, kwargs):
877
+ def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, device="cpu", *, args, kwargs):
862
878
 
863
879
  if "loss" in kwargs:
864
880
  norm_loss_ = kwargs["loss"]
@@ -884,24 +900,24 @@ class RRAE_Trainor_class(AE_Trainor_class):
884
900
  inpT = inp.permute(*range(inp.ndim - 1, -1, -1))
885
901
  dataset = TensorDataset(inpT)
886
902
  dataloader = DataLoader(dataset, batch_size=basis_batch_size, shuffle=False)
903
+ model = self.model.to(device)
887
904
 
888
905
  all_bases = []
889
906
 
890
907
  for inp_b in dataloader:
891
908
  inp_bT = inp_b[0].permute(*range(inp_b[0].ndim - 1, -1, -1))
892
- all_bases.append(self.model.latent(
893
- self.pre_func_inp(inp_bT), get_basis_coeffs=True, **basis_kwargs
894
- )[0]
909
+ all_bases.append(model.latent(
910
+ self.pre_func_inp(inp_bT.to(device)), get_basis_coeffs=True, **basis_kwargs
911
+ )[0].to("cpu")
895
912
  )
896
913
  if end_type == "concat":
897
914
  all_bases = torch.concatenate(all_bases, axis=1)
898
- print(all_bases.shape)
899
915
  basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
900
916
  self.basis = basis[:, : self.track_params["k_max"]]
901
917
  else:
902
918
  self.basis = all_bases
903
919
  else:
904
- bas = self.model.latent(self.pre_func_inp(inp[..., 0:1]), get_basis_coeffs=True, **self.track_params)[0]
920
+ bas = model.latent(self.pre_func_inp(inp[..., 0:1].to(device)), get_basis_coeffs=True, **self.track_params)[0].to("cpu")
905
921
  self.basis = torch.eye(bas.shape[0])
906
922
  else:
907
923
  self.basis = basis
@@ -932,10 +948,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
932
948
  batch_size=None,
933
949
  pre_func_inp=lambda x: x,
934
950
  pre_func_out=lambda x: x,
935
- call_func=None,
951
+ device="cpu",
936
952
  ):
937
953
 
938
- call_func = lambda x: self.model(pre_func_inp(x), apply_basis=self.basis, epsilon=None)
954
+ call_func = lambda x: self.model(pre_func_inp(x.to(device)), apply_basis=self.basis.to(device), epsilon=None).to("cpu")
939
955
  res = super().evaluate(
940
956
  x_train_o,
941
957
  y_train_o,
@@ -945,6 +961,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
945
961
  call_func=call_func,
946
962
  pre_func_inp=pre_func_inp,
947
963
  pre_func_out=pre_func_out,
964
+ device=device,
948
965
  )
949
966
  return res
950
967
 
@@ -1510,7 +1510,6 @@ class MLP_with_CNN3D_trans(torch.nn.Module):
1510
1510
  self.out_after_mlp = out_after_mlp
1511
1511
 
1512
1512
  def forward(self, x, *args, **kwargs):
1513
- print(x.shape)
1514
1513
  x = self.layers[0](x)
1515
1514
  x = torch.reshape(x, (self.out_after_mlp, self.first_D0, self.first_D1, self.first_D2))
1516
1515
  x = self.layers[1](x)
@@ -68,6 +68,8 @@ if __name__ == "__main__":
68
68
  # you need to specify training kw arguments (first stage of training with SVD to
69
69
  # find the basis), and fine-tuning kw arguments (second stage of training with the
70
70
  # basis found in the first stage).
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
71
73
  training_kwargs = {
72
74
  "step_st": [2,], # Increase those to train well
73
75
  "batch_size_st": [20, 64],
@@ -75,6 +77,7 @@ if __name__ == "__main__":
75
77
  "print_every": 1,
76
78
  # "save_every": 789,
77
79
  "loss_type": loss_type,
80
+ "device": device,
78
81
  "save_losses": True # if you want to save losses to plot them later
79
82
  }
80
83
 
@@ -83,6 +86,7 @@ if __name__ == "__main__":
83
86
  "batch_size_st": [20],
84
87
  "lr_st": [1e-4, 1e-6, 1e-7, 1e-8],
85
88
  "print_every": 100,
89
+ "device": device,
86
90
  # "save_every": 50,
87
91
  }
88
92
 
@@ -103,12 +107,12 @@ if __name__ == "__main__":
103
107
  # trainor.plot_training_losses(idx=0) # to plot both training and validation losses
104
108
 
105
109
  preds = trainor.evaluate(
106
- x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out
110
+ x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out, device
107
111
  )
108
112
  # NOTE: preds are not saved so uncomment last line if you want to save/plot etc.
109
113
 
110
114
  trainor.save_model(kwargs=kwargs)
111
-
115
+
112
116
  # Uncomment the following line if you want to hold the session to check your
113
117
  # results in the console.
114
118
  # pdb.set_trace()
@@ -51,8 +51,8 @@ if __name__ == "__main__":
51
51
  k_max=k_max,
52
52
  folder=f"{problem}",
53
53
  file=f"{method}_{problem}_test.pkl",
54
- norm_in="minmax",
55
- norm_out="minmax",
54
+ norm_in="None",
55
+ norm_out="None",
56
56
  out_train=x_train,
57
57
  kwargs_enc={
58
58
  "width_CNNs": [32, 64, 128],
@@ -73,6 +73,8 @@ if __name__ == "__main__":
73
73
  # you need to specify training kw arguments (first stage of training with SVD to
74
74
  # find the basis), and fine-tuning kw arguments (second stage of training with the
75
75
  # basis found in the first stage).
76
+ device = "cuda" if torch.cuda.is_available() else "cpu"
77
+
76
78
  training_kwargs = {
77
79
  "step_st": [2], # Increase those to train well
78
80
  "batch_size_st": [64, 64],
@@ -80,6 +82,7 @@ if __name__ == "__main__":
80
82
  "print_every": 1,
81
83
  # "save_every": 789,
82
84
  "loss_type": loss_type,
85
+ "device": device,
83
86
  }
84
87
 
85
88
  ft_kwargs = {
@@ -87,6 +90,7 @@ if __name__ == "__main__":
87
90
  "batch_size_st": [64],
88
91
  "lr_st": [1e-4, 1e-6, 1e-7, 1e-8],
89
92
  "print_every": 100,
93
+ "device": device,
90
94
  # "save_every": 50,
91
95
  }
92
96
 
@@ -100,7 +104,7 @@ if __name__ == "__main__":
100
104
  pre_func_out=pre_func_out,
101
105
  )
102
106
  preds = trainor.evaluate(
103
- x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out
107
+ x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out, device
104
108
  )
105
109
  # NOTE: preds are not saved so uncomment last line if you want to save/plot etc.
106
110
 
@@ -3,6 +3,7 @@ import RRAEsTorch.config
3
3
  from RRAEsTorch.AE_classes import *
4
4
  from RRAEsTorch.training_classes import RRAE_Trainor_class
5
5
  import numpy.random as random
6
+ import torch
6
7
 
7
8
  if __name__ == "__main__":
8
9
 
@@ -67,6 +68,8 @@ if __name__ == "__main__":
67
68
  # you need to specify training kw arguments (first stage of training with SVD to
68
69
  # find the basis), and fine-tuning kw arguments (second stage of training with the
69
70
  # basis found in the first stage).
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
70
73
  training_kwargs = {
71
74
  "step_st": [2,], # Increase those to train well
72
75
  "batch_size_st": [64, 64],
@@ -74,6 +77,7 @@ if __name__ == "__main__":
74
77
  "print_every": 1,
75
78
  # "save_every": 789,
76
79
  "loss_type": loss_type,
80
+ "device": device,
77
81
  }
78
82
 
79
83
  ft_kwargs = {
@@ -81,6 +85,7 @@ if __name__ == "__main__":
81
85
  "batch_size_st": [20],
82
86
  "lr_st": [1e-4, 1e-6, 1e-7, 1e-8],
83
87
  "print_every": 100,
88
+ "device": device,
84
89
  # "save_every": 50,
85
90
  }
86
91
 
@@ -93,7 +98,10 @@ if __name__ == "__main__":
93
98
  #pre_func_inp=pre_func_inp,
94
99
  #pre_func_out=pre_func_out,
95
100
  )
96
-
101
+
102
+ preds = trainor.evaluate(
103
+ x_train, y_train, x_test, y_test, None, device=device
104
+ )
97
105
  # NOTE: preds are not saved so uncomment last line if you want to save/plot etc.
98
106
 
99
107
  #trainor.save_model(kwargs=kwargs)
@@ -3,6 +3,7 @@ import RRAEsTorch.config # Include this in all your scripts
3
3
  from RRAEsTorch.AE_classes import *
4
4
  from RRAEsTorch.training_classes import RRAE_Trainor_class # , Trainor_class
5
5
  from RRAEsTorch.utilities import get_data
6
+ import torch
6
7
 
7
8
  if __name__ == "__main__":
8
9
  # Step 1: Get the data - replace this with your own data of the same shape.
@@ -61,13 +62,16 @@ if __name__ == "__main__":
61
62
  # you need to specify training kw arguments (first stage of training with SVD to
62
63
  # find the basis), and fine-tuning kw arguments (second stage of training with the
63
64
  # basis found in the first stage).
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+
64
67
  training_kwargs = {
65
68
  "step_st": [2], # Increase this to train well (e.g. 2000)
66
69
  "batch_size_st": [64, 64, 64, 64, 64],
67
70
  "lr_st": [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
68
71
  "print_every": 1,
69
72
  "loss_type": loss_type,
70
- "save_losses":True
73
+ "save_losses":True,
74
+ "device": device,
71
75
  }
72
76
 
73
77
  ft_kwargs = {
@@ -75,6 +79,7 @@ if __name__ == "__main__":
75
79
  "batch_size_st": [64],
76
80
  "lr_st": [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
77
81
  "print_every": 1,
82
+ "device": device,
78
83
  }
79
84
 
80
85
  # Step 6: Train the model and get the predictions.
@@ -97,7 +102,7 @@ if __name__ == "__main__":
97
102
 
98
103
 
99
104
  preds = trainor.evaluate(
100
- x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out
105
+ x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out, device
101
106
  )
102
107
 
103
108
 
@@ -7,7 +7,7 @@ from RRAEsTorch.AE_classes import RRAE_CNN
7
7
  from RRAEsTorch.training_classes import RRAE_Trainor_class
8
8
  from RRAEsTorch.trackers import RRAE_gen_Tracker, RRAE_fixed_Tracker, RRAE_pars_Tracker
9
9
  from RRAEsTorch.utilities import get_data
10
-
10
+ import torch
11
11
 
12
12
  if __name__ == "__main__":
13
13
  # Step 1: Get the data - replace this with your own data of the same shape.
@@ -69,13 +69,16 @@ if __name__ == "__main__":
69
69
  },
70
70
  )
71
71
 
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+
72
74
  training_kwargs = {
73
75
  "step_st": [2,], # increase for a very big value
74
76
  "batch_size_st": [64,],
75
77
  "lr_st": [1e-3, 1e-4, 1e-7, 1e-8],
76
78
  "print_every": 1,
77
79
  "loss_type": loss_type,
78
- "tracker": RRAE_gen_Tracker(k_init=k_max, patience_init=50)
80
+ "tracker": RRAE_gen_Tracker(k_init=k_max, patience_init=50),
81
+ "device": device,
79
82
  }
80
83
 
81
84
  # The tracker above will specify the adaptive scheme to be used. Gen means generic and it
@@ -89,6 +92,7 @@ if __name__ == "__main__":
89
92
  "batch_size_st": [64],
90
93
  "lr_st": [1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
91
94
  "print_every": 1,
95
+ "device": device,
92
96
  }
93
97
 
94
98
  trainor.fit(
@@ -101,7 +105,7 @@ if __name__ == "__main__":
101
105
  )
102
106
 
103
107
  preds = trainor.evaluate(
104
- x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out
108
+ x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out, device
105
109
  )
106
110
 
107
111
  # Uncomment the following line if you want to hold the session to check your
@@ -7,7 +7,7 @@ from RRAEsTorch.AE_classes import RRAE_MLP
7
7
  from RRAEsTorch.training_classes import RRAE_Trainor_class, Trainor_class
8
8
  from RRAEsTorch.trackers import RRAE_gen_Tracker, RRAE_fixed_Tracker, RRAE_pars_Tracker
9
9
  from RRAEsTorch.utilities import get_data
10
-
10
+ import torch
11
11
 
12
12
  if __name__ == "__main__":
13
13
  # Step 1: Get the data - replace this with your own data of the same shape.
@@ -68,13 +68,16 @@ if __name__ == "__main__":
68
68
  # you need to specify training kw arguments (first stage of training with SVD to
69
69
  # find the basis), and fine-tuning kw arguments (second stage of training with the
70
70
  # basis found in the first stage).
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
71
73
  training_kwargs = {
72
74
  "step_st": [2], # Increase those to train better
73
75
  "batch_size_st": [64, 64, 64, 64, 64],
74
76
  "lr_st": [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
75
77
  "print_every": 1,
76
78
  "loss_type": loss_type,
77
- "tracker": RRAE_gen_Tracker(k_init=k_max, patience_init=200)
79
+ "tracker": RRAE_gen_Tracker(k_init=k_max, patience_init=200),
80
+ "device": device,
78
81
  }
79
82
 
80
83
  # The tracker above will specify the adaptive scheme to be used. Gen means generic and it
@@ -88,6 +91,7 @@ if __name__ == "__main__":
88
91
  "batch_size_st": [64],
89
92
  "lr_st": [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
90
93
  "print_every": 1,
94
+ "device": device,
91
95
  }
92
96
 
93
97
  # Step 6: Train the model and get the predictions.
@@ -101,7 +105,7 @@ if __name__ == "__main__":
101
105
  )
102
106
 
103
107
  preds = trainor.evaluate(
104
- x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out
108
+ x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out, device
105
109
  )
106
110
 
107
111
  # Uncomment the following line if you want to hold the session to check your
@@ -4,6 +4,7 @@ from RRAEsTorch.AE_classes import *
4
4
  from RRAEsTorch.training_classes import RRAE_Trainor_class, Trainor_class # , Trainor_class
5
5
  from RRAEsTorch.utilities import get_data
6
6
  import numpy as np
7
+ import torch
7
8
 
8
9
 
9
10
  if __name__ == "__main__":
@@ -43,9 +44,9 @@ if __name__ == "__main__":
43
44
 
44
45
  match method:
45
46
  case "VRRAE":
46
- eps_fn = lambda lat, bs: np.random.normal(0, 1, size=(1, 1, k_max, bs))
47
+ eps_fn = lambda lat, bs: torch.tensor(np.random.normal(0, 1, size=(1, 1, k_max, bs)), dtype=torch.float32, device=device)
47
48
  case "VAE":
48
- eps_fn = lambda lat, bs: np.random.normal(size=(1, 1, lat, bs))
49
+ eps_fn = lambda lat, bs: torch.tensor(np.random.normal(size=(1, 1, lat, bs)), dtype=torch.float32, device=device)
49
50
 
50
51
  # Step 3: Specify the archietectures' parameters:
51
52
  latent_size = 200 # latent space dimension
@@ -86,6 +87,8 @@ if __name__ == "__main__":
86
87
  # you need to specify training kw arguments (first stage of training with SVD to
87
88
  # find the basis), and fine-tuning kw arguments (second stage of training with the
88
89
  # basis found in the first stage).
90
+ device = "cuda" if torch.cuda.is_available() else "cpu"
91
+
89
92
  training_kwargs = {
90
93
  "step_st": [2], # 7680*data_size/64
91
94
  "batch_size_st": [64],
@@ -94,6 +97,7 @@ if __name__ == "__main__":
94
97
  "loss_type": loss_type,
95
98
  "loss_kwargs": {"beta": 0.001},
96
99
  "eps_fn": eps_fn,
100
+ "device": device,
97
101
  }
98
102
 
99
103
 
@@ -102,7 +106,8 @@ if __name__ == "__main__":
102
106
  "batch_size_st": [64],
103
107
  "lr_st": [1e-4, 1e-6, 1e-7, 1e-8],
104
108
  "print_every": 1,
105
- "eps_fn": eps_fn
109
+ "eps_fn": eps_fn,
110
+ "device": device,
106
111
  }
107
112
 
108
113
 
@@ -122,7 +127,7 @@ if __name__ == "__main__":
122
127
  trainor.save_model()
123
128
 
124
129
  preds = trainor.evaluate(
125
- x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out
130
+ x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out, device
126
131
  )
127
132
 
128
133
  # pdb.set_trace()
@@ -4,6 +4,7 @@ from RRAEsTorch.AE_classes import *
4
4
  from RRAEsTorch.training_classes import RRAE_Trainor_class, Trainor_class # , Trainor_class
5
5
  from RRAEsTorch.utilities import get_data
6
6
  import numpy as np
7
+ import torch
7
8
 
8
9
 
9
10
  if __name__ == "__main__":
@@ -45,9 +46,9 @@ if __name__ == "__main__":
45
46
 
46
47
  match method:
47
48
  case "VRRAE":
48
- eps_fn = lambda lat, bs: np.random.normal(0, 1, size=(1, 1, k_max, bs))
49
+ eps_fn = lambda lat, bs: torch.tensor(np.random.normal(0, 1, size=(1, 1, k_max, bs)), dtype=torch.float32, device=device)
49
50
  case "VAE":
50
- eps_fn = lambda lat, bs: np.random.normal(size=(1, 1, lat, bs))
51
+ eps_fn = lambda lat, bs: torch.tensor(np.random.normal(size=(1, 1, lat, bs)), dtype=torch.float32, device=device)
51
52
 
52
53
  # Step 3: Specify the archietectures' parameters:
53
54
  latent_size = 200 # latent space dimension
@@ -88,6 +89,8 @@ if __name__ == "__main__":
88
89
  # you need to specify training kw arguments (first stage of training with SVD to
89
90
  # find the basis), and fine-tuning kw arguments (second stage of training with the
90
91
  # basis found in the first stage).
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+
91
94
  training_kwargs = {
92
95
  "step_st": [2], # 7680*data_size/64
93
96
  "batch_size_st": [64],
@@ -96,6 +99,7 @@ if __name__ == "__main__":
96
99
  "loss_type": loss_type,
97
100
  "loss_kwargs": {"beta": 0.001},
98
101
  "eps_fn": eps_fn,
102
+ "device": device,
99
103
  }
100
104
 
101
105
 
@@ -104,7 +108,8 @@ if __name__ == "__main__":
104
108
  "batch_size_st": [64],
105
109
  "lr_st": [1e-4, 1e-6, 1e-7, 1e-8],
106
110
  "print_every": 1,
107
- "eps_fn": eps_fn
111
+ "eps_fn": eps_fn,
112
+ "device": device
108
113
  }
109
114
 
110
115
 
@@ -124,7 +129,7 @@ if __name__ == "__main__":
124
129
  trainor.save_model()
125
130
 
126
131
  preds = trainor.evaluate(
127
- x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out
132
+ x_train, y_train, x_test, y_test, None, pre_func_inp, pre_func_out, device
128
133
  )
129
134
 
130
135
  # pdb.set_trace()
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "RRAEsTorch"
7
- version = "0.1.5"
7
+ version = "0.1.6"
8
8
  description= " A repo for RRAEs in PyTorch."
9
9
  readme="README.md"
10
10
  requires-python= ">=3.10"
@@ -1,207 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[codz]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py.cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # UV
98
- # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- #uv.lock
102
-
103
- # poetry
104
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
- # This is especially recommended for binary packages to ensure reproducibility, and is more
106
- # commonly ignored for libraries.
107
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
- #poetry.lock
109
- #poetry.toml
110
-
111
- # pdm
112
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
- # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
- # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
- #pdm.lock
116
- #pdm.toml
117
- .pdm-python
118
- .pdm-build/
119
-
120
- # pixi
121
- # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
- #pixi.lock
123
- # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
- # in the .venv directory. It is recommended not to include this directory in version control.
125
- .pixi
126
-
127
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
- __pypackages__/
129
-
130
- # Celery stuff
131
- celerybeat-schedule
132
- celerybeat.pid
133
-
134
- # SageMath parsed files
135
- *.sage.py
136
-
137
- # Environments
138
- .env
139
- .envrc
140
- .venv
141
- env/
142
- venv/
143
- ENV/
144
- env.bak/
145
- venv.bak/
146
-
147
- # Spyder project settings
148
- .spyderproject
149
- .spyproject
150
-
151
- # Rope project settings
152
- .ropeproject
153
-
154
- # mkdocs documentation
155
- /site
156
-
157
- # mypy
158
- .mypy_cache/
159
- .dmypy.json
160
- dmypy.json
161
-
162
- # Pyre type checker
163
- .pyre/
164
-
165
- # pytype static type analyzer
166
- .pytype/
167
-
168
- # Cython debug symbols
169
- cython_debug/
170
-
171
- # PyCharm
172
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
- # and can be added to the global gitignore or merged into this file. For a more nuclear
175
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
- #.idea/
177
-
178
- # Abstra
179
- # Abstra is an AI-powered process automation framework.
180
- # Ignore directories containing user credentials, local state, and settings.
181
- # Learn more at https://abstra.io/docs
182
- .abstra/
183
-
184
- # Visual Studio Code
185
- # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
- # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
- # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
- # you could uncomment the following to ignore the entire vscode folder
189
- # .vscode/
190
-
191
- # Ruff stuff:
192
- .ruff_cache/
193
-
194
- # PyPI configuration file
195
- .pypirc
196
-
197
- # Cursor
198
- # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
- # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
- # refer to https://docs.cursor.com/context/ignore-files
201
- .cursorignore
202
- .cursorindexingignore
203
-
204
- # Marimo
205
- marimo/_static/
206
- marimo/_lsp/
207
- __marimo__/
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2026 Jad Mounayer
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes