jaxsim 0.3.1.dev64__py3-none-any.whl → 0.3.1.dev113__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. jaxsim/__init__.py +5 -5
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/com.py +3 -4
  4. jaxsim/api/common.py +11 -11
  5. jaxsim/api/contact.py +11 -3
  6. jaxsim/api/data.py +3 -6
  7. jaxsim/api/frame.py +9 -10
  8. jaxsim/api/kin_dyn_parameters.py +25 -28
  9. jaxsim/api/link.py +12 -12
  10. jaxsim/api/model.py +47 -43
  11. jaxsim/api/ode.py +19 -12
  12. jaxsim/api/ode_data.py +11 -11
  13. jaxsim/integrators/common.py +17 -20
  14. jaxsim/integrators/fixed_step.py +10 -10
  15. jaxsim/integrators/variable_step.py +13 -13
  16. jaxsim/math/__init__.py +2 -1
  17. jaxsim/math/joint_model.py +2 -1
  18. jaxsim/math/quaternion.py +3 -9
  19. jaxsim/math/transform.py +2 -2
  20. jaxsim/mujoco/loaders.py +5 -5
  21. jaxsim/mujoco/model.py +6 -6
  22. jaxsim/mujoco/visualizer.py +3 -0
  23. jaxsim/parsers/__init__.py +0 -1
  24. jaxsim/parsers/descriptions/joint.py +1 -1
  25. jaxsim/parsers/descriptions/link.py +3 -4
  26. jaxsim/parsers/descriptions/model.py +1 -1
  27. jaxsim/parsers/kinematic_graph.py +38 -39
  28. jaxsim/parsers/rod/parser.py +14 -14
  29. jaxsim/parsers/rod/utils.py +9 -11
  30. jaxsim/rbda/aba.py +6 -12
  31. jaxsim/rbda/collidable_points.py +8 -7
  32. jaxsim/rbda/contacts/soft.py +29 -27
  33. jaxsim/rbda/crba.py +3 -3
  34. jaxsim/rbda/forward_kinematics.py +1 -1
  35. jaxsim/rbda/jacobian.py +8 -8
  36. jaxsim/rbda/rnea.py +3 -3
  37. jaxsim/rbda/utils.py +1 -1
  38. jaxsim/terrain/terrain.py +100 -22
  39. jaxsim/typing.py +21 -24
  40. jaxsim/utils/jaxsim_dataclass.py +4 -4
  41. jaxsim/utils/wrappers.py +5 -1
  42. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev113.dist-info}/METADATA +1 -1
  43. jaxsim-0.3.1.dev113.dist-info/RECORD +68 -0
  44. jaxsim-0.3.1.dev64.dist-info/RECORD +0 -68
  45. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev113.dist-info}/LICENSE +0 -0
  46. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev113.dist-info}/WHEEL +0 -0
  47. {jaxsim-0.3.1.dev64.dist-info → jaxsim-0.3.1.dev113.dist-info}/top_level.txt +0 -0
@@ -312,9 +312,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
312
312
 
313
313
  # Clip the estimated initial step size to the given bounds, if necessary.
314
314
  self.params["dt0"] = jnp.clip(
315
- a=self.params["dt0"],
316
- a_min=jnp.minimum(self.dt_min, self.params["dt0"]),
317
- a_max=jnp.minimum(self.dt_max, self.params["dt0"]),
315
+ self.params["dt0"],
316
+ jnp.minimum(self.dt_min, self.params["dt0"]),
317
+ jnp.minimum(self.dt_max, self.params["dt0"]),
318
318
  )
319
319
 
320
320
  # =========================================================
@@ -371,7 +371,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
371
371
 
372
372
  # Shrink the Δt every time by the safety factor (even when accepted).
373
373
  # The β parameters define the bounds of the timestep update factor.
374
- safety = jnp.clip(self.safety, a_min=0.0, a_max=1.0)
374
+ safety = jnp.clip(self.safety, 0.0, 1.0)
375
375
  β_min = jnp.maximum(0.0, self.beta_min)
376
376
  β_max = jnp.maximum(β_min, self.beta_max)
377
377
 
@@ -383,9 +383,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
383
383
  # In case of acceptance, Δt_next could either be larger than Δt0,
384
384
  # or slightly smaller than Δt0 depending on the safety factor.
385
385
  Δt_next = Δt0 * jnp.clip(
386
- a=safety * jnp.power(1 / local_error, 1 / (q + 1)),
387
- a_min=β_min,
388
- a_max=β_max,
386
+ safety * jnp.power(1 / local_error, 1 / (q + 1)),
387
+ β_min,
388
+ β_max,
389
389
  )
390
390
 
391
391
  def accept_step():
@@ -545,14 +545,14 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
545
545
  @jax_dataclasses.pytree_dataclass
546
546
  class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
547
547
 
548
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
548
+ A: ClassVar[jtp.Matrix] = jnp.array(
549
549
  [
550
550
  [0, 0],
551
551
  [1, 0],
552
552
  ]
553
553
  ).astype(float)
554
554
 
555
- b: ClassVar[jax.typing.ArrayLike] = (
555
+ b: ClassVar[jtp.Matrix] = (
556
556
  jnp.atleast_2d(
557
557
  jnp.array(
558
558
  [
@@ -565,7 +565,7 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
565
565
  .transpose()
566
566
  )
567
567
 
568
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
568
+ c: ClassVar[jtp.Vector] = jnp.array(
569
569
  [0, 1],
570
570
  ).astype(float)
571
571
 
@@ -578,7 +578,7 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
578
578
  @jax_dataclasses.pytree_dataclass
579
579
  class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
580
580
 
581
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
581
+ A: ClassVar[jtp.Matrix] = jnp.array(
582
582
  [
583
583
  [0, 0, 0, 0],
584
584
  [1 / 2, 0, 0, 0],
@@ -587,7 +587,7 @@ class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mi
587
587
  ]
588
588
  ).astype(float)
589
589
 
590
- b: ClassVar[jax.typing.ArrayLike] = (
590
+ b: ClassVar[jtp.Matrix] = (
591
591
  jnp.atleast_2d(
592
592
  jnp.array(
593
593
  [
@@ -600,7 +600,7 @@ class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mi
600
600
  .transpose()
601
601
  )
602
602
 
603
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
603
+ c: ClassVar[jtp.Vector] = jnp.array(
604
604
  [0, 1 / 2, 3 / 4, 1],
605
605
  ).astype(float)
606
606
 
jaxsim/math/__init__.py CHANGED
@@ -4,8 +4,9 @@ StandardGravity = 9.81
4
4
  from .adjoint import Adjoint
5
5
  from .cross import Cross
6
6
  from .inertia import Inertia
7
- from .joint_model import JointModel, supported_joint_motion
8
7
  from .quaternion import Quaternion
9
8
  from .rotation import Rotation
10
9
  from .skew import Skew
11
10
  from .transform import Transform
11
+
12
+ from .joint_model import JointModel, supported_joint_motion # isort:skip
@@ -11,6 +11,7 @@ from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescri
11
11
  from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
12
12
 
13
13
  from .rotation import Rotation
14
+ from .transform import Transform
14
15
 
15
16
 
16
17
  @jax_dataclasses.pytree_dataclass
@@ -162,7 +163,7 @@ class JointModel:
162
163
  joint_index=joint_index, joint_position=joint_position
163
164
  )
164
165
 
165
- i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix()
166
+ i_Hi_λ = Transform.inverse(λ_Hi_i)
166
167
 
167
168
  return i_Hi_λ, S
168
169
 
jaxsim/math/quaternion.py CHANGED
@@ -58,9 +58,7 @@ class Quaternion:
58
58
  Returns:
59
59
  jtp.Vector: Quaternion in XYZW representation.
60
60
  """
61
- return Quaternion.to_wxyz(
62
- xyzw=jaxlie.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
63
- )
61
+ return jaxlie.SO3.from_matrix(matrix=dcm).wxyz
64
62
 
65
63
  @staticmethod
66
64
  def derivative(
@@ -165,12 +163,8 @@ class Quaternion:
165
163
  # Integrate the quaternion on the manifold.
166
164
  W_Q_B_tf = jax.lax.select(
167
165
  pred=omega_in_body_fixed,
168
- on_true=Quaternion.to_wxyz(
169
- xyzw=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).as_quaternion_xyzw()
170
- ),
171
- on_false=Quaternion.to_wxyz(
172
- xyzw=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).as_quaternion_xyzw()
173
- ),
166
+ on_true=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).wxyz,
167
+ on_false=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).wxyz,
174
168
  )
175
169
 
176
170
  return W_Q_B_tf
jaxsim/math/transform.py CHANGED
@@ -46,8 +46,8 @@ class Transform:
46
46
 
47
47
  @staticmethod
48
48
  def from_rotation_and_translation(
49
- rotation: jtp.MatrixLike,
50
- translation: jtp.VectorLike,
49
+ rotation: jtp.MatrixLike = jnp.eye(3),
50
+ translation: jtp.VectorLike = jnp.zeros(3),
51
51
  inverse: jtp.BoolLike = False,
52
52
  ) -> jtp.Matrix:
53
53
  """
jaxsim/mujoco/loaders.py CHANGED
@@ -160,7 +160,7 @@ class RodModelToMjcf:
160
160
  considered_joints: list[str] | None = None,
161
161
  plane_normal: tuple[float, float, float] = (0, 0, 1),
162
162
  heightmap: bool | None = None,
163
- cameras: list[dict[str, str]] | dict[str, str] = None,
163
+ cameras: list[dict[str, str]] | dict[str, str] | None = None,
164
164
  ) -> tuple[str, dict[str, Any]]:
165
165
  """
166
166
  Converts a ROD model to a Mujoco MJCF string.
@@ -274,7 +274,7 @@ class RodModelToMjcf:
274
274
 
275
275
  # Load the URDF model into Mujoco.
276
276
  assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model)
277
- mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets) # noqa
277
+ mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
278
278
 
279
279
  # Get the joint names.
280
280
  mj_joint_names = set(
@@ -306,7 +306,7 @@ class RodModelToMjcf:
306
306
  root: ET._Element = tree.getroot()
307
307
 
308
308
  # Find the <mujoco> element (might be the root itself).
309
- mujoco_element: ET._Element = list(root.iter("mujoco"))[0]
309
+ mujoco_element: ET._Element = next(iter(root.iter("mujoco")))
310
310
 
311
311
  # --------------
312
312
  # Add the motors
@@ -516,7 +516,7 @@ class UrdfToMjcf:
516
516
  model_name: str | None = None,
517
517
  plane_normal: tuple[float, float, float] = (0, 0, 1),
518
518
  heightmap: bool | None = None,
519
- cameras: list[dict[str, str]] | dict[str, str] = None,
519
+ cameras: list[dict[str, str]] | dict[str, str] | None = None,
520
520
  ) -> tuple[str, dict[str, Any]]:
521
521
  """
522
522
  Converts a URDF file to a Mujoco MJCF string.
@@ -558,7 +558,7 @@ class SdfToMjcf:
558
558
  model_name: str | None = None,
559
559
  plane_normal: tuple[float, float, float] = (0, 0, 1),
560
560
  heightmap: bool | None = None,
561
- cameras: list[dict[str, str]] | dict[str, str] = None,
561
+ cameras: list[dict[str, str]] | dict[str, str] | None = None,
562
562
  ) -> tuple[str, dict[str, Any]]:
563
563
  """
564
564
  Converts a SDF file to a Mujoco MJCF string.
jaxsim/mujoco/model.py CHANGED
@@ -31,16 +31,16 @@ class MujocoModelHelper:
31
31
  self.model = model
32
32
  self.data = data if data is not None else mj.MjData(self.model)
33
33
 
34
- # Populate the data with kinematics
34
+ # Populate the data with kinematics.
35
35
  mj.mj_forward(self.model, self.data)
36
36
 
37
- # Keep the cache of this method local to improve GC
37
+ # Keep the cache of this method local to improve GC.
38
38
  self.mask_qpos = functools.cache(self._mask_qpos)
39
39
 
40
40
  @staticmethod
41
41
  def build_from_xml(
42
42
  mjcf_description: str | pathlib.Path,
43
- assets: dict[str, Any] = None,
43
+ assets: dict[str, Any] | None = None,
44
44
  heightmap: HeightmapCallable | None = None,
45
45
  ) -> MujocoModelHelper:
46
46
  """
@@ -56,15 +56,15 @@ class MujocoModelHelper:
56
56
  A MujocoModelHelper object.
57
57
  """
58
58
 
59
- # Read the XML description if it's a path to file
59
+ # Read the XML description if it is a path to file.
60
60
  mjcf_description = (
61
61
  mjcf_description.read_text()
62
62
  if isinstance(mjcf_description, pathlib.Path)
63
63
  else mjcf_description
64
64
  )
65
65
 
66
- # Create the Mujoco model from the XML and, optionally, the assets dictionary
67
- model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa
66
+ # Create the Mujoco model from the XML and, optionally, the assets dictionary.
67
+ model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets)
68
68
  data = mj.MjData(model)
69
69
 
70
70
  if heightmap:
@@ -81,6 +81,9 @@ class MujocoVideoRecorder:
81
81
  def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
82
82
  """Writes the video to a file."""
83
83
 
84
+ # Resolve the path to the video.
85
+ path = path.expanduser().resolve()
86
+
84
87
  if path.is_dir():
85
88
  raise IsADirectoryError(f"The path '{path}' is a directory.")
86
89
 
@@ -1 +0,0 @@
1
- from . import descriptions, kinematic_graph
@@ -26,7 +26,7 @@ class JointGenericAxis:
26
26
  A joint requiring the specification of a 3D axis.
27
27
  """
28
28
 
29
- #: The axis of rotation or translation of the joint (must have norm 1).
29
+ # The axis of rotation or translation of the joint (must have norm 1).
30
30
  axis: jtp.Vector
31
31
 
32
32
  def __hash__(self) -> int:
@@ -4,11 +4,11 @@ import dataclasses
4
4
 
5
5
  import jax.numpy as jnp
6
6
  import jax_dataclasses
7
- import jaxlie
8
7
  import numpy as np
9
8
  from jax_dataclasses import Static
10
9
 
11
10
  import jaxsim.typing as jtp
11
+ from jaxsim.math import Adjoint
12
12
  from jaxsim.utils import JaxsimDataclass
13
13
 
14
14
 
@@ -102,12 +102,11 @@ class LinkDescription(JaxsimDataclass):
102
102
  The combined link.
103
103
  """
104
104
 
105
- # Get the 6D inertia of the link to remove
105
+ # Get the 6D inertia of the link to remove.
106
106
  I_removed = link.inertia
107
107
 
108
108
  # Create the SE3 object. Note the inverse.
109
- r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse()
110
- r_X_l = r_H_l.adjoint()
109
+ r_X_l = Adjoint.from_transform(transform=lumped_H_removed, inverse=True)
111
110
 
112
111
  # Move the inertia
113
112
  I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l
@@ -176,7 +176,7 @@ class ModelDescription(KinematicGraph):
176
176
  frames=self.frames,
177
177
  collisions=tuple(self.collision_shapes),
178
178
  fixed_base=self.fixed_base,
179
- base_link_name=list(iter(self))[0].name,
179
+ base_link_name=next(iter(self)).name,
180
180
  model_pose=self.root_pose,
181
181
  considered_joints=considered_joints,
182
182
  )
@@ -12,7 +12,8 @@ import jaxsim.utils
12
12
  from jaxsim import logging
13
13
  from jaxsim.utils import Mutability
14
14
 
15
- from . import descriptions
15
+ from .descriptions.joint import JointDescription, JointType
16
+ from .descriptions.link import LinkDescription
16
17
 
17
18
 
18
19
  @dataclasses.dataclass
@@ -61,7 +62,7 @@ class RootPose:
61
62
 
62
63
 
63
64
  @dataclasses.dataclass(frozen=True)
64
- class KinematicGraph(Sequence[descriptions.LinkDescription]):
65
+ class KinematicGraph(Sequence[LinkDescription]):
65
66
  """
66
67
  Class storing a kinematic graph having links as nodes and joints as edges.
67
68
 
@@ -72,11 +73,11 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
72
73
  root_pose: The pose of the kinematic graph's root.
73
74
  """
74
75
 
75
- root: descriptions.LinkDescription
76
- frames: list[descriptions.LinkDescription] = dataclasses.field(
76
+ root: LinkDescription
77
+ frames: list[LinkDescription] = dataclasses.field(
77
78
  default_factory=list, hash=False, compare=False
78
79
  )
79
- joints: list[descriptions.JointDescription] = dataclasses.field(
80
+ joints: list[JointDescription] = dataclasses.field(
80
81
  default_factory=list, hash=False, compare=False
81
82
  )
82
83
 
@@ -89,26 +90,26 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
89
90
 
90
91
  # Private attribute storing the unconnected joints from the parsed model and
91
92
  # the joints removed after model reduction.
92
- _joints_removed: list[descriptions.JointDescription] = dataclasses.field(
93
+ _joints_removed: list[JointDescription] = dataclasses.field(
93
94
  default_factory=list, repr=False, hash=False, compare=False
94
95
  )
95
96
 
96
97
  @functools.cached_property
97
- def links_dict(self) -> dict[str, descriptions.LinkDescription]:
98
+ def links_dict(self) -> dict[str, LinkDescription]:
98
99
  return {l.name: l for l in iter(self)}
99
100
 
100
101
  @functools.cached_property
101
- def frames_dict(self) -> dict[str, descriptions.LinkDescription]:
102
+ def frames_dict(self) -> dict[str, LinkDescription]:
102
103
  return {f.name: f for f in self.frames}
103
104
 
104
105
  @functools.cached_property
105
- def joints_dict(self) -> dict[str, descriptions.JointDescription]:
106
+ def joints_dict(self) -> dict[str, JointDescription]:
106
107
  return {j.name: j for j in self.joints}
107
108
 
108
109
  @functools.cached_property
109
110
  def joints_connection_dict(
110
111
  self,
111
- ) -> dict[tuple[str, str], descriptions.JointDescription]:
112
+ ) -> dict[tuple[str, str], JointDescription]:
112
113
  return {(j.parent.name, j.child.name): j for j in self.joints}
113
114
 
114
115
  def __post_init__(self) -> None:
@@ -158,9 +159,9 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
158
159
 
159
160
  @staticmethod
160
161
  def build_from(
161
- links: list[descriptions.LinkDescription],
162
- joints: list[descriptions.JointDescription],
163
- frames: list[descriptions.LinkDescription] | None = None,
162
+ links: list[LinkDescription],
163
+ joints: list[JointDescription],
164
+ frames: list[LinkDescription] | None = None,
164
165
  root_link_name: str | None = None,
165
166
  root_pose: RootPose = RootPose(),
166
167
  ) -> KinematicGraph:
@@ -186,7 +187,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
186
187
  logging.debug(msg=f"Assuming '{root_link_name}' as the root link")
187
188
 
188
189
  # Couple links and joints and create the graph of links.
189
- # Note that the pose of the frames is not updated; it's the caller's
190
+ # Note that the pose of the frames is not updated; it is the caller's
190
191
  # responsibility to update their pose if they want to use them.
191
192
  (
192
193
  graph_root_node,
@@ -218,17 +219,17 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
218
219
 
219
220
  @staticmethod
220
221
  def _create_graph(
221
- links: list[descriptions.LinkDescription],
222
- joints: list[descriptions.JointDescription],
222
+ links: list[LinkDescription],
223
+ joints: list[JointDescription],
223
224
  root_link_name: str,
224
- frames: list[descriptions.LinkDescription] | None = None,
225
+ frames: list[LinkDescription] | None = None,
225
226
  ) -> tuple[
226
- descriptions.LinkDescription,
227
- list[descriptions.JointDescription],
228
- list[descriptions.LinkDescription],
229
- list[descriptions.LinkDescription],
230
- list[descriptions.JointDescription],
231
- list[descriptions.LinkDescription],
227
+ LinkDescription,
228
+ list[JointDescription],
229
+ list[LinkDescription],
230
+ list[LinkDescription],
231
+ list[JointDescription],
232
+ list[LinkDescription],
232
233
  ]:
233
234
  """
234
235
  Low-level creator of kinematic graph components.
@@ -248,7 +249,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
248
249
  """
249
250
 
250
251
  # Create a dictionary that maps the link name to the link, for easy retrieval.
251
- links_dict: dict[str, descriptions.LinkDescription] = {
252
+ links_dict: dict[str, LinkDescription] = {
252
253
  l.name: l.mutable(validate=False) for l in links
253
254
  }
254
255
 
@@ -280,7 +281,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
280
281
  # Couple links and joints creating the kinematic graph.
281
282
  for joint in joints:
282
283
 
283
- # Get the parent and child links of the joint
284
+ # Get the parent and child links of the joint.
284
285
  parent_link = links_dict[joint.parent.name]
285
286
  child_link = links_dict[joint.child.name]
286
287
 
@@ -293,7 +294,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
293
294
  # Assign link's children and make sure they are unique.
294
295
  if child_link.name not in {l.name for l in parent_link.children}:
295
296
  with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):
296
- parent_link.children = parent_link.children + (child_link,)
297
+ parent_link.children = (*parent_link.children, child_link)
297
298
 
298
299
  # Collect all the links of the kinematic graph.
299
300
  all_links_in_graph = list(
@@ -641,7 +642,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
641
642
  )
642
643
 
643
644
  @property
644
- def joints_removed(self) -> list[descriptions.JointDescription]:
645
+ def joints_removed(self) -> list[JointDescription]:
645
646
  """
646
647
  Get the list of joints removed during the graph reduction.
647
648
 
@@ -653,9 +654,9 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
653
654
 
654
655
  @staticmethod
655
656
  def breadth_first_search(
656
- root: descriptions.LinkDescription,
657
+ root: LinkDescription,
657
658
  sort_children: Callable[[Any], Any] | None = lambda link: link.name,
658
- ) -> Iterable[descriptions.LinkDescription]:
659
+ ) -> Iterable[LinkDescription]:
659
660
  """
660
661
  Perform a breadth-first search (BFS) traversal of the kinematic graph.
661
662
 
@@ -698,25 +699,25 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
698
699
  # Sequence protocol
699
700
  # =================
700
701
 
701
- def __iter__(self) -> Iterable[descriptions.LinkDescription]:
702
+ def __iter__(self) -> Iterable[LinkDescription]:
702
703
  yield from KinematicGraph.breadth_first_search(root=self.root)
703
704
 
704
- def __reversed__(self) -> Iterable[descriptions.LinkDescription]:
705
+ def __reversed__(self) -> Iterable[LinkDescription]:
705
706
  yield from reversed(list(iter(self)))
706
707
 
707
708
  def __len__(self) -> int:
708
709
  return len(list(iter(self)))
709
710
 
710
- def __contains__(self, item: str | descriptions.LinkDescription) -> bool:
711
+ def __contains__(self, item: str | LinkDescription) -> bool:
711
712
  if isinstance(item, str):
712
713
  return item in self.link_names()
713
714
 
714
- if isinstance(item, descriptions.LinkDescription):
715
+ if isinstance(item, LinkDescription):
715
716
  return item in set(iter(self))
716
717
 
717
718
  raise TypeError(type(item).__name__)
718
719
 
719
- def __getitem__(self, key: int | str) -> descriptions.LinkDescription:
720
+ def __getitem__(self, key: int | str) -> LinkDescription:
720
721
  if isinstance(key, str):
721
722
  if key not in self.link_names():
722
723
  raise KeyError(key)
@@ -731,12 +732,10 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
731
732
 
732
733
  raise TypeError(type(key).__name__)
733
734
 
734
- def count(self, value: descriptions.LinkDescription) -> int:
735
+ def count(self, value: LinkDescription) -> int:
735
736
  return list(iter(self)).count(value)
736
737
 
737
- def index(
738
- self, value: descriptions.LinkDescription, start: int = 0, stop: int = -1
739
- ) -> int:
738
+ def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:
740
739
  return list(iter(self)).index(value, start, stop)
741
740
 
742
741
 
@@ -906,7 +905,7 @@ class KinematicGraphTransforms:
906
905
 
907
906
  @staticmethod
908
907
  def pre_H_suc(
909
- joint_type: descriptions.JointType,
908
+ joint_type: JointType,
910
909
  joint_axis: npt.NDArray,
911
910
  joint_position: float | None = None,
912
911
  ) -> npt.NDArray:
@@ -54,19 +54,19 @@ def extract_model_data(
54
54
  if isinstance(model_description, rod.Model):
55
55
  sdf_model = model_description
56
56
  else:
57
- # Parse the SDF resource
57
+ # Parse the SDF resource.
58
58
  sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)
59
59
 
60
60
  if len(sdf_element.models()) == 0:
61
61
  raise RuntimeError("Failed to find any model in SDF resource")
62
62
 
63
- # Assume the SDF resource has only one model, or the desired model name is given
63
+ # Assume the SDF resource has only one model, or the desired model name is given.
64
64
  sdf_models = {m.name: m for m in sdf_element.models()}
65
65
  sdf_model = (
66
66
  sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name]
67
67
  )
68
68
 
69
- # Log model name
69
+ # Log model name.
70
70
  logging.debug(msg=f"Found model '{sdf_model.name}' in SDF resource")
71
71
 
72
72
  # Jaxsim supports only models compatible with URDF, i.e. those having all links
@@ -75,7 +75,7 @@ def extract_model_data(
75
75
  # pose is expressed wrt the parent link they are rigidly attached to.
76
76
  sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf)
77
77
 
78
- # Log type of base link
78
+ # Log type of base link.
79
79
  logging.debug(
80
80
  msg="Model '{}' is {}".format(
81
81
  sdf_model.name,
@@ -83,7 +83,7 @@ def extract_model_data(
83
83
  )
84
84
  )
85
85
 
86
- # Log detected base link
86
+ # Log detected base link.
87
87
  logging.debug(msg=f"Considering '{sdf_model.get_canonical_link()}' as base link")
88
88
 
89
89
  # Pose of the model
@@ -101,7 +101,7 @@ def extract_model_data(
101
101
  # Parse links
102
102
  # ===========
103
103
 
104
- # Parse the links (unconnected)
104
+ # Parse the links (unconnected).
105
105
  links = [
106
106
  descriptions.LinkDescription(
107
107
  name=l.name,
@@ -113,14 +113,14 @@ def extract_model_data(
113
113
  if l.inertial.mass > 0
114
114
  ]
115
115
 
116
- # Create a dictionary to find easily links
116
+ # Create a dictionary to find easily links.
117
117
  links_dict: Dict[str, descriptions.LinkDescription] = {l.name: l for l in links}
118
118
 
119
119
  # ============
120
120
  # Parse frames
121
121
  # ============
122
122
 
123
- # Parse the frames (unconnected)
123
+ # Parse the frames (unconnected).
124
124
  frames = [
125
125
  descriptions.LinkDescription(
126
126
  name=f.name,
@@ -138,7 +138,7 @@ def extract_model_data(
138
138
  # =========================
139
139
 
140
140
  # In this case, we need to get the pose of the joint that connects the base link
141
- # to the world and combine their pose
141
+ # to the world and combine their pose.
142
142
  if sdf_model.is_fixed_base():
143
143
  # Create a massless word link
144
144
  world_link = descriptions.LinkDescription(
@@ -200,7 +200,7 @@ def extract_model_data(
200
200
  # Parse joints
201
201
  # ============
202
202
 
203
- # Check that all joint poses are expressed w.r.t. their parent link
203
+ # Check that all joint poses are expressed w.r.t. their parent link.
204
204
  for j in sdf_model.joints():
205
205
  if j.pose is None:
206
206
  continue
@@ -215,7 +215,7 @@ def extract_model_data(
215
215
  msg = "Pose of joint '{}' is not expressed wrt its parent link '{}'"
216
216
  raise ValueError(msg.format(j.name, j.parent))
217
217
 
218
- # Parse the joints
218
+ # Parse the joints.
219
219
  joints = [
220
220
  descriptions.JointDescription(
221
221
  name=j.name,
@@ -278,10 +278,10 @@ def extract_model_data(
278
278
  and j.child in links_dict.keys()
279
279
  ]
280
280
 
281
- # Create a dictionary to find the parent joint of the links
281
+ # Create a dictionary to find the parent joint of the links.
282
282
  joint_dict = {j.child.name: j.name for j in joints}
283
283
 
284
- # Check that all the link poses are expressed wrt their parent joint
284
+ # Check that all the link poses are expressed wrt their parent joint.
285
285
  for l in sdf_model.links():
286
286
  if l.name not in links_dict:
287
287
  continue
@@ -354,7 +354,7 @@ def build_model_description(
354
354
  The parsed model description.
355
355
  """
356
356
 
357
- # Parse data from the SDF assuming it contains a single model
357
+ # Parse data from the SDF assuming it contains a single model.
358
358
  sdf_data = extract_model_data(
359
359
  model_description=model_description, model_name=None, is_urdf=is_urdf
360
360
  )