imt-ring 1.6.33__tar.gz → 1.6.35__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. {imt_ring-1.6.33 → imt_ring-1.6.35}/PKG-INFO +1 -1
  2. {imt_ring-1.6.33 → imt_ring-1.6.35}/pyproject.toml +1 -1
  3. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/callbacks.py +25 -9
  5. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/train.py +5 -4
  6. {imt_ring-1.6.33 → imt_ring-1.6.35}/readme.md +0 -0
  7. {imt_ring-1.6.33 → imt_ring-1.6.35}/setup.cfg +0 -0
  8. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  9. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  10. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/imt_ring.egg-info/requires.txt +0 -0
  11. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/imt_ring.egg-info/top_level.txt +0 -0
  12. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/__init__.py +0 -0
  13. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algebra.py +0 -0
  14. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/__init__.py +0 -0
  15. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/_random.py +0 -0
  16. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  17. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  18. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  19. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
  20. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  21. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/dynamics.py +0 -0
  22. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/__init__.py +0 -0
  23. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/base.py +0 -0
  24. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/batch.py +0 -0
  25. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/finalize_fns.py +0 -0
  26. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  27. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/pd_control.py +0 -0
  28. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/setup_fns.py +0 -0
  29. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/generator/types.py +0 -0
  30. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/jcalc.py +0 -0
  31. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/kinematics.py +0 -0
  32. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/algorithms/sensors.py +0 -0
  33. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/base.py +0 -0
  34. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/__init__.py +0 -0
  35. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/branched.xml +0 -0
  36. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  37. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  38. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  39. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/inv_pendulum.xml +0 -0
  40. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  41. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/spherical_stiff.xml +0 -0
  42. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/symmetric.xml +0 -0
  43. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_all_1.xml +0 -0
  44. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_all_2.xml +0 -0
  45. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  46. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_control.xml +0 -0
  47. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  48. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_free.xml +0 -0
  49. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_kinematics.xml +0 -0
  50. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  51. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  52. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_randomize_position.xml +0 -0
  53. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_sensors.xml +0 -0
  54. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  55. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/examples.py +0 -0
  56. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/test_examples.py +0 -0
  57. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/xml/__init__.py +0 -0
  58. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/xml/abstract.py +0 -0
  59. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/xml/from_xml.py +0 -0
  60. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/xml/test_from_xml.py +0 -0
  61. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/xml/test_to_xml.py +0 -0
  62. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/io/xml/to_xml.py +0 -0
  63. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/maths.py +0 -0
  64. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/__init__.py +0 -0
  65. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/base.py +0 -0
  66. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/ml_utils.py +0 -0
  67. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/optimizer.py +0 -0
  68. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  69. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  70. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/ringnet.py +0 -0
  71. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/rnno_v1.py +0 -0
  72. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/ml/training_loop.py +0 -0
  73. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/rendering/__init__.py +0 -0
  74. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/rendering/base_render.py +0 -0
  75. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/rendering/mujoco_render.py +0 -0
  76. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/rendering/vispy_render.py +0 -0
  77. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/rendering/vispy_visuals.py +0 -0
  78. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/sim2real/__init__.py +0 -0
  79. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/sim2real/sim2real.py +0 -0
  80. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/spatial.py +0 -0
  81. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/sys_composer/__init__.py +0 -0
  82. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/sys_composer/delete_sys.py +0 -0
  83. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/sys_composer/inject_sys.py +0 -0
  84. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/sys_composer/morph_sys.py +0 -0
  85. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/__init__.py +0 -0
  86. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/backend.py +0 -0
  87. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/batchsize.py +0 -0
  88. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/colab.py +0 -0
  89. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/dataloader.py +0 -0
  90. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/dataloader_torch.py +0 -0
  91. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/hdf5.py +0 -0
  92. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/normalizer.py +0 -0
  93. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/path.py +0 -0
  94. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/randomize_sys.py +0 -0
  95. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  96. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  97. {imt_ring-1.6.33 → imt_ring-1.6.35}/src/ring/utils/utils.py +0 -0
  98. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_algebra.py +0 -0
  99. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_base.py +0 -0
  100. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_custom_joints.py +0 -0
  101. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_dynamics.py +0 -0
  102. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_generator.py +0 -0
  103. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_jcalc.py +0 -0
  104. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_jit.py +0 -0
  105. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_kinematics.py +0 -0
  106. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_maths.py +0 -0
  107. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_ml_utils.py +0 -0
  108. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_motion_artifacts.py +0 -0
  109. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_pd_control.py +0 -0
  110. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_quickstart_example.py +0 -0
  111. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_random.py +0 -0
  112. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_randomize.py +0 -0
  113. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_rcmg.py +0 -0
  114. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_render.py +0 -0
  115. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_sensors.py +0 -0
  116. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_sim2real.py +0 -0
  117. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_sys_composer.py +0 -0
  118. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_train.py +0 -0
  119. {imt_ring-1.6.33 → imt_ring-1.6.35}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.33
3
+ Version: 1.6.35
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.33"
7
+ version = "1.6.35"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.33
3
+ Version: 1.6.35
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -388,6 +388,14 @@ class TimingKillRunCallback(training_loop.TrainingLoopCallback):
388
388
 
389
389
 
390
390
  class CheckpointCallback(training_loop.TrainingLoopCallback):
391
+ def __init__(
392
+ self,
393
+ checkpoint_every: Optional[int] = None,
394
+ checkpoint_folder: str = "~/.ring_checkpoints",
395
+ ):
396
+ self.checkpoint_every = checkpoint_every
397
+ self.checkpoint_folder = checkpoint_folder
398
+
391
399
  def after_training_step(
392
400
  self,
393
401
  i_episode: int,
@@ -401,18 +409,26 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
401
409
  self.params = params
402
410
  self.opt_state = opt_state
403
411
 
412
+ if self.checkpoint_every is not None and (
413
+ (i_episode % self.checkpoint_every) == 0
414
+ ):
415
+ self._create_checkpoint()
416
+
417
+ def _create_checkpoint(self):
418
+ path = parse_path(
419
+ self.checkpoint_folder, ml_utils.unique_id(), extension="pickle"
420
+ )
421
+ data = {"params": self.params, "opt_state": self.opt_state}
422
+ pickle_save(
423
+ obj=jax.device_get(data),
424
+ path=path,
425
+ overwrite=True,
426
+ )
427
+
404
428
  def close(self):
405
429
  # only checkpoint if run has been killed
406
430
  if training_loop.recv_kill_run_signal():
407
- path = parse_path(
408
- "~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
409
- )
410
- data = {"params": self.params, "opt_state": self.opt_state}
411
- pickle_save(
412
- obj=jax.device_get(data),
413
- path=path,
414
- overwrite=True,
415
- )
431
+ self._create_checkpoint()
416
432
 
417
433
 
418
434
  class WandbKillRun(training_loop.TrainingLoopCallback):
@@ -39,6 +39,7 @@ def _build_step_fn(
39
39
  filter: ml_base.AbstractFilter,
40
40
  optimizer,
41
41
  tbp,
42
+ skip_first_tbp_batch,
42
43
  ):
43
44
  """Build step function that optimizes filter parameters based on `metric_fn`.
44
45
  `initial_state` has shape (pmap, vmap, state_dim)"""
@@ -89,6 +90,8 @@ def _build_step_fn(
89
90
  ):
90
91
  (loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
91
92
  debug_grads.append(grads)
93
+ if skip_first_tbp_batch and i == 0:
94
+ continue
92
95
  state = jax.lax.stop_gradient(state)
93
96
  params, opt_state = apply_grads(grads, params, opt_state)
94
97
 
@@ -119,6 +122,7 @@ def train_fn(
119
122
  loss_fn: LOSS_FN = _default_loss_fn,
120
123
  metrices: Optional[METRICES] = _default_metrices,
121
124
  link_names: Optional[list[str]] = None,
125
+ skip_first_tbp_batch: bool = False,
122
126
  ) -> bool:
123
127
  """Trains RNNO
124
128
 
@@ -161,10 +165,7 @@ def train_fn(
161
165
  opt_state = optimizer.init(filter_params)
162
166
 
163
167
  step_fn = _build_step_fn(
164
- loss_fn,
165
- filter,
166
- optimizer,
167
- tbp=tbp,
168
+ loss_fn, filter, optimizer, tbp=tbp, skip_first_tbp_batch=skip_first_tbp_batch
168
169
  )
169
170
 
170
171
  # always log, because we also want `i_epsiode` to be logged in wandb
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes