warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/sim/collide.py CHANGED
@@ -561,9 +561,17 @@ def create_soft_contacts(
561
561
  shape_v = wp.cw_mul(shape_v, geo_scale)
562
562
 
563
563
  delta = x_local - shape_p
564
+
564
565
  d = wp.length(delta) * sign
565
566
  n = wp.normalize(delta) * sign
566
567
  v = shape_v
568
+
569
+ if geo_type == wp.sim.GEO_SDF:
570
+ volume = geo.source[shape_index]
571
+ xpred_local = wp.volume_world_to_index(volume, wp.cw_div(x_local, geo_scale))
572
+ nn = wp.vec3(0.0, 0.0, 0.0)
573
+ d = wp.volume_sample_grad_f(volume, xpred_local, wp.Volume.LINEAR, nn)
574
+ n = wp.normalize(nn)
567
575
 
568
576
  if geo_type == wp.sim.GEO_PLANE:
569
577
  d = plane_sdf(geo_scale[0], geo_scale[1], x_local)
warp/sim/import_mjcf.py CHANGED
@@ -8,18 +8,17 @@
8
8
 
9
9
  import math
10
10
  import os
11
+ import re
11
12
  import xml.etree.ElementTree as ET
12
13
 
13
14
  import numpy as np
14
-
15
15
  import warp as wp
16
- from warp.sim.model import JOINT_COMPOUND, JOINT_UNIVERSAL
17
- from warp.sim.model import Mesh
18
16
 
19
17
 
20
18
  def parse_mjcf(
21
- filename,
19
+ mjcf_filename,
22
20
  builder,
21
+ xform=wp.transform(),
23
22
  density=1000.0,
24
23
  stiffness=0.0,
25
24
  damping=0.0,
@@ -30,69 +29,206 @@ def parse_mjcf(
30
29
  contact_restitution=0.5,
31
30
  limit_ke=100.0,
32
31
  limit_kd=10.0,
32
+ scale=1.0,
33
33
  armature=0.0,
34
34
  armature_scale=1.0,
35
- parse_meshes=False,
36
- enable_self_collisions=True,
35
+ parse_meshes=True,
36
+ enable_self_collisions=False,
37
+ up_axis="Z",
38
+ ignore_classes=[],
39
+ collapse_fixed_joints=False,
37
40
  ):
38
- file = ET.parse(filename)
41
+ """
42
+ Parses MuJoCo XML (MJCF) file and adds the bodies and joints to the given ModelBuilder.
43
+
44
+ Args:
45
+ mjcf_filename (str): The filename of the MuJoCo file to parse.
46
+ builder (ModelBuilder): The :class:`ModelBuilder` to add the bodies and joints to.
47
+ xform (:ref:`transform <transform>`): The transform to apply to the imported mechanism.
48
+ density (float): The density of the shapes in kg/m^3 which will be used to calculate the body mass and inertia.
49
+ stiffness (float): The stiffness of the joints.
50
+ damping (float): The damping of the joints.
51
+ contact_ke (float): The stiffness of the shape contacts (used by SemiImplicitIntegrator).
52
+ contact_kd (float): The damping of the shape contacts (used by SemiImplicitIntegrator).
53
+ contact_kf (float): The friction stiffness of the shape contacts (used by SemiImplicitIntegrator).
54
+ contact_mu (float): The friction coefficient of the shape contacts.
55
+ contact_restitution (float): The restitution coefficient of the shape contacts.
56
+ limit_ke (float): The stiffness of the joint limits (used by SemiImplicitIntegrator).
57
+ limit_kd (float): The damping of the joint limits (used by SemiImplicitIntegrator).
58
+ scale (float): The scaling factor to apply to the imported mechanism.
59
+ armature (float): Default joint armature to use if `armature` has not been defined for a joint in the MJCF.
60
+ armature_scale (float): Scaling factor to apply to the MJCF-defined joint armature values.
61
+ parse_meshes (bool): Whether geometries of type `"mesh"` should be parsed. If False, geometries of type `"mesh"` are ignored.
62
+ enable_self_collisions (bool): If True, self-collisions are enabled.
63
+ up_axis (str): The up axis of the mechanism. Can be either `"X"`, `"Y"` or `"Z"`. The default is `"Z"`.
64
+ ignore_classes (List[str]): A list of regular expressions. Bodies and joints with a class matching one of the regular expressions will be ignored.
65
+ collapse_fixed_joints (bool): If True, fixed joints are removed and the respective bodies are merged.
66
+
67
+ Note:
68
+ The inertia and masses of the bodies are calculated from the shape geometry and the given density. The values defined in the MJCF are not respected at the moment.
69
+
70
+ The handling of advanced features, such as MJCF classes, is still experimental.
71
+ """
72
+ mjcf_dirname = os.path.dirname(mjcf_filename)
73
+ file = ET.parse(mjcf_filename)
39
74
  root = file.getroot()
40
75
 
41
- type_map = {
42
- "ball": wp.sim.JOINT_BALL,
43
- "hinge": wp.sim.JOINT_REVOLUTE,
44
- "slide": wp.sim.JOINT_PRISMATIC,
45
- "free": wp.sim.JOINT_FREE,
46
- "fixed": wp.sim.JOINT_FIXED,
47
- }
48
-
49
- def parse_float(node, key, default):
50
- if key in node.attrib:
51
- return float(node.attrib[key])
76
+ use_degrees = True # angles are in degrees by default
77
+ euler_seq = [1, 2, 3] # XYZ by default
78
+
79
+ compiler = root.find("compiler")
80
+ if compiler is not None:
81
+ use_degrees = compiler.attrib.get("angle", "degree").lower() == "degree"
82
+ euler_seq = ["xyz".index(c) + 1 for c in compiler.attrib.get("eulerseq", "xyz").lower()]
83
+ mesh_dir = compiler.attrib.get("meshdir", ".")
84
+
85
+ mesh_assets = {}
86
+ for asset in root.findall("asset"):
87
+ for mesh in asset.findall("mesh"):
88
+ if "file" in mesh.attrib:
89
+ fname = os.path.join(mesh_dir, mesh.attrib["file"])
90
+ # handle stl relative paths
91
+ if not os.path.isabs(fname):
92
+ fname = os.path.abspath(os.path.join(mjcf_dirname, fname))
93
+ if "name" in mesh.attrib:
94
+ mesh_assets[mesh.attrib["name"]] = fname
95
+ else:
96
+ name = ".".join(os.path.basename(fname).split(".")[:-1])
97
+ mesh_assets[name] = fname
98
+
99
+ class_parent = {}
100
+ class_children = {}
101
+ class_defaults = {"__all__": {}}
102
+
103
+ def get_class(element):
104
+ return element.get("class", "__all__")
105
+
106
+ def parse_default(node, parent):
107
+ nonlocal class_parent
108
+ nonlocal class_children
109
+ nonlocal class_defaults
110
+ class_name = "__all__"
111
+ if "class" in node.attrib:
112
+ class_name = node.attrib["class"]
113
+ class_parent[class_name] = parent
114
+ parent = parent or "__all__"
115
+ if parent not in class_children:
116
+ class_children[parent] = []
117
+ class_children[parent].append(class_name)
118
+
119
+ if class_name not in class_defaults:
120
+ class_defaults[class_name] = {}
121
+ for child in node:
122
+ if child.tag == "default":
123
+ parse_default(child, node.get("class"))
124
+ else:
125
+ class_defaults[class_name][child.tag] = child.attrib
126
+
127
+ for default in root.findall("default"):
128
+ parse_default(default, None)
129
+
130
+ def merge_attrib(default_attrib: dict, incoming_attrib: dict):
131
+ attrib = default_attrib.copy()
132
+ attrib.update(incoming_attrib)
133
+ return attrib
134
+
135
+ if isinstance(up_axis, str):
136
+ up_axis = "XYZ".index(up_axis.upper())
137
+ sqh = np.sqrt(0.5)
138
+ if up_axis == 0:
139
+ xform = wp.transform(xform.p, wp.quat(0.0, 0.0, -sqh, sqh) * xform.q)
140
+ elif up_axis == 2:
141
+ xform = wp.transform(xform.p, wp.quat(sqh, 0.0, 0.0, -sqh) * xform.q)
142
+ # do not apply scaling to the root transform
143
+ xform = wp.transform(np.array(xform.p) / scale, xform.q)
144
+
145
+ def parse_float(attrib, key, default):
146
+ if key in attrib:
147
+ return float(attrib[key])
52
148
  else:
53
149
  return default
54
150
 
55
- def parse_vec(node, key, default):
56
- if key in node.attrib:
57
- return np.fromstring(node.attrib[key], sep=" ")
151
+ def parse_vec(attrib, key, default):
152
+ if key in attrib:
153
+ out = np.fromstring(attrib[key], sep=" ", dtype=np.float32)
58
154
  else:
59
- return np.array(default)
155
+ out = np.array(default, dtype=np.float32)
156
+
157
+ length = len(out)
158
+ if length == 1:
159
+ return wp.vec(len(default), wp.float32)(out[0], out[0], out[0])
160
+
161
+ return wp.vec(length, wp.float32)(out)
162
+
163
+ def parse_orientation(attrib):
164
+ if "quat" in attrib:
165
+ wxyz = np.fromstring(attrib["quat"], sep=" ")
166
+ return wp.normalize(wp.quat(*wxyz[1:], wxyz[0]))
167
+ if "euler" in attrib:
168
+ euler = np.fromstring(attrib["euler"], sep=" ")
169
+ if use_degrees:
170
+ euler *= np.pi / 180
171
+ return wp.quat_from_euler(euler, *euler_seq)
172
+ if "axisangle" in attrib:
173
+ axisangle = np.fromstring(attrib["axisangle"], sep=" ")
174
+ angle = axisangle[3]
175
+ if use_degrees:
176
+ angle *= np.pi / 180
177
+ axis = wp.normalize(wp.vec3(*axisangle[:3]))
178
+ return wp.quat_from_axis_angle(axis, angle)
179
+ if "xyaxes" in attrib:
180
+ xyaxes = np.fromstring(attrib["xyaxes"], sep=" ")
181
+ xaxis = wp.normalize(wp.vec3(*xyaxes[:3]))
182
+ zaxis = wp.normalize(wp.vec3(*xyaxes[3:]))
183
+ yaxis = wp.normalize(wp.cross(zaxis, xaxis))
184
+ rot_matrix = np.array([xaxis, yaxis, zaxis]).T
185
+ return wp.quat_from_matrix(rot_matrix)
186
+ if "zaxis" in attrib:
187
+ zaxis = np.fromstring(attrib["zaxis"], sep=" ")
188
+ zaxis = wp.normalize(wp.vec3(*zaxis))
189
+ xaxis = wp.normalize(wp.cross(wp.vec3(0, 0, 1), zaxis))
190
+ yaxis = wp.normalize(wp.cross(zaxis, xaxis))
191
+ rot_matrix = np.array([xaxis, yaxis, zaxis]).T
192
+ return wp.quat_from_matrix(rot_matrix)
193
+ return wp.quat_identity()
60
194
 
61
195
  def parse_mesh(geom):
62
196
  import trimesh
63
197
 
64
198
  faces = []
65
199
  vertices = []
66
- stl_file = next(
67
- filter(
68
- lambda m: m.attrib["name"] == geom.attrib["mesh"],
69
- root.find("asset").findall("mesh"),
70
- )
71
- ).attrib["file"]
72
- # handle stl relative paths
73
- if not os.path.isabs(stl_file):
74
- stl_file = os.path.join(os.path.dirname(filename), stl_file)
200
+ stl_file = mesh_assets[geom["mesh"]]
75
201
  m = trimesh.load(stl_file)
76
202
 
77
203
  for v in m.vertices:
78
- vertices.append(np.array(v))
204
+ vertices.append(np.array(v) * scale)
79
205
 
80
206
  for f in m.faces:
81
207
  faces.append(int(f[0]))
82
208
  faces.append(int(f[1]))
83
209
  faces.append(int(f[2]))
84
- return Mesh(vertices, faces), m.scale
85
-
86
- def parse_body(body, parent):
87
- body_name = body.attrib["name"]
88
- body_pos = parse_vec(body, "pos", (0.0, 0.0, 0.0))
89
- body_ori_euler = parse_vec(body, "euler", (0.0, 0.0, 0.0))
90
- if len(np.nonzero(body_ori_euler)[0]) > 0:
91
- body_axis = tuple(np.sign(body_ori_euler))
92
- body_angle = body_ori_euler[np.nonzero(body_ori_euler)[0].item()] / 180 * np.pi
93
- body_ori = wp.utils.quat_from_axis_angle(body_axis, body_angle)
210
+ return wp.sim.Mesh(vertices, faces), m.scale
211
+
212
+ def parse_body(body, parent, incoming_defaults: dict):
213
+ body_class = body.get("childclass")
214
+ if body_class is None:
215
+ defaults = incoming_defaults
94
216
  else:
95
- body_ori = wp.quat_identity()
217
+ for pattern in ignore_classes:
218
+ if re.match(pattern, body_class):
219
+ return
220
+ defaults = merge_attrib(incoming_defaults, class_defaults[body_class])
221
+ if "body" in defaults:
222
+ body_attrib = merge_attrib(defaults["body"], body.attrib)
223
+ else:
224
+ body_attrib = body.attrib
225
+ body_name = body_attrib["name"]
226
+ body_pos = parse_vec(body_attrib, "pos", (0.0, 0.0, 0.0))
227
+ body_ori = parse_orientation(body_attrib)
228
+ if parent == -1:
229
+ body_pos = wp.transform_point(xform, body_pos)
230
+ body_ori = xform.q * body_ori
231
+ body_pos *= scale
96
232
 
97
233
  joint_armature = []
98
234
  joint_name = []
@@ -102,43 +238,55 @@ def parse_mjcf(
102
238
  angular_axes = []
103
239
  joint_type = None
104
240
 
105
- joints = body.findall("joint")
106
- for i, joint in enumerate(joints):
107
- # default to hinge if not specified
108
- if "type" not in joint.attrib:
109
- joint.attrib["type"] = "hinge"
110
-
111
- joint_name.append(joint.attrib["name"])
112
- joint_pos.append(parse_vec(joint, "pos", (0.0, 0.0, 0.0)))
113
- # TODO parse joint (child transform) rotation?
114
- joint_range = parse_vec(joint, "range", (-3.0, 3.0))
115
- joint_armature.append(parse_float(joint, "armature", armature) * armature_scale)
116
-
117
- if joint.attrib["type"].lower() == "free":
118
- joint_type = wp.sim.JOINT_FREE
119
- break
120
- is_angular = joint.attrib["type"].lower() == "hinge"
121
- mode = wp.sim.JOINT_MODE_LIMIT
122
- if stiffness > 0.0 or "stiffness" in joint.attrib:
123
- mode = wp.sim.JOINT_MODE_TARGET_POSITION
124
- ax = wp.sim.model.JointAxis(
125
- axis=parse_vec(joint, "axis", (0.0, 0.0, 0.0)),
126
- limit_lower=(np.deg2rad(joint_range[0]) if is_angular else joint_range[0]),
127
- limit_upper=(np.deg2rad(joint_range[1]) if is_angular else joint_range[1]),
128
- target_ke=parse_float(joint, "stiffness", stiffness),
129
- target_kd=parse_float(joint, "damping", damping),
130
- limit_ke=limit_ke,
131
- limit_kd=limit_kd,
132
- mode=mode,
133
- )
134
- if is_angular:
135
- angular_axes.append(ax)
136
- else:
137
- linear_axes.append(ax)
241
+ freejoint_tags = body.findall("freejoint")
242
+ if len(freejoint_tags) > 0:
243
+ joint_type = wp.sim.JOINT_FREE
244
+ joint_name.append(freejoint_tags[0].attrib.get("name", f"{body_name}_freejoint"))
245
+ else:
246
+ joints = body.findall("joint")
247
+ for i, joint in enumerate(joints):
248
+ if "joint" in defaults:
249
+ joint_attrib = merge_attrib(defaults["joint"], joint.attrib)
250
+ else:
251
+ joint_attrib = joint.attrib
252
+
253
+ # default to hinge if not specified
254
+ joint_type_str = joint_attrib.get("type", "hinge")
255
+
256
+ joint_name.append(joint_attrib["name"])
257
+ joint_pos.append(parse_vec(joint_attrib, "pos", (0.0, 0.0, 0.0)) * scale)
258
+ joint_range = parse_vec(joint_attrib, "range", (-3.0, 3.0))
259
+ joint_armature.append(parse_float(joint_attrib, "armature", armature) * armature_scale)
260
+
261
+ if joint_type_str == "free":
262
+ joint_type = wp.sim.JOINT_FREE
263
+ break
264
+ if joint_type_str == "fixed":
265
+ joint_type = wp.sim.JOINT_FIXED
266
+ break
267
+ is_angular = joint_type_str == "hinge"
268
+ mode = wp.sim.JOINT_MODE_LIMIT
269
+ if stiffness > 0.0 or "stiffness" in joint_attrib:
270
+ mode = wp.sim.JOINT_MODE_TARGET_POSITION
271
+ axis_vec = parse_vec(joint_attrib, "axis", (0.0, 0.0, 0.0))
272
+ ax = wp.sim.model.JointAxis(
273
+ axis=axis_vec,
274
+ limit_lower=(np.deg2rad(joint_range[0]) if is_angular and use_degrees else joint_range[0]),
275
+ limit_upper=(np.deg2rad(joint_range[1]) if is_angular and use_degrees else joint_range[1]),
276
+ target_ke=parse_float(joint_attrib, "stiffness", stiffness),
277
+ target_kd=parse_float(joint_attrib, "damping", damping),
278
+ limit_ke=limit_ke,
279
+ limit_kd=limit_kd,
280
+ mode=mode,
281
+ )
282
+ if is_angular:
283
+ angular_axes.append(ax)
284
+ else:
285
+ linear_axes.append(ax)
138
286
 
139
287
  link = builder.add_body(
140
- origin=wp.transform_identity(), # will be evaluated in fk()
141
- armature=joint_armature[0],
288
+ origin=wp.transform(body_pos, body_ori), # will be evaluated in fk()
289
+ armature=joint_armature[0] if len(joint_armature) > 0 else armature,
142
290
  name=body_name,
143
291
  )
144
292
 
@@ -157,6 +305,7 @@ def parse_mjcf(
157
305
  else:
158
306
  joint_type = wp.sim.JOINT_D6
159
307
 
308
+ joint_pos = joint_pos[0] if len(joint_pos) > 0 else (0.0, 0.0, 0.0)
160
309
  builder.add_joint(
161
310
  joint_type,
162
311
  parent,
@@ -164,21 +313,40 @@ def parse_mjcf(
164
313
  linear_axes,
165
314
  angular_axes,
166
315
  name="_".join(joint_name),
167
- parent_xform=wp.transform(body_pos, body_ori),
168
- # child_xform=wp.transform(joint_pos[0], wp.quat_identity()),
316
+ parent_xform=wp.transform(body_pos + joint_pos, body_ori),
317
+ child_xform=wp.transform(joint_pos, wp.quat_identity()),
169
318
  )
170
319
 
171
320
  # -----------------
172
321
  # add shapes
173
322
 
174
- for geom in body.findall("geom"):
175
- geom_name = geom.attrib["name"]
176
- geom_type = geom.attrib["type"]
323
+ for geo_count, geom in enumerate(body.findall("geom")):
324
+ geom_defaults = defaults
325
+ if "class" in geom.attrib:
326
+ geom_class = geom.attrib["class"]
327
+ ignore_geom = False
328
+ for pattern in ignore_classes:
329
+ if re.match(pattern, geom_class):
330
+ ignore_geom = True
331
+ break
332
+ if ignore_geom:
333
+ continue
334
+ if geom_class in class_defaults:
335
+ geom_defaults = merge_attrib(defaults, class_defaults[geom_class])
336
+ if "geom" in geom_defaults:
337
+ geom_attrib = merge_attrib(geom_defaults["geom"], geom.attrib)
338
+ else:
339
+ geom_attrib = geom.attrib
340
+
341
+ geom_name = geom_attrib.get("name", f"{body_name}_geom_{geo_count}")
342
+ geom_type = geom_attrib.get("type", "sphere")
343
+ if "mesh" in geom_attrib:
344
+ geom_type = "mesh"
177
345
 
178
- geom_size = parse_vec(geom, "size", [1.0])
179
- geom_pos = parse_vec(geom, "pos", (0.0, 0.0, 0.0))
180
- geom_rot = parse_vec(geom, "quat", (0.0, 0.0, 0.0, 1.0))
181
- geom_density = parse_float(geom, "density", density)
346
+ geom_size = parse_vec(geom_attrib, "size", [1.0, 1.0, 1.0]) * scale
347
+ geom_pos = parse_vec(geom_attrib, "pos", (0.0, 0.0, 0.0)) * scale
348
+ geom_rot = parse_orientation(geom_attrib)
349
+ geom_density = parse_float(geom_attrib, "density", density)
182
350
 
183
351
  if geom_type == "sphere":
184
352
  builder.add_shape_sphere(
@@ -194,16 +362,36 @@ def parse_mjcf(
194
362
  restitution=contact_restitution,
195
363
  )
196
364
 
365
+ elif geom_type == "box":
366
+ builder.add_shape_box(
367
+ link,
368
+ pos=geom_pos,
369
+ rot=geom_rot,
370
+ hx=geom_size[0],
371
+ hy=geom_size[1],
372
+ hz=geom_size[2],
373
+ density=geom_density,
374
+ ke=contact_ke,
375
+ kd=contact_kd,
376
+ kf=contact_kf,
377
+ mu=contact_mu,
378
+ restitution=contact_restitution,
379
+ )
380
+
197
381
  elif geom_type == "mesh" and parse_meshes:
198
- mesh, scale = parse_mesh(geom)
199
- geom_size = tuple([scale * s for s in geom_size])
382
+ mesh, _ = parse_mesh(geom_attrib)
383
+ if "mesh" in defaults:
384
+ mesh_scale = parse_vec(defaults["mesh"], "scale", [1.0, 1.0, 1.0])
385
+ else:
386
+ mesh_scale = [1.0, 1.0, 1.0]
387
+ # as per the Mujoco XML reference, ignore geom size attribute
200
388
  assert len(geom_size) == 3, "need to specify size for mesh geom"
201
389
  builder.add_shape_mesh(
202
390
  body=link,
203
391
  pos=geom_pos,
204
392
  rot=geom_rot,
205
393
  mesh=mesh,
206
- scale=geom_size,
394
+ scale=mesh_scale,
207
395
  density=density,
208
396
  ke=contact_ke,
209
397
  kd=contact_kd,
@@ -212,32 +400,28 @@ def parse_mjcf(
212
400
  )
213
401
 
214
402
  elif geom_type in {"capsule", "cylinder"}:
215
- if "fromto" in geom.attrib:
216
- geom_fromto = parse_vec(geom, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
403
+ if "fromto" in geom_attrib:
404
+ geom_fromto = parse_vec(geom_attrib, "fromto", (0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
217
405
 
218
- start = geom_fromto[0:3]
219
- end = geom_fromto[3:6]
406
+ start = wp.vec3(geom_fromto[0:3]) * scale
407
+ end = wp.vec3(geom_fromto[3:6]) * scale
220
408
 
221
409
  # compute rotation to align the Warp capsule (along x-axis), with mjcf fromto direction
222
410
  axis = wp.normalize(end - start)
223
- angle = math.acos(np.dot(axis, (0.0, 1.0, 0.0)))
224
- axis = wp.normalize(np.cross(axis, (0.0, 1.0, 0.0)))
411
+ angle = math.acos(wp.dot(axis, wp.vec3(0.0, 1.0, 0.0)))
412
+ axis = wp.normalize(wp.cross(axis, wp.vec3(0.0, 1.0, 0.0)))
225
413
 
226
414
  geom_pos = (start + end) * 0.5
227
415
  geom_rot = wp.quat_from_axis_angle(axis, -angle)
228
416
 
229
417
  geom_radius = geom_size[0]
230
- geom_height = np.linalg.norm(end - start) * 0.5
418
+ geom_height = wp.length(end - start) * 0.5
419
+ geom_up_axis = 1
231
420
 
232
421
  else:
233
422
  geom_radius = geom_size[0]
234
423
  geom_height = geom_size[1]
235
- geom_pos = parse_vec(geom, "pos", (0.0, 0.0, 0.0))
236
- # orientation along the z axis by default
237
- axis = np.array((0.0, 0.0, 1.0))
238
- angle = math.acos(np.dot(axis, (0.0, 1.0, 0.0)))
239
- axis = wp.normalize(np.cross(axis, (0.0, 1.0, 0.0)))
240
- geom_rot = wp.quat_from_axis_angle(axis, -angle)
424
+ geom_up_axis = up_axis
241
425
 
242
426
  if geom_type == "cylinder":
243
427
  builder.add_shape_cylinder(
@@ -252,6 +436,7 @@ def parse_mjcf(
252
436
  kf=contact_kf,
253
437
  mu=contact_mu,
254
438
  restitution=contact_restitution,
439
+ up_axis=geom_up_axis,
255
440
  )
256
441
  else:
257
442
  builder.add_shape_capsule(
@@ -266,16 +451,17 @@ def parse_mjcf(
266
451
  kf=contact_kf,
267
452
  mu=contact_mu,
268
453
  restitution=contact_restitution,
454
+ up_axis=geom_up_axis,
269
455
  )
270
456
 
271
457
  else:
272
- print("MJCF parsing issue: geom type", geom_type, "is unsupported")
458
+ print(f"MJCF parsing shape {geom_name} issue: geom type {geom_type} is unsupported")
273
459
 
274
460
  # -----------------
275
461
  # recurse
276
462
 
277
463
  for child in body.findall("body"):
278
- parse_body(child, link)
464
+ parse_body(child, link, defaults)
279
465
 
280
466
  # -----------------
281
467
  # start articulation
@@ -284,8 +470,10 @@ def parse_mjcf(
284
470
  builder.add_articulation()
285
471
 
286
472
  world = root.find("worldbody")
473
+ world_class = get_class(world)
474
+ world_defaults = merge_attrib(class_defaults["__all__"], class_defaults.get(world_class, {}))
287
475
  for body in world.findall("body"):
288
- parse_body(body, -1)
476
+ parse_body(body, -1, world_defaults)
289
477
 
290
478
  end_shape_count = len(builder.shape_geo_type)
291
479
 
@@ -293,3 +481,6 @@ def parse_mjcf(
293
481
  for i in range(start_shape_count, end_shape_count):
294
482
  for j in range(i + 1, end_shape_count):
295
483
  builder.shape_collision_filter_pairs.add((i, j))
484
+
485
+ if collapse_fixed_joints:
486
+ builder.collapse_fixed_joints()