imt-ring 1.6.34__py3-none-any.whl → 1.6.36__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.34
3
+ Version: 1.6.36
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -53,13 +53,13 @@ ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,94
53
53
  ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
55
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
56
- ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
+ ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
57
57
  ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
61
61
  ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
62
- ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
+ ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
63
63
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
64
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
65
65
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.34.dist-info/METADATA,sha256=D7FXQFI8b4iXaJiWkM4MvHNqzncvFh_wn4UpVK8iqMs,4251
90
- imt_ring-1.6.34.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.34.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.34.dist-info/RECORD,,
89
+ imt_ring-1.6.36.dist-info/METADATA,sha256=a-uW_s0jWJEBX9kW1q36Br4SNXPK7eGIVhlsyKDWruE,4251
90
+ imt_ring-1.6.36.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.36.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.36.dist-info/RECORD,,
ring/ml/callbacks.py CHANGED
@@ -388,6 +388,14 @@ class TimingKillRunCallback(training_loop.TrainingLoopCallback):
388
388
 
389
389
 
390
390
  class CheckpointCallback(training_loop.TrainingLoopCallback):
391
+ def __init__(
392
+ self,
393
+ checkpoint_every: Optional[int] = None,
394
+ checkpoint_folder: str = "~/.ring_checkpoints",
395
+ ):
396
+ self.checkpoint_every = checkpoint_every
397
+ self.checkpoint_folder = checkpoint_folder
398
+
391
399
  def after_training_step(
392
400
  self,
393
401
  i_episode: int,
@@ -401,18 +409,26 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
401
409
  self.params = params
402
410
  self.opt_state = opt_state
403
411
 
412
+ if self.checkpoint_every is not None and (
413
+ (i_episode % self.checkpoint_every) == 0
414
+ ):
415
+ self._create_checkpoint()
416
+
417
+ def _create_checkpoint(self):
418
+ path = parse_path(
419
+ self.checkpoint_folder, ml_utils.unique_id(), extension="pickle"
420
+ )
421
+ data = {"params": self.params, "opt_state": self.opt_state}
422
+ pickle_save(
423
+ obj=jax.device_get(data),
424
+ path=path,
425
+ overwrite=True,
426
+ )
427
+
404
428
  def close(self):
405
429
  # only checkpoint if run has been killed
406
430
  if training_loop.recv_kill_run_signal():
407
- path = parse_path(
408
- "~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
409
- )
410
- data = {"params": self.params, "opt_state": self.opt_state}
411
- pickle_save(
412
- obj=jax.device_get(data),
413
- path=path,
414
- overwrite=True,
415
- )
431
+ self._create_checkpoint()
416
432
 
417
433
 
418
434
  class WandbKillRun(training_loop.TrainingLoopCallback):
ring/ml/training_loop.py CHANGED
@@ -2,11 +2,12 @@ import random
2
2
  from typing import Optional
3
3
 
4
4
  import jax
5
- from ring.algorithms import Generator
6
- from ring.ml import ml_utils
7
5
  import tqdm
8
6
  import tree_utils
9
7
 
8
+ from ring.algorithms import Generator
9
+ from ring.ml import ml_utils
10
+
10
11
  _KILL_RUN = False
11
12
 
12
13
 
@@ -83,14 +84,20 @@ class TrainingLoop:
83
84
  # reset the kill_run flag from previous runs
84
85
  send_kill_run_signal(value=False)
85
86
 
86
- for _ in tqdm.tqdm(range(n_episodes)):
87
- self.step()
88
-
89
- if recv_kill_run_signal():
90
- break
87
+ try:
88
+ for _ in tqdm.tqdm(range(n_episodes)):
89
+ self.step()
91
90
 
92
- if close_afterwards:
93
- self.close()
91
+ if recv_kill_run_signal():
92
+ break
93
+ except Exception as e:
94
+ print(
95
+ "Exception occured, attemping to .close() all callbacks before raising."
96
+ )
97
+ raise e
98
+ finally:
99
+ if close_afterwards:
100
+ self.close()
94
101
 
95
102
  return recv_kill_run_signal()
96
103