jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py
CHANGED
@@ -1,12 +1,20 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import functools
|
2
4
|
import pathlib
|
5
|
+
from collections.abc import Callable, Sequence
|
3
6
|
from typing import Any
|
4
7
|
|
5
8
|
import mujoco as mj
|
6
9
|
import numpy as np
|
7
10
|
import numpy.typing as npt
|
11
|
+
import xmltodict
|
8
12
|
from scipy.spatial.transform import Rotation
|
9
13
|
|
14
|
+
import jaxsim.typing as jtp
|
15
|
+
|
16
|
+
HeightmapCallable = Callable[[jtp.FloatLike, jtp.FloatLike], jtp.FloatLike]
|
17
|
+
|
10
18
|
|
11
19
|
class MujocoModelHelper:
|
12
20
|
"""
|
@@ -14,34 +22,118 @@ class MujocoModelHelper:
|
|
14
22
|
"""
|
15
23
|
|
16
24
|
def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:
|
17
|
-
"""
|
25
|
+
"""
|
26
|
+
Initialize the MujocoModelHelper object.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
model: A Mujoco model object.
|
30
|
+
data: A Mujoco data object. If None, a new one will be created.
|
31
|
+
"""
|
18
32
|
|
19
33
|
self.model = model
|
20
34
|
self.data = data if data is not None else mj.MjData(self.model)
|
21
35
|
|
22
|
-
# Populate the data with kinematics
|
36
|
+
# Populate the data with kinematics.
|
23
37
|
mj.mj_forward(self.model, self.data)
|
24
38
|
|
25
|
-
# Keep the cache of this method local to improve GC
|
39
|
+
# Keep the cache of this method local to improve GC.
|
26
40
|
self.mask_qpos = functools.cache(self._mask_qpos)
|
27
41
|
|
28
42
|
@staticmethod
|
29
43
|
def build_from_xml(
|
30
|
-
mjcf_description: str | pathlib.Path,
|
31
|
-
|
32
|
-
|
44
|
+
mjcf_description: str | pathlib.Path,
|
45
|
+
assets: dict[str, Any] | None = None,
|
46
|
+
heightmap: HeightmapCallable | None = None,
|
47
|
+
heightmap_name: str = "terrain",
|
48
|
+
heightmap_radius_xy: tuple[float, float] = (1.0, 1.0),
|
49
|
+
) -> MujocoModelHelper:
|
50
|
+
"""
|
51
|
+
Build a Mujoco model from an MJCF description.
|
33
52
|
|
34
|
-
|
53
|
+
Args:
|
54
|
+
mjcf_description:
|
55
|
+
A string containing the XML description of the Mujoco model
|
56
|
+
or a path to a file containing the XML description.
|
57
|
+
assets: An optional dictionary containing the assets of the model.
|
58
|
+
heightmap:
|
59
|
+
A function in two variables that returns the height of a terrain
|
60
|
+
in the specified coordinate point.
|
61
|
+
heightmap_name:
|
62
|
+
The default name of the heightmap in the MJCF description
|
63
|
+
to load the corresponding configuration.
|
64
|
+
heightmap_radius_xy:
|
65
|
+
The extension of the heightmap in the x-y surface corresponding to the
|
66
|
+
plane over which the grid of the sampled heightmap is generated.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
A MujocoModelHelper object.
|
70
|
+
"""
|
71
|
+
|
72
|
+
# Read the XML description if it is a path to file.
|
35
73
|
mjcf_description = (
|
36
74
|
mjcf_description.read_text()
|
37
75
|
if isinstance(mjcf_description, pathlib.Path)
|
38
76
|
else mjcf_description
|
39
77
|
)
|
40
78
|
|
41
|
-
|
42
|
-
|
79
|
+
if heightmap is None:
|
80
|
+
hfield = None
|
81
|
+
|
82
|
+
else:
|
83
|
+
|
84
|
+
mjcf_description_dict = xmltodict.parse(xml_input=mjcf_description)
|
85
|
+
|
86
|
+
# Create a dictionary of all hfield configurations from the MJCF.
|
87
|
+
hfields = mjcf_description_dict["mujoco"]["asset"].get("hfield", [])
|
88
|
+
hfields = hfields if isinstance(hfields, list) else [hfields]
|
89
|
+
hfields_dict = {hfield["@name"]: hfield for hfield in hfields}
|
90
|
+
|
91
|
+
if heightmap_name not in hfields_dict:
|
92
|
+
raise ValueError(f"Heightmap '{heightmap_name}' not found in MJCF")
|
43
93
|
|
44
|
-
|
94
|
+
hfield_element = hfields_dict[heightmap_name]
|
95
|
+
|
96
|
+
# Generate the hfield by sampling the heightmap function.
|
97
|
+
hfield = generate_hfield(
|
98
|
+
heightmap=heightmap,
|
99
|
+
samples_xy=(int(hfield_element["@nrow"]), int(hfield_element["@ncol"])),
|
100
|
+
radius_xy=heightmap_radius_xy,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Update dynamically the '/asset/hfield[@name=heightmap_name]@size' attribute
|
104
|
+
# with the information of the sampled points.
|
105
|
+
# This is necessary for correctly rendering the heightmap over the
|
106
|
+
# specified xy area with the correct z elevation.
|
107
|
+
size = [float(el) for el in hfield_element["@size"].split(" ")]
|
108
|
+
size[0], size[1] = heightmap_radius_xy
|
109
|
+
size[2] = 1.0
|
110
|
+
# The following could be zero but Mujoco complains if it's exactly zero.
|
111
|
+
size[3] = max(0.000_001, -min(hfield))
|
112
|
+
|
113
|
+
# Replace the 'size' attribute.
|
114
|
+
hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)
|
115
|
+
|
116
|
+
# Update the hfield elements of the original MJCF.
|
117
|
+
# Only the hfield corresponding to 'heightmap_name' was actually edited.
|
118
|
+
mjcf_description_dict["mujoco"]["asset"]["hfield"] = list(
|
119
|
+
hfields_dict.values()
|
120
|
+
)
|
121
|
+
|
122
|
+
# Serialize the updated MJCF to XML.
|
123
|
+
mjcf_description = xmltodict.unparse(
|
124
|
+
input_dict=mjcf_description_dict, pretty=True
|
125
|
+
)
|
126
|
+
|
127
|
+
# Create the Mujoco model from the XML and, optionally, the dictionary of assets.
|
128
|
+
model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets)
|
129
|
+
data = mj.MjData(model)
|
130
|
+
|
131
|
+
# Store the sampled heightmap into the Mujoco model.
|
132
|
+
if heightmap is not None:
|
133
|
+
assert hfield is not None
|
134
|
+
model.hfield_data = hfield
|
135
|
+
|
136
|
+
return MujocoModelHelper(model=model, data=data)
|
45
137
|
|
46
138
|
def time(self) -> float:
|
47
139
|
"""Return the simulation time."""
|
@@ -148,9 +240,9 @@ class MujocoModelHelper:
|
|
148
240
|
raise ValueError("The orientation is not a valid element of SO(3)")
|
149
241
|
|
150
242
|
W_Q_B = (
|
151
|
-
Rotation.from_matrix(orientation).as_quat(
|
152
|
-
|
153
|
-
|
243
|
+
Rotation.from_matrix(orientation).as_quat(
|
244
|
+
canonical=True, scalar_first=False
|
245
|
+
)
|
154
246
|
if dcm
|
155
247
|
else orientation
|
156
248
|
)
|
@@ -162,17 +254,17 @@ class MujocoModelHelper:
|
|
162
254
|
# ==================
|
163
255
|
|
164
256
|
def number_of_joints(self) -> int:
|
165
|
-
""""""
|
257
|
+
"""Return the number of joints in the model."""
|
166
258
|
|
167
259
|
return self.model.njnt
|
168
260
|
|
169
261
|
def number_of_dofs(self) -> int:
|
170
|
-
""""""
|
262
|
+
"""Return the number of DoFs in the model."""
|
171
263
|
|
172
264
|
return self.model.nq
|
173
265
|
|
174
266
|
def joint_names(self) -> list[str]:
|
175
|
-
""""""
|
267
|
+
"""Return the names of the joints in the model."""
|
176
268
|
|
177
269
|
return [
|
178
270
|
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
|
@@ -180,7 +272,7 @@ class MujocoModelHelper:
|
|
180
272
|
]
|
181
273
|
|
182
274
|
def joint_dofs(self, joint_name: str) -> int:
|
183
|
-
""""""
|
275
|
+
"""Return the number of DoFs of a joint."""
|
184
276
|
|
185
277
|
if joint_name not in self.joint_names():
|
186
278
|
raise ValueError(f"Joint '{joint_name}' not found")
|
@@ -188,7 +280,7 @@ class MujocoModelHelper:
|
|
188
280
|
return self.data.joint(joint_name).qpos.size
|
189
281
|
|
190
282
|
def joint_position(self, joint_name: str) -> npt.NDArray:
|
191
|
-
""""""
|
283
|
+
"""Return the position of a joint."""
|
192
284
|
|
193
285
|
if joint_name not in self.joint_names():
|
194
286
|
raise ValueError(f"Joint '{joint_name}' not found")
|
@@ -196,7 +288,7 @@ class MujocoModelHelper:
|
|
196
288
|
return self.data.joint(joint_name).qpos
|
197
289
|
|
198
290
|
def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
|
199
|
-
""""""
|
291
|
+
"""Return the positions of the joints."""
|
200
292
|
|
201
293
|
joint_names = joint_names if joint_names is not None else self.joint_names()
|
202
294
|
|
@@ -207,7 +299,7 @@ class MujocoModelHelper:
|
|
207
299
|
def set_joint_position(
|
208
300
|
self, joint_name: str, position: npt.NDArray | float
|
209
301
|
) -> None:
|
210
|
-
""""""
|
302
|
+
"""Set the position of a joint."""
|
211
303
|
|
212
304
|
position = np.atleast_1d(np.array(position).squeeze())
|
213
305
|
|
@@ -224,9 +316,9 @@ class MujocoModelHelper:
|
|
224
316
|
self.data.qpos[sl] = position
|
225
317
|
|
226
318
|
def set_joint_positions(
|
227
|
-
self, joint_names:
|
319
|
+
self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray]
|
228
320
|
) -> None:
|
229
|
-
""""""
|
321
|
+
"""Set the positions of multiple joints."""
|
230
322
|
|
231
323
|
mask = self.mask_qpos(joint_names=tuple(joint_names))
|
232
324
|
self.data.qpos[mask] = positions
|
@@ -236,12 +328,12 @@ class MujocoModelHelper:
|
|
236
328
|
# ==================
|
237
329
|
|
238
330
|
def number_of_bodies(self) -> int:
|
239
|
-
""""""
|
331
|
+
"""Return the number of bodies in the model."""
|
240
332
|
|
241
333
|
return self.model.nbody
|
242
334
|
|
243
335
|
def body_names(self) -> list[str]:
|
244
|
-
""""""
|
336
|
+
"""Return the names of the bodies in the model."""
|
245
337
|
|
246
338
|
return [
|
247
339
|
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
|
@@ -249,7 +341,7 @@ class MujocoModelHelper:
|
|
249
341
|
]
|
250
342
|
|
251
343
|
def body_position(self, body_name: str) -> npt.NDArray:
|
252
|
-
""""""
|
344
|
+
"""Return the position of a body."""
|
253
345
|
|
254
346
|
if body_name not in self.body_names():
|
255
347
|
raise ValueError(f"Body '{body_name}' not found")
|
@@ -257,7 +349,7 @@ class MujocoModelHelper:
|
|
257
349
|
return self.data.body(body_name).xpos
|
258
350
|
|
259
351
|
def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
|
260
|
-
""""""
|
352
|
+
"""Return the orientation of a body."""
|
261
353
|
|
262
354
|
if body_name not in self.body_names():
|
263
355
|
raise ValueError(f"Body '{body_name}' not found")
|
@@ -271,12 +363,12 @@ class MujocoModelHelper:
|
|
271
363
|
# ======================
|
272
364
|
|
273
365
|
def number_of_geometries(self) -> int:
|
274
|
-
""""""
|
366
|
+
"""Return the number of geometries in the model."""
|
275
367
|
|
276
368
|
return self.model.ngeom
|
277
369
|
|
278
370
|
def geometry_names(self) -> list[str]:
|
279
|
-
""""""
|
371
|
+
"""Return the names of the geometries in the model."""
|
280
372
|
|
281
373
|
return [
|
282
374
|
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
|
@@ -284,7 +376,7 @@ class MujocoModelHelper:
|
|
284
376
|
]
|
285
377
|
|
286
378
|
def geometry_position(self, geometry_name: str) -> npt.NDArray:
|
287
|
-
""""""
|
379
|
+
"""Return the position of a geometry."""
|
288
380
|
|
289
381
|
if geometry_name not in self.geometry_names():
|
290
382
|
raise ValueError(f"Geometry '{geometry_name}' not found")
|
@@ -294,7 +386,7 @@ class MujocoModelHelper:
|
|
294
386
|
def geometry_orientation(
|
295
387
|
self, geometry_name: str, dcm: bool = False
|
296
388
|
) -> npt.NDArray:
|
297
|
-
""""""
|
389
|
+
"""Return the orientation of a geometry."""
|
298
390
|
|
299
391
|
if geometry_name not in self.geometry_names():
|
300
392
|
raise ValueError(f"Geometry '{geometry_name}' not found")
|
@@ -304,8 +396,8 @@ class MujocoModelHelper:
|
|
304
396
|
if dcm:
|
305
397
|
return R
|
306
398
|
|
307
|
-
q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
|
308
|
-
return q_xyzw
|
399
|
+
q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False)
|
400
|
+
return q_xyzw
|
309
401
|
|
310
402
|
# ===============
|
311
403
|
# Private methods
|
@@ -346,7 +438,45 @@ class MujocoModelHelper:
|
|
346
438
|
for i in range(self.joint_dofs(joint_name=joint_name))
|
347
439
|
]
|
348
440
|
)
|
349
|
-
for idx, joint_name in zip(idxs, joint_names)
|
441
|
+
for idx, joint_name in zip(idxs, joint_names, strict=True)
|
350
442
|
]
|
351
443
|
).squeeze()
|
352
444
|
)
|
445
|
+
|
446
|
+
|
447
|
+
def generate_hfield(
|
448
|
+
heightmap: HeightmapCallable,
|
449
|
+
samples_xy: tuple[int, int] = (11, 11),
|
450
|
+
radius_xy: tuple[float, float] = (1.0, 1.0),
|
451
|
+
) -> npt.NDArray:
|
452
|
+
"""
|
453
|
+
Generate an array with elevation points sampled from a heightmap function.
|
454
|
+
|
455
|
+
The map will have the following format:
|
456
|
+
```
|
457
|
+
heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]
|
458
|
+
heightmap[1, 0] heightmap[1, 1] ... heightmap[1, size[1]-1]
|
459
|
+
...
|
460
|
+
heightmap[size[0]-1, 0] heightmap[size[0]-1, 1] ... heightmap[size[0]-1, size[1]-1]
|
461
|
+
```
|
462
|
+
|
463
|
+
Args:
|
464
|
+
heightmap:
|
465
|
+
A function that takes two arguments (x, y) and returns the height
|
466
|
+
at that point.
|
467
|
+
samples_xy: A tuple of two integers representing the size of the grid.
|
468
|
+
radius_xy:
|
469
|
+
A tuple of two floats representing extension of the heightmap in the
|
470
|
+
x-y surface corresponding to the area over which the grid of the sampled
|
471
|
+
heightmap is generated.
|
472
|
+
|
473
|
+
Returns:
|
474
|
+
A flat array of the sampled terrain heightmap.
|
475
|
+
"""
|
476
|
+
|
477
|
+
# Generate the grid.
|
478
|
+
x = np.linspace(-radius_xy[0], radius_xy[0], samples_xy[0])
|
479
|
+
y = np.linspace(-radius_xy[1], radius_xy[1], samples_xy[1])
|
480
|
+
|
481
|
+
# Generate the heightmap.
|
482
|
+
return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()
|
jaxsim/mujoco/utils.py
ADDED
@@ -0,0 +1,228 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from collections.abc import Sequence
|
5
|
+
|
6
|
+
import mujoco as mj
|
7
|
+
import numpy as np
|
8
|
+
import numpy.typing as npt
|
9
|
+
from scipy.spatial.transform import Rotation
|
10
|
+
|
11
|
+
from .model import MujocoModelHelper
|
12
|
+
|
13
|
+
|
14
|
+
def mujoco_data_from_jaxsim(
|
15
|
+
mujoco_model: mj.MjModel,
|
16
|
+
jaxsim_model,
|
17
|
+
jaxsim_data,
|
18
|
+
mujoco_data: mj.MjData | None = None,
|
19
|
+
update_removed_joints: bool = True,
|
20
|
+
) -> mj.MjData:
|
21
|
+
"""
|
22
|
+
Create a Mujoco data object from a JaxSim model and data objects.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
mujoco_model: The Mujoco model object corresponding to the JaxSim model.
|
26
|
+
jaxsim_model: The JaxSim model object from which the Mujoco model was created.
|
27
|
+
jaxsim_data: The JaxSim data object containing the state of the model.
|
28
|
+
mujoco_data: An optional Mujoco data object. If None, a new one will be created.
|
29
|
+
update_removed_joints:
|
30
|
+
If True, the positions of the joints that have been removed during the
|
31
|
+
model reduction process will be set to their initial values.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The Mujoco data object containing the state of the JaxSim model.
|
35
|
+
|
36
|
+
Note:
|
37
|
+
This method is useful to initialize a Mujoco data object used for visualization
|
38
|
+
with the state of a JaxSim model. In particular, this function takes care of
|
39
|
+
initializing the positions of the joints that have been removed during the
|
40
|
+
model reduction process. After the initial creation of the Mujoco data object,
|
41
|
+
it's faster to update the state using an external MujocoModelHelper object.
|
42
|
+
"""
|
43
|
+
|
44
|
+
# The package `jaxsim.mujoco` is supposed to be jax-independent.
|
45
|
+
# We import all the JaxSim resources privately.
|
46
|
+
import jaxsim.api as js
|
47
|
+
|
48
|
+
if not isinstance(jaxsim_model, js.model.JaxSimModel):
|
49
|
+
raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.")
|
50
|
+
|
51
|
+
if not isinstance(jaxsim_data, js.data.JaxSimModelData):
|
52
|
+
raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.")
|
53
|
+
|
54
|
+
# Create the helper to operate on the Mujoco model and data.
|
55
|
+
model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)
|
56
|
+
|
57
|
+
# If the model is fixed-base, the Mujoco model won't have the joint corresponding
|
58
|
+
# to the floating base, and the helper would raise an exception.
|
59
|
+
if jaxsim_model.floating_base():
|
60
|
+
|
61
|
+
# Set the model position.
|
62
|
+
model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))
|
63
|
+
|
64
|
+
# Set the model orientation.
|
65
|
+
model_helper.set_base_orientation(
|
66
|
+
orientation=np.array(jaxsim_data.base_orientation())
|
67
|
+
)
|
68
|
+
|
69
|
+
# Set the joint positions.
|
70
|
+
if jaxsim_model.dofs() > 0:
|
71
|
+
|
72
|
+
model_helper.set_joint_positions(
|
73
|
+
joint_names=list(jaxsim_model.joint_names()),
|
74
|
+
positions=np.array(
|
75
|
+
jaxsim_data.joint_positions(
|
76
|
+
model=jaxsim_model, joint_names=jaxsim_model.joint_names()
|
77
|
+
)
|
78
|
+
),
|
79
|
+
)
|
80
|
+
|
81
|
+
# Updating these joints is not necessary after the first time.
|
82
|
+
# Users can disable this update after initialization.
|
83
|
+
if update_removed_joints:
|
84
|
+
|
85
|
+
# Create a dictionary with the joints that have been removed for various reasons
|
86
|
+
# (like link lumping due to model reduction).
|
87
|
+
joints_removed_dict = {
|
88
|
+
j.name: j
|
89
|
+
for j in jaxsim_model.description._joints_removed
|
90
|
+
if j.name not in set(jaxsim_model.joint_names())
|
91
|
+
}
|
92
|
+
|
93
|
+
# Set the positions of the removed joints.
|
94
|
+
_ = [
|
95
|
+
model_helper.set_joint_position(
|
96
|
+
position=joints_removed_dict[joint_name].initial_position,
|
97
|
+
joint_name=joint_name,
|
98
|
+
)
|
99
|
+
# Select all original joint that have been removed from the JaxSim model
|
100
|
+
# that are still present in the Mujoco model.
|
101
|
+
for joint_name in joints_removed_dict
|
102
|
+
if joint_name in model_helper.joint_names()
|
103
|
+
]
|
104
|
+
|
105
|
+
# Return the mujoco data with updated kinematics.
|
106
|
+
mj.mj_forward(mujoco_model, model_helper.data)
|
107
|
+
|
108
|
+
return model_helper.data
|
109
|
+
|
110
|
+
|
111
|
+
@dataclasses.dataclass
|
112
|
+
class MujocoCamera:
|
113
|
+
"""
|
114
|
+
Helper class storing parameters of a Mujoco camera.
|
115
|
+
|
116
|
+
Refer to the official documentation for more details:
|
117
|
+
https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
|
118
|
+
"""
|
119
|
+
|
120
|
+
mode: str = "fixed"
|
121
|
+
|
122
|
+
target: str | None = None
|
123
|
+
fovy: str = "45"
|
124
|
+
pos: str = "0 0 0"
|
125
|
+
|
126
|
+
quat: str | None = None
|
127
|
+
axisangle: str | None = None
|
128
|
+
xyaxes: str | None = None
|
129
|
+
zaxis: str | None = None
|
130
|
+
euler: str | None = None
|
131
|
+
|
132
|
+
name: str | None = None
|
133
|
+
|
134
|
+
@classmethod
|
135
|
+
def build(cls, **kwargs) -> MujocoCamera:
|
136
|
+
"""
|
137
|
+
Build a Mujoco camera from a dictionary.
|
138
|
+
"""
|
139
|
+
|
140
|
+
if not all(isinstance(value, str) for value in kwargs.values()):
|
141
|
+
raise ValueError(f"Values must be strings: {kwargs}")
|
142
|
+
|
143
|
+
return cls(**kwargs)
|
144
|
+
|
145
|
+
@staticmethod
|
146
|
+
def build_from_target_view(
|
147
|
+
camera_name: str,
|
148
|
+
lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
|
149
|
+
distance: float | int | npt.NDArray = 3,
|
150
|
+
azimuth: float | int | npt.NDArray = 90,
|
151
|
+
elevation: float | int | npt.NDArray = -45,
|
152
|
+
fovy: float | int | npt.NDArray = 45,
|
153
|
+
degrees: bool = True,
|
154
|
+
**kwargs,
|
155
|
+
) -> MujocoCamera:
|
156
|
+
"""
|
157
|
+
Create a custom camera that looks at a target point.
|
158
|
+
|
159
|
+
Note:
|
160
|
+
The choice of the parameters is easier if we imagine to consider a target
|
161
|
+
frame `T` whose origin is located over the lookat point and having the same
|
162
|
+
orientation of the world frame `W`. We also introduce a camera frame `C`
|
163
|
+
whose origin is located over the lower-left corner of the image, and having
|
164
|
+
the x-axis pointing right and the y-axis pointing up in image coordinates.
|
165
|
+
The camera renders what it sees in the -z direction of frame `C`.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
camera_name: The name of the camera.
|
169
|
+
lookat: The target point to look at (origin of `T`).
|
170
|
+
distance:
|
171
|
+
The distance from the target point (displacement between the origins
|
172
|
+
of `T` and `C`).
|
173
|
+
azimuth:
|
174
|
+
The rotation around z of the camera. With an angle of 0, the camera
|
175
|
+
would loot at the target point towards the positive x-axis of `T`.
|
176
|
+
elevation:
|
177
|
+
The rotation around the x-axis of the camera frame `C`. Note that if
|
178
|
+
you want to lift the view angle, the elevation is negative.
|
179
|
+
fovy: The field of view of the camera.
|
180
|
+
degrees: Whether the angles are in degrees or radians.
|
181
|
+
**kwargs: Additional camera parameters.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
The custom camera.
|
185
|
+
"""
|
186
|
+
|
187
|
+
# Start from a frame whose origin is located over the lookat point.
|
188
|
+
# We initialize a -90 degrees rotation around the z-axis because due to
|
189
|
+
# the default camera coordinate system (x pointing right, y pointing up).
|
190
|
+
W_H_C = np.eye(4)
|
191
|
+
W_H_C[0:3, 3] = np.array(lookat)
|
192
|
+
W_H_C[0:3, 0:3] = Rotation.from_euler(
|
193
|
+
seq="ZX", angles=[-90, 90], degrees=True
|
194
|
+
).as_matrix()
|
195
|
+
|
196
|
+
# Process the azimuth.
|
197
|
+
R_az = Rotation.from_euler(seq="Y", angles=azimuth, degrees=degrees).as_matrix()
|
198
|
+
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
|
199
|
+
|
200
|
+
# Process elevation.
|
201
|
+
R_el = Rotation.from_euler(
|
202
|
+
seq="X", angles=elevation, degrees=degrees
|
203
|
+
).as_matrix()
|
204
|
+
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
|
205
|
+
|
206
|
+
# Process distance.
|
207
|
+
tf_distance = np.eye(4)
|
208
|
+
tf_distance[2, 3] = distance
|
209
|
+
W_H_C = W_H_C @ tf_distance
|
210
|
+
|
211
|
+
# Extract the position and the quaternion.
|
212
|
+
p = W_H_C[0:3, 3]
|
213
|
+
Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
|
214
|
+
|
215
|
+
return MujocoCamera.build(
|
216
|
+
name=camera_name,
|
217
|
+
mode="fixed",
|
218
|
+
fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
|
219
|
+
pos=" ".join(p.astype(str).tolist()),
|
220
|
+
quat=" ".join(Q.astype(str).tolist()),
|
221
|
+
**kwargs,
|
222
|
+
)
|
223
|
+
|
224
|
+
def asdict(self) -> dict[str, str]:
|
225
|
+
"""
|
226
|
+
Convert the camera to a dictionary.
|
227
|
+
"""
|
228
|
+
return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
|