jaxsim 0.4.3.dev115__py3-none-any.whl → 0.4.3.dev129__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev115'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev115')
15
+ __version__ = version = '0.4.3.dev129'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev129')
jaxsim/api/contact.py CHANGED
@@ -9,7 +9,6 @@ import jaxsim.api as js
9
9
  import jaxsim.terrain
10
10
  import jaxsim.typing as jtp
11
11
  from jaxsim.math import Adjoint, Cross, Transform
12
- from jaxsim.rbda.contacts.soft import SoftContactsParams
13
12
 
14
13
  from .common import VelRepr
15
14
 
@@ -156,56 +155,43 @@ def collidable_point_dynamics(
156
155
  Instead, the 6D forces are returned in the active representation.
157
156
  """
158
157
 
159
- # Compute the position and linear velocities (mixed representation) of
160
- # all collidable points belonging to the robot.
161
- W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
162
-
163
158
  # Import privately the contacts classes.
164
- from jaxsim.rbda.contacts.relaxed_rigid import (
159
+ from jaxsim.rbda.contacts import (
165
160
  RelaxedRigidContacts,
166
161
  RelaxedRigidContactsState,
162
+ RigidContacts,
163
+ RigidContactsState,
164
+ SoftContacts,
165
+ SoftContactsState,
167
166
  )
168
- from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
169
- from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
170
167
 
171
168
  # Build the soft contact model.
172
169
  match model.contact_model:
173
170
 
174
171
  case SoftContacts():
175
-
176
172
  assert isinstance(model.contact_model, SoftContacts)
177
173
  assert isinstance(data.state.contact, SoftContactsState)
178
174
 
179
- # Build the contact model.
180
- soft_contacts = SoftContacts(
181
- parameters=data.contacts_params, terrain=model.terrain
182
- )
183
-
184
175
  # Compute the 6D force expressed in the inertial frame and applied to each
185
176
  # collidable point, and the corresponding material deformation rate.
186
177
  # Note that the material deformation rate is always returned in the mixed frame
187
178
  # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
188
- W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
189
- position=W_p_Ci,
190
- velocity=W_ṗ_Ci,
191
- tangential_deformation=data.state.contact.tangential_deformation,
179
+ W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces(
180
+ model=model, data=data
192
181
  )
182
+
183
+ # Create the dictionary of auxiliary data.
184
+ # This contact model considers the material deformation as additional state
185
+ # of the ODE system. We need to pass its dynamics to the integrator.
193
186
  aux_data = dict(m_dot=CW_ṁ)
194
187
 
195
188
  case RigidContacts():
196
189
  assert isinstance(model.contact_model, RigidContacts)
197
190
  assert isinstance(data.state.contact, RigidContactsState)
198
191
 
199
- # Build the contact model.
200
- rigid_contacts = RigidContacts(
201
- parameters=data.contacts_params, terrain=model.terrain
202
- )
203
-
204
192
  # Compute the 6D force expressed in the inertial frame and applied to each
205
193
  # collidable point.
206
- W_f_Ci, _ = rigid_contacts.compute_contact_forces(
207
- position=W_p_Ci,
208
- velocity=W_ṗ_Ci,
194
+ W_f_Ci, _ = model.contact_model.compute_contact_forces(
209
195
  model=model,
210
196
  data=data,
211
197
  link_forces=link_forces,
@@ -219,16 +205,9 @@ def collidable_point_dynamics(
219
205
  assert isinstance(model.contact_model, RelaxedRigidContacts)
220
206
  assert isinstance(data.state.contact, RelaxedRigidContactsState)
221
207
 
222
- # Build the contact model.
223
- relaxed_rigid_contacts = RelaxedRigidContacts(
224
- parameters=data.contacts_params, terrain=model.terrain
225
- )
226
-
227
208
  # Compute the 6D force expressed in the inertial frame and applied to each
228
209
  # collidable point.
229
- W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
230
- position=W_p_Ci,
231
- velocity=W_ṗ_Ci,
210
+ W_f_Ci, _ = model.contact_model.compute_contact_forces(
232
211
  model=model,
233
212
  data=data,
234
213
  link_forces=link_forces,
@@ -318,7 +297,7 @@ def estimate_good_soft_contacts_parameters(
318
297
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
319
298
  damping_ratio: jtp.FloatLike = 1.0,
320
299
  max_penetration: jtp.FloatLike | None = None,
321
- ) -> SoftContactsParams:
300
+ ) -> jaxsim.rbda.contacts.SoftContactsParams:
322
301
  """
323
302
  Estimate good soft contacts parameters for the given model.
324
303
 
@@ -342,14 +321,13 @@ def estimate_good_soft_contacts_parameters(
342
321
  The user is encouraged to fine-tune the parameters based on the
343
322
  specific application.
344
323
  """
345
- from jaxsim.rbda.contacts.soft import SoftContactsParams
346
324
 
347
325
  def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
348
326
  """"""
349
327
 
350
328
  zero_data = js.data.JaxSimModelData.build(
351
329
  model=model,
352
- contacts_params=SoftContactsParams(),
330
+ contacts_params=jaxsim.rbda.contacts.SoftContactsParams(),
353
331
  )
354
332
 
355
333
  W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
@@ -368,16 +346,26 @@ def estimate_good_soft_contacts_parameters(
368
346
 
369
347
  nc = number_of_active_collidable_points_steady_state
370
348
 
371
- sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
372
- model=model,
373
- standard_gravity=standard_gravity,
374
- static_friction_coefficient=static_friction_coefficient,
375
- max_penetration=max_δ,
376
- number_of_active_collidable_points_steady_state=nc,
377
- damping_ratio=damping_ratio,
378
- )
349
+ match model.contact_model:
350
+
351
+ case jaxsim.rbda.contacts.SoftContacts():
352
+ assert isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts)
353
+
354
+ parameters = (
355
+ jaxsim.rbda.contacts.SoftContactsParams.build_default_from_jaxsim_model(
356
+ model=model,
357
+ standard_gravity=standard_gravity,
358
+ static_friction_coefficient=static_friction_coefficient,
359
+ max_penetration=max_δ,
360
+ number_of_active_collidable_points_steady_state=nc,
361
+ damping_ratio=damping_ratio,
362
+ )
363
+ )
364
+
365
+ case _:
366
+ parameters = model.contact_model.parameters
379
367
 
380
- return sc_parameters
368
+ return parameters
381
369
 
382
370
 
383
371
  @jax.jit
jaxsim/api/data.py CHANGED
@@ -13,7 +13,7 @@ import jaxsim.api as js
13
13
  import jaxsim.math
14
14
  import jaxsim.rbda
15
15
  import jaxsim.typing as jtp
16
- from jaxsim.rbda.contacts.soft import SoftContacts
16
+ from jaxsim.rbda.contacts import SoftContacts
17
17
  from jaxsim.utils import Mutability
18
18
  from jaxsim.utils.tracing import not_tracing
19
19
 
@@ -37,7 +37,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
37
37
 
38
38
  gravity: jtp.Array
39
39
 
40
- contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
40
+ contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
41
41
 
42
42
  time_ns: jtp.Int = dataclasses.field(
43
43
  default_factory=lambda: jnp.array(
@@ -114,8 +114,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
114
114
  base_angular_velocity: jtp.Vector | None = None,
115
115
  joint_velocities: jtp.Vector | None = None,
116
116
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
117
- contact: jaxsim.rbda.ContactsState | None = None,
118
- contacts_params: jaxsim.rbda.ContactsParams | None = None,
117
+ contact: jaxsim.rbda.contacts.ContactsState | None = None,
118
+ contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
119
119
  velocity_representation: VelRepr = VelRepr.Inertial,
120
120
  time: jtp.FloatLike | None = None,
121
121
  ) -> JaxSimModelData:
@@ -185,17 +185,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
185
185
  )
186
186
  )
187
187
 
188
- if isinstance(model.contact_model, SoftContacts):
189
- contacts_params = (
190
- contacts_params
191
- if contacts_params is not None
192
- else js.contact.estimate_good_soft_contacts_parameters(
193
- model=model, standard_gravity=standard_gravity
194
- )
195
- )
196
- else:
197
- contacts_params = model.contact_model.parameters
198
-
199
188
  W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
200
189
  translation=base_position, quaternion=base_quaternion
201
190
  )
@@ -225,6 +214,15 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
225
214
  if not ode_state.valid(model=model):
226
215
  raise ValueError(ode_state)
227
216
 
217
+ if contacts_params is None:
218
+
219
+ if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
220
+ contacts_params = js.contact.estimate_good_soft_contacts_parameters(
221
+ model=model, standard_gravity=standard_gravity
222
+ )
223
+ else:
224
+ contacts_params = model.contact_model.parameters
225
+
228
226
  return JaxSimModelData(
229
227
  time_ns=time_ns,
230
228
  state=ode_state,
jaxsim/api/model.py CHANGED
@@ -36,7 +36,7 @@ class JaxSimModel(JaxsimDataclass):
36
36
  default=jaxsim.terrain.FlatTerrain(), repr=False
37
37
  )
38
38
 
39
- contact_model: jaxsim.rbda.ContactModel | None = dataclasses.field(
39
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
40
40
  default=None, repr=False
41
41
  )
42
42
 
@@ -89,7 +89,7 @@ class JaxSimModel(JaxsimDataclass):
89
89
  model_name: str | None = None,
90
90
  *,
91
91
  terrain: jaxsim.terrain.Terrain | None = None,
92
- contact_model: jaxsim.rbda.ContactModel | None = None,
92
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
93
93
  is_urdf: bool | None = None,
94
94
  considered_joints: Sequence[str] | None = None,
95
95
  ) -> JaxSimModel:
@@ -150,7 +150,7 @@ class JaxSimModel(JaxsimDataclass):
150
150
  model_name: str | None = None,
151
151
  *,
152
152
  terrain: jaxsim.terrain.Terrain | None = None,
153
- contact_model: jaxsim.rbda.ContactModel | None = None,
153
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
154
154
  ) -> JaxSimModel:
155
155
  """
156
156
  Build a Model object from an intermediate model description.
@@ -169,14 +169,15 @@ class JaxSimModel(JaxsimDataclass):
169
169
  Returns:
170
170
  The built Model object.
171
171
  """
172
- from jaxsim.rbda.contacts.soft import SoftContacts
173
172
 
174
173
  # Set the model name (if not provided, use the one from the model description).
175
174
  model_name = model_name if model_name is not None else model_description.name
176
175
 
177
176
  # Set the terrain (if not provided, use the default flat terrain).
178
177
  terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
179
- contact_model = contact_model or SoftContacts(terrain=terrain)
178
+ contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts(
179
+ terrain=terrain
180
+ )
180
181
 
181
182
  # Build the model.
182
183
  model = JaxSimModel(
@@ -1930,8 +1931,6 @@ def step(
1930
1931
  and the new state of the integrator.
1931
1932
  """
1932
1933
 
1933
- from jaxsim.rbda.contacts.rigid import RigidContacts
1934
-
1935
1934
  # Extract the integrator kwargs.
1936
1935
  # The following logic allows using integrators having kwargs colliding with the
1937
1936
  # kwargs of this step function.
@@ -1992,12 +1991,16 @@ def step(
1992
1991
  # Post process the simulation state, if needed.
1993
1992
  match model.contact_model:
1994
1993
 
1995
- # Rigid contact models use an impact model that produces a discontinuous model velocity.
1996
- # Hence here we need to reset the velocity after each impact to guarantee that
1994
+ # Rigid contact models use an impact model that produces discontinuous model velocities.
1995
+ # Hence, here we need to reset the velocity after each impact to guarantee that
1997
1996
  # the linear velocity of the active collidable points is zero.
1998
- case RigidContacts():
1999
- # Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization
2000
- # enabled are used with ForwardEuler integrator.
1997
+ case jaxsim.rbda.contacts.RigidContacts():
1998
+ assert isinstance(
1999
+ data_tf.contacts_params, jaxsim.rbda.contacts.RigidContactsParams
2000
+ )
2001
+
2002
+ # Raise runtime error for not supported case in which Rigid contacts and
2003
+ # Baumgarte stabilization are enabled and used with ForwardEuler integrator.
2001
2004
  jaxsim.exceptions.raise_runtime_error_if(
2002
2005
  condition=jnp.logical_and(
2003
2006
  isinstance(
@@ -2013,23 +2016,38 @@ def step(
2013
2016
  )
2014
2017
 
2015
2018
  with data_tf.switch_velocity_representation(VelRepr.Mixed):
2016
- W_p_C = js.contact.collidable_point_positions(model, data_tf)
2017
- M = js.model.free_floating_mass_matrix(model, data_tf)
2019
+
2018
2020
  J_WC = js.contact.jacobian(model, data_tf)
2021
+ M = js.model.free_floating_mass_matrix(model, data_tf)
2022
+ W_p_C = js.contact.collidable_point_positions(model, data_tf)
2023
+
2024
+ # Compute the height of the terrain below each collidable point.
2019
2025
  px, py, _ = W_p_C.T
2020
2026
  terrain_height = jax.vmap(model.terrain.height)(px, py)
2021
- inactive_collidable_points, _ = RigidContacts.detect_contacts(
2022
- W_p_C=W_p_C,
2023
- terrain_height=terrain_height,
2027
+
2028
+ # Compute the contact state.
2029
+ inactive_collidable_points, _ = (
2030
+ jaxsim.rbda.contacts.RigidContacts.detect_contacts(
2031
+ W_p_C=W_p_C,
2032
+ terrain_height=terrain_height,
2033
+ )
2024
2034
  )
2025
- BW_nu_post_impact = RigidContacts.compute_impact_velocity(
2026
- data=data_tf,
2027
- inactive_collidable_points=inactive_collidable_points,
2028
- M=M,
2029
- J_WC=J_WC,
2035
+
2036
+ # Compute the impact velocity.
2037
+ # It may be discontinuous in case new contacts are made.
2038
+ BW_nu_post_impact = (
2039
+ jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
2040
+ data=data_tf,
2041
+ inactive_collidable_points=inactive_collidable_points,
2042
+ M=M,
2043
+ J_WC=J_WC,
2044
+ )
2030
2045
  )
2046
+
2047
+ # Reset the generalized velocity.
2031
2048
  data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
2032
2049
  data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
2050
+
2033
2051
  # Restore the input velocity representation.
2034
2052
  data_tf = data_tf.replace(
2035
2053
  velocity_representation=data.velocity_representation, validate=False
jaxsim/api/ode_data.py CHANGED
@@ -5,13 +5,15 @@ import jax_dataclasses
5
5
 
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
- from jaxsim.rbda import ContactsState
9
- from jaxsim.rbda.contacts.relaxed_rigid import (
8
+ from jaxsim.rbda.contacts import (
9
+ ContactsState,
10
10
  RelaxedRigidContacts,
11
11
  RelaxedRigidContactsState,
12
+ RigidContacts,
13
+ RigidContactsState,
14
+ SoftContacts,
15
+ SoftContactsState,
12
16
  )
13
- from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
14
- from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
15
17
  from jaxsim.utils import JaxsimDataclass
16
18
 
17
19
  # =============================================================================
@@ -165,8 +167,11 @@ class ODEState(JaxsimDataclass):
165
167
 
166
168
  # Get the contact model from the `JaxSimModel`.
167
169
  match model.contact_model:
170
+
168
171
  case SoftContacts():
172
+
169
173
  tangential_deformation = kwargs.get("tangential_deformation", None)
174
+
170
175
  contact = SoftContactsState.build_from_jaxsim_model(
171
176
  model=model,
172
177
  **(
@@ -182,7 +187,7 @@ class ODEState(JaxsimDataclass):
182
187
  contact = RelaxedRigidContactsState.build()
183
188
 
184
189
  case _:
185
- raise ValueError("Unable to determine contact state class prefix.")
190
+ raise ValueError("Unsupported contact model.")
186
191
 
187
192
  return ODEState.build(
188
193
  model=model,
jaxsim/rbda/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
+ from . import contacts
1
2
  from .aba import aba
2
3
  from .collidable_points import collidable_points_pos_vel
3
- from .contacts.common import ContactModel, ContactsParams, ContactsState
4
4
  from .crba import crba
5
5
  from .forward_kinematics import forward_kinematics, forward_kinematics_model
6
6
  from .jacobian import (
@@ -1,3 +1,4 @@
1
+ from . import relaxed_rigid, rigid, soft
1
2
  from .common import ContactModel, ContactsParams, ContactsState
2
3
  from .relaxed_rigid import (
3
4
  RelaxedRigidContacts,
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import abc
4
4
  from typing import Any
5
5
 
6
+ import jaxsim.api as js
6
7
  import jaxsim.terrain
7
8
  import jaxsim.typing as jtp
8
9
  from jaxsim.utils import JaxsimDataclass
@@ -90,20 +91,48 @@ class ContactModel(JaxsimDataclass):
90
91
  @abc.abstractmethod
91
92
  def compute_contact_forces(
92
93
  self,
93
- position: jtp.VectorLike,
94
- velocity: jtp.VectorLike,
94
+ model: js.model.JaxSimModel,
95
+ data: js.data.JaxSimModelData,
95
96
  **kwargs,
96
97
  ) -> tuple[jtp.Vector, tuple[Any, ...]]:
97
98
  """
98
99
  Compute the contact forces.
99
100
 
100
101
  Args:
101
- position: The position of the collidable point w.r.t. the world frame.
102
- velocity:
103
- The linear velocity of the collidable point (linear component of the mixed 6D velocity).
102
+ model: The model to consider.
103
+ data: The data of the considered model.
104
104
 
105
105
  Returns:
106
106
  A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
107
107
  and as second element a tuple of optional additional information.
108
108
  """
109
+
109
110
  pass
111
+
112
+ def initialize_model_and_data(
113
+ self,
114
+ model: js.model.JaxSimModel,
115
+ data: js.data.JaxSimModelData,
116
+ validate: bool = True,
117
+ ) -> tuple[js.model.JaxSimModel, js.data.JaxSimModelData]:
118
+ """
119
+ Helper function to initialize the active model and data objects.
120
+
121
+ Args:
122
+ model: The robot model considered by the contact model.
123
+ data: The data of the considered robot model.
124
+ validate:
125
+ Whether to validate if the model and data objects have been
126
+ initialized with the current contact model.
127
+
128
+ Returns:
129
+ The initialized model and data objects.
130
+ """
131
+
132
+ with model.editable(validate=validate) as model_out:
133
+ model_out.contact_model = self
134
+
135
+ with data.editable(validate=validate) as data_out:
136
+ data_out.contacts_params = data.contacts_params
137
+
138
+ return model_out, data_out
@@ -169,12 +169,12 @@ class RelaxedRigidContactsState(ContactsState):
169
169
  return cls()
170
170
 
171
171
  @classmethod
172
- def zero(cls: type[Self]) -> Self:
172
+ def zero(cls: type[Self], **kwargs) -> Self:
173
173
  """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
174
174
 
175
175
  return cls.build()
176
176
 
177
- def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike:
177
+ def valid(self, **kwargs) -> jtp.BoolLike:
178
178
  return True
179
179
 
180
180
 
@@ -193,11 +193,9 @@ class RelaxedRigidContacts(ContactModel):
193
193
  @jax.jit
194
194
  def compute_contact_forces(
195
195
  self,
196
- position: jtp.VectorLike,
197
- velocity: jtp.VectorLike,
198
- *,
199
196
  model: js.model.JaxSimModel,
200
197
  data: js.data.JaxSimModelData,
198
+ *,
201
199
  link_forces: jtp.MatrixLike | None = None,
202
200
  joint_force_references: jtp.VectorLike | None = None,
203
201
  ) -> tuple[jtp.Vector, tuple[Any, ...]]:
@@ -205,10 +203,8 @@ class RelaxedRigidContacts(ContactModel):
205
203
  Compute the contact forces.
206
204
 
207
205
  Args:
208
- position: The position of the collidable point.
209
- velocity: The linear velocity of the collidable point.
210
- model: The `JaxSimModel` instance.
211
- data: The `JaxSimModelData` instance.
206
+ model: The model to consider.
207
+ data: The data of the considered model.
212
208
  link_forces:
213
209
  Optional `(n_links, 6)` matrix of external forces acting on the links,
214
210
  expressed in the same representation of data.
@@ -219,6 +215,11 @@ class RelaxedRigidContacts(ContactModel):
219
215
  A tuple containing the contact forces.
220
216
  """
221
217
 
218
+ # Initialize the model and data this contact model is operating on.
219
+ # This will raise an exception if either the contact model or the
220
+ # contact parameters are not compatible.
221
+ model, data = self.initialize_model_and_data(model=model, data=data)
222
+
222
223
  link_forces = (
223
224
  link_forces
224
225
  if link_forces is not None
@@ -247,6 +248,12 @@ class RelaxedRigidContacts(ContactModel):
247
248
 
248
249
  return jnp.dot(h, n̂)
249
250
 
251
+ # Compute the position and linear velocities (mixed representation) of
252
+ # all collidable points belonging to the robot.
253
+ position, velocity = js.contact.collidable_point_kinematics(
254
+ model=model, data=data
255
+ )
256
+
250
257
  # Compute the activation state of the collidable points
251
258
  δ = jax.vmap(_detect_contact)(*position.T)
252
259
 
@@ -92,12 +92,12 @@ class RigidContactsState(ContactsState):
92
92
  return cls()
93
93
 
94
94
  @classmethod
95
- def zero(cls: type[Self]) -> Self:
95
+ def zero(cls: type[Self], **kwargs) -> Self:
96
96
  """Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
97
97
 
98
98
  return cls.build()
99
99
 
100
- def valid(self) -> jtp.BoolLike:
100
+ def valid(self, **kwargs) -> jtp.BoolLike:
101
101
  return True
102
102
 
103
103
 
@@ -219,11 +219,9 @@ class RigidContacts(ContactModel):
219
219
  @jax.jit
220
220
  def compute_contact_forces(
221
221
  self,
222
- position: jtp.VectorLike,
223
- velocity: jtp.VectorLike,
224
- *,
225
222
  model: js.model.JaxSimModel,
226
223
  data: js.data.JaxSimModelData,
224
+ *,
227
225
  link_forces: jtp.MatrixLike | None = None,
228
226
  joint_force_references: jtp.VectorLike | None = None,
229
227
  regularization_term: jtp.FloatLike = 1e-6,
@@ -233,10 +231,8 @@ class RigidContacts(ContactModel):
233
231
  Compute the contact forces.
234
232
 
235
233
  Args:
236
- position: The position of the collidable point.
237
- velocity: The linear velocity of the collidable point.
238
- model: The `JaxSimModel` instance.
239
- data: The `JaxSimModelData` instance.
234
+ model: The model to consider.
235
+ data: The data of the considered model.
240
236
  link_forces:
241
237
  Optional `(n_links, 6)` matrix of external forces acting on the links,
242
238
  expressed in the same representation of data.
@@ -245,11 +241,17 @@ class RigidContacts(ContactModel):
245
241
  regularization_term:
246
242
  The regularization term to add to the diagonal of the Delassus
247
243
  matrix for better numerical conditioning.
244
+ solver_tol: The convergence tolerance to consider in the QP solver.
248
245
 
249
246
  Returns:
250
247
  A tuple containing the contact forces.
251
248
  """
252
249
 
250
+ # Initialize the model and data this contact model is operating on.
251
+ # This will raise an exception if either the contact model or the
252
+ # contact parameters are not compatible.
253
+ model, data = self.initialize_model_and_data(model=model, data=data)
254
+
253
255
  # Import qpax just in this method
254
256
  import qpax
255
257
 
@@ -273,6 +275,12 @@ class RigidContacts(ContactModel):
273
275
  J̇_WC_BW = js.contact.jacobian_derivative(model=model, data=data)
274
276
  BW_ν = data.generalized_velocity()
275
277
 
278
+ # Compute the position and linear velocities (mixed representation) of
279
+ # all collidable points belonging to the robot.
280
+ position, velocity = js.contact.collidable_point_kinematics(
281
+ model=model, data=data
282
+ )
283
+
276
284
  terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1])
277
285
  n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0]
278
286
 
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
+ import functools
4
5
 
5
6
  import jax
6
7
  import jax.numpy as jnp
@@ -198,30 +199,32 @@ class SoftContacts(ContactModel):
198
199
  default_factory=FlatTerrain
199
200
  )
200
201
 
201
- def compute_contact_forces(
202
- self,
202
+ @staticmethod
203
+ @functools.partial(jax.jit, static_argnames=("terrain",))
204
+ def hunt_crossley_contact_model(
203
205
  position: jtp.VectorLike,
204
206
  velocity: jtp.VectorLike,
205
- *,
206
207
  tangential_deformation: jtp.VectorLike,
207
- ) -> tuple[jtp.Vector, tuple[jtp.Vector]]:
208
+ terrain: Terrain,
209
+ K: jtp.FloatLike,
210
+ D: jtp.FloatLike,
211
+ mu: jtp.FloatLike,
212
+ p: jtp.FloatLike = 0.5,
213
+ q: jtp.FloatLike = 0.5,
214
+ ) -> tuple[jtp.Vector, jtp.Vector]:
208
215
 
209
216
  # Convert the input vectors to arrays.
210
217
  W_p_C = jnp.array(position, dtype=float).squeeze()
211
218
  W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
212
219
  m = jnp.array(tangential_deformation, dtype=float).squeeze()
213
220
 
214
- # Short name of parameters.
215
- K = self.parameters.K
216
- D = self.parameters.D
217
- μ = self.parameters.mu
221
+ # Use symbol for the static friction.
222
+ μ = mu
218
223
 
219
224
  # Compute the penetration depth, its rate, and the considered terrain normal.
220
- δ, δ̇, n̂ = self.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=self.terrain)
221
-
222
- # Get the exponents of the Hunt/Crossley model non-linear terms.
223
- p = self.parameters.p
224
- q = self.parameters.q
225
+ δ, δ̇, n̂ = SoftContacts.compute_penetration_data(
226
+ p=W_p_C, v=W_ṗ_C, terrain=terrain
227
+ )
225
228
 
226
229
  # There are few operations like computing the norm of a vector with zero length
227
230
  # or computing the square root of zero that are problematic in an AD context.
@@ -256,14 +259,15 @@ class SoftContacts(ContactModel):
256
259
  # Extract the tangential component of the velocity.
257
260
  v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
258
261
 
259
- # Extract the tangential component of the material deformation.
260
- # This should not be necessary if the sticking-slipping transition occurs
261
- # in a terrain area with a locally constant normal. However, this assumption
262
- # is not true in general for highly uneven terrains.
262
+ # Extract the normal and tangential components of the material deformation.
263
263
  m_normal = jnp.dot(m, n̂) * n̂
264
264
  m_tangential = m - jnp.dot(m, n̂) * n̂
265
265
 
266
266
  # Compute the tangential force in the sticking case.
267
+ # Using the tangential component of the material deformation should not be
268
+ # necessary if the sticking-slipping transition occurs in a terrain area
269
+ # with a locally constant normal. However, this assumption is not true in
270
+ # general, especially for highly uneven terrains.
267
271
  f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
268
272
 
269
273
  # Detect the contact type (sticking or slipping).
@@ -298,6 +302,9 @@ class SoftContacts(ContactModel):
298
302
  # =====================================
299
303
 
300
304
  # Compute the derivative of the material deformation.
305
+ # Note that we included an additional relaxation of `m_normal` in the
306
+ # sticking case, so that the normal deformation that could have accumulated
307
+ # from a previous slipping phase can relax to zero.
301
308
  ṁ_no_contact = -(K / D) * m
302
309
  ṁ_sticking = v_tangential - (K / D) * m_normal
303
310
  ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
@@ -316,15 +323,79 @@ class SoftContacts(ContactModel):
316
323
  # Compute and return the final contact force
317
324
  # ==========================================
318
325
 
319
- # Sum the normal and tangential forces and create a mixed 6D force.
320
- CW_f = jnp.hstack([f_normal + f_tangential, jnp.zeros(3)])
326
+ # Sum the normal and tangential forces.
327
+ CW_fl = f_normal + f_tangential
328
+
329
+ return CW_fl, ṁ
330
+
331
+ @staticmethod
332
+ @functools.partial(jax.jit, static_argnames=("terrain",))
333
+ def compute_contact_force(
334
+ position: jtp.VectorLike,
335
+ velocity: jtp.VectorLike,
336
+ tangential_deformation: jtp.VectorLike,
337
+ parameters: SoftContactsParams,
338
+ terrain: Terrain,
339
+ ) -> tuple[jtp.Vector, jtp.Vector]:
340
+
341
+ CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
342
+ position=position,
343
+ velocity=velocity,
344
+ tangential_deformation=tangential_deformation,
345
+ terrain=terrain,
346
+ K=parameters.K,
347
+ D=parameters.D,
348
+ mu=parameters.mu,
349
+ p=parameters.p,
350
+ q=parameters.q,
351
+ )
352
+
353
+ # Pack a mixed 6D force.
354
+ CW_f = jnp.hstack([CW_fl, jnp.zeros(3)])
321
355
 
322
356
  # Compute the 6D force transform from the mixed to the inertial-fixed frame.
323
357
  W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation(
324
- translation=W_p_C, inverse=True
358
+ translation=jnp.array(position), inverse=True
325
359
  ).T
326
360
 
327
- return W_Xf_CW @ CW_f, (ṁ,)
361
+ # Compute the 6D force in the inertial-fixed frame.
362
+ W_f = W_Xf_CW @ CW_f
363
+
364
+ return W_f, ṁ
365
+
366
+ @jax.jit
367
+ def compute_contact_forces(
368
+ self,
369
+ model: js.model.JaxSimModel,
370
+ data: js.data.JaxSimModelData,
371
+ ) -> tuple[jtp.Vector, tuple[jtp.Vector]]:
372
+
373
+ # Initialize the model and data this contact model is operating on.
374
+ # This will raise an exception if either the contact model or the
375
+ # contact parameters are not compatible.
376
+ model, data = self.initialize_model_and_data(model=model, data=data)
377
+
378
+ # Compute the position and linear velocities (mixed representation) of
379
+ # all collidable points belonging to the robot.
380
+ W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data)
381
+
382
+ # Extract the material deformation corresponding to the collidable points.
383
+ assert isinstance(data.state.contact, SoftContactsState)
384
+ m = data.state.contact.tangential_deformation
385
+
386
+ # Compute the contact forces for all collidable points.
387
+ # Since we treat them as independent, we can vmap the computation.
388
+ W_f, ṁ = jax.vmap(
389
+ lambda p, v, m: SoftContacts.compute_contact_force(
390
+ position=p,
391
+ velocity=v,
392
+ tangential_deformation=m,
393
+ parameters=self.parameters,
394
+ terrain=self.terrain,
395
+ )
396
+ )(W_p_C, W_ṗ_C, m)
397
+
398
+ return W_f, (ṁ,)
328
399
 
329
400
  @staticmethod
330
401
  @jax.jit
jaxsim/terrain/terrain.py CHANGED
@@ -155,7 +155,7 @@ class PlaneTerrain(FlatTerrain):
155
155
  (
156
156
  hash(self._height),
157
157
  HashedNumpyArray.hash_of_array(
158
- array=jnp.array(self._normal, dtype=float)
158
+ array=np.array(self._normal, dtype=float)
159
159
  ),
160
160
  )
161
161
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev115
3
+ Version: 0.4.3.dev129
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,20 +1,20 @@
1
1
  jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=5euYiJAuz4T6MQwktcrK-oclprll5Dr6ScaKSKchsa8,428
2
+ jaxsim/_version.py,sha256=TZYv9WjsK9pTtCnxiHUO_BF5LO5oh7sVdFFBwKnF98k,428
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
- jaxsim/api/contact.py,sha256=ocwsVS1jaBfrd81990hcgfS0-2xD8VVzDq7gdPguAUg,23087
10
- jaxsim/api/data.py,sha256=QldUHniJqKrdNtAcXuRaS9UyeslJ0Rjvb17UA0Ca5Tw,29008
9
+ jaxsim/api/contact.py,sha256=ktol-c2UHDqt3Hd6NpMhK86NlyadIyB5eY7chA5GhkY,22568
10
+ jaxsim/api/data.py,sha256=bSuBVKKksYtANLnxCkMQ3u2JGbn5mud7g8G5aLZawls,28988
11
11
  jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
12
12
  jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=ElahFk_RCcLvjTidH2qDOsY-m1gN1hXitCv4SvfgGYY,29260
14
14
  jaxsim/api/link.py,sha256=LAA6ZMQXkWomXeptURBtc7z3_xDZ2BBnBMhVrohh0bE,18621
15
- jaxsim/api/model.py,sha256=FLCk3fKUfLp3eCPrFEgPljTDp3zgIJnVW5Gwx0fWAog,67637
15
+ jaxsim/api/model.py,sha256=hDb8lVoRMhAWOMn1Q-MFHAa4wJY431-_nuLTKY0E3-Q,68200
16
16
  jaxsim/api/ode.py,sha256=gYSbtHWGCDP-IkUzQlH3t0fBKnK8qmxwhIvsbLG9lwU,13616
17
- jaxsim/api/ode_data.py,sha256=7RSoBhfCJdP6P9InQbDwdBVpClPMMuetewI-6AWm-_0,20276
17
+ jaxsim/api/ode_data.py,sha256=k1hVU1x8vuTVYdkf0cLhZ-oqeGJocXN2lLey-pS1_vo,20166
18
18
  jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
20
  jaxsim/integrators/common.py,sha256=78MBs89GxsL0wU2yAexjvBZt3HEtfZoGVIN9f0a8yTc,20305
@@ -44,7 +44,7 @@ jaxsim/parsers/descriptions/model.py,sha256=I2Vsbv8Josl4Le7b5rIvhqA2k9Bbv5JxMqwy
44
44
  jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
45
45
  jaxsim/parsers/rod/parser.py,sha256=9zBYTQF2vC4NO6HEZBvV8VTaGuZSdbd3v66BAcMpuVI,13923
46
46
  jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5698
47
- jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
47
+ jaxsim/rbda/__init__.py,sha256=kmy4G9aMkrqPNGdLSaSV3k15dpF52vBEUQXDFDuKIxU,337
48
48
  jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
49
49
  jaxsim/rbda/collidable_points.py,sha256=Rmf1DhflhOTYh9mDalv0agS0CGSbmfoOybwP2KzKuJ0,4883
50
50
  jaxsim/rbda/crba.py,sha256=bXkXESnVbv-lxhU1Y_i0rViEcQA4z2t2_jHwdVj5CBo,5049
@@ -52,19 +52,19 @@ jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdul
52
52
  jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
53
53
  jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
54
54
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
- jaxsim/rbda/contacts/__init__.py,sha256=Y1yT2zdgFa0zviZseI09wNaMcydH8TeoaWr6ehqzwdc,328
56
- jaxsim/rbda/contacts/common.py,sha256=CEmLS_PT44AOWKJ0bWrJJBqm2Q9v9LiqvL0rht63-ic,2605
57
- jaxsim/rbda/contacts/relaxed_rigid.py,sha256=f4g_qjoiml1gU9mqy6gEb8ljI3OSr4aoD2Li0J8dqQs,13708
58
- jaxsim/rbda/contacts/rigid.py,sha256=SP9qwkf35ybK0ZMiNhfIrP_tiiFN6jiG0ewBSBtR2Mc,15204
59
- jaxsim/rbda/contacts/soft.py,sha256=-d7zbMdKNq0aRT2zRXIu_Dbh8BL4VUnMDprz4Ddfwj0,16276
55
+ jaxsim/rbda/contacts/__init__.py,sha256=0UnO9ZR3BwdjQa276jOFbPi90pporr32LSc0qa9UUm4,369
56
+ jaxsim/rbda/contacts/common.py,sha256=-eM8d1kvJ2E_2_kAgZJk4s3x8vDZHNSyOAinwPmRmEk,3469
57
+ jaxsim/rbda/contacts/relaxed_rigid.py,sha256=VBD1FPjwpuKTfQ1bFKDlD_4xlEox4QhiatsjcrjDMNk,14025
58
+ jaxsim/rbda/contacts/rigid.py,sha256=6cU8kM8LMjEFbt8dtSg5nnz_uh4aD50sKw_svCzYUms,15633
59
+ jaxsim/rbda/contacts/soft.py,sha256=NzzCYw5rvK8Fx_qH3fiMzPgey-KoxmRe9xkF3fluidE,18866
60
60
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
61
- jaxsim/terrain/terrain.py,sha256=xUQg47yGxIOcTkLPbnO3sruEGBhoCd16j1evTGlmNjI,5010
61
+ jaxsim/terrain/terrain.py,sha256=Y0TGnUAGPuaeeSN8vbaSFjMXP5GYy3nxMCETjpUIMSA,5009
62
62
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
63
63
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
64
64
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
65
65
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
66
- jaxsim-0.4.3.dev115.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
- jaxsim-0.4.3.dev115.dist-info/METADATA,sha256=LG0pWBHUEVPV3MGVXaApN8SdiEUrmKTgJuD7PQmmyhI,17277
68
- jaxsim-0.4.3.dev115.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
69
- jaxsim-0.4.3.dev115.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
- jaxsim-0.4.3.dev115.dist-info/RECORD,,
66
+ jaxsim-0.4.3.dev129.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
+ jaxsim-0.4.3.dev129.dist-info/METADATA,sha256=oWWfWqv_MyqMsPtAZmGGEFc8qm_cuzzxTRGa7J9Vdhg,17277
68
+ jaxsim-0.4.3.dev129.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
69
+ jaxsim-0.4.3.dev129.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
+ jaxsim-0.4.3.dev129.dist-info/RECORD,,