jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,156 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Skew
8
+
9
+ from . import utils
10
+
11
+
12
+ def collidable_points_pos_vel(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.Vector,
16
+ base_quaternion: jtp.Vector,
17
+ joint_positions: jtp.Vector,
18
+ base_linear_velocity: jtp.Vector,
19
+ base_angular_velocity: jtp.Vector,
20
+ joint_velocities: jtp.Vector,
21
+ ) -> tuple[jtp.Matrix, jtp.Matrix]:
22
+ """
23
+
24
+ Compute the position and linear velocity of the enabled collidable points in the world frame.
25
+
26
+ Args:
27
+ model: The model to consider.
28
+ base_position: The position of the base link.
29
+ base_quaternion: The quaternion of the base link.
30
+ joint_positions: The positions of the joints.
31
+ base_linear_velocity:
32
+ The linear velocity of the base link in inertial-fixed representation.
33
+ base_angular_velocity:
34
+ The angular velocity of the base link in inertial-fixed representation.
35
+ joint_velocities: The velocities of the joints.
36
+
37
+ Returns:
38
+ A tuple containing the position and linear velocity of the enabled collidable points.
39
+ """
40
+
41
+ # Get the indices of the enabled collidable points.
42
+ indices_of_enabled_collidable_points = (
43
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
44
+ )
45
+
46
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
47
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
48
+ )[indices_of_enabled_collidable_points]
49
+
50
+ L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
51
+ indices_of_enabled_collidable_points
52
+ ]
53
+
54
+ if len(indices_of_enabled_collidable_points) == 0:
55
+ return jnp.array(0).astype(float), jnp.empty(0).astype(float)
56
+
57
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
58
+ model=model,
59
+ base_position=base_position,
60
+ base_quaternion=base_quaternion,
61
+ joint_positions=joint_positions,
62
+ base_linear_velocity=base_linear_velocity,
63
+ base_angular_velocity=base_angular_velocity,
64
+ joint_velocities=joint_velocities,
65
+ )
66
+
67
+ # Get the parent array λ(i).
68
+ # Note: λ(0) must not be used, it's initialized to -1.
69
+ λ = model.kin_dyn_parameters.parent_array
70
+
71
+ # Compute the base transform.
72
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
73
+ rotation=jaxlie.SO3(wxyz=W_Q_B),
74
+ translation=W_p_B,
75
+ )
76
+
77
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
78
+ # These transforms define the relative kinematics of the entire model, including
79
+ # the base transform for both floating-base and fixed-base models.
80
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
81
+ joint_positions=s, base_transform=W_H_B.as_matrix()
82
+ )
83
+
84
+ # Allocate buffer of transforms world -> link and initialize the base pose.
85
+ W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
86
+ W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
87
+
88
+ # Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
89
+ W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
90
+ W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
91
+
92
+ # ====================
93
+ # Propagate kinematics
94
+ # ====================
95
+
96
+ PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix]
97
+ propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
98
+
99
+ def propagate_kinematics(
100
+ carry: PropagateTransformsCarry, i: jtp.Int
101
+ ) -> tuple[PropagateTransformsCarry, None]:
102
+
103
+ ii = i - 1
104
+ W_X_i, W_v_Wi = carry
105
+
106
+ # Compute the parent to child 6D transform.
107
+ λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
108
+
109
+ # Compute the world to child 6D transform.
110
+ W_Xi_i = W_X_i[λ[i]] @ λi_X_i
111
+ W_X_i = W_X_i.at[i].set(W_Xi_i)
112
+
113
+ # Propagate the 6D velocity.
114
+ W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
115
+ W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
116
+
117
+ return (W_X_i, W_v_Wi), None
118
+
119
+ (W_X_i, W_v_Wi), _ = (
120
+ jax.lax.scan(
121
+ f=propagate_kinematics,
122
+ init=propagate_transforms_carry,
123
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
124
+ )
125
+ if model.number_of_links() > 1
126
+ else [(W_X_i, W_v_Wi), None]
127
+ )
128
+
129
+ # ==================================================
130
+ # Compute position and velocity of collidable points
131
+ # ==================================================
132
+
133
+ def process_point_kinematics(
134
+ Li_p_C: jtp.Vector, parent_body: jtp.Int
135
+ ) -> tuple[jtp.Vector, jtp.Vector]:
136
+
137
+ # Compute the position of the collidable point.
138
+ W_p_Ci = (
139
+ Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
140
+ )[0:3]
141
+
142
+ # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
143
+ CW_vl_WCi = (
144
+ jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
145
+ @ W_v_Wi[parent_body].squeeze()
146
+ )
147
+
148
+ return W_p_Ci, CW_vl_WCi
149
+
150
+ # Process all the collidable points in parallel.
151
+ W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
152
+ L_p_Ci,
153
+ parent_link_idx_of_enabled_collidable_points,
154
+ )
155
+
156
+ return W_p_Ci, CW_vl_WC
@@ -0,0 +1,13 @@
1
+ from . import relaxed_rigid, rigid, soft, visco_elastic
2
+ from .common import ContactModel, ContactsParams
3
+ from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
4
+ from .rigid import RigidContacts, RigidContactsParams
5
+ from .soft import SoftContacts, SoftContactsParams
6
+ from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
7
+
8
+ ContactParamsTypes = (
9
+ SoftContactsParams
10
+ | RigidContactsParams
11
+ | RelaxedRigidContactsParams
12
+ | ViscoElasticContactsParams
13
+ )
@@ -0,0 +1,313 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import functools
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+
9
+ import jaxsim.api as js
10
+ import jaxsim.terrain
11
+ import jaxsim.typing as jtp
12
+ from jaxsim.api.common import ModelDataWithVelocityRepresentation
13
+ from jaxsim.utils import JaxsimDataclass
14
+
15
+ try:
16
+ from typing import Self
17
+ except ImportError:
18
+ from typing_extensions import Self
19
+
20
+
21
+ @functools.partial(jax.jit, static_argnames=("terrain",))
22
+ def compute_penetration_data(
23
+ p: jtp.VectorLike,
24
+ v: jtp.VectorLike,
25
+ terrain: jaxsim.terrain.Terrain,
26
+ ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
27
+ """
28
+ Compute the penetration data (depth, rate, and terrain normal) of a collidable point.
29
+
30
+ Args:
31
+ p: The position of the collidable point.
32
+ v:
33
+ The linear velocity of the point (linear component of the mixed 6D velocity
34
+ of the implicit frame `C = (W_p_C, [W])` associated to the point).
35
+ terrain: The considered terrain.
36
+
37
+ Returns:
38
+ A tuple containing the penetration depth, the penetration velocity,
39
+ and the considered terrain normal.
40
+ """
41
+
42
+ # Pre-process the position and the linear velocity of the collidable point.
43
+ W_ṗ_C = jnp.array(v).squeeze()
44
+ px, py, pz = jnp.array(p).squeeze()
45
+
46
+ # Compute the terrain normal and the contact depth.
47
+ n̂ = terrain.normal(x=px, y=py).squeeze()
48
+ h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
49
+
50
+ # Compute the penetration depth normal to the terrain.
51
+ δ = jnp.maximum(0.0, jnp.dot(h, n̂))
52
+
53
+ # Compute the penetration normal velocity.
54
+ δ_dot = -jnp.dot(W_ṗ_C, n̂)
55
+
56
+ # Enforce the penetration rate to be zero when the penetration depth is zero.
57
+ δ_dot = jnp.where(δ > 0, δ_dot, 0.0)
58
+
59
+ return δ, δ_dot, n̂
60
+
61
+
62
+ class ContactsParams(JaxsimDataclass):
63
+ """
64
+ Abstract class representing the parameters of a contact model.
65
+
66
+ Note:
67
+ This class is supposed to store only the tunable parameters of the contact
68
+ model, i.e. all those parameters that can be changed during runtime.
69
+ If the contact model has also static parameters, they should be stored
70
+ in the corresponding `ContactModel` class.
71
+ """
72
+
73
+ @classmethod
74
+ @abc.abstractmethod
75
+ def build(cls: type[Self], **kwargs) -> Self:
76
+ """
77
+ Create a `ContactsParams` instance with specified parameters.
78
+
79
+ Returns:
80
+ The `ContactsParams` instance.
81
+ """
82
+ pass
83
+
84
+ @abc.abstractmethod
85
+ def valid(self, **kwargs) -> jtp.BoolLike:
86
+ """
87
+ Check if the parameters are valid.
88
+
89
+ Returns:
90
+ True if the parameters are valid, False otherwise.
91
+ """
92
+ pass
93
+
94
+
95
+ class ContactModel(JaxsimDataclass):
96
+ """
97
+ Abstract class representing a contact model.
98
+ """
99
+
100
+ @classmethod
101
+ @abc.abstractmethod
102
+ def build(
103
+ cls: type[Self],
104
+ **kwargs,
105
+ ) -> Self:
106
+ """
107
+ Create a `ContactModel` instance with specified parameters.
108
+
109
+ Returns:
110
+ The `ContactModel` instance.
111
+ """
112
+
113
+ pass
114
+
115
+ @abc.abstractmethod
116
+ def compute_contact_forces(
117
+ self,
118
+ model: js.model.JaxSimModel,
119
+ data: js.data.JaxSimModelData,
120
+ **kwargs,
121
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
122
+ """
123
+ Compute the contact forces.
124
+
125
+ Args:
126
+ model: The robot model considered by the contact model.
127
+ data: The data of the considered model.
128
+ **kwargs: Optional additional arguments, specific to the contact model.
129
+
130
+ Returns:
131
+ A tuple containing as first element the computed 6D contact force applied to
132
+ the contact points and expressed in the world frame, and as second element
133
+ a dictionary of optional additional information.
134
+ """
135
+
136
+ pass
137
+
138
+ def compute_link_contact_forces(
139
+ self,
140
+ model: js.model.JaxSimModel,
141
+ data: js.data.JaxSimModelData,
142
+ **kwargs,
143
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
144
+ """
145
+ Compute the link contact forces.
146
+
147
+ Args:
148
+ model: The robot model considered by the contact model.
149
+ data: The data of the considered model.
150
+ **kwargs: Optional additional arguments, specific to the contact model.
151
+
152
+ Returns:
153
+ A tuple containing as first element the 6D contact force applied to the
154
+ links and expressed in the frame of the velocity representation of data,
155
+ and as second element a dictionary of optional additional information.
156
+ """
157
+
158
+ # Compute the contact forces expressed in the inertial frame.
159
+ # This function, contrarily to `compute_contact_forces`, already handles how
160
+ # the optional kwargs should be passed to the specific contact models.
161
+ W_f_C, aux_dict = js.contact.collidable_point_dynamics(
162
+ model=model, data=data, **kwargs
163
+ )
164
+
165
+ # Compute the 6D forces applied to the links equivalent to the forces applied
166
+ # to the frames associated to the collidable points.
167
+ with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
168
+
169
+ W_f_L = self.link_forces_from_contact_forces(
170
+ model=model, data=data, contact_forces=W_f_C
171
+ )
172
+
173
+ # Store the link forces in the references object for easy conversion.
174
+ references = js.references.JaxSimModelReferences.build(
175
+ model=model,
176
+ data=data,
177
+ link_forces=W_f_L,
178
+ velocity_representation=jaxsim.VelRepr.Inertial,
179
+ )
180
+
181
+ # Convert the link forces to the frame corresponding to the velocity
182
+ # representation of data.
183
+ with references.switch_velocity_representation(data.velocity_representation):
184
+ f_L = references.link_forces(model=model, data=data)
185
+
186
+ return f_L, aux_dict
187
+
188
+ @staticmethod
189
+ def link_forces_from_contact_forces(
190
+ model: js.model.JaxSimModel,
191
+ data: js.data.JaxSimModelData,
192
+ *,
193
+ contact_forces: jtp.MatrixLike,
194
+ ) -> jtp.Matrix:
195
+ """
196
+ Compute the link forces from the contact forces.
197
+
198
+ Args:
199
+ model: The robot model considered by the contact model.
200
+ data: The data of the considered model.
201
+ contact_forces: The contact forces computed by the contact model.
202
+
203
+ Returns:
204
+ The 6D contact forces applied to the links and expressed in the frame of
205
+ the velocity representation of data.
206
+ """
207
+
208
+ # Get the object storing the contact parameters of the model.
209
+ contact_parameters = model.kin_dyn_parameters.contact_parameters
210
+
211
+ # Extract the indices corresponding to the enabled collidable points.
212
+ indices_of_enabled_collidable_points = (
213
+ contact_parameters.indices_of_enabled_collidable_points
214
+ )
215
+
216
+ # Convert the contact forces to a JAX array.
217
+ f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
218
+
219
+ # Get the pose of the enabled collidable points.
220
+ W_H_C = js.contact.transforms(model=model, data=data)[
221
+ indices_of_enabled_collidable_points
222
+ ]
223
+
224
+ # Convert the contact forces to inertial-fixed representation.
225
+ W_f_C = jax.vmap(
226
+ lambda f_C, W_H_C: (
227
+ ModelDataWithVelocityRepresentation.other_representation_to_inertial(
228
+ array=f_C,
229
+ other_representation=data.velocity_representation,
230
+ transform=W_H_C,
231
+ is_force=True,
232
+ )
233
+ )
234
+ )(f_C, W_H_C)
235
+
236
+ # Construct the vector defining the parent link index of each collidable point.
237
+ # We use this vector to sum the 6D forces of all collidable points rigidly
238
+ # attached to the same link.
239
+ parent_link_index_of_collidable_points = jnp.array(
240
+ contact_parameters.body, dtype=int
241
+ )[indices_of_enabled_collidable_points]
242
+
243
+ # Create the mask that associate each collidable point to their parent link.
244
+ # We use this mask to sum the collidable points to the right link.
245
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
246
+ model.number_of_links()
247
+ )
248
+
249
+ # Sum the forces of all collidable points rigidly attached to a body.
250
+ # Since the contact forces W_f_C are expressed in the world frame,
251
+ # we don't need any coordinate transformation.
252
+ W_f_L = mask.T @ W_f_C
253
+
254
+ # Compute the link transforms.
255
+ W_H_L = (
256
+ js.model.forward_kinematics(model=model, data=data)
257
+ if data.velocity_representation is not jaxsim.VelRepr.Inertial
258
+ else jnp.zeros(shape=(model.number_of_links(), 4, 4))
259
+ )
260
+
261
+ # Convert the inertial-fixed link forces to the velocity representation of data.
262
+ f_L = jax.vmap(
263
+ lambda W_f_L, W_H_L: (
264
+ ModelDataWithVelocityRepresentation.inertial_to_other_representation(
265
+ array=W_f_L,
266
+ other_representation=data.velocity_representation,
267
+ transform=W_H_L,
268
+ is_force=True,
269
+ )
270
+ )
271
+ )(W_f_L, W_H_L)
272
+
273
+ return f_L
274
+
275
+ @classmethod
276
+ def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
277
+ """
278
+ Build zero state variables of the contact model.
279
+
280
+ Args:
281
+ model: The robot model considered by the contact model.
282
+
283
+ Note:
284
+ There are contact models that require to extend the state vector of the
285
+ integrated ODE system with additional variables. Our integrators are
286
+ capable of operating on a generic state, as long as it is a PyTree.
287
+ This method builds the zero state variables of the contact model as a
288
+ dictionary of JAX arrays.
289
+
290
+ Returns:
291
+ A dictionary storing the zero state variables of the contact model.
292
+ """
293
+
294
+ return {}
295
+
296
+ @property
297
+ def _parameters_class(cls) -> type[ContactsParams]:
298
+ """
299
+ Return the class of the contact parameters.
300
+
301
+ Returns:
302
+ The class of the contact parameters.
303
+ """
304
+ import importlib
305
+
306
+ return getattr(
307
+ importlib.import_module("jaxsim.rbda.contacts"),
308
+ (
309
+ cls.__name__ + "Params"
310
+ if isinstance(cls, type)
311
+ else cls.__class__.__name__ + "Params"
312
+ ),
313
+ )