RRAEsTorch 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/training_classes/training_classes.py +6 -7
- {rraestorch-0.1.0.dist-info → rraestorch-0.1.1.dist-info}/METADATA +1 -1
- {rraestorch-0.1.0.dist-info → rraestorch-0.1.1.dist-info}/RECORD +6 -6
- {rraestorch-0.1.0.dist-info → rraestorch-0.1.1.dist-info}/WHEEL +0 -0
- {rraestorch-0.1.0.dist-info → rraestorch-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {rraestorch-0.1.0.dist-info → rraestorch-0.1.1.dist-info}/licenses/LICENSE copy +0 -0
|
@@ -25,7 +25,6 @@ import matplotlib.pyplot as plt
|
|
|
25
25
|
from prettytable import PrettyTable
|
|
26
26
|
import torch
|
|
27
27
|
from torch.utils.data import TensorDataset, DataLoader
|
|
28
|
-
from adabelief_pytorch import AdaBelief
|
|
29
28
|
|
|
30
29
|
class Circular_list:
|
|
31
30
|
"""
|
|
@@ -221,7 +220,7 @@ class Trainor_class:
|
|
|
221
220
|
tracker=Null_Tracker(),
|
|
222
221
|
stagn_window=20,
|
|
223
222
|
eps_fn=lambda lat, bs: None,
|
|
224
|
-
optimizer=
|
|
223
|
+
optimizer=torch.optim.Adam,
|
|
225
224
|
verbatim = {
|
|
226
225
|
"print_type": "std",
|
|
227
226
|
"window_size" : 5,
|
|
@@ -307,7 +306,7 @@ class Trainor_class:
|
|
|
307
306
|
for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
|
|
308
307
|
try:
|
|
309
308
|
t_t = 0.0 # Zero time
|
|
310
|
-
|
|
309
|
+
optimizer_tr = optimizer(
|
|
311
310
|
filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
|
|
312
311
|
)
|
|
313
312
|
|
|
@@ -341,11 +340,11 @@ class Trainor_class:
|
|
|
341
340
|
step_kwargs = merge_dicts(loss_kwargs, track_params)
|
|
342
341
|
|
|
343
342
|
# Compute loss
|
|
344
|
-
loss, model,
|
|
343
|
+
loss, model, optimizer_tr, (aux, extra_track) = make_step(
|
|
345
344
|
model,
|
|
346
345
|
input_b,
|
|
347
346
|
out_b,
|
|
348
|
-
|
|
347
|
+
optimizer_tr,
|
|
349
348
|
idx_b,
|
|
350
349
|
epsilon,
|
|
351
350
|
**step_kwargs,
|
|
@@ -391,7 +390,7 @@ class Trainor_class:
|
|
|
391
390
|
self.del_file(f"checkpoint_k_{track_params.get('k_max')}")
|
|
392
391
|
model = self.model
|
|
393
392
|
|
|
394
|
-
|
|
393
|
+
optimizer_tr = optimizer(
|
|
395
394
|
filter(lambda p: p.requires_grad, model.parameters()), lr=lr
|
|
396
395
|
)
|
|
397
396
|
|
|
@@ -974,4 +973,4 @@ class RRAE_Trainor_class(AE_Trainor_class):
|
|
|
974
973
|
latent_func=call_func,
|
|
975
974
|
decode_func=decode_func,
|
|
976
975
|
norm_out_func=norm_out_func,
|
|
977
|
-
)
|
|
976
|
+
)
|
|
@@ -15,13 +15,13 @@ RRAEsTorch/tests/test_wrappers.py,sha256=Ike4IfMUx2Qic3f3_cBikgFPEU1WW5TuH1jT_r2
|
|
|
15
15
|
RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
|
|
16
16
|
RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
|
|
17
17
|
RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
|
|
18
|
-
RRAEsTorch/training_classes/training_classes.py,sha256=
|
|
18
|
+
RRAEsTorch/training_classes/training_classes.py,sha256=a7JjhCrH7s7VmVCsvKj768Ciq-tdbh6E_B9aG1kw7vc,36634
|
|
19
19
|
RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
|
|
20
20
|
RRAEsTorch/utilities/utilities.py,sha256=FzJWV9oFPF9sL9MC2m7euMqMKxCuLUEukzLfU0cF2to,53396
|
|
21
21
|
RRAEsTorch/wrappers/__init__.py,sha256=txiLh4ylnuvPlapagz7DiAslmjllOzTqwCDL2dFr6dM,44
|
|
22
22
|
RRAEsTorch/wrappers/wrappers.py,sha256=9Rmq2RS_EkZvsg96SKrt1HFIP35sF0xyPI0goV0ujOs,9659
|
|
23
|
-
rraestorch-0.1.
|
|
24
|
-
rraestorch-0.1.
|
|
25
|
-
rraestorch-0.1.
|
|
26
|
-
rraestorch-0.1.
|
|
27
|
-
rraestorch-0.1.
|
|
23
|
+
rraestorch-0.1.1.dist-info/METADATA,sha256=oyhQOf9j7F3MbYYqclhO4glRYRDvLV-3UwwuTlQZark,3055
|
|
24
|
+
rraestorch-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
25
|
+
rraestorch-0.1.1.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
|
|
26
|
+
rraestorch-0.1.1.dist-info/licenses/LICENSE copy,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
|
|
27
|
+
rraestorch-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|