jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py ADDED
@@ -0,0 +1,1633 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import functools
5
+ import pathlib
6
+ from typing import Any
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jax_dataclasses
11
+ import jaxlie
12
+ import rod
13
+ from jax_dataclasses import Static
14
+
15
+ import jaxsim.api as js
16
+ import jaxsim.parsers.descriptions
17
+ import jaxsim.typing as jtp
18
+ from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability
19
+
20
+ from .common import VelRepr
21
+
22
+
23
+ @jax_dataclasses.pytree_dataclass
24
+ class JaxSimModel(JaxsimDataclass):
25
+ """
26
+ The JaxSim model defining the kinematics and dynamics of a robot.
27
+ """
28
+
29
+ model_name: Static[str]
30
+
31
+ terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
32
+ default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
33
+ )
34
+
35
+ built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
36
+ default=None, repr=False, compare=False, hash=False
37
+ )
38
+
39
+ description: Static[
40
+ HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
41
+ ] = dataclasses.field(default=None, repr=False, compare=False, hash=False)
42
+
43
+ kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
44
+ dataclasses.field(default=None, repr=False, compare=False, hash=False)
45
+ )
46
+
47
+ # ========================
48
+ # Initialization and state
49
+ # ========================
50
+
51
+ @staticmethod
52
+ def build_from_model_description(
53
+ model_description: str | pathlib.Path | rod.Model,
54
+ model_name: str | None = None,
55
+ *,
56
+ terrain: jaxsim.terrain.Terrain | None = None,
57
+ is_urdf: bool | None = None,
58
+ considered_joints: list[str] | None = None,
59
+ ) -> JaxSimModel:
60
+ """
61
+ Build a Model object from a model description.
62
+
63
+ Args:
64
+ model_description:
65
+ A path to an SDF/URDF file, a string containing
66
+ its content, or a pre-parsed/pre-built rod model.
67
+ model_name:
68
+ The optional name of the model that overrides the one in
69
+ the description.
70
+ terrain:
71
+ The optional terrain to consider.
72
+ is_urdf:
73
+ Whether the model description is a URDF or an SDF. This is
74
+ automatically inferred if the model description is a path to a file.
75
+ considered_joints:
76
+ The list of joints to consider. If None, all joints are considered.
77
+
78
+ Returns:
79
+ The built Model object.
80
+ """
81
+
82
+ import jaxsim.parsers.rod
83
+
84
+ # Parse the input resource (either a path to file or a string with the URDF/SDF)
85
+ # and build the -intermediate- model description
86
+ intermediate_description = jaxsim.parsers.rod.build_model_description(
87
+ model_description=model_description, is_urdf=is_urdf
88
+ )
89
+
90
+ # Lump links together if not all joints are considered.
91
+ # Note: this procedure assigns a zero position to all joints not considered.
92
+ if considered_joints is not None:
93
+ intermediate_description = intermediate_description.reduce(
94
+ considered_joints=considered_joints
95
+ )
96
+
97
+ # Build the model
98
+ model = JaxSimModel.build(
99
+ model_description=intermediate_description,
100
+ model_name=model_name,
101
+ terrain=terrain,
102
+ )
103
+
104
+ # Store the origin of the model, in case downstream logic needs it
105
+ with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
106
+ model.built_from = model_description
107
+
108
+ return model
109
+
110
+ @staticmethod
111
+ def build(
112
+ model_description: jaxsim.parsers.descriptions.ModelDescription,
113
+ model_name: str | None = None,
114
+ *,
115
+ terrain: jaxsim.terrain.Terrain | None = None,
116
+ ) -> JaxSimModel:
117
+ """
118
+ Build a Model object from an intermediate model description.
119
+
120
+ Args:
121
+ model_description:
122
+ The intermediate model description defining the kinematics and dynamics
123
+ of the model.
124
+ model_name:
125
+ The optional name of the model overriding the physics model name.
126
+ terrain:
127
+ The optional terrain to consider.
128
+
129
+ Returns:
130
+ The built Model object.
131
+ """
132
+
133
+ # Set the model name (if not provided, use the one from the model description)
134
+ model_name = model_name if model_name is not None else model_description.name
135
+
136
+ # Build the model
137
+ model = JaxSimModel(
138
+ model_name=model_name,
139
+ description=HashlessObject(obj=model_description),
140
+ kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
141
+ model_description=model_description
142
+ ),
143
+ terrain=terrain or JaxSimModel.__dataclass_fields__["terrain"].default,
144
+ )
145
+
146
+ return model
147
+
148
+ # ==========
149
+ # Properties
150
+ # ==========
151
+
152
+ def name(self) -> str:
153
+ """
154
+ Return the name of the model.
155
+
156
+ Returns:
157
+ The name of the model.
158
+ """
159
+
160
+ return self.model_name
161
+
162
+ def number_of_links(self) -> jtp.Int:
163
+ """
164
+ Return the number of links in the model.
165
+
166
+ Returns:
167
+ The number of links in the model.
168
+
169
+ Note:
170
+ The base link is included in the count and its index is always 0.
171
+ """
172
+
173
+ return self.kin_dyn_parameters.number_of_links()
174
+
175
+ def number_of_joints(self) -> jtp.Int:
176
+ """
177
+ Return the number of joints in the model.
178
+
179
+ Returns:
180
+ The number of joints in the model.
181
+ """
182
+
183
+ return self.kin_dyn_parameters.number_of_joints()
184
+
185
+ # =================
186
+ # Base link methods
187
+ # =================
188
+
189
+ def floating_base(self) -> bool:
190
+ """
191
+ Return whether the model has a floating base.
192
+
193
+ Returns:
194
+ True if the model is floating-base, False otherwise.
195
+ """
196
+
197
+ return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
198
+
199
+ def base_link(self) -> str:
200
+ """
201
+ Return the name of the base link.
202
+
203
+ Returns:
204
+ The name of the base link.
205
+
206
+ Note:
207
+ By default, the base link is the root of the kinematic tree.
208
+ """
209
+
210
+ return self.link_names()[0]
211
+
212
+ # =====================
213
+ # Joint-related methods
214
+ # =====================
215
+
216
+ def dofs(self) -> int:
217
+ """
218
+ Return the number of degrees of freedom of the model.
219
+
220
+ Returns:
221
+ The number of degrees of freedom of the model.
222
+
223
+ Note:
224
+ We do not yet support multi-DoF joints, therefore this is always equal to
225
+ the number of joints. In the future, this could be different.
226
+ """
227
+
228
+ return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
229
+
230
+ def joint_names(self) -> tuple[str, ...]:
231
+ """
232
+ Return the names of the joints in the model.
233
+
234
+ Returns:
235
+ The names of the joints in the model.
236
+ """
237
+
238
+ return self.kin_dyn_parameters.joint_model.joint_names[1:]
239
+
240
+ # ====================
241
+ # Link-related methods
242
+ # ====================
243
+
244
+ def link_names(self) -> tuple[str, ...]:
245
+ """
246
+ Return the names of the links in the model.
247
+
248
+ Returns:
249
+ The names of the links in the model.
250
+ """
251
+
252
+ return self.kin_dyn_parameters.link_names
253
+
254
+
255
+ # =====================
256
+ # Model post-processing
257
+ # =====================
258
+
259
+
260
+ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel:
261
+ """
262
+ Reduce the model by lumping together the links connected by removed joints.
263
+
264
+ Args:
265
+ model: The model to reduce.
266
+ considered_joints: The sequence of joints to consider.
267
+
268
+ Note:
269
+ If considered_joints contains joints not existing in the model, the method
270
+ will raise an exception. If considered_joints is empty, the method will
271
+ return a copy of the input model.
272
+ """
273
+
274
+ # Reduce the model description.
275
+ # If considered_joints contains joints not existing in the model, the method
276
+ # will raise an exception.
277
+ reduced_intermediate_description = model.description.obj.reduce(
278
+ considered_joints=list(considered_joints)
279
+ )
280
+
281
+ # Build the reduced model
282
+ reduced_model = JaxSimModel.build(
283
+ model_description=reduced_intermediate_description,
284
+ model_name=model.name(),
285
+ )
286
+
287
+ # Store the origin of the model, in case downstream logic needs it
288
+ with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
289
+ reduced_model.built_from = model.built_from
290
+
291
+ return reduced_model
292
+
293
+
294
+ # ===================
295
+ # Inertial properties
296
+ # ===================
297
+
298
+
299
+ @jax.jit
300
+ def total_mass(model: JaxSimModel) -> jtp.Float:
301
+ """
302
+ Compute the total mass of the model.
303
+
304
+ Args:
305
+ model: The model to consider.
306
+
307
+ Returns:
308
+ The total mass of the model.
309
+ """
310
+
311
+ return (
312
+ jax.vmap(lambda idx: js.link.mass(model=model, link_index=idx))(
313
+ jnp.arange(model.number_of_links())
314
+ )
315
+ .sum()
316
+ .astype(float)
317
+ )
318
+
319
+
320
+ @jax.jit
321
+ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
322
+ """
323
+ Compute the spatial 6D inertia matrices of all links of the model.
324
+
325
+ Args:
326
+ model: The model to consider.
327
+
328
+ Returns:
329
+ A 3D array containing the stacked spatial 6D inertia matrices of the links.
330
+ """
331
+
332
+ return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)(
333
+ model.kin_dyn_parameters.link_parameters
334
+ )
335
+
336
+
337
+ # ==============================
338
+ # Rigid Body Dynamics Algorithms
339
+ # ==============================
340
+
341
+
342
+ @jax.jit
343
+ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
344
+ """
345
+ Compute the SE(3) transforms from the world frame to the frames of all links.
346
+
347
+ Args:
348
+ model: The model to consider.
349
+ data: The data of the considered model.
350
+
351
+ Returns:
352
+ A (nL, 4, 4) array containing the stacked SE(3) transforms of the links.
353
+ The first axis is the link index.
354
+ """
355
+
356
+ W_H_LL = jaxsim.rbda.forward_kinematics_model(
357
+ model=model,
358
+ base_position=data.base_position(),
359
+ base_quaternion=data.base_orientation(dcm=False),
360
+ joint_positions=data.joint_positions(model=model),
361
+ )
362
+
363
+ return jnp.atleast_3d(W_H_LL).astype(float)
364
+
365
+
366
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
367
+ def generalized_free_floating_jacobian(
368
+ model: JaxSimModel,
369
+ data: js.data.JaxSimModelData,
370
+ *,
371
+ output_vel_repr: VelRepr | None = None,
372
+ ) -> jtp.Matrix:
373
+ """
374
+ Compute the free-floating jacobians of all links.
375
+
376
+ Args:
377
+ model: The model to consider.
378
+ data: The data of the considered model.
379
+ output_vel_repr:
380
+ The output velocity representation of the free-floating jacobians.
381
+
382
+ Returns:
383
+ The `(nL, 6, 6+dofs)` array containing the stacked free-floating
384
+ jacobians of the links. The first axis is the link index.
385
+
386
+ Note:
387
+ The v-stacked version of the returned Jacobian array together with the
388
+ flattened 6D forces of the links, are useful to compute the `J.T @ f`
389
+ product of the multi-body EoM.
390
+ """
391
+
392
+ output_vel_repr = (
393
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
394
+ )
395
+
396
+ # Compute the doubly-left free-floating full jacobian.
397
+ B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
398
+ model=model,
399
+ joint_positions=data.joint_positions(),
400
+ )
401
+
402
+ # Update the input velocity representation such that `J_WL_I @ I_ν`.
403
+ match data.velocity_representation:
404
+ case VelRepr.Inertial:
405
+ W_H_B = data.base_transform()
406
+ B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
407
+ B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
408
+ B_X_W, jnp.eye(model.dofs())
409
+ )
410
+
411
+ case VelRepr.Body:
412
+ B_J_full_WX_I = B_J_full_WX_B
413
+
414
+ case VelRepr.Mixed:
415
+ W_R_B = data.base_orientation(dcm=True)
416
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
417
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
418
+ B_J_full_WX_I = B_J_full_WX_BW = (
419
+ B_J_full_WX_B
420
+ @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
421
+ )
422
+
423
+ case _:
424
+ raise ValueError(data.velocity_representation)
425
+
426
+ # Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
427
+ match output_vel_repr:
428
+ case VelRepr.Inertial:
429
+ W_H_B = data.base_transform()
430
+ W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
431
+ O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I
432
+
433
+ case VelRepr.Body:
434
+ O_J_full_WX_I = B_J_full_WX_I
435
+
436
+ case VelRepr.Mixed:
437
+ W_R_B = data.base_orientation(dcm=True)
438
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
439
+ BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint()
440
+ O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I
441
+
442
+ case _:
443
+ raise ValueError(output_vel_repr)
444
+
445
+ κ_bool = model.kin_dyn_parameters.support_body_array_bool
446
+
447
+ O_J_WL_I = jax.vmap(
448
+ lambda κ: jnp.where(
449
+ jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I)
450
+ )
451
+ )(κ_bool)
452
+
453
+ return O_J_WL_I
454
+
455
+
456
+ @functools.partial(jax.jit, static_argnames=["prefer_aba"])
457
+ def forward_dynamics(
458
+ model: JaxSimModel,
459
+ data: js.data.JaxSimModelData,
460
+ *,
461
+ joint_forces: jtp.VectorLike | None = None,
462
+ link_forces: jtp.MatrixLike | None = None,
463
+ prefer_aba: float = True,
464
+ ) -> tuple[jtp.Vector, jtp.Vector]:
465
+ """
466
+ Compute the forward dynamics of the model.
467
+
468
+ Args:
469
+ model: The model to consider.
470
+ data: The data of the considered model.
471
+ joint_forces:
472
+ The joint forces to consider as a vector of shape `(dofs,)`.
473
+ link_forces:
474
+ The link 6D forces consider as a matrix of shape `(nL, 6)`.
475
+ The frame in which they are expressed must be `data.velocity_representation`.
476
+ prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
477
+
478
+ Returns:
479
+ A tuple containing the 6D acceleration in the active representation of the
480
+ base link and the joint accelerations resulting from the application of the
481
+ considered joint forces and external forces.
482
+ """
483
+
484
+ forward_dynamics_fn = forward_dynamics_aba if prefer_aba else forward_dynamics_crb
485
+
486
+ return forward_dynamics_fn(
487
+ model=model,
488
+ data=data,
489
+ joint_forces=joint_forces,
490
+ link_forces=link_forces,
491
+ )
492
+
493
+
494
+ @jax.jit
495
+ def forward_dynamics_aba(
496
+ model: JaxSimModel,
497
+ data: js.data.JaxSimModelData,
498
+ *,
499
+ joint_forces: jtp.VectorLike | None = None,
500
+ link_forces: jtp.MatrixLike | None = None,
501
+ ) -> tuple[jtp.Vector, jtp.Vector]:
502
+ """
503
+ Compute the forward dynamics of the model with the ABA algorithm.
504
+
505
+ Args:
506
+ model: The model to consider.
507
+ data: The data of the considered model.
508
+ joint_forces:
509
+ The joint forces to consider as a vector of shape `(dofs,)`.
510
+ link_forces:
511
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
512
+ The frame in which they are expressed must be `data.velocity_representation`.
513
+
514
+ Returns:
515
+ A tuple containing the 6D acceleration in the active representation of the
516
+ base link and the joint accelerations resulting from the application of the
517
+ considered joint forces and external forces.
518
+ """
519
+
520
+ # ============
521
+ # Prepare data
522
+ # ============
523
+
524
+ # Build joint forces, if not provided.
525
+ τ = (
526
+ jnp.atleast_1d(joint_forces.squeeze())
527
+ if joint_forces is not None
528
+ else jnp.zeros_like(data.joint_positions())
529
+ )
530
+
531
+ # Build link forces, if not provided.
532
+ f_L = (
533
+ jnp.atleast_2d(link_forces.squeeze())
534
+ if link_forces is not None
535
+ else jnp.zeros((model.number_of_links(), 6))
536
+ )
537
+
538
+ # Create a references object that simplifies converting among representations.
539
+ references = js.references.JaxSimModelReferences.build(
540
+ model=model,
541
+ joint_force_references=τ,
542
+ link_forces=f_L,
543
+ data=data,
544
+ velocity_representation=data.velocity_representation,
545
+ )
546
+
547
+ # Extract the link and joint serializations.
548
+ link_names = model.link_names()
549
+ joint_names = model.joint_names()
550
+
551
+ # Extract the state in inertial-fixed representation.
552
+ with data.switch_velocity_representation(VelRepr.Inertial):
553
+ W_p_B = data.base_position()
554
+ W_v_WB = data.base_velocity()
555
+ W_Q_B = data.base_orientation(dcm=False)
556
+ s = data.joint_positions(model=model, joint_names=joint_names)
557
+ ṡ = data.joint_velocities(model=model, joint_names=joint_names)
558
+
559
+ # Extract the inputs in inertial-fixed representation.
560
+ with references.switch_velocity_representation(VelRepr.Inertial):
561
+ W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
562
+ τ = references.joint_force_references(model=model, joint_names=joint_names)
563
+
564
+ # ========================
565
+ # Compute forward dynamics
566
+ # ========================
567
+
568
+ W_v̇_WB, s̈ = jaxsim.rbda.aba(
569
+ model=model,
570
+ base_position=W_p_B,
571
+ base_quaternion=W_Q_B,
572
+ joint_positions=s,
573
+ base_linear_velocity=W_v_WB[0:3],
574
+ base_angular_velocity=W_v_WB[3:6],
575
+ joint_velocities=ṡ,
576
+ joint_forces=τ,
577
+ link_forces=W_f_L,
578
+ standard_gravity=data.standard_gravity(),
579
+ )
580
+
581
+ # =============
582
+ # Adjust output
583
+ # =============
584
+
585
+ def to_active(
586
+ W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
587
+ ) -> jtp.Vector:
588
+ """
589
+ Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to
590
+ another representation C_v̇_WB expressed in a generic frame C.
591
+ """
592
+
593
+ from jaxsim.math import Cross
594
+
595
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
596
+ # In Inertial and Body representations, the cross product is always zero.
597
+ C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
598
+ return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
599
+
600
+ match data.velocity_representation:
601
+ case VelRepr.Inertial:
602
+ # In this case C=W
603
+ W_H_C = W_H_W = jnp.eye(4)
604
+ W_v_WC = W_v_WW = jnp.zeros(6)
605
+
606
+ case VelRepr.Body:
607
+ # In this case C=B
608
+ W_H_C = W_H_B = data.base_transform()
609
+ W_v_WC = W_v_WB
610
+
611
+ case VelRepr.Mixed:
612
+ # In this case C=B[W]
613
+ W_H_B = data.base_transform()
614
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
615
+ W_ṗ_B = data.base_velocity()[0:3]
616
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
617
+
618
+ case _:
619
+ raise ValueError(data.velocity_representation)
620
+
621
+ # We need to convert the derivative of the base velocity to the active
622
+ # representation. In Mixed representation, this conversion is not a plain
623
+ # transformation with just X, but it also involves a cross product in ℝ⁶.
624
+ C_v̇_WB = to_active(
625
+ W_v̇_WB=W_v̇_WB,
626
+ W_H_C=W_H_C,
627
+ W_v_WB=jnp.hstack(
628
+ [
629
+ data.state.physics_model.base_linear_velocity,
630
+ data.state.physics_model.base_angular_velocity,
631
+ ]
632
+ ),
633
+ W_v_WC=W_v_WC,
634
+ )
635
+
636
+ # The ABA algorithm already returns a zero base 6D acceleration for
637
+ # fixed-based models. However, the to_active function introduces an
638
+ # additional acceleration component in Mixed representation.
639
+ # Here below we make sure that the base acceleration is zero.
640
+ C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6)
641
+
642
+ return C_v̇_WB.astype(float), s̈.astype(float)
643
+
644
+
645
+ @jax.jit
646
+ def forward_dynamics_crb(
647
+ model: JaxSimModel,
648
+ data: js.data.JaxSimModelData,
649
+ *,
650
+ joint_forces: jtp.VectorLike | None = None,
651
+ link_forces: jtp.MatrixLike | None = None,
652
+ ) -> tuple[jtp.Vector, jtp.Vector]:
653
+ """
654
+ Compute the forward dynamics of the model with the CRB algorithm.
655
+
656
+ Args:
657
+ model: The model to consider.
658
+ data: The data of the considered model.
659
+ joint_forces:
660
+ The joint forces to consider as a vector of shape `(dofs,)`.
661
+ link_forces:
662
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
663
+ The frame in which they are expressed must be `data.velocity_representation`.
664
+
665
+ Returns:
666
+ A tuple containing the 6D acceleration in the active representation of the
667
+ base link and the joint accelerations resulting from the application of the
668
+ considered joint forces and external forces.
669
+
670
+ Note:
671
+ Compared to ABA, this method could be significantly slower, especially for
672
+ models with a large number of degrees of freedom.
673
+ """
674
+
675
+ # ============
676
+ # Prepare data
677
+ # ============
678
+
679
+ # Build joint torques if not provided
680
+ τ = (
681
+ jnp.atleast_1d(joint_forces)
682
+ if joint_forces is not None
683
+ else jnp.zeros_like(data.joint_positions())
684
+ )
685
+
686
+ # Build external forces if not provided
687
+ f = (
688
+ jnp.atleast_2d(link_forces)
689
+ if link_forces is not None
690
+ else jnp.zeros(shape=(model.number_of_links(), 6))
691
+ )
692
+
693
+ # Compute terms of the floating-base EoM
694
+ M = free_floating_mass_matrix(model=model, data=data)
695
+ h = free_floating_bias_forces(model=model, data=data)
696
+ S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
697
+ J = generalized_free_floating_jacobian(model=model, data=data)
698
+
699
+ # TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
700
+
701
+ # ========================
702
+ # Compute forward dynamics
703
+ # ========================
704
+
705
+ if model.floating_base():
706
+ # l: number of links.
707
+ # g: generalized coordinates, 6 + number of joints.
708
+ JTf = jnp.einsum("l6g,l6->g", J, f)
709
+ ν̇ = jnp.linalg.solve(M, S @ τ - h + JTf)
710
+
711
+ else:
712
+ # l: number of links.
713
+ # j: number of joints.
714
+ JTf = jnp.einsum("l6j,l6->j", J[:, :, 6:], f)
715
+ s̈ = jnp.linalg.solve(M[6:, 6:], τ - h[6:] + JTf)
716
+
717
+ v̇_WB = jnp.zeros(6)
718
+ ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
719
+
720
+ # =============
721
+ # Adjust output
722
+ # =============
723
+
724
+ # Extract the base acceleration in the active representation.
725
+ # Note that this is an apparent acceleration (relevant in Mixed representation),
726
+ # therefore it cannot be always expressed in different frames with just a
727
+ # 6D transformation X.
728
+ v̇_WB = ν̇[0:6].squeeze().astype(float)
729
+
730
+ # Extract the joint accelerations
731
+ s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
732
+
733
+ return v̇_WB, s̈
734
+
735
+
736
+ @jax.jit
737
+ def free_floating_mass_matrix(
738
+ model: JaxSimModel, data: js.data.JaxSimModelData
739
+ ) -> jtp.Matrix:
740
+ """
741
+ Compute the free-floating mass matrix of the model with the CRBA algorithm.
742
+
743
+ Args:
744
+ model: The model to consider.
745
+ data: The data of the considered model.
746
+
747
+ Returns:
748
+ The free-floating mass matrix of the model.
749
+ """
750
+
751
+ M_body = jaxsim.rbda.crba(
752
+ model=model,
753
+ joint_positions=data.state.physics_model.joint_positions,
754
+ )
755
+
756
+ match data.velocity_representation:
757
+ case VelRepr.Body:
758
+ return M_body
759
+
760
+ case VelRepr.Inertial:
761
+
762
+ B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
763
+ invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
764
+
765
+ return invT.T @ M_body @ invT
766
+
767
+ case VelRepr.Mixed:
768
+
769
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
770
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
771
+ invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
772
+
773
+ return invT.T @ M_body @ invT
774
+
775
+ case _:
776
+ raise ValueError(data.velocity_representation)
777
+
778
+
779
+ @jax.jit
780
+ def inverse_dynamics(
781
+ model: JaxSimModel,
782
+ data: js.data.JaxSimModelData,
783
+ *,
784
+ joint_accelerations: jtp.VectorLike | None = None,
785
+ base_acceleration: jtp.VectorLike | None = None,
786
+ link_forces: jtp.MatrixLike | None = None,
787
+ ) -> tuple[jtp.Vector, jtp.Vector]:
788
+ """
789
+ Compute inverse dynamics with the RNEA algorithm.
790
+
791
+ Args:
792
+ model: The model to consider.
793
+ data: The data of the considered model.
794
+ joint_accelerations:
795
+ The joint accelerations to consider as a vector of shape `(dofs,)`.
796
+ base_acceleration:
797
+ The base acceleration to consider as a vector of shape `(6,)`.
798
+ link_forces:
799
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
800
+ The frame in which they are expressed must be `data.velocity_representation`.
801
+
802
+ Returns:
803
+ A tuple containing the 6D force in the active representation applied to the
804
+ base to obtain the considered base acceleration, and the joint forces to apply
805
+ to obtain the considered joint accelerations.
806
+ """
807
+
808
+ # ============
809
+ # Prepare data
810
+ # ============
811
+
812
+ # Build joint accelerations, if not provided.
813
+ s̈ = (
814
+ jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
815
+ if joint_accelerations is not None
816
+ else jnp.zeros_like(data.joint_positions())
817
+ )
818
+
819
+ # Build base acceleration, if not provided.
820
+ v̇_WB = (
821
+ jnp.array(base_acceleration).squeeze()
822
+ if base_acceleration is not None
823
+ else jnp.zeros(6)
824
+ )
825
+
826
+ # Build link forces, if not provided.
827
+ f_L = (
828
+ jnp.atleast_2d(jnp.array(link_forces).squeeze())
829
+ if link_forces is not None
830
+ else jnp.zeros(shape=(model.number_of_links(), 6))
831
+ )
832
+
833
+ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
834
+ """
835
+ Helper to convert the active representation of the base acceleration C_v̇_WB
836
+ expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
837
+ """
838
+
839
+ from jaxsim.math import Cross
840
+
841
+ W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
842
+ C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
843
+ C_v_WC = C_X_W @ W_v_WC
844
+
845
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
846
+ # In Inertial and Body representations, the cross product is always zero.
847
+ return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
848
+
849
+ match data.velocity_representation:
850
+ case VelRepr.Inertial:
851
+ W_H_C = W_H_W = jnp.eye(4)
852
+ W_v_WC = W_v_WW = jnp.zeros(6)
853
+
854
+ case VelRepr.Body:
855
+ W_H_C = W_H_B = data.base_transform()
856
+ with data.switch_velocity_representation(VelRepr.Inertial):
857
+ W_v_WC = W_v_WB = data.base_velocity()
858
+
859
+ case VelRepr.Mixed:
860
+ W_H_B = data.base_transform()
861
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
862
+ W_ṗ_B = data.base_velocity()[0:3]
863
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
864
+
865
+ case _:
866
+ raise ValueError(data.velocity_representation)
867
+
868
+ # We need to convert the derivative of the base acceleration to the Inertial
869
+ # representation. In Mixed representation, this conversion is not a plain
870
+ # transformation with just X, but it also involves a cross product in ℝ⁶.
871
+ W_v̇_WB = to_inertial(
872
+ C_v̇_WB=v̇_WB,
873
+ W_H_C=W_H_C,
874
+ C_v_WB=data.base_velocity(),
875
+ W_v_WC=W_v_WC,
876
+ )
877
+
878
+ # Create a references object that simplifies converting among representations.
879
+ references = js.references.JaxSimModelReferences.build(
880
+ model=model,
881
+ data=data,
882
+ link_forces=f_L,
883
+ velocity_representation=data.velocity_representation,
884
+ )
885
+
886
+ # Extract the link and joint serializations.
887
+ link_names = model.link_names()
888
+ joint_names = model.joint_names()
889
+
890
+ # Extract the state in inertial-fixed representation.
891
+ with data.switch_velocity_representation(VelRepr.Inertial):
892
+ W_p_B = data.base_position()
893
+ W_v_WB = data.base_velocity()
894
+ W_Q_B = data.base_orientation(dcm=False)
895
+ s = data.joint_positions(model=model, joint_names=joint_names)
896
+ ṡ = data.joint_velocities(model=model, joint_names=joint_names)
897
+
898
+ # Extract the inputs in inertial-fixed representation.
899
+ with references.switch_velocity_representation(VelRepr.Inertial):
900
+ W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
901
+
902
+ # ========================
903
+ # Compute inverse dynamics
904
+ # ========================
905
+
906
+ W_f_B, τ = jaxsim.rbda.rnea(
907
+ model=model,
908
+ base_position=W_p_B,
909
+ base_quaternion=W_Q_B,
910
+ joint_positions=s,
911
+ base_linear_velocity=W_v_WB[0:3],
912
+ base_angular_velocity=W_v_WB[3:6],
913
+ joint_velocities=ṡ,
914
+ base_linear_acceleration=W_v̇_WB[0:3],
915
+ base_angular_acceleration=W_v̇_WB[3:6],
916
+ joint_accelerations=s̈,
917
+ link_forces=W_f_L,
918
+ standard_gravity=data.standard_gravity(),
919
+ )
920
+
921
+ # =============
922
+ # Adjust output
923
+ # =============
924
+
925
+ # Express W_f_B in the active representation.
926
+ f_B = js.data.JaxSimModelData.inertial_to_other_representation(
927
+ array=W_f_B,
928
+ other_representation=data.velocity_representation,
929
+ transform=data.base_transform(),
930
+ is_force=True,
931
+ ).squeeze()
932
+
933
+ return f_B.astype(float), τ.astype(float)
934
+
935
+
936
+ @jax.jit
937
+ def free_floating_gravity_forces(
938
+ model: JaxSimModel, data: js.data.JaxSimModelData
939
+ ) -> jtp.Vector:
940
+ r"""
941
+ Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model.
942
+
943
+ Args:
944
+ model: The model to consider.
945
+ data: The data of the considered model.
946
+
947
+ Returns:
948
+ The free-floating gravity forces of the model.
949
+ """
950
+
951
+ # Build a zeroed state
952
+ data_rnea = js.data.JaxSimModelData.zero(
953
+ model=model, velocity_representation=data.velocity_representation
954
+ )
955
+
956
+ # Set just the generalized position
957
+ with data_rnea.mutable_context(
958
+ mutability=Mutability.MUTABLE, restore_after_exception=False
959
+ ):
960
+
961
+ data_rnea.state.physics_model.base_position = (
962
+ data.state.physics_model.base_position
963
+ )
964
+
965
+ data_rnea.state.physics_model.base_quaternion = (
966
+ data.state.physics_model.base_quaternion
967
+ )
968
+
969
+ data_rnea.state.physics_model.joint_positions = (
970
+ data.state.physics_model.joint_positions
971
+ )
972
+
973
+ return jnp.hstack(
974
+ inverse_dynamics(
975
+ model=model,
976
+ data=data_rnea,
977
+ # Set zero inputs:
978
+ joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
979
+ base_acceleration=jnp.zeros(6),
980
+ link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
981
+ )
982
+ ).astype(float)
983
+
984
+
985
+ @jax.jit
986
+ def free_floating_bias_forces(
987
+ model: JaxSimModel, data: js.data.JaxSimModelData
988
+ ) -> jtp.Vector:
989
+ r"""
990
+ Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})`
991
+ of the model.
992
+
993
+ Args:
994
+ model: The model to consider.
995
+ data: The data of the considered model.
996
+
997
+ Returns:
998
+ The free-floating bias forces of the model.
999
+ """
1000
+
1001
+ # Build a zeroed state
1002
+ data_rnea = js.data.JaxSimModelData.zero(
1003
+ model=model, velocity_representation=data.velocity_representation
1004
+ )
1005
+
1006
+ # Set the generalized position and generalized velocity
1007
+ with data_rnea.mutable_context(
1008
+ mutability=Mutability.MUTABLE, restore_after_exception=False
1009
+ ):
1010
+
1011
+ data_rnea.state.physics_model.base_position = (
1012
+ data.state.physics_model.base_position
1013
+ )
1014
+
1015
+ data_rnea.state.physics_model.base_quaternion = (
1016
+ data.state.physics_model.base_quaternion
1017
+ )
1018
+
1019
+ data_rnea.state.physics_model.joint_positions = (
1020
+ data.state.physics_model.joint_positions
1021
+ )
1022
+
1023
+ data_rnea.state.physics_model.joint_velocities = (
1024
+ data.state.physics_model.joint_velocities
1025
+ )
1026
+
1027
+ # Make sure that base velocity is zero for fixed-base model.
1028
+ if model.floating_base():
1029
+ data_rnea.state.physics_model.base_linear_velocity = (
1030
+ data.state.physics_model.base_linear_velocity
1031
+ )
1032
+
1033
+ data_rnea.state.physics_model.base_angular_velocity = (
1034
+ data.state.physics_model.base_angular_velocity
1035
+ )
1036
+
1037
+ return jnp.hstack(
1038
+ inverse_dynamics(
1039
+ model=model,
1040
+ data=data_rnea,
1041
+ # Set zero inputs:
1042
+ joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
1043
+ base_acceleration=jnp.zeros(6),
1044
+ link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
1045
+ )
1046
+ ).astype(float)
1047
+
1048
+
1049
+ # ==========================
1050
+ # Other kinematic quantities
1051
+ # ==========================
1052
+
1053
+
1054
+ @jax.jit
1055
+ def locked_spatial_inertia(
1056
+ model: JaxSimModel, data: js.data.JaxSimModelData
1057
+ ) -> jtp.Matrix:
1058
+ """
1059
+ Compute the locked 6D inertia matrix of the model.
1060
+
1061
+ Args:
1062
+ model: The model to consider.
1063
+ data: The data of the considered model.
1064
+
1065
+ Returns:
1066
+ The locked 6D inertia matrix of the model.
1067
+ """
1068
+
1069
+ return total_momentum_jacobian(model=model, data=data)[:, 0:6]
1070
+
1071
+
1072
+ @jax.jit
1073
+ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
1074
+ """
1075
+ Compute the total momentum of the model.
1076
+
1077
+ Args:
1078
+ model: The model to consider.
1079
+ data: The data of the considered model.
1080
+
1081
+ Returns:
1082
+ The total momentum of the model in the active velocity representation.
1083
+ """
1084
+
1085
+ ν = data.generalized_velocity()
1086
+ Jh = total_momentum_jacobian(model=model, data=data)
1087
+
1088
+ return Jh @ ν
1089
+
1090
+
1091
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
1092
+ def total_momentum_jacobian(
1093
+ model: JaxSimModel,
1094
+ data: js.data.JaxSimModelData,
1095
+ *,
1096
+ output_vel_repr: VelRepr | None = None,
1097
+ ) -> jtp.Matrix:
1098
+ """
1099
+ Compute the jacobian of the total momentum.
1100
+
1101
+ Args:
1102
+ model: The model to consider.
1103
+ data: The data of the considered model.
1104
+ output_vel_repr: The output velocity representation of the jacobian.
1105
+
1106
+ Returns:
1107
+ The jacobian of the total momentum of the model in the active representation.
1108
+ """
1109
+
1110
+ output_vel_repr = (
1111
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
1112
+ )
1113
+
1114
+ if output_vel_repr is data.velocity_representation:
1115
+ return free_floating_mass_matrix(model=model, data=data)[0:6]
1116
+
1117
+ with data.switch_velocity_representation(VelRepr.Body):
1118
+ B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6]
1119
+
1120
+ match data.velocity_representation:
1121
+ case VelRepr.Body:
1122
+ B_Jh = B_Jh_B
1123
+
1124
+ case VelRepr.Inertial:
1125
+ B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
1126
+ B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
1127
+
1128
+ case VelRepr.Mixed:
1129
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1130
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
1131
+ B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
1132
+
1133
+ case _:
1134
+ raise ValueError(data.velocity_representation)
1135
+
1136
+ match output_vel_repr:
1137
+ case VelRepr.Body:
1138
+ return B_Jh
1139
+
1140
+ case VelRepr.Inertial:
1141
+ W_H_B = data.base_transform()
1142
+ B_Xv_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
1143
+ W_Xf_B = B_Xv_W.T
1144
+ W_Jh = W_Xf_B @ B_Jh
1145
+ return W_Jh
1146
+
1147
+ case VelRepr.Mixed:
1148
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1149
+ B_Xv_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
1150
+ BW_Xf_B = B_Xv_BW.T
1151
+ BW_Jh = BW_Xf_B @ B_Jh
1152
+ return BW_Jh
1153
+
1154
+ case _:
1155
+ raise ValueError(output_vel_repr)
1156
+
1157
+
1158
+ @jax.jit
1159
+ def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
1160
+ """
1161
+ Compute the average velocity of the model.
1162
+
1163
+ Args:
1164
+ model: The model to consider.
1165
+ data: The data of the considered model.
1166
+
1167
+ Returns:
1168
+ The average velocity of the model computed in the base frame and expressed
1169
+ in the active representation.
1170
+ """
1171
+
1172
+ ν = data.generalized_velocity()
1173
+ J = average_velocity_jacobian(model=model, data=data)
1174
+
1175
+ return J @ ν
1176
+
1177
+
1178
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
1179
+ def average_velocity_jacobian(
1180
+ model: JaxSimModel,
1181
+ data: js.data.JaxSimModelData,
1182
+ *,
1183
+ output_vel_repr: VelRepr | None = None,
1184
+ ) -> jtp.Matrix:
1185
+ """
1186
+ Compute the Jacobian of the average velocity of the model.
1187
+
1188
+ Args:
1189
+ model: The model to consider.
1190
+ data: The data of the considered model.
1191
+ output_vel_repr: The output velocity representation of the jacobian.
1192
+
1193
+ Returns:
1194
+ The Jacobian of the average centroidal velocity of the model in the desired
1195
+ representation.
1196
+ """
1197
+
1198
+ output_vel_repr = (
1199
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
1200
+ )
1201
+
1202
+ # Depending on the velocity representation, the frame G is either G[W] or G[B].
1203
+ G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
1204
+
1205
+ match output_vel_repr:
1206
+
1207
+ case VelRepr.Inertial:
1208
+
1209
+ GW_J = G_J
1210
+ W_p_CoM = js.com.com_position(model=model, data=data)
1211
+
1212
+ W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
1213
+ W_X_GW = jaxlie.SE3.from_matrix(W_H_GW).adjoint()
1214
+
1215
+ return W_X_GW @ GW_J
1216
+
1217
+ case VelRepr.Body:
1218
+
1219
+ GB_J = G_J
1220
+ W_p_B = data.base_position()
1221
+ W_p_CoM = js.com.com_position(model=model, data=data)
1222
+ B_R_W = data.base_orientation(dcm=True).transpose()
1223
+
1224
+ B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
1225
+ B_X_GB = jaxlie.SE3.from_matrix(B_H_GB).adjoint()
1226
+
1227
+ return B_X_GB @ GB_J
1228
+
1229
+ case VelRepr.Mixed:
1230
+
1231
+ GW_J = G_J
1232
+ W_p_B = data.base_position()
1233
+ W_p_CoM = js.com.com_position(model=model, data=data)
1234
+
1235
+ BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
1236
+ BW_X_GW = jaxlie.SE3.from_matrix(BW_H_GW).adjoint()
1237
+
1238
+ return BW_X_GW @ GW_J
1239
+
1240
+
1241
+ # ========================
1242
+ # Other dynamic quantities
1243
+ # ========================
1244
+
1245
+
1246
+ @jax.jit
1247
+ def link_bias_accelerations(
1248
+ model: JaxSimModel,
1249
+ data: js.data.JaxSimModelData,
1250
+ ) -> jtp.Vector:
1251
+ r"""
1252
+ Compute the bias accelerations of the links of the model.
1253
+
1254
+ Args:
1255
+ model: The model to consider.
1256
+ data: The data of the considered model.
1257
+
1258
+ Returns:
1259
+ The bias accelerations of the links of the model.
1260
+
1261
+ Note:
1262
+ This function computes the component of the total 6D acceleration not due to
1263
+ the joint or base acceleration.
1264
+ It is often called :math:`\dot{J} \boldsymbol{\nu}`.
1265
+ """
1266
+
1267
+ # ================================================
1268
+ # Compute the body-fixed zero base 6D acceleration
1269
+ # ================================================
1270
+
1271
+ # Compute the base transform.
1272
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
1273
+ rotation=jaxlie.SO3.from_quaternion_xyzw(
1274
+ xyzw=jaxsim.math.Quaternion.to_xyzw(wxyz=data.base_orientation())
1275
+ ),
1276
+ translation=data.base_position(),
1277
+ ).as_matrix()
1278
+
1279
+ def other_representation_to_inertial(
1280
+ C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
1281
+ ) -> jtp.Vector:
1282
+ """
1283
+ Helper to convert the active representation of the base acceleration C_v̇_WB
1284
+ expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1285
+ """
1286
+
1287
+ W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
1288
+ C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
1289
+
1290
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
1291
+ # In Inertial and Body representations, the cross product is always zero.
1292
+ return W_X_C @ (C_v̇_WB + jaxsim.math.Cross.vx(C_X_W @ W_v_WC) @ C_v_WB)
1293
+
1294
+ # Here we initialize a zero 6D acceleration in the active representation, and
1295
+ # convert it to inertial-fixed. This is a useful intermediate representation
1296
+ # because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration
1297
+ # W_a_WB, and intrinsic accelerations can be expressed in different frames through
1298
+ # a simple C_X_W 6D transform.
1299
+ match data.velocity_representation:
1300
+ case VelRepr.Inertial:
1301
+ W_H_C = W_H_W = jnp.eye(4)
1302
+ W_v_WC = W_v_WW = jnp.zeros(6)
1303
+ with data.switch_velocity_representation(VelRepr.Inertial):
1304
+ C_v_WB = W_v_WB = data.base_velocity()
1305
+
1306
+ case VelRepr.Body:
1307
+ W_H_C = W_H_B
1308
+ with data.switch_velocity_representation(VelRepr.Inertial):
1309
+ W_v_WC = W_v_WB = data.base_velocity()
1310
+ with data.switch_velocity_representation(VelRepr.Body):
1311
+ C_v_WB = B_v_WB = data.base_velocity()
1312
+
1313
+ case VelRepr.Mixed:
1314
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
1315
+ W_H_C = W_H_BW
1316
+ with data.switch_velocity_representation(VelRepr.Mixed):
1317
+ W_ṗ_B = data.base_velocity()[0:3]
1318
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1319
+ with data.switch_velocity_representation(VelRepr.Mixed):
1320
+ C_v_WB = BW_v_WB = data.base_velocity()
1321
+ case _:
1322
+ raise ValueError(data.velocity_representation)
1323
+
1324
+ # Convert a zero 6D acceleration from the active representation to inertial-fixed.
1325
+ W_v̇_WB = other_representation_to_inertial(
1326
+ C_v̇_WB=jnp.zeros(6), C_v_WB=C_v_WB, W_H_C=W_H_C, W_v_WC=W_v_WC
1327
+ )
1328
+
1329
+ # ===================================
1330
+ # Initialize buffers and prepare data
1331
+ # ===================================
1332
+
1333
+ # Get the parent array λ(i).
1334
+ # Note: λ(0) must not be used, it's initialized to -1.
1335
+ λ = model.kin_dyn_parameters.parent_array
1336
+
1337
+ # Compute 6D transforms of the base velocity.
1338
+ B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
1339
+
1340
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
1341
+ # These transforms define the relative kinematics of the entire model, including
1342
+ # the base transform for both floating-base and fixed-base models.
1343
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
1344
+ joint_positions=data.joint_positions(), base_transform=W_H_B
1345
+ )
1346
+
1347
+ # Allocate the buffer to store the body-fixed link velocities.
1348
+ L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))
1349
+
1350
+ # Store the base velocity.
1351
+ with data.switch_velocity_representation(VelRepr.Body):
1352
+ B_v_WB = data.base_velocity()
1353
+ L_v_WL = L_v_WL.at[0].set(B_v_WB)
1354
+
1355
+ # Get the joint velocities.
1356
+ ṡ = data.joint_velocities(model=model, joint_names=model.joint_names())
1357
+
1358
+ # Allocate the buffer to store the body-fixed link accelerations,
1359
+ # and initialize the base acceleration.
1360
+ L_v̇_WL = jnp.zeros(shape=(model.number_of_links(), 6))
1361
+ L_v̇_WL = L_v̇_WL.at[0].set(B_X_W @ W_v̇_WB)
1362
+
1363
+ # ======================================
1364
+ # Propagate accelerations and velocities
1365
+ # ======================================
1366
+
1367
+ # The computation of the bias forces is similar to the forward pass of RNEA,
1368
+ # this time with zero base and joint accelerations. Furthermore, here we do
1369
+ # not remove gravity during the propagation.
1370
+
1371
+ # Initialize the loop.
1372
+ Carry = tuple[jtp.MatrixJax, jtp.MatrixJax]
1373
+ carry0: Carry = (L_v_WL, L_v̇_WL)
1374
+
1375
+ def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
1376
+ # Initialize index and unpack the carry.
1377
+ ii = i - 1
1378
+ v, a = carry
1379
+
1380
+ # Get the motion subspace of the joint.
1381
+ Si = S[i].squeeze()
1382
+
1383
+ # Project the joint velocity into its motion subspace.
1384
+ vJ = Si * ṡ[ii]
1385
+
1386
+ # Propagate the link body-fixed velocity.
1387
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
1388
+ v = v.at[i].set(v_i)
1389
+
1390
+ # Propagate the link body-fixed acceleration considering zero joint acceleration.
1391
+ s̈ = 0.0
1392
+ a_i = i_X_λi[i] @ a[λ[i]] + Si * s̈ + jaxsim.math.Cross.vx(v[i]) @ vJ
1393
+ a = a.at[i].set(a_i)
1394
+
1395
+ return (v, a), None
1396
+
1397
+ # Compute the body-fixed velocity and body-fixed apparent acceleration of the links.
1398
+ (L_v_WL, L_v̇_WL), _ = (
1399
+ jax.lax.scan(
1400
+ f=propagate_accelerations,
1401
+ init=carry0,
1402
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
1403
+ )
1404
+ if model.number_of_links() > 1
1405
+ else [(L_v_WL, L_v̇_WL), None]
1406
+ )
1407
+
1408
+ # ===================================================================
1409
+ # Convert the body-fixed 6D acceleration to the active representation
1410
+ # ===================================================================
1411
+
1412
+ def body_to_other_representation(
1413
+ L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
1414
+ ) -> jtp.Vector:
1415
+ """
1416
+ Helper to convert the body-fixed apparent acceleration L_v̇_WL to
1417
+ another representation C_v̇_WL expressed in a generic frame C.
1418
+ """
1419
+
1420
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
1421
+ # In Inertial and Body representations, the cross product is always zero.
1422
+ C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L)
1423
+ return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL)
1424
+
1425
+ match data.velocity_representation:
1426
+ case VelRepr.Body:
1427
+ C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links())
1428
+ L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6))
1429
+
1430
+ case VelRepr.Inertial:
1431
+ C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
1432
+ L_v_CL = L_v_WL
1433
+
1434
+ case VelRepr.Mixed:
1435
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
1436
+ LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
1437
+ C_H_L = LW_H_L
1438
+ L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL)
1439
+
1440
+ case _:
1441
+ raise ValueError(data.velocity_representation)
1442
+
1443
+ # Convert from body-fixed to the active representation.
1444
+ O_v̇_WL = jax.vmap(body_to_other_representation)(
1445
+ L_v̇_WL=L_v̇_WL, L_v_WL=L_v_WL, C_H_L=C_H_L, L_v_CL=L_v_CL
1446
+ )
1447
+
1448
+ return O_v̇_WL
1449
+
1450
+
1451
+ @jax.jit
1452
+ def link_contact_forces(
1453
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
1454
+ ) -> jtp.Matrix:
1455
+ """
1456
+ Compute the 6D contact forces of all links of the model.
1457
+
1458
+ Args:
1459
+ model: The model to consider.
1460
+ data: The data of the considered model.
1461
+
1462
+ Returns:
1463
+ A (nL, 6) array containing the stacked 6D contact forces of the links,
1464
+ expressed in the frame corresponding to the active representation.
1465
+ """
1466
+
1467
+ # Compute the 6D forces applied to each collidable point expressed in the
1468
+ # inertial frame.
1469
+ with data.switch_velocity_representation(VelRepr.Inertial):
1470
+ W_f_Ci = js.contact.collidable_point_forces(model=model, data=data)
1471
+
1472
+ # Construct the vector defining the parent link index of each collidable point.
1473
+ # We use this vector to sum the 6D forces of all collidable points rigidly
1474
+ # attached to the same link.
1475
+ parent_link_index_of_collidable_points = jnp.array(
1476
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
1477
+ )
1478
+
1479
+ # Sum the forces of all collidable points rigidly attached to a body.
1480
+ # Since the contact forces W_f_Ci are expressed in the world frame,
1481
+ # we don't need any coordinate transformation.
1482
+ W_f_Li = jax.vmap(
1483
+ lambda nc: (
1484
+ jnp.vstack(
1485
+ jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
1486
+ )
1487
+ * W_f_Ci
1488
+ ).sum(axis=0)
1489
+ )(jnp.arange(model.number_of_links()))
1490
+
1491
+ # Convert the 6D forces to the active representation.
1492
+ f_Li = jax.vmap(
1493
+ lambda W_f_L: data.inertial_to_other_representation(
1494
+ array=W_f_L,
1495
+ other_representation=data.velocity_representation,
1496
+ transform=data.base_transform(),
1497
+ is_force=True,
1498
+ )
1499
+ )(W_f_Li)
1500
+
1501
+ return f_Li
1502
+
1503
+
1504
+ # ======
1505
+ # Energy
1506
+ # ======
1507
+
1508
+
1509
+ @jax.jit
1510
+ def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1511
+ """
1512
+ Compute the mechanical energy of the model.
1513
+
1514
+ Args:
1515
+ model: The model to consider.
1516
+ data: The data of the considered model.
1517
+
1518
+ Returns:
1519
+ The mechanical energy of the model.
1520
+ """
1521
+
1522
+ K = kinetic_energy(model=model, data=data)
1523
+ U = potential_energy(model=model, data=data)
1524
+
1525
+ return (K + U).astype(float)
1526
+
1527
+
1528
+ @jax.jit
1529
+ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1530
+ """
1531
+ Compute the kinetic energy of the model.
1532
+
1533
+ Args:
1534
+ model: The model to consider.
1535
+ data: The data of the considered model.
1536
+
1537
+ Returns:
1538
+ The kinetic energy of the model.
1539
+ """
1540
+
1541
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
1542
+ B_ν = data.generalized_velocity()
1543
+ M_B = free_floating_mass_matrix(model=model, data=data)
1544
+
1545
+ K = 0.5 * B_ν.T @ M_B @ B_ν
1546
+ return K.squeeze().astype(float)
1547
+
1548
+
1549
+ @jax.jit
1550
+ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1551
+ """
1552
+ Compute the potential energy of the model.
1553
+
1554
+ Args:
1555
+ model: The model to consider.
1556
+ data: The data of the considered model.
1557
+
1558
+ Returns:
1559
+ The potential energy of the model.
1560
+ """
1561
+
1562
+ m = total_mass(model=model)
1563
+ gravity = data.gravity.squeeze()
1564
+ W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
1565
+
1566
+ U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
1567
+ return U.squeeze().astype(float)
1568
+
1569
+
1570
+ # ==========
1571
+ # Simulation
1572
+ # ==========
1573
+
1574
+
1575
+ @jax.jit
1576
+ def step(
1577
+ model: JaxSimModel,
1578
+ data: js.data.JaxSimModelData,
1579
+ *,
1580
+ dt: jtp.FloatLike,
1581
+ integrator: jaxsim.integrators.Integrator,
1582
+ integrator_state: dict[str, Any] | None = None,
1583
+ joint_forces: jtp.VectorLike | None = None,
1584
+ link_forces: jtp.MatrixLike | None = None,
1585
+ **kwargs,
1586
+ ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
1587
+ """
1588
+ Perform a simulation step.
1589
+
1590
+ Args:
1591
+ model: The model to consider.
1592
+ data: The data of the considered model.
1593
+ dt: The time step to consider.
1594
+ integrator: The integrator to use.
1595
+ integrator_state: The state of the integrator.
1596
+ joint_forces: The joint forces to consider.
1597
+ link_forces:
1598
+ The link 6D forces to consider.
1599
+ The frame in which they are expressed must be `data.velocity_representation`.
1600
+ kwargs: Additional kwargs to pass to the integrator.
1601
+
1602
+ Returns:
1603
+ A tuple containing the new data of the model
1604
+ and the new state of the integrator.
1605
+ """
1606
+
1607
+ integrator_kwargs = kwargs if kwargs is not None else dict()
1608
+ integrator_state = integrator_state if integrator_state is not None else dict()
1609
+
1610
+ # Extract the initial resources.
1611
+ t0_ns = data.time_ns
1612
+ state_x0 = data.state
1613
+ integrator_state_x0 = integrator_state
1614
+
1615
+ # Step the dynamics forward.
1616
+ state_xf, integrator_state_xf = integrator.step(
1617
+ x0=state_x0,
1618
+ t0=jnp.array(t0_ns / 1e9).astype(float),
1619
+ dt=dt,
1620
+ params=integrator_state_x0,
1621
+ **(
1622
+ dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs
1623
+ ),
1624
+ )
1625
+
1626
+ return (
1627
+ # Store the new state of the model and the new time.
1628
+ data.replace(
1629
+ state=state_xf,
1630
+ time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1631
+ ),
1632
+ integrator_state_xf,
1633
+ )