jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py CHANGED
@@ -114,5 +114,5 @@ del _get_default_logging_level
114
114
  del _is_editable
115
115
 
116
116
  from . import terrain # isort:skip
117
- from . import api, integrators, logging, math, rbda
117
+ from . import api, logging, math, rbda
118
118
  from .api.common import VelRepr
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.6.2.dev2'
16
- __version_tuple__ = version_tuple = (0, 6, 2, 'dev2')
15
+ __version__ = version = '0.6.2.dev102'
16
+ __version_tuple__ = version_tuple = (0, 6, 2, 'dev102')
jaxsim/api/__init__.py CHANGED
@@ -1,13 +1,15 @@
1
1
  from . import common # isort:skip
2
2
  from . import model, data # isort:skip
3
3
  from . import (
4
+ actuation_model,
4
5
  com,
5
6
  contact,
7
+ contact_model,
6
8
  frame,
9
+ integrators,
7
10
  joint,
8
11
  kin_dyn_parameters,
9
12
  link,
10
13
  ode,
11
- ode_data,
12
14
  references,
13
15
  )
@@ -0,0 +1,96 @@
1
+ import jax.numpy as jnp
2
+
3
+ import jaxsim.api as js
4
+ import jaxsim.typing as jtp
5
+
6
+
7
+ def compute_resultant_torques(
8
+ model: js.model.JaxSimModel,
9
+ data: js.data.JaxSimModelData,
10
+ *,
11
+ joint_force_references: jtp.Vector | None = None,
12
+ ) -> jtp.Vector:
13
+ """
14
+ Compute the resultant torques acting on the joints.
15
+
16
+ Args:
17
+ model: The model to consider.
18
+ data: The data of the considered model.
19
+ joint_force_references: The joint force references to apply.
20
+
21
+ Returns:
22
+ The resultant torques acting on the joints.
23
+ """
24
+
25
+ # Build joint torques if not provided.
26
+ τ_references = (
27
+ jnp.atleast_1d(joint_force_references.squeeze())
28
+ if joint_force_references is not None
29
+ else jnp.zeros_like(data.joint_positions)
30
+ ).astype(float)
31
+
32
+ # ====================
33
+ # Enforce joint limits
34
+ # ====================
35
+
36
+ τ_position_limit = jnp.zeros_like(τ_references).astype(float)
37
+
38
+ if model.dofs() > 0:
39
+
40
+ # Stiffness and damper parameters for the joint position limits.
41
+ k_j = jnp.array(
42
+ model.kin_dyn_parameters.joint_parameters.position_limit_spring
43
+ ).astype(float)
44
+ d_j = jnp.array(
45
+ model.kin_dyn_parameters.joint_parameters.position_limit_damper
46
+ ).astype(float)
47
+
48
+ # Compute the joint position limit violations.
49
+ lower_violation = jnp.clip(
50
+ data.joint_positions
51
+ - model.kin_dyn_parameters.joint_parameters.position_limits_min,
52
+ max=0.0,
53
+ )
54
+
55
+ upper_violation = jnp.clip(
56
+ data.joint_positions
57
+ - model.kin_dyn_parameters.joint_parameters.position_limits_max,
58
+ min=0.0,
59
+ )
60
+
61
+ # Compute the joint position limit torque.
62
+ τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
63
+
64
+ τ_position_limit -= (
65
+ jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities
66
+ )
67
+
68
+ # ====================
69
+ # Joint friction model
70
+ # ====================
71
+
72
+ τ_friction = jnp.zeros_like(τ_references).astype(float)
73
+
74
+ if model.dofs() > 0:
75
+
76
+ # Static and viscous joint friction parameters
77
+ kc = jnp.array(
78
+ model.kin_dyn_parameters.joint_parameters.friction_static
79
+ ).astype(float)
80
+ kv = jnp.array(
81
+ model.kin_dyn_parameters.joint_parameters.friction_viscous
82
+ ).astype(float)
83
+
84
+ # Compute the joint friction torque.
85
+ τ_friction = -(
86
+ jnp.diag(kc) @ jnp.sign(data.joint_velocities)
87
+ + jnp.diag(kv) @ data.joint_velocities
88
+ )
89
+
90
+ # ===============================
91
+ # Compute the total joint forces.
92
+ # ===============================
93
+
94
+ τ_total = τ_references + τ_friction + τ_position_limit
95
+
96
+ return τ_total
jaxsim/api/com.py CHANGED
@@ -26,8 +26,8 @@ def com_position(
26
26
 
27
27
  m = js.model.total_mass(model=model)
28
28
 
29
- W_H_L = js.model.forward_kinematics(model=model, data=data)
30
- W_H_B = data.base_transform()
29
+ W_H_L = data._link_transforms
30
+ W_H_B = data._base_transform
31
31
  B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
32
32
 
33
33
  def B_p̃_LCoM(i) -> jtp.Vector:
@@ -98,7 +98,7 @@ def centroidal_momentum(
98
98
  and :math:`C = B` if the active velocity representation is body-fixed.
99
99
  """
100
100
 
101
- ν = data.generalized_velocity()
101
+ ν = data.generalized_velocity
102
102
  G_J = centroidal_momentum_jacobian(model=model, data=data)
103
103
 
104
104
  return G_J @ ν
@@ -134,7 +134,7 @@ def centroidal_momentum_jacobian(
134
134
  model=model, data=data, output_vel_repr=VelRepr.Body
135
135
  )
136
136
 
137
- W_H_B = data.base_transform()
137
+ W_H_B = data._base_transform
138
138
  B_H_W = jaxsim.math.Transform.inverse(W_H_B)
139
139
 
140
140
  W_p_CoM = com_position(model=model, data=data)
@@ -172,7 +172,7 @@ def locked_centroidal_spatial_inertia(
172
172
  with data.switch_velocity_representation(VelRepr.Body):
173
173
  B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
174
174
 
175
- W_H_B = data.base_transform()
175
+ W_H_B = data._base_transform
176
176
  W_p_CoM = com_position(model=model, data=data)
177
177
 
178
178
  match data.velocity_representation:
@@ -213,7 +213,7 @@ def average_centroidal_velocity(
213
213
  and :math:`[C] = [B]` if the active velocity representation is body-fixed.
214
214
  """
215
215
 
216
- ν = data.generalized_velocity()
216
+ ν = data.generalized_velocity
217
217
  G_J = average_centroidal_velocity_jacobian(model=model, data=data)
218
218
 
219
219
  return G_J @ ν
@@ -269,7 +269,7 @@ def bias_acceleration(
269
269
  """
270
270
 
271
271
  # Compute the pose of all links with forward kinematics.
272
- W_H_L = js.model.forward_kinematics(model=model, data=data)
272
+ W_H_L = data._link_transforms
273
273
 
274
274
  # Compute the bias acceleration of all links by zeroing the generalized velocity
275
275
  # in the active representation.
@@ -411,7 +411,7 @@ def bias_acceleration(
411
411
  case VelRepr.Body:
412
412
 
413
413
  GB_Xf_W = jaxsim.math.Adjoint.from_transform(
414
- transform=data.base_transform().at[0:3].set(W_p_CoM)
414
+ transform=data._base_transform.at[0:3].set(W_p_CoM)
415
415
  ).T
416
416
 
417
417
  GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
jaxsim/api/contact.py CHANGED
@@ -42,12 +42,8 @@ def collidable_point_kinematics(
42
42
 
43
43
  W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
44
44
  model=model,
45
- base_position=data.base_position(),
46
- base_quaternion=data.base_orientation(dcm=False),
47
- joint_positions=data.joint_positions(model=model),
48
- base_linear_velocity=data.base_velocity()[0:3],
49
- base_angular_velocity=data.base_velocity()[3:6],
50
- joint_velocities=data.joint_velocities(model=model),
45
+ link_transforms=data._link_transforms,
46
+ link_velocities=data._link_velocities,
51
47
  )
52
48
 
53
49
  return W_p_Ci, W_ṗ_Ci
@@ -95,143 +91,6 @@ def collidable_point_velocities(
95
91
  return W_ṗ_Ci
96
92
 
97
93
 
98
- @jax.jit
99
- @js.common.named_scope
100
- def collidable_point_forces(
101
- model: js.model.JaxSimModel,
102
- data: js.data.JaxSimModelData,
103
- link_forces: jtp.MatrixLike | None = None,
104
- joint_force_references: jtp.VectorLike | None = None,
105
- **kwargs,
106
- ) -> jtp.Matrix:
107
- """
108
- Compute the 6D forces applied to each collidable point.
109
-
110
- Args:
111
- model: The model to consider.
112
- data: The data of the considered model.
113
- link_forces:
114
- The 6D external forces to apply to the links expressed in the same
115
- representation of data.
116
- joint_force_references:
117
- The joint force references to apply to the joints.
118
- kwargs: Additional keyword arguments to pass to the active contact model.
119
-
120
- Returns:
121
- The 6D forces applied to each collidable point expressed in the frame
122
- corresponding to the active representation.
123
- """
124
-
125
- f_Ci, _ = collidable_point_dynamics(
126
- model=model,
127
- data=data,
128
- link_forces=link_forces,
129
- joint_force_references=joint_force_references,
130
- **kwargs,
131
- )
132
-
133
- return f_Ci
134
-
135
-
136
- @jax.jit
137
- @js.common.named_scope
138
- def collidable_point_dynamics(
139
- model: js.model.JaxSimModel,
140
- data: js.data.JaxSimModelData,
141
- link_forces: jtp.MatrixLike | None = None,
142
- joint_force_references: jtp.VectorLike | None = None,
143
- **kwargs,
144
- ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
145
- r"""
146
- Compute the 6D force applied to each enabled collidable point.
147
-
148
- Args:
149
- model: The model to consider.
150
- data: The data of the considered model.
151
- link_forces:
152
- The 6D external forces to apply to the links expressed in the same
153
- representation of data.
154
- joint_force_references:
155
- The joint force references to apply to the joints.
156
- kwargs: Additional keyword arguments to pass to the active contact model.
157
-
158
- Returns:
159
- The 6D force applied to each enabled collidable point and additional data based
160
- on the contact model configured:
161
- - Soft: the material deformation rate.
162
- - Rigid: no additional data.
163
- - QuasiRigid: no additional data.
164
-
165
- Note:
166
- The material deformation rate is always returned in the mixed frame
167
- `C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
168
- Instead, the 6D forces are returned in the active representation.
169
- """
170
-
171
- # Build the common kw arguments to pass to the computation of the contact forces.
172
- common_kwargs = dict(
173
- link_forces=link_forces,
174
- joint_force_references=joint_force_references,
175
- )
176
-
177
- # Build the additional kwargs to pass to the computation of the contact forces.
178
- match model.contact_model:
179
-
180
- case contacts.SoftContacts():
181
-
182
- kwargs_contact_model = {}
183
-
184
- case contacts.RigidContacts():
185
-
186
- kwargs_contact_model = common_kwargs | kwargs
187
-
188
- case contacts.RelaxedRigidContacts():
189
-
190
- kwargs_contact_model = common_kwargs | kwargs
191
-
192
- case contacts.ViscoElasticContacts():
193
-
194
- kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs
195
-
196
- case _:
197
- raise ValueError(f"Invalid contact model: {model.contact_model}")
198
-
199
- # Compute the contact forces with the active contact model.
200
- W_f_C, aux_data = model.contact_model.compute_contact_forces(
201
- model=model,
202
- data=data,
203
- **kwargs_contact_model,
204
- )
205
-
206
- # Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
207
- # associated to the enabled collidable point.
208
- # In inertial-fixed representation, the computation of these transforms
209
- # is not necessary and the conversion below becomes a no-op.
210
-
211
- # Get the indices of the enabled collidable points.
212
- indices_of_enabled_collidable_points = (
213
- model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
214
- )
215
-
216
- W_H_C = (
217
- js.contact.transforms(model=model, data=data)
218
- if data.velocity_representation is not VelRepr.Inertial
219
- else jnp.stack([jnp.eye(4)] * len(indices_of_enabled_collidable_points))
220
- )
221
-
222
- # Convert the 6D forces to the active representation.
223
- f_Ci = jax.vmap(
224
- lambda W_f_C, W_H_C: data.inertial_to_other_representation(
225
- array=W_f_C,
226
- other_representation=data.velocity_representation,
227
- transform=W_H_C,
228
- is_force=True,
229
- )
230
- )(W_f_C, W_H_C)
231
-
232
- return f_Ci, aux_data
233
-
234
-
235
94
  @functools.partial(jax.jit, static_argnames=["link_names"])
236
95
  @js.common.named_scope
237
96
  def in_contact(
@@ -305,11 +164,7 @@ def estimate_good_soft_contacts_parameters(
305
164
  def estimate_good_contact_parameters(
306
165
  model: js.model.JaxSimModel,
307
166
  *,
308
- standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
309
167
  static_friction_coefficient: jtp.FloatLike = 0.5,
310
- number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
311
- damping_ratio: jtp.FloatLike = 1.0,
312
- max_penetration: jtp.FloatLike | None = None,
313
168
  **kwargs,
314
169
  ) -> jaxsim.rbda.contacts.ContactParamsTypes:
315
170
  """
@@ -317,15 +172,7 @@ def estimate_good_contact_parameters(
317
172
 
318
173
  Args:
319
174
  model: The model to consider.
320
- standard_gravity: The standard gravity constant.
321
175
  static_friction_coefficient: The static friction coefficient.
322
- number_of_active_collidable_points_steady_state:
323
- The number of active collidable points in steady state supporting
324
- the weight of the robot.
325
- damping_ratio: The damping ratio.
326
- max_penetration:
327
- The maximum penetration allowed in steady state when the robot is
328
- supported by the configured number of active collidable points.
329
176
  kwargs:
330
177
  Additional model-specific parameters passed to the builder method of
331
178
  the parameters class.
@@ -343,82 +190,8 @@ def estimate_good_contact_parameters(
343
190
  specific application.
344
191
  """
345
192
 
346
- def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
347
- """
348
- Displacement between the CoM and the lowest collidable point using zero
349
- joint positions.
350
- """
351
-
352
- zero_data = js.data.JaxSimModelData.build(
353
- model=model,
354
- contacts_params=jaxsim.rbda.contacts.SoftContactsParams(),
355
- )
356
-
357
- W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
358
-
359
- if model.floating_base():
360
- W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
361
- return 2 * (W_pz_CoM - W_pz_C.min())
362
-
363
- return 2 * W_pz_CoM
364
-
365
- max_δ = (
366
- max_penetration
367
- if max_penetration is not None
368
- # Consider as default a 0.5% of the model height.
369
- else 0.005 * estimate_model_height(model=model)
370
- )
371
-
372
- nc = number_of_active_collidable_points_steady_state
373
-
374
193
  match model.contact_model:
375
194
 
376
- case contacts.SoftContacts():
377
- assert isinstance(model.contact_model, contacts.SoftContacts)
378
-
379
- parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
380
- model=model,
381
- standard_gravity=standard_gravity,
382
- static_friction_coefficient=static_friction_coefficient,
383
- max_penetration=max_δ,
384
- number_of_active_collidable_points_steady_state=nc,
385
- damping_ratio=damping_ratio,
386
- **kwargs,
387
- )
388
-
389
- case contacts.ViscoElasticContacts():
390
- assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
391
-
392
- parameters = (
393
- contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
394
- model=model,
395
- standard_gravity=standard_gravity,
396
- static_friction_coefficient=static_friction_coefficient,
397
- max_penetration=max_δ,
398
- number_of_active_collidable_points_steady_state=nc,
399
- damping_ratio=damping_ratio,
400
- **kwargs,
401
- )
402
- )
403
-
404
- case contacts.RigidContacts():
405
- assert isinstance(model.contact_model, contacts.RigidContacts)
406
-
407
- # Disable Baumgarte stabilization by default since it does not play
408
- # well with the forward Euler integrator.
409
- K = kwargs.get("K", 0.0)
410
-
411
- parameters = contacts.RigidContactsParams.build(
412
- mu=static_friction_coefficient,
413
- **(
414
- dict(
415
- K=K,
416
- D=2 * jnp.sqrt(K),
417
- )
418
- | kwargs
419
- ),
420
- )
421
-
422
195
  case contacts.RelaxedRigidContacts():
423
196
  assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
424
197
 
@@ -463,9 +236,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
463
236
  )[indices_of_enabled_collidable_points]
464
237
 
465
238
  # Get the transforms of the parent link of all collidable points.
466
- W_H_L = js.model.forward_kinematics(model=model, data=data)[
467
- parent_link_idx_of_enabled_collidable_points
468
- ]
239
+ W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points]
469
240
 
470
241
  L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
471
242
  indices_of_enabled_collidable_points
@@ -615,7 +386,10 @@ def jacobian_derivative(
615
386
  ]
616
387
 
617
388
  # Get the transforms of all the parent links.
618
- W_H_Li = js.model.forward_kinematics(model=model, data=data)
389
+ W_H_Li = data._link_transforms
390
+
391
+ # Get the link velocities.
392
+ W_v_WLi = data._link_velocities
619
393
 
620
394
  # =====================================================
621
395
  # Compute quantities to adjust the input representation
@@ -643,9 +417,9 @@ def jacobian_derivative(
643
417
  Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
644
418
 
645
419
  case VelRepr.Body:
646
- W_H_B = data.base_transform()
420
+ W_H_B = data._base_transform
647
421
  W_X_B = Adjoint.from_transform(transform=W_H_B)
648
- B_v_WB = data.base_velocity()
422
+ B_v_WB = data.base_velocity
649
423
  B_vx_WB = Cross.vx(B_v_WB)
650
424
  W_Ẋ_B = W_X_B @ B_vx_WB
651
425
 
@@ -653,10 +427,10 @@ def jacobian_derivative(
653
427
  Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
654
428
 
655
429
  case VelRepr.Mixed:
656
- W_H_B = data.base_transform()
430
+ W_H_B = data._base_transform
657
431
  W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
658
432
  W_X_BW = Adjoint.from_transform(transform=W_H_BW)
659
- BW_v_WB = data.base_velocity()
433
+ BW_v_WB = data.base_velocity
660
434
  BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
661
435
  BW_vx_W_BW = Cross.vx(BW_v_W_BW)
662
436
  W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
@@ -676,27 +450,16 @@ def jacobian_derivative(
676
450
  W_J_WL_W = js.model.generalized_free_floating_jacobian(
677
451
  model=model,
678
452
  data=data,
679
- output_vel_repr=VelRepr.Inertial,
680
453
  )
681
454
  # Compute the Jacobian derivative of the parent link in inertial representation.
682
455
  W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
683
456
  model=model,
684
457
  data=data,
685
- output_vel_repr=VelRepr.Inertial,
686
- )
687
-
688
- # Get the Jacobian of the enabled collidable points in the mixed representation.
689
- with data.switch_velocity_representation(VelRepr.Mixed):
690
- CW_J_WC_BW = jacobian(
691
- model=model,
692
- data=data,
693
- output_vel_repr=VelRepr.Mixed,
694
458
  )
695
459
 
696
460
  def compute_O_J̇_WC_I(
697
461
  L_p_C: jtp.Vector,
698
462
  parent_link_idx: jtp.Int,
699
- CW_J_WC_BW: jtp.Matrix,
700
463
  W_H_L: jtp.Matrix,
701
464
  ) -> jtp.Matrix:
702
465
 
@@ -711,9 +474,7 @@ def jacobian_derivative(
711
474
  L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
712
475
  W_H_C = W_H_L[parent_link_idx] @ L_H_C
713
476
  O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
714
- with data.switch_velocity_representation(VelRepr.Inertial):
715
- W_nu = data.generalized_velocity()
716
- W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
477
+ W_v_WC = W_v_WLi[parent_link_idx]
717
478
  W_vx_WC = Cross.vx(W_v_WC)
718
479
  O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841
719
480
 
@@ -723,8 +484,7 @@ def jacobian_derivative(
723
484
  W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
724
485
  CW_H_W = Transform.inverse(W_H_CW)
725
486
  O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
726
- with data.switch_velocity_representation(VelRepr.Mixed):
727
- CW_v_WC = CW_J_WC_BW @ data.generalized_velocity()
487
+ CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx]
728
488
  W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
729
489
  W_vx_W_CW = Cross.vx(W_v_W_CW)
730
490
  O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841
@@ -739,8 +499,8 @@ def jacobian_derivative(
739
499
 
740
500
  return O_J̇_WC_I
741
501
 
742
- O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
743
- L_p_Ci, parent_link_idx_of_enabled_collidable_points, CW_J_WC_BW, W_H_Li
502
+ O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))(
503
+ L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li
744
504
  )
745
505
 
746
506
  return O_J̇_WC
@@ -0,0 +1,101 @@
1
+ from __future__ import annotations
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ import jaxsim.api as js
7
+ import jaxsim.typing as jtp
8
+
9
+
10
+ @jax.jit
11
+ @js.common.named_scope
12
+ def link_contact_forces(
13
+ model: js.model.JaxSimModel,
14
+ data: js.data.JaxSimModelData,
15
+ *,
16
+ link_forces: jtp.MatrixLike | None = None,
17
+ joint_torques: jtp.VectorLike | None = None,
18
+ ) -> jtp.Matrix:
19
+ """
20
+ Compute the 6D contact forces of all links of the model in inertial representation.
21
+
22
+ Args:
23
+ model: The model to consider.
24
+ data: The data of the considered model.
25
+ link_forces:
26
+ The 6D external forces to apply to the links expressed in inertial representation
27
+ joint_torques:
28
+ The joint torques acting on the joints.
29
+
30
+ Returns:
31
+ A `(nL, 6)` array containing the stacked 6D contact forces of the links,
32
+ expressed in inertial representation.
33
+ """
34
+
35
+ # Compute the contact forces for each collidable point with the active contact model.
36
+ W_f_C, _ = model.contact_model.compute_contact_forces(
37
+ model=model,
38
+ data=data,
39
+ link_forces=link_forces,
40
+ joint_force_references=joint_torques,
41
+ )
42
+
43
+ # Compute the 6D forces applied to the links equivalent to the forces applied
44
+ # to the frames associated to the collidable points.
45
+ W_f_L = link_forces_from_contact_forces(
46
+ model=model, data=data, contact_forces=W_f_C
47
+ )
48
+
49
+ return W_f_L
50
+
51
+
52
+ @staticmethod
53
+ def link_forces_from_contact_forces(
54
+ model: js.model.JaxSimModel,
55
+ data: js.data.JaxSimModelData,
56
+ *,
57
+ contact_forces: jtp.MatrixLike,
58
+ ) -> jtp.Matrix:
59
+ """
60
+ Compute the link forces from the contact forces.
61
+
62
+ Args:
63
+ model: The robot model considered by the contact model.
64
+ data: The data of the considered model.
65
+ contact_forces: The contact forces computed by the contact model.
66
+
67
+ Returns:
68
+ The 6D contact forces applied to the links and expressed in the frame of
69
+ the velocity representation of data.
70
+ """
71
+
72
+ # Get the object storing the contact parameters of the model.
73
+ contact_parameters = model.kin_dyn_parameters.contact_parameters
74
+
75
+ # Extract the indices corresponding to the enabled collidable points.
76
+ indices_of_enabled_collidable_points = (
77
+ contact_parameters.indices_of_enabled_collidable_points
78
+ )
79
+
80
+ # Convert the contact forces to a JAX array.
81
+ W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
82
+
83
+ # Construct the vector defining the parent link index of each collidable point.
84
+ # We use this vector to sum the 6D forces of all collidable points rigidly
85
+ # attached to the same link.
86
+ parent_link_index_of_collidable_points = jnp.array(
87
+ contact_parameters.body, dtype=int
88
+ )[indices_of_enabled_collidable_points]
89
+
90
+ # Create the mask that associate each collidable point to their parent link.
91
+ # We use this mask to sum the collidable points to the right link.
92
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
93
+ model.number_of_links()
94
+ )
95
+
96
+ # Sum the forces of all collidable points rigidly attached to a body.
97
+ # Since the contact forces W_f_C are expressed in the world frame,
98
+ # we don't need any coordinate transformation.
99
+ W_f_L = mask.T @ W_f_C
100
+
101
+ return W_f_L