jaxsim 0.2.1.dev62__py3-none-any.whl → 0.2.1.dev70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,22 +3,12 @@ from __future__ import annotations
3
3
  import copy
4
4
  import dataclasses
5
5
  import functools
6
- from typing import (
7
- Any,
8
- Callable,
9
- Dict,
10
- Iterable,
11
- List,
12
- NamedTuple,
13
- Optional,
14
- Sequence,
15
- Tuple,
16
- Union,
17
- )
6
+ from typing import Any, Callable, Iterable, NamedTuple, Sequence
18
7
 
19
8
  import numpy as np
20
9
  import numpy.typing as npt
21
10
 
11
+ import jaxsim.utils
22
12
  from jaxsim import logging
23
13
  from jaxsim.utils import Mutability
24
14
 
@@ -30,52 +20,49 @@ class RootPose(NamedTuple):
30
20
  Represents the root pose in a kinematic graph.
31
21
 
32
22
  Attributes:
33
- root_position (npt.NDArray): A NumPy array of shape (3,) representing the root's position.
34
- root_quaternion (npt.NDArray): A NumPy array of shape (4,) representing the root's quaternion.
23
+ root_position: The 3D position of the root link of the graph.
24
+ root_quaternion:
25
+ The quaternion representing the rotation of the root link of the graph.
26
+
27
+ Note:
28
+ The root link of the kinematic graph is the base link.
35
29
  """
36
30
 
37
31
  root_position: npt.NDArray = np.zeros(3)
38
32
  root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
39
33
 
40
- def __eq__(self, other):
41
- return (self.root_position == other.root_position).all() and (
42
- self.root_quaternion == other.root_quaternion
43
- ).all()
34
+ def __eq__(self, other: RootPose) -> bool:
35
+
36
+ if not isinstance(other, RootPose):
37
+ return False
38
+
39
+ return np.allclose(self.root_position, other.root_position) and np.allclose(
40
+ self.root_quaternion, other.root_quaternion
41
+ )
44
42
 
45
43
 
46
44
  @dataclasses.dataclass(frozen=True)
47
45
  class KinematicGraph(Sequence[descriptions.LinkDescription]):
48
46
  """
49
- Represents a kinematic graph of links and joints.
50
-
51
- Args:
52
- root (descriptions.LinkDescription): The root link of the kinematic graph.
53
- frames (List[descriptions.LinkDescription]): A list of frame links in the graph.
54
- joints (List[descriptions.JointDescription]): A list of joint descriptions in the graph.
55
- root_pose (RootPose): The root pose of the graph.
56
- transform_cache (Dict[str, npt.NDArray]): A dictionary to cache transformation matrices.
57
- extra_info (Dict[str, Any]): Additional information associated with the graph.
47
+ Class storing a kinematic graph having links as nodes and joints as edges.
58
48
 
59
49
  Attributes:
60
- links_dict (Dict[str, descriptions.LinkDescription]): A dictionary mapping link names to link descriptions.
61
- frames_dict (Dict[str, descriptions.LinkDescription]): A dictionary mapping frame names to frame link descriptions.
62
- joints_dict (Dict[str, descriptions.JointDescription]): A dictionary mapping joint names to joint descriptions.
63
- joints_connection_dict (Dict[Tuple[str, str], descriptions.JointDescription]): A dictionary mapping pairs of parent and child link names to joint descriptions.
50
+ root: The root node of the kinematic graph.
51
+ frames: List of frames rigidly attached to the graph nodes.
52
+ joints: List of joints connecting the graph nodes.
53
+ root_pose: The pose of the kinematic graph's root.
64
54
  """
65
55
 
66
56
  root: descriptions.LinkDescription
67
- frames: List[descriptions.LinkDescription] = dataclasses.field(default_factory=list)
68
- joints: List[descriptions.JointDescription] = dataclasses.field(
57
+ frames: list[descriptions.LinkDescription] = dataclasses.field(default_factory=list)
58
+ joints: list[descriptions.JointDescription] = dataclasses.field(
69
59
  default_factory=list
70
60
  )
71
61
 
72
- root_pose: RootPose = dataclasses.field(default_factory=RootPose)
73
-
74
- transform_cache: Dict[str, npt.NDArray] = dataclasses.field(
75
- repr=False, init=False, compare=False, default_factory=dict
76
- )
62
+ root_pose: RootPose = dataclasses.field(default_factory=lambda: RootPose())
77
63
 
78
- extra_info: Dict[str, Any] = dataclasses.field(
64
+ # Private attribute storing optional additional info.
65
+ _extra_info: dict[str, Any] = dataclasses.field(
79
66
  repr=False, compare=False, default_factory=dict
80
67
  )
81
68
 
@@ -86,142 +73,189 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
86
73
  )
87
74
 
88
75
  @functools.cached_property
89
- def links_dict(self) -> Dict[str, descriptions.LinkDescription]:
76
+ def links_dict(self) -> dict[str, descriptions.LinkDescription]:
90
77
  return {l.name: l for l in iter(self)}
91
78
 
92
79
  @functools.cached_property
93
- def frames_dict(self) -> Dict[str, descriptions.LinkDescription]:
80
+ def frames_dict(self) -> dict[str, descriptions.LinkDescription]:
94
81
  return {f.name: f for f in self.frames}
95
82
 
96
83
  @functools.cached_property
97
- def joints_dict(self) -> Dict[str, descriptions.JointDescription]:
84
+ def joints_dict(self) -> dict[str, descriptions.JointDescription]:
98
85
  return {j.name: j for j in self.joints}
99
86
 
100
87
  @functools.cached_property
101
88
  def joints_connection_dict(
102
89
  self,
103
- ) -> Dict[Tuple[str, str], descriptions.JointDescription]:
90
+ ) -> dict[tuple[str, str], descriptions.JointDescription]:
104
91
  return {(j.parent.name, j.child.name): j for j in self.joints}
105
92
 
106
- def __post_init__(self):
107
- """
108
- Post-initialization method to set various properties and validate the kinematic graph.
109
- """
110
- # Assign the link index traversing the graph with BFS.
111
- # Here we assume the model is fixed-base, therefore the base link will
112
- # have index 0. We will deal with the floating base in a later stage,
113
- # when this Model object is converted to the physics model.
93
+ def __post_init__(self) -> None:
94
+
95
+ # Assign the link index by traversing the graph with BFS.
96
+ # Here we assume the model being fixed-base, therefore the base link will
97
+ # have index 0. We will deal with the floating base in a later stage.
114
98
  for index, link in enumerate(self):
115
99
  link.mutable(validate=False).index = index
116
100
 
117
- # Order frames with their name
101
+ # Get the names of the links and frames.
102
+ link_names = [l.name for l in self]
103
+ frame_names = [f.name for f in self.frames]
104
+
105
+ # Make sure that they are unique.
106
+ assert len(link_names) == len(set(link_names))
107
+ assert len(frame_names) == len(set(frame_names))
108
+ assert set(link_names).isdisjoint(set(frame_names))
109
+
110
+ # Order frames with their name.
118
111
  super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name))
119
112
 
120
113
  # Assign the frame index following the name-based indexing.
121
- # Also here, we assume the model is fixed-base, therefore the first frame will
122
- # have last_link_idx + 1. These frames are not part of the physics model.
114
+ # We assume the model being fixed-base, therefore the first frame will
115
+ # have last_link_idx + 1.
123
116
  for index, frame in enumerate(self.frames):
124
117
  with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
125
118
  frame.index = int(index + len(self.link_names()))
126
119
 
127
- # Number joints so that their index matches their child link index
120
+ # Number joints so that their index matches their child link index.
121
+ # Therefore, the first joint has index 1.
128
122
  links_dict = {l.name: l for l in iter(self)}
129
123
  for joint in self.joints:
130
124
  with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
131
125
  joint.index = links_dict[joint.child.name].index
132
126
 
133
- # Check that joint indices are unique
127
+ # Check that joint indices are unique.
134
128
  assert len([j.index for j in self.joints]) == len(
135
129
  {j.index for j in self.joints}
136
130
  )
137
131
 
138
- # Order joints with their indices
132
+ # Order joints with their indices.
139
133
  super().__setattr__("joints", sorted(self.joints, key=lambda j: j.index))
140
134
 
141
135
  @staticmethod
142
136
  def build_from(
143
- links: List[descriptions.LinkDescription],
144
- joints: List[descriptions.JointDescription],
137
+ links: list[descriptions.LinkDescription],
138
+ joints: list[descriptions.JointDescription],
139
+ frames: list[descriptions.LinkDescription] | None = None,
145
140
  root_link_name: str | None = None,
146
141
  root_pose: RootPose = RootPose(),
147
- ) -> "KinematicGraph":
142
+ ) -> KinematicGraph:
148
143
  """
149
- Build a KinematicGraph from a list of links and joints.
144
+ Build a KinematicGraph from links, joints, and frames.
150
145
 
151
146
  Args:
152
- links (List[descriptions.LinkDescription]): A list of link descriptions.
153
- joints (List[descriptions.JointDescription]): A list of joint descriptions.
154
- root_link_name (str, optional): The name of the root link. If not provided, it's assumed to be the first link's name.
155
- root_pose (RootPose, optional): The root pose of the kinematic graph.
147
+ links: A list of link descriptions.
148
+ joints: A list of joint descriptions.
149
+ frames: A list of frame descriptions.
150
+ root_link_name:
151
+ The name of the root link. If not provided, it's assumed to be the
152
+ first link's name.
153
+ root_pose: The root pose of the kinematic graph.
156
154
 
157
155
  Returns:
158
- KinematicGraph: The constructed kinematic graph.
156
+ The resulting kinematic graph.
159
157
  """
158
+
159
+ # Consider the first link as the root link if not provided.
160
160
  if root_link_name is None:
161
161
  root_link_name = links[0].name
162
+ logging.debug(msg=f"Assuming '{root_link_name}' as the root link")
162
163
 
163
164
  # Couple links and joints and create the graph of links.
164
165
  # Note that the pose of the frames is not updated; it's the caller's
165
166
  # responsibility to update their pose if they want to use them.
166
- graph_root_node, graph_joints, graph_frames, unconnected_joints = (
167
- KinematicGraph.create_graph(
168
- links=links, joints=joints, root_link_name=root_link_name
169
- )
167
+ (
168
+ graph_root_node,
169
+ graph_joints,
170
+ graph_frames,
171
+ unconnected_links,
172
+ unconnected_joints,
173
+ unconnected_frames,
174
+ ) = KinematicGraph._create_graph(
175
+ links=links, joints=joints, root_link_name=root_link_name, frames=frames
170
176
  )
171
177
 
172
- for frame in graph_frames:
173
- logging.warning(msg=f"Ignoring unconnected link / frame: '{frame.name}'")
178
+ for link in unconnected_links:
179
+ logging.warning(msg=f"Ignoring unconnected link: '{link.name}'")
174
180
 
175
181
  for joint in unconnected_joints:
176
182
  logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'")
177
183
 
184
+ for frame in unconnected_frames:
185
+ logging.warning(msg=f"Ignoring unconnected frame: '{frame.name}'")
186
+
178
187
  return KinematicGraph(
179
188
  root=graph_root_node,
180
189
  joints=graph_joints,
181
- frames=[],
190
+ frames=graph_frames,
182
191
  root_pose=root_pose,
183
192
  _joints_removed=unconnected_joints,
184
193
  )
185
194
 
186
195
  @staticmethod
187
- def create_graph(
188
- links: List[descriptions.LinkDescription],
189
- joints: List[descriptions.JointDescription],
196
+ def _create_graph(
197
+ links: list[descriptions.LinkDescription],
198
+ joints: list[descriptions.JointDescription],
190
199
  root_link_name: str,
191
- ) -> Tuple[
200
+ frames: list[descriptions.LinkDescription] | None = None,
201
+ ) -> tuple[
192
202
  descriptions.LinkDescription,
193
- List[descriptions.JointDescription],
194
- List[descriptions.LinkDescription],
195
203
  list[descriptions.JointDescription],
204
+ list[descriptions.LinkDescription],
205
+ list[descriptions.LinkDescription],
206
+ list[descriptions.JointDescription],
207
+ list[descriptions.LinkDescription],
196
208
  ]:
197
209
  """
198
- Create a kinematic graph from the lists of parsed links and joints.
210
+ Low-level creator of kinematic graph components.
199
211
 
200
212
  Args:
201
- links (List[descriptions.LinkDescription]): A list of link descriptions.
202
- joints (List[descriptions.JointDescription]): A list of joint descriptions.
203
- root_link_name (str): The name of the root link.
213
+ links: A list of parsed link descriptions.
214
+ joints: A list of parsed joint descriptions.
215
+ root_link_name: The name of the root link used as root node of the graph.
216
+ frames: A list of parsed frame descriptions.
204
217
 
205
218
  Returns:
206
- A tuple containing the root node with the full kinematic graph as child nodes,
207
- the list of joints associated to graph nodes, the list of frames rigidly
208
- attached to graph nodes, and the list of joints not part of the graph.
219
+ A tuple containing the root node of the graph (defining the entire kinematic
220
+ tree by iterating on its child nodes), the list of joints representing the
221
+ actual graph edges, the list of frames rigidly attached to the graph nodes,
222
+ the list of unconnected links, the list of unconnected joints, and the list
223
+ of unconnected frames.
209
224
  """
210
225
 
211
- # Create a dict that maps link name to the link, for easy retrieval
212
- links_dict: Dict[str, descriptions.LinkDescription] = {
226
+ # Create a dictionary that maps the link name to the link, for easy retrieval.
227
+ links_dict: dict[str, descriptions.LinkDescription] = {
213
228
  l.name: l.mutable(validate=False) for l in links
214
229
  }
215
230
 
231
+ # Create an empty list of frames if not provided.
232
+ frames = frames if frames is not None else []
233
+
234
+ # Create a dictionary that maps the frame name to the frame, for easy retrieval.
235
+ frames_dict = {frame.name: frame for frame in frames}
236
+
237
+ # Check that our parser correctly resolved the frame's parent to be a link.
238
+ for frame in frames:
239
+ assert frame.parent.name != "", frame
240
+ assert frame.parent.name is not None, frame
241
+ assert frame.parent.name != "__model__", frame
242
+ assert frame.parent.name not in frames_dict, frame
243
+
244
+ # ===========================================================
245
+ # Populate the kinematic graph with links, joints, and frames
246
+ # ===========================================================
247
+
248
+ # Check the existence of the root link.
216
249
  if root_link_name not in links_dict:
217
250
  raise ValueError(root_link_name)
218
251
 
219
- # Reset the connections of the root link
252
+ # Reset the connections of the root link.
220
253
  for link in links_dict.values():
221
254
  link.children = []
222
255
 
223
- # Couple links and joints creating the final kinematic graph
256
+ # Couple links and joints creating the kinematic graph.
224
257
  for joint in joints:
258
+
225
259
  # Get the parent and child links of the joint
226
260
  parent_link = links_dict[joint.parent.name]
227
261
  child_link = links_dict[joint.child.name]
@@ -229,48 +263,81 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
229
263
  assert child_link.name == joint.child.name
230
264
  assert parent_link.name == joint.parent.name
231
265
 
232
- # Assign link parent
266
+ # Assign link's parent.
233
267
  child_link.parent = parent_link
234
268
 
235
- # Assign link children and make sure they are unique
269
+ # Assign link's children and make sure they are unique.
236
270
  if child_link.name not in {l.name for l in parent_link.children}:
237
271
  parent_link.children.append(child_link)
238
272
 
239
- # Collect all the links of the kinematic graph
273
+ # Collect all the links of the kinematic graph.
240
274
  all_links_in_graph = list(
241
275
  KinematicGraph.breadth_first_search(root=links_dict[root_link_name])
242
276
  )
277
+
278
+ # Get the names of all links in the kinematic graph.
243
279
  all_link_names_in_graph = [l.name for l in all_links_in_graph]
244
280
 
245
- # Collect all the joints not part of the kinematic graph
246
- removed_joints = [
247
- j
248
- for j in joints
249
- if not {j.parent.name, j.child.name}.issubset(all_link_names_in_graph)
281
+ # Collect all the joints of the kinematic graph.
282
+ all_joints_in_graph = [
283
+ joint
284
+ for joint in joints
285
+ if joint.parent.name in all_link_names_in_graph
286
+ and joint.child.name in all_link_names_in_graph
250
287
  ]
251
288
 
252
- for removed_joint in removed_joints:
253
- msg = "Joint '{}' has been removed for the graph because unconnected"
254
- logging.info(msg=msg.format(removed_joint.name))
289
+ # Get the names of all joints in the kinematic graph.
290
+ all_joint_names_in_graph = [j.name for j in all_joints_in_graph]
255
291
 
256
- # Store as frames all the links that are not part of the kinematic graph
257
- frames = list(set(links) - set(all_links_in_graph))
292
+ # Collect all the frames of the kinematic graph.
293
+ # Note: our parser ensures that the parent of a frame is not another frame.
294
+ all_frames_in_graph = [
295
+ frame for frame in frames if frame.parent.name in all_link_names_in_graph
296
+ ]
258
297
 
259
- # Update the frames. In particular, reset their children. The other properties
260
- # are kept as they are, and it's caller responsibility to update them if needed.
261
- for frame in frames:
262
- frame.children = []
263
- msg = f"Link '{frame.name}' became a frame"
264
- logging.info(msg=msg)
298
+ # Get the names of all frames in the kinematic graph.
299
+ all_frames_names_in_graph = [f.name for f in all_frames_in_graph]
300
+
301
+ # ============================
302
+ # Collect unconnected elements
303
+ # ============================
304
+
305
+ # Collect all the joints that are not part of the kinematic graph.
306
+ removed_joints = [j for j in joints if j.name not in all_joint_names_in_graph]
307
+
308
+ for joint in removed_joints:
309
+ msg = "Joint '{}' is unconnected and it will be removed"
310
+ logging.debug(msg=msg.format(joint.name))
311
+
312
+ # Collect all the links that are not part of the kinematic graph.
313
+ unconnected_links = [l for l in links if l.name not in all_link_names_in_graph]
314
+
315
+ # Update the unconnected links by removing their children. The other properties
316
+ # are left untouched, it's caller responsibility to post-process them if needed.
317
+ for link in unconnected_links:
318
+ link.children = []
319
+ msg = "Link '{}' won't be part of the kinematic graph because unconnected"
320
+ logging.debug(msg=msg.format(link.name))
321
+
322
+ # Collect all the frames that are not part of the kinematic graph.
323
+ unconnected_frames = [
324
+ f for f in frames if f.name not in all_frames_names_in_graph
325
+ ]
326
+
327
+ for frame in unconnected_frames:
328
+ msg = "Frame '{}' won't be part of the kinematic graph because unconnected"
329
+ logging.debug(msg=msg.format(frame.name))
265
330
 
266
331
  return (
267
332
  links_dict[root_link_name].mutable(mutable=False),
268
333
  list(set(joints) - set(removed_joints)),
269
- frames,
334
+ all_frames_in_graph,
335
+ unconnected_links,
270
336
  list(set(removed_joints)),
337
+ unconnected_frames,
271
338
  )
272
339
 
273
- def reduce(self, considered_joints: List[str]) -> KinematicGraph:
340
+ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
274
341
  """
275
342
  Reduce the kinematic graph by removing unspecified joints.
276
343
 
@@ -366,8 +433,8 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
366
433
  )
367
434
 
368
435
  # Pop the original two links from the dictionary...
369
- links_dict.pop(link_to_remove.name)
370
- links_dict.pop(parent_of_link_to_remove.name)
436
+ _ = links_dict.pop(link_to_remove.name)
437
+ _ = links_dict.pop(parent_of_link_to_remove.name)
371
438
 
372
439
  # ... and insert the lumped link (having the same name of the parent)
373
440
  links_dict[lumped_link.name] = lumped_link
@@ -377,11 +444,13 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
377
444
  links_dict[link_to_remove.name] = lumped_link
378
445
 
379
446
  # As a consequence of the back-insertion, we need to adjust the resulting
380
- # lumped link of links that have been removed previously
447
+ # lumped link of links that have been removed previously.
448
+ # Note: in the dictionary, only items whose key is not matching value.name
449
+ # are links that have been removed.
381
450
  for previously_removed_link_name in {
382
- k
383
- for k, v in links_dict.items()
384
- if k != v.name and v.name == link_to_remove.name
451
+ link_name
452
+ for link_name, link in links_dict.items()
453
+ if link_name != link.name and link.name == link_to_remove.name
385
454
  }:
386
455
  links_dict[previously_removed_link_name] = lumped_link
387
456
 
@@ -427,19 +496,31 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
427
496
 
428
497
  # Create the reduced graph data. We pass the full list of links so that those
429
498
  # that are not part of the graph will be returned as frames.
430
- reduced_root_node, reduced_joints, reduced_frames, unconnected_joints = (
431
- KinematicGraph.create_graph(
432
- links=list(full_graph_links_dict.values()),
433
- joints=[joints_dict[joint_name] for joint_name in considered_joints],
434
- root_link_name=full_graph.root.name,
435
- )
499
+ (
500
+ reduced_root_node,
501
+ reduced_joints,
502
+ reduced_frames,
503
+ unconnected_links,
504
+ unconnected_joints,
505
+ unconnected_frames,
506
+ ) = KinematicGraph._create_graph(
507
+ links=list(full_graph_links_dict.values()),
508
+ joints=[joints_dict[joint_name] for joint_name in considered_joints],
509
+ root_link_name=full_graph.root.name,
510
+ )
511
+
512
+ assert set(f.name for f in self.frames).isdisjoint(
513
+ set(f.name for f in unconnected_frames + reduced_frames)
436
514
  )
437
515
 
438
- # Create the reduced graph
516
+ for link in unconnected_links:
517
+ logging.debug(msg=f"Link '{link.name}' is unconnected and became a frame")
518
+
519
+ # Create the reduced graph.
439
520
  reduced_graph = KinematicGraph(
440
521
  root=reduced_root_node,
441
522
  joints=reduced_joints,
442
- frames=self.frames + reduced_frames,
523
+ frames=self.frames + unconnected_links + reduced_frames,
443
524
  root_pose=full_graph.root_pose,
444
525
  _joints_removed=(
445
526
  self._joints_removed
@@ -452,58 +533,77 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
452
533
  # 4. Resolve the pose of the frames wrt their reduced graph parent
453
534
  # ================================================================
454
535
 
455
- # Update frames properties using the transforms from the full graph
456
- for frame in reduced_graph.frames:
457
- # Get the link in which the removed link was lumped into
458
- new_parent_link = links_dict[frame.name]
536
+ # Build a new object to compute FK on the reduced graph.
537
+ fk_reduced = KinematicGraphTransforms(graph=reduced_graph)
459
538
 
460
- msg = f"New parent of frame '{frame.name}' is '{new_parent_link.name}'"
461
- logging.info(msg)
539
+ # We need to adjust the pose of the frames since their parent link
540
+ # could have been removed by the reduction process.
541
+ for frame in reduced_graph.frames:
462
542
 
463
- # Update the connection of the frame
464
- frame.parent = new_parent_link
465
- frame.pose = fk.relative_transform(
466
- relative_to=new_parent_link.name, name=frame.name
543
+ # Always find the real parent link of the frame
544
+ name_of_new_parent_link = fk_reduced.find_parent_link_of_frame(
545
+ name=frame.name
467
546
  )
547
+ assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link
548
+
549
+ # Notify the user if the parent link has changed.
550
+ if name_of_new_parent_link != frame.parent.name:
551
+ msg = "New parent of frame '{}' is '{}'"
552
+ logging.debug(msg=msg.format(frame.name, name_of_new_parent_link))
553
+
554
+ # Always recompute the pose of the frame, and set zero inertial params.
555
+ with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION):
468
556
 
469
- # Update frame data
470
- frame.mass = 0.0
471
- frame.inertia = np.zeros_like(frame.inertia)
557
+ # Update kinematic parameters of the frame.
558
+ # Note that here we compute the transform using the FK object of the
559
+ # full model, so that we are sure that the kinematic is not altered.
560
+ frame.pose = fk.relative_transform(
561
+ relative_to=name_of_new_parent_link, name=frame.name
562
+ )
563
+
564
+ # Update the parent link such that the pose is expressed in its frame.
565
+ frame.parent = reduced_graph.links_dict[name_of_new_parent_link]
472
566
 
473
- # Return the reduced graph
567
+ # Update dynamic parameters of the frame.
568
+ frame.mass = 0.0
569
+ frame.inertia = np.zeros_like(frame.inertia)
570
+
571
+ # Return the reduced graph.
474
572
  return reduced_graph
475
573
 
476
- def link_names(self) -> List[str]:
574
+ def link_names(self) -> list[str]:
477
575
  """
478
- Get the names of all links in the kinematic graph.
576
+ Get the names of all links in the kinematic graph (i.e. the nodes).
479
577
 
480
578
  Returns:
481
- List[str]: A list of link names.
579
+ The list of link names.
482
580
  """
483
581
  return list(self.links_dict.keys())
484
582
 
485
- def joint_names(self) -> List[str]:
583
+ def joint_names(self) -> list[str]:
486
584
  """
487
- Get the names of all joints in the kinematic graph.
585
+ Get the names of all joints in the kinematic graph (i.e. the edges).
488
586
 
489
587
  Returns:
490
- List[str]: A list of joint names.
588
+ The list of joint names.
491
589
  """
492
590
  return list(self.joints_dict.keys())
493
591
 
494
- def frame_names(self) -> List[str]:
592
+ def frame_names(self) -> list[str]:
495
593
  """
496
594
  Get the names of all frames in the kinematic graph.
497
595
 
498
596
  Returns:
499
- List[str]: A list of frame names.
597
+ The list of frame names.
500
598
  """
599
+
501
600
  return list(self.frames_dict.keys())
502
601
 
503
602
  def print_tree(self) -> None:
504
603
  """
505
604
  Print the tree structure of the kinematic graph.
506
605
  """
606
+
507
607
  import pptree
508
608
 
509
609
  root_node = self.root
@@ -518,21 +618,23 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
518
618
  @staticmethod
519
619
  def breadth_first_search(
520
620
  root: descriptions.LinkDescription,
521
- sort_children: Optional[Callable[[Any], Any]] = lambda link: link.name,
621
+ sort_children: Callable[[Any], Any] | None = lambda link: link.name,
522
622
  ) -> Iterable[descriptions.LinkDescription]:
523
623
  """
524
624
  Perform a breadth-first search (BFS) traversal of the kinematic graph.
525
625
 
526
626
  Args:
527
- root (descriptions.LinkDescription): The root link for BFS.
528
- sort_children (Optional[Callable[[Any], Any]]): A function to sort children of a node.
627
+ root: The root link for BFS.
628
+ sort_children: A function to sort children of a node.
529
629
 
530
630
  Yields:
531
- Iterable[descriptions.LinkDescription]: An iterable of link descriptions.
631
+ The links in the kinematic graph in BFS order.
532
632
  """
633
+
634
+ # Initialize the queue with the root node.
533
635
  queue = [root]
534
636
 
535
- # We assume that nodes have unique names, and mark a link as visited using
637
+ # We assume that nodes have unique names and mark a link as visited using
536
638
  # its name. This speeds up considerably object comparison.
537
639
  visited = []
538
640
  visited.append(root.name)
@@ -540,11 +642,14 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
540
642
  yield root
541
643
 
542
644
  while len(queue) > 0:
645
+
646
+ # Extract the first element of the queue.
543
647
  l = queue.pop(0)
544
648
 
545
649
  # Note: sorting the links with their name so that the order of children
546
- # insertion does not matter when assigning the link index
650
+ # insertion does not matter when assigning the link index.
547
651
  for child in sorted(l.children, key=sort_children):
652
+
548
653
  if child.name in visited:
549
654
  continue
550
655
 
@@ -566,7 +671,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
566
671
  def __len__(self) -> int:
567
672
  return len(list(iter(self)))
568
673
 
569
- def __contains__(self, item: Union[str, descriptions.LinkDescription]) -> bool:
674
+ def __contains__(self, item: str | descriptions.LinkDescription) -> bool:
570
675
  if isinstance(item, str):
571
676
  return item in self.link_names()
572
677
 
@@ -575,7 +680,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
575
680
 
576
681
  raise TypeError(type(item).__name__)
577
682
 
578
- def __getitem__(self, key: Union[int, str]) -> descriptions.LinkDescription:
683
+ def __getitem__(self, key: int | str) -> descriptions.LinkDescription:
579
684
  if isinstance(key, str):
580
685
  if key not in self.link_names():
581
686
  raise KeyError(key)
@@ -765,7 +870,7 @@ class KinematicGraphTransforms:
765
870
  @staticmethod
766
871
  def pre_H_suc(
767
872
  joint_type: descriptions.JointType,
768
- joint_axis: descriptions.JointGenericAxis,
873
+ joint_axis: npt.NDArray,
769
874
  joint_position: float | None = None,
770
875
  ) -> npt.NDArray:
771
876
 
@@ -776,3 +881,30 @@ class KinematicGraphTransforms:
776
881
  0
777
882
  ]
778
883
  )
884
+
885
+ def find_parent_link_of_frame(self, name: str) -> str:
886
+ """
887
+ Find the parent link of a frame.
888
+
889
+ Args:
890
+ name: The name of the frame.
891
+
892
+ Returns:
893
+ The name of the parent link of the frame.
894
+ """
895
+
896
+ try:
897
+ frame = self.graph.frames_dict[name]
898
+ except KeyError as e:
899
+ raise ValueError(f"Frame '{name}' not found in the kinematic graph") from e
900
+
901
+ match frame.parent.name:
902
+ case parent_name if parent_name in self.graph.links_dict:
903
+ return parent_name
904
+
905
+ case parent_name if parent_name in self.graph.frames_dict:
906
+ return self.find_parent_link_of_frame(name=parent_name)
907
+
908
+ case _:
909
+ msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent.name}'"
910
+ raise RuntimeError(msg)