jaxsim 0.2.dev56__py3-none-any.whl → 0.2.dev77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.dev56'
16
- __version_tuple__ = version_tuple = (0, 2, 'dev56')
15
+ __version__ = version = '0.2.dev77'
16
+ __version_tuple__ = version_tuple = (0, 2, 'dev77')
@@ -0,0 +1,3 @@
1
+ from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf
2
+ from .model import MujocoModelHelper
3
+ from .visualizer import MujocoVisualizer
@@ -0,0 +1,192 @@
1
+ import argparse
2
+ import pathlib
3
+ import sys
4
+ import time
5
+
6
+ import numpy as np
7
+
8
+ from . import MujocoModelHelper, MujocoVisualizer, SdfToMjcf, UrdfToMjcf
9
+
10
+ if __name__ == "__main__":
11
+
12
+ parser = argparse.ArgumentParser(
13
+ prog="jaxsim.mujoco",
14
+ description="Process URDF and SDF files for Mujoco usage.",
15
+ )
16
+
17
+ parser.add_argument(
18
+ "-d",
19
+ "--description",
20
+ required=True,
21
+ metavar="INPUT_FILE",
22
+ type=pathlib.Path,
23
+ help="Path to the URDF or SDF file.",
24
+ )
25
+
26
+ parser.add_argument(
27
+ "-m",
28
+ "--model-name",
29
+ metavar="NAME",
30
+ type=str,
31
+ default=None,
32
+ help="The target model of a SDF description if multiple models exists.",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "-e",
37
+ "--export",
38
+ metavar="MJCF_FILE",
39
+ type=pathlib.Path,
40
+ default=None,
41
+ help="Path to the exported MJCF file.",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "-f",
46
+ "--force",
47
+ action="store_true",
48
+ default=False,
49
+ help="Override the output MJCF file if it already exists (default: %(default)s).",
50
+ )
51
+
52
+ parser.add_argument(
53
+ "-p",
54
+ "--print",
55
+ action="store_true",
56
+ default=False,
57
+ help="Print in the stdout the exported MJCF string (default: %(default)s).",
58
+ )
59
+
60
+ parser.add_argument(
61
+ "-v",
62
+ "--visualize",
63
+ action="store_true",
64
+ default=False,
65
+ help="Visualize the description in the Mujoco viewer (default: %(default)s).",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "-b",
70
+ "--base-position",
71
+ metavar=("x", "y", "z"),
72
+ nargs=3,
73
+ type=float,
74
+ default=None,
75
+ help="Override the base position (supports only floating-base models).",
76
+ )
77
+
78
+ parser.add_argument(
79
+ "-q",
80
+ "--base-quaternion",
81
+ metavar=("w", "x", "y", "z"),
82
+ nargs=4,
83
+ type=float,
84
+ default=None,
85
+ help="Override the base quaternion (supports only floating-base models).",
86
+ )
87
+
88
+ args = parser.parse_args()
89
+
90
+ # ==================
91
+ # Validate arguments
92
+ # ==================
93
+
94
+ # Expand the path of the URDF/SDF file if not absolute.
95
+ if args.description is not None:
96
+ args.description = (
97
+ (
98
+ args.description
99
+ if args.description.is_absolute()
100
+ else pathlib.Path.cwd() / args.description
101
+ )
102
+ .expanduser()
103
+ .absolute()
104
+ )
105
+
106
+ if not pathlib.Path(args.description).is_file():
107
+ msg = f"The URDF/SDF file '{args.description}' does not exist."
108
+ parser.error(msg)
109
+ sys.exit(1)
110
+
111
+ # Expand the path of the output MJCF file if not absolute.
112
+ if args.export is not None:
113
+ args.export = (
114
+ (
115
+ args.export
116
+ if args.export.is_absolute()
117
+ else pathlib.Path.cwd() / args.export
118
+ )
119
+ .expanduser()
120
+ .absolute()
121
+ )
122
+
123
+ if pathlib.Path(args.export).is_file() and not args.force:
124
+ msg = "The output file '{}' already exists, use '--force' to override."
125
+ parser.error(msg.format(args.export))
126
+ sys.exit(1)
127
+
128
+ # ================================================
129
+ # Load the URDF/SDF file and produce a MJCF string
130
+ # ================================================
131
+
132
+ match args.description.suffix.lower()[1:]:
133
+
134
+ case "urdf":
135
+ mjcf_string, assets = UrdfToMjcf().convert(urdf=args.description)
136
+
137
+ case "sdf":
138
+ mjcf_string, assets = SdfToMjcf().convert(
139
+ sdf=args.description, model_name=args.model_name
140
+ )
141
+
142
+ case _:
143
+ msg = f"The file extension '{args.description.suffix}' is not supported."
144
+ parser.error(msg)
145
+ sys.exit(1)
146
+
147
+ if args.print:
148
+ print(mjcf_string, flush=True)
149
+
150
+ # ========================================
151
+ # Write the MJCF string to the output file
152
+ # ========================================
153
+
154
+ if args.export is not None:
155
+ with open(args.export, "w+", encoding="utf-8") as file:
156
+ file.write(mjcf_string)
157
+
158
+ # =======================================
159
+ # Visualize the MJCF in the Mujoco viewer
160
+ # =======================================
161
+
162
+ if args.visualize:
163
+
164
+ mj_model_helper = MujocoModelHelper.build_from_xml(
165
+ mjcf_description=mjcf_string, assets=assets
166
+ )
167
+
168
+ viz = MujocoVisualizer(model=mj_model_helper.model, data=mj_model_helper.data)
169
+
170
+ with viz.open() as viewer:
171
+
172
+ with viewer.lock():
173
+ if args.base_position is not None:
174
+ mj_model_helper.set_base_position(
175
+ position=np.array(args.base_position)
176
+ )
177
+
178
+ if args.base_quaternion is not None:
179
+ mj_model_helper.set_base_orientation(
180
+ orientation=np.array(args.base_quaternion)
181
+ )
182
+
183
+ viz.sync(viewer=viewer)
184
+
185
+ while viewer.is_running():
186
+ time.sleep(0.500)
187
+
188
+ # =============================
189
+ # Exit the program with success
190
+ # =============================
191
+
192
+ sys.exit(0)
@@ -0,0 +1,475 @@
1
+ import pathlib
2
+ import tempfile
3
+ import warnings
4
+ from typing import Any
5
+
6
+ import mujoco as mj
7
+ import rod.urdf.exporter
8
+ from lxml import etree as ET
9
+
10
+
11
+ def load_rod_model(
12
+ model_description: str | pathlib.Path | rod.Model,
13
+ is_urdf: bool | None = None,
14
+ model_name: str | None = None,
15
+ ) -> rod.Model:
16
+ """"""
17
+
18
+ # Parse the SDF resource.
19
+ sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)
20
+
21
+ # Fail if the SDF resource has no model.
22
+ if len(sdf_element.models()) == 0:
23
+ raise RuntimeError("Failed to find any model in the model description")
24
+
25
+ # Return the model if there is only one.
26
+ if len(sdf_element.models()) == 1:
27
+ if model_name is not None and sdf_element.models()[0].name != model_name:
28
+ raise ValueError(f"Model '{model_name}' not found in the description")
29
+
30
+ return sdf_element.models()[0]
31
+
32
+ # Require users to specify the model name if there are multiple models.
33
+ if model_name is None:
34
+ msg = "The resource has multiple models. Please specify the model name."
35
+ raise ValueError(msg)
36
+
37
+ # Build a dictionary of models in the resource for easy access.
38
+ models = {m.name: m for m in sdf_element.models()}
39
+
40
+ if model_name not in models:
41
+ raise ValueError(f"Model '{model_name}' not found in the resource")
42
+
43
+ return models[model_name]
44
+
45
+
46
+ class RodModelToMjcf:
47
+ """"""
48
+
49
+ @staticmethod
50
+ def assets_from_rod_model(
51
+ rod_model: rod.Model,
52
+ ) -> dict[str, bytes]:
53
+ """"""
54
+
55
+ import resolve_robotics_uri_py
56
+
57
+ assets_files = dict()
58
+
59
+ for link in rod_model.links():
60
+ for visual in link.visuals():
61
+ if visual.geometry.mesh and visual.geometry.mesh.uri:
62
+ assets_files[visual.geometry.mesh.uri] = (
63
+ resolve_robotics_uri_py.resolve_robotics_uri(
64
+ visual.geometry.mesh.uri
65
+ )
66
+ )
67
+
68
+ for collision in link.collisions():
69
+ if collision.geometry.mesh and collision.geometry.mesh.uri:
70
+ assets_files[collision.geometry.mesh.uri] = (
71
+ resolve_robotics_uri_py.resolve_robotics_uri(
72
+ collision.geometry.mesh.uri
73
+ )
74
+ )
75
+
76
+ assets = {
77
+ asset_name: asset.read_bytes() for asset_name, asset in assets_files.items()
78
+ }
79
+
80
+ return assets
81
+
82
+ @staticmethod
83
+ def add_floating_joint(
84
+ urdf_string: str,
85
+ base_link_name: str,
86
+ floating_joint_name: str = "world_to_base",
87
+ ) -> str:
88
+ """"""
89
+
90
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".urdf") as urdf_file:
91
+
92
+ # Write the URDF string to a temporary file and move current position
93
+ # to the beginning.
94
+ urdf_file.write(urdf_string)
95
+ urdf_file.seek(0)
96
+
97
+ # Parse the MJCF string as XML (etree).
98
+ parser = ET.XMLParser(remove_blank_text=True)
99
+ tree = ET.parse(source=urdf_file, parser=parser)
100
+
101
+ root: ET._Element = tree.getroot()
102
+
103
+ if root.find(f".//joint[@name='{floating_joint_name}']") is not None:
104
+ msg = f"The URDF already has a floating joint '{floating_joint_name}'"
105
+ warnings.warn(msg)
106
+ return ET.tostring(root, pretty_print=True).decode()
107
+
108
+ # Create the "world" link if it doesn't exist.
109
+ if root.find(".//link[@name='world']") is None:
110
+ _ = ET.SubElement(root, "link", name="world")
111
+
112
+ # Create the floating joint.
113
+ world_to_base = ET.SubElement(
114
+ root, "joint", name=floating_joint_name, type="floating"
115
+ )
116
+
117
+ # Check that the base link exists.
118
+ if root.find(f".//link[@name='{base_link_name}']") is None:
119
+ raise ValueError(f"Link '{base_link_name}' not found in the URDF")
120
+
121
+ # Attach the floating joint to the base link.
122
+ ET.SubElement(world_to_base, "parent", link="world")
123
+ ET.SubElement(world_to_base, "child", link=base_link_name)
124
+
125
+ urdf_string = ET.tostring(root, pretty_print=True).decode()
126
+ return urdf_string
127
+
128
+ @staticmethod
129
+ def convert(
130
+ rod_model: rod.Model,
131
+ considered_joints: list[str] | None = None,
132
+ ) -> tuple[str, dict[str, Any]]:
133
+ """"""
134
+
135
+ # -------------------------------------
136
+ # Convert the model description to URDF
137
+ # -------------------------------------
138
+
139
+ # Consider all joints if not specified otherwise.
140
+ considered_joints = set(
141
+ considered_joints
142
+ if considered_joints is not None
143
+ else [j.name for j in rod_model.joints() if j.type != "fixed"]
144
+ )
145
+
146
+ # If considered joints are passed, make sure that they are all part of the model.
147
+ if considered_joints - set([j.name for j in rod_model.joints()]):
148
+ extra_joints = set(considered_joints) - set(
149
+ [j.name for j in rod_model.joints()]
150
+ )
151
+ msg = f"Couldn't find the following joints in the model: '{extra_joints}'"
152
+ raise ValueError(msg)
153
+
154
+ # Create a dictionary of joints for quick access.
155
+ joints_dict = {j.name: j for j in rod_model.joints()}
156
+
157
+ # Convert all the joints not considered to fixed joints.
158
+ for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
159
+ joints_dict[joint_name].type = "fixed"
160
+
161
+ # Convert the ROD model to URDF.
162
+ urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
163
+ sdf=rod.Sdf(model=rod_model, version="1.7"),
164
+ gazebo_preserve_fixed_joints=False,
165
+ pretty=True,
166
+ )
167
+
168
+ # -------------------------------------
169
+ # Add a floating joint if floating-base
170
+ # -------------------------------------
171
+
172
+ if not rod_model.is_fixed_base():
173
+ considered_joints |= {"world_to_base"}
174
+ urdf_string = RodModelToMjcf.add_floating_joint(
175
+ urdf_string=urdf_string,
176
+ base_link_name=rod_model.get_canonical_link(),
177
+ floating_joint_name="world_to_base",
178
+ )
179
+
180
+ # ---------------------------------------
181
+ # Inject the <mujoco> element in the URDF
182
+ # ---------------------------------------
183
+
184
+ parser = ET.XMLParser(remove_blank_text=True)
185
+ root = ET.fromstring(text=urdf_string.encode(), parser=parser)
186
+
187
+ mujoco_element = (
188
+ ET.SubElement(root, "mujoco")
189
+ if len(root.findall("./mujoco")) == 0
190
+ else root.find("./mujoco")
191
+ )
192
+
193
+ _ = ET.SubElement(
194
+ mujoco_element,
195
+ "compiler",
196
+ balanceinertia="true",
197
+ discardvisual="false",
198
+ )
199
+
200
+ urdf_string = ET.tostring(root, pretty_print=True).decode()
201
+ # print(urdf_string)
202
+ # raise
203
+
204
+ # ------------------------------
205
+ # Post-process all dummy visuals
206
+ # ------------------------------
207
+
208
+ parser = ET.XMLParser(remove_blank_text=True)
209
+ root: ET._Element = ET.fromstring(text=urdf_string.encode(), parser=parser)
210
+ import numpy as np
211
+
212
+ # Give a tiny radius to all dummy spheres
213
+ for geometry in root.findall(".//visual/geometry[sphere]"):
214
+ radius = np.fromstring(
215
+ geometry.find("./sphere").attrib["radius"], sep=" ", dtype=float
216
+ )
217
+ if np.allclose(radius, np.zeros(1)):
218
+ geometry.find("./sphere").set("radius", "0.001")
219
+
220
+ # Give a tiny volume to all dummy boxes
221
+ for geometry in root.findall(".//visual/geometry[box]"):
222
+ size = np.fromstring(
223
+ geometry.find("./box").attrib["size"], sep=" ", dtype=float
224
+ )
225
+ if np.allclose(size, np.zeros(3)):
226
+ geometry.find("./box").set("size", "0.001 0.001 0.001")
227
+
228
+ urdf_string = ET.tostring(root, pretty_print=True).decode()
229
+
230
+ # ------------------------
231
+ # Convert the URDF to MJCF
232
+ # ------------------------
233
+
234
+ # Load the URDF model into Mujoco.
235
+ assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model)
236
+ mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets) # noqa
237
+
238
+ # Get the joint names.
239
+ mj_joint_names = set(
240
+ mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
241
+ for idx in range(mj_model.njnt)
242
+ )
243
+
244
+ # Check that the Mujoco model only has the considered joints.
245
+ if mj_joint_names != considered_joints:
246
+ extra1 = mj_joint_names - considered_joints
247
+ extra2 = considered_joints - mj_joint_names
248
+ extra_joints = extra1.union(extra2)
249
+ msg = "The Mujoco model has the following extra/missing joints: '{}'"
250
+ raise ValueError(msg.format(extra_joints))
251
+
252
+ with tempfile.NamedTemporaryFile(
253
+ mode="w+", suffix=".xml", prefix=f"{rod_model.name}_"
254
+ ) as mjcf_file:
255
+
256
+ # Convert the in-memory Mujoco model to MJCF.
257
+ mj.mj_saveLastXML(mjcf_file.name, mj_model)
258
+
259
+ # Parse the MJCF string as XML (etree).
260
+ # We need to post-process the file to include additional elements.
261
+ parser = ET.XMLParser(remove_blank_text=True)
262
+ tree = ET.parse(source=mjcf_file, parser=parser)
263
+
264
+ # Get the root element.
265
+ root: ET._Element = tree.getroot()
266
+
267
+ # Find the <mujoco> element (might be the root itself).
268
+ mujoco_element: ET._Element = list(root.iter("mujoco"))[0]
269
+
270
+ # --------------
271
+ # Add the motors
272
+ # --------------
273
+
274
+ if len(mujoco_element.findall(".//actuator")) > 0:
275
+ raise RuntimeError("The model already has <actuator> elements.")
276
+
277
+ # Add the actuator element.
278
+ actuator_element = ET.SubElement(mujoco_element, "actuator")
279
+
280
+ # Add a motor for each joint.
281
+ for joint_element in mujoco_element.findall(".//joint"):
282
+ assert (
283
+ joint_element.attrib["name"] in considered_joints
284
+ ), joint_element.attrib["name"]
285
+ if joint_element.attrib.get("type", "hinge") in {"free", "ball"}:
286
+ continue
287
+ ET.SubElement(
288
+ actuator_element,
289
+ "motor",
290
+ name=f"{joint_element.attrib['name']}_motor",
291
+ joint=joint_element.attrib["name"],
292
+ gear="1",
293
+ )
294
+
295
+ # ---------------------------------------------
296
+ # Set full transparency of collision geometries
297
+ # ---------------------------------------------
298
+
299
+ parser = ET.XMLParser(remove_blank_text=True)
300
+
301
+ # Get all the (optional) names of the URDF collision elements
302
+ collision_names = {
303
+ c.attrib["name"]
304
+ for c in ET.fromstring(text=urdf_string.encode(), parser=parser).findall(
305
+ ".//collision[geometry]"
306
+ )
307
+ if "name" in c.attrib
308
+ }
309
+
310
+ # Set alpha=0 to the color of all collision elements
311
+ for geometry_element in mujoco_element.findall(".//geom[@rgba]"):
312
+ if geometry_element.attrib.get("name") in collision_names:
313
+ r, g, b, a = geometry_element.attrib["rgba"].split(" ")
314
+ geometry_element.set("rgba", f"{r} {g} {b} 0")
315
+
316
+ # -----------------------
317
+ # Create the scene assets
318
+ # -----------------------
319
+
320
+ asset_element = (
321
+ ET.SubElement(mujoco_element, "asset")
322
+ if len(mujoco_element.findall(".//asset")) == 0
323
+ else mujoco_element.find(".//asset")
324
+ )
325
+
326
+ _ = ET.SubElement(
327
+ asset_element,
328
+ "texture",
329
+ type="skybox",
330
+ builtin="gradient",
331
+ rgb1="0.3 0.5 0.7",
332
+ rgb2="0 0 0",
333
+ width="512",
334
+ height="512",
335
+ )
336
+
337
+ _ = ET.SubElement(
338
+ asset_element,
339
+ "texture",
340
+ name="plane_texture",
341
+ type="2d",
342
+ builtin="checker",
343
+ rgb1="0.1 0.2 0.3",
344
+ rgb2="0.2 0.3 0.4",
345
+ width="512",
346
+ height="512",
347
+ mark="cross",
348
+ markrgb=".8 .8 .8",
349
+ )
350
+
351
+ _ = ET.SubElement(
352
+ asset_element,
353
+ "material",
354
+ name="plane_material",
355
+ texture="plane_texture",
356
+ reflectance="0.2",
357
+ texrepeat="5 5",
358
+ texuniform="true",
359
+ )
360
+
361
+ # ----------------------------------
362
+ # Populate the scene with the assets
363
+ # ----------------------------------
364
+
365
+ worldbody_scene_element = ET.SubElement(mujoco_element, "worldbody")
366
+
367
+ _ = ET.SubElement(
368
+ worldbody_scene_element,
369
+ "geom",
370
+ name="floor",
371
+ type="plane",
372
+ size="0 0 0.05",
373
+ material="plane_material",
374
+ condim="3",
375
+ contype="1",
376
+ conaffinity="1",
377
+ )
378
+
379
+ _ = ET.SubElement(
380
+ worldbody_scene_element,
381
+ "light",
382
+ name="sun",
383
+ mode="fixed",
384
+ directional="true",
385
+ castshadow="true",
386
+ pos="0 0 10",
387
+ dir="0 0 -1",
388
+ )
389
+
390
+ # ------------------------------------------------
391
+ # Add a light following the CoM of the first link
392
+ # ------------------------------------------------
393
+
394
+ if not rod_model.is_fixed_base():
395
+
396
+ worldbody_element = None
397
+
398
+ # Find the <worldbody> element of our model by searching the one that contains
399
+ # all the considered joints. This is needed because there might be multiple
400
+ # <worldbody> elements inside <mujoco>.
401
+ for wb in mujoco_element.findall(".//worldbody"):
402
+ if all(
403
+ wb.find(f".//joint[@name='{j}']") is not None
404
+ for j in considered_joints
405
+ ):
406
+ worldbody_element = wb
407
+ break
408
+
409
+ if worldbody_element is None:
410
+ raise RuntimeError(
411
+ "Failed to find the <worldbody> element of the model"
412
+ )
413
+
414
+ # Light attached to the model
415
+ _ = ET.SubElement(
416
+ worldbody_element,
417
+ "light",
418
+ name="light_model",
419
+ mode="targetbodycom",
420
+ target=worldbody_element.find(".//body").attrib["name"],
421
+ directional="false",
422
+ castshadow="true",
423
+ pos="1 0 5",
424
+ )
425
+
426
+ # --------------------------------
427
+ # Return the resulting MJCF string
428
+ # --------------------------------
429
+
430
+ mjcf_string = ET.tostring(root, pretty_print=True).decode()
431
+ return mjcf_string, assets
432
+
433
+
434
+ class UrdfToMjcf:
435
+ @staticmethod
436
+ def convert(
437
+ urdf: str | pathlib.Path,
438
+ considered_joints: list[str] | None = None,
439
+ model_name: str | None = None,
440
+ ) -> tuple[str, dict[str, Any]]:
441
+ """"""
442
+
443
+ # Get the ROD model.
444
+ rod_model = load_rod_model(
445
+ model_description=urdf,
446
+ is_urdf=True,
447
+ model_name=model_name,
448
+ )
449
+
450
+ # Convert the ROD model to MJCF.
451
+ return RodModelToMjcf.convert(
452
+ rod_model=rod_model, considered_joints=considered_joints
453
+ )
454
+
455
+
456
+ class SdfToMjcf:
457
+ @staticmethod
458
+ def convert(
459
+ sdf: str | pathlib.Path,
460
+ considered_joints: list[str] | None = None,
461
+ model_name: str | None = None,
462
+ ) -> tuple[str, dict[str, Any]]:
463
+ """"""
464
+
465
+ # Get the ROD model.
466
+ rod_model = load_rod_model(
467
+ model_description=sdf,
468
+ is_urdf=False,
469
+ model_name=model_name,
470
+ )
471
+
472
+ # Convert the ROD model to MJCF.
473
+ return RodModelToMjcf.convert(
474
+ rod_model=rod_model, considered_joints=considered_joints
475
+ )