jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,258 +1,397 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import dataclasses
3
5
  import functools
4
- from typing import (
5
- Any,
6
- Callable,
7
- Dict,
8
- Iterable,
9
- List,
10
- NamedTuple,
11
- Optional,
12
- Tuple,
13
- Union,
14
- )
6
+ from collections.abc import Callable, Iterable, Iterator, Sequence
7
+ from typing import Any
15
8
 
16
9
  import numpy as np
17
10
  import numpy.typing as npt
18
11
 
12
+ import jaxsim.utils
19
13
  from jaxsim import logging
20
14
  from jaxsim.utils import Mutability
21
15
 
22
- from . import descriptions
16
+ from .descriptions.joint import JointDescription, JointType
17
+ from .descriptions.link import LinkDescription
23
18
 
24
19
 
25
- class RootPose(NamedTuple):
20
+ @dataclasses.dataclass
21
+ class RootPose:
26
22
  """
27
23
  Represents the root pose in a kinematic graph.
28
24
 
29
25
  Attributes:
30
- root_position (npt.NDArray): A NumPy array of shape (3,) representing the root's position.
31
- root_quaternion (npt.NDArray): A NumPy array of shape (4,) representing the root's quaternion.
26
+ root_position: The 3D position of the root link of the graph.
27
+ root_quaternion:
28
+ The quaternion representing the rotation of the root link of the graph.
29
+
30
+ Note:
31
+ The root link of the kinematic graph is the base link.
32
32
  """
33
33
 
34
- root_position: npt.NDArray = np.zeros(3)
35
- root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
34
+ root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
35
+
36
+ root_quaternion: npt.NDArray = dataclasses.field(
37
+ default_factory=lambda: np.array([1.0, 0, 0, 0])
38
+ )
39
+
40
+ def __hash__(self) -> int:
41
+
42
+ from jaxsim.utils.wrappers import HashedNumpyArray
43
+
44
+ return hash(
45
+ (
46
+ HashedNumpyArray.hash_of_array(self.root_position),
47
+ HashedNumpyArray.hash_of_array(self.root_quaternion),
48
+ )
49
+ )
50
+
51
+ def __eq__(self, other: RootPose) -> bool:
52
+
53
+ if not isinstance(other, RootPose):
54
+ return False
55
+
56
+ if not np.allclose(self.root_position, other.root_position):
57
+ return False
58
+
59
+ if not np.allclose(self.root_quaternion, other.root_quaternion):
60
+ return False
61
+
62
+ return True
36
63
 
37
64
 
38
65
  @dataclasses.dataclass(frozen=True)
39
- class KinematicGraph:
66
+ class KinematicGraph(Sequence[LinkDescription]):
40
67
  """
41
- Represents a kinematic graph of links and joints.
42
-
43
- Args:
44
- root (descriptions.LinkDescription): The root link of the kinematic graph.
45
- frames (List[descriptions.LinkDescription]): A list of frame links in the graph.
46
- joints (List[descriptions.JointDescription]): A list of joint descriptions in the graph.
47
- root_pose (RootPose): The root pose of the graph.
48
- transform_cache (Dict[str, npt.NDArray]): A dictionary to cache transformation matrices.
49
- extra_info (Dict[str, Any]): Additional information associated with the graph.
68
+ Class storing a kinematic graph having links as nodes and joints as edges.
50
69
 
51
70
  Attributes:
52
- links_dict (Dict[str, descriptions.LinkDescription]): A dictionary mapping link names to link descriptions.
53
- frames_dict (Dict[str, descriptions.LinkDescription]): A dictionary mapping frame names to frame link descriptions.
54
- joints_dict (Dict[str, descriptions.JointDescription]): A dictionary mapping joint names to joint descriptions.
55
- joints_connection_dict (Dict[Tuple[str, str], descriptions.JointDescription]): A dictionary mapping pairs of parent and child link names to joint descriptions.
71
+ root: The root node of the kinematic graph.
72
+ frames: List of frames rigidly attached to the graph nodes.
73
+ joints: List of joints connecting the graph nodes.
74
+ root_pose: The pose of the kinematic graph's root.
56
75
  """
57
76
 
58
- root: descriptions.LinkDescription
59
- frames: List[descriptions.LinkDescription] = dataclasses.field(default_factory=list)
60
- joints: List[descriptions.JointDescription] = dataclasses.field(
61
- default_factory=list
77
+ root: LinkDescription
78
+ frames: list[LinkDescription] = dataclasses.field(
79
+ default_factory=list, hash=False, compare=False
80
+ )
81
+ joints: list[JointDescription] = dataclasses.field(
82
+ default_factory=list, hash=False, compare=False
62
83
  )
63
84
 
64
85
  root_pose: RootPose = dataclasses.field(default_factory=RootPose)
65
86
 
66
- transform_cache: Dict[str, npt.NDArray] = dataclasses.field(
67
- repr=False, init=False, compare=False, default_factory=dict
87
+ # Private attribute storing optional additional info.
88
+ _extra_info: dict[str, Any] = dataclasses.field(
89
+ default_factory=dict, repr=False, hash=False, compare=False
68
90
  )
69
91
 
70
- extra_info: Dict[str, Any] = dataclasses.field(
71
- repr=False, compare=False, default_factory=dict
92
+ # Private attribute storing the unconnected joints from the parsed model and
93
+ # the joints removed after model reduction.
94
+ _joints_removed: list[JointDescription] = dataclasses.field(
95
+ default_factory=list, repr=False, hash=False, compare=False
72
96
  )
73
97
 
74
98
  @functools.cached_property
75
- def links_dict(self) -> Dict[str, descriptions.LinkDescription]:
99
+ def links_dict(self) -> dict[str, LinkDescription]:
100
+ """
101
+ Get a dictionary of links indexed by their name.
102
+ """
76
103
  return {l.name: l for l in iter(self)}
77
104
 
78
105
  @functools.cached_property
79
- def frames_dict(self) -> Dict[str, descriptions.LinkDescription]:
106
+ def frames_dict(self) -> dict[str, LinkDescription]:
107
+ """
108
+ Get a dictionary of frames indexed by their name.
109
+ """
80
110
  return {f.name: f for f in self.frames}
81
111
 
82
112
  @functools.cached_property
83
- def joints_dict(self) -> Dict[str, descriptions.JointDescription]:
113
+ def joints_dict(self) -> dict[str, JointDescription]:
114
+ """
115
+ Get a dictionary of joints indexed by their name.
116
+ """
84
117
  return {j.name: j for j in self.joints}
85
118
 
86
119
  @functools.cached_property
87
120
  def joints_connection_dict(
88
121
  self,
89
- ) -> Dict[Tuple[str, str], descriptions.JointDescription]:
90
- return {(j.parent.name, j.child.name): j for j in self.joints}
91
-
92
- def __post_init__(self):
122
+ ) -> dict[tuple[str, str], JointDescription]:
93
123
  """
94
- Post-initialization method to set various properties and validate the kinematic graph.
124
+ Get a dictionary of joints indexed by the tuple (parent, child) link names.
95
125
  """
96
- # Assign the link index traversing the graph with BFS.
97
- # Here we assume the model is fixed-base, therefore the base link will
98
- # have index 0. We will deal with the floating base in a later stage,
99
- # when this Model object is converted to the physics model.
126
+ return {(j.parent.name, j.child.name): j for j in self.joints}
127
+
128
+ def __post_init__(self) -> None:
129
+
130
+ # Assign the link index by traversing the graph with BFS.
131
+ # Here we assume the model being fixed-base, therefore the base link will
132
+ # have index 0. We will deal with the floating base in a later stage.
100
133
  for index, link in enumerate(self):
101
134
  link.mutable(validate=False).index = index
102
135
 
103
- # Order frames with their name
136
+ # Get the names of the links, frames, and joints.
137
+ link_names = [l.name for l in self]
138
+ frame_names = [f.name for f in self.frames]
139
+ joint_names = [j.name for j in self.joints]
140
+
141
+ # Make sure that they are unique.
142
+ assert len(link_names) == len(set(link_names))
143
+ assert len(frame_names) == len(set(frame_names))
144
+ assert len(joint_names) == len(set(joint_names))
145
+ assert set(link_names).isdisjoint(set(frame_names))
146
+ assert set(link_names).isdisjoint(set(joint_names))
147
+
148
+ # Order frames with their name.
104
149
  super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name))
105
150
 
106
151
  # Assign the frame index following the name-based indexing.
107
- # Also here, we assume the model is fixed-base, therefore the first frame will
108
- # have last_link_idx + 1. These frames are not part of the physics model.
152
+ # We assume the model being fixed-base, therefore the first frame will
153
+ # have last_link_idx + 1.
109
154
  for index, frame in enumerate(self.frames):
110
- frame.index = index + len(self.link_names())
155
+ with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
156
+ frame.index = int(index + len(self.link_names()))
111
157
 
112
- # Number joints so that their index matches their child link index
158
+ # Number joints so that their index matches their child link index.
159
+ # Therefore, the first joint has index 1.
113
160
  links_dict = {l.name: l for l in iter(self)}
114
161
  for joint in self.joints:
115
162
  with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
116
163
  joint.index = links_dict[joint.child.name].index
117
164
 
118
- # Check that joint indices are unique
165
+ # Check that joint indices are unique.
119
166
  assert len([j.index for j in self.joints]) == len(
120
167
  {j.index for j in self.joints}
121
168
  )
122
169
 
123
- # Order joints with their indices
170
+ # Order joints with their indices.
124
171
  super().__setattr__("joints", sorted(self.joints, key=lambda j: j.index))
125
172
 
126
173
  @staticmethod
127
174
  def build_from(
128
- links: List[descriptions.LinkDescription],
129
- joints: List[descriptions.JointDescription],
175
+ links: list[LinkDescription],
176
+ joints: list[JointDescription],
177
+ frames: list[LinkDescription] | None = None,
130
178
  root_link_name: str | None = None,
131
179
  root_pose: RootPose = RootPose(),
132
- ) -> "KinematicGraph":
180
+ ) -> KinematicGraph:
133
181
  """
134
- Build a KinematicGraph from a list of links and joints.
182
+ Build a KinematicGraph from links, joints, and frames.
135
183
 
136
184
  Args:
137
- links (List[descriptions.LinkDescription]): A list of link descriptions.
138
- joints (List[descriptions.JointDescription]): A list of joint descriptions.
139
- root_link_name (str, optional): The name of the root link. If not provided, it's assumed to be the first link's name.
140
- root_pose (RootPose, optional): The root pose of the kinematic graph.
185
+ links: A list of link descriptions.
186
+ joints: A list of joint descriptions.
187
+ frames: A list of frame descriptions.
188
+ root_link_name:
189
+ The name of the root link. If not provided, it's assumed to be the
190
+ first link's name.
191
+ root_pose: The root pose of the kinematic graph.
141
192
 
142
193
  Returns:
143
- KinematicGraph: The constructed kinematic graph.
194
+ The resulting kinematic graph.
144
195
  """
196
+
197
+ # Consider the first link as the root link if not provided.
145
198
  if root_link_name is None:
146
199
  root_link_name = links[0].name
200
+ logging.debug(msg=f"Assuming '{root_link_name}' as the root link")
147
201
 
148
202
  # Couple links and joints and create the graph of links.
149
- # Note that the pose of the frames is not updated; it's the caller's
203
+ # Note that the pose of the frames is not updated; it is the caller's
150
204
  # responsibility to update their pose if they want to use them.
151
- graph_root_node, graph_joints, graph_frames = KinematicGraph.create_graph(
152
- links=links, joints=joints, root_link_name=root_link_name
205
+ (
206
+ graph_root_node,
207
+ graph_joints,
208
+ graph_frames,
209
+ unconnected_links,
210
+ unconnected_joints,
211
+ unconnected_frames,
212
+ ) = KinematicGraph._create_graph(
213
+ links=links, joints=joints, root_link_name=root_link_name, frames=frames
153
214
  )
154
215
 
155
- for frame in graph_frames:
156
- logging.warning(msg=f"Ignoring unconnected link / frame: '{frame.name}'")
216
+ for link in unconnected_links:
217
+ logging.warning(msg=f"Ignoring unconnected link: '{link.name}'")
218
+
219
+ for joint in unconnected_joints:
220
+ logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'")
221
+
222
+ for frame in unconnected_frames:
223
+ logging.warning(msg=f"Ignoring unconnected frame: '{frame.name}'")
157
224
 
158
225
  return KinematicGraph(
159
- root=graph_root_node, joints=graph_joints, frames=[], root_pose=root_pose
226
+ root=graph_root_node,
227
+ joints=graph_joints,
228
+ frames=graph_frames,
229
+ root_pose=root_pose,
230
+ _joints_removed=unconnected_joints,
160
231
  )
161
232
 
162
233
  @staticmethod
163
- def create_graph(
164
- links: List[descriptions.LinkDescription],
165
- joints: List[descriptions.JointDescription],
234
+ def _create_graph(
235
+ links: list[LinkDescription],
236
+ joints: list[JointDescription],
166
237
  root_link_name: str,
167
- ) -> Tuple[
168
- descriptions.LinkDescription,
169
- List[descriptions.JointDescription],
170
- List[descriptions.LinkDescription],
238
+ frames: list[LinkDescription] | None = None,
239
+ ) -> tuple[
240
+ LinkDescription,
241
+ list[JointDescription],
242
+ list[LinkDescription],
243
+ list[LinkDescription],
244
+ list[JointDescription],
245
+ list[LinkDescription],
171
246
  ]:
172
247
  """
173
- Create a kinematic graph from lists of links and joints.
248
+ Low-level creator of kinematic graph components.
174
249
 
175
250
  Args:
176
- links (List[descriptions.LinkDescription]): A list of link descriptions.
177
- joints (List[descriptions.JointDescription]): A list of joint descriptions.
178
- root_link_name (str): The name of the root link.
251
+ links: A list of parsed link descriptions.
252
+ joints: A list of parsed joint descriptions.
253
+ root_link_name: The name of the root link used as root node of the graph.
254
+ frames: A list of parsed frame descriptions.
179
255
 
180
256
  Returns:
181
- Tuple[descriptions.LinkDescription, List[descriptions.JointDescription], List[descriptions.LinkDescription]]:
182
- A tuple containing the root link, list of joints, and list of frames in the graph.
257
+ A tuple containing the root node of the graph (defining the entire kinematic
258
+ tree by iterating on its child nodes), the list of joints representing the
259
+ actual graph edges, the list of frames rigidly attached to the graph nodes,
260
+ the list of unconnected links, the list of unconnected joints, and the list
261
+ of unconnected frames.
183
262
  """
184
263
 
185
- # Create a dict that maps link name to the link, for easy retrieval
186
- links_dict: Dict[str, descriptions.LinkDescription] = {
264
+ # Create a dictionary that maps the link name to the link, for easy retrieval.
265
+ links_dict: dict[str, LinkDescription] = {
187
266
  l.name: l.mutable(validate=False) for l in links
188
267
  }
189
268
 
269
+ # Create an empty list of frames if not provided.
270
+ frames = frames if frames is not None else []
271
+
272
+ # Create a dictionary that maps the frame name to the frame, for easy retrieval.
273
+ frames_dict = {frame.name: frame for frame in frames}
274
+
275
+ # Check that our parser correctly resolved the frame's parent to be a link.
276
+ for frame in frames:
277
+ assert frame.parent.name != "", frame
278
+ assert frame.parent.name is not None, frame
279
+ assert frame.parent.name != "__model__", frame
280
+ assert frame.parent.name not in frames_dict, frame
281
+
282
+ # ===========================================================
283
+ # Populate the kinematic graph with links, joints, and frames
284
+ # ===========================================================
285
+
286
+ # Check the existence of the root link.
190
287
  if root_link_name not in links_dict:
191
288
  raise ValueError(root_link_name)
192
289
 
193
- # Reset the connections of the root link
290
+ # Reset the connections of the root link.
194
291
  for link in links_dict.values():
195
- link.children = []
292
+ link.children = tuple()
196
293
 
197
- # Couple links and joints creating the final kinematic graph
294
+ # Couple links and joints creating the kinematic graph.
198
295
  for joint in joints:
199
- # Get the parent and child links of the joint
296
+
297
+ # Get the parent and child links of the joint.
200
298
  parent_link = links_dict[joint.parent.name]
201
299
  child_link = links_dict[joint.child.name]
202
300
 
203
301
  assert child_link.name == joint.child.name
204
302
  assert parent_link.name == joint.parent.name
205
303
 
206
- # Assign link parent
304
+ # Assign link's parent.
207
305
  child_link.parent = parent_link
208
306
 
209
- # Assign link children and make sure they are unique
307
+ # Assign link's children and make sure they are unique.
210
308
  if child_link.name not in {l.name for l in parent_link.children}:
211
- parent_link.children.append(child_link)
309
+ with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):
310
+ parent_link.children = (*parent_link.children, child_link)
212
311
 
213
- # Collect all the links of the kinematic graph
312
+ # Collect all the links of the kinematic graph.
214
313
  all_links_in_graph = list(
215
314
  KinematicGraph.breadth_first_search(root=links_dict[root_link_name])
216
315
  )
316
+
317
+ # Get the names of all links in the kinematic graph.
217
318
  all_link_names_in_graph = [l.name for l in all_links_in_graph]
218
319
 
219
- # Collect all the joints not part of the kinematic graph
220
- removed_joints = [
221
- j
222
- for j in joints
223
- if not {j.parent.name, j.child.name}.issubset(all_link_names_in_graph)
320
+ # Collect all the joints of the kinematic graph.
321
+ all_joints_in_graph = [
322
+ joint
323
+ for joint in joints
324
+ if joint.parent.name in all_link_names_in_graph
325
+ and joint.child.name in all_link_names_in_graph
224
326
  ]
225
327
 
226
- for removed_joint in removed_joints:
227
- msg = "Joint '{}' has been removed for the graph because unconnected"
228
- logging.info(msg=msg.format(removed_joint.name))
328
+ # Get the names of all joints in the kinematic graph.
329
+ all_joint_names_in_graph = [j.name for j in all_joints_in_graph]
229
330
 
230
- # Store as frames all the links that are not part of the kinematic graph
231
- frames = list(set(links) - set(all_links_in_graph))
331
+ # Collect all the frames of the kinematic graph.
332
+ # Note: our parser ensures that the parent of a frame is not another frame.
333
+ all_frames_in_graph = [
334
+ frame for frame in frames if frame.parent.name in all_link_names_in_graph
335
+ ]
232
336
 
233
- # Update the frames. In particular, reset their children. The other properties
234
- # are kept as they are, and it's caller responsibility to update them if needed.
235
- for frame in frames:
236
- frame.children = []
237
- msg = f"Link '{frame.name}' became a frame"
238
- logging.info(msg=msg)
337
+ # Get the names of all frames in the kinematic graph.
338
+ all_frames_names_in_graph = [f.name for f in all_frames_in_graph]
339
+
340
+ # ============================
341
+ # Collect unconnected elements
342
+ # ============================
343
+
344
+ # Collect all the joints that are not part of the kinematic graph.
345
+ removed_joints = [j for j in joints if j.name not in all_joint_names_in_graph]
346
+
347
+ for joint in removed_joints:
348
+ msg = "Joint '{}' is unconnected and it will be removed"
349
+ logging.debug(msg=msg.format(joint.name))
350
+
351
+ # Collect all the links that are not part of the kinematic graph.
352
+ unconnected_links = [l for l in links if l.name not in all_link_names_in_graph]
353
+
354
+ # Update the unconnected links by removing their children. The other properties
355
+ # are left untouched, it's caller responsibility to post-process them if needed.
356
+ for link in unconnected_links:
357
+ link.children = tuple()
358
+ msg = "Link '{}' won't be part of the kinematic graph because unconnected"
359
+ logging.debug(msg=msg.format(link.name))
360
+
361
+ # Collect all the frames that are not part of the kinematic graph.
362
+ unconnected_frames = [
363
+ f for f in frames if f.name not in all_frames_names_in_graph
364
+ ]
365
+
366
+ for frame in unconnected_frames:
367
+ msg = "Frame '{}' won't be part of the kinematic graph because unconnected"
368
+ logging.debug(msg=msg.format(frame.name))
239
369
 
240
370
  return (
241
371
  links_dict[root_link_name].mutable(mutable=False),
242
372
  list(set(joints) - set(removed_joints)),
243
- frames,
373
+ all_frames_in_graph,
374
+ unconnected_links,
375
+ list(set(removed_joints)),
376
+ unconnected_frames,
244
377
  )
245
378
 
246
- def reduce(self, considered_joints: List[str]) -> "KinematicGraph":
379
+ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
247
380
  """
248
- Reduce the kinematic graph by removing specified joints and lumping the mass and inertia of removed links into their parent links.
381
+ Reduce the kinematic graph by removing unspecified joints.
382
+
383
+ When a joint is removed, the mass and inertia of its child link are lumped
384
+ with those of its parent link, obtaining a new link that combines the two.
385
+ The description of the removed joint specifies the default angle (usually 0)
386
+ that is considered when the joint is removed.
249
387
 
250
388
  Args:
251
- considered_joints (List[str]): A list of joint names to consider.
389
+ considered_joints: A list of joint names to consider.
252
390
 
253
391
  Returns:
254
- KinematicGraph: The reduced kinematic graph.
392
+ The reduced kinematic graph.
255
393
  """
394
+
256
395
  # The current object represents the complete kinematic graph
257
396
  full_graph = self
258
397
 
@@ -263,11 +402,11 @@ class KinematicGraph:
263
402
 
264
403
  # Return early if there is no action to take
265
404
  if len(joint_names_to_remove) == 0:
266
- logging.info(f"The kinematic graph doesn't need to be reduced")
405
+ logging.info("The kinematic graph doesn't need to be reduced")
267
406
  return copy.deepcopy(self)
268
407
 
269
408
  # Check if all considered joints are part of the full kinematic graph
270
- if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
409
+ if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
271
410
  extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
272
411
  msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
273
412
  raise ValueError(msg)
@@ -276,6 +415,9 @@ class KinematicGraph:
276
415
  links_dict = copy.deepcopy(full_graph.links_dict)
277
416
  joints_dict = copy.deepcopy(full_graph.joints_dict)
278
417
 
418
+ # Create the object to compute forward kinematics.
419
+ fk = KinematicGraphTransforms(graph=full_graph)
420
+
279
421
  # The following steps are implemented below in order to create the reduced graph:
280
422
  #
281
423
  # 1. Lump the mass of the removed links into their parent
@@ -315,7 +457,7 @@ class KinematicGraph:
315
457
  msg.format(
316
458
  link_to_remove.name,
317
459
  self.joints_connection_dict[
318
- (parent_of_link_to_remove.name, link_to_remove.name)
460
+ parent_of_link_to_remove.name, link_to_remove.name
319
461
  ].name,
320
462
  parent_of_link_to_remove.name,
321
463
  )
@@ -324,14 +466,14 @@ class KinematicGraph:
324
466
  # Lump the link
325
467
  lumped_link = parent_of_link_to_remove.lump_with(
326
468
  link=link_to_remove,
327
- lumped_H_removed=full_graph.relative_transform(
469
+ lumped_H_removed=fk.relative_transform(
328
470
  relative_to=parent_of_link_to_remove.name, name=link_to_remove.name
329
471
  ),
330
472
  )
331
473
 
332
474
  # Pop the original two links from the dictionary...
333
- links_dict.pop(link_to_remove.name)
334
- links_dict.pop(parent_of_link_to_remove.name)
475
+ _ = links_dict.pop(link_to_remove.name)
476
+ _ = links_dict.pop(parent_of_link_to_remove.name)
335
477
 
336
478
  # ... and insert the lumped link (having the same name of the parent)
337
479
  links_dict[lumped_link.name] = lumped_link
@@ -341,11 +483,13 @@ class KinematicGraph:
341
483
  links_dict[link_to_remove.name] = lumped_link
342
484
 
343
485
  # As a consequence of the back-insertion, we need to adjust the resulting
344
- # lumped link of links that have been removed previously
486
+ # lumped link of links that have been removed previously.
487
+ # Note: in the dictionary, only items whose key is not matching value.name
488
+ # are links that have been removed.
345
489
  for previously_removed_link_name in {
346
- k
347
- for k, v in links_dict.items()
348
- if k != v.name and v.name == link_to_remove.name
490
+ link_name
491
+ for link_name, link in links_dict.items()
492
+ if link_name != link.name and link.name == link_to_remove.name
349
493
  }:
350
494
  links_dict[previously_removed_link_name] = lumped_link
351
495
 
@@ -365,7 +509,7 @@ class KinematicGraph:
365
509
  # Update the pose. Note that after the lumping process, the dict entry
366
510
  # links_dict[joint.parent.name] contains the final lumped link
367
511
  with joint.mutable_context(mutability=Mutability.MUTABLE):
368
- joint.pose = full_graph.relative_transform(
512
+ joint.pose = fk.relative_transform(
369
513
  relative_to=links_dict[joint.parent.name].name, name=joint.name
370
514
  )
371
515
  with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
@@ -391,135 +535,114 @@ class KinematicGraph:
391
535
 
392
536
  # Create the reduced graph data. We pass the full list of links so that those
393
537
  # that are not part of the graph will be returned as frames.
394
- reduced_root_node, reduced_joints, reduced_frames = KinematicGraph.create_graph(
538
+ (
539
+ reduced_root_node,
540
+ reduced_joints,
541
+ reduced_frames,
542
+ unconnected_links,
543
+ unconnected_joints,
544
+ unconnected_frames,
545
+ ) = KinematicGraph._create_graph(
395
546
  links=list(full_graph_links_dict.values()),
396
547
  joints=[joints_dict[joint_name] for joint_name in considered_joints],
397
548
  root_link_name=full_graph.root.name,
398
549
  )
399
550
 
400
- # Create the reduced graph
551
+ assert {f.name for f in self.frames}.isdisjoint(
552
+ {f.name for f in unconnected_frames + reduced_frames}
553
+ )
554
+
555
+ for link in unconnected_links:
556
+ logging.debug(msg=f"Link '{link.name}' is unconnected and became a frame")
557
+
558
+ # Create the reduced graph.
401
559
  reduced_graph = KinematicGraph(
402
560
  root=reduced_root_node,
403
561
  joints=reduced_joints,
404
- frames=reduced_frames,
562
+ frames=self.frames + unconnected_links + reduced_frames,
405
563
  root_pose=full_graph.root_pose,
564
+ _joints_removed=(
565
+ self._joints_removed
566
+ + unconnected_joints
567
+ + [joints_dict[name] for name in joint_names_to_remove]
568
+ ),
406
569
  )
407
570
 
408
571
  # ================================================================
409
572
  # 4. Resolve the pose of the frames wrt their reduced graph parent
410
573
  # ================================================================
411
574
 
412
- # Update frames properties using the transforms from the full graph
413
- for frame in reduced_graph.frames:
414
- # Get the link in which the removed link was lumped into
415
- new_parent_link = links_dict[frame.name]
575
+ # Build a new object to compute FK on the reduced graph.
576
+ fk_reduced = KinematicGraphTransforms(graph=reduced_graph)
416
577
 
417
- msg = f"New parent of frame '{frame.name}' is '{new_parent_link.name}'"
418
- logging.info(msg)
578
+ # We need to adjust the pose of the frames since their parent link
579
+ # could have been removed by the reduction process.
580
+ for frame in reduced_graph.frames:
419
581
 
420
- # Update the connection of the frame
421
- frame.parent = new_parent_link
422
- frame.pose = full_graph.relative_transform(
423
- relative_to=new_parent_link.name, name=frame.name
582
+ # Always find the real parent link of the frame
583
+ name_of_new_parent_link = fk_reduced.find_parent_link_of_frame(
584
+ name=frame.name
424
585
  )
586
+ assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link
425
587
 
426
- # Update frame data
427
- frame.mass = 0.0
428
- frame.inertia = np.zeros_like(frame.inertia)
588
+ # Notify the user if the parent link has changed.
589
+ if name_of_new_parent_link != frame.parent.name:
590
+ msg = "New parent of frame '{}' is '{}'"
591
+ logging.debug(msg=msg.format(frame.name, name_of_new_parent_link))
429
592
 
430
- # Return the reduced graph
593
+ # Always recompute the pose of the frame, and set zero inertial params.
594
+ with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION):
595
+
596
+ # Update kinematic parameters of the frame.
597
+ # Note that here we compute the transform using the FK object of the
598
+ # full model, so that we are sure that the kinematic is not altered.
599
+ frame.pose = fk.relative_transform(
600
+ relative_to=name_of_new_parent_link, name=frame.name
601
+ )
602
+
603
+ # Update the parent link such that the pose is expressed in its frame.
604
+ frame.parent = reduced_graph.links_dict[name_of_new_parent_link]
605
+
606
+ # Update dynamic parameters of the frame.
607
+ frame.mass = 0.0
608
+ frame.inertia = np.zeros_like(frame.inertia)
609
+
610
+ # Return the reduced graph.
431
611
  return reduced_graph
432
612
 
433
- def link_names(self) -> List[str]:
613
+ def link_names(self) -> list[str]:
434
614
  """
435
- Get the names of all links in the kinematic graph.
615
+ Get the names of all links in the kinematic graph (i.e. the nodes).
436
616
 
437
617
  Returns:
438
- List[str]: A list of link names.
618
+ The list of link names.
439
619
  """
440
620
  return list(self.links_dict.keys())
441
621
 
442
- def joint_names(self) -> List[str]:
622
+ def joint_names(self) -> list[str]:
443
623
  """
444
- Get the names of all joints in the kinematic graph.
624
+ Get the names of all joints in the kinematic graph (i.e. the edges).
445
625
 
446
626
  Returns:
447
- List[str]: A list of joint names.
627
+ The list of joint names.
448
628
  """
449
629
  return list(self.joints_dict.keys())
450
630
 
451
- def frame_names(self) -> List[str]:
631
+ def frame_names(self) -> list[str]:
452
632
  """
453
633
  Get the names of all frames in the kinematic graph.
454
634
 
455
635
  Returns:
456
- List[str]: A list of frame names.
636
+ The list of frame names.
457
637
  """
458
- return list(self.frames_dict.keys())
459
-
460
- def transform(self, name: str) -> npt.NDArray:
461
- """
462
- Compute the transformation matrix for a given link, joint, or frame.
463
-
464
- Args:
465
- name (str): The name of the link, joint, or frame.
466
-
467
- Returns:
468
- npt.NDArray: The transformation matrix.
469
- """
470
- if name in self.transform_cache:
471
- return self.transform_cache[name]
472
-
473
- if name in self.joint_names():
474
- joint = self.joints_dict[name]
475
-
476
- if joint.initial_position != 0.0:
477
- msg = f"Ignoring unsupported initial position of joint '{name}'"
478
- logging.warning(msg=msg)
479
-
480
- transform = self.transform(name=joint.parent.name) @ joint.pose
481
- self.transform_cache[name] = transform
482
- return self.transform_cache[name]
483
-
484
- if name in self.link_names():
485
- link = self.links_dict[name]
486
638
 
487
- if link.name == self.root.name:
488
- return link.pose
489
-
490
- parent_joint = self.joints_connection_dict[(link.parent.name, link.name)]
491
- transform = self.transform(name=parent_joint.name) @ link.pose
492
- self.transform_cache[name] = transform
493
- return self.transform_cache[name]
494
-
495
- # It can only be a plain frame
496
- if name not in self.frame_names():
497
- raise ValueError(name)
498
-
499
- frame = self.frames_dict[name]
500
- transform = self.transform(name=frame.parent.name) @ frame.pose
501
- self.transform_cache[name] = transform
502
- return self.transform_cache[name]
503
-
504
- def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
505
- """
506
- Compute the relative transformation matrix between two elements in the kinematic graph.
507
-
508
- Args:
509
- relative_to (str): The name of the reference element.
510
- name (str): The name of the element to compute the relative transformation for.
511
-
512
- Returns:
513
- npt.NDArray: The relative transformation matrix.
514
- """
515
- return np.linalg.inv(self.transform(name=relative_to)) @ self.transform(
516
- name=name
517
- )
639
+ return list(self.frames_dict.keys())
518
640
 
519
641
  def print_tree(self) -> None:
520
642
  """
521
643
  Print the tree structure of the kinematic graph.
522
644
  """
645
+
523
646
  import pptree
524
647
 
525
648
  root_node = self.root
@@ -531,24 +654,37 @@ class KinematicGraph:
531
654
  horizontal=True,
532
655
  )
533
656
 
657
+ @property
658
+ def joints_removed(self) -> list[JointDescription]:
659
+ """
660
+ Get the list of joints removed during the graph reduction.
661
+
662
+ Returns:
663
+ The list of removed joints.
664
+ """
665
+
666
+ return self._joints_removed
667
+
534
668
  @staticmethod
535
669
  def breadth_first_search(
536
- root: descriptions.LinkDescription,
537
- sort_children: Optional[Callable[[Any], Any]] = lambda link: link.name,
538
- ) -> Iterable[descriptions.LinkDescription]:
670
+ root: LinkDescription,
671
+ sort_children: Callable[[Any], Any] | None = lambda link: link.name,
672
+ ) -> Iterable[LinkDescription]:
539
673
  """
540
674
  Perform a breadth-first search (BFS) traversal of the kinematic graph.
541
675
 
542
676
  Args:
543
- root (descriptions.LinkDescription): The root link for BFS.
544
- sort_children (Optional[Callable[[Any], Any]]): A function to sort children of a node.
677
+ root: The root link for BFS.
678
+ sort_children: A function to sort children of a node.
545
679
 
546
680
  Yields:
547
- Iterable[descriptions.LinkDescription]: An iterable of link descriptions.
681
+ The links in the kinematic graph in BFS order.
548
682
  """
683
+
684
+ # Initialize the queue with the root node.
549
685
  queue = [root]
550
686
 
551
- # We assume that nodes have unique names, and mark a link as visited using
687
+ # We assume that nodes have unique names and mark a link as visited using
552
688
  # its name. This speeds up considerably object comparison.
553
689
  visited = []
554
690
  visited.append(root.name)
@@ -556,11 +692,14 @@ class KinematicGraph:
556
692
  yield root
557
693
 
558
694
  while len(queue) > 0:
695
+
696
+ # Extract the first element of the queue.
559
697
  l = queue.pop(0)
560
698
 
561
699
  # Note: sorting the links with their name so that the order of children
562
- # insertion does not matter when assigning the link index
700
+ # insertion does not matter when assigning the link index.
563
701
  for child in sorted(l.children, key=sort_children):
702
+
564
703
  if child.name in visited:
565
704
  continue
566
705
 
@@ -569,25 +708,29 @@ class KinematicGraph:
569
708
 
570
709
  yield child
571
710
 
572
- def __iter__(self) -> Iterable[descriptions.LinkDescription]:
711
+ # =================
712
+ # Sequence protocol
713
+ # =================
714
+
715
+ def __iter__(self) -> Iterator[LinkDescription]:
573
716
  yield from KinematicGraph.breadth_first_search(root=self.root)
574
717
 
575
- def __reversed__(self) -> Iterable[descriptions.LinkDescription]:
718
+ def __reversed__(self) -> Iterable[LinkDescription]:
576
719
  yield from reversed(list(iter(self)))
577
720
 
578
721
  def __len__(self) -> int:
579
722
  return len(list(iter(self)))
580
723
 
581
- def __contains__(self, item: Union[str, descriptions.LinkDescription]) -> bool:
724
+ def __contains__(self, item: str | LinkDescription) -> bool:
582
725
  if isinstance(item, str):
583
726
  return item in self.link_names()
584
727
 
585
- if isinstance(item, descriptions.LinkDescription):
728
+ if isinstance(item, LinkDescription):
586
729
  return item in set(iter(self))
587
730
 
588
731
  raise TypeError(type(item).__name__)
589
732
 
590
- def __getitem__(self, key: Union[int, str]) -> descriptions.LinkDescription:
733
+ def __getitem__(self, key: int | str) -> LinkDescription:
591
734
  if isinstance(key, str):
592
735
  if key not in self.link_names():
593
736
  raise KeyError(key)
@@ -601,3 +744,242 @@ class KinematicGraph:
601
744
  return list(iter(self))[key]
602
745
 
603
746
  raise TypeError(type(key).__name__)
747
+
748
+ def count(self, value: LinkDescription) -> int:
749
+ """
750
+ Count the occurrences of a link in the kinematic graph.
751
+ """
752
+ return list(iter(self)).count(value)
753
+
754
+ def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:
755
+ """
756
+ Find the index of a link in the kinematic graph.
757
+ """
758
+ return list(iter(self)).index(value, start, stop)
759
+
760
+
761
+ # ====================
762
+ # Other useful classes
763
+ # ====================
764
+
765
+
766
+ @dataclasses.dataclass(frozen=True)
767
+ class KinematicGraphTransforms:
768
+ """
769
+ Class to compute forward kinematics on a kinematic graph.
770
+
771
+ Attributes:
772
+ graph: The kinematic graph on which to compute forward kinematics.
773
+ """
774
+
775
+ graph: KinematicGraph
776
+
777
+ _transform_cache: dict[str, npt.NDArray] = dataclasses.field(
778
+ default_factory=dict, init=False, repr=False, compare=False
779
+ )
780
+
781
+ _initial_joint_positions: dict[str, float] = dataclasses.field(
782
+ init=False, repr=False, compare=False
783
+ )
784
+
785
+ def __post_init__(self) -> None:
786
+
787
+ super().__setattr__(
788
+ "_initial_joint_positions",
789
+ {joint.name: joint.initial_position for joint in self.graph.joints},
790
+ )
791
+
792
+ @property
793
+ def initial_joint_positions(self) -> npt.NDArray:
794
+ """
795
+ Get the initial joint positions of the kinematic graph.
796
+ """
797
+
798
+ return np.atleast_1d(
799
+ np.array(list(self._initial_joint_positions.values()))
800
+ ).astype(float)
801
+
802
+ @initial_joint_positions.setter
803
+ def initial_joint_positions(
804
+ self,
805
+ positions: npt.NDArray | Sequence,
806
+ joint_names: Sequence[str] | None = None,
807
+ ) -> None:
808
+
809
+ joint_names = (
810
+ joint_names
811
+ if joint_names is not None
812
+ else list(self._initial_joint_positions.keys())
813
+ )
814
+
815
+ s = np.atleast_1d(np.array(positions).squeeze())
816
+
817
+ if s.size != len(joint_names):
818
+ raise ValueError(s.size, len(joint_names))
819
+
820
+ for joint_name in joint_names:
821
+ if joint_name not in self._initial_joint_positions:
822
+ raise ValueError(joint_name)
823
+
824
+ # Clear transform cache.
825
+ self._transform_cache.clear()
826
+
827
+ # Update initial joint positions.
828
+ for joint_name, position in zip(joint_names, s, strict=True):
829
+ self._initial_joint_positions[joint_name] = position
830
+
831
+ def transform(self, name: str) -> npt.NDArray:
832
+ """
833
+ Compute the SE(3) transform of elements belonging to the kinematic graph.
834
+
835
+ Args:
836
+ name: The name of a link, a joint, or a frame.
837
+
838
+ Returns:
839
+ The 4x4 transform matrix of the element w.r.t. the model frame.
840
+ """
841
+
842
+ # If the transform was already computed, return it.
843
+ if name in self._transform_cache:
844
+ return self._transform_cache[name]
845
+
846
+ # If the name is a joint, compute M_H_J transform.
847
+ if name in self.graph.joint_names():
848
+
849
+ # Get the joint.
850
+ joint = self.graph.joints_dict[name]
851
+ assert joint.name == name
852
+
853
+ # Get the transform of the parent link.
854
+ M_H_L = self.transform(name=joint.parent.name)
855
+
856
+ # Rename the pose of the predecessor joint frame w.r.t. its parent link.
857
+ L_H_pre = joint.pose
858
+
859
+ # Compute the joint transform from the predecessor to the successor frame.
860
+ pre_H_J = self.pre_H_suc(
861
+ joint_type=joint.jtype,
862
+ joint_axis=joint.axis,
863
+ joint_position=self._initial_joint_positions[joint.name],
864
+ )
865
+
866
+ # Compute the M_H_J transform.
867
+ self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J
868
+ return self._transform_cache[name]
869
+
870
+ # If the name is a link, compute M_H_L transform.
871
+ if name in self.graph.link_names():
872
+
873
+ # Get the link.
874
+ link = self.graph.links_dict[name]
875
+
876
+ # Handle the pose between the __model__ frame and the root link.
877
+ if link.name == self.graph.root.name:
878
+ M_H_B = link.pose
879
+ return M_H_B
880
+
881
+ # Get the joint between the link and its parent.
882
+ parent_joint = self.graph.joints_connection_dict[
883
+ link.parent.name, link.name
884
+ ]
885
+
886
+ # Get the transform of the parent joint.
887
+ M_H_J = self.transform(name=parent_joint.name)
888
+
889
+ # Rename the pose of the link w.r.t. its parent joint.
890
+ J_H_L = link.pose
891
+
892
+ # Compute the M_H_L transform.
893
+ self._transform_cache[name] = M_H_J @ J_H_L
894
+ return self._transform_cache[name]
895
+
896
+ # It can only be a plain frame.
897
+ if name not in self.graph.frame_names():
898
+ raise ValueError(name)
899
+
900
+ # Get the frame.
901
+ frame = self.graph.frames_dict[name]
902
+
903
+ # Get the transform of the parent link.
904
+ M_H_L = self.transform(name=frame.parent.name)
905
+
906
+ # Rename the pose of the frame w.r.t. its parent link.
907
+ L_H_F = frame.pose
908
+
909
+ # Compute the M_H_F transform.
910
+ self._transform_cache[name] = M_H_L @ L_H_F
911
+ return self._transform_cache[name]
912
+
913
+ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
914
+ """
915
+ Compute the SE(3) relative transform of elements belonging to the kinematic graph.
916
+
917
+ Args:
918
+ relative_to: The name of the reference element.
919
+ name: The name of a link, a joint, or a frame.
920
+
921
+ Returns:
922
+ The 4x4 transform matrix of the element w.r.t. the desired frame.
923
+ """
924
+
925
+ import jaxsim.math
926
+
927
+ M_H_target = self.transform(name=name)
928
+ M_H_R = self.transform(name=relative_to)
929
+
930
+ # Compute the relative transform R_H_target, where R is the reference frame,
931
+ # and i the frame of the desired link|joint|frame.
932
+ return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target
933
+
934
+ @staticmethod
935
+ def pre_H_suc(
936
+ joint_type: JointType,
937
+ joint_axis: npt.NDArray,
938
+ joint_position: float | None = None,
939
+ ) -> npt.NDArray:
940
+ """
941
+ Compute the SE(3) transform from the predecessor to the successor frame.
942
+
943
+ Args:
944
+ joint_type: The type of the joint.
945
+ joint_axis: The axis of the joint.
946
+ joint_position: The position of the joint.
947
+
948
+ Returns:
949
+ The 4x4 transform matrix from the predecessor to the successor frame.
950
+ """
951
+
952
+ import jaxsim.math
953
+
954
+ return np.array(
955
+ jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[
956
+ 0
957
+ ]
958
+ )
959
+
960
+ def find_parent_link_of_frame(self, name: str) -> str:
961
+ """
962
+ Find the parent link of a frame.
963
+
964
+ Args:
965
+ name: The name of the frame.
966
+
967
+ Returns:
968
+ The name of the parent link of the frame.
969
+ """
970
+
971
+ try:
972
+ frame = self.graph.frames_dict[name]
973
+ except KeyError as e:
974
+ raise ValueError(f"Frame '{name}' not found in the kinematic graph") from e
975
+
976
+ match frame.parent.name:
977
+ case parent_name if parent_name in self.graph.links_dict:
978
+ return parent_name
979
+
980
+ case parent_name if parent_name in self.graph.frames_dict:
981
+ return self.find_parent_link_of_frame(name=parent_name)
982
+
983
+ case _:
984
+ msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent.name}'"
985
+ raise RuntimeError(msg)