imt-ring 1.6.15__py3-none-any.whl → 1.6.17__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.15
3
+ Version: 1.6.17
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -32,11 +32,11 @@ Requires-Dist: pytest-xdist ; extra == 'dev'
32
32
  Requires-Dist: nbmake ; extra == 'dev'
33
33
 
34
34
  <p align="center">
35
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/icon.svg" height="200" />
35
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/icon.svg" height="200" />
36
36
  </p>
37
37
 
38
38
  # Recurrent Inertial Graph-based Estimator (RING)
39
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/coverage_badge.svg" height="20" />
39
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/coverage_badge.svg" height="20" />
40
40
 
41
41
  ## Installation
42
42
 
@@ -1,4 +1,4 @@
1
- ring/__init__.py,sha256=k7tL-XgggUwWxHCXyv60rQn-OcXHPg82QcIUkKLEd-c,5057
1
+ ring/__init__.py,sha256=ncBRdHvge6uqpzFyk8_HUQNx4kZxENpVXTJTEK_SjCg,5216
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
3
  ring/base.py,sha256=Ystn1EjTyOXBhVm5koroV_YPUYtFxrteJLd-XR3kEL8,33840
4
4
  ring/maths.py,sha256=qPHH6TpHCK3TgExI98gNEySoSRKOwteN9McUlyUFipI,12207
@@ -6,17 +6,18 @@ ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
8
  ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
9
- ring/algorithms/jcalc.py,sha256=bwfVH3qKEnUs6RFgEEeUBnecpBt-nf8cesJbNGDrE7g,28974
9
+ ring/algorithms/jcalc.py,sha256=kQgrRE0XoUBrcdeFbUw_xKnL4m1P2G1Q1m4n7BX2yDk,30064
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
12
- ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
12
+ ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
+ ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
15
16
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
18
  ring/algorithms/generator/base.py,sha256=vxUdA0ZeSNH3SOanL51qVRvCiJrmsWQyQX0g2fdm3Rg,15825
18
19
  ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
19
- ring/algorithms/generator/finalize_fns.py,sha256=559sGXs06n46p-eme0SE8hn0lXwGT0P2r3-52ElTldo,9861
20
+ ring/algorithms/generator/finalize_fns.py,sha256=EZ5p7fuZu0Zd0rHJzCVg3vy3U9ysny6TfQfolIGPERc,10029
20
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
22
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
23
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
@@ -55,7 +56,7 @@ ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
55
56
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
57
  ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
57
58
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
- ring/ml/ringnet.py,sha256=Oud23uKmcvFtwNKdEu2KMMvNAFzJM_yBSRNz2a3CjL4,8670
59
+ ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
59
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
60
61
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
62
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
@@ -84,7 +85,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
84
85
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
85
86
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
86
87
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
87
- imt_ring-1.6.15.dist-info/METADATA,sha256=zG-f_woph73I5ErczEeJaYXZLNahC-oNXDtzj26f1Po,3821
88
- imt_ring-1.6.15.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
89
- imt_ring-1.6.15.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
90
- imt_ring-1.6.15.dist-info/RECORD,,
88
+ imt_ring-1.6.17.dist-info/METADATA,sha256=j0IKIyc6qAgz9059Z4-b46hi7qTn_ffI7oLw_OrD_Tk,3833
89
+ imt_ring-1.6.17.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
90
+ imt_ring-1.6.17.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
91
+ imt_ring-1.6.17.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
ring/__init__.py CHANGED
@@ -121,6 +121,7 @@ _UNIQUE_ID = None
121
121
  def setup(
122
122
  rr_joint_kwargs: None | dict = dict(),
123
123
  rr_imp_joint_kwargs: None | dict = dict(),
124
+ rsaddle_joint_kwargs: None | dict = dict(),
124
125
  suntay_joint_kwargs: None | dict = None,
125
126
  train_timing_start: None | float = None,
126
127
  unique_id: None | str = None,
@@ -138,6 +139,9 @@ def setup(
138
139
  if rr_imp_joint_kwargs is not None:
139
140
  custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
140
141
 
142
+ if rsaddle_joint_kwargs is not None:
143
+ custom_joints.register_rsaddle_joint(**rsaddle_joint_kwargs)
144
+
141
145
  if suntay_joint_kwargs is not None:
142
146
  custom_joints.register_suntay(**suntay_joint_kwargs)
143
147
 
@@ -1,5 +1,6 @@
1
1
  from .rr_imp_joint import register_rr_imp_joint
2
2
  from .rr_joint import register_rr_joint
3
+ from .rsaddle_joint import register_rsaddle_joint
3
4
  from .suntay import ConstantValue_DrawnFnPair
4
5
  from .suntay import GP_DrawFnPair
5
6
  from .suntay import MLP_DrawnFnPair
@@ -0,0 +1,40 @@
1
+ import jax.numpy as jnp
2
+
3
+ import ring
4
+ from ring import maths
5
+ from ring.algorithms.jcalc import _draw_saddle
6
+ from ring.algorithms.jcalc import _p_control_term_rxyz
7
+ from ring.algorithms.jcalc import _qd_from_q_cartesian
8
+
9
+
10
+ def register_rsaddle_joint():
11
+ def _transform(q, params):
12
+ axes = params["joint_axes"]
13
+ rot1 = maths.quat_rot_axis(axes[0], q[0])
14
+ rot2 = maths.quat_rot_axis(axes[1], q[1])
15
+ rot = maths.quat_mul(rot2, rot1)
16
+ return ring.Transform.create(rot=rot)
17
+
18
+ def _motion_fn_gen(i: int):
19
+ def _motion_fn(params):
20
+ axis = params["joint_axes"][i]
21
+ return ring.base.Motion.create(ang=axis)
22
+
23
+ return _motion_fn
24
+
25
+ joint_model = ring.JointModel(
26
+ _transform,
27
+ motion=[_motion_fn_gen(i) for i in range(2)],
28
+ rcmg_draw_fn=_draw_saddle,
29
+ p_control_term=_p_control_term_rxyz,
30
+ qd_from_q=_qd_from_q_cartesian,
31
+ init_joint_params=_draw_random_joint_axes,
32
+ )
33
+
34
+ ring.register_new_joint_type("rsaddle", joint_model, 2, overwrite=True)
35
+
36
+
37
+ def _draw_random_joint_axes(key):
38
+ return dict(
39
+ joint_axes=maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key, (2,)))
40
+ )
@@ -160,6 +160,7 @@ _P_gains = {
160
160
  "spherical": jnp.array(3 * [P_rot]),
161
161
  "p3d": jnp.array(3 * [P_pos]),
162
162
  "saddle": jnp.array([P_rot, P_rot]),
163
+ "rsaddle": jnp.array([P_rot, P_rot]),
163
164
  "frozen": jnp.array([]),
164
165
  "suntay": jnp.array([P_rot]),
165
166
  }
@@ -182,13 +183,16 @@ class DynamicalSimulation:
182
183
 
183
184
  @staticmethod
184
185
  def assert_test_system(sys: base.System) -> None:
185
- "test that system has no zero mass bodies and no joints without damping"
186
+ "test that system has no zero mass leaf bodies and no joints without damping"
186
187
 
187
188
  def f(_, __, n, m, d):
188
- assert d.size == 0 or m > 0, (
189
- "Dynamic simulation is set to `True` which requires masses >= 0, "
190
- f"but found body `{n}` with mass={float(m[0])}. This can lead to NaNs."
191
- )
189
+ is_leaf_body = len(sys.children(n)) == 0
190
+ if is_leaf_body:
191
+ assert d.size == 0 or m > 0, (
192
+ "Dynamic simulation is set to `True` which requires masses >= 0, "
193
+ f"but found body `{n}` with mass={float(m[0])}. This can lead to "
194
+ "NaNs."
195
+ )
192
196
 
193
197
  assert d.size == 0 or all(d > 0.0), (
194
198
  "Dynamic simulation is set to `True` which requires dampings > 0, "
ring/algorithms/jcalc.py CHANGED
@@ -1,3 +1,4 @@
1
+ from collections import defaultdict
1
2
  from dataclasses import asdict
2
3
  from dataclasses import dataclass
3
4
  from dataclasses import field
@@ -87,11 +88,39 @@ class MotionConfig:
87
88
  assert nomotion_config.is_feasible()
88
89
  return nomotion_config
89
90
 
91
+ @staticmethod
92
+ def overwrite_for_joint_type(joint_type: str, **changes) -> None:
93
+ """Changes values of the `MotionConfig` used by the draw_fn for only a specific
94
+ joint.
95
+ """
96
+ previous_changes = _overwrite_for_joint_type_changes[joint_type]
97
+ for change in changes:
98
+ assert change not in previous_changes, f"For jointtype={joint_type} you "
99
+ f"previously changed the value={change}. You can't change it again, this "
100
+ "is not supported."
101
+ previous_changes.update(changes)
102
+
103
+ jm = get_joint_model(joint_type)
104
+
105
+ def draw_fn(config, *args):
106
+ return jm.rcmg_draw_fn(replace(config, **changes), *args)
107
+
108
+ register_new_joint_type(
109
+ joint_type,
110
+ replace(jm, rcmg_draw_fn=draw_fn),
111
+ base.Q_WIDTHS[joint_type],
112
+ base.QD_WIDTHS[joint_type],
113
+ overwrite=True,
114
+ )
115
+
90
116
  @staticmethod
91
117
  def from_register(name: str) -> "MotionConfig":
92
118
  return _registered_motion_configs[name]
93
119
 
94
120
 
121
+ _overwrite_for_joint_type_changes: dict[str, dict] = defaultdict(lambda: dict())
122
+
123
+
95
124
  _registered_motion_configs = {
96
125
  "hinUndHer": MotionConfig(
97
126
  t_min=0.3,
@@ -222,7 +251,9 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
222
251
  def inside_box_checks(x_min, x_max, x0_min, x0_max) -> bool:
223
252
  return (x0_min >= x_min) and (x0_max <= x_max)
224
253
 
225
- cond2 = inside_box_checks(c.pos_min, c.pos_max, c.pos0_min, c.pos0_max)
254
+ cond2 = inside_box_checks(
255
+ _to_float(c.pos_min, 0.0), _to_float(c.pos_max, 0.0), c.pos0_min, c.pos0_max
256
+ )
226
257
 
227
258
  return cond1 and cond2
228
259
 
ring/ml/ringnet.py CHANGED
@@ -87,6 +87,7 @@ def make_ring(
87
87
  link_output_normalize: bool = True,
88
88
  link_output_transform: Optional[Callable] = None,
89
89
  layernorm: bool = True,
90
+ layernorm_trainable: bool = True,
90
91
  ) -> SimpleNamespace:
91
92
 
92
93
  if link_output_normalize:
@@ -104,7 +105,11 @@ def make_ring(
104
105
  )
105
106
 
106
107
  inner_cell = StackedRNNCell(
107
- celltype, hidden_state_dim, stack_rnn_cells, layernorm=layernorm
108
+ celltype,
109
+ hidden_state_dim,
110
+ stack_rnn_cells,
111
+ layernorm=layernorm,
112
+ layernorm_trainable=layernorm_trainable,
108
113
  )
109
114
  send_output = hk.nets.MLP([hidden_state_dim, link_output_dim])
110
115
  state = hk.get_state(
@@ -143,6 +148,7 @@ class StackedRNNCell(hk.Module):
143
148
  hidden_state_dim,
144
149
  stacks: int,
145
150
  layernorm: bool = False,
151
+ layernorm_trainable: bool = True,
146
152
  name: str | None = None,
147
153
  ):
148
154
  super().__init__(name)
@@ -150,6 +156,7 @@ class StackedRNNCell(hk.Module):
150
156
 
151
157
  self.cells = [cell(hidden_state_dim) for _ in range(stacks)]
152
158
  self.layernorm = layernorm
159
+ self.layernorm_trainable = layernorm_trainable
153
160
 
154
161
  def __call__(self, x, state):
155
162
  output = x
@@ -159,7 +166,9 @@ class StackedRNNCell(hk.Module):
159
166
  next_state.append(next_state_i)
160
167
 
161
168
  if self.layernorm:
162
- output = hk.LayerNorm(-1, True, True)(output)
169
+ output = hk.LayerNorm(
170
+ -1, self.layernorm_trainable, self.layernorm_trainable
171
+ )(output)
163
172
 
164
173
  return output, jnp.stack(next_state)
165
174