jaxsim 0.5.1.dev126__py3-none-any.whl → 0.5.1.dev139__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. jaxsim/__init__.py +0 -7
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/com.py +1 -1
  4. jaxsim/api/common.py +1 -1
  5. jaxsim/api/contact.py +3 -0
  6. jaxsim/api/data.py +2 -1
  7. jaxsim/api/kin_dyn_parameters.py +18 -1
  8. jaxsim/api/model.py +7 -4
  9. jaxsim/api/ode.py +21 -1
  10. jaxsim/exceptions.py +8 -0
  11. jaxsim/integrators/common.py +72 -11
  12. jaxsim/integrators/fixed_step.py +91 -40
  13. jaxsim/integrators/variable_step.py +117 -46
  14. jaxsim/math/adjoint.py +19 -10
  15. jaxsim/math/cross.py +6 -2
  16. jaxsim/math/inertia.py +8 -4
  17. jaxsim/math/quaternion.py +10 -6
  18. jaxsim/math/rotation.py +6 -3
  19. jaxsim/math/skew.py +2 -2
  20. jaxsim/math/transform.py +12 -4
  21. jaxsim/math/utils.py +2 -2
  22. jaxsim/mujoco/loaders.py +17 -7
  23. jaxsim/mujoco/model.py +15 -15
  24. jaxsim/mujoco/utils.py +6 -1
  25. jaxsim/mujoco/visualizer.py +11 -7
  26. jaxsim/parsers/descriptions/collision.py +7 -4
  27. jaxsim/parsers/descriptions/joint.py +16 -14
  28. jaxsim/parsers/descriptions/model.py +1 -1
  29. jaxsim/parsers/kinematic_graph.py +38 -0
  30. jaxsim/parsers/rod/meshes.py +5 -5
  31. jaxsim/parsers/rod/parser.py +1 -1
  32. jaxsim/parsers/rod/utils.py +11 -0
  33. jaxsim/rbda/contacts/common.py +2 -0
  34. jaxsim/rbda/contacts/relaxed_rigid.py +7 -4
  35. jaxsim/rbda/contacts/rigid.py +8 -4
  36. jaxsim/rbda/contacts/soft.py +37 -0
  37. jaxsim/rbda/contacts/visco_elastic.py +1 -0
  38. jaxsim/terrain/terrain.py +52 -0
  39. jaxsim/utils/jaxsim_dataclass.py +3 -3
  40. jaxsim/utils/tracing.py +2 -2
  41. jaxsim/utils/wrappers.py +9 -0
  42. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/METADATA +1 -1
  43. jaxsim-0.5.1.dev139.dist-info/RECORD +74 -0
  44. jaxsim-0.5.1.dev126.dist-info/RECORD +0 -74
  45. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/LICENSE +0 -0
  46. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/WHEEL +0 -0
  47. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev139.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py CHANGED
@@ -254,17 +254,17 @@ class MujocoModelHelper:
254
254
  # ==================
255
255
 
256
256
  def number_of_joints(self) -> int:
257
- """Returns the number of joints in the model."""
257
+ """Return the number of joints in the model."""
258
258
 
259
259
  return self.model.njnt
260
260
 
261
261
  def number_of_dofs(self) -> int:
262
- """Returns the number of DoFs in the model."""
262
+ """Return the number of DoFs in the model."""
263
263
 
264
264
  return self.model.nq
265
265
 
266
266
  def joint_names(self) -> list[str]:
267
- """Returns the names of the joints in the model."""
267
+ """Return the names of the joints in the model."""
268
268
 
269
269
  return [
270
270
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
@@ -272,7 +272,7 @@ class MujocoModelHelper:
272
272
  ]
273
273
 
274
274
  def joint_dofs(self, joint_name: str) -> int:
275
- """Returns the number of DoFs of a joint."""
275
+ """Return the number of DoFs of a joint."""
276
276
 
277
277
  if joint_name not in self.joint_names():
278
278
  raise ValueError(f"Joint '{joint_name}' not found")
@@ -280,7 +280,7 @@ class MujocoModelHelper:
280
280
  return self.data.joint(joint_name).qpos.size
281
281
 
282
282
  def joint_position(self, joint_name: str) -> npt.NDArray:
283
- """Returns the position of a joint."""
283
+ """Return the position of a joint."""
284
284
 
285
285
  if joint_name not in self.joint_names():
286
286
  raise ValueError(f"Joint '{joint_name}' not found")
@@ -288,7 +288,7 @@ class MujocoModelHelper:
288
288
  return self.data.joint(joint_name).qpos
289
289
 
290
290
  def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
291
- """Returns the positions of the joints."""
291
+ """Return the positions of the joints."""
292
292
 
293
293
  joint_names = joint_names if joint_names is not None else self.joint_names()
294
294
 
@@ -299,7 +299,7 @@ class MujocoModelHelper:
299
299
  def set_joint_position(
300
300
  self, joint_name: str, position: npt.NDArray | float
301
301
  ) -> None:
302
- """Sets the position of a joint."""
302
+ """Set the position of a joint."""
303
303
 
304
304
  position = np.atleast_1d(np.array(position).squeeze())
305
305
 
@@ -328,12 +328,12 @@ class MujocoModelHelper:
328
328
  # ==================
329
329
 
330
330
  def number_of_bodies(self) -> int:
331
- """Returns the number of bodies in the model."""
331
+ """Return the number of bodies in the model."""
332
332
 
333
333
  return self.model.nbody
334
334
 
335
335
  def body_names(self) -> list[str]:
336
- """Returns the names of the bodies in the model."""
336
+ """Return the names of the bodies in the model."""
337
337
 
338
338
  return [
339
339
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
@@ -341,7 +341,7 @@ class MujocoModelHelper:
341
341
  ]
342
342
 
343
343
  def body_position(self, body_name: str) -> npt.NDArray:
344
- """Returns the position of a body."""
344
+ """Return the position of a body."""
345
345
 
346
346
  if body_name not in self.body_names():
347
347
  raise ValueError(f"Body '{body_name}' not found")
@@ -349,7 +349,7 @@ class MujocoModelHelper:
349
349
  return self.data.body(body_name).xpos
350
350
 
351
351
  def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
352
- """Returns the orientation of a body."""
352
+ """Return the orientation of a body."""
353
353
 
354
354
  if body_name not in self.body_names():
355
355
  raise ValueError(f"Body '{body_name}' not found")
@@ -363,12 +363,12 @@ class MujocoModelHelper:
363
363
  # ======================
364
364
 
365
365
  def number_of_geometries(self) -> int:
366
- """Returns the number of geometries in the model."""
366
+ """Return the number of geometries in the model."""
367
367
 
368
368
  return self.model.ngeom
369
369
 
370
370
  def geometry_names(self) -> list[str]:
371
- """Returns the names of the geometries in the model."""
371
+ """Return the names of the geometries in the model."""
372
372
 
373
373
  return [
374
374
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
@@ -376,7 +376,7 @@ class MujocoModelHelper:
376
376
  ]
377
377
 
378
378
  def geometry_position(self, geometry_name: str) -> npt.NDArray:
379
- """Returns the position of a geometry."""
379
+ """Return the position of a geometry."""
380
380
 
381
381
  if geometry_name not in self.geometry_names():
382
382
  raise ValueError(f"Geometry '{geometry_name}' not found")
@@ -386,7 +386,7 @@ class MujocoModelHelper:
386
386
  def geometry_orientation(
387
387
  self, geometry_name: str, dcm: bool = False
388
388
  ) -> npt.NDArray:
389
- """Returns the orientation of a geometry."""
389
+ """Return the orientation of a geometry."""
390
390
 
391
391
  if geometry_name not in self.geometry_names():
392
392
  raise ValueError(f"Geometry '{geometry_name}' not found")
jaxsim/mujoco/utils.py CHANGED
@@ -133,6 +133,9 @@ class MujocoCamera:
133
133
 
134
134
  @classmethod
135
135
  def build(cls, **kwargs) -> MujocoCamera:
136
+ """
137
+ Build a Mujoco camera from a dictionary.
138
+ """
136
139
 
137
140
  if not all(isinstance(value, str) for value in kwargs.values()):
138
141
  raise ValueError(f"Values must be strings: {kwargs}")
@@ -219,5 +222,7 @@ class MujocoCamera:
219
222
  )
220
223
 
221
224
  def asdict(self) -> dict[str, str]:
222
-
225
+ """
226
+ Convert the camera to a dictionary.
227
+ """
223
228
  return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
@@ -10,7 +10,9 @@ import numpy.typing as npt
10
10
 
11
11
 
12
12
  class MujocoVideoRecorder:
13
- """"""
13
+ """
14
+ Video recorder for the MuJoCo passive viewer.
15
+ """
14
16
 
15
17
  def __init__(
16
18
  self,
@@ -64,7 +66,7 @@ class MujocoVideoRecorder:
64
66
  self.model = model if model is not None else self.model
65
67
 
66
68
  def render_frame(self, camera_name: str = "track") -> npt.NDArray:
67
- """Renders a frame."""
69
+ """Render a frame."""
68
70
 
69
71
  mujoco.mj_forward(self.model, self.data)
70
72
  self.renderer.update_scene(data=self.data, camera=camera_name)
@@ -72,13 +74,13 @@ class MujocoVideoRecorder:
72
74
  return self.renderer.render()
73
75
 
74
76
  def record_frame(self, camera_name: str = "track") -> None:
75
- """Stores a frame in the buffer."""
77
+ """Store a frame in the buffer."""
76
78
 
77
79
  frame = self.render_frame(camera_name=camera_name)
78
80
  self.frames.append(frame)
79
81
 
80
82
  def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
81
- """Writes the video to a file."""
83
+ """Write the video to a file."""
82
84
 
83
85
  # Resolve the path to the video.
84
86
  path = path.expanduser().resolve()
@@ -117,7 +119,9 @@ class MujocoVideoRecorder:
117
119
 
118
120
 
119
121
  class MujocoVisualizer:
120
- """"""
122
+ """
123
+ Visualizer for the MuJoCo passive viewer.
124
+ """
121
125
 
122
126
  def __init__(
123
127
  self, model: mj.MjModel | None = None, data: mj.MjData | None = None
@@ -139,7 +143,7 @@ class MujocoVisualizer:
139
143
  model: mj.MjModel | None = None,
140
144
  data: mj.MjData | None = None,
141
145
  ) -> None:
142
- """Updates the viewer with the current model and data."""
146
+ """Update the viewer with the current model and data."""
143
147
 
144
148
  data = data if data is not None else self.data
145
149
  model = model if model is not None else self.model
@@ -150,7 +154,7 @@ class MujocoVisualizer:
150
154
  def open_viewer(
151
155
  self, model: mj.MjModel | None = None, data: mj.MjData | None = None
152
156
  ) -> mj.viewer.Handle:
153
- """Opens a viewer."""
157
+ """Open a viewer."""
154
158
 
155
159
  data = data if data is not None else self.data
156
160
  model = model if model is not None else self.model
@@ -22,7 +22,6 @@ class CollidablePoint:
22
22
  parent_link: The parent link to which the collidable point is attached.
23
23
  position: The position of the collidable point relative to the parent link.
24
24
  enabled: A flag indicating whether the collidable point is enabled for collision detection.
25
-
26
25
  """
27
26
 
28
27
  parent_link: LinkDescription
@@ -86,7 +85,6 @@ class CollisionShape(abc.ABC):
86
85
 
87
86
  Attributes:
88
87
  collidable_points: A list of collidable points associated with the collision shape.
89
-
90
88
  """
91
89
 
92
90
  collidable_points: tuple[CollidablePoint]
@@ -107,7 +105,6 @@ class BoxCollision(CollisionShape):
107
105
 
108
106
  Attributes:
109
107
  center: The center of the box in the local frame of the collision shape.
110
-
111
108
  """
112
109
 
113
110
  center: jtp.VectorLike
@@ -135,7 +132,6 @@ class SphereCollision(CollisionShape):
135
132
 
136
133
  Attributes:
137
134
  center: The center of the sphere in the local frame of the collision shape.
138
-
139
135
  """
140
136
 
141
137
  center: jtp.VectorLike
@@ -158,6 +154,13 @@ class SphereCollision(CollisionShape):
158
154
 
159
155
  @dataclasses.dataclass
160
156
  class MeshCollision(CollisionShape):
157
+ """
158
+ Represents a mesh-shaped collision shape.
159
+
160
+ Attributes:
161
+ center: The center of the mesh in the local frame of the collision shape.
162
+ """
163
+
161
164
  center: jtp.VectorLike
162
165
 
163
166
  def __hash__(self) -> int:
@@ -14,6 +14,9 @@ from .link import LinkDescription
14
14
 
15
15
  @dataclasses.dataclass(frozen=True)
16
16
  class JointType:
17
+ """
18
+ Enumeration of joint types.
19
+ """
17
20
 
18
21
  Fixed: ClassVar[int] = 0
19
22
  Revolute: ClassVar[int] = 1
@@ -47,20 +50,19 @@ class JointDescription(JaxsimDataclass):
47
50
  In-memory description of a robot link.
48
51
 
49
52
  Attributes:
50
- name (str): The name of the joint.
51
- axis (npt.NDArray): The axis of rotation or translation for the joint.
52
- pose (npt.NDArray): The pose transformation matrix of the joint.
53
- jtype (JointType): The type of the joint.
54
- child (LinkDescription): The child link attached to the joint.
55
- parent (LinkDescription): The parent link attached to the joint.
56
- index (Optional[int]): An optional index for the joint.
57
- friction_static (float): The static friction coefficient for the joint.
58
- friction_viscous (float): The viscous friction coefficient for the joint.
59
- position_limit_damper (float): The damper coefficient for position limits.
60
- position_limit_spring (float): The spring coefficient for position limits.
61
- position_limit (Tuple[float, float]): The position limits for the joint.
62
- initial_position (Union[float, npt.NDArray]): The initial position of the joint.
63
-
53
+ name: The name of the joint.
54
+ axis: The axis of rotation or translation for the joint.
55
+ pose: The pose transformation matrix of the joint.
56
+ jtype: The type of the joint.
57
+ child: The child link attached to the joint.
58
+ parent: The parent link attached to the joint.
59
+ index: An optional index for the joint.
60
+ friction_static: The static friction coefficient for the joint.
61
+ friction_viscous: The viscous friction coefficient for the joint.
62
+ position_limit_damper: The damper coefficient for position limits.
63
+ position_limit_spring: The spring coefficient for position limits.
64
+ position_limit: The position limits for the joint.
65
+ initial_position: The initial position of the joint.
64
66
  """
65
67
 
66
68
  name: jax_dataclasses.Static[str]
@@ -158,7 +158,7 @@ class ModelDescription(KinematicGraph):
158
158
  Reduce the model by removing specified joints.
159
159
 
160
160
  Args:
161
- The joint names to consider.
161
+ considered_joints: Sequence of joint names to consider.
162
162
 
163
163
  Returns:
164
164
  A `ModelDescription` instance that only includes the considered joints.
@@ -97,20 +97,32 @@ class KinematicGraph(Sequence[LinkDescription]):
97
97
 
98
98
  @functools.cached_property
99
99
  def links_dict(self) -> dict[str, LinkDescription]:
100
+ """
101
+ Get a dictionary of links indexed by their name.
102
+ """
100
103
  return {l.name: l for l in iter(self)}
101
104
 
102
105
  @functools.cached_property
103
106
  def frames_dict(self) -> dict[str, LinkDescription]:
107
+ """
108
+ Get a dictionary of frames indexed by their name.
109
+ """
104
110
  return {f.name: f for f in self.frames}
105
111
 
106
112
  @functools.cached_property
107
113
  def joints_dict(self) -> dict[str, JointDescription]:
114
+ """
115
+ Get a dictionary of joints indexed by their name.
116
+ """
108
117
  return {j.name: j for j in self.joints}
109
118
 
110
119
  @functools.cached_property
111
120
  def joints_connection_dict(
112
121
  self,
113
122
  ) -> dict[tuple[str, str], JointDescription]:
123
+ """
124
+ Get a dictionary of joints indexed by the tuple (parent, child) link names.
125
+ """
114
126
  return {(j.parent.name, j.child.name): j for j in self.joints}
115
127
 
116
128
  def __post_init__(self) -> None:
@@ -734,9 +746,15 @@ class KinematicGraph(Sequence[LinkDescription]):
734
746
  raise TypeError(type(key).__name__)
735
747
 
736
748
  def count(self, value: LinkDescription) -> int:
749
+ """
750
+ Count the occurrences of a link in the kinematic graph.
751
+ """
737
752
  return list(iter(self)).count(value)
738
753
 
739
754
  def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:
755
+ """
756
+ Find the index of a link in the kinematic graph.
757
+ """
740
758
  return list(iter(self)).index(value, start, stop)
741
759
 
742
760
 
@@ -747,6 +765,12 @@ class KinematicGraph(Sequence[LinkDescription]):
747
765
 
748
766
  @dataclasses.dataclass(frozen=True)
749
767
  class KinematicGraphTransforms:
768
+ """
769
+ Class to compute forward kinematics on a kinematic graph.
770
+
771
+ Attributes:
772
+ graph: The kinematic graph on which to compute forward kinematics.
773
+ """
750
774
 
751
775
  graph: KinematicGraph
752
776
 
@@ -767,6 +791,9 @@ class KinematicGraphTransforms:
767
791
 
768
792
  @property
769
793
  def initial_joint_positions(self) -> npt.NDArray:
794
+ """
795
+ Get the initial joint positions of the kinematic graph.
796
+ """
770
797
 
771
798
  return np.atleast_1d(
772
799
  np.array(list(self._initial_joint_positions.values()))
@@ -910,6 +937,17 @@ class KinematicGraphTransforms:
910
937
  joint_axis: npt.NDArray,
911
938
  joint_position: float | None = None,
912
939
  ) -> npt.NDArray:
940
+ """
941
+ Compute the SE(3) transform from the predecessor to the successor frame.
942
+
943
+ Args:
944
+ joint_type: The type of the joint.
945
+ joint_axis: The axis of the joint.
946
+ joint_position: The position of the joint.
947
+
948
+ Returns:
949
+ The 4x4 transform matrix from the predecessor to the successor frame.
950
+ """
913
951
 
914
952
  import jaxsim.math
915
953
 
@@ -6,14 +6,14 @@ VALID_AXIS = {"x": 0, "y": 1, "z": 2}
6
6
 
7
7
  def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
8
8
  """
9
- Extracts the vertices of a mesh as points.
9
+ Extract the vertices of a mesh as points.
10
10
  """
11
11
  return mesh.vertices
12
12
 
13
13
 
14
14
  def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:
15
15
  """
16
- Extracts N random points from the surface of a mesh.
16
+ Extract N random points from the surface of a mesh.
17
17
 
18
18
  Args:
19
19
  mesh: The mesh from which to extract points.
@@ -30,7 +30,7 @@ def extract_points_uniform_surface_sampling(
30
30
  mesh: trimesh.Trimesh, n: int
31
31
  ) -> np.ndarray:
32
32
  """
33
- Extracts N uniformly sampled points from the surface of a mesh.
33
+ Extract N uniformly sampled points from the surface of a mesh.
34
34
 
35
35
  Args:
36
36
  mesh: The mesh from which to extract points.
@@ -47,7 +47,7 @@ def extract_points_select_points_over_axis(
47
47
  mesh: trimesh.Trimesh, axis: str, direction: str, n: int
48
48
  ) -> np.ndarray:
49
49
  """
50
- Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis.
50
+ Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis.
51
51
 
52
52
  Args:
53
53
  mesh: The mesh from which to extract points.
@@ -75,7 +75,7 @@ def extract_points_aap(
75
75
  lower: float | None = None,
76
76
  ) -> np.ndarray:
77
77
  """
78
- Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.
78
+ Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.
79
79
 
80
80
  Args:
81
81
  mesh: The mesh from which to extract points.
@@ -364,7 +364,7 @@ def build_model_description(
364
364
  is_urdf: bool | None = None,
365
365
  ) -> descriptions.ModelDescription:
366
366
  """
367
- Builds a model description from an SDF/URDF resource.
367
+ Build a model description from an SDF/URDF resource.
368
368
 
369
369
  Args:
370
370
  model_description: A path to an SDF/URDF file, a string containing its content,
@@ -223,6 +223,17 @@ def create_mesh_collision(
223
223
  link_description: descriptions.LinkDescription,
224
224
  method: MeshMappingMethod = None,
225
225
  ) -> descriptions.MeshCollision:
226
+ """
227
+ Create a mesh collision from an SDF collision element.
228
+
229
+ Args:
230
+ collision: The SDF collision element.
231
+ link_description: The link description.
232
+ method: The method to use for mesh wrapping.
233
+
234
+ Returns:
235
+ The mesh collision description.
236
+ """
226
237
 
227
238
  file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
228
239
  file_type = file.suffix.replace(".", "")
@@ -125,6 +125,7 @@ class ContactModel(JaxsimDataclass):
125
125
  Args:
126
126
  model: The robot model considered by the contact model.
127
127
  data: The data of the considered model.
128
+ **kwargs: Optional additional arguments, specific to the contact model.
128
129
 
129
130
  Returns:
130
131
  A tuple containing as first element the computed 6D contact force applied to
@@ -146,6 +147,7 @@ class ContactModel(JaxsimDataclass):
146
147
  Args:
147
148
  model: The robot model considered by the contact model.
148
149
  data: The data of the considered model.
150
+ **kwargs: Optional additional arguments, specific to the contact model.
149
151
 
150
152
  Returns:
151
153
  A tuple containing as first element the 6D contact force applied to the
@@ -112,7 +112,7 @@ class RelaxedRigidContactsParams(common.ContactsParams):
112
112
  damping: jtp.FloatLike | None = None,
113
113
  mu: jtp.FloatLike | None = None,
114
114
  ) -> Self:
115
- """Create a `RelaxedRigidContactsParams` instance"""
115
+ """Create a `RelaxedRigidContactsParams` instance."""
116
116
 
117
117
  def default(name: str):
118
118
  return cls.__dataclass_fields__[name].default_factory()
@@ -160,6 +160,7 @@ class RelaxedRigidContactsParams(common.ContactsParams):
160
160
  )
161
161
 
162
162
  def valid(self) -> jtp.BoolLike:
163
+ """Check if the parameters are valid."""
163
164
 
164
165
  return bool(
165
166
  jnp.all(self.time_constant >= 0.0)
@@ -187,6 +188,7 @@ class RelaxedRigidContacts(common.ContactModel):
187
188
 
188
189
  @property
189
190
  def solver_options(self) -> dict[str, Any]:
191
+ """Get the solver options."""
190
192
 
191
193
  return dict(
192
194
  zip(
@@ -207,6 +209,7 @@ class RelaxedRigidContacts(common.ContactModel):
207
209
 
208
210
  Args:
209
211
  solver_options: The options to pass to the L-BFGS solver.
212
+ **kwargs: The parameters of the relaxed rigid contacts model.
210
213
 
211
214
  Returns:
212
215
  The `RelaxedRigidContacts` instance.
@@ -483,8 +486,8 @@ class RelaxedRigidContacts(common.ContactModel):
483
486
 
484
487
  Args:
485
488
  model: The jaxsim model.
486
- penetration: The point position in the constraint frame.
487
- velocity: The point velocity in the constraint frame.
489
+ position_constraint: The position of the collidable points in the constraint frame.
490
+ velocity_constraint: The velocity of the collidable points in the constraint frame.
488
491
  parameters: The parameters of the relaxed rigid contacts model.
489
492
 
490
493
  Returns:
@@ -526,7 +529,7 @@ class RelaxedRigidContacts(common.ContactModel):
526
529
  vel: jtp.Vector,
527
530
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]:
528
531
  """
529
- Calculates impedance and offset acceleration in constraint frame.
532
+ Calculate impedance and offset acceleration in constraint frame.
530
533
 
531
534
  Args:
532
535
  pos: position in constraint frame.
@@ -62,7 +62,7 @@ class RigidContactsParams(ContactsParams):
62
62
  K: jtp.FloatLike | None = None,
63
63
  D: jtp.FloatLike | None = None,
64
64
  ) -> Self:
65
- """Create a `RigidContactParams` instance"""
65
+ """Create a `RigidContactParams` instance."""
66
66
 
67
67
  return cls(
68
68
  mu=jnp.array(
@@ -79,7 +79,7 @@ class RigidContactsParams(ContactsParams):
79
79
  )
80
80
 
81
81
  def valid(self) -> jtp.BoolLike:
82
-
82
+ """Check if the parameters are valid."""
83
83
  return bool(
84
84
  jnp.all(self.mu >= 0.0)
85
85
  and jnp.all(self.K >= 0.0)
@@ -104,6 +104,7 @@ class RigidContacts(ContactModel):
104
104
 
105
105
  @property
106
106
  def solver_options(self) -> dict[str, Any]:
107
+ """Get the solver options as a dictionary."""
107
108
 
108
109
  return dict(
109
110
  zip(
@@ -127,6 +128,7 @@ class RigidContacts(ContactModel):
127
128
  regularization_delassus:
128
129
  The regularization term to add to the diagonal of the Delassus matrix.
129
130
  solver_options: The options to pass to the QP solver.
131
+ **kwargs: Extra arguments which are ignored.
130
132
 
131
133
  Returns:
132
134
  The `RigidContacts` instance.
@@ -173,7 +175,8 @@ class RigidContacts(ContactModel):
173
175
  J_WC: jtp.MatrixLike,
174
176
  data: js.data.JaxSimModelData,
175
177
  ) -> jtp.Vector:
176
- """Returns the new velocity of the system after a potential impact.
178
+ """
179
+ Return the new velocity of the system after a potential impact.
177
180
 
178
181
  Args:
179
182
  inactive_collidable_points: The activation state of the collidable points.
@@ -413,7 +416,8 @@ class RigidContacts(ContactModel):
413
416
  inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
414
417
  ) -> jtp.Matrix:
415
418
  """
416
- Compute the inequality constraint matrix for a single collidable point
419
+ Compute the inequality constraint matrix for a single collidable point.
420
+
417
421
  Rows 0-3: enforce the friction pyramid constraint,
418
422
  Row 4: last one is for the non negativity of the vertical force
419
423
  Row 5: contact complementarity condition
@@ -207,6 +207,7 @@ class SoftContacts(common.ContactModel):
207
207
  model:
208
208
  The robot model considered by the contact model.
209
209
  If passed, it is used to estimate good default parameters.
210
+ **kwargs: Additional parameters to pass to the contact model.
210
211
 
211
212
  Returns:
212
213
  The `SoftContacts` instance.
@@ -244,6 +245,28 @@ class SoftContacts(common.ContactModel):
244
245
  p: jtp.FloatLike = 0.5,
245
246
  q: jtp.FloatLike = 0.5,
246
247
  ) -> tuple[jtp.Vector, jtp.Vector]:
248
+ """
249
+ Compute the contact force using the Hunt/Crossley model.
250
+
251
+ Args:
252
+ position: The position of the collidable point.
253
+ velocity: The velocity of the collidable point.
254
+ tangential_deformation: The material deformation of the collidable point.
255
+ terrain: The terrain model.
256
+ K: The stiffness parameter.
257
+ D: The damping parameter of the soft contacts model.
258
+ mu: The static friction coefficient.
259
+ p:
260
+ The exponent p corresponding to the damping-related non-linearity
261
+ of the Hunt/Crossley model.
262
+ q:
263
+ The exponent q corresponding to the spring-related non-linearity
264
+ of the Hunt/Crossley model
265
+
266
+ Returns:
267
+ A tuple containing the computed contact force and the derivative of the
268
+ material deformation.
269
+ """
247
270
 
248
271
  # Convert the input vectors to arrays.
249
272
  W_p_C = jnp.array(position, dtype=float).squeeze()
@@ -364,6 +387,20 @@ class SoftContacts(common.ContactModel):
364
387
  parameters: SoftContactsParams,
365
388
  terrain: Terrain,
366
389
  ) -> tuple[jtp.Vector, jtp.Vector]:
390
+ """
391
+ Compute the contact force.
392
+
393
+ Args:
394
+ position: The position of the collidable point.
395
+ velocity: The velocity of the collidable point.
396
+ tangential_deformation: The material deformation of the collidable point.
397
+ parameters: The parameters of the soft contacts model.
398
+ terrain: The terrain model.
399
+
400
+ Returns:
401
+ A tuple containing the computed contact force and the derivative of the
402
+ material deformation.
403
+ """
367
404
 
368
405
  CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
369
406
  position=position,
@@ -206,6 +206,7 @@ class ViscoElasticContacts(common.ContactModel):
206
206
  If passed, it is used to estimate good default parameters.
207
207
  max_squarings:
208
208
  The maximum number of squarings performed in the matrix exponential.
209
+ **kwargs: Extra arguments to ignore.
209
210
 
210
211
  Returns:
211
212
  The `ViscoElasticContacts` instance.