jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1.dev401.dist-info/METADATA +0 -167
- jaxsim-0.1.dev401.dist-info/RECORD +0 -64
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py
ADDED
@@ -0,0 +1,414 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import pathlib
|
5
|
+
from typing import Any, Callable
|
6
|
+
|
7
|
+
import mujoco as mj
|
8
|
+
import numpy as np
|
9
|
+
import numpy.typing as npt
|
10
|
+
from scipy.spatial.transform import Rotation
|
11
|
+
|
12
|
+
import jaxsim.typing as jtp
|
13
|
+
|
14
|
+
HeightmapCallable = Callable[[jtp.FloatLike, jtp.FloatLike], jtp.FloatLike]
|
15
|
+
|
16
|
+
|
17
|
+
class MujocoModelHelper:
|
18
|
+
"""
|
19
|
+
Helper class to create and interact with Mujoco models and data objects.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:
|
23
|
+
"""
|
24
|
+
Initialize the MujocoModelHelper object.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
model: A Mujoco model object.
|
28
|
+
data: A Mujoco data object. If None, a new one will be created.
|
29
|
+
"""
|
30
|
+
|
31
|
+
self.model = model
|
32
|
+
self.data = data if data is not None else mj.MjData(self.model)
|
33
|
+
|
34
|
+
# Populate the data with kinematics
|
35
|
+
mj.mj_forward(self.model, self.data)
|
36
|
+
|
37
|
+
# Keep the cache of this method local to improve GC
|
38
|
+
self.mask_qpos = functools.cache(self._mask_qpos)
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def build_from_xml(
|
42
|
+
mjcf_description: str | pathlib.Path,
|
43
|
+
assets: dict[str, Any] = None,
|
44
|
+
heightmap: HeightmapCallable | None = None,
|
45
|
+
) -> MujocoModelHelper:
|
46
|
+
"""
|
47
|
+
Build a Mujoco model from an XML description and an optional assets dictionary.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
mjcf_description: A string containing the XML description of the Mujoco model
|
51
|
+
or a path to a file containing the XML description.
|
52
|
+
assets: An optional dictionary containing the assets of the model.
|
53
|
+
heightmap: A function in two variables that returns the height of a terrain
|
54
|
+
in the specified coordinate point.
|
55
|
+
Returns:
|
56
|
+
A MujocoModelHelper object.
|
57
|
+
"""
|
58
|
+
|
59
|
+
# Read the XML description if it's a path to file
|
60
|
+
mjcf_description = (
|
61
|
+
mjcf_description.read_text()
|
62
|
+
if isinstance(mjcf_description, pathlib.Path)
|
63
|
+
else mjcf_description
|
64
|
+
)
|
65
|
+
|
66
|
+
# Create the Mujoco model from the XML and, optionally, the assets dictionary
|
67
|
+
model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa
|
68
|
+
data = mj.MjData(model)
|
69
|
+
|
70
|
+
if heightmap:
|
71
|
+
nrow = model.hfield_nrow.item()
|
72
|
+
ncol = model.hfield_ncol.item()
|
73
|
+
new_hfield = generate_hfield(heightmap, (nrow, ncol))
|
74
|
+
model.hfield_data = new_hfield
|
75
|
+
|
76
|
+
return MujocoModelHelper(model=model, data=mj.MjData(model))
|
77
|
+
|
78
|
+
def time(self) -> float:
|
79
|
+
"""Return the simulation time."""
|
80
|
+
|
81
|
+
return self.data.time
|
82
|
+
|
83
|
+
def timestep(self) -> float:
|
84
|
+
"""Return the simulation timestep."""
|
85
|
+
|
86
|
+
return self.model.opt.timestep
|
87
|
+
|
88
|
+
def gravity(self) -> npt.NDArray:
|
89
|
+
"""Return the 3D gravity vector."""
|
90
|
+
|
91
|
+
return self.model.opt.gravity
|
92
|
+
|
93
|
+
# =========================
|
94
|
+
# Methods for the base link
|
95
|
+
# =========================
|
96
|
+
|
97
|
+
def is_floating_base(self) -> bool:
|
98
|
+
"""Return true if the model is floating-base."""
|
99
|
+
|
100
|
+
# A body with no joints is considered a fixed-base model.
|
101
|
+
# In fact, in mujoco, a floating-base model has a 6 DoFs first joint.
|
102
|
+
if self.number_of_joints() == 0:
|
103
|
+
return False
|
104
|
+
|
105
|
+
# We just check that the first joint has 6 DoFs.
|
106
|
+
joint0_type = self.model.jnt_type[0]
|
107
|
+
return joint0_type == mj.mjtJoint.mjJNT_FREE
|
108
|
+
|
109
|
+
def is_fixed_base(self) -> bool:
|
110
|
+
"""Return true if the model is fixed-base."""
|
111
|
+
|
112
|
+
return not self.is_floating_base()
|
113
|
+
|
114
|
+
def base_link(self) -> str:
|
115
|
+
"""Return the name of the base link."""
|
116
|
+
|
117
|
+
return mj.mj_id2name(
|
118
|
+
self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1
|
119
|
+
)
|
120
|
+
|
121
|
+
def base_position(self) -> npt.NDArray:
|
122
|
+
"""Return the 3D position of the base link."""
|
123
|
+
|
124
|
+
return (
|
125
|
+
self.data.qpos[:3]
|
126
|
+
if self.is_floating_base()
|
127
|
+
else self.body_position(body_name=self.base_link())
|
128
|
+
)
|
129
|
+
|
130
|
+
def base_orientation(self, dcm: bool = False) -> npt.NDArray:
|
131
|
+
"""Return the orientation of the base link."""
|
132
|
+
|
133
|
+
return (
|
134
|
+
(
|
135
|
+
np.reshape(self.data.xmat[0], newshape=(3, 3))
|
136
|
+
if dcm is True
|
137
|
+
else self.data.xquat[0]
|
138
|
+
)
|
139
|
+
if self.is_floating_base()
|
140
|
+
else self.body_orientation(body_name=self.base_link(), dcm=dcm)
|
141
|
+
)
|
142
|
+
|
143
|
+
def set_base_position(self, position: npt.NDArray) -> None:
|
144
|
+
"""Set the 3D position of the base link."""
|
145
|
+
|
146
|
+
if self.is_fixed_base():
|
147
|
+
raise ValueError("The position of a fixed-base model cannot be set.")
|
148
|
+
|
149
|
+
position = np.atleast_1d(np.array(position).squeeze())
|
150
|
+
|
151
|
+
if position.size != 3:
|
152
|
+
raise ValueError(f"Wrong position size ({position.size})")
|
153
|
+
|
154
|
+
self.data.qpos[:3] = position
|
155
|
+
|
156
|
+
def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None:
|
157
|
+
"""Set the 3D position of the base link."""
|
158
|
+
|
159
|
+
if self.is_fixed_base():
|
160
|
+
raise ValueError("The orientation of a fixed-base model cannot be set.")
|
161
|
+
|
162
|
+
orientation = (
|
163
|
+
np.atleast_2d(np.array(orientation).squeeze())
|
164
|
+
if dcm
|
165
|
+
else np.atleast_1d(np.array(orientation).squeeze())
|
166
|
+
)
|
167
|
+
|
168
|
+
if orientation.shape != ((4,) if not dcm else (3, 3)):
|
169
|
+
raise ValueError(f"Wrong orientation shape {orientation.shape}")
|
170
|
+
|
171
|
+
def is_quaternion(Q):
|
172
|
+
return np.allclose(np.linalg.norm(Q), 1.0)
|
173
|
+
|
174
|
+
def is_dcm(R):
|
175
|
+
return np.allclose(np.linalg.det(R), 1.0) and np.allclose(
|
176
|
+
R.T @ R, np.eye(3)
|
177
|
+
)
|
178
|
+
|
179
|
+
if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)):
|
180
|
+
raise ValueError("The orientation is not a valid element of SO(3)")
|
181
|
+
|
182
|
+
W_Q_B = (
|
183
|
+
Rotation.from_matrix(orientation).as_quat(canonical=True)[
|
184
|
+
np.array([3, 0, 1, 2])
|
185
|
+
]
|
186
|
+
if dcm
|
187
|
+
else orientation
|
188
|
+
)
|
189
|
+
|
190
|
+
self.data.qpos[3:7] = W_Q_B
|
191
|
+
|
192
|
+
# ==================
|
193
|
+
# Methods for joints
|
194
|
+
# ==================
|
195
|
+
|
196
|
+
def number_of_joints(self) -> int:
|
197
|
+
"""Returns the number of joints in the model."""
|
198
|
+
|
199
|
+
return self.model.njnt
|
200
|
+
|
201
|
+
def number_of_dofs(self) -> int:
|
202
|
+
"""Returns the number of DoFs in the model."""
|
203
|
+
|
204
|
+
return self.model.nq
|
205
|
+
|
206
|
+
def joint_names(self) -> list[str]:
|
207
|
+
"""Returns the names of the joints in the model."""
|
208
|
+
|
209
|
+
return [
|
210
|
+
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
|
211
|
+
for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints())
|
212
|
+
]
|
213
|
+
|
214
|
+
def joint_dofs(self, joint_name: str) -> int:
|
215
|
+
"""Returns the number of DoFs of a joint."""
|
216
|
+
|
217
|
+
if joint_name not in self.joint_names():
|
218
|
+
raise ValueError(f"Joint '{joint_name}' not found")
|
219
|
+
|
220
|
+
return self.data.joint(joint_name).qpos.size
|
221
|
+
|
222
|
+
def joint_position(self, joint_name: str) -> npt.NDArray:
|
223
|
+
"""Returns the position of a joint."""
|
224
|
+
|
225
|
+
if joint_name not in self.joint_names():
|
226
|
+
raise ValueError(f"Joint '{joint_name}' not found")
|
227
|
+
|
228
|
+
return self.data.joint(joint_name).qpos
|
229
|
+
|
230
|
+
def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
|
231
|
+
"""Returns the positions of the joints."""
|
232
|
+
|
233
|
+
joint_names = joint_names if joint_names is not None else self.joint_names()
|
234
|
+
|
235
|
+
return np.hstack(
|
236
|
+
[self.joint_position(joint_name) for joint_name in joint_names]
|
237
|
+
)
|
238
|
+
|
239
|
+
def set_joint_position(
|
240
|
+
self, joint_name: str, position: npt.NDArray | float
|
241
|
+
) -> None:
|
242
|
+
"""Sets the position of a joint."""
|
243
|
+
|
244
|
+
position = np.atleast_1d(np.array(position).squeeze())
|
245
|
+
|
246
|
+
if position.size != self.joint_dofs(joint_name=joint_name):
|
247
|
+
raise ValueError(
|
248
|
+
f"Wrong position size ({position.size}) of "
|
249
|
+
f"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'."
|
250
|
+
)
|
251
|
+
|
252
|
+
idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
|
253
|
+
offset = self.model.jnt_qposadr[idx]
|
254
|
+
|
255
|
+
sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)]
|
256
|
+
self.data.qpos[sl] = position
|
257
|
+
|
258
|
+
def set_joint_positions(
|
259
|
+
self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
|
260
|
+
) -> None:
|
261
|
+
"""Set the positions of multiple joints."""
|
262
|
+
|
263
|
+
mask = self.mask_qpos(joint_names=tuple(joint_names))
|
264
|
+
self.data.qpos[mask] = positions
|
265
|
+
|
266
|
+
# ==================
|
267
|
+
# Methods for bodies
|
268
|
+
# ==================
|
269
|
+
|
270
|
+
def number_of_bodies(self) -> int:
|
271
|
+
"""Returns the number of bodies in the model."""
|
272
|
+
|
273
|
+
return self.model.nbody
|
274
|
+
|
275
|
+
def body_names(self) -> list[str]:
|
276
|
+
"""Returns the names of the bodies in the model."""
|
277
|
+
|
278
|
+
return [
|
279
|
+
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
|
280
|
+
for idx in range(self.number_of_bodies())
|
281
|
+
]
|
282
|
+
|
283
|
+
def body_position(self, body_name: str) -> npt.NDArray:
|
284
|
+
"""Returns the position of a body."""
|
285
|
+
|
286
|
+
if body_name not in self.body_names():
|
287
|
+
raise ValueError(f"Body '{body_name}' not found")
|
288
|
+
|
289
|
+
return self.data.body(body_name).xpos
|
290
|
+
|
291
|
+
def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
|
292
|
+
"""Returns the orientation of a body."""
|
293
|
+
|
294
|
+
if body_name not in self.body_names():
|
295
|
+
raise ValueError(f"Body '{body_name}' not found")
|
296
|
+
|
297
|
+
return (
|
298
|
+
self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat
|
299
|
+
)
|
300
|
+
|
301
|
+
# ======================
|
302
|
+
# Methods for geometries
|
303
|
+
# ======================
|
304
|
+
|
305
|
+
def number_of_geometries(self) -> int:
|
306
|
+
"""Returns the number of geometries in the model."""
|
307
|
+
|
308
|
+
return self.model.ngeom
|
309
|
+
|
310
|
+
def geometry_names(self) -> list[str]:
|
311
|
+
"""Returns the names of the geometries in the model."""
|
312
|
+
|
313
|
+
return [
|
314
|
+
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
|
315
|
+
for idx in range(self.number_of_geometries())
|
316
|
+
]
|
317
|
+
|
318
|
+
def geometry_position(self, geometry_name: str) -> npt.NDArray:
|
319
|
+
"""Returns the position of a geometry."""
|
320
|
+
|
321
|
+
if geometry_name not in self.geometry_names():
|
322
|
+
raise ValueError(f"Geometry '{geometry_name}' not found")
|
323
|
+
|
324
|
+
return self.data.geom(geometry_name).xpos
|
325
|
+
|
326
|
+
def geometry_orientation(
|
327
|
+
self, geometry_name: str, dcm: bool = False
|
328
|
+
) -> npt.NDArray:
|
329
|
+
"""Returns the orientation of a geometry."""
|
330
|
+
|
331
|
+
if geometry_name not in self.geometry_names():
|
332
|
+
raise ValueError(f"Geometry '{geometry_name}' not found")
|
333
|
+
|
334
|
+
R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3))
|
335
|
+
|
336
|
+
if dcm:
|
337
|
+
return R
|
338
|
+
|
339
|
+
q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
|
340
|
+
return q_xyzw[[3, 0, 1, 2]]
|
341
|
+
|
342
|
+
# ===============
|
343
|
+
# Private methods
|
344
|
+
# ===============
|
345
|
+
|
346
|
+
def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:
|
347
|
+
"""
|
348
|
+
Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array.
|
349
|
+
|
350
|
+
Args:
|
351
|
+
joint_names: A tuple containing the names of the joints.
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
A 1D array containing the indices of the `qpos` array to access the DoFs of
|
355
|
+
the desired `joint_names`.
|
356
|
+
|
357
|
+
Note:
|
358
|
+
This method takes a tuple of strings because we cache the output mask for
|
359
|
+
each combination of joint names. We need a hashable object for the cache.
|
360
|
+
"""
|
361
|
+
|
362
|
+
# Get the indices of the joints in `joint_names`.
|
363
|
+
idxs = [
|
364
|
+
mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
|
365
|
+
for joint_name in joint_names
|
366
|
+
]
|
367
|
+
|
368
|
+
# We first get the index of each joint in the qpos array, and for those that
|
369
|
+
# have multiple DoFs, we expand their mask by appending new elements.
|
370
|
+
# Finally, we flatten the list of arrays to a single array, that is the
|
371
|
+
# final qpos mask accessing all the DoFs of the desired `joint_names`.
|
372
|
+
return np.atleast_1d(
|
373
|
+
np.hstack(
|
374
|
+
[
|
375
|
+
np.array(
|
376
|
+
[
|
377
|
+
self.model.jnt_qposadr[idx] + i
|
378
|
+
for i in range(self.joint_dofs(joint_name=joint_name))
|
379
|
+
]
|
380
|
+
)
|
381
|
+
for idx, joint_name in zip(idxs, joint_names)
|
382
|
+
]
|
383
|
+
).squeeze()
|
384
|
+
)
|
385
|
+
|
386
|
+
|
387
|
+
def generate_hfield(
|
388
|
+
heightmap: HeightmapCallable, size: tuple[int, int] = (10, 10)
|
389
|
+
) -> npt.NDArray:
|
390
|
+
"""
|
391
|
+
Generates a numpy array representing the heightmap of
|
392
|
+
The map will have the following format:
|
393
|
+
```
|
394
|
+
heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]
|
395
|
+
heightmap[1, 0] heightmap[1, 1] ... heightmap[1, size[1]-1]
|
396
|
+
...
|
397
|
+
heightmap[size[0]-1, 0] heightmap[size[0]-1, 1] ... heightmap[size[0]-1, size[1]-1]
|
398
|
+
```
|
399
|
+
|
400
|
+
Args:
|
401
|
+
heightmap: A function that takes two arguments (x, y) and returns the height
|
402
|
+
at that point.
|
403
|
+
size: A tuple of two integers representing the size of the grid.
|
404
|
+
|
405
|
+
Returns:
|
406
|
+
np.ndarray: The terrain heightmap
|
407
|
+
"""
|
408
|
+
|
409
|
+
# Generate the grid.
|
410
|
+
x = np.linspace(0, 1, size[0])
|
411
|
+
y = np.linspace(0, 1, size[1])
|
412
|
+
|
413
|
+
# Generate the heightmap.
|
414
|
+
return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()
|
@@ -0,0 +1,176 @@
|
|
1
|
+
import contextlib
|
2
|
+
import pathlib
|
3
|
+
from typing import ContextManager
|
4
|
+
|
5
|
+
import mediapy as media
|
6
|
+
import mujoco as mj
|
7
|
+
import mujoco.viewer
|
8
|
+
import numpy.typing as npt
|
9
|
+
|
10
|
+
|
11
|
+
class MujocoVideoRecorder:
|
12
|
+
""""""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
model: mj.MjModel,
|
17
|
+
data: mj.MjData,
|
18
|
+
fps: int = 30,
|
19
|
+
width: int | None = None,
|
20
|
+
height: int | None = None,
|
21
|
+
**kwargs,
|
22
|
+
) -> None:
|
23
|
+
"""
|
24
|
+
Initialize the Mujoco video recorder.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
model: The Mujoco model.
|
28
|
+
data: The Mujoco data.
|
29
|
+
fps: The frames per second.
|
30
|
+
width: The width of the video.
|
31
|
+
height: The height of the video.
|
32
|
+
**kwargs: Additional arguments for the renderer.
|
33
|
+
"""
|
34
|
+
|
35
|
+
width = width if width is not None else model.vis.global_.offwidth
|
36
|
+
height = height if height is not None else model.vis.global_.offheight
|
37
|
+
|
38
|
+
if model.vis.global_.offwidth != width:
|
39
|
+
model.vis.global_.offwidth = width
|
40
|
+
|
41
|
+
if model.vis.global_.offheight != height:
|
42
|
+
model.vis.global_.offheight = height
|
43
|
+
|
44
|
+
self.fps = fps
|
45
|
+
self.frames: list[npt.NDArray] = []
|
46
|
+
self.data: mujoco.MjData | None = None
|
47
|
+
self.model: mujoco.MjModel | None = None
|
48
|
+
self.reset(model=model, data=data)
|
49
|
+
|
50
|
+
self.renderer = mujoco.Renderer(
|
51
|
+
model=self.model,
|
52
|
+
**(dict(width=width, height=height) | kwargs),
|
53
|
+
)
|
54
|
+
|
55
|
+
def reset(
|
56
|
+
self, model: mj.MjModel | None = None, data: mj.MjData | None = None
|
57
|
+
) -> None:
|
58
|
+
"""Reset the model and data."""
|
59
|
+
|
60
|
+
self.frames = []
|
61
|
+
|
62
|
+
self.data = data if data is not None else self.data
|
63
|
+
self.model = model if model is not None else self.model
|
64
|
+
|
65
|
+
def render_frame(self, camera_name: str | None = None) -> npt.NDArray:
|
66
|
+
"""Renders a frame."""
|
67
|
+
camera_name = camera_name or "track"
|
68
|
+
|
69
|
+
mujoco.mj_forward(self.model, self.data)
|
70
|
+
self.renderer.update_scene(data=self.data, camera=camera_name)
|
71
|
+
|
72
|
+
return self.renderer.render()
|
73
|
+
|
74
|
+
def record_frame(self, camera_name: str | None = None) -> None:
|
75
|
+
"""Stores a frame in the buffer."""
|
76
|
+
camera_name = camera_name or "track"
|
77
|
+
|
78
|
+
frame = self.render_frame(camera_name=camera_name)
|
79
|
+
self.frames.append(frame)
|
80
|
+
|
81
|
+
def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
|
82
|
+
"""Writes the video to a file."""
|
83
|
+
|
84
|
+
if path.is_dir():
|
85
|
+
raise IsADirectoryError(f"The path '{path}' is a directory.")
|
86
|
+
|
87
|
+
if not exist_ok and path.is_file():
|
88
|
+
raise FileExistsError(f"The file '{path}' already exists.")
|
89
|
+
|
90
|
+
media.write_video(path=path, images=self.frames, fps=self.fps)
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
|
94
|
+
"""
|
95
|
+
Return the integer down-sampling factor to reach at least the target fps.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
original_fps: The original fps.
|
99
|
+
target_min_fps: The target minimum fps.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
The down-sampling factor.
|
103
|
+
"""
|
104
|
+
|
105
|
+
down_sampling = 1
|
106
|
+
down_sampling_final = down_sampling
|
107
|
+
|
108
|
+
while original_fps / (down_sampling + 1) >= target_min_fps:
|
109
|
+
down_sampling = down_sampling + 1
|
110
|
+
|
111
|
+
if int(original_fps / down_sampling) == original_fps / down_sampling:
|
112
|
+
down_sampling_final = down_sampling
|
113
|
+
|
114
|
+
return down_sampling_final
|
115
|
+
|
116
|
+
|
117
|
+
class MujocoVisualizer:
|
118
|
+
""""""
|
119
|
+
|
120
|
+
def __init__(
|
121
|
+
self, model: mj.MjModel | None = None, data: mj.MjData | None = None
|
122
|
+
) -> None:
|
123
|
+
"""
|
124
|
+
Initialize the Mujoco visualizer.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
model: The Mujoco model.
|
128
|
+
data: The Mujoco data.
|
129
|
+
"""
|
130
|
+
|
131
|
+
self.data = data
|
132
|
+
self.model = model
|
133
|
+
|
134
|
+
def sync(
|
135
|
+
self,
|
136
|
+
viewer: mujoco.viewer.Handle,
|
137
|
+
model: mj.MjModel | None = None,
|
138
|
+
data: mj.MjData | None = None,
|
139
|
+
) -> None:
|
140
|
+
"""Updates the viewer with the current model and data."""
|
141
|
+
|
142
|
+
data = data if data is not None else self.data
|
143
|
+
model = model if model is not None else self.model
|
144
|
+
|
145
|
+
mj.mj_forward(model, data)
|
146
|
+
viewer.sync()
|
147
|
+
|
148
|
+
def open_viewer(
|
149
|
+
self, model: mj.MjModel | None = None, data: mj.MjData | None = None
|
150
|
+
) -> mj.viewer.Handle:
|
151
|
+
"""Opens a viewer."""
|
152
|
+
|
153
|
+
data = data if data is not None else self.data
|
154
|
+
model = model if model is not None else self.model
|
155
|
+
|
156
|
+
handle = mj.viewer.launch_passive(
|
157
|
+
model, data, show_left_ui=False, show_right_ui=False
|
158
|
+
)
|
159
|
+
|
160
|
+
return handle
|
161
|
+
|
162
|
+
@contextlib.contextmanager
|
163
|
+
def open(
|
164
|
+
self,
|
165
|
+
model: mj.MjModel | None = None,
|
166
|
+
data: mj.MjData | None = None,
|
167
|
+
close_on_exit: bool = True,
|
168
|
+
) -> ContextManager[mujoco.viewer.Handle]:
|
169
|
+
"""Context manager to open a viewer."""
|
170
|
+
|
171
|
+
handle = self.open_viewer(model=model, data=data)
|
172
|
+
|
173
|
+
try:
|
174
|
+
yield handle
|
175
|
+
finally:
|
176
|
+
handle.close() if close_on_exit else None
|
@@ -50,6 +50,14 @@ class CollidablePoint:
|
|
50
50
|
enabled=self.enabled,
|
51
51
|
)
|
52
52
|
|
53
|
+
def __eq__(self, other):
|
54
|
+
retval = (
|
55
|
+
self.parent_link == other.parent_link
|
56
|
+
and (self.position == other.position).all()
|
57
|
+
and self.enabled == other.enabled
|
58
|
+
)
|
59
|
+
return retval
|
60
|
+
|
53
61
|
def __str__(self):
|
54
62
|
return (
|
55
63
|
f"{self.__class__.__name__}("
|
@@ -93,6 +101,9 @@ class BoxCollision(CollisionShape):
|
|
93
101
|
|
94
102
|
center: npt.NDArray
|
95
103
|
|
104
|
+
def __eq__(self, other):
|
105
|
+
return (self.center == other.center).all() and super().__eq__(other)
|
106
|
+
|
96
107
|
|
97
108
|
@dataclasses.dataclass
|
98
109
|
class SphereCollision(CollisionShape):
|
@@ -105,3 +116,6 @@ class SphereCollision(CollisionShape):
|
|
105
116
|
"""
|
106
117
|
|
107
118
|
center: npt.NDArray
|
119
|
+
|
120
|
+
def __eq__(self, other):
|
121
|
+
return (self.center == other.center).all() and super().__eq__(other)
|
@@ -3,10 +3,10 @@ from typing import List
|
|
3
3
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jax_dataclasses
|
6
|
+
import jaxlie
|
6
7
|
from jax_dataclasses import Static
|
7
8
|
|
8
9
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim.sixd import se3
|
10
10
|
from jaxsim.utils import JaxsimDataclass
|
11
11
|
|
12
12
|
|
@@ -38,6 +38,17 @@ class LinkDescription(JaxsimDataclass):
|
|
38
38
|
def __hash__(self) -> int:
|
39
39
|
return hash(self.__repr__())
|
40
40
|
|
41
|
+
def __eq__(self, other) -> bool:
|
42
|
+
return (
|
43
|
+
self.name == other.name
|
44
|
+
and self.mass == other.mass
|
45
|
+
and (self.inertia == other.inertia).all()
|
46
|
+
and self.index == other.index
|
47
|
+
and self.parent == other.parent
|
48
|
+
and (self.pose == other.pose).all()
|
49
|
+
and self.children == other.children
|
50
|
+
)
|
51
|
+
|
41
52
|
@property
|
42
53
|
def name_and_index(self) -> str:
|
43
54
|
"""
|
@@ -67,7 +78,7 @@ class LinkDescription(JaxsimDataclass):
|
|
67
78
|
I_removed = link.inertia
|
68
79
|
|
69
80
|
# Create the SE3 object. Note the inverse.
|
70
|
-
r_H_l =
|
81
|
+
r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse()
|
71
82
|
r_X_l = r_H_l.adjoint()
|
72
83
|
|
73
84
|
# Move the inertia
|
@@ -34,6 +34,11 @@ class RootPose(NamedTuple):
|
|
34
34
|
root_position: npt.NDArray = np.zeros(3)
|
35
35
|
root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
|
36
36
|
|
37
|
+
def __eq__(self, other):
|
38
|
+
return (self.root_position == other.root_position).all() and (
|
39
|
+
self.root_quaternion == other.root_quaternion
|
40
|
+
).all()
|
41
|
+
|
37
42
|
|
38
43
|
@dataclasses.dataclass(frozen=True)
|
39
44
|
class KinematicGraph:
|
@@ -117,7 +122,7 @@ class KinematicGraph:
|
|
117
122
|
|
118
123
|
# Check that joint indices are unique
|
119
124
|
assert len([j.index for j in self.joints]) == len(
|
120
|
-
|
125
|
+
{j.index for j in self.joints}
|
121
126
|
)
|
122
127
|
|
123
128
|
# Order joints with their indices
|
@@ -263,12 +268,12 @@ class KinematicGraph:
|
|
263
268
|
|
264
269
|
# Return early if there is no action to take
|
265
270
|
if len(joint_names_to_remove) == 0:
|
266
|
-
logging.info(
|
271
|
+
logging.info("The kinematic graph doesn't need to be reduced")
|
267
272
|
return copy.deepcopy(self)
|
268
273
|
|
269
274
|
# Check if all considered joints are part of the full kinematic graph
|
270
275
|
if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
|
271
|
-
extra_j = set(considered_joints) -
|
276
|
+
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
|
272
277
|
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
|
273
278
|
raise ValueError(msg)
|
274
279
|
|