jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +64 -30
- jaxsim/math/cross.py +18 -9
- jaxsim/math/inertia.py +11 -9
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +59 -25
- jaxsim/math/rotation.py +30 -24
- jaxsim/math/skew.py +18 -7
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,901 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
|
5
|
+
import jax.lax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax_dataclasses
|
8
|
+
import numpy as np
|
9
|
+
import numpy.typing as npt
|
10
|
+
from jax_dataclasses import Static
|
11
|
+
|
12
|
+
import jaxsim.typing as jtp
|
13
|
+
from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
|
14
|
+
from jaxsim.parsers.descriptions import JointDescription, ModelDescription
|
15
|
+
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
|
16
|
+
|
17
|
+
|
18
|
+
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
|
19
|
+
class KinDynParameters(JaxsimDataclass):
|
20
|
+
r"""
|
21
|
+
Class storing the kinematic and dynamic parameters of a model.
|
22
|
+
|
23
|
+
Attributes:
|
24
|
+
link_names: The names of the links.
|
25
|
+
parent_array: The parent array :math:`\lambda(i)` of the model.
|
26
|
+
support_body_array_bool:
|
27
|
+
The boolean support parent array :math:`\kappa_{b}(i)` of the model.
|
28
|
+
link_parameters: The parameters of the links.
|
29
|
+
frame_parameters: The parameters of the frames.
|
30
|
+
contact_parameters: The parameters of the collidable points.
|
31
|
+
joint_model: The joint model of the model.
|
32
|
+
joint_parameters: The parameters of the joints.
|
33
|
+
"""
|
34
|
+
|
35
|
+
# Static
|
36
|
+
link_names: Static[tuple[str]]
|
37
|
+
_parent_array: Static[HashedNumpyArray]
|
38
|
+
_support_body_array_bool: Static[HashedNumpyArray]
|
39
|
+
|
40
|
+
# Links
|
41
|
+
link_parameters: LinkParameters
|
42
|
+
|
43
|
+
# Contacts
|
44
|
+
contact_parameters: ContactParameters
|
45
|
+
|
46
|
+
# Frames
|
47
|
+
frame_parameters: FrameParameters
|
48
|
+
|
49
|
+
# Joints
|
50
|
+
joint_model: JointModel
|
51
|
+
joint_parameters: JointParameters | None
|
52
|
+
|
53
|
+
@property
|
54
|
+
def parent_array(self) -> jtp.Vector:
|
55
|
+
r"""
|
56
|
+
Return the parent array :math:`\lambda(i)` of the model.
|
57
|
+
"""
|
58
|
+
return self._parent_array.get()
|
59
|
+
|
60
|
+
@property
|
61
|
+
def support_body_array_bool(self) -> jtp.Matrix:
|
62
|
+
r"""
|
63
|
+
Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
|
64
|
+
"""
|
65
|
+
return self._support_body_array_bool.get()
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def build(model_description: ModelDescription) -> KinDynParameters:
|
69
|
+
"""
|
70
|
+
Construct the kinematic and dynamic parameters of the model.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
model_description: The parsed model description to consider.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
The kinematic and dynamic parameters of the model.
|
77
|
+
|
78
|
+
Note:
|
79
|
+
This class is meant to ease the management of parametric models in
|
80
|
+
an automatic differentiation context.
|
81
|
+
"""
|
82
|
+
|
83
|
+
# Extract the links ordered by their index.
|
84
|
+
# The link index corresponds to the body index ∈ [0, num_bodies - 1].
|
85
|
+
ordered_links = sorted(
|
86
|
+
list(model_description.links_dict.values()),
|
87
|
+
key=lambda l: l.index,
|
88
|
+
)
|
89
|
+
|
90
|
+
# Extract the joints ordered by their index.
|
91
|
+
# The joint index matches the index of its child link, therefore it starts
|
92
|
+
# from 1. Keep this in mind since this 1-indexing might introduce bugs.
|
93
|
+
ordered_joints = sorted(
|
94
|
+
list(model_description.joints_dict.values()),
|
95
|
+
key=lambda j: j.index,
|
96
|
+
)
|
97
|
+
|
98
|
+
# ================
|
99
|
+
# Links properties
|
100
|
+
# ================
|
101
|
+
|
102
|
+
# Create a list of link parameters objects.
|
103
|
+
link_parameters_list = [
|
104
|
+
LinkParameters.build_from_spatial_inertia(index=link.index, M=link.inertia)
|
105
|
+
for link in ordered_links
|
106
|
+
]
|
107
|
+
|
108
|
+
# Create a vectorized object of link parameters.
|
109
|
+
link_parameters = jax.tree.map(lambda *l: jnp.stack(l), *link_parameters_list)
|
110
|
+
|
111
|
+
# =================
|
112
|
+
# Joints properties
|
113
|
+
# =================
|
114
|
+
|
115
|
+
# Create a list of joint parameters objects.
|
116
|
+
joint_parameters_list = [
|
117
|
+
JointParameters.build_from_joint_description(joint_description=joint)
|
118
|
+
for joint in ordered_joints
|
119
|
+
]
|
120
|
+
|
121
|
+
# Create a vectorized object of joint parameters.
|
122
|
+
joint_parameters = (
|
123
|
+
jax.tree.map(lambda *l: jnp.stack(l), *joint_parameters_list)
|
124
|
+
if len(ordered_joints) > 0
|
125
|
+
else JointParameters(
|
126
|
+
index=jnp.array([], dtype=int),
|
127
|
+
friction_static=jnp.array([], dtype=float),
|
128
|
+
friction_viscous=jnp.array([], dtype=float),
|
129
|
+
position_limits_min=jnp.array([], dtype=float),
|
130
|
+
position_limits_max=jnp.array([], dtype=float),
|
131
|
+
position_limit_spring=jnp.array([], dtype=float),
|
132
|
+
position_limit_damper=jnp.array([], dtype=float),
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
# Create an object that defines the joint model (parent-to-child transforms).
|
137
|
+
joint_model = JointModel.build(description=model_description)
|
138
|
+
|
139
|
+
# ===================
|
140
|
+
# Contacts properties
|
141
|
+
# ===================
|
142
|
+
|
143
|
+
# Create the object storing the parameters of collidable points.
|
144
|
+
# Note that, contrarily to LinkParameters and JointsParameters, this object
|
145
|
+
# is not created with vmap. This is because the "body" attribute of the object
|
146
|
+
# must be Static for JIT-related reasons, and tree_map would not consider it
|
147
|
+
# as a leaf.
|
148
|
+
contact_parameters = ContactParameters.build_from(
|
149
|
+
model_description=model_description
|
150
|
+
)
|
151
|
+
|
152
|
+
# =================
|
153
|
+
# Frames properties
|
154
|
+
# =================
|
155
|
+
|
156
|
+
# Create the object storing the parameters of frames.
|
157
|
+
# Note that, contrarily to LinkParameters and JointsParameters, this object
|
158
|
+
# is not created with vmap. This is because the "name" attribute of the object
|
159
|
+
# must be Static for JIT-related reasons, and tree_map would not consider it
|
160
|
+
# as a leaf.
|
161
|
+
frame_parameters = FrameParameters.build_from(
|
162
|
+
model_description=model_description
|
163
|
+
)
|
164
|
+
|
165
|
+
# ===============
|
166
|
+
# Tree properties
|
167
|
+
# ===============
|
168
|
+
|
169
|
+
# Build the parent array λ(i) of the model.
|
170
|
+
# Note: the parent of the base link is not set since it's not defined.
|
171
|
+
parent_array_dict = {
|
172
|
+
link.index: link.parent.index
|
173
|
+
for link in ordered_links
|
174
|
+
if link.parent is not None
|
175
|
+
}
|
176
|
+
parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int)
|
177
|
+
|
178
|
+
# Instead of building the support parent array κ(i) for each link of the model,
|
179
|
+
# that has a variable length depending on the number of links connecting the
|
180
|
+
# root to the i-th link, we build the corresponding boolean version.
|
181
|
+
# Given a link index i, the boolean support parent array κb(i) is an array
|
182
|
+
# with the same number of elements of λ(i) having the i-th element set to True
|
183
|
+
# if the i-th link is in the support parent array κ(i), False otherwise.
|
184
|
+
# We store the boolean κb(i) as static attribute of the PyTree so that
|
185
|
+
# algorithms that need to access it can be jit-compiled.
|
186
|
+
def κb(link_index: jtp.IntLike) -> jtp.Vector:
|
187
|
+
κb = jnp.zeros(len(ordered_links), dtype=bool)
|
188
|
+
|
189
|
+
carry0 = κb, link_index
|
190
|
+
|
191
|
+
def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
|
192
|
+
|
193
|
+
κb, active_link_index = carry
|
194
|
+
|
195
|
+
κb, active_link_index = jax.lax.cond(
|
196
|
+
pred=(i == active_link_index),
|
197
|
+
false_fun=lambda: (κb, active_link_index),
|
198
|
+
true_fun=lambda: (
|
199
|
+
κb.at[active_link_index].set(True),
|
200
|
+
parent_array[active_link_index],
|
201
|
+
),
|
202
|
+
)
|
203
|
+
|
204
|
+
return (κb, active_link_index), None
|
205
|
+
|
206
|
+
(κb, _), _ = jax.lax.scan(
|
207
|
+
f=scan_body,
|
208
|
+
init=carry0,
|
209
|
+
xs=jnp.flip(jnp.arange(start=0, stop=len(ordered_links))),
|
210
|
+
)
|
211
|
+
|
212
|
+
return κb
|
213
|
+
|
214
|
+
support_body_array_bool = jax.vmap(κb)(
|
215
|
+
jnp.arange(start=0, stop=len(ordered_links))
|
216
|
+
)
|
217
|
+
|
218
|
+
# =================================
|
219
|
+
# Build and return KinDynParameters
|
220
|
+
# =================================
|
221
|
+
|
222
|
+
return KinDynParameters(
|
223
|
+
link_names=tuple(l.name for l in ordered_links),
|
224
|
+
_parent_array=HashedNumpyArray(array=parent_array),
|
225
|
+
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
|
226
|
+
link_parameters=link_parameters,
|
227
|
+
joint_model=joint_model,
|
228
|
+
joint_parameters=joint_parameters,
|
229
|
+
contact_parameters=contact_parameters,
|
230
|
+
frame_parameters=frame_parameters,
|
231
|
+
)
|
232
|
+
|
233
|
+
def __eq__(self, other: KinDynParameters) -> bool:
|
234
|
+
|
235
|
+
if not isinstance(other, KinDynParameters):
|
236
|
+
return False
|
237
|
+
|
238
|
+
return hash(self) == hash(other)
|
239
|
+
|
240
|
+
def __hash__(self) -> int:
|
241
|
+
|
242
|
+
return hash(
|
243
|
+
(
|
244
|
+
hash(self.number_of_links()),
|
245
|
+
hash(self.number_of_joints()),
|
246
|
+
hash(self.frame_parameters.name),
|
247
|
+
hash(self.frame_parameters.body),
|
248
|
+
hash(self._parent_array),
|
249
|
+
hash(self._support_body_array_bool),
|
250
|
+
)
|
251
|
+
)
|
252
|
+
|
253
|
+
# =============================
|
254
|
+
# Helpers to extract parameters
|
255
|
+
# =============================
|
256
|
+
|
257
|
+
def number_of_links(self) -> int:
|
258
|
+
"""
|
259
|
+
Return the number of links of the model.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
The number of links of the model.
|
263
|
+
"""
|
264
|
+
|
265
|
+
return len(self.link_names)
|
266
|
+
|
267
|
+
def number_of_joints(self) -> int:
|
268
|
+
"""
|
269
|
+
Return the number of joints of the model.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
The number of joints of the model.
|
273
|
+
"""
|
274
|
+
|
275
|
+
return len(self.joint_model.joint_names) - 1
|
276
|
+
|
277
|
+
def number_of_frames(self) -> int:
|
278
|
+
"""
|
279
|
+
Return the number of frames of the model.
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
The number of frames of the model.
|
283
|
+
"""
|
284
|
+
|
285
|
+
return len(self.frame_parameters.name)
|
286
|
+
|
287
|
+
def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
|
288
|
+
r"""
|
289
|
+
Return the support parent array :math:`\kappa(i)` of a link.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
link_index: The index of the link.
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
The support parent array :math:`\kappa(i)` of the link.
|
296
|
+
|
297
|
+
Note:
|
298
|
+
This method returns a variable-length vector. In jit-compiled functions,
|
299
|
+
it's better to use the (static) boolean version `support_body_array_bool`.
|
300
|
+
"""
|
301
|
+
|
302
|
+
return jnp.array(
|
303
|
+
jnp.where(self.support_body_array_bool[link_index])[0], dtype=int
|
304
|
+
)
|
305
|
+
|
306
|
+
# ========================
|
307
|
+
# Quantities used by RBDAs
|
308
|
+
# ========================
|
309
|
+
|
310
|
+
@jax.jit
|
311
|
+
def links_spatial_inertia(self) -> jtp.Array:
|
312
|
+
"""
|
313
|
+
Return the spatial inertia of all links of the model.
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
The spatial inertia of all links of the model.
|
317
|
+
"""
|
318
|
+
|
319
|
+
return jax.vmap(LinkParameters.spatial_inertia)(self.link_parameters)
|
320
|
+
|
321
|
+
@jax.jit
|
322
|
+
def tree_transforms(self) -> jtp.Array:
|
323
|
+
r"""
|
324
|
+
Return the tree transforms of the model.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
The transforms
|
328
|
+
:math:`{}^{\text{pre}(i)} H_{\lambda(i)}`
|
329
|
+
of all joints of the model.
|
330
|
+
"""
|
331
|
+
|
332
|
+
pre_Xi_λ = jax.vmap(
|
333
|
+
lambda i: self.joint_model.parent_H_predecessor(joint_index=i)
|
334
|
+
.inverse()
|
335
|
+
.adjoint()
|
336
|
+
)(jnp.arange(1, self.number_of_joints() + 1))
|
337
|
+
|
338
|
+
return jnp.vstack(
|
339
|
+
[
|
340
|
+
jnp.zeros(shape=(1, 6, 6), dtype=float),
|
341
|
+
pre_Xi_λ,
|
342
|
+
]
|
343
|
+
)
|
344
|
+
|
345
|
+
@jax.jit
|
346
|
+
def joint_transforms(
|
347
|
+
self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
|
348
|
+
) -> jtp.Array:
|
349
|
+
r"""
|
350
|
+
Return the transforms of the joints.
|
351
|
+
|
352
|
+
Args:
|
353
|
+
joint_positions: The joint positions.
|
354
|
+
base_transform: The homogeneous matrix defining the base pose.
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
The stacked transforms
|
358
|
+
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
|
359
|
+
of each joint.
|
360
|
+
"""
|
361
|
+
|
362
|
+
return self.joint_transforms_and_motion_subspaces(
|
363
|
+
joint_positions=joint_positions,
|
364
|
+
base_transform=base_transform,
|
365
|
+
)[0]
|
366
|
+
|
367
|
+
@jax.jit
|
368
|
+
def joint_motion_subspaces(
|
369
|
+
self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
|
370
|
+
) -> jtp.Array:
|
371
|
+
r"""
|
372
|
+
Return the motion subspaces of the joints.
|
373
|
+
|
374
|
+
Args:
|
375
|
+
joint_positions: The joint positions.
|
376
|
+
base_transform: The homogeneous matrix defining the base pose.
|
377
|
+
|
378
|
+
Returns:
|
379
|
+
The stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
|
380
|
+
"""
|
381
|
+
|
382
|
+
return self.joint_transforms_and_motion_subspaces(
|
383
|
+
joint_positions=joint_positions,
|
384
|
+
base_transform=base_transform,
|
385
|
+
)[1]
|
386
|
+
|
387
|
+
@jax.jit
|
388
|
+
def joint_transforms_and_motion_subspaces(
|
389
|
+
self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
|
390
|
+
) -> tuple[jtp.Array, jtp.Array]:
|
391
|
+
r"""
|
392
|
+
Return the transforms and the motion subspaces of the joints.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
joint_positions: The joint positions.
|
396
|
+
base_transform: The homogeneous matrix defining the base pose.
|
397
|
+
|
398
|
+
Returns:
|
399
|
+
A tuple containing the stacked transforms
|
400
|
+
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
|
401
|
+
and the stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
|
402
|
+
|
403
|
+
Note:
|
404
|
+
The first transform, at index 0, provides the pose of the base link
|
405
|
+
w.r.t. the world frame. For both floating-base and fixed-base systems,
|
406
|
+
it takes into account the base pose and the optional transform
|
407
|
+
between the root frame of the model and the base link.
|
408
|
+
"""
|
409
|
+
|
410
|
+
# Rename the base transform.
|
411
|
+
W_H_B = base_transform
|
412
|
+
|
413
|
+
# Extract the parent-to-predecessor fixed transforms of the joints.
|
414
|
+
λ_H_pre = jnp.vstack(
|
415
|
+
[
|
416
|
+
jnp.eye(4)[jnp.newaxis],
|
417
|
+
self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()],
|
418
|
+
]
|
419
|
+
)
|
420
|
+
|
421
|
+
# Compute the transforms and motion subspaces of the joints.
|
422
|
+
if self.number_of_joints() == 0:
|
423
|
+
pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1))
|
424
|
+
else:
|
425
|
+
pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
|
426
|
+
jnp.array(self.joint_model.joint_types[1:]).astype(int),
|
427
|
+
jnp.array(joint_positions),
|
428
|
+
jnp.array([j.axis for j in self.joint_model.joint_axis]),
|
429
|
+
)
|
430
|
+
|
431
|
+
# Extract the transforms and motion subspaces of the joints.
|
432
|
+
# We stack the base transform W_H_B at index 0, and a dummy motion subspace
|
433
|
+
# for either the fixed or free-floating joint connecting the world to the base.
|
434
|
+
pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
|
435
|
+
S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
|
436
|
+
|
437
|
+
# Extract the successor-to-child fixed transforms.
|
438
|
+
# Note that here we include also the index 0 since suc_H_child[0] stores the
|
439
|
+
# optional pose of the base link w.r.t. the root frame of the model.
|
440
|
+
# This is supported by SDF when the base link <pose> element is defined.
|
441
|
+
suc_H_i = self.joint_model.suc_H_i[jnp.arange(0, 1 + self.number_of_joints())]
|
442
|
+
|
443
|
+
# Compute the overall transforms from the parent to the child of each joint by
|
444
|
+
# composing all the components of our joint model.
|
445
|
+
i_X_λ = jax.vmap(
|
446
|
+
lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform(
|
447
|
+
transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True
|
448
|
+
)
|
449
|
+
)(λ_H_pre, pre_H_suc, suc_H_i)
|
450
|
+
|
451
|
+
return i_X_λ, S
|
452
|
+
|
453
|
+
# ============================
|
454
|
+
# Helpers to update parameters
|
455
|
+
# ============================
|
456
|
+
|
457
|
+
def set_link_mass(
|
458
|
+
self, link_index: jtp.IntLike, mass: jtp.FloatLike
|
459
|
+
) -> KinDynParameters:
|
460
|
+
"""
|
461
|
+
Set the mass of a link.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
link_index: The index of the link.
|
465
|
+
mass: The mass of the link.
|
466
|
+
|
467
|
+
Returns:
|
468
|
+
The updated kinematic and dynamic parameters of the model.
|
469
|
+
"""
|
470
|
+
|
471
|
+
link_parameters = self.link_parameters.replace(
|
472
|
+
mass=self.link_parameters.mass.at[link_index].set(mass)
|
473
|
+
)
|
474
|
+
|
475
|
+
return self.replace(link_parameters=link_parameters)
|
476
|
+
|
477
|
+
def set_link_inertia(
|
478
|
+
self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
|
479
|
+
) -> KinDynParameters:
|
480
|
+
r"""
|
481
|
+
Set the inertia tensor of a link.
|
482
|
+
|
483
|
+
Args:
|
484
|
+
link_index: The index of the link.
|
485
|
+
inertia: The :math:`3 \times 3` inertia tensor of the link.
|
486
|
+
|
487
|
+
Returns:
|
488
|
+
The updated kinematic and dynamic parameters of the model.
|
489
|
+
"""
|
490
|
+
|
491
|
+
inertia_elements = LinkParameters.flatten_inertia_tensor(I=inertia)
|
492
|
+
|
493
|
+
link_parameters = self.link_parameters.replace(
|
494
|
+
inertia_elements=self.link_parameters.inertia_elements.at[link_index].set(
|
495
|
+
inertia_elements
|
496
|
+
)
|
497
|
+
)
|
498
|
+
|
499
|
+
return self.replace(link_parameters=link_parameters)
|
500
|
+
|
501
|
+
|
502
|
+
@jax_dataclasses.pytree_dataclass
|
503
|
+
class JointParameters(JaxsimDataclass):
|
504
|
+
"""
|
505
|
+
Class storing the parameters of a joint.
|
506
|
+
|
507
|
+
Attributes:
|
508
|
+
index: The index of the joint.
|
509
|
+
friction_static: The static friction of the joint.
|
510
|
+
friction_viscous: The viscous friction of the joint.
|
511
|
+
position_limits_min: The lower position limit of the joint.
|
512
|
+
position_limits_max: The upper position limit of the joint.
|
513
|
+
position_limit_spring: The spring constant of the position limit.
|
514
|
+
position_limit_damper: The damper constant of the position limit.
|
515
|
+
|
516
|
+
Note:
|
517
|
+
This class is used inside KinDynParameters to store the vectorized set
|
518
|
+
of joint parameters.
|
519
|
+
"""
|
520
|
+
|
521
|
+
index: jtp.Int
|
522
|
+
|
523
|
+
friction_static: jtp.Float
|
524
|
+
friction_viscous: jtp.Float
|
525
|
+
|
526
|
+
position_limits_min: jtp.Float
|
527
|
+
position_limits_max: jtp.Float
|
528
|
+
|
529
|
+
position_limit_spring: jtp.Float
|
530
|
+
position_limit_damper: jtp.Float
|
531
|
+
|
532
|
+
@staticmethod
|
533
|
+
def build_from_joint_description(
|
534
|
+
joint_description: JointDescription,
|
535
|
+
) -> JointParameters:
|
536
|
+
"""
|
537
|
+
Build a JointParameters object from a joint description.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
joint_description: The joint description to consider.
|
541
|
+
|
542
|
+
Returns:
|
543
|
+
The JointParameters object.
|
544
|
+
"""
|
545
|
+
|
546
|
+
s_min = joint_description.position_limit[0]
|
547
|
+
s_max = joint_description.position_limit[1]
|
548
|
+
|
549
|
+
position_limits_min = jnp.minimum(s_min, s_max)
|
550
|
+
position_limits_max = jnp.maximum(s_min, s_max)
|
551
|
+
|
552
|
+
friction_static = jnp.array(joint_description.friction_static).squeeze()
|
553
|
+
friction_viscous = jnp.array(joint_description.friction_viscous).squeeze()
|
554
|
+
|
555
|
+
position_limit_spring = jnp.array(
|
556
|
+
joint_description.position_limit_spring
|
557
|
+
).squeeze()
|
558
|
+
|
559
|
+
position_limit_damper = jnp.array(
|
560
|
+
joint_description.position_limit_damper
|
561
|
+
).squeeze()
|
562
|
+
|
563
|
+
return JointParameters(
|
564
|
+
index=jnp.array(joint_description.index).squeeze().astype(int),
|
565
|
+
friction_static=friction_static.astype(float),
|
566
|
+
friction_viscous=friction_viscous.astype(float),
|
567
|
+
position_limits_min=position_limits_min.astype(float),
|
568
|
+
position_limits_max=position_limits_max.astype(float),
|
569
|
+
position_limit_spring=position_limit_spring.astype(float),
|
570
|
+
position_limit_damper=position_limit_damper.astype(float),
|
571
|
+
)
|
572
|
+
|
573
|
+
|
574
|
+
@jax_dataclasses.pytree_dataclass
|
575
|
+
class LinkParameters(JaxsimDataclass):
|
576
|
+
r"""
|
577
|
+
Class storing the parameters of a link.
|
578
|
+
|
579
|
+
Attributes:
|
580
|
+
index: The index of the link.
|
581
|
+
mass: The mass of the link.
|
582
|
+
inertia_elements:
|
583
|
+
The unique elements of the :math:`3 \times 3` inertia tensor of the link.
|
584
|
+
center_of_mass:
|
585
|
+
The translation :math:`{}^L \mathbf{p}_{\text{CoM}}` between the origin
|
586
|
+
of the link frame and the link's center of mass, expressed in the
|
587
|
+
coordinates of the link frame.
|
588
|
+
|
589
|
+
Note:
|
590
|
+
This class is used inside KinDynParameters to store the vectorized set
|
591
|
+
of link parameters.
|
592
|
+
"""
|
593
|
+
|
594
|
+
index: jtp.Int
|
595
|
+
|
596
|
+
mass: jtp.Float
|
597
|
+
center_of_mass: jtp.Vector
|
598
|
+
inertia_elements: jtp.Vector
|
599
|
+
|
600
|
+
@staticmethod
|
601
|
+
def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters:
|
602
|
+
r"""
|
603
|
+
Build a LinkParameters object from a :math:`6 \times 6` spatial inertia matrix.
|
604
|
+
|
605
|
+
Args:
|
606
|
+
index: The index of the link.
|
607
|
+
M: The :math:`6 \times 6` spatial inertia matrix of the link.
|
608
|
+
|
609
|
+
Returns:
|
610
|
+
The LinkParameters object.
|
611
|
+
"""
|
612
|
+
|
613
|
+
# Extract the link parameters from the 6D spatial inertia.
|
614
|
+
m, L_p_CoM, I_CoM = Inertia.to_params(M=M)
|
615
|
+
|
616
|
+
# Extract only the necessary elements of the inertia tensor.
|
617
|
+
inertia_elements = I_CoM[jnp.triu_indices(3)]
|
618
|
+
|
619
|
+
return LinkParameters(
|
620
|
+
index=jnp.array(index).squeeze().astype(int),
|
621
|
+
mass=jnp.array(m).squeeze().astype(float),
|
622
|
+
center_of_mass=jnp.atleast_1d(jnp.array(L_p_CoM).squeeze()).astype(float),
|
623
|
+
inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
|
624
|
+
)
|
625
|
+
|
626
|
+
@staticmethod
|
627
|
+
def build_from_inertial_parameters(
|
628
|
+
index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike
|
629
|
+
) -> LinkParameters:
|
630
|
+
r"""
|
631
|
+
Build a LinkParameters object from the inertial parameters of a link.
|
632
|
+
|
633
|
+
Args:
|
634
|
+
index: The index of the link.
|
635
|
+
m: The mass of the link.
|
636
|
+
I: The :math:`3 \times 3` inertia tensor of the link.
|
637
|
+
c: The translation between the link frame and the link's center of mass.
|
638
|
+
|
639
|
+
Returns:
|
640
|
+
The LinkParameters object.
|
641
|
+
"""
|
642
|
+
|
643
|
+
# Extract only the necessary elements of the inertia tensor.
|
644
|
+
inertia_elements = I[jnp.triu_indices(3)]
|
645
|
+
|
646
|
+
return LinkParameters(
|
647
|
+
index=jnp.array(index).squeeze().astype(int),
|
648
|
+
mass=jnp.array(m).squeeze().astype(float),
|
649
|
+
center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float),
|
650
|
+
inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
|
651
|
+
)
|
652
|
+
|
653
|
+
@staticmethod
|
654
|
+
def build_from_flat_parameters(
|
655
|
+
index: jtp.IntLike, parameters: jtp.VectorLike
|
656
|
+
) -> LinkParameters:
|
657
|
+
"""
|
658
|
+
Build a LinkParameters object from a flat vector of parameters.
|
659
|
+
|
660
|
+
Args:
|
661
|
+
index: The index of the link.
|
662
|
+
parameters: The flat vector of parameters.
|
663
|
+
|
664
|
+
Returns:
|
665
|
+
The LinkParameters object.
|
666
|
+
"""
|
667
|
+
index = jnp.array(index).squeeze().astype(int)
|
668
|
+
|
669
|
+
m = jnp.array(parameters[0]).squeeze().astype(float)
|
670
|
+
c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float)
|
671
|
+
inertia_elements = jnp.atleast_1d(parameters[4:].squeeze()).astype(float)
|
672
|
+
|
673
|
+
return LinkParameters(
|
674
|
+
index=index, mass=m, inertia_elements=inertia_elements, center_of_mass=c
|
675
|
+
)
|
676
|
+
|
677
|
+
@staticmethod
|
678
|
+
def flat_parameters(params: LinkParameters) -> jtp.Vector:
|
679
|
+
"""
|
680
|
+
Return the parameters of a link as a flat vector.
|
681
|
+
|
682
|
+
Args:
|
683
|
+
params: The link parameters.
|
684
|
+
|
685
|
+
Returns:
|
686
|
+
The parameters of the link as a flat vector.
|
687
|
+
"""
|
688
|
+
|
689
|
+
return (
|
690
|
+
jnp.hstack(
|
691
|
+
[params.mass, params.center_of_mass.squeeze(), params.inertia_elements]
|
692
|
+
)
|
693
|
+
.squeeze()
|
694
|
+
.astype(float)
|
695
|
+
)
|
696
|
+
|
697
|
+
@staticmethod
|
698
|
+
def inertia_tensor(params: LinkParameters) -> jtp.Matrix:
|
699
|
+
r"""
|
700
|
+
Return the :math:`3 \times 3` inertia tensor of a link.
|
701
|
+
|
702
|
+
Args:
|
703
|
+
params: The link parameters.
|
704
|
+
|
705
|
+
Returns:
|
706
|
+
The :math:`3 \times 3` inertia tensor of the link.
|
707
|
+
"""
|
708
|
+
|
709
|
+
return LinkParameters.unflatten_inertia_tensor(
|
710
|
+
inertia_elements=params.inertia_elements
|
711
|
+
)
|
712
|
+
|
713
|
+
@staticmethod
|
714
|
+
def spatial_inertia(params: LinkParameters) -> jtp.Matrix:
|
715
|
+
r"""
|
716
|
+
Return the :math:`6 \times 6` spatial inertia matrix of a link.
|
717
|
+
|
718
|
+
Args:
|
719
|
+
params: The link parameters.
|
720
|
+
|
721
|
+
Returns:
|
722
|
+
The :math:`6 \times 6` spatial inertia matrix of the link.
|
723
|
+
"""
|
724
|
+
|
725
|
+
return Inertia.to_sixd(
|
726
|
+
mass=params.mass,
|
727
|
+
I=LinkParameters.inertia_tensor(params),
|
728
|
+
com=params.center_of_mass,
|
729
|
+
)
|
730
|
+
|
731
|
+
@staticmethod
|
732
|
+
def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector:
|
733
|
+
r"""
|
734
|
+
Flatten a :math:`3 \times 3` inertia tensor into a vector of unique elements.
|
735
|
+
|
736
|
+
Args:
|
737
|
+
I: The :math:`3 \times 3` inertia tensor.
|
738
|
+
|
739
|
+
Returns:
|
740
|
+
The vector of unique elements of the inertia tensor.
|
741
|
+
"""
|
742
|
+
|
743
|
+
return jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze())
|
744
|
+
|
745
|
+
@staticmethod
|
746
|
+
def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix:
|
747
|
+
r"""
|
748
|
+
Unflatten a vector of unique elements into a :math:`3 \times 3` inertia tensor.
|
749
|
+
|
750
|
+
Args:
|
751
|
+
inertia_elements: The vector of unique elements of the inertia tensor.
|
752
|
+
|
753
|
+
Returns:
|
754
|
+
The :math:`3 \times 3` inertia tensor.
|
755
|
+
"""
|
756
|
+
|
757
|
+
I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze())
|
758
|
+
return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float)
|
759
|
+
|
760
|
+
|
761
|
+
@jax_dataclasses.pytree_dataclass
|
762
|
+
class ContactParameters(JaxsimDataclass):
|
763
|
+
"""
|
764
|
+
Class storing the contact parameters of a model.
|
765
|
+
|
766
|
+
Attributes:
|
767
|
+
body:
|
768
|
+
A tuple of integers representing, for each collidable point, the index of
|
769
|
+
the body (link) to which it is rigidly attached to.
|
770
|
+
point:
|
771
|
+
The translations between the link frame and the collidable point, expressed
|
772
|
+
in the coordinates of the parent link frame.
|
773
|
+
enabled:
|
774
|
+
A tuple of booleans representing, for each collidable point, whether it is
|
775
|
+
enabled or not in contact models.
|
776
|
+
|
777
|
+
Note:
|
778
|
+
Contrarily to LinkParameters and JointParameters, this class is not meant
|
779
|
+
to be created with vmap. This is because the `body` attribute must be `Static`.
|
780
|
+
"""
|
781
|
+
|
782
|
+
body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
|
783
|
+
|
784
|
+
point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))
|
785
|
+
|
786
|
+
enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple)
|
787
|
+
|
788
|
+
@property
|
789
|
+
def indices_of_enabled_collidable_points(self) -> npt.NDArray:
|
790
|
+
"""
|
791
|
+
Return the indices of the enabled collidable points.
|
792
|
+
"""
|
793
|
+
return np.where(np.array(self.enabled))[0]
|
794
|
+
|
795
|
+
@staticmethod
|
796
|
+
def build_from(model_description: ModelDescription) -> ContactParameters:
|
797
|
+
"""
|
798
|
+
Build a ContactParameters object from a model description.
|
799
|
+
|
800
|
+
Args:
|
801
|
+
model_description: The model description to consider.
|
802
|
+
|
803
|
+
Returns:
|
804
|
+
The ContactParameters object.
|
805
|
+
"""
|
806
|
+
|
807
|
+
if len(model_description.collision_shapes) == 0:
|
808
|
+
return ContactParameters()
|
809
|
+
|
810
|
+
# Get all the links so that we can take their updated index.
|
811
|
+
links_dict = {link.name: link for link in model_description}
|
812
|
+
|
813
|
+
# Get all the enabled collidable points of the model.
|
814
|
+
collidable_points = model_description.all_enabled_collidable_points()
|
815
|
+
|
816
|
+
# Extract the positions L_p_C of the collidable points w.r.t. the link frames
|
817
|
+
# they are rigidly attached to.
|
818
|
+
points = jnp.vstack([cp.position for cp in collidable_points])
|
819
|
+
|
820
|
+
# Extract the indices of the links to which the collidable points are rigidly
|
821
|
+
# attached to.
|
822
|
+
link_index_of_points = tuple(
|
823
|
+
links_dict[cp.parent_link.name].index for cp in collidable_points
|
824
|
+
)
|
825
|
+
|
826
|
+
# Build the ContactParameters object.
|
827
|
+
cp = ContactParameters(
|
828
|
+
point=points,
|
829
|
+
body=link_index_of_points,
|
830
|
+
enabled=tuple(True for _ in link_index_of_points),
|
831
|
+
)
|
832
|
+
|
833
|
+
assert cp.point.shape[1] == 3, cp.point.shape[1]
|
834
|
+
assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
|
835
|
+
|
836
|
+
return cp
|
837
|
+
|
838
|
+
|
839
|
+
@jax_dataclasses.pytree_dataclass
|
840
|
+
class FrameParameters(JaxsimDataclass):
|
841
|
+
"""
|
842
|
+
Class storing the frame parameters of a model.
|
843
|
+
|
844
|
+
Attributes:
|
845
|
+
name: A tuple of strings defining the frame names.
|
846
|
+
body:
|
847
|
+
A vector of integers representing, for each frame, the index of
|
848
|
+
the body (link) to which it is rigidly attached to.
|
849
|
+
transform: The transforms of the frames w.r.t. their parent link.
|
850
|
+
|
851
|
+
Note:
|
852
|
+
Contrarily to LinkParameters and JointParameters, this class is not meant
|
853
|
+
to be created with vmap. This is because the `name` attribute must be `Static`.
|
854
|
+
"""
|
855
|
+
|
856
|
+
name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)
|
857
|
+
|
858
|
+
body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
|
859
|
+
|
860
|
+
transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([]))
|
861
|
+
|
862
|
+
@staticmethod
|
863
|
+
def build_from(model_description: ModelDescription) -> FrameParameters:
|
864
|
+
"""
|
865
|
+
Build a FrameParameters object from a model description.
|
866
|
+
|
867
|
+
Args:
|
868
|
+
model_description: The model description to consider.
|
869
|
+
|
870
|
+
Returns:
|
871
|
+
The FrameParameters object.
|
872
|
+
"""
|
873
|
+
|
874
|
+
if len(model_description.frames) == 0:
|
875
|
+
return FrameParameters()
|
876
|
+
|
877
|
+
# Extract the frame names.
|
878
|
+
names = tuple(frame.name for frame in model_description.frames)
|
879
|
+
|
880
|
+
# For each frame, extract the index of the link to which it is attached to.
|
881
|
+
parent_link_index_of_frames = tuple(
|
882
|
+
model_description.links_dict[frame.parent.name].index
|
883
|
+
for frame in model_description.frames
|
884
|
+
)
|
885
|
+
|
886
|
+
# For each frame, extract the transform w.r.t. its parent link.
|
887
|
+
transforms = jnp.atleast_3d(
|
888
|
+
jnp.stack([frame.pose for frame in model_description.frames])
|
889
|
+
)
|
890
|
+
|
891
|
+
# Build the FrameParameters object.
|
892
|
+
fp = FrameParameters(
|
893
|
+
name=names,
|
894
|
+
transform=transforms.astype(float),
|
895
|
+
body=parent_link_index_of_frames,
|
896
|
+
)
|
897
|
+
|
898
|
+
assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:]
|
899
|
+
assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]
|
900
|
+
|
901
|
+
return fp
|