jaxsim 0.4.3.dev105__py3-none-any.whl → 0.4.3.dev118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev105'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev105')
15
+ __version__ = version = '0.4.3.dev118'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev118')
jaxsim/api/contact.py CHANGED
@@ -186,7 +186,9 @@ def collidable_point_dynamics(
186
186
  # Note that the material deformation rate is always returned in the mixed frame
187
187
  # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
188
188
  W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
189
- W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
189
+ position=W_p_Ci,
190
+ velocity=W_ṗ_Ci,
191
+ tangential_deformation=data.state.contact.tangential_deformation,
190
192
  )
191
193
  aux_data = dict(m_dot=CW_ṁ)
192
194
 
@@ -0,0 +1,8 @@
1
+ from .common import ContactModel, ContactsParams, ContactsState
2
+ from .relaxed_rigid import (
3
+ RelaxedRigidContacts,
4
+ RelaxedRigidContactsParams,
5
+ RelaxedRigidContactsState,
6
+ )
7
+ from .rigid import RigidContacts, RigidContactsParams, RigidContactsState
8
+ from .soft import SoftContacts, SoftContactsParams, SoftContactsState
@@ -7,15 +7,20 @@ import jaxsim.terrain
7
7
  import jaxsim.typing as jtp
8
8
  from jaxsim.utils import JaxsimDataclass
9
9
 
10
+ try:
11
+ from typing import Self
12
+ except ImportError:
13
+ from typing_extensions import Self
10
14
 
11
- class ContactsState(abc.ABC):
15
+
16
+ class ContactsState(JaxsimDataclass):
12
17
  """
13
18
  Abstract class storing the state of the contacts model.
14
19
  """
15
20
 
16
21
  @classmethod
17
22
  @abc.abstractmethod
18
- def build(cls, **kwargs) -> ContactsState:
23
+ def build(cls: type[Self], **kwargs) -> Self:
19
24
  """
20
25
  Build the contact state object.
21
26
 
@@ -26,7 +31,7 @@ class ContactsState(abc.ABC):
26
31
 
27
32
  @classmethod
28
33
  @abc.abstractmethod
29
- def zero(cls, **kwargs) -> ContactsState:
34
+ def zero(cls: type[Self], **kwargs) -> Self:
30
35
  """
31
36
  Build a zero contact state.
32
37
 
@@ -36,7 +41,7 @@ class ContactsState(abc.ABC):
36
41
  pass
37
42
 
38
43
  @abc.abstractmethod
39
- def valid(self, **kwargs) -> bool:
44
+ def valid(self, **kwargs) -> jtp.BoolLike:
40
45
  """
41
46
  Check if the contacts state is valid.
42
47
  """
@@ -50,18 +55,20 @@ class ContactsParams(JaxsimDataclass):
50
55
 
51
56
  @classmethod
52
57
  @abc.abstractmethod
53
- def build(cls) -> ContactsParams:
58
+ def build(cls: type[Self], **kwargs) -> Self:
54
59
  """
55
60
  Create a `ContactsParams` instance with specified parameters.
61
+
56
62
  Returns:
57
63
  The `ContactsParams` instance.
58
64
  """
59
65
  pass
60
66
 
61
67
  @abc.abstractmethod
62
- def valid(self, *args, **kwargs) -> bool:
68
+ def valid(self, **kwargs) -> jtp.BoolLike:
63
69
  """
64
70
  Check if the parameters are valid.
71
+
65
72
  Returns:
66
73
  True if the parameters are valid, False otherwise.
67
74
  """
@@ -83,8 +90,8 @@ class ContactModel(JaxsimDataclass):
83
90
  @abc.abstractmethod
84
91
  def compute_contact_forces(
85
92
  self,
86
- position: jtp.Vector,
87
- velocity: jtp.Vector,
93
+ position: jtp.VectorLike,
94
+ velocity: jtp.VectorLike,
88
95
  **kwargs,
89
96
  ) -> tuple[jtp.Vector, tuple[Any, ...]]:
90
97
  """
@@ -16,6 +16,11 @@ from jaxsim.terrain.terrain import FlatTerrain, Terrain
16
16
 
17
17
  from .common import ContactModel, ContactsParams, ContactsState
18
18
 
19
+ try:
20
+ from typing import Self
21
+ except ImportError:
22
+ from typing_extensions import Self
23
+
19
24
 
20
25
  @jax_dataclasses.pytree_dataclass
21
26
  class RelaxedRigidContactsParams(ContactsParams):
@@ -106,7 +111,8 @@ class RelaxedRigidContactsParams(ContactsParams):
106
111
 
107
112
  @classmethod
108
113
  def build(
109
- cls,
114
+ cls: type[Self],
115
+ *,
110
116
  time_constant: jtp.FloatLike | None = None,
111
117
  damping_coefficient: jtp.FloatLike | None = None,
112
118
  d_min: jtp.FloatLike | None = None,
@@ -119,7 +125,7 @@ class RelaxedRigidContactsParams(ContactsParams):
119
125
  mu: jtp.FloatLike | None = None,
120
126
  max_iterations: jtp.IntLike | None = None,
121
127
  tolerance: jtp.FloatLike | None = None,
122
- ) -> RelaxedRigidContactsParams:
128
+ ) -> Self:
123
129
  """Create a `RelaxedRigidContactsParams` instance"""
124
130
 
125
131
  return cls(
@@ -132,7 +138,8 @@ class RelaxedRigidContactsParams(ContactsParams):
132
138
  }
133
139
  )
134
140
 
135
- def valid(self) -> bool:
141
+ def valid(self) -> jtp.BoolLike:
142
+
136
143
  return bool(
137
144
  jnp.all(self.time_constant >= 0.0)
138
145
  and jnp.all(self.damping_coefficient > 0.0)
@@ -155,18 +162,19 @@ class RelaxedRigidContactsState(ContactsState):
155
162
  def __eq__(self, other: RelaxedRigidContactsState) -> bool:
156
163
  return hash(self) == hash(other)
157
164
 
158
- @staticmethod
159
- def build() -> RelaxedRigidContactsState:
165
+ @classmethod
166
+ def build(cls: type[Self]) -> Self:
160
167
  """Create a `RelaxedRigidContactsState` instance"""
161
168
 
162
- return RelaxedRigidContactsState()
169
+ return cls()
163
170
 
164
- @staticmethod
165
- def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
171
+ @classmethod
172
+ def zero(cls: type[Self], **kwargs) -> Self:
166
173
  """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
167
- return RelaxedRigidContactsState.build()
168
174
 
169
- def valid(self, model: js.model.JaxSimModel) -> bool:
175
+ return cls.build()
176
+
177
+ def valid(self, **kwargs) -> jtp.BoolLike:
170
178
  return True
171
179
 
172
180
 
@@ -182,10 +190,12 @@ class RelaxedRigidContacts(ContactModel):
182
190
  default_factory=FlatTerrain
183
191
  )
184
192
 
193
+ @jax.jit
185
194
  def compute_contact_forces(
186
195
  self,
187
- position: jtp.Vector,
188
- velocity: jtp.Vector,
196
+ position: jtp.VectorLike,
197
+ velocity: jtp.VectorLike,
198
+ *,
189
199
  model: js.model.JaxSimModel,
190
200
  data: js.data.JaxSimModelData,
191
201
  link_forces: jtp.MatrixLike | None = None,
@@ -14,6 +14,11 @@ from jaxsim.terrain import FlatTerrain, Terrain
14
14
 
15
15
  from .common import ContactModel, ContactsParams, ContactsState
16
16
 
17
+ try:
18
+ from typing import Self
19
+ except ImportError:
20
+ from typing_extensions import Self
21
+
17
22
 
18
23
  @jax_dataclasses.pytree_dataclass
19
24
  class RigidContactsParams(ContactsParams):
@@ -50,19 +55,22 @@ class RigidContactsParams(ContactsParams):
50
55
 
51
56
  @classmethod
52
57
  def build(
53
- cls,
58
+ cls: type[Self],
59
+ *,
54
60
  mu: jtp.FloatLike | None = None,
55
61
  K: jtp.FloatLike | None = None,
56
62
  D: jtp.FloatLike | None = None,
57
- ) -> RigidContactsParams:
63
+ ) -> Self:
58
64
  """Create a `RigidContactParams` instance"""
59
- return RigidContactsParams(
65
+
66
+ return cls(
60
67
  mu=mu or cls.__dataclass_fields__["mu"].default,
61
68
  K=K or cls.__dataclass_fields__["K"].default,
62
69
  D=D or cls.__dataclass_fields__["D"].default,
63
70
  )
64
71
 
65
- def valid(self) -> bool:
72
+ def valid(self) -> jtp.BoolLike:
73
+
66
74
  return bool(
67
75
  jnp.all(self.mu >= 0.0)
68
76
  and jnp.all(self.K >= 0.0)
@@ -77,18 +85,19 @@ class RigidContactsState(ContactsState):
77
85
  def __eq__(self, other: RigidContactsState) -> bool:
78
86
  return hash(self) == hash(other)
79
87
 
80
- @staticmethod
81
- def build(**kwargs) -> RigidContactsState:
88
+ @classmethod
89
+ def build(cls: type[Self]) -> Self:
82
90
  """Create a `RigidContactsState` instance"""
83
91
 
84
- return RigidContactsState()
92
+ return cls()
85
93
 
86
- @staticmethod
87
- def zero(**kwargs) -> RigidContactsState:
94
+ @classmethod
95
+ def zero(cls: type[Self], **kwargs) -> Self:
88
96
  """Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
89
- return RigidContactsState.build()
90
97
 
91
- def valid(self, **kwargs) -> bool:
98
+ return cls.build()
99
+
100
+ def valid(self, **kwargs) -> jtp.BoolLike:
92
101
  return True
93
102
 
94
103
 
@@ -117,7 +126,8 @@ class RigidContacts(ContactModel):
117
126
  terrain_height: The height of the terrain at the collidable point position.
118
127
 
119
128
  Returns:
120
- A tuple containing the activation state of the collidable points and the contact penetration depth h.
129
+ A tuple containing the activation state of the collidable points
130
+ and the contact penetration depth h.
121
131
  """
122
132
 
123
133
  # TODO: reduce code duplication with js.contact.in_contact
@@ -154,8 +164,8 @@ class RigidContacts(ContactModel):
154
164
 
155
165
  Args:
156
166
  inactive_collidable_points: The activation state of the collidable points.
157
- M: The mass matrix of the system.
158
- J_WC: The Jacobian matrix of the collidable points.
167
+ M: The mass matrix of the system (in mixed representation).
168
+ J_WC: The Jacobian matrix of the collidable points (in mixed representation).
159
169
  data: The `JaxSimModelData` instance.
160
170
  """
161
171
 
@@ -206,10 +216,12 @@ class RigidContacts(ContactModel):
206
216
 
207
217
  return BW_ν_post_impact
208
218
 
219
+ @jax.jit
209
220
  def compute_contact_forces(
210
221
  self,
211
- position: jtp.Vector,
212
- velocity: jtp.Vector,
222
+ position: jtp.VectorLike,
223
+ velocity: jtp.VectorLike,
224
+ *,
213
225
  model: js.model.JaxSimModel,
214
226
  data: js.data.JaxSimModelData,
215
227
  link_forces: jtp.MatrixLike | None = None,
@@ -7,12 +7,18 @@ import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
 
9
9
  import jaxsim.api as js
10
+ import jaxsim.math
10
11
  import jaxsim.typing as jtp
11
- from jaxsim.math import Skew, StandardGravity
12
+ from jaxsim.math import StandardGravity
12
13
  from jaxsim.terrain import FlatTerrain, Terrain
13
14
 
14
15
  from .common import ContactModel, ContactsParams, ContactsState
15
16
 
17
+ try:
18
+ from typing import Self
19
+ except ImportError:
20
+ from typing_extensions import Self
21
+
16
22
 
17
23
  @jax_dataclasses.pytree_dataclass
18
24
  class SoftContactsParams(ContactsParams):
@@ -30,6 +36,14 @@ class SoftContactsParams(ContactsParams):
30
36
  default_factory=lambda: jnp.array(0.5, dtype=float)
31
37
  )
32
38
 
39
+ p: jtp.Float = dataclasses.field(
40
+ default_factory=lambda: jnp.array(0.5, dtype=float)
41
+ )
42
+
43
+ q: jtp.Float = dataclasses.field(
44
+ default_factory=lambda: jnp.array(0.5, dtype=float)
45
+ )
46
+
33
47
  def __hash__(self) -> int:
34
48
 
35
49
  from jaxsim.utils.wrappers import HashedNumpyArray
@@ -39,6 +53,8 @@ class SoftContactsParams(ContactsParams):
39
53
  HashedNumpyArray.hash_of_array(self.K),
40
54
  HashedNumpyArray.hash_of_array(self.D),
41
55
  HashedNumpyArray.hash_of_array(self.mu),
56
+ HashedNumpyArray.hash_of_array(self.p),
57
+ HashedNumpyArray.hash_of_array(self.q),
42
58
  )
43
59
  )
44
60
 
@@ -49,10 +65,16 @@ class SoftContactsParams(ContactsParams):
49
65
 
50
66
  return hash(self) == hash(other)
51
67
 
52
- @staticmethod
68
+ @classmethod
53
69
  def build(
54
- K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
55
- ) -> SoftContactsParams:
70
+ cls: type[Self],
71
+ *,
72
+ K: jtp.FloatLike = 1e6,
73
+ D: jtp.FloatLike = 2_000,
74
+ mu: jtp.FloatLike = 0.5,
75
+ p: jtp.FloatLike = 0.5,
76
+ q: jtp.FloatLike = 0.5,
77
+ ) -> Self:
56
78
  """
57
79
  Create a SoftContactsParams instance with specified parameters.
58
80
 
@@ -60,6 +82,12 @@ class SoftContactsParams(ContactsParams):
60
82
  K: The stiffness parameter.
61
83
  D: The damping parameter of the soft contacts model.
62
84
  mu: The static friction coefficient.
85
+ p:
86
+ The exponent p corresponding to the damping-related non-linearity
87
+ of the Hunt/Crossley model.
88
+ q:
89
+ The exponent q corresponding to the spring-related non-linearity
90
+ of the Hunt/Crossley model
63
91
 
64
92
  Returns:
65
93
  A SoftContactsParams instance with the specified parameters.
@@ -69,10 +97,13 @@ class SoftContactsParams(ContactsParams):
69
97
  K=jnp.array(K, dtype=float),
70
98
  D=jnp.array(D, dtype=float),
71
99
  mu=jnp.array(mu, dtype=float),
100
+ p=jnp.array(p, dtype=float),
101
+ q=jnp.array(q, dtype=float),
72
102
  )
73
103
 
74
- @staticmethod
104
+ @classmethod
75
105
  def build_default_from_jaxsim_model(
106
+ cls: type[Self],
76
107
  model: js.model.JaxSimModel,
77
108
  *,
78
109
  standard_gravity: jtp.FloatLike = StandardGravity,
@@ -80,6 +111,8 @@ class SoftContactsParams(ContactsParams):
80
111
  max_penetration: jtp.FloatLike = 0.001,
81
112
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
82
113
  damping_ratio: jtp.FloatLike = 1.0,
114
+ p: jtp.FloatLike = 0.5,
115
+ q: jtp.FloatLike = 0.5,
83
116
  ) -> SoftContactsParams:
84
117
  """
85
118
  Create a SoftContactsParams instance with good default parameters.
@@ -94,6 +127,12 @@ class SoftContactsParams(ContactsParams):
94
127
  The number of contacts supporting the weight of the model
95
128
  in steady state.
96
129
  damping_ratio: The ratio controlling the damping behavior.
130
+ p:
131
+ The exponent p corresponding to the damping-related non-linearity
132
+ of the Hunt/Crossley model.
133
+ q:
134
+ The exponent q corresponding to the spring-related non-linearity
135
+ of the Hunt/Crossley model
97
136
 
98
137
  Returns:
99
138
  A `SoftContactsParams` instance with the specified parameters.
@@ -126,9 +165,9 @@ class SoftContactsParams(ContactsParams):
126
165
  critical_damping = 2 * jnp.sqrt(K * m)
127
166
  D = ξ * critical_damping
128
167
 
129
- return SoftContactsParams.build(K=K, D=D, mu=μc)
168
+ return SoftContactsParams.build(K=K, D=D, mu=μc, p=p, q=q)
130
169
 
131
- def valid(self) -> bool:
170
+ def valid(self) -> jtp.BoolLike:
132
171
  """
133
172
  Check if the parameters are valid.
134
173
 
@@ -136,11 +175,15 @@ class SoftContactsParams(ContactsParams):
136
175
  `True` if the parameters are valid, `False` otherwise.
137
176
  """
138
177
 
139
- return (
140
- jnp.all(self.K >= 0.0)
141
- and jnp.all(self.D >= 0.0)
142
- and jnp.all(self.mu >= 0.0)
143
- )
178
+ return jnp.hstack(
179
+ [
180
+ self.K >= 0.0,
181
+ self.D >= 0.0,
182
+ self.mu >= 0.0,
183
+ self.p >= 0.0,
184
+ self.q >= 0.0,
185
+ ]
186
+ ).all()
144
187
 
145
188
 
146
189
  @jax_dataclasses.pytree_dataclass
@@ -157,179 +200,158 @@ class SoftContacts(ContactModel):
157
200
 
158
201
  def compute_contact_forces(
159
202
  self,
160
- position: jtp.Vector,
161
- velocity: jtp.Vector,
162
- tangential_deformation: jtp.Vector,
203
+ position: jtp.VectorLike,
204
+ velocity: jtp.VectorLike,
205
+ *,
206
+ tangential_deformation: jtp.VectorLike,
163
207
  ) -> tuple[jtp.Vector, tuple[jtp.Vector]]:
164
- """
165
- Compute the contact forces and material deformation rate.
166
208
 
167
- Args:
168
- position: The position of the collidable point.
169
- velocity: The linear velocity of the collidable point.
170
- tangential_deformation: The tangential deformation.
171
-
172
- Returns:
173
- A tuple containing the contact force and material deformation rate.
174
- """
209
+ # Convert the input vectors to arrays.
210
+ W_p_C = jnp.array(position, dtype=float).squeeze()
211
+ W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
212
+ m = jnp.array(tangential_deformation, dtype=float).squeeze()
175
213
 
176
- # Short name of parameters
214
+ # Short name of parameters.
177
215
  K = self.parameters.K
178
216
  D = self.parameters.D
179
217
  μ = self.parameters.mu
180
218
 
181
- # Material 3D tangential deformation and its derivative
182
- m = tangential_deformation.squeeze()
183
- ṁ = jnp.zeros_like(m)
219
+ # Compute the penetration depth, its rate, and the considered terrain normal.
220
+ δ, δ̇, n̂ = self.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=self.terrain)
184
221
 
185
- # Note: all the small hardcoded tolerances in this method have been introduced
186
- # to allow jax differentiating through this algorithm. They should not affect
187
- # the accuracy of the simulation, although they might make it less readable.
222
+ # Get the exponents of the Hunt/Crossley model non-linear terms.
223
+ p = self.parameters.p
224
+ q = self.parameters.q
225
+
226
+ # There are few operations like computing the norm of a vector with zero length
227
+ # or computing the square root of zero that are problematic in an AD context.
228
+ # To avoid these issues, we introduce a small tolerance ε to their arguments
229
+ # and make sure that we do not check them against zero directly.
230
+ ε = jnp.finfo(float).eps
231
+
232
+ # Compute the powers of the penetration depth.
233
+ # Inject ε to address AD issues in differentiating the square root when
234
+ # p and q are fractional.
235
+ δp = jnp.power(δ + ε, p)
236
+ δq = jnp.power(δ + ε, q)
188
237
 
189
238
  # ========================
190
- # Normal force computation
239
+ # Compute the normal force
191
240
  # ========================
192
241
 
193
- # Unpack the position of the collidable point.
194
- px, py, pz = W_p_C = position.squeeze()
195
- W_ṗ_C = velocity.squeeze()
242
+ # Non-linear spring-damper model (Hunt/Crossley model).
243
+ # This is the force magnitude along the direction normal to the terrain.
244
+ force_normal_mag = (K * δp) * δ + (D * δq) * δ̇
196
245
 
197
- # Compute the terrain normal and the contact depth.
198
- = self.terrain.normal(x=px, y=py).squeeze()
199
- h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
246
+ # Depending on the magnitude of δ̇, the normal force could be negative.
247
+ force_normal_mag = jnp.maximum(0.0, force_normal_mag)
200
248
 
201
- # Compute the penetration depth normal to the terrain.
202
- δ = jnp.maximum(0.0, jnp.dot(h,))
249
+ # Compute the 3D linear force in C[W] frame.
250
+ f_normal = force_normal_mag *
251
+
252
+ # ============================
253
+ # Compute the tangential force
254
+ # ============================
255
+
256
+ # Extract the tangential component of the velocity.
257
+ v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
258
+
259
+ # Extract the tangential component of the material deformation.
260
+ # This should not be necessary if the sticking-slipping transition occurs
261
+ # in a terrain area with a locally constant normal. However, this assumption
262
+ # is not true in general for highly uneven terrains.
263
+ m_normal = jnp.dot(m, n̂) * n̂
264
+ m_tangential = m - jnp.dot(m, n̂) * n̂
265
+
266
+ # Compute the tangential force in the sticking case.
267
+ f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
268
+
269
+ # Detect the contact type (sticking or slipping).
270
+ # Note that if there is no contact, sticking is set to True, and this detail
271
+ # is exploited in the computation of the `contact_status` variable.
272
+ sticking = jnp.logical_or(
273
+ δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
274
+ )
203
275
 
204
- # Compute the penetration normal velocity.
205
- δ̇ = -jnp.dot(W_ṗ_C, n̂)
276
+ # Compute the direction of the tangential force.
277
+ # To prevent dividing by zero, we use a switch statement.
278
+ # The ε, instead, is needed to make AD happy.
279
+ f_tangential_direction = jnp.where(
280
+ f_tangential.dot(f_tangential) != 0,
281
+ f_tangential / jnp.linalg.norm(f_tangential + ε),
282
+ jnp.zeros(3),
283
+ )
206
284
 
207
- # Non-linear spring-damper model.
208
- # This is the force magnitude along the direction normal to the terrain.
209
- force_normal_mag = jax.lax.select(
210
- pred=δ >= 1e-9,
211
- on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
212
- on_false=jnp.array(0.0),
285
+ # Project the tangential force to the friction cone if slipping.
286
+ f_tangential = jnp.where(
287
+ sticking,
288
+ f_tangential,
289
+ jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε))
290
+ * f_tangential_direction,
213
291
  )
214
292
 
215
- # Prevent negative normal forces that might occur when δ̇ is largely negative.
216
- force_normal_mag = jnp.maximum(0.0, force_normal_mag)
293
+ # Set the tangential force to zero if there is no contact.
294
+ f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)
217
295
 
218
- # Compute the 3D linear force in C[W] frame.
219
- force_normal = force_normal_mag *
296
+ # =====================================
297
+ # Compute the material deformation rate
298
+ # =====================================
220
299
 
221
- # ====================================
222
- # No friction and no tangential forces
223
- # ====================================
300
+ # Compute the derivative of the material deformation.
301
+ ṁ_no_contact = -(K / D) * m
302
+ ṁ_sticking = v_tangential - (K / D) * m_normal
303
+ ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
224
304
 
225
- # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
226
- # Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
227
- W_Xf_CW = jnp.vstack(
228
- [
229
- jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
230
- jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
231
- ]
232
- )
305
+ # Compute the contact status:
306
+ # 0: slipping
307
+ # 1: sticking
308
+ # 2: no contact
309
+ contact_status = sticking.astype(int)
310
+ contact_status += (δ <= 0).astype(int)
233
311
 
234
- def with_no_friction():
235
- # Compute 6D mixed force in C[W].
236
- CW_f_lin = force_normal
237
- CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
238
-
239
- # Compute lin-ang 6D forces (inertial representation).
240
- W_f = W_Xf_CW @ CW_f
241
-
242
- return W_f, (ṁ,)
243
-
244
- # =========================
245
- # Compute tangential forces
246
- # =========================
247
-
248
- def with_friction():
249
- # Initialize the tangential deformation rate ṁ.
250
- # For inactive contacts with m≠0, this is the dynamics of the material
251
- # relaxation converging exponentially to steady state.
252
- ṁ = (-K / D) * m
253
-
254
- # Check if the collidable point is below ground.
255
- # Note: when δ=0, we consider the point still not it contact such that
256
- # we prevent divisions by 0 in the computations below.
257
- active_contact = pz < self.terrain.height(x=px, y=py)
258
-
259
- def above_terrain():
260
- return jnp.zeros(6), (ṁ,)
261
-
262
- def below_terrain():
263
- # Decompose the velocity in normal and tangential components.
264
- v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
265
- v_tangential = W_ṗ_C - v_normal
266
-
267
- # Compute the tangential force. If inside the friction cone, the contact.
268
- f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
269
-
270
- def sticking_contact():
271
- # Sum the normal and tangential forces, and create the 6D force.
272
- CW_f_stick = force_normal + f_tangential
273
- CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
274
-
275
- # In this case the 3D material deformation is the tangential velocity.
276
- ṁ = v_tangential
277
-
278
- # Return the 6D force in the contact frame and
279
- # the deformation derivative.
280
- return CW_f, ṁ
281
-
282
- def slipping_contact():
283
- # Project the force to the friction cone boundary.
284
- f_tangential_projected = (μ * force_normal_mag) * (
285
- f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9)
286
- )
287
-
288
- # Sum the normal and tangential forces, and create the 6D force.
289
- CW_f_slip = force_normal + f_tangential_projected
290
- CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
291
-
292
- # Correct the material deformation derivative for slipping contacts.
293
- # Basically we compute ṁ such that we get `f_tangential` on the cone
294
- # given the current (m, δ).
295
- ε = 1e-9
296
- δε = jnp.maximum(δ, ε)
297
- α = -K * jnp.sqrt(δε)
298
- β = -D * jnp.sqrt(δε)
299
- ṁ = (f_tangential_projected - α * m) / β
300
-
301
- # Return the 6D force in the contact frame and
302
- # the deformation derivative.
303
- return CW_f, ṁ
304
-
305
- CW_f, ṁ = jax.lax.cond(
306
- pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
307
- true_fun=lambda _: slipping_contact(),
308
- false_fun=lambda _: sticking_contact(),
309
- operand=None,
310
- )
311
-
312
- # Express the 6D force in the world frame.
313
- W_f = W_Xf_CW @ CW_f
314
-
315
- # Return the 6D force in the world frame and the deformation derivative.
316
- return W_f, (ṁ,)
317
-
318
- # (W_f, (ṁ,))
319
- return jax.lax.cond(
320
- pred=active_contact,
321
- true_fun=lambda _: below_terrain(),
322
- false_fun=lambda _: above_terrain(),
323
- operand=None,
324
- )
312
+ # Select the right material deformation rate depending on the contact status.
313
+ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)
325
314
 
326
- # (W_f, (ṁ,))
327
- return jax.lax.cond(
328
- pred=(μ == 0.0),
329
- true_fun=lambda _: with_no_friction(),
330
- false_fun=lambda _: with_friction(),
331
- operand=None,
332
- )
315
+ # ==========================================
316
+ # Compute and return the final contact force
317
+ # ==========================================
318
+
319
+ # Sum the normal and tangential forces and create a mixed 6D force.
320
+ CW_f = jnp.hstack([f_normal + f_tangential, jnp.zeros(3)])
321
+
322
+ # Compute the 6D force transform from the mixed to the inertial-fixed frame.
323
+ W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation(
324
+ translation=W_p_C, inverse=True
325
+ ).T
326
+
327
+ return W_Xf_CW @ CW_f, (ṁ,)
328
+
329
+ @staticmethod
330
+ @jax.jit
331
+ def compute_penetration_data(
332
+ p: jtp.VectorLike,
333
+ v: jtp.VectorLike,
334
+ terrain: jaxsim.terrain.Terrain,
335
+ ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
336
+
337
+ # Pre-process the position and the linear velocity of the collidable point.
338
+ W_ṗ_C = jnp.array(v).squeeze()
339
+ px, py, pz = jnp.array(p).squeeze()
340
+
341
+ # Compute the terrain normal and the contact depth.
342
+ n̂ = terrain.normal(x=px, y=py).squeeze()
343
+ h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
344
+
345
+ # Compute the penetration depth normal to the terrain.
346
+ δ = jnp.maximum(0.0, jnp.dot(h, n̂))
347
+
348
+ # Compute the penetration normal velocity.
349
+ δ̇ = -jnp.dot(W_ṗ_C, n̂)
350
+
351
+ # Enforce the penetration rate to be zero when the penetration depth is zero.
352
+ δ̇ = jnp.where(δ > 0, δ̇, 0.0)
353
+
354
+ return δ, δ̇, n̂
333
355
 
334
356
 
335
357
  @jax_dataclasses.pytree_dataclass
@@ -346,21 +368,24 @@ class SoftContactsState(ContactsState):
346
368
  tangential_deformation: jtp.Matrix
347
369
 
348
370
  def __hash__(self) -> int:
371
+
349
372
  return hash(
350
373
  tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
351
374
  )
352
375
 
353
- def __eq__(self, other: SoftContactsState) -> bool:
354
- if not isinstance(other, SoftContactsState):
376
+ def __eq__(self: Self, other: Self) -> bool:
377
+
378
+ if not isinstance(other, type(self)):
355
379
  return False
356
380
 
357
381
  return hash(self) == hash(other)
358
382
 
359
- @staticmethod
383
+ @classmethod
360
384
  def build_from_jaxsim_model(
385
+ cls: type[Self],
361
386
  model: js.model.JaxSimModel | None = None,
362
- tangential_deformation: jtp.Matrix | None = None,
363
- ) -> SoftContactsState:
387
+ tangential_deformation: jtp.MatrixLike | None = None,
388
+ ) -> Self:
364
389
  """
365
390
  Build a `SoftContactsState` from a `JaxSimModel`.
366
391
 
@@ -376,18 +401,20 @@ class SoftContactsState(ContactsState):
376
401
  `JaxSimModel` and initialized to zero.
377
402
  """
378
403
 
379
- return SoftContactsState.build(
404
+ return cls.build(
380
405
  tangential_deformation=tangential_deformation,
381
406
  number_of_collidable_points=len(
382
407
  model.kin_dyn_parameters.contact_parameters.body
383
408
  ),
384
409
  )
385
410
 
386
- @staticmethod
411
+ @classmethod
387
412
  def build(
388
- tangential_deformation: jtp.Matrix | None = None,
413
+ cls: type[Self],
414
+ *,
415
+ tangential_deformation: jtp.MatrixLike | None = None,
389
416
  number_of_collidable_points: int | None = None,
390
- ) -> SoftContactsState:
417
+ ) -> Self:
391
418
  """
392
419
  Create a `SoftContactsState`.
393
420
 
@@ -402,10 +429,10 @@ class SoftContactsState(ContactsState):
402
429
  """
403
430
 
404
431
  tangential_deformation = (
405
- tangential_deformation
432
+ jnp.atleast_2d(tangential_deformation)
406
433
  if tangential_deformation is not None
407
434
  else jnp.zeros(shape=(number_of_collidable_points, 3))
408
- )
435
+ ).astype(float)
409
436
 
410
437
  if tangential_deformation.shape[1] != 3:
411
438
  raise RuntimeError("The tangential deformation matrix must have 3 columns.")
@@ -418,12 +445,10 @@ class SoftContactsState(ContactsState):
418
445
  msg += "in the tangential deformation matrix."
419
446
  raise RuntimeError(msg)
420
447
 
421
- return SoftContactsState(
422
- tangential_deformation=jnp.array(tangential_deformation).astype(float)
423
- )
448
+ return cls(tangential_deformation=tangential_deformation)
424
449
 
425
- @staticmethod
426
- def zero(model: js.model.JaxSimModel) -> SoftContactsState:
450
+ @classmethod
451
+ def zero(cls: type[Self], *, model: js.model.JaxSimModel) -> Self:
427
452
  """
428
453
  Build a zero `SoftContactsState` from a `JaxSimModel`.
429
454
 
@@ -434,9 +459,9 @@ class SoftContactsState(ContactsState):
434
459
  A zero `SoftContactsState` instance.
435
460
  """
436
461
 
437
- return SoftContactsState.build_from_jaxsim_model(model=model)
462
+ return cls.build_from_jaxsim_model(model=model)
438
463
 
439
- def valid(self, model: js.model.JaxSimModel) -> bool:
464
+ def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike:
440
465
  """
441
466
  Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
442
467
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev105
3
+ Version: 0.4.3.dev118
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,12 +1,12 @@
1
1
  jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=1Pw2WIpSfIcMXUmFaqKP5hkuhIK6fZ2PSjrJqsDco98,428
2
+ jaxsim/_version.py,sha256=SpzKi91xA9hgMqjpOts0HnzKduXJ_CeCiMYlA0iOEgY,428
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
- jaxsim/api/contact.py,sha256=BQMIMHBFYiHe_LVx_bwxKCpy20uiy0V-NljHfYXWhI0,23013
9
+ jaxsim/api/contact.py,sha256=ocwsVS1jaBfrd81990hcgfS0-2xD8VVzDq7gdPguAUg,23087
10
10
  jaxsim/api/data.py,sha256=QldUHniJqKrdNtAcXuRaS9UyeslJ0Rjvb17UA0Ca5Tw,29008
11
11
  jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
12
12
  jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
@@ -52,19 +52,19 @@ jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdul
52
52
  jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
53
53
  jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
54
54
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
- jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
- jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
57
- jaxsim/rbda/contacts/relaxed_rigid.py,sha256=deTC0M2a_RER7iwVpxLCfuSlgBLqkTmHQdOJ4169IR4,13646
58
- jaxsim/rbda/contacts/rigid.py,sha256=BYvQgoaG5yKo2N1SLlXgjP6cb1OrJ1BawGXkJf0Hhi0,15060
59
- jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
55
+ jaxsim/rbda/contacts/__init__.py,sha256=Y1yT2zdgFa0zviZseI09wNaMcydH8TeoaWr6ehqzwdc,328
56
+ jaxsim/rbda/contacts/common.py,sha256=CEmLS_PT44AOWKJ0bWrJJBqm2Q9v9LiqvL0rht63-ic,2605
57
+ jaxsim/rbda/contacts/relaxed_rigid.py,sha256=VCFU2WG7MoDA6eo72VOLzNhXd1ZEy78F-52JVNASAcU,13696
58
+ jaxsim/rbda/contacts/rigid.py,sha256=QBjgBIJR6jz3w_dd-ZEex6K1c9_Cwwl5xrZALHzz5Zo,15224
59
+ jaxsim/rbda/contacts/soft.py,sha256=-d7zbMdKNq0aRT2zRXIu_Dbh8BL4VUnMDprz4Ddfwj0,16276
60
60
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
61
61
  jaxsim/terrain/terrain.py,sha256=xUQg47yGxIOcTkLPbnO3sruEGBhoCd16j1evTGlmNjI,5010
62
62
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
63
63
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
64
64
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
65
65
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
66
- jaxsim-0.4.3.dev105.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
- jaxsim-0.4.3.dev105.dist-info/METADATA,sha256=jYNJIkcAgei9qg7OLD7-E0HHYZRb3h4qFdmaVnyHXGo,17277
68
- jaxsim-0.4.3.dev105.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
69
- jaxsim-0.4.3.dev105.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
- jaxsim-0.4.3.dev105.dist-info/RECORD,,
66
+ jaxsim-0.4.3.dev118.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
+ jaxsim-0.4.3.dev118.dist-info/METADATA,sha256=SH1vKEjwiZxKVzUXu2BqZTNkEjIgXylaoaNuYwqiQQY,17277
68
+ jaxsim-0.4.3.dev118.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
69
+ jaxsim-0.4.3.dev118.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
+ jaxsim-0.4.3.dev118.dist-info/RECORD,,