jaxsim 0.3.1.dev17__py3-none-any.whl → 0.3.1.dev40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.1.dev17'
16
- __version_tuple__ = version_tuple = (0, 3, 1, 'dev17')
15
+ __version__ = version = '0.3.1.dev40'
16
+ __version_tuple__ = version_tuple = (0, 3, 1, 'dev40')
jaxsim/api/contact.py CHANGED
@@ -1,11 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
4
 
3
5
  import jax
4
6
  import jax.numpy as jnp
5
7
 
6
8
  import jaxsim.api as js
7
- import jaxsim.rbda
9
+ import jaxsim.terrain
8
10
  import jaxsim.typing as jtp
11
+ from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsParams
9
12
 
10
13
  from .common import VelRepr
11
14
 
@@ -135,17 +138,23 @@ def collidable_point_dynamics(
135
138
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
136
139
 
137
140
  # Build the soft contact model.
138
- soft_contacts = jaxsim.rbda.SoftContacts(
139
- parameters=data.soft_contacts_params, terrain=model.terrain
140
- )
141
+ match model.contact_model:
142
+ case s if isinstance(s, SoftContacts):
143
+ # Build the contact model.
144
+ soft_contacts = SoftContacts(
145
+ parameters=data.contacts_params, terrain=model.terrain
146
+ )
147
+
148
+ # Compute the 6D force expressed in the inertial frame and applied to each
149
+ # collidable point, and the corresponding material deformation rate.
150
+ # Note that the material deformation rate is always returned in the mixed frame
151
+ # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
152
+ W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
153
+ W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
154
+ )
141
155
 
142
- # Compute the 6D force expressed in the inertial frame and applied to each
143
- # collidable point, and the corresponding material deformation rate.
144
- # Note that the material deformation rate is always returned in the mixed frame
145
- # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
146
- W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
147
- W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation
148
- )
156
+ case _:
157
+ raise ValueError("Invalid contact model {}".format(model.contact_model))
149
158
 
150
159
  # Convert the 6D forces to the active representation.
151
160
  f_Ci = jax.vmap(
@@ -213,7 +222,7 @@ def estimate_good_soft_contacts_parameters(
213
222
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
214
223
  damping_ratio: jtp.FloatLike = 1.0,
215
224
  max_penetration: jtp.FloatLike | None = None,
216
- ) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
225
+ ) -> SoftContactsParams:
217
226
  """
218
227
  Estimate good soft contacts parameters for the given model.
219
228
 
@@ -237,13 +246,14 @@ def estimate_good_soft_contacts_parameters(
237
246
  The user is encouraged to fine-tune the parameters based on the
238
247
  specific application.
239
248
  """
249
+ from jaxsim.rbda.contacts.soft import SoftContactsParams
240
250
 
241
251
  def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
242
252
  """"""
243
253
 
244
254
  zero_data = js.data.JaxSimModelData.build(
245
255
  model=model,
246
- soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
256
+ contacts_params=SoftContactsParams(),
247
257
  )
248
258
 
249
259
  W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
@@ -262,15 +272,13 @@ def estimate_good_soft_contacts_parameters(
262
272
 
263
273
  nc = number_of_active_collidable_points_steady_state
264
274
 
265
- sc_parameters = (
266
- jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
267
- model=model,
268
- standard_gravity=standard_gravity,
269
- static_friction_coefficient=static_friction_coefficient,
270
- max_penetration=max_δ,
271
- number_of_active_collidable_points_steady_state=nc,
272
- damping_ratio=damping_ratio,
273
- )
275
+ sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
276
+ model=model,
277
+ standard_gravity=standard_gravity,
278
+ static_friction_coefficient=static_friction_coefficient,
279
+ max_penetration=max_δ,
280
+ number_of_active_collidable_points_steady_state=nc,
281
+ damping_ratio=damping_ratio,
274
282
  )
275
283
 
276
284
  return sc_parameters
jaxsim/api/data.py CHANGED
@@ -14,6 +14,7 @@ import jaxsim.api as js
14
14
  import jaxsim.rbda
15
15
  import jaxsim.typing as jtp
16
16
  from jaxsim.math import Quaternion
17
+ from jaxsim.rbda.contacts.soft import SoftContacts
17
18
  from jaxsim.utils import Mutability
18
19
  from jaxsim.utils.tracing import not_tracing
19
20
 
@@ -37,7 +38,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
37
38
 
38
39
  gravity: jtp.Array
39
40
 
40
- soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
41
+ contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
41
42
 
42
43
  time_ns: jtp.Int = dataclasses.field(
43
44
  default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
@@ -51,8 +52,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
51
52
  (
52
53
  hash(self.state),
53
54
  HashedNumpyArray.hash_of_array(self.gravity),
54
- hash(self.soft_contacts_params),
55
55
  HashedNumpyArray.hash_of_array(self.time_ns),
56
+ hash(self.contacts_params),
56
57
  )
57
58
  )
58
59
 
@@ -112,8 +113,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
112
113
  base_angular_velocity: jtp.Vector | None = None,
113
114
  joint_velocities: jtp.Vector | None = None,
114
115
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
115
- soft_contacts_state: js.ode_data.SoftContactsState | None = None,
116
- soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
116
+ contact: jaxsim.rbda.ContactsState | None = None,
117
+ contacts_params: jaxsim.rbda.ContactsParams | None = None,
117
118
  velocity_representation: VelRepr = VelRepr.Inertial,
118
119
  time: jtp.FloatLike | None = None,
119
120
  ) -> JaxSimModelData:
@@ -131,8 +132,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
131
132
  The base angular velocity in the selected representation.
132
133
  joint_velocities: The joint velocities.
133
134
  standard_gravity: The standard gravity constant.
134
- soft_contacts_state: The state of the soft contacts.
135
- soft_contacts_params: The parameters of the soft contacts.
135
+ contact: The state of the soft contacts.
136
+ contacts_params: The parameters of the soft contacts.
136
137
  velocity_representation: The velocity representation to use.
137
138
  time: The time at which the state is created.
138
139
 
@@ -178,13 +179,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
178
179
  else jnp.array(0, dtype=jnp.uint64)
179
180
  )
180
181
 
181
- soft_contacts_params = (
182
- soft_contacts_params
183
- if soft_contacts_params is not None
184
- else js.contact.estimate_good_soft_contacts_parameters(
185
- model=model, standard_gravity=standard_gravity
182
+ if isinstance(model.contact_model, SoftContacts):
183
+ contacts_params = (
184
+ contacts_params
185
+ if contacts_params is not None
186
+ else js.contact.estimate_good_soft_contacts_parameters(
187
+ model=model, standard_gravity=standard_gravity
188
+ )
186
189
  )
187
- )
190
+ else:
191
+ contacts_params = model.contact_model.parameters
188
192
 
189
193
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
190
194
  translation=base_position,
@@ -209,8 +213,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
209
213
  base_angular_velocity=v_WB[3:6].astype(float),
210
214
  joint_velocities=joint_velocities.astype(float),
211
215
  tangential_deformation=(
212
- soft_contacts_state.tangential_deformation
213
- if soft_contacts_state is not None
216
+ contact.tangential_deformation
217
+ if contact is not None and isinstance(model.contact_model, SoftContacts)
214
218
  else None
215
219
  ),
216
220
  )
@@ -222,7 +226,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
222
226
  time_ns=time_ns,
223
227
  state=ode_state,
224
228
  gravity=gravity.astype(float),
225
- soft_contacts_params=soft_contacts_params,
229
+ contacts_params=contacts_params,
226
230
  velocity_representation=velocity_representation,
227
231
  )
228
232
 
@@ -652,7 +656,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
652
656
 
653
657
  return self.reset_base_velocity(
654
658
  base_velocity=jnp.hstack(
655
- [linear_velocity.squeeze(), self.base_velocity()[3:6]]
659
+ [
660
+ linear_velocity.squeeze(),
661
+ self.base_velocity()[3:6],
662
+ ]
656
663
  ),
657
664
  velocity_representation=velocity_representation,
658
665
  )
@@ -680,7 +687,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
680
687
 
681
688
  return self.reset_base_velocity(
682
689
  base_velocity=jnp.hstack(
683
- [self.base_velocity()[0:3], angular_velocity.squeeze()]
690
+ [
691
+ self.base_velocity()[0:3],
692
+ angular_velocity.squeeze(),
693
+ ]
684
694
  ),
685
695
  velocity_representation=velocity_representation,
686
696
  )
jaxsim/api/model.py CHANGED
@@ -34,6 +34,10 @@ class JaxSimModel(JaxsimDataclass):
34
34
  default=jaxsim.terrain.FlatTerrain(), repr=False
35
35
  )
36
36
 
37
+ contact_model: jaxsim.rbda.ContactModel | None = dataclasses.field(
38
+ default=None, repr=False
39
+ )
40
+
37
41
  kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
38
42
  dataclasses.field(default=None, repr=False)
39
43
  )
@@ -69,6 +73,7 @@ class JaxSimModel(JaxsimDataclass):
69
73
  (
70
74
  hash(self.model_name),
71
75
  hash(self.kin_dyn_parameters),
76
+ hash(self.contact_model),
72
77
  )
73
78
  )
74
79
 
@@ -82,6 +87,7 @@ class JaxSimModel(JaxsimDataclass):
82
87
  model_name: str | None = None,
83
88
  *,
84
89
  terrain: jaxsim.terrain.Terrain | None = None,
90
+ contact_model: jaxsim.rbda.ContactModel | None = None,
85
91
  is_urdf: bool | None = None,
86
92
  considered_joints: Sequence[str] | None = None,
87
93
  ) -> JaxSimModel:
@@ -127,6 +133,7 @@ class JaxSimModel(JaxsimDataclass):
127
133
  model_description=intermediate_description,
128
134
  model_name=model_name,
129
135
  terrain=terrain,
136
+ contact_model=contact_model,
130
137
  )
131
138
 
132
139
  # Store the origin of the model, in case downstream logic needs it
@@ -141,6 +148,7 @@ class JaxSimModel(JaxsimDataclass):
141
148
  model_name: str | None = None,
142
149
  *,
143
150
  terrain: jaxsim.terrain.Terrain | None = None,
151
+ contact_model: jaxsim.rbda.ContactModel | None = None,
144
152
  ) -> JaxSimModel:
145
153
  """
146
154
  Build a Model object from an intermediate model description.
@@ -153,22 +161,30 @@ class JaxSimModel(JaxsimDataclass):
153
161
  The optional name of the model overriding the physics model name.
154
162
  terrain:
155
163
  The optional terrain to consider.
164
+ contact_model:
165
+ The optional contact model to consider. If None, the soft contact model is used.
156
166
 
157
167
  Returns:
158
168
  The built Model object.
159
169
  """
170
+ from jaxsim.rbda.contacts.soft import SoftContacts
160
171
 
161
172
  # Set the model name (if not provided, use the one from the model description)
162
173
  model_name = model_name if model_name is not None else model_description.name
163
174
 
164
- # Build the model.
175
+ # Set the terrain (if not provided, use the default flat terrain)
176
+ terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
177
+ contact_model = contact_model or SoftContacts(terrain=terrain)
178
+
179
+ # Build the model
165
180
  model = JaxSimModel(
166
181
  model_name=model_name,
167
182
  _description=wrappers.HashlessObject(obj=model_description),
168
183
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
169
184
  model_description=model_description
170
185
  ),
171
- terrain=terrain or JaxSimModel.__dataclass_fields__["terrain"].default,
186
+ terrain=terrain,
187
+ contact_model=contact_model,
172
188
  )
173
189
 
174
190
  return model
@@ -350,6 +366,7 @@ def reduce(
350
366
  model_description=reduced_intermediate_description,
351
367
  model_name=model.name(),
352
368
  terrain=model.terrain,
369
+ contact_model=model.contact_model,
353
370
  )
354
371
 
355
372
  # Store the origin of the model, in case downstream logic needs it
jaxsim/api/ode.py CHANGED
@@ -132,7 +132,7 @@ def system_velocity_dynamics(
132
132
  W_f_Ci = None
133
133
 
134
134
  # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
135
- ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
135
+ ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
136
136
 
137
137
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
138
138
  # Compute the 6D forces applied to each collidable point and the
jaxsim/api/ode_data.py CHANGED
@@ -5,6 +5,8 @@ import jax_dataclasses
5
5
 
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
+ from jaxsim.rbda import ContactsState
9
+ from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
8
10
  from jaxsim.utils import JaxsimDataclass
9
11
 
10
12
  # =============================================================================
@@ -116,11 +118,11 @@ class ODEState(JaxsimDataclass):
116
118
 
117
119
  Attributes:
118
120
  physics_model: The state of the physics model.
119
- soft_contacts: The state of the soft-contacts model.
121
+ contact: The state of the contacts model.
120
122
  """
121
123
 
122
124
  physics_model: PhysicsModelState
123
- soft_contacts: SoftContactsState
125
+ contact: ContactsState
124
126
 
125
127
  @staticmethod
126
128
  def build_from_jaxsim_model(
@@ -158,6 +160,20 @@ class ODEState(JaxsimDataclass):
158
160
  `JaxSimModel` and initialized to zero.
159
161
  """
160
162
 
163
+ # Get the contact model from the `JaxSimModel`
164
+ match model.contact_model:
165
+ case SoftContacts():
166
+ contact = SoftContactsState.build_from_jaxsim_model(
167
+ model=model,
168
+ **(
169
+ dict(tangential_deformation=tangential_deformation)
170
+ if tangential_deformation is not None
171
+ else dict()
172
+ ),
173
+ )
174
+ case _:
175
+ raise ValueError("Unable to determine contact state class prefix.")
176
+
161
177
  return ODEState.build(
162
178
  model=model,
163
179
  physics_model_state=PhysicsModelState.build_from_jaxsim_model(
@@ -169,24 +185,21 @@ class ODEState(JaxsimDataclass):
169
185
  base_linear_velocity=base_linear_velocity,
170
186
  base_angular_velocity=base_angular_velocity,
171
187
  ),
172
- soft_contacts_state=SoftContactsState.build_from_jaxsim_model(
173
- model=model,
174
- tangential_deformation=tangential_deformation,
175
- ),
188
+ contact=contact,
176
189
  )
177
190
 
178
191
  @staticmethod
179
192
  def build(
180
193
  physics_model_state: PhysicsModelState | None = None,
181
- soft_contacts_state: SoftContactsState | None = None,
194
+ contact: ContactsState | None = None,
182
195
  model: js.model.JaxSimModel | None = None,
183
196
  ) -> ODEState:
184
197
  """
185
- Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`.
198
+ Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
186
199
 
187
200
  Args:
188
201
  physics_model_state: The state of the physics model.
189
- soft_contacts_state: The state of the soft-contacts model.
202
+ contact: The state of the contacts model.
190
203
  model: The `JaxSimModel` associated with the ODE state.
191
204
 
192
205
  Returns:
@@ -199,15 +212,16 @@ class ODEState(JaxsimDataclass):
199
212
  else PhysicsModelState.zero(model=model)
200
213
  )
201
214
 
202
- soft_contacts_state = (
203
- soft_contacts_state
204
- if soft_contacts_state is not None
205
- else SoftContactsState.zero(model=model)
206
- )
215
+ # Get the contact model from the `JaxSimModel`
216
+ match contact:
217
+ case SoftContactsState():
218
+ pass
219
+ case None:
220
+ contact = SoftContactsState.zero(model=model)
221
+ case _:
222
+ raise ValueError("Unable to determine contact state class prefix.")
207
223
 
208
- return ODEState(
209
- physics_model=physics_model_state, soft_contacts=soft_contacts_state
210
- )
224
+ return ODEState(physics_model=physics_model_state, contact=contact)
211
225
 
212
226
  @staticmethod
213
227
  def zero(model: js.model.JaxSimModel) -> ODEState:
@@ -236,9 +250,7 @@ class ODEState(JaxsimDataclass):
236
250
  `True` if the ODE state is valid for the given model, `False` otherwise.
237
251
  """
238
252
 
239
- return self.physics_model.valid(model=model) and self.soft_contacts.valid(
240
- model=model
241
- )
253
+ return self.physics_model.valid(model=model) and self.contact.valid(model=model)
242
254
 
243
255
 
244
256
  # ==================================================
@@ -595,135 +607,3 @@ class PhysicsModelInput(JaxsimDataclass):
595
607
  return False
596
608
 
597
609
  return True
598
-
599
-
600
- # ===========================================
601
- # Define the state of the soft-contacts model
602
- # ===========================================
603
-
604
-
605
- @jax_dataclasses.pytree_dataclass
606
- class SoftContactsState(JaxsimDataclass):
607
- """
608
- Class storing the state of the soft contacts model.
609
-
610
- Attributes:
611
- tangential_deformation:
612
- The matrix of 3D tangential material deformations corresponding to
613
- each collidable point.
614
- """
615
-
616
- tangential_deformation: jtp.Matrix
617
-
618
- def __hash__(self) -> int:
619
-
620
- from jaxsim.utils.wrappers import HashedNumpyArray
621
-
622
- return HashedNumpyArray.hash_of_array(self.tangential_deformation)
623
-
624
- def __eq__(self, other: SoftContactsState) -> bool:
625
-
626
- if not isinstance(other, SoftContactsState):
627
- return False
628
-
629
- return hash(self) == hash(other)
630
-
631
- @staticmethod
632
- def build_from_jaxsim_model(
633
- model: js.model.JaxSimModel | None = None,
634
- tangential_deformation: jtp.Matrix | None = None,
635
- ) -> SoftContactsState:
636
- """
637
- Build a `SoftContactsState` from a `JaxSimModel`.
638
-
639
- Args:
640
- model: The `JaxSimModel` associated with the soft contacts state.
641
- tangential_deformation: The matrix of 3D tangential material deformations.
642
-
643
- Returns:
644
- The `SoftContactsState` built from the `JaxSimModel`.
645
-
646
- Note:
647
- If any of the state components are not provided, they are built from the
648
- `JaxSimModel` and initialized to zero.
649
- """
650
-
651
- return SoftContactsState.build(
652
- tangential_deformation=tangential_deformation,
653
- number_of_collidable_points=len(
654
- model.kin_dyn_parameters.contact_parameters.body
655
- ),
656
- )
657
-
658
- @staticmethod
659
- def build(
660
- tangential_deformation: jtp.Matrix | None = None,
661
- number_of_collidable_points: int | None = None,
662
- ) -> SoftContactsState:
663
- """
664
- Create a `SoftContactsState`.
665
-
666
- Args:
667
- tangential_deformation:
668
- The matrix of 3D tangential material deformations corresponding to
669
- each collidable point.
670
- number_of_collidable_points: The number of collidable points.
671
-
672
- Returns:
673
- A `SoftContactsState` instance.
674
- """
675
-
676
- tangential_deformation = (
677
- tangential_deformation
678
- if tangential_deformation is not None
679
- else jnp.zeros(shape=(number_of_collidable_points, 3))
680
- )
681
-
682
- if tangential_deformation.shape[1] != 3:
683
- raise RuntimeError("The tangential deformation matrix must have 3 columns.")
684
-
685
- if (
686
- number_of_collidable_points is not None
687
- and tangential_deformation.shape[0] != number_of_collidable_points
688
- ):
689
- msg = "The number of collidable points must match the number of rows "
690
- msg += "in the tangential deformation matrix."
691
- raise RuntimeError(msg)
692
-
693
- return SoftContactsState(
694
- tangential_deformation=jnp.array(tangential_deformation).astype(float)
695
- )
696
-
697
- @staticmethod
698
- def zero(model: js.model.JaxSimModel) -> SoftContactsState:
699
- """
700
- Build a zero `SoftContactsState` from a `JaxSimModel`.
701
-
702
- Args:
703
- model: The `JaxSimModel` associated with the soft contacts state.
704
-
705
- Returns:
706
- A zero `SoftContactsState` instance.
707
- """
708
-
709
- return SoftContactsState.build_from_jaxsim_model(model=model)
710
-
711
- def valid(self, model: js.model.JaxSimModel) -> bool:
712
- """
713
- Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
714
-
715
- Args:
716
- model: The `JaxSimModel` to validate the `SoftContactsState` against.
717
-
718
- Returns:
719
- `True` if the soft contacts state is valid for the given `JaxSimModel`,
720
- `False` otherwise.
721
- """
722
-
723
- shape = self.tangential_deformation.shape
724
- expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3)
725
-
726
- if shape != expected:
727
- return False
728
-
729
- return True
jaxsim/exceptions.py ADDED
@@ -0,0 +1,63 @@
1
+ import jax
2
+
3
+
4
+ def raise_if(
5
+ condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
6
+ ) -> None:
7
+ """
8
+ Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
9
+
10
+ Args:
11
+ condition:
12
+ The boolean condition of the evaluated expression that triggers
13
+ the exception during runtime.
14
+ exception: The type of exception to raise.
15
+ msg:
16
+ The message to display when the exception is raised. The message can be a
17
+ format string (fmt), whose fields are filled with the args and kwargs.
18
+ """
19
+
20
+ # Check early that the format string is well-formed.
21
+ try:
22
+ _ = msg.format(*args, **kwargs)
23
+ except Exception as e:
24
+ msg = "Error in formatting exception message with args={} and kwargs={}"
25
+ raise ValueError(msg.format(args, kwargs)) from e
26
+
27
+ def _raise_exception(condition: bool, *args, **kwargs) -> None:
28
+ """The function called by the JAX callback."""
29
+
30
+ if condition:
31
+ raise exception(msg.format(*args, **kwargs))
32
+
33
+ def _callback(args, kwargs) -> None:
34
+ """The function that calls the JAX callback, executed only when needed."""
35
+
36
+ jax.debug.callback(_raise_exception, condition, *args, **kwargs)
37
+
38
+ # Since running a callable on the host is expensive, we prevent its execution
39
+ # if the condition is False with a low-level conditional expression.
40
+ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
41
+ return jax.lax.cond(
42
+ condition,
43
+ _callback,
44
+ lambda args, kwargs: None,
45
+ args,
46
+ kwargs,
47
+ )
48
+
49
+ return _run_callback_only_if_condition_is_true(*args, **kwargs)
50
+
51
+
52
+ def raise_runtime_error_if(
53
+ condition: bool | jax.Array, msg: str, *args, **kwargs
54
+ ) -> None:
55
+
56
+ return raise_if(condition, RuntimeError, msg, *args, **kwargs)
57
+
58
+
59
+ def raise_value_error_if(
60
+ condition: bool | jax.Array, msg: str, *args, **kwargs
61
+ ) -> None:
62
+
63
+ return raise_if(condition, ValueError, msg, *args, **kwargs)
jaxsim/rbda/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from .aba import aba
2
2
  from .collidable_points import collidable_points_pos_vel
3
+ from .contacts.common import ContactModel, ContactsParams, ContactsState
3
4
  from .crba import crba
4
5
  from .forward_kinematics import forward_kinematics, forward_kinematics_model
5
6
  from .jacobian import (
@@ -8,4 +9,3 @@ from .jacobian import (
8
9
  jacobian_full_doubly_left,
9
10
  )
10
11
  from .rnea import rnea
11
- from .soft_contacts import SoftContacts, SoftContactsParams
File without changes
@@ -0,0 +1,101 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import Any
5
+
6
+ import jaxsim.terrain
7
+ import jaxsim.typing as jtp
8
+
9
+
10
+ class ContactsState(abc.ABC):
11
+ """
12
+ Abstract class storing the state of the contacts model.
13
+ """
14
+
15
+ @classmethod
16
+ @abc.abstractmethod
17
+ def build(cls, **kwargs) -> ContactsState:
18
+ """
19
+ Build the contact state object.
20
+
21
+ Returns:
22
+ The contact state object.
23
+ """
24
+ pass
25
+
26
+ @classmethod
27
+ @abc.abstractmethod
28
+ def zero(cls, **kwargs) -> ContactsState:
29
+ """
30
+ Build a zero contact state.
31
+
32
+ Returns:
33
+ The zero contact state.
34
+ """
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def valid(self, **kwargs) -> bool:
39
+ """
40
+ Check if the contacts state is valid.
41
+ """
42
+ pass
43
+
44
+
45
+ class ContactsParams(abc.ABC):
46
+ """
47
+ Abstract class representing the parameters of a contact model.
48
+ """
49
+
50
+ @classmethod
51
+ @abc.abstractmethod
52
+ def build(cls) -> ContactsParams:
53
+ """
54
+ Create a `ContactsParams` instance with specified parameters.
55
+ Returns:
56
+ The `ContactsParams` instance.
57
+ """
58
+ pass
59
+
60
+ @abc.abstractmethod
61
+ def valid(self, *args, **kwargs) -> bool:
62
+ """
63
+ Check if the parameters are valid.
64
+ Returns:
65
+ True if the parameters are valid, False otherwise.
66
+ """
67
+ pass
68
+
69
+
70
+ class ContactModel(abc.ABC):
71
+ """
72
+ Abstract class representing a contact model.
73
+
74
+ Attributes:
75
+ parameters: The parameters of the contact model.
76
+ terrain: The terrain model.
77
+ """
78
+
79
+ parameters: ContactsParams
80
+ terrain: jaxsim.terrain.Terrain
81
+
82
+ @abc.abstractmethod
83
+ def compute_contact_forces(
84
+ self,
85
+ position: jtp.Vector,
86
+ velocity: jtp.Vector,
87
+ **kwargs,
88
+ ) -> tuple[jtp.Vector, tuple[Any, ...]]:
89
+ """
90
+ Compute the contact forces.
91
+
92
+ Args:
93
+ position: The position of the collidable point w.r.t. the world frame.
94
+ velocity:
95
+ The linear velocity of the collidable point (linear component of the mixed 6D velocity).
96
+
97
+ Returns:
98
+ A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
99
+ and as second element a tuple of optional additional information.
100
+ """
101
+ pass
@@ -10,11 +10,12 @@ import jaxsim.api as js
10
10
  import jaxsim.typing as jtp
11
11
  from jaxsim.math import Skew, StandardGravity
12
12
  from jaxsim.terrain import FlatTerrain, Terrain
13
- from jaxsim.utils import JaxsimDataclass
13
+
14
+ from .common import ContactModel, ContactsParams, ContactsState
14
15
 
15
16
 
16
17
  @jax_dataclasses.pytree_dataclass
17
- class SoftContactsParams(JaxsimDataclass):
18
+ class SoftContactsParams(ContactsParams):
18
19
  """Parameters of the soft contacts model."""
19
20
 
20
21
  K: jtp.Float = dataclasses.field(
@@ -127,9 +128,23 @@ class SoftContactsParams(JaxsimDataclass):
127
128
 
128
129
  return SoftContactsParams.build(K=K, D=D, mu=μc)
129
130
 
131
+ def valid(self) -> bool:
132
+ """
133
+ Check if the parameters are valid.
134
+
135
+ Returns:
136
+ `True` if the parameters are valid, `False` otherwise.
137
+ """
138
+
139
+ return (
140
+ jnp.all(self.K >= 0.0)
141
+ and jnp.all(self.D >= 0.0)
142
+ and jnp.all(self.mu >= 0.0)
143
+ )
144
+
130
145
 
131
146
  @jax_dataclasses.pytree_dataclass
132
- class SoftContacts:
147
+ class SoftContacts(ContactModel):
133
148
  """Soft contacts model."""
134
149
 
135
150
  parameters: SoftContactsParams = dataclasses.field(
@@ -138,12 +153,12 @@ class SoftContacts:
138
153
 
139
154
  terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
140
155
 
141
- def contact_model(
156
+ def compute_contact_forces(
142
157
  self,
143
158
  position: jtp.Vector,
144
159
  velocity: jtp.Vector,
145
160
  tangential_deformation: jtp.Vector,
146
- ) -> tuple[jtp.Vector, jtp.Vector]:
161
+ ) -> tuple[jtp.Vector, tuple[jtp.Vector, None]]:
147
162
  """
148
163
  Compute the contact forces and material deformation rate.
149
164
 
@@ -222,7 +237,7 @@ class SoftContacts:
222
237
  # Compute lin-ang 6D forces (inertial representation)
223
238
  W_f = W_Xf_CW @ CW_f
224
239
 
225
- return W_f, ṁ
240
+ return W_f, (ṁ,)
226
241
 
227
242
  # =========================
228
243
  # Compute tangential forces
@@ -240,7 +255,7 @@ class SoftContacts:
240
255
  active_contact = pz < self.terrain.height(x=px, y=py)
241
256
 
242
257
  def above_terrain():
243
- return jnp.zeros(6), ṁ
258
+ return jnp.zeros(6), (ṁ,)
244
259
 
245
260
  def below_terrain():
246
261
  # Decompose the velocity in normal and tangential components
@@ -296,9 +311,9 @@ class SoftContacts:
296
311
  W_f = W_Xf_CW @ CW_f
297
312
 
298
313
  # Return the 6D force in the world frame and the deformation derivative
299
- return W_f, ṁ
314
+ return W_f, (ṁ,)
300
315
 
301
- # (W_f, ṁ)
316
+ # (W_f, (ṁ,))
302
317
  return jax.lax.cond(
303
318
  pred=active_contact,
304
319
  true_fun=lambda _: below_terrain(),
@@ -313,3 +328,128 @@ class SoftContacts:
313
328
  false_fun=lambda _: with_friction(),
314
329
  operand=None,
315
330
  )
331
+
332
+
333
+ @jax_dataclasses.pytree_dataclass
334
+ class SoftContactsState(ContactsState):
335
+ """
336
+ Class storing the state of the soft contacts model.
337
+
338
+ Attributes:
339
+ tangential_deformation:
340
+ The matrix of 3D tangential material deformations corresponding to
341
+ each collidable point.
342
+ """
343
+
344
+ tangential_deformation: jtp.Matrix
345
+
346
+ def __hash__(self) -> int:
347
+ return hash(
348
+ tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
349
+ )
350
+
351
+ def __eq__(self, other: SoftContactsState) -> bool:
352
+ if not isinstance(other, SoftContactsState):
353
+ return False
354
+
355
+ return hash(self) == hash(other)
356
+
357
+ @staticmethod
358
+ def build_from_jaxsim_model(
359
+ model: js.model.JaxSimModel | None = None,
360
+ tangential_deformation: jtp.Matrix | None = None,
361
+ ) -> SoftContactsState:
362
+ """
363
+ Build a `SoftContactsState` from a `JaxSimModel`.
364
+
365
+ Args:
366
+ model: The `JaxSimModel` associated with the soft contacts state.
367
+ tangential_deformation: The matrix of 3D tangential material deformations.
368
+
369
+ Returns:
370
+ The `SoftContactsState` built from the `JaxSimModel`.
371
+
372
+ Note:
373
+ If any of the state components are not provided, they are built from the
374
+ `JaxSimModel` and initialized to zero.
375
+ """
376
+
377
+ return SoftContactsState.build(
378
+ tangential_deformation=tangential_deformation,
379
+ number_of_collidable_points=len(
380
+ model.kin_dyn_parameters.contact_parameters.body
381
+ ),
382
+ )
383
+
384
+ @staticmethod
385
+ def build(
386
+ tangential_deformation: jtp.Matrix | None = None,
387
+ number_of_collidable_points: int | None = None,
388
+ ) -> SoftContactsState:
389
+ """
390
+ Create a `SoftContactsState`.
391
+
392
+ Args:
393
+ tangential_deformation:
394
+ The matrix of 3D tangential material deformations corresponding to
395
+ each collidable point.
396
+ number_of_collidable_points: The number of collidable points.
397
+
398
+ Returns:
399
+ A `SoftContactsState` instance.
400
+ """
401
+
402
+ tangential_deformation = (
403
+ tangential_deformation
404
+ if tangential_deformation is not None
405
+ else jnp.zeros(shape=(number_of_collidable_points, 3))
406
+ )
407
+
408
+ if tangential_deformation.shape[1] != 3:
409
+ raise RuntimeError("The tangential deformation matrix must have 3 columns.")
410
+
411
+ if (
412
+ number_of_collidable_points is not None
413
+ and tangential_deformation.shape[0] != number_of_collidable_points
414
+ ):
415
+ msg = "The number of collidable points must match the number of rows "
416
+ msg += "in the tangential deformation matrix."
417
+ raise RuntimeError(msg)
418
+
419
+ return SoftContactsState(
420
+ tangential_deformation=jnp.array(tangential_deformation).astype(float)
421
+ )
422
+
423
+ @staticmethod
424
+ def zero(model: js.model.JaxSimModel) -> SoftContactsState:
425
+ """
426
+ Build a zero `SoftContactsState` from a `JaxSimModel`.
427
+
428
+ Args:
429
+ model: The `JaxSimModel` associated with the soft contacts state.
430
+
431
+ Returns:
432
+ A zero `SoftContactsState` instance.
433
+ """
434
+
435
+ return SoftContactsState.build_from_jaxsim_model(model=model)
436
+
437
+ def valid(self, model: js.model.JaxSimModel) -> bool:
438
+ """
439
+ Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
440
+
441
+ Args:
442
+ model: The `JaxSimModel` to validate the `SoftContactsState` against.
443
+
444
+ Returns:
445
+ `True` if the soft contacts state is valid for the given `JaxSimModel`,
446
+ `False` otherwise.
447
+ """
448
+
449
+ shape = self.tangential_deformation.shape
450
+ expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3)
451
+
452
+ if shape != expected:
453
+ return False
454
+
455
+ return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.3.1.dev17
3
+ Version: 0.3.1.dev40
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -1,19 +1,20 @@
1
1
  jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
2
- jaxsim/_version.py,sha256=EQQfkY5WXMHFjdRnYAQqABGWC0VK4dlpuNh_wr1KxYA,426
2
+ jaxsim/_version.py,sha256=JHeHgaRnZLSH9QDd807jRnlOgX9sn4w_CGqGSxA8lL0,426
3
+ jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
3
4
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
4
5
  jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
5
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
6
7
  jaxsim/api/com.py,sha256=Yof6otFi-mLWAs1rqjmeNJTOWIH9gn7BdU5EIjiL6Ts,13481
7
8
  jaxsim/api/common.py,sha256=bqQ__pIQZbh-j8rkoHUkYHAgGiJnDzjHG-q4Ny0OOYQ,6646
8
- jaxsim/api/contact.py,sha256=79kcdq7C1_kWgxd1QWBabBhIPkwWEVLk-Fiz9kh-4so,12800
9
- jaxsim/api/data.py,sha256=fkVDBV1tODRYIaRb2N15l34InAcnzNygMGG1KFiIU2w,27307
9
+ jaxsim/api/contact.py,sha256=soB28vqmzUwE6CN36TU4keASWZoSWE2_zhJLXA8yw2E,13132
10
+ jaxsim/api/data.py,sha256=oAJ2suPeQLQZGpHZi98g6UZp1VcoDtuqT_aZBpynA30,27582
10
11
  jaxsim/api/frame.py,sha256=vSbFHL4WtKPySxunNoZLlM_aDuJXZtf8CSBKku63BAs,6178
11
12
  jaxsim/api/joint.py,sha256=-5DogPg4g4mmLckyVIVNjwv-Rxz0IWS7_md9nDlhPWA,4581
12
13
  jaxsim/api/kin_dyn_parameters.py,sha256=AEpDg9kihbKUN9PA8pNrAruSuWFUC-k_GGxtlcdcDiQ,29215
13
14
  jaxsim/api/link.py,sha256=MdMWaMpM5Dj5JHK8uwHZ4zR4Fjq3R4asi2sGTxk1OAs,16647
14
- jaxsim/api/model.py,sha256=iuNYsn4xIfX36smmZpwM2O5eftT7ioDQtb6mSUqWu6Q,59759
15
- jaxsim/api/ode.py,sha256=luTQJsIXUtCp_81dR42X7WrMvwrXtYbyJiqss29v7zA,10786
16
- jaxsim/api/ode_data.py,sha256=FxUIV5qDNOg_OiOXWs3UrhDgKhGmTKcbHqgr4NX5bv0,23290
15
+ jaxsim/api/model.py,sha256=HAnrlgPDl5CCZQzQ84AfjC_DZjmrCzBKEDodE6hyLf8,60518
16
+ jaxsim/api/ode.py,sha256=xQL53ppnKweMQWRNm5gGR8FTjqRVzds8WKg9js9k5TA,10780
17
+ jaxsim/api/ode_data.py,sha256=Sa2i1zZhqyQqIGv1jarTmmU-W9HhTw-DErs12kFA1GA,19737
17
18
  jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
18
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
19
20
  jaxsim/integrators/common.py,sha256=9HXRVFo95Mpt6RcVhBrOfvOO7mDxqbkXeg_lKUibEFY,20693
@@ -43,23 +44,25 @@ jaxsim/parsers/descriptions/model.py,sha256=vfubtW68CUdgcbCHPcgKy0_BxzKQhhM8ycbC
43
44
  jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
44
45
  jaxsim/parsers/rod/parser.py,sha256=4COuhkAYv4-GIpCqvkXEJWpDEQczEkBM3KwpqX48Rek,13514
45
46
  jaxsim/parsers/rod/utils.py,sha256=KSjgy6WsmTrD5HZEA2x8hOBSRU4bUGOOHzxKkeFO5r8,5721
46
- jaxsim/rbda/__init__.py,sha256=MqEZwzu8SHPAlIFHmSXmCjehuOJGRX58OrBVAbBVMwg,374
47
+ jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
47
48
  jaxsim/rbda/aba.py,sha256=0OoCzHhf1v-qqr1y5PIrD7_mPwAlid0fjXxUrIa5E_s,9118
48
49
  jaxsim/rbda/collidable_points.py,sha256=4ZNJbEj2nEi15jBLR-GNbdaqKgkN58FBgqd_TXupEgg,4948
49
50
  jaxsim/rbda/crba.py,sha256=awsWEQXLE0UPEXIcZCVsAqBEPjyahMNzY9ux6nE1l-s,4739
50
51
  jaxsim/rbda/forward_kinematics.py,sha256=94W7TUXvZjMb-99CyYR8pObuxIYYX9B_dtRZqsNcThs,3418
51
52
  jaxsim/rbda/jacobian.py,sha256=M79bGir-2w_iJ2GurYhOGgMfJnp7ZMOCW6AeeWKK8iM,10745
52
53
  jaxsim/rbda/rnea.py,sha256=DjwkvXQVUSUclM3Uy3UPZ2tao91R5dGd4o7TsS2qObI,7650
53
- jaxsim/rbda/soft_contacts.py,sha256=0hx9JT4R1X2PPjhZ1EDizBR1gGoCFCtKYu86SeuIvvA,11269
54
54
  jaxsim/rbda/utils.py,sha256=zpbFM2Iq8cntku0BFVu9nfEqZhInCWi9D2INT6MFEI8,5003
55
+ jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
+ jaxsim/rbda/contacts/common.py,sha256=iMKLP30Qft9eGTiHo2iY-UoACJjg1JphA9_pW8wRdjc,2410
57
+ jaxsim/rbda/contacts/soft.py,sha256=TvjGrKFmk6IIip-D3WLcOr9hWjlmF11-ULPkAqJKTZY,15601
55
58
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
56
59
  jaxsim/terrain/terrain.py,sha256=UXQCt7TCkq6GkM8bOZu44pNTpf-FZWiKN6VE4kb4kFk,2342
57
60
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
58
61
  jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
59
62
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
60
63
  jaxsim/utils/wrappers.py,sha256=QIJitSoljrKR_U4T3ewCJPT3DTh-tPZsRsg0t_MH93E,3896
61
- jaxsim-0.3.1.dev17.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
62
- jaxsim-0.3.1.dev17.dist-info/METADATA,sha256=zRsMl96hDJt919NgrEuxkhye1S8X20bi_nWdPzJiptU,9739
63
- jaxsim-0.3.1.dev17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
64
- jaxsim-0.3.1.dev17.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
65
- jaxsim-0.3.1.dev17.dist-info/RECORD,,
64
+ jaxsim-0.3.1.dev40.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
+ jaxsim-0.3.1.dev40.dist-info/METADATA,sha256=ds7zFF0BWvBs8TLnHcQ54-diuMpka_Y5NKlF-vDK2nA,9739
66
+ jaxsim-0.3.1.dev40.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
67
+ jaxsim-0.3.1.dev40.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
+ jaxsim-0.3.1.dev40.dist-info/RECORD,,