jaxsim 0.5.1.dev126__py3-none-any.whl → 0.5.1.dev133__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/_version.py +2 -2
  2. jaxsim/api/com.py +1 -1
  3. jaxsim/api/common.py +1 -1
  4. jaxsim/api/contact.py +3 -0
  5. jaxsim/api/data.py +2 -1
  6. jaxsim/api/kin_dyn_parameters.py +18 -1
  7. jaxsim/api/model.py +7 -4
  8. jaxsim/api/ode.py +21 -1
  9. jaxsim/exceptions.py +8 -0
  10. jaxsim/integrators/common.py +60 -2
  11. jaxsim/integrators/fixed_step.py +21 -0
  12. jaxsim/integrators/variable_step.py +44 -0
  13. jaxsim/math/adjoint.py +13 -10
  14. jaxsim/math/cross.py +6 -2
  15. jaxsim/math/inertia.py +8 -4
  16. jaxsim/math/quaternion.py +10 -6
  17. jaxsim/math/rotation.py +6 -3
  18. jaxsim/math/skew.py +2 -2
  19. jaxsim/math/transform.py +3 -0
  20. jaxsim/math/utils.py +2 -2
  21. jaxsim/mujoco/loaders.py +17 -7
  22. jaxsim/mujoco/model.py +15 -15
  23. jaxsim/mujoco/utils.py +6 -1
  24. jaxsim/mujoco/visualizer.py +11 -7
  25. jaxsim/parsers/descriptions/collision.py +7 -4
  26. jaxsim/parsers/descriptions/joint.py +16 -14
  27. jaxsim/parsers/descriptions/model.py +1 -1
  28. jaxsim/parsers/kinematic_graph.py +38 -0
  29. jaxsim/parsers/rod/meshes.py +5 -5
  30. jaxsim/parsers/rod/parser.py +1 -1
  31. jaxsim/parsers/rod/utils.py +11 -0
  32. jaxsim/rbda/contacts/common.py +2 -0
  33. jaxsim/rbda/contacts/relaxed_rigid.py +7 -4
  34. jaxsim/rbda/contacts/rigid.py +8 -4
  35. jaxsim/rbda/contacts/soft.py +37 -0
  36. jaxsim/rbda/contacts/visco_elastic.py +1 -0
  37. jaxsim/terrain/terrain.py +52 -0
  38. jaxsim/utils/jaxsim_dataclass.py +3 -3
  39. jaxsim/utils/tracing.py +2 -2
  40. jaxsim/utils/wrappers.py +9 -0
  41. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/METADATA +1 -1
  42. jaxsim-0.5.1.dev133.dist-info/RECORD +74 -0
  43. jaxsim-0.5.1.dev126.dist-info/RECORD +0 -74
  44. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.5.1.dev126.dist-info → jaxsim-0.5.1.dev133.dist-info}/top_level.txt +0 -0
jaxsim/math/rotation.py CHANGED
@@ -8,6 +8,9 @@ from .utils import safe_norm
8
8
 
9
9
 
10
10
  class Rotation:
11
+ """
12
+ A utility class for rotation matrix operations.
13
+ """
11
14
 
12
15
  @staticmethod
13
16
  def x(theta: jtp.Float) -> jtp.Matrix:
@@ -15,7 +18,7 @@ class Rotation:
15
18
  Generate a 3D rotation matrix around the X-axis.
16
19
 
17
20
  Args:
18
- theta (jtp.Float): Rotation angle in radians.
21
+ theta: Rotation angle in radians.
19
22
 
20
23
  Returns:
21
24
  jtp.Matrix: 3D rotation matrix.
@@ -29,7 +32,7 @@ class Rotation:
29
32
  Generate a 3D rotation matrix around the Y-axis.
30
33
 
31
34
  Args:
32
- theta (jtp.Float): Rotation angle in radians.
35
+ theta: Rotation angle in radians.
33
36
 
34
37
  Returns:
35
38
  jtp.Matrix: 3D rotation matrix.
@@ -43,7 +46,7 @@ class Rotation:
43
46
  Generate a 3D rotation matrix around the Z-axis.
44
47
 
45
48
  Args:
46
- theta (jtp.Float): Rotation angle in radians.
49
+ theta: Rotation angle in radians.
47
50
 
48
51
  Returns:
49
52
  jtp.Matrix: 3D rotation matrix.
jaxsim/math/skew.py CHANGED
@@ -14,7 +14,7 @@ class Skew:
14
14
  Compute the skew-symmetric matrix (wedge operator) of a 3D vector.
15
15
 
16
16
  Args:
17
- vector (jtp.Vector): A 3D vector.
17
+ vector: A 3D vector.
18
18
 
19
19
  Returns:
20
20
  jtp.Matrix: The skew-symmetric matrix corresponding to the input vector.
@@ -31,7 +31,7 @@ class Skew:
31
31
  Extract the 3D vector from a skew-symmetric matrix (vee operator).
32
32
 
33
33
  Args:
34
- matrix (jtp.Matrix): A 3x3 skew-symmetric matrix.
34
+ matrix: A 3x3 skew-symmetric matrix.
35
35
 
36
36
  Returns:
37
37
  jtp.Vector: The 3D vector extracted from the input matrix.
jaxsim/math/transform.py CHANGED
@@ -5,6 +5,9 @@ import jaxsim.typing as jtp
5
5
 
6
6
 
7
7
  class Transform:
8
+ """
9
+ A utility class for transformation matrix operations.
10
+ """
8
11
 
9
12
  @staticmethod
10
13
  def from_quaternion_and_translation(
jaxsim/math/utils.py CHANGED
@@ -5,8 +5,8 @@ import jaxsim.typing as jtp
5
5
 
6
6
  def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
7
7
  """
8
- Provides a calculation for an array norm so that it is safe
9
- to compute the gradient and handle NaNs.
8
+ Compute an array norm handling NaNs and making sure that
9
+ it is safe to get the gradient.
10
10
 
11
11
  Args:
12
12
  array: The array for which to compute the norm.
jaxsim/mujoco/loaders.py CHANGED
@@ -22,7 +22,7 @@ def load_rod_model(
22
22
  model_name: str | None = None,
23
23
  ) -> rod.Model:
24
24
  """
25
- Loads a ROD model from a URDF/SDF file or a ROD model.
25
+ Load a ROD model from a URDF/SDF file or a ROD model.
26
26
 
27
27
  Args:
28
28
  model_description: The URDF/SDF file or ROD model to load.
@@ -62,14 +62,16 @@ def load_rod_model(
62
62
 
63
63
 
64
64
  class RodModelToMjcf:
65
- """"""
65
+ """
66
+ Class to convert a ROD model to a Mujoco MJCF string.
67
+ """
66
68
 
67
69
  @staticmethod
68
70
  def assets_from_rod_model(
69
71
  rod_model: rod.Model,
70
72
  ) -> dict[str, bytes]:
71
73
  """
72
- Generates a dictionary of assets from a ROD model.
74
+ Generate a dictionary of assets from a ROD model.
73
75
 
74
76
  Args:
75
77
  rod_model: The ROD model to extract the assets from.
@@ -112,7 +114,7 @@ class RodModelToMjcf:
112
114
  floating_joint_name: str = "world_to_base",
113
115
  ) -> str:
114
116
  """
115
- Adds a floating joint to a URDF string.
117
+ Add a floating joint to a URDF string.
116
118
 
117
119
  Args:
118
120
  urdf_string: The URDF string to modify.
@@ -171,7 +173,7 @@ class RodModelToMjcf:
171
173
  cameras: MujocoCameraType = (),
172
174
  ) -> tuple[str, dict[str, Any]]:
173
175
  """
174
- Converts a ROD model to a Mujoco MJCF string.
176
+ Convert a ROD model to a Mujoco MJCF string.
175
177
 
176
178
  Args:
177
179
  rod_model: The ROD model to convert.
@@ -522,6 +524,10 @@ class RodModelToMjcf:
522
524
 
523
525
 
524
526
  class UrdfToMjcf:
527
+ """
528
+ Class to convert a URDF file to a Mujoco MJCF string.
529
+ """
530
+
525
531
  @staticmethod
526
532
  def convert(
527
533
  urdf: str | pathlib.Path,
@@ -532,7 +538,7 @@ class UrdfToMjcf:
532
538
  cameras: MujocoCameraType = (),
533
539
  ) -> tuple[str, dict[str, Any]]:
534
540
  """
535
- Converts a URDF file to a Mujoco MJCF string.
541
+ Convert a URDF file to a Mujoco MJCF string.
536
542
 
537
543
  Args:
538
544
  urdf: The URDF file to convert.
@@ -564,6 +570,10 @@ class UrdfToMjcf:
564
570
 
565
571
 
566
572
  class SdfToMjcf:
573
+ """
574
+ Class to convert a SDF file to a Mujoco MJCF string.
575
+ """
576
+
567
577
  @staticmethod
568
578
  def convert(
569
579
  sdf: str | pathlib.Path,
@@ -574,7 +584,7 @@ class SdfToMjcf:
574
584
  cameras: MujocoCameraType = (),
575
585
  ) -> tuple[str, dict[str, Any]]:
576
586
  """
577
- Converts a SDF file to a Mujoco MJCF string.
587
+ Convert a SDF file to a Mujoco MJCF string.
578
588
 
579
589
  Args:
580
590
  sdf: The SDF file to convert.
jaxsim/mujoco/model.py CHANGED
@@ -254,17 +254,17 @@ class MujocoModelHelper:
254
254
  # ==================
255
255
 
256
256
  def number_of_joints(self) -> int:
257
- """Returns the number of joints in the model."""
257
+ """Return the number of joints in the model."""
258
258
 
259
259
  return self.model.njnt
260
260
 
261
261
  def number_of_dofs(self) -> int:
262
- """Returns the number of DoFs in the model."""
262
+ """Return the number of DoFs in the model."""
263
263
 
264
264
  return self.model.nq
265
265
 
266
266
  def joint_names(self) -> list[str]:
267
- """Returns the names of the joints in the model."""
267
+ """Return the names of the joints in the model."""
268
268
 
269
269
  return [
270
270
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
@@ -272,7 +272,7 @@ class MujocoModelHelper:
272
272
  ]
273
273
 
274
274
  def joint_dofs(self, joint_name: str) -> int:
275
- """Returns the number of DoFs of a joint."""
275
+ """Return the number of DoFs of a joint."""
276
276
 
277
277
  if joint_name not in self.joint_names():
278
278
  raise ValueError(f"Joint '{joint_name}' not found")
@@ -280,7 +280,7 @@ class MujocoModelHelper:
280
280
  return self.data.joint(joint_name).qpos.size
281
281
 
282
282
  def joint_position(self, joint_name: str) -> npt.NDArray:
283
- """Returns the position of a joint."""
283
+ """Return the position of a joint."""
284
284
 
285
285
  if joint_name not in self.joint_names():
286
286
  raise ValueError(f"Joint '{joint_name}' not found")
@@ -288,7 +288,7 @@ class MujocoModelHelper:
288
288
  return self.data.joint(joint_name).qpos
289
289
 
290
290
  def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
291
- """Returns the positions of the joints."""
291
+ """Return the positions of the joints."""
292
292
 
293
293
  joint_names = joint_names if joint_names is not None else self.joint_names()
294
294
 
@@ -299,7 +299,7 @@ class MujocoModelHelper:
299
299
  def set_joint_position(
300
300
  self, joint_name: str, position: npt.NDArray | float
301
301
  ) -> None:
302
- """Sets the position of a joint."""
302
+ """Set the position of a joint."""
303
303
 
304
304
  position = np.atleast_1d(np.array(position).squeeze())
305
305
 
@@ -328,12 +328,12 @@ class MujocoModelHelper:
328
328
  # ==================
329
329
 
330
330
  def number_of_bodies(self) -> int:
331
- """Returns the number of bodies in the model."""
331
+ """Return the number of bodies in the model."""
332
332
 
333
333
  return self.model.nbody
334
334
 
335
335
  def body_names(self) -> list[str]:
336
- """Returns the names of the bodies in the model."""
336
+ """Return the names of the bodies in the model."""
337
337
 
338
338
  return [
339
339
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
@@ -341,7 +341,7 @@ class MujocoModelHelper:
341
341
  ]
342
342
 
343
343
  def body_position(self, body_name: str) -> npt.NDArray:
344
- """Returns the position of a body."""
344
+ """Return the position of a body."""
345
345
 
346
346
  if body_name not in self.body_names():
347
347
  raise ValueError(f"Body '{body_name}' not found")
@@ -349,7 +349,7 @@ class MujocoModelHelper:
349
349
  return self.data.body(body_name).xpos
350
350
 
351
351
  def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
352
- """Returns the orientation of a body."""
352
+ """Return the orientation of a body."""
353
353
 
354
354
  if body_name not in self.body_names():
355
355
  raise ValueError(f"Body '{body_name}' not found")
@@ -363,12 +363,12 @@ class MujocoModelHelper:
363
363
  # ======================
364
364
 
365
365
  def number_of_geometries(self) -> int:
366
- """Returns the number of geometries in the model."""
366
+ """Return the number of geometries in the model."""
367
367
 
368
368
  return self.model.ngeom
369
369
 
370
370
  def geometry_names(self) -> list[str]:
371
- """Returns the names of the geometries in the model."""
371
+ """Return the names of the geometries in the model."""
372
372
 
373
373
  return [
374
374
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
@@ -376,7 +376,7 @@ class MujocoModelHelper:
376
376
  ]
377
377
 
378
378
  def geometry_position(self, geometry_name: str) -> npt.NDArray:
379
- """Returns the position of a geometry."""
379
+ """Return the position of a geometry."""
380
380
 
381
381
  if geometry_name not in self.geometry_names():
382
382
  raise ValueError(f"Geometry '{geometry_name}' not found")
@@ -386,7 +386,7 @@ class MujocoModelHelper:
386
386
  def geometry_orientation(
387
387
  self, geometry_name: str, dcm: bool = False
388
388
  ) -> npt.NDArray:
389
- """Returns the orientation of a geometry."""
389
+ """Return the orientation of a geometry."""
390
390
 
391
391
  if geometry_name not in self.geometry_names():
392
392
  raise ValueError(f"Geometry '{geometry_name}' not found")
jaxsim/mujoco/utils.py CHANGED
@@ -133,6 +133,9 @@ class MujocoCamera:
133
133
 
134
134
  @classmethod
135
135
  def build(cls, **kwargs) -> MujocoCamera:
136
+ """
137
+ Build a Mujoco camera from a dictionary.
138
+ """
136
139
 
137
140
  if not all(isinstance(value, str) for value in kwargs.values()):
138
141
  raise ValueError(f"Values must be strings: {kwargs}")
@@ -219,5 +222,7 @@ class MujocoCamera:
219
222
  )
220
223
 
221
224
  def asdict(self) -> dict[str, str]:
222
-
225
+ """
226
+ Convert the camera to a dictionary.
227
+ """
223
228
  return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
@@ -10,7 +10,9 @@ import numpy.typing as npt
10
10
 
11
11
 
12
12
  class MujocoVideoRecorder:
13
- """"""
13
+ """
14
+ Video recorder for the MuJoCo passive viewer.
15
+ """
14
16
 
15
17
  def __init__(
16
18
  self,
@@ -64,7 +66,7 @@ class MujocoVideoRecorder:
64
66
  self.model = model if model is not None else self.model
65
67
 
66
68
  def render_frame(self, camera_name: str = "track") -> npt.NDArray:
67
- """Renders a frame."""
69
+ """Render a frame."""
68
70
 
69
71
  mujoco.mj_forward(self.model, self.data)
70
72
  self.renderer.update_scene(data=self.data, camera=camera_name)
@@ -72,13 +74,13 @@ class MujocoVideoRecorder:
72
74
  return self.renderer.render()
73
75
 
74
76
  def record_frame(self, camera_name: str = "track") -> None:
75
- """Stores a frame in the buffer."""
77
+ """Store a frame in the buffer."""
76
78
 
77
79
  frame = self.render_frame(camera_name=camera_name)
78
80
  self.frames.append(frame)
79
81
 
80
82
  def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
81
- """Writes the video to a file."""
83
+ """Write the video to a file."""
82
84
 
83
85
  # Resolve the path to the video.
84
86
  path = path.expanduser().resolve()
@@ -117,7 +119,9 @@ class MujocoVideoRecorder:
117
119
 
118
120
 
119
121
  class MujocoVisualizer:
120
- """"""
122
+ """
123
+ Visualizer for the MuJoCo passive viewer.
124
+ """
121
125
 
122
126
  def __init__(
123
127
  self, model: mj.MjModel | None = None, data: mj.MjData | None = None
@@ -139,7 +143,7 @@ class MujocoVisualizer:
139
143
  model: mj.MjModel | None = None,
140
144
  data: mj.MjData | None = None,
141
145
  ) -> None:
142
- """Updates the viewer with the current model and data."""
146
+ """Update the viewer with the current model and data."""
143
147
 
144
148
  data = data if data is not None else self.data
145
149
  model = model if model is not None else self.model
@@ -150,7 +154,7 @@ class MujocoVisualizer:
150
154
  def open_viewer(
151
155
  self, model: mj.MjModel | None = None, data: mj.MjData | None = None
152
156
  ) -> mj.viewer.Handle:
153
- """Opens a viewer."""
157
+ """Open a viewer."""
154
158
 
155
159
  data = data if data is not None else self.data
156
160
  model = model if model is not None else self.model
@@ -22,7 +22,6 @@ class CollidablePoint:
22
22
  parent_link: The parent link to which the collidable point is attached.
23
23
  position: The position of the collidable point relative to the parent link.
24
24
  enabled: A flag indicating whether the collidable point is enabled for collision detection.
25
-
26
25
  """
27
26
 
28
27
  parent_link: LinkDescription
@@ -86,7 +85,6 @@ class CollisionShape(abc.ABC):
86
85
 
87
86
  Attributes:
88
87
  collidable_points: A list of collidable points associated with the collision shape.
89
-
90
88
  """
91
89
 
92
90
  collidable_points: tuple[CollidablePoint]
@@ -107,7 +105,6 @@ class BoxCollision(CollisionShape):
107
105
 
108
106
  Attributes:
109
107
  center: The center of the box in the local frame of the collision shape.
110
-
111
108
  """
112
109
 
113
110
  center: jtp.VectorLike
@@ -135,7 +132,6 @@ class SphereCollision(CollisionShape):
135
132
 
136
133
  Attributes:
137
134
  center: The center of the sphere in the local frame of the collision shape.
138
-
139
135
  """
140
136
 
141
137
  center: jtp.VectorLike
@@ -158,6 +154,13 @@ class SphereCollision(CollisionShape):
158
154
 
159
155
  @dataclasses.dataclass
160
156
  class MeshCollision(CollisionShape):
157
+ """
158
+ Represents a mesh-shaped collision shape.
159
+
160
+ Attributes:
161
+ center: The center of the mesh in the local frame of the collision shape.
162
+ """
163
+
161
164
  center: jtp.VectorLike
162
165
 
163
166
  def __hash__(self) -> int:
@@ -14,6 +14,9 @@ from .link import LinkDescription
14
14
 
15
15
  @dataclasses.dataclass(frozen=True)
16
16
  class JointType:
17
+ """
18
+ Enumeration of joint types.
19
+ """
17
20
 
18
21
  Fixed: ClassVar[int] = 0
19
22
  Revolute: ClassVar[int] = 1
@@ -47,20 +50,19 @@ class JointDescription(JaxsimDataclass):
47
50
  In-memory description of a robot link.
48
51
 
49
52
  Attributes:
50
- name (str): The name of the joint.
51
- axis (npt.NDArray): The axis of rotation or translation for the joint.
52
- pose (npt.NDArray): The pose transformation matrix of the joint.
53
- jtype (JointType): The type of the joint.
54
- child (LinkDescription): The child link attached to the joint.
55
- parent (LinkDescription): The parent link attached to the joint.
56
- index (Optional[int]): An optional index for the joint.
57
- friction_static (float): The static friction coefficient for the joint.
58
- friction_viscous (float): The viscous friction coefficient for the joint.
59
- position_limit_damper (float): The damper coefficient for position limits.
60
- position_limit_spring (float): The spring coefficient for position limits.
61
- position_limit (Tuple[float, float]): The position limits for the joint.
62
- initial_position (Union[float, npt.NDArray]): The initial position of the joint.
63
-
53
+ name: The name of the joint.
54
+ axis: The axis of rotation or translation for the joint.
55
+ pose: The pose transformation matrix of the joint.
56
+ jtype: The type of the joint.
57
+ child: The child link attached to the joint.
58
+ parent: The parent link attached to the joint.
59
+ index: An optional index for the joint.
60
+ friction_static: The static friction coefficient for the joint.
61
+ friction_viscous: The viscous friction coefficient for the joint.
62
+ position_limit_damper: The damper coefficient for position limits.
63
+ position_limit_spring: The spring coefficient for position limits.
64
+ position_limit: The position limits for the joint.
65
+ initial_position: The initial position of the joint.
64
66
  """
65
67
 
66
68
  name: jax_dataclasses.Static[str]
@@ -158,7 +158,7 @@ class ModelDescription(KinematicGraph):
158
158
  Reduce the model by removing specified joints.
159
159
 
160
160
  Args:
161
- The joint names to consider.
161
+ considered_joints: Sequence of joint names to consider.
162
162
 
163
163
  Returns:
164
164
  A `ModelDescription` instance that only includes the considered joints.
@@ -97,20 +97,32 @@ class KinematicGraph(Sequence[LinkDescription]):
97
97
 
98
98
  @functools.cached_property
99
99
  def links_dict(self) -> dict[str, LinkDescription]:
100
+ """
101
+ Get a dictionary of links indexed by their name.
102
+ """
100
103
  return {l.name: l for l in iter(self)}
101
104
 
102
105
  @functools.cached_property
103
106
  def frames_dict(self) -> dict[str, LinkDescription]:
107
+ """
108
+ Get a dictionary of frames indexed by their name.
109
+ """
104
110
  return {f.name: f for f in self.frames}
105
111
 
106
112
  @functools.cached_property
107
113
  def joints_dict(self) -> dict[str, JointDescription]:
114
+ """
115
+ Get a dictionary of joints indexed by their name.
116
+ """
108
117
  return {j.name: j for j in self.joints}
109
118
 
110
119
  @functools.cached_property
111
120
  def joints_connection_dict(
112
121
  self,
113
122
  ) -> dict[tuple[str, str], JointDescription]:
123
+ """
124
+ Get a dictionary of joints indexed by the tuple (parent, child) link names.
125
+ """
114
126
  return {(j.parent.name, j.child.name): j for j in self.joints}
115
127
 
116
128
  def __post_init__(self) -> None:
@@ -734,9 +746,15 @@ class KinematicGraph(Sequence[LinkDescription]):
734
746
  raise TypeError(type(key).__name__)
735
747
 
736
748
  def count(self, value: LinkDescription) -> int:
749
+ """
750
+ Count the occurrences of a link in the kinematic graph.
751
+ """
737
752
  return list(iter(self)).count(value)
738
753
 
739
754
  def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:
755
+ """
756
+ Find the index of a link in the kinematic graph.
757
+ """
740
758
  return list(iter(self)).index(value, start, stop)
741
759
 
742
760
 
@@ -747,6 +765,12 @@ class KinematicGraph(Sequence[LinkDescription]):
747
765
 
748
766
  @dataclasses.dataclass(frozen=True)
749
767
  class KinematicGraphTransforms:
768
+ """
769
+ Class to compute forward kinematics on a kinematic graph.
770
+
771
+ Attributes:
772
+ graph: The kinematic graph on which to compute forward kinematics.
773
+ """
750
774
 
751
775
  graph: KinematicGraph
752
776
 
@@ -767,6 +791,9 @@ class KinematicGraphTransforms:
767
791
 
768
792
  @property
769
793
  def initial_joint_positions(self) -> npt.NDArray:
794
+ """
795
+ Get the initial joint positions of the kinematic graph.
796
+ """
770
797
 
771
798
  return np.atleast_1d(
772
799
  np.array(list(self._initial_joint_positions.values()))
@@ -910,6 +937,17 @@ class KinematicGraphTransforms:
910
937
  joint_axis: npt.NDArray,
911
938
  joint_position: float | None = None,
912
939
  ) -> npt.NDArray:
940
+ """
941
+ Compute the SE(3) transform from the predecessor to the successor frame.
942
+
943
+ Args:
944
+ joint_type: The type of the joint.
945
+ joint_axis: The axis of the joint.
946
+ joint_position: The position of the joint.
947
+
948
+ Returns:
949
+ The 4x4 transform matrix from the predecessor to the successor frame.
950
+ """
913
951
 
914
952
  import jaxsim.math
915
953
 
@@ -6,14 +6,14 @@ VALID_AXIS = {"x": 0, "y": 1, "z": 2}
6
6
 
7
7
  def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
8
8
  """
9
- Extracts the vertices of a mesh as points.
9
+ Extract the vertices of a mesh as points.
10
10
  """
11
11
  return mesh.vertices
12
12
 
13
13
 
14
14
  def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:
15
15
  """
16
- Extracts N random points from the surface of a mesh.
16
+ Extract N random points from the surface of a mesh.
17
17
 
18
18
  Args:
19
19
  mesh: The mesh from which to extract points.
@@ -30,7 +30,7 @@ def extract_points_uniform_surface_sampling(
30
30
  mesh: trimesh.Trimesh, n: int
31
31
  ) -> np.ndarray:
32
32
  """
33
- Extracts N uniformly sampled points from the surface of a mesh.
33
+ Extract N uniformly sampled points from the surface of a mesh.
34
34
 
35
35
  Args:
36
36
  mesh: The mesh from which to extract points.
@@ -47,7 +47,7 @@ def extract_points_select_points_over_axis(
47
47
  mesh: trimesh.Trimesh, axis: str, direction: str, n: int
48
48
  ) -> np.ndarray:
49
49
  """
50
- Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis.
50
+ Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis.
51
51
 
52
52
  Args:
53
53
  mesh: The mesh from which to extract points.
@@ -75,7 +75,7 @@ def extract_points_aap(
75
75
  lower: float | None = None,
76
76
  ) -> np.ndarray:
77
77
  """
78
- Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.
78
+ Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.
79
79
 
80
80
  Args:
81
81
  mesh: The mesh from which to extract points.
@@ -364,7 +364,7 @@ def build_model_description(
364
364
  is_urdf: bool | None = None,
365
365
  ) -> descriptions.ModelDescription:
366
366
  """
367
- Builds a model description from an SDF/URDF resource.
367
+ Build a model description from an SDF/URDF resource.
368
368
 
369
369
  Args:
370
370
  model_description: A path to an SDF/URDF file, a string containing its content,
@@ -223,6 +223,17 @@ def create_mesh_collision(
223
223
  link_description: descriptions.LinkDescription,
224
224
  method: MeshMappingMethod = None,
225
225
  ) -> descriptions.MeshCollision:
226
+ """
227
+ Create a mesh collision from an SDF collision element.
228
+
229
+ Args:
230
+ collision: The SDF collision element.
231
+ link_description: The link description.
232
+ method: The method to use for mesh wrapping.
233
+
234
+ Returns:
235
+ The mesh collision description.
236
+ """
226
237
 
227
238
  file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
228
239
  file_type = file.suffix.replace(".", "")
@@ -125,6 +125,7 @@ class ContactModel(JaxsimDataclass):
125
125
  Args:
126
126
  model: The robot model considered by the contact model.
127
127
  data: The data of the considered model.
128
+ **kwargs: Optional additional arguments, specific to the contact model.
128
129
 
129
130
  Returns:
130
131
  A tuple containing as first element the computed 6D contact force applied to
@@ -146,6 +147,7 @@ class ContactModel(JaxsimDataclass):
146
147
  Args:
147
148
  model: The robot model considered by the contact model.
148
149
  data: The data of the considered model.
150
+ **kwargs: Optional additional arguments, specific to the contact model.
149
151
 
150
152
  Returns:
151
153
  A tuple containing as first element the 6D contact force applied to the