imt-ring 1.6.20__tar.gz → 1.6.22__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. {imt_ring-1.6.20 → imt_ring-1.6.22}/PKG-INFO +1 -1
  2. {imt_ring-1.6.20 → imt_ring-1.6.22}/pyproject.toml +1 -1
  3. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/_random.py +1 -1
  5. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/dynamics.py +4 -4
  6. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/base.py +20 -1
  7. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/finalize_fns.py +12 -0
  8. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/jcalc.py +149 -6
  9. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/base.py +43 -2
  10. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/xml/abstract.py +2 -1
  11. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/train.py +4 -6
  12. {imt_ring-1.6.20 → imt_ring-1.6.22}/readme.md +0 -0
  13. {imt_ring-1.6.20 → imt_ring-1.6.22}/setup.cfg +0 -0
  14. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  15. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  16. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/imt_ring.egg-info/requires.txt +0 -0
  17. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/imt_ring.egg-info/top_level.txt +0 -0
  18. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/__init__.py +0 -0
  19. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algebra.py +0 -0
  20. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/__init__.py +0 -0
  21. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  22. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  23. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  24. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
  25. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  26. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/__init__.py +0 -0
  27. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/batch.py +0 -0
  28. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  29. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/pd_control.py +0 -0
  30. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/setup_fns.py +0 -0
  31. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/generator/types.py +0 -0
  32. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/kinematics.py +0 -0
  33. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/algorithms/sensors.py +0 -0
  34. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/__init__.py +0 -0
  35. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/branched.xml +0 -0
  36. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  37. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  38. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  39. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/inv_pendulum.xml +0 -0
  40. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  41. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/spherical_stiff.xml +0 -0
  42. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/symmetric.xml +0 -0
  43. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_all_1.xml +0 -0
  44. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_all_2.xml +0 -0
  45. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  46. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_control.xml +0 -0
  47. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  48. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_free.xml +0 -0
  49. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_kinematics.xml +0 -0
  50. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  51. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  52. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_randomize_position.xml +0 -0
  53. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_sensors.xml +0 -0
  54. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  55. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/examples.py +0 -0
  56. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/test_examples.py +0 -0
  57. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/xml/__init__.py +0 -0
  58. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/xml/from_xml.py +0 -0
  59. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/xml/test_from_xml.py +0 -0
  60. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/xml/test_to_xml.py +0 -0
  61. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/io/xml/to_xml.py +0 -0
  62. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/maths.py +0 -0
  63. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/__init__.py +0 -0
  64. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/base.py +0 -0
  65. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/callbacks.py +0 -0
  66. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/ml_utils.py +0 -0
  67. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/optimizer.py +0 -0
  68. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  69. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  70. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/ringnet.py +0 -0
  71. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/rnno_v1.py +0 -0
  72. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/ml/training_loop.py +0 -0
  73. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/rendering/__init__.py +0 -0
  74. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/rendering/base_render.py +0 -0
  75. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/rendering/mujoco_render.py +0 -0
  76. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/rendering/vispy_render.py +0 -0
  77. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/rendering/vispy_visuals.py +0 -0
  78. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/sim2real/__init__.py +0 -0
  79. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/sim2real/sim2real.py +0 -0
  80. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/spatial.py +0 -0
  81. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/sys_composer/__init__.py +0 -0
  82. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/sys_composer/delete_sys.py +0 -0
  83. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/sys_composer/inject_sys.py +0 -0
  84. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/sys_composer/morph_sys.py +0 -0
  85. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/__init__.py +0 -0
  86. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/backend.py +0 -0
  87. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/batchsize.py +0 -0
  88. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/colab.py +0 -0
  89. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/dataloader.py +0 -0
  90. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/dataloader_torch.py +0 -0
  91. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/hdf5.py +0 -0
  92. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/normalizer.py +0 -0
  93. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/path.py +0 -0
  94. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/randomize_sys.py +0 -0
  95. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  96. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  97. {imt_ring-1.6.20 → imt_ring-1.6.22}/src/ring/utils/utils.py +0 -0
  98. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_algebra.py +0 -0
  99. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_base.py +0 -0
  100. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_custom_joints.py +0 -0
  101. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_dynamics.py +0 -0
  102. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_generator.py +0 -0
  103. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_jcalc.py +0 -0
  104. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_jit.py +0 -0
  105. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_kinematics.py +0 -0
  106. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_maths.py +0 -0
  107. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_ml_utils.py +0 -0
  108. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_motion_artifacts.py +0 -0
  109. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_pd_control.py +0 -0
  110. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_quickstart_example.py +0 -0
  111. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_random.py +0 -0
  112. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_randomize.py +0 -0
  113. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_rcmg.py +0 -0
  114. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_render.py +0 -0
  115. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_sensors.py +0 -0
  116. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_sim2real.py +0 -0
  117. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_sys_composer.py +0 -0
  118. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_train.py +0 -0
  119. {imt_ring-1.6.20 → imt_ring-1.6.22}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.20
3
+ Version: 1.6.22
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.20"
7
+ version = "1.6.22"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.20
3
+ Version: 1.6.22
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -199,7 +199,7 @@ def random_position_over_time(
199
199
  POS = jnp.zeros((int(T // t_min) + 1, 2))
200
200
  POS = POS.at[0, 1].set(POS_0)
201
201
 
202
- val_outer = (1, 0.0, 0.0, 0.0, 0.0, key, POS)
202
+ val_outer = (1, 0.0, 0.0, POS_0, POS_0, key, POS)
203
203
  end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
204
204
  POS = jnp.where(
205
205
  (jnp.arange(len(POS)) < end)[:, None],
@@ -190,11 +190,11 @@ def _spring_force(sys: base.System, q: jax.Array):
190
190
 
191
191
  def _calc_spring_force_per_link(_, __, q, zeropoint, typ):
192
192
  # cor is (free, p3d) stacked; free is (spherical, p3d) stacked
193
- if typ in ["free", "cor"]:
193
+ if base.System.joint_type_is_free_or_cor(typ):
194
194
  quat_force = _quaternion_spring_force(zeropoint[:4], q[:4])
195
195
  pos_force = zeropoint[4:] - q[4:]
196
196
  q_spring_force_link = jnp.concatenate((quat_force, pos_force))
197
- elif typ == "spherical":
197
+ elif base.System.joint_type_is_spherical(typ):
198
198
  q_spring_force_link = _quaternion_spring_force(zeropoint, q)
199
199
  else:
200
200
  q_spring_force_link = zeropoint - q
@@ -268,11 +268,11 @@ def _semi_implicit_euler_integration(
268
268
  q_next = []
269
269
 
270
270
  def q_integrate(_, __, q, qd, typ):
271
- if typ in ["free", "cor"]:
271
+ if sys.joint_type_is_free_or_cor(typ):
272
272
  quat_next = _strapdown_integration(q[:4], qd[:3], sys.dt)
273
273
  pos_next = q[4:] + qd[3:] * sys.dt
274
274
  q_next_i = jnp.concatenate((quat_next, pos_next))
275
- elif typ == "spherical":
275
+ elif sys.joint_type_is_spherical(typ):
276
276
  quat_next = _strapdown_integration(q, qd, sys.dt)
277
277
  q_next_i = quat_next
278
278
  else:
@@ -1,3 +1,4 @@
1
+ from dataclasses import replace
1
2
  from functools import partial
2
3
  import random
3
4
  from typing import Callable, Optional
@@ -34,6 +35,8 @@ class RCMG:
34
35
  add_y_relpose: bool = False,
35
36
  add_y_rootincl: bool = False,
36
37
  add_y_rootincl_kwargs: dict = dict(),
38
+ add_y_rootfull: bool = False,
39
+ add_y_rootfull_kwargs: dict = dict(),
37
40
  sys_ml: Optional[base.System] = None,
38
41
  randomize_positions: bool = False,
39
42
  randomize_motion_artifacts: bool = False,
@@ -73,6 +76,8 @@ class RCMG:
73
76
  add_y_relpose=add_y_relpose,
74
77
  add_y_rootincl=add_y_rootincl,
75
78
  add_y_rootincl_kwargs=add_y_rootincl_kwargs,
79
+ add_y_rootfull=add_y_rootfull,
80
+ add_y_rootfull_kwargs=add_y_rootfull_kwargs,
76
81
  sys_ml=sys_ml,
77
82
  randomize_positions=randomize_positions,
78
83
  randomize_motion_artifacts=randomize_motion_artifacts,
@@ -279,6 +284,8 @@ def _build_mconfig_batched_generator(
279
284
  add_y_relpose: bool,
280
285
  add_y_rootincl: bool,
281
286
  add_y_rootincl_kwargs: dict,
287
+ add_y_rootfull: bool,
288
+ add_y_rootfull_kwargs: dict,
282
289
  sys_ml: base.System,
283
290
  randomize_positions: bool,
284
291
  randomize_motion_artifacts: bool,
@@ -365,7 +372,11 @@ def _build_mconfig_batched_generator(
365
372
  if add_y_relpose:
366
373
  pipe.append(finalize_fns.RelPose(sys_noimu))
367
374
  if add_y_rootincl:
375
+ assert not add_y_rootfull
368
376
  pipe.append(finalize_fns.RootIncl(sys_noimu, **add_y_rootincl_kwargs))
377
+ if add_y_rootfull:
378
+ assert not add_y_rootincl
379
+ pipe.append(finalize_fns.RootFull(sys_noimu, **add_y_rootfull_kwargs))
369
380
  if use_link_number_in_Xy:
370
381
  pipe.append(finalize_fns.Names2Indices(sys_noimu))
371
382
 
@@ -436,7 +447,15 @@ def draw_random_q(
436
447
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
437
448
  if draw_fn is None:
438
449
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
439
- q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
450
+
451
+ if link_type in config.joint_type_specific_overwrites:
452
+ _config = replace(
453
+ config, **config.joint_type_specific_overwrites[link_type]
454
+ )
455
+ else:
456
+ _config = config
457
+
458
+ q_link = draw_fn(_config, key_t, key_value, sys.dt, N, joint_params)
440
459
  # even revolute and prismatic joints must be 2d arrays
441
460
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
442
461
  q_list.append(q_link)
@@ -88,6 +88,18 @@ class RootIncl:
88
88
  return (X, y), (key, q, x, sys_x)
89
89
 
90
90
 
91
+ class RootFull:
92
+ def __init__(self, sys: base.System, **kwargs):
93
+ self.sys = sys
94
+ self.kwargs = kwargs
95
+
96
+ def __call__(self, Xy, extras):
97
+ (X, y), (key, q, x, sys_x) = Xy, extras
98
+ y_root_incl = sensors.root_full(self.sys, x, sys_x, **self.kwargs)
99
+ y = utils.dict_union(y, y_root_incl)
100
+ return (X, y), (key, q, x, sys_x)
101
+
102
+
91
103
  _default_imu_kwargs = dict(
92
104
  noisy=True,
93
105
  low_pass_filter_pos_f_cutoff=13.5,
@@ -40,6 +40,12 @@ class MotionConfig:
40
40
  dpos_max: float | TimeDependentFloat = 0.7
41
41
  pos_min: float | TimeDependentFloat = -2.5
42
42
  pos_max: float | TimeDependentFloat = +2.5
43
+ pos_min_p3d_x: float | TimeDependentFloat = -2.5
44
+ pos_max_p3d_x: float | TimeDependentFloat = +2.5
45
+ pos_min_p3d_y: float | TimeDependentFloat = -2.5
46
+ pos_max_p3d_y: float | TimeDependentFloat = +2.5
47
+ pos_min_p3d_z: float | TimeDependentFloat = -2.5
48
+ pos_max_p3d_z: float | TimeDependentFloat = +2.5
43
49
 
44
50
  # used by both `random_angle_*` and `random_pos_*`
45
51
  # only used if `randomized_interpolation` is set
@@ -59,6 +65,12 @@ class MotionConfig:
59
65
  ang0_max: float = jnp.pi
60
66
  pos0_min: float = 0.0
61
67
  pos0_max: float = 0.0
68
+ pos0_min_p3d_x: float = 0.0
69
+ pos0_max_p3d_x: float = 0.0
70
+ pos0_min_p3d_y: float = 0.0
71
+ pos0_max_p3d_y: float = 0.0
72
+ pos0_min_p3d_z: float = 0.0
73
+ pos0_max_p3d_z: float = 0.0
62
74
 
63
75
  # cor (center of rotation) custom fields
64
76
  cor_t_min: float = 0.2
@@ -67,6 +79,14 @@ class MotionConfig:
67
79
  cor_dpos_max: float | TimeDependentFloat = 0.5
68
80
  cor_pos_min: float | TimeDependentFloat = -0.4
69
81
  cor_pos_max: float | TimeDependentFloat = 0.4
82
+ cor_pos0_min: float = 0.0
83
+ cor_pos0_max: float = 0.0
84
+
85
+ # specify changes for this motionconfig and for specific joint types
86
+ # map of `link_types` -> dictionary of changes
87
+ joint_type_specific_overwrites: dict[str, dict[str, Any]] = field(
88
+ default_factory=lambda: dict()
89
+ )
70
90
 
71
91
  def is_feasible(self) -> bool:
72
92
  return _is_feasible_config1(self)
@@ -92,6 +112,9 @@ class MotionConfig:
92
112
  def overwrite_for_joint_type(joint_type: str, **changes) -> None:
93
113
  """Changes values of the `MotionConfig` used by the draw_fn for only a specific
94
114
  joint.
115
+ !!! Note
116
+ This applies these changes to *all* MotionConfigs for this joint type!
117
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
95
118
  """
96
119
  previous_changes = _overwrite_for_joint_type_changes[joint_type]
97
120
  for change in changes:
@@ -113,6 +136,56 @@ class MotionConfig:
113
136
  overwrite=True,
114
137
  )
115
138
 
139
+ @staticmethod
140
+ def overwrite_for_subsystem(
141
+ sys: base.System, link_name: str, **changes
142
+ ) -> base.System:
143
+ """Modifies motionconfig of all joints in subsystem with root `link_name`.
144
+ Note that if the subsystem contains a free joint then the jointtype will
145
+ will be re-named to `free_<link_name>`, then the RCMG flag `cor` will
146
+ potentially not work as expected because it searches for all joints of
147
+ type `free` to replace with `cor`. The workaround here is to change the
148
+ type already from `free` to `cor in the xml file.
149
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
150
+
151
+ Args:
152
+ sys (base.System): System object that gets updated
153
+ link_name (str): Root node of subsystem
154
+ changes: Changes to apply to the motionconfig
155
+
156
+ Return:
157
+ base.System: Updated system with new jointtypes
158
+ """
159
+ from ring.algorithms.generator.finalize_fns import _P_gains
160
+
161
+ # all bodies in the subsystem
162
+ bodies = sys.findall_bodies_subsystem(link_name) + [sys.name_to_idx(link_name)]
163
+
164
+ jts_subsys = set([sys.link_types[i] for i in bodies]) - set(["frozen"])
165
+ postfix = "_" + link_name
166
+ # create new joint types with updated motionconfig
167
+ for typ in jts_subsys:
168
+ register_new_joint_type(
169
+ typ + postfix,
170
+ get_joint_model(typ),
171
+ base.Q_WIDTHS[typ],
172
+ base.QD_WIDTHS[typ],
173
+ )
174
+ MotionConfig.overwrite_for_joint_type(typ + postfix, **changes)
175
+ _P_gains[typ + postfix] = _P_gains[typ]
176
+
177
+ # rename all jointtypes
178
+ new_link_types = [
179
+ (
180
+ sys.link_types[i] + postfix
181
+ if (i in bodies and sys.link_types[i] != "frozen")
182
+ else sys.link_types[i]
183
+ )
184
+ for i in range(sys.num_links())
185
+ ]
186
+ sys = sys.replace(link_types=new_link_types)
187
+ return sys
188
+
116
189
  @staticmethod
117
190
  def from_register(name: str) -> "MotionConfig":
118
191
  return _registered_motion_configs[name]
@@ -221,6 +294,37 @@ _registered_motion_configs = {
221
294
  }
222
295
 
223
296
 
297
+ def _joint_specific_overwrites_free_cor(
298
+ id: str, dang: float, dpos: float
299
+ ) -> MotionConfig:
300
+ changes = dict(
301
+ dang_max_free_spherical=dang,
302
+ dpos_max=dpos,
303
+ cor_dpos_max=dpos,
304
+ t_min=1.5,
305
+ t_max=15.0,
306
+ )
307
+ return replace(
308
+ _registered_motion_configs[id],
309
+ joint_type_specific_overwrites=dict(free=changes, cor=changes),
310
+ )
311
+
312
+
313
+ _registered_motion_configs.update(
314
+ {
315
+ f"{id}-S": _joint_specific_overwrites_free_cor(id, 0.2, 0.1)
316
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
317
+ }
318
+ )
319
+ _registered_motion_configs.update(
320
+ {
321
+ f"{id}-S+": _joint_specific_overwrites_free_cor(id, 0.1, 0.05)
322
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
323
+ }
324
+ )
325
+ del _joint_specific_overwrites_free_cor
326
+
327
+
224
328
  def _is_feasible_config1(c: MotionConfig) -> bool:
225
329
  t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
226
330
 
@@ -254,8 +358,26 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
254
358
  cond2 = inside_box_checks(
255
359
  _to_float(c.pos_min, 0.0), _to_float(c.pos_max, 0.0), c.pos0_min, c.pos0_max
256
360
  )
361
+ cond3 = inside_box_checks(
362
+ _to_float(c.pos_min_p3d_x, 0.0),
363
+ _to_float(c.pos_max_p3d_x, 0.0),
364
+ c.pos0_min_p3d_x,
365
+ c.pos0_max_p3d_x,
366
+ )
367
+ cond4 = inside_box_checks(
368
+ _to_float(c.pos_min_p3d_y, 0.0),
369
+ _to_float(c.pos_max_p3d_y, 0.0),
370
+ c.pos0_min_p3d_y,
371
+ c.pos0_max_p3d_y,
372
+ )
373
+ cond5 = inside_box_checks(
374
+ _to_float(c.pos_min_p3d_z, 0.0),
375
+ _to_float(c.pos_max_p3d_z, 0.0),
376
+ c.pos0_min_p3d_z,
377
+ c.pos0_max_p3d_z,
378
+ )
257
379
 
258
- return cond1 and cond2
380
+ return cond1 and cond2 and cond3 and cond4 and cond5
259
381
 
260
382
 
261
383
  def _find_interval(t: jax.Array, boundaries: jax.Array):
@@ -504,7 +626,11 @@ def _draw_pxyz(
504
626
  cor: bool = False,
505
627
  ) -> jax.Array:
506
628
  key_value, consume = jax.random.split(key_value)
507
- POS_0 = jax.random.uniform(consume, minval=config.pos0_min, maxval=config.pos0_max)
629
+ POS_0 = jax.random.uniform(
630
+ consume,
631
+ minval=config.cor_pos0_min if cor else config.pos0_min,
632
+ maxval=config.cor_pos0_max if cor else config.pos0_max,
633
+ )
508
634
  max_iter = 100
509
635
  return _random.random_position_over_time(
510
636
  key_value,
@@ -590,10 +716,27 @@ def _draw_p3d_and_cor(
590
716
  __: jax.Array,
591
717
  cor: bool,
592
718
  ) -> jax.Array:
593
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
594
- jax.random.split(key_value, 3)
595
- )
596
- return pos.T
719
+ keys = jax.random.split(key_value, 3)
720
+
721
+ def draw(key, xyz: str):
722
+ return _draw_pxyz(
723
+ replace(
724
+ config,
725
+ pos_min=getattr(config, f"pos_min_p3d_{xyz}"),
726
+ pos_max=getattr(config, f"pos_max_p3d_{xyz}"),
727
+ pos0_min=getattr(config, f"pos0_min_p3d_{xyz}"),
728
+ pos0_max=getattr(config, f"pos0_max_p3d_{xyz}"),
729
+ ),
730
+ None,
731
+ key,
732
+ dt,
733
+ N,
734
+ None,
735
+ cor,
736
+ )[:, None]
737
+
738
+ px, py, pz = draw(keys[0], "x"), draw(keys[1], "y"), draw(keys[2], "z")
739
+ return jnp.concat((px, py, pz), axis=-1)
597
740
 
598
741
 
599
742
  def _draw_p3d(
@@ -7,6 +7,7 @@ from jax.core import Tracer
7
7
  import jax.numpy as jnp
8
8
  from jax.tree_util import tree_map
9
9
  import numpy as np
10
+ import tree
10
11
  import tree_utils as tu
11
12
 
12
13
  import ring
@@ -590,6 +591,34 @@ class System(_Base):
590
591
 
591
592
  return sys
592
593
 
594
+ @staticmethod
595
+ def joint_type_simplification(typ: str) -> str:
596
+ if typ[:4] == "free":
597
+ if typ == "free_2d":
598
+ return "free_2d"
599
+ else:
600
+ return "free"
601
+ elif typ[:3] == "cor":
602
+ return "cor"
603
+ elif typ[:9] == "spherical":
604
+ return "spherical"
605
+ else:
606
+ return typ
607
+
608
+ @staticmethod
609
+ def joint_type_is_free_or_cor(typ: str) -> bool:
610
+ return System.joint_type_simplification(typ) in ["free", "cor"]
611
+
612
+ @staticmethod
613
+ def joint_type_is_spherical(typ: str) -> bool:
614
+ return System.joint_type_simplification(typ) == "spherical"
615
+
616
+ @staticmethod
617
+ def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
618
+ return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
619
+ typ
620
+ )
621
+
593
622
  def findall_imus(self, names: bool = True) -> list[str] | list[int]:
594
623
  bodies = [name for name in self.link_names if name[:3] == "imu"]
595
624
  return bodies if names else [self.name_to_idx(n) for n in bodies]
@@ -618,10 +647,20 @@ class System(_Base):
618
647
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
619
648
 
620
649
  def children(self, name: str, names: bool = False) -> list[int] | list[str]:
650
+ "List all direct children of body, does not include body itself"
621
651
  p = self.name_to_idx(name)
622
652
  bodies = [i for i in range(self.num_links()) if self.link_parents[i] == p]
623
653
  return bodies if (not names) else [self.idx_to_name(i) for i in bodies]
624
654
 
655
+ def findall_bodies_subsystem(
656
+ self, name: str, names: bool = False
657
+ ) -> list[int] | list[str]:
658
+ "List all children and children's children; does not include body itself"
659
+ children = self.children(name, names=True)
660
+ grandchildren = [self.findall_bodies_subsystem(n, names=True) for n in children]
661
+ bodies = tree.flatten([children, grandchildren])
662
+ return bodies if names else [self.name_to_idx(n) for n in bodies]
663
+
625
664
  def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
626
665
  """Scan `f` along each link in system whilst carrying along state.
627
666
 
@@ -889,7 +928,9 @@ def _parse_system(sys: System) -> System:
889
928
  assert d.size == a.size == s.size == qd_size, error_msg
890
929
  assert z.size == q_size, error_msg
891
930
 
892
- if typ in ["spherical", "free", "cor"] and not isinstance(z, Tracer):
931
+ if System.joint_type_is_free_or_cor_or_spherical(typ) and not isinstance(
932
+ z, Tracer
933
+ ):
893
934
  assert jnp.allclose(
894
935
  jnp.linalg.norm(z[:4]), 1.0
895
936
  ), f"not unit quat for link `{name}` of typ `{typ}` in model"
@@ -1030,7 +1071,7 @@ class State(_Base):
1030
1071
  def replace_by_unit_quat(_, idx_map, link_typ, link_idx):
1031
1072
  nonlocal q
1032
1073
 
1033
- if link_typ in ["free", "cor", "spherical"]:
1074
+ if sys.joint_type_is_free_or_cor_or_spherical(link_typ):
1034
1075
  q_idxs_link = idx_map["q"](link_idx)
1035
1076
  q = q.at[q_idxs_link.start].set(1.0)
1036
1077
 
@@ -3,6 +3,7 @@ from typing import Tuple, TypeVar
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import numpy as np
6
+
6
7
  from ring import base
7
8
 
8
9
  T = TypeVar("T")
@@ -17,7 +18,7 @@ default_stiffness = lambda qd_size, **_: jnp.zeros((qd_size,))
17
18
 
18
19
  def default_zeropoint(q_size, link_typ: str, **_):
19
20
  zeropoint = jnp.zeros((q_size))
20
- if link_typ in ["spherical", "free", "cor"]:
21
+ if base.System.joint_type_is_free_or_cor_or_spherical(link_typ):
21
22
  # zeropoint then is unit quaternion and not zeros
22
23
  zeropoint = zeropoint.at[0].set(1.0)
23
24
  return zeropoint
@@ -167,7 +167,10 @@ def train_fn(
167
167
  tbp=tbp,
168
168
  )
169
169
 
170
- default_callbacks = []
170
+ # always log, because we also want `i_epsiode` to be logged in wandb
171
+ default_callbacks = [
172
+ ml_callbacks.LogEpisodeTrainingLoopCallback(callback_kill_after_episode)
173
+ ]
171
174
  if metrices is not None:
172
175
  eval_fn = _build_eval_fn(metrices, filter, link_names)
173
176
  default_callbacks.append(_DefaultEvalFnCallback(eval_fn))
@@ -192,11 +195,6 @@ def train_fn(
192
195
  if callback_kill_if_nan:
193
196
  default_callbacks.append(ml_callbacks.NanKillRunCallback())
194
197
 
195
- # always log, because we also want `i_epsiode` to be logged in wandb
196
- default_callbacks.append(
197
- ml_callbacks.LogEpisodeTrainingLoopCallback(callback_kill_after_episode)
198
- )
199
-
200
198
  if callback_kill_after_seconds is not None:
201
199
  default_callbacks.append(
202
200
  ml_callbacks.TimingKillRunCallback(callback_kill_after_seconds)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes