jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py CHANGED
@@ -13,15 +13,11 @@ import rod
13
13
  from jax_dataclasses import Static
14
14
 
15
15
  import jaxsim.api as js
16
- import jaxsim.physics.algos.aba
17
- import jaxsim.physics.algos.crba
18
- import jaxsim.physics.algos.forward_kinematics
19
- import jaxsim.physics.algos.rnea
20
- import jaxsim.physics.model.physics_model
16
+ import jaxsim.parsers.descriptions
21
17
  import jaxsim.typing as jtp
22
- from jaxsim.high_level.common import VelRepr
23
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
24
- from jaxsim.utils import JaxsimDataclass, Mutability
18
+ from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability
19
+
20
+ from .common import VelRepr
25
21
 
26
22
 
27
23
  @jax_dataclasses.pytree_dataclass
@@ -32,35 +28,22 @@ class JaxSimModel(JaxsimDataclass):
32
28
 
33
29
  model_name: Static[str]
34
30
 
35
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel = dataclasses.field(
36
- repr=False
31
+ terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
32
+ default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
37
33
  )
38
34
 
39
- terrain: Static[Terrain] = dataclasses.field(default=FlatTerrain(), repr=False)
40
-
41
35
  built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
42
- repr=False, default=None
36
+ default=None, repr=False, compare=False, hash=False
43
37
  )
44
38
 
45
- _number_of_links: Static[int] = dataclasses.field(
46
- init=False, repr=False, default=None
47
- )
39
+ description: Static[
40
+ HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
41
+ ] = dataclasses.field(default=None, repr=False, compare=False, hash=False)
48
42
 
49
- _number_of_joints: Static[int] = dataclasses.field(
50
- init=False, repr=False, default=None
43
+ kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
44
+ dataclasses.field(default=None, repr=False, compare=False, hash=False)
51
45
  )
52
46
 
53
- def __post_init__(self):
54
-
55
- # These attributes are Static so that we can use `jax.vmap` and `jax.lax.scan`
56
- # over the all links and joints
57
- with self.mutable_context(
58
- mutability=Mutability.MUTABLE_NO_VALIDATION,
59
- restore_after_exception=False,
60
- ):
61
- self._number_of_links = len(self.physics_model.description.links_dict)
62
- self._number_of_joints = len(self.physics_model.description.joints_dict)
63
-
64
47
  # ========================
65
48
  # Initialization and state
66
49
  # ========================
@@ -69,7 +52,6 @@ class JaxSimModel(JaxsimDataclass):
69
52
  def build_from_model_description(
70
53
  model_description: str | pathlib.Path | rod.Model,
71
54
  model_name: str | None = None,
72
- gravity: jtp.Array = jaxsim.physics.default_gravity(),
73
55
  is_urdf: bool | None = None,
74
56
  considered_joints: list[str] | None = None,
75
57
  ) -> JaxSimModel:
@@ -83,7 +65,6 @@ class JaxSimModel(JaxsimDataclass):
83
65
  model_name:
84
66
  The optional name of the model that overrides the one in
85
67
  the description.
86
- gravity: The 3D gravity vector.
87
68
  is_urdf:
88
69
  Whether the model description is a URDF or an SDF. This is
89
70
  automatically inferred if the model description is a path to a file.
@@ -109,13 +90,10 @@ class JaxSimModel(JaxsimDataclass):
109
90
  considered_joints=considered_joints
110
91
  )
111
92
 
112
- # Create the physics model from the model description
113
- physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
114
- model_description=intermediate_description, gravity=gravity
115
- )
116
-
117
93
  # Build the model
118
- model = JaxSimModel.build(physics_model=physics_model, model_name=model_name)
94
+ model = JaxSimModel.build(
95
+ model_description=intermediate_description, model_name=model_name
96
+ )
119
97
 
120
98
  # Store the origin of the model, in case downstream logic needs it
121
99
  with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
@@ -125,14 +103,16 @@ class JaxSimModel(JaxsimDataclass):
125
103
 
126
104
  @staticmethod
127
105
  def build(
128
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
106
+ model_description: jaxsim.parsers.descriptions.ModelDescription,
129
107
  model_name: str | None = None,
130
108
  ) -> JaxSimModel:
131
109
  """
132
- Build a Model object from a physics model.
110
+ Build a Model object from an intermediate model description.
133
111
 
134
112
  Args:
135
- physics_model: The physics model.
113
+ model_description:
114
+ The intermediate model description defining the kinematics and dynamics
115
+ of the model.
136
116
  model_name:
137
117
  The optional name of the model overriding the physics model name.
138
118
 
@@ -141,12 +121,16 @@ class JaxSimModel(JaxsimDataclass):
141
121
  """
142
122
 
143
123
  # Set the model name (if not provided, use the one from the model description)
144
- model_name = (
145
- model_name if model_name is not None else physics_model.description.name
146
- )
124
+ model_name = model_name if model_name is not None else model_description.name
147
125
 
148
126
  # Build the model
149
- model = JaxSimModel(physics_model=physics_model, model_name=model_name) # noqa
127
+ model = JaxSimModel(
128
+ model_name=model_name,
129
+ description=HashlessObject(obj=model_description),
130
+ kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
131
+ model_description=model_description
132
+ ),
133
+ )
150
134
 
151
135
  return model
152
136
 
@@ -175,7 +159,7 @@ class JaxSimModel(JaxsimDataclass):
175
159
  The base link is included in the count and its index is always 0.
176
160
  """
177
161
 
178
- return self._number_of_links
162
+ return self.kin_dyn_parameters.number_of_links()
179
163
 
180
164
  def number_of_joints(self) -> jtp.Int:
181
165
  """
@@ -185,7 +169,7 @@ class JaxSimModel(JaxsimDataclass):
185
169
  The number of joints in the model.
186
170
  """
187
171
 
188
- return self._number_of_joints
172
+ return self.kin_dyn_parameters.number_of_joints()
189
173
 
190
174
  # =================
191
175
  # Base link methods
@@ -199,7 +183,7 @@ class JaxSimModel(JaxsimDataclass):
199
183
  True if the model is floating-base, False otherwise.
200
184
  """
201
185
 
202
- return self.physics_model.is_floating_base
186
+ return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
203
187
 
204
188
  def base_link(self) -> str:
205
189
  """
@@ -207,9 +191,12 @@ class JaxSimModel(JaxsimDataclass):
207
191
 
208
192
  Returns:
209
193
  The name of the base link.
194
+
195
+ Note:
196
+ By default, the base link is the root of the kinematic tree.
210
197
  """
211
198
 
212
- return self.physics_model.description.root.name
199
+ return self.link_names()[0]
213
200
 
214
201
  # =====================
215
202
  # Joint-related methods
@@ -227,7 +214,7 @@ class JaxSimModel(JaxsimDataclass):
227
214
  the number of joints. In the future, this could be different.
228
215
  """
229
216
 
230
- return len(self.physics_model.description.joints_dict)
217
+ return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
231
218
 
232
219
  def joint_names(self) -> tuple[str, ...]:
233
220
  """
@@ -237,7 +224,7 @@ class JaxSimModel(JaxsimDataclass):
237
224
  The names of the joints in the model.
238
225
  """
239
226
 
240
- return tuple(self.physics_model.description.joints_dict.keys())
227
+ return self.kin_dyn_parameters.joint_model.joint_names[1:]
241
228
 
242
229
  # ====================
243
230
  # Link-related methods
@@ -251,7 +238,7 @@ class JaxSimModel(JaxsimDataclass):
251
238
  The names of the links in the model.
252
239
  """
253
240
 
254
- return tuple(self.physics_model.description.links_dict.keys())
241
+ return self.kin_dyn_parameters.link_names
255
242
 
256
243
 
257
244
  # =====================
@@ -273,25 +260,17 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode
273
260
  return a copy of the input model.
274
261
  """
275
262
 
276
- if len(considered_joints) == 0:
277
- return model.copy()
278
-
279
263
  # Reduce the model description.
280
264
  # If considered_joints contains joints not existing in the model, the method
281
265
  # will raise an exception.
282
- reduced_intermediate_description = model.physics_model.description.reduce(
266
+ reduced_intermediate_description = model.description.obj.reduce(
283
267
  considered_joints=list(considered_joints)
284
268
  )
285
269
 
286
- # Create the physics model from the reduced model description
287
- physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
288
- model_description=reduced_intermediate_description,
289
- gravity=model.physics_model.gravity[0:3],
290
- )
291
-
292
270
  # Build the reduced model
293
271
  reduced_model = JaxSimModel.build(
294
- physics_model=physics_model, model_name=model.name()
272
+ model_description=reduced_intermediate_description,
273
+ model_name=model.name(),
295
274
  )
296
275
 
297
276
  # Store the origin of the model, in case downstream logic needs it
@@ -327,43 +306,21 @@ def total_mass(model: JaxSimModel) -> jtp.Float:
327
306
  )
328
307
 
329
308
 
330
- # ==============
331
- # Center of mass
332
- # ==============
333
-
334
-
335
309
  @jax.jit
336
- def com_position(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
310
+ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
337
311
  """
338
- Compute the position of the center of mass of the model.
312
+ Compute the spatial 6D inertia matrices of all links of the model.
339
313
 
340
314
  Args:
341
315
  model: The model to consider.
342
- data: The data of the considered model.
343
316
 
344
317
  Returns:
345
- The position of the center of mass of the model w.r.t. the world frame.
318
+ A 3D array containing the stacked spatial 6D inertia matrices of the links.
346
319
  """
347
320
 
348
- m = total_mass(model=model)
349
-
350
- W_H_L = forward_kinematics(model=model, data=data)
351
- W_H_B = data.base_transform()
352
- B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
353
-
354
- def B_p̃_LCoM(i) -> jtp.Vector:
355
- m = js.link.mass(model=model, link_index=i)
356
- L_p_LCoM = js.link.com_position(
357
- model=model, data=data, link_index=i, in_link_frame=True
358
- )
359
- return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
360
-
361
- com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
362
-
363
- B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
364
- B_p̃_CoM = B_p̃_CoM.at[3].set(1)
365
-
366
- return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
321
+ return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)(
322
+ model.kin_dyn_parameters.link_parameters
323
+ )
367
324
 
368
325
 
369
326
  # ==============================
@@ -385,10 +342,11 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp
385
342
  The first axis is the link index.
386
343
  """
387
344
 
388
- W_H_LL = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
389
- model=model.physics_model,
390
- q=data.state.physics_model.joint_positions,
391
- xfb=data.state.physics_model.xfb(),
345
+ W_H_LL = jaxsim.rbda.forward_kinematics_model(
346
+ model=model,
347
+ base_position=data.base_position(),
348
+ base_quaternion=data.base_orientation(dcm=False),
349
+ joint_positions=data.joint_positions(model=model),
392
350
  )
393
351
 
394
352
  return jnp.atleast_3d(W_H_LL).astype(float)
@@ -424,51 +382,64 @@ def generalized_free_floating_jacobian(
424
382
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
425
383
  )
426
384
 
427
- # The body frame of the link.jacobian method is the link frame L.
428
- # In this method, we want instead to use the base link B as body frame.
429
- # Therefore, we always get the link jacobian having Inertial as output
430
- # representation, and then we convert it to the desired output representation.
431
- match output_vel_repr:
385
+ # Compute the doubly-left free-floating full jacobian.
386
+ B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
387
+ model=model,
388
+ joint_positions=data.joint_positions(),
389
+ )
390
+
391
+ # Update the input velocity representation such that `J_WL_I @ I_ν`.
392
+ match data.velocity_representation:
432
393
  case VelRepr.Inertial:
433
- to_output = lambda W_J_WL: W_J_WL
394
+ W_H_B = data.base_transform()
395
+ B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
396
+ B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
397
+ B_X_W, jnp.eye(model.dofs())
398
+ )
434
399
 
435
400
  case VelRepr.Body:
436
-
437
- def to_output(W_J_WL: jtp.Matrix) -> jtp.Matrix:
438
- W_H_B = data.base_transform()
439
- B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
440
- return B_X_W @ W_J_WL
401
+ B_J_full_WX_I = B_J_full_WX_B
441
402
 
442
403
  case VelRepr.Mixed:
404
+ W_R_B = data.base_orientation(dcm=True)
405
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
406
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
407
+ B_J_full_WX_I = B_J_full_WX_BW = (
408
+ B_J_full_WX_B
409
+ @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
410
+ )
411
+
412
+ case _:
413
+ raise ValueError(data.velocity_representation)
414
+
415
+ # Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
416
+ match output_vel_repr:
417
+ case VelRepr.Inertial:
418
+ W_H_B = data.base_transform()
419
+ W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
420
+ O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I
443
421
 
444
- def to_output(W_J_WL: jtp.Matrix) -> jtp.Matrix:
445
- W_H_B = data.base_transform()
446
- W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
447
- BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
448
- return BW_X_W @ W_J_WL
422
+ case VelRepr.Body:
423
+ O_J_full_WX_I = B_J_full_WX_I
424
+
425
+ case VelRepr.Mixed:
426
+ W_R_B = data.base_orientation(dcm=True)
427
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
428
+ BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint()
429
+ O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I
449
430
 
450
431
  case _:
451
432
  raise ValueError(output_vel_repr)
452
433
 
453
- # Compute first the link jacobians having the active representation of `data`
454
- # as input representation (matching the one of ν), and inertial as output
455
- # representation (i.e. W_J_WL_C where C is C_ν).
456
- # Then, with to_output, we convert this jacobian to the desired output
457
- # representation, that can either be W (inertial), B (body), or B[W] (mixed).
458
- # This is necessary because for example the body-fixed free-floating jacobian
459
- # of a link is L_J_WL, but here being inside model we need B_J_WL.
460
- J_free_floating = jax.vmap(
461
- lambda i: to_output(
462
- W_J_WL=js.link.jacobian(
463
- model=model,
464
- data=data,
465
- link_index=i,
466
- output_vel_repr=VelRepr.Inertial,
467
- )
434
+ κ_bool = model.kin_dyn_parameters.support_body_array_bool
435
+
436
+ O_J_WL_I = jax.vmap(
437
+ lambda κ: jnp.where(
438
+ jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I)
468
439
  )
469
- )(jnp.arange(model.number_of_links()))
440
+ )(κ_bool)
470
441
 
471
- return J_free_floating
442
+ return O_J_WL_I
472
443
 
473
444
 
474
445
  @functools.partial(jax.jit, static_argnames=["prefer_aba"])
@@ -477,7 +448,7 @@ def forward_dynamics(
477
448
  data: js.data.JaxSimModelData,
478
449
  *,
479
450
  joint_forces: jtp.VectorLike | None = None,
480
- external_forces: jtp.MatrixLike | None = None,
451
+ link_forces: jtp.MatrixLike | None = None,
481
452
  prefer_aba: float = True,
482
453
  ) -> tuple[jtp.Vector, jtp.Vector]:
483
454
  """
@@ -488,8 +459,8 @@ def forward_dynamics(
488
459
  data: The data of the considered model.
489
460
  joint_forces:
490
461
  The joint forces to consider as a vector of shape `(dofs,)`.
491
- external_forces:
492
- The external forces to consider as a matrix of shape `(nL, 6)`.
462
+ link_forces:
463
+ The link 6D forces consider as a matrix of shape `(nL, 6)`.
493
464
  The frame in which they are expressed must be `data.velocity_representation`.
494
465
  prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
495
466
 
@@ -505,7 +476,7 @@ def forward_dynamics(
505
476
  model=model,
506
477
  data=data,
507
478
  joint_forces=joint_forces,
508
- external_forces=external_forces,
479
+ link_forces=link_forces,
509
480
  )
510
481
 
511
482
 
@@ -515,7 +486,7 @@ def forward_dynamics_aba(
515
486
  data: js.data.JaxSimModelData,
516
487
  *,
517
488
  joint_forces: jtp.VectorLike | None = None,
518
- external_forces: jtp.MatrixLike | None = None,
489
+ link_forces: jtp.MatrixLike | None = None,
519
490
  ) -> tuple[jtp.Vector, jtp.Vector]:
520
491
  """
521
492
  Compute the forward dynamics of the model with the ABA algorithm.
@@ -525,8 +496,8 @@ def forward_dynamics_aba(
525
496
  data: The data of the considered model.
526
497
  joint_forces:
527
498
  The joint forces to consider as a vector of shape `(dofs,)`.
528
- external_forces:
529
- The external forces to consider as a matrix of shape `(nL, 6)`.
499
+ link_forces:
500
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
530
501
  The frame in which they are expressed must be `data.velocity_representation`.
531
502
 
532
503
  Returns:
@@ -535,63 +506,112 @@ def forward_dynamics_aba(
535
506
  considered joint forces and external forces.
536
507
  """
537
508
 
538
- # Build joint torques if not provided
509
+ # ============
510
+ # Prepare data
511
+ # ============
512
+
513
+ # Build joint forces, if not provided.
539
514
  τ = (
540
- joint_forces
515
+ jnp.atleast_1d(joint_forces.squeeze())
541
516
  if joint_forces is not None
542
517
  else jnp.zeros_like(data.joint_positions())
543
518
  )
544
519
 
545
- # Build external forces if not provided
546
- f_ext = (
547
- external_forces
548
- if external_forces is not None
520
+ # Build link forces, if not provided.
521
+ f_L = (
522
+ jnp.atleast_2d(link_forces.squeeze())
523
+ if link_forces is not None
549
524
  else jnp.zeros((model.number_of_links(), 6))
550
525
  )
551
526
 
552
- # Compute ABA
553
- W_v̇_WB, = jaxsim.physics.algos.aba.aba(
554
- model=model.physics_model,
555
- xfb=data.state.physics_model.xfb(),
556
- q=data.state.physics_model.joint_positions,
557
- qd=data.state.physics_model.joint_velocities,
558
- tau=τ,
559
- f_ext=f_ext,
527
+ # Create a references object that simplifies converting among representations.
528
+ references = js.references.JaxSimModelReferences.build(
529
+ model=model,
530
+ joint_force_references=τ,
531
+ link_forces=f_L,
532
+ data=data,
533
+ velocity_representation=data.velocity_representation,
560
534
  )
561
535
 
562
- def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
563
- C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
536
+ # Extract the link and joint serializations.
537
+ link_names = model.link_names()
538
+ joint_names = model.joint_names()
539
+
540
+ # Extract the state in inertial-fixed representation.
541
+ with data.switch_velocity_representation(VelRepr.Inertial):
542
+ W_p_B = data.base_position()
543
+ W_v_WB = data.base_velocity()
544
+ W_Q_B = data.base_orientation(dcm=False)
545
+ s = data.joint_positions(model=model, joint_names=joint_names)
546
+ ṡ = data.joint_velocities(model=model, joint_names=joint_names)
564
547
 
565
- if data.velocity_representation != VelRepr.Mixed:
566
- return C_X_W @ W_vd_WB
548
+ # Extract the inputs in inertial-fixed representation.
549
+ with references.switch_velocity_representation(VelRepr.Inertial):
550
+ W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
551
+ τ = references.joint_force_references(model=model, joint_names=joint_names)
552
+
553
+ # ========================
554
+ # Compute forward dynamics
555
+ # ========================
556
+
557
+ W_v̇_WB, s̈ = jaxsim.rbda.aba(
558
+ model=model,
559
+ base_position=W_p_B,
560
+ base_quaternion=W_Q_B,
561
+ joint_positions=s,
562
+ base_linear_velocity=W_v_WB[0:3],
563
+ base_angular_velocity=W_v_WB[3:6],
564
+ joint_velocities=ṡ,
565
+ joint_forces=τ,
566
+ link_forces=W_f_L,
567
+ standard_gravity=data.standard_gravity(),
568
+ )
569
+
570
+ # =============
571
+ # Adjust output
572
+ # =============
567
573
 
568
- from jaxsim.math.cross import Cross
574
+ def to_active(
575
+ W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
576
+ ) -> jtp.Vector:
577
+ """
578
+ Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to
579
+ another representation C_v̇_WB expressed in a generic frame C.
580
+ """
569
581
 
570
- W_v_WC = jnp.hstack([W_vl_WC, jnp.zeros(3)])
571
- return C_X_W @ (W_vd_WB - Cross.vx(W_v_WC) @ W_v_WB)
582
+ from jaxsim.math import Cross
583
+
584
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
585
+ # In Inertial and Body representations, the cross product is always zero.
586
+ C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
587
+ return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
572
588
 
573
589
  match data.velocity_representation:
574
590
  case VelRepr.Inertial:
591
+ # In this case C=W
575
592
  W_H_C = W_H_W = jnp.eye(4)
576
- W_vl_WC = W_vl_WW = jnp.zeros(3)
593
+ W_v_WC = W_v_WW = jnp.zeros(6)
577
594
 
578
595
  case VelRepr.Body:
596
+ # In this case C=B
579
597
  W_H_C = W_H_B = data.base_transform()
580
- W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
598
+ W_v_WC = W_v_WB
581
599
 
582
600
  case VelRepr.Mixed:
601
+ # In this case C=B[W]
583
602
  W_H_B = data.base_transform()
584
603
  W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
585
- W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
604
+ W_ṗ_B = data.base_velocity()[0:3]
605
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
586
606
 
587
607
  case _:
588
608
  raise ValueError(data.velocity_representation)
589
609
 
590
- # We need to convert the derivative of the base acceleration to the active
610
+ # We need to convert the derivative of the base velocity to the active
591
611
  # representation. In Mixed representation, this conversion is not a plain
592
612
  # transformation with just X, but it also involves a cross product in ℝ⁶.
593
613
  C_v̇_WB = to_active(
594
- W_vd_WB=W_v̇_WB.squeeze(),
614
+ W_v̇_WB=W_v̇_WB,
595
615
  W_H_C=W_H_C,
596
616
  W_v_WB=jnp.hstack(
597
617
  [
@@ -599,13 +619,16 @@ def forward_dynamics_aba(
599
619
  data.state.physics_model.base_angular_velocity,
600
620
  ]
601
621
  ),
602
- W_vl_WC=W_vl_WC,
622
+ W_v_WC=W_v_WC,
603
623
  )
604
624
 
605
- # Adjust shape
606
- = jnp.atleast_1d(s̈.squeeze())
625
+ # The ABA algorithm already returns a zero base 6D acceleration for
626
+ # fixed-based models. However, the to_active function introduces an
627
+ # additional acceleration component in Mixed representation.
628
+ # Here below we make sure that the base acceleration is zero.
629
+ C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6)
607
630
 
608
- return C_v̇_WB, s̈
631
+ return C_v̇_WB.astype(float), s̈.astype(float)
609
632
 
610
633
 
611
634
  @jax.jit
@@ -614,7 +637,7 @@ def forward_dynamics_crb(
614
637
  data: js.data.JaxSimModelData,
615
638
  *,
616
639
  joint_forces: jtp.VectorLike | None = None,
617
- external_forces: jtp.MatrixLike | None = None,
640
+ link_forces: jtp.MatrixLike | None = None,
618
641
  ) -> tuple[jtp.Vector, jtp.Vector]:
619
642
  """
620
643
  Compute the forward dynamics of the model with the CRB algorithm.
@@ -624,8 +647,8 @@ def forward_dynamics_crb(
624
647
  data: The data of the considered model.
625
648
  joint_forces:
626
649
  The joint forces to consider as a vector of shape `(dofs,)`.
627
- external_forces:
628
- The external forces to consider as a matrix of shape `(nL, 6)`.
650
+ link_forces:
651
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
629
652
  The frame in which they are expressed must be `data.velocity_representation`.
630
653
 
631
654
  Returns:
@@ -638,6 +661,10 @@ def forward_dynamics_crb(
638
661
  models with a large number of degrees of freedom.
639
662
  """
640
663
 
664
+ # ============
665
+ # Prepare data
666
+ # ============
667
+
641
668
  # Build joint torques if not provided
642
669
  τ = (
643
670
  jnp.atleast_1d(joint_forces)
@@ -647,8 +674,8 @@ def forward_dynamics_crb(
647
674
 
648
675
  # Build external forces if not provided
649
676
  f = (
650
- jnp.atleast_2d(external_forces)
651
- if external_forces is not None
677
+ jnp.atleast_2d(link_forces)
678
+ if link_forces is not None
652
679
  else jnp.zeros(shape=(model.number_of_links(), 6))
653
680
  )
654
681
 
@@ -660,6 +687,10 @@ def forward_dynamics_crb(
660
687
 
661
688
  # TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
662
689
 
690
+ # ========================
691
+ # Compute forward dynamics
692
+ # ========================
693
+
663
694
  if model.floating_base():
664
695
  # l: number of links.
665
696
  # g: generalized coordinates, 6 + number of joints.
@@ -675,6 +706,10 @@ def forward_dynamics_crb(
675
706
  v̇_WB = jnp.zeros(6)
676
707
  ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
677
708
 
709
+ # =============
710
+ # Adjust output
711
+ # =============
712
+
678
713
  # Extract the base acceleration in the active representation.
679
714
  # Note that this is an apparent acceleration (relevant in Mixed representation),
680
715
  # therefore it cannot be always expressed in different frames with just a
@@ -702,9 +737,9 @@ def free_floating_mass_matrix(
702
737
  The free-floating mass matrix of the model.
703
738
  """
704
739
 
705
- M_body = jaxsim.physics.algos.crba.crba(
706
- model=model.physics_model,
707
- q=data.state.physics_model.joint_positions,
740
+ M_body = jaxsim.rbda.crba(
741
+ model=model,
742
+ joint_positions=data.state.physics_model.joint_positions,
708
743
  )
709
744
 
710
745
  match data.velocity_representation:
@@ -712,29 +747,17 @@ def free_floating_mass_matrix(
712
747
  return M_body
713
748
 
714
749
  case VelRepr.Inertial:
715
- zero_6n = jnp.zeros(shape=(6, model.dofs()))
716
- B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
717
750
 
718
- invT = jnp.vstack(
719
- [
720
- jnp.block([B_X_W, zero_6n]),
721
- jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
722
- ]
723
- )
751
+ B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
752
+ invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
724
753
 
725
754
  return invT.T @ M_body @ invT
726
755
 
727
756
  case VelRepr.Mixed:
728
- zero_6n = jnp.zeros(shape=(6, model.dofs()))
729
- W_H_BW = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
730
- BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
731
-
732
- invT = jnp.vstack(
733
- [
734
- jnp.block([BW_X_W, zero_6n]),
735
- jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
736
- ]
737
- )
757
+
758
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
759
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
760
+ invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
738
761
 
739
762
  return invT.T @ M_body @ invT
740
763
 
@@ -747,9 +770,9 @@ def inverse_dynamics(
747
770
  model: JaxSimModel,
748
771
  data: js.data.JaxSimModelData,
749
772
  *,
750
- joint_accelerations: jtp.Vector | None = None,
751
- base_acceleration: jtp.Vector | None = None,
752
- external_forces: jtp.Matrix | None = None,
773
+ joint_accelerations: jtp.VectorLike | None = None,
774
+ base_acceleration: jtp.VectorLike | None = None,
775
+ link_forces: jtp.MatrixLike | None = None,
753
776
  ) -> tuple[jtp.Vector, jtp.Vector]:
754
777
  """
755
778
  Compute inverse dynamics with the RNEA algorithm.
@@ -761,8 +784,8 @@ def inverse_dynamics(
761
784
  The joint accelerations to consider as a vector of shape `(dofs,)`.
762
785
  base_acceleration:
763
786
  The base acceleration to consider as a vector of shape `(6,)`.
764
- external_forces:
765
- The external forces to consider as a matrix of shape `(nL, 6)`.
787
+ link_forces:
788
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
766
789
  The frame in which they are expressed must be `data.velocity_representation`.
767
790
 
768
791
  Returns:
@@ -771,49 +794,62 @@ def inverse_dynamics(
771
794
  to obtain the considered joint accelerations.
772
795
  """
773
796
 
774
- # Build joint accelerations if not provided
775
- joint_accelerations = (
776
- joint_accelerations
797
+ # ============
798
+ # Prepare data
799
+ # ============
800
+
801
+ # Build joint accelerations, if not provided.
802
+ s̈ = (
803
+ jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
777
804
  if joint_accelerations is not None
778
805
  else jnp.zeros_like(data.joint_positions())
779
806
  )
780
807
 
781
- # Build base acceleration if not provided
782
- base_acceleration = (
783
- base_acceleration if base_acceleration is not None else jnp.zeros(6)
808
+ # Build base acceleration, if not provided.
809
+ v̇_WB = (
810
+ jnp.array(base_acceleration).squeeze()
811
+ if base_acceleration is not None
812
+ else jnp.zeros(6)
784
813
  )
785
814
 
786
- external_forces = (
787
- external_forces
788
- if external_forces is not None
815
+ # Build link forces, if not provided.
816
+ f_L = (
817
+ jnp.atleast_2d(jnp.array(link_forces).squeeze())
818
+ if link_forces is not None
789
819
  else jnp.zeros(shape=(model.number_of_links(), 6))
790
820
  )
791
821
 
792
- def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_vl_WC):
822
+ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
823
+ """
824
+ Helper to convert the active representation of the base acceleration C_v̇_WB
825
+ expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
826
+ """
827
+
828
+ from jaxsim.math import Cross
829
+
793
830
  W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
794
831
  C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
832
+ C_v_WC = C_X_W @ W_v_WC
795
833
 
796
- if data.velocity_representation != VelRepr.Mixed:
797
- return W_X_C @ C_v̇_WB
798
- else:
799
- from jaxsim.math.cross import Cross
800
-
801
- C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)])
802
- return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
834
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
835
+ # In Inertial and Body representations, the cross product is always zero.
836
+ return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
803
837
 
804
838
  match data.velocity_representation:
805
839
  case VelRepr.Inertial:
806
840
  W_H_C = W_H_W = jnp.eye(4)
807
- W_vl_WC = W_vl_WW = jnp.zeros(3)
841
+ W_v_WC = W_v_WW = jnp.zeros(6)
808
842
 
809
843
  case VelRepr.Body:
810
844
  W_H_C = W_H_B = data.base_transform()
811
- W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
845
+ with data.switch_velocity_representation(VelRepr.Inertial):
846
+ W_v_WC = W_v_WB = data.base_velocity()
812
847
 
813
848
  case VelRepr.Mixed:
814
849
  W_H_B = data.base_transform()
815
850
  W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
816
- W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
851
+ W_ṗ_B = data.base_velocity()[0:3]
852
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
817
853
 
818
854
  case _:
819
855
  raise ValueError(data.velocity_representation)
@@ -822,35 +858,60 @@ def inverse_dynamics(
822
858
  # representation. In Mixed representation, this conversion is not a plain
823
859
  # transformation with just X, but it also involves a cross product in ℝ⁶.
824
860
  W_v̇_WB = to_inertial(
825
- C_v̇_WB=base_acceleration,
861
+ C_v̇_WB=v̇_WB,
826
862
  W_H_C=W_H_C,
827
863
  C_v_WB=data.base_velocity(),
828
- W_vl_WC=W_vl_WC,
864
+ W_v_WC=W_v_WC,
829
865
  )
830
866
 
867
+ # Create a references object that simplifies converting among representations.
831
868
  references = js.references.JaxSimModelReferences.build(
832
869
  model=model,
833
870
  data=data,
834
- link_forces=external_forces,
871
+ link_forces=f_L,
835
872
  velocity_representation=data.velocity_representation,
836
873
  )
837
874
 
838
- # Compute RNEA
875
+ # Extract the link and joint serializations.
876
+ link_names = model.link_names()
877
+ joint_names = model.joint_names()
878
+
879
+ # Extract the state in inertial-fixed representation.
880
+ with data.switch_velocity_representation(VelRepr.Inertial):
881
+ W_p_B = data.base_position()
882
+ W_v_WB = data.base_velocity()
883
+ W_Q_B = data.base_orientation(dcm=False)
884
+ s = data.joint_positions(model=model, joint_names=joint_names)
885
+ ṡ = data.joint_velocities(model=model, joint_names=joint_names)
886
+
887
+ # Extract the inputs in inertial-fixed representation.
839
888
  with references.switch_velocity_representation(VelRepr.Inertial):
840
- W_f_B, τ = jaxsim.physics.algos.rnea.rnea(
841
- model=model.physics_model,
842
- xfb=data.state.physics_model.xfb(),
843
- q=data.state.physics_model.joint_positions,
844
- qd=data.state.physics_model.joint_velocities,
845
- qdd=joint_accelerations,
846
- a0fb=W_v̇_WB,
847
- f_ext=references.link_forces(model=model, data=data),
848
- )
889
+ W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
890
+
891
+ # ========================
892
+ # Compute inverse dynamics
893
+ # ========================
849
894
 
850
- # Adjust shape
851
- τ = jnp.atleast_1d(τ.squeeze())
895
+ W_f_B, τ = jaxsim.rbda.rnea(
896
+ model=model,
897
+ base_position=W_p_B,
898
+ base_quaternion=W_Q_B,
899
+ joint_positions=s,
900
+ base_linear_velocity=W_v_WB[0:3],
901
+ base_angular_velocity=W_v_WB[3:6],
902
+ joint_velocities=ṡ,
903
+ base_linear_acceleration=W_v̇_WB[0:3],
904
+ base_angular_acceleration=W_v̇_WB[3:6],
905
+ joint_accelerations=s̈,
906
+ link_forces=W_f_L,
907
+ standard_gravity=data.standard_gravity(),
908
+ )
852
909
 
853
- # Express W_f_B in the active representation
910
+ # =============
911
+ # Adjust output
912
+ # =============
913
+
914
+ # Express W_f_B in the active representation.
854
915
  f_B = js.data.JaxSimModelData.inertial_to_other_representation(
855
916
  array=W_f_B,
856
917
  other_representation=data.velocity_representation,
@@ -905,7 +966,7 @@ def free_floating_gravity_forces(
905
966
  # Set zero inputs:
906
967
  joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
907
968
  base_acceleration=jnp.zeros(6),
908
- external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
969
+ link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
909
970
  )
910
971
  ).astype(float)
911
972
 
@@ -948,18 +1009,20 @@ def free_floating_bias_forces(
948
1009
  data.state.physics_model.joint_positions
949
1010
  )
950
1011
 
951
- data_rnea.state.physics_model.base_linear_velocity = (
952
- data.state.physics_model.base_linear_velocity
953
- )
954
-
955
- data_rnea.state.physics_model.base_angular_velocity = (
956
- data.state.physics_model.base_angular_velocity
957
- )
958
-
959
1012
  data_rnea.state.physics_model.joint_velocities = (
960
1013
  data.state.physics_model.joint_velocities
961
1014
  )
962
1015
 
1016
+ # Make sure that base velocity is zero for fixed-base model.
1017
+ if model.floating_base():
1018
+ data_rnea.state.physics_model.base_linear_velocity = (
1019
+ data.state.physics_model.base_linear_velocity
1020
+ )
1021
+
1022
+ data_rnea.state.physics_model.base_angular_velocity = (
1023
+ data.state.physics_model.base_angular_velocity
1024
+ )
1025
+
963
1026
  return jnp.hstack(
964
1027
  inverse_dynamics(
965
1028
  model=model,
@@ -967,7 +1030,7 @@ def free_floating_bias_forces(
967
1030
  # Set zero inputs:
968
1031
  joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
969
1032
  base_acceleration=jnp.zeros(6),
970
- external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
1033
+ link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
971
1034
  )
972
1035
  ).astype(float)
973
1036
 
@@ -977,6 +1040,24 @@ def free_floating_bias_forces(
977
1040
  # ==========================
978
1041
 
979
1042
 
1043
+ @jax.jit
1044
+ def locked_spatial_inertia(
1045
+ model: JaxSimModel, data: js.data.JaxSimModelData
1046
+ ) -> jtp.Matrix:
1047
+ """
1048
+ Compute the locked 6D inertia matrix of the model.
1049
+
1050
+ Args:
1051
+ model: The model to consider.
1052
+ data: The data of the considered model.
1053
+
1054
+ Returns:
1055
+ The locked 6D inertia matrix of the model.
1056
+ """
1057
+
1058
+ return total_momentum_jacobian(model=model, data=data)[:, 0:6]
1059
+
1060
+
980
1061
  @jax.jit
981
1062
  def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
982
1063
  """
@@ -987,34 +1068,221 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
987
1068
  data: The data of the considered model.
988
1069
 
989
1070
  Returns:
990
- The total momentum of the model.
1071
+ The total momentum of the model in the active velocity representation.
991
1072
  """
992
1073
 
993
- # Compute the momentum in body-fixed velocity representation.
994
- # Note: the first 6 rows of the mass matrix define the jacobian of the
995
- # floating-base momentum.
996
- with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
997
- B_ν = data.generalized_velocity()
998
- M_B = free_floating_mass_matrix(model=model, data=data)
1074
+ ν = data.generalized_velocity()
1075
+ Jh = total_momentum_jacobian(model=model, data=data)
999
1076
 
1000
- # Compute the total momentum expressed in the base frame
1001
- B_h = M_B[0:6, :] @ B_ν
1077
+ return Jh @ ν
1002
1078
 
1003
- # Compute the 6D transformation matrix
1004
- W_H_B = data.base_transform()
1005
- B_X_W: jtp.Array = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
1006
1079
 
1007
- # Convert to inertial-fixed representation
1008
- # (its coordinates transform like 6D forces)
1009
- W_h = B_X_W.T @ B_h
1080
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
1081
+ def total_momentum_jacobian(
1082
+ model: JaxSimModel,
1083
+ data: js.data.JaxSimModelData,
1084
+ *,
1085
+ output_vel_repr: VelRepr | None = None,
1086
+ ) -> jtp.Matrix:
1087
+ """
1088
+ Compute the jacobian of the total momentum.
1010
1089
 
1011
- # Convert to the active representation of the model
1012
- return js.data.JaxSimModelData.inertial_to_other_representation(
1013
- array=W_h,
1014
- other_representation=data.velocity_representation,
1015
- transform=W_H_B,
1016
- is_force=True,
1017
- ).astype(float)
1090
+ Args:
1091
+ model: The model to consider.
1092
+ data: The data of the considered model.
1093
+ output_vel_repr: The output velocity representation of the jacobian.
1094
+
1095
+ Returns:
1096
+ The jacobian of the total momentum of the model in the active representation.
1097
+ """
1098
+
1099
+ output_vel_repr = (
1100
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
1101
+ )
1102
+
1103
+ if output_vel_repr is data.velocity_representation:
1104
+ return free_floating_mass_matrix(model=model, data=data)[0:6]
1105
+
1106
+ with data.switch_velocity_representation(VelRepr.Body):
1107
+ B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6]
1108
+
1109
+ match data.velocity_representation:
1110
+ case VelRepr.Body:
1111
+ B_Jh = B_Jh_B
1112
+
1113
+ case VelRepr.Inertial:
1114
+ B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
1115
+ B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
1116
+
1117
+ case VelRepr.Mixed:
1118
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1119
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
1120
+ B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
1121
+
1122
+ case _:
1123
+ raise ValueError(data.velocity_representation)
1124
+
1125
+ match output_vel_repr:
1126
+ case VelRepr.Body:
1127
+ return B_Jh
1128
+
1129
+ case VelRepr.Inertial:
1130
+ W_H_B = data.base_transform()
1131
+ B_Xv_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
1132
+ W_Xf_B = B_Xv_W.T
1133
+ W_Jh = W_Xf_B @ B_Jh
1134
+ return W_Jh
1135
+
1136
+ case VelRepr.Mixed:
1137
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1138
+ B_Xv_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
1139
+ BW_Xf_B = B_Xv_BW.T
1140
+ BW_Jh = BW_Xf_B @ B_Jh
1141
+ return BW_Jh
1142
+
1143
+ case _:
1144
+ raise ValueError(output_vel_repr)
1145
+
1146
+
1147
+ @jax.jit
1148
+ def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
1149
+ """
1150
+ Compute the average velocity of the model.
1151
+
1152
+ Args:
1153
+ model: The model to consider.
1154
+ data: The data of the considered model.
1155
+
1156
+ Returns:
1157
+ The average velocity of the model computed in the base frame and expressed
1158
+ in the active representation.
1159
+ """
1160
+
1161
+ ν = data.generalized_velocity()
1162
+ J = average_velocity_jacobian(model=model, data=data)
1163
+
1164
+ return J @ ν
1165
+
1166
+
1167
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
1168
+ def average_velocity_jacobian(
1169
+ model: JaxSimModel,
1170
+ data: js.data.JaxSimModelData,
1171
+ *,
1172
+ output_vel_repr: VelRepr | None = None,
1173
+ ) -> jtp.Matrix:
1174
+ """
1175
+ Compute the Jacobian of the average velocity of the model.
1176
+
1177
+ Args:
1178
+ model: The model to consider.
1179
+ data: The data of the considered model.
1180
+ output_vel_repr: The output velocity representation of the jacobian.
1181
+
1182
+ Returns:
1183
+ The Jacobian of the average centroidal velocity of the model in the desired
1184
+ representation.
1185
+ """
1186
+
1187
+ output_vel_repr = (
1188
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
1189
+ )
1190
+
1191
+ # Depending on the velocity representation, the frame G is either G[W] or G[B].
1192
+ G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
1193
+
1194
+ match output_vel_repr:
1195
+
1196
+ case VelRepr.Inertial:
1197
+
1198
+ GW_J = G_J
1199
+ W_p_CoM = js.com.com_position(model=model, data=data)
1200
+
1201
+ W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
1202
+ W_X_GW = jaxlie.SE3.from_matrix(W_H_GW).adjoint()
1203
+
1204
+ return W_X_GW @ GW_J
1205
+
1206
+ case VelRepr.Body:
1207
+
1208
+ GB_J = G_J
1209
+ W_p_B = data.base_position()
1210
+ W_p_CoM = js.com.com_position(model=model, data=data)
1211
+ B_R_W = data.base_orientation(dcm=True).transpose()
1212
+
1213
+ B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
1214
+ B_X_GB = jaxlie.SE3.from_matrix(B_H_GB).adjoint()
1215
+
1216
+ return B_X_GB @ GB_J
1217
+
1218
+ case VelRepr.Mixed:
1219
+
1220
+ GW_J = G_J
1221
+ W_p_B = data.base_position()
1222
+ W_p_CoM = js.com.com_position(model=model, data=data)
1223
+
1224
+ BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
1225
+ BW_X_GW = jaxlie.SE3.from_matrix(BW_H_GW).adjoint()
1226
+
1227
+ return BW_X_GW @ GW_J
1228
+
1229
+
1230
+ # ========================
1231
+ # Other dynamic quantities
1232
+ # ========================
1233
+
1234
+
1235
+ @jax.jit
1236
+ def link_contact_forces(
1237
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
1238
+ ) -> jtp.Matrix:
1239
+ """
1240
+ Compute the 6D contact forces of all links of the model.
1241
+
1242
+ Args:
1243
+ model: The model to consider.
1244
+ data: The data of the considered model.
1245
+
1246
+ Returns:
1247
+ A (nL, 6) array containing the stacked 6D contact forces of the links,
1248
+ expressed in the frame corresponding to the active representation.
1249
+ """
1250
+
1251
+ # Compute the 6D forces applied to each collidable point expressed in the
1252
+ # inertial frame.
1253
+ with data.switch_velocity_representation(VelRepr.Inertial):
1254
+ W_f_Ci = js.contact.collidable_point_forces(model=model, data=data)
1255
+
1256
+ # Construct the vector defining the parent link index of each collidable point.
1257
+ # We use this vector to sum the 6D forces of all collidable points rigidly
1258
+ # attached to the same link.
1259
+ parent_link_index_of_collidable_points = jnp.array(
1260
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
1261
+ )
1262
+
1263
+ # Sum the forces of all collidable points rigidly attached to a body.
1264
+ # Since the contact forces W_f_Ci are expressed in the world frame,
1265
+ # we don't need any coordinate transformation.
1266
+ W_f_Li = jax.vmap(
1267
+ lambda nc: (
1268
+ jnp.vstack(
1269
+ jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
1270
+ )
1271
+ * W_f_Ci
1272
+ ).sum(axis=0)
1273
+ )(jnp.arange(model.number_of_links()))
1274
+
1275
+ # Convert the 6D forces to the active representation.
1276
+ f_Li = jax.vmap(
1277
+ lambda W_f_L: data.inertial_to_other_representation(
1278
+ array=W_f_L,
1279
+ other_representation=data.velocity_representation,
1280
+ transform=data.base_transform(),
1281
+ is_force=True,
1282
+ )
1283
+ )(W_f_Li)
1284
+
1285
+ return f_Li
1018
1286
 
1019
1287
 
1020
1288
  # ======
@@ -1077,7 +1345,7 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
1077
1345
 
1078
1346
  m = total_mass(model=model)
1079
1347
  gravity = data.gravity.squeeze()
1080
- W_p̃_CoM = jnp.hstack([com_position(model=model, data=data), 1])
1348
+ W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
1081
1349
 
1082
1350
  U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
1083
1351
  return U.squeeze().astype(float)
@@ -1097,7 +1365,8 @@ def step(
1097
1365
  integrator: jaxsim.integrators.Integrator,
1098
1366
  integrator_state: dict[str, Any] | None = None,
1099
1367
  joint_forces: jtp.VectorLike | None = None,
1100
- external_forces: jtp.MatrixLike | None = None,
1368
+ link_forces: jtp.MatrixLike | None = None,
1369
+ **kwargs,
1101
1370
  ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
1102
1371
  """
1103
1372
  Perform a simulation step.
@@ -1109,15 +1378,17 @@ def step(
1109
1378
  integrator: The integrator to use.
1110
1379
  integrator_state: The state of the integrator.
1111
1380
  joint_forces: The joint forces to consider.
1112
- external_forces:
1113
- The external forces to consider.
1381
+ link_forces:
1382
+ The link 6D forces to consider.
1114
1383
  The frame in which they are expressed must be `data.velocity_representation`.
1384
+ kwargs: Additional kwargs to pass to the integrator.
1115
1385
 
1116
1386
  Returns:
1117
1387
  A tuple containing the new data of the model
1118
1388
  and the new state of the integrator.
1119
1389
  """
1120
1390
 
1391
+ integrator_kwargs = kwargs if kwargs is not None else dict()
1121
1392
  integrator_state = integrator_state if integrator_state is not None else dict()
1122
1393
 
1123
1394
  # Extract the initial resources.
@@ -1128,10 +1399,12 @@ def step(
1128
1399
  # Step the dynamics forward.
1129
1400
  state_xf, integrator_state_xf = integrator.step(
1130
1401
  x0=state_x0,
1131
- t0=jnp.array(t0_ns * 1e9).astype(float),
1402
+ t0=jnp.array(t0_ns / 1e9).astype(float),
1132
1403
  dt=dt,
1133
1404
  params=integrator_state_x0,
1134
- **dict(joint_forces=joint_forces, external_forces=external_forces),
1405
+ **(
1406
+ dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs
1407
+ ),
1135
1408
  )
1136
1409
 
1137
1410
  return (