jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,263 +1,397 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import dataclasses
3
5
  import functools
4
- from typing import (
5
- Any,
6
- Callable,
7
- Dict,
8
- Iterable,
9
- List,
10
- NamedTuple,
11
- Optional,
12
- Tuple,
13
- Union,
14
- )
6
+ from collections.abc import Callable, Iterable, Iterator, Sequence
7
+ from typing import Any
15
8
 
16
9
  import numpy as np
17
10
  import numpy.typing as npt
18
11
 
12
+ import jaxsim.utils
19
13
  from jaxsim import logging
20
14
  from jaxsim.utils import Mutability
21
15
 
22
- from . import descriptions
16
+ from .descriptions.joint import JointDescription, JointType
17
+ from .descriptions.link import LinkDescription
23
18
 
24
19
 
25
- class RootPose(NamedTuple):
20
+ @dataclasses.dataclass
21
+ class RootPose:
26
22
  """
27
23
  Represents the root pose in a kinematic graph.
28
24
 
29
25
  Attributes:
30
- root_position (npt.NDArray): A NumPy array of shape (3,) representing the root's position.
31
- root_quaternion (npt.NDArray): A NumPy array of shape (4,) representing the root's quaternion.
26
+ root_position: The 3D position of the root link of the graph.
27
+ root_quaternion:
28
+ The quaternion representing the rotation of the root link of the graph.
29
+
30
+ Note:
31
+ The root link of the kinematic graph is the base link.
32
32
  """
33
33
 
34
- root_position: npt.NDArray = np.zeros(3)
35
- root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
34
+ root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
35
+
36
+ root_quaternion: npt.NDArray = dataclasses.field(
37
+ default_factory=lambda: np.array([1.0, 0, 0, 0])
38
+ )
39
+
40
+ def __hash__(self) -> int:
41
+
42
+ from jaxsim.utils.wrappers import HashedNumpyArray
43
+
44
+ return hash(
45
+ (
46
+ HashedNumpyArray.hash_of_array(self.root_position),
47
+ HashedNumpyArray.hash_of_array(self.root_quaternion),
48
+ )
49
+ )
50
+
51
+ def __eq__(self, other: RootPose) -> bool:
52
+
53
+ if not isinstance(other, RootPose):
54
+ return False
55
+
56
+ if not np.allclose(self.root_position, other.root_position):
57
+ return False
58
+
59
+ if not np.allclose(self.root_quaternion, other.root_quaternion):
60
+ return False
36
61
 
37
- def __eq__(self, other):
38
- return (self.root_position == other.root_position).all() and (
39
- self.root_quaternion == other.root_quaternion
40
- ).all()
62
+ return True
41
63
 
42
64
 
43
65
  @dataclasses.dataclass(frozen=True)
44
- class KinematicGraph:
66
+ class KinematicGraph(Sequence[LinkDescription]):
45
67
  """
46
- Represents a kinematic graph of links and joints.
47
-
48
- Args:
49
- root (descriptions.LinkDescription): The root link of the kinematic graph.
50
- frames (List[descriptions.LinkDescription]): A list of frame links in the graph.
51
- joints (List[descriptions.JointDescription]): A list of joint descriptions in the graph.
52
- root_pose (RootPose): The root pose of the graph.
53
- transform_cache (Dict[str, npt.NDArray]): A dictionary to cache transformation matrices.
54
- extra_info (Dict[str, Any]): Additional information associated with the graph.
68
+ Class storing a kinematic graph having links as nodes and joints as edges.
55
69
 
56
70
  Attributes:
57
- links_dict (Dict[str, descriptions.LinkDescription]): A dictionary mapping link names to link descriptions.
58
- frames_dict (Dict[str, descriptions.LinkDescription]): A dictionary mapping frame names to frame link descriptions.
59
- joints_dict (Dict[str, descriptions.JointDescription]): A dictionary mapping joint names to joint descriptions.
60
- joints_connection_dict (Dict[Tuple[str, str], descriptions.JointDescription]): A dictionary mapping pairs of parent and child link names to joint descriptions.
71
+ root: The root node of the kinematic graph.
72
+ frames: List of frames rigidly attached to the graph nodes.
73
+ joints: List of joints connecting the graph nodes.
74
+ root_pose: The pose of the kinematic graph's root.
61
75
  """
62
76
 
63
- root: descriptions.LinkDescription
64
- frames: List[descriptions.LinkDescription] = dataclasses.field(default_factory=list)
65
- joints: List[descriptions.JointDescription] = dataclasses.field(
66
- default_factory=list
77
+ root: LinkDescription
78
+ frames: list[LinkDescription] = dataclasses.field(
79
+ default_factory=list, hash=False, compare=False
80
+ )
81
+ joints: list[JointDescription] = dataclasses.field(
82
+ default_factory=list, hash=False, compare=False
67
83
  )
68
84
 
69
85
  root_pose: RootPose = dataclasses.field(default_factory=RootPose)
70
86
 
71
- transform_cache: Dict[str, npt.NDArray] = dataclasses.field(
72
- repr=False, init=False, compare=False, default_factory=dict
87
+ # Private attribute storing optional additional info.
88
+ _extra_info: dict[str, Any] = dataclasses.field(
89
+ default_factory=dict, repr=False, hash=False, compare=False
73
90
  )
74
91
 
75
- extra_info: Dict[str, Any] = dataclasses.field(
76
- repr=False, compare=False, default_factory=dict
92
+ # Private attribute storing the unconnected joints from the parsed model and
93
+ # the joints removed after model reduction.
94
+ _joints_removed: list[JointDescription] = dataclasses.field(
95
+ default_factory=list, repr=False, hash=False, compare=False
77
96
  )
78
97
 
79
98
  @functools.cached_property
80
- def links_dict(self) -> Dict[str, descriptions.LinkDescription]:
99
+ def links_dict(self) -> dict[str, LinkDescription]:
100
+ """
101
+ Get a dictionary of links indexed by their name.
102
+ """
81
103
  return {l.name: l for l in iter(self)}
82
104
 
83
105
  @functools.cached_property
84
- def frames_dict(self) -> Dict[str, descriptions.LinkDescription]:
106
+ def frames_dict(self) -> dict[str, LinkDescription]:
107
+ """
108
+ Get a dictionary of frames indexed by their name.
109
+ """
85
110
  return {f.name: f for f in self.frames}
86
111
 
87
112
  @functools.cached_property
88
- def joints_dict(self) -> Dict[str, descriptions.JointDescription]:
113
+ def joints_dict(self) -> dict[str, JointDescription]:
114
+ """
115
+ Get a dictionary of joints indexed by their name.
116
+ """
89
117
  return {j.name: j for j in self.joints}
90
118
 
91
119
  @functools.cached_property
92
120
  def joints_connection_dict(
93
121
  self,
94
- ) -> Dict[Tuple[str, str], descriptions.JointDescription]:
95
- return {(j.parent.name, j.child.name): j for j in self.joints}
96
-
97
- def __post_init__(self):
122
+ ) -> dict[tuple[str, str], JointDescription]:
98
123
  """
99
- Post-initialization method to set various properties and validate the kinematic graph.
124
+ Get a dictionary of joints indexed by the tuple (parent, child) link names.
100
125
  """
101
- # Assign the link index traversing the graph with BFS.
102
- # Here we assume the model is fixed-base, therefore the base link will
103
- # have index 0. We will deal with the floating base in a later stage,
104
- # when this Model object is converted to the physics model.
126
+ return {(j.parent.name, j.child.name): j for j in self.joints}
127
+
128
+ def __post_init__(self) -> None:
129
+
130
+ # Assign the link index by traversing the graph with BFS.
131
+ # Here we assume the model being fixed-base, therefore the base link will
132
+ # have index 0. We will deal with the floating base in a later stage.
105
133
  for index, link in enumerate(self):
106
134
  link.mutable(validate=False).index = index
107
135
 
108
- # Order frames with their name
136
+ # Get the names of the links, frames, and joints.
137
+ link_names = [l.name for l in self]
138
+ frame_names = [f.name for f in self.frames]
139
+ joint_names = [j.name for j in self.joints]
140
+
141
+ # Make sure that they are unique.
142
+ assert len(link_names) == len(set(link_names))
143
+ assert len(frame_names) == len(set(frame_names))
144
+ assert len(joint_names) == len(set(joint_names))
145
+ assert set(link_names).isdisjoint(set(frame_names))
146
+ assert set(link_names).isdisjoint(set(joint_names))
147
+
148
+ # Order frames with their name.
109
149
  super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name))
110
150
 
111
151
  # Assign the frame index following the name-based indexing.
112
- # Also here, we assume the model is fixed-base, therefore the first frame will
113
- # have last_link_idx + 1. These frames are not part of the physics model.
152
+ # We assume the model being fixed-base, therefore the first frame will
153
+ # have last_link_idx + 1.
114
154
  for index, frame in enumerate(self.frames):
115
- frame.index = index + len(self.link_names())
155
+ with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
156
+ frame.index = int(index + len(self.link_names()))
116
157
 
117
- # Number joints so that their index matches their child link index
158
+ # Number joints so that their index matches their child link index.
159
+ # Therefore, the first joint has index 1.
118
160
  links_dict = {l.name: l for l in iter(self)}
119
161
  for joint in self.joints:
120
162
  with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
121
163
  joint.index = links_dict[joint.child.name].index
122
164
 
123
- # Check that joint indices are unique
165
+ # Check that joint indices are unique.
124
166
  assert len([j.index for j in self.joints]) == len(
125
167
  {j.index for j in self.joints}
126
168
  )
127
169
 
128
- # Order joints with their indices
170
+ # Order joints with their indices.
129
171
  super().__setattr__("joints", sorted(self.joints, key=lambda j: j.index))
130
172
 
131
173
  @staticmethod
132
174
  def build_from(
133
- links: List[descriptions.LinkDescription],
134
- joints: List[descriptions.JointDescription],
175
+ links: list[LinkDescription],
176
+ joints: list[JointDescription],
177
+ frames: list[LinkDescription] | None = None,
135
178
  root_link_name: str | None = None,
136
179
  root_pose: RootPose = RootPose(),
137
- ) -> "KinematicGraph":
180
+ ) -> KinematicGraph:
138
181
  """
139
- Build a KinematicGraph from a list of links and joints.
182
+ Build a KinematicGraph from links, joints, and frames.
140
183
 
141
184
  Args:
142
- links (List[descriptions.LinkDescription]): A list of link descriptions.
143
- joints (List[descriptions.JointDescription]): A list of joint descriptions.
144
- root_link_name (str, optional): The name of the root link. If not provided, it's assumed to be the first link's name.
145
- root_pose (RootPose, optional): The root pose of the kinematic graph.
185
+ links: A list of link descriptions.
186
+ joints: A list of joint descriptions.
187
+ frames: A list of frame descriptions.
188
+ root_link_name:
189
+ The name of the root link. If not provided, it's assumed to be the
190
+ first link's name.
191
+ root_pose: The root pose of the kinematic graph.
146
192
 
147
193
  Returns:
148
- KinematicGraph: The constructed kinematic graph.
194
+ The resulting kinematic graph.
149
195
  """
196
+
197
+ # Consider the first link as the root link if not provided.
150
198
  if root_link_name is None:
151
199
  root_link_name = links[0].name
200
+ logging.debug(msg=f"Assuming '{root_link_name}' as the root link")
152
201
 
153
202
  # Couple links and joints and create the graph of links.
154
- # Note that the pose of the frames is not updated; it's the caller's
203
+ # Note that the pose of the frames is not updated; it is the caller's
155
204
  # responsibility to update their pose if they want to use them.
156
- graph_root_node, graph_joints, graph_frames = KinematicGraph.create_graph(
157
- links=links, joints=joints, root_link_name=root_link_name
205
+ (
206
+ graph_root_node,
207
+ graph_joints,
208
+ graph_frames,
209
+ unconnected_links,
210
+ unconnected_joints,
211
+ unconnected_frames,
212
+ ) = KinematicGraph._create_graph(
213
+ links=links, joints=joints, root_link_name=root_link_name, frames=frames
158
214
  )
159
215
 
160
- for frame in graph_frames:
161
- logging.warning(msg=f"Ignoring unconnected link / frame: '{frame.name}'")
216
+ for link in unconnected_links:
217
+ logging.warning(msg=f"Ignoring unconnected link: '{link.name}'")
218
+
219
+ for joint in unconnected_joints:
220
+ logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'")
221
+
222
+ for frame in unconnected_frames:
223
+ logging.warning(msg=f"Ignoring unconnected frame: '{frame.name}'")
162
224
 
163
225
  return KinematicGraph(
164
- root=graph_root_node, joints=graph_joints, frames=[], root_pose=root_pose
226
+ root=graph_root_node,
227
+ joints=graph_joints,
228
+ frames=graph_frames,
229
+ root_pose=root_pose,
230
+ _joints_removed=unconnected_joints,
165
231
  )
166
232
 
167
233
  @staticmethod
168
- def create_graph(
169
- links: List[descriptions.LinkDescription],
170
- joints: List[descriptions.JointDescription],
234
+ def _create_graph(
235
+ links: list[LinkDescription],
236
+ joints: list[JointDescription],
171
237
  root_link_name: str,
172
- ) -> Tuple[
173
- descriptions.LinkDescription,
174
- List[descriptions.JointDescription],
175
- List[descriptions.LinkDescription],
238
+ frames: list[LinkDescription] | None = None,
239
+ ) -> tuple[
240
+ LinkDescription,
241
+ list[JointDescription],
242
+ list[LinkDescription],
243
+ list[LinkDescription],
244
+ list[JointDescription],
245
+ list[LinkDescription],
176
246
  ]:
177
247
  """
178
- Create a kinematic graph from lists of links and joints.
248
+ Low-level creator of kinematic graph components.
179
249
 
180
250
  Args:
181
- links (List[descriptions.LinkDescription]): A list of link descriptions.
182
- joints (List[descriptions.JointDescription]): A list of joint descriptions.
183
- root_link_name (str): The name of the root link.
251
+ links: A list of parsed link descriptions.
252
+ joints: A list of parsed joint descriptions.
253
+ root_link_name: The name of the root link used as root node of the graph.
254
+ frames: A list of parsed frame descriptions.
184
255
 
185
256
  Returns:
186
- Tuple[descriptions.LinkDescription, List[descriptions.JointDescription], List[descriptions.LinkDescription]]:
187
- A tuple containing the root link, list of joints, and list of frames in the graph.
257
+ A tuple containing the root node of the graph (defining the entire kinematic
258
+ tree by iterating on its child nodes), the list of joints representing the
259
+ actual graph edges, the list of frames rigidly attached to the graph nodes,
260
+ the list of unconnected links, the list of unconnected joints, and the list
261
+ of unconnected frames.
188
262
  """
189
263
 
190
- # Create a dict that maps link name to the link, for easy retrieval
191
- links_dict: Dict[str, descriptions.LinkDescription] = {
264
+ # Create a dictionary that maps the link name to the link, for easy retrieval.
265
+ links_dict: dict[str, LinkDescription] = {
192
266
  l.name: l.mutable(validate=False) for l in links
193
267
  }
194
268
 
269
+ # Create an empty list of frames if not provided.
270
+ frames = frames if frames is not None else []
271
+
272
+ # Create a dictionary that maps the frame name to the frame, for easy retrieval.
273
+ frames_dict = {frame.name: frame for frame in frames}
274
+
275
+ # Check that our parser correctly resolved the frame's parent to be a link.
276
+ for frame in frames:
277
+ assert frame.parent.name != "", frame
278
+ assert frame.parent.name is not None, frame
279
+ assert frame.parent.name != "__model__", frame
280
+ assert frame.parent.name not in frames_dict, frame
281
+
282
+ # ===========================================================
283
+ # Populate the kinematic graph with links, joints, and frames
284
+ # ===========================================================
285
+
286
+ # Check the existence of the root link.
195
287
  if root_link_name not in links_dict:
196
288
  raise ValueError(root_link_name)
197
289
 
198
- # Reset the connections of the root link
290
+ # Reset the connections of the root link.
199
291
  for link in links_dict.values():
200
- link.children = []
292
+ link.children = tuple()
201
293
 
202
- # Couple links and joints creating the final kinematic graph
294
+ # Couple links and joints creating the kinematic graph.
203
295
  for joint in joints:
204
- # Get the parent and child links of the joint
296
+
297
+ # Get the parent and child links of the joint.
205
298
  parent_link = links_dict[joint.parent.name]
206
299
  child_link = links_dict[joint.child.name]
207
300
 
208
301
  assert child_link.name == joint.child.name
209
302
  assert parent_link.name == joint.parent.name
210
303
 
211
- # Assign link parent
304
+ # Assign link's parent.
212
305
  child_link.parent = parent_link
213
306
 
214
- # Assign link children and make sure they are unique
307
+ # Assign link's children and make sure they are unique.
215
308
  if child_link.name not in {l.name for l in parent_link.children}:
216
- parent_link.children.append(child_link)
309
+ with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):
310
+ parent_link.children = (*parent_link.children, child_link)
217
311
 
218
- # Collect all the links of the kinematic graph
312
+ # Collect all the links of the kinematic graph.
219
313
  all_links_in_graph = list(
220
314
  KinematicGraph.breadth_first_search(root=links_dict[root_link_name])
221
315
  )
316
+
317
+ # Get the names of all links in the kinematic graph.
222
318
  all_link_names_in_graph = [l.name for l in all_links_in_graph]
223
319
 
224
- # Collect all the joints not part of the kinematic graph
225
- removed_joints = [
226
- j
227
- for j in joints
228
- if not {j.parent.name, j.child.name}.issubset(all_link_names_in_graph)
320
+ # Collect all the joints of the kinematic graph.
321
+ all_joints_in_graph = [
322
+ joint
323
+ for joint in joints
324
+ if joint.parent.name in all_link_names_in_graph
325
+ and joint.child.name in all_link_names_in_graph
229
326
  ]
230
327
 
231
- for removed_joint in removed_joints:
232
- msg = "Joint '{}' has been removed for the graph because unconnected"
233
- logging.info(msg=msg.format(removed_joint.name))
328
+ # Get the names of all joints in the kinematic graph.
329
+ all_joint_names_in_graph = [j.name for j in all_joints_in_graph]
234
330
 
235
- # Store as frames all the links that are not part of the kinematic graph
236
- frames = list(set(links) - set(all_links_in_graph))
331
+ # Collect all the frames of the kinematic graph.
332
+ # Note: our parser ensures that the parent of a frame is not another frame.
333
+ all_frames_in_graph = [
334
+ frame for frame in frames if frame.parent.name in all_link_names_in_graph
335
+ ]
237
336
 
238
- # Update the frames. In particular, reset their children. The other properties
239
- # are kept as they are, and it's caller responsibility to update them if needed.
240
- for frame in frames:
241
- frame.children = []
242
- msg = f"Link '{frame.name}' became a frame"
243
- logging.info(msg=msg)
337
+ # Get the names of all frames in the kinematic graph.
338
+ all_frames_names_in_graph = [f.name for f in all_frames_in_graph]
339
+
340
+ # ============================
341
+ # Collect unconnected elements
342
+ # ============================
343
+
344
+ # Collect all the joints that are not part of the kinematic graph.
345
+ removed_joints = [j for j in joints if j.name not in all_joint_names_in_graph]
346
+
347
+ for joint in removed_joints:
348
+ msg = "Joint '{}' is unconnected and it will be removed"
349
+ logging.debug(msg=msg.format(joint.name))
350
+
351
+ # Collect all the links that are not part of the kinematic graph.
352
+ unconnected_links = [l for l in links if l.name not in all_link_names_in_graph]
353
+
354
+ # Update the unconnected links by removing their children. The other properties
355
+ # are left untouched, it's caller responsibility to post-process them if needed.
356
+ for link in unconnected_links:
357
+ link.children = tuple()
358
+ msg = "Link '{}' won't be part of the kinematic graph because unconnected"
359
+ logging.debug(msg=msg.format(link.name))
360
+
361
+ # Collect all the frames that are not part of the kinematic graph.
362
+ unconnected_frames = [
363
+ f for f in frames if f.name not in all_frames_names_in_graph
364
+ ]
365
+
366
+ for frame in unconnected_frames:
367
+ msg = "Frame '{}' won't be part of the kinematic graph because unconnected"
368
+ logging.debug(msg=msg.format(frame.name))
244
369
 
245
370
  return (
246
371
  links_dict[root_link_name].mutable(mutable=False),
247
372
  list(set(joints) - set(removed_joints)),
248
- frames,
373
+ all_frames_in_graph,
374
+ unconnected_links,
375
+ list(set(removed_joints)),
376
+ unconnected_frames,
249
377
  )
250
378
 
251
- def reduce(self, considered_joints: List[str]) -> "KinematicGraph":
379
+ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
252
380
  """
253
- Reduce the kinematic graph by removing specified joints and lumping the mass and inertia of removed links into their parent links.
381
+ Reduce the kinematic graph by removing unspecified joints.
382
+
383
+ When a joint is removed, the mass and inertia of its child link are lumped
384
+ with those of its parent link, obtaining a new link that combines the two.
385
+ The description of the removed joint specifies the default angle (usually 0)
386
+ that is considered when the joint is removed.
254
387
 
255
388
  Args:
256
- considered_joints (List[str]): A list of joint names to consider.
389
+ considered_joints: A list of joint names to consider.
257
390
 
258
391
  Returns:
259
- KinematicGraph: The reduced kinematic graph.
392
+ The reduced kinematic graph.
260
393
  """
394
+
261
395
  # The current object represents the complete kinematic graph
262
396
  full_graph = self
263
397
 
@@ -268,11 +402,11 @@ class KinematicGraph:
268
402
 
269
403
  # Return early if there is no action to take
270
404
  if len(joint_names_to_remove) == 0:
271
- logging.info(f"The kinematic graph doesn't need to be reduced")
405
+ logging.info("The kinematic graph doesn't need to be reduced")
272
406
  return copy.deepcopy(self)
273
407
 
274
408
  # Check if all considered joints are part of the full kinematic graph
275
- if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
409
+ if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
276
410
  extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
277
411
  msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
278
412
  raise ValueError(msg)
@@ -281,6 +415,9 @@ class KinematicGraph:
281
415
  links_dict = copy.deepcopy(full_graph.links_dict)
282
416
  joints_dict = copy.deepcopy(full_graph.joints_dict)
283
417
 
418
+ # Create the object to compute forward kinematics.
419
+ fk = KinematicGraphTransforms(graph=full_graph)
420
+
284
421
  # The following steps are implemented below in order to create the reduced graph:
285
422
  #
286
423
  # 1. Lump the mass of the removed links into their parent
@@ -320,7 +457,7 @@ class KinematicGraph:
320
457
  msg.format(
321
458
  link_to_remove.name,
322
459
  self.joints_connection_dict[
323
- (parent_of_link_to_remove.name, link_to_remove.name)
460
+ parent_of_link_to_remove.name, link_to_remove.name
324
461
  ].name,
325
462
  parent_of_link_to_remove.name,
326
463
  )
@@ -329,14 +466,14 @@ class KinematicGraph:
329
466
  # Lump the link
330
467
  lumped_link = parent_of_link_to_remove.lump_with(
331
468
  link=link_to_remove,
332
- lumped_H_removed=full_graph.relative_transform(
469
+ lumped_H_removed=fk.relative_transform(
333
470
  relative_to=parent_of_link_to_remove.name, name=link_to_remove.name
334
471
  ),
335
472
  )
336
473
 
337
474
  # Pop the original two links from the dictionary...
338
- links_dict.pop(link_to_remove.name)
339
- links_dict.pop(parent_of_link_to_remove.name)
475
+ _ = links_dict.pop(link_to_remove.name)
476
+ _ = links_dict.pop(parent_of_link_to_remove.name)
340
477
 
341
478
  # ... and insert the lumped link (having the same name of the parent)
342
479
  links_dict[lumped_link.name] = lumped_link
@@ -346,11 +483,13 @@ class KinematicGraph:
346
483
  links_dict[link_to_remove.name] = lumped_link
347
484
 
348
485
  # As a consequence of the back-insertion, we need to adjust the resulting
349
- # lumped link of links that have been removed previously
486
+ # lumped link of links that have been removed previously.
487
+ # Note: in the dictionary, only items whose key is not matching value.name
488
+ # are links that have been removed.
350
489
  for previously_removed_link_name in {
351
- k
352
- for k, v in links_dict.items()
353
- if k != v.name and v.name == link_to_remove.name
490
+ link_name
491
+ for link_name, link in links_dict.items()
492
+ if link_name != link.name and link.name == link_to_remove.name
354
493
  }:
355
494
  links_dict[previously_removed_link_name] = lumped_link
356
495
 
@@ -370,7 +509,7 @@ class KinematicGraph:
370
509
  # Update the pose. Note that after the lumping process, the dict entry
371
510
  # links_dict[joint.parent.name] contains the final lumped link
372
511
  with joint.mutable_context(mutability=Mutability.MUTABLE):
373
- joint.pose = full_graph.relative_transform(
512
+ joint.pose = fk.relative_transform(
374
513
  relative_to=links_dict[joint.parent.name].name, name=joint.name
375
514
  )
376
515
  with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
@@ -396,135 +535,114 @@ class KinematicGraph:
396
535
 
397
536
  # Create the reduced graph data. We pass the full list of links so that those
398
537
  # that are not part of the graph will be returned as frames.
399
- reduced_root_node, reduced_joints, reduced_frames = KinematicGraph.create_graph(
538
+ (
539
+ reduced_root_node,
540
+ reduced_joints,
541
+ reduced_frames,
542
+ unconnected_links,
543
+ unconnected_joints,
544
+ unconnected_frames,
545
+ ) = KinematicGraph._create_graph(
400
546
  links=list(full_graph_links_dict.values()),
401
547
  joints=[joints_dict[joint_name] for joint_name in considered_joints],
402
548
  root_link_name=full_graph.root.name,
403
549
  )
404
550
 
405
- # Create the reduced graph
551
+ assert {f.name for f in self.frames}.isdisjoint(
552
+ {f.name for f in unconnected_frames + reduced_frames}
553
+ )
554
+
555
+ for link in unconnected_links:
556
+ logging.debug(msg=f"Link '{link.name}' is unconnected and became a frame")
557
+
558
+ # Create the reduced graph.
406
559
  reduced_graph = KinematicGraph(
407
560
  root=reduced_root_node,
408
561
  joints=reduced_joints,
409
- frames=reduced_frames,
562
+ frames=self.frames + unconnected_links + reduced_frames,
410
563
  root_pose=full_graph.root_pose,
564
+ _joints_removed=(
565
+ self._joints_removed
566
+ + unconnected_joints
567
+ + [joints_dict[name] for name in joint_names_to_remove]
568
+ ),
411
569
  )
412
570
 
413
571
  # ================================================================
414
572
  # 4. Resolve the pose of the frames wrt their reduced graph parent
415
573
  # ================================================================
416
574
 
417
- # Update frames properties using the transforms from the full graph
418
- for frame in reduced_graph.frames:
419
- # Get the link in which the removed link was lumped into
420
- new_parent_link = links_dict[frame.name]
575
+ # Build a new object to compute FK on the reduced graph.
576
+ fk_reduced = KinematicGraphTransforms(graph=reduced_graph)
421
577
 
422
- msg = f"New parent of frame '{frame.name}' is '{new_parent_link.name}'"
423
- logging.info(msg)
578
+ # We need to adjust the pose of the frames since their parent link
579
+ # could have been removed by the reduction process.
580
+ for frame in reduced_graph.frames:
424
581
 
425
- # Update the connection of the frame
426
- frame.parent = new_parent_link
427
- frame.pose = full_graph.relative_transform(
428
- relative_to=new_parent_link.name, name=frame.name
582
+ # Always find the real parent link of the frame
583
+ name_of_new_parent_link = fk_reduced.find_parent_link_of_frame(
584
+ name=frame.name
429
585
  )
586
+ assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link
430
587
 
431
- # Update frame data
432
- frame.mass = 0.0
433
- frame.inertia = np.zeros_like(frame.inertia)
588
+ # Notify the user if the parent link has changed.
589
+ if name_of_new_parent_link != frame.parent.name:
590
+ msg = "New parent of frame '{}' is '{}'"
591
+ logging.debug(msg=msg.format(frame.name, name_of_new_parent_link))
434
592
 
435
- # Return the reduced graph
593
+ # Always recompute the pose of the frame, and set zero inertial params.
594
+ with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION):
595
+
596
+ # Update kinematic parameters of the frame.
597
+ # Note that here we compute the transform using the FK object of the
598
+ # full model, so that we are sure that the kinematic is not altered.
599
+ frame.pose = fk.relative_transform(
600
+ relative_to=name_of_new_parent_link, name=frame.name
601
+ )
602
+
603
+ # Update the parent link such that the pose is expressed in its frame.
604
+ frame.parent = reduced_graph.links_dict[name_of_new_parent_link]
605
+
606
+ # Update dynamic parameters of the frame.
607
+ frame.mass = 0.0
608
+ frame.inertia = np.zeros_like(frame.inertia)
609
+
610
+ # Return the reduced graph.
436
611
  return reduced_graph
437
612
 
438
- def link_names(self) -> List[str]:
613
+ def link_names(self) -> list[str]:
439
614
  """
440
- Get the names of all links in the kinematic graph.
615
+ Get the names of all links in the kinematic graph (i.e. the nodes).
441
616
 
442
617
  Returns:
443
- List[str]: A list of link names.
618
+ The list of link names.
444
619
  """
445
620
  return list(self.links_dict.keys())
446
621
 
447
- def joint_names(self) -> List[str]:
622
+ def joint_names(self) -> list[str]:
448
623
  """
449
- Get the names of all joints in the kinematic graph.
624
+ Get the names of all joints in the kinematic graph (i.e. the edges).
450
625
 
451
626
  Returns:
452
- List[str]: A list of joint names.
627
+ The list of joint names.
453
628
  """
454
629
  return list(self.joints_dict.keys())
455
630
 
456
- def frame_names(self) -> List[str]:
631
+ def frame_names(self) -> list[str]:
457
632
  """
458
633
  Get the names of all frames in the kinematic graph.
459
634
 
460
635
  Returns:
461
- List[str]: A list of frame names.
636
+ The list of frame names.
462
637
  """
463
- return list(self.frames_dict.keys())
464
-
465
- def transform(self, name: str) -> npt.NDArray:
466
- """
467
- Compute the transformation matrix for a given link, joint, or frame.
468
-
469
- Args:
470
- name (str): The name of the link, joint, or frame.
471
-
472
- Returns:
473
- npt.NDArray: The transformation matrix.
474
- """
475
- if name in self.transform_cache:
476
- return self.transform_cache[name]
477
-
478
- if name in self.joint_names():
479
- joint = self.joints_dict[name]
480
-
481
- if joint.initial_position != 0.0:
482
- msg = f"Ignoring unsupported initial position of joint '{name}'"
483
- logging.warning(msg=msg)
484
-
485
- transform = self.transform(name=joint.parent.name) @ joint.pose
486
- self.transform_cache[name] = transform
487
- return self.transform_cache[name]
488
-
489
- if name in self.link_names():
490
- link = self.links_dict[name]
491
638
 
492
- if link.name == self.root.name:
493
- return link.pose
494
-
495
- parent_joint = self.joints_connection_dict[(link.parent.name, link.name)]
496
- transform = self.transform(name=parent_joint.name) @ link.pose
497
- self.transform_cache[name] = transform
498
- return self.transform_cache[name]
499
-
500
- # It can only be a plain frame
501
- if name not in self.frame_names():
502
- raise ValueError(name)
503
-
504
- frame = self.frames_dict[name]
505
- transform = self.transform(name=frame.parent.name) @ frame.pose
506
- self.transform_cache[name] = transform
507
- return self.transform_cache[name]
508
-
509
- def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
510
- """
511
- Compute the relative transformation matrix between two elements in the kinematic graph.
512
-
513
- Args:
514
- relative_to (str): The name of the reference element.
515
- name (str): The name of the element to compute the relative transformation for.
516
-
517
- Returns:
518
- npt.NDArray: The relative transformation matrix.
519
- """
520
- return np.linalg.inv(self.transform(name=relative_to)) @ self.transform(
521
- name=name
522
- )
639
+ return list(self.frames_dict.keys())
523
640
 
524
641
  def print_tree(self) -> None:
525
642
  """
526
643
  Print the tree structure of the kinematic graph.
527
644
  """
645
+
528
646
  import pptree
529
647
 
530
648
  root_node = self.root
@@ -536,24 +654,37 @@ class KinematicGraph:
536
654
  horizontal=True,
537
655
  )
538
656
 
657
+ @property
658
+ def joints_removed(self) -> list[JointDescription]:
659
+ """
660
+ Get the list of joints removed during the graph reduction.
661
+
662
+ Returns:
663
+ The list of removed joints.
664
+ """
665
+
666
+ return self._joints_removed
667
+
539
668
  @staticmethod
540
669
  def breadth_first_search(
541
- root: descriptions.LinkDescription,
542
- sort_children: Optional[Callable[[Any], Any]] = lambda link: link.name,
543
- ) -> Iterable[descriptions.LinkDescription]:
670
+ root: LinkDescription,
671
+ sort_children: Callable[[Any], Any] | None = lambda link: link.name,
672
+ ) -> Iterable[LinkDescription]:
544
673
  """
545
674
  Perform a breadth-first search (BFS) traversal of the kinematic graph.
546
675
 
547
676
  Args:
548
- root (descriptions.LinkDescription): The root link for BFS.
549
- sort_children (Optional[Callable[[Any], Any]]): A function to sort children of a node.
677
+ root: The root link for BFS.
678
+ sort_children: A function to sort children of a node.
550
679
 
551
680
  Yields:
552
- Iterable[descriptions.LinkDescription]: An iterable of link descriptions.
681
+ The links in the kinematic graph in BFS order.
553
682
  """
683
+
684
+ # Initialize the queue with the root node.
554
685
  queue = [root]
555
686
 
556
- # We assume that nodes have unique names, and mark a link as visited using
687
+ # We assume that nodes have unique names and mark a link as visited using
557
688
  # its name. This speeds up considerably object comparison.
558
689
  visited = []
559
690
  visited.append(root.name)
@@ -561,11 +692,14 @@ class KinematicGraph:
561
692
  yield root
562
693
 
563
694
  while len(queue) > 0:
695
+
696
+ # Extract the first element of the queue.
564
697
  l = queue.pop(0)
565
698
 
566
699
  # Note: sorting the links with their name so that the order of children
567
- # insertion does not matter when assigning the link index
700
+ # insertion does not matter when assigning the link index.
568
701
  for child in sorted(l.children, key=sort_children):
702
+
569
703
  if child.name in visited:
570
704
  continue
571
705
 
@@ -574,25 +708,29 @@ class KinematicGraph:
574
708
 
575
709
  yield child
576
710
 
577
- def __iter__(self) -> Iterable[descriptions.LinkDescription]:
711
+ # =================
712
+ # Sequence protocol
713
+ # =================
714
+
715
+ def __iter__(self) -> Iterator[LinkDescription]:
578
716
  yield from KinematicGraph.breadth_first_search(root=self.root)
579
717
 
580
- def __reversed__(self) -> Iterable[descriptions.LinkDescription]:
718
+ def __reversed__(self) -> Iterable[LinkDescription]:
581
719
  yield from reversed(list(iter(self)))
582
720
 
583
721
  def __len__(self) -> int:
584
722
  return len(list(iter(self)))
585
723
 
586
- def __contains__(self, item: Union[str, descriptions.LinkDescription]) -> bool:
724
+ def __contains__(self, item: str | LinkDescription) -> bool:
587
725
  if isinstance(item, str):
588
726
  return item in self.link_names()
589
727
 
590
- if isinstance(item, descriptions.LinkDescription):
728
+ if isinstance(item, LinkDescription):
591
729
  return item in set(iter(self))
592
730
 
593
731
  raise TypeError(type(item).__name__)
594
732
 
595
- def __getitem__(self, key: Union[int, str]) -> descriptions.LinkDescription:
733
+ def __getitem__(self, key: int | str) -> LinkDescription:
596
734
  if isinstance(key, str):
597
735
  if key not in self.link_names():
598
736
  raise KeyError(key)
@@ -606,3 +744,242 @@ class KinematicGraph:
606
744
  return list(iter(self))[key]
607
745
 
608
746
  raise TypeError(type(key).__name__)
747
+
748
+ def count(self, value: LinkDescription) -> int:
749
+ """
750
+ Count the occurrences of a link in the kinematic graph.
751
+ """
752
+ return list(iter(self)).count(value)
753
+
754
+ def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:
755
+ """
756
+ Find the index of a link in the kinematic graph.
757
+ """
758
+ return list(iter(self)).index(value, start, stop)
759
+
760
+
761
+ # ====================
762
+ # Other useful classes
763
+ # ====================
764
+
765
+
766
+ @dataclasses.dataclass(frozen=True)
767
+ class KinematicGraphTransforms:
768
+ """
769
+ Class to compute forward kinematics on a kinematic graph.
770
+
771
+ Attributes:
772
+ graph: The kinematic graph on which to compute forward kinematics.
773
+ """
774
+
775
+ graph: KinematicGraph
776
+
777
+ _transform_cache: dict[str, npt.NDArray] = dataclasses.field(
778
+ default_factory=dict, init=False, repr=False, compare=False
779
+ )
780
+
781
+ _initial_joint_positions: dict[str, float] = dataclasses.field(
782
+ init=False, repr=False, compare=False
783
+ )
784
+
785
+ def __post_init__(self) -> None:
786
+
787
+ super().__setattr__(
788
+ "_initial_joint_positions",
789
+ {joint.name: joint.initial_position for joint in self.graph.joints},
790
+ )
791
+
792
+ @property
793
+ def initial_joint_positions(self) -> npt.NDArray:
794
+ """
795
+ Get the initial joint positions of the kinematic graph.
796
+ """
797
+
798
+ return np.atleast_1d(
799
+ np.array(list(self._initial_joint_positions.values()))
800
+ ).astype(float)
801
+
802
+ @initial_joint_positions.setter
803
+ def initial_joint_positions(
804
+ self,
805
+ positions: npt.NDArray | Sequence,
806
+ joint_names: Sequence[str] | None = None,
807
+ ) -> None:
808
+
809
+ joint_names = (
810
+ joint_names
811
+ if joint_names is not None
812
+ else list(self._initial_joint_positions.keys())
813
+ )
814
+
815
+ s = np.atleast_1d(np.array(positions).squeeze())
816
+
817
+ if s.size != len(joint_names):
818
+ raise ValueError(s.size, len(joint_names))
819
+
820
+ for joint_name in joint_names:
821
+ if joint_name not in self._initial_joint_positions:
822
+ raise ValueError(joint_name)
823
+
824
+ # Clear transform cache.
825
+ self._transform_cache.clear()
826
+
827
+ # Update initial joint positions.
828
+ for joint_name, position in zip(joint_names, s, strict=True):
829
+ self._initial_joint_positions[joint_name] = position
830
+
831
+ def transform(self, name: str) -> npt.NDArray:
832
+ """
833
+ Compute the SE(3) transform of elements belonging to the kinematic graph.
834
+
835
+ Args:
836
+ name: The name of a link, a joint, or a frame.
837
+
838
+ Returns:
839
+ The 4x4 transform matrix of the element w.r.t. the model frame.
840
+ """
841
+
842
+ # If the transform was already computed, return it.
843
+ if name in self._transform_cache:
844
+ return self._transform_cache[name]
845
+
846
+ # If the name is a joint, compute M_H_J transform.
847
+ if name in self.graph.joint_names():
848
+
849
+ # Get the joint.
850
+ joint = self.graph.joints_dict[name]
851
+ assert joint.name == name
852
+
853
+ # Get the transform of the parent link.
854
+ M_H_L = self.transform(name=joint.parent.name)
855
+
856
+ # Rename the pose of the predecessor joint frame w.r.t. its parent link.
857
+ L_H_pre = joint.pose
858
+
859
+ # Compute the joint transform from the predecessor to the successor frame.
860
+ pre_H_J = self.pre_H_suc(
861
+ joint_type=joint.jtype,
862
+ joint_axis=joint.axis,
863
+ joint_position=self._initial_joint_positions[joint.name],
864
+ )
865
+
866
+ # Compute the M_H_J transform.
867
+ self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J
868
+ return self._transform_cache[name]
869
+
870
+ # If the name is a link, compute M_H_L transform.
871
+ if name in self.graph.link_names():
872
+
873
+ # Get the link.
874
+ link = self.graph.links_dict[name]
875
+
876
+ # Handle the pose between the __model__ frame and the root link.
877
+ if link.name == self.graph.root.name:
878
+ M_H_B = link.pose
879
+ return M_H_B
880
+
881
+ # Get the joint between the link and its parent.
882
+ parent_joint = self.graph.joints_connection_dict[
883
+ link.parent.name, link.name
884
+ ]
885
+
886
+ # Get the transform of the parent joint.
887
+ M_H_J = self.transform(name=parent_joint.name)
888
+
889
+ # Rename the pose of the link w.r.t. its parent joint.
890
+ J_H_L = link.pose
891
+
892
+ # Compute the M_H_L transform.
893
+ self._transform_cache[name] = M_H_J @ J_H_L
894
+ return self._transform_cache[name]
895
+
896
+ # It can only be a plain frame.
897
+ if name not in self.graph.frame_names():
898
+ raise ValueError(name)
899
+
900
+ # Get the frame.
901
+ frame = self.graph.frames_dict[name]
902
+
903
+ # Get the transform of the parent link.
904
+ M_H_L = self.transform(name=frame.parent.name)
905
+
906
+ # Rename the pose of the frame w.r.t. its parent link.
907
+ L_H_F = frame.pose
908
+
909
+ # Compute the M_H_F transform.
910
+ self._transform_cache[name] = M_H_L @ L_H_F
911
+ return self._transform_cache[name]
912
+
913
+ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
914
+ """
915
+ Compute the SE(3) relative transform of elements belonging to the kinematic graph.
916
+
917
+ Args:
918
+ relative_to: The name of the reference element.
919
+ name: The name of a link, a joint, or a frame.
920
+
921
+ Returns:
922
+ The 4x4 transform matrix of the element w.r.t. the desired frame.
923
+ """
924
+
925
+ import jaxsim.math
926
+
927
+ M_H_target = self.transform(name=name)
928
+ M_H_R = self.transform(name=relative_to)
929
+
930
+ # Compute the relative transform R_H_target, where R is the reference frame,
931
+ # and i the frame of the desired link|joint|frame.
932
+ return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target
933
+
934
+ @staticmethod
935
+ def pre_H_suc(
936
+ joint_type: JointType,
937
+ joint_axis: npt.NDArray,
938
+ joint_position: float | None = None,
939
+ ) -> npt.NDArray:
940
+ """
941
+ Compute the SE(3) transform from the predecessor to the successor frame.
942
+
943
+ Args:
944
+ joint_type: The type of the joint.
945
+ joint_axis: The axis of the joint.
946
+ joint_position: The position of the joint.
947
+
948
+ Returns:
949
+ The 4x4 transform matrix from the predecessor to the successor frame.
950
+ """
951
+
952
+ import jaxsim.math
953
+
954
+ return np.array(
955
+ jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[
956
+ 0
957
+ ]
958
+ )
959
+
960
+ def find_parent_link_of_frame(self, name: str) -> str:
961
+ """
962
+ Find the parent link of a frame.
963
+
964
+ Args:
965
+ name: The name of the frame.
966
+
967
+ Returns:
968
+ The name of the parent link of the frame.
969
+ """
970
+
971
+ try:
972
+ frame = self.graph.frames_dict[name]
973
+ except KeyError as e:
974
+ raise ValueError(f"Frame '{name}' not found in the kinematic graph") from e
975
+
976
+ match frame.parent.name:
977
+ case parent_name if parent_name in self.graph.links_dict:
978
+ return parent_name
979
+
980
+ case parent_name if parent_name in self.graph.frames_dict:
981
+ return self.find_parent_link_of_frame(name=parent_name)
982
+
983
+ case _:
984
+ msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent.name}'"
985
+ raise RuntimeError(msg)