jaxsim 0.2.dev56__py3-none-any.whl → 0.2.dev77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +475 -0
- jaxsim/mujoco/model.py +352 -0
- jaxsim/mujoco/visualizer.py +152 -0
- jaxsim/simulation/integrators.py +90 -343
- jaxsim/simulation/ode_integration.py +3 -16
- {jaxsim-0.2.dev56.dist-info → jaxsim-0.2.dev77.dist-info}/METADATA +6 -1
- {jaxsim-0.2.dev56.dist-info → jaxsim-0.2.dev77.dist-info}/RECORD +13 -8
- {jaxsim-0.2.dev56.dist-info → jaxsim-0.2.dev77.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev56.dist-info → jaxsim-0.2.dev77.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.dev56.dist-info → jaxsim-0.2.dev77.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py
ADDED
@@ -0,0 +1,352 @@
|
|
1
|
+
import functools
|
2
|
+
import pathlib
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import mujoco as mj
|
6
|
+
import numpy as np
|
7
|
+
import numpy.typing as npt
|
8
|
+
from scipy.spatial.transform import Rotation
|
9
|
+
|
10
|
+
|
11
|
+
class MujocoModelHelper:
|
12
|
+
"""
|
13
|
+
Helper class to create and interact with Mujoco models and data objects.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:
|
17
|
+
""""""
|
18
|
+
|
19
|
+
self.model = model
|
20
|
+
self.data = data if data is not None else mj.MjData(self.model)
|
21
|
+
|
22
|
+
# Populate the data with kinematics
|
23
|
+
mj.mj_forward(self.model, self.data)
|
24
|
+
|
25
|
+
# Keep the cache of this method local to improve GC
|
26
|
+
self.mask_qpos = functools.cache(self._mask_qpos)
|
27
|
+
|
28
|
+
@staticmethod
|
29
|
+
def build_from_xml(
|
30
|
+
mjcf_description: str | pathlib.Path, assets: dict[str, Any] = None
|
31
|
+
) -> "MujocoModelHelper":
|
32
|
+
""""""
|
33
|
+
|
34
|
+
# Read the XML description if it's a path to file
|
35
|
+
mjcf_description = (
|
36
|
+
mjcf_description.read_text()
|
37
|
+
if isinstance(mjcf_description, pathlib.Path)
|
38
|
+
else mjcf_description
|
39
|
+
)
|
40
|
+
|
41
|
+
# Create the Mujoco model from the XML and, optionally, the assets dictionary
|
42
|
+
model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa
|
43
|
+
|
44
|
+
return MujocoModelHelper(model=model, data=mj.MjData(model))
|
45
|
+
|
46
|
+
def time(self) -> float:
|
47
|
+
"""Return the simulation time."""
|
48
|
+
|
49
|
+
return self.data.time
|
50
|
+
|
51
|
+
def timestep(self) -> float:
|
52
|
+
"""Return the simulation timestep."""
|
53
|
+
|
54
|
+
return self.model.opt.timestep
|
55
|
+
|
56
|
+
def gravity(self) -> npt.NDArray:
|
57
|
+
"""Return the 3D gravity vector."""
|
58
|
+
|
59
|
+
return self.model.opt.gravity
|
60
|
+
|
61
|
+
# =========================
|
62
|
+
# Methods for the base link
|
63
|
+
# =========================
|
64
|
+
|
65
|
+
def is_floating_base(self) -> bool:
|
66
|
+
"""Return true if the model is floating-base."""
|
67
|
+
|
68
|
+
# A body with no joints is considered a fixed-base model.
|
69
|
+
# In fact, in mujoco, a floating-base model has a 6 DoFs first joint.
|
70
|
+
if self.number_of_joints() == 0:
|
71
|
+
return False
|
72
|
+
|
73
|
+
# We just check that the first joint has 6 DoFs.
|
74
|
+
joint0_type = self.model.jnt_type[0]
|
75
|
+
return joint0_type == mj.mjtJoint.mjJNT_FREE
|
76
|
+
|
77
|
+
def is_fixed_base(self) -> bool:
|
78
|
+
"""Return true if the model is fixed-base."""
|
79
|
+
|
80
|
+
return not self.is_floating_base()
|
81
|
+
|
82
|
+
def base_link(self) -> str:
|
83
|
+
"""Return the name of the base link."""
|
84
|
+
|
85
|
+
return mj.mj_id2name(
|
86
|
+
self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1
|
87
|
+
)
|
88
|
+
|
89
|
+
def base_position(self) -> npt.NDArray:
|
90
|
+
"""Return the 3D position of the base link."""
|
91
|
+
|
92
|
+
return (
|
93
|
+
self.data.qpos[:3]
|
94
|
+
if self.is_floating_base()
|
95
|
+
else self.body_position(body_name=self.base_link())
|
96
|
+
)
|
97
|
+
|
98
|
+
def base_orientation(self, dcm: bool = False) -> npt.NDArray:
|
99
|
+
"""Return the orientation of the base link."""
|
100
|
+
|
101
|
+
return (
|
102
|
+
(
|
103
|
+
np.reshape(self.data.xmat[0], newshape=(3, 3))
|
104
|
+
if dcm is True
|
105
|
+
else self.data.xquat[0]
|
106
|
+
)
|
107
|
+
if self.is_floating_base()
|
108
|
+
else self.body_orientation(body_name=self.base_link(), dcm=dcm)
|
109
|
+
)
|
110
|
+
|
111
|
+
def set_base_position(self, position: npt.NDArray) -> None:
|
112
|
+
"""Set the 3D position of the base link."""
|
113
|
+
|
114
|
+
if self.is_fixed_base():
|
115
|
+
raise ValueError("The position of a fixed-base model cannot be set.")
|
116
|
+
|
117
|
+
position = np.atleast_1d(np.array(position).squeeze())
|
118
|
+
|
119
|
+
if position.size != 3:
|
120
|
+
raise ValueError(f"Wrong position size ({position.size})")
|
121
|
+
|
122
|
+
self.data.qpos[:3] = position
|
123
|
+
|
124
|
+
def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None:
|
125
|
+
"""Set the 3D position of the base link."""
|
126
|
+
|
127
|
+
if self.is_fixed_base():
|
128
|
+
raise ValueError("The orientation of a fixed-base model cannot be set.")
|
129
|
+
|
130
|
+
orientation = (
|
131
|
+
np.atleast_2d(np.array(orientation).squeeze())
|
132
|
+
if dcm
|
133
|
+
else np.atleast_1d(np.array(orientation).squeeze())
|
134
|
+
)
|
135
|
+
|
136
|
+
if orientation.shape != ((4,) if not dcm else (3, 3)):
|
137
|
+
raise ValueError(f"Wrong orientation shape {orientation.shape}")
|
138
|
+
|
139
|
+
def is_quaternion(Q):
|
140
|
+
return np.allclose(np.linalg.norm(Q), 1.0)
|
141
|
+
|
142
|
+
def is_dcm(R):
|
143
|
+
return np.allclose(np.linalg.det(R), 1.0) and np.allclose(
|
144
|
+
R.T @ R, np.eye(3)
|
145
|
+
)
|
146
|
+
|
147
|
+
if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)):
|
148
|
+
raise ValueError("The orientation is not a valid element of SO(3)")
|
149
|
+
|
150
|
+
W_Q_B = (
|
151
|
+
Rotation.from_matrix(orientation).as_quat(canonical=True)[
|
152
|
+
np.array([3, 0, 1, 2])
|
153
|
+
]
|
154
|
+
if dcm
|
155
|
+
else orientation
|
156
|
+
)
|
157
|
+
|
158
|
+
self.data.qpos[3:7] = W_Q_B
|
159
|
+
|
160
|
+
# ==================
|
161
|
+
# Methods for joints
|
162
|
+
# ==================
|
163
|
+
|
164
|
+
def number_of_joints(self) -> int:
|
165
|
+
""""""
|
166
|
+
|
167
|
+
return self.model.njnt
|
168
|
+
|
169
|
+
def number_of_dofs(self) -> int:
|
170
|
+
""""""
|
171
|
+
|
172
|
+
return self.model.nq
|
173
|
+
|
174
|
+
def joint_names(self) -> list[str]:
|
175
|
+
""""""
|
176
|
+
|
177
|
+
return [
|
178
|
+
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
|
179
|
+
for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints())
|
180
|
+
]
|
181
|
+
|
182
|
+
def joint_dofs(self, joint_name: str) -> int:
|
183
|
+
""""""
|
184
|
+
|
185
|
+
if joint_name not in self.joint_names():
|
186
|
+
raise ValueError(f"Joint '{joint_name}' not found")
|
187
|
+
|
188
|
+
return self.data.joint(joint_name).qpos.size
|
189
|
+
|
190
|
+
def joint_position(self, joint_name: str) -> npt.NDArray:
|
191
|
+
""""""
|
192
|
+
|
193
|
+
if joint_name not in self.joint_names():
|
194
|
+
raise ValueError(f"Joint '{joint_name}' not found")
|
195
|
+
|
196
|
+
return self.data.joint(joint_name).qpos
|
197
|
+
|
198
|
+
def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
|
199
|
+
""""""
|
200
|
+
|
201
|
+
joint_names = joint_names if joint_names is not None else self.joint_names()
|
202
|
+
|
203
|
+
return np.hstack(
|
204
|
+
[self.joint_position(joint_name) for joint_name in joint_names]
|
205
|
+
)
|
206
|
+
|
207
|
+
def set_joint_position(
|
208
|
+
self, joint_name: str, position: npt.NDArray | float
|
209
|
+
) -> None:
|
210
|
+
""""""
|
211
|
+
|
212
|
+
position = np.atleast_1d(np.array(position).squeeze())
|
213
|
+
|
214
|
+
if position.size != self.joint_dofs(joint_name=joint_name):
|
215
|
+
raise ValueError(
|
216
|
+
f"Wrong position size ({position.size}) of "
|
217
|
+
f"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'."
|
218
|
+
)
|
219
|
+
|
220
|
+
idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
|
221
|
+
offset = self.model.jnt_qposadr[idx]
|
222
|
+
|
223
|
+
sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)]
|
224
|
+
self.data.qpos[sl] = position
|
225
|
+
|
226
|
+
def set_joint_positions(
|
227
|
+
self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
|
228
|
+
) -> None:
|
229
|
+
""""""
|
230
|
+
|
231
|
+
mask = self.mask_qpos(joint_names=tuple(joint_names))
|
232
|
+
self.data.qpos[mask] = positions
|
233
|
+
|
234
|
+
# ==================
|
235
|
+
# Methods for bodies
|
236
|
+
# ==================
|
237
|
+
|
238
|
+
def number_of_bodies(self) -> int:
|
239
|
+
""""""
|
240
|
+
|
241
|
+
return self.model.nbody
|
242
|
+
|
243
|
+
def body_names(self) -> list[str]:
|
244
|
+
""""""
|
245
|
+
|
246
|
+
return [
|
247
|
+
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
|
248
|
+
for idx in range(self.number_of_bodies())
|
249
|
+
]
|
250
|
+
|
251
|
+
def body_position(self, body_name: str) -> npt.NDArray:
|
252
|
+
""""""
|
253
|
+
|
254
|
+
if body_name not in self.body_names():
|
255
|
+
raise ValueError(f"Body '{body_name}' not found")
|
256
|
+
|
257
|
+
return self.data.body(body_name).xpos
|
258
|
+
|
259
|
+
def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
|
260
|
+
""""""
|
261
|
+
|
262
|
+
if body_name not in self.body_names():
|
263
|
+
raise ValueError(f"Body '{body_name}' not found")
|
264
|
+
|
265
|
+
return (
|
266
|
+
self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat
|
267
|
+
)
|
268
|
+
|
269
|
+
# ======================
|
270
|
+
# Methods for geometries
|
271
|
+
# ======================
|
272
|
+
|
273
|
+
def number_of_geometries(self) -> int:
|
274
|
+
""""""
|
275
|
+
|
276
|
+
return self.model.ngeom
|
277
|
+
|
278
|
+
def geometry_names(self) -> list[str]:
|
279
|
+
""""""
|
280
|
+
|
281
|
+
return [
|
282
|
+
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
|
283
|
+
for idx in range(self.number_of_geometries())
|
284
|
+
]
|
285
|
+
|
286
|
+
def geometry_position(self, geometry_name: str) -> npt.NDArray:
|
287
|
+
""""""
|
288
|
+
|
289
|
+
if geometry_name not in self.geometry_names():
|
290
|
+
raise ValueError(f"Geometry '{geometry_name}' not found")
|
291
|
+
|
292
|
+
return self.data.geom(geometry_name).xpos
|
293
|
+
|
294
|
+
def geometry_orientation(
|
295
|
+
self, geometry_name: str, dcm: bool = False
|
296
|
+
) -> npt.NDArray:
|
297
|
+
""""""
|
298
|
+
|
299
|
+
if geometry_name not in self.geometry_names():
|
300
|
+
raise ValueError(f"Geometry '{geometry_name}' not found")
|
301
|
+
|
302
|
+
R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3))
|
303
|
+
|
304
|
+
if dcm:
|
305
|
+
return R
|
306
|
+
|
307
|
+
q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
|
308
|
+
return q_xyzw[[3, 0, 1, 2]]
|
309
|
+
|
310
|
+
# ===============
|
311
|
+
# Private methods
|
312
|
+
# ===============
|
313
|
+
|
314
|
+
def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:
|
315
|
+
"""
|
316
|
+
Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
joint_names: A tuple containing the names of the joints.
|
320
|
+
|
321
|
+
Returns:
|
322
|
+
A 1D array containing the indices of the `qpos` array to access the DoFs of
|
323
|
+
the desired `joint_names`.
|
324
|
+
|
325
|
+
Note:
|
326
|
+
This method takes a tuple of strings because we cache the output mask for
|
327
|
+
each combination of joint names. We need a hashable object for the cache.
|
328
|
+
"""
|
329
|
+
|
330
|
+
# Get the indices of the joints in `joint_names`.
|
331
|
+
idxs = [
|
332
|
+
mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
|
333
|
+
for joint_name in joint_names
|
334
|
+
]
|
335
|
+
|
336
|
+
# We first get the index of each joint in the qpos array, and for those that
|
337
|
+
# have multiple DoFs, we expand their mask by appending new elements.
|
338
|
+
# Finally, we flatten the list of arrays to a single array, that is the
|
339
|
+
# final qpos mask accessing all the DoFs of the desired `joint_names`.
|
340
|
+
return np.atleast_1d(
|
341
|
+
np.hstack(
|
342
|
+
[
|
343
|
+
np.array(
|
344
|
+
[
|
345
|
+
self.model.jnt_qposadr[idx] + i
|
346
|
+
for i in range(self.joint_dofs(joint_name=joint_name))
|
347
|
+
]
|
348
|
+
)
|
349
|
+
for idx, joint_name in zip(idxs, joint_names)
|
350
|
+
]
|
351
|
+
).squeeze()
|
352
|
+
)
|
@@ -0,0 +1,152 @@
|
|
1
|
+
import contextlib
|
2
|
+
import pathlib
|
3
|
+
from typing import ContextManager
|
4
|
+
|
5
|
+
import mediapy as media
|
6
|
+
import mujoco as mj
|
7
|
+
import mujoco.viewer
|
8
|
+
import numpy.typing as npt
|
9
|
+
|
10
|
+
|
11
|
+
class MujocoVideoRecorder:
|
12
|
+
""""""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
model: mj.MjModel,
|
17
|
+
data: mj.MjData,
|
18
|
+
fps: int = 30,
|
19
|
+
width: int | None = None,
|
20
|
+
height: int | None = None,
|
21
|
+
**kwargs,
|
22
|
+
) -> None:
|
23
|
+
""""""
|
24
|
+
|
25
|
+
width = width if width is not None else model.vis.global_.offwidth
|
26
|
+
height = height if height is not None else model.vis.global_.offheight
|
27
|
+
|
28
|
+
if model.vis.global_.offwidth != width:
|
29
|
+
model.vis.global_.offwidth = width
|
30
|
+
|
31
|
+
if model.vis.global_.offheight != height:
|
32
|
+
model.vis.global_.offheight = height
|
33
|
+
|
34
|
+
self.fps = fps
|
35
|
+
self.frames: list[npt.NDArray] = []
|
36
|
+
self.data: mujoco.MjData | None = None
|
37
|
+
self.model: mujoco.MjModel | None = None
|
38
|
+
self.reset(model=model, data=data)
|
39
|
+
|
40
|
+
self.renderer = mujoco.Renderer(
|
41
|
+
model=self.model,
|
42
|
+
**(dict(width=width, height=height) | kwargs),
|
43
|
+
)
|
44
|
+
|
45
|
+
def reset(
|
46
|
+
self, model: mj.MjModel | None = None, data: mj.MjData | None = None
|
47
|
+
) -> None:
|
48
|
+
""""""
|
49
|
+
|
50
|
+
self.frames = []
|
51
|
+
|
52
|
+
self.data = data if data is not None else self.data
|
53
|
+
self.model = model if model is not None else self.model
|
54
|
+
|
55
|
+
def render_frame(self, camera_name: str | None = None) -> None:
|
56
|
+
""""""
|
57
|
+
|
58
|
+
mujoco.mj_forward(self.model, self.data)
|
59
|
+
self.renderer.update_scene(data=self.data) # TODO camera name
|
60
|
+
|
61
|
+
self.frames.append(self.renderer.render())
|
62
|
+
|
63
|
+
def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
|
64
|
+
""""""
|
65
|
+
|
66
|
+
if path.is_dir():
|
67
|
+
raise IsADirectoryError(f"The path '{path}' is a directory.")
|
68
|
+
|
69
|
+
if not exist_ok and path.is_file():
|
70
|
+
raise FileExistsError(f"The file '{path}' already exists.")
|
71
|
+
|
72
|
+
media.write_video(path=path, images=self.frames, fps=self.fps)
|
73
|
+
|
74
|
+
@staticmethod
|
75
|
+
def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
|
76
|
+
"""
|
77
|
+
Return the integer down-sampling factor to reach at least the target fps.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
original_fps: The original fps.
|
81
|
+
target_min_fps: The target minimum fps.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
The down-sampling factor.
|
85
|
+
"""
|
86
|
+
|
87
|
+
down_sampling = 1
|
88
|
+
down_sampling_final = down_sampling
|
89
|
+
|
90
|
+
while original_fps / (down_sampling + 1) >= target_min_fps:
|
91
|
+
down_sampling = down_sampling + 1
|
92
|
+
|
93
|
+
if int(original_fps / down_sampling) == original_fps / down_sampling:
|
94
|
+
down_sampling_final = down_sampling
|
95
|
+
|
96
|
+
return down_sampling_final
|
97
|
+
|
98
|
+
|
99
|
+
class MujocoVisualizer:
|
100
|
+
""""""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self, model: mj.MjModel | None = None, data: mj.MjData | None = None
|
104
|
+
) -> None:
|
105
|
+
""""""
|
106
|
+
|
107
|
+
self.data = data
|
108
|
+
self.model = model
|
109
|
+
|
110
|
+
def sync(
|
111
|
+
self,
|
112
|
+
viewer: mujoco.viewer.Handle,
|
113
|
+
model: mj.MjModel | None = None,
|
114
|
+
data: mj.MjData | None = None,
|
115
|
+
) -> None:
|
116
|
+
""""""
|
117
|
+
|
118
|
+
data = data if data is not None else self.data
|
119
|
+
model = model if model is not None else self.model
|
120
|
+
|
121
|
+
mj.mj_forward(model, data)
|
122
|
+
viewer.sync()
|
123
|
+
|
124
|
+
def open_viewer(
|
125
|
+
self, model: mj.MjModel | None = None, data: mj.MjData | None = None
|
126
|
+
) -> mj.viewer.Handle:
|
127
|
+
""""""
|
128
|
+
|
129
|
+
data = data if data is not None else self.data
|
130
|
+
model = model if model is not None else self.model
|
131
|
+
|
132
|
+
handle = mj.viewer.launch_passive(
|
133
|
+
model, data, show_left_ui=False, show_right_ui=False
|
134
|
+
)
|
135
|
+
|
136
|
+
return handle
|
137
|
+
|
138
|
+
@contextlib.contextmanager
|
139
|
+
def open(
|
140
|
+
self,
|
141
|
+
model: mj.MjModel | None = None,
|
142
|
+
data: mj.MjData | None = None,
|
143
|
+
close_on_exit: bool = True,
|
144
|
+
) -> ContextManager[mujoco.viewer.Handle]:
|
145
|
+
""""""
|
146
|
+
|
147
|
+
handle = self.open_viewer(model=model, data=data)
|
148
|
+
|
149
|
+
try:
|
150
|
+
yield handle
|
151
|
+
finally:
|
152
|
+
handle.close() if close_on_exit else None
|