imt-ring 1.6.34__py3-none-any.whl → 1.6.36__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.34.dist-info → imt_ring-1.6.36.dist-info}/METADATA +1 -1
- {imt_ring-1.6.34.dist-info → imt_ring-1.6.36.dist-info}/RECORD +6 -6
- ring/ml/callbacks.py +25 -9
- ring/ml/training_loop.py +16 -9
- {imt_ring-1.6.34.dist-info → imt_ring-1.6.36.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.34.dist-info → imt_ring-1.6.36.dist-info}/top_level.txt +0 -0
@@ -53,13 +53,13 @@ ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,94
|
|
53
53
|
ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
|
54
54
|
ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
|
55
55
|
ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
|
56
|
-
ring/ml/callbacks.py,sha256=
|
56
|
+
ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
|
57
57
|
ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
|
58
58
|
ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
|
59
59
|
ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
|
60
60
|
ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
|
61
61
|
ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
|
62
|
-
ring/ml/training_loop.py,sha256=
|
62
|
+
ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
|
63
63
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
64
64
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
65
65
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
86
86
|
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
87
87
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
88
88
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
91
|
-
imt_ring-1.6.
|
92
|
-
imt_ring-1.6.
|
89
|
+
imt_ring-1.6.36.dist-info/METADATA,sha256=a-uW_s0jWJEBX9kW1q36Br4SNXPK7eGIVhlsyKDWruE,4251
|
90
|
+
imt_ring-1.6.36.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
91
|
+
imt_ring-1.6.36.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
92
|
+
imt_ring-1.6.36.dist-info/RECORD,,
|
ring/ml/callbacks.py
CHANGED
@@ -388,6 +388,14 @@ class TimingKillRunCallback(training_loop.TrainingLoopCallback):
|
|
388
388
|
|
389
389
|
|
390
390
|
class CheckpointCallback(training_loop.TrainingLoopCallback):
|
391
|
+
def __init__(
|
392
|
+
self,
|
393
|
+
checkpoint_every: Optional[int] = None,
|
394
|
+
checkpoint_folder: str = "~/.ring_checkpoints",
|
395
|
+
):
|
396
|
+
self.checkpoint_every = checkpoint_every
|
397
|
+
self.checkpoint_folder = checkpoint_folder
|
398
|
+
|
391
399
|
def after_training_step(
|
392
400
|
self,
|
393
401
|
i_episode: int,
|
@@ -401,18 +409,26 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
|
|
401
409
|
self.params = params
|
402
410
|
self.opt_state = opt_state
|
403
411
|
|
412
|
+
if self.checkpoint_every is not None and (
|
413
|
+
(i_episode % self.checkpoint_every) == 0
|
414
|
+
):
|
415
|
+
self._create_checkpoint()
|
416
|
+
|
417
|
+
def _create_checkpoint(self):
|
418
|
+
path = parse_path(
|
419
|
+
self.checkpoint_folder, ml_utils.unique_id(), extension="pickle"
|
420
|
+
)
|
421
|
+
data = {"params": self.params, "opt_state": self.opt_state}
|
422
|
+
pickle_save(
|
423
|
+
obj=jax.device_get(data),
|
424
|
+
path=path,
|
425
|
+
overwrite=True,
|
426
|
+
)
|
427
|
+
|
404
428
|
def close(self):
|
405
429
|
# only checkpoint if run has been killed
|
406
430
|
if training_loop.recv_kill_run_signal():
|
407
|
-
|
408
|
-
"~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
|
409
|
-
)
|
410
|
-
data = {"params": self.params, "opt_state": self.opt_state}
|
411
|
-
pickle_save(
|
412
|
-
obj=jax.device_get(data),
|
413
|
-
path=path,
|
414
|
-
overwrite=True,
|
415
|
-
)
|
431
|
+
self._create_checkpoint()
|
416
432
|
|
417
433
|
|
418
434
|
class WandbKillRun(training_loop.TrainingLoopCallback):
|
ring/ml/training_loop.py
CHANGED
@@ -2,11 +2,12 @@ import random
|
|
2
2
|
from typing import Optional
|
3
3
|
|
4
4
|
import jax
|
5
|
-
from ring.algorithms import Generator
|
6
|
-
from ring.ml import ml_utils
|
7
5
|
import tqdm
|
8
6
|
import tree_utils
|
9
7
|
|
8
|
+
from ring.algorithms import Generator
|
9
|
+
from ring.ml import ml_utils
|
10
|
+
|
10
11
|
_KILL_RUN = False
|
11
12
|
|
12
13
|
|
@@ -83,14 +84,20 @@ class TrainingLoop:
|
|
83
84
|
# reset the kill_run flag from previous runs
|
84
85
|
send_kill_run_signal(value=False)
|
85
86
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
if recv_kill_run_signal():
|
90
|
-
break
|
87
|
+
try:
|
88
|
+
for _ in tqdm.tqdm(range(n_episodes)):
|
89
|
+
self.step()
|
91
90
|
|
92
|
-
|
93
|
-
|
91
|
+
if recv_kill_run_signal():
|
92
|
+
break
|
93
|
+
except Exception as e:
|
94
|
+
print(
|
95
|
+
"Exception occured, attemping to .close() all callbacks before raising."
|
96
|
+
)
|
97
|
+
raise e
|
98
|
+
finally:
|
99
|
+
if close_afterwards:
|
100
|
+
self.close()
|
94
101
|
|
95
102
|
return recv_kill_run_signal()
|
96
103
|
|
File without changes
|
File without changes
|