jaxsim 0.4.3.dev186__py3-none-any.whl → 0.4.3.dev200__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev186'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev186')
15
+ __version__ = version = '0.4.3.dev200'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev200')
jaxsim/api/contact.py CHANGED
@@ -36,11 +36,10 @@ def collidable_point_kinematics(
36
36
  the linear component of the mixed 6D frame velocity.
37
37
  """
38
38
 
39
- from jaxsim.rbda import collidable_points
40
-
41
39
  # Switch to inertial-fixed since the RBDAs expect velocities in this representation.
42
40
  with data.switch_velocity_representation(VelRepr.Inertial):
43
- W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(
41
+
42
+ W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
44
43
  model=model,
45
44
  base_position=data.base_position(),
46
45
  base_quaternion=data.base_orientation(dcm=False),
@@ -304,6 +303,15 @@ def in_contact(
304
303
 
305
304
 
306
305
  def estimate_good_soft_contacts_parameters(
306
+ *args, **kwargs
307
+ ) -> jaxsim.rbda.contacts.ContactParamsTypes:
308
+
309
+ msg = "This method is deprecated, please use `{}`."
310
+ logging.warning(msg.format(estimate_good_contact_parameters.__name__))
311
+ return estimate_good_contact_parameters(*args, **kwargs)
312
+
313
+
314
+ def estimate_good_contact_parameters(
307
315
  model: js.model.JaxSimModel,
308
316
  *,
309
317
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
@@ -312,14 +320,9 @@ def estimate_good_soft_contacts_parameters(
312
320
  damping_ratio: jtp.FloatLike = 1.0,
313
321
  max_penetration: jtp.FloatLike | None = None,
314
322
  **kwargs,
315
- ) -> (
316
- jaxsim.rbda.contacts.RelaxedRigidContactsParams
317
- | jaxsim.rbda.contacts.RigidContactsParams
318
- | jaxsim.rbda.contacts.SoftContactsParams
319
- | jaxsim.rbda.contacts.ViscoElasticContactsParams
320
- ):
323
+ ) -> jaxsim.rbda.contacts.ContactParamsTypes:
321
324
  """
322
- Estimate good parameters for soft-like contact models.
325
+ Estimate good contact parameters.
323
326
 
324
327
  Args:
325
328
  model: The model to consider.
@@ -332,12 +335,19 @@ def estimate_good_soft_contacts_parameters(
332
335
  max_penetration:
333
336
  The maximum penetration allowed in steady state when the robot is
334
337
  supported by the configured number of active collidable points.
338
+ kwargs:
339
+ Additional model-specific parameters passed to the builder method of
340
+ the parameters class.
335
341
 
336
342
  Returns:
337
- The estimated good soft contacts parameters.
343
+ The estimated good contacts parameters.
344
+
345
+ Note:
346
+ This is primarily a convenience function for soft-like contact models.
347
+ However, it provides with some good default parameters also for the other ones.
338
348
 
339
349
  Note:
340
- This method provides a good starting point for the soft contacts parameters.
350
+ This method provides a good set of contacts parameters.
341
351
  The user is encouraged to fine-tune the parameters based on the
342
352
  specific application.
343
353
  """
@@ -364,6 +374,7 @@ def estimate_good_soft_contacts_parameters(
364
374
  max_δ = (
365
375
  max_penetration
366
376
  if max_penetration is not None
377
+ # Consider as default a 0.5% of the model height.
367
378
  else 0.005 * estimate_model_height(model=model)
368
379
  )
369
380
 
@@ -381,8 +392,11 @@ def estimate_good_soft_contacts_parameters(
381
392
  max_penetration=max_δ,
382
393
  number_of_active_collidable_points_steady_state=nc,
383
394
  damping_ratio=damping_ratio,
384
- p=model.contact_model.parameters.p,
385
- q=model.contact_model.parameters.q,
395
+ **dict(
396
+ p=model.contact_model.parameters.p,
397
+ q=model.contact_model.parameters.q,
398
+ )
399
+ | kwargs,
386
400
  )
387
401
 
388
402
  case contacts.ViscoElasticContacts():
@@ -396,15 +410,40 @@ def estimate_good_soft_contacts_parameters(
396
410
  max_penetration=max_δ,
397
411
  number_of_active_collidable_points_steady_state=nc,
398
412
  damping_ratio=damping_ratio,
399
- p=model.contact_model.parameters.p,
400
- q=model.contact_model.parameters.q,
401
- **kwargs,
413
+ **dict(
414
+ p=model.contact_model.parameters.p,
415
+ q=model.contact_model.parameters.q,
416
+ )
417
+ | kwargs,
402
418
  )
403
419
  )
404
420
 
421
+ case contacts.RigidContacts():
422
+ assert isinstance(model.contact_model, contacts.RigidContacts)
423
+
424
+ # Disable Baumgarte stabilization by default since it does not play
425
+ # well with the forward Euler integrator.
426
+ K = kwargs.get("K", 0.0)
427
+
428
+ parameters = contacts.RigidContactsParams.build(
429
+ mu=static_friction_coefficient,
430
+ **dict(
431
+ K=K,
432
+ D=2 * jnp.sqrt(K),
433
+ )
434
+ | kwargs,
435
+ )
436
+
437
+ case contacts.RelaxedRigidContacts():
438
+ assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
439
+
440
+ parameters = contacts.RelaxedRigidContactsParams.build(
441
+ mu=static_friction_coefficient,
442
+ **kwargs,
443
+ )
444
+
405
445
  case _:
406
- logging.warning("The active contact model is not soft-like, no-op.")
407
- parameters = model.contact_model.parameters
446
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
408
447
 
409
448
  return parameters
410
449
 
jaxsim/api/data.py CHANGED
@@ -34,7 +34,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
34
34
 
35
35
  state: ODEState
36
36
 
37
- gravity: jtp.Array
37
+ gravity: jtp.Vector
38
38
 
39
39
  contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
40
40
 
@@ -224,7 +224,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
224
224
  jaxsim.rbda.contacts.SoftContacts
225
225
  | jaxsim.rbda.contacts.ViscoElasticContacts,
226
226
  ):
227
- contacts_params = js.contact.estimate_good_soft_contacts_parameters(
227
+
228
+ contacts_params = js.contact.estimate_good_contact_parameters(
228
229
  model=model, standard_gravity=standard_gravity
229
230
  )
230
231
 
jaxsim/api/model.py CHANGED
@@ -40,6 +40,8 @@ class JaxSimModel(JaxsimDataclass):
40
40
  default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
41
41
  )
42
42
 
43
+ # Note that this is the default contact model.
44
+ # Its parameters, if any, are then overridden from those stored in JaxSimModelData.
43
45
  contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
44
46
  default=None, repr=False
45
47
  )
@@ -2044,24 +2046,18 @@ def step(
2044
2046
  M = js.model.free_floating_mass_matrix(model, data_tf)
2045
2047
  W_p_C = js.contact.collidable_point_positions(model, data_tf)
2046
2048
 
2047
- # Compute the height of the terrain below each collidable point.
2048
- px, py, _ = W_p_C.T
2049
- terrain_height = jax.vmap(model.terrain.height)(px, py)
2050
-
2051
- # Compute the contact state.
2052
- inactive_collidable_points, _ = (
2053
- jaxsim.rbda.contacts.RigidContacts.detect_contacts(
2054
- W_p_C=W_p_C,
2055
- terrain_height=terrain_height,
2056
- )
2057
- )
2049
+ # Compute the penetration depth of the collidable points.
2050
+ δ, *_ = jax.vmap(
2051
+ jaxsim.rbda.contacts.common.compute_penetration_data,
2052
+ in_axes=(0, 0, None),
2053
+ )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
2058
2054
 
2059
2055
  # Compute the impact velocity.
2060
2056
  # It may be discontinuous in case new contacts are made.
2061
2057
  BW_nu_post_impact = (
2062
2058
  jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
2063
2059
  data=data_tf,
2064
- inactive_collidable_points=inactive_collidable_points,
2060
+ inactive_collidable_points=(δ <= 0),
2065
2061
  M=M,
2066
2062
  J_WC=J_WC,
2067
2063
  )
@@ -4,3 +4,10 @@ from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
4
4
  from .rigid import RigidContacts, RigidContactsParams
5
5
  from .soft import SoftContacts, SoftContactsParams
6
6
  from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
7
+
8
+ ContactParamsTypes = (
9
+ SoftContactsParams
10
+ | RigidContactsParams
11
+ | RelaxedRigidContactsParams
12
+ | ViscoElasticContactsParams
13
+ )
@@ -1,8 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
+ import functools
4
5
  from typing import Any
5
6
 
7
+ import jax
8
+ import jax.numpy as jnp
9
+
6
10
  import jaxsim.api as js
7
11
  import jaxsim.terrain
8
12
  import jaxsim.typing as jtp
@@ -14,6 +18,47 @@ except ImportError:
14
18
  from typing_extensions import Self
15
19
 
16
20
 
21
+ @functools.partial(jax.jit, static_argnames=("terrain",))
22
+ def compute_penetration_data(
23
+ p: jtp.VectorLike,
24
+ v: jtp.VectorLike,
25
+ terrain: jaxsim.terrain.Terrain,
26
+ ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
27
+ """
28
+ Compute the penetration data (depth, rate, and terrain normal) of a collidable point.
29
+
30
+ Args:
31
+ p: The position of the collidable point.
32
+ v:
33
+ The linear velocity of the point (linear component of the mixed 6D velocity
34
+ of the implicit frame `C = (W_p_C, [W])` associated to the point).
35
+ terrain: The considered terrain.
36
+
37
+ Returns:
38
+ A tuple containing the penetration depth, the penetration velocity,
39
+ and the considered terrain normal.
40
+ """
41
+
42
+ # Pre-process the position and the linear velocity of the collidable point.
43
+ W_ṗ_C = jnp.array(v).squeeze()
44
+ px, py, pz = jnp.array(p).squeeze()
45
+
46
+ # Compute the terrain normal and the contact depth.
47
+ n̂ = terrain.normal(x=px, y=py).squeeze()
48
+ h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
49
+
50
+ # Compute the penetration depth normal to the terrain.
51
+ δ = jnp.maximum(0.0, jnp.dot(h, n̂))
52
+
53
+ # Compute the penetration normal velocity.
54
+ δ_dot = -jnp.dot(W_ṗ_C, n̂)
55
+
56
+ # Enforce the penetration rate to be zero when the penetration depth is zero.
57
+ δ_dot = jnp.where(δ > 0, δ_dot, 0.0)
58
+
59
+ return δ, δ_dot, n̂
60
+
61
+
17
62
  class ContactsParams(JaxsimDataclass):
18
63
  """
19
64
  Abstract class representing the parameters of a contact model.
@@ -86,7 +131,7 @@ class ContactModel(JaxsimDataclass):
86
131
  model: js.model.JaxSimModel,
87
132
  data: js.data.JaxSimModelData,
88
133
  **kwargs,
89
- ) -> tuple[jtp.Vector, tuple[Any, ...]]:
134
+ ) -> tuple[jtp.Matrix, tuple[Any, ...]]:
90
135
  """
91
136
  Compute the contact forces.
92
137
 
@@ -95,8 +140,9 @@ class ContactModel(JaxsimDataclass):
95
140
  data: The data of the considered model.
96
141
 
97
142
  Returns:
98
- A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
99
- and as second element a tuple of optional additional information.
143
+ A tuple containing as first element the computed 6D contact force applied to
144
+ the contact points and expressed in the world frame, and as second element
145
+ a tuple of optional additional information.
100
146
  """
101
147
 
102
148
  pass
@@ -12,11 +12,10 @@ import optax
12
12
  import jaxsim.api as js
13
13
  import jaxsim.typing as jtp
14
14
  from jaxsim import logging
15
- from jaxsim.api.common import VelRepr
16
- from jaxsim.math import Adjoint
15
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
17
16
  from jaxsim.terrain.terrain import FlatTerrain, Terrain
18
17
 
19
- from .common import ContactModel, ContactsParams
18
+ from . import common
20
19
 
21
20
  try:
22
21
  from typing import Self
@@ -25,7 +24,7 @@ except ImportError:
25
24
 
26
25
 
27
26
  @jax_dataclasses.pytree_dataclass
28
- class RelaxedRigidContactsParams(ContactsParams):
27
+ class RelaxedRigidContactsParams(common.ContactsParams):
29
28
  """Parameters of the relaxed rigid contacts model."""
30
29
 
31
30
  # Time constant
@@ -116,14 +115,24 @@ class RelaxedRigidContactsParams(ContactsParams):
116
115
  ) -> Self:
117
116
  """Create a `RelaxedRigidContactsParams` instance"""
118
117
 
118
+ def default(name: str):
119
+ return cls.__dataclass_fields__[name].default_factory()
120
+
119
121
  return cls(
120
- **{
121
- field: jnp.array(locals().get(field, default), dtype=default.dtype)
122
- for field, default in map(
123
- lambda f: (f, cls.__dataclass_fields__[f].default),
124
- filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
125
- )
126
- }
122
+ time_constant=jnp.array(
123
+ time_constant or default("time_constant"), dtype=float
124
+ ),
125
+ damping_coefficient=jnp.array(
126
+ damping_coefficient or default("damping_coefficient"), dtype=float
127
+ ),
128
+ d_min=jnp.array(d_min or default("d_min"), dtype=float),
129
+ d_max=jnp.array(d_max or default("d_max"), dtype=float),
130
+ width=jnp.array(width or default("width"), dtype=float),
131
+ midpoint=jnp.array(midpoint or default("midpoint"), dtype=float),
132
+ power=jnp.array(power or default("power"), dtype=float),
133
+ stiffness=jnp.array(stiffness or default("stiffness"), dtype=float),
134
+ damping=jnp.array(damping or default("damping"), dtype=float),
135
+ mu=jnp.array(mu or default("mu"), dtype=float),
127
136
  )
128
137
 
129
138
  def valid(self) -> jtp.BoolLike:
@@ -142,7 +151,7 @@ class RelaxedRigidContactsParams(ContactsParams):
142
151
 
143
152
 
144
153
  @jax_dataclasses.pytree_dataclass
145
- class RelaxedRigidContacts(ContactModel):
154
+ class RelaxedRigidContacts(common.ContactModel):
146
155
  """Relaxed rigid contacts model."""
147
156
 
148
157
  parameters: RelaxedRigidContactsParams = dataclasses.field(
@@ -229,7 +238,7 @@ class RelaxedRigidContacts(ContactModel):
229
238
  *,
230
239
  link_forces: jtp.MatrixLike | None = None,
231
240
  joint_force_references: jtp.VectorLike | None = None,
232
- ) -> tuple[jtp.Vector, tuple[Any, ...]]:
241
+ ) -> tuple[jtp.Matrix, tuple]:
233
242
  """
234
243
  Compute the contact forces.
235
244
 
@@ -243,22 +252,23 @@ class RelaxedRigidContacts(ContactModel):
243
252
  Optional `(n_joints,)` vector of joint forces.
244
253
 
245
254
  Returns:
246
- A tuple containing the contact forces.
255
+ A tuple containing as first element the computed contact forces.
247
256
  """
248
257
 
249
258
  # Initialize the model and data this contact model is operating on.
250
259
  # This will raise an exception if either the contact model or the
251
260
  # contact parameters are not compatible.
252
261
  model, data = self.initialize_model_and_data(model=model, data=data)
262
+ assert isinstance(data.contacts_params, RelaxedRigidContactsParams)
253
263
 
254
- link_forces = (
255
- link_forces
264
+ link_forces = jnp.atleast_2d(
265
+ jnp.array(link_forces, dtype=float).squeeze()
256
266
  if link_forces is not None
257
267
  else jnp.zeros((model.number_of_links(), 6))
258
268
  )
259
269
 
260
- joint_force_references = (
261
- joint_force_references
270
+ joint_force_references = jnp.atleast_1d(
271
+ jnp.array(joint_force_references, dtype=float).squeeze()
262
272
  if joint_force_references is not None
263
273
  else jnp.zeros(model.number_of_joints())
264
274
  )
@@ -271,10 +281,10 @@ class RelaxedRigidContacts(ContactModel):
271
281
  joint_force_references=joint_force_references,
272
282
  )
273
283
 
274
- def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
284
+ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
275
285
  x, y, z = jax.tree.map(jnp.squeeze, (x, y, z))
276
286
 
277
- n̂ = self.terrain.normal(x=x, y=y).squeeze()
287
+ n̂ = model.terrain.normal(x=x, y=y).squeeze()
278
288
  h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
279
289
 
280
290
  return jnp.dot(h, n̂)
@@ -286,19 +296,19 @@ class RelaxedRigidContacts(ContactModel):
286
296
  )
287
297
 
288
298
  # Compute the activation state of the collidable points
289
- δ = jax.vmap(_detect_contact)(*position.T)
299
+ δ = jax.vmap(detect_contact)(*position.T)
300
+
301
+ # Compute the transforms of the implicit frames corresponding to the
302
+ # collidable points.
303
+ W_H_C = js.contact.transforms(model=model, data=data)
290
304
 
291
305
  with (
292
306
  references.switch_velocity_representation(VelRepr.Mixed),
293
307
  data.switch_velocity_representation(VelRepr.Mixed),
294
308
  ):
295
- M = js.model.free_floating_mass_matrix(model=model, data=data)
296
- Jl_WC = jnp.vstack(
297
- jax.vmap(lambda J, height: J * (height < 0))(
298
- js.contact.jacobian(model=model, data=data)[:, :3, :], δ
299
- )
300
- )
301
- W_H_C = js.contact.transforms(model=model, data=data)
309
+
310
+ BW_ν = data.generalized_velocity()
311
+
302
312
  BW_ν̇_free = jnp.hstack(
303
313
  js.ode.system_acceleration(
304
314
  model=model,
@@ -309,20 +319,31 @@ class RelaxedRigidContacts(ContactModel):
309
319
  ),
310
320
  )
311
321
  )
312
- BW_ν = data.generalized_velocity()
322
+
323
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
324
+
325
+ Jl_WC = jnp.vstack(
326
+ jax.vmap(lambda J, height: J * (height < 0))(
327
+ js.contact.jacobian(model=model, data=data)[:, :3, :], δ
328
+ )
329
+ )
330
+
313
331
  J̇_WC = jnp.vstack(
314
332
  jax.vmap(lambda J̇, height: J̇ * (height < 0))(
315
333
  js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
316
334
  ),
317
335
  )
318
336
 
319
- a_ref, R, K, D = self._regularizers(
320
- model=model,
321
- penetration=δ,
322
- velocity=velocity,
323
- parameters=self.parameters,
324
- )
337
+ # Compute the regularization terms.
338
+ a_ref, R, K, D = self._regularizers(
339
+ model=model,
340
+ penetration=δ,
341
+ velocity=velocity,
342
+ parameters=data.contacts_params,
343
+ )
325
344
 
345
+ # Compute the Delassus matrix and the free mixed linear acceleration of
346
+ # the collidable points.
326
347
  G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
327
348
  CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
328
349
 
@@ -330,26 +351,40 @@ class RelaxedRigidContacts(ContactModel):
330
351
  A = G + R
331
352
  b = CW_al_free_WC - a_ref
332
353
 
354
+ # Create the objective function to minimize as a lambda computing the cost
355
+ # from the optimized variables x.
333
356
  objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
334
357
 
358
+ # ========================================
359
+ # Helper function to run the L-BFGS solver
360
+ # ========================================
361
+
335
362
  def run_optimization(
336
- init_params: jtp.Array,
363
+ init_params: jtp.Vector,
337
364
  fun: Callable,
338
- opt: optax.GradientTransformation,
339
- maxiter: jtp.Int,
340
- tol: jtp.Float,
341
- **kwargs,
342
- ):
365
+ opt: optax.GradientTransformationExtraArgs,
366
+ maxiter: int,
367
+ tol: float,
368
+ ) -> tuple[jtp.Vector, optax.OptState]:
369
+
370
+ # Get the function to compute the loss and the gradient w.r.t. its inputs.
343
371
  value_and_grad_fn = optax.value_and_grad_from_state(fun)
344
372
 
345
- def step(carry):
373
+ # Initialize the carry of the following loop.
374
+ OptimizationCarry = tuple[jtp.Vector, optax.OptState]
375
+ init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
376
+
377
+ def step(carry: OptimizationCarry) -> OptimizationCarry:
378
+
346
379
  params, state = carry
380
+
347
381
  value, grad = value_and_grad_fn(
348
382
  params,
349
383
  state=state,
350
384
  A=A,
351
385
  b=b,
352
386
  )
387
+
353
388
  updates, state = opt.update(
354
389
  updates=grad,
355
390
  state=state,
@@ -360,22 +395,32 @@ class RelaxedRigidContacts(ContactModel):
360
395
  A=A,
361
396
  b=b,
362
397
  )
398
+
363
399
  params = optax.apply_updates(params, updates)
400
+
364
401
  return params, state
365
402
 
366
- def continuing_criterion(carry):
403
+ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
404
+
367
405
  _, state = carry
406
+
368
407
  iter_num = optax.tree_utils.tree_get(state, "count")
369
408
  grad = optax.tree_utils.tree_get(state, "grad")
370
409
  err = optax.tree_utils.tree_l2_norm(grad)
410
+
371
411
  return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
372
412
 
373
- init_carry = (init_params, opt.init(init_params))
374
413
  final_params, final_state = jax.lax.while_loop(
375
414
  continuing_criterion, step, init_carry
376
415
  )
416
+
377
417
  return final_params, final_state
378
418
 
419
+ # ======================================
420
+ # Compute the contact forces with L-BFGS
421
+ # ======================================
422
+
423
+ # Initialize the optimized forces with a linear Hunt/Crossley model.
379
424
  init_params = (
380
425
  K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
381
426
  + D[:, jnp.newaxis] * velocity
@@ -390,28 +435,30 @@ class RelaxedRigidContacts(ContactModel):
390
435
  maxiter = solver_options.pop("maxiter")
391
436
 
392
437
  # Compute the 3D linear force in C[W] frame.
393
- CW_f_Ci, _ = run_optimization(
438
+ solution, _ = run_optimization(
394
439
  init_params=init_params,
395
- A=A,
396
- b=b,
397
- maxiter=maxiter,
398
- opt=optax.lbfgs(**solver_options),
399
440
  fun=objective,
441
+ opt=optax.lbfgs(**solver_options),
400
442
  tol=tol,
443
+ maxiter=maxiter,
401
444
  )
402
445
 
403
- CW_f_Ci = CW_f_Ci.reshape((-1, 3))
404
-
405
- def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
406
- W_Xf_CW = Adjoint.from_transform(
407
- W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
408
- inverse=True,
409
- ).T
410
- return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
411
-
412
- W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
446
+ # Reshape the optimized solution to be a matrix of 3D contact forces.
447
+ CW_fl_C = solution.reshape(-1, 3)
448
+
449
+ # Convert the contact forces from mixed to inertial-fixed representation.
450
+ W_f_C = jax.vmap(
451
+ lambda CW_fl_C, W_H_C: (
452
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
453
+ array=jnp.zeros(6).at[0:3].set(CW_fl_C),
454
+ transform=W_H_C,
455
+ other_representation=VelRepr.Mixed,
456
+ is_force=True,
457
+ )
458
+ ),
459
+ )(CW_fl_C, W_H_C)
413
460
 
414
- return W_f_C, (None,)
461
+ return W_f_C, ()
415
462
 
416
463
  @staticmethod
417
464
  def _regularizers(
@@ -433,13 +480,28 @@ class RelaxedRigidContacts(ContactModel):
433
480
  A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
434
481
  """
435
482
 
436
- Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
437
- parameters
483
+ # Extract the parameters of the contact model.
484
+ Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (
485
+ getattr(parameters, field)
486
+ for field in (
487
+ "time_constant",
488
+ "damping_coefficient",
489
+ "d_min",
490
+ "d_max",
491
+ "width",
492
+ "midpoint",
493
+ "power",
494
+ "stiffness",
495
+ "damping",
496
+ "mu",
497
+ )
438
498
  )
439
499
 
440
- def _imp_aref(
441
- penetration: jtp.Array,
442
- velocity: jtp.Array,
500
+ # Compute the 6D inertia matrices of all links.
501
+ M_L = js.model.link_spatial_inertia_matrices(model=model)
502
+
503
+ def imp_aref(
504
+ penetration: jtp.Array, velocity: jtp.Array
443
505
  ) -> tuple[jtp.Array, jtp.Array]:
444
506
  """
445
507
  Calculates impedance and offset acceleration in constraint frame.
@@ -474,7 +536,7 @@ class RelaxedRigidContacts(ContactModel):
474
536
 
475
537
  return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
476
538
 
477
- def _compute_row(
539
+ def compute_row(
478
540
  *,
479
541
  link_idx: jtp.Float,
480
542
  penetration: jtp.Array,
@@ -482,7 +544,7 @@ class RelaxedRigidContacts(ContactModel):
482
544
  ) -> tuple[jtp.Array, jtp.Array]:
483
545
 
484
546
  # Compute the reference acceleration.
485
- ξ, a_ref, K, D = _imp_aref(
547
+ ξ, a_ref, K, D = imp_aref(
486
548
  penetration=penetration,
487
549
  velocity=velocity,
488
550
  )
@@ -496,12 +558,10 @@ class RelaxedRigidContacts(ContactModel):
496
558
 
497
559
  return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
498
560
 
499
- M_L = js.model.link_spatial_inertia_matrices(model=model)
500
-
501
561
  a_ref, R, K, D = jax.tree.map(
502
- jnp.concatenate,
503
- (
504
- *jax.vmap(_compute_row)(
562
+ f=jnp.concatenate,
563
+ tree=(
564
+ *jax.vmap(compute_row)(
505
565
  link_idx=jnp.array(
506
566
  model.kin_dyn_parameters.contact_parameters.body
507
567
  ),
@@ -510,4 +570,5 @@ class RelaxedRigidContacts(ContactModel):
510
570
  ),
511
571
  ),
512
572
  )
573
+
513
574
  return a_ref, jnp.diag(R), K, D
@@ -13,6 +13,7 @@ from jaxsim import logging
13
13
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
14
14
  from jaxsim.terrain import FlatTerrain, Terrain
15
15
 
16
+ from . import common
16
17
  from .common import ContactModel, ContactsParams
17
18
 
18
19
  try:
@@ -170,46 +171,6 @@ class RigidContacts(ContactModel):
170
171
  _solver_options_values=tuple(solver_options.values()),
171
172
  )
172
173
 
173
- @staticmethod
174
- def detect_contacts(
175
- W_p_C: jtp.ArrayLike,
176
- terrain_height: jtp.ArrayLike,
177
- ) -> tuple[jtp.Vector, jtp.Vector]:
178
- """
179
- Detect contacts between the collidable points and the terrain.
180
-
181
- Args:
182
- W_p_C: The position of the collidable points.
183
- terrain_height: The height of the terrain at the collidable point position.
184
-
185
- Returns:
186
- A tuple containing the activation state of the collidable points
187
- and the contact penetration depth h.
188
- """
189
-
190
- # TODO: reduce code duplication with js.contact.in_contact
191
- def detect_contact(
192
- W_p_C: jtp.ArrayLike,
193
- terrain_height: jtp.FloatLike,
194
- ) -> tuple[jtp.Bool, jtp.Float]:
195
- """
196
- Detect contacts between the collidable points and the terrain.
197
- """
198
-
199
- # Unpack the position of the collidable point.
200
- _, _, pz = W_p_C.squeeze()
201
-
202
- inactive = pz > terrain_height
203
-
204
- # Compute contact penetration depth
205
- h = jnp.maximum(0.0, terrain_height - pz)
206
-
207
- return inactive, h
208
-
209
- inactive_collidable_points, h = jax.vmap(detect_contact)(W_p_C, terrain_height)
210
-
211
- return inactive_collidable_points, h
212
-
213
174
  @staticmethod
214
175
  def compute_impact_velocity(
215
176
  inactive_collidable_points: jtp.ArrayLike,
@@ -281,7 +242,7 @@ class RigidContacts(ContactModel):
281
242
  *,
282
243
  link_forces: jtp.MatrixLike | None = None,
283
244
  joint_force_references: jtp.VectorLike | None = None,
284
- ) -> tuple[jtp.Vector, tuple[Any, ...]]:
245
+ ) -> tuple[jtp.Matrix, tuple]:
285
246
  """
286
247
  Compute the contact forces.
287
248
 
@@ -295,36 +256,41 @@ class RigidContacts(ContactModel):
295
256
  Optional `(n_joints,)` vector of joint forces.
296
257
 
297
258
  Returns:
298
- A tuple containing the contact forces.
259
+ A tuple containing as first element the computed contact forces.
299
260
  """
300
261
 
301
262
  # Initialize the model and data this contact model is operating on.
302
263
  # This will raise an exception if either the contact model or the
303
264
  # contact parameters are not compatible.
304
265
  model, data = self.initialize_model_and_data(model=model, data=data)
266
+ assert isinstance(data.contacts_params, RigidContactsParams)
305
267
 
306
- # Import qpax just in this method
268
+ # Import qpax privately just in this method.
307
269
  import qpax
308
270
 
309
- link_forces = (
310
- link_forces
271
+ link_forces = jnp.atleast_2d(
272
+ jnp.array(link_forces, dtype=float).squeeze()
311
273
  if link_forces is not None
312
274
  else jnp.zeros((model.number_of_links(), 6))
313
275
  )
314
276
 
315
- joint_force_references = (
316
- joint_force_references
277
+ joint_force_references = jnp.atleast_1d(
278
+ jnp.array(joint_force_references, dtype=float).squeeze()
317
279
  if joint_force_references is not None
318
280
  else jnp.zeros((model.number_of_joints(),))
319
281
  )
320
282
 
321
- # Compute kin-dyn quantities used in the contact model
283
+ # Compute kin-dyn quantities used in the contact model.
322
284
  with data.switch_velocity_representation(VelRepr.Mixed):
285
+
286
+ BW_ν = data.generalized_velocity()
287
+
323
288
  M = js.model.free_floating_mass_matrix(model=model, data=data)
289
+
324
290
  J_WC = js.contact.jacobian(model=model, data=data)
291
+ J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
292
+
325
293
  W_H_C = js.contact.transforms(model=model, data=data)
326
- J̇_WC_BW = js.contact.jacobian_derivative(model=model, data=data)
327
- BW_ν = data.generalized_velocity()
328
294
 
329
295
  # Compute the position and linear velocities (mixed representation) of
330
296
  # all collidable points belonging to the robot.
@@ -332,23 +298,16 @@ class RigidContacts(ContactModel):
332
298
  model=model, data=data
333
299
  )
334
300
 
335
- terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1])
336
- n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0]
301
+ # Get the number of collidable points.
302
+ n_collidable_points = len(model.kin_dyn_parameters.contact_parameters.body)
337
303
 
338
- # Compute the activation state of the collidable points
339
- inactive_collidable_points, h = RigidContacts.detect_contacts(
340
- W_p_C=position,
341
- terrain_height=terrain_height,
342
- )
343
-
344
- # Compute the Delassus matrix.
345
- delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
346
-
347
- # Add regularization for better numerical conditioning.
348
- delassus_matrix = delassus_matrix + self.regularization_delassus * jnp.eye(
349
- delassus_matrix.shape[0]
304
+ # Compute the penetration depth and velocity of the collidable points.
305
+ # Note that this function considers the penetration in the normal direction.
306
+ δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
307
+ position, velocity, model.terrain
350
308
  )
351
309
 
310
+ # Build a references object to simplify converting link forces.
352
311
  references = js.references.JaxSimModelReferences.build(
353
312
  model=model,
354
313
  data=data,
@@ -357,10 +316,12 @@ class RigidContacts(ContactModel):
357
316
  joint_force_references=joint_force_references,
358
317
  )
359
318
 
319
+ # Compute the generalized free acceleration.
360
320
  with (
361
321
  references.switch_velocity_representation(VelRepr.Mixed),
362
322
  data.switch_velocity_representation(VelRepr.Mixed),
363
323
  ):
324
+
364
325
  BW_ν̇_free = jnp.hstack(
365
326
  js.ode.system_acceleration(
366
327
  model=model,
@@ -372,64 +333,74 @@ class RigidContacts(ContactModel):
372
333
  )
373
334
  )
374
335
 
336
+ # Compute the free linear acceleration of the collidable points.
337
+ # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
375
338
  free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
376
339
  BW_nu=BW_ν,
377
340
  BW_nu_dot=BW_ν̇_free,
378
341
  CW_J_WC_BW=J_WC,
379
- CW_J_dot_WC_BW=J̇_WC_BW,
342
+ CW_J_dot_WC_BW=J̇_WC,
380
343
  ).flatten()
381
344
 
382
- # Compute stabilization term
383
- ḣ = velocity[:, 2].squeeze()
345
+ # Compute stabilization term.
384
346
  baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
385
- inactive_collidable_points=inactive_collidable_points,
386
- h=h,
387
- ḣ=ḣ,
388
- K=self.parameters.K,
389
- D=self.parameters.D,
347
+ inactive_collidable_points=(δ <= 0),
348
+ δ=δ,
349
+ δ_dot=δ_dot,
350
+ n=n̂,
351
+ K=data.contacts_params.K,
352
+ D=data.contacts_params.D,
390
353
  ).flatten()
391
354
 
392
- free_contact_acc -= baumgarte_term
355
+ # Compute the Delassus matrix.
356
+ delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
357
+
358
+ # Initialize regularization term of the Delassus matrix for
359
+ # better numerical conditioning.
360
+ Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])
361
+
362
+ # Construct the quadratic cost function.
363
+ Q = delassus_matrix + Iε
364
+ q = free_contact_acc - baumgarte_term
393
365
 
394
- # Setup optimization problem
395
- Q = delassus_matrix
396
- q = free_contact_acc
366
+ # Construct the inequality constraints.
397
367
  G = RigidContacts._compute_ineq_constraint_matrix(
398
- inactive_collidable_points=inactive_collidable_points, mu=self.parameters.mu
368
+ inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu
399
369
  )
400
370
  h_bounds = RigidContacts._compute_ineq_bounds(
401
371
  n_collidable_points=n_collidable_points
402
372
  )
373
+
374
+ # Construct the equality constraints.
403
375
  A = jnp.zeros((0, 3 * n_collidable_points))
404
376
  b = jnp.zeros((0,))
405
377
 
406
- # Solve the optimization problem
407
- solution, *_ = qpax.solve_qp(
378
+ # Solve the following optimization problem with qpax:
379
+ #
380
+ # min_{x} 0.5 x⊤ Q x + q⊤ x
381
+ #
382
+ # s.t. A x = b
383
+ # G x ≤ h
384
+ #
385
+ # TODO: add possibility to notify if the QP problem did not converge.
386
+ solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841
408
387
  Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
409
388
  )
410
389
 
411
- f_C_lin = solution.reshape(-1, 3)
412
-
413
- # Transform linear contact forces to 6D
414
- CW_f_C = jnp.hstack(
415
- (
416
- f_C_lin,
417
- jnp.zeros((f_C_lin.shape[0], 3)),
418
- )
419
- )
390
+ # Reshape the optimized solution to be a matrix of 3D contact forces.
391
+ CW_fl_C = solution.reshape(-1, 3)
420
392
 
421
- # Transform the contact forces to inertial-fixed representation
393
+ # Convert the contact forces from mixed to inertial-fixed representation.
422
394
  W_f_C = jax.vmap(
423
- lambda CW_f_C, W_H_C: ModelDataWithVelocityRepresentation.other_representation_to_inertial(
424
- array=CW_f_C,
425
- transform=W_H_C,
426
- other_representation=VelRepr.Mixed,
427
- is_force=True,
395
+ lambda CW_fl_C, W_H_C: (
396
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
397
+ array=jnp.zeros(6).at[0:3].set(CW_fl_C),
398
+ transform=W_H_C,
399
+ other_representation=VelRepr.Mixed,
400
+ is_force=True,
401
+ )
428
402
  ),
429
- )(
430
- CW_f_C,
431
- W_H_C,
432
- )
403
+ )(CW_fl_C, W_H_C)
433
404
 
434
405
  return W_f_C, ()
435
406
 
@@ -438,6 +409,7 @@ class RigidContacts(ContactModel):
438
409
  M: jtp.MatrixLike,
439
410
  J_WC: jtp.MatrixLike,
440
411
  ) -> jtp.Matrix:
412
+
441
413
  sl = jnp.s_[:, 0:3, :]
442
414
  J_WC_lin = jnp.vstack(J_WC[sl])
443
415
 
@@ -448,6 +420,7 @@ class RigidContacts(ContactModel):
448
420
  def _compute_ineq_constraint_matrix(
449
421
  inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
450
422
  ) -> jtp.Matrix:
423
+
451
424
  def compute_G_single_point(mu: float, c: float) -> jtp.Matrix:
452
425
  """
453
426
  Compute the inequality constraint matrix for a single collidable point
@@ -475,6 +448,7 @@ class RigidContacts(ContactModel):
475
448
 
476
449
  @staticmethod
477
450
  def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector:
451
+
478
452
  n_constraints = 6 * n_collidable_points
479
453
  return jnp.zeros(shape=(n_constraints,))
480
454
 
@@ -485,45 +459,50 @@ class RigidContacts(ContactModel):
485
459
  CW_J_WC_BW: jtp.MatrixLike,
486
460
  CW_J_dot_WC_BW: jtp.MatrixLike,
487
461
  ) -> jtp.Matrix:
488
- CW_J̇_WC_BW = CW_J_dot_WC_BW
462
+
489
463
  BW_ν = BW_nu
490
464
  BW_ν̇ = BW_nu_dot
465
+ CW_J̇_WC_BW = CW_J_dot_WC_BW
491
466
 
467
+ # Compute the linear acceleration of the collidable points.
468
+ # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
492
469
  CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
493
- CW_a_WC = CW_a_WC.reshape(-1, 6)
494
470
 
471
+ CW_a_WC = CW_a_WC.reshape(-1, 6)
495
472
  return CW_a_WC[:, 0:3].squeeze()
496
473
 
497
474
  @staticmethod
498
475
  def _compute_baumgarte_stabilization_term(
499
476
  inactive_collidable_points: jtp.ArrayLike,
500
- h: jtp.ArrayLike,
501
- ḣ: jtp.ArrayLike,
477
+ δ: jtp.ArrayLike,
478
+ δ_dot: jtp.ArrayLike,
479
+ n: jtp.ArrayLike,
502
480
  K: jtp.FloatLike,
503
481
  D: jtp.FloatLike,
504
482
  ) -> jtp.Array:
505
- def baumgarte_stabilization(
483
+
484
+ def baumgarte_stabilization_of_single_point(
506
485
  inactive: jtp.BoolLike,
507
- h: jtp.FloatLike,
508
- ḣ: jtp.FloatLike,
486
+ δ: jtp.FloatLike,
487
+ δ_dot: jtp.FloatLike,
488
+ n: jtp.ArrayLike,
509
489
  k_baumgarte: jtp.FloatLike,
510
490
  d_baumgarte: jtp.FloatLike,
511
491
  ) -> jtp.Array:
492
+
512
493
  baumgarte_term = jax.lax.cond(
513
494
  inactive,
514
- lambda h, ḣ, K, D: jnp.zeros(shape=(3,)),
515
- lambda h, ḣ, K, D: jnp.zeros(shape=(3,)).at[2].set(K * h + D * ),
516
- *(
517
- h,
518
- ḣ,
519
- k_baumgarte,
520
- d_baumgarte,
521
- ),
495
+ lambda δ, δ_dot, n, K, D: jnp.zeros(3),
496
+ # This is equivalent to: K*(pT - p)⋅n̂ + D*(0 - v)⋅n̂,
497
+ # where pT is the point on the terrain surface vertical to p.
498
+ lambda δ, δ_dot, n, K, D: (K * δ + D * δ_dot) * n,
499
+ *(δ, δ_dot, n, k_baumgarte, d_baumgarte),
522
500
  )
501
+
523
502
  return baumgarte_term
524
503
 
525
504
  baumgarte_term = jax.vmap(
526
- baumgarte_stabilization, in_axes=(0, 0, 0, None, None)
527
- )(inactive_collidable_points, h, ḣ, K, D)
505
+ baumgarte_stabilization_of_single_point, in_axes=(0, 0, 0, 0, None, None)
506
+ )(inactive_collidable_points, δ, δ_dot, n, K, D)
528
507
 
529
508
  return baumgarte_term
@@ -14,7 +14,7 @@ from jaxsim import logging
14
14
  from jaxsim.math import StandardGravity
15
15
  from jaxsim.terrain import FlatTerrain, Terrain
16
16
 
17
- from .common import ContactModel, ContactsParams
17
+ from . import common
18
18
 
19
19
  try:
20
20
  from typing import Self
@@ -23,7 +23,7 @@ except ImportError:
23
23
 
24
24
 
25
25
  @jax_dataclasses.pytree_dataclass
26
- class SoftContactsParams(ContactsParams):
26
+ class SoftContactsParams(common.ContactsParams):
27
27
  """Parameters of the soft contacts model."""
28
28
 
29
29
  K: jtp.Float = dataclasses.field(
@@ -161,7 +161,9 @@ class SoftContactsParams(ContactsParams):
161
161
  f_average = m * g / number_of_active_collidable_points_steady_state
162
162
 
163
163
  # Compute the stiffness to get the desired steady-state penetration.
164
- K = f_average / jnp.power(δ_max, 3 / 2)
164
+ # Note that this is dependent on the non-linear exponent used in
165
+ # the damping term of the Hunt/Crossley model.
166
+ K = f_average / jnp.power(δ_max, 1 + p)
165
167
 
166
168
  # Compute the damping using the damping ratio.
167
169
  critical_damping = 2 * jnp.sqrt(K * m)
@@ -189,7 +191,7 @@ class SoftContactsParams(ContactsParams):
189
191
 
190
192
 
191
193
  @jax_dataclasses.pytree_dataclass
192
- class SoftContacts(ContactModel):
194
+ class SoftContacts(common.ContactModel):
193
195
  """Soft contacts model."""
194
196
 
195
197
  parameters: SoftContactsParams = dataclasses.field(
@@ -277,9 +279,7 @@ class SoftContacts(ContactModel):
277
279
  μ = mu
278
280
 
279
281
  # Compute the penetration depth, its rate, and the considered terrain normal.
280
- δ, δ̇, n̂ = SoftContacts.compute_penetration_data(
281
- p=W_p_C, v=W_ṗ_C, terrain=terrain
282
- )
282
+ δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
283
283
 
284
284
  # There are few operations like computing the norm of a vector with zero length
285
285
  # or computing the square root of zero that are problematic in an AD context.
@@ -423,7 +423,18 @@ class SoftContacts(ContactModel):
423
423
  self,
424
424
  model: js.model.JaxSimModel,
425
425
  data: js.data.JaxSimModelData,
426
- ) -> tuple[jtp.Vector, tuple[jtp.Vector]]:
426
+ ) -> tuple[jtp.Matrix, tuple[jtp.Matrix]]:
427
+ """
428
+ Compute the contact forces.
429
+
430
+ Args:
431
+ model: The model to consider.
432
+ data: The data of the considered model.
433
+
434
+ Returns:
435
+ A tuple containing as first element the computed contact forces, and as
436
+ second element the derivative of the material deformation.
437
+ """
427
438
 
428
439
  # Initialize the model and data this contact model is operating on.
429
440
  # This will raise an exception if either the contact model or the
@@ -444,36 +455,9 @@ class SoftContacts(ContactModel):
444
455
  position=p,
445
456
  velocity=v,
446
457
  tangential_deformation=m,
447
- parameters=self.parameters,
448
- terrain=self.terrain,
458
+ parameters=data.contacts_params,
459
+ terrain=model.terrain,
449
460
  )
450
461
  )(W_p_C, W_ṗ_C, m)
451
462
 
452
463
  return W_f, (ṁ,)
453
-
454
- @staticmethod
455
- @jax.jit
456
- def compute_penetration_data(
457
- p: jtp.VectorLike,
458
- v: jtp.VectorLike,
459
- terrain: jaxsim.terrain.Terrain,
460
- ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
461
-
462
- # Pre-process the position and the linear velocity of the collidable point.
463
- W_ṗ_C = jnp.array(v).squeeze()
464
- px, py, pz = jnp.array(p).squeeze()
465
-
466
- # Compute the terrain normal and the contact depth.
467
- n̂ = terrain.normal(x=px, y=py).squeeze()
468
- h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
469
-
470
- # Compute the penetration depth normal to the terrain.
471
- δ = jnp.maximum(0.0, jnp.dot(h, n̂))
472
-
473
- # Compute the penetration normal velocity.
474
- δ̇ = -jnp.dot(W_ṗ_C, n̂)
475
-
476
- # Enforce the penetration rate to be zero when the penetration depth is zero.
477
- δ̇ = jnp.where(δ > 0, δ̇, 0.0)
478
-
479
- return δ, δ̇, n̂
@@ -195,7 +195,7 @@ class ViscoElasticContacts(common.ContactModel):
195
195
  default_factory=FlatTerrain
196
196
  )
197
197
 
198
- max_squarings: jax_dataclasses.Static[int] = 25
198
+ max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25)
199
199
 
200
200
  @classmethod
201
201
  def build(
@@ -239,7 +239,7 @@ class ViscoElasticContacts(common.ContactModel):
239
239
  parameters=parameters,
240
240
  terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
241
241
  max_squarings=int(
242
- max_squarings or cls.__dataclass_fields__["max_squarings"].default()
242
+ max_squarings or cls.__dataclass_fields__["max_squarings"].default
243
243
  ),
244
244
  )
245
245
 
@@ -266,7 +266,7 @@ class ViscoElasticContacts(common.ContactModel):
266
266
  dt: jtp.FloatLike | None = None,
267
267
  link_forces: jtp.MatrixLike | None = None,
268
268
  joint_force_references: jtp.VectorLike | None = None,
269
- ) -> tuple[jtp.Vector, tuple[Any, ...]]:
269
+ ) -> tuple[jtp.Matrix, tuple[jtp.Matrix, jtp.Matrix]]:
270
270
  """
271
271
  Compute the contact forces.
272
272
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev186
3
+ Version: 0.4.3.dev200
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,18 +1,18 @@
1
1
  jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
- jaxsim/_version.py,sha256=K0Qt3IiihQ28Vnsxow16UNJNyDGyw4M94790Je5aXw8,428
2
+ jaxsim/_version.py,sha256=WDziMJEeSmuE81cozOtxmazlb4qAX6VPTrKOR0f3akg,428
3
3
  jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
- jaxsim/api/contact.py,sha256=2qBIStXWxIJTrW3Eyx6UPQcTXDXalOyasgD7Pk1_v1E,24486
10
- jaxsim/api/data.py,sha256=sfu_uJtYRQIf_sg9IWzR95McRdZgtHwArAuzF6KD-1A,28939
9
+ jaxsim/api/contact.py,sha256=Egc62310ljn5goXlswwJYSB-LyW6M5gmPoT_a3mkd7U,25812
10
+ jaxsim/api/data.py,sha256=gQX6hfEaw0ooJYvpr5f8UvEJwqhtflEK_NHWn9XgTZY,28935
11
11
  jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
12
12
  jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=thJbz9XhpXgom23S6MXX2ugxGoAD-k947ZMAHDisy2w,29620
14
14
  jaxsim/api/link.py,sha256=LAA6ZMQXkWomXeptURBtc7z3_xDZ2BBnBMhVrohh0bE,18621
15
- jaxsim/api/model.py,sha256=-Au3Xdm3TJLyKN_r06pr9G99zmzjNhDI4KZz4xox7iE,69783
15
+ jaxsim/api/model.py,sha256=s2i4obxMjZ_XntJgT0dEV57LCo0GIC7VppUnxsqC1fc,69704
16
16
  jaxsim/api/ode.py,sha256=J_WuaoPl3ZY-yvTrCQun-rQoIAv_duynSXAGxqx93sg,14211
17
17
  jaxsim/api/ode_data.py,sha256=1SD-x-lYk_YSEnVpxTLd69uOKC0mFUj44ZqpSmEDOxw,20190
18
18
  jaxsim/api/references.py,sha256=fW77LitZ8DYgT6ZmUInJfm5luBV1mTcqcNRiC_i79og,20862
@@ -53,20 +53,20 @@ jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdul
53
53
  jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
54
54
  jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
55
55
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
56
- jaxsim/rbda/contacts/__init__.py,sha256=8-JvjjuCkGf-ORMNnTe641zfamagmiFqZWzNO3cneWE,362
57
- jaxsim/rbda/contacts/common.py,sha256=_yrxTM16Je9ck5aM95ndk8Kwu_oijxG9Jaf1jEjHEYw,4332
58
- jaxsim/rbda/contacts/relaxed_rigid.py,sha256=Ob5LdKe3D7tGlIdT4LamJ6_F0j5pzUmWNYoWqy8Di98,17169
59
- jaxsim/rbda/contacts/rigid.py,sha256=1TTiGXSOipO8l5FDTtxqRNo1ArCNtDg-Yr3olPgBLGs,17588
60
- jaxsim/rbda/contacts/soft.py,sha256=TMCUDtFmNIae04LCla57iXMjdt9F5qTFjYEnP5NdLFg,16809
61
- jaxsim/rbda/contacts/visco_elastic.py,sha256=tfcH4_JFox6_6PyR29kLlqc8pYN8WCslYnClxV7TnSU,39780
56
+ jaxsim/rbda/contacts/__init__.py,sha256=L5MM-2pv76YPGzxExdz2EErgGBATuAjYnNHlq5QOySs,503
57
+ jaxsim/rbda/contacts/common.py,sha256=iywCQtesrnrwywRQv8cjyot2bG11dT_iONyF8OJztIA,5798
58
+ jaxsim/rbda/contacts/relaxed_rigid.py,sha256=TR81tJ4ipcpvPnwlfkpyNDhvWizpEG542SFVu_CwHRU,19614
59
+ jaxsim/rbda/contacts/rigid.py,sha256=3aDPFrIm2_QpKKRpTqJJk8qBK-W63gq7Arc8WDVAcHc,17382
60
+ jaxsim/rbda/contacts/soft.py,sha256=6eFgV2hJK793RZfoY8oSqw-zC1UqFldaE0hfGHELnmU,16325
61
+ jaxsim/rbda/contacts/visco_elastic.py,sha256=wATvBhLrV-7IyVLJhW7OaMg_HDAmczl_8MnYm3wuqSc,39819
62
62
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
63
63
  jaxsim/terrain/terrain.py,sha256=K91HEzPqTSyNrc_j1KfAAEF_5oDeuk_-jnnZGrcMEcY,5015
64
64
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
65
65
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
66
66
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
67
67
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
68
- jaxsim-0.4.3.dev186.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
- jaxsim-0.4.3.dev186.dist-info/METADATA,sha256=YYb7FonjyOeyop0Ni-0dB0ijfk215c0-PZo6k9v6JAo,17276
70
- jaxsim-0.4.3.dev186.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
71
- jaxsim-0.4.3.dev186.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
- jaxsim-0.4.3.dev186.dist-info/RECORD,,
68
+ jaxsim-0.4.3.dev200.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
+ jaxsim-0.4.3.dev200.dist-info/METADATA,sha256=NUJ6GXIFFK-y9-p-M2OTTdI3g7utYHa3Lsg2VXSXtoI,17276
70
+ jaxsim-0.4.3.dev200.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
71
+ jaxsim-0.4.3.dev200.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
+ jaxsim-0.4.3.dev200.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5