jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py ADDED
@@ -0,0 +1,414 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import pathlib
5
+ from typing import Any, Callable
6
+
7
+ import mujoco as mj
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ from scipy.spatial.transform import Rotation
11
+
12
+ import jaxsim.typing as jtp
13
+
14
+ HeightmapCallable = Callable[[jtp.FloatLike, jtp.FloatLike], jtp.FloatLike]
15
+
16
+
17
+ class MujocoModelHelper:
18
+ """
19
+ Helper class to create and interact with Mujoco models and data objects.
20
+ """
21
+
22
+ def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:
23
+ """
24
+ Initialize the MujocoModelHelper object.
25
+
26
+ Args:
27
+ model: A Mujoco model object.
28
+ data: A Mujoco data object. If None, a new one will be created.
29
+ """
30
+
31
+ self.model = model
32
+ self.data = data if data is not None else mj.MjData(self.model)
33
+
34
+ # Populate the data with kinematics
35
+ mj.mj_forward(self.model, self.data)
36
+
37
+ # Keep the cache of this method local to improve GC
38
+ self.mask_qpos = functools.cache(self._mask_qpos)
39
+
40
+ @staticmethod
41
+ def build_from_xml(
42
+ mjcf_description: str | pathlib.Path,
43
+ assets: dict[str, Any] = None,
44
+ heightmap: HeightmapCallable | None = None,
45
+ ) -> MujocoModelHelper:
46
+ """
47
+ Build a Mujoco model from an XML description and an optional assets dictionary.
48
+
49
+ Args:
50
+ mjcf_description: A string containing the XML description of the Mujoco model
51
+ or a path to a file containing the XML description.
52
+ assets: An optional dictionary containing the assets of the model.
53
+ heightmap: A function in two variables that returns the height of a terrain
54
+ in the specified coordinate point.
55
+ Returns:
56
+ A MujocoModelHelper object.
57
+ """
58
+
59
+ # Read the XML description if it's a path to file
60
+ mjcf_description = (
61
+ mjcf_description.read_text()
62
+ if isinstance(mjcf_description, pathlib.Path)
63
+ else mjcf_description
64
+ )
65
+
66
+ # Create the Mujoco model from the XML and, optionally, the assets dictionary
67
+ model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa
68
+ data = mj.MjData(model)
69
+
70
+ if heightmap:
71
+ nrow = model.hfield_nrow.item()
72
+ ncol = model.hfield_ncol.item()
73
+ new_hfield = generate_hfield(heightmap, (nrow, ncol))
74
+ model.hfield_data = new_hfield
75
+
76
+ return MujocoModelHelper(model=model, data=mj.MjData(model))
77
+
78
+ def time(self) -> float:
79
+ """Return the simulation time."""
80
+
81
+ return self.data.time
82
+
83
+ def timestep(self) -> float:
84
+ """Return the simulation timestep."""
85
+
86
+ return self.model.opt.timestep
87
+
88
+ def gravity(self) -> npt.NDArray:
89
+ """Return the 3D gravity vector."""
90
+
91
+ return self.model.opt.gravity
92
+
93
+ # =========================
94
+ # Methods for the base link
95
+ # =========================
96
+
97
+ def is_floating_base(self) -> bool:
98
+ """Return true if the model is floating-base."""
99
+
100
+ # A body with no joints is considered a fixed-base model.
101
+ # In fact, in mujoco, a floating-base model has a 6 DoFs first joint.
102
+ if self.number_of_joints() == 0:
103
+ return False
104
+
105
+ # We just check that the first joint has 6 DoFs.
106
+ joint0_type = self.model.jnt_type[0]
107
+ return joint0_type == mj.mjtJoint.mjJNT_FREE
108
+
109
+ def is_fixed_base(self) -> bool:
110
+ """Return true if the model is fixed-base."""
111
+
112
+ return not self.is_floating_base()
113
+
114
+ def base_link(self) -> str:
115
+ """Return the name of the base link."""
116
+
117
+ return mj.mj_id2name(
118
+ self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1
119
+ )
120
+
121
+ def base_position(self) -> npt.NDArray:
122
+ """Return the 3D position of the base link."""
123
+
124
+ return (
125
+ self.data.qpos[:3]
126
+ if self.is_floating_base()
127
+ else self.body_position(body_name=self.base_link())
128
+ )
129
+
130
+ def base_orientation(self, dcm: bool = False) -> npt.NDArray:
131
+ """Return the orientation of the base link."""
132
+
133
+ return (
134
+ (
135
+ np.reshape(self.data.xmat[0], newshape=(3, 3))
136
+ if dcm is True
137
+ else self.data.xquat[0]
138
+ )
139
+ if self.is_floating_base()
140
+ else self.body_orientation(body_name=self.base_link(), dcm=dcm)
141
+ )
142
+
143
+ def set_base_position(self, position: npt.NDArray) -> None:
144
+ """Set the 3D position of the base link."""
145
+
146
+ if self.is_fixed_base():
147
+ raise ValueError("The position of a fixed-base model cannot be set.")
148
+
149
+ position = np.atleast_1d(np.array(position).squeeze())
150
+
151
+ if position.size != 3:
152
+ raise ValueError(f"Wrong position size ({position.size})")
153
+
154
+ self.data.qpos[:3] = position
155
+
156
+ def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None:
157
+ """Set the 3D position of the base link."""
158
+
159
+ if self.is_fixed_base():
160
+ raise ValueError("The orientation of a fixed-base model cannot be set.")
161
+
162
+ orientation = (
163
+ np.atleast_2d(np.array(orientation).squeeze())
164
+ if dcm
165
+ else np.atleast_1d(np.array(orientation).squeeze())
166
+ )
167
+
168
+ if orientation.shape != ((4,) if not dcm else (3, 3)):
169
+ raise ValueError(f"Wrong orientation shape {orientation.shape}")
170
+
171
+ def is_quaternion(Q):
172
+ return np.allclose(np.linalg.norm(Q), 1.0)
173
+
174
+ def is_dcm(R):
175
+ return np.allclose(np.linalg.det(R), 1.0) and np.allclose(
176
+ R.T @ R, np.eye(3)
177
+ )
178
+
179
+ if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)):
180
+ raise ValueError("The orientation is not a valid element of SO(3)")
181
+
182
+ W_Q_B = (
183
+ Rotation.from_matrix(orientation).as_quat(canonical=True)[
184
+ np.array([3, 0, 1, 2])
185
+ ]
186
+ if dcm
187
+ else orientation
188
+ )
189
+
190
+ self.data.qpos[3:7] = W_Q_B
191
+
192
+ # ==================
193
+ # Methods for joints
194
+ # ==================
195
+
196
+ def number_of_joints(self) -> int:
197
+ """Returns the number of joints in the model."""
198
+
199
+ return self.model.njnt
200
+
201
+ def number_of_dofs(self) -> int:
202
+ """Returns the number of DoFs in the model."""
203
+
204
+ return self.model.nq
205
+
206
+ def joint_names(self) -> list[str]:
207
+ """Returns the names of the joints in the model."""
208
+
209
+ return [
210
+ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
211
+ for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints())
212
+ ]
213
+
214
+ def joint_dofs(self, joint_name: str) -> int:
215
+ """Returns the number of DoFs of a joint."""
216
+
217
+ if joint_name not in self.joint_names():
218
+ raise ValueError(f"Joint '{joint_name}' not found")
219
+
220
+ return self.data.joint(joint_name).qpos.size
221
+
222
+ def joint_position(self, joint_name: str) -> npt.NDArray:
223
+ """Returns the position of a joint."""
224
+
225
+ if joint_name not in self.joint_names():
226
+ raise ValueError(f"Joint '{joint_name}' not found")
227
+
228
+ return self.data.joint(joint_name).qpos
229
+
230
+ def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
231
+ """Returns the positions of the joints."""
232
+
233
+ joint_names = joint_names if joint_names is not None else self.joint_names()
234
+
235
+ return np.hstack(
236
+ [self.joint_position(joint_name) for joint_name in joint_names]
237
+ )
238
+
239
+ def set_joint_position(
240
+ self, joint_name: str, position: npt.NDArray | float
241
+ ) -> None:
242
+ """Sets the position of a joint."""
243
+
244
+ position = np.atleast_1d(np.array(position).squeeze())
245
+
246
+ if position.size != self.joint_dofs(joint_name=joint_name):
247
+ raise ValueError(
248
+ f"Wrong position size ({position.size}) of "
249
+ f"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'."
250
+ )
251
+
252
+ idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
253
+ offset = self.model.jnt_qposadr[idx]
254
+
255
+ sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)]
256
+ self.data.qpos[sl] = position
257
+
258
+ def set_joint_positions(
259
+ self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
260
+ ) -> None:
261
+ """Set the positions of multiple joints."""
262
+
263
+ mask = self.mask_qpos(joint_names=tuple(joint_names))
264
+ self.data.qpos[mask] = positions
265
+
266
+ # ==================
267
+ # Methods for bodies
268
+ # ==================
269
+
270
+ def number_of_bodies(self) -> int:
271
+ """Returns the number of bodies in the model."""
272
+
273
+ return self.model.nbody
274
+
275
+ def body_names(self) -> list[str]:
276
+ """Returns the names of the bodies in the model."""
277
+
278
+ return [
279
+ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
280
+ for idx in range(self.number_of_bodies())
281
+ ]
282
+
283
+ def body_position(self, body_name: str) -> npt.NDArray:
284
+ """Returns the position of a body."""
285
+
286
+ if body_name not in self.body_names():
287
+ raise ValueError(f"Body '{body_name}' not found")
288
+
289
+ return self.data.body(body_name).xpos
290
+
291
+ def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
292
+ """Returns the orientation of a body."""
293
+
294
+ if body_name not in self.body_names():
295
+ raise ValueError(f"Body '{body_name}' not found")
296
+
297
+ return (
298
+ self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat
299
+ )
300
+
301
+ # ======================
302
+ # Methods for geometries
303
+ # ======================
304
+
305
+ def number_of_geometries(self) -> int:
306
+ """Returns the number of geometries in the model."""
307
+
308
+ return self.model.ngeom
309
+
310
+ def geometry_names(self) -> list[str]:
311
+ """Returns the names of the geometries in the model."""
312
+
313
+ return [
314
+ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
315
+ for idx in range(self.number_of_geometries())
316
+ ]
317
+
318
+ def geometry_position(self, geometry_name: str) -> npt.NDArray:
319
+ """Returns the position of a geometry."""
320
+
321
+ if geometry_name not in self.geometry_names():
322
+ raise ValueError(f"Geometry '{geometry_name}' not found")
323
+
324
+ return self.data.geom(geometry_name).xpos
325
+
326
+ def geometry_orientation(
327
+ self, geometry_name: str, dcm: bool = False
328
+ ) -> npt.NDArray:
329
+ """Returns the orientation of a geometry."""
330
+
331
+ if geometry_name not in self.geometry_names():
332
+ raise ValueError(f"Geometry '{geometry_name}' not found")
333
+
334
+ R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3))
335
+
336
+ if dcm:
337
+ return R
338
+
339
+ q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
340
+ return q_xyzw[[3, 0, 1, 2]]
341
+
342
+ # ===============
343
+ # Private methods
344
+ # ===============
345
+
346
+ def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:
347
+ """
348
+ Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array.
349
+
350
+ Args:
351
+ joint_names: A tuple containing the names of the joints.
352
+
353
+ Returns:
354
+ A 1D array containing the indices of the `qpos` array to access the DoFs of
355
+ the desired `joint_names`.
356
+
357
+ Note:
358
+ This method takes a tuple of strings because we cache the output mask for
359
+ each combination of joint names. We need a hashable object for the cache.
360
+ """
361
+
362
+ # Get the indices of the joints in `joint_names`.
363
+ idxs = [
364
+ mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
365
+ for joint_name in joint_names
366
+ ]
367
+
368
+ # We first get the index of each joint in the qpos array, and for those that
369
+ # have multiple DoFs, we expand their mask by appending new elements.
370
+ # Finally, we flatten the list of arrays to a single array, that is the
371
+ # final qpos mask accessing all the DoFs of the desired `joint_names`.
372
+ return np.atleast_1d(
373
+ np.hstack(
374
+ [
375
+ np.array(
376
+ [
377
+ self.model.jnt_qposadr[idx] + i
378
+ for i in range(self.joint_dofs(joint_name=joint_name))
379
+ ]
380
+ )
381
+ for idx, joint_name in zip(idxs, joint_names)
382
+ ]
383
+ ).squeeze()
384
+ )
385
+
386
+
387
+ def generate_hfield(
388
+ heightmap: HeightmapCallable, size: tuple[int, int] = (10, 10)
389
+ ) -> npt.NDArray:
390
+ """
391
+ Generates a numpy array representing the heightmap of
392
+ The map will have the following format:
393
+ ```
394
+ heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]
395
+ heightmap[1, 0] heightmap[1, 1] ... heightmap[1, size[1]-1]
396
+ ...
397
+ heightmap[size[0]-1, 0] heightmap[size[0]-1, 1] ... heightmap[size[0]-1, size[1]-1]
398
+ ```
399
+
400
+ Args:
401
+ heightmap: A function that takes two arguments (x, y) and returns the height
402
+ at that point.
403
+ size: A tuple of two integers representing the size of the grid.
404
+
405
+ Returns:
406
+ np.ndarray: The terrain heightmap
407
+ """
408
+
409
+ # Generate the grid.
410
+ x = np.linspace(0, 1, size[0])
411
+ y = np.linspace(0, 1, size[1])
412
+
413
+ # Generate the heightmap.
414
+ return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()
@@ -0,0 +1,176 @@
1
+ import contextlib
2
+ import pathlib
3
+ from typing import ContextManager
4
+
5
+ import mediapy as media
6
+ import mujoco as mj
7
+ import mujoco.viewer
8
+ import numpy.typing as npt
9
+
10
+
11
+ class MujocoVideoRecorder:
12
+ """"""
13
+
14
+ def __init__(
15
+ self,
16
+ model: mj.MjModel,
17
+ data: mj.MjData,
18
+ fps: int = 30,
19
+ width: int | None = None,
20
+ height: int | None = None,
21
+ **kwargs,
22
+ ) -> None:
23
+ """
24
+ Initialize the Mujoco video recorder.
25
+
26
+ Args:
27
+ model: The Mujoco model.
28
+ data: The Mujoco data.
29
+ fps: The frames per second.
30
+ width: The width of the video.
31
+ height: The height of the video.
32
+ **kwargs: Additional arguments for the renderer.
33
+ """
34
+
35
+ width = width if width is not None else model.vis.global_.offwidth
36
+ height = height if height is not None else model.vis.global_.offheight
37
+
38
+ if model.vis.global_.offwidth != width:
39
+ model.vis.global_.offwidth = width
40
+
41
+ if model.vis.global_.offheight != height:
42
+ model.vis.global_.offheight = height
43
+
44
+ self.fps = fps
45
+ self.frames: list[npt.NDArray] = []
46
+ self.data: mujoco.MjData | None = None
47
+ self.model: mujoco.MjModel | None = None
48
+ self.reset(model=model, data=data)
49
+
50
+ self.renderer = mujoco.Renderer(
51
+ model=self.model,
52
+ **(dict(width=width, height=height) | kwargs),
53
+ )
54
+
55
+ def reset(
56
+ self, model: mj.MjModel | None = None, data: mj.MjData | None = None
57
+ ) -> None:
58
+ """Reset the model and data."""
59
+
60
+ self.frames = []
61
+
62
+ self.data = data if data is not None else self.data
63
+ self.model = model if model is not None else self.model
64
+
65
+ def render_frame(self, camera_name: str | None = None) -> npt.NDArray:
66
+ """Renders a frame."""
67
+ camera_name = camera_name or "track"
68
+
69
+ mujoco.mj_forward(self.model, self.data)
70
+ self.renderer.update_scene(data=self.data, camera=camera_name)
71
+
72
+ return self.renderer.render()
73
+
74
+ def record_frame(self, camera_name: str | None = None) -> None:
75
+ """Stores a frame in the buffer."""
76
+ camera_name = camera_name or "track"
77
+
78
+ frame = self.render_frame(camera_name=camera_name)
79
+ self.frames.append(frame)
80
+
81
+ def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
82
+ """Writes the video to a file."""
83
+
84
+ if path.is_dir():
85
+ raise IsADirectoryError(f"The path '{path}' is a directory.")
86
+
87
+ if not exist_ok and path.is_file():
88
+ raise FileExistsError(f"The file '{path}' already exists.")
89
+
90
+ media.write_video(path=path, images=self.frames, fps=self.fps)
91
+
92
+ @staticmethod
93
+ def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
94
+ """
95
+ Return the integer down-sampling factor to reach at least the target fps.
96
+
97
+ Args:
98
+ original_fps: The original fps.
99
+ target_min_fps: The target minimum fps.
100
+
101
+ Returns:
102
+ The down-sampling factor.
103
+ """
104
+
105
+ down_sampling = 1
106
+ down_sampling_final = down_sampling
107
+
108
+ while original_fps / (down_sampling + 1) >= target_min_fps:
109
+ down_sampling = down_sampling + 1
110
+
111
+ if int(original_fps / down_sampling) == original_fps / down_sampling:
112
+ down_sampling_final = down_sampling
113
+
114
+ return down_sampling_final
115
+
116
+
117
+ class MujocoVisualizer:
118
+ """"""
119
+
120
+ def __init__(
121
+ self, model: mj.MjModel | None = None, data: mj.MjData | None = None
122
+ ) -> None:
123
+ """
124
+ Initialize the Mujoco visualizer.
125
+
126
+ Args:
127
+ model: The Mujoco model.
128
+ data: The Mujoco data.
129
+ """
130
+
131
+ self.data = data
132
+ self.model = model
133
+
134
+ def sync(
135
+ self,
136
+ viewer: mujoco.viewer.Handle,
137
+ model: mj.MjModel | None = None,
138
+ data: mj.MjData | None = None,
139
+ ) -> None:
140
+ """Updates the viewer with the current model and data."""
141
+
142
+ data = data if data is not None else self.data
143
+ model = model if model is not None else self.model
144
+
145
+ mj.mj_forward(model, data)
146
+ viewer.sync()
147
+
148
+ def open_viewer(
149
+ self, model: mj.MjModel | None = None, data: mj.MjData | None = None
150
+ ) -> mj.viewer.Handle:
151
+ """Opens a viewer."""
152
+
153
+ data = data if data is not None else self.data
154
+ model = model if model is not None else self.model
155
+
156
+ handle = mj.viewer.launch_passive(
157
+ model, data, show_left_ui=False, show_right_ui=False
158
+ )
159
+
160
+ return handle
161
+
162
+ @contextlib.contextmanager
163
+ def open(
164
+ self,
165
+ model: mj.MjModel | None = None,
166
+ data: mj.MjData | None = None,
167
+ close_on_exit: bool = True,
168
+ ) -> ContextManager[mujoco.viewer.Handle]:
169
+ """Context manager to open a viewer."""
170
+
171
+ handle = self.open_viewer(model=model, data=data)
172
+
173
+ try:
174
+ yield handle
175
+ finally:
176
+ handle.close() if close_on_exit else None
@@ -50,6 +50,14 @@ class CollidablePoint:
50
50
  enabled=self.enabled,
51
51
  )
52
52
 
53
+ def __eq__(self, other):
54
+ retval = (
55
+ self.parent_link == other.parent_link
56
+ and (self.position == other.position).all()
57
+ and self.enabled == other.enabled
58
+ )
59
+ return retval
60
+
53
61
  def __str__(self):
54
62
  return (
55
63
  f"{self.__class__.__name__}("
@@ -93,6 +101,9 @@ class BoxCollision(CollisionShape):
93
101
 
94
102
  center: npt.NDArray
95
103
 
104
+ def __eq__(self, other):
105
+ return (self.center == other.center).all() and super().__eq__(other)
106
+
96
107
 
97
108
  @dataclasses.dataclass
98
109
  class SphereCollision(CollisionShape):
@@ -105,3 +116,6 @@ class SphereCollision(CollisionShape):
105
116
  """
106
117
 
107
118
  center: npt.NDArray
119
+
120
+ def __eq__(self, other):
121
+ return (self.center == other.center).all() and super().__eq__(other)
@@ -3,10 +3,10 @@ from typing import List
3
3
 
4
4
  import jax.numpy as jnp
5
5
  import jax_dataclasses
6
+ import jaxlie
6
7
  from jax_dataclasses import Static
7
8
 
8
9
  import jaxsim.typing as jtp
9
- from jaxsim.sixd import se3
10
10
  from jaxsim.utils import JaxsimDataclass
11
11
 
12
12
 
@@ -38,6 +38,17 @@ class LinkDescription(JaxsimDataclass):
38
38
  def __hash__(self) -> int:
39
39
  return hash(self.__repr__())
40
40
 
41
+ def __eq__(self, other) -> bool:
42
+ return (
43
+ self.name == other.name
44
+ and self.mass == other.mass
45
+ and (self.inertia == other.inertia).all()
46
+ and self.index == other.index
47
+ and self.parent == other.parent
48
+ and (self.pose == other.pose).all()
49
+ and self.children == other.children
50
+ )
51
+
41
52
  @property
42
53
  def name_and_index(self) -> str:
43
54
  """
@@ -67,7 +78,7 @@ class LinkDescription(JaxsimDataclass):
67
78
  I_removed = link.inertia
68
79
 
69
80
  # Create the SE3 object. Note the inverse.
70
- r_H_l = se3.SE3.from_matrix(lumped_H_removed).inverse()
81
+ r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse()
71
82
  r_X_l = r_H_l.adjoint()
72
83
 
73
84
  # Move the inertia
@@ -34,6 +34,11 @@ class RootPose(NamedTuple):
34
34
  root_position: npt.NDArray = np.zeros(3)
35
35
  root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
36
36
 
37
+ def __eq__(self, other):
38
+ return (self.root_position == other.root_position).all() and (
39
+ self.root_quaternion == other.root_quaternion
40
+ ).all()
41
+
37
42
 
38
43
  @dataclasses.dataclass(frozen=True)
39
44
  class KinematicGraph:
@@ -117,7 +122,7 @@ class KinematicGraph:
117
122
 
118
123
  # Check that joint indices are unique
119
124
  assert len([j.index for j in self.joints]) == len(
120
- set([j.index for j in self.joints])
125
+ {j.index for j in self.joints}
121
126
  )
122
127
 
123
128
  # Order joints with their indices
@@ -263,12 +268,12 @@ class KinematicGraph:
263
268
 
264
269
  # Return early if there is no action to take
265
270
  if len(joint_names_to_remove) == 0:
266
- logging.info(f"The kinematic graph doesn't need to be reduced")
271
+ logging.info("The kinematic graph doesn't need to be reduced")
267
272
  return copy.deepcopy(self)
268
273
 
269
274
  # Check if all considered joints are part of the full kinematic graph
270
275
  if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
271
- extra_j = set(considered_joints) - set([j.name for j in full_graph.joints])
276
+ extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
272
277
  msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
273
278
  raise ValueError(msg)
274
279