jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/loaders.py CHANGED
@@ -1,19 +1,39 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  import warnings
4
+ from collections.abc import Sequence
4
5
  from typing import Any
5
6
 
6
7
  import mujoco as mj
8
+ import numpy as np
7
9
  import rod.urdf.exporter
8
10
  from lxml import etree as ET
9
11
 
12
+ from jaxsim import logging
13
+
14
+ from .utils import MujocoCamera
15
+
16
+ MujocoCameraType = (
17
+ MujocoCamera | Sequence[MujocoCamera] | dict[str, str] | Sequence[dict[str, str]]
18
+ )
19
+
10
20
 
11
21
  def load_rod_model(
12
22
  model_description: str | pathlib.Path | rod.Model,
13
23
  is_urdf: bool | None = None,
14
24
  model_name: str | None = None,
15
25
  ) -> rod.Model:
16
- """"""
26
+ """
27
+ Load a ROD model from a URDF/SDF file or a ROD model.
28
+
29
+ Args:
30
+ model_description: The URDF/SDF file or ROD model to load.
31
+ is_urdf: Whether to force parsing the model description as a URDF file.
32
+ model_name: The name of the model to load from the resource.
33
+
34
+ Returns:
35
+ rod.Model: The loaded ROD model.
36
+ """
17
37
 
18
38
  # Parse the SDF resource.
19
39
  sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)
@@ -43,14 +63,77 @@ def load_rod_model(
43
63
  return models[model_name]
44
64
 
45
65
 
66
+ class ModelToMjcf:
67
+ """
68
+ Class to convert a URDF/SDF file or a ROD model to a Mujoco MJCF string.
69
+ """
70
+
71
+ @staticmethod
72
+ def convert(
73
+ model: str | pathlib.Path | rod.Model,
74
+ considered_joints: list[str] | None = None,
75
+ plane_normal: tuple[float, float, float] = (0, 0, 1),
76
+ heightmap: bool | None = None,
77
+ heightmap_samples_xy: tuple[int, int] = (101, 101),
78
+ cameras: MujocoCameraType = (),
79
+ ) -> tuple[str, dict[str, Any]]:
80
+ """
81
+ Convert a model to a Mujoco MJCF string.
82
+
83
+ Args:
84
+ model: The URDF/SDF file or ROD model to convert.
85
+ considered_joints: The list of joint names to consider in the conversion.
86
+ plane_normal: The normal vector of the plane.
87
+ heightmap: Whether to generate a heightmap.
88
+ heightmap_samples_xy: The number of points in the heightmap grid.
89
+ cameras: The custom cameras to add to the scene.
90
+
91
+ Returns:
92
+ A tuple containing the MJCF string and the dictionary of assets.
93
+ """
94
+
95
+ match model:
96
+ case rod.Model():
97
+ rod_model = model
98
+ case str() | pathlib.Path():
99
+ # Convert the JaxSim model to a ROD model.
100
+ rod_model = load_rod_model(
101
+ model_description=model,
102
+ is_urdf=None,
103
+ model_name=None,
104
+ )
105
+ case _:
106
+ raise TypeError(f"Unsupported type for 'model': {type(model)}")
107
+
108
+ # Convert the ROD model to MJCF.
109
+ return RodModelToMjcf.convert(
110
+ rod_model=rod_model,
111
+ considered_joints=considered_joints,
112
+ plane_normal=plane_normal,
113
+ heightmap=heightmap,
114
+ heightmap_samples_xy=heightmap_samples_xy,
115
+ cameras=cameras,
116
+ )
117
+
118
+
46
119
  class RodModelToMjcf:
47
- """"""
120
+ """
121
+ Class to convert a ROD model to a Mujoco MJCF string.
122
+ """
48
123
 
49
124
  @staticmethod
50
125
  def assets_from_rod_model(
51
126
  rod_model: rod.Model,
52
127
  ) -> dict[str, bytes]:
53
- """"""
128
+ """
129
+ Generate a dictionary of assets from a ROD model.
130
+
131
+ Args:
132
+ rod_model: The ROD model to extract the assets from.
133
+
134
+ Returns:
135
+ dict: A dictionary of assets.
136
+ """
54
137
 
55
138
  import resolve_robotics_uri_py
56
139
 
@@ -85,7 +168,17 @@ class RodModelToMjcf:
85
168
  base_link_name: str,
86
169
  floating_joint_name: str = "world_to_base",
87
170
  ) -> str:
88
- """"""
171
+ """
172
+ Add a floating joint to a URDF string.
173
+
174
+ Args:
175
+ urdf_string: The URDF string to modify.
176
+ base_link_name: The name of the base link to attach the floating joint.
177
+ floating_joint_name: The name of the floating joint to add.
178
+
179
+ Returns:
180
+ str: The modified URDF string.
181
+ """
89
182
 
90
183
  with tempfile.NamedTemporaryFile(mode="w+", suffix=".urdf") as urdf_file:
91
184
 
@@ -102,7 +195,7 @@ class RodModelToMjcf:
102
195
 
103
196
  if root.find(f".//joint[@name='{floating_joint_name}']") is not None:
104
197
  msg = f"The URDF already has a floating joint '{floating_joint_name}'"
105
- warnings.warn(msg)
198
+ warnings.warn(msg, stacklevel=2)
106
199
  return ET.tostring(root, pretty_print=True).decode()
107
200
 
108
201
  # Create the "world" link if it doesn't exist.
@@ -129,8 +222,25 @@ class RodModelToMjcf:
129
222
  def convert(
130
223
  rod_model: rod.Model,
131
224
  considered_joints: list[str] | None = None,
225
+ plane_normal: tuple[float, float, float] = (0, 0, 1),
226
+ heightmap: bool | None = None,
227
+ heightmap_samples_xy: tuple[int, int] = (101, 101),
228
+ cameras: MujocoCameraType = (),
132
229
  ) -> tuple[str, dict[str, Any]]:
133
- """"""
230
+ """
231
+ Convert a ROD model to a Mujoco MJCF string.
232
+
233
+ Args:
234
+ rod_model: The ROD model to convert.
235
+ considered_joints: The list of joint names to consider in the conversion.
236
+ plane_normal: The normal vector of the plane.
237
+ heightmap: Whether to generate a heightmap.
238
+ heightmap_samples_xy: The number of points in the heightmap grid.
239
+ cameras: The custom cameras to add to the scene.
240
+
241
+ Returns:
242
+ A tuple containing the MJCF string and the dictionary of assets.
243
+ """
134
244
 
135
245
  # -------------------------------------
136
246
  # Convert the model description to URDF
@@ -144,10 +254,9 @@ class RodModelToMjcf:
144
254
  )
145
255
 
146
256
  # If considered joints are passed, make sure that they are all part of the model.
147
- if considered_joints - set([j.name for j in rod_model.joints()]):
148
- extra_joints = set(considered_joints) - set(
149
- [j.name for j in rod_model.joints()]
150
- )
257
+ if considered_joints - {j.name for j in rod_model.joints()}:
258
+ extra_joints = set(considered_joints) - {j.name for j in rod_model.joints()}
259
+
151
260
  msg = f"Couldn't find the following joints in the model: '{extra_joints}'"
152
261
  raise ValueError(msg)
153
262
 
@@ -155,14 +264,14 @@ class RodModelToMjcf:
155
264
  joints_dict = {j.name: j for j in rod_model.joints()}
156
265
 
157
266
  # Convert all the joints not considered to fixed joints.
158
- for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
267
+ for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
159
268
  joints_dict[joint_name].type = "fixed"
160
269
 
161
270
  # Convert the ROD model to URDF.
162
- urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
271
+ urdf_string = rod.urdf.exporter.UrdfExporter(
272
+ gazebo_preserve_fixed_joints=False, pretty=True
273
+ ).to_urdf_string(
163
274
  sdf=rod.Sdf(model=rod_model, version="1.7"),
164
- gazebo_preserve_fixed_joints=False,
165
- pretty=True,
166
275
  )
167
276
 
168
277
  # -------------------------------------
@@ -198,8 +307,6 @@ class RodModelToMjcf:
198
307
  )
199
308
 
200
309
  urdf_string = ET.tostring(root, pretty_print=True).decode()
201
- # print(urdf_string)
202
- # raise
203
310
 
204
311
  # ------------------------------
205
312
  # Post-process all dummy visuals
@@ -207,7 +314,6 @@ class RodModelToMjcf:
207
314
 
208
315
  parser = ET.XMLParser(remove_blank_text=True)
209
316
  root: ET._Element = ET.fromstring(text=urdf_string.encode(), parser=parser)
210
- import numpy as np
211
317
 
212
318
  # Give a tiny radius to all dummy spheres
213
319
  for geometry in root.findall(".//visual/geometry[sphere]"):
@@ -233,13 +339,13 @@ class RodModelToMjcf:
233
339
 
234
340
  # Load the URDF model into Mujoco.
235
341
  assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model)
236
- mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets) # noqa
342
+ mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
237
343
 
238
344
  # Get the joint names.
239
- mj_joint_names = set(
345
+ mj_joint_names = {
240
346
  mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
241
347
  for idx in range(mj_model.njnt)
242
- )
348
+ }
243
349
 
244
350
  # Check that the Mujoco model only has the considered joints.
245
351
  if mj_joint_names != considered_joints:
@@ -265,7 +371,7 @@ class RodModelToMjcf:
265
371
  root: ET._Element = tree.getroot()
266
372
 
267
373
  # Find the <mujoco> element (might be the root itself).
268
- mujoco_element: ET._Element = list(root.iter("mujoco"))[0]
374
+ mujoco_element: ET._Element = next(iter(root.iter("mujoco")))
269
375
 
270
376
  # --------------
271
377
  # Add the motors
@@ -310,7 +416,7 @@ class RodModelToMjcf:
310
416
  # Set alpha=0 to the color of all collision elements
311
417
  for geometry_element in mujoco_element.findall(".//geom[@rgba]"):
312
418
  if geometry_element.attrib.get("name") in collision_names:
313
- r, g, b, a = geometry_element.attrib["rgba"].split(" ")
419
+ r, g, b, _ = geometry_element.attrib["rgba"].split(" ")
314
420
  geometry_element.set("rgba", f"{r} {g} {b} 0")
315
421
 
316
422
  # -----------------------
@@ -358,6 +464,21 @@ class RodModelToMjcf:
358
464
  texuniform="true",
359
465
  )
360
466
 
467
+ _ = (
468
+ ET.SubElement(
469
+ asset_element,
470
+ "hfield",
471
+ name="terrain",
472
+ nrow=f"{int(heightmap_samples_xy[0])}",
473
+ ncol=f"{int(heightmap_samples_xy[1])}",
474
+ # The following 'size' is a placeholder, it is updated dynamically
475
+ # when a hfield/heightmap is stored into MjData.
476
+ size="1 1 1 1",
477
+ )
478
+ if heightmap
479
+ else None
480
+ )
481
+
361
482
  # ----------------------------------
362
483
  # Populate the scene with the assets
363
484
  # ----------------------------------
@@ -368,12 +489,14 @@ class RodModelToMjcf:
368
489
  worldbody_scene_element,
369
490
  "geom",
370
491
  name="floor",
371
- type="plane",
492
+ type="plane" if not heightmap else "hfield",
372
493
  size="0 0 0.05",
373
494
  material="plane_material",
374
495
  condim="3",
375
496
  contype="1",
376
497
  conaffinity="1",
498
+ zaxis=" ".join(map(str, plane_normal)),
499
+ **({"hfield": "terrain"} if heightmap else {}),
377
500
  )
378
501
 
379
502
  _ = ET.SubElement(
@@ -407,16 +530,28 @@ class RodModelToMjcf:
407
530
  raise RuntimeError("Failed to find the <worldbody> element of the model")
408
531
 
409
532
  # Camera attached to the model
533
+ # It can be manually copied from `python -m mujoco.viewer --mjcf=<URDF_PATH>`
410
534
  _ = ET.SubElement(
411
535
  worldbody_element,
412
536
  "camera",
413
537
  name="track",
414
538
  mode="trackcom",
415
- pos="1 0 5",
416
- zaxis="0 0 1",
539
+ pos="1.930 -2.279 0.556",
540
+ xyaxes="0.771 0.637 0.000 -0.116 0.140 0.983",
417
541
  fovy="60",
418
542
  )
419
543
 
544
+ # Add user-defined camera.
545
+ for camera in cameras if isinstance(cameras, Sequence) else [cameras]:
546
+
547
+ mj_camera = (
548
+ camera
549
+ if isinstance(camera, MujocoCamera)
550
+ else MujocoCamera.build(**camera)
551
+ )
552
+
553
+ _ = ET.SubElement(worldbody_element, "camera", mj_camera.asdict())
554
+
420
555
  # ------------------------------------------------
421
556
  # Add a light following the CoM of the first link
422
557
  # ------------------------------------------------
@@ -444,13 +579,35 @@ class RodModelToMjcf:
444
579
 
445
580
 
446
581
  class UrdfToMjcf:
582
+ """
583
+ Class to convert a URDF file to a Mujoco MJCF string.
584
+ """
585
+
447
586
  @staticmethod
448
587
  def convert(
449
588
  urdf: str | pathlib.Path,
450
589
  considered_joints: list[str] | None = None,
451
590
  model_name: str | None = None,
591
+ plane_normal: tuple[float, float, float] = (0, 0, 1),
592
+ heightmap: bool | None = None,
593
+ cameras: MujocoCameraType = (),
452
594
  ) -> tuple[str, dict[str, Any]]:
453
- """"""
595
+ """
596
+ Convert a URDF file to a Mujoco MJCF string.
597
+
598
+ Args:
599
+ urdf: The URDF file to convert.
600
+ considered_joints: The list of joint names to consider in the conversion.
601
+ model_name: The name of the model to convert.
602
+ plane_normal: The normal vector of the plane.
603
+ heightmap: Whether to generate a heightmap.
604
+ cameras: The list of cameras to add to the scene.
605
+
606
+ Returns:
607
+ tuple: A tuple containing the MJCF string and the assets dictionary.
608
+ """
609
+
610
+ logging.warning("This method is deprecated. Use 'ModelToMjcf.convert' instead.")
454
611
 
455
612
  # Get the ROD model.
456
613
  rod_model = load_rod_model(
@@ -461,18 +618,44 @@ class UrdfToMjcf:
461
618
 
462
619
  # Convert the ROD model to MJCF.
463
620
  return RodModelToMjcf.convert(
464
- rod_model=rod_model, considered_joints=considered_joints
621
+ rod_model=rod_model,
622
+ considered_joints=considered_joints,
623
+ plane_normal=plane_normal,
624
+ heightmap=heightmap,
625
+ cameras=cameras,
465
626
  )
466
627
 
467
628
 
468
629
  class SdfToMjcf:
630
+ """
631
+ Class to convert a SDF file to a Mujoco MJCF string.
632
+ """
633
+
469
634
  @staticmethod
470
635
  def convert(
471
636
  sdf: str | pathlib.Path,
472
637
  considered_joints: list[str] | None = None,
473
638
  model_name: str | None = None,
639
+ plane_normal: tuple[float, float, float] = (0, 0, 1),
640
+ heightmap: bool | None = None,
641
+ cameras: MujocoCameraType = (),
474
642
  ) -> tuple[str, dict[str, Any]]:
475
- """"""
643
+ """
644
+ Convert a SDF file to a Mujoco MJCF string.
645
+
646
+ Args:
647
+ sdf: The SDF file to convert.
648
+ considered_joints: The list of joint names to consider in the conversion.
649
+ model_name: The name of the model to convert.
650
+ plane_normal: The normal vector of the plane.
651
+ heightmap: Whether to generate a heightmap.
652
+ cameras: The list of cameras to add to the scene.
653
+
654
+ Returns:
655
+ tuple: A tuple containing the MJCF string and the assets dictionary.
656
+ """
657
+
658
+ logging.warning("This method is deprecated. Use 'ModelToMjcf.convert' instead.")
476
659
 
477
660
  # Get the ROD model.
478
661
  rod_model = load_rod_model(
@@ -483,5 +666,9 @@ class SdfToMjcf:
483
666
 
484
667
  # Convert the ROD model to MJCF.
485
668
  return RodModelToMjcf.convert(
486
- rod_model=rod_model, considered_joints=considered_joints
669
+ rod_model=rod_model,
670
+ considered_joints=considered_joints,
671
+ plane_normal=plane_normal,
672
+ heightmap=heightmap,
673
+ cameras=cameras,
487
674
  )