imt-ring 1.6.34__py3-none-any.whl → 1.6.35__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.34
3
+ Version: 1.6.35
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -53,7 +53,7 @@ ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,94
53
53
  ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
55
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
56
- ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
+ ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
57
57
  ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.34.dist-info/METADATA,sha256=D7FXQFI8b4iXaJiWkM4MvHNqzncvFh_wn4UpVK8iqMs,4251
90
- imt_ring-1.6.34.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.34.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.34.dist-info/RECORD,,
89
+ imt_ring-1.6.35.dist-info/METADATA,sha256=mwRs6Hzb4M39Ld4PRamHOYPxtbu-Ug2kOLE03M1l8hI,4251
90
+ imt_ring-1.6.35.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.35.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.35.dist-info/RECORD,,
ring/ml/callbacks.py CHANGED
@@ -388,6 +388,14 @@ class TimingKillRunCallback(training_loop.TrainingLoopCallback):
388
388
 
389
389
 
390
390
  class CheckpointCallback(training_loop.TrainingLoopCallback):
391
+ def __init__(
392
+ self,
393
+ checkpoint_every: Optional[int] = None,
394
+ checkpoint_folder: str = "~/.ring_checkpoints",
395
+ ):
396
+ self.checkpoint_every = checkpoint_every
397
+ self.checkpoint_folder = checkpoint_folder
398
+
391
399
  def after_training_step(
392
400
  self,
393
401
  i_episode: int,
@@ -401,18 +409,26 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
401
409
  self.params = params
402
410
  self.opt_state = opt_state
403
411
 
412
+ if self.checkpoint_every is not None and (
413
+ (i_episode % self.checkpoint_every) == 0
414
+ ):
415
+ self._create_checkpoint()
416
+
417
+ def _create_checkpoint(self):
418
+ path = parse_path(
419
+ self.checkpoint_folder, ml_utils.unique_id(), extension="pickle"
420
+ )
421
+ data = {"params": self.params, "opt_state": self.opt_state}
422
+ pickle_save(
423
+ obj=jax.device_get(data),
424
+ path=path,
425
+ overwrite=True,
426
+ )
427
+
404
428
  def close(self):
405
429
  # only checkpoint if run has been killed
406
430
  if training_loop.recv_kill_run_signal():
407
- path = parse_path(
408
- "~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
409
- )
410
- data = {"params": self.params, "opt_state": self.opt_state}
411
- pickle_save(
412
- obj=jax.device_get(data),
413
- path=path,
414
- overwrite=True,
415
- )
431
+ self._create_checkpoint()
416
432
 
417
433
 
418
434
  class WandbKillRun(training_loop.TrainingLoopCallback):