imt-ring 1.6.21__py3-none-any.whl → 1.6.22__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.21
3
+ Version: 1.6.22
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,12 +1,12 @@
1
1
  ring/__init__.py,sha256=ncBRdHvge6uqpzFyk8_HUQNx4kZxENpVXTJTEK_SjCg,5216
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=Ystn1EjTyOXBhVm5koroV_YPUYtFxrteJLd-XR3kEL8,33840
3
+ ring/base.py,sha256=yPdbPywwDllCRsJEbnLW4s9Z-bBD8qdxpEDYV3pCLP8,35296
4
4
  ring/maths.py,sha256=qPHH6TpHCK3TgExI98gNEySoSRKOwteN9McUlyUFipI,12207
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=6QwV7EAE2ckDi7rSGXT1XcPhR5R-NU-J8HVa5DoPBSQ,14011
8
- ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
9
- ring/algorithms/jcalc.py,sha256=kQgrRE0XoUBrcdeFbUw_xKnL4m1P2G1Q1m4n7BX2yDk,30064
8
+ ring/algorithms/dynamics.py,sha256=dpe-F3Yq4sY2dY6DQW3v7TnPLRdxdkePtdbGPQIrijg,10997
9
+ ring/algorithms/jcalc.py,sha256=pniA8uZutN6H_4sP_YxHbGfW-mWEEUimHGkwLC0fgms,35049
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
@@ -15,7 +15,7 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=Us25eMLQwB6rIVZuq1afwe2JqGSA8yKBIE9tr87HLfk,16270
18
+ ring/algorithms/generator/base.py,sha256=nFVUcFl7AILnyjuXf_YQJxI12MB4TzJ1SRFREKzqv8Q,16531
19
19
  ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
20
20
  ring/algorithms/generator/finalize_fns.py,sha256=nY2RKiLbHriTkdec94lc4UGSZKd0v547MDNn4dr8I3E,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
@@ -46,7 +46,7 @@ ring/io/examples/exclude/standard_sys_rr_imp.xml,sha256=1K8aLe4n97qFMQaytXdHVHDf
46
46
  ring/io/examples/test_morph_system/four_seg_seg1.xml,sha256=XJvGtEnvedejs_OmCVfQULWJNK8MLDQQ3raqPNRCJbA,1283
47
47
  ring/io/examples/test_morph_system/four_seg_seg3.xml,sha256=HktN7_a_Ly3YflWit5W-WncxApWGMORAGnRXyMEqnoA,1265
48
48
  ring/io/xml/__init__.py,sha256=-3k6ffvFyc4zm0oTyVz3ez-o3Lb9bPp2sjwSub_K1AA,242
49
- ring/io/xml/abstract.py,sha256=ojsXgz15J4pI1FThUIZbB_Iw1wmR9cOHt5Thmcbih4I,9721
49
+ ring/io/xml/abstract.py,sha256=8Q2ebnUYLmuS9HJAQwDVrDTrRfD5z4G5RAB7MW8Oa60,9742
50
50
  ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
51
51
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
52
52
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.21.dist-info/METADATA,sha256=A7R1aabkG59Lz2ANQI7eOdFSAAeOPeAfvhgsAKzLw7w,4097
90
- imt_ring-1.6.21.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
91
- imt_ring-1.6.21.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.21.dist-info/RECORD,,
89
+ imt_ring-1.6.22.dist-info/METADATA,sha256=EHw8XK0JO8d6lQ-PiknLrM7V5ATyLkniMzBFydE6NA4,4097
90
+ imt_ring-1.6.22.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
91
+ imt_ring-1.6.22.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.22.dist-info/RECORD,,
@@ -190,11 +190,11 @@ def _spring_force(sys: base.System, q: jax.Array):
190
190
 
191
191
  def _calc_spring_force_per_link(_, __, q, zeropoint, typ):
192
192
  # cor is (free, p3d) stacked; free is (spherical, p3d) stacked
193
- if typ in ["free", "cor"]:
193
+ if base.System.joint_type_is_free_or_cor(typ):
194
194
  quat_force = _quaternion_spring_force(zeropoint[:4], q[:4])
195
195
  pos_force = zeropoint[4:] - q[4:]
196
196
  q_spring_force_link = jnp.concatenate((quat_force, pos_force))
197
- elif typ == "spherical":
197
+ elif base.System.joint_type_is_spherical(typ):
198
198
  q_spring_force_link = _quaternion_spring_force(zeropoint, q)
199
199
  else:
200
200
  q_spring_force_link = zeropoint - q
@@ -268,11 +268,11 @@ def _semi_implicit_euler_integration(
268
268
  q_next = []
269
269
 
270
270
  def q_integrate(_, __, q, qd, typ):
271
- if typ in ["free", "cor"]:
271
+ if sys.joint_type_is_free_or_cor(typ):
272
272
  quat_next = _strapdown_integration(q[:4], qd[:3], sys.dt)
273
273
  pos_next = q[4:] + qd[3:] * sys.dt
274
274
  q_next_i = jnp.concatenate((quat_next, pos_next))
275
- elif typ == "spherical":
275
+ elif sys.joint_type_is_spherical(typ):
276
276
  quat_next = _strapdown_integration(q, qd, sys.dt)
277
277
  q_next_i = quat_next
278
278
  else:
@@ -1,3 +1,4 @@
1
+ from dataclasses import replace
1
2
  from functools import partial
2
3
  import random
3
4
  from typing import Callable, Optional
@@ -446,7 +447,15 @@ def draw_random_q(
446
447
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
447
448
  if draw_fn is None:
448
449
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
449
- q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
450
+
451
+ if link_type in config.joint_type_specific_overwrites:
452
+ _config = replace(
453
+ config, **config.joint_type_specific_overwrites[link_type]
454
+ )
455
+ else:
456
+ _config = config
457
+
458
+ q_link = draw_fn(_config, key_t, key_value, sys.dt, N, joint_params)
450
459
  # even revolute and prismatic joints must be 2d arrays
451
460
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
452
461
  q_list.append(q_link)
ring/algorithms/jcalc.py CHANGED
@@ -40,6 +40,12 @@ class MotionConfig:
40
40
  dpos_max: float | TimeDependentFloat = 0.7
41
41
  pos_min: float | TimeDependentFloat = -2.5
42
42
  pos_max: float | TimeDependentFloat = +2.5
43
+ pos_min_p3d_x: float | TimeDependentFloat = -2.5
44
+ pos_max_p3d_x: float | TimeDependentFloat = +2.5
45
+ pos_min_p3d_y: float | TimeDependentFloat = -2.5
46
+ pos_max_p3d_y: float | TimeDependentFloat = +2.5
47
+ pos_min_p3d_z: float | TimeDependentFloat = -2.5
48
+ pos_max_p3d_z: float | TimeDependentFloat = +2.5
43
49
 
44
50
  # used by both `random_angle_*` and `random_pos_*`
45
51
  # only used if `randomized_interpolation` is set
@@ -59,6 +65,12 @@ class MotionConfig:
59
65
  ang0_max: float = jnp.pi
60
66
  pos0_min: float = 0.0
61
67
  pos0_max: float = 0.0
68
+ pos0_min_p3d_x: float = 0.0
69
+ pos0_max_p3d_x: float = 0.0
70
+ pos0_min_p3d_y: float = 0.0
71
+ pos0_max_p3d_y: float = 0.0
72
+ pos0_min_p3d_z: float = 0.0
73
+ pos0_max_p3d_z: float = 0.0
62
74
 
63
75
  # cor (center of rotation) custom fields
64
76
  cor_t_min: float = 0.2
@@ -67,6 +79,14 @@ class MotionConfig:
67
79
  cor_dpos_max: float | TimeDependentFloat = 0.5
68
80
  cor_pos_min: float | TimeDependentFloat = -0.4
69
81
  cor_pos_max: float | TimeDependentFloat = 0.4
82
+ cor_pos0_min: float = 0.0
83
+ cor_pos0_max: float = 0.0
84
+
85
+ # specify changes for this motionconfig and for specific joint types
86
+ # map of `link_types` -> dictionary of changes
87
+ joint_type_specific_overwrites: dict[str, dict[str, Any]] = field(
88
+ default_factory=lambda: dict()
89
+ )
70
90
 
71
91
  def is_feasible(self) -> bool:
72
92
  return _is_feasible_config1(self)
@@ -92,6 +112,9 @@ class MotionConfig:
92
112
  def overwrite_for_joint_type(joint_type: str, **changes) -> None:
93
113
  """Changes values of the `MotionConfig` used by the draw_fn for only a specific
94
114
  joint.
115
+ !!! Note
116
+ This applies these changes to *all* MotionConfigs for this joint type!
117
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
95
118
  """
96
119
  previous_changes = _overwrite_for_joint_type_changes[joint_type]
97
120
  for change in changes:
@@ -113,6 +136,56 @@ class MotionConfig:
113
136
  overwrite=True,
114
137
  )
115
138
 
139
+ @staticmethod
140
+ def overwrite_for_subsystem(
141
+ sys: base.System, link_name: str, **changes
142
+ ) -> base.System:
143
+ """Modifies motionconfig of all joints in subsystem with root `link_name`.
144
+ Note that if the subsystem contains a free joint then the jointtype will
145
+ will be re-named to `free_<link_name>`, then the RCMG flag `cor` will
146
+ potentially not work as expected because it searches for all joints of
147
+ type `free` to replace with `cor`. The workaround here is to change the
148
+ type already from `free` to `cor in the xml file.
149
+ This takes precedence *over* `Motionconfig.joint_type_specific_overwrites`!
150
+
151
+ Args:
152
+ sys (base.System): System object that gets updated
153
+ link_name (str): Root node of subsystem
154
+ changes: Changes to apply to the motionconfig
155
+
156
+ Return:
157
+ base.System: Updated system with new jointtypes
158
+ """
159
+ from ring.algorithms.generator.finalize_fns import _P_gains
160
+
161
+ # all bodies in the subsystem
162
+ bodies = sys.findall_bodies_subsystem(link_name) + [sys.name_to_idx(link_name)]
163
+
164
+ jts_subsys = set([sys.link_types[i] for i in bodies]) - set(["frozen"])
165
+ postfix = "_" + link_name
166
+ # create new joint types with updated motionconfig
167
+ for typ in jts_subsys:
168
+ register_new_joint_type(
169
+ typ + postfix,
170
+ get_joint_model(typ),
171
+ base.Q_WIDTHS[typ],
172
+ base.QD_WIDTHS[typ],
173
+ )
174
+ MotionConfig.overwrite_for_joint_type(typ + postfix, **changes)
175
+ _P_gains[typ + postfix] = _P_gains[typ]
176
+
177
+ # rename all jointtypes
178
+ new_link_types = [
179
+ (
180
+ sys.link_types[i] + postfix
181
+ if (i in bodies and sys.link_types[i] != "frozen")
182
+ else sys.link_types[i]
183
+ )
184
+ for i in range(sys.num_links())
185
+ ]
186
+ sys = sys.replace(link_types=new_link_types)
187
+ return sys
188
+
116
189
  @staticmethod
117
190
  def from_register(name: str) -> "MotionConfig":
118
191
  return _registered_motion_configs[name]
@@ -221,6 +294,37 @@ _registered_motion_configs = {
221
294
  }
222
295
 
223
296
 
297
+ def _joint_specific_overwrites_free_cor(
298
+ id: str, dang: float, dpos: float
299
+ ) -> MotionConfig:
300
+ changes = dict(
301
+ dang_max_free_spherical=dang,
302
+ dpos_max=dpos,
303
+ cor_dpos_max=dpos,
304
+ t_min=1.5,
305
+ t_max=15.0,
306
+ )
307
+ return replace(
308
+ _registered_motion_configs[id],
309
+ joint_type_specific_overwrites=dict(free=changes, cor=changes),
310
+ )
311
+
312
+
313
+ _registered_motion_configs.update(
314
+ {
315
+ f"{id}-S": _joint_specific_overwrites_free_cor(id, 0.2, 0.1)
316
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
317
+ }
318
+ )
319
+ _registered_motion_configs.update(
320
+ {
321
+ f"{id}-S+": _joint_specific_overwrites_free_cor(id, 0.1, 0.05)
322
+ for id in ["expSlow", "expFast", "hinUndHer", "standard"]
323
+ }
324
+ )
325
+ del _joint_specific_overwrites_free_cor
326
+
327
+
224
328
  def _is_feasible_config1(c: MotionConfig) -> bool:
225
329
  t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
226
330
 
@@ -254,8 +358,26 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
254
358
  cond2 = inside_box_checks(
255
359
  _to_float(c.pos_min, 0.0), _to_float(c.pos_max, 0.0), c.pos0_min, c.pos0_max
256
360
  )
361
+ cond3 = inside_box_checks(
362
+ _to_float(c.pos_min_p3d_x, 0.0),
363
+ _to_float(c.pos_max_p3d_x, 0.0),
364
+ c.pos0_min_p3d_x,
365
+ c.pos0_max_p3d_x,
366
+ )
367
+ cond4 = inside_box_checks(
368
+ _to_float(c.pos_min_p3d_y, 0.0),
369
+ _to_float(c.pos_max_p3d_y, 0.0),
370
+ c.pos0_min_p3d_y,
371
+ c.pos0_max_p3d_y,
372
+ )
373
+ cond5 = inside_box_checks(
374
+ _to_float(c.pos_min_p3d_z, 0.0),
375
+ _to_float(c.pos_max_p3d_z, 0.0),
376
+ c.pos0_min_p3d_z,
377
+ c.pos0_max_p3d_z,
378
+ )
257
379
 
258
- return cond1 and cond2
380
+ return cond1 and cond2 and cond3 and cond4 and cond5
259
381
 
260
382
 
261
383
  def _find_interval(t: jax.Array, boundaries: jax.Array):
@@ -504,7 +626,11 @@ def _draw_pxyz(
504
626
  cor: bool = False,
505
627
  ) -> jax.Array:
506
628
  key_value, consume = jax.random.split(key_value)
507
- POS_0 = jax.random.uniform(consume, minval=config.pos0_min, maxval=config.pos0_max)
629
+ POS_0 = jax.random.uniform(
630
+ consume,
631
+ minval=config.cor_pos0_min if cor else config.pos0_min,
632
+ maxval=config.cor_pos0_max if cor else config.pos0_max,
633
+ )
508
634
  max_iter = 100
509
635
  return _random.random_position_over_time(
510
636
  key_value,
@@ -590,10 +716,27 @@ def _draw_p3d_and_cor(
590
716
  __: jax.Array,
591
717
  cor: bool,
592
718
  ) -> jax.Array:
593
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
594
- jax.random.split(key_value, 3)
595
- )
596
- return pos.T
719
+ keys = jax.random.split(key_value, 3)
720
+
721
+ def draw(key, xyz: str):
722
+ return _draw_pxyz(
723
+ replace(
724
+ config,
725
+ pos_min=getattr(config, f"pos_min_p3d_{xyz}"),
726
+ pos_max=getattr(config, f"pos_max_p3d_{xyz}"),
727
+ pos0_min=getattr(config, f"pos0_min_p3d_{xyz}"),
728
+ pos0_max=getattr(config, f"pos0_max_p3d_{xyz}"),
729
+ ),
730
+ None,
731
+ key,
732
+ dt,
733
+ N,
734
+ None,
735
+ cor,
736
+ )[:, None]
737
+
738
+ px, py, pz = draw(keys[0], "x"), draw(keys[1], "y"), draw(keys[2], "z")
739
+ return jnp.concat((px, py, pz), axis=-1)
597
740
 
598
741
 
599
742
  def _draw_p3d(
ring/base.py CHANGED
@@ -7,6 +7,7 @@ from jax.core import Tracer
7
7
  import jax.numpy as jnp
8
8
  from jax.tree_util import tree_map
9
9
  import numpy as np
10
+ import tree
10
11
  import tree_utils as tu
11
12
 
12
13
  import ring
@@ -590,6 +591,34 @@ class System(_Base):
590
591
 
591
592
  return sys
592
593
 
594
+ @staticmethod
595
+ def joint_type_simplification(typ: str) -> str:
596
+ if typ[:4] == "free":
597
+ if typ == "free_2d":
598
+ return "free_2d"
599
+ else:
600
+ return "free"
601
+ elif typ[:3] == "cor":
602
+ return "cor"
603
+ elif typ[:9] == "spherical":
604
+ return "spherical"
605
+ else:
606
+ return typ
607
+
608
+ @staticmethod
609
+ def joint_type_is_free_or_cor(typ: str) -> bool:
610
+ return System.joint_type_simplification(typ) in ["free", "cor"]
611
+
612
+ @staticmethod
613
+ def joint_type_is_spherical(typ: str) -> bool:
614
+ return System.joint_type_simplification(typ) == "spherical"
615
+
616
+ @staticmethod
617
+ def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
618
+ return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
619
+ typ
620
+ )
621
+
593
622
  def findall_imus(self, names: bool = True) -> list[str] | list[int]:
594
623
  bodies = [name for name in self.link_names if name[:3] == "imu"]
595
624
  return bodies if names else [self.name_to_idx(n) for n in bodies]
@@ -618,10 +647,20 @@ class System(_Base):
618
647
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
619
648
 
620
649
  def children(self, name: str, names: bool = False) -> list[int] | list[str]:
650
+ "List all direct children of body, does not include body itself"
621
651
  p = self.name_to_idx(name)
622
652
  bodies = [i for i in range(self.num_links()) if self.link_parents[i] == p]
623
653
  return bodies if (not names) else [self.idx_to_name(i) for i in bodies]
624
654
 
655
+ def findall_bodies_subsystem(
656
+ self, name: str, names: bool = False
657
+ ) -> list[int] | list[str]:
658
+ "List all children and children's children; does not include body itself"
659
+ children = self.children(name, names=True)
660
+ grandchildren = [self.findall_bodies_subsystem(n, names=True) for n in children]
661
+ bodies = tree.flatten([children, grandchildren])
662
+ return bodies if names else [self.name_to_idx(n) for n in bodies]
663
+
625
664
  def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
626
665
  """Scan `f` along each link in system whilst carrying along state.
627
666
 
@@ -889,7 +928,9 @@ def _parse_system(sys: System) -> System:
889
928
  assert d.size == a.size == s.size == qd_size, error_msg
890
929
  assert z.size == q_size, error_msg
891
930
 
892
- if typ in ["spherical", "free", "cor"] and not isinstance(z, Tracer):
931
+ if System.joint_type_is_free_or_cor_or_spherical(typ) and not isinstance(
932
+ z, Tracer
933
+ ):
893
934
  assert jnp.allclose(
894
935
  jnp.linalg.norm(z[:4]), 1.0
895
936
  ), f"not unit quat for link `{name}` of typ `{typ}` in model"
@@ -1030,7 +1071,7 @@ class State(_Base):
1030
1071
  def replace_by_unit_quat(_, idx_map, link_typ, link_idx):
1031
1072
  nonlocal q
1032
1073
 
1033
- if link_typ in ["free", "cor", "spherical"]:
1074
+ if sys.joint_type_is_free_or_cor_or_spherical(link_typ):
1034
1075
  q_idxs_link = idx_map["q"](link_idx)
1035
1076
  q = q.at[q_idxs_link.start].set(1.0)
1036
1077
 
ring/io/xml/abstract.py CHANGED
@@ -3,6 +3,7 @@ from typing import Tuple, TypeVar
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import numpy as np
6
+
6
7
  from ring import base
7
8
 
8
9
  T = TypeVar("T")
@@ -17,7 +18,7 @@ default_stiffness = lambda qd_size, **_: jnp.zeros((qd_size,))
17
18
 
18
19
  def default_zeropoint(q_size, link_typ: str, **_):
19
20
  zeropoint = jnp.zeros((q_size))
20
- if link_typ in ["spherical", "free", "cor"]:
21
+ if base.System.joint_type_is_free_or_cor_or_spherical(link_typ):
21
22
  # zeropoint then is unit quaternion and not zeros
22
23
  zeropoint = zeropoint.at[0].set(1.0)
23
24
  return zeropoint