imt-ring 1.6.20__py3-none-any.whl → 1.6.22__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.20
3
+ Version: 1.6.22
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,12 +1,12 @@
1
1
  ring/__init__.py,sha256=ncBRdHvge6uqpzFyk8_HUQNx4kZxENpVXTJTEK_SjCg,5216
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=Ystn1EjTyOXBhVm5koroV_YPUYtFxrteJLd-XR3kEL8,33840
3
+ ring/base.py,sha256=yPdbPywwDllCRsJEbnLW4s9Z-bBD8qdxpEDYV3pCLP8,35296
4
4
  ring/maths.py,sha256=qPHH6TpHCK3TgExI98gNEySoSRKOwteN9McUlyUFipI,12207
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
- ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
- ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
9
- ring/algorithms/jcalc.py,sha256=kQgrRE0XoUBrcdeFbUw_xKnL4m1P2G1Q1m4n7BX2yDk,30064
7
+ ring/algorithms/_random.py,sha256=6QwV7EAE2ckDi7rSGXT1XcPhR5R-NU-J8HVa5DoPBSQ,14011
8
+ ring/algorithms/dynamics.py,sha256=dpe-F3Yq4sY2dY6DQW3v7TnPLRdxdkePtdbGPQIrijg,10997
9
+ ring/algorithms/jcalc.py,sha256=pniA8uZutN6H_4sP_YxHbGfW-mWEEUimHGkwLC0fgms,35049
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
@@ -15,9 +15,9 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=vxUdA0ZeSNH3SOanL51qVRvCiJrmsWQyQX0g2fdm3Rg,15825
18
+ ring/algorithms/generator/base.py,sha256=nFVUcFl7AILnyjuXf_YQJxI12MB4TzJ1SRFREKzqv8Q,16531
19
19
  ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
20
- ring/algorithms/generator/finalize_fns.py,sha256=EZ5p7fuZu0Zd0rHJzCVg3vy3U9ysny6TfQfolIGPERc,10029
20
+ ring/algorithms/generator/finalize_fns.py,sha256=nY2RKiLbHriTkdec94lc4UGSZKd0v547MDNn4dr8I3E,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
22
22
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
23
23
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
@@ -46,7 +46,7 @@ ring/io/examples/exclude/standard_sys_rr_imp.xml,sha256=1K8aLe4n97qFMQaytXdHVHDf
46
46
  ring/io/examples/test_morph_system/four_seg_seg1.xml,sha256=XJvGtEnvedejs_OmCVfQULWJNK8MLDQQ3raqPNRCJbA,1283
47
47
  ring/io/examples/test_morph_system/four_seg_seg3.xml,sha256=HktN7_a_Ly3YflWit5W-WncxApWGMORAGnRXyMEqnoA,1265
48
48
  ring/io/xml/__init__.py,sha256=-3k6ffvFyc4zm0oTyVz3ez-o3Lb9bPp2sjwSub_K1AA,242
49
- ring/io/xml/abstract.py,sha256=ojsXgz15J4pI1FThUIZbB_Iw1wmR9cOHt5Thmcbih4I,9721
49
+ ring/io/xml/abstract.py,sha256=8Q2ebnUYLmuS9HJAQwDVrDTrRfD5z4G5RAB7MW8Oa60,9742
50
50
  ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
51
51
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
52
52
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
@@ -58,7 +58,7 @@ ring/ml/ml_utils.py,sha256=xqy9BnLy8IKVqkFS9mlZsGJXSbThI9zZxZ5rhl8LSI8,7144
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
61
- ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
+ ring/ml/train.py,sha256=XuUUB0NhvByGtZDtS_weyp-TKPG9ErnKixS4NqB8q6M,10822
62
62
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
63
63
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
64
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.20.dist-info/METADATA,sha256=uOgzTpTz7MgDEIGhL7onaD_r64DlL6ALSnaT5t8u3HY,4097
90
- imt_ring-1.6.20.dist-info/WHEEL,sha256=a7TGlA-5DaHMRrarXjVbQagU3Man_dCnGIWMJr5kRWo,91
91
- imt_ring-1.6.20.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.20.dist-info/RECORD,,
89
+ imt_ring-1.6.22.dist-info/METADATA,sha256=EHw8XK0JO8d6lQ-PiknLrM7V5ATyLkniMzBFydE6NA4,4097
90
+ imt_ring-1.6.22.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
91
+ imt_ring-1.6.22.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.22.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.4.0)
2
+ Generator: setuptools (75.5.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -199,7 +199,7 @@ def random_position_over_time(
199
199
  POS = jnp.zeros((int(T // t_min) + 1, 2))
200
200
  POS = POS.at[0, 1].set(POS_0)
201
201
 
202
- val_outer = (1, 0.0, 0.0, 0.0, 0.0, key, POS)
202
+ val_outer = (1, 0.0, 0.0, POS_0, POS_0, key, POS)
203
203
  end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
204
204
  POS = jnp.where(
205
205
  (jnp.arange(len(POS)) < end)[:, None],
@@ -190,11 +190,11 @@ def _spring_force(sys: base.System, q: jax.Array):
190
190
 
191
191
  def _calc_spring_force_per_link(_, __, q, zeropoint, typ):
192
192
  # cor is (free, p3d) stacked; free is (spherical, p3d) stacked
193
- if typ in ["free", "cor"]:
193
+ if base.System.joint_type_is_free_or_cor(typ):
194
194
  quat_force = _quaternion_spring_force(zeropoint[:4], q[:4])
195
195
  pos_force = zeropoint[4:] - q[4:]
196
196
  q_spring_force_link = jnp.concatenate((quat_force, pos_force))
197
- elif typ == "spherical":
197
+ elif base.System.joint_type_is_spherical(typ):
198
198
  q_spring_force_link = _quaternion_spring_force(zeropoint, q)
199
199
  else:
200
200
  q_spring_force_link = zeropoint - q
@@ -268,11 +268,11 @@ def _semi_implicit_euler_integration(
268
268
  q_next = []
269
269
 
270
270
  def q_integrate(_, __, q, qd, typ):
271
- if typ in ["free", "cor"]:
271
+ if sys.joint_type_is_free_or_cor(typ):
272
272
  quat_next = _strapdown_integration(q[:4], qd[:3], sys.dt)
273
273
  pos_next = q[4:] + qd[3:] * sys.dt
274
274
  q_next_i = jnp.concatenate((quat_next, pos_next))
275
- elif typ == "spherical":
275
+ elif sys.joint_type_is_spherical(typ):
276
276
  quat_next = _strapdown_integration(q, qd, sys.dt)
277
277
  q_next_i = quat_next
278
278
  else:
@@ -1,3 +1,4 @@
1
+ from dataclasses import replace
1
2
  from functools import partial
2
3
  import random
3
4
  from typing import Callable, Optional
@@ -34,6 +35,8 @@ class RCMG:
34
35
  add_y_relpose: bool = False,
35
36
  add_y_rootincl: bool = False,
36
37
  add_y_rootincl_kwargs: dict = dict(),
38
+ add_y_rootfull: bool = False,
39
+ add_y_rootfull_kwargs: dict = dict(),
37
40
  sys_ml: Optional[base.System] = None,
38
41
  randomize_positions: bool = False,
39
42
  randomize_motion_artifacts: bool = False,
@@ -73,6 +76,8 @@ class RCMG:
73
76
  add_y_relpose=add_y_relpose,
74
77
  add_y_rootincl=add_y_rootincl,
75
78
  add_y_rootincl_kwargs=add_y_rootincl_kwargs,
79
+ add_y_rootfull=add_y_rootfull,
80
+ add_y_rootfull_kwargs=add_y_rootfull_kwargs,
76
81
  sys_ml=sys_ml,
77
82
  randomize_positions=randomize_positions,
78
83
  randomize_motion_artifacts=randomize_motion_artifacts,
@@ -279,6 +284,8 @@ def _build_mconfig_batched_generator(
279
284
  add_y_relpose: bool,
280
285
  add_y_rootincl: bool,
281
286
  add_y_rootincl_kwargs: dict,
287
+ add_y_rootfull: bool,
288
+ add_y_rootfull_kwargs: dict,
282
289
  sys_ml: base.System,
283
290
  randomize_positions: bool,
284
291
  randomize_motion_artifacts: bool,
@@ -365,7 +372,11 @@ def _build_mconfig_batched_generator(
365
372
  if add_y_relpose:
366
373
  pipe.append(finalize_fns.RelPose(sys_noimu))
367
374
  if add_y_rootincl:
375
+ assert not add_y_rootfull
368
376
  pipe.append(finalize_fns.RootIncl(sys_noimu, **add_y_rootincl_kwargs))
377
+ if add_y_rootfull:
378
+ assert not add_y_rootincl
379
+ pipe.append(finalize_fns.RootFull(sys_noimu, **add_y_rootfull_kwargs))
369
380
  if use_link_number_in_Xy:
370
381
  pipe.append(finalize_fns.Names2Indices(sys_noimu))
371
382
 
@@ -436,7 +447,15 @@ def draw_random_q(
436
447
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
437
448
  if draw_fn is None:
438
449
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
439
- q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
450
+
451
+ if link_type in config.joint_type_specific_overwrites:
452
+ _config = replace(
453
+ config, **config.joint_type_specific_overwrites[link_type]
454
+ )
455
+ else:
456
+ _config = config
457
+
458
+ q_link = draw_fn(_config, key_t, key_value, sys.dt, N, joint_params)
440
459
  # even revolute and prismatic joints must be 2d arrays
441
460
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
442
461
  q_list.append(q_link)
@@ -88,6 +88,18 @@ class RootIncl:
88
88
  return (X, y), (key, q, x, sys_x)
89
89
 
90
90
 
91
+ class RootFull:
92
+ def __init__(self, sys: base.System, **kwargs):
93
+ self.sys = sys
94
+ self.kwargs = kwargs
95
+
96
+ def __call__(self, Xy, extras):
97
+ (X, y), (key, q, x, sys_x) = Xy, extras
98
+ y_root_incl = sensors.root_full(self.sys, x, sys_x, **self.kwargs)
99
+ y = utils.dict_union(y, y_root_incl)
100
+ return (X, y), (key, q, x, sys_x)
101
+
102
+
91
103
  _default_imu_kwargs = dict(
92
104
  noisy=True,
93
105
  low_pass_filter_pos_f_cutoff=13.5,
ring/algorithms/jcalc.py CHANGED
@@ -40,6 +40,12 @@ class MotionConfig:
40
40
  dpos_max: float | TimeDependentFloat = 0.7
41
41
  pos_min: float | TimeDependentFloat = -2.5
42
42
  pos_max: float | TimeDependentFloat = +2.5
43
+ pos_min_p3d_x: float | TimeDependentFloat = -2.5
44
+ pos_max_p3d_x: float | TimeDependentFloat = +2.5
45
+ pos_min_p3d_y: float | TimeDependentFloat = -2.5
46
+ pos_max_p3d_y: float | TimeDependentFloat = +2.5
47
+ pos_min_p3d_z: float | TimeDependentFloat = -2.5
48
+ pos_max_p3d_z: float | TimeDependentFloat = +2.5
43
49
 
44
50
  # used by both `random_angle_*` and `random_pos_*`
45
51
  # only used if `randomized_interpolation` is set
@@ -59,6 +65,12 @@ class MotionConfig:
59
65
  ang0_max: float = jnp.pi
60
66
  pos0_min: float = 0.0
61
67
  pos0_max: float = 0.0
68
+ pos0_min_p3d_x: float = 0.0
69
+ pos0_max_p3d_x: float = 0.0
70
+ pos0_min_p3d_y: float = 0.0
71
+ pos0_max_p3d_y: float = 0.0
72
+ pos0_min_p3d_z: float = 0.0
73
+ pos0_max_p3d_z: float = 0.0
62
74
 
63
75
  # cor (center of rotation) custom fields
64
76
  cor_t_min: float = 0.2
@@ -67,6 +79,14 @@ class MotionConfig:
67
79
  cor_dpos_max: float | TimeDependentFloat = 0.5
68
80
  cor_pos_min: float | TimeDependentFloat = -0.4
69
81
  cor_pos_max: float | TimeDependentFloat = 0.4
82
+ cor_pos0_min: float = 0.0
83
+ cor_pos0_max: float = 0.0
84
+
85
+ # specify changes for this motionconfig and for specific joint types
86
+ # map of `link_types` -> dictionary of changes
87
+ joint_type_specific_overwrites: dict[str, dict[str, Any]] = field(
88
+ default_factory=lambda: dict()
89
+ )
70
90
 
71
91
  def is_feasible(self) -> bool:
72
92
  return _is_feasible_config1(self)
@@ -92,6 +112,9 @@ class MotionConfig:
92
112
  def overwrite_for_joint_type(joint_type: str, **changes) -> None:
93
113
  """Changes values of the `MotionConfig` used by the draw_fn for only a specific
94
114
  joint.
115
+ !!! Note
116
+ This applies these changes to *all* MotionConfigs for this joint type!
117
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
95
118
  """
96
119
  previous_changes = _overwrite_for_joint_type_changes[joint_type]
97
120
  for change in changes:
@@ -113,6 +136,56 @@ class MotionConfig:
113
136
  overwrite=True,
114
137
  )
115
138
 
139
+ @staticmethod
140
+ def overwrite_for_subsystem(
141
+ sys: base.System, link_name: str, **changes
142
+ ) -> base.System:
143
+ """Modifies motionconfig of all joints in subsystem with root `link_name`.
144
+ Note that if the subsystem contains a free joint then the jointtype will
145
+ will be re-named to `free_<link_name>`, then the RCMG flag `cor` will
146
+ potentially not work as expected because it searches for all joints of
147
+ type `free` to replace with `cor`. The workaround here is to change the
148
+ type already from `free` to `cor in the xml file.
149
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
150
+
151
+ Args:
152
+ sys (base.System): System object that gets updated
153
+ link_name (str): Root node of subsystem
154
+ changes: Changes to apply to the motionconfig
155
+
156
+ Return:
157
+ base.System: Updated system with new jointtypes
158
+ """
159
+ from ring.algorithms.generator.finalize_fns import _P_gains
160
+
161
+ # all bodies in the subsystem
162
+ bodies = sys.findall_bodies_subsystem(link_name) + [sys.name_to_idx(link_name)]
163
+
164
+ jts_subsys = set([sys.link_types[i] for i in bodies]) - set(["frozen"])
165
+ postfix = "_" + link_name
166
+ # create new joint types with updated motionconfig
167
+ for typ in jts_subsys:
168
+ register_new_joint_type(
169
+ typ + postfix,
170
+ get_joint_model(typ),
171
+ base.Q_WIDTHS[typ],
172
+ base.QD_WIDTHS[typ],
173
+ )
174
+ MotionConfig.overwrite_for_joint_type(typ + postfix, **changes)
175
+ _P_gains[typ + postfix] = _P_gains[typ]
176
+
177
+ # rename all jointtypes
178
+ new_link_types = [
179
+ (
180
+ sys.link_types[i] + postfix
181
+ if (i in bodies and sys.link_types[i] != "frozen")
182
+ else sys.link_types[i]
183
+ )
184
+ for i in range(sys.num_links())
185
+ ]
186
+ sys = sys.replace(link_types=new_link_types)
187
+ return sys
188
+
116
189
  @staticmethod
117
190
  def from_register(name: str) -> "MotionConfig":
118
191
  return _registered_motion_configs[name]
@@ -221,6 +294,37 @@ _registered_motion_configs = {
221
294
  }
222
295
 
223
296
 
297
+ def _joint_specific_overwrites_free_cor(
298
+ id: str, dang: float, dpos: float
299
+ ) -> MotionConfig:
300
+ changes = dict(
301
+ dang_max_free_spherical=dang,
302
+ dpos_max=dpos,
303
+ cor_dpos_max=dpos,
304
+ t_min=1.5,
305
+ t_max=15.0,
306
+ )
307
+ return replace(
308
+ _registered_motion_configs[id],
309
+ joint_type_specific_overwrites=dict(free=changes, cor=changes),
310
+ )
311
+
312
+
313
+ _registered_motion_configs.update(
314
+ {
315
+ f"{id}-S": _joint_specific_overwrites_free_cor(id, 0.2, 0.1)
316
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
317
+ }
318
+ )
319
+ _registered_motion_configs.update(
320
+ {
321
+ f"{id}-S+": _joint_specific_overwrites_free_cor(id, 0.1, 0.05)
322
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
323
+ }
324
+ )
325
+ del _joint_specific_overwrites_free_cor
326
+
327
+
224
328
  def _is_feasible_config1(c: MotionConfig) -> bool:
225
329
  t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
226
330
 
@@ -254,8 +358,26 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
254
358
  cond2 = inside_box_checks(
255
359
  _to_float(c.pos_min, 0.0), _to_float(c.pos_max, 0.0), c.pos0_min, c.pos0_max
256
360
  )
361
+ cond3 = inside_box_checks(
362
+ _to_float(c.pos_min_p3d_x, 0.0),
363
+ _to_float(c.pos_max_p3d_x, 0.0),
364
+ c.pos0_min_p3d_x,
365
+ c.pos0_max_p3d_x,
366
+ )
367
+ cond4 = inside_box_checks(
368
+ _to_float(c.pos_min_p3d_y, 0.0),
369
+ _to_float(c.pos_max_p3d_y, 0.0),
370
+ c.pos0_min_p3d_y,
371
+ c.pos0_max_p3d_y,
372
+ )
373
+ cond5 = inside_box_checks(
374
+ _to_float(c.pos_min_p3d_z, 0.0),
375
+ _to_float(c.pos_max_p3d_z, 0.0),
376
+ c.pos0_min_p3d_z,
377
+ c.pos0_max_p3d_z,
378
+ )
257
379
 
258
- return cond1 and cond2
380
+ return cond1 and cond2 and cond3 and cond4 and cond5
259
381
 
260
382
 
261
383
  def _find_interval(t: jax.Array, boundaries: jax.Array):
@@ -504,7 +626,11 @@ def _draw_pxyz(
504
626
  cor: bool = False,
505
627
  ) -> jax.Array:
506
628
  key_value, consume = jax.random.split(key_value)
507
- POS_0 = jax.random.uniform(consume, minval=config.pos0_min, maxval=config.pos0_max)
629
+ POS_0 = jax.random.uniform(
630
+ consume,
631
+ minval=config.cor_pos0_min if cor else config.pos0_min,
632
+ maxval=config.cor_pos0_max if cor else config.pos0_max,
633
+ )
508
634
  max_iter = 100
509
635
  return _random.random_position_over_time(
510
636
  key_value,
@@ -590,10 +716,27 @@ def _draw_p3d_and_cor(
590
716
  __: jax.Array,
591
717
  cor: bool,
592
718
  ) -> jax.Array:
593
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
594
- jax.random.split(key_value, 3)
595
- )
596
- return pos.T
719
+ keys = jax.random.split(key_value, 3)
720
+
721
+ def draw(key, xyz: str):
722
+ return _draw_pxyz(
723
+ replace(
724
+ config,
725
+ pos_min=getattr(config, f"pos_min_p3d_{xyz}"),
726
+ pos_max=getattr(config, f"pos_max_p3d_{xyz}"),
727
+ pos0_min=getattr(config, f"pos0_min_p3d_{xyz}"),
728
+ pos0_max=getattr(config, f"pos0_max_p3d_{xyz}"),
729
+ ),
730
+ None,
731
+ key,
732
+ dt,
733
+ N,
734
+ None,
735
+ cor,
736
+ )[:, None]
737
+
738
+ px, py, pz = draw(keys[0], "x"), draw(keys[1], "y"), draw(keys[2], "z")
739
+ return jnp.concat((px, py, pz), axis=-1)
597
740
 
598
741
 
599
742
  def _draw_p3d(
ring/base.py CHANGED
@@ -7,6 +7,7 @@ from jax.core import Tracer
7
7
  import jax.numpy as jnp
8
8
  from jax.tree_util import tree_map
9
9
  import numpy as np
10
+ import tree
10
11
  import tree_utils as tu
11
12
 
12
13
  import ring
@@ -590,6 +591,34 @@ class System(_Base):
590
591
 
591
592
  return sys
592
593
 
594
+ @staticmethod
595
+ def joint_type_simplification(typ: str) -> str:
596
+ if typ[:4] == "free":
597
+ if typ == "free_2d":
598
+ return "free_2d"
599
+ else:
600
+ return "free"
601
+ elif typ[:3] == "cor":
602
+ return "cor"
603
+ elif typ[:9] == "spherical":
604
+ return "spherical"
605
+ else:
606
+ return typ
607
+
608
+ @staticmethod
609
+ def joint_type_is_free_or_cor(typ: str) -> bool:
610
+ return System.joint_type_simplification(typ) in ["free", "cor"]
611
+
612
+ @staticmethod
613
+ def joint_type_is_spherical(typ: str) -> bool:
614
+ return System.joint_type_simplification(typ) == "spherical"
615
+
616
+ @staticmethod
617
+ def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
618
+ return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
619
+ typ
620
+ )
621
+
593
622
  def findall_imus(self, names: bool = True) -> list[str] | list[int]:
594
623
  bodies = [name for name in self.link_names if name[:3] == "imu"]
595
624
  return bodies if names else [self.name_to_idx(n) for n in bodies]
@@ -618,10 +647,20 @@ class System(_Base):
618
647
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
619
648
 
620
649
  def children(self, name: str, names: bool = False) -> list[int] | list[str]:
650
+ "List all direct children of body, does not include body itself"
621
651
  p = self.name_to_idx(name)
622
652
  bodies = [i for i in range(self.num_links()) if self.link_parents[i] == p]
623
653
  return bodies if (not names) else [self.idx_to_name(i) for i in bodies]
624
654
 
655
+ def findall_bodies_subsystem(
656
+ self, name: str, names: bool = False
657
+ ) -> list[int] | list[str]:
658
+ "List all children and children's children; does not include body itself"
659
+ children = self.children(name, names=True)
660
+ grandchildren = [self.findall_bodies_subsystem(n, names=True) for n in children]
661
+ bodies = tree.flatten([children, grandchildren])
662
+ return bodies if names else [self.name_to_idx(n) for n in bodies]
663
+
625
664
  def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
626
665
  """Scan `f` along each link in system whilst carrying along state.
627
666
 
@@ -889,7 +928,9 @@ def _parse_system(sys: System) -> System:
889
928
  assert d.size == a.size == s.size == qd_size, error_msg
890
929
  assert z.size == q_size, error_msg
891
930
 
892
- if typ in ["spherical", "free", "cor"] and not isinstance(z, Tracer):
931
+ if System.joint_type_is_free_or_cor_or_spherical(typ) and not isinstance(
932
+ z, Tracer
933
+ ):
893
934
  assert jnp.allclose(
894
935
  jnp.linalg.norm(z[:4]), 1.0
895
936
  ), f"not unit quat for link `{name}` of typ `{typ}` in model"
@@ -1030,7 +1071,7 @@ class State(_Base):
1030
1071
  def replace_by_unit_quat(_, idx_map, link_typ, link_idx):
1031
1072
  nonlocal q
1032
1073
 
1033
- if link_typ in ["free", "cor", "spherical"]:
1074
+ if sys.joint_type_is_free_or_cor_or_spherical(link_typ):
1034
1075
  q_idxs_link = idx_map["q"](link_idx)
1035
1076
  q = q.at[q_idxs_link.start].set(1.0)
1036
1077
 
ring/io/xml/abstract.py CHANGED
@@ -3,6 +3,7 @@ from typing import Tuple, TypeVar
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import numpy as np
6
+
6
7
  from ring import base
7
8
 
8
9
  T = TypeVar("T")
@@ -17,7 +18,7 @@ default_stiffness = lambda qd_size, **_: jnp.zeros((qd_size,))
17
18
 
18
19
  def default_zeropoint(q_size, link_typ: str, **_):
19
20
  zeropoint = jnp.zeros((q_size))
20
- if link_typ in ["spherical", "free", "cor"]:
21
+ if base.System.joint_type_is_free_or_cor_or_spherical(link_typ):
21
22
  # zeropoint then is unit quaternion and not zeros
22
23
  zeropoint = zeropoint.at[0].set(1.0)
23
24
  return zeropoint
ring/ml/train.py CHANGED
@@ -167,7 +167,10 @@ def train_fn(
167
167
  tbp=tbp,
168
168
  )
169
169
 
170
- default_callbacks = []
170
+ # always log, because we also want `i_epsiode` to be logged in wandb
171
+ default_callbacks = [
172
+ ml_callbacks.LogEpisodeTrainingLoopCallback(callback_kill_after_episode)
173
+ ]
171
174
  if metrices is not None:
172
175
  eval_fn = _build_eval_fn(metrices, filter, link_names)
173
176
  default_callbacks.append(_DefaultEvalFnCallback(eval_fn))
@@ -192,11 +195,6 @@ def train_fn(
192
195
  if callback_kill_if_nan:
193
196
  default_callbacks.append(ml_callbacks.NanKillRunCallback())
194
197
 
195
- # always log, because we also want `i_epsiode` to be logged in wandb
196
- default_callbacks.append(
197
- ml_callbacks.LogEpisodeTrainingLoopCallback(callback_kill_after_episode)
198
- )
199
-
200
198
  if callback_kill_after_seconds is not None:
201
199
  default_callbacks.append(
202
200
  ml_callbacks.TimingKillRunCallback(callback_kill_after_seconds)