warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/sim/import_mjcf.py CHANGED
@@ -17,36 +17,53 @@ import math
17
17
  import os
18
18
  import re
19
19
  import xml.etree.ElementTree as ET
20
+ from typing import Union
20
21
 
21
22
  import numpy as np
22
23
 
23
24
  import warp as wp
25
+ from warp.sim.model import Mesh
24
26
 
25
27
 
26
28
  def parse_mjcf(
27
29
  mjcf_filename,
28
30
  builder,
29
31
  xform=None,
32
+ floating=False,
33
+ base_joint: Union[dict, str, None] = None,
30
34
  density=1000.0,
31
- stiffness=0.0,
32
- damping=0.0,
33
- contact_ke=1000.0,
34
- contact_kd=100.0,
35
- contact_kf=100.0,
35
+ stiffness=100.0,
36
+ damping=10.0,
37
+ armature=0.0,
38
+ armature_scale=1.0,
39
+ contact_ke=1.0e4,
40
+ contact_kd=1.0e3,
41
+ contact_kf=1.0e2,
36
42
  contact_ka=0.0,
37
- contact_mu=0.5,
43
+ contact_mu=0.25,
38
44
  contact_restitution=0.5,
39
45
  contact_thickness=0.0,
40
46
  limit_ke=100.0,
41
47
  limit_kd=10.0,
48
+ joint_limit_lower=-1e6,
49
+ joint_limit_upper=1e6,
42
50
  scale=1.0,
43
- armature=0.0,
44
- armature_scale=1.0,
51
+ hide_visuals=False,
52
+ parse_visuals_as_colliders=False,
45
53
  parse_meshes=True,
46
- enable_self_collisions=False,
47
54
  up_axis="Z",
55
+ ignore_names=(),
48
56
  ignore_classes=None,
57
+ visual_classes=("visual",),
58
+ collider_classes=("collision",),
59
+ no_class_as_colliders=True,
60
+ force_show_colliders=False,
61
+ enable_self_collisions=False,
62
+ ignore_inertial_definitions=True,
63
+ ensure_nonstatic_links=True,
64
+ static_link_mass=1e-2,
49
65
  collapse_fixed_joints=False,
66
+ verbose=False,
50
67
  ):
51
68
  """
52
69
  Parses MuJoCo XML (MJCF) file and adds the bodies and joints to the given ModelBuilder.
@@ -55,9 +72,13 @@ def parse_mjcf(
55
72
  mjcf_filename (str): The filename of the MuJoCo file to parse.
56
73
  builder (ModelBuilder): The :class:`ModelBuilder` to add the bodies and joints to.
57
74
  xform (:ref:`transform <transform>`): The transform to apply to the imported mechanism.
75
+ floating (bool): If True, the root body is a free joint. If False, the root body is connected via a fixed joint to the world, unless a `base_joint` is defined.
76
+ base_joint (Union[str, dict]): The joint by which the root body is connected to the world. This can be either a string defining the joint axes of a D6 joint with comma-separated positional and angular axis names (e.g. "px,py,rz" for a D6 joint with linear axes in x, y and an angular axis in z) or a dict with joint parameters (see :meth:`ModelBuilder.add_joint`).
58
77
  density (float): The density of the shapes in kg/m^3 which will be used to calculate the body mass and inertia.
59
78
  stiffness (float): The stiffness of the joints.
60
79
  damping (float): The damping of the joints.
80
+ armature (float): Default joint armature to use if `armature` has not been defined for a joint in the MJCF.
81
+ armature_scale (float): Scaling factor to apply to the MJCF-defined joint armature values.
61
82
  contact_ke (float): The stiffness of the shape contacts.
62
83
  contact_kd (float): The damping of the shape contacts.
63
84
  contact_kf (float): The friction stiffness of the shape contacts.
@@ -67,19 +88,25 @@ def parse_mjcf(
67
88
  contact_thickness (float): The thickness to add to the shape geometry.
68
89
  limit_ke (float): The stiffness of the joint limits.
69
90
  limit_kd (float): The damping of the joint limits.
91
+ joint_limit_lower (float): The default lower joint limit if not specified in the MJCF.
92
+ joint_limit_upper (float): The default upper joint limit if not specified in the MJCF.
70
93
  scale (float): The scaling factor to apply to the imported mechanism.
71
- armature (float): Default joint armature to use if `armature` has not been defined for a joint in the MJCF.
72
- armature_scale (float): Scaling factor to apply to the MJCF-defined joint armature values.
94
+ hide_visuals (bool): If True, hide visual shapes.
95
+ parse_visuals_as_colliders (bool): If True, the geometry defined under the `visual_classes` tags is used for collision handling instead of the `collider_classes` geometries.
73
96
  parse_meshes (bool): Whether geometries of type `"mesh"` should be parsed. If False, geometries of type `"mesh"` are ignored.
74
- enable_self_collisions (bool): If True, self-collisions are enabled.
75
97
  up_axis (str): The up axis of the mechanism. Can be either `"X"`, `"Y"` or `"Z"`. The default is `"Z"`.
98
+ ignore_names (List[str]): A list of regular expressions. Bodies and joints with a name matching one of the regular expressions will be ignored.
76
99
  ignore_classes (List[str]): A list of regular expressions. Bodies and joints with a class matching one of the regular expressions will be ignored.
100
+ visual_classes (List[str]): A list of regular expressions. Visual geometries with a class matching one of the regular expressions will be parsed.
101
+ collider_classes (List[str]): A list of regular expressions. Collision geometries with a class matching one of the regular expressions will be parsed.
102
+ no_class_as_colliders: If True, geometries without a class are parsed as collision geometries. If False, geometries without a class are parsed as visual geometries.
103
+ force_show_colliders (bool): If True, the collision shapes are always shown, even if there are visual shapes.
104
+ enable_self_collisions (bool): If True, self-collisions are enabled.
105
+ ignore_inertial_definitions (bool): If True, the inertial parameters defined in the MJCF are ignored and the inertia is calculated from the shape geometry.
106
+ ensure_nonstatic_links (bool): If True, links with zero mass are given a small mass (see `static_link_mass`) to ensure they are dynamic.
107
+ static_link_mass (float): The mass to assign to links with zero mass (if `ensure_nonstatic_links` is set to True).
77
108
  collapse_fixed_joints (bool): If True, fixed joints are removed and the respective bodies are merged.
78
-
79
- Note:
80
- The inertia and masses of the bodies are calculated from the shape geometry and the given density. The values defined in the MJCF are not respected at the moment.
81
-
82
- The handling of advanced features, such as MJCF classes, is still experimental.
109
+ verbose (bool): If True, print additional information about parsing the MJCF.
83
110
  """
84
111
  if xform is None:
85
112
  xform = wp.transform()
@@ -102,13 +129,15 @@ def parse_mjcf(
102
129
  }
103
130
 
104
131
  use_degrees = True # angles are in degrees by default
105
- euler_seq = [1, 2, 3] # XYZ by default
132
+ euler_seq = [0, 1, 2] # XYZ by default
106
133
 
107
134
  compiler = root.find("compiler")
108
135
  if compiler is not None:
109
136
  use_degrees = compiler.attrib.get("angle", "degree").lower() == "degree"
110
- euler_seq = ["xyz".index(c) + 1 for c in compiler.attrib.get("eulerseq", "xyz").lower()]
137
+ euler_seq = ["xyz".index(c) for c in compiler.attrib.get("eulerseq", "xyz").lower()]
111
138
  mesh_dir = compiler.attrib.get("meshdir", ".")
139
+ else:
140
+ mesh_dir = "."
112
141
 
113
142
  mesh_assets = {}
114
143
  for asset in root.findall("asset"):
@@ -118,11 +147,10 @@ def parse_mjcf(
118
147
  # handle stl relative paths
119
148
  if not os.path.isabs(fname):
120
149
  fname = os.path.abspath(os.path.join(mjcf_dirname, fname))
121
- if "name" in mesh.attrib:
122
- mesh_assets[mesh.attrib["name"]] = fname
123
- else:
124
- name = ".".join(os.path.basename(fname).split(".")[:-1])
125
- mesh_assets[name] = fname
150
+ name = mesh.attrib.get("name", ".".join(os.path.basename(fname).split(".")[:-1]))
151
+ s = mesh.attrib.get("scale", "1.0 1.0 1.0")
152
+ s = np.fromstring(s, sep=" ", dtype=np.float32)
153
+ mesh_assets[name] = {"file": fname, "scale": s}
126
154
 
127
155
  class_parent = {}
128
156
  class_children = {}
@@ -196,14 +224,14 @@ def parse_mjcf(
196
224
  euler = np.fromstring(attrib["euler"], sep=" ")
197
225
  if use_degrees:
198
226
  euler *= np.pi / 180
199
- return wp.quat_from_euler(euler, *euler_seq)
227
+ return wp.sim.quat_from_euler(wp.vec3(euler), *euler_seq)
200
228
  if "axisangle" in attrib:
201
229
  axisangle = np.fromstring(attrib["axisangle"], sep=" ")
202
230
  angle = axisangle[3]
203
231
  if use_degrees:
204
232
  angle *= np.pi / 180
205
233
  axis = wp.normalize(wp.vec3(*axisangle[:3]))
206
- return wp.quat_from_axis_angle(axis, angle)
234
+ return wp.quat_from_axis_angle(axis, float(angle))
207
235
  if "xyaxes" in attrib:
208
236
  xyaxes = np.fromstring(attrib["xyaxes"], sep=" ")
209
237
  xaxis = wp.normalize(wp.vec3(*xyaxes[:3]))
@@ -220,26 +248,209 @@ def parse_mjcf(
220
248
  return wp.quat_from_matrix(rot_matrix)
221
249
  return wp.quat_identity()
222
250
 
223
- def parse_mesh(geom):
224
- import trimesh
251
+ def parse_shapes(defaults, body_name, link, geoms, density, visible=True, just_visual=False, incoming_xform=None):
252
+ shapes = []
253
+ for geo_count, geom in enumerate(geoms):
254
+ geom_defaults = defaults
255
+ if "class" in geom.attrib:
256
+ geom_class = geom.attrib["class"]
257
+ ignore_geom = False
258
+ for pattern in ignore_classes:
259
+ if re.match(pattern, geom_class):
260
+ ignore_geom = True
261
+ break
262
+ if ignore_geom:
263
+ continue
264
+ if geom_class in class_defaults:
265
+ geom_defaults = merge_attrib(defaults, class_defaults[geom_class])
266
+ if "geom" in geom_defaults:
267
+ geom_attrib = merge_attrib(geom_defaults["geom"], geom.attrib)
268
+ else:
269
+ geom_attrib = geom.attrib
225
270
 
226
- faces = []
227
- vertices = []
228
- stl_file = mesh_assets[geom["mesh"]]
229
- m = trimesh.load(stl_file)
271
+ geom_name = geom_attrib.get("name", f"{body_name}_geom_{geo_count}{'_visual' if just_visual else ''}")
272
+ geom_type = geom_attrib.get("type", "sphere")
273
+ if "mesh" in geom_attrib:
274
+ geom_type = "mesh"
230
275
 
231
- for v in m.vertices:
232
- vertices.append(np.array(v) * scale)
276
+ ignore_geom = False
277
+ for pattern in ignore_names:
278
+ if re.match(pattern, geom_name):
279
+ ignore_geom = True
280
+ break
281
+ if ignore_geom:
282
+ continue
233
283
 
234
- for f in m.faces:
235
- faces.append(int(f[0]))
236
- faces.append(int(f[1]))
237
- faces.append(int(f[2]))
238
- return wp.sim.Mesh(vertices, faces), m.scale
284
+ geom_size = parse_vec(geom_attrib, "size", [1.0, 1.0, 1.0]) * scale
285
+ geom_pos = parse_vec(geom_attrib, "pos", (0.0, 0.0, 0.0)) * scale
286
+ geom_rot = parse_orientation(geom_attrib)
287
+ geom_density = parse_float(geom_attrib, "density", density)
288
+
289
+ if incoming_xform is not None:
290
+ geom_pos = wp.transform_point(incoming_xform, geom_pos)
291
+ geom_rot = incoming_xform.q * geom_rot
292
+
293
+ if geom_type == "sphere":
294
+ s = builder.add_shape_sphere(
295
+ link,
296
+ pos=geom_pos,
297
+ rot=geom_rot,
298
+ radius=geom_size[0],
299
+ density=geom_density,
300
+ is_visible=visible,
301
+ has_ground_collision=not just_visual,
302
+ has_shape_collision=not just_visual,
303
+ **contact_vars,
304
+ )
305
+ shapes.append(s)
306
+
307
+ elif geom_type == "box":
308
+ s = builder.add_shape_box(
309
+ link,
310
+ pos=geom_pos,
311
+ rot=geom_rot,
312
+ hx=geom_size[0],
313
+ hy=geom_size[1],
314
+ hz=geom_size[2],
315
+ density=geom_density,
316
+ is_visible=visible,
317
+ has_ground_collision=not just_visual,
318
+ has_shape_collision=not just_visual,
319
+ **contact_vars,
320
+ )
321
+ shapes.append(s)
322
+
323
+ elif geom_type == "mesh" and parse_meshes:
324
+ import trimesh
325
+
326
+ # use force='mesh' to load the mesh as a trimesh object
327
+ # with baked in transforms, e.g. from COLLADA files
328
+ stl_file = mesh_assets[geom_attrib["mesh"]]["file"]
329
+ m = trimesh.load(stl_file, force="mesh")
330
+ if "mesh" in geom_defaults:
331
+ mesh_scale = parse_vec(geom_defaults["mesh"], "scale", mesh_assets[geom_attrib["mesh"]]["scale"])
332
+ else:
333
+ mesh_scale = mesh_assets[geom_attrib["mesh"]]["scale"]
334
+ scaling = np.array(mesh_scale) * scale
335
+ # as per the Mujoco XML reference, ignore geom size attribute
336
+ assert len(geom_size) == 3, "need to specify size for mesh geom"
337
+
338
+ if hasattr(m, "geometry"):
339
+ # multiple meshes are contained in a scene
340
+ for m_geom in m.geometry.values():
341
+ m_vertices = np.array(m_geom.vertices, dtype=np.float32) * scaling
342
+ m_faces = np.array(m_geom.faces.flatten(), dtype=np.int32)
343
+ m_mesh = Mesh(m_vertices, m_faces)
344
+ s = builder.add_shape_mesh(
345
+ body=link,
346
+ pos=geom_pos,
347
+ rot=geom_rot,
348
+ mesh=m_mesh,
349
+ density=density,
350
+ is_visible=visible,
351
+ has_ground_collision=not just_visual,
352
+ has_shape_collision=not just_visual,
353
+ **contact_vars,
354
+ )
355
+ shapes.append(s)
356
+ else:
357
+ # a single mesh
358
+ m_vertices = np.array(m.vertices, dtype=np.float32) * scaling
359
+ m_faces = np.array(m.faces.flatten(), dtype=np.int32)
360
+ m_mesh = Mesh(m_vertices, m_faces)
361
+ s = builder.add_shape_mesh(
362
+ body=link,
363
+ pos=geom_pos,
364
+ rot=geom_rot,
365
+ mesh=m_mesh,
366
+ density=density,
367
+ is_visible=visible,
368
+ has_ground_collision=not just_visual,
369
+ has_shape_collision=not just_visual,
370
+ **contact_vars,
371
+ )
372
+ shapes.append(s)
373
+
374
+ elif geom_type in {"capsule", "cylinder"}:
375
+ if "fromto" in geom_attrib:
376
+ geom_fromto = parse_vec(geom_attrib, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
377
+
378
+ start = wp.vec3(geom_fromto[0:3]) * scale
379
+ end = wp.vec3(geom_fromto[3:6]) * scale
380
+
381
+ # compute rotation to align the Warp capsule (along x-axis), with mjcf fromto direction
382
+ axis = wp.normalize(end - start)
383
+ angle = math.acos(wp.dot(axis, wp.vec3(0.0, 1.0, 0.0)))
384
+ axis = wp.normalize(wp.cross(axis, wp.vec3(0.0, 1.0, 0.0)))
385
+
386
+ geom_pos = (start + end) * 0.5
387
+ geom_rot = wp.quat_from_axis_angle(axis, -angle)
388
+
389
+ geom_radius = geom_size[0]
390
+ geom_height = wp.length(end - start) * 0.5
391
+ geom_up_axis = 1
392
+
393
+ else:
394
+ geom_radius = geom_size[0]
395
+ geom_height = geom_size[1]
396
+ geom_up_axis = up_axis
397
+
398
+ if geom_type == "cylinder":
399
+ s = builder.add_shape_cylinder(
400
+ link,
401
+ pos=geom_pos,
402
+ rot=geom_rot,
403
+ radius=geom_radius,
404
+ half_height=geom_height,
405
+ density=density,
406
+ up_axis=geom_up_axis,
407
+ is_visible=visible,
408
+ has_ground_collision=not just_visual,
409
+ has_shape_collision=not just_visual,
410
+ **contact_vars,
411
+ )
412
+ shapes.append(s)
413
+ else:
414
+ s = builder.add_shape_capsule(
415
+ link,
416
+ pos=geom_pos,
417
+ rot=geom_rot,
418
+ radius=geom_radius,
419
+ half_height=geom_height,
420
+ density=density,
421
+ up_axis=geom_up_axis,
422
+ is_visible=visible,
423
+ has_ground_collision=not just_visual,
424
+ has_shape_collision=not just_visual,
425
+ **contact_vars,
426
+ )
427
+ shapes.append(s)
239
428
 
240
- def parse_body(body, parent, incoming_defaults: dict):
241
- body_class = body.get("childclass")
429
+ elif geom_type == "plane":
430
+ normal = wp.quat_rotate(geom_rot, wp.vec3(0.0, 0.0, 1.0))
431
+ p = wp.dot(geom_pos, normal)
432
+ s = builder.add_shape_plane(
433
+ body=link,
434
+ plane=(*normal, p),
435
+ width=geom_size[0],
436
+ length=geom_size[1],
437
+ is_visible=visible,
438
+ has_ground_collision=False,
439
+ has_shape_collision=not just_visual,
440
+ **contact_vars,
441
+ )
442
+ shapes.append(s)
443
+
444
+ else:
445
+ if verbose:
446
+ print(f"MJCF parsing shape {geom_name} issue: geom type {geom_type} is unsupported")
447
+
448
+ return shapes
449
+
450
+ def parse_body(body, parent, incoming_defaults: dict, childclass: str = None):
451
+ body_class = body.get("class")
242
452
  if body_class is None:
453
+ body_class = childclass
243
454
  defaults = incoming_defaults
244
455
  else:
245
456
  for pattern in ignore_classes:
@@ -251,6 +462,7 @@ def parse_mjcf(
251
462
  else:
252
463
  body_attrib = body.attrib
253
464
  body_name = body_attrib["name"]
465
+ body_name = body_name.replace("-", "_") # ensure valid USD path
254
466
  body_pos = parse_vec(body_attrib, "pos", (0.0, 0.0, 0.0))
255
467
  body_ori = parse_orientation(body_attrib)
256
468
  if parent == -1:
@@ -270,11 +482,17 @@ def parse_mjcf(
270
482
  if len(freejoint_tags) > 0:
271
483
  joint_type = wp.sim.JOINT_FREE
272
484
  joint_name.append(freejoint_tags[0].attrib.get("name", f"{body_name}_freejoint"))
485
+ joint_armature.append(0.0)
273
486
  else:
274
487
  joints = body.findall("joint")
275
488
  for _i, joint in enumerate(joints):
276
- if "joint" in defaults:
277
- joint_attrib = merge_attrib(defaults["joint"], joint.attrib)
489
+ joint_defaults = defaults
490
+ if "class" in joint.attrib:
491
+ joint_class = joint.attrib["class"]
492
+ if joint_class in class_defaults:
493
+ joint_defaults = merge_attrib(joint_defaults, class_defaults[joint_class])
494
+ if "joint" in joint_defaults:
495
+ joint_attrib = merge_attrib(joint_defaults["joint"], joint.attrib)
278
496
  else:
279
497
  joint_attrib = joint.attrib
280
498
 
@@ -283,7 +501,7 @@ def parse_mjcf(
283
501
 
284
502
  joint_name.append(joint_attrib["name"])
285
503
  joint_pos.append(parse_vec(joint_attrib, "pos", (0.0, 0.0, 0.0)) * scale)
286
- joint_range = parse_vec(joint_attrib, "range", (-3.0, 3.0))
504
+ joint_range = parse_vec(joint_attrib, "range", (joint_limit_lower, joint_limit_upper))
287
505
  joint_armature.append(parse_float(joint_attrib, "armature", armature) * armature_scale)
288
506
 
289
507
  if joint_type_str == "free":
@@ -297,10 +515,12 @@ def parse_mjcf(
297
515
  if stiffness > 0.0 or "stiffness" in joint_attrib:
298
516
  mode = wp.sim.JOINT_MODE_TARGET_POSITION
299
517
  axis_vec = parse_vec(joint_attrib, "axis", (0.0, 0.0, 0.0))
300
- ax = wp.sim.model.JointAxis(
518
+ limit_lower = np.deg2rad(joint_range[0]) if is_angular and use_degrees else joint_range[0]
519
+ limit_upper = np.deg2rad(joint_range[1]) if is_angular and use_degrees else joint_range[1]
520
+ ax = wp.sim.JointAxis(
301
521
  axis=axis_vec,
302
- limit_lower=(np.deg2rad(joint_range[0]) if is_angular and use_degrees else joint_range[0]),
303
- limit_upper=(np.deg2rad(joint_range[1]) if is_angular and use_degrees else joint_range[1]),
522
+ limit_lower=limit_lower,
523
+ limit_upper=limit_upper,
304
524
  target_ke=parse_float(joint_attrib, "stiffness", stiffness),
305
525
  target_kd=parse_float(joint_attrib, "damping", damping),
306
526
  limit_ke=limit_ke,
@@ -333,23 +553,85 @@ def parse_mjcf(
333
553
  else:
334
554
  joint_type = wp.sim.JOINT_D6
335
555
 
336
- joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
337
- builder.add_joint(
338
- joint_type,
339
- parent,
340
- link,
341
- linear_axes,
342
- angular_axes,
343
- name="_".join(joint_name),
344
- parent_xform=wp.transform(body_pos + joint_pos, body_ori),
345
- child_xform=wp.transform(joint_pos, wp.quat_identity()),
346
- armature=joint_armature[0] if len(joint_armature) > 0 else armature,
347
- )
556
+ if len(freejoint_tags) > 0 and parent == -1 and (base_joint is not None or floating is not None):
557
+ joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
558
+ _xform = wp.transform(body_pos + joint_pos, body_ori)
559
+
560
+ if base_joint is not None:
561
+ # in case of a given base joint, the position is applied first, the rotation only
562
+ # after the base joint itself to not rotate its axis
563
+ base_parent_xform = wp.transform(_xform.p, wp.quat_identity())
564
+ base_child_xform = wp.transform((0.0, 0.0, 0.0), wp.quat_inverse(_xform.q))
565
+ if isinstance(base_joint, str):
566
+ axes = base_joint.lower().split(",")
567
+ axes = [ax.strip() for ax in axes]
568
+ linear_axes = [ax[-1] for ax in axes if ax[0] in {"l", "p"}]
569
+ angular_axes = [ax[-1] for ax in axes if ax[0] in {"a", "r"}]
570
+ axes = {
571
+ "x": [1.0, 0.0, 0.0],
572
+ "y": [0.0, 1.0, 0.0],
573
+ "z": [0.0, 0.0, 1.0],
574
+ }
575
+ builder.add_joint_d6(
576
+ linear_axes=[wp.sim.JointAxis(axes[a]) for a in linear_axes],
577
+ angular_axes=[wp.sim.JointAxis(axes[a]) for a in angular_axes],
578
+ parent_xform=base_parent_xform,
579
+ child_xform=base_child_xform,
580
+ parent=-1,
581
+ child=link,
582
+ name="base_joint",
583
+ )
584
+ elif isinstance(base_joint, dict):
585
+ base_joint["parent"] = -1
586
+ base_joint["child"] = root
587
+ base_joint["parent_xform"] = base_parent_xform
588
+ base_joint["child_xform"] = base_child_xform
589
+ base_joint["name"] = "base_joint"
590
+ builder.add_joint(**base_joint)
591
+ else:
592
+ raise ValueError(
593
+ "base_joint must be a comma-separated string of joint axes or a dict with joint parameters"
594
+ )
595
+ elif floating:
596
+ builder.add_joint_free(link, name="floating_base")
597
+
598
+ # set dofs to transform
599
+ start = builder.joint_q_start[link]
600
+
601
+ builder.joint_q[start + 0] = _xform.p[0]
602
+ builder.joint_q[start + 1] = _xform.p[1]
603
+ builder.joint_q[start + 2] = _xform.p[2]
604
+
605
+ builder.joint_q[start + 3] = _xform.q[0]
606
+ builder.joint_q[start + 4] = _xform.q[1]
607
+ builder.joint_q[start + 5] = _xform.q[2]
608
+ builder.joint_q[start + 6] = _xform.q[3]
609
+ else:
610
+ builder.add_joint_fixed(-1, link, parent_xform=_xform, name="fixed_base")
611
+
612
+ else:
613
+ joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
614
+ if len(joint_name) == 0:
615
+ joint_name = [f"{body_name}_joint"]
616
+ builder.add_joint(
617
+ joint_type,
618
+ parent,
619
+ link,
620
+ linear_axes,
621
+ angular_axes,
622
+ name="_".join(joint_name),
623
+ parent_xform=wp.transform(body_pos + joint_pos, body_ori),
624
+ child_xform=wp.transform(joint_pos, wp.quat_identity()),
625
+ armature=joint_armature[0] if len(joint_armature) > 0 else armature,
626
+ )
348
627
 
349
628
  # -----------------
350
629
  # add shapes
351
630
 
352
- for geo_count, geom in enumerate(body.findall("geom")):
631
+ geoms = body.findall("geom")
632
+ visuals = []
633
+ colliders = []
634
+ for geo_count, geom in enumerate(geoms):
353
635
  geom_defaults = defaults
354
636
  if "class" in geom.attrib:
355
637
  geom_class = geom.attrib["class"]
@@ -368,125 +650,137 @@ def parse_mjcf(
368
650
  geom_attrib = geom.attrib
369
651
 
370
652
  geom_name = geom_attrib.get("name", f"{body_name}_geom_{geo_count}")
371
- geom_type = geom_attrib.get("type", "sphere")
372
- if "mesh" in geom_attrib:
373
- geom_type = "mesh"
374
-
375
- geom_size = parse_vec(geom_attrib, "size", [1.0, 1.0, 1.0]) * scale
376
- geom_pos = parse_vec(geom_attrib, "pos", (0.0, 0.0, 0.0)) * scale
377
- geom_rot = parse_orientation(geom_attrib)
378
- geom_density = parse_float(geom_attrib, "density", density)
379
-
380
- if geom_type == "sphere":
381
- builder.add_shape_sphere(
382
- link,
383
- pos=geom_pos,
384
- rot=geom_rot,
385
- radius=geom_size[0],
386
- density=geom_density,
387
- **contact_vars,
388
- )
389
-
390
- elif geom_type == "box":
391
- builder.add_shape_box(
392
- link,
393
- pos=geom_pos,
394
- rot=geom_rot,
395
- hx=geom_size[0],
396
- hy=geom_size[1],
397
- hz=geom_size[2],
398
- density=geom_density,
399
- **contact_vars,
400
- )
401
-
402
- elif geom_type == "mesh" and parse_meshes:
403
- mesh, _ = parse_mesh(geom_attrib)
404
- if "mesh" in defaults:
405
- mesh_scale = parse_vec(defaults["mesh"], "scale", [1.0, 1.0, 1.0])
406
- else:
407
- mesh_scale = [1.0, 1.0, 1.0]
408
- # as per the Mujoco XML reference, ignore geom size attribute
409
- assert len(geom_size) == 3, "need to specify size for mesh geom"
410
- builder.add_shape_mesh(
411
- body=link,
412
- pos=geom_pos,
413
- rot=geom_rot,
414
- mesh=mesh,
415
- scale=mesh_scale,
416
- density=density,
417
- **contact_vars,
418
- )
419
-
420
- elif geom_type in {"capsule", "cylinder"}:
421
- if "fromto" in geom_attrib:
422
- geom_fromto = parse_vec(geom_attrib, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
423
-
424
- start = wp.vec3(geom_fromto[0:3]) * scale
425
- end = wp.vec3(geom_fromto[3:6]) * scale
426
653
 
427
- # compute rotation to align the Warp capsule (along x-axis), with mjcf fromto direction
428
- axis = wp.normalize(end - start)
429
- angle = math.acos(wp.dot(axis, wp.vec3(0.0, 1.0, 0.0)))
430
- axis = wp.normalize(wp.cross(axis, wp.vec3(0.0, 1.0, 0.0)))
431
-
432
- geom_pos = (start + end) * 0.5
433
- geom_rot = wp.quat_from_axis_angle(axis, -angle)
434
-
435
- geom_radius = geom_size[0]
436
- geom_height = wp.length(end - start) * 0.5
437
- geom_up_axis = 1
438
-
439
- else:
440
- geom_radius = geom_size[0]
441
- geom_height = geom_size[1]
442
- geom_up_axis = up_axis
443
-
444
- if geom_type == "cylinder":
445
- builder.add_shape_cylinder(
446
- link,
447
- pos=geom_pos,
448
- rot=geom_rot,
449
- radius=geom_radius,
450
- half_height=geom_height,
451
- density=density,
452
- up_axis=geom_up_axis,
453
- **contact_vars,
454
- )
654
+ if "class" in geom.attrib:
655
+ for pattern in visual_classes:
656
+ if re.match(pattern, geom_class):
657
+ visuals.append(geom)
658
+ break
659
+ for pattern in collider_classes:
660
+ if re.match(pattern, geom_class):
661
+ colliders.append(geom)
662
+ break
663
+ else:
664
+ no_class_class = "collision" if no_class_as_colliders else "visual"
665
+ if verbose:
666
+ print(f"MJCF parsing shape {geom_name} issue: no class defined for geom, assuming {no_class_class}")
667
+ if no_class_as_colliders:
668
+ colliders.append(geom)
455
669
  else:
456
- builder.add_shape_capsule(
457
- link,
458
- pos=geom_pos,
459
- rot=geom_rot,
460
- radius=geom_radius,
461
- half_height=geom_height,
462
- density=density,
463
- up_axis=geom_up_axis,
464
- **contact_vars,
465
- )
670
+ visuals.append(geom)
466
671
 
672
+ if parse_visuals_as_colliders:
673
+ colliders = visuals
674
+ else:
675
+ s = parse_shapes(
676
+ defaults, body_name, link, visuals, density=0.0, just_visual=True, visible=not hide_visuals
677
+ )
678
+ visual_shapes.extend(s)
679
+
680
+ show_colliders = force_show_colliders
681
+ if parse_visuals_as_colliders:
682
+ show_colliders = True
683
+ elif len(visuals) == 0:
684
+ # we need to show the collision shapes since there are no visual shapes
685
+ show_colliders = True
686
+
687
+ parse_shapes(defaults, body_name, link, colliders, density, visible=show_colliders)
688
+
689
+ m = builder.body_mass[link]
690
+ if not ignore_inertial_definitions and body.find("inertial") is not None:
691
+ inertial = body.find("inertial")
692
+ if "inertial" in defaults:
693
+ inertial_attrib = merge_attrib(defaults["inertial"], inertial.attrib)
467
694
  else:
468
- print(f"MJCF parsing shape {geom_name} issue: geom type {geom_type} is unsupported")
695
+ inertial_attrib = inertial.attrib
696
+ # overwrite inertial parameters if defined
697
+ inertial_pos = parse_vec(inertial_attrib, "pos", (0.0, 0.0, 0.0)) * scale
698
+ inertial_rot = parse_orientation(inertial_attrib)
699
+
700
+ inertial_frame = wp.transform(inertial_pos, inertial_rot)
701
+ com = inertial_frame.p
702
+ if inertial_attrib.get("diaginertia") is not None:
703
+ diaginertia = parse_vec(inertial_attrib, "diaginertia", None)
704
+ I_m = np.zeros((3, 3))
705
+ I_m[0, 0] = diaginertia[0] * scale**2
706
+ I_m[1, 1] = diaginertia[1] * scale**2
707
+ I_m[2, 2] = diaginertia[2] * scale**2
708
+ else:
709
+ fullinertia = inertial_attrib.get("fullinertia")
710
+ assert fullinertia is not None
711
+ fullinertia = np.fromstring(fullinertia, sep=" ", dtype=np.float32)
712
+ I_m = np.zeros((3, 3))
713
+ I_m[0, 0] = fullinertia[0] * scale**2
714
+ I_m[1, 1] = fullinertia[1] * scale**2
715
+ I_m[2, 2] = fullinertia[2] * scale**2
716
+ I_m[0, 1] = fullinertia[3] * scale**2
717
+ I_m[0, 2] = fullinertia[4] * scale**2
718
+ I_m[1, 2] = fullinertia[5] * scale**2
719
+ I_m[1, 0] = I_m[0, 1]
720
+ I_m[2, 0] = I_m[0, 2]
721
+ I_m[2, 1] = I_m[1, 2]
722
+ rot = wp.quat_to_matrix(inertial_frame.q)
723
+ I_m = rot @ wp.mat33(I_m)
724
+ m = float(inertial_attrib.get("mass", "0"))
725
+ builder.body_mass[link] = m
726
+ builder.body_inv_mass[link] = 1.0 / m if m > 0.0 else 0.0
727
+ builder.body_com[link] = com
728
+ builder.body_inertia[link] = I_m
729
+ if any(x for x in I_m):
730
+ builder.body_inv_inertia[link] = wp.inverse(I_m)
731
+ else:
732
+ builder.body_inv_inertia[link] = I_m
733
+ if m == 0.0 and ensure_nonstatic_links:
734
+ # set the mass to something nonzero to ensure the body is dynamic
735
+ m = static_link_mass
736
+ # cube with side length 0.5
737
+ I_m = wp.mat33(np.eye(3)) * m / 12.0 * (0.5 * scale) ** 2 * 2.0
738
+ I_m += wp.mat33(armature * np.eye(3))
739
+ builder.body_mass[link] = m
740
+ builder.body_inv_mass[link] = 1.0 / m
741
+ builder.body_inertia[link] = I_m
742
+ builder.body_inv_inertia[link] = wp.inverse(I_m)
469
743
 
470
744
  # -----------------
471
745
  # recurse
472
746
 
473
747
  for child in body.findall("body"):
474
- parse_body(child, link, defaults)
748
+ _childclass = body.get("childclass")
749
+ if _childclass is None:
750
+ _childclass = childclass
751
+ _incoming_defaults = defaults
752
+ else:
753
+ _incoming_defaults = merge_attrib(defaults, class_defaults[_childclass])
754
+ parse_body(child, link, _incoming_defaults, childclass=_childclass)
475
755
 
476
756
  # -----------------
477
757
  # start articulation
478
758
 
759
+ visual_shapes = []
479
760
  start_shape_count = len(builder.shape_geo_type)
480
761
  builder.add_articulation()
481
762
 
482
763
  world = root.find("worldbody")
483
764
  world_class = get_class(world)
484
765
  world_defaults = merge_attrib(class_defaults["__all__"], class_defaults.get(world_class, {}))
766
+
767
+ # -----------------
768
+ # add bodies
769
+
485
770
  for body in world.findall("body"):
486
771
  parse_body(body, -1, world_defaults)
487
772
 
773
+ # -----------------
774
+ # add static geoms
775
+
776
+ parse_shapes(world_defaults, "world", -1, world.findall("geom"), density, incoming_xform=xform)
777
+
488
778
  end_shape_count = len(builder.shape_geo_type)
489
779
 
780
+ for i in range(start_shape_count, end_shape_count):
781
+ for j in visual_shapes:
782
+ builder.shape_collision_filter_pairs.add((i, j))
783
+
490
784
  if not enable_self_collisions:
491
785
  for i in range(start_shape_count, end_shape_count):
492
786
  for j in range(i + 1, end_shape_count):