jaxsim 0.6.2.dev182__py3-none-any.whl → 0.6.2.dev225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- import functools
5
4
  from collections.abc import Callable
6
5
  from typing import Any
7
6
 
@@ -11,12 +10,10 @@ import jax_dataclasses
11
10
  import optax
12
11
 
13
12
  import jaxsim.api as js
14
- import jaxsim.rbda.contacts
15
13
  import jaxsim.typing as jtp
16
14
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
17
- from jaxsim.terrain.terrain import Terrain
18
15
 
19
- from . import common
16
+ from . import common, soft
20
17
 
21
18
  try:
22
19
  from typing import Self
@@ -64,12 +61,12 @@ class RelaxedRigidContactsParams(common.ContactsParams):
64
61
  )
65
62
 
66
63
  # Stiffness
67
- stiffness: jtp.Float = dataclasses.field(
64
+ K: jtp.Float = dataclasses.field(
68
65
  default_factory=lambda: jnp.array(0.0, dtype=float)
69
66
  )
70
67
 
71
68
  # Damping
72
- damping: jtp.Float = dataclasses.field(
69
+ D: jtp.Float = dataclasses.field(
73
70
  default_factory=lambda: jnp.array(0.0, dtype=float)
74
71
  )
75
72
 
@@ -90,13 +87,16 @@ class RelaxedRigidContactsParams(common.ContactsParams):
90
87
  HashedNumpyArray(self.width),
91
88
  HashedNumpyArray(self.midpoint),
92
89
  HashedNumpyArray(self.power),
93
- HashedNumpyArray(self.stiffness),
94
- HashedNumpyArray(self.damping),
90
+ HashedNumpyArray(self.K),
91
+ HashedNumpyArray(self.D),
95
92
  HashedNumpyArray(self.mu),
96
93
  )
97
94
  )
98
95
 
99
96
  def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
97
+ if not isinstance(other, RelaxedRigidContactsParams):
98
+ return False
99
+
100
100
  return hash(self) == hash(other)
101
101
 
102
102
  @classmethod
@@ -110,9 +110,10 @@ class RelaxedRigidContactsParams(common.ContactsParams):
110
110
  width: jtp.FloatLike | None = None,
111
111
  midpoint: jtp.FloatLike | None = None,
112
112
  power: jtp.FloatLike | None = None,
113
- stiffness: jtp.FloatLike | None = None,
114
- damping: jtp.FloatLike | None = None,
113
+ K: jtp.FloatLike | None = None,
114
+ D: jtp.FloatLike | None = None,
115
115
  mu: jtp.FloatLike | None = None,
116
+ **kwargs,
116
117
  ) -> Self:
117
118
  """Create a `RelaxedRigidContactsParams` instance."""
118
119
 
@@ -151,13 +152,11 @@ class RelaxedRigidContactsParams(common.ContactsParams):
151
152
  power=jnp.array(
152
153
  power if power is not None else default("power"), dtype=float
153
154
  ),
154
- stiffness=jnp.array(
155
- stiffness if stiffness is not None else default("stiffness"),
155
+ K=jnp.array(
156
+ K if K is not None else default("K"),
156
157
  dtype=float,
157
158
  ),
158
- damping=jnp.array(
159
- damping if damping is not None else default("damping"), dtype=float
160
- ),
159
+ D=jnp.array(D if D is not None else default("D"), dtype=float),
161
160
  mu=jnp.array(mu if mu is not None else default("mu"), dtype=float),
162
161
  )
163
162
 
@@ -243,6 +242,37 @@ class RelaxedRigidContacts(common.ContactModel):
243
242
  **kwargs,
244
243
  )
245
244
 
245
+ def update_contact_state(
246
+ self: type[Self], old_contact_state: dict[str, jtp.Array]
247
+ ) -> dict[str, jtp.Array]:
248
+ """
249
+ Update the contact state.
250
+
251
+ Args:
252
+ old_contact_state: The old contact state.
253
+
254
+ Returns:
255
+ The updated contact state.
256
+ """
257
+
258
+ return {}
259
+
260
+ def update_velocity_after_impact(
261
+ self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
262
+ ) -> js.data.JaxSimModelData:
263
+ """
264
+ Update the velocity after an impact.
265
+
266
+ Args:
267
+ model: The robot model considered by the contact model.
268
+ data: The data of the considered model.
269
+
270
+ Returns:
271
+ The updated data of the considered model.
272
+ """
273
+
274
+ return data
275
+
246
276
  @jax.jit
247
277
  def compute_contact_forces(
248
278
  self,
@@ -314,11 +344,11 @@ class RelaxedRigidContacts(common.ContactModel):
314
344
  BW_ν = data.generalized_velocity
315
345
 
316
346
  BW_ν̇_free = jnp.hstack(
317
- js.ode.system_acceleration(
347
+ js.model.forward_dynamics_aba(
318
348
  model=model,
319
349
  data=data,
320
350
  link_forces=references.link_forces(model=model, data=data),
321
- joint_torques=references.joint_force_references(model=model),
351
+ joint_forces=references.joint_force_references(model=model),
322
352
  )
323
353
  )
324
354
 
@@ -341,7 +371,7 @@ class RelaxedRigidContacts(common.ContactModel):
341
371
  model=model,
342
372
  position_constraint=position_constraint,
343
373
  velocity_constraint=velocity,
344
- parameters=model.contacts_params,
374
+ parameters=model.contact_params,
345
375
  )
346
376
 
347
377
  # Compute the Delassus matrix and the free mixed linear acceleration of
@@ -425,7 +455,7 @@ class RelaxedRigidContacts(common.ContactModel):
425
455
 
426
456
  # Initialize the optimized forces with a linear Hunt/Crossley model.
427
457
  init_params = jax.vmap(
428
- lambda p, v: self._hunt_crossley_contact_model(
458
+ lambda p, v: soft.SoftContacts.hunt_crossley_contact_model(
429
459
  position=p,
430
460
  velocity=v,
431
461
  terrain=model.terrain,
@@ -460,16 +490,12 @@ class RelaxedRigidContacts(common.ContactModel):
460
490
  CW_fl_C = solution.reshape(-1, 3)
461
491
 
462
492
  # Convert the contact forces from mixed to inertial-fixed representation.
463
- W_f_C = jax.vmap(
464
- lambda CW_fl_C, W_H_C: (
465
- ModelDataWithVelocityRepresentation.other_representation_to_inertial(
466
- array=jnp.zeros(6).at[0:3].set(CW_fl_C),
467
- transform=W_H_C,
468
- other_representation=VelRepr.Mixed,
469
- is_force=True,
470
- )
471
- ),
472
- )(CW_fl_C, W_H_C)
493
+ W_f_C = ModelDataWithVelocityRepresentation.other_representation_to_inertial(
494
+ array=jnp.zeros((W_H_C.shape[0], 6)).at[:, :3].set(CW_fl_C),
495
+ transform=W_H_C,
496
+ other_representation=VelRepr.Mixed,
497
+ is_force=True,
498
+ )
473
499
 
474
500
  return W_f_C, {}
475
501
 
@@ -505,8 +531,8 @@ class RelaxedRigidContacts(common.ContactModel):
505
531
  "width",
506
532
  "midpoint",
507
533
  "power",
508
- "stiffness",
509
- "damping",
534
+ "K",
535
+ "D",
510
536
  "mu",
511
537
  )
512
538
  )
@@ -602,149 +628,3 @@ class RelaxedRigidContacts(common.ContactModel):
602
628
  )
603
629
 
604
630
  return a_ref, jnp.diag(R), K, D
605
-
606
- @staticmethod
607
- @functools.partial(jax.jit, static_argnames=("terrain",))
608
- def _hunt_crossley_contact_model(
609
- position: jtp.VectorLike,
610
- velocity: jtp.VectorLike,
611
- tangential_deformation: jtp.VectorLike,
612
- terrain: Terrain,
613
- K: jtp.FloatLike,
614
- D: jtp.FloatLike,
615
- mu: jtp.FloatLike,
616
- p: jtp.FloatLike = 0.5,
617
- q: jtp.FloatLike = 0.5,
618
- ) -> tuple[jtp.Vector, jtp.Vector]:
619
- """
620
- Compute the contact force using the Hunt/Crossley model.
621
-
622
- Args:
623
- position: The position of the collidable point.
624
- velocity: The velocity of the collidable point.
625
- tangential_deformation: The material deformation of the collidable point.
626
- terrain: The terrain model.
627
- K: The stiffness parameter.
628
- D: The damping parameter of the soft contacts model.
629
- mu: The static friction coefficient.
630
- p:
631
- The exponent p corresponding to the damping-related non-linearity
632
- of the Hunt/Crossley model.
633
- q:
634
- The exponent q corresponding to the spring-related non-linearity
635
- of the Hunt/Crossley model
636
-
637
- Returns:
638
- A tuple containing the computed contact force and the derivative of the
639
- material deformation.
640
- """
641
-
642
- # Convert the input vectors to arrays.
643
- W_p_C = jnp.array(position, dtype=float).squeeze()
644
- W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
645
- m = jnp.array(tangential_deformation, dtype=float).squeeze()
646
-
647
- # Use symbol for the static friction.
648
- μ = mu
649
-
650
- # Compute the penetration depth, its rate, and the considered terrain normal.
651
- δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
652
-
653
- # There are few operations like computing the norm of a vector with zero length
654
- # or computing the square root of zero that are problematic in an AD context.
655
- # To avoid these issues, we introduce a small tolerance ε to their arguments
656
- # and make sure that we do not check them against zero directly.
657
- ε = jnp.finfo(float).eps
658
-
659
- # Compute the powers of the penetration depth.
660
- # Inject ε to address AD issues in differentiating the square root when
661
- # p and q are fractional.
662
- δp = jnp.power(δ + ε, p)
663
- δq = jnp.power(δ + ε, q)
664
-
665
- # ========================
666
- # Compute the normal force
667
- # ========================
668
-
669
- # Non-linear spring-damper model (Hunt/Crossley model).
670
- # This is the force magnitude along the direction normal to the terrain.
671
- force_normal_mag = (K * δp) * δ + (D * δq) * δ̇
672
-
673
- # Depending on the magnitude of δ̇, the normal force could be negative.
674
- force_normal_mag = jnp.maximum(0.0, force_normal_mag)
675
-
676
- # Compute the 3D linear force in C[W] frame.
677
- f_normal = force_normal_mag * n̂
678
-
679
- # ============================
680
- # Compute the tangential force
681
- # ============================
682
-
683
- # Extract the tangential component of the velocity.
684
- v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
685
-
686
- # Extract the normal and tangential components of the material deformation.
687
- m_normal = jnp.dot(m, n̂) * n̂
688
- m_tangential = m - jnp.dot(m, n̂) * n̂
689
-
690
- # Compute the tangential force in the sticking case.
691
- # Using the tangential component of the material deformation should not be
692
- # necessary if the sticking-slipping transition occurs in a terrain area
693
- # with a locally constant normal. However, this assumption is not true in
694
- # general, especially for highly uneven terrains.
695
- f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
696
-
697
- # Detect the contact type (sticking or slipping).
698
- # Note that if there is no contact, sticking is set to True, and this detail
699
- # is exploited in the computation of the `contact_status` variable.
700
- sticking = jnp.logical_or(
701
- δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
702
- )
703
-
704
- # Compute the direction of the tangential force.
705
- # To prevent dividing by zero, we use a switch statement.
706
- norm = jaxsim.math.safe_norm(f_tangential)
707
- f_tangential_direction = f_tangential / (
708
- norm + jnp.finfo(float).eps * (norm == 0)
709
- )
710
-
711
- # Project the tangential force to the friction cone if slipping.
712
- f_tangential = jnp.where(
713
- sticking,
714
- f_tangential,
715
- jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
716
- )
717
-
718
- # Set the tangential force to zero if there is no contact.
719
- f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)
720
-
721
- # =====================================
722
- # Compute the material deformation rate
723
- # =====================================
724
-
725
- # Compute the derivative of the material deformation.
726
- # Note that we included an additional relaxation of `m_normal` in the
727
- # sticking case, so that the normal deformation that could have accumulated
728
- # from a previous slipping phase can relax to zero.
729
- ṁ_no_contact = -(K / D) * m
730
- ṁ_sticking = v_tangential - (K / D) * m_normal
731
- ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
732
-
733
- # Compute the contact status:
734
- # 0: slipping
735
- # 1: sticking
736
- # 2: no contact
737
- contact_status = sticking.astype(int)
738
- contact_status += (δ <= 0).astype(int)
739
-
740
- # Select the right material deformation rate depending on the contact status.
741
- ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)
742
-
743
- # ==========================================
744
- # Compute and return the final contact force
745
- # ==========================================
746
-
747
- # Sum the normal and tangential forces.
748
- CW_fl = f_normal + f_tangential
749
-
750
- return CW_fl, ṁ