RRAEsTorch 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_classes/AE_classes.py +7 -5
- RRAEsTorch/training_classes/training_classes.py +35 -18
- RRAEsTorch/utilities/utilities.py +0 -1
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.6.dist-info}/METADATA +1 -2
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.6.dist-info}/RECORD +7 -8
- rraestorch-0.1.5.dist-info/licenses/LICENSE copy +0 -21
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.6.dist-info}/WHEEL +0 -0
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -127,7 +127,7 @@ def latent_func_var_strong_RRAE(self, y, k_max=None, epsilon=None, return_dist=F
|
|
|
127
127
|
if epsilon is not None:
|
|
128
128
|
if len(epsilon.shape) == 4:
|
|
129
129
|
epsilon = epsilon[0, 0] # to allow tpu sharding
|
|
130
|
-
z = mean +
|
|
130
|
+
z = mean + epsilon * std
|
|
131
131
|
else:
|
|
132
132
|
z = mean
|
|
133
133
|
|
|
@@ -612,10 +612,6 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
|
|
|
612
612
|
typ: int
|
|
613
613
|
|
|
614
614
|
def __init__(self, channels, input_dim, latent_size, k_max, typ="eye", *, count=1, **kwargs):
|
|
615
|
-
v_Linear = vmap_wrap(Linear, -1, count=count)
|
|
616
|
-
self.lin_mean = v_Linear(k_max, k_max,)
|
|
617
|
-
self.lin_logvar = v_Linear(k_max, k_max)
|
|
618
|
-
self.typ = typ
|
|
619
615
|
super().__init__(
|
|
620
616
|
channels,
|
|
621
617
|
input_dim,
|
|
@@ -623,6 +619,12 @@ class VRRAE_CNN1D(CNN1D_Autoencoder):
|
|
|
623
619
|
count=count,
|
|
624
620
|
**kwargs,
|
|
625
621
|
)
|
|
622
|
+
|
|
623
|
+
v_Linear = vmap_wrap(Linear, -1, count=count)
|
|
624
|
+
self.lin_mean = v_Linear(k_max, k_max,)
|
|
625
|
+
self.lin_logvar = v_Linear(k_max, k_max)
|
|
626
|
+
self.typ = typ
|
|
627
|
+
|
|
626
628
|
|
|
627
629
|
def _perform_in_latent(self, y, *args, k_max=None, epsilon=None, return_dist=False, return_lat_dist=False, **kwargs):
|
|
628
630
|
return latent_func_var_strong_RRAE(self, y, k_max, epsilon, return_dist, return_lat_dist, **kwargs)
|
|
@@ -229,7 +229,8 @@ class Trainor_class:
|
|
|
229
229
|
save_losses=False,
|
|
230
230
|
input_val=None,
|
|
231
231
|
output_val=None,
|
|
232
|
-
latent_size=0
|
|
232
|
+
latent_size=0,
|
|
233
|
+
device="cpu"
|
|
233
234
|
):
|
|
234
235
|
assert isinstance(input, torch.Tensor), "Input should be a torch tensor"
|
|
235
236
|
assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
|
|
@@ -301,7 +302,12 @@ class Trainor_class:
|
|
|
301
302
|
all_losses = []
|
|
302
303
|
|
|
303
304
|
|
|
304
|
-
|
|
305
|
+
model = model.to(device)
|
|
306
|
+
|
|
307
|
+
if input_val is not None:
|
|
308
|
+
dataset_val = TensorDataset(input_val.permute(*range(input_val.ndim - 1, -1, -1)), output_val.permute(*range(output_val.ndim - 1, -1, -1)), torch.arange(0, input_val.shape[-1], 1))
|
|
309
|
+
dataloader_val = DataLoader(dataset_val, batch_size=input_val.shape[-1], shuffle=False)
|
|
310
|
+
|
|
305
311
|
# Outler Loop
|
|
306
312
|
for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
|
|
307
313
|
try:
|
|
@@ -340,6 +346,9 @@ class Trainor_class:
|
|
|
340
346
|
step_kwargs = merge_dicts(loss_kwargs, track_params)
|
|
341
347
|
|
|
342
348
|
# Compute loss
|
|
349
|
+
input_b = input_b.to(device)
|
|
350
|
+
out_b = out_b.to(device)
|
|
351
|
+
|
|
343
352
|
loss, model, optimizer_tr, (aux, extra_track) = make_step(
|
|
344
353
|
model,
|
|
345
354
|
input_b,
|
|
@@ -351,12 +360,17 @@ class Trainor_class:
|
|
|
351
360
|
)
|
|
352
361
|
|
|
353
362
|
if input_val is not None:
|
|
363
|
+
val_loss = []
|
|
364
|
+
for input_vb, out_vb, idx_b in dataloader_val:
|
|
365
|
+
input_vb = input_vb.permute(*range(input_vb.ndim - 1, -1, -1))
|
|
366
|
+
out_vb = self.model.norm_out.default(None, pre_func_out(out_vb)) # Pre-process batch out values
|
|
367
|
+
out_vb = out_vb.permute(*range(out_vb.ndim - 1, -1, -1))
|
|
354
368
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
)
|
|
369
|
+
val_loss_batch = loss_fun(
|
|
370
|
+
model, input_vb.to(device), out_vb.to(device), idx=idx_b, epsilon=None, **step_kwargs
|
|
371
|
+
)[0]
|
|
372
|
+
val_loss.append(val_loss_batch.item())
|
|
373
|
+
val_loss = sum(val_loss) / len(val_loss)
|
|
360
374
|
aux["val_loss"] = val_loss
|
|
361
375
|
else:
|
|
362
376
|
aux["val_loss"] = None
|
|
@@ -618,7 +632,7 @@ class Trainor_class:
|
|
|
618
632
|
dill.dump(obj, f)
|
|
619
633
|
print(f"Object saved in {filename}")
|
|
620
634
|
|
|
621
|
-
def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None,
|
|
635
|
+
def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, device="cpu",**fn_kwargs):
|
|
622
636
|
"""NOTE: fn_kwargs defines the functions of the model
|
|
623
637
|
(e.g. final_activation, inner activation), if
|
|
624
638
|
needed to be saved/loaded on different devices/OS.
|
|
@@ -660,7 +674,7 @@ class Trainor_class:
|
|
|
660
674
|
|
|
661
675
|
model = model_cls(**kwargs)
|
|
662
676
|
model.load_state_dict(save_dict["model_state_dict"])
|
|
663
|
-
self.model = model
|
|
677
|
+
self.model = model.to(device)
|
|
664
678
|
attributes = save_dict["attr"]
|
|
665
679
|
|
|
666
680
|
for key in attributes:
|
|
@@ -849,8 +863,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
849
863
|
ft_end_type = "concat"
|
|
850
864
|
basis_call_kwargs = {}
|
|
851
865
|
|
|
866
|
+
device = ft_kwargs.get("device", "cpu")
|
|
867
|
+
|
|
852
868
|
ft_model, ft_track_params = self.fine_tune_basis(
|
|
853
|
-
None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs
|
|
869
|
+
None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs, device=device
|
|
854
870
|
) # fine tune basis
|
|
855
871
|
self.ft_track_params = ft_track_params
|
|
856
872
|
else:
|
|
@@ -858,7 +874,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
858
874
|
ft_track_params = {}
|
|
859
875
|
return model, track_params, ft_model, ft_track_params
|
|
860
876
|
|
|
861
|
-
def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, *, args, kwargs):
|
|
877
|
+
def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, device="cpu", *, args, kwargs):
|
|
862
878
|
|
|
863
879
|
if "loss" in kwargs:
|
|
864
880
|
norm_loss_ = kwargs["loss"]
|
|
@@ -884,24 +900,24 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
884
900
|
inpT = inp.permute(*range(inp.ndim - 1, -1, -1))
|
|
885
901
|
dataset = TensorDataset(inpT)
|
|
886
902
|
dataloader = DataLoader(dataset, batch_size=basis_batch_size, shuffle=False)
|
|
903
|
+
model = self.model.to(device)
|
|
887
904
|
|
|
888
905
|
all_bases = []
|
|
889
906
|
|
|
890
907
|
for inp_b in dataloader:
|
|
891
908
|
inp_bT = inp_b[0].permute(*range(inp_b[0].ndim - 1, -1, -1))
|
|
892
|
-
all_bases.append(
|
|
893
|
-
self.pre_func_inp(inp_bT), get_basis_coeffs=True, **basis_kwargs
|
|
894
|
-
)[0]
|
|
909
|
+
all_bases.append(model.latent(
|
|
910
|
+
self.pre_func_inp(inp_bT.to(device)), get_basis_coeffs=True, **basis_kwargs
|
|
911
|
+
)[0].to("cpu")
|
|
895
912
|
)
|
|
896
913
|
if end_type == "concat":
|
|
897
914
|
all_bases = torch.concatenate(all_bases, axis=1)
|
|
898
|
-
print(all_bases.shape)
|
|
899
915
|
basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
|
|
900
916
|
self.basis = basis[:, : self.track_params["k_max"]]
|
|
901
917
|
else:
|
|
902
918
|
self.basis = all_bases
|
|
903
919
|
else:
|
|
904
|
-
bas =
|
|
920
|
+
bas = model.latent(self.pre_func_inp(inp[..., 0:1].to(device)), get_basis_coeffs=True, **self.track_params)[0].to("cpu")
|
|
905
921
|
self.basis = torch.eye(bas.shape[0])
|
|
906
922
|
else:
|
|
907
923
|
self.basis = basis
|
|
@@ -932,10 +948,10 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
932
948
|
batch_size=None,
|
|
933
949
|
pre_func_inp=lambda x: x,
|
|
934
950
|
pre_func_out=lambda x: x,
|
|
935
|
-
|
|
951
|
+
device="cpu",
|
|
936
952
|
):
|
|
937
953
|
|
|
938
|
-
call_func = lambda x: self.model(pre_func_inp(x), apply_basis=self.basis, epsilon=None)
|
|
954
|
+
call_func = lambda x: self.model(pre_func_inp(x.to(device)), apply_basis=self.basis.to(device), epsilon=None).to("cpu")
|
|
939
955
|
res = super().evaluate(
|
|
940
956
|
x_train_o,
|
|
941
957
|
y_train_o,
|
|
@@ -945,6 +961,7 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
945
961
|
call_func=call_func,
|
|
946
962
|
pre_func_inp=pre_func_inp,
|
|
947
963
|
pre_func_out=pre_func_out,
|
|
964
|
+
device=device,
|
|
948
965
|
)
|
|
949
966
|
return res
|
|
950
967
|
|
|
@@ -1510,7 +1510,6 @@ class MLP_with_CNN3D_trans(torch.nn.Module):
|
|
|
1510
1510
|
self.out_after_mlp = out_after_mlp
|
|
1511
1511
|
|
|
1512
1512
|
def forward(self, x, *args, **kwargs):
|
|
1513
|
-
print(x.shape)
|
|
1514
1513
|
x = self.layers[0](x)
|
|
1515
1514
|
x = torch.reshape(x, (self.out_after_mlp, self.first_D0, self.first_D1, self.first_D2))
|
|
1516
1515
|
x = self.layers[1](x)
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: RRAEsTorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: A repo for RRAEs in PyTorch.
|
|
5
5
|
Author-email: Jad Mounayer <jad.mounayer@outlook.com>
|
|
6
6
|
License: MIT
|
|
7
7
|
License-File: LICENSE
|
|
8
|
-
License-File: LICENSE copy
|
|
9
8
|
Requires-Python: >=3.10
|
|
10
9
|
Requires-Dist: dill
|
|
11
10
|
Requires-Dist: jaxtyping
|
|
@@ -2,7 +2,7 @@ RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
|
|
|
2
2
|
RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
|
|
3
3
|
RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
|
|
4
4
|
RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
|
|
5
|
-
RRAEsTorch/AE_classes/AE_classes.py,sha256=
|
|
5
|
+
RRAEsTorch/AE_classes/AE_classes.py,sha256=oDpDzQasPbtK2L9vDLiG4VQdKH02VRCagOYT1-FAldo,18063
|
|
6
6
|
RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
|
|
7
7
|
RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=bEE9JnTo84t9w0a4kw1W74L51eLGjBB8trrlAG938RE,3182
|
|
8
8
|
RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Cr1_uP7lag6RPQC1UhN2O7RFW4BEx1cd0Z-Y6VgrWRg,2718
|
|
@@ -15,13 +15,12 @@ RRAEsTorch/tests/test_wrappers.py,sha256=Ike4IfMUx2Qic3f3_cBikgFPEU1WW5TuH1jT_r2
|
|
|
15
15
|
RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
|
|
16
16
|
RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
|
|
17
17
|
RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
|
|
18
|
-
RRAEsTorch/training_classes/training_classes.py,sha256=
|
|
18
|
+
RRAEsTorch/training_classes/training_classes.py,sha256=HU8Ksz1-2WwOMuwyGiWdkQ_vrrgBEwaeQT4avs4jd2E,37870
|
|
19
19
|
RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
|
|
20
|
-
RRAEsTorch/utilities/utilities.py,sha256=
|
|
20
|
+
RRAEsTorch/utilities/utilities.py,sha256=JfLkAPEC8fzwgM32LEcXVe0tA4C7UBgsrkuh6noUA_4,53372
|
|
21
21
|
RRAEsTorch/wrappers/__init__.py,sha256=txiLh4ylnuvPlapagz7DiAslmjllOzTqwCDL2dFr6dM,44
|
|
22
22
|
RRAEsTorch/wrappers/wrappers.py,sha256=9Rmq2RS_EkZvsg96SKrt1HFIP35sF0xyPI0goV0ujOs,9659
|
|
23
|
-
rraestorch-0.1.
|
|
24
|
-
rraestorch-0.1.
|
|
25
|
-
rraestorch-0.1.
|
|
26
|
-
rraestorch-0.1.
|
|
27
|
-
rraestorch-0.1.5.dist-info/RECORD,,
|
|
23
|
+
rraestorch-0.1.6.dist-info/METADATA,sha256=BfnB-vhx0m-d79hCd8UgZT0c8GPHFRwVn8x-M8k_h6E,3028
|
|
24
|
+
rraestorch-0.1.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
25
|
+
rraestorch-0.1.6.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
|
|
26
|
+
rraestorch-0.1.6.dist-info/RECORD,,
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
MIT License
|
|
2
|
-
|
|
3
|
-
Copyright (c) 2026 Jad Mounayer
|
|
4
|
-
|
|
5
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
-
in the Software without restriction, including without limitation the rights
|
|
8
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
-
furnished to do so, subject to the following conditions:
|
|
11
|
-
|
|
12
|
-
The above copyright notice and this permission notice shall be included in all
|
|
13
|
-
copies or substantial portions of the Software.
|
|
14
|
-
|
|
15
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
-
SOFTWARE.
|
|
File without changes
|
|
File without changes
|