RRAEsTorch 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -127,7 +127,7 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
127
127
  if epsilon is not None:
128
128
  if len(epsilon.shape) == 4:
129
129
  epsilon = epsilon[0, 0] # to allow tpu sharding
130
- z = mean + torch.tensor(epsilon, dtype=torch.float32) * std
130
+ z = mean + epsilon * std
131
131
  else:
132
132
  z = mean
133
133
 
@@ -612,10 +612,6 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
612
612
  typ: int
613
613
 
614
614
  def __init__(self, channels, input_dim, latent_size, k_max, typ="eye", *, count=1, **kwargs):
615
- v_Linear = vmap_wrap(Linear, -1, count=count)
616
- self.lin_mean = v_Linear(k_max, k_max,)
617
- self.lin_logvar = v_Linear(k_max, k_max)
618
- self.typ = typ
619
615
  super().__init__(
620
616
  channels,
621
617
  input_dim,
@@ -623,6 +619,12 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
623
619
  count=count,
624
620
  **kwargs,
625
621
  )
622
+
623
+ v_Linear = vmap_wrap(Linear, -1, count=count)
624
+ self.lin_mean = v_Linear(k_max, k_max,)
625
+ self.lin_logvar = v_Linear(k_max, k_max)
626
+ self.typ = typ
627
+
626
628
 
627
629
  def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
628
630
  return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
@@ -229,7 +229,8 @@ class Trainor_class:
229
229
  save_losses=False,
230
230
  input_val=None,
231
231
  output_val=None,
232
- latent_size=0
232
+ latent_size=0,
233
+ device="cpu"
233
234
  ):
234
235
  assert isinstance(input, torch.Tensor), "Input should be a torch tensor"
235
236
  assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
@@ -301,7 +302,12 @@ class Trainor_class:
301
302
  all_losses = []
302
303
 
303
304
 
304
-
305
+ model = model.to(device)
306
+
307
+ if input_val is not None:
308
+ dataset_val = TensorDataset(input_val.permute(*range(input_val.ndim - 1, -1, -1)), output_val.permute(*range(output_val.ndim - 1, -1, -1)), torch.arange(0, input_val.shape[-1], 1))
309
+ dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[-1], shuffle=False)
310
+
305
311
  # Outler Loop
306
312
  for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
307
313
  try:
@@ -340,6 +346,9 @@ class Trainor_class:
340
346
  step_kwargs = merge_dicts(loss_kwargs, track_params)
341
347
 
342
348
  # Compute loss
349
+ input_b = input_b.to(device)
350
+ out_b = out_b.to(device)
351
+
343
352
  loss, model, optimizer_tr, (aux, extra_track) = make_step(
344
353
  model,
345
354
  input_b,
@@ -351,12 +360,17 @@ class Trainor_class:
351
360
  )
352
361
 
353
362
  if input_val is not None:
363
+ val_loss = []
364
+ for input_vb, out_vb, idx_b in dataloader_val:
365
+ input_vb = input_vb.permute(*range(input_vb.ndim - 1, -1, -1))
366
+ out_vb = self.model.norm_out.default(None, pre_func_out(out_vb)) # Pre-process batch out values
367
+ out_vb = out_vb.permute(*range(out_vb.ndim - 1, -1, -1))
354
368
 
355
-
356
- idx = np.arange(input_val.shape[-1])
357
- val_loss, _ = loss_fun(
358
- model, input_val, output_val, idx=idx, epsilon=None, **step_kwargs
359
- )
369
+ val_loss_batch = loss_fun(
370
+ model, input_vb.to(device), out_vb.to(device), idx=idx_b, epsilon=None, **step_kwargs
371
+ )[0]
372
+ val_loss.append(val_loss_batch.item())
373
+ val_loss = sum(val_loss) / len(val_loss)
360
374
  aux["val_loss"] = val_loss
361
375
  else:
362
376
  aux["val_loss"] = None
@@ -618,7 +632,7 @@ class Trainor_class:
618
632
  dill.dump(obj, f)
619
633
  print(f"Object saved in {filename}")
620
634
 
621
- def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, **fn_kwargs):
635
+ def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, device="cpu",**fn_kwargs):
622
636
  """NOTE: fn_kwargs defines the functions of the model
623
637
  (e.g. final_activation, inner activation), if
624
638
  needed to be saved/loaded on different devices/OS.
@@ -660,7 +674,7 @@ class Trainor_class:
660
674
 
661
675
  model = model_cls(**kwargs)
662
676
  model.load_state_dict(save_dict["model_state_dict"])
663
- self.model = model
677
+ self.model = model.to(device)
664
678
  attributes = save_dict["attr"]
665
679
 
666
680
  for key in attributes:
@@ -849,8 +863,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
849
863
  ft_end_type = "concat"
850
864
  basis_call_kwargs = {}
851
865
 
866
+ device = ft_kwargs.get("device", "cpu")
867
+
852
868
  ft_model, ft_track_params = self.fine_tune_basis(
853
- None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs
869
+ None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device
854
870
  ) # fine tune basis
855
871
  self.ft_track_params = ft_track_params
856
872
  else:
@@ -858,7 +874,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
858
874
  ft_track_params = {}
859
875
  return model, track_params, ft_model, ft_track_params
860
876
 
861
- def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, *, args, kwargs):
877
+ def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, device="cpu", *, args, kwargs):
862
878
 
863
879
  if "loss" in kwargs:
864
880
  norm_loss_ = kwargs["loss"]
@@ -884,24 +900,24 @@ class RRAE_Trainor_class(AE_Trainor_class):
884
900
  inpT = inp.permute(*range(inp.ndim - 1, -1, -1))
885
901
  dataset = TensorDataset(inpT)
886
902
  dataloader = DataLoader(dataset, batch_size=basis_batch_size, shuffle=False)
903
+ model = self.model.to(device)
887
904
 
888
905
  all_bases = []
889
906
 
890
907
  for inp_b in dataloader:
891
908
  inp_bT = inp_b[0].permute(*range(inp_b[0].ndim - 1, -1, -1))
892
- all_bases.append(self.model.latent(
893
- self.pre_func_inp(inp_bT), get_basis_coeffs=True, **basis_kwargs
894
- )[0]
909
+ all_bases.append(model.latent(
910
+ self.pre_func_inp(inp_bT.to(device)), get_basis_coeffs=True, **basis_kwargs
911
+ )[0].to("cpu")
895
912
  )
896
913
  if end_type == "concat":
897
914
  all_bases = torch.concatenate(all_bases, axis=1)
898
- print(all_bases.shape)
899
915
  basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
900
916
  self.basis = basis[:, : self.track_params["k_max"]]
901
917
  else:
902
918
  self.basis = all_bases
903
919
  else:
904
- bas = self.model.latent(self.pre_func_inp(inp[..., 0:1]), get_basis_coeffs=True, **self.track_params)[0]
920
+ bas = model.latent(self.pre_func_inp(inp[..., 0:1].to(device)), get_basis_coeffs=True, **self.track_params)[0].to("cpu")
905
921
  self.basis = torch.eye(bas.shape[0])
906
922
  else:
907
923
  self.basis = basis
@@ -932,10 +948,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
932
948
  batch_size=None,
933
949
  pre_func_inp=lambda x: x,
934
950
  pre_func_out=lambda x: x,
935
- call_func=None,
951
+ device="cpu",
936
952
  ):
937
953
 
938
- call_func = lambda x: self.model(pre_func_inp(x), apply_basis=self.basis, epsilon=None)
954
+ call_func = lambda x: self.model(pre_func_inp(x.to(device)), apply_basis=self.basis.to(device), epsilon=None).to("cpu")
939
955
  res = super().evaluate(
940
956
  x_train_o,
941
957
  y_train_o,
@@ -945,6 +961,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
945
961
  call_func=call_func,
946
962
  pre_func_inp=pre_func_inp,
947
963
  pre_func_out=pre_func_out,
964
+ device=device,
948
965
  )
949
966
  return res
950
967
 
@@ -1510,7 +1510,6 @@ class MLP_with_CNN3D_trans(torch.nn.Module):
1510
1510
  self.out_after_mlp = out_after_mlp
1511
1511
 
1512
1512
  def forward(self, x, *args, **kwargs):
1513
- print(x.shape)
1514
1513
  x = self.layers[0](x)
1515
1514
  x = torch.reshape(x, (self.out_after_mlp, self.first_D0, self.first_D1, self.first_D2))
1516
1515
  x = self.layers[1](x)
@@ -1,11 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: RRAEsTorch
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A repo for RRAEs in PyTorch.
5
5
  Author-email: Jad Mounayer <jad.mounayer@outlook.com>
6
6
  License: MIT
7
7
  License-File: LICENSE
8
- License-File: LICENSE copy
9
8
  Requires-Python: >=3.10
10
9
  Requires-Dist: dill
11
10
  Requires-Dist: jaxtyping
@@ -2,7 +2,7 @@ RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
2
2
  RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
3
3
  RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
4
4
  RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
5
- RRAEsTorch/AE_classes/AE_classes.py,sha256=HUiCH5iXhdMSZIQiulIktIwCpDjfPAXPrLiMpUYHg4M,18078
5
+ RRAEsTorch/AE_classes/AE_classes.py,sha256=oDpDzQasPbtK2L9vDLiG4VQdKH02VRCagOYT1-FAldo,18063
6
6
  RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
7
7
  RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=bEE9JnTo84t9w0a4kw1W74L51eLGjBB8trrlAG938RE,3182
8
8
  RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Cr1_uP7lag6RPQC1UhN2O7RFW4BEx1cd0Z-Y6VgrWRg,2718
@@ -15,13 +15,12 @@ RRAEsTorch/tests/test_wrappers.py,sha256=Ike4IfMUx2Qic3f3_cBikgFPEU1WW5TuH1jT_r2
15
15
  RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
16
16
  RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
17
17
  RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
18
- RRAEsTorch/training_classes/training_classes.py,sha256=a7JjhCrH7s7VmVCsvKj768Ciq-tdbh6E_B9aG1kw7vc,36634
18
+ RRAEsTorch/training_classes/training_classes.py,sha256=HU8Ksz1-2WwOMuwyGiWdkQ_vrrgBEwaeQT4avs4jd2E,37870
19
19
  RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
20
- RRAEsTorch/utilities/utilities.py,sha256=FzJWV9oFPF9sL9MC2m7euMqMKxCuLUEukzLfU0cF2to,53396
20
+ RRAEsTorch/utilities/utilities.py,sha256=JfLkAPEC8fzwgM32LEcXVe0tA4C7UBgsrkuh6noUA_4,53372
21
21
  RRAEsTorch/wrappers/__init__.py,sha256=txiLh4ylnuvPlapagz7DiAslmjllOzTqwCDL2dFr6dM,44
22
22
  RRAEsTorch/wrappers/wrappers.py,sha256=9Rmq2RS_EkZvsg96SKrt1HFIP35sF0xyPI0goV0ujOs,9659
23
- rraestorch-0.1.5.dist-info/METADATA,sha256=uL69l4DSjgNliJfX3uPyrZqrH2qvqZAu55DE1jfF5KQ,3055
24
- rraestorch-0.1.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
- rraestorch-0.1.5.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
26
- rraestorch-0.1.5.dist-info/licenses/LICENSE copy,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
27
- rraestorch-0.1.5.dist-info/RECORD,,
23
+ rraestorch-0.1.6.dist-info/METADATA,sha256=BfnB-vhx0m-d79hCd8UgZT0c8GPHFRwVn8x-M8k_h6E,3028
24
+ rraestorch-0.1.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
+ rraestorch-0.1.6.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
26
+ rraestorch-0.1.6.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2026 Jad Mounayer
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.