warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/utils.py CHANGED
@@ -5,215 +5,37 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import os
9
- import math
10
- import timeit
11
8
  import cProfile
9
+ import sys
10
+ import timeit
11
+ import warnings
12
+ from typing import Any
13
+
12
14
  import numpy as np
13
- from typing import Union, Tuple
14
15
 
15
16
  import warp as wp
17
+ import warp.types
16
18
 
17
19
 
18
- def length(a):
19
- return np.linalg.norm(a)
20
-
21
-
22
- def length_sq(a):
23
- return np.dot(a, a)
24
-
25
-
26
- def cross(a, b):
27
- return np.array((a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]), dtype=np.float32)
28
-
29
-
30
- # NumPy has no normalize() method..
31
- def normalize(v):
32
- norm = np.linalg.norm(v)
33
- if norm == 0.0:
34
- return v
35
- return v / norm
36
-
37
-
38
- def skew(v):
39
- return np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
40
-
41
-
42
- # math utils
43
- # def quat(i, j, k, w):
44
- # return np.array([i, j, k, w])
45
-
46
-
47
- def quat_identity():
48
- return np.array((0.0, 0.0, 0.0, 1.0))
49
-
50
-
51
- def quat_inverse(q):
52
- return np.array((-q[0], -q[1], -q[2], q[3]))
53
-
54
-
55
- def quat_from_axis_angle(axis, angle):
56
- v = normalize(np.array(axis))
57
-
58
- half = angle * 0.5
59
- w = math.cos(half)
60
-
61
- sin_theta_over_two = math.sin(half)
62
- v *= sin_theta_over_two
63
-
64
- return np.array((v[0], v[1], v[2], w))
65
-
66
-
67
- def quat_to_axis_angle(quat):
68
- w2 = quat[3] * quat[3]
69
- if w2 > 1 - 1e-7:
70
- return np.zeros(3), 0.0
71
-
72
- angle = 2 * np.arccos(quat[3])
73
- xyz = quat[:3] / np.sqrt(1 - w2)
74
- return xyz, angle
75
-
76
-
77
- # quat_rotate a vector
78
- def quat_rotate(q, x):
79
- x = np.array(x)
80
- axis = np.array((q[0], q[1], q[2]))
81
- return x * (2.0 * q[3] * q[3] - 1.0) + np.cross(axis, x) * q[3] * 2.0 + axis * np.dot(axis, x) * 2.0
82
-
83
-
84
- # multiply two quats
85
- def quat_multiply(a, b):
86
- return np.array(
87
- (
88
- a[3] * b[0] + b[3] * a[0] + a[1] * b[2] - b[1] * a[2],
89
- a[3] * b[1] + b[3] * a[1] + a[2] * b[0] - b[2] * a[0],
90
- a[3] * b[2] + b[3] * a[2] + a[0] * b[1] - b[0] * a[1],
91
- a[3] * b[3] - a[0] * b[0] - a[1] * b[1] - a[2] * b[2],
92
- )
93
- )
94
-
95
-
96
- # convert to mat33
97
- def quat_to_matrix(q):
98
- c1 = quat_rotate(q, np.array((1.0, 0.0, 0.0)))
99
- c2 = quat_rotate(q, np.array((0.0, 1.0, 0.0)))
100
- c3 = quat_rotate(q, np.array((0.0, 0.0, 1.0)))
101
-
102
- return np.array([c1, c2, c3]).T
103
-
104
-
105
- def quat_rpy(roll, pitch, yaw):
106
- cy = math.cos(yaw * 0.5)
107
- sy = math.sin(yaw * 0.5)
108
- cr = math.cos(roll * 0.5)
109
- sr = math.sin(roll * 0.5)
110
- cp = math.cos(pitch * 0.5)
111
- sp = math.sin(pitch * 0.5)
112
-
113
- w = cy * cr * cp + sy * sr * sp
114
- x = cy * sr * cp - sy * cr * sp
115
- y = cy * cr * sp + sy * sr * cp
116
- z = sy * cr * cp - cy * sr * sp
117
-
118
- return (x, y, z, w)
119
-
120
-
121
- def quat_from_matrix(m):
122
- tr = m[0, 0] + m[1, 1] + m[2, 2]
123
- h = 0.0
124
-
125
- if tr >= 0.0:
126
- h = math.sqrt(tr + 1.0)
127
- w = 0.5 * h
128
- h = 0.5 / h
129
-
130
- x = (m[2, 1] - m[1, 2]) * h
131
- y = (m[0, 2] - m[2, 0]) * h
132
- z = (m[1, 0] - m[0, 1]) * h
133
-
134
- else:
135
- i = 0
136
- if m[1, 1] > m[0, 0]:
137
- i = 1
138
- if m[2, 2] > m[i, i]:
139
- i = 2
140
-
141
- if i == 0:
142
- h = math.sqrt((m[0, 0] - (m[1, 1] + m[2, 2])) + 1.0)
143
- x = 0.5 * h
144
- h = 0.5 / h
145
-
146
- y = (m[0, 1] + m[1, 0]) * h
147
- z = (m[2, 0] + m[0, 2]) * h
148
- w = (m[2, 1] - m[1, 2]) * h
149
-
150
- elif i == 1:
151
- h = math.sqrt((m[1, 1] - (m[2, 2] + m[0, 0])) + 1.0)
152
- y = 0.5 * h
153
- h = 0.5 / h
154
-
155
- z = (m[1, 2] + m[2, 1]) * h
156
- x = (m[0, 1] + m[1, 0]) * h
157
- w = (m[0, 2] - m[2, 0]) * h
158
-
159
- elif i == 2:
160
- h = math.sqrt((m[2, 2] - (m[0, 0] + m[1, 1])) + 1.0)
161
- z = 0.5 * h
162
- h = 0.5 / h
163
-
164
- x = (m[2, 0] + m[0, 2]) * h
165
- y = (m[1, 2] + m[2, 1]) * h
166
- w = (m[1, 0] - m[0, 1]) * h
167
-
168
- return normalize(np.array([x, y, z, w]))
169
-
170
-
171
- # rigid body transform
172
-
173
-
174
- # def transform(x, r):
175
- # return (np.array(x), np.array(r))
176
-
177
-
178
- def transform_identity():
179
- return wp.transform(np.array((0.0, 0.0, 0.0)), quat_identity())
180
-
181
-
182
- # se(3) -> SE(3), Park & Lynch pg. 105, screw in [w, v] normalized form
183
- def transform_exp(s, angle):
184
- w = np.array(s[0:3])
185
- v = np.array(s[3:6])
186
-
187
- if length(w) < 1.0:
188
- r = quat_identity()
189
- else:
190
- r = quat_from_axis_angle(w, angle)
191
-
192
- t = v * angle + (1.0 - math.cos(angle)) * np.cross(w, v) + (angle - math.sin(angle)) * np.cross(w, np.cross(w, v))
193
-
194
- return (t, r)
195
-
196
-
197
- def transform_inverse(t):
198
- q_inv = quat_inverse(t.q)
199
- return wp.transform(-quat_rotate(q_inv, t.p), q_inv)
200
-
201
-
202
- def transform_vector(t, v):
203
- return quat_rotate(t.q, v)
20
+ warnings_seen = set()
204
21
 
205
22
 
206
- def transform_point(t, p):
207
- return np.array(t.p) + quat_rotate(t.q, p)
23
+ def warp_showwarning(message, category, filename, lineno, file=None, line=None):
24
+ """Version of warnings.showwarning that always prints to sys.stdout."""
25
+ sys.stdout.write(warnings.formatwarning(message, category, filename, lineno, line=line))
208
26
 
209
27
 
210
- def transform_multiply(t, u):
211
- return wp.transform(quat_rotate(t.q, u.p) + t.p, quat_multiply(t.q, u.q))
28
+ def warn(message, category=None, stacklevel=1):
29
+ if (category, message) in warnings_seen:
30
+ return
212
31
 
32
+ with warnings.catch_warnings():
33
+ warnings.simplefilter("default") # Change the filter in this process
34
+ warnings.showwarning = warp_showwarning
35
+ warnings.warn(message, category, stacklevel + 1) # Increment stacklevel by 1 since we are in a wrapper
213
36
 
214
- # flatten an array of transforms (p,q) format to a 7-vector
215
- def transform_flatten(t):
216
- return np.array([*t.p, *t.q])
37
+ if category is DeprecationWarning:
38
+ warnings_seen.add((category, message))
217
39
 
218
40
 
219
41
  # expand a 7-vec to a tuple of arrays
@@ -221,183 +43,368 @@ def transform_expand(t):
221
43
  return wp.transform(np.array(t[0:3]), np.array(t[3:7]))
222
44
 
223
45
 
224
- # convert array of transforms to a array of 7-vecs
225
- def transform_flatten_list(xforms):
226
- exp = lambda t: transform_flatten(t)
227
- return list(map(exp, xforms))
228
-
229
-
230
- def transform_expand_list(xforms):
231
- exp = lambda t: transform_expand(t)
232
- return list(map(exp, xforms))
233
-
234
-
235
- def transform_inertia(m, I, p, q):
46
+ @wp.func
47
+ def quat_between_vectors(a: wp.vec3, b: wp.vec3) -> wp.quat:
236
48
  """
237
- Transforms the inertia tensor described by the given mass and 3x3 inertia
238
- matrix to a new frame described by the given position and orientation.
49
+ Compute the quaternion that rotates vector a to vector b
239
50
  """
240
- R = quat_to_matrix(q)
241
-
242
- # Steiner's theorem
243
- return R @ I @ R.T + m * (np.dot(p, p) * np.eye(3) - np.outer(p, p))
244
-
51
+ a = wp.normalize(a)
52
+ b = wp.normalize(b)
53
+ c = wp.cross(a, b)
54
+ d = wp.dot(a, b)
55
+ q = wp.quat(c[0], c[1], c[2], 1.0 + d)
56
+ return wp.normalize(q)
245
57
 
246
- # spatial operators
247
58
 
59
+ def array_scan(in_array, out_array, inclusive=True):
60
+ if in_array.device != out_array.device:
61
+ raise RuntimeError("Array storage devices do not match")
248
62
 
249
- # AdT
250
- def spatial_adjoint(t):
251
- R = quat_to_matrix(t.q)
252
- w = skew(t.p)
63
+ if in_array.size != out_array.size:
64
+ raise RuntimeError("Array storage sizes do not match")
253
65
 
254
- A = np.zeros((6, 6))
255
- A[0:3, 0:3] = R
256
- A[3:6, 0:3] = np.dot(w, R)
257
- A[3:6, 3:6] = R
66
+ if in_array.dtype != out_array.dtype:
67
+ raise RuntimeError("Array data types do not match")
258
68
 
259
- return A
69
+ if in_array.size == 0:
70
+ return
260
71
 
72
+ from warp.context import runtime
261
73
 
262
- # (AdT)^-T
263
- def spatial_adjoint_dual(t):
264
- R = quat_to_matrix(t.q)
265
- w = skew(t.p)
74
+ if in_array.device.is_cpu:
75
+ if in_array.dtype == wp.int32:
76
+ runtime.core.array_scan_int_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
77
+ elif in_array.dtype == wp.float32:
78
+ runtime.core.array_scan_float_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
79
+ else:
80
+ raise RuntimeError("Unsupported data type")
81
+ elif in_array.device.is_cuda:
82
+ if in_array.dtype == wp.int32:
83
+ runtime.core.array_scan_int_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
84
+ elif in_array.dtype == wp.float32:
85
+ runtime.core.array_scan_float_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
86
+ else:
87
+ raise RuntimeError("Unsupported data type")
266
88
 
267
- A = np.zeros((6, 6))
268
- A[0:3, 0:3] = R
269
- A[0:3, 3:6] = np.dot(w, R)
270
- A[3:6, 3:6] = R
271
89
 
272
- return A
90
+ def radix_sort_pairs(keys, values, count: int):
91
+ if keys.device != values.device:
92
+ raise RuntimeError("Array storage devices do not match")
273
93
 
94
+ if count == 0:
95
+ return
274
96
 
275
- # AdT*s
276
- def transform_twist(t_ab, s_b):
277
- return np.dot(spatial_adjoint(t_ab), s_b)
97
+ if keys.size < 2 * count or values.size < 2 * count:
98
+ raise RuntimeError("Array storage must be large enough to contain 2*count elements")
278
99
 
100
+ from warp.context import runtime
279
101
 
280
- # AdT^{-T}*s
281
- def transform_wrench(t_ab, f_b):
282
- return np.dot(spatial_adjoint_dual(t_ab), f_b)
102
+ if keys.device.is_cpu:
103
+ if keys.dtype == wp.int32 and values.dtype == wp.int32:
104
+ runtime.core.radix_sort_pairs_int_host(keys.ptr, values.ptr, count)
105
+ else:
106
+ raise RuntimeError("Unsupported data type")
107
+ elif keys.device.is_cuda:
108
+ if keys.dtype == wp.int32 and values.dtype == wp.int32:
109
+ runtime.core.radix_sort_pairs_int_device(keys.ptr, values.ptr, count)
110
+ else:
111
+ raise RuntimeError("Unsupported data type")
283
112
 
284
113
 
285
- # transform spatial inertia (6x6) in b frame to a frame
286
- def transform_spatial_inertia(t_ab, I_b):
287
- t_ba = transform_inverse(t_ab)
114
+ def runlength_encode(values, run_values, run_lengths, run_count=None, value_count=None):
115
+ if run_values.device != values.device or run_lengths.device != values.device:
116
+ raise RuntimeError("Array storage devices do not match")
288
117
 
289
- # todo: write specialized method
290
- I_a = np.dot(np.dot(spatial_adjoint(t_ba).T, I_b), spatial_adjoint(t_ba))
291
- return I_a
118
+ if value_count is None:
119
+ value_count = values.size
292
120
 
121
+ if run_values.size < value_count or run_lengths.size < value_count:
122
+ raise RuntimeError("Output array storage sizes must be at least equal to value_count")
293
123
 
294
- def translate_twist(p_ab, s_b):
295
- w = s_b[0:3]
296
- v = np.cross(p_ab, s_b[0:3]) + s_b[3:6]
124
+ if values.dtype != run_values.dtype:
125
+ raise RuntimeError("values and run_values data types do not match")
297
126
 
298
- return np.array((*w, *v))
127
+ if run_lengths.dtype != wp.int32:
128
+ raise RuntimeError("run_lengths array must be of type int32")
299
129
 
130
+ # User can provide a device output array for storing the number of runs
131
+ # For convenience, if no such array is provided, number of runs is returned on host
132
+ if run_count is None:
133
+ if value_count == 0:
134
+ return 0
135
+ run_count = wp.empty(shape=(1,), dtype=int, device=values.device)
136
+ host_return = True
137
+ else:
138
+ if run_count.device != values.device:
139
+ raise RuntimeError("run_count storage device does not match other arrays")
140
+ if run_count.dtype != wp.int32:
141
+ raise RuntimeError("run_count array must be of type int32")
142
+ if value_count == 0:
143
+ run_count.zero_()
144
+ return 0
145
+ host_return = False
300
146
 
301
- def translate_wrench(p_ab, s_b):
302
- w = s_b[0:3] + np.cross(p_ab, s_b[3:6])
303
- v = s_b[3:6]
147
+ from warp.context import runtime
304
148
 
305
- return np.array((*w, *v))
149
+ if values.device.is_cpu:
150
+ if values.dtype == wp.int32:
151
+ runtime.core.runlength_encode_int_host(
152
+ values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
153
+ )
154
+ else:
155
+ raise RuntimeError("Unsupported data type")
156
+ elif values.device.is_cuda:
157
+ if values.dtype == wp.int32:
158
+ runtime.core.runlength_encode_int_device(
159
+ values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
160
+ )
161
+ else:
162
+ raise RuntimeError("Unsupported data type")
306
163
 
164
+ if host_return:
165
+ return int(run_count.numpy()[0])
307
166
 
308
- # def spatial_vector(v=(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)):
309
- # return np.array(v)
310
167
 
168
+ def array_sum(values, out=None, value_count=None, axis=None):
169
+ if value_count is None:
170
+ if axis is None:
171
+ value_count = values.size
172
+ else:
173
+ value_count = values.shape[axis]
311
174
 
312
- # ad_V pg. 289 L&P, pg. 25 Featherstone
313
- def spatial_cross(a, b):
314
- w = np.cross(a[0:3], b[0:3])
315
- v = np.cross(a[3:6], b[0:3]) + np.cross(a[0:3], b[3:6])
175
+ if axis is None:
176
+ output_shape = (1,)
177
+ else:
316
178
 
317
- return np.array((*w, *v))
179
+ def output_dim(ax, dim):
180
+ return 1 if ax == axis else dim
318
181
 
182
+ output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(values.shape))
319
183
 
320
- # ad_V^T pg. 290 L&P, pg. 25 Featurestone, note this does not includes the sign flip in the definition
321
- def spatial_cross_dual(a, b):
322
- w = np.cross(a[0:3], b[0:3]) + np.cross(a[3:6], b[3:6])
323
- v = np.cross(a[0:3], b[3:6])
184
+ type_length = wp.types.type_length(values.dtype)
185
+ scalar_type = wp.types.type_scalar_type(values.dtype)
324
186
 
325
- return np.array((*w, *v))
187
+ # User can provide a device output array for storing the number of runs
188
+ # For convenience, if no such array is provided, number of runs is returned on host
189
+ if out is None:
190
+ host_return = True
191
+ out = wp.empty(shape=output_shape, dtype=values.dtype, device=values.device)
192
+ else:
193
+ host_return = False
194
+ if out.device != values.device:
195
+ raise RuntimeError("out storage device should match values array")
196
+ if out.dtype != values.dtype:
197
+ raise RuntimeError(f"out array should have type {values.dtype.__name__}")
198
+ if out.shape != output_shape:
199
+ raise RuntimeError(f"out array should have shape {output_shape}")
200
+
201
+ if value_count == 0:
202
+ out.zero_()
203
+ if axis is None and host_return:
204
+ return out.numpy()[0]
205
+ return out
326
206
 
207
+ from warp.context import runtime
327
208
 
328
- def spatial_dot(a, b):
329
- return np.dot(a, b)
209
+ if values.device.is_cpu:
210
+ if scalar_type == wp.float32:
211
+ native_func = runtime.core.array_sum_float_host
212
+ elif scalar_type == wp.float64:
213
+ native_func = runtime.core.array_sum_double_host
214
+ else:
215
+ raise RuntimeError("Unsupported data type")
216
+ elif values.device.is_cuda:
217
+ if scalar_type == wp.float32:
218
+ native_func = runtime.core.array_sum_float_device
219
+ elif scalar_type == wp.float64:
220
+ native_func = runtime.core.array_sum_double_device
221
+ else:
222
+ raise RuntimeError("Unsupported data type")
330
223
 
224
+ if axis is None:
225
+ stride = wp.types.type_size_in_bytes(values.dtype)
226
+ native_func(values.ptr, out.ptr, value_count, stride, type_length)
331
227
 
332
- def spatial_outer(a, b):
333
- return np.outer(a, b)
228
+ if host_return:
229
+ return out.numpy()[0]
230
+ else:
231
+ stride = values.strides[axis]
232
+ for idx in np.ndindex(output_shape):
233
+ out_offset = sum(i * s for i, s in zip(idx, out.strides))
234
+ val_offset = sum(i * s for i, s in zip(idx, values.strides))
334
235
 
236
+ native_func(
237
+ values.ptr + val_offset,
238
+ out.ptr + out_offset,
239
+ value_count,
240
+ stride,
241
+ type_length,
242
+ )
335
243
 
336
- # def spatial_matrix():
337
- # return np.zeros((6, 6))
244
+ if host_return:
245
+ return out
338
246
 
339
247
 
340
- def spatial_matrix_from_inertia(I, m):
341
- G = spatial_matrix()
248
+ def array_inner(a, b, out=None, count=None, axis=None):
249
+ if a.size != b.size:
250
+ raise RuntimeError("Array storage sizes do not match")
342
251
 
343
- G[0:3, 0:3] = I
344
- G[3, 3] = m
345
- G[4, 4] = m
346
- G[5, 5] = m
252
+ if a.device != b.device:
253
+ raise RuntimeError("Array storage devices do not match")
347
254
 
348
- return G
255
+ if a.dtype != b.dtype:
256
+ raise RuntimeError("Array data types do not match")
349
257
 
258
+ if count is None:
259
+ if axis is None:
260
+ count = a.size
261
+ else:
262
+ count = a.shape[axis]
350
263
 
351
- # solves x = I^(-1)b
352
- def spatial_solve(I, b):
353
- return np.dot(np.linalg.inv(I), b)
264
+ if axis is None:
265
+ output_shape = (1,)
266
+ else:
354
267
 
268
+ def output_dim(ax, dim):
269
+ return 1 if ax == axis else dim
355
270
 
356
- # helper to retrive body angular velocity from a twist v_s in se(3)
357
- def get_body_angular_velocity(v_s):
358
- return v_s[0:3]
271
+ output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(a.shape))
359
272
 
273
+ type_length = wp.types.type_length(a.dtype)
274
+ scalar_type = wp.types.type_scalar_type(a.dtype)
360
275
 
361
- # helper to compute velocity of a point p on a body given it's spatial twist v_s
362
- def get_body_linear_velocity(v_s, p):
363
- dpdt = v_s[3:6] + np.cross(v_s[0:3], p)
364
- return dpdt
276
+ # User can provide a device output array for storing the number of runs
277
+ # For convenience, if no such array is provided, number of runs is returned on host
278
+ if out is None:
279
+ host_return = True
280
+ out = wp.empty(shape=output_shape, dtype=scalar_type, device=a.device)
281
+ else:
282
+ host_return = False
283
+ if out.device != a.device:
284
+ raise RuntimeError("out storage device should match values array")
285
+ if out.dtype != scalar_type:
286
+ raise RuntimeError(f"out array should have type {scalar_type.__name__}")
287
+ if out.shape != output_shape:
288
+ raise RuntimeError(f"out array should have shape {output_shape}")
289
+
290
+ if count == 0:
291
+ if axis is None and host_return:
292
+ return 0.0
293
+ out.zero_()
294
+ return out
365
295
 
296
+ from warp.context import runtime
366
297
 
367
- # helper to build a body twist given the angular and linear velocity of
368
- # the center of mass specified in the world frame, returns the body
369
- # twist with respect to the origin (v_s)
370
- def get_body_twist(w_m, v_m, p_m):
371
- lin = v_m + np.cross(p_m, w_m)
372
- return (*w_m, *lin)
298
+ if a.device.is_cpu:
299
+ if scalar_type == wp.float32:
300
+ native_func = runtime.core.array_inner_float_host
301
+ elif scalar_type == wp.float64:
302
+ native_func = runtime.core.array_inner_double_host
303
+ else:
304
+ raise RuntimeError("Unsupported data type")
305
+ elif a.device.is_cuda:
306
+ if scalar_type == wp.float32:
307
+ native_func = runtime.core.array_inner_float_device
308
+ elif scalar_type == wp.float64:
309
+ native_func = runtime.core.array_inner_double_device
310
+ else:
311
+ raise RuntimeError("Unsupported data type")
373
312
 
313
+ if axis is None:
314
+ stride_a = wp.types.type_size_in_bytes(a.dtype)
315
+ stride_b = wp.types.type_size_in_bytes(b.dtype)
316
+ native_func(a.ptr, b.ptr, out.ptr, count, stride_a, stride_b, type_length)
374
317
 
375
- def array_scan(in_array, out_array, inclusive=True):
318
+ if host_return:
319
+ return out.numpy()[0]
320
+ else:
321
+ stride_a = a.strides[axis]
322
+ stride_b = b.strides[axis]
323
+
324
+ for idx in np.ndindex(output_shape):
325
+ out_offset = sum(i * s for i, s in zip(idx, out.strides))
326
+ a_offset = sum(i * s for i, s in zip(idx, a.strides))
327
+ b_offset = sum(i * s for i, s in zip(idx, b.strides))
328
+
329
+ native_func(
330
+ a.ptr + a_offset,
331
+ b.ptr + b_offset,
332
+ out.ptr + out_offset,
333
+ count,
334
+ stride_a,
335
+ stride_b,
336
+ type_length,
337
+ )
338
+
339
+ if host_return:
340
+ return out
341
+
342
+
343
+ @wp.kernel
344
+ def _array_cast_kernel(
345
+ dest: Any,
346
+ src: Any,
347
+ ):
348
+ i = wp.tid()
349
+ dest[i] = dest.dtype(src[i])
350
+
351
+
352
+ def array_cast(in_array, out_array, count=None):
376
353
  if in_array.device != out_array.device:
377
354
  raise RuntimeError("Array storage devices do not match")
378
355
 
379
- if in_array.size != out_array.size:
380
- raise RuntimeError("Array storage sizes do not match")
356
+ in_array_data_shape = getattr(in_array.dtype, "_shape_", ())
357
+ out_array_data_shape = getattr(out_array.dtype, "_shape_", ())
358
+
359
+ if in_array.ndim != out_array.ndim or in_array_data_shape != out_array_data_shape:
360
+ # Number of dimensions or data type shape do not match.
361
+ # Flatten arrays and do cast at the scalar level
362
+ in_array = in_array.flatten()
363
+ out_array = out_array.flatten()
364
+
365
+ in_array_data_length = warp.types.type_length(in_array.dtype)
366
+ out_array_data_length = warp.types.type_length(out_array.dtype)
367
+ in_array_scalar_type = wp.types.type_scalar_type(in_array.dtype)
368
+ out_array_scalar_type = wp.types.type_scalar_type(out_array.dtype)
369
+
370
+ in_array = wp.array(
371
+ data=None,
372
+ ptr=in_array.ptr,
373
+ capacity=in_array.capacity,
374
+ owner=False,
375
+ device=in_array.device,
376
+ dtype=in_array_scalar_type,
377
+ shape=in_array.shape[0] * in_array_data_length,
378
+ )
381
379
 
382
- if in_array.dtype != out_array.dtype:
383
- raise RuntimeError("Array data types do not match")
380
+ out_array = wp.array(
381
+ data=None,
382
+ ptr=out_array.ptr,
383
+ capacity=out_array.capacity,
384
+ owner=False,
385
+ device=out_array.device,
386
+ dtype=out_array_scalar_type,
387
+ shape=out_array.shape[0] * out_array_data_length,
388
+ )
384
389
 
385
- from warp.context import runtime
390
+ if count is not None:
391
+ count *= in_array_data_length
386
392
 
387
- if in_array.device == "cpu":
388
- if in_array.dtype == wp.int32:
389
- runtime.core.array_scan_int_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
390
- elif in_array.dtype == wp.float32:
391
- runtime.core.array_scan_float_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
392
- else:
393
- raise RuntimeError("Unsupported data type")
394
- elif in_array.device == "cuda":
395
- if in_array.dtype == wp.int32:
396
- runtime.core.array_scan_int_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
397
- elif in_array.dtype == wp.float32:
398
- runtime.core.array_scan_float_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
399
- else:
400
- raise RuntimeError("Unsupported data type")
393
+ if count is None:
394
+ count = in_array.size
395
+
396
+ if in_array.ndim == 1:
397
+ dim = count
398
+ elif count < in_array.size:
399
+ raise RuntimeError("Partial cast is not supported for arrays with more than one dimension")
400
+ else:
401
+ dim = in_array.shape
402
+
403
+ if in_array.dtype == out_array.dtype:
404
+ # Same data type, can simply copy
405
+ wp.copy(dest=out_array, src=in_array, count=count)
406
+ else:
407
+ wp.launch(kernel=_array_cast_kernel, dim=dim, inputs=[out_array, in_array], device=out_array.device)
401
408
 
402
409
 
403
410
  # code snippet for invoking cProfile
@@ -411,6 +418,25 @@ def array_scan(in_array, out_array, inclusive=True):
411
418
  # exit(0)
412
419
 
413
420
 
421
+ # helper kernels for initializing NVDB volumes from a dense array
422
+ @wp.kernel
423
+ def copy_dense_volume_to_nano_vdb_v(volume: wp.uint64, values: wp.array(dtype=wp.vec3, ndim=3)):
424
+ i, j, k = wp.tid()
425
+ wp.volume_store_v(volume, i, j, k, values[i, j, k])
426
+
427
+
428
+ @wp.kernel
429
+ def copy_dense_volume_to_nano_vdb_f(volume: wp.uint64, values: wp.array(dtype=wp.float32, ndim=3)):
430
+ i, j, k = wp.tid()
431
+ wp.volume_store_f(volume, i, j, k, values[i, j, k])
432
+
433
+
434
+ @wp.kernel
435
+ def copy_dense_volume_to_nano_vdb_i(volume: wp.uint64, values: wp.array(dtype=wp.int32, ndim=3)):
436
+ i, j, k = wp.tid()
437
+ wp.volume_store_i(volume, i, j, k, values[i, j, k])
438
+
439
+
414
440
  # represent an edge between v0, v1 with connected faces f0, f1, and opposite vertex o0, and o1
415
441
  # winding is such that first tri can be reconstructed as {v0, v1, o0}, and second tri as { v1, v0, o1 }
416
442
  class MeshEdge:
@@ -454,11 +480,8 @@ class MeshAdjacency:
454
480
 
455
481
  self.edges[key] = edge
456
482
 
457
- def opposite_vertex(self, edge):
458
- pass
459
-
460
483
 
461
- def mem_report():
484
+ def mem_report(): #pragma: no cover
462
485
  def _mem_report(tensors, mem_type):
463
486
  """Print the selected tensors of type
464
487
  There are two major storage types in our major concern:
@@ -494,6 +517,7 @@ def mem_report():
494
517
  print("Type: %s Total Tensors: %d \tUsed Memory Space: %.2f MBytes" % (mem_type, total_numel, total_mem))
495
518
 
496
519
  import gc
520
+
497
521
  import torch
498
522
 
499
523
  gc.collect()
@@ -509,35 +533,6 @@ def mem_report():
509
533
  print("=" * LEN)
510
534
 
511
535
 
512
- def lame_parameters(E, nu):
513
- l = (E * nu) / ((1.0 + nu) * (1.0 - 2.0 * nu))
514
- mu = E / (2.0 * (1.0 + nu))
515
-
516
- return (l, mu)
517
-
518
-
519
- # **Deprecated: use ScopedDevice instead
520
- # ensures that correct CUDA is set for the guards lifetime
521
- # restores the previous CUDA context on exit
522
- class ScopedCudaGuard:
523
- def __init__(self):
524
- import warnings
525
-
526
- warnings.warn("ScopedCudaGuard is deprecated, use ScopedDevice instead")
527
-
528
- if wp.context.runtime.cuda_devices:
529
- self.device = wp.context.runtime.initial_cuda_device
530
- else:
531
- self.device = None
532
-
533
- def __enter__(self):
534
- if self.device is not None:
535
- self.device.context_guard.__enter__()
536
-
537
- def __exit__(self, exc_type, exc_value, traceback):
538
- if self.device is not None:
539
- self.device.context_guard.__exit__(exc_type, exc_value, traceback)
540
-
541
536
 
542
537
  class ScopedDevice:
543
538
  def __init__(self, device):
@@ -642,7 +637,8 @@ class ScopedTimer:
642
637
  return
643
638
 
644
639
  self.start = timeit.default_timer()
645
- ScopedTimer.indent += 1
640
+ if self.print:
641
+ ScopedTimer.indent += 1
646
642
 
647
643
  if self.detailed:
648
644
  self.cp = cProfile.Profile()
@@ -679,3 +675,17 @@ class ScopedTimer:
679
675
  print("{}{} took {:.2f} ms".format(indent, self.name, self.elapsed))
680
676
 
681
677
  ScopedTimer.indent -= 1
678
+
679
+
680
+ # helper kernels for adj_matmul
681
+ @wp.kernel
682
+ def add_kernel_2d(x: wp.array2d(dtype=Any), acc: wp.array2d(dtype=Any), beta: Any):
683
+ i, j = wp.tid()
684
+
685
+ x[i,j] = x[i,j] + beta * acc[i,j]
686
+
687
+ @wp.kernel
688
+ def add_kernel_3d(x: wp.array3d(dtype=Any), acc: wp.array3d(dtype=Any), beta: Any):
689
+ i, j, k = wp.tid()
690
+
691
+ x[i,j,k] = x[i,j,k] + beta * acc[i,j,k]