warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/sim/import_mjcf.py CHANGED
@@ -1,45 +1,69 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
8
15
 
9
16
  import math
10
17
  import os
11
18
  import re
12
19
  import xml.etree.ElementTree as ET
20
+ from typing import Union
13
21
 
14
22
  import numpy as np
15
23
 
16
24
  import warp as wp
25
+ from warp.sim.model import Mesh
17
26
 
18
27
 
19
28
  def parse_mjcf(
20
29
  mjcf_filename,
21
30
  builder,
22
31
  xform=None,
32
+ floating=False,
33
+ base_joint: Union[dict, str, None] = None,
23
34
  density=1000.0,
24
- stiffness=0.0,
25
- damping=0.0,
26
- contact_ke=1000.0,
27
- contact_kd=100.0,
28
- contact_kf=100.0,
35
+ stiffness=100.0,
36
+ damping=10.0,
37
+ armature=0.0,
38
+ armature_scale=1.0,
39
+ contact_ke=1.0e4,
40
+ contact_kd=1.0e3,
41
+ contact_kf=1.0e2,
29
42
  contact_ka=0.0,
30
- contact_mu=0.5,
43
+ contact_mu=0.25,
31
44
  contact_restitution=0.5,
32
45
  contact_thickness=0.0,
33
46
  limit_ke=100.0,
34
47
  limit_kd=10.0,
48
+ joint_limit_lower=-1e6,
49
+ joint_limit_upper=1e6,
35
50
  scale=1.0,
36
- armature=0.0,
37
- armature_scale=1.0,
51
+ hide_visuals=False,
52
+ parse_visuals_as_colliders=False,
38
53
  parse_meshes=True,
39
- enable_self_collisions=False,
40
54
  up_axis="Z",
55
+ ignore_names=(),
41
56
  ignore_classes=None,
57
+ visual_classes=("visual",),
58
+ collider_classes=("collision",),
59
+ no_class_as_colliders=True,
60
+ force_show_colliders=False,
61
+ enable_self_collisions=False,
62
+ ignore_inertial_definitions=True,
63
+ ensure_nonstatic_links=True,
64
+ static_link_mass=1e-2,
42
65
  collapse_fixed_joints=False,
66
+ verbose=False,
43
67
  ):
44
68
  """
45
69
  Parses MuJoCo XML (MJCF) file and adds the bodies and joints to the given ModelBuilder.
@@ -48,9 +72,13 @@ def parse_mjcf(
48
72
  mjcf_filename (str): The filename of the MuJoCo file to parse.
49
73
  builder (ModelBuilder): The :class:`ModelBuilder` to add the bodies and joints to.
50
74
  xform (:ref:`transform <transform>`): The transform to apply to the imported mechanism.
75
+ floating (bool): If True, the root body is a free joint. If False, the root body is connected via a fixed joint to the world, unless a `base_joint` is defined.
76
+ base_joint (Union[str, dict]): The joint by which the root body is connected to the world. This can be either a string defining the joint axes of a D6 joint with comma-separated positional and angular axis names (e.g. "px,py,rz" for a D6 joint with linear axes in x, y and an angular axis in z) or a dict with joint parameters (see :meth:`ModelBuilder.add_joint`).
51
77
  density (float): The density of the shapes in kg/m^3 which will be used to calculate the body mass and inertia.
52
78
  stiffness (float): The stiffness of the joints.
53
79
  damping (float): The damping of the joints.
80
+ armature (float): Default joint armature to use if `armature` has not been defined for a joint in the MJCF.
81
+ armature_scale (float): Scaling factor to apply to the MJCF-defined joint armature values.
54
82
  contact_ke (float): The stiffness of the shape contacts.
55
83
  contact_kd (float): The damping of the shape contacts.
56
84
  contact_kf (float): The friction stiffness of the shape contacts.
@@ -60,19 +88,25 @@ def parse_mjcf(
60
88
  contact_thickness (float): The thickness to add to the shape geometry.
61
89
  limit_ke (float): The stiffness of the joint limits.
62
90
  limit_kd (float): The damping of the joint limits.
91
+ joint_limit_lower (float): The default lower joint limit if not specified in the MJCF.
92
+ joint_limit_upper (float): The default upper joint limit if not specified in the MJCF.
63
93
  scale (float): The scaling factor to apply to the imported mechanism.
64
- armature (float): Default joint armature to use if `armature` has not been defined for a joint in the MJCF.
65
- armature_scale (float): Scaling factor to apply to the MJCF-defined joint armature values.
94
+ hide_visuals (bool): If True, hide visual shapes.
95
+ parse_visuals_as_colliders (bool): If True, the geometry defined under the `visual_classes` tags is used for collision handling instead of the `collider_classes` geometries.
66
96
  parse_meshes (bool): Whether geometries of type `"mesh"` should be parsed. If False, geometries of type `"mesh"` are ignored.
67
- enable_self_collisions (bool): If True, self-collisions are enabled.
68
97
  up_axis (str): The up axis of the mechanism. Can be either `"X"`, `"Y"` or `"Z"`. The default is `"Z"`.
98
+ ignore_names (List[str]): A list of regular expressions. Bodies and joints with a name matching one of the regular expressions will be ignored.
69
99
  ignore_classes (List[str]): A list of regular expressions. Bodies and joints with a class matching one of the regular expressions will be ignored.
100
+ visual_classes (List[str]): A list of regular expressions. Visual geometries with a class matching one of the regular expressions will be parsed.
101
+ collider_classes (List[str]): A list of regular expressions. Collision geometries with a class matching one of the regular expressions will be parsed.
102
+ no_class_as_colliders: If True, geometries without a class are parsed as collision geometries. If False, geometries without a class are parsed as visual geometries.
103
+ force_show_colliders (bool): If True, the collision shapes are always shown, even if there are visual shapes.
104
+ enable_self_collisions (bool): If True, self-collisions are enabled.
105
+ ignore_inertial_definitions (bool): If True, the inertial parameters defined in the MJCF are ignored and the inertia is calculated from the shape geometry.
106
+ ensure_nonstatic_links (bool): If True, links with zero mass are given a small mass (see `static_link_mass`) to ensure they are dynamic.
107
+ static_link_mass (float): The mass to assign to links with zero mass (if `ensure_nonstatic_links` is set to True).
70
108
  collapse_fixed_joints (bool): If True, fixed joints are removed and the respective bodies are merged.
71
-
72
- Note:
73
- The inertia and masses of the bodies are calculated from the shape geometry and the given density. The values defined in the MJCF are not respected at the moment.
74
-
75
- The handling of advanced features, such as MJCF classes, is still experimental.
109
+ verbose (bool): If True, print additional information about parsing the MJCF.
76
110
  """
77
111
  if xform is None:
78
112
  xform = wp.transform()
@@ -95,13 +129,15 @@ def parse_mjcf(
95
129
  }
96
130
 
97
131
  use_degrees = True # angles are in degrees by default
98
- euler_seq = [1, 2, 3] # XYZ by default
132
+ euler_seq = [0, 1, 2] # XYZ by default
99
133
 
100
134
  compiler = root.find("compiler")
101
135
  if compiler is not None:
102
136
  use_degrees = compiler.attrib.get("angle", "degree").lower() == "degree"
103
- euler_seq = ["xyz".index(c) + 1 for c in compiler.attrib.get("eulerseq", "xyz").lower()]
137
+ euler_seq = ["xyz".index(c) for c in compiler.attrib.get("eulerseq", "xyz").lower()]
104
138
  mesh_dir = compiler.attrib.get("meshdir", ".")
139
+ else:
140
+ mesh_dir = "."
105
141
 
106
142
  mesh_assets = {}
107
143
  for asset in root.findall("asset"):
@@ -111,11 +147,10 @@ def parse_mjcf(
111
147
  # handle stl relative paths
112
148
  if not os.path.isabs(fname):
113
149
  fname = os.path.abspath(os.path.join(mjcf_dirname, fname))
114
- if "name" in mesh.attrib:
115
- mesh_assets[mesh.attrib["name"]] = fname
116
- else:
117
- name = ".".join(os.path.basename(fname).split(".")[:-1])
118
- mesh_assets[name] = fname
150
+ name = mesh.attrib.get("name", ".".join(os.path.basename(fname).split(".")[:-1]))
151
+ s = mesh.attrib.get("scale", "1.0 1.0 1.0")
152
+ s = np.fromstring(s, sep=" ", dtype=np.float32)
153
+ mesh_assets[name] = {"file": fname, "scale": s}
119
154
 
120
155
  class_parent = {}
121
156
  class_children = {}
@@ -189,14 +224,14 @@ def parse_mjcf(
189
224
  euler = np.fromstring(attrib["euler"], sep=" ")
190
225
  if use_degrees:
191
226
  euler *= np.pi / 180
192
- return wp.quat_from_euler(euler, *euler_seq)
227
+ return wp.sim.quat_from_euler(wp.vec3(euler), *euler_seq)
193
228
  if "axisangle" in attrib:
194
229
  axisangle = np.fromstring(attrib["axisangle"], sep=" ")
195
230
  angle = axisangle[3]
196
231
  if use_degrees:
197
232
  angle *= np.pi / 180
198
233
  axis = wp.normalize(wp.vec3(*axisangle[:3]))
199
- return wp.quat_from_axis_angle(axis, angle)
234
+ return wp.quat_from_axis_angle(axis, float(angle))
200
235
  if "xyaxes" in attrib:
201
236
  xyaxes = np.fromstring(attrib["xyaxes"], sep=" ")
202
237
  xaxis = wp.normalize(wp.vec3(*xyaxes[:3]))
@@ -213,26 +248,209 @@ def parse_mjcf(
213
248
  return wp.quat_from_matrix(rot_matrix)
214
249
  return wp.quat_identity()
215
250
 
216
- def parse_mesh(geom):
217
- import trimesh
251
+ def parse_shapes(defaults, body_name, link, geoms, density, visible=True, just_visual=False, incoming_xform=None):
252
+ shapes = []
253
+ for geo_count, geom in enumerate(geoms):
254
+ geom_defaults = defaults
255
+ if "class" in geom.attrib:
256
+ geom_class = geom.attrib["class"]
257
+ ignore_geom = False
258
+ for pattern in ignore_classes:
259
+ if re.match(pattern, geom_class):
260
+ ignore_geom = True
261
+ break
262
+ if ignore_geom:
263
+ continue
264
+ if geom_class in class_defaults:
265
+ geom_defaults = merge_attrib(defaults, class_defaults[geom_class])
266
+ if "geom" in geom_defaults:
267
+ geom_attrib = merge_attrib(geom_defaults["geom"], geom.attrib)
268
+ else:
269
+ geom_attrib = geom.attrib
270
+
271
+ geom_name = geom_attrib.get("name", f"{body_name}_geom_{geo_count}{'_visual' if just_visual else ''}")
272
+ geom_type = geom_attrib.get("type", "sphere")
273
+ if "mesh" in geom_attrib:
274
+ geom_type = "mesh"
275
+
276
+ ignore_geom = False
277
+ for pattern in ignore_names:
278
+ if re.match(pattern, geom_name):
279
+ ignore_geom = True
280
+ break
281
+ if ignore_geom:
282
+ continue
283
+
284
+ geom_size = parse_vec(geom_attrib, "size", [1.0, 1.0, 1.0]) * scale
285
+ geom_pos = parse_vec(geom_attrib, "pos", (0.0, 0.0, 0.0)) * scale
286
+ geom_rot = parse_orientation(geom_attrib)
287
+ geom_density = parse_float(geom_attrib, "density", density)
288
+
289
+ if incoming_xform is not None:
290
+ geom_pos = wp.transform_point(incoming_xform, geom_pos)
291
+ geom_rot = incoming_xform.q * geom_rot
292
+
293
+ if geom_type == "sphere":
294
+ s = builder.add_shape_sphere(
295
+ link,
296
+ pos=geom_pos,
297
+ rot=geom_rot,
298
+ radius=geom_size[0],
299
+ density=geom_density,
300
+ is_visible=visible,
301
+ has_ground_collision=not just_visual,
302
+ has_shape_collision=not just_visual,
303
+ **contact_vars,
304
+ )
305
+ shapes.append(s)
306
+
307
+ elif geom_type == "box":
308
+ s = builder.add_shape_box(
309
+ link,
310
+ pos=geom_pos,
311
+ rot=geom_rot,
312
+ hx=geom_size[0],
313
+ hy=geom_size[1],
314
+ hz=geom_size[2],
315
+ density=geom_density,
316
+ is_visible=visible,
317
+ has_ground_collision=not just_visual,
318
+ has_shape_collision=not just_visual,
319
+ **contact_vars,
320
+ )
321
+ shapes.append(s)
322
+
323
+ elif geom_type == "mesh" and parse_meshes:
324
+ import trimesh
325
+
326
+ # use force='mesh' to load the mesh as a trimesh object
327
+ # with baked in transforms, e.g. from COLLADA files
328
+ stl_file = mesh_assets[geom_attrib["mesh"]]["file"]
329
+ m = trimesh.load(stl_file, force="mesh")
330
+ if "mesh" in geom_defaults:
331
+ mesh_scale = parse_vec(geom_defaults["mesh"], "scale", mesh_assets[geom_attrib["mesh"]]["scale"])
332
+ else:
333
+ mesh_scale = mesh_assets[geom_attrib["mesh"]]["scale"]
334
+ scaling = np.array(mesh_scale) * scale
335
+ # as per the Mujoco XML reference, ignore geom size attribute
336
+ assert len(geom_size) == 3, "need to specify size for mesh geom"
337
+
338
+ if hasattr(m, "geometry"):
339
+ # multiple meshes are contained in a scene
340
+ for m_geom in m.geometry.values():
341
+ m_vertices = np.array(m_geom.vertices, dtype=np.float32) * scaling
342
+ m_faces = np.array(m_geom.faces.flatten(), dtype=np.int32)
343
+ m_mesh = Mesh(m_vertices, m_faces)
344
+ s = builder.add_shape_mesh(
345
+ body=link,
346
+ pos=geom_pos,
347
+ rot=geom_rot,
348
+ mesh=m_mesh,
349
+ density=density,
350
+ is_visible=visible,
351
+ has_ground_collision=not just_visual,
352
+ has_shape_collision=not just_visual,
353
+ **contact_vars,
354
+ )
355
+ shapes.append(s)
356
+ else:
357
+ # a single mesh
358
+ m_vertices = np.array(m.vertices, dtype=np.float32) * scaling
359
+ m_faces = np.array(m.faces.flatten(), dtype=np.int32)
360
+ m_mesh = Mesh(m_vertices, m_faces)
361
+ s = builder.add_shape_mesh(
362
+ body=link,
363
+ pos=geom_pos,
364
+ rot=geom_rot,
365
+ mesh=m_mesh,
366
+ density=density,
367
+ is_visible=visible,
368
+ has_ground_collision=not just_visual,
369
+ has_shape_collision=not just_visual,
370
+ **contact_vars,
371
+ )
372
+ shapes.append(s)
373
+
374
+ elif geom_type in {"capsule", "cylinder"}:
375
+ if "fromto" in geom_attrib:
376
+ geom_fromto = parse_vec(geom_attrib, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
377
+
378
+ start = wp.vec3(geom_fromto[0:3]) * scale
379
+ end = wp.vec3(geom_fromto[3:6]) * scale
380
+
381
+ # compute rotation to align the Warp capsule (along x-axis), with mjcf fromto direction
382
+ axis = wp.normalize(end - start)
383
+ angle = math.acos(wp.dot(axis, wp.vec3(0.0, 1.0, 0.0)))
384
+ axis = wp.normalize(wp.cross(axis, wp.vec3(0.0, 1.0, 0.0)))
385
+
386
+ geom_pos = (start + end) * 0.5
387
+ geom_rot = wp.quat_from_axis_angle(axis, -angle)
388
+
389
+ geom_radius = geom_size[0]
390
+ geom_height = wp.length(end - start) * 0.5
391
+ geom_up_axis = 1
392
+
393
+ else:
394
+ geom_radius = geom_size[0]
395
+ geom_height = geom_size[1]
396
+ geom_up_axis = up_axis
397
+
398
+ if geom_type == "cylinder":
399
+ s = builder.add_shape_cylinder(
400
+ link,
401
+ pos=geom_pos,
402
+ rot=geom_rot,
403
+ radius=geom_radius,
404
+ half_height=geom_height,
405
+ density=density,
406
+ up_axis=geom_up_axis,
407
+ is_visible=visible,
408
+ has_ground_collision=not just_visual,
409
+ has_shape_collision=not just_visual,
410
+ **contact_vars,
411
+ )
412
+ shapes.append(s)
413
+ else:
414
+ s = builder.add_shape_capsule(
415
+ link,
416
+ pos=geom_pos,
417
+ rot=geom_rot,
418
+ radius=geom_radius,
419
+ half_height=geom_height,
420
+ density=density,
421
+ up_axis=geom_up_axis,
422
+ is_visible=visible,
423
+ has_ground_collision=not just_visual,
424
+ has_shape_collision=not just_visual,
425
+ **contact_vars,
426
+ )
427
+ shapes.append(s)
218
428
 
219
- faces = []
220
- vertices = []
221
- stl_file = mesh_assets[geom["mesh"]]
222
- m = trimesh.load(stl_file)
429
+ elif geom_type == "plane":
430
+ normal = wp.quat_rotate(geom_rot, wp.vec3(0.0, 0.0, 1.0))
431
+ p = wp.dot(geom_pos, normal)
432
+ s = builder.add_shape_plane(
433
+ body=link,
434
+ plane=(*normal, p),
435
+ width=geom_size[0],
436
+ length=geom_size[1],
437
+ is_visible=visible,
438
+ has_ground_collision=False,
439
+ has_shape_collision=not just_visual,
440
+ **contact_vars,
441
+ )
442
+ shapes.append(s)
223
443
 
224
- for v in m.vertices:
225
- vertices.append(np.array(v) * scale)
444
+ else:
445
+ if verbose:
446
+ print(f"MJCF parsing shape {geom_name} issue: geom type {geom_type} is unsupported")
226
447
 
227
- for f in m.faces:
228
- faces.append(int(f[0]))
229
- faces.append(int(f[1]))
230
- faces.append(int(f[2]))
231
- return wp.sim.Mesh(vertices, faces), m.scale
448
+ return shapes
232
449
 
233
- def parse_body(body, parent, incoming_defaults: dict):
234
- body_class = body.get("childclass")
450
+ def parse_body(body, parent, incoming_defaults: dict, childclass: str = None):
451
+ body_class = body.get("class")
235
452
  if body_class is None:
453
+ body_class = childclass
236
454
  defaults = incoming_defaults
237
455
  else:
238
456
  for pattern in ignore_classes:
@@ -244,6 +462,7 @@ def parse_mjcf(
244
462
  else:
245
463
  body_attrib = body.attrib
246
464
  body_name = body_attrib["name"]
465
+ body_name = body_name.replace("-", "_") # ensure valid USD path
247
466
  body_pos = parse_vec(body_attrib, "pos", (0.0, 0.0, 0.0))
248
467
  body_ori = parse_orientation(body_attrib)
249
468
  if parent == -1:
@@ -263,11 +482,17 @@ def parse_mjcf(
263
482
  if len(freejoint_tags) > 0:
264
483
  joint_type = wp.sim.JOINT_FREE
265
484
  joint_name.append(freejoint_tags[0].attrib.get("name", f"{body_name}_freejoint"))
485
+ joint_armature.append(0.0)
266
486
  else:
267
487
  joints = body.findall("joint")
268
488
  for _i, joint in enumerate(joints):
269
- if "joint" in defaults:
270
- joint_attrib = merge_attrib(defaults["joint"], joint.attrib)
489
+ joint_defaults = defaults
490
+ if "class" in joint.attrib:
491
+ joint_class = joint.attrib["class"]
492
+ if joint_class in class_defaults:
493
+ joint_defaults = merge_attrib(joint_defaults, class_defaults[joint_class])
494
+ if "joint" in joint_defaults:
495
+ joint_attrib = merge_attrib(joint_defaults["joint"], joint.attrib)
271
496
  else:
272
497
  joint_attrib = joint.attrib
273
498
 
@@ -276,7 +501,7 @@ def parse_mjcf(
276
501
 
277
502
  joint_name.append(joint_attrib["name"])
278
503
  joint_pos.append(parse_vec(joint_attrib, "pos", (0.0, 0.0, 0.0)) * scale)
279
- joint_range = parse_vec(joint_attrib, "range", (-3.0, 3.0))
504
+ joint_range = parse_vec(joint_attrib, "range", (joint_limit_lower, joint_limit_upper))
280
505
  joint_armature.append(parse_float(joint_attrib, "armature", armature) * armature_scale)
281
506
 
282
507
  if joint_type_str == "free":
@@ -290,10 +515,12 @@ def parse_mjcf(
290
515
  if stiffness > 0.0 or "stiffness" in joint_attrib:
291
516
  mode = wp.sim.JOINT_MODE_TARGET_POSITION
292
517
  axis_vec = parse_vec(joint_attrib, "axis", (0.0, 0.0, 0.0))
293
- ax = wp.sim.model.JointAxis(
518
+ limit_lower = np.deg2rad(joint_range[0]) if is_angular and use_degrees else joint_range[0]
519
+ limit_upper = np.deg2rad(joint_range[1]) if is_angular and use_degrees else joint_range[1]
520
+ ax = wp.sim.JointAxis(
294
521
  axis=axis_vec,
295
- limit_lower=(np.deg2rad(joint_range[0]) if is_angular and use_degrees else joint_range[0]),
296
- limit_upper=(np.deg2rad(joint_range[1]) if is_angular and use_degrees else joint_range[1]),
522
+ limit_lower=limit_lower,
523
+ limit_upper=limit_upper,
297
524
  target_ke=parse_float(joint_attrib, "stiffness", stiffness),
298
525
  target_kd=parse_float(joint_attrib, "damping", damping),
299
526
  limit_ke=limit_ke,
@@ -326,23 +553,85 @@ def parse_mjcf(
326
553
  else:
327
554
  joint_type = wp.sim.JOINT_D6
328
555
 
329
- joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
330
- builder.add_joint(
331
- joint_type,
332
- parent,
333
- link,
334
- linear_axes,
335
- angular_axes,
336
- name="_".join(joint_name),
337
- parent_xform=wp.transform(body_pos + joint_pos, body_ori),
338
- child_xform=wp.transform(joint_pos, wp.quat_identity()),
339
- armature=joint_armature[0] if len(joint_armature) > 0 else armature,
340
- )
556
+ if len(freejoint_tags) > 0 and parent == -1 and (base_joint is not None or floating is not None):
557
+ joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
558
+ _xform = wp.transform(body_pos + joint_pos, body_ori)
559
+
560
+ if base_joint is not None:
561
+ # in case of a given base joint, the position is applied first, the rotation only
562
+ # after the base joint itself to not rotate its axis
563
+ base_parent_xform = wp.transform(_xform.p, wp.quat_identity())
564
+ base_child_xform = wp.transform((0.0, 0.0, 0.0), wp.quat_inverse(_xform.q))
565
+ if isinstance(base_joint, str):
566
+ axes = base_joint.lower().split(",")
567
+ axes = [ax.strip() for ax in axes]
568
+ linear_axes = [ax[-1] for ax in axes if ax[0] in {"l", "p"}]
569
+ angular_axes = [ax[-1] for ax in axes if ax[0] in {"a", "r"}]
570
+ axes = {
571
+ "x": [1.0, 0.0, 0.0],
572
+ "y": [0.0, 1.0, 0.0],
573
+ "z": [0.0, 0.0, 1.0],
574
+ }
575
+ builder.add_joint_d6(
576
+ linear_axes=[wp.sim.JointAxis(axes[a]) for a in linear_axes],
577
+ angular_axes=[wp.sim.JointAxis(axes[a]) for a in angular_axes],
578
+ parent_xform=base_parent_xform,
579
+ child_xform=base_child_xform,
580
+ parent=-1,
581
+ child=link,
582
+ name="base_joint",
583
+ )
584
+ elif isinstance(base_joint, dict):
585
+ base_joint["parent"] = -1
586
+ base_joint["child"] = root
587
+ base_joint["parent_xform"] = base_parent_xform
588
+ base_joint["child_xform"] = base_child_xform
589
+ base_joint["name"] = "base_joint"
590
+ builder.add_joint(**base_joint)
591
+ else:
592
+ raise ValueError(
593
+ "base_joint must be a comma-separated string of joint axes or a dict with joint parameters"
594
+ )
595
+ elif floating:
596
+ builder.add_joint_free(link, name="floating_base")
597
+
598
+ # set dofs to transform
599
+ start = builder.joint_q_start[link]
600
+
601
+ builder.joint_q[start + 0] = _xform.p[0]
602
+ builder.joint_q[start + 1] = _xform.p[1]
603
+ builder.joint_q[start + 2] = _xform.p[2]
604
+
605
+ builder.joint_q[start + 3] = _xform.q[0]
606
+ builder.joint_q[start + 4] = _xform.q[1]
607
+ builder.joint_q[start + 5] = _xform.q[2]
608
+ builder.joint_q[start + 6] = _xform.q[3]
609
+ else:
610
+ builder.add_joint_fixed(-1, link, parent_xform=_xform, name="fixed_base")
611
+
612
+ else:
613
+ joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
614
+ if len(joint_name) == 0:
615
+ joint_name = [f"{body_name}_joint"]
616
+ builder.add_joint(
617
+ joint_type,
618
+ parent,
619
+ link,
620
+ linear_axes,
621
+ angular_axes,
622
+ name="_".join(joint_name),
623
+ parent_xform=wp.transform(body_pos + joint_pos, body_ori),
624
+ child_xform=wp.transform(joint_pos, wp.quat_identity()),
625
+ armature=joint_armature[0] if len(joint_armature) > 0 else armature,
626
+ )
341
627
 
342
628
  # -----------------
343
629
  # add shapes
344
630
 
345
- for geo_count, geom in enumerate(body.findall("geom")):
631
+ geoms = body.findall("geom")
632
+ visuals = []
633
+ colliders = []
634
+ for geo_count, geom in enumerate(geoms):
346
635
  geom_defaults = defaults
347
636
  if "class" in geom.attrib:
348
637
  geom_class = geom.attrib["class"]
@@ -361,125 +650,137 @@ def parse_mjcf(
361
650
  geom_attrib = geom.attrib
362
651
 
363
652
  geom_name = geom_attrib.get("name", f"{body_name}_geom_{geo_count}")
364
- geom_type = geom_attrib.get("type", "sphere")
365
- if "mesh" in geom_attrib:
366
- geom_type = "mesh"
367
653
 
368
- geom_size = parse_vec(geom_attrib, "size", [1.0, 1.0, 1.0]) * scale
369
- geom_pos = parse_vec(geom_attrib, "pos", (0.0, 0.0, 0.0)) * scale
370
- geom_rot = parse_orientation(geom_attrib)
371
- geom_density = parse_float(geom_attrib, "density", density)
372
-
373
- if geom_type == "sphere":
374
- builder.add_shape_sphere(
375
- link,
376
- pos=geom_pos,
377
- rot=geom_rot,
378
- radius=geom_size[0],
379
- density=geom_density,
380
- **contact_vars,
381
- )
382
-
383
- elif geom_type == "box":
384
- builder.add_shape_box(
385
- link,
386
- pos=geom_pos,
387
- rot=geom_rot,
388
- hx=geom_size[0],
389
- hy=geom_size[1],
390
- hz=geom_size[2],
391
- density=geom_density,
392
- **contact_vars,
393
- )
394
-
395
- elif geom_type == "mesh" and parse_meshes:
396
- mesh, _ = parse_mesh(geom_attrib)
397
- if "mesh" in defaults:
398
- mesh_scale = parse_vec(defaults["mesh"], "scale", [1.0, 1.0, 1.0])
399
- else:
400
- mesh_scale = [1.0, 1.0, 1.0]
401
- # as per the Mujoco XML reference, ignore geom size attribute
402
- assert len(geom_size) == 3, "need to specify size for mesh geom"
403
- builder.add_shape_mesh(
404
- body=link,
405
- pos=geom_pos,
406
- rot=geom_rot,
407
- mesh=mesh,
408
- scale=mesh_scale,
409
- density=density,
410
- **contact_vars,
411
- )
412
-
413
- elif geom_type in {"capsule", "cylinder"}:
414
- if "fromto" in geom_attrib:
415
- geom_fromto = parse_vec(geom_attrib, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
416
-
417
- start = wp.vec3(geom_fromto[0:3]) * scale
418
- end = wp.vec3(geom_fromto[3:6]) * scale
419
-
420
- # compute rotation to align the Warp capsule (along x-axis), with mjcf fromto direction
421
- axis = wp.normalize(end - start)
422
- angle = math.acos(wp.dot(axis, wp.vec3(0.0, 1.0, 0.0)))
423
- axis = wp.normalize(wp.cross(axis, wp.vec3(0.0, 1.0, 0.0)))
424
-
425
- geom_pos = (start + end) * 0.5
426
- geom_rot = wp.quat_from_axis_angle(axis, -angle)
427
-
428
- geom_radius = geom_size[0]
429
- geom_height = wp.length(end - start) * 0.5
430
- geom_up_axis = 1
431
-
432
- else:
433
- geom_radius = geom_size[0]
434
- geom_height = geom_size[1]
435
- geom_up_axis = up_axis
436
-
437
- if geom_type == "cylinder":
438
- builder.add_shape_cylinder(
439
- link,
440
- pos=geom_pos,
441
- rot=geom_rot,
442
- radius=geom_radius,
443
- half_height=geom_height,
444
- density=density,
445
- up_axis=geom_up_axis,
446
- **contact_vars,
447
- )
654
+ if "class" in geom.attrib:
655
+ for pattern in visual_classes:
656
+ if re.match(pattern, geom_class):
657
+ visuals.append(geom)
658
+ break
659
+ for pattern in collider_classes:
660
+ if re.match(pattern, geom_class):
661
+ colliders.append(geom)
662
+ break
663
+ else:
664
+ no_class_class = "collision" if no_class_as_colliders else "visual"
665
+ if verbose:
666
+ print(f"MJCF parsing shape {geom_name} issue: no class defined for geom, assuming {no_class_class}")
667
+ if no_class_as_colliders:
668
+ colliders.append(geom)
448
669
  else:
449
- builder.add_shape_capsule(
450
- link,
451
- pos=geom_pos,
452
- rot=geom_rot,
453
- radius=geom_radius,
454
- half_height=geom_height,
455
- density=density,
456
- up_axis=geom_up_axis,
457
- **contact_vars,
458
- )
670
+ visuals.append(geom)
459
671
 
672
+ if parse_visuals_as_colliders:
673
+ colliders = visuals
674
+ else:
675
+ s = parse_shapes(
676
+ defaults, body_name, link, visuals, density=0.0, just_visual=True, visible=not hide_visuals
677
+ )
678
+ visual_shapes.extend(s)
679
+
680
+ show_colliders = force_show_colliders
681
+ if parse_visuals_as_colliders:
682
+ show_colliders = True
683
+ elif len(visuals) == 0:
684
+ # we need to show the collision shapes since there are no visual shapes
685
+ show_colliders = True
686
+
687
+ parse_shapes(defaults, body_name, link, colliders, density, visible=show_colliders)
688
+
689
+ m = builder.body_mass[link]
690
+ if not ignore_inertial_definitions and body.find("inertial") is not None:
691
+ inertial = body.find("inertial")
692
+ if "inertial" in defaults:
693
+ inertial_attrib = merge_attrib(defaults["inertial"], inertial.attrib)
460
694
  else:
461
- print(f"MJCF parsing shape {geom_name} issue: geom type {geom_type} is unsupported")
695
+ inertial_attrib = inertial.attrib
696
+ # overwrite inertial parameters if defined
697
+ inertial_pos = parse_vec(inertial_attrib, "pos", (0.0, 0.0, 0.0)) * scale
698
+ inertial_rot = parse_orientation(inertial_attrib)
699
+
700
+ inertial_frame = wp.transform(inertial_pos, inertial_rot)
701
+ com = inertial_frame.p
702
+ if inertial_attrib.get("diaginertia") is not None:
703
+ diaginertia = parse_vec(inertial_attrib, "diaginertia", None)
704
+ I_m = np.zeros((3, 3))
705
+ I_m[0, 0] = diaginertia[0] * scale**2
706
+ I_m[1, 1] = diaginertia[1] * scale**2
707
+ I_m[2, 2] = diaginertia[2] * scale**2
708
+ else:
709
+ fullinertia = inertial_attrib.get("fullinertia")
710
+ assert fullinertia is not None
711
+ fullinertia = np.fromstring(fullinertia, sep=" ", dtype=np.float32)
712
+ I_m = np.zeros((3, 3))
713
+ I_m[0, 0] = fullinertia[0] * scale**2
714
+ I_m[1, 1] = fullinertia[1] * scale**2
715
+ I_m[2, 2] = fullinertia[2] * scale**2
716
+ I_m[0, 1] = fullinertia[3] * scale**2
717
+ I_m[0, 2] = fullinertia[4] * scale**2
718
+ I_m[1, 2] = fullinertia[5] * scale**2
719
+ I_m[1, 0] = I_m[0, 1]
720
+ I_m[2, 0] = I_m[0, 2]
721
+ I_m[2, 1] = I_m[1, 2]
722
+ rot = wp.quat_to_matrix(inertial_frame.q)
723
+ I_m = rot @ wp.mat33(I_m)
724
+ m = float(inertial_attrib.get("mass", "0"))
725
+ builder.body_mass[link] = m
726
+ builder.body_inv_mass[link] = 1.0 / m if m > 0.0 else 0.0
727
+ builder.body_com[link] = com
728
+ builder.body_inertia[link] = I_m
729
+ if any(x for x in I_m):
730
+ builder.body_inv_inertia[link] = wp.inverse(I_m)
731
+ else:
732
+ builder.body_inv_inertia[link] = I_m
733
+ if m == 0.0 and ensure_nonstatic_links:
734
+ # set the mass to something nonzero to ensure the body is dynamic
735
+ m = static_link_mass
736
+ # cube with side length 0.5
737
+ I_m = wp.mat33(np.eye(3)) * m / 12.0 * (0.5 * scale) ** 2 * 2.0
738
+ I_m += wp.mat33(armature * np.eye(3))
739
+ builder.body_mass[link] = m
740
+ builder.body_inv_mass[link] = 1.0 / m
741
+ builder.body_inertia[link] = I_m
742
+ builder.body_inv_inertia[link] = wp.inverse(I_m)
462
743
 
463
744
  # -----------------
464
745
  # recurse
465
746
 
466
747
  for child in body.findall("body"):
467
- parse_body(child, link, defaults)
748
+ _childclass = body.get("childclass")
749
+ if _childclass is None:
750
+ _childclass = childclass
751
+ _incoming_defaults = defaults
752
+ else:
753
+ _incoming_defaults = merge_attrib(defaults, class_defaults[_childclass])
754
+ parse_body(child, link, _incoming_defaults, childclass=_childclass)
468
755
 
469
756
  # -----------------
470
757
  # start articulation
471
758
 
759
+ visual_shapes = []
472
760
  start_shape_count = len(builder.shape_geo_type)
473
761
  builder.add_articulation()
474
762
 
475
763
  world = root.find("worldbody")
476
764
  world_class = get_class(world)
477
765
  world_defaults = merge_attrib(class_defaults["__all__"], class_defaults.get(world_class, {}))
766
+
767
+ # -----------------
768
+ # add bodies
769
+
478
770
  for body in world.findall("body"):
479
771
  parse_body(body, -1, world_defaults)
480
772
 
773
+ # -----------------
774
+ # add static geoms
775
+
776
+ parse_shapes(world_defaults, "world", -1, world.findall("geom"), density, incoming_xform=xform)
777
+
481
778
  end_shape_count = len(builder.shape_geo_type)
482
779
 
780
+ for i in range(start_shape_count, end_shape_count):
781
+ for j in visual_shapes:
782
+ builder.shape_collision_filter_pairs.add((i, j))
783
+
483
784
  if not enable_self_collisions:
484
785
  for i in range(start_shape_count, end_shape_count):
485
786
  for j in range(i + 1, end_shape_count):