imt-ring 1.6.33__py3-none-any.whl → 1.6.35__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.35.dist-info}/METADATA +1 -1
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.35.dist-info}/RECORD +6 -6
- ring/ml/callbacks.py +25 -9
- ring/ml/train.py +5 -4
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.35.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.33.dist-info → imt_ring-1.6.35.dist-info}/top_level.txt +0 -0
@@ -53,12 +53,12 @@ ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,94
|
|
53
53
|
ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
|
54
54
|
ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
|
55
55
|
ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
|
56
|
-
ring/ml/callbacks.py,sha256=
|
56
|
+
ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
|
57
57
|
ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
|
58
58
|
ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
|
59
59
|
ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
|
60
60
|
ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
|
61
|
-
ring/ml/train.py,sha256
|
61
|
+
ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
|
62
62
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
63
63
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
64
64
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
86
86
|
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
87
87
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
88
88
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
91
|
-
imt_ring-1.6.
|
92
|
-
imt_ring-1.6.
|
89
|
+
imt_ring-1.6.35.dist-info/METADATA,sha256=mwRs6Hzb4M39Ld4PRamHOYPxtbu-Ug2kOLE03M1l8hI,4251
|
90
|
+
imt_ring-1.6.35.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
91
|
+
imt_ring-1.6.35.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
92
|
+
imt_ring-1.6.35.dist-info/RECORD,,
|
ring/ml/callbacks.py
CHANGED
@@ -388,6 +388,14 @@ class TimingKillRunCallback(training_loop.TrainingLoopCallback):
|
|
388
388
|
|
389
389
|
|
390
390
|
class CheckpointCallback(training_loop.TrainingLoopCallback):
|
391
|
+
def __init__(
|
392
|
+
self,
|
393
|
+
checkpoint_every: Optional[int] = None,
|
394
|
+
checkpoint_folder: str = "~/.ring_checkpoints",
|
395
|
+
):
|
396
|
+
self.checkpoint_every = checkpoint_every
|
397
|
+
self.checkpoint_folder = checkpoint_folder
|
398
|
+
|
391
399
|
def after_training_step(
|
392
400
|
self,
|
393
401
|
i_episode: int,
|
@@ -401,18 +409,26 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
|
|
401
409
|
self.params = params
|
402
410
|
self.opt_state = opt_state
|
403
411
|
|
412
|
+
if self.checkpoint_every is not None and (
|
413
|
+
(i_episode % self.checkpoint_every) == 0
|
414
|
+
):
|
415
|
+
self._create_checkpoint()
|
416
|
+
|
417
|
+
def _create_checkpoint(self):
|
418
|
+
path = parse_path(
|
419
|
+
self.checkpoint_folder, ml_utils.unique_id(), extension="pickle"
|
420
|
+
)
|
421
|
+
data = {"params": self.params, "opt_state": self.opt_state}
|
422
|
+
pickle_save(
|
423
|
+
obj=jax.device_get(data),
|
424
|
+
path=path,
|
425
|
+
overwrite=True,
|
426
|
+
)
|
427
|
+
|
404
428
|
def close(self):
|
405
429
|
# only checkpoint if run has been killed
|
406
430
|
if training_loop.recv_kill_run_signal():
|
407
|
-
|
408
|
-
"~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
|
409
|
-
)
|
410
|
-
data = {"params": self.params, "opt_state": self.opt_state}
|
411
|
-
pickle_save(
|
412
|
-
obj=jax.device_get(data),
|
413
|
-
path=path,
|
414
|
-
overwrite=True,
|
415
|
-
)
|
431
|
+
self._create_checkpoint()
|
416
432
|
|
417
433
|
|
418
434
|
class WandbKillRun(training_loop.TrainingLoopCallback):
|
ring/ml/train.py
CHANGED
@@ -39,6 +39,7 @@ def _build_step_fn(
|
|
39
39
|
filter: ml_base.AbstractFilter,
|
40
40
|
optimizer,
|
41
41
|
tbp,
|
42
|
+
skip_first_tbp_batch,
|
42
43
|
):
|
43
44
|
"""Build step function that optimizes filter parameters based on `metric_fn`.
|
44
45
|
`initial_state` has shape (pmap, vmap, state_dim)"""
|
@@ -89,6 +90,8 @@ def _build_step_fn(
|
|
89
90
|
):
|
90
91
|
(loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
|
91
92
|
debug_grads.append(grads)
|
93
|
+
if skip_first_tbp_batch and i == 0:
|
94
|
+
continue
|
92
95
|
state = jax.lax.stop_gradient(state)
|
93
96
|
params, opt_state = apply_grads(grads, params, opt_state)
|
94
97
|
|
@@ -119,6 +122,7 @@ def train_fn(
|
|
119
122
|
loss_fn: LOSS_FN = _default_loss_fn,
|
120
123
|
metrices: Optional[METRICES] = _default_metrices,
|
121
124
|
link_names: Optional[list[str]] = None,
|
125
|
+
skip_first_tbp_batch: bool = False,
|
122
126
|
) -> bool:
|
123
127
|
"""Trains RNNO
|
124
128
|
|
@@ -161,10 +165,7 @@ def train_fn(
|
|
161
165
|
opt_state = optimizer.init(filter_params)
|
162
166
|
|
163
167
|
step_fn = _build_step_fn(
|
164
|
-
loss_fn,
|
165
|
-
filter,
|
166
|
-
optimizer,
|
167
|
-
tbp=tbp,
|
168
|
+
loss_fn, filter, optimizer, tbp=tbp, skip_first_tbp_batch=skip_first_tbp_batch
|
168
169
|
)
|
169
170
|
|
170
171
|
# always log, because we also want `i_epsiode` to be logged in wandb
|
File without changes
|
File without changes
|