jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,258 +1,397 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import copy
|
2
4
|
import dataclasses
|
3
5
|
import functools
|
4
|
-
from
|
5
|
-
|
6
|
-
Callable,
|
7
|
-
Dict,
|
8
|
-
Iterable,
|
9
|
-
List,
|
10
|
-
NamedTuple,
|
11
|
-
Optional,
|
12
|
-
Tuple,
|
13
|
-
Union,
|
14
|
-
)
|
6
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
7
|
+
from typing import Any
|
15
8
|
|
16
9
|
import numpy as np
|
17
10
|
import numpy.typing as npt
|
18
11
|
|
12
|
+
import jaxsim.utils
|
19
13
|
from jaxsim import logging
|
20
14
|
from jaxsim.utils import Mutability
|
21
15
|
|
22
|
-
from . import
|
16
|
+
from .descriptions.joint import JointDescription, JointType
|
17
|
+
from .descriptions.link import LinkDescription
|
23
18
|
|
24
19
|
|
25
|
-
|
20
|
+
@dataclasses.dataclass
|
21
|
+
class RootPose:
|
26
22
|
"""
|
27
23
|
Represents the root pose in a kinematic graph.
|
28
24
|
|
29
25
|
Attributes:
|
30
|
-
root_position
|
31
|
-
root_quaternion
|
26
|
+
root_position: The 3D position of the root link of the graph.
|
27
|
+
root_quaternion:
|
28
|
+
The quaternion representing the rotation of the root link of the graph.
|
29
|
+
|
30
|
+
Note:
|
31
|
+
The root link of the kinematic graph is the base link.
|
32
32
|
"""
|
33
33
|
|
34
|
-
root_position: npt.NDArray = np.zeros(3)
|
35
|
-
|
34
|
+
root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
|
35
|
+
|
36
|
+
root_quaternion: npt.NDArray = dataclasses.field(
|
37
|
+
default_factory=lambda: np.array([1.0, 0, 0, 0])
|
38
|
+
)
|
39
|
+
|
40
|
+
def __hash__(self) -> int:
|
41
|
+
|
42
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
43
|
+
|
44
|
+
return hash(
|
45
|
+
(
|
46
|
+
HashedNumpyArray.hash_of_array(self.root_position),
|
47
|
+
HashedNumpyArray.hash_of_array(self.root_quaternion),
|
48
|
+
)
|
49
|
+
)
|
50
|
+
|
51
|
+
def __eq__(self, other: RootPose) -> bool:
|
52
|
+
|
53
|
+
if not isinstance(other, RootPose):
|
54
|
+
return False
|
55
|
+
|
56
|
+
if not np.allclose(self.root_position, other.root_position):
|
57
|
+
return False
|
58
|
+
|
59
|
+
if not np.allclose(self.root_quaternion, other.root_quaternion):
|
60
|
+
return False
|
61
|
+
|
62
|
+
return True
|
36
63
|
|
37
64
|
|
38
65
|
@dataclasses.dataclass(frozen=True)
|
39
|
-
class KinematicGraph:
|
66
|
+
class KinematicGraph(Sequence[LinkDescription]):
|
40
67
|
"""
|
41
|
-
|
42
|
-
|
43
|
-
Args:
|
44
|
-
root (descriptions.LinkDescription): The root link of the kinematic graph.
|
45
|
-
frames (List[descriptions.LinkDescription]): A list of frame links in the graph.
|
46
|
-
joints (List[descriptions.JointDescription]): A list of joint descriptions in the graph.
|
47
|
-
root_pose (RootPose): The root pose of the graph.
|
48
|
-
transform_cache (Dict[str, npt.NDArray]): A dictionary to cache transformation matrices.
|
49
|
-
extra_info (Dict[str, Any]): Additional information associated with the graph.
|
68
|
+
Class storing a kinematic graph having links as nodes and joints as edges.
|
50
69
|
|
51
70
|
Attributes:
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
71
|
+
root: The root node of the kinematic graph.
|
72
|
+
frames: List of frames rigidly attached to the graph nodes.
|
73
|
+
joints: List of joints connecting the graph nodes.
|
74
|
+
root_pose: The pose of the kinematic graph's root.
|
56
75
|
"""
|
57
76
|
|
58
|
-
root:
|
59
|
-
frames:
|
60
|
-
|
61
|
-
|
77
|
+
root: LinkDescription
|
78
|
+
frames: list[LinkDescription] = dataclasses.field(
|
79
|
+
default_factory=list, hash=False, compare=False
|
80
|
+
)
|
81
|
+
joints: list[JointDescription] = dataclasses.field(
|
82
|
+
default_factory=list, hash=False, compare=False
|
62
83
|
)
|
63
84
|
|
64
85
|
root_pose: RootPose = dataclasses.field(default_factory=RootPose)
|
65
86
|
|
66
|
-
|
67
|
-
|
87
|
+
# Private attribute storing optional additional info.
|
88
|
+
_extra_info: dict[str, Any] = dataclasses.field(
|
89
|
+
default_factory=dict, repr=False, hash=False, compare=False
|
68
90
|
)
|
69
91
|
|
70
|
-
|
71
|
-
|
92
|
+
# Private attribute storing the unconnected joints from the parsed model and
|
93
|
+
# the joints removed after model reduction.
|
94
|
+
_joints_removed: list[JointDescription] = dataclasses.field(
|
95
|
+
default_factory=list, repr=False, hash=False, compare=False
|
72
96
|
)
|
73
97
|
|
74
98
|
@functools.cached_property
|
75
|
-
def links_dict(self) ->
|
99
|
+
def links_dict(self) -> dict[str, LinkDescription]:
|
100
|
+
"""
|
101
|
+
Get a dictionary of links indexed by their name.
|
102
|
+
"""
|
76
103
|
return {l.name: l for l in iter(self)}
|
77
104
|
|
78
105
|
@functools.cached_property
|
79
|
-
def frames_dict(self) ->
|
106
|
+
def frames_dict(self) -> dict[str, LinkDescription]:
|
107
|
+
"""
|
108
|
+
Get a dictionary of frames indexed by their name.
|
109
|
+
"""
|
80
110
|
return {f.name: f for f in self.frames}
|
81
111
|
|
82
112
|
@functools.cached_property
|
83
|
-
def joints_dict(self) ->
|
113
|
+
def joints_dict(self) -> dict[str, JointDescription]:
|
114
|
+
"""
|
115
|
+
Get a dictionary of joints indexed by their name.
|
116
|
+
"""
|
84
117
|
return {j.name: j for j in self.joints}
|
85
118
|
|
86
119
|
@functools.cached_property
|
87
120
|
def joints_connection_dict(
|
88
121
|
self,
|
89
|
-
) ->
|
90
|
-
return {(j.parent.name, j.child.name): j for j in self.joints}
|
91
|
-
|
92
|
-
def __post_init__(self):
|
122
|
+
) -> dict[tuple[str, str], JointDescription]:
|
93
123
|
"""
|
94
|
-
|
124
|
+
Get a dictionary of joints indexed by the tuple (parent, child) link names.
|
95
125
|
"""
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
126
|
+
return {(j.parent.name, j.child.name): j for j in self.joints}
|
127
|
+
|
128
|
+
def __post_init__(self) -> None:
|
129
|
+
|
130
|
+
# Assign the link index by traversing the graph with BFS.
|
131
|
+
# Here we assume the model being fixed-base, therefore the base link will
|
132
|
+
# have index 0. We will deal with the floating base in a later stage.
|
100
133
|
for index, link in enumerate(self):
|
101
134
|
link.mutable(validate=False).index = index
|
102
135
|
|
103
|
-
#
|
136
|
+
# Get the names of the links, frames, and joints.
|
137
|
+
link_names = [l.name for l in self]
|
138
|
+
frame_names = [f.name for f in self.frames]
|
139
|
+
joint_names = [j.name for j in self.joints]
|
140
|
+
|
141
|
+
# Make sure that they are unique.
|
142
|
+
assert len(link_names) == len(set(link_names))
|
143
|
+
assert len(frame_names) == len(set(frame_names))
|
144
|
+
assert len(joint_names) == len(set(joint_names))
|
145
|
+
assert set(link_names).isdisjoint(set(frame_names))
|
146
|
+
assert set(link_names).isdisjoint(set(joint_names))
|
147
|
+
|
148
|
+
# Order frames with their name.
|
104
149
|
super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name))
|
105
150
|
|
106
151
|
# Assign the frame index following the name-based indexing.
|
107
|
-
#
|
108
|
-
# have last_link_idx + 1.
|
152
|
+
# We assume the model being fixed-base, therefore the first frame will
|
153
|
+
# have last_link_idx + 1.
|
109
154
|
for index, frame in enumerate(self.frames):
|
110
|
-
frame.
|
155
|
+
with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
156
|
+
frame.index = int(index + len(self.link_names()))
|
111
157
|
|
112
|
-
# Number joints so that their index matches their child link index
|
158
|
+
# Number joints so that their index matches their child link index.
|
159
|
+
# Therefore, the first joint has index 1.
|
113
160
|
links_dict = {l.name: l for l in iter(self)}
|
114
161
|
for joint in self.joints:
|
115
162
|
with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
116
163
|
joint.index = links_dict[joint.child.name].index
|
117
164
|
|
118
|
-
# Check that joint indices are unique
|
165
|
+
# Check that joint indices are unique.
|
119
166
|
assert len([j.index for j in self.joints]) == len(
|
120
167
|
{j.index for j in self.joints}
|
121
168
|
)
|
122
169
|
|
123
|
-
# Order joints with their indices
|
170
|
+
# Order joints with their indices.
|
124
171
|
super().__setattr__("joints", sorted(self.joints, key=lambda j: j.index))
|
125
172
|
|
126
173
|
@staticmethod
|
127
174
|
def build_from(
|
128
|
-
links:
|
129
|
-
joints:
|
175
|
+
links: list[LinkDescription],
|
176
|
+
joints: list[JointDescription],
|
177
|
+
frames: list[LinkDescription] | None = None,
|
130
178
|
root_link_name: str | None = None,
|
131
179
|
root_pose: RootPose = RootPose(),
|
132
|
-
) ->
|
180
|
+
) -> KinematicGraph:
|
133
181
|
"""
|
134
|
-
Build a KinematicGraph from
|
182
|
+
Build a KinematicGraph from links, joints, and frames.
|
135
183
|
|
136
184
|
Args:
|
137
|
-
links
|
138
|
-
joints
|
139
|
-
|
140
|
-
|
185
|
+
links: A list of link descriptions.
|
186
|
+
joints: A list of joint descriptions.
|
187
|
+
frames: A list of frame descriptions.
|
188
|
+
root_link_name:
|
189
|
+
The name of the root link. If not provided, it's assumed to be the
|
190
|
+
first link's name.
|
191
|
+
root_pose: The root pose of the kinematic graph.
|
141
192
|
|
142
193
|
Returns:
|
143
|
-
|
194
|
+
The resulting kinematic graph.
|
144
195
|
"""
|
196
|
+
|
197
|
+
# Consider the first link as the root link if not provided.
|
145
198
|
if root_link_name is None:
|
146
199
|
root_link_name = links[0].name
|
200
|
+
logging.debug(msg=f"Assuming '{root_link_name}' as the root link")
|
147
201
|
|
148
202
|
# Couple links and joints and create the graph of links.
|
149
|
-
# Note that the pose of the frames is not updated; it
|
203
|
+
# Note that the pose of the frames is not updated; it is the caller's
|
150
204
|
# responsibility to update their pose if they want to use them.
|
151
|
-
|
152
|
-
|
205
|
+
(
|
206
|
+
graph_root_node,
|
207
|
+
graph_joints,
|
208
|
+
graph_frames,
|
209
|
+
unconnected_links,
|
210
|
+
unconnected_joints,
|
211
|
+
unconnected_frames,
|
212
|
+
) = KinematicGraph._create_graph(
|
213
|
+
links=links, joints=joints, root_link_name=root_link_name, frames=frames
|
153
214
|
)
|
154
215
|
|
155
|
-
for
|
156
|
-
logging.warning(msg=f"Ignoring unconnected link
|
216
|
+
for link in unconnected_links:
|
217
|
+
logging.warning(msg=f"Ignoring unconnected link: '{link.name}'")
|
218
|
+
|
219
|
+
for joint in unconnected_joints:
|
220
|
+
logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'")
|
221
|
+
|
222
|
+
for frame in unconnected_frames:
|
223
|
+
logging.warning(msg=f"Ignoring unconnected frame: '{frame.name}'")
|
157
224
|
|
158
225
|
return KinematicGraph(
|
159
|
-
root=graph_root_node,
|
226
|
+
root=graph_root_node,
|
227
|
+
joints=graph_joints,
|
228
|
+
frames=graph_frames,
|
229
|
+
root_pose=root_pose,
|
230
|
+
_joints_removed=unconnected_joints,
|
160
231
|
)
|
161
232
|
|
162
233
|
@staticmethod
|
163
|
-
def
|
164
|
-
links:
|
165
|
-
joints:
|
234
|
+
def _create_graph(
|
235
|
+
links: list[LinkDescription],
|
236
|
+
joints: list[JointDescription],
|
166
237
|
root_link_name: str,
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
238
|
+
frames: list[LinkDescription] | None = None,
|
239
|
+
) -> tuple[
|
240
|
+
LinkDescription,
|
241
|
+
list[JointDescription],
|
242
|
+
list[LinkDescription],
|
243
|
+
list[LinkDescription],
|
244
|
+
list[JointDescription],
|
245
|
+
list[LinkDescription],
|
171
246
|
]:
|
172
247
|
"""
|
173
|
-
|
248
|
+
Low-level creator of kinematic graph components.
|
174
249
|
|
175
250
|
Args:
|
176
|
-
links
|
177
|
-
joints
|
178
|
-
root_link_name
|
251
|
+
links: A list of parsed link descriptions.
|
252
|
+
joints: A list of parsed joint descriptions.
|
253
|
+
root_link_name: The name of the root link used as root node of the graph.
|
254
|
+
frames: A list of parsed frame descriptions.
|
179
255
|
|
180
256
|
Returns:
|
181
|
-
|
182
|
-
|
257
|
+
A tuple containing the root node of the graph (defining the entire kinematic
|
258
|
+
tree by iterating on its child nodes), the list of joints representing the
|
259
|
+
actual graph edges, the list of frames rigidly attached to the graph nodes,
|
260
|
+
the list of unconnected links, the list of unconnected joints, and the list
|
261
|
+
of unconnected frames.
|
183
262
|
"""
|
184
263
|
|
185
|
-
# Create a
|
186
|
-
links_dict:
|
264
|
+
# Create a dictionary that maps the link name to the link, for easy retrieval.
|
265
|
+
links_dict: dict[str, LinkDescription] = {
|
187
266
|
l.name: l.mutable(validate=False) for l in links
|
188
267
|
}
|
189
268
|
|
269
|
+
# Create an empty list of frames if not provided.
|
270
|
+
frames = frames if frames is not None else []
|
271
|
+
|
272
|
+
# Create a dictionary that maps the frame name to the frame, for easy retrieval.
|
273
|
+
frames_dict = {frame.name: frame for frame in frames}
|
274
|
+
|
275
|
+
# Check that our parser correctly resolved the frame's parent to be a link.
|
276
|
+
for frame in frames:
|
277
|
+
assert frame.parent.name != "", frame
|
278
|
+
assert frame.parent.name is not None, frame
|
279
|
+
assert frame.parent.name != "__model__", frame
|
280
|
+
assert frame.parent.name not in frames_dict, frame
|
281
|
+
|
282
|
+
# ===========================================================
|
283
|
+
# Populate the kinematic graph with links, joints, and frames
|
284
|
+
# ===========================================================
|
285
|
+
|
286
|
+
# Check the existence of the root link.
|
190
287
|
if root_link_name not in links_dict:
|
191
288
|
raise ValueError(root_link_name)
|
192
289
|
|
193
|
-
# Reset the connections of the root link
|
290
|
+
# Reset the connections of the root link.
|
194
291
|
for link in links_dict.values():
|
195
|
-
link.children =
|
292
|
+
link.children = tuple()
|
196
293
|
|
197
|
-
# Couple links and joints creating the
|
294
|
+
# Couple links and joints creating the kinematic graph.
|
198
295
|
for joint in joints:
|
199
|
-
|
296
|
+
|
297
|
+
# Get the parent and child links of the joint.
|
200
298
|
parent_link = links_dict[joint.parent.name]
|
201
299
|
child_link = links_dict[joint.child.name]
|
202
300
|
|
203
301
|
assert child_link.name == joint.child.name
|
204
302
|
assert parent_link.name == joint.parent.name
|
205
303
|
|
206
|
-
# Assign link parent
|
304
|
+
# Assign link's parent.
|
207
305
|
child_link.parent = parent_link
|
208
306
|
|
209
|
-
# Assign link children and make sure they are unique
|
307
|
+
# Assign link's children and make sure they are unique.
|
210
308
|
if child_link.name not in {l.name for l in parent_link.children}:
|
211
|
-
parent_link.
|
309
|
+
with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):
|
310
|
+
parent_link.children = (*parent_link.children, child_link)
|
212
311
|
|
213
|
-
# Collect all the links of the kinematic graph
|
312
|
+
# Collect all the links of the kinematic graph.
|
214
313
|
all_links_in_graph = list(
|
215
314
|
KinematicGraph.breadth_first_search(root=links_dict[root_link_name])
|
216
315
|
)
|
316
|
+
|
317
|
+
# Get the names of all links in the kinematic graph.
|
217
318
|
all_link_names_in_graph = [l.name for l in all_links_in_graph]
|
218
319
|
|
219
|
-
# Collect all the joints
|
220
|
-
|
221
|
-
|
222
|
-
for
|
223
|
-
if
|
320
|
+
# Collect all the joints of the kinematic graph.
|
321
|
+
all_joints_in_graph = [
|
322
|
+
joint
|
323
|
+
for joint in joints
|
324
|
+
if joint.parent.name in all_link_names_in_graph
|
325
|
+
and joint.child.name in all_link_names_in_graph
|
224
326
|
]
|
225
327
|
|
226
|
-
|
227
|
-
|
228
|
-
logging.info(msg=msg.format(removed_joint.name))
|
328
|
+
# Get the names of all joints in the kinematic graph.
|
329
|
+
all_joint_names_in_graph = [j.name for j in all_joints_in_graph]
|
229
330
|
|
230
|
-
#
|
231
|
-
|
331
|
+
# Collect all the frames of the kinematic graph.
|
332
|
+
# Note: our parser ensures that the parent of a frame is not another frame.
|
333
|
+
all_frames_in_graph = [
|
334
|
+
frame for frame in frames if frame.parent.name in all_link_names_in_graph
|
335
|
+
]
|
232
336
|
|
233
|
-
#
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
337
|
+
# Get the names of all frames in the kinematic graph.
|
338
|
+
all_frames_names_in_graph = [f.name for f in all_frames_in_graph]
|
339
|
+
|
340
|
+
# ============================
|
341
|
+
# Collect unconnected elements
|
342
|
+
# ============================
|
343
|
+
|
344
|
+
# Collect all the joints that are not part of the kinematic graph.
|
345
|
+
removed_joints = [j for j in joints if j.name not in all_joint_names_in_graph]
|
346
|
+
|
347
|
+
for joint in removed_joints:
|
348
|
+
msg = "Joint '{}' is unconnected and it will be removed"
|
349
|
+
logging.debug(msg=msg.format(joint.name))
|
350
|
+
|
351
|
+
# Collect all the links that are not part of the kinematic graph.
|
352
|
+
unconnected_links = [l for l in links if l.name not in all_link_names_in_graph]
|
353
|
+
|
354
|
+
# Update the unconnected links by removing their children. The other properties
|
355
|
+
# are left untouched, it's caller responsibility to post-process them if needed.
|
356
|
+
for link in unconnected_links:
|
357
|
+
link.children = tuple()
|
358
|
+
msg = "Link '{}' won't be part of the kinematic graph because unconnected"
|
359
|
+
logging.debug(msg=msg.format(link.name))
|
360
|
+
|
361
|
+
# Collect all the frames that are not part of the kinematic graph.
|
362
|
+
unconnected_frames = [
|
363
|
+
f for f in frames if f.name not in all_frames_names_in_graph
|
364
|
+
]
|
365
|
+
|
366
|
+
for frame in unconnected_frames:
|
367
|
+
msg = "Frame '{}' won't be part of the kinematic graph because unconnected"
|
368
|
+
logging.debug(msg=msg.format(frame.name))
|
239
369
|
|
240
370
|
return (
|
241
371
|
links_dict[root_link_name].mutable(mutable=False),
|
242
372
|
list(set(joints) - set(removed_joints)),
|
243
|
-
|
373
|
+
all_frames_in_graph,
|
374
|
+
unconnected_links,
|
375
|
+
list(set(removed_joints)),
|
376
|
+
unconnected_frames,
|
244
377
|
)
|
245
378
|
|
246
|
-
def reduce(self, considered_joints:
|
379
|
+
def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
|
247
380
|
"""
|
248
|
-
Reduce the kinematic graph by removing
|
381
|
+
Reduce the kinematic graph by removing unspecified joints.
|
382
|
+
|
383
|
+
When a joint is removed, the mass and inertia of its child link are lumped
|
384
|
+
with those of its parent link, obtaining a new link that combines the two.
|
385
|
+
The description of the removed joint specifies the default angle (usually 0)
|
386
|
+
that is considered when the joint is removed.
|
249
387
|
|
250
388
|
Args:
|
251
|
-
considered_joints
|
389
|
+
considered_joints: A list of joint names to consider.
|
252
390
|
|
253
391
|
Returns:
|
254
|
-
|
392
|
+
The reduced kinematic graph.
|
255
393
|
"""
|
394
|
+
|
256
395
|
# The current object represents the complete kinematic graph
|
257
396
|
full_graph = self
|
258
397
|
|
@@ -263,11 +402,11 @@ class KinematicGraph:
|
|
263
402
|
|
264
403
|
# Return early if there is no action to take
|
265
404
|
if len(joint_names_to_remove) == 0:
|
266
|
-
logging.info(
|
405
|
+
logging.info("The kinematic graph doesn't need to be reduced")
|
267
406
|
return copy.deepcopy(self)
|
268
407
|
|
269
408
|
# Check if all considered joints are part of the full kinematic graph
|
270
|
-
if len(set(considered_joints) -
|
409
|
+
if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
|
271
410
|
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
|
272
411
|
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
|
273
412
|
raise ValueError(msg)
|
@@ -276,6 +415,9 @@ class KinematicGraph:
|
|
276
415
|
links_dict = copy.deepcopy(full_graph.links_dict)
|
277
416
|
joints_dict = copy.deepcopy(full_graph.joints_dict)
|
278
417
|
|
418
|
+
# Create the object to compute forward kinematics.
|
419
|
+
fk = KinematicGraphTransforms(graph=full_graph)
|
420
|
+
|
279
421
|
# The following steps are implemented below in order to create the reduced graph:
|
280
422
|
#
|
281
423
|
# 1. Lump the mass of the removed links into their parent
|
@@ -315,7 +457,7 @@ class KinematicGraph:
|
|
315
457
|
msg.format(
|
316
458
|
link_to_remove.name,
|
317
459
|
self.joints_connection_dict[
|
318
|
-
|
460
|
+
parent_of_link_to_remove.name, link_to_remove.name
|
319
461
|
].name,
|
320
462
|
parent_of_link_to_remove.name,
|
321
463
|
)
|
@@ -324,14 +466,14 @@ class KinematicGraph:
|
|
324
466
|
# Lump the link
|
325
467
|
lumped_link = parent_of_link_to_remove.lump_with(
|
326
468
|
link=link_to_remove,
|
327
|
-
lumped_H_removed=
|
469
|
+
lumped_H_removed=fk.relative_transform(
|
328
470
|
relative_to=parent_of_link_to_remove.name, name=link_to_remove.name
|
329
471
|
),
|
330
472
|
)
|
331
473
|
|
332
474
|
# Pop the original two links from the dictionary...
|
333
|
-
links_dict.pop(link_to_remove.name)
|
334
|
-
links_dict.pop(parent_of_link_to_remove.name)
|
475
|
+
_ = links_dict.pop(link_to_remove.name)
|
476
|
+
_ = links_dict.pop(parent_of_link_to_remove.name)
|
335
477
|
|
336
478
|
# ... and insert the lumped link (having the same name of the parent)
|
337
479
|
links_dict[lumped_link.name] = lumped_link
|
@@ -341,11 +483,13 @@ class KinematicGraph:
|
|
341
483
|
links_dict[link_to_remove.name] = lumped_link
|
342
484
|
|
343
485
|
# As a consequence of the back-insertion, we need to adjust the resulting
|
344
|
-
# lumped link of links that have been removed previously
|
486
|
+
# lumped link of links that have been removed previously.
|
487
|
+
# Note: in the dictionary, only items whose key is not matching value.name
|
488
|
+
# are links that have been removed.
|
345
489
|
for previously_removed_link_name in {
|
346
|
-
|
347
|
-
for
|
348
|
-
if
|
490
|
+
link_name
|
491
|
+
for link_name, link in links_dict.items()
|
492
|
+
if link_name != link.name and link.name == link_to_remove.name
|
349
493
|
}:
|
350
494
|
links_dict[previously_removed_link_name] = lumped_link
|
351
495
|
|
@@ -365,7 +509,7 @@ class KinematicGraph:
|
|
365
509
|
# Update the pose. Note that after the lumping process, the dict entry
|
366
510
|
# links_dict[joint.parent.name] contains the final lumped link
|
367
511
|
with joint.mutable_context(mutability=Mutability.MUTABLE):
|
368
|
-
joint.pose =
|
512
|
+
joint.pose = fk.relative_transform(
|
369
513
|
relative_to=links_dict[joint.parent.name].name, name=joint.name
|
370
514
|
)
|
371
515
|
with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
@@ -391,135 +535,114 @@ class KinematicGraph:
|
|
391
535
|
|
392
536
|
# Create the reduced graph data. We pass the full list of links so that those
|
393
537
|
# that are not part of the graph will be returned as frames.
|
394
|
-
|
538
|
+
(
|
539
|
+
reduced_root_node,
|
540
|
+
reduced_joints,
|
541
|
+
reduced_frames,
|
542
|
+
unconnected_links,
|
543
|
+
unconnected_joints,
|
544
|
+
unconnected_frames,
|
545
|
+
) = KinematicGraph._create_graph(
|
395
546
|
links=list(full_graph_links_dict.values()),
|
396
547
|
joints=[joints_dict[joint_name] for joint_name in considered_joints],
|
397
548
|
root_link_name=full_graph.root.name,
|
398
549
|
)
|
399
550
|
|
400
|
-
|
551
|
+
assert {f.name for f in self.frames}.isdisjoint(
|
552
|
+
{f.name for f in unconnected_frames + reduced_frames}
|
553
|
+
)
|
554
|
+
|
555
|
+
for link in unconnected_links:
|
556
|
+
logging.debug(msg=f"Link '{link.name}' is unconnected and became a frame")
|
557
|
+
|
558
|
+
# Create the reduced graph.
|
401
559
|
reduced_graph = KinematicGraph(
|
402
560
|
root=reduced_root_node,
|
403
561
|
joints=reduced_joints,
|
404
|
-
frames=reduced_frames,
|
562
|
+
frames=self.frames + unconnected_links + reduced_frames,
|
405
563
|
root_pose=full_graph.root_pose,
|
564
|
+
_joints_removed=(
|
565
|
+
self._joints_removed
|
566
|
+
+ unconnected_joints
|
567
|
+
+ [joints_dict[name] for name in joint_names_to_remove]
|
568
|
+
),
|
406
569
|
)
|
407
570
|
|
408
571
|
# ================================================================
|
409
572
|
# 4. Resolve the pose of the frames wrt their reduced graph parent
|
410
573
|
# ================================================================
|
411
574
|
|
412
|
-
#
|
413
|
-
|
414
|
-
# Get the link in which the removed link was lumped into
|
415
|
-
new_parent_link = links_dict[frame.name]
|
575
|
+
# Build a new object to compute FK on the reduced graph.
|
576
|
+
fk_reduced = KinematicGraphTransforms(graph=reduced_graph)
|
416
577
|
|
417
|
-
|
418
|
-
|
578
|
+
# We need to adjust the pose of the frames since their parent link
|
579
|
+
# could have been removed by the reduction process.
|
580
|
+
for frame in reduced_graph.frames:
|
419
581
|
|
420
|
-
#
|
421
|
-
|
422
|
-
|
423
|
-
relative_to=new_parent_link.name, name=frame.name
|
582
|
+
# Always find the real parent link of the frame
|
583
|
+
name_of_new_parent_link = fk_reduced.find_parent_link_of_frame(
|
584
|
+
name=frame.name
|
424
585
|
)
|
586
|
+
assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link
|
425
587
|
|
426
|
-
#
|
427
|
-
|
428
|
-
|
588
|
+
# Notify the user if the parent link has changed.
|
589
|
+
if name_of_new_parent_link != frame.parent.name:
|
590
|
+
msg = "New parent of frame '{}' is '{}'"
|
591
|
+
logging.debug(msg=msg.format(frame.name, name_of_new_parent_link))
|
429
592
|
|
430
|
-
|
593
|
+
# Always recompute the pose of the frame, and set zero inertial params.
|
594
|
+
with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION):
|
595
|
+
|
596
|
+
# Update kinematic parameters of the frame.
|
597
|
+
# Note that here we compute the transform using the FK object of the
|
598
|
+
# full model, so that we are sure that the kinematic is not altered.
|
599
|
+
frame.pose = fk.relative_transform(
|
600
|
+
relative_to=name_of_new_parent_link, name=frame.name
|
601
|
+
)
|
602
|
+
|
603
|
+
# Update the parent link such that the pose is expressed in its frame.
|
604
|
+
frame.parent = reduced_graph.links_dict[name_of_new_parent_link]
|
605
|
+
|
606
|
+
# Update dynamic parameters of the frame.
|
607
|
+
frame.mass = 0.0
|
608
|
+
frame.inertia = np.zeros_like(frame.inertia)
|
609
|
+
|
610
|
+
# Return the reduced graph.
|
431
611
|
return reduced_graph
|
432
612
|
|
433
|
-
def link_names(self) ->
|
613
|
+
def link_names(self) -> list[str]:
|
434
614
|
"""
|
435
|
-
Get the names of all links in the kinematic graph.
|
615
|
+
Get the names of all links in the kinematic graph (i.e. the nodes).
|
436
616
|
|
437
617
|
Returns:
|
438
|
-
|
618
|
+
The list of link names.
|
439
619
|
"""
|
440
620
|
return list(self.links_dict.keys())
|
441
621
|
|
442
|
-
def joint_names(self) ->
|
622
|
+
def joint_names(self) -> list[str]:
|
443
623
|
"""
|
444
|
-
Get the names of all joints in the kinematic graph.
|
624
|
+
Get the names of all joints in the kinematic graph (i.e. the edges).
|
445
625
|
|
446
626
|
Returns:
|
447
|
-
|
627
|
+
The list of joint names.
|
448
628
|
"""
|
449
629
|
return list(self.joints_dict.keys())
|
450
630
|
|
451
|
-
def frame_names(self) ->
|
631
|
+
def frame_names(self) -> list[str]:
|
452
632
|
"""
|
453
633
|
Get the names of all frames in the kinematic graph.
|
454
634
|
|
455
635
|
Returns:
|
456
|
-
|
636
|
+
The list of frame names.
|
457
637
|
"""
|
458
|
-
return list(self.frames_dict.keys())
|
459
|
-
|
460
|
-
def transform(self, name: str) -> npt.NDArray:
|
461
|
-
"""
|
462
|
-
Compute the transformation matrix for a given link, joint, or frame.
|
463
|
-
|
464
|
-
Args:
|
465
|
-
name (str): The name of the link, joint, or frame.
|
466
|
-
|
467
|
-
Returns:
|
468
|
-
npt.NDArray: The transformation matrix.
|
469
|
-
"""
|
470
|
-
if name in self.transform_cache:
|
471
|
-
return self.transform_cache[name]
|
472
|
-
|
473
|
-
if name in self.joint_names():
|
474
|
-
joint = self.joints_dict[name]
|
475
|
-
|
476
|
-
if joint.initial_position != 0.0:
|
477
|
-
msg = f"Ignoring unsupported initial position of joint '{name}'"
|
478
|
-
logging.warning(msg=msg)
|
479
|
-
|
480
|
-
transform = self.transform(name=joint.parent.name) @ joint.pose
|
481
|
-
self.transform_cache[name] = transform
|
482
|
-
return self.transform_cache[name]
|
483
|
-
|
484
|
-
if name in self.link_names():
|
485
|
-
link = self.links_dict[name]
|
486
638
|
|
487
|
-
|
488
|
-
return link.pose
|
489
|
-
|
490
|
-
parent_joint = self.joints_connection_dict[(link.parent.name, link.name)]
|
491
|
-
transform = self.transform(name=parent_joint.name) @ link.pose
|
492
|
-
self.transform_cache[name] = transform
|
493
|
-
return self.transform_cache[name]
|
494
|
-
|
495
|
-
# It can only be a plain frame
|
496
|
-
if name not in self.frame_names():
|
497
|
-
raise ValueError(name)
|
498
|
-
|
499
|
-
frame = self.frames_dict[name]
|
500
|
-
transform = self.transform(name=frame.parent.name) @ frame.pose
|
501
|
-
self.transform_cache[name] = transform
|
502
|
-
return self.transform_cache[name]
|
503
|
-
|
504
|
-
def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
|
505
|
-
"""
|
506
|
-
Compute the relative transformation matrix between two elements in the kinematic graph.
|
507
|
-
|
508
|
-
Args:
|
509
|
-
relative_to (str): The name of the reference element.
|
510
|
-
name (str): The name of the element to compute the relative transformation for.
|
511
|
-
|
512
|
-
Returns:
|
513
|
-
npt.NDArray: The relative transformation matrix.
|
514
|
-
"""
|
515
|
-
return np.linalg.inv(self.transform(name=relative_to)) @ self.transform(
|
516
|
-
name=name
|
517
|
-
)
|
639
|
+
return list(self.frames_dict.keys())
|
518
640
|
|
519
641
|
def print_tree(self) -> None:
|
520
642
|
"""
|
521
643
|
Print the tree structure of the kinematic graph.
|
522
644
|
"""
|
645
|
+
|
523
646
|
import pptree
|
524
647
|
|
525
648
|
root_node = self.root
|
@@ -531,24 +654,37 @@ class KinematicGraph:
|
|
531
654
|
horizontal=True,
|
532
655
|
)
|
533
656
|
|
657
|
+
@property
|
658
|
+
def joints_removed(self) -> list[JointDescription]:
|
659
|
+
"""
|
660
|
+
Get the list of joints removed during the graph reduction.
|
661
|
+
|
662
|
+
Returns:
|
663
|
+
The list of removed joints.
|
664
|
+
"""
|
665
|
+
|
666
|
+
return self._joints_removed
|
667
|
+
|
534
668
|
@staticmethod
|
535
669
|
def breadth_first_search(
|
536
|
-
root:
|
537
|
-
sort_children:
|
538
|
-
) -> Iterable[
|
670
|
+
root: LinkDescription,
|
671
|
+
sort_children: Callable[[Any], Any] | None = lambda link: link.name,
|
672
|
+
) -> Iterable[LinkDescription]:
|
539
673
|
"""
|
540
674
|
Perform a breadth-first search (BFS) traversal of the kinematic graph.
|
541
675
|
|
542
676
|
Args:
|
543
|
-
root
|
544
|
-
sort_children
|
677
|
+
root: The root link for BFS.
|
678
|
+
sort_children: A function to sort children of a node.
|
545
679
|
|
546
680
|
Yields:
|
547
|
-
|
681
|
+
The links in the kinematic graph in BFS order.
|
548
682
|
"""
|
683
|
+
|
684
|
+
# Initialize the queue with the root node.
|
549
685
|
queue = [root]
|
550
686
|
|
551
|
-
# We assume that nodes have unique names
|
687
|
+
# We assume that nodes have unique names and mark a link as visited using
|
552
688
|
# its name. This speeds up considerably object comparison.
|
553
689
|
visited = []
|
554
690
|
visited.append(root.name)
|
@@ -556,11 +692,14 @@ class KinematicGraph:
|
|
556
692
|
yield root
|
557
693
|
|
558
694
|
while len(queue) > 0:
|
695
|
+
|
696
|
+
# Extract the first element of the queue.
|
559
697
|
l = queue.pop(0)
|
560
698
|
|
561
699
|
# Note: sorting the links with their name so that the order of children
|
562
|
-
# insertion does not matter when assigning the link index
|
700
|
+
# insertion does not matter when assigning the link index.
|
563
701
|
for child in sorted(l.children, key=sort_children):
|
702
|
+
|
564
703
|
if child.name in visited:
|
565
704
|
continue
|
566
705
|
|
@@ -569,25 +708,29 @@ class KinematicGraph:
|
|
569
708
|
|
570
709
|
yield child
|
571
710
|
|
572
|
-
|
711
|
+
# =================
|
712
|
+
# Sequence protocol
|
713
|
+
# =================
|
714
|
+
|
715
|
+
def __iter__(self) -> Iterator[LinkDescription]:
|
573
716
|
yield from KinematicGraph.breadth_first_search(root=self.root)
|
574
717
|
|
575
|
-
def __reversed__(self) -> Iterable[
|
718
|
+
def __reversed__(self) -> Iterable[LinkDescription]:
|
576
719
|
yield from reversed(list(iter(self)))
|
577
720
|
|
578
721
|
def __len__(self) -> int:
|
579
722
|
return len(list(iter(self)))
|
580
723
|
|
581
|
-
def __contains__(self, item:
|
724
|
+
def __contains__(self, item: str | LinkDescription) -> bool:
|
582
725
|
if isinstance(item, str):
|
583
726
|
return item in self.link_names()
|
584
727
|
|
585
|
-
if isinstance(item,
|
728
|
+
if isinstance(item, LinkDescription):
|
586
729
|
return item in set(iter(self))
|
587
730
|
|
588
731
|
raise TypeError(type(item).__name__)
|
589
732
|
|
590
|
-
def __getitem__(self, key:
|
733
|
+
def __getitem__(self, key: int | str) -> LinkDescription:
|
591
734
|
if isinstance(key, str):
|
592
735
|
if key not in self.link_names():
|
593
736
|
raise KeyError(key)
|
@@ -601,3 +744,242 @@ class KinematicGraph:
|
|
601
744
|
return list(iter(self))[key]
|
602
745
|
|
603
746
|
raise TypeError(type(key).__name__)
|
747
|
+
|
748
|
+
def count(self, value: LinkDescription) -> int:
|
749
|
+
"""
|
750
|
+
Count the occurrences of a link in the kinematic graph.
|
751
|
+
"""
|
752
|
+
return list(iter(self)).count(value)
|
753
|
+
|
754
|
+
def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:
|
755
|
+
"""
|
756
|
+
Find the index of a link in the kinematic graph.
|
757
|
+
"""
|
758
|
+
return list(iter(self)).index(value, start, stop)
|
759
|
+
|
760
|
+
|
761
|
+
# ====================
|
762
|
+
# Other useful classes
|
763
|
+
# ====================
|
764
|
+
|
765
|
+
|
766
|
+
@dataclasses.dataclass(frozen=True)
|
767
|
+
class KinematicGraphTransforms:
|
768
|
+
"""
|
769
|
+
Class to compute forward kinematics on a kinematic graph.
|
770
|
+
|
771
|
+
Attributes:
|
772
|
+
graph: The kinematic graph on which to compute forward kinematics.
|
773
|
+
"""
|
774
|
+
|
775
|
+
graph: KinematicGraph
|
776
|
+
|
777
|
+
_transform_cache: dict[str, npt.NDArray] = dataclasses.field(
|
778
|
+
default_factory=dict, init=False, repr=False, compare=False
|
779
|
+
)
|
780
|
+
|
781
|
+
_initial_joint_positions: dict[str, float] = dataclasses.field(
|
782
|
+
init=False, repr=False, compare=False
|
783
|
+
)
|
784
|
+
|
785
|
+
def __post_init__(self) -> None:
|
786
|
+
|
787
|
+
super().__setattr__(
|
788
|
+
"_initial_joint_positions",
|
789
|
+
{joint.name: joint.initial_position for joint in self.graph.joints},
|
790
|
+
)
|
791
|
+
|
792
|
+
@property
|
793
|
+
def initial_joint_positions(self) -> npt.NDArray:
|
794
|
+
"""
|
795
|
+
Get the initial joint positions of the kinematic graph.
|
796
|
+
"""
|
797
|
+
|
798
|
+
return np.atleast_1d(
|
799
|
+
np.array(list(self._initial_joint_positions.values()))
|
800
|
+
).astype(float)
|
801
|
+
|
802
|
+
@initial_joint_positions.setter
|
803
|
+
def initial_joint_positions(
|
804
|
+
self,
|
805
|
+
positions: npt.NDArray | Sequence,
|
806
|
+
joint_names: Sequence[str] | None = None,
|
807
|
+
) -> None:
|
808
|
+
|
809
|
+
joint_names = (
|
810
|
+
joint_names
|
811
|
+
if joint_names is not None
|
812
|
+
else list(self._initial_joint_positions.keys())
|
813
|
+
)
|
814
|
+
|
815
|
+
s = np.atleast_1d(np.array(positions).squeeze())
|
816
|
+
|
817
|
+
if s.size != len(joint_names):
|
818
|
+
raise ValueError(s.size, len(joint_names))
|
819
|
+
|
820
|
+
for joint_name in joint_names:
|
821
|
+
if joint_name not in self._initial_joint_positions:
|
822
|
+
raise ValueError(joint_name)
|
823
|
+
|
824
|
+
# Clear transform cache.
|
825
|
+
self._transform_cache.clear()
|
826
|
+
|
827
|
+
# Update initial joint positions.
|
828
|
+
for joint_name, position in zip(joint_names, s, strict=True):
|
829
|
+
self._initial_joint_positions[joint_name] = position
|
830
|
+
|
831
|
+
def transform(self, name: str) -> npt.NDArray:
|
832
|
+
"""
|
833
|
+
Compute the SE(3) transform of elements belonging to the kinematic graph.
|
834
|
+
|
835
|
+
Args:
|
836
|
+
name: The name of a link, a joint, or a frame.
|
837
|
+
|
838
|
+
Returns:
|
839
|
+
The 4x4 transform matrix of the element w.r.t. the model frame.
|
840
|
+
"""
|
841
|
+
|
842
|
+
# If the transform was already computed, return it.
|
843
|
+
if name in self._transform_cache:
|
844
|
+
return self._transform_cache[name]
|
845
|
+
|
846
|
+
# If the name is a joint, compute M_H_J transform.
|
847
|
+
if name in self.graph.joint_names():
|
848
|
+
|
849
|
+
# Get the joint.
|
850
|
+
joint = self.graph.joints_dict[name]
|
851
|
+
assert joint.name == name
|
852
|
+
|
853
|
+
# Get the transform of the parent link.
|
854
|
+
M_H_L = self.transform(name=joint.parent.name)
|
855
|
+
|
856
|
+
# Rename the pose of the predecessor joint frame w.r.t. its parent link.
|
857
|
+
L_H_pre = joint.pose
|
858
|
+
|
859
|
+
# Compute the joint transform from the predecessor to the successor frame.
|
860
|
+
pre_H_J = self.pre_H_suc(
|
861
|
+
joint_type=joint.jtype,
|
862
|
+
joint_axis=joint.axis,
|
863
|
+
joint_position=self._initial_joint_positions[joint.name],
|
864
|
+
)
|
865
|
+
|
866
|
+
# Compute the M_H_J transform.
|
867
|
+
self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J
|
868
|
+
return self._transform_cache[name]
|
869
|
+
|
870
|
+
# If the name is a link, compute M_H_L transform.
|
871
|
+
if name in self.graph.link_names():
|
872
|
+
|
873
|
+
# Get the link.
|
874
|
+
link = self.graph.links_dict[name]
|
875
|
+
|
876
|
+
# Handle the pose between the __model__ frame and the root link.
|
877
|
+
if link.name == self.graph.root.name:
|
878
|
+
M_H_B = link.pose
|
879
|
+
return M_H_B
|
880
|
+
|
881
|
+
# Get the joint between the link and its parent.
|
882
|
+
parent_joint = self.graph.joints_connection_dict[
|
883
|
+
link.parent.name, link.name
|
884
|
+
]
|
885
|
+
|
886
|
+
# Get the transform of the parent joint.
|
887
|
+
M_H_J = self.transform(name=parent_joint.name)
|
888
|
+
|
889
|
+
# Rename the pose of the link w.r.t. its parent joint.
|
890
|
+
J_H_L = link.pose
|
891
|
+
|
892
|
+
# Compute the M_H_L transform.
|
893
|
+
self._transform_cache[name] = M_H_J @ J_H_L
|
894
|
+
return self._transform_cache[name]
|
895
|
+
|
896
|
+
# It can only be a plain frame.
|
897
|
+
if name not in self.graph.frame_names():
|
898
|
+
raise ValueError(name)
|
899
|
+
|
900
|
+
# Get the frame.
|
901
|
+
frame = self.graph.frames_dict[name]
|
902
|
+
|
903
|
+
# Get the transform of the parent link.
|
904
|
+
M_H_L = self.transform(name=frame.parent.name)
|
905
|
+
|
906
|
+
# Rename the pose of the frame w.r.t. its parent link.
|
907
|
+
L_H_F = frame.pose
|
908
|
+
|
909
|
+
# Compute the M_H_F transform.
|
910
|
+
self._transform_cache[name] = M_H_L @ L_H_F
|
911
|
+
return self._transform_cache[name]
|
912
|
+
|
913
|
+
def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
|
914
|
+
"""
|
915
|
+
Compute the SE(3) relative transform of elements belonging to the kinematic graph.
|
916
|
+
|
917
|
+
Args:
|
918
|
+
relative_to: The name of the reference element.
|
919
|
+
name: The name of a link, a joint, or a frame.
|
920
|
+
|
921
|
+
Returns:
|
922
|
+
The 4x4 transform matrix of the element w.r.t. the desired frame.
|
923
|
+
"""
|
924
|
+
|
925
|
+
import jaxsim.math
|
926
|
+
|
927
|
+
M_H_target = self.transform(name=name)
|
928
|
+
M_H_R = self.transform(name=relative_to)
|
929
|
+
|
930
|
+
# Compute the relative transform R_H_target, where R is the reference frame,
|
931
|
+
# and i the frame of the desired link|joint|frame.
|
932
|
+
return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target
|
933
|
+
|
934
|
+
@staticmethod
|
935
|
+
def pre_H_suc(
|
936
|
+
joint_type: JointType,
|
937
|
+
joint_axis: npt.NDArray,
|
938
|
+
joint_position: float | None = None,
|
939
|
+
) -> npt.NDArray:
|
940
|
+
"""
|
941
|
+
Compute the SE(3) transform from the predecessor to the successor frame.
|
942
|
+
|
943
|
+
Args:
|
944
|
+
joint_type: The type of the joint.
|
945
|
+
joint_axis: The axis of the joint.
|
946
|
+
joint_position: The position of the joint.
|
947
|
+
|
948
|
+
Returns:
|
949
|
+
The 4x4 transform matrix from the predecessor to the successor frame.
|
950
|
+
"""
|
951
|
+
|
952
|
+
import jaxsim.math
|
953
|
+
|
954
|
+
return np.array(
|
955
|
+
jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[
|
956
|
+
0
|
957
|
+
]
|
958
|
+
)
|
959
|
+
|
960
|
+
def find_parent_link_of_frame(self, name: str) -> str:
|
961
|
+
"""
|
962
|
+
Find the parent link of a frame.
|
963
|
+
|
964
|
+
Args:
|
965
|
+
name: The name of the frame.
|
966
|
+
|
967
|
+
Returns:
|
968
|
+
The name of the parent link of the frame.
|
969
|
+
"""
|
970
|
+
|
971
|
+
try:
|
972
|
+
frame = self.graph.frames_dict[name]
|
973
|
+
except KeyError as e:
|
974
|
+
raise ValueError(f"Frame '{name}' not found in the kinematic graph") from e
|
975
|
+
|
976
|
+
match frame.parent.name:
|
977
|
+
case parent_name if parent_name in self.graph.links_dict:
|
978
|
+
return parent_name
|
979
|
+
|
980
|
+
case parent_name if parent_name in self.graph.frames_dict:
|
981
|
+
return self.find_parent_link_of_frame(name=parent_name)
|
982
|
+
|
983
|
+
case _:
|
984
|
+
msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent.name}'"
|
985
|
+
raise RuntimeError(msg)
|