imt-ring 1.3.6__tar.gz → 1.3.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. {imt_ring-1.3.6 → imt_ring-1.3.8}/PKG-INFO +1 -1
  2. {imt_ring-1.3.6 → imt_ring-1.3.8}/pyproject.toml +1 -1
  3. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/imt_ring.egg-info/SOURCES.txt +1 -0
  5. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/batch.py +1 -1
  6. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/motion_artifacts.py +6 -2
  7. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/__init__.py +15 -3
  8. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/callbacks.py +3 -2
  9. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/ringnet.py +10 -2
  10. imt_ring-1.3.8/src/ring/ml/rnno_v1.py +37 -0
  11. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/train.py +9 -8
  12. {imt_ring-1.3.6 → imt_ring-1.3.8}/readme.md +0 -0
  13. {imt_ring-1.3.6 → imt_ring-1.3.8}/setup.cfg +0 -0
  14. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  15. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/imt_ring.egg-info/requires.txt +0 -0
  16. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/imt_ring.egg-info/top_level.txt +0 -0
  17. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/__init__.py +0 -0
  18. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algebra.py +0 -0
  19. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/__init__.py +0 -0
  20. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/_random.py +0 -0
  21. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  22. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  23. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  24. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  25. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/dynamics.py +0 -0
  26. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/__init__.py +0 -0
  27. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/base.py +0 -0
  28. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/pd_control.py +0 -0
  29. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/randomize.py +0 -0
  30. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/transforms.py +0 -0
  31. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/generator/types.py +0 -0
  32. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/jcalc.py +0 -0
  33. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/kinematics.py +0 -0
  34. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/algorithms/sensors.py +0 -0
  35. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/base.py +0 -0
  36. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/__init__.py +0 -0
  37. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/branched.xml +0 -0
  38. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  39. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  40. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  41. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/inv_pendulum.xml +0 -0
  42. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  43. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/spherical_stiff.xml +0 -0
  44. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/symmetric.xml +0 -0
  45. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_all_1.xml +0 -0
  46. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_all_2.xml +0 -0
  47. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  48. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_control.xml +0 -0
  49. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  50. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_free.xml +0 -0
  51. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_kinematics.xml +0 -0
  52. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  53. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  54. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_randomize_position.xml +0 -0
  55. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_sensors.xml +0 -0
  56. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  57. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/examples.py +0 -0
  58. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/test_examples.py +0 -0
  59. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/xml/__init__.py +0 -0
  60. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/xml/abstract.py +0 -0
  61. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/xml/from_xml.py +0 -0
  62. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/xml/test_from_xml.py +0 -0
  63. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/xml/test_to_xml.py +0 -0
  64. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/io/xml/to_xml.py +0 -0
  65. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/maths.py +0 -0
  66. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/base.py +0 -0
  67. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/ml_utils.py +0 -0
  68. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/optimizer.py +0 -0
  69. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  70. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/ml/training_loop.py +0 -0
  71. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/rendering/__init__.py +0 -0
  72. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/rendering/base_render.py +0 -0
  73. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/rendering/mujoco_render.py +0 -0
  74. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/rendering/vispy_render.py +0 -0
  75. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/rendering/vispy_visuals.py +0 -0
  76. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/sim2real/__init__.py +0 -0
  77. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/sim2real/sim2real.py +0 -0
  78. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/spatial.py +0 -0
  79. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/sys_composer/__init__.py +0 -0
  80. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/sys_composer/delete_sys.py +0 -0
  81. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/sys_composer/inject_sys.py +0 -0
  82. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/sys_composer/morph_sys.py +0 -0
  83. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/utils/__init__.py +0 -0
  84. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/utils/batchsize.py +0 -0
  85. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/utils/colab.py +0 -0
  86. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/utils/hdf5.py +0 -0
  87. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/utils/normalizer.py +0 -0
  88. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/utils/path.py +0 -0
  89. {imt_ring-1.3.6 → imt_ring-1.3.8}/src/ring/utils/utils.py +0 -0
  90. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_algebra.py +0 -0
  91. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_base.py +0 -0
  92. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_custom_joints.py +0 -0
  93. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_dynamics.py +0 -0
  94. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_generator.py +0 -0
  95. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_jcalc.py +0 -0
  96. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_jit.py +0 -0
  97. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_kinematics.py +0 -0
  98. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_maths.py +0 -0
  99. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_ml_utils.py +0 -0
  100. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_motion_artifacts.py +0 -0
  101. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_pd_control.py +0 -0
  102. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_random.py +0 -0
  103. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_randomize.py +0 -0
  104. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_rcmg.py +0 -0
  105. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_render.py +0 -0
  106. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_sensors.py +0 -0
  107. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_sim2real.py +0 -0
  108. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_sys_composer.py +0 -0
  109. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_train.py +0 -0
  110. {imt_ring-1.3.6 → imt_ring-1.3.8}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.6
3
+ Version: 1.3.8
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.3.6"
7
+ version = "1.3.8"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.6
3
+ Version: 1.3.8
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -64,6 +64,7 @@ src/ring/ml/callbacks.py
64
64
  src/ring/ml/ml_utils.py
65
65
  src/ring/ml/optimizer.py
66
66
  src/ring/ml/ringnet.py
67
+ src/ring/ml/rnno_v1.py
67
68
  src/ring/ml/train.py
68
69
  src/ring/ml/training_loop.py
69
70
  src/ring/ml/params/0x13e3518065c21cd8.pickle
@@ -147,7 +147,7 @@ def _data_fn_from_paths(
147
147
  paths = [utils.parse_path(p, mkdir=False) for p in paths]
148
148
 
149
149
  extensions = list(set([Path(p).suffix for p in paths]))
150
- assert len(extensions) == 1
150
+ assert len(extensions) == 1, f"{extensions}"
151
151
 
152
152
  if extensions[0] == ".h5":
153
153
  N = sum([utils.hdf5_load_length(p) for p in paths])
@@ -180,9 +180,13 @@ def setup_fn_randomize_damping_stiffness_factory(
180
180
  link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
181
181
  link_damping = link_damping.at[slice].set(damp)
182
182
 
183
- assert len(imus_surely_rigid) == len(triggered_surely_rigid)
183
+ assert len(imus_surely_rigid) == len(
184
+ triggered_surely_rigid
185
+ ), f"{imus_surely_rigid}, {triggered_surely_rigid}"
184
186
  for imu_surely_rigid in imus_surely_rigid:
185
- assert imu_surely_rigid in triggered_surely_rigid
187
+ assert (
188
+ imu_surely_rigid in triggered_surely_rigid
189
+ ), f"{imus_surely_rigid} not in {triggered_surely_rigid}"
186
190
 
187
191
  return sys.replace(
188
192
  link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
@@ -3,6 +3,7 @@ from . import callbacks
3
3
  from . import ml_utils
4
4
  from . import optimizer
5
5
  from . import ringnet
6
+ from . import rnno_v1
6
7
  from . import train
7
8
  from . import training_loop
8
9
  from .base import AbstractFilter
@@ -42,17 +43,28 @@ def RNNO(
42
43
  params=None,
43
44
  eval: bool = True,
44
45
  samp_freq: float | None = None,
46
+ v1: bool = False,
45
47
  **kwargs,
46
48
  ):
47
49
  assert "message_dim" not in kwargs
48
50
  assert "link_output_normalize" not in kwargs
49
51
  assert "link_output_dim" not in kwargs
50
52
 
53
+ if v1:
54
+ kwargs.update(
55
+ dict(forward_factory=rnno_v1.rnno_v1_forward_factory, output_dim=output_dim)
56
+ )
57
+ else:
58
+ kwargs.update(
59
+ dict(
60
+ message_dim=0,
61
+ link_output_normalize=False,
62
+ link_output_dim=output_dim,
63
+ )
64
+ )
65
+
51
66
  ringnet = RING( # noqa: F811
52
67
  params=params,
53
- message_dim=0,
54
- link_output_normalize=False,
55
- link_output_dim=output_dim,
56
68
  **kwargs,
57
69
  )
58
70
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
@@ -245,7 +245,8 @@ class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
245
245
  else:
246
246
  value = "{:.2f}".format(ele.value).replace(".", ",")
247
247
  filename = parse_path(
248
- self.path_to_file + f"_episode={ele.episode}_value={value}",
248
+ str(Path(self.path_to_file).with_suffix(""))
249
+ + f"_episode={ele.episode}_value={value}",
249
250
  extension="pickle",
250
251
  )
251
252
 
@@ -404,7 +405,7 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
404
405
  # only checkpoint if run has been killed
405
406
  if training_loop.recv_kill_run_signal():
406
407
  path = parse_path(
407
- "~/.xxy_checkpoints", ml_utils.unique_id(), extension="pickle"
408
+ "~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
408
409
  )
409
410
  data = {"params": self.params, "opt_state": self.opt_state}
410
411
  pickle_save(
@@ -191,8 +191,16 @@ class LSTM(hk.RNNCore):
191
191
 
192
192
 
193
193
  class RING(ml_base.AbstractFilter):
194
- def __init__(self, params=None, lam=None, jit: bool = True, name=None, **kwargs):
195
- self.forward_lam_factory = partial(make_ring, **kwargs)
194
+ def __init__(
195
+ self,
196
+ params=None,
197
+ lam=None,
198
+ jit: bool = True,
199
+ name=None,
200
+ forward_factory=make_ring,
201
+ **kwargs,
202
+ ):
203
+ self.forward_lam_factory = partial(forward_factory, **kwargs)
196
204
  self.params = self._load_params(params)
197
205
  self.lam = lam
198
206
  self._name = name
@@ -0,0 +1,37 @@
1
+ from typing import Sequence
2
+
3
+ import haiku as hk
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+
8
+ def rnno_v1_forward_factory(
9
+ output_dim: int,
10
+ rnn_layers: Sequence[int] = (400, 300),
11
+ linear_layers: Sequence[int] = (200, 100, 50, 50, 25, 25),
12
+ layernorm: bool = True,
13
+ act_fn_linear=jax.nn.relu,
14
+ act_fn_rnn=jax.nn.elu,
15
+ ):
16
+ @hk.without_apply_rng
17
+ @hk.transform_with_state
18
+ def forward_fn(X):
19
+ assert X.shape[-2] == 1
20
+
21
+ for i, n_units in enumerate(rnn_layers):
22
+ state = hk.get_state(f"rnn_{i}", shape=[n_units], init=jnp.zeros)
23
+ X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
24
+ hk.set_state(f"rnn_{i}", state)
25
+
26
+ if layernorm:
27
+ X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)
28
+ X = act_fn_rnn(X)
29
+
30
+ for n_units in linear_layers:
31
+ X = hk.Linear(n_units)(X)
32
+ X = act_fn_linear(X)
33
+
34
+ y = hk.Linear(output_dim)(X)
35
+ return y[..., None, :]
36
+
37
+ return forward_fn
@@ -5,6 +5,8 @@ from typing import Callable, Optional, Tuple
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import optax
8
+ import tree_utils
9
+
8
10
  from ring import maths
9
11
  from ring.algorithms.generator import types
10
12
  from ring.ml import base as ml_base
@@ -13,10 +15,7 @@ from ring.ml import ml_utils
13
15
  from ring.ml import training_loop
14
16
  from ring.utils import distribute_batchsize
15
17
  from ring.utils import expand_batchsize
16
- from ring.utils import parse_path
17
18
  from ring.utils import pickle_load
18
- import tree_utils
19
-
20
19
  import wandb
21
20
 
22
21
  # (T, N, F) -> Scalar
@@ -142,15 +141,17 @@ def train_fn(
142
141
  Wether or not the training run was killed by a callback.
143
142
  """
144
143
 
144
+ filter = filter.nojit()
145
+
145
146
  if checkpoint is not None:
146
147
  checkpoint = Path(checkpoint).with_suffix(".pickle")
147
148
  recv_checkpoint: dict = pickle_load(checkpoint)
148
- filter.params = recv_checkpoint["params"]
149
+ filter_params = recv_checkpoint["params"]
149
150
  opt_state = recv_checkpoint["opt_state"]
151
+ del recv_checkpoint
152
+ else:
153
+ filter_params = filter.search_attr("params")
150
154
 
151
- filter = filter.nojit()
152
-
153
- filter_params = filter.search_attr("params")
154
155
  if filter_params is None:
155
156
  X, _ = generator(jax.random.PRNGKey(1))
156
157
  filter_params, _ = filter.init(X=X, seed=seed_network)
@@ -215,7 +216,7 @@ def train_fn(
215
216
 
216
217
  callbacks_all.append(
217
218
  ml_callbacks.SaveParamsTrainingLoopCallback(
218
- path_to_file=parse_path(callback_save_params, extension=""),
219
+ path_to_file=callback_save_params,
219
220
  last_n_params=3,
220
221
  track_metrices=callback_save_params_track_metrices,
221
222
  cleanup=False,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes