warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/sim/collide.py CHANGED
@@ -394,7 +394,8 @@ def mesh_sdf(mesh: wp.uint64, point: wp.vec3, max_dist: float):
394
394
  face_u = float(0.0)
395
395
  face_v = float(0.0)
396
396
  sign = float(0.0)
397
- res = wp.mesh_query_point(mesh, point, max_dist, sign, face_index, face_u, face_v)
397
+ res = wp.mesh_query_point_sign_normal(mesh, point, max_dist, sign, face_index, face_u, face_v)
398
+
398
399
  if res:
399
400
  closest = wp.mesh_eval_position(mesh, face_index, face_u, face_v)
400
401
  return wp.length(point - closest) * sign
@@ -407,7 +408,8 @@ def closest_point_mesh(mesh: wp.uint64, point: wp.vec3, max_dist: float):
407
408
  face_u = float(0.0)
408
409
  face_v = float(0.0)
409
410
  sign = float(0.0)
410
- res = wp.mesh_query_point(mesh, point, max_dist, sign, face_index, face_u, face_v)
411
+ res = wp.mesh_query_point_sign_normal(mesh, point, max_dist, sign, face_index, face_u, face_v)
412
+
411
413
  if res:
412
414
  return wp.mesh_eval_position(mesh, face_index, face_u, face_v)
413
415
  # return arbitrary point from mesh
@@ -549,7 +551,9 @@ def create_soft_contacts(
549
551
  face_v = float(0.0)
550
552
  sign = float(0.0)
551
553
 
552
- if wp.mesh_query_point(mesh, wp.cw_div(x_local, geo_scale), margin, sign, face_index, face_u, face_v):
554
+ if wp.mesh_query_point_sign_normal(
555
+ mesh, wp.cw_div(x_local, geo_scale), margin + radius, sign, face_index, face_u, face_v
556
+ ):
553
557
  shape_p = wp.mesh_eval_position(mesh, face_index, face_u, face_v)
554
558
  shape_v = wp.mesh_eval_velocity(mesh, face_index, face_u, face_v)
555
559
 
@@ -557,9 +561,17 @@ def create_soft_contacts(
557
561
  shape_v = wp.cw_mul(shape_v, geo_scale)
558
562
 
559
563
  delta = x_local - shape_p
564
+
560
565
  d = wp.length(delta) * sign
561
566
  n = wp.normalize(delta) * sign
562
567
  v = shape_v
568
+
569
+ if geo_type == wp.sim.GEO_SDF:
570
+ volume = geo.source[shape_index]
571
+ xpred_local = wp.volume_world_to_index(volume, wp.cw_div(x_local, geo_scale))
572
+ nn = wp.vec3(0.0, 0.0, 0.0)
573
+ d = wp.volume_sample_grad_f(volume, xpred_local, wp.Volume.LINEAR, nn)
574
+ n = wp.normalize(nn)
563
575
 
564
576
  if geo_type == wp.sim.GEO_PLANE:
565
577
  d = plane_sdf(geo_scale[0], geo_scale[1], x_local)
@@ -941,8 +953,8 @@ def handle_contact_pairs(
941
953
  face_u = float(0.0)
942
954
  face_v = float(0.0)
943
955
  sign = float(0.0)
944
- max_dist = (thickness_a + thickness_b + rigid_contact_margin) / min_scale_b
945
- res = wp.mesh_query_point(
956
+ max_dist = (thickness_a + thickness_b + rigid_contact_margin) / geo_scale_b[0]
957
+ res = wp.mesh_query_point_sign_normal(
946
958
  mesh_b, wp.cw_div(query_b_local, geo_scale_b), max_dist, sign, face_index, face_u, face_v
947
959
  )
948
960
  if res:
@@ -1112,9 +1124,10 @@ def handle_contact_pairs(
1112
1124
  face_u = float(0.0)
1113
1125
  face_v = float(0.0)
1114
1126
  sign = float(0.0)
1115
- res = wp.mesh_query_point(
1127
+ res = wp.mesh_query_point_sign_normal(
1116
1128
  mesh_b, wp.cw_div(query_b_local, geo_scale_b), max_dist, sign, face_index, face_u, face_v
1117
1129
  )
1130
+
1118
1131
  if res:
1119
1132
  shape_p = wp.mesh_eval_position(mesh_b, face_index, face_u, face_v)
1120
1133
  shape_p = wp.cw_mul(shape_p, geo_scale_b)
@@ -1211,7 +1224,7 @@ def handle_contact_pairs(
1211
1224
  face_u = float(0.0)
1212
1225
  face_v = float(0.0)
1213
1226
  sign = float(0.0)
1214
- res = wp.mesh_query_point(
1227
+ res = wp.mesh_query_point_sign_normal(
1215
1228
  mesh_b, wp.cw_div(query_b_local, geo_scale_b), max_dist, sign, face_index, face_u, face_v
1216
1229
  )
1217
1230
 
@@ -1244,7 +1257,7 @@ def handle_contact_pairs(
1244
1257
  min_scale = min(min_scale_a, min_scale_b)
1245
1258
  max_dist = (rigid_contact_margin + thickness_a + thickness_b) / min_scale
1246
1259
 
1247
- res = wp.mesh_query_point(
1260
+ res = wp.mesh_query_point_sign_normal(
1248
1261
  mesh_b, wp.cw_div(query_b_local, geo_scale_b), max_dist, sign, face_index, face_u, face_v
1249
1262
  )
1250
1263
 
warp/sim/import_mjcf.py CHANGED
@@ -8,18 +8,17 @@
8
8
 
9
9
  import math
10
10
  import os
11
+ import re
11
12
  import xml.etree.ElementTree as ET
12
13
 
13
14
  import numpy as np
14
-
15
15
  import warp as wp
16
- from warp.sim.model import JOINT_COMPOUND, JOINT_UNIVERSAL
17
- from warp.sim.model import Mesh
18
16
 
19
17
 
20
18
  def parse_mjcf(
21
- filename,
19
+ mjcf_filename,
22
20
  builder,
21
+ xform=wp.transform(),
23
22
  density=1000.0,
24
23
  stiffness=0.0,
25
24
  damping=0.0,
@@ -30,69 +29,206 @@ def parse_mjcf(
30
29
  contact_restitution=0.5,
31
30
  limit_ke=100.0,
32
31
  limit_kd=10.0,
32
+ scale=1.0,
33
33
  armature=0.0,
34
34
  armature_scale=1.0,
35
- parse_meshes=False,
36
- enable_self_collisions=True,
35
+ parse_meshes=True,
36
+ enable_self_collisions=False,
37
+ up_axis="Z",
38
+ ignore_classes=[],
39
+ collapse_fixed_joints=False,
37
40
  ):
38
- file = ET.parse(filename)
41
+ """
42
+ Parses MuJoCo XML (MJCF) file and adds the bodies and joints to the given ModelBuilder.
43
+
44
+ Args:
45
+ mjcf_filename (str): The filename of the MuJoCo file to parse.
46
+ builder (ModelBuilder): The :class:`ModelBuilder` to add the bodies and joints to.
47
+ xform (:ref:`transform <transform>`): The transform to apply to the imported mechanism.
48
+ density (float): The density of the shapes in kg/m^3 which will be used to calculate the body mass and inertia.
49
+ stiffness (float): The stiffness of the joints.
50
+ damping (float): The damping of the joints.
51
+ contact_ke (float): The stiffness of the shape contacts (used by SemiImplicitIntegrator).
52
+ contact_kd (float): The damping of the shape contacts (used by SemiImplicitIntegrator).
53
+ contact_kf (float): The friction stiffness of the shape contacts (used by SemiImplicitIntegrator).
54
+ contact_mu (float): The friction coefficient of the shape contacts.
55
+ contact_restitution (float): The restitution coefficient of the shape contacts.
56
+ limit_ke (float): The stiffness of the joint limits (used by SemiImplicitIntegrator).
57
+ limit_kd (float): The damping of the joint limits (used by SemiImplicitIntegrator).
58
+ scale (float): The scaling factor to apply to the imported mechanism.
59
+ armature (float): Default joint armature to use if `armature` has not been defined for a joint in the MJCF.
60
+ armature_scale (float): Scaling factor to apply to the MJCF-defined joint armature values.
61
+ parse_meshes (bool): Whether geometries of type `"mesh"` should be parsed. If False, geometries of type `"mesh"` are ignored.
62
+ enable_self_collisions (bool): If True, self-collisions are enabled.
63
+ up_axis (str): The up axis of the mechanism. Can be either `"X"`, `"Y"` or `"Z"`. The default is `"Z"`.
64
+ ignore_classes (List[str]): A list of regular expressions. Bodies and joints with a class matching one of the regular expressions will be ignored.
65
+ collapse_fixed_joints (bool): If True, fixed joints are removed and the respective bodies are merged.
66
+
67
+ Note:
68
+ The inertia and masses of the bodies are calculated from the shape geometry and the given density. The values defined in the MJCF are not respected at the moment.
69
+
70
+ The handling of advanced features, such as MJCF classes, is still experimental.
71
+ """
72
+ mjcf_dirname = os.path.dirname(mjcf_filename)
73
+ file = ET.parse(mjcf_filename)
39
74
  root = file.getroot()
40
75
 
41
- type_map = {
42
- "ball": wp.sim.JOINT_BALL,
43
- "hinge": wp.sim.JOINT_REVOLUTE,
44
- "slide": wp.sim.JOINT_PRISMATIC,
45
- "free": wp.sim.JOINT_FREE,
46
- "fixed": wp.sim.JOINT_FIXED,
47
- }
48
-
49
- def parse_float(node, key, default):
50
- if key in node.attrib:
51
- return float(node.attrib[key])
76
+ use_degrees = True # angles are in degrees by default
77
+ euler_seq = [1, 2, 3] # XYZ by default
78
+
79
+ compiler = root.find("compiler")
80
+ if compiler is not None:
81
+ use_degrees = compiler.attrib.get("angle", "degree").lower() == "degree"
82
+ euler_seq = ["xyz".index(c) + 1 for c in compiler.attrib.get("eulerseq", "xyz").lower()]
83
+ mesh_dir = compiler.attrib.get("meshdir", ".")
84
+
85
+ mesh_assets = {}
86
+ for asset in root.findall("asset"):
87
+ for mesh in asset.findall("mesh"):
88
+ if "file" in mesh.attrib:
89
+ fname = os.path.join(mesh_dir, mesh.attrib["file"])
90
+ # handle stl relative paths
91
+ if not os.path.isabs(fname):
92
+ fname = os.path.abspath(os.path.join(mjcf_dirname, fname))
93
+ if "name" in mesh.attrib:
94
+ mesh_assets[mesh.attrib["name"]] = fname
95
+ else:
96
+ name = ".".join(os.path.basename(fname).split(".")[:-1])
97
+ mesh_assets[name] = fname
98
+
99
+ class_parent = {}
100
+ class_children = {}
101
+ class_defaults = {"__all__": {}}
102
+
103
+ def get_class(element):
104
+ return element.get("class", "__all__")
105
+
106
+ def parse_default(node, parent):
107
+ nonlocal class_parent
108
+ nonlocal class_children
109
+ nonlocal class_defaults
110
+ class_name = "__all__"
111
+ if "class" in node.attrib:
112
+ class_name = node.attrib["class"]
113
+ class_parent[class_name] = parent
114
+ parent = parent or "__all__"
115
+ if parent not in class_children:
116
+ class_children[parent] = []
117
+ class_children[parent].append(class_name)
118
+
119
+ if class_name not in class_defaults:
120
+ class_defaults[class_name] = {}
121
+ for child in node:
122
+ if child.tag == "default":
123
+ parse_default(child, node.get("class"))
124
+ else:
125
+ class_defaults[class_name][child.tag] = child.attrib
126
+
127
+ for default in root.findall("default"):
128
+ parse_default(default, None)
129
+
130
+ def merge_attrib(default_attrib: dict, incoming_attrib: dict):
131
+ attrib = default_attrib.copy()
132
+ attrib.update(incoming_attrib)
133
+ return attrib
134
+
135
+ if isinstance(up_axis, str):
136
+ up_axis = "XYZ".index(up_axis.upper())
137
+ sqh = np.sqrt(0.5)
138
+ if up_axis == 0:
139
+ xform = wp.transform(xform.p, wp.quat(0.0, 0.0, -sqh, sqh) * xform.q)
140
+ elif up_axis == 2:
141
+ xform = wp.transform(xform.p, wp.quat(sqh, 0.0, 0.0, -sqh) * xform.q)
142
+ # do not apply scaling to the root transform
143
+ xform = wp.transform(np.array(xform.p) / scale, xform.q)
144
+
145
+ def parse_float(attrib, key, default):
146
+ if key in attrib:
147
+ return float(attrib[key])
52
148
  else:
53
149
  return default
54
150
 
55
- def parse_vec(node, key, default):
56
- if key in node.attrib:
57
- return np.fromstring(node.attrib[key], sep=" ")
151
+ def parse_vec(attrib, key, default):
152
+ if key in attrib:
153
+ out = np.fromstring(attrib[key], sep=" ", dtype=np.float32)
58
154
  else:
59
- return np.array(default)
155
+ out = np.array(default, dtype=np.float32)
156
+
157
+ length = len(out)
158
+ if length == 1:
159
+ return wp.vec(len(default), wp.float32)(out[0], out[0], out[0])
160
+
161
+ return wp.vec(length, wp.float32)(out)
162
+
163
+ def parse_orientation(attrib):
164
+ if "quat" in attrib:
165
+ wxyz = np.fromstring(attrib["quat"], sep=" ")
166
+ return wp.normalize(wp.quat(*wxyz[1:], wxyz[0]))
167
+ if "euler" in attrib:
168
+ euler = np.fromstring(attrib["euler"], sep=" ")
169
+ if use_degrees:
170
+ euler *= np.pi / 180
171
+ return wp.quat_from_euler(euler, *euler_seq)
172
+ if "axisangle" in attrib:
173
+ axisangle = np.fromstring(attrib["axisangle"], sep=" ")
174
+ angle = axisangle[3]
175
+ if use_degrees:
176
+ angle *= np.pi / 180
177
+ axis = wp.normalize(wp.vec3(*axisangle[:3]))
178
+ return wp.quat_from_axis_angle(axis, angle)
179
+ if "xyaxes" in attrib:
180
+ xyaxes = np.fromstring(attrib["xyaxes"], sep=" ")
181
+ xaxis = wp.normalize(wp.vec3(*xyaxes[:3]))
182
+ zaxis = wp.normalize(wp.vec3(*xyaxes[3:]))
183
+ yaxis = wp.normalize(wp.cross(zaxis, xaxis))
184
+ rot_matrix = np.array([xaxis, yaxis, zaxis]).T
185
+ return wp.quat_from_matrix(rot_matrix)
186
+ if "zaxis" in attrib:
187
+ zaxis = np.fromstring(attrib["zaxis"], sep=" ")
188
+ zaxis = wp.normalize(wp.vec3(*zaxis))
189
+ xaxis = wp.normalize(wp.cross(wp.vec3(0, 0, 1), zaxis))
190
+ yaxis = wp.normalize(wp.cross(zaxis, xaxis))
191
+ rot_matrix = np.array([xaxis, yaxis, zaxis]).T
192
+ return wp.quat_from_matrix(rot_matrix)
193
+ return wp.quat_identity()
60
194
 
61
195
  def parse_mesh(geom):
62
196
  import trimesh
63
197
 
64
198
  faces = []
65
199
  vertices = []
66
- stl_file = next(
67
- filter(
68
- lambda m: m.attrib["name"] == geom.attrib["mesh"],
69
- root.find("asset").findall("mesh"),
70
- )
71
- ).attrib["file"]
72
- # handle stl relative paths
73
- if not os.path.isabs(stl_file):
74
- stl_file = os.path.join(os.path.dirname(filename), stl_file)
200
+ stl_file = mesh_assets[geom["mesh"]]
75
201
  m = trimesh.load(stl_file)
76
202
 
77
203
  for v in m.vertices:
78
- vertices.append(np.array(v))
204
+ vertices.append(np.array(v) * scale)
79
205
 
80
206
  for f in m.faces:
81
207
  faces.append(int(f[0]))
82
208
  faces.append(int(f[1]))
83
209
  faces.append(int(f[2]))
84
- return Mesh(vertices, faces), m.scale
85
-
86
- def parse_body(body, parent):
87
- body_name = body.attrib["name"]
88
- body_pos = parse_vec(body, "pos", (0.0, 0.0, 0.0))
89
- body_ori_euler = parse_vec(body, "euler", (0.0, 0.0, 0.0))
90
- if len(np.nonzero(body_ori_euler)[0]) > 0:
91
- body_axis = tuple(np.sign(body_ori_euler))
92
- body_angle = body_ori_euler[np.nonzero(body_ori_euler)[0].item()] / 180 * np.pi
93
- body_ori = wp.utils.quat_from_axis_angle(body_axis, body_angle)
210
+ return wp.sim.Mesh(vertices, faces), m.scale
211
+
212
+ def parse_body(body, parent, incoming_defaults: dict):
213
+ body_class = body.get("childclass")
214
+ if body_class is None:
215
+ defaults = incoming_defaults
94
216
  else:
95
- body_ori = wp.quat_identity()
217
+ for pattern in ignore_classes:
218
+ if re.match(pattern, body_class):
219
+ return
220
+ defaults = merge_attrib(incoming_defaults, class_defaults[body_class])
221
+ if "body" in defaults:
222
+ body_attrib = merge_attrib(defaults["body"], body.attrib)
223
+ else:
224
+ body_attrib = body.attrib
225
+ body_name = body_attrib["name"]
226
+ body_pos = parse_vec(body_attrib, "pos", (0.0, 0.0, 0.0))
227
+ body_ori = parse_orientation(body_attrib)
228
+ if parent == -1:
229
+ body_pos = wp.transform_point(xform, body_pos)
230
+ body_ori = xform.q * body_ori
231
+ body_pos *= scale
96
232
 
97
233
  joint_armature = []
98
234
  joint_name = []
@@ -102,43 +238,55 @@ def parse_mjcf(
102
238
  angular_axes = []
103
239
  joint_type = None
104
240
 
105
- joints = body.findall("joint")
106
- for i, joint in enumerate(joints):
107
- # default to hinge if not specified
108
- if "type" not in joint.attrib:
109
- joint.attrib["type"] = "hinge"
110
-
111
- joint_name.append(joint.attrib["name"])
112
- joint_pos.append(parse_vec(joint, "pos", (0.0, 0.0, 0.0)))
113
- # TODO parse joint (child transform) rotation?
114
- joint_range = parse_vec(joint, "range", (-3.0, 3.0))
115
- joint_armature.append(parse_float(joint, "armature", armature) * armature_scale)
116
-
117
- if joint.attrib["type"].lower() == "free":
118
- joint_type = wp.sim.JOINT_FREE
119
- break
120
- is_angular = joint.attrib["type"].lower() == "hinge"
121
- mode = wp.sim.JOINT_MODE_LIMIT
122
- if stiffness > 0.0 or "stiffness" in joint.attrib:
123
- mode = wp.sim.JOINT_MODE_TARGET_POSITION
124
- ax = wp.sim.model.JointAxis(
125
- axis=parse_vec(joint, "axis", (0.0, 0.0, 0.0)),
126
- limit_lower=(np.deg2rad(joint_range[0]) if is_angular else joint_range[0]),
127
- limit_upper=(np.deg2rad(joint_range[1]) if is_angular else joint_range[1]),
128
- target_ke=parse_float(joint, "stiffness", stiffness),
129
- target_kd=parse_float(joint, "damping", damping),
130
- limit_ke=limit_ke,
131
- limit_kd=limit_kd,
132
- mode=mode,
133
- )
134
- if is_angular:
135
- angular_axes.append(ax)
136
- else:
137
- linear_axes.append(ax)
241
+ freejoint_tags = body.findall("freejoint")
242
+ if len(freejoint_tags) > 0:
243
+ joint_type = wp.sim.JOINT_FREE
244
+ joint_name.append(freejoint_tags[0].attrib.get("name", f"{body_name}_freejoint"))
245
+ else:
246
+ joints = body.findall("joint")
247
+ for i, joint in enumerate(joints):
248
+ if "joint" in defaults:
249
+ joint_attrib = merge_attrib(defaults["joint"], joint.attrib)
250
+ else:
251
+ joint_attrib = joint.attrib
252
+
253
+ # default to hinge if not specified
254
+ joint_type_str = joint_attrib.get("type", "hinge")
255
+
256
+ joint_name.append(joint_attrib["name"])
257
+ joint_pos.append(parse_vec(joint_attrib, "pos", (0.0, 0.0, 0.0)) * scale)
258
+ joint_range = parse_vec(joint_attrib, "range", (-3.0, 3.0))
259
+ joint_armature.append(parse_float(joint_attrib, "armature", armature) * armature_scale)
260
+
261
+ if joint_type_str == "free":
262
+ joint_type = wp.sim.JOINT_FREE
263
+ break
264
+ if joint_type_str == "fixed":
265
+ joint_type = wp.sim.JOINT_FIXED
266
+ break
267
+ is_angular = joint_type_str == "hinge"
268
+ mode = wp.sim.JOINT_MODE_LIMIT
269
+ if stiffness > 0.0 or "stiffness" in joint_attrib:
270
+ mode = wp.sim.JOINT_MODE_TARGET_POSITION
271
+ axis_vec = parse_vec(joint_attrib, "axis", (0.0, 0.0, 0.0))
272
+ ax = wp.sim.model.JointAxis(
273
+ axis=axis_vec,
274
+ limit_lower=(np.deg2rad(joint_range[0]) if is_angular and use_degrees else joint_range[0]),
275
+ limit_upper=(np.deg2rad(joint_range[1]) if is_angular and use_degrees else joint_range[1]),
276
+ target_ke=parse_float(joint_attrib, "stiffness", stiffness),
277
+ target_kd=parse_float(joint_attrib, "damping", damping),
278
+ limit_ke=limit_ke,
279
+ limit_kd=limit_kd,
280
+ mode=mode,
281
+ )
282
+ if is_angular:
283
+ angular_axes.append(ax)
284
+ else:
285
+ linear_axes.append(ax)
138
286
 
139
287
  link = builder.add_body(
140
- origin=wp.transform_identity(), # will be evaluated in fk()
141
- armature=joint_armature[0],
288
+ origin=wp.transform(body_pos, body_ori), # will be evaluated in fk()
289
+ armature=joint_armature[0] if len(joint_armature) > 0 else armature,
142
290
  name=body_name,
143
291
  )
144
292
 
@@ -157,6 +305,7 @@ def parse_mjcf(
157
305
  else:
158
306
  joint_type = wp.sim.JOINT_D6
159
307
 
308
+ joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
160
309
  builder.add_joint(
161
310
  joint_type,
162
311
  parent,
@@ -164,21 +313,40 @@ def parse_mjcf(
164
313
  linear_axes,
165
314
  angular_axes,
166
315
  name="_".join(joint_name),
167
- parent_xform=wp.transform(body_pos, body_ori),
168
- # child_xform=wp.transform(joint_pos[0], wp.quat_identity()),
316
+ parent_xform=wp.transform(body_pos + joint_pos, body_ori),
317
+ child_xform=wp.transform(joint_pos, wp.quat_identity()),
169
318
  )
170
319
 
171
320
  # -----------------
172
321
  # add shapes
173
322
 
174
- for geom in body.findall("geom"):
175
- geom_name = geom.attrib["name"]
176
- geom_type = geom.attrib["type"]
323
+ for geo_count, geom in enumerate(body.findall("geom")):
324
+ geom_defaults = defaults
325
+ if "class" in geom.attrib:
326
+ geom_class = geom.attrib["class"]
327
+ ignore_geom = False
328
+ for pattern in ignore_classes:
329
+ if re.match(pattern, geom_class):
330
+ ignore_geom = True
331
+ break
332
+ if ignore_geom:
333
+ continue
334
+ if geom_class in class_defaults:
335
+ geom_defaults = merge_attrib(defaults, class_defaults[geom_class])
336
+ if "geom" in geom_defaults:
337
+ geom_attrib = merge_attrib(geom_defaults["geom"], geom.attrib)
338
+ else:
339
+ geom_attrib = geom.attrib
340
+
341
+ geom_name = geom_attrib.get("name", f"{body_name}_geom_{geo_count}")
342
+ geom_type = geom_attrib.get("type", "sphere")
343
+ if "mesh" in geom_attrib:
344
+ geom_type = "mesh"
177
345
 
178
- geom_size = parse_vec(geom, "size", [1.0])
179
- geom_pos = parse_vec(geom, "pos", (0.0, 0.0, 0.0))
180
- geom_rot = parse_vec(geom, "quat", (0.0, 0.0, 0.0, 1.0))
181
- geom_density = parse_float(geom, "density", density)
346
+ geom_size = parse_vec(geom_attrib, "size", [1.0, 1.0, 1.0]) * scale
347
+ geom_pos = parse_vec(geom_attrib, "pos", (0.0, 0.0, 0.0)) * scale
348
+ geom_rot = parse_orientation(geom_attrib)
349
+ geom_density = parse_float(geom_attrib, "density", density)
182
350
 
183
351
  if geom_type == "sphere":
184
352
  builder.add_shape_sphere(
@@ -194,16 +362,36 @@ def parse_mjcf(
194
362
  restitution=contact_restitution,
195
363
  )
196
364
 
365
+ elif geom_type == "box":
366
+ builder.add_shape_box(
367
+ link,
368
+ pos=geom_pos,
369
+ rot=geom_rot,
370
+ hx=geom_size[0],
371
+ hy=geom_size[1],
372
+ hz=geom_size[2],
373
+ density=geom_density,
374
+ ke=contact_ke,
375
+ kd=contact_kd,
376
+ kf=contact_kf,
377
+ mu=contact_mu,
378
+ restitution=contact_restitution,
379
+ )
380
+
197
381
  elif geom_type == "mesh" and parse_meshes:
198
- mesh, scale = parse_mesh(geom)
199
- geom_size = tuple([scale * s for s in geom_size])
382
+ mesh, _ = parse_mesh(geom_attrib)
383
+ if "mesh" in defaults:
384
+ mesh_scale = parse_vec(defaults["mesh"], "scale", [1.0, 1.0, 1.0])
385
+ else:
386
+ mesh_scale = [1.0, 1.0, 1.0]
387
+ # as per the Mujoco XML reference, ignore geom size attribute
200
388
  assert len(geom_size) == 3, "need to specify size for mesh geom"
201
389
  builder.add_shape_mesh(
202
390
  body=link,
203
391
  pos=geom_pos,
204
392
  rot=geom_rot,
205
393
  mesh=mesh,
206
- scale=geom_size,
394
+ scale=mesh_scale,
207
395
  density=density,
208
396
  ke=contact_ke,
209
397
  kd=contact_kd,
@@ -212,32 +400,28 @@ def parse_mjcf(
212
400
  )
213
401
 
214
402
  elif geom_type in {"capsule", "cylinder"}:
215
- if "fromto" in geom.attrib:
216
- geom_fromto = parse_vec(geom, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
403
+ if "fromto" in geom_attrib:
404
+ geom_fromto = parse_vec(geom_attrib, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
217
405
 
218
- start = geom_fromto[0:3]
219
- end = geom_fromto[3:6]
406
+ start = wp.vec3(geom_fromto[0:3]) * scale
407
+ end = wp.vec3(geom_fromto[3:6]) * scale
220
408
 
221
409
  # compute rotation to align the Warp capsule (along x-axis), with mjcf fromto direction
222
410
  axis = wp.normalize(end - start)
223
- angle = math.acos(np.dot(axis, (0.0, 1.0, 0.0)))
224
- axis = wp.normalize(np.cross(axis, (0.0, 1.0, 0.0)))
411
+ angle = math.acos(wp.dot(axis, wp.vec3(0.0, 1.0, 0.0)))
412
+ axis = wp.normalize(wp.cross(axis, wp.vec3(0.0, 1.0, 0.0)))
225
413
 
226
414
  geom_pos = (start + end) * 0.5
227
415
  geom_rot = wp.quat_from_axis_angle(axis, -angle)
228
416
 
229
417
  geom_radius = geom_size[0]
230
- geom_height = np.linalg.norm(end - start) * 0.5
418
+ geom_height = wp.length(end - start) * 0.5
419
+ geom_up_axis = 1
231
420
 
232
421
  else:
233
422
  geom_radius = geom_size[0]
234
423
  geom_height = geom_size[1]
235
- geom_pos = parse_vec(geom, "pos", (0.0, 0.0, 0.0))
236
- # orientation along the z axis by default
237
- axis = np.array((0.0, 0.0, 1.0))
238
- angle = math.acos(np.dot(axis, (0.0, 1.0, 0.0)))
239
- axis = wp.normalize(np.cross(axis, (0.0, 1.0, 0.0)))
240
- geom_rot = wp.quat_from_axis_angle(axis, -angle)
424
+ geom_up_axis = up_axis
241
425
 
242
426
  if geom_type == "cylinder":
243
427
  builder.add_shape_cylinder(
@@ -252,6 +436,7 @@ def parse_mjcf(
252
436
  kf=contact_kf,
253
437
  mu=contact_mu,
254
438
  restitution=contact_restitution,
439
+ up_axis=geom_up_axis,
255
440
  )
256
441
  else:
257
442
  builder.add_shape_capsule(
@@ -266,16 +451,17 @@ def parse_mjcf(
266
451
  kf=contact_kf,
267
452
  mu=contact_mu,
268
453
  restitution=contact_restitution,
454
+ up_axis=geom_up_axis,
269
455
  )
270
456
 
271
457
  else:
272
- print("MJCF parsing issue: geom type", geom_type, "is unsupported")
458
+ print(f"MJCF parsing shape {geom_name} issue: geom type {geom_type} is unsupported")
273
459
 
274
460
  # -----------------
275
461
  # recurse
276
462
 
277
463
  for child in body.findall("body"):
278
- parse_body(child, link)
464
+ parse_body(child, link, defaults)
279
465
 
280
466
  # -----------------
281
467
  # start articulation
@@ -284,8 +470,10 @@ def parse_mjcf(
284
470
  builder.add_articulation()
285
471
 
286
472
  world = root.find("worldbody")
473
+ world_class = get_class(world)
474
+ world_defaults = merge_attrib(class_defaults["__all__"], class_defaults.get(world_class, {}))
287
475
  for body in world.findall("body"):
288
- parse_body(body, -1)
476
+ parse_body(body, -1, world_defaults)
289
477
 
290
478
  end_shape_count = len(builder.shape_geo_type)
291
479
 
@@ -293,3 +481,6 @@ def parse_mjcf(
293
481
  for i in range(start_shape_count, end_shape_count):
294
482
  for j in range(i + 1, end_shape_count):
295
483
  builder.shape_collision_filter_pairs.add((i, j))
484
+
485
+ if collapse_fixed_joints:
486
+ builder.collapse_fixed_joints()