RRAEsTorch 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  from RRAEsTorch.utilities import MLP_with_linear
2
- import jax.random as jrandom
3
2
  import torch.nn as nn
4
3
  from torch.func import vmap
5
4
 
@@ -25,7 +25,6 @@ import matplotlib.pyplot as plt
25
25
  from prettytable import PrettyTable
26
26
  import torch
27
27
  from torch.utils.data import TensorDataset, DataLoader
28
- from adabelief_pytorch import AdaBelief
29
28
 
30
29
  class Circular_list:
31
30
  """
@@ -221,7 +220,7 @@ class Trainor_class:
221
220
  tracker=Null_Tracker(),
222
221
  stagn_window=20,
223
222
  eps_fn=lambda lat, bs: None,
224
- optimizer=AdaBelief,
223
+ optimizer=torch.optim.Adam,
225
224
  verbatim = {
226
225
  "print_type": "std",
227
226
  "window_size" : 5,
@@ -307,7 +306,7 @@ class Trainor_class:
307
306
  for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
308
307
  try:
309
308
  t_t = 0.0 # Zero time
310
- optimizer = torch.optim.Adam(
309
+ optimizer_tr = optimizer(
311
310
  filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
312
311
  )
313
312
 
@@ -341,11 +340,11 @@ class Trainor_class:
341
340
  step_kwargs = merge_dicts(loss_kwargs, track_params)
342
341
 
343
342
  # Compute loss
344
- loss, model, optimizer, (aux, extra_track) = make_step(
343
+ loss, model, optimizer_tr, (aux, extra_track) = make_step(
345
344
  model,
346
345
  input_b,
347
346
  out_b,
348
- optimizer,
347
+ optimizer_tr,
349
348
  idx_b,
350
349
  epsilon,
351
350
  **step_kwargs,
@@ -391,7 +390,7 @@ class Trainor_class:
391
390
  self.del_file(f"checkpoint_k_{track_params.get('k_max')}")
392
391
  model = self.model
393
392
 
394
- optimizer = torch.optim.Adam(
393
+ optimizer_tr = optimizer(
395
394
  filter(lambda p: p.requires_grad, model.parameters()), lr=lr
396
395
  )
397
396
 
@@ -974,4 +973,4 @@ class RRAE_Trainor_class(AE_Trainor_class):
974
973
  latent_func=call_func,
975
974
  decode_func=decode_func,
976
975
  norm_out_func=norm_out_func,
977
- )
976
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: RRAEsTorch
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: A repo for RRAEs in PyTorch.
5
5
  Author-email: Jad Mounayer <jad.mounayer@outlook.com>
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
2
2
  RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
3
- RRAEsTorch/AE_base/AE_base.py,sha256=4nnuv3VbrakkvpmcT9NqRBuFUZsC1BGMpQWdpxEcGyk,3398
3
+ RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
4
4
  RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
5
5
  RRAEsTorch/AE_classes/AE_classes.py,sha256=LTIKobJ5FXOPwGZLRhs82NKeuQzvXzj2YzdhuBLi2i0,18108
6
6
  RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
@@ -15,13 +15,13 @@ RRAEsTorch/tests/test_wrappers.py,sha256=Ike4IfMUx2Qic3f3_cBikgFPEU1WW5TuH1jT_r2
15
15
  RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
16
16
  RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
17
17
  RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
18
- RRAEsTorch/training_classes/training_classes.py,sha256=l1wPWZghYwzz_wxgrBi2fmOumgjbCz3qaaIcDks-GCg,36668
18
+ RRAEsTorch/training_classes/training_classes.py,sha256=a7JjhCrH7s7VmVCsvKj768Ciq-tdbh6E_B9aG1kw7vc,36634
19
19
  RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
20
20
  RRAEsTorch/utilities/utilities.py,sha256=FzJWV9oFPF9sL9MC2m7euMqMKxCuLUEukzLfU0cF2to,53396
21
21
  RRAEsTorch/wrappers/__init__.py,sha256=txiLh4ylnuvPlapagz7DiAslmjllOzTqwCDL2dFr6dM,44
22
22
  RRAEsTorch/wrappers/wrappers.py,sha256=9Rmq2RS_EkZvsg96SKrt1HFIP35sF0xyPI0goV0ujOs,9659
23
- rraestorch-0.1.0.dist-info/METADATA,sha256=Wm_NiD_XGPLm8bHVI-U2Rt7aDnNAQ1tkoeeo3pzWarQ,3055
24
- rraestorch-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
- rraestorch-0.1.0.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
26
- rraestorch-0.1.0.dist-info/licenses/LICENSE copy,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
27
- rraestorch-0.1.0.dist-info/RECORD,,
23
+ rraestorch-0.1.2.dist-info/METADATA,sha256=Gf8TImFT4nLlgXlaHjJPvzlj75Q7nBVrOT29eanTubA,3055
24
+ rraestorch-0.1.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
+ rraestorch-0.1.2.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
26
+ rraestorch-0.1.2.dist-info/licenses/LICENSE copy,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
27
+ rraestorch-0.1.2.dist-info/RECORD,,