jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,18 @@
1
1
  import contextlib
2
2
  import pathlib
3
- from typing import ContextManager
3
+ from collections.abc import Sequence
4
4
 
5
5
  import mediapy as media
6
6
  import mujoco as mj
7
7
  import mujoco.viewer
8
+ import numpy as np
8
9
  import numpy.typing as npt
9
10
 
10
11
 
11
12
  class MujocoVideoRecorder:
12
- """"""
13
+ """
14
+ Video recorder for the MuJoCo passive viewer.
15
+ """
13
16
 
14
17
  def __init__(
15
18
  self,
@@ -20,7 +23,17 @@ class MujocoVideoRecorder:
20
23
  height: int | None = None,
21
24
  **kwargs,
22
25
  ) -> None:
23
- """"""
26
+ """
27
+ Initialize the Mujoco video recorder.
28
+
29
+ Args:
30
+ model: The Mujoco model.
31
+ data: The Mujoco data.
32
+ fps: The frames per second.
33
+ width: The width of the video.
34
+ height: The height of the video.
35
+ **kwargs: Additional arguments for the renderer.
36
+ """
24
37
 
25
38
  width = width if width is not None else model.vis.global_.offwidth
26
39
  height = height if height is not None else model.vis.global_.offheight
@@ -45,31 +58,32 @@ class MujocoVideoRecorder:
45
58
  def reset(
46
59
  self, model: mj.MjModel | None = None, data: mj.MjData | None = None
47
60
  ) -> None:
48
- """"""
61
+ """Reset the model and data."""
49
62
 
50
63
  self.frames = []
51
64
 
52
65
  self.data = data if data is not None else self.data
53
66
  self.model = model if model is not None else self.model
54
67
 
55
- def render_frame(self, camera_name: str | None = None) -> npt.NDArray:
56
- """"""
57
- camera_name = camera_name or "track"
68
+ def render_frame(self, camera_name: str = "track") -> npt.NDArray:
69
+ """Render a frame."""
58
70
 
59
71
  mujoco.mj_forward(self.model, self.data)
60
72
  self.renderer.update_scene(data=self.data, camera=camera_name)
61
73
 
62
74
  return self.renderer.render()
63
75
 
64
- def record_frame(self, camera_name: str | None = None) -> None:
65
- """"""
66
- camera_name = camera_name or "track"
76
+ def record_frame(self, camera_name: str = "track") -> None:
77
+ """Store a frame in the buffer."""
67
78
 
68
79
  frame = self.render_frame(camera_name=camera_name)
69
80
  self.frames.append(frame)
70
81
 
71
82
  def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
72
- """"""
83
+ """Write the video to a file."""
84
+
85
+ # Resolve the path to the video.
86
+ path = path.expanduser().resolve()
73
87
 
74
88
  if path.is_dir():
75
89
  raise IsADirectoryError(f"The path '{path}' is a directory.")
@@ -77,7 +91,7 @@ class MujocoVideoRecorder:
77
91
  if not exist_ok and path.is_file():
78
92
  raise FileExistsError(f"The file '{path}' already exists.")
79
93
 
80
- media.write_video(path=path, images=self.frames, fps=self.fps)
94
+ media.write_video(path=path, images=np.array(self.frames), fps=self.fps)
81
95
 
82
96
  @staticmethod
83
97
  def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
@@ -105,12 +119,20 @@ class MujocoVideoRecorder:
105
119
 
106
120
 
107
121
  class MujocoVisualizer:
108
- """"""
122
+ """
123
+ Visualizer for the MuJoCo passive viewer.
124
+ """
109
125
 
110
126
  def __init__(
111
127
  self, model: mj.MjModel | None = None, data: mj.MjData | None = None
112
128
  ) -> None:
113
- """"""
129
+ """
130
+ Initialize the Mujoco visualizer.
131
+
132
+ Args:
133
+ model: The Mujoco model.
134
+ data: The Mujoco data.
135
+ """
114
136
 
115
137
  self.data = data
116
138
  self.model = model
@@ -121,7 +143,7 @@ class MujocoVisualizer:
121
143
  model: mj.MjModel | None = None,
122
144
  data: mj.MjData | None = None,
123
145
  ) -> None:
124
- """"""
146
+ """Update the viewer with the current model and data."""
125
147
 
126
148
  data = data if data is not None else self.data
127
149
  model = model if model is not None else self.model
@@ -130,15 +152,18 @@ class MujocoVisualizer:
130
152
  viewer.sync()
131
153
 
132
154
  def open_viewer(
133
- self, model: mj.MjModel | None = None, data: mj.MjData | None = None
155
+ self,
156
+ model: mj.MjModel | None = None,
157
+ data: mj.MjData | None = None,
158
+ show_left_ui: bool = False,
134
159
  ) -> mj.viewer.Handle:
135
- """"""
160
+ """Open a viewer."""
136
161
 
137
162
  data = data if data is not None else self.data
138
163
  model = model if model is not None else self.model
139
164
 
140
165
  handle = mj.viewer.launch_passive(
141
- model, data, show_left_ui=False, show_right_ui=False
166
+ model, data, show_left_ui=show_left_ui, show_right_ui=False
142
167
  )
143
168
 
144
169
  return handle
@@ -148,13 +173,73 @@ class MujocoVisualizer:
148
173
  self,
149
174
  model: mj.MjModel | None = None,
150
175
  data: mj.MjData | None = None,
176
+ *,
177
+ show_left_ui: bool = False,
151
178
  close_on_exit: bool = True,
152
- ) -> ContextManager[mujoco.viewer.Handle]:
153
- """"""
179
+ lookat: Sequence[float | int] | npt.NDArray | None = None,
180
+ distance: float | int | npt.NDArray | None = None,
181
+ azimuth: float | int | npt.NDArray | None = None,
182
+ elevation: float | int | npt.NDArray | None = None,
183
+ ) -> contextlib.AbstractContextManager[mujoco.viewer.Handle]:
184
+ """
185
+ Context manager to open the Mujoco passive viewer.
186
+
187
+ Note:
188
+ Refer to the Mujoco documentation for details of the camera options:
189
+ https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global
190
+ """
191
+
192
+ handle = self.open_viewer(model=model, data=data, show_left_ui=show_left_ui)
154
193
 
155
- handle = self.open_viewer(model=model, data=data)
194
+ handle = MujocoVisualizer.setup_viewer_camera(
195
+ viewer=handle,
196
+ lookat=lookat,
197
+ distance=distance,
198
+ azimuth=azimuth,
199
+ elevation=elevation,
200
+ )
156
201
 
157
202
  try:
158
203
  yield handle
159
204
  finally:
160
- handle.close() if close_on_exit else None
205
+ _ = handle.close() if close_on_exit else None
206
+
207
+ @staticmethod
208
+ def setup_viewer_camera(
209
+ viewer: mj.viewer.Handle,
210
+ *,
211
+ lookat: Sequence[float | int] | npt.NDArray | None,
212
+ distance: float | int | npt.NDArray | None = None,
213
+ azimuth: float | int | npt.NDArray | None = None,
214
+ elevation: float | int | npt.NDArray | None = None,
215
+ ) -> mj.viewer.Handle:
216
+ """
217
+ Configure the initial viewpoint of the Mujoco passive viewer.
218
+
219
+ Note:
220
+ Refer to the Mujoco documentation for details of the camera options:
221
+ https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global
222
+
223
+ Returns:
224
+ The viewer with configured camera.
225
+ """
226
+
227
+ if lookat is not None:
228
+
229
+ lookat_array = np.array(lookat, dtype=float).squeeze()
230
+
231
+ if lookat_array.size != 3:
232
+ raise ValueError(lookat)
233
+
234
+ viewer.cam.lookat = lookat_array
235
+
236
+ if distance is not None:
237
+ viewer.cam.distance = float(distance)
238
+
239
+ if azimuth is not None:
240
+ viewer.cam.azimuth = float(azimuth) % 360
241
+
242
+ if elevation is not None:
243
+ viewer.cam.elevation = float(elevation)
244
+
245
+ return viewer
@@ -1 +0,0 @@
1
- from . import descriptions, kinematic_graph
@@ -1,4 +1,10 @@
1
- from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
2
- from .joint import JointDescription, JointDescriptor, JointGenericAxis, JointType
1
+ from .collision import (
2
+ BoxCollision,
3
+ CollidablePoint,
4
+ CollisionShape,
5
+ MeshCollision,
6
+ SphereCollision,
7
+ )
8
+ from .joint import JointDescription, JointGenericAxis, JointType
3
9
  from .link import LinkDescription
4
10
  from .model import ModelDescription
@@ -1,11 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  import abc
2
4
  import dataclasses
3
- from typing import List
4
5
 
5
6
  import jax.numpy as jnp
6
7
  import numpy as np
7
8
  import numpy.typing as npt
8
9
 
10
+ import jaxsim.typing as jtp
9
11
  from jaxsim import logging
10
12
 
11
13
  from .link import LinkDescription
@@ -17,10 +19,9 @@ class CollidablePoint:
17
19
  Represents a collidable point associated with a parent link.
18
20
 
19
21
  Attributes:
20
- parent_link (LinkDescription): The parent link to which the collidable point is attached.
21
- position (npt.NDArray): The position of the collidable point relative to the parent link.
22
- enabled (bool): A flag indicating whether the collidable point is enabled for collision detection.
23
-
22
+ parent_link: The parent link to which the collidable point is attached.
23
+ position: The position of the collidable point relative to the parent link.
24
+ enabled: A flag indicating whether the collidable point is enabled for collision detection.
24
25
  """
25
26
 
26
27
  parent_link: LinkDescription
@@ -29,7 +30,7 @@ class CollidablePoint:
29
30
 
30
31
  def change_link(
31
32
  self, new_link: LinkDescription, new_H_old: npt.NDArray
32
- ) -> "CollidablePoint":
33
+ ) -> CollidablePoint:
33
34
  """
34
35
  Move the collidable point to a new parent link.
35
36
 
@@ -39,8 +40,8 @@ class CollidablePoint:
39
40
 
40
41
  Returns:
41
42
  CollidablePoint: A new collidable point associated with the new parent link.
42
-
43
43
  """
44
+
44
45
  msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}"
45
46
  logging.debug(msg=msg)
46
47
 
@@ -50,7 +51,24 @@ class CollidablePoint:
50
51
  enabled=self.enabled,
51
52
  )
52
53
 
53
- def __str__(self):
54
+ def __hash__(self) -> int:
55
+
56
+ return hash(
57
+ (
58
+ hash(self.parent_link),
59
+ hash(tuple(self.position.tolist())),
60
+ hash(self.enabled),
61
+ )
62
+ )
63
+
64
+ def __eq__(self, other: CollidablePoint) -> bool:
65
+
66
+ if not isinstance(other, CollidablePoint):
67
+ return False
68
+
69
+ return hash(self) == hash(other)
70
+
71
+ def __str__(self) -> str:
54
72
  return (
55
73
  f"{self.__class__.__name__}("
56
74
  + f"parent_link={self.parent_link.name}"
@@ -66,11 +84,10 @@ class CollisionShape(abc.ABC):
66
84
  Abstract base class for representing collision shapes.
67
85
 
68
86
  Attributes:
69
- collidable_points (List[CollidablePoint]): A list of collidable points associated with the collision shape.
70
-
87
+ collidable_points: A list of collidable points associated with the collision shape.
71
88
  """
72
89
 
73
- collidable_points: List[CollidablePoint]
90
+ collidable_points: tuple[CollidablePoint]
74
91
 
75
92
  def __str__(self):
76
93
  return (
@@ -87,11 +104,25 @@ class BoxCollision(CollisionShape):
87
104
  Represents a box-shaped collision shape.
88
105
 
89
106
  Attributes:
90
- center (npt.NDArray): The center of the box in the local frame of the collision shape.
91
-
107
+ center: The center of the box in the local frame of the collision shape.
92
108
  """
93
109
 
94
- center: npt.NDArray
110
+ center: jtp.VectorLike
111
+
112
+ def __hash__(self) -> int:
113
+ return hash(
114
+ (
115
+ hash(super()),
116
+ hash(tuple(self.center.tolist())),
117
+ )
118
+ )
119
+
120
+ def __eq__(self, other: BoxCollision) -> bool:
121
+
122
+ if not isinstance(other, BoxCollision):
123
+ return False
124
+
125
+ return hash(self) == hash(other)
95
126
 
96
127
 
97
128
  @dataclasses.dataclass
@@ -100,8 +131,48 @@ class SphereCollision(CollisionShape):
100
131
  Represents a spherical collision shape.
101
132
 
102
133
  Attributes:
103
- center (npt.NDArray): The center of the sphere in the local frame of the collision shape.
134
+ center: The center of the sphere in the local frame of the collision shape.
135
+ """
136
+
137
+ center: jtp.VectorLike
138
+
139
+ def __hash__(self) -> int:
140
+ return hash(
141
+ (
142
+ hash(super()),
143
+ hash(tuple(self.center.tolist())),
144
+ )
145
+ )
146
+
147
+ def __eq__(self, other: BoxCollision) -> bool:
148
+
149
+ if not isinstance(other, BoxCollision):
150
+ return False
151
+
152
+ return hash(self) == hash(other)
153
+
104
154
 
155
+ @dataclasses.dataclass
156
+ class MeshCollision(CollisionShape):
157
+ """
158
+ Represents a mesh-shaped collision shape.
159
+
160
+ Attributes:
161
+ center: The center of the mesh in the local frame of the collision shape.
105
162
  """
106
163
 
107
- center: npt.NDArray
164
+ center: jtp.VectorLike
165
+
166
+ def __hash__(self) -> int:
167
+ return hash(
168
+ (
169
+ hash(tuple(self.center.tolist())),
170
+ hash(self.collidable_points),
171
+ )
172
+ )
173
+
174
+ def __eq__(self, other: MeshCollision) -> bool:
175
+ if not isinstance(other, MeshCollision):
176
+ return False
177
+
178
+ return hash(self) == hash(other)
@@ -1,137 +1,130 @@
1
+ from __future__ import annotations
2
+
1
3
  import dataclasses
2
- import enum
3
- from typing import Tuple, Union
4
+ from typing import ClassVar
4
5
 
5
6
  import jax_dataclasses
6
7
  import numpy as np
7
- import numpy.typing as npt
8
8
 
9
+ import jaxsim.typing as jtp
9
10
  from jaxsim.utils import JaxsimDataclass, Mutability
10
11
 
11
12
  from .link import LinkDescription
12
13
 
13
14
 
14
- class JointType(enum.IntEnum):
15
+ @dataclasses.dataclass(frozen=True)
16
+ class JointType:
15
17
  """
16
- Enumeration of joint types for robot joints.
17
-
18
- Args:
19
- F: Fixed joint (no movement).
20
- R: Revolute joint (rotation).
21
- P: Prismatic joint (translation).
22
- Rx: Revolute joint with rotation about the X-axis.
23
- Ry: Revolute joint with rotation about the Y-axis.
24
- Rz: Revolute joint with rotation about the Z-axis.
25
- Px: Prismatic joint with translation along the X-axis.
26
- Py: Prismatic joint with translation along the Y-axis.
27
- Pz: Prismatic joint with translation along the Z-axis.
18
+ Enumeration of joint types.
28
19
  """
29
20
 
30
- F = enum.auto() # Fixed
31
- R = enum.auto() # Revolute
32
- P = enum.auto() # Prismatic
33
-
34
- # Revolute joints, single axis
35
- Rx = enum.auto()
36
- Ry = enum.auto()
37
- Rz = enum.auto()
38
-
39
- # Prismatic joints, single axis
40
- Px = enum.auto()
41
- Py = enum.auto()
42
- Pz = enum.auto()
21
+ Fixed: ClassVar[int] = 0
22
+ Revolute: ClassVar[int] = 1
23
+ Prismatic: ClassVar[int] = 2
43
24
 
44
25
 
45
- @dataclasses.dataclass
46
- class JointDescriptor:
26
+ @jax_dataclasses.pytree_dataclass
27
+ class JointGenericAxis:
47
28
  """
48
- Description of a joint type with a specific code.
49
-
50
- Args:
51
- code (JointType): The code representing the joint type.
52
-
29
+ A joint requiring the specification of a 3D axis.
53
30
  """
54
31
 
55
- code: JointType
32
+ # The axis of rotation or translation of the joint (must have norm 1).
33
+ axis: jtp.Vector
56
34
 
57
35
  def __hash__(self) -> int:
58
- return hash(self.__repr__())
59
-
60
-
61
- @dataclasses.dataclass
62
- class JointGenericAxis(JointDescriptor):
63
- """
64
- Description of a joint type with a generic axis.
65
-
66
- Attributes:
67
- axis (npt.NDArray): The axis of rotation or translation for the joint.
68
36
 
69
- """
70
-
71
- axis: npt.NDArray
37
+ return hash(tuple(self.axis.tolist()))
72
38
 
73
- def __post_init__(self):
74
- if np.allclose(self.axis, 0.0):
75
- raise ValueError(self.axis)
39
+ def __eq__(self, other: JointGenericAxis) -> bool:
76
40
 
77
- def __eq__(self, other):
78
- return super().__eq__(other) and np.allclose(self.axis, other.axis)
41
+ if not isinstance(other, JointGenericAxis):
42
+ return False
79
43
 
80
- def __hash__(self) -> int:
81
- return hash(self.__repr__())
44
+ return hash(self) == hash(other)
82
45
 
83
46
 
84
- @jax_dataclasses.pytree_dataclass
47
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
85
48
  class JointDescription(JaxsimDataclass):
86
49
  """
87
50
  In-memory description of a robot link.
88
51
 
89
52
  Attributes:
90
- name (str): The name of the joint.
91
- axis (npt.NDArray): The axis of rotation or translation for the joint.
92
- pose (npt.NDArray): The pose transformation matrix of the joint.
93
- jtype (Union[JointType, JointDescriptor]): The type of the joint.
94
- child (LinkDescription): The child link attached to the joint.
95
- parent (LinkDescription): The parent link attached to the joint.
96
- index (Optional[int]): An optional index for the joint.
97
- friction_static (float): The static friction coefficient for the joint.
98
- friction_viscous (float): The viscous friction coefficient for the joint.
99
- position_limit_damper (float): The damper coefficient for position limits.
100
- position_limit_spring (float): The spring coefficient for position limits.
101
- position_limit (Tuple[float, float]): The position limits for the joint.
102
- initial_position (Union[float, npt.NDArray]): The initial position of the joint.
103
-
53
+ name: The name of the joint.
54
+ axis: The axis of rotation or translation for the joint.
55
+ pose: The pose transformation matrix of the joint.
56
+ jtype: The type of the joint.
57
+ child: The child link attached to the joint.
58
+ parent: The parent link attached to the joint.
59
+ index: An optional index for the joint.
60
+ friction_static: The static friction coefficient for the joint.
61
+ friction_viscous: The viscous friction coefficient for the joint.
62
+ position_limit_damper: The damper coefficient for position limits.
63
+ position_limit_spring: The spring coefficient for position limits.
64
+ position_limit: The position limits for the joint.
65
+ initial_position: The initial position of the joint.
104
66
  """
105
67
 
106
68
  name: jax_dataclasses.Static[str]
107
- axis: npt.NDArray
108
- pose: npt.NDArray
109
- jtype: jax_dataclasses.Static[Union[JointType, JointDescriptor]]
69
+ axis: jtp.Vector
70
+ pose: jtp.Matrix
71
+ jtype: jax_dataclasses.Static[jtp.IntLike]
110
72
  child: LinkDescription = dataclasses.dataclass(repr=False)
111
73
  parent: LinkDescription = dataclasses.dataclass(repr=False)
112
74
 
113
- index: int | None = None
75
+ index: jtp.IntLike | None = None
76
+
77
+ friction_static: jtp.FloatLike = 0.0
78
+ friction_viscous: jtp.FloatLike = 0.0
114
79
 
115
- friction_static: float = 0.0
116
- friction_viscous: float = 0.0
80
+ position_limit_damper: jtp.FloatLike = 0.0
81
+ position_limit_spring: jtp.FloatLike = 0.0
117
82
 
118
- position_limit_damper: float = 0.0
119
- position_limit_spring: float = 0.0
83
+ position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0)
84
+ initial_position: jtp.FloatLike | jtp.VectorLike = 0.0
120
85
 
121
- position_limit: Tuple[float, float] = (0.0, 0.0)
122
- initial_position: Union[float, npt.NDArray] = 0.0
86
+ motor_inertia: jtp.FloatLike = 0.0
87
+ motor_viscous_friction: jtp.FloatLike = 0.0
88
+ motor_gear_ratio: jtp.FloatLike = 1.0
123
89
 
124
- motor_inertia: float = 0.0
125
- motor_viscous_friction: float = 0.0
126
- motor_gear_ratio: float = 1.0
90
+ def __post_init__(self) -> None:
127
91
 
128
- def __post_init__(self):
129
92
  if self.axis is not None:
93
+
130
94
  with self.mutable_context(
131
95
  mutability=Mutability.MUTABLE, restore_after_exception=False
132
96
  ):
133
97
  norm_of_axis = np.linalg.norm(self.axis)
134
98
  self.axis = self.axis / norm_of_axis
135
99
 
100
+ def __eq__(self, other: JointDescription) -> bool:
101
+
102
+ if not isinstance(other, JointDescription):
103
+ return False
104
+
105
+ return hash(self) == hash(other)
106
+
136
107
  def __hash__(self) -> int:
137
- return hash(self.__repr__())
108
+
109
+ from jaxsim.utils.wrappers import HashedNumpyArray
110
+
111
+ return hash(
112
+ (
113
+ hash(self.name),
114
+ HashedNumpyArray.hash_of_array(self.axis),
115
+ HashedNumpyArray.hash_of_array(self.pose),
116
+ hash(int(self.jtype)),
117
+ hash(self.child),
118
+ hash(self.parent),
119
+ hash(int(self.index)) if self.index is not None else 0,
120
+ HashedNumpyArray.hash_of_array(self.friction_static),
121
+ HashedNumpyArray.hash_of_array(self.friction_viscous),
122
+ HashedNumpyArray.hash_of_array(self.position_limit_damper),
123
+ HashedNumpyArray.hash_of_array(self.position_limit_spring),
124
+ HashedNumpyArray.hash_of_array(self.position_limit),
125
+ HashedNumpyArray.hash_of_array(self.initial_position),
126
+ HashedNumpyArray.hash_of_array(self.motor_inertia),
127
+ HashedNumpyArray.hash_of_array(self.motor_viscous_friction),
128
+ HashedNumpyArray.hash_of_array(self.motor_gear_ratio),
129
+ ),
130
+ )