jaxsim 0.6.2.dev182__py3-none-any.whl → 0.6.2.dev225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/api/integrators.py CHANGED
@@ -22,7 +22,7 @@ def semi_implicit_euler_integration(
22
22
  with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
23
23
 
24
24
  # Compute the system acceleration
25
- W_v̇_WB, s̈ = js.ode.system_acceleration(
25
+ W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration(
26
26
  model=model,
27
27
  data=data,
28
28
  link_forces=link_forces,
@@ -58,12 +58,18 @@ def semi_implicit_euler_integration(
58
58
  W_p_B = data.base_position + dt * W_ṗ_B
59
59
  W_Q_B = data.base_orientation + dt * W_Q̇_B
60
60
 
61
- base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B)
61
+ base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)
62
62
 
63
63
  W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm)
64
64
 
65
65
  s = data.joint_positions + dt * ṡ
66
66
 
67
+ integrated_contact_state = jax.tree.map(
68
+ lambda x, x_dot: x + dt * x_dot,
69
+ data.contact_state,
70
+ contact_state_derivative,
71
+ )
72
+
67
73
  # TODO: Avoid double replace, e.g. by computing cached value here
68
74
  data = dataclasses.replace(
69
75
  data,
@@ -73,6 +79,7 @@ def semi_implicit_euler_integration(
73
79
  _joint_velocities=ṡ,
74
80
  _base_linear_velocity=W_v_B[0:3],
75
81
  _base_angular_velocity=W_ω_WB,
82
+ contact_state=integrated_contact_state,
76
83
  )
77
84
 
78
85
  # Update the cached computations.
@@ -104,7 +111,7 @@ def rk4_integration(
104
111
  joint_torques=joint_torques,
105
112
  )
106
113
 
107
- base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion)
114
+ base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1)
108
115
  base_quaternion = data._base_quaternion / jnp.where(
109
116
  base_quaternion_norm == 0, 1.0, base_quaternion_norm
110
117
  )
@@ -116,6 +123,7 @@ def rk4_integration(
116
123
  base_linear_velocity=data._base_linear_velocity,
117
124
  base_angular_velocity=data._base_angular_velocity,
118
125
  joint_velocities=data._joint_velocities,
126
+ contact_state=data.contact_state,
119
127
  )
120
128
 
121
129
  euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt
@@ -136,14 +144,13 @@ def rk4_integration(
136
144
 
137
145
  data_tf = dataclasses.replace(
138
146
  data,
139
- **{
140
- "_base_position": x_tf["base_position"],
141
- "_base_quaternion": x_tf["base_quaternion"],
142
- "_joint_positions": x_tf["joint_positions"],
143
- "_base_linear_velocity": x_tf["base_linear_velocity"],
144
- "_base_angular_velocity": x_tf["base_angular_velocity"],
145
- "_joint_velocities": x_tf["joint_velocities"],
146
- },
147
+ _base_position=x_tf["base_position"],
148
+ _base_quaternion=x_tf["base_quaternion"],
149
+ _joint_positions=x_tf["joint_positions"],
150
+ _base_linear_velocity=x_tf["base_linear_velocity"],
151
+ _base_angular_velocity=x_tf["base_angular_velocity"],
152
+ _joint_velocities=x_tf["joint_velocities"],
153
+ contact_state=x_tf["contact_state"],
147
154
  )
148
155
 
149
156
  return data_tf.replace(model=model)
jaxsim/api/model.py CHANGED
@@ -47,13 +47,13 @@ class JaxSimModel(JaxsimDataclass):
47
47
  default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
48
48
  )
49
49
 
50
- gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY
50
+ gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY
51
51
 
52
52
  contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
53
53
  default=None, repr=False
54
54
  )
55
55
 
56
- contacts_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field(
56
+ contact_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field(
57
57
  default=None, repr=False
58
58
  )
59
59
 
@@ -177,9 +177,9 @@ class JaxSimModel(JaxsimDataclass):
177
177
  time_step=time_step,
178
178
  terrain=terrain,
179
179
  contact_model=contact_model,
180
- contacts_params=contact_params,
180
+ contact_params=contact_params,
181
181
  integrator=integrator,
182
- gravity=gravity,
182
+ gravity=-gravity,
183
183
  )
184
184
 
185
185
  # Store the origin of the model, in case downstream logic needs it.
@@ -197,7 +197,7 @@ class JaxSimModel(JaxsimDataclass):
197
197
  time_step: jtp.FloatLike | None = None,
198
198
  terrain: jaxsim.terrain.Terrain | None = None,
199
199
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
200
- contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
200
+ contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
201
201
  integrator: IntegratorType | None = None,
202
202
  gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
203
203
  ) -> JaxSimModel:
@@ -217,8 +217,8 @@ class JaxSimModel(JaxsimDataclass):
217
217
  The optional name of the model overriding the physics model name.
218
218
  contact_model:
219
219
  The contact model to consider.
220
- If not specified, a soft contacts model is used.
221
- contacts_params: The parameters of the soft contacts.
220
+ If not specified, a relaxed-constraints rigid contacts model is used.
221
+ contact_params: The parameters of the contact model.
222
222
  integrator: The integrator to use for the simulation.
223
223
  gravity: The gravity constant.
224
224
 
@@ -252,8 +252,8 @@ class JaxSimModel(JaxsimDataclass):
252
252
  else jaxsim.rbda.contacts.RelaxedRigidContacts.build()
253
253
  )
254
254
 
255
- if contacts_params is None:
256
- contacts_params = contact_model._parameters_class()
255
+ if contact_params is None:
256
+ contact_params = contact_model._parameters_class()
257
257
 
258
258
  # Consider the default integrator if not specified.
259
259
  integrator = (
@@ -271,7 +271,7 @@ class JaxSimModel(JaxsimDataclass):
271
271
  time_step=time_step,
272
272
  terrain=terrain,
273
273
  contact_model=contact_model,
274
- contacts_params=contacts_params,
274
+ contact_params=contact_params,
275
275
  integrator=integrator,
276
276
  gravity=gravity,
277
277
  # The following is wrapped as hashless since it's a static argument, and we
@@ -473,7 +473,7 @@ def reduce(
473
473
  time_step=model.time_step,
474
474
  terrain=model.terrain,
475
475
  contact_model=model.contact_model,
476
- contacts_params=model.contacts_params,
476
+ contact_params=model.contact_params,
477
477
  gravity=model.gravity,
478
478
  integrator=model.integrator,
479
479
  )
@@ -2069,14 +2069,12 @@ def step(
2069
2069
  )
2070
2070
 
2071
2071
  # Get the external forces in inertial-fixed representation.
2072
- W_f_L_external = jax.vmap(
2073
- lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial(
2074
- f_L,
2075
- other_representation=data.velocity_representation,
2076
- transform=W_H_L,
2077
- is_force=True,
2078
- )
2079
- )(O_f_L_external, data._link_transforms)
2072
+ W_f_L_external = js.data.JaxSimModelData.other_representation_to_inertial(
2073
+ O_f_L_external,
2074
+ other_representation=data.velocity_representation,
2075
+ transform=data._link_transforms,
2076
+ is_force=True,
2077
+ )
2080
2078
 
2081
2079
  τ_references = jnp.atleast_1d(
2082
2080
  jnp.array(joint_force_references, dtype=float).squeeze()
@@ -2092,29 +2090,6 @@ def step(
2092
2090
  model, data, joint_force_references=τ_references
2093
2091
  )
2094
2092
 
2095
- # ======================
2096
- # Compute contact forces
2097
- # ======================
2098
-
2099
- W_f_L_terrain = jnp.zeros_like(W_f_L_external)
2100
-
2101
- if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
2102
-
2103
- # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
2104
- # with the terrain.
2105
- W_f_L_terrain = js.contact_model.link_contact_forces(
2106
- model=model,
2107
- data=data,
2108
- link_forces=W_f_L_external,
2109
- joint_torques=τ_total,
2110
- )
2111
-
2112
- # ==============================
2113
- # Compute the total link forces
2114
- # ==============================
2115
-
2116
- W_f_L_total = W_f_L_external + W_f_L_terrain
2117
-
2118
2093
  # =============================
2119
2094
  # Advance the simulation state
2120
2095
  # =============================
@@ -2124,7 +2099,14 @@ def step(
2124
2099
  integrator_fn = _INTEGRATORS_MAP[model.integrator]
2125
2100
 
2126
2101
  data_tf = integrator_fn(
2127
- model=model, data=data, link_forces=W_f_L_total, joint_torques=τ_total
2102
+ model=model,
2103
+ data=data,
2104
+ link_forces=W_f_L_external,
2105
+ joint_torques=τ_total,
2106
+ )
2107
+
2108
+ data_tf = model.contact_model.update_velocity_after_impact(
2109
+ model=model, data=data_tf
2128
2110
  )
2129
2111
 
2130
2112
  return data_tf
jaxsim/api/ode.py CHANGED
@@ -46,12 +46,36 @@ def system_acceleration(
46
46
  else jnp.zeros((model.number_of_links(), 6))
47
47
  ).astype(float)
48
48
 
49
+ # ======================
50
+ # Compute contact forces
51
+ # ======================
52
+
53
+ W_f_L_terrain = jnp.zeros_like(f_L)
54
+ contact_state_derivative = {}
55
+
56
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
57
+
58
+ # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
59
+ # with the terrain.
60
+ W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces(
61
+ model=model,
62
+ data=data,
63
+ link_forces=f_L,
64
+ joint_torques=joint_torques,
65
+ )
66
+
67
+ W_f_L_total = f_L + W_f_L_terrain
68
+
69
+ # Update the contact state data. This is necessary only for the contact models
70
+ # that require propagation and integration of contact state.
71
+ contact_state = model.contact_model.update_contact_state(contact_state_derivative)
72
+
49
73
  # Store the link forces in a references object.
50
74
  references = js.references.JaxSimModelReferences.build(
51
75
  model=model,
52
76
  data=data,
53
77
  velocity_representation=data.velocity_representation,
54
- link_forces=f_L,
78
+ link_forces=W_f_L_total,
55
79
  )
56
80
 
57
81
  # Compute forward dynamics.
@@ -68,13 +92,12 @@ def system_acceleration(
68
92
  link_forces=references.link_forces(model=model, data=data),
69
93
  )
70
94
 
71
- return v̇_WB, s̈
95
+ return v̇_WB, s̈, contact_state
72
96
 
73
97
 
74
98
  @jax.jit
75
99
  @js.common.named_scope
76
100
  def system_position_dynamics(
77
- model: js.model.JaxSimModel,
78
101
  data: js.data.JaxSimModelData,
79
102
  baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
80
103
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
@@ -82,7 +105,6 @@ def system_position_dynamics(
82
105
  Compute the dynamics of the system position.
83
106
 
84
107
  Args:
85
- model: The model to consider.
86
108
  data: The data of the considered model.
87
109
  baumgarte_quaternion_regularization:
88
110
  The Baumgarte regularization coefficient for adjusting the quaternion norm.
@@ -144,7 +166,7 @@ def system_dynamics(
144
166
  """
145
167
 
146
168
  with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
147
- W_v̇_WB, s̈ = system_acceleration(
169
+ W_v̇_WB, s̈, contact_state_derivative = system_acceleration(
148
170
  model=model,
149
171
  data=data,
150
172
  joint_torques=joint_torques,
@@ -152,7 +174,6 @@ def system_dynamics(
152
174
  )
153
175
 
154
176
  W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
155
- model=model,
156
177
  data=data,
157
178
  baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
158
179
  )
@@ -164,4 +185,5 @@ def system_dynamics(
164
185
  base_linear_velocity=W_v̇_WB[0:3],
165
186
  base_angular_velocity=W_v̇_WB[3:6],
166
187
  joint_velocities=s̈,
188
+ contact_state=contact_state_derivative,
167
189
  )
jaxsim/api/references.py CHANGED
@@ -434,24 +434,17 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
434
434
  if not_tracing(forces) and not data.valid(model=model):
435
435
  raise ValueError("The provided data is not valid for the model")
436
436
 
437
- # Helper function to convert a single 6D force to the inertial representation
438
- # considering as body the link (i.e. L_f_L and LW_f_L).
439
- def convert_using_link_frame(
440
- f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
441
- ) -> jtp.Matrix:
442
-
443
- return jax.vmap(
444
- lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
445
- array=f_L,
446
- other_representation=self.velocity_representation,
447
- transform=W_H_L,
448
- is_force=True,
449
- )
450
- )(f_L, W_H_L)
437
+ W_H_L = data._link_transforms
451
438
 
439
+ # Convert a single 6D force to the inertial representation
440
+ # considering as body the link (i.e. L_f_L and LW_f_L).
452
441
  # The f_L input is either L_f_L or LW_f_L, depending on the representation.
453
- W_H_L = data._link_transforms
454
- W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
442
+ W_f_L = JaxSimModelReferences.other_representation_to_inertial(
443
+ array=f_L,
444
+ other_representation=self.velocity_representation,
445
+ transform=W_H_L[link_idxs] if model.number_of_links() > 1 else W_H_L,
446
+ is_force=True,
447
+ )
455
448
 
456
449
  return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
457
450
 
jaxsim/math/__init__.py CHANGED
@@ -11,4 +11,4 @@ from .joint_model import JointModel, supported_joint_motion # isort:skip
11
11
 
12
12
 
13
13
  # Define the default standard gravity constant.
14
- STANDARD_GRAVITY = -9.81
14
+ STANDARD_GRAVITY = 9.81
jaxsim/math/adjoint.py CHANGED
@@ -55,13 +55,13 @@ class Adjoint:
55
55
  The 6x6 adjoint matrix.
56
56
  """
57
57
 
58
- A_H_B = jnp.reshape(transform, (-1, 4, 4))
58
+ A_H_B = transform
59
59
 
60
60
  return (
61
61
  jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
62
62
  if not inverse
63
63
  else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
64
- ).reshape(transform.shape[:-2] + (6, 6))
64
+ )
65
65
 
66
66
  @staticmethod
67
67
  def from_rotation_and_translation(
jaxsim/math/transform.py CHANGED
@@ -36,8 +36,8 @@ class Transform:
36
36
  W_Q_B = jnp.array(quaternion).astype(float)
37
37
  W_p_B = jnp.array(translation).astype(float)
38
38
 
39
- assert W_p_B.size == 3
40
- assert W_Q_B.size == 4
39
+ assert W_p_B.shape[-1] == 3
40
+ assert W_Q_B.shape[-1] == 4
41
41
 
42
42
  A_R_B = jaxlie.SO3(wxyz=W_Q_B)
43
43
  A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
jaxsim/math/utils.py CHANGED
@@ -3,7 +3,7 @@ import jax.numpy as jnp
3
3
  import jaxsim.typing as jtp
4
4
 
5
5
 
6
- def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
6
+ def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array:
7
7
  """
8
8
  Compute an array norm handling NaNs and making sure that
9
9
  it is safe to get the gradient.
@@ -11,6 +11,7 @@ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
11
11
  Args:
12
12
  array: The array for which to compute the norm.
13
13
  axis: The axis for which to compute the norm.
14
+ keepdims: Whether to keep the dimensions of the input
14
15
 
15
16
  Returns:
16
17
  The norm of the array with handling for zero arrays to avoid NaNs.
@@ -24,7 +25,7 @@ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
24
25
  array = jnp.where(is_zero, jnp.ones_like(array), array)
25
26
 
26
27
  # Compute the norm of the array along the specified axis.
27
- norm = jnp.linalg.norm(array, axis=axis)
28
+ norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
28
29
 
29
30
  # Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
30
31
  # This usage supports potential batch processing for future scalability.
@@ -64,7 +64,7 @@ class MujocoVideoRecorder:
64
64
 
65
65
  self.frames = []
66
66
 
67
- self.data = [data] if data is not None else self.data
67
+ self.data = data if data is not None else self.data
68
68
  self.data = self.data if isinstance(self.data, list) else [self.data]
69
69
 
70
70
  self.model = model if model is not None else self.model
@@ -973,7 +973,7 @@ class KinematicGraphTransforms:
973
973
 
974
974
  if frame.parent_name in self.graph.links_dict:
975
975
  return frame.parent_name
976
- elif frame.parent_name in self.graph.frames_dict:
976
+ if frame.parent_name in self.graph.frames_dict:
977
977
  return self.find_parent_link_of_frame(name=frame.parent_name)
978
978
 
979
979
  msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'"
jaxsim/rbda/__init__.py CHANGED
@@ -2,7 +2,7 @@ from . import contacts
2
2
  from .aba import aba
3
3
  from .collidable_points import collidable_points_pos_vel
4
4
  from .crba import crba
5
- from .forward_kinematics import forward_kinematics, forward_kinematics_model
5
+ from .forward_kinematics import forward_kinematics_model
6
6
  from .jacobian import (
7
7
  jacobian,
8
8
  jacobian_derivative_full_doubly_left,
@@ -1,5 +1,9 @@
1
- from . import relaxed_rigid
1
+ from . import relaxed_rigid, rigid, soft
2
2
  from .common import ContactModel, ContactsParams
3
3
  from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
4
+ from .rigid import RigidContacts, RigidContactsParams
5
+ from .soft import SoftContacts, SoftContactsParams
4
6
 
5
- ContactParamsTypes = RelaxedRigidContactsParams
7
+ ContactParamsTypes = (
8
+ SoftContactsParams | RigidContactsParams | RelaxedRigidContactsParams
9
+ )
@@ -9,6 +9,7 @@ import jax.numpy as jnp
9
9
  import jaxsim.api as js
10
10
  import jaxsim.terrain
11
11
  import jaxsim.typing as jtp
12
+ from jaxsim.math import STANDARD_GRAVITY
12
13
  from jaxsim.utils import JaxsimDataclass
13
14
 
14
15
  try:
@@ -80,6 +81,86 @@ class ContactsParams(JaxsimDataclass):
80
81
  """
81
82
  pass
82
83
 
84
+ def build_default_from_jaxsim_model(
85
+ self: type[Self],
86
+ model: js.model.JaxSimModel,
87
+ *,
88
+ stiffness: jtp.FloatLike | None = None,
89
+ damping: jtp.FloatLike | None = None,
90
+ standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
91
+ static_friction_coefficient: jtp.FloatLike = 0.5,
92
+ max_penetration: jtp.FloatLike = 0.001,
93
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
94
+ damping_ratio: jtp.FloatLike = 1.0,
95
+ p: jtp.FloatLike = 0.5,
96
+ q: jtp.FloatLike = 0.5,
97
+ **kwargs,
98
+ ) -> Self:
99
+ """
100
+ Create a `ContactsParams` instance with default parameters.
101
+
102
+ Args:
103
+ model: The robot model considered by the contact model.
104
+ stiffness: The stiffness of the contact model.
105
+ damping: The damping of the contact model.
106
+ standard_gravity: The standard gravity acceleration.
107
+ static_friction_coefficient: The static friction coefficient.
108
+ max_penetration: The maximum penetration depth.
109
+ number_of_active_collidable_points_steady_state:
110
+ The number of active collidable points in steady state.
111
+ damping_ratio: The damping ratio.
112
+ p: The first parameter of the contact model.
113
+ q: The second parameter of the contact model.
114
+ **kwargs: Optional additional arguments.
115
+
116
+ Returns:
117
+ The `ContactsParams` instance.
118
+
119
+ Note:
120
+ The `stiffness` is intended as the terrain stiffness in the Soft Contacts model,
121
+ while it is the Baumgarte stabilization stiffness in the Rigid Contacts model.
122
+
123
+ The `damping` is intended as the terrain damping in the Soft Contacts model,
124
+ while it is the Baumgarte stabilization damping in the Rigid Contacts model.
125
+
126
+ The `damping_ratio` parameter allows to operate on the following conditions:
127
+ - ξ > 1.0: over-damped
128
+ - ξ = 1.0: critically damped
129
+ - ξ < 1.0: under-damped
130
+ """
131
+
132
+ # Use symbols for input parameters.
133
+ ξ = damping_ratio
134
+ δ_max = max_penetration
135
+ μc = static_friction_coefficient
136
+
137
+ # Compute the total mass of the model.
138
+ m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
139
+
140
+ # Rename the standard gravity.
141
+ g = standard_gravity
142
+
143
+ # Compute the average support force on each collidable point.
144
+ f_average = m * g / number_of_active_collidable_points_steady_state
145
+
146
+ # Compute the stiffness to get the desired steady-state penetration.
147
+ # Note that this is dependent on the non-linear exponent used in
148
+ # the damping term of the Hunt/Crossley model.
149
+ K = f_average / jnp.power(δ_max, 1 + p) if stiffness is None else stiffness
150
+
151
+ # Compute the damping using the damping ratio.
152
+ critical_damping = 2 * jnp.sqrt(K * m)
153
+ D = ξ * critical_damping if damping is None else damping
154
+
155
+ return self.build(
156
+ K=K,
157
+ D=D,
158
+ mu=μc,
159
+ p=p,
160
+ q=q,
161
+ **kwargs,
162
+ )
163
+
83
164
  @abc.abstractmethod
84
165
  def valid(self, **kwargs) -> jtp.BoolLike:
85
166
  """
@@ -156,7 +237,7 @@ class ContactModel(JaxsimDataclass):
156
237
  return {}
157
238
 
158
239
  @property
159
- def _parameters_class(cls) -> type[ContactsParams]:
240
+ def _parameters_class(self) -> type[ContactsParams]:
160
241
  """
161
242
  Return the class of the contact parameters.
162
243
 
@@ -168,8 +249,37 @@ class ContactModel(JaxsimDataclass):
168
249
  return getattr(
169
250
  importlib.import_module("jaxsim.rbda.contacts"),
170
251
  (
171
- cls.__name__ + "Params"
172
- if isinstance(cls, type)
173
- else cls.__class__.__name__ + "Params"
252
+ self.__name__ + "Params"
253
+ if isinstance(self, type)
254
+ else self.__class__.__name__ + "Params"
174
255
  ),
175
256
  )
257
+
258
+ @abc.abstractmethod
259
+ def update_contact_state(
260
+ self: type[Self], old_contact_state: dict[str, jtp.Array]
261
+ ) -> dict[str, jtp.Array]:
262
+ """
263
+ Update the contact state.
264
+
265
+ Args:
266
+ old_contact_state: The old contact state.
267
+
268
+ Returns:
269
+ The updated contact state.
270
+ """
271
+
272
+ @abc.abstractmethod
273
+ def update_velocity_after_impact(
274
+ self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
275
+ ) -> js.data.JaxSimModelData:
276
+ """
277
+ Update the velocity after an impact.
278
+
279
+ Args:
280
+ model: The robot model considered by the contact model.
281
+ data: The data of the considered model.
282
+
283
+ Returns:
284
+ The updated data of the considered model.
285
+ """