jaxsim 0.2.dev56__py3-none-any.whl → 0.2.dev77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/mujoco/model.py ADDED
@@ -0,0 +1,352 @@
1
+ import functools
2
+ import pathlib
3
+ from typing import Any
4
+
5
+ import mujoco as mj
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ from scipy.spatial.transform import Rotation
9
+
10
+
11
+ class MujocoModelHelper:
12
+ """
13
+ Helper class to create and interact with Mujoco models and data objects.
14
+ """
15
+
16
+ def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:
17
+ """"""
18
+
19
+ self.model = model
20
+ self.data = data if data is not None else mj.MjData(self.model)
21
+
22
+ # Populate the data with kinematics
23
+ mj.mj_forward(self.model, self.data)
24
+
25
+ # Keep the cache of this method local to improve GC
26
+ self.mask_qpos = functools.cache(self._mask_qpos)
27
+
28
+ @staticmethod
29
+ def build_from_xml(
30
+ mjcf_description: str | pathlib.Path, assets: dict[str, Any] = None
31
+ ) -> "MujocoModelHelper":
32
+ """"""
33
+
34
+ # Read the XML description if it's a path to file
35
+ mjcf_description = (
36
+ mjcf_description.read_text()
37
+ if isinstance(mjcf_description, pathlib.Path)
38
+ else mjcf_description
39
+ )
40
+
41
+ # Create the Mujoco model from the XML and, optionally, the assets dictionary
42
+ model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa
43
+
44
+ return MujocoModelHelper(model=model, data=mj.MjData(model))
45
+
46
+ def time(self) -> float:
47
+ """Return the simulation time."""
48
+
49
+ return self.data.time
50
+
51
+ def timestep(self) -> float:
52
+ """Return the simulation timestep."""
53
+
54
+ return self.model.opt.timestep
55
+
56
+ def gravity(self) -> npt.NDArray:
57
+ """Return the 3D gravity vector."""
58
+
59
+ return self.model.opt.gravity
60
+
61
+ # =========================
62
+ # Methods for the base link
63
+ # =========================
64
+
65
+ def is_floating_base(self) -> bool:
66
+ """Return true if the model is floating-base."""
67
+
68
+ # A body with no joints is considered a fixed-base model.
69
+ # In fact, in mujoco, a floating-base model has a 6 DoFs first joint.
70
+ if self.number_of_joints() == 0:
71
+ return False
72
+
73
+ # We just check that the first joint has 6 DoFs.
74
+ joint0_type = self.model.jnt_type[0]
75
+ return joint0_type == mj.mjtJoint.mjJNT_FREE
76
+
77
+ def is_fixed_base(self) -> bool:
78
+ """Return true if the model is fixed-base."""
79
+
80
+ return not self.is_floating_base()
81
+
82
+ def base_link(self) -> str:
83
+ """Return the name of the base link."""
84
+
85
+ return mj.mj_id2name(
86
+ self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1
87
+ )
88
+
89
+ def base_position(self) -> npt.NDArray:
90
+ """Return the 3D position of the base link."""
91
+
92
+ return (
93
+ self.data.qpos[:3]
94
+ if self.is_floating_base()
95
+ else self.body_position(body_name=self.base_link())
96
+ )
97
+
98
+ def base_orientation(self, dcm: bool = False) -> npt.NDArray:
99
+ """Return the orientation of the base link."""
100
+
101
+ return (
102
+ (
103
+ np.reshape(self.data.xmat[0], newshape=(3, 3))
104
+ if dcm is True
105
+ else self.data.xquat[0]
106
+ )
107
+ if self.is_floating_base()
108
+ else self.body_orientation(body_name=self.base_link(), dcm=dcm)
109
+ )
110
+
111
+ def set_base_position(self, position: npt.NDArray) -> None:
112
+ """Set the 3D position of the base link."""
113
+
114
+ if self.is_fixed_base():
115
+ raise ValueError("The position of a fixed-base model cannot be set.")
116
+
117
+ position = np.atleast_1d(np.array(position).squeeze())
118
+
119
+ if position.size != 3:
120
+ raise ValueError(f"Wrong position size ({position.size})")
121
+
122
+ self.data.qpos[:3] = position
123
+
124
+ def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None:
125
+ """Set the 3D position of the base link."""
126
+
127
+ if self.is_fixed_base():
128
+ raise ValueError("The orientation of a fixed-base model cannot be set.")
129
+
130
+ orientation = (
131
+ np.atleast_2d(np.array(orientation).squeeze())
132
+ if dcm
133
+ else np.atleast_1d(np.array(orientation).squeeze())
134
+ )
135
+
136
+ if orientation.shape != ((4,) if not dcm else (3, 3)):
137
+ raise ValueError(f"Wrong orientation shape {orientation.shape}")
138
+
139
+ def is_quaternion(Q):
140
+ return np.allclose(np.linalg.norm(Q), 1.0)
141
+
142
+ def is_dcm(R):
143
+ return np.allclose(np.linalg.det(R), 1.0) and np.allclose(
144
+ R.T @ R, np.eye(3)
145
+ )
146
+
147
+ if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)):
148
+ raise ValueError("The orientation is not a valid element of SO(3)")
149
+
150
+ W_Q_B = (
151
+ Rotation.from_matrix(orientation).as_quat(canonical=True)[
152
+ np.array([3, 0, 1, 2])
153
+ ]
154
+ if dcm
155
+ else orientation
156
+ )
157
+
158
+ self.data.qpos[3:7] = W_Q_B
159
+
160
+ # ==================
161
+ # Methods for joints
162
+ # ==================
163
+
164
+ def number_of_joints(self) -> int:
165
+ """"""
166
+
167
+ return self.model.njnt
168
+
169
+ def number_of_dofs(self) -> int:
170
+ """"""
171
+
172
+ return self.model.nq
173
+
174
+ def joint_names(self) -> list[str]:
175
+ """"""
176
+
177
+ return [
178
+ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
179
+ for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints())
180
+ ]
181
+
182
+ def joint_dofs(self, joint_name: str) -> int:
183
+ """"""
184
+
185
+ if joint_name not in self.joint_names():
186
+ raise ValueError(f"Joint '{joint_name}' not found")
187
+
188
+ return self.data.joint(joint_name).qpos.size
189
+
190
+ def joint_position(self, joint_name: str) -> npt.NDArray:
191
+ """"""
192
+
193
+ if joint_name not in self.joint_names():
194
+ raise ValueError(f"Joint '{joint_name}' not found")
195
+
196
+ return self.data.joint(joint_name).qpos
197
+
198
+ def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
199
+ """"""
200
+
201
+ joint_names = joint_names if joint_names is not None else self.joint_names()
202
+
203
+ return np.hstack(
204
+ [self.joint_position(joint_name) for joint_name in joint_names]
205
+ )
206
+
207
+ def set_joint_position(
208
+ self, joint_name: str, position: npt.NDArray | float
209
+ ) -> None:
210
+ """"""
211
+
212
+ position = np.atleast_1d(np.array(position).squeeze())
213
+
214
+ if position.size != self.joint_dofs(joint_name=joint_name):
215
+ raise ValueError(
216
+ f"Wrong position size ({position.size}) of "
217
+ f"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'."
218
+ )
219
+
220
+ idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
221
+ offset = self.model.jnt_qposadr[idx]
222
+
223
+ sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)]
224
+ self.data.qpos[sl] = position
225
+
226
+ def set_joint_positions(
227
+ self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
228
+ ) -> None:
229
+ """"""
230
+
231
+ mask = self.mask_qpos(joint_names=tuple(joint_names))
232
+ self.data.qpos[mask] = positions
233
+
234
+ # ==================
235
+ # Methods for bodies
236
+ # ==================
237
+
238
+ def number_of_bodies(self) -> int:
239
+ """"""
240
+
241
+ return self.model.nbody
242
+
243
+ def body_names(self) -> list[str]:
244
+ """"""
245
+
246
+ return [
247
+ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
248
+ for idx in range(self.number_of_bodies())
249
+ ]
250
+
251
+ def body_position(self, body_name: str) -> npt.NDArray:
252
+ """"""
253
+
254
+ if body_name not in self.body_names():
255
+ raise ValueError(f"Body '{body_name}' not found")
256
+
257
+ return self.data.body(body_name).xpos
258
+
259
+ def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
260
+ """"""
261
+
262
+ if body_name not in self.body_names():
263
+ raise ValueError(f"Body '{body_name}' not found")
264
+
265
+ return (
266
+ self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat
267
+ )
268
+
269
+ # ======================
270
+ # Methods for geometries
271
+ # ======================
272
+
273
+ def number_of_geometries(self) -> int:
274
+ """"""
275
+
276
+ return self.model.ngeom
277
+
278
+ def geometry_names(self) -> list[str]:
279
+ """"""
280
+
281
+ return [
282
+ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
283
+ for idx in range(self.number_of_geometries())
284
+ ]
285
+
286
+ def geometry_position(self, geometry_name: str) -> npt.NDArray:
287
+ """"""
288
+
289
+ if geometry_name not in self.geometry_names():
290
+ raise ValueError(f"Geometry '{geometry_name}' not found")
291
+
292
+ return self.data.geom(geometry_name).xpos
293
+
294
+ def geometry_orientation(
295
+ self, geometry_name: str, dcm: bool = False
296
+ ) -> npt.NDArray:
297
+ """"""
298
+
299
+ if geometry_name not in self.geometry_names():
300
+ raise ValueError(f"Geometry '{geometry_name}' not found")
301
+
302
+ R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3))
303
+
304
+ if dcm:
305
+ return R
306
+
307
+ q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
308
+ return q_xyzw[[3, 0, 1, 2]]
309
+
310
+ # ===============
311
+ # Private methods
312
+ # ===============
313
+
314
+ def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:
315
+ """
316
+ Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array.
317
+
318
+ Args:
319
+ joint_names: A tuple containing the names of the joints.
320
+
321
+ Returns:
322
+ A 1D array containing the indices of the `qpos` array to access the DoFs of
323
+ the desired `joint_names`.
324
+
325
+ Note:
326
+ This method takes a tuple of strings because we cache the output mask for
327
+ each combination of joint names. We need a hashable object for the cache.
328
+ """
329
+
330
+ # Get the indices of the joints in `joint_names`.
331
+ idxs = [
332
+ mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
333
+ for joint_name in joint_names
334
+ ]
335
+
336
+ # We first get the index of each joint in the qpos array, and for those that
337
+ # have multiple DoFs, we expand their mask by appending new elements.
338
+ # Finally, we flatten the list of arrays to a single array, that is the
339
+ # final qpos mask accessing all the DoFs of the desired `joint_names`.
340
+ return np.atleast_1d(
341
+ np.hstack(
342
+ [
343
+ np.array(
344
+ [
345
+ self.model.jnt_qposadr[idx] + i
346
+ for i in range(self.joint_dofs(joint_name=joint_name))
347
+ ]
348
+ )
349
+ for idx, joint_name in zip(idxs, joint_names)
350
+ ]
351
+ ).squeeze()
352
+ )
@@ -0,0 +1,152 @@
1
+ import contextlib
2
+ import pathlib
3
+ from typing import ContextManager
4
+
5
+ import mediapy as media
6
+ import mujoco as mj
7
+ import mujoco.viewer
8
+ import numpy.typing as npt
9
+
10
+
11
+ class MujocoVideoRecorder:
12
+ """"""
13
+
14
+ def __init__(
15
+ self,
16
+ model: mj.MjModel,
17
+ data: mj.MjData,
18
+ fps: int = 30,
19
+ width: int | None = None,
20
+ height: int | None = None,
21
+ **kwargs,
22
+ ) -> None:
23
+ """"""
24
+
25
+ width = width if width is not None else model.vis.global_.offwidth
26
+ height = height if height is not None else model.vis.global_.offheight
27
+
28
+ if model.vis.global_.offwidth != width:
29
+ model.vis.global_.offwidth = width
30
+
31
+ if model.vis.global_.offheight != height:
32
+ model.vis.global_.offheight = height
33
+
34
+ self.fps = fps
35
+ self.frames: list[npt.NDArray] = []
36
+ self.data: mujoco.MjData | None = None
37
+ self.model: mujoco.MjModel | None = None
38
+ self.reset(model=model, data=data)
39
+
40
+ self.renderer = mujoco.Renderer(
41
+ model=self.model,
42
+ **(dict(width=width, height=height) | kwargs),
43
+ )
44
+
45
+ def reset(
46
+ self, model: mj.MjModel | None = None, data: mj.MjData | None = None
47
+ ) -> None:
48
+ """"""
49
+
50
+ self.frames = []
51
+
52
+ self.data = data if data is not None else self.data
53
+ self.model = model if model is not None else self.model
54
+
55
+ def render_frame(self, camera_name: str | None = None) -> None:
56
+ """"""
57
+
58
+ mujoco.mj_forward(self.model, self.data)
59
+ self.renderer.update_scene(data=self.data) # TODO camera name
60
+
61
+ self.frames.append(self.renderer.render())
62
+
63
+ def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
64
+ """"""
65
+
66
+ if path.is_dir():
67
+ raise IsADirectoryError(f"The path '{path}' is a directory.")
68
+
69
+ if not exist_ok and path.is_file():
70
+ raise FileExistsError(f"The file '{path}' already exists.")
71
+
72
+ media.write_video(path=path, images=self.frames, fps=self.fps)
73
+
74
+ @staticmethod
75
+ def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
76
+ """
77
+ Return the integer down-sampling factor to reach at least the target fps.
78
+
79
+ Args:
80
+ original_fps: The original fps.
81
+ target_min_fps: The target minimum fps.
82
+
83
+ Returns:
84
+ The down-sampling factor.
85
+ """
86
+
87
+ down_sampling = 1
88
+ down_sampling_final = down_sampling
89
+
90
+ while original_fps / (down_sampling + 1) >= target_min_fps:
91
+ down_sampling = down_sampling + 1
92
+
93
+ if int(original_fps / down_sampling) == original_fps / down_sampling:
94
+ down_sampling_final = down_sampling
95
+
96
+ return down_sampling_final
97
+
98
+
99
+ class MujocoVisualizer:
100
+ """"""
101
+
102
+ def __init__(
103
+ self, model: mj.MjModel | None = None, data: mj.MjData | None = None
104
+ ) -> None:
105
+ """"""
106
+
107
+ self.data = data
108
+ self.model = model
109
+
110
+ def sync(
111
+ self,
112
+ viewer: mujoco.viewer.Handle,
113
+ model: mj.MjModel | None = None,
114
+ data: mj.MjData | None = None,
115
+ ) -> None:
116
+ """"""
117
+
118
+ data = data if data is not None else self.data
119
+ model = model if model is not None else self.model
120
+
121
+ mj.mj_forward(model, data)
122
+ viewer.sync()
123
+
124
+ def open_viewer(
125
+ self, model: mj.MjModel | None = None, data: mj.MjData | None = None
126
+ ) -> mj.viewer.Handle:
127
+ """"""
128
+
129
+ data = data if data is not None else self.data
130
+ model = model if model is not None else self.model
131
+
132
+ handle = mj.viewer.launch_passive(
133
+ model, data, show_left_ui=False, show_right_ui=False
134
+ )
135
+
136
+ return handle
137
+
138
+ @contextlib.contextmanager
139
+ def open(
140
+ self,
141
+ model: mj.MjModel | None = None,
142
+ data: mj.MjData | None = None,
143
+ close_on_exit: bool = True,
144
+ ) -> ContextManager[mujoco.viewer.Handle]:
145
+ """"""
146
+
147
+ handle = self.open_viewer(model=model, data=data)
148
+
149
+ try:
150
+ yield handle
151
+ finally:
152
+ handle.close() if close_on_exit else None