jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,901 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+
5
+ import jax.lax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ from jax_dataclasses import Static
11
+
12
+ import jaxsim.typing as jtp
13
+ from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
14
+ from jaxsim.parsers.descriptions import JointDescription, ModelDescription
15
+ from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
16
+
17
+
18
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
19
+ class KinDynParameters(JaxsimDataclass):
20
+ r"""
21
+ Class storing the kinematic and dynamic parameters of a model.
22
+
23
+ Attributes:
24
+ link_names: The names of the links.
25
+ parent_array: The parent array :math:`\lambda(i)` of the model.
26
+ support_body_array_bool:
27
+ The boolean support parent array :math:`\kappa_{b}(i)` of the model.
28
+ link_parameters: The parameters of the links.
29
+ frame_parameters: The parameters of the frames.
30
+ contact_parameters: The parameters of the collidable points.
31
+ joint_model: The joint model of the model.
32
+ joint_parameters: The parameters of the joints.
33
+ """
34
+
35
+ # Static
36
+ link_names: Static[tuple[str]]
37
+ _parent_array: Static[HashedNumpyArray]
38
+ _support_body_array_bool: Static[HashedNumpyArray]
39
+
40
+ # Links
41
+ link_parameters: LinkParameters
42
+
43
+ # Contacts
44
+ contact_parameters: ContactParameters
45
+
46
+ # Frames
47
+ frame_parameters: FrameParameters
48
+
49
+ # Joints
50
+ joint_model: JointModel
51
+ joint_parameters: JointParameters | None
52
+
53
+ @property
54
+ def parent_array(self) -> jtp.Vector:
55
+ r"""
56
+ Return the parent array :math:`\lambda(i)` of the model.
57
+ """
58
+ return self._parent_array.get()
59
+
60
+ @property
61
+ def support_body_array_bool(self) -> jtp.Matrix:
62
+ r"""
63
+ Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
64
+ """
65
+ return self._support_body_array_bool.get()
66
+
67
+ @staticmethod
68
+ def build(model_description: ModelDescription) -> KinDynParameters:
69
+ """
70
+ Construct the kinematic and dynamic parameters of the model.
71
+
72
+ Args:
73
+ model_description: The parsed model description to consider.
74
+
75
+ Returns:
76
+ The kinematic and dynamic parameters of the model.
77
+
78
+ Note:
79
+ This class is meant to ease the management of parametric models in
80
+ an automatic differentiation context.
81
+ """
82
+
83
+ # Extract the links ordered by their index.
84
+ # The link index corresponds to the body index ∈ [0, num_bodies - 1].
85
+ ordered_links = sorted(
86
+ list(model_description.links_dict.values()),
87
+ key=lambda l: l.index,
88
+ )
89
+
90
+ # Extract the joints ordered by their index.
91
+ # The joint index matches the index of its child link, therefore it starts
92
+ # from 1. Keep this in mind since this 1-indexing might introduce bugs.
93
+ ordered_joints = sorted(
94
+ list(model_description.joints_dict.values()),
95
+ key=lambda j: j.index,
96
+ )
97
+
98
+ # ================
99
+ # Links properties
100
+ # ================
101
+
102
+ # Create a list of link parameters objects.
103
+ link_parameters_list = [
104
+ LinkParameters.build_from_spatial_inertia(index=link.index, M=link.inertia)
105
+ for link in ordered_links
106
+ ]
107
+
108
+ # Create a vectorized object of link parameters.
109
+ link_parameters = jax.tree.map(lambda *l: jnp.stack(l), *link_parameters_list)
110
+
111
+ # =================
112
+ # Joints properties
113
+ # =================
114
+
115
+ # Create a list of joint parameters objects.
116
+ joint_parameters_list = [
117
+ JointParameters.build_from_joint_description(joint_description=joint)
118
+ for joint in ordered_joints
119
+ ]
120
+
121
+ # Create a vectorized object of joint parameters.
122
+ joint_parameters = (
123
+ jax.tree.map(lambda *l: jnp.stack(l), *joint_parameters_list)
124
+ if len(ordered_joints) > 0
125
+ else JointParameters(
126
+ index=jnp.array([], dtype=int),
127
+ friction_static=jnp.array([], dtype=float),
128
+ friction_viscous=jnp.array([], dtype=float),
129
+ position_limits_min=jnp.array([], dtype=float),
130
+ position_limits_max=jnp.array([], dtype=float),
131
+ position_limit_spring=jnp.array([], dtype=float),
132
+ position_limit_damper=jnp.array([], dtype=float),
133
+ )
134
+ )
135
+
136
+ # Create an object that defines the joint model (parent-to-child transforms).
137
+ joint_model = JointModel.build(description=model_description)
138
+
139
+ # ===================
140
+ # Contacts properties
141
+ # ===================
142
+
143
+ # Create the object storing the parameters of collidable points.
144
+ # Note that, contrarily to LinkParameters and JointsParameters, this object
145
+ # is not created with vmap. This is because the "body" attribute of the object
146
+ # must be Static for JIT-related reasons, and tree_map would not consider it
147
+ # as a leaf.
148
+ contact_parameters = ContactParameters.build_from(
149
+ model_description=model_description
150
+ )
151
+
152
+ # =================
153
+ # Frames properties
154
+ # =================
155
+
156
+ # Create the object storing the parameters of frames.
157
+ # Note that, contrarily to LinkParameters and JointsParameters, this object
158
+ # is not created with vmap. This is because the "name" attribute of the object
159
+ # must be Static for JIT-related reasons, and tree_map would not consider it
160
+ # as a leaf.
161
+ frame_parameters = FrameParameters.build_from(
162
+ model_description=model_description
163
+ )
164
+
165
+ # ===============
166
+ # Tree properties
167
+ # ===============
168
+
169
+ # Build the parent array λ(i) of the model.
170
+ # Note: the parent of the base link is not set since it's not defined.
171
+ parent_array_dict = {
172
+ link.index: link.parent.index
173
+ for link in ordered_links
174
+ if link.parent is not None
175
+ }
176
+ parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int)
177
+
178
+ # Instead of building the support parent array κ(i) for each link of the model,
179
+ # that has a variable length depending on the number of links connecting the
180
+ # root to the i-th link, we build the corresponding boolean version.
181
+ # Given a link index i, the boolean support parent array κb(i) is an array
182
+ # with the same number of elements of λ(i) having the i-th element set to True
183
+ # if the i-th link is in the support parent array κ(i), False otherwise.
184
+ # We store the boolean κb(i) as static attribute of the PyTree so that
185
+ # algorithms that need to access it can be jit-compiled.
186
+ def κb(link_index: jtp.IntLike) -> jtp.Vector:
187
+ κb = jnp.zeros(len(ordered_links), dtype=bool)
188
+
189
+ carry0 = κb, link_index
190
+
191
+ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
192
+
193
+ κb, active_link_index = carry
194
+
195
+ κb, active_link_index = jax.lax.cond(
196
+ pred=(i == active_link_index),
197
+ false_fun=lambda: (κb, active_link_index),
198
+ true_fun=lambda: (
199
+ κb.at[active_link_index].set(True),
200
+ parent_array[active_link_index],
201
+ ),
202
+ )
203
+
204
+ return (κb, active_link_index), None
205
+
206
+ (κb, _), _ = jax.lax.scan(
207
+ f=scan_body,
208
+ init=carry0,
209
+ xs=jnp.flip(jnp.arange(start=0, stop=len(ordered_links))),
210
+ )
211
+
212
+ return κb
213
+
214
+ support_body_array_bool = jax.vmap(κb)(
215
+ jnp.arange(start=0, stop=len(ordered_links))
216
+ )
217
+
218
+ # =================================
219
+ # Build and return KinDynParameters
220
+ # =================================
221
+
222
+ return KinDynParameters(
223
+ link_names=tuple(l.name for l in ordered_links),
224
+ _parent_array=HashedNumpyArray(array=parent_array),
225
+ _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
226
+ link_parameters=link_parameters,
227
+ joint_model=joint_model,
228
+ joint_parameters=joint_parameters,
229
+ contact_parameters=contact_parameters,
230
+ frame_parameters=frame_parameters,
231
+ )
232
+
233
+ def __eq__(self, other: KinDynParameters) -> bool:
234
+
235
+ if not isinstance(other, KinDynParameters):
236
+ return False
237
+
238
+ return hash(self) == hash(other)
239
+
240
+ def __hash__(self) -> int:
241
+
242
+ return hash(
243
+ (
244
+ hash(self.number_of_links()),
245
+ hash(self.number_of_joints()),
246
+ hash(self.frame_parameters.name),
247
+ hash(self.frame_parameters.body),
248
+ hash(self._parent_array),
249
+ hash(self._support_body_array_bool),
250
+ )
251
+ )
252
+
253
+ # =============================
254
+ # Helpers to extract parameters
255
+ # =============================
256
+
257
+ def number_of_links(self) -> int:
258
+ """
259
+ Return the number of links of the model.
260
+
261
+ Returns:
262
+ The number of links of the model.
263
+ """
264
+
265
+ return len(self.link_names)
266
+
267
+ def number_of_joints(self) -> int:
268
+ """
269
+ Return the number of joints of the model.
270
+
271
+ Returns:
272
+ The number of joints of the model.
273
+ """
274
+
275
+ return len(self.joint_model.joint_names) - 1
276
+
277
+ def number_of_frames(self) -> int:
278
+ """
279
+ Return the number of frames of the model.
280
+
281
+ Returns:
282
+ The number of frames of the model.
283
+ """
284
+
285
+ return len(self.frame_parameters.name)
286
+
287
+ def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
288
+ r"""
289
+ Return the support parent array :math:`\kappa(i)` of a link.
290
+
291
+ Args:
292
+ link_index: The index of the link.
293
+
294
+ Returns:
295
+ The support parent array :math:`\kappa(i)` of the link.
296
+
297
+ Note:
298
+ This method returns a variable-length vector. In jit-compiled functions,
299
+ it's better to use the (static) boolean version `support_body_array_bool`.
300
+ """
301
+
302
+ return jnp.array(
303
+ jnp.where(self.support_body_array_bool[link_index])[0], dtype=int
304
+ )
305
+
306
+ # ========================
307
+ # Quantities used by RBDAs
308
+ # ========================
309
+
310
+ @jax.jit
311
+ def links_spatial_inertia(self) -> jtp.Array:
312
+ """
313
+ Return the spatial inertia of all links of the model.
314
+
315
+ Returns:
316
+ The spatial inertia of all links of the model.
317
+ """
318
+
319
+ return jax.vmap(LinkParameters.spatial_inertia)(self.link_parameters)
320
+
321
+ @jax.jit
322
+ def tree_transforms(self) -> jtp.Array:
323
+ r"""
324
+ Return the tree transforms of the model.
325
+
326
+ Returns:
327
+ The transforms
328
+ :math:`{}^{\text{pre}(i)} H_{\lambda(i)}`
329
+ of all joints of the model.
330
+ """
331
+
332
+ pre_Xi_λ = jax.vmap(
333
+ lambda i: self.joint_model.parent_H_predecessor(joint_index=i)
334
+ .inverse()
335
+ .adjoint()
336
+ )(jnp.arange(1, self.number_of_joints() + 1))
337
+
338
+ return jnp.vstack(
339
+ [
340
+ jnp.zeros(shape=(1, 6, 6), dtype=float),
341
+ pre_Xi_λ,
342
+ ]
343
+ )
344
+
345
+ @jax.jit
346
+ def joint_transforms(
347
+ self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
348
+ ) -> jtp.Array:
349
+ r"""
350
+ Return the transforms of the joints.
351
+
352
+ Args:
353
+ joint_positions: The joint positions.
354
+ base_transform: The homogeneous matrix defining the base pose.
355
+
356
+ Returns:
357
+ The stacked transforms
358
+ :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
359
+ of each joint.
360
+ """
361
+
362
+ return self.joint_transforms_and_motion_subspaces(
363
+ joint_positions=joint_positions,
364
+ base_transform=base_transform,
365
+ )[0]
366
+
367
+ @jax.jit
368
+ def joint_motion_subspaces(
369
+ self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
370
+ ) -> jtp.Array:
371
+ r"""
372
+ Return the motion subspaces of the joints.
373
+
374
+ Args:
375
+ joint_positions: The joint positions.
376
+ base_transform: The homogeneous matrix defining the base pose.
377
+
378
+ Returns:
379
+ The stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
380
+ """
381
+
382
+ return self.joint_transforms_and_motion_subspaces(
383
+ joint_positions=joint_positions,
384
+ base_transform=base_transform,
385
+ )[1]
386
+
387
+ @jax.jit
388
+ def joint_transforms_and_motion_subspaces(
389
+ self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
390
+ ) -> tuple[jtp.Array, jtp.Array]:
391
+ r"""
392
+ Return the transforms and the motion subspaces of the joints.
393
+
394
+ Args:
395
+ joint_positions: The joint positions.
396
+ base_transform: The homogeneous matrix defining the base pose.
397
+
398
+ Returns:
399
+ A tuple containing the stacked transforms
400
+ :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
401
+ and the stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
402
+
403
+ Note:
404
+ The first transform, at index 0, provides the pose of the base link
405
+ w.r.t. the world frame. For both floating-base and fixed-base systems,
406
+ it takes into account the base pose and the optional transform
407
+ between the root frame of the model and the base link.
408
+ """
409
+
410
+ # Rename the base transform.
411
+ W_H_B = base_transform
412
+
413
+ # Extract the parent-to-predecessor fixed transforms of the joints.
414
+ λ_H_pre = jnp.vstack(
415
+ [
416
+ jnp.eye(4)[jnp.newaxis],
417
+ self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()],
418
+ ]
419
+ )
420
+
421
+ # Compute the transforms and motion subspaces of the joints.
422
+ if self.number_of_joints() == 0:
423
+ pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1))
424
+ else:
425
+ pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
426
+ jnp.array(self.joint_model.joint_types[1:]).astype(int),
427
+ jnp.array(joint_positions),
428
+ jnp.array([j.axis for j in self.joint_model.joint_axis]),
429
+ )
430
+
431
+ # Extract the transforms and motion subspaces of the joints.
432
+ # We stack the base transform W_H_B at index 0, and a dummy motion subspace
433
+ # for either the fixed or free-floating joint connecting the world to the base.
434
+ pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
435
+ S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
436
+
437
+ # Extract the successor-to-child fixed transforms.
438
+ # Note that here we include also the index 0 since suc_H_child[0] stores the
439
+ # optional pose of the base link w.r.t. the root frame of the model.
440
+ # This is supported by SDF when the base link <pose> element is defined.
441
+ suc_H_i = self.joint_model.suc_H_i[jnp.arange(0, 1 + self.number_of_joints())]
442
+
443
+ # Compute the overall transforms from the parent to the child of each joint by
444
+ # composing all the components of our joint model.
445
+ i_X_λ = jax.vmap(
446
+ lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform(
447
+ transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True
448
+ )
449
+ )(λ_H_pre, pre_H_suc, suc_H_i)
450
+
451
+ return i_X_λ, S
452
+
453
+ # ============================
454
+ # Helpers to update parameters
455
+ # ============================
456
+
457
+ def set_link_mass(
458
+ self, link_index: jtp.IntLike, mass: jtp.FloatLike
459
+ ) -> KinDynParameters:
460
+ """
461
+ Set the mass of a link.
462
+
463
+ Args:
464
+ link_index: The index of the link.
465
+ mass: The mass of the link.
466
+
467
+ Returns:
468
+ The updated kinematic and dynamic parameters of the model.
469
+ """
470
+
471
+ link_parameters = self.link_parameters.replace(
472
+ mass=self.link_parameters.mass.at[link_index].set(mass)
473
+ )
474
+
475
+ return self.replace(link_parameters=link_parameters)
476
+
477
+ def set_link_inertia(
478
+ self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
479
+ ) -> KinDynParameters:
480
+ r"""
481
+ Set the inertia tensor of a link.
482
+
483
+ Args:
484
+ link_index: The index of the link.
485
+ inertia: The :math:`3 \times 3` inertia tensor of the link.
486
+
487
+ Returns:
488
+ The updated kinematic and dynamic parameters of the model.
489
+ """
490
+
491
+ inertia_elements = LinkParameters.flatten_inertia_tensor(I=inertia)
492
+
493
+ link_parameters = self.link_parameters.replace(
494
+ inertia_elements=self.link_parameters.inertia_elements.at[link_index].set(
495
+ inertia_elements
496
+ )
497
+ )
498
+
499
+ return self.replace(link_parameters=link_parameters)
500
+
501
+
502
+ @jax_dataclasses.pytree_dataclass
503
+ class JointParameters(JaxsimDataclass):
504
+ """
505
+ Class storing the parameters of a joint.
506
+
507
+ Attributes:
508
+ index: The index of the joint.
509
+ friction_static: The static friction of the joint.
510
+ friction_viscous: The viscous friction of the joint.
511
+ position_limits_min: The lower position limit of the joint.
512
+ position_limits_max: The upper position limit of the joint.
513
+ position_limit_spring: The spring constant of the position limit.
514
+ position_limit_damper: The damper constant of the position limit.
515
+
516
+ Note:
517
+ This class is used inside KinDynParameters to store the vectorized set
518
+ of joint parameters.
519
+ """
520
+
521
+ index: jtp.Int
522
+
523
+ friction_static: jtp.Float
524
+ friction_viscous: jtp.Float
525
+
526
+ position_limits_min: jtp.Float
527
+ position_limits_max: jtp.Float
528
+
529
+ position_limit_spring: jtp.Float
530
+ position_limit_damper: jtp.Float
531
+
532
+ @staticmethod
533
+ def build_from_joint_description(
534
+ joint_description: JointDescription,
535
+ ) -> JointParameters:
536
+ """
537
+ Build a JointParameters object from a joint description.
538
+
539
+ Args:
540
+ joint_description: The joint description to consider.
541
+
542
+ Returns:
543
+ The JointParameters object.
544
+ """
545
+
546
+ s_min = joint_description.position_limit[0]
547
+ s_max = joint_description.position_limit[1]
548
+
549
+ position_limits_min = jnp.minimum(s_min, s_max)
550
+ position_limits_max = jnp.maximum(s_min, s_max)
551
+
552
+ friction_static = jnp.array(joint_description.friction_static).squeeze()
553
+ friction_viscous = jnp.array(joint_description.friction_viscous).squeeze()
554
+
555
+ position_limit_spring = jnp.array(
556
+ joint_description.position_limit_spring
557
+ ).squeeze()
558
+
559
+ position_limit_damper = jnp.array(
560
+ joint_description.position_limit_damper
561
+ ).squeeze()
562
+
563
+ return JointParameters(
564
+ index=jnp.array(joint_description.index).squeeze().astype(int),
565
+ friction_static=friction_static.astype(float),
566
+ friction_viscous=friction_viscous.astype(float),
567
+ position_limits_min=position_limits_min.astype(float),
568
+ position_limits_max=position_limits_max.astype(float),
569
+ position_limit_spring=position_limit_spring.astype(float),
570
+ position_limit_damper=position_limit_damper.astype(float),
571
+ )
572
+
573
+
574
+ @jax_dataclasses.pytree_dataclass
575
+ class LinkParameters(JaxsimDataclass):
576
+ r"""
577
+ Class storing the parameters of a link.
578
+
579
+ Attributes:
580
+ index: The index of the link.
581
+ mass: The mass of the link.
582
+ inertia_elements:
583
+ The unique elements of the :math:`3 \times 3` inertia tensor of the link.
584
+ center_of_mass:
585
+ The translation :math:`{}^L \mathbf{p}_{\text{CoM}}` between the origin
586
+ of the link frame and the link's center of mass, expressed in the
587
+ coordinates of the link frame.
588
+
589
+ Note:
590
+ This class is used inside KinDynParameters to store the vectorized set
591
+ of link parameters.
592
+ """
593
+
594
+ index: jtp.Int
595
+
596
+ mass: jtp.Float
597
+ center_of_mass: jtp.Vector
598
+ inertia_elements: jtp.Vector
599
+
600
+ @staticmethod
601
+ def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters:
602
+ r"""
603
+ Build a LinkParameters object from a :math:`6 \times 6` spatial inertia matrix.
604
+
605
+ Args:
606
+ index: The index of the link.
607
+ M: The :math:`6 \times 6` spatial inertia matrix of the link.
608
+
609
+ Returns:
610
+ The LinkParameters object.
611
+ """
612
+
613
+ # Extract the link parameters from the 6D spatial inertia.
614
+ m, L_p_CoM, I_CoM = Inertia.to_params(M=M)
615
+
616
+ # Extract only the necessary elements of the inertia tensor.
617
+ inertia_elements = I_CoM[jnp.triu_indices(3)]
618
+
619
+ return LinkParameters(
620
+ index=jnp.array(index).squeeze().astype(int),
621
+ mass=jnp.array(m).squeeze().astype(float),
622
+ center_of_mass=jnp.atleast_1d(jnp.array(L_p_CoM).squeeze()).astype(float),
623
+ inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
624
+ )
625
+
626
+ @staticmethod
627
+ def build_from_inertial_parameters(
628
+ index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike
629
+ ) -> LinkParameters:
630
+ r"""
631
+ Build a LinkParameters object from the inertial parameters of a link.
632
+
633
+ Args:
634
+ index: The index of the link.
635
+ m: The mass of the link.
636
+ I: The :math:`3 \times 3` inertia tensor of the link.
637
+ c: The translation between the link frame and the link's center of mass.
638
+
639
+ Returns:
640
+ The LinkParameters object.
641
+ """
642
+
643
+ # Extract only the necessary elements of the inertia tensor.
644
+ inertia_elements = I[jnp.triu_indices(3)]
645
+
646
+ return LinkParameters(
647
+ index=jnp.array(index).squeeze().astype(int),
648
+ mass=jnp.array(m).squeeze().astype(float),
649
+ center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float),
650
+ inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
651
+ )
652
+
653
+ @staticmethod
654
+ def build_from_flat_parameters(
655
+ index: jtp.IntLike, parameters: jtp.VectorLike
656
+ ) -> LinkParameters:
657
+ """
658
+ Build a LinkParameters object from a flat vector of parameters.
659
+
660
+ Args:
661
+ index: The index of the link.
662
+ parameters: The flat vector of parameters.
663
+
664
+ Returns:
665
+ The LinkParameters object.
666
+ """
667
+ index = jnp.array(index).squeeze().astype(int)
668
+
669
+ m = jnp.array(parameters[0]).squeeze().astype(float)
670
+ c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float)
671
+ inertia_elements = jnp.atleast_1d(parameters[4:].squeeze()).astype(float)
672
+
673
+ return LinkParameters(
674
+ index=index, mass=m, inertia_elements=inertia_elements, center_of_mass=c
675
+ )
676
+
677
+ @staticmethod
678
+ def flat_parameters(params: LinkParameters) -> jtp.Vector:
679
+ """
680
+ Return the parameters of a link as a flat vector.
681
+
682
+ Args:
683
+ params: The link parameters.
684
+
685
+ Returns:
686
+ The parameters of the link as a flat vector.
687
+ """
688
+
689
+ return (
690
+ jnp.hstack(
691
+ [params.mass, params.center_of_mass.squeeze(), params.inertia_elements]
692
+ )
693
+ .squeeze()
694
+ .astype(float)
695
+ )
696
+
697
+ @staticmethod
698
+ def inertia_tensor(params: LinkParameters) -> jtp.Matrix:
699
+ r"""
700
+ Return the :math:`3 \times 3` inertia tensor of a link.
701
+
702
+ Args:
703
+ params: The link parameters.
704
+
705
+ Returns:
706
+ The :math:`3 \times 3` inertia tensor of the link.
707
+ """
708
+
709
+ return LinkParameters.unflatten_inertia_tensor(
710
+ inertia_elements=params.inertia_elements
711
+ )
712
+
713
+ @staticmethod
714
+ def spatial_inertia(params: LinkParameters) -> jtp.Matrix:
715
+ r"""
716
+ Return the :math:`6 \times 6` spatial inertia matrix of a link.
717
+
718
+ Args:
719
+ params: The link parameters.
720
+
721
+ Returns:
722
+ The :math:`6 \times 6` spatial inertia matrix of the link.
723
+ """
724
+
725
+ return Inertia.to_sixd(
726
+ mass=params.mass,
727
+ I=LinkParameters.inertia_tensor(params),
728
+ com=params.center_of_mass,
729
+ )
730
+
731
+ @staticmethod
732
+ def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector:
733
+ r"""
734
+ Flatten a :math:`3 \times 3` inertia tensor into a vector of unique elements.
735
+
736
+ Args:
737
+ I: The :math:`3 \times 3` inertia tensor.
738
+
739
+ Returns:
740
+ The vector of unique elements of the inertia tensor.
741
+ """
742
+
743
+ return jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze())
744
+
745
+ @staticmethod
746
+ def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix:
747
+ r"""
748
+ Unflatten a vector of unique elements into a :math:`3 \times 3` inertia tensor.
749
+
750
+ Args:
751
+ inertia_elements: The vector of unique elements of the inertia tensor.
752
+
753
+ Returns:
754
+ The :math:`3 \times 3` inertia tensor.
755
+ """
756
+
757
+ I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze())
758
+ return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float)
759
+
760
+
761
+ @jax_dataclasses.pytree_dataclass
762
+ class ContactParameters(JaxsimDataclass):
763
+ """
764
+ Class storing the contact parameters of a model.
765
+
766
+ Attributes:
767
+ body:
768
+ A tuple of integers representing, for each collidable point, the index of
769
+ the body (link) to which it is rigidly attached to.
770
+ point:
771
+ The translations between the link frame and the collidable point, expressed
772
+ in the coordinates of the parent link frame.
773
+ enabled:
774
+ A tuple of booleans representing, for each collidable point, whether it is
775
+ enabled or not in contact models.
776
+
777
+ Note:
778
+ Contrarily to LinkParameters and JointParameters, this class is not meant
779
+ to be created with vmap. This is because the `body` attribute must be `Static`.
780
+ """
781
+
782
+ body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
783
+
784
+ point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))
785
+
786
+ enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple)
787
+
788
+ @property
789
+ def indices_of_enabled_collidable_points(self) -> npt.NDArray:
790
+ """
791
+ Return the indices of the enabled collidable points.
792
+ """
793
+ return np.where(np.array(self.enabled))[0]
794
+
795
+ @staticmethod
796
+ def build_from(model_description: ModelDescription) -> ContactParameters:
797
+ """
798
+ Build a ContactParameters object from a model description.
799
+
800
+ Args:
801
+ model_description: The model description to consider.
802
+
803
+ Returns:
804
+ The ContactParameters object.
805
+ """
806
+
807
+ if len(model_description.collision_shapes) == 0:
808
+ return ContactParameters()
809
+
810
+ # Get all the links so that we can take their updated index.
811
+ links_dict = {link.name: link for link in model_description}
812
+
813
+ # Get all the enabled collidable points of the model.
814
+ collidable_points = model_description.all_enabled_collidable_points()
815
+
816
+ # Extract the positions L_p_C of the collidable points w.r.t. the link frames
817
+ # they are rigidly attached to.
818
+ points = jnp.vstack([cp.position for cp in collidable_points])
819
+
820
+ # Extract the indices of the links to which the collidable points are rigidly
821
+ # attached to.
822
+ link_index_of_points = tuple(
823
+ links_dict[cp.parent_link.name].index for cp in collidable_points
824
+ )
825
+
826
+ # Build the ContactParameters object.
827
+ cp = ContactParameters(
828
+ point=points,
829
+ body=link_index_of_points,
830
+ enabled=tuple(True for _ in link_index_of_points),
831
+ )
832
+
833
+ assert cp.point.shape[1] == 3, cp.point.shape[1]
834
+ assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
835
+
836
+ return cp
837
+
838
+
839
+ @jax_dataclasses.pytree_dataclass
840
+ class FrameParameters(JaxsimDataclass):
841
+ """
842
+ Class storing the frame parameters of a model.
843
+
844
+ Attributes:
845
+ name: A tuple of strings defining the frame names.
846
+ body:
847
+ A vector of integers representing, for each frame, the index of
848
+ the body (link) to which it is rigidly attached to.
849
+ transform: The transforms of the frames w.r.t. their parent link.
850
+
851
+ Note:
852
+ Contrarily to LinkParameters and JointParameters, this class is not meant
853
+ to be created with vmap. This is because the `name` attribute must be `Static`.
854
+ """
855
+
856
+ name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)
857
+
858
+ body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
859
+
860
+ transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([]))
861
+
862
+ @staticmethod
863
+ def build_from(model_description: ModelDescription) -> FrameParameters:
864
+ """
865
+ Build a FrameParameters object from a model description.
866
+
867
+ Args:
868
+ model_description: The model description to consider.
869
+
870
+ Returns:
871
+ The FrameParameters object.
872
+ """
873
+
874
+ if len(model_description.frames) == 0:
875
+ return FrameParameters()
876
+
877
+ # Extract the frame names.
878
+ names = tuple(frame.name for frame in model_description.frames)
879
+
880
+ # For each frame, extract the index of the link to which it is attached to.
881
+ parent_link_index_of_frames = tuple(
882
+ model_description.links_dict[frame.parent.name].index
883
+ for frame in model_description.frames
884
+ )
885
+
886
+ # For each frame, extract the transform w.r.t. its parent link.
887
+ transforms = jnp.atleast_3d(
888
+ jnp.stack([frame.pose for frame in model_description.frames])
889
+ )
890
+
891
+ # Build the FrameParameters object.
892
+ fp = FrameParameters(
893
+ name=names,
894
+ transform=transforms.astype(float),
895
+ body=parent_link_index_of_frames,
896
+ )
897
+
898
+ assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:]
899
+ assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]
900
+
901
+ return fp