jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev366.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,777 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+
5
+ import jax.lax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+ import jaxlie
9
+ from jax_dataclasses import Static
10
+
11
+ import jaxsim.typing as jtp
12
+ from jaxsim.math import Inertia, JointModel, supported_joint_motion
13
+ from jaxsim.parsers.descriptions import JointDescription, ModelDescription
14
+ from jaxsim.utils import JaxsimDataclass
15
+
16
+
17
+ @jax_dataclasses.pytree_dataclass
18
+ class KynDynParameters(JaxsimDataclass):
19
+ r"""
20
+ Class storing the kinematic and dynamic parameters of a model.
21
+
22
+ Attributes:
23
+ link_names: The names of the links.
24
+ parent_array: The parent array :math:`\lambda(i)` of the model.
25
+ support_body_array_bool:
26
+ The boolean support parent array :math:`\kappa_{b}(i)` of the model.
27
+ link_parameters: The parameters of the links.
28
+ contact_parameters: The parameters of the collidable points.
29
+ joint_model: The joint model of the model.
30
+ joint_parameters: The parameters of the joints.
31
+ """
32
+
33
+ # Static
34
+ link_names: Static[tuple[str]]
35
+ parent_array: Static[jtp.Vector]
36
+ support_body_array_bool: Static[jtp.Matrix]
37
+
38
+ # Links
39
+ link_parameters: LinkParameters
40
+
41
+ # Contacts
42
+ contact_parameters: ContactParameters
43
+
44
+ # Joints
45
+ joint_model: JointModel
46
+ joint_parameters: JointParameters | None
47
+
48
+ @staticmethod
49
+ def build(model_description: ModelDescription) -> KynDynParameters:
50
+ """
51
+ Construct the kinematic and dynamic parameters of the model.
52
+
53
+ Args:
54
+ model_description: The parsed model description to consider.
55
+
56
+ Returns:
57
+ The kinematic and dynamic parameters of the model.
58
+
59
+ Note:
60
+ This class is meant to ease the management of parametric models in
61
+ an automatic differentiation context.
62
+ """
63
+
64
+ # Extract the links ordered by their index.
65
+ # The link index corresponds to the body index ∈ [0, num_bodies - 1].
66
+ ordered_links = sorted(
67
+ list(model_description.links_dict.values()),
68
+ key=lambda l: l.index,
69
+ )
70
+
71
+ # Extract the joints ordered by their index.
72
+ # The joint index matches the index of its child link, therefore it starts
73
+ # from 1. Keep this in mind since this 1-indexing might introduce bugs.
74
+ ordered_joints = sorted(
75
+ list(model_description.joints_dict.values()),
76
+ key=lambda j: j.index,
77
+ )
78
+
79
+ # ================
80
+ # Links properties
81
+ # ================
82
+
83
+ # Create a list of link parameters objects.
84
+ link_parameters_list = [
85
+ LinkParameters.build_from_spatial_inertia(index=link.index, M=link.inertia)
86
+ for link in ordered_links
87
+ ]
88
+
89
+ # Create a vectorized object of link parameters.
90
+ link_parameters = jax.tree_util.tree_map(
91
+ lambda *l: jnp.stack(l), *link_parameters_list
92
+ )
93
+
94
+ # =================
95
+ # Joints properties
96
+ # =================
97
+
98
+ # Create a list of joint parameters objects.
99
+ joint_parameters_list = [
100
+ JointParameters.build_from_joint_description(joint_description=joint)
101
+ for joint in ordered_joints
102
+ ]
103
+
104
+ # Create a vectorized object of joint parameters.
105
+ joint_parameters = (
106
+ jax.tree_util.tree_map(lambda *l: jnp.stack(l), *joint_parameters_list)
107
+ if len(ordered_joints) > 0
108
+ else JointParameters(
109
+ index=jnp.array([], dtype=int),
110
+ friction_static=jnp.array([], dtype=float),
111
+ friction_viscous=jnp.array([], dtype=float),
112
+ position_limits_min=jnp.array([], dtype=float),
113
+ position_limits_max=jnp.array([], dtype=float),
114
+ position_limit_spring=jnp.array([], dtype=float),
115
+ position_limit_damper=jnp.array([], dtype=float),
116
+ )
117
+ )
118
+
119
+ # Create an object that defines the joint model (parent-to-child transforms).
120
+ joint_model = JointModel.build(description=model_description)
121
+
122
+ # ===================
123
+ # Contacts properties
124
+ # ===================
125
+
126
+ # Create the object storing the parameters of collidable points.
127
+ # Note that, contrarily to LinkParameters and JointsParameters, this object
128
+ # is not created with vmap. This is because the "body" attribute of the object
129
+ # must be Static for JIT-related reasons, and tree_map would not consider it
130
+ # as a leaf.
131
+ contact_parameters = ContactParameters.build_from(
132
+ model_description=model_description
133
+ )
134
+
135
+ # ===============
136
+ # Tree properties
137
+ # ===============
138
+
139
+ # Build the parent array λ(i) of the model.
140
+ # Note: the parent of the base link is not set since it's not defined.
141
+ parent_array_dict = {
142
+ link.index: link.parent.index
143
+ for link in ordered_links
144
+ if link.parent is not None
145
+ }
146
+ parent_array = jnp.array([-1] + list(parent_array_dict.values()), dtype=int)
147
+
148
+ # Instead of building the support parent array κ(i) for each link of the model,
149
+ # that has a variable length depending on the number of links connecting the
150
+ # root to the i-th link, we build the corresponding boolean version.
151
+ # Given a link index i, the boolean support parent array κb(i) is an array
152
+ # with the same number of elements of λ(i) having the i-th element set to True
153
+ # if the i-th link is in the support parent array κ(i), False otherwise.
154
+ # We store the boolean κb(i) as static attribute of the PyTree so that
155
+ # algorithms that need to access it can be jit-compiled.
156
+ def κb(link_index: jtp.IntLike) -> jtp.Vector:
157
+ κb = jnp.zeros(len(ordered_links), dtype=bool)
158
+
159
+ carry0 = κb, link_index
160
+
161
+ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
162
+
163
+ κb, active_link_index = carry
164
+
165
+ κb, active_link_index = jax.lax.cond(
166
+ pred=(i == active_link_index),
167
+ false_fun=lambda: (κb, active_link_index),
168
+ true_fun=lambda: (
169
+ κb.at[active_link_index].set(True),
170
+ parent_array[active_link_index],
171
+ ),
172
+ )
173
+
174
+ return (κb, active_link_index), None
175
+
176
+ (κb, _), _ = jax.lax.scan(
177
+ f=scan_body,
178
+ init=carry0,
179
+ xs=jnp.flip(jnp.arange(start=0, stop=len(ordered_links))),
180
+ )
181
+
182
+ return κb
183
+
184
+ support_body_array_bool = jax.vmap(κb)(
185
+ jnp.arange(start=0, stop=len(ordered_links))
186
+ )
187
+
188
+ # =================================
189
+ # Build and return KynDynParameters
190
+ # =================================
191
+
192
+ return KynDynParameters(
193
+ link_names=tuple(l.name for l in ordered_links),
194
+ parent_array=parent_array,
195
+ support_body_array_bool=support_body_array_bool,
196
+ link_parameters=link_parameters,
197
+ joint_model=joint_model,
198
+ joint_parameters=joint_parameters,
199
+ contact_parameters=contact_parameters,
200
+ )
201
+
202
+ def __eq__(self, other: KynDynParameters) -> bool:
203
+
204
+ if not isinstance(other, KynDynParameters):
205
+ return False
206
+
207
+ equal = True
208
+ equal = equal and self.number_of_links() == other.number_of_links()
209
+ equal = equal and self.number_of_joints() == other.number_of_joints()
210
+ equal = equal and jnp.allclose(self.parent_array, other.parent_array)
211
+
212
+ return equal
213
+
214
+ def __hash__(self) -> int:
215
+
216
+ h = (
217
+ hash(self.number_of_links()),
218
+ hash(self.number_of_joints()),
219
+ hash(tuple(self.parent_array.tolist())),
220
+ )
221
+
222
+ return hash(h)
223
+
224
+ # =============================
225
+ # Helpers to extract parameters
226
+ # =============================
227
+
228
+ def number_of_links(self) -> int:
229
+ """
230
+ Return the number of links of the model.
231
+
232
+ Returns:
233
+ The number of links of the model.
234
+ """
235
+
236
+ return len(self.link_names)
237
+
238
+ def number_of_joints(self) -> int:
239
+ """
240
+ Return the number of joints of the model.
241
+
242
+ Returns:
243
+ The number of joints of the model.
244
+ """
245
+
246
+ return len(self.joint_model.joint_names) - 1
247
+
248
+ def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
249
+ r"""
250
+ Return the support parent array :math:`\kappa(i)` of a link.
251
+
252
+ Args:
253
+ link_index: The index of the link.
254
+
255
+ Returns:
256
+ The support parent array :math:`\kappa(i)` of the link.
257
+
258
+ Note:
259
+ This method returns a variable-length vector. In jit-compiled functions,
260
+ it's better to use the (static) boolean version `support_body_array_bool`.
261
+ """
262
+
263
+ return jnp.array(
264
+ jnp.where(self.support_body_array_bool[link_index])[0], dtype=int
265
+ )
266
+
267
+ # ========================
268
+ # Quantities used by RBDAs
269
+ # ========================
270
+
271
+ @jax.jit
272
+ def links_spatial_inertia(self) -> jtp.Array:
273
+ """
274
+ Return the spatial inertia of all links of the model.
275
+
276
+ Returns:
277
+ The spatial inertia of all links of the model.
278
+ """
279
+
280
+ return jax.vmap(LinkParameters.spatial_inertia)(self.link_parameters)
281
+
282
+ @jax.jit
283
+ def tree_transforms(self) -> jtp.Array:
284
+ r"""
285
+ Return the tree transforms of the model.
286
+
287
+ Returns:
288
+ The transforms
289
+ :math:`{}^{\text{pre}(i)} H_{\lambda(i)}`
290
+ of all joints of the model.
291
+ """
292
+
293
+ pre_Xi_λ = jax.vmap(
294
+ lambda i: self.joint_model.parent_H_predecessor(joint_index=i)
295
+ .inverse()
296
+ .adjoint()
297
+ )(jnp.arange(1, self.number_of_joints() + 1))
298
+
299
+ return jnp.vstack(
300
+ [
301
+ jnp.zeros(shape=(1, 6, 6), dtype=float),
302
+ pre_Xi_λ,
303
+ ]
304
+ )
305
+
306
+ @jax.jit
307
+ def joint_transforms(
308
+ self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
309
+ ) -> jtp.Array:
310
+ r"""
311
+ Return the transforms of the joints.
312
+
313
+ Args:
314
+ joint_positions: The joint positions.
315
+ base_transform: The homogeneous matrix defining the base pose.
316
+
317
+ Returns:
318
+ The stacked transforms
319
+ :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
320
+ of each joint.
321
+ """
322
+
323
+ return self.joint_transforms_and_motion_subspaces(
324
+ joint_positions=joint_positions,
325
+ base_transform=base_transform,
326
+ )[0]
327
+
328
+ @jax.jit
329
+ def joint_motion_subspaces(
330
+ self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
331
+ ) -> jtp.Array:
332
+ r"""
333
+ Return the motion subspaces of the joints.
334
+
335
+ Args:
336
+ joint_positions: The joint positions.
337
+ base_transform: The homogeneous matrix defining the base pose.
338
+
339
+ Returns:
340
+ The stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
341
+ """
342
+
343
+ return self.joint_transforms_and_motion_subspaces(
344
+ joint_positions=joint_positions,
345
+ base_transform=base_transform,
346
+ )[1]
347
+
348
+ @jax.jit
349
+ def joint_transforms_and_motion_subspaces(
350
+ self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
351
+ ) -> tuple[jtp.Array, jtp.Array]:
352
+ r"""
353
+ Return the transforms and the motion subspaces of the joints.
354
+
355
+ Args:
356
+ joint_positions: The joint positions.
357
+ base_transform: The homogeneous matrix defining the base pose.
358
+
359
+ Returns:
360
+ A tuple containing the stacked transforms
361
+ :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
362
+ and the stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
363
+
364
+ Note:
365
+ The first transform, at index 0, provides the pose of the base link
366
+ w.r.t. the world frame. For both floating-base and fixed-base systems,
367
+ it takes into account the base pose and the optional transform
368
+ between the root frame of the model and the base link.
369
+ """
370
+
371
+ # Rename the base transform.
372
+ W_H_B = base_transform
373
+
374
+ # Extract the parent-to-predecessor fixed transforms of the joints.
375
+ λ_H_pre = jnp.vstack(
376
+ [
377
+ jnp.eye(4)[jnp.newaxis],
378
+ jax.vmap(
379
+ lambda i: self.joint_model.parent_H_predecessor(joint_index=i)
380
+ )(jnp.arange(1, 1 + self.number_of_joints())),
381
+ ]
382
+ )
383
+
384
+ # Compute the transforms and motion subspaces of the joints.
385
+ # TODO: understand how to use joint_indices to access joint_types, right now
386
+ # it fails when used within a JIT context.
387
+ pre_H_suc_and_S = [
388
+ supported_joint_motion(
389
+ joint_type=self.joint_model.joint_types[i + 1],
390
+ joint_position=jnp.array(s),
391
+ )
392
+ for i, s in enumerate(jnp.array(joint_positions).astype(float))
393
+ ]
394
+
395
+ # Extract the transforms and motion subspaces of the joints.
396
+ # We stack the base transform W_H_B at index 0, and a dummy motion subspace
397
+ # for either the fixed or free-floating joint connecting the world to the base.
398
+ pre_H_suc = jnp.stack([W_H_B] + [H for H, _ in pre_H_suc_and_S])
399
+ S = jnp.stack([jnp.vstack(jnp.zeros(6))] + [S for _, S in pre_H_suc_and_S])
400
+
401
+ # Extract the successor-to-child fixed transforms.
402
+ # Note that here we include also the index 0 since suc_H_child[0] stores the
403
+ # optional pose of the base link w.r.t. the root frame of the model.
404
+ # This is supported by SDF when the base link <pose> element is defined.
405
+ suc_H_i = jax.vmap(lambda i: self.joint_model.successor_H_child(joint_index=i))(
406
+ jnp.arange(0, 1 + self.number_of_joints())
407
+ )
408
+
409
+ # Compute the overall transforms from the parent to the child of each joint by
410
+ # composing all the components of our joint model.
411
+ i_X_λ = jax.vmap(
412
+ lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: jaxlie.SE3.from_matrix(
413
+ λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i
414
+ )
415
+ .inverse()
416
+ .adjoint()
417
+ )(λ_H_pre, pre_H_suc, suc_H_i)
418
+
419
+ return i_X_λ, S
420
+
421
+ # ============================
422
+ # Helpers to update parameters
423
+ # ============================
424
+
425
+ def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameters:
426
+ """
427
+ Set the mass of a link.
428
+
429
+ Args:
430
+ link_index: The index of the link.
431
+ mass: The mass of the link.
432
+
433
+ Returns:
434
+ The updated kinematic and dynamic parameters of the model.
435
+ """
436
+
437
+ link_parameters = self.link_parameters.replace(
438
+ mass=self.link_parameters.mass.at[link_index].set(mass)
439
+ )
440
+
441
+ return self.replace(link_parameters=link_parameters)
442
+
443
+ def set_link_inertia(
444
+ self, link_index: int, inertia: jtp.MatrixLike
445
+ ) -> KynDynParameters:
446
+ """
447
+ Set the inertia tensor of a link.
448
+
449
+ Args:
450
+ link_index: The index of the link.
451
+ inertia: The 3×3 inertia tensor of the link.
452
+
453
+ Returns:
454
+ The updated kinematic and dynamic parameters of the model.
455
+ """
456
+
457
+ inertia_elements = LinkParameters.flatten_inertia_tensor(I=inertia)
458
+
459
+ link_parameters = self.link_parameters.replace(
460
+ inertia_elements=self.link_parameters.inertia_elements.at[link_index].set(
461
+ inertia_elements
462
+ )
463
+ )
464
+
465
+ return self.replace(link_parameters=link_parameters)
466
+
467
+
468
+ @jax_dataclasses.pytree_dataclass
469
+ class JointParameters(JaxsimDataclass):
470
+ """
471
+ Class storing the parameters of a joint.
472
+
473
+ Attributes:
474
+ index: The index of the joint.
475
+ friction_static: The static friction of the joint.
476
+ friction_viscous: The viscous friction of the joint.
477
+ position_limits_min: The lower position limit of the joint.
478
+ position_limits_max: The upper position limit of the joint.
479
+ position_limit_spring: The spring constant of the position limit.
480
+ position_limit_damper: The damper constant of the position limit.
481
+
482
+ Note:
483
+ This class is used inside KinDynParameters to store the vectorized set
484
+ of joint parameters.
485
+ """
486
+
487
+ index: jtp.Int
488
+
489
+ friction_static: jtp.Float
490
+ friction_viscous: jtp.Float
491
+
492
+ position_limits_min: jtp.Float
493
+ position_limits_max: jtp.Float
494
+
495
+ position_limit_spring: jtp.Float
496
+ position_limit_damper: jtp.Float
497
+
498
+ @staticmethod
499
+ def build_from_joint_description(
500
+ joint_description: JointDescription,
501
+ ) -> JointParameters:
502
+ """
503
+ Build a JointParameters object from a joint description.
504
+
505
+ Args:
506
+ joint_description: The joint description to consider.
507
+
508
+ Returns:
509
+ The JointParameters object.
510
+ """
511
+
512
+ s_min = joint_description.position_limit[0]
513
+ s_max = joint_description.position_limit[1]
514
+
515
+ position_limits_min = jnp.minimum(s_min, s_max)
516
+ position_limits_max = jnp.maximum(s_min, s_max)
517
+
518
+ friction_static = jnp.array(joint_description.friction_static).squeeze()
519
+ friction_viscous = jnp.array(joint_description.friction_viscous).squeeze()
520
+
521
+ position_limit_spring = jnp.array(
522
+ joint_description.position_limit_spring
523
+ ).squeeze()
524
+
525
+ position_limit_damper = jnp.array(
526
+ joint_description.position_limit_damper
527
+ ).squeeze()
528
+
529
+ return JointParameters(
530
+ index=jnp.array(joint_description.index).squeeze().astype(int),
531
+ friction_static=friction_static.astype(float),
532
+ friction_viscous=friction_viscous.astype(float),
533
+ position_limits_min=position_limits_min.astype(float),
534
+ position_limits_max=position_limits_max.astype(float),
535
+ position_limit_spring=position_limit_spring.astype(float),
536
+ position_limit_damper=position_limit_damper.astype(float),
537
+ )
538
+
539
+
540
+ @jax_dataclasses.pytree_dataclass
541
+ class LinkParameters(JaxsimDataclass):
542
+ r"""
543
+ Class storing the parameters of a link.
544
+
545
+ Attributes:
546
+ index: The index of the link.
547
+ mass: The mass of the link.
548
+ inertia_elements:
549
+ The unique elements of the 3×3 inertia tensor of the link.
550
+ center_of_mass:
551
+ The translation :math:`{}^L \mathbf{p}_{\text{CoM}}` between the origin
552
+ of the link frame and the link's center of mass, expressed in the
553
+ coordinates of the link frame.
554
+
555
+ Note:
556
+ This class is used inside KinDynParameters to store the vectorized set
557
+ of link parameters.
558
+ """
559
+
560
+ index: jtp.Int
561
+
562
+ mass: jtp.Float
563
+ center_of_mass: jtp.Vector
564
+ inertia_elements: jtp.Vector
565
+
566
+ @staticmethod
567
+ def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters:
568
+ """
569
+ Build a LinkParameters object from a 6×6 spatial inertia matrix.
570
+
571
+ Args:
572
+ index: The index of the link.
573
+ M: The 6×6 spatial inertia matrix of the link.
574
+
575
+ Returns:
576
+ The LinkParameters object.
577
+ """
578
+
579
+ # Extract the link parameters from the 6D spatial inertia.
580
+ m, L_p_CoM, I = Inertia.to_params(M=M)
581
+
582
+ # Extract only the necessary elements of the inertia tensor.
583
+ inertia_elements = I[jnp.triu_indices(3)]
584
+
585
+ return LinkParameters(
586
+ index=jnp.array(index).squeeze().astype(int),
587
+ mass=jnp.array(m).squeeze().astype(float),
588
+ center_of_mass=jnp.atleast_1d(jnp.array(L_p_CoM).squeeze()).astype(float),
589
+ inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
590
+ )
591
+
592
+ @staticmethod
593
+ def build_from_inertial_parameters(
594
+ index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike
595
+ ) -> LinkParameters:
596
+ """
597
+ Build a LinkParameters object from the inertial parameters of a link.
598
+
599
+ Args:
600
+ index: The index of the link.
601
+ m: The mass of the link.
602
+ I: The 3×3 inertia tensor of the link.
603
+ c: The translation between the link frame and the link's center of mass.
604
+
605
+ Returns:
606
+ The LinkParameters object.
607
+ """
608
+
609
+ # Extract only the necessary elements of the inertia tensor.
610
+ inertia_elements = I[jnp.triu_indices(3)]
611
+
612
+ return LinkParameters(
613
+ index=jnp.array(index).squeeze().astype(int),
614
+ mass=jnp.array(m).squeeze().astype(float),
615
+ center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float),
616
+ inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
617
+ )
618
+
619
+ @staticmethod
620
+ def build_from_flat_parameters(
621
+ index: jtp.IntLike, parameters: jtp.VectorLike
622
+ ) -> LinkParameters:
623
+
624
+ index = jnp.array(index).squeeze().astype(int)
625
+
626
+ m = jnp.array(parameters[0]).squeeze().astype(float)
627
+ c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float)
628
+ inertia_elements = jnp.atleast_1d(parameters[4:].squeeze()).astype(float)
629
+
630
+ return LinkParameters(
631
+ index=index, mass=m, inertia_elements=inertia_elements, center_of_mass=c
632
+ )
633
+
634
+ @staticmethod
635
+ def flat_parameters(params: LinkParameters) -> jtp.Vector:
636
+ """
637
+ Return the parameters of a link as a flat vector.
638
+
639
+ Args:
640
+ params: The link parameters.
641
+
642
+ Returns:
643
+ The parameters of the link as a flat vector.
644
+ """
645
+
646
+ return (
647
+ jnp.hstack(
648
+ [params.mass, params.center_of_mass.squeeze(), params.inertia_elements]
649
+ )
650
+ .squeeze()
651
+ .astype(float)
652
+ )
653
+
654
+ @staticmethod
655
+ def inertia_tensor(params: LinkParameters) -> jtp.Matrix:
656
+ """
657
+ Return the 3×3 inertia tensor of a link.
658
+
659
+ Args:
660
+ params: The link parameters.
661
+
662
+ Returns:
663
+ The 3×3 inertia tensor of the link.
664
+ """
665
+
666
+ return LinkParameters.unflatten_inertia_tensor(
667
+ inertia_elements=params.inertia_elements
668
+ )
669
+
670
+ @staticmethod
671
+ def spatial_inertia(params: LinkParameters) -> jtp.Matrix:
672
+ """
673
+ Return the 6×6 spatial inertia matrix of a link.
674
+
675
+ Args:
676
+ params: The link parameters.
677
+
678
+ Returns:
679
+ The 6×6 spatial inertia matrix of the link.
680
+ """
681
+
682
+ return Inertia.to_sixd(
683
+ mass=params.mass,
684
+ I=LinkParameters.inertia_tensor(params),
685
+ com=params.center_of_mass,
686
+ )
687
+
688
+ @staticmethod
689
+ def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector:
690
+ """
691
+ Flatten a 3×3 inertia tensor into a vector of unique elements.
692
+
693
+ Args:
694
+ I: The 3×3 inertia tensor.
695
+
696
+ Returns:
697
+ The vector of unique elements of the inertia tensor.
698
+ """
699
+
700
+ return jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze())
701
+
702
+ @staticmethod
703
+ def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix:
704
+ """
705
+ Unflatten a vector of unique elements into a 3×3 inertia tensor.
706
+
707
+ Args:
708
+ inertia_elements: The vector of unique elements of the inertia tensor.
709
+
710
+ Returns:
711
+ The 3×3 inertia tensor.
712
+ """
713
+
714
+ I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze())
715
+ return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float)
716
+
717
+
718
+ @jax_dataclasses.pytree_dataclass
719
+ class ContactParameters(JaxsimDataclass):
720
+ """
721
+ Class storing the contact parameters of a model.
722
+
723
+ Attributes:
724
+ body:
725
+ A tuple of integers representing, for each collidable point, the index of
726
+ the body (link) to which it is rigidly attached to.
727
+ point:
728
+ The translation between the link frame and the collidable point, expressed
729
+ in the coordinates of the parent link frame.
730
+
731
+ Note:
732
+ Contrarily to LinkParameters and JointParameters, this class is not meant
733
+ to be created with vmap. This is because the `body` attribute must be `Static`.
734
+ """
735
+
736
+ body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
737
+
738
+ point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))
739
+
740
+ @staticmethod
741
+ def build_from(model_description: ModelDescription) -> ContactParameters:
742
+ """
743
+ Build a ContactParameters object from a model description.
744
+
745
+ Args:
746
+ model_description: The model description to consider.
747
+
748
+ Returns:
749
+ The ContactParameters object.
750
+ """
751
+
752
+ if len(model_description.collision_shapes) == 0:
753
+ return ContactParameters()
754
+
755
+ # Get all the links so that we can take their updated index.
756
+ links_dict = {link.name: link for link in model_description}
757
+
758
+ # Get all the enabled collidable points of the model.
759
+ collidable_points = model_description.all_enabled_collidable_points()
760
+
761
+ # Extract the positions L_p_C of the collidable points w.r.t. the link frames
762
+ # they are rigidly attached to.
763
+ points = jnp.vstack([cp.position for cp in collidable_points])
764
+
765
+ # Extract the indices of the links to which the collidable points are rigidly
766
+ # attached to.
767
+ link_index_of_points = tuple(
768
+ links_dict[cp.parent_link.name].index for cp in collidable_points
769
+ )
770
+
771
+ # Build the GroundContact object.
772
+ cp = ContactParameters(point=points, body=link_index_of_points) # noqa
773
+
774
+ assert cp.point.shape[1] == 3
775
+ assert cp.point.shape[0] == len(cp.body)
776
+
777
+ return cp