jaxsim 0.4.3.dev231__py3-none-any.whl → 0.4.3.dev245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import functools
5
- from typing import Any
6
5
 
7
6
  import jax
8
7
  import jax.numpy as jnp
@@ -10,6 +9,7 @@ import jax.numpy as jnp
10
9
  import jaxsim.api as js
11
10
  import jaxsim.terrain
12
11
  import jaxsim.typing as jtp
12
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation
13
13
  from jaxsim.utils import JaxsimDataclass
14
14
 
15
15
  try:
@@ -131,7 +131,7 @@ class ContactModel(JaxsimDataclass):
131
131
  model: js.model.JaxSimModel,
132
132
  data: js.data.JaxSimModelData,
133
133
  **kwargs,
134
- ) -> tuple[jtp.Matrix, tuple[Any, ...]]:
134
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
135
135
  """
136
136
  Compute the contact forces.
137
137
 
@@ -142,11 +142,145 @@ class ContactModel(JaxsimDataclass):
142
142
  Returns:
143
143
  A tuple containing as first element the computed 6D contact force applied to
144
144
  the contact points and expressed in the world frame, and as second element
145
- a tuple of optional additional information.
145
+ a dictionary of optional additional information.
146
146
  """
147
147
 
148
148
  pass
149
149
 
150
+ def compute_link_contact_forces(
151
+ self,
152
+ model: js.model.JaxSimModel,
153
+ data: js.data.JaxSimModelData,
154
+ **kwargs,
155
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
156
+ """
157
+ Compute the link contact forces.
158
+
159
+ Args:
160
+ model: The robot model considered by the contact model.
161
+ data: The data of the considered model.
162
+
163
+ Returns:
164
+ A tuple containing as first element the 6D contact force applied to the
165
+ links and expressed in the frame of the velocity representation of data,
166
+ and as second element a dictionary of optional additional information.
167
+ """
168
+
169
+ # Compute the contact forces expressed in the inertial frame.
170
+ # This function, contrarily to `compute_contact_forces`, already handles how
171
+ # the optional kwargs should be passed to the specific contact models.
172
+ W_f_C, aux_dict = js.contact.collidable_point_dynamics(
173
+ model=model, data=data, **kwargs
174
+ )
175
+
176
+ # Compute the 6D forces applied to the links equivalent to the forces applied
177
+ # to the frames associated to the collidable points.
178
+ with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
179
+
180
+ W_f_L = self.link_forces_from_contact_forces(
181
+ model=model, data=data, contact_forces=W_f_C
182
+ )
183
+
184
+ # Store the link forces in the references object for easy conversion.
185
+ references = js.references.JaxSimModelReferences.build(
186
+ model=model,
187
+ data=data,
188
+ link_forces=W_f_L,
189
+ velocity_representation=jaxsim.VelRepr.Inertial,
190
+ )
191
+
192
+ # Convert the link forces to the frame corresponding to the velocity
193
+ # representation of data.
194
+ with references.switch_velocity_representation(data.velocity_representation):
195
+ f_L = references.link_forces(model=model, data=data)
196
+
197
+ return f_L, aux_dict
198
+
199
+ @staticmethod
200
+ def link_forces_from_contact_forces(
201
+ model: js.model.JaxSimModel,
202
+ data: js.data.JaxSimModelData,
203
+ *,
204
+ contact_forces: jtp.MatrixLike,
205
+ ) -> jtp.Matrix:
206
+ """
207
+ Compute the link forces from the contact forces.
208
+
209
+ Args:
210
+ model: The robot model considered by the contact model.
211
+ data: The data of the considered model.
212
+ contact_forces: The contact forces computed by the contact model.
213
+
214
+ Returns:
215
+ The 6D contact forces applied to the links and expressed in the frame of
216
+ the velocity representation of data.
217
+ """
218
+
219
+ # Convert the contact forces to a JAX array.
220
+ f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
221
+
222
+ # Get the pose of the enabled collidable points.
223
+ W_H_C = js.contact.transforms(model=model, data=data)
224
+
225
+ # Convert the contact forces to inertial-fixed representation.
226
+ W_f_C = jax.vmap(
227
+ lambda f_C, W_H_C: (
228
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
229
+ array=f_C,
230
+ other_representation=data.velocity_representation,
231
+ transform=W_H_C,
232
+ is_force=True,
233
+ )
234
+ )
235
+ )(f_C, W_H_C)
236
+
237
+ # Get the object storing the contact parameters of the model.
238
+ contact_parameters = model.kin_dyn_parameters.contact_parameters
239
+
240
+ # Extract the indices corresponding to the enabled collidable points.
241
+ indices_of_enabled_collidable_points = (
242
+ contact_parameters.indices_of_enabled_collidable_points
243
+ )
244
+
245
+ # Construct the vector defining the parent link index of each collidable point.
246
+ # We use this vector to sum the 6D forces of all collidable points rigidly
247
+ # attached to the same link.
248
+ parent_link_index_of_collidable_points = jnp.array(
249
+ contact_parameters.body, dtype=int
250
+ )[indices_of_enabled_collidable_points]
251
+
252
+ # Create the mask that associate each collidable point to their parent link.
253
+ # We use this mask to sum the collidable points to the right link.
254
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
255
+ model.number_of_links()
256
+ )
257
+
258
+ # Sum the forces of all collidable points rigidly attached to a body.
259
+ # Since the contact forces W_f_C are expressed in the world frame,
260
+ # we don't need any coordinate transformation.
261
+ W_f_L = mask.T @ W_f_C
262
+
263
+ # Compute the link transforms.
264
+ W_H_L = (
265
+ js.model.forward_kinematics(model=model, data=data)
266
+ if data.velocity_representation is not jaxsim.VelRepr.Inertial
267
+ else jnp.zeros(shape=(model.number_of_links(), 4, 4))
268
+ )
269
+
270
+ # Convert the inertial-fixed link forces to the velocity representation of data.
271
+ f_L = jax.vmap(
272
+ lambda W_f_L, W_H_L: (
273
+ ModelDataWithVelocityRepresentation.inertial_to_other_representation(
274
+ array=W_f_L,
275
+ other_representation=data.velocity_representation,
276
+ transform=W_H_L,
277
+ is_force=True,
278
+ )
279
+ )
280
+ )(W_f_L, W_H_L)
281
+
282
+ return f_L
283
+
150
284
  @classmethod
151
285
  def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
152
286
  """
@@ -120,19 +120,44 @@ class RelaxedRigidContactsParams(common.ContactsParams):
120
120
 
121
121
  return cls(
122
122
  time_constant=jnp.array(
123
- time_constant or default("time_constant"), dtype=float
123
+ (
124
+ time_constant
125
+ if time_constant is not None
126
+ else default("time_constant")
127
+ ),
128
+ dtype=float,
124
129
  ),
125
130
  damping_coefficient=jnp.array(
126
- damping_coefficient or default("damping_coefficient"), dtype=float
131
+ (
132
+ damping_coefficient
133
+ if damping_coefficient is not None
134
+ else default("damping_coefficient")
135
+ ),
136
+ dtype=float,
137
+ ),
138
+ d_min=jnp.array(
139
+ d_min if d_min is not None else default("d_min"), dtype=float
140
+ ),
141
+ d_max=jnp.array(
142
+ d_max if d_max is not None else default("d_max"), dtype=float
143
+ ),
144
+ width=jnp.array(
145
+ width if width is not None else default("width"), dtype=float
146
+ ),
147
+ midpoint=jnp.array(
148
+ midpoint if midpoint is not None else default("midpoint"), dtype=float
127
149
  ),
128
- d_min=jnp.array(d_min or default("d_min"), dtype=float),
129
- d_max=jnp.array(d_max or default("d_max"), dtype=float),
130
- width=jnp.array(width or default("width"), dtype=float),
131
- midpoint=jnp.array(midpoint or default("midpoint"), dtype=float),
132
- power=jnp.array(power or default("power"), dtype=float),
133
- stiffness=jnp.array(stiffness or default("stiffness"), dtype=float),
134
- damping=jnp.array(damping or default("damping"), dtype=float),
135
- mu=jnp.array(mu or default("mu"), dtype=float),
150
+ power=jnp.array(
151
+ power if power is not None else default("power"), dtype=float
152
+ ),
153
+ stiffness=jnp.array(
154
+ stiffness if stiffness is not None else default("stiffness"),
155
+ dtype=float,
156
+ ),
157
+ damping=jnp.array(
158
+ damping if damping is not None else default("damping"), dtype=float
159
+ ),
160
+ mu=jnp.array(mu if mu is not None else default("mu"), dtype=float),
136
161
  )
137
162
 
138
163
  def valid(self) -> jtp.BoolLike:
@@ -210,7 +235,9 @@ class RelaxedRigidContacts(common.ContactModel):
210
235
 
211
236
  # Create the solver options to set by combining the default solver options
212
237
  # with the user-provided solver options.
213
- solver_options = default_solver_options | (solver_options or {})
238
+ solver_options = default_solver_options | (
239
+ solver_options if solver_options is not None else {}
240
+ )
214
241
 
215
242
  # Make sure that the solver options are hashable.
216
243
  # We need to check this because the solver options are static.
@@ -223,9 +250,15 @@ class RelaxedRigidContacts(common.ContactModel):
223
250
 
224
251
  return cls(
225
252
  parameters=(
226
- parameters or cls.__dataclass_fields__["parameters"].default_factory()
253
+ parameters
254
+ if parameters is not None
255
+ else cls.__dataclass_fields__["parameters"].default_factory()
256
+ ),
257
+ terrain=(
258
+ terrain
259
+ if terrain is not None
260
+ else cls.__dataclass_fields__["terrain"].default_factory()
227
261
  ),
228
- terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
229
262
  _solver_options_keys=tuple(solver_options.keys()),
230
263
  _solver_options_values=tuple(solver_options.values()),
231
264
  )
@@ -238,7 +271,7 @@ class RelaxedRigidContacts(common.ContactModel):
238
271
  *,
239
272
  link_forces: jtp.MatrixLike | None = None,
240
273
  joint_force_references: jtp.VectorLike | None = None,
241
- ) -> tuple[jtp.Matrix, tuple]:
274
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
242
275
  """
243
276
  Compute the contact forces.
244
277
 
@@ -458,7 +491,7 @@ class RelaxedRigidContacts(common.ContactModel):
458
491
  ),
459
492
  )(CW_fl_C, W_H_C)
460
493
 
461
- return W_f_C, ()
494
+ return W_f_C, {}
462
495
 
463
496
  @staticmethod
464
497
  def _regularizers(
@@ -66,9 +66,17 @@ class RigidContactsParams(ContactsParams):
66
66
  """Create a `RigidContactParams` instance"""
67
67
 
68
68
  return cls(
69
- mu=mu or cls.__dataclass_fields__["mu"].default,
70
- K=K or cls.__dataclass_fields__["K"].default,
71
- D=D or cls.__dataclass_fields__["D"].default,
69
+ mu=jnp.array(
70
+ mu
71
+ if mu is not None
72
+ else cls.__dataclass_fields__["mu"].default_factory()
73
+ ).astype(float),
74
+ K=jnp.array(
75
+ K if K is not None else cls.__dataclass_fields__["K"].default_factory()
76
+ ).astype(float),
77
+ D=jnp.array(
78
+ D if D is not None else cls.__dataclass_fields__["D"].default_factory()
79
+ ).astype(float),
72
80
  )
73
81
 
74
82
  def valid(self) -> jtp.BoolLike:
@@ -147,7 +155,9 @@ class RigidContacts(ContactModel):
147
155
 
148
156
  # Create the solver options to set by combining the default solver options
149
157
  # with the user-provided solver options.
150
- solver_options = default_solver_options | (solver_options or {})
158
+ solver_options = default_solver_options | (
159
+ solver_options if solver_options is not None else {}
160
+ )
151
161
 
152
162
  # Make sure that the solver options are hashable.
153
163
  # We need to check this because the solver options are static.
@@ -160,12 +170,19 @@ class RigidContacts(ContactModel):
160
170
 
161
171
  return cls(
162
172
  parameters=(
163
- parameters or cls.__dataclass_fields__["parameters"].default_factory()
173
+ parameters
174
+ if parameters is not None
175
+ else cls.__dataclass_fields__["parameters"].default_factory()
176
+ ),
177
+ terrain=(
178
+ terrain
179
+ if terrain is not None
180
+ else cls.__dataclass_fields__["terrain"].default_factory()
164
181
  ),
165
- terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
166
182
  regularization_delassus=float(
167
183
  regularization_delassus
168
- or cls.__dataclass_fields__["regularization_delassus"].default
184
+ if regularization_delassus is not None
185
+ else cls.__dataclass_fields__["regularization_delassus"].default
169
186
  ),
170
187
  _solver_options_keys=tuple(solver_options.keys()),
171
188
  _solver_options_values=tuple(solver_options.values()),
@@ -242,7 +259,7 @@ class RigidContacts(ContactModel):
242
259
  *,
243
260
  link_forces: jtp.MatrixLike | None = None,
244
261
  joint_force_references: jtp.VectorLike | None = None,
245
- ) -> tuple[jtp.Matrix, tuple]:
262
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
246
263
  """
247
264
  Compute the contact forces.
248
265
 
@@ -402,7 +419,7 @@ class RigidContacts(ContactModel):
402
419
  ),
403
420
  )(CW_fl_C, W_H_C)
404
421
 
405
- return W_f_C, ()
422
+ return W_f_C, {}
406
423
 
407
424
  @staticmethod
408
425
  def _delassus_matrix(
@@ -237,9 +237,13 @@ class SoftContacts(common.ContactModel):
237
237
  else cls.__dataclass_fields__["parameters"].default_factory()
238
238
  )
239
239
 
240
- return SoftContacts(
240
+ return cls(
241
241
  parameters=parameters,
242
- terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
242
+ terrain=(
243
+ terrain
244
+ if terrain is not None
245
+ else cls.__dataclass_fields__["terrain"].default_factory()
246
+ ),
243
247
  )
244
248
 
245
249
  @classmethod
@@ -423,7 +427,7 @@ class SoftContacts(common.ContactModel):
423
427
  self,
424
428
  model: js.model.JaxSimModel,
425
429
  data: js.data.JaxSimModelData,
426
- ) -> tuple[jtp.Matrix, tuple[jtp.Matrix]]:
430
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
427
431
  """
428
432
  Compute the contact forces.
429
433
 
@@ -433,7 +437,7 @@ class SoftContacts(common.ContactModel):
433
437
 
434
438
  Returns:
435
439
  A tuple containing as first element the computed contact forces, and as
436
- second element the derivative of the material deformation.
440
+ second element a dictionary with derivative of the material deformation.
437
441
  """
438
442
 
439
443
  # Initialize the model and data this contact model is operating on.
@@ -460,4 +464,4 @@ class SoftContacts(common.ContactModel):
460
464
  )
461
465
  )(W_p_C, W_ṗ_C, m)
462
466
 
463
- return W_f, (ṁ,)
467
+ return W_f, dict(m_dot=ṁ)
@@ -13,6 +13,7 @@ import jaxsim.api as js
13
13
  import jaxsim.exceptions
14
14
  import jaxsim.typing as jtp
15
15
  from jaxsim import logging
16
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation
16
17
  from jaxsim.math import StandardGravity
17
18
  from jaxsim.terrain import FlatTerrain, Terrain
18
19
 
@@ -235,11 +236,17 @@ class ViscoElasticContacts(common.ContactModel):
235
236
  else cls.__dataclass_fields__["parameters"].default_factory()
236
237
  )
237
238
 
238
- return ViscoElasticContacts(
239
+ return cls(
239
240
  parameters=parameters,
240
- terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
241
+ terrain=(
242
+ terrain
243
+ if terrain is not None
244
+ else cls.__dataclass_fields__["terrain"].default_factory()
245
+ ),
241
246
  max_squarings=int(
242
- max_squarings or cls.__dataclass_fields__["max_squarings"].default
247
+ max_squarings
248
+ if max_squarings is not None
249
+ else cls.__dataclass_fields__["max_squarings"].default
243
250
  ),
244
251
  )
245
252
 
@@ -266,7 +273,7 @@ class ViscoElasticContacts(common.ContactModel):
266
273
  dt: jtp.FloatLike | None = None,
267
274
  link_forces: jtp.MatrixLike | None = None,
268
275
  joint_force_references: jtp.VectorLike | None = None,
269
- ) -> tuple[jtp.Matrix, tuple[jtp.Matrix, jtp.Matrix]]:
276
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
270
277
  """
271
278
  Compute the contact forces.
272
279
 
@@ -291,7 +298,7 @@ class ViscoElasticContacts(common.ContactModel):
291
298
  Returns:
292
299
  A tuple containing as first element the computed 6D contact force applied to
293
300
  the contact point and expressed in the world frame, and as second element
294
- a tuple of optional additional information.
301
+ a dictionary of optional additional information.
295
302
  """
296
303
 
297
304
  # Initialize the model and data this contact model is operating on.
@@ -315,8 +322,8 @@ class ViscoElasticContacts(common.ContactModel):
315
322
  model=model,
316
323
  data=data,
317
324
  dt=jnp.array(dt).astype(float),
318
- joint_force_references=joint_force_references,
319
325
  link_forces=link_forces,
326
+ joint_force_references=joint_force_references,
320
327
  indices_of_enabled_collidable_points=indices_of_enabled_collidable_points,
321
328
  max_squarings=self.max_squarings,
322
329
  )
@@ -334,11 +341,13 @@ class ViscoElasticContacts(common.ContactModel):
334
341
 
335
342
  # Vmapped transformation from mixed to inertial-fixed representation.
336
343
  compute_forces_inertial_fixed_vmap = jax.vmap(
337
- lambda CW_fl_C, W_H_C: data.other_representation_to_inertial(
338
- array=jnp.zeros(6).at[0:3].set(CW_fl_C),
339
- other_representation=jaxsim.VelRepr.Mixed,
340
- transform=W_H_C,
341
- is_force=True,
344
+ lambda CW_fl_C, W_H_C: (
345
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
346
+ array=jnp.zeros(6).at[0:3].set(CW_fl_C),
347
+ other_representation=jaxsim.VelRepr.Mixed,
348
+ transform=W_H_C,
349
+ is_force=True,
350
+ )
342
351
  )
343
352
  )
344
353
 
@@ -347,7 +356,7 @@ class ViscoElasticContacts(common.ContactModel):
347
356
  lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C)
348
357
  )(jnp.stack([CW_f̅l, CW_fl̿]))
349
358
 
350
- return W_f̅_C, (W_f̿_C, m_tf)
359
+ return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf)
351
360
 
352
361
  @staticmethod
353
362
  @functools.partial(jax.jit, static_argnames=("max_squarings",))
@@ -407,8 +416,8 @@ class ViscoElasticContacts(common.ContactModel):
407
416
  A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics(
408
417
  model=model,
409
418
  data=data,
410
- joint_force_references=joint_force_references,
411
419
  link_forces=link_forces,
420
+ joint_force_references=joint_force_references,
412
421
  indices_of_enabled_collidable_points=indices,
413
422
  p_t0=p_t0,
414
423
  v_t0=v_t0,
@@ -657,8 +666,8 @@ class ViscoElasticContacts(common.ContactModel):
657
666
  BW_v̇_free_WB, s̈_free = js.ode.system_acceleration(
658
667
  model=model,
659
668
  data=data,
660
- joint_force_references=references.joint_force_references(model=model),
661
669
  link_forces=references.link_forces(model=model, data=data),
670
+ joint_force_references=references.joint_force_references(model=model),
662
671
  )
663
672
 
664
673
  # Pack the free system acceleration in mixed representation.
@@ -688,7 +697,20 @@ class ViscoElasticContacts(common.ContactModel):
688
697
  parameters: ViscoElasticContactsParams,
689
698
  terrain: Terrain,
690
699
  ) -> tuple[jtp.Matrix, jtp.Vector]:
691
- """"""
700
+ """
701
+ Linearize the Hunt/Crossley contact model at the initial state.
702
+
703
+ Args:
704
+ position: The position of the contact point.
705
+ velocity: The velocity of the contact point.
706
+ tangential_deformation: The tangential deformation of the contact point.
707
+ parameters: The parameters of the contact model.
708
+ terrain: The considered terrain.
709
+
710
+ Returns:
711
+ A tuple containing the `A` matrix and the `b` vector of the linear system
712
+ corresponding to the contact dynamics linearized at the initial state.
713
+ """
692
714
 
693
715
  # Initialize the state at which the model is linearized.
694
716
  p0 = jnp.array(position, dtype=float).squeeze()
@@ -969,58 +991,67 @@ def step(
969
991
  assert isinstance(model.contact_model, ViscoElasticContacts)
970
992
  assert isinstance(data.contacts_params, ViscoElasticContactsParams)
971
993
 
994
+ # Compute the contact forces in inertial-fixed representation.
995
+ # TODO: understand what's wrong in other representations.
996
+ data_inertial_fixed = data.replace(
997
+ velocity_representation=jaxsim.VelRepr.Inertial, validate=False
998
+ )
999
+
1000
+ # Create the references object.
1001
+ references = js.references.JaxSimModelReferences.build(
1002
+ model=model,
1003
+ data=data,
1004
+ link_forces=link_forces,
1005
+ joint_force_references=joint_force_references,
1006
+ velocity_representation=data.velocity_representation,
1007
+ )
1008
+
972
1009
  # Initialize the time step.
973
1010
  dt = dt if dt is not None else model.time_step
974
1011
 
975
1012
  # Compute the contact forces with the exponential integrator.
976
- W_f̅_C, (W_f̿_C, m_tf) = model.contact_model.compute_contact_forces(
1013
+ W_f̅_C, aux_data = model.contact_model.compute_contact_forces(
977
1014
  model=model,
978
- data=data,
1015
+ data=data_inertial_fixed,
979
1016
  dt=jnp.array(dt).astype(float),
980
- link_forces=link_forces,
981
- joint_force_references=joint_force_references,
1017
+ link_forces=references.link_forces(model=model, data=data),
1018
+ joint_force_references=references.joint_force_references(model=model),
982
1019
  )
983
1020
 
1021
+ # Extract the final material deformation and the average of average forces
1022
+ # from the dictionary containing auxiliary data.
1023
+ m_tf = aux_data["m_tf"]
1024
+ W_f̿_C = aux_data["W_f_avg2_C"]
1025
+
984
1026
  # ===============================
985
1027
  # Compute the link contact forces
986
1028
  # ===============================
987
1029
 
988
- # Extract the indices corresponding to the enabled collidable points.
989
- # The visco-elastic contact model computed only their contact forces.
990
- indices_of_enabled_collidable_points = (
991
- model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
992
- )
1030
+ # Get the link contact forces by summing the forces of contact points belonging
1031
+ # to the same link.
1032
+ W_f̅_L, W_f̿_L = jax.vmap(
1033
+ lambda W_f_C: model.contact_model.link_forces_from_contact_forces(
1034
+ model=model, data=data_inertial_fixed, contact_forces=W_f_C
1035
+ )
1036
+ )(jnp.stack([W_f̅_C, W_f̿_C]))
993
1037
 
994
1038
  # Compute the link transforms.
995
- W_H_L = js.model.forward_kinematics(model=model, data=data)
996
-
997
- # Construct the vector defining the parent link index of each collidable point.
998
- # We use this vector to sum the 6D forces of all collidable points rigidly
999
- # attached to the same link.
1000
- parent_link_index_of_collidable_points = jnp.array(
1001
- model.kin_dyn_parameters.contact_parameters.body, dtype=int
1002
- )[indices_of_enabled_collidable_points]
1003
-
1004
- # Create the mask that associate each collidable point to their parent link.
1005
- # We use this mask to sum the collidable points to the right link.
1006
- mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
1007
- model.number_of_links()
1039
+ W_H_L = (
1040
+ js.model.forward_kinematics(model=model, data=data)
1041
+ if data.velocity_representation is not jaxsim.VelRepr.Inertial
1042
+ else jnp.zeros(shape=(model.number_of_links(), 4, 4))
1008
1043
  )
1009
1044
 
1010
- # Sum the forces of all collidable points rigidly attached to a body.
1011
- # Since the contact forces W_f_C are expressed in the world frame,
1012
- # we don't need any coordinate transformation.
1013
- W_f̅_L = mask.T @ W_f̅_C
1014
- W_f̿_L = mask.T @ W_f̿_C
1015
-
1016
- # For integration purpose, we need these average of averages expressed in
1045
+ # For integration purpose, we need the average of average forces expressed in
1017
1046
  # mixed representation.
1018
1047
  LW_f̿_L = jax.vmap(
1019
- lambda W_f_L, W_H_L: data.inertial_to_other_representation(
1020
- array=W_f_L,
1021
- other_representation=jaxsim.VelRepr.Mixed,
1022
- transform=W_H_L,
1023
- is_force=True,
1048
+ lambda W_f_L, W_H_L: (
1049
+ ModelDataWithVelocityRepresentation.inertial_to_other_representation(
1050
+ array=W_f_L,
1051
+ other_representation=jaxsim.VelRepr.Mixed,
1052
+ transform=W_H_L,
1053
+ is_force=True,
1054
+ )
1024
1055
  )
1025
1056
  )(W_f̿_L, W_H_L)
1026
1057
 
@@ -1032,10 +1063,10 @@ def step(
1032
1063
  data_tf: js.data.JaxSimModelData = (
1033
1064
  model.contact_model.integrate_data_with_average_contact_forces(
1034
1065
  model=model,
1035
- data=data,
1066
+ data=data_inertial_fixed,
1036
1067
  dt=dt,
1037
- link_forces=link_forces,
1038
- joint_force_references=joint_force_references,
1068
+ link_forces=references.link_forces(model=model, data=data),
1069
+ joint_force_references=references.joint_force_references(model=model),
1039
1070
  average_link_contact_forces_inertial=W_f̅_L,
1040
1071
  average_of_average_link_contact_forces_mixed=LW_f̿_L,
1041
1072
  )
@@ -1046,10 +1077,21 @@ def step(
1046
1077
  # be much more accurate than the one computed with the discrete soft contacts.
1047
1078
  with data_tf.mutable_context():
1048
1079
 
1080
+ # Extract the indices corresponding to the enabled collidable points.
1081
+ # The visco-elastic contact model computed only their contact forces.
1082
+ indices_of_enabled_collidable_points = (
1083
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
1084
+ )
1085
+
1049
1086
  data_tf.state.extended |= {
1050
1087
  "tangential_deformation": data_tf.state.extended["tangential_deformation"]
1051
1088
  .at[indices_of_enabled_collidable_points]
1052
1089
  .set(m_tf)
1053
1090
  }
1054
1091
 
1092
+ # Restore the original velocity representation.
1093
+ data_tf = data_tf.replace(
1094
+ velocity_representation=data.velocity_representation, validate=False
1095
+ )
1096
+
1055
1097
  return data_tf, {}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev231
3
+ Version: 0.4.3.dev245
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>