imt-ring 1.6.21__tar.gz → 1.6.23__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. {imt_ring-1.6.21 → imt_ring-1.6.23}/PKG-INFO +1 -1
  2. {imt_ring-1.6.21 → imt_ring-1.6.23}/pyproject.toml +1 -1
  3. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/_random.py +47 -34
  5. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/dynamics.py +4 -4
  6. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/base.py +10 -1
  7. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/jcalc.py +158 -6
  8. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/base.py +43 -2
  9. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/xml/abstract.py +2 -1
  10. {imt_ring-1.6.21 → imt_ring-1.6.23}/readme.md +0 -0
  11. {imt_ring-1.6.21 → imt_ring-1.6.23}/setup.cfg +0 -0
  12. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  13. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  14. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/imt_ring.egg-info/requires.txt +0 -0
  15. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/imt_ring.egg-info/top_level.txt +0 -0
  16. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/__init__.py +0 -0
  17. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algebra.py +0 -0
  18. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/__init__.py +0 -0
  19. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  20. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  21. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  22. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
  23. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  24. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/__init__.py +0 -0
  25. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/batch.py +0 -0
  26. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/finalize_fns.py +0 -0
  27. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  28. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/pd_control.py +0 -0
  29. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/setup_fns.py +0 -0
  30. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/generator/types.py +0 -0
  31. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/kinematics.py +0 -0
  32. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/algorithms/sensors.py +0 -0
  33. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/__init__.py +0 -0
  34. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/branched.xml +0 -0
  35. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  36. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  37. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  38. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/inv_pendulum.xml +0 -0
  39. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  40. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/spherical_stiff.xml +0 -0
  41. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/symmetric.xml +0 -0
  42. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_all_1.xml +0 -0
  43. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_all_2.xml +0 -0
  44. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  45. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_control.xml +0 -0
  46. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  47. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_free.xml +0 -0
  48. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_kinematics.xml +0 -0
  49. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  50. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  51. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_randomize_position.xml +0 -0
  52. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_sensors.xml +0 -0
  53. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  54. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/examples.py +0 -0
  55. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/test_examples.py +0 -0
  56. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/xml/__init__.py +0 -0
  57. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/xml/from_xml.py +0 -0
  58. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/xml/test_from_xml.py +0 -0
  59. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/xml/test_to_xml.py +0 -0
  60. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/io/xml/to_xml.py +0 -0
  61. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/maths.py +0 -0
  62. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/__init__.py +0 -0
  63. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/base.py +0 -0
  64. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/callbacks.py +0 -0
  65. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/ml_utils.py +0 -0
  66. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/optimizer.py +0 -0
  67. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  68. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  69. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/ringnet.py +0 -0
  70. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/rnno_v1.py +0 -0
  71. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/train.py +0 -0
  72. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/ml/training_loop.py +0 -0
  73. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/rendering/__init__.py +0 -0
  74. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/rendering/base_render.py +0 -0
  75. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/rendering/mujoco_render.py +0 -0
  76. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/rendering/vispy_render.py +0 -0
  77. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/rendering/vispy_visuals.py +0 -0
  78. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/sim2real/__init__.py +0 -0
  79. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/sim2real/sim2real.py +0 -0
  80. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/spatial.py +0 -0
  81. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/sys_composer/__init__.py +0 -0
  82. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/sys_composer/delete_sys.py +0 -0
  83. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/sys_composer/inject_sys.py +0 -0
  84. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/sys_composer/morph_sys.py +0 -0
  85. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/__init__.py +0 -0
  86. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/backend.py +0 -0
  87. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/batchsize.py +0 -0
  88. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/colab.py +0 -0
  89. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/dataloader.py +0 -0
  90. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/dataloader_torch.py +0 -0
  91. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/hdf5.py +0 -0
  92. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/normalizer.py +0 -0
  93. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/path.py +0 -0
  94. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/randomize_sys.py +0 -0
  95. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  96. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  97. {imt_ring-1.6.21 → imt_ring-1.6.23}/src/ring/utils/utils.py +0 -0
  98. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_algebra.py +0 -0
  99. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_base.py +0 -0
  100. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_custom_joints.py +0 -0
  101. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_dynamics.py +0 -0
  102. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_generator.py +0 -0
  103. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_jcalc.py +0 -0
  104. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_jit.py +0 -0
  105. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_kinematics.py +0 -0
  106. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_maths.py +0 -0
  107. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_ml_utils.py +0 -0
  108. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_motion_artifacts.py +0 -0
  109. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_pd_control.py +0 -0
  110. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_quickstart_example.py +0 -0
  111. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_random.py +0 -0
  112. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_randomize.py +0 -0
  113. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_rcmg.py +0 -0
  114. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_render.py +0 -0
  115. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_sensors.py +0 -0
  116. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_sim2real.py +0 -0
  117. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_sys_composer.py +0 -0
  118. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_train.py +0 -0
  119. {imt_ring-1.6.21 → imt_ring-1.6.23}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.21
3
+ Version: 1.6.23
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.21"
7
+ version = "1.6.23"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.21
3
+ Version: 1.6.23
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -35,6 +35,9 @@ def random_angle_over_time(
35
35
  randomized_interpolation: bool = False,
36
36
  range_of_motion: bool = False,
37
37
  range_of_motion_method: str = "uniform",
38
+ # this value has nothing to do with `range_of_motion` flag
39
+ # this forces the value to stay within [ANG_0 - rom_halfsize, ANG_0 + rom_halfsize]
40
+ rom_halfsize: float | TimeDependentFloat = 2 * jnp.pi,
38
41
  cdf_bins_min: int = 5,
39
42
  cdf_bins_max: Optional[int] = None,
40
43
  interpolation_method: str = "cosine",
@@ -44,9 +47,14 @@ def random_angle_over_time(
44
47
 
45
48
  key_t, consume_t = random.split(key_t)
46
49
  key_ang, consume_ang = random.split(key_ang)
50
+ rom_halfsize_float = _to_float(rom_halfsize, t)
51
+ rom_lower = ANG_0 - rom_halfsize_float
52
+ rom_upper = ANG_0 + rom_halfsize_float
47
53
  dt, phi = _resolve_range_of_motion(
48
54
  range_of_motion,
49
55
  range_of_motion_method,
56
+ rom_lower,
57
+ rom_upper,
50
58
  _to_float(dang_min, t),
51
59
  _to_float(dang_max, t),
52
60
  _to_float(delta_ang_min, t),
@@ -251,6 +259,8 @@ def _clip_to_pi(phi):
251
259
  def _resolve_range_of_motion(
252
260
  range_of_motion,
253
261
  range_of_motion_method,
262
+ rom_lower: float,
263
+ rom_upper: float,
254
264
  dang_min,
255
265
  dang_max,
256
266
  delta_ang_min,
@@ -265,44 +275,47 @@ def _resolve_range_of_motion(
265
275
  def _next_phi(key, dt):
266
276
  key, consume = random.split(key)
267
277
 
268
- if range_of_motion:
269
- if range_of_motion_method == "coinflip":
270
- probs = jnp.array([0.5, 0.5])
271
- elif range_of_motion_method == "uniform":
272
- p = 0.5 * (1 - prev_phi / jnp.pi)
273
- probs = jnp.array([p, (1 - p)])
274
- elif range_of_motion_method[:7] == "sigmoid":
275
- scale = 1.5
276
- provided_params = range_of_motion_method.split("-")
277
- if len(provided_params) == 2:
278
- scale = float(provided_params[-1])
279
- hardcut = jnp.pi - 0.01
280
- p = jnp.where(
281
- prev_phi > hardcut,
282
- 0.0,
283
- jnp.where(
284
- prev_phi < -hardcut, 1.0, jax.nn.sigmoid(-scale * prev_phi)
285
- ),
286
- )
287
- probs = jnp.array([p, (1 - p)])
288
- else:
289
- raise NotImplementedError
278
+ # legacy reasons, without range of motion the `sign` value, so going
279
+ # left or right is 50-50 for free joints and spherical joints
280
+ if not range_of_motion:
281
+ range_of_motion_method = "coinflip"
282
+
283
+ if range_of_motion_method == "coinflip":
284
+ probs = jnp.array([0.5, 0.5])
285
+ elif range_of_motion_method == "uniform":
286
+ p = 0.5 * (1 - prev_phi / jnp.pi)
287
+ probs = jnp.array([p, (1 - p)])
288
+ elif range_of_motion_method[:7] == "sigmoid":
289
+ scale = 1.5
290
+ provided_params = range_of_motion_method.split("-")
291
+ if len(provided_params) == 2:
292
+ scale = float(provided_params[-1])
293
+ hardcut = jnp.pi - 0.01
294
+ p = jnp.where(
295
+ prev_phi > hardcut,
296
+ 0.0,
297
+ jnp.where(prev_phi < -hardcut, 1.0, jax.nn.sigmoid(-scale * prev_phi)),
298
+ )
299
+ probs = jnp.array([p, (1 - p)])
300
+ else:
301
+ raise NotImplementedError
290
302
 
291
- sign = random.choice(consume, jnp.array([1.0, -1.0]), p=probs)
292
- lower = _clip_to_pi(prev_phi + sign * dang_min * dt)
293
- upper = _clip_to_pi(prev_phi + sign * dang_max * dt)
303
+ sign = random.choice(consume, jnp.array([1.0, -1.0]), p=probs)
304
+ lower = prev_phi + sign * dang_min * dt
305
+ upper = prev_phi + sign * dang_max * dt
294
306
 
295
- # swap if lower > upper
296
- lower, upper = jnp.sort(jnp.hstack((lower, upper)))
307
+ if range_of_motion:
308
+ lower, upper = _clip_to_pi(lower), _clip_to_pi(upper)
297
309
 
298
- key, consume = random.split(key)
299
- return random.uniform(consume, minval=lower, maxval=upper)
310
+ # swap if lower > upper
311
+ lower, upper = jnp.sort(jnp.hstack((lower, upper)))
300
312
 
301
- else:
302
- dphi = random.uniform(consume, minval=dang_min, maxval=dang_max) * dt
303
- key, consume = random.split(key)
304
- sign = random.choice(consume, jnp.array([1.0, -1.0]))
305
- return prev_phi + sign * dphi
313
+ # clip bounds given by the angular velocity bounds to the rom bounds
314
+ lower = jnp.clip(lower, a_min=rom_lower)
315
+ upper = jnp.clip(upper, a_max=rom_upper)
316
+
317
+ key, consume = random.split(key)
318
+ return random.uniform(consume, minval=lower, maxval=upper)
306
319
 
307
320
  def body_fn(val):
308
321
  key_t, key_ang, _, _, i = val
@@ -190,11 +190,11 @@ def _spring_force(sys: base.System, q: jax.Array):
190
190
 
191
191
  def _calc_spring_force_per_link(_, __, q, zeropoint, typ):
192
192
  # cor is (free, p3d) stacked; free is (spherical, p3d) stacked
193
- if typ in ["free", "cor"]:
193
+ if base.System.joint_type_is_free_or_cor(typ):
194
194
  quat_force = _quaternion_spring_force(zeropoint[:4], q[:4])
195
195
  pos_force = zeropoint[4:] - q[4:]
196
196
  q_spring_force_link = jnp.concatenate((quat_force, pos_force))
197
- elif typ == "spherical":
197
+ elif base.System.joint_type_is_spherical(typ):
198
198
  q_spring_force_link = _quaternion_spring_force(zeropoint, q)
199
199
  else:
200
200
  q_spring_force_link = zeropoint - q
@@ -268,11 +268,11 @@ def _semi_implicit_euler_integration(
268
268
  q_next = []
269
269
 
270
270
  def q_integrate(_, __, q, qd, typ):
271
- if typ in ["free", "cor"]:
271
+ if sys.joint_type_is_free_or_cor(typ):
272
272
  quat_next = _strapdown_integration(q[:4], qd[:3], sys.dt)
273
273
  pos_next = q[4:] + qd[3:] * sys.dt
274
274
  q_next_i = jnp.concatenate((quat_next, pos_next))
275
- elif typ == "spherical":
275
+ elif sys.joint_type_is_spherical(typ):
276
276
  quat_next = _strapdown_integration(q, qd, sys.dt)
277
277
  q_next_i = quat_next
278
278
  else:
@@ -1,3 +1,4 @@
1
+ from dataclasses import replace
1
2
  from functools import partial
2
3
  import random
3
4
  from typing import Callable, Optional
@@ -446,7 +447,15 @@ def draw_random_q(
446
447
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
447
448
  if draw_fn is None:
448
449
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
449
- q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
450
+
451
+ if link_type in config.joint_type_specific_overwrites:
452
+ _config = replace(
453
+ config, **config.joint_type_specific_overwrites[link_type]
454
+ )
455
+ else:
456
+ _config = config
457
+
458
+ q_link = draw_fn(_config, key_t, key_value, sys.dt, N, joint_params)
450
459
  # even revolute and prismatic joints must be 2d arrays
451
460
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
452
461
  q_list.append(q_link)
@@ -40,6 +40,12 @@ class MotionConfig:
40
40
  dpos_max: float | TimeDependentFloat = 0.7
41
41
  pos_min: float | TimeDependentFloat = -2.5
42
42
  pos_max: float | TimeDependentFloat = +2.5
43
+ pos_min_p3d_x: float | TimeDependentFloat = -2.5
44
+ pos_max_p3d_x: float | TimeDependentFloat = +2.5
45
+ pos_min_p3d_y: float | TimeDependentFloat = -2.5
46
+ pos_max_p3d_y: float | TimeDependentFloat = +2.5
47
+ pos_min_p3d_z: float | TimeDependentFloat = -2.5
48
+ pos_max_p3d_z: float | TimeDependentFloat = +2.5
43
49
 
44
50
  # used by both `random_angle_*` and `random_pos_*`
45
51
  # only used if `randomized_interpolation` is set
@@ -54,11 +60,22 @@ class MotionConfig:
54
60
  range_of_motion_hinge: bool = True
55
61
  range_of_motion_hinge_method: str = "uniform"
56
62
 
63
+ # this value has nothing to do with `range_of_motion` flag
64
+ # this forces the value to stay within [ANG_0 - rom_halfsize, ANG_0 + rom_halfsize]
65
+ # used only by the `_draw_rxyz` function
66
+ rom_halfsize: float | TimeDependentFloat = 2 * jnp.pi
67
+
57
68
  # initial value of joints
58
69
  ang0_min: float = -jnp.pi
59
70
  ang0_max: float = jnp.pi
60
71
  pos0_min: float = 0.0
61
72
  pos0_max: float = 0.0
73
+ pos0_min_p3d_x: float = 0.0
74
+ pos0_max_p3d_x: float = 0.0
75
+ pos0_min_p3d_y: float = 0.0
76
+ pos0_max_p3d_y: float = 0.0
77
+ pos0_min_p3d_z: float = 0.0
78
+ pos0_max_p3d_z: float = 0.0
62
79
 
63
80
  # cor (center of rotation) custom fields
64
81
  cor_t_min: float = 0.2
@@ -67,6 +84,14 @@ class MotionConfig:
67
84
  cor_dpos_max: float | TimeDependentFloat = 0.5
68
85
  cor_pos_min: float | TimeDependentFloat = -0.4
69
86
  cor_pos_max: float | TimeDependentFloat = 0.4
87
+ cor_pos0_min: float = 0.0
88
+ cor_pos0_max: float = 0.0
89
+
90
+ # specify changes for this motionconfig and for specific joint types
91
+ # map of `link_types` -> dictionary of changes
92
+ joint_type_specific_overwrites: dict[str, dict[str, Any]] = field(
93
+ default_factory=lambda: dict()
94
+ )
70
95
 
71
96
  def is_feasible(self) -> bool:
72
97
  return _is_feasible_config1(self)
@@ -92,6 +117,9 @@ class MotionConfig:
92
117
  def overwrite_for_joint_type(joint_type: str, **changes) -> None:
93
118
  """Changes values of the `MotionConfig` used by the draw_fn for only a specific
94
119
  joint.
120
+ !!! Note
121
+ This applies these changes to *all* MotionConfigs for this joint type!
122
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
95
123
  """
96
124
  previous_changes = _overwrite_for_joint_type_changes[joint_type]
97
125
  for change in changes:
@@ -113,6 +141,56 @@ class MotionConfig:
113
141
  overwrite=True,
114
142
  )
115
143
 
144
+ @staticmethod
145
+ def overwrite_for_subsystem(
146
+ sys: base.System, link_name: str, **changes
147
+ ) -> base.System:
148
+ """Modifies motionconfig of all joints in subsystem with root `link_name`.
149
+ Note that if the subsystem contains a free joint then the jointtype will
150
+ will be re-named to `free_<link_name>`, then the RCMG flag `cor` will
151
+ potentially not work as expected because it searches for all joints of
152
+ type `free` to replace with `cor`. The workaround here is to change the
153
+ type already from `free` to `cor in the xml file.
154
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
155
+
156
+ Args:
157
+ sys (base.System): System object that gets updated
158
+ link_name (str): Root node of subsystem
159
+ changes: Changes to apply to the motionconfig
160
+
161
+ Return:
162
+ base.System: Updated system with new jointtypes
163
+ """
164
+ from ring.algorithms.generator.finalize_fns import _P_gains
165
+
166
+ # all bodies in the subsystem
167
+ bodies = sys.findall_bodies_subsystem(link_name) + [sys.name_to_idx(link_name)]
168
+
169
+ jts_subsys = set([sys.link_types[i] for i in bodies]) - set(["frozen"])
170
+ postfix = "_" + link_name
171
+ # create new joint types with updated motionconfig
172
+ for typ in jts_subsys:
173
+ register_new_joint_type(
174
+ typ + postfix,
175
+ get_joint_model(typ),
176
+ base.Q_WIDTHS[typ],
177
+ base.QD_WIDTHS[typ],
178
+ )
179
+ MotionConfig.overwrite_for_joint_type(typ + postfix, **changes)
180
+ _P_gains[typ + postfix] = _P_gains[typ]
181
+
182
+ # rename all jointtypes
183
+ new_link_types = [
184
+ (
185
+ sys.link_types[i] + postfix
186
+ if (i in bodies and sys.link_types[i] != "frozen")
187
+ else sys.link_types[i]
188
+ )
189
+ for i in range(sys.num_links())
190
+ ]
191
+ sys = sys.replace(link_types=new_link_types)
192
+ return sys
193
+
116
194
  @staticmethod
117
195
  def from_register(name: str) -> "MotionConfig":
118
196
  return _registered_motion_configs[name]
@@ -221,6 +299,37 @@ _registered_motion_configs = {
221
299
  }
222
300
 
223
301
 
302
+ def _joint_specific_overwrites_free_cor(
303
+ id: str, dang: float, dpos: float
304
+ ) -> MotionConfig:
305
+ changes = dict(
306
+ dang_max_free_spherical=dang,
307
+ dpos_max=dpos,
308
+ cor_dpos_max=dpos,
309
+ t_min=1.5,
310
+ t_max=15.0,
311
+ )
312
+ return replace(
313
+ _registered_motion_configs[id],
314
+ joint_type_specific_overwrites=dict(free=changes, cor=changes),
315
+ )
316
+
317
+
318
+ _registered_motion_configs.update(
319
+ {
320
+ f"{id}-S": _joint_specific_overwrites_free_cor(id, 0.2, 0.1)
321
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
322
+ }
323
+ )
324
+ _registered_motion_configs.update(
325
+ {
326
+ f"{id}-S+": _joint_specific_overwrites_free_cor(id, 0.1, 0.05)
327
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
328
+ }
329
+ )
330
+ del _joint_specific_overwrites_free_cor
331
+
332
+
224
333
  def _is_feasible_config1(c: MotionConfig) -> bool:
225
334
  t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
226
335
 
@@ -254,8 +363,29 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
254
363
  cond2 = inside_box_checks(
255
364
  _to_float(c.pos_min, 0.0), _to_float(c.pos_max, 0.0), c.pos0_min, c.pos0_max
256
365
  )
366
+ cond3 = inside_box_checks(
367
+ _to_float(c.pos_min_p3d_x, 0.0),
368
+ _to_float(c.pos_max_p3d_x, 0.0),
369
+ c.pos0_min_p3d_x,
370
+ c.pos0_max_p3d_x,
371
+ )
372
+ cond4 = inside_box_checks(
373
+ _to_float(c.pos_min_p3d_y, 0.0),
374
+ _to_float(c.pos_max_p3d_y, 0.0),
375
+ c.pos0_min_p3d_y,
376
+ c.pos0_max_p3d_y,
377
+ )
378
+ cond5 = inside_box_checks(
379
+ _to_float(c.pos_min_p3d_z, 0.0),
380
+ _to_float(c.pos_max_p3d_z, 0.0),
381
+ c.pos0_min_p3d_z,
382
+ c.pos0_max_p3d_z,
383
+ )
384
+
385
+ # test that the delta_ang_min is smaller than 2*rom_halfsize
386
+ cond6 = _to_float(c.delta_ang_min, 0.0) < 2 * _to_float(c.rom_halfsize, 0.0)
257
387
 
258
- return cond1 and cond2
388
+ return cond1 and cond2 and cond3 and cond4 and cond5 and cond6
259
389
 
260
390
 
261
391
  def _find_interval(t: jax.Array, boundaries: jax.Array):
@@ -488,6 +618,7 @@ def _draw_rxyz(
488
618
  config.randomized_interpolation_angle,
489
619
  config.range_of_motion_hinge if enable_range_of_motion else False,
490
620
  config.range_of_motion_hinge_method,
621
+ config.rom_halfsize,
491
622
  config.cdf_bins_min,
492
623
  config.cdf_bins_max,
493
624
  config.interpolation_method,
@@ -504,7 +635,11 @@ def _draw_pxyz(
504
635
  cor: bool = False,
505
636
  ) -> jax.Array:
506
637
  key_value, consume = jax.random.split(key_value)
507
- POS_0 = jax.random.uniform(consume, minval=config.pos0_min, maxval=config.pos0_max)
638
+ POS_0 = jax.random.uniform(
639
+ consume,
640
+ minval=config.cor_pos0_min if cor else config.pos0_min,
641
+ maxval=config.cor_pos0_max if cor else config.pos0_max,
642
+ )
508
643
  max_iter = 100
509
644
  return _random.random_position_over_time(
510
645
  key_value,
@@ -590,10 +725,27 @@ def _draw_p3d_and_cor(
590
725
  __: jax.Array,
591
726
  cor: bool,
592
727
  ) -> jax.Array:
593
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
594
- jax.random.split(key_value, 3)
595
- )
596
- return pos.T
728
+ keys = jax.random.split(key_value, 3)
729
+
730
+ def draw(key, xyz: str):
731
+ return _draw_pxyz(
732
+ replace(
733
+ config,
734
+ pos_min=getattr(config, f"pos_min_p3d_{xyz}"),
735
+ pos_max=getattr(config, f"pos_max_p3d_{xyz}"),
736
+ pos0_min=getattr(config, f"pos0_min_p3d_{xyz}"),
737
+ pos0_max=getattr(config, f"pos0_max_p3d_{xyz}"),
738
+ ),
739
+ None,
740
+ key,
741
+ dt,
742
+ N,
743
+ None,
744
+ cor,
745
+ )[:, None]
746
+
747
+ px, py, pz = draw(keys[0], "x"), draw(keys[1], "y"), draw(keys[2], "z")
748
+ return jnp.concat((px, py, pz), axis=-1)
597
749
 
598
750
 
599
751
  def _draw_p3d(
@@ -7,6 +7,7 @@ from jax.core import Tracer
7
7
  import jax.numpy as jnp
8
8
  from jax.tree_util import tree_map
9
9
  import numpy as np
10
+ import tree
10
11
  import tree_utils as tu
11
12
 
12
13
  import ring
@@ -590,6 +591,34 @@ class System(_Base):
590
591
 
591
592
  return sys
592
593
 
594
+ @staticmethod
595
+ def joint_type_simplification(typ: str) -> str:
596
+ if typ[:4] == "free":
597
+ if typ == "free_2d":
598
+ return "free_2d"
599
+ else:
600
+ return "free"
601
+ elif typ[:3] == "cor":
602
+ return "cor"
603
+ elif typ[:9] == "spherical":
604
+ return "spherical"
605
+ else:
606
+ return typ
607
+
608
+ @staticmethod
609
+ def joint_type_is_free_or_cor(typ: str) -> bool:
610
+ return System.joint_type_simplification(typ) in ["free", "cor"]
611
+
612
+ @staticmethod
613
+ def joint_type_is_spherical(typ: str) -> bool:
614
+ return System.joint_type_simplification(typ) == "spherical"
615
+
616
+ @staticmethod
617
+ def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
618
+ return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
619
+ typ
620
+ )
621
+
593
622
  def findall_imus(self, names: bool = True) -> list[str] | list[int]:
594
623
  bodies = [name for name in self.link_names if name[:3] == "imu"]
595
624
  return bodies if names else [self.name_to_idx(n) for n in bodies]
@@ -618,10 +647,20 @@ class System(_Base):
618
647
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
619
648
 
620
649
  def children(self, name: str, names: bool = False) -> list[int] | list[str]:
650
+ "List all direct children of body, does not include body itself"
621
651
  p = self.name_to_idx(name)
622
652
  bodies = [i for i in range(self.num_links()) if self.link_parents[i] == p]
623
653
  return bodies if (not names) else [self.idx_to_name(i) for i in bodies]
624
654
 
655
+ def findall_bodies_subsystem(
656
+ self, name: str, names: bool = False
657
+ ) -> list[int] | list[str]:
658
+ "List all children and children's children; does not include body itself"
659
+ children = self.children(name, names=True)
660
+ grandchildren = [self.findall_bodies_subsystem(n, names=True) for n in children]
661
+ bodies = tree.flatten([children, grandchildren])
662
+ return bodies if names else [self.name_to_idx(n) for n in bodies]
663
+
625
664
  def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
626
665
  """Scan `f` along each link in system whilst carrying along state.
627
666
 
@@ -889,7 +928,9 @@ def _parse_system(sys: System) -> System:
889
928
  assert d.size == a.size == s.size == qd_size, error_msg
890
929
  assert z.size == q_size, error_msg
891
930
 
892
- if typ in ["spherical", "free", "cor"] and not isinstance(z, Tracer):
931
+ if System.joint_type_is_free_or_cor_or_spherical(typ) and not isinstance(
932
+ z, Tracer
933
+ ):
893
934
  assert jnp.allclose(
894
935
  jnp.linalg.norm(z[:4]), 1.0
895
936
  ), f"not unit quat for link `{name}` of typ `{typ}` in model"
@@ -1030,7 +1071,7 @@ class State(_Base):
1030
1071
  def replace_by_unit_quat(_, idx_map, link_typ, link_idx):
1031
1072
  nonlocal q
1032
1073
 
1033
- if link_typ in ["free", "cor", "spherical"]:
1074
+ if sys.joint_type_is_free_or_cor_or_spherical(link_typ):
1034
1075
  q_idxs_link = idx_map["q"](link_idx)
1035
1076
  q = q.at[q_idxs_link.start].set(1.0)
1036
1077
 
@@ -3,6 +3,7 @@ from typing import Tuple, TypeVar
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import numpy as np
6
+
6
7
  from ring import base
7
8
 
8
9
  T = TypeVar("T")
@@ -17,7 +18,7 @@ default_stiffness = lambda qd_size, **_: jnp.zeros((qd_size,))
17
18
 
18
19
  def default_zeropoint(q_size, link_typ: str, **_):
19
20
  zeropoint = jnp.zeros((q_size))
20
- if link_typ in ["spherical", "free", "cor"]:
21
+ if base.System.joint_type_is_free_or_cor_or_spherical(link_typ):
21
22
  # zeropoint then is unit quaternion and not zeros
22
23
  zeropoint = zeropoint.at[0].set(1.0)
23
24
  return zeropoint
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes