warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,953 @@
1
+ from typing import Optional
2
+
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
11
+
12
+ from .element import Cube, Square
13
+ from .geometry import Geometry
14
+
15
+
16
+ @wp.struct
17
+ class HexmeshCellArg:
18
+ hex_vertex_indices: wp.array2d(dtype=int)
19
+ positions: wp.array(dtype=wp.vec3)
20
+
21
+ # for neighbor cell lookup
22
+ vertex_hex_offsets: wp.array(dtype=int)
23
+ vertex_hex_indices: wp.array(dtype=int)
24
+
25
+
26
+ @wp.struct
27
+ class HexmeshSideArg:
28
+ cell_arg: HexmeshCellArg
29
+ face_vertex_indices: wp.array(dtype=wp.vec4i)
30
+ face_hex_indices: wp.array(dtype=wp.vec2i)
31
+ face_hex_face_orientation: wp.array(dtype=wp.vec4i)
32
+
33
+
34
+ _mat32 = wp.mat(shape=(3, 2), dtype=float)
35
+
36
+ FACE_VERTEX_INDICES = wp.constant(
37
+ wp.mat(shape=(6, 4), dtype=int)(
38
+ [
39
+ [0, 4, 7, 3], # x = 0
40
+ [1, 2, 6, 5], # x = 1
41
+ [0, 1, 5, 4], # y = 0
42
+ [3, 7, 6, 2], # y = 1
43
+ [0, 3, 2, 1], # z = 0
44
+ [4, 5, 6, 7], # z = 1
45
+ ]
46
+ )
47
+ )
48
+
49
+ EDGE_VERTEX_INDICES = wp.constant(
50
+ wp.mat(shape=(12, 2), dtype=int)(
51
+ [
52
+ [0, 1],
53
+ [1, 2],
54
+ [3, 2],
55
+ [0, 3],
56
+ [4, 5],
57
+ [5, 6],
58
+ [7, 6],
59
+ [4, 7],
60
+ [0, 4],
61
+ [1, 5],
62
+ [2, 6],
63
+ [3, 7],
64
+ ]
65
+ )
66
+ )
67
+
68
+ # orthogal transform for face coordinates given first vertex + winding
69
+ # (two rows per entry)
70
+
71
+ FACE_ORIENTATION = [
72
+ [1, 0], # FV: 0, det: +
73
+ [0, 1],
74
+ [0, 1], # FV: 0, det: -
75
+ [1, 0],
76
+ [0, -1], # FV: 1, det: +
77
+ [1, 0],
78
+ [-1, 0], # FV: 1, det: -
79
+ [0, 1],
80
+ [-1, 0], # FV: 2, det: +
81
+ [0, -1],
82
+ [0, -1], # FV: 2, det: -
83
+ [-1, 0],
84
+ [0, 1], # FV: 3, det: +
85
+ [-1, 0],
86
+ [1, 0], # FV: 3, det: -
87
+ [0, -1],
88
+ ]
89
+
90
+ FACE_TRANSLATION = [
91
+ [0, 0],
92
+ [1, 0],
93
+ [1, 1],
94
+ [0, 1],
95
+ ]
96
+
97
+ # local face coordinate system
98
+ _FACE_COORD_INDICES = wp.constant(
99
+ wp.mat(shape=(6, 4), dtype=int)(
100
+ [
101
+ [2, 1, 0, 0], # 0: z y -x
102
+ [1, 2, 0, 1], # 1: y z x-1
103
+ [0, 2, 1, 0], # 2: x z -y
104
+ [2, 0, 1, 1], # 3: z x y-1
105
+ [1, 0, 2, 0], # 4: y x -z
106
+ [0, 1, 2, 1], # 5: x y z-1
107
+ ]
108
+ )
109
+ )
110
+
111
+ _FACE_ORIENTATION_F = wp.constant(wp.mat(shape=(16, 2), dtype=float)(FACE_ORIENTATION))
112
+ _FACE_TRANSLATION_F = wp.constant(wp.mat(shape=(4, 2), dtype=float)(FACE_TRANSLATION))
113
+
114
+
115
+ class Hexmesh(Geometry):
116
+ """Hexahedral mesh geometry"""
117
+
118
+ dimension = 3
119
+
120
+ def __init__(
121
+ self, hex_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
122
+ ):
123
+ """
124
+ Constructs a tetrahedral mesh.
125
+
126
+ Args:
127
+ hex_vertex_indices: warp array of shape (num_hexes, 8) containing vertex indices for each hex
128
+ following standard ordering (bottom face vertices in counter-clockwise order, then similarly for upper face)
129
+ positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
130
+ temporary_store: shared pool from which to allocate temporary arrays
131
+ """
132
+
133
+ self.hex_vertex_indices = hex_vertex_indices
134
+ self.positions = positions
135
+
136
+ self._face_vertex_indices: wp.array = None
137
+ self._face_hex_indices: wp.array = None
138
+ self._face_hex_face_orientation: wp.array = None
139
+ self._vertex_hex_offsets: wp.array = None
140
+ self._vertex_hex_indices: wp.array = None
141
+ self._hex_edge_indices: wp.array = None
142
+ self._edge_count = 0
143
+ self._build_topology(temporary_store)
144
+
145
+ def cell_count(self):
146
+ return self.hex_vertex_indices.shape[0]
147
+
148
+ def vertex_count(self):
149
+ return self.positions.shape[0]
150
+
151
+ def side_count(self):
152
+ return self._face_vertex_indices.shape[0]
153
+
154
+ def edge_count(self):
155
+ if self._hex_edge_indices is None:
156
+ self._compute_hex_edges()
157
+ return self._edge_count
158
+
159
+ def boundary_side_count(self):
160
+ return self._boundary_face_indices.shape[0]
161
+
162
+ def reference_cell(self) -> Cube:
163
+ return Cube()
164
+
165
+ def reference_side(self) -> Square:
166
+ return Square()
167
+
168
+ @property
169
+ def hex_edge_indices(self) -> wp.array:
170
+ if self._hex_edge_indices is None:
171
+ self._compute_hex_edges()
172
+ return self._hex_edge_indices
173
+
174
+ @property
175
+ def face_hex_indices(self) -> wp.array:
176
+ return self._face_hex_indices
177
+
178
+ @property
179
+ def face_vertex_indices(self) -> wp.array:
180
+ return self._face_vertex_indices
181
+
182
+ CellArg = HexmeshCellArg
183
+ SideArg = HexmeshSideArg
184
+
185
+ @wp.struct
186
+ class SideIndexArg:
187
+ boundary_face_indices: wp.array(dtype=int)
188
+
189
+ # Geometry device interface
190
+
191
+ @cached_arg_value
192
+ def cell_arg_value(self, device) -> CellArg:
193
+ args = self.CellArg()
194
+
195
+ args.hex_vertex_indices = self.hex_vertex_indices.to(device)
196
+ args.positions = self.positions.to(device)
197
+ args.vertex_hex_offsets = self._vertex_hex_offsets.to(device)
198
+ args.vertex_hex_indices = self._vertex_hex_indices.to(device)
199
+
200
+ return args
201
+
202
+ @wp.func
203
+ def cell_position(args: CellArg, s: Sample):
204
+ hex_idx = args.hex_vertex_indices[s.element_index]
205
+
206
+ w_p = s.element_coords
207
+ w_m = Coords(1.0) - s.element_coords
208
+
209
+ # 0 : m m m
210
+ # 1 : p m m
211
+ # 2 : p p m
212
+ # 3 : m p m
213
+ # 4 : m m p
214
+ # 5 : p m p
215
+ # 6 : p p p
216
+ # 7 : m p p
217
+
218
+ return (
219
+ w_m[0] * w_m[1] * w_m[2] * args.positions[hex_idx[0]]
220
+ + w_p[0] * w_m[1] * w_m[2] * args.positions[hex_idx[1]]
221
+ + w_p[0] * w_p[1] * w_m[2] * args.positions[hex_idx[2]]
222
+ + w_m[0] * w_p[1] * w_m[2] * args.positions[hex_idx[3]]
223
+ + w_m[0] * w_m[1] * w_p[2] * args.positions[hex_idx[4]]
224
+ + w_p[0] * w_m[1] * w_p[2] * args.positions[hex_idx[5]]
225
+ + w_p[0] * w_p[1] * w_p[2] * args.positions[hex_idx[6]]
226
+ + w_m[0] * w_p[1] * w_p[2] * args.positions[hex_idx[7]]
227
+ )
228
+
229
+ @wp.func
230
+ def cell_deformation_gradient(cell_arg: CellArg, s: Sample):
231
+ """Deformation gradient at `coords`"""
232
+ """Transposed deformation gradient at `coords`"""
233
+ hex_idx = cell_arg.hex_vertex_indices[s.element_index]
234
+
235
+ w_p = s.element_coords
236
+ w_m = Coords(1.0) - s.element_coords
237
+
238
+ return (
239
+ wp.outer(cell_arg.positions[hex_idx[0]], wp.vec3(-w_m[1] * w_m[2], -w_m[0] * w_m[2], -w_m[0] * w_m[1]))
240
+ + wp.outer(cell_arg.positions[hex_idx[1]], wp.vec3(w_m[1] * w_m[2], -w_p[0] * w_m[2], -w_p[0] * w_m[1]))
241
+ + wp.outer(cell_arg.positions[hex_idx[2]], wp.vec3(w_p[1] * w_m[2], w_p[0] * w_m[2], -w_p[0] * w_p[1]))
242
+ + wp.outer(cell_arg.positions[hex_idx[3]], wp.vec3(-w_p[1] * w_m[2], w_m[0] * w_m[2], -w_m[0] * w_p[1]))
243
+ + wp.outer(cell_arg.positions[hex_idx[4]], wp.vec3(-w_m[1] * w_p[2], -w_m[0] * w_p[2], w_m[0] * w_m[1]))
244
+ + wp.outer(cell_arg.positions[hex_idx[5]], wp.vec3(w_m[1] * w_p[2], -w_p[0] * w_p[2], w_p[0] * w_m[1]))
245
+ + wp.outer(cell_arg.positions[hex_idx[6]], wp.vec3(w_p[1] * w_p[2], w_p[0] * w_p[2], w_p[0] * w_p[1]))
246
+ + wp.outer(cell_arg.positions[hex_idx[7]], wp.vec3(-w_p[1] * w_p[2], w_m[0] * w_p[2], w_m[0] * w_p[1]))
247
+ )
248
+
249
+ @wp.func
250
+ def cell_inverse_deformation_gradient(cell_arg: CellArg, s: Sample):
251
+ return wp.inverse(Hexmesh.cell_deformation_gradient(cell_arg, s))
252
+
253
+ @wp.func
254
+ def cell_measure(args: CellArg, s: Sample):
255
+ return wp.abs(wp.determinant(Hexmesh.cell_deformation_gradient(args, s)))
256
+
257
+ @wp.func
258
+ def cell_normal(args: CellArg, s: Sample):
259
+ return wp.vec3(0.0)
260
+
261
+ @cached_arg_value
262
+ def side_index_arg_value(self, device) -> SideIndexArg:
263
+ args = self.SideIndexArg()
264
+
265
+ args.boundary_face_indices = self._boundary_face_indices.to(device)
266
+
267
+ return args
268
+
269
+ @wp.func
270
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
271
+ """Boundary side to side index"""
272
+
273
+ return args.boundary_face_indices[boundary_side_index]
274
+
275
+ @cached_arg_value
276
+ def side_arg_value(self, device) -> CellArg:
277
+ args = self.SideArg()
278
+
279
+ args.cell_arg = self.cell_arg_value(device)
280
+ args.face_vertex_indices = self._face_vertex_indices.to(device)
281
+ args.face_hex_indices = self._face_hex_indices.to(device)
282
+ args.face_hex_face_orientation = self._face_hex_face_orientation.to(device)
283
+
284
+ return args
285
+
286
+ @wp.func
287
+ def side_position(args: SideArg, s: Sample):
288
+ face_idx = args.face_vertex_indices[s.element_index]
289
+
290
+ w_p = s.element_coords
291
+ w_m = Coords(1.0) - s.element_coords
292
+
293
+ return (
294
+ w_m[0] * w_m[1] * args.cell_arg.positions[face_idx[0]]
295
+ + w_p[0] * w_m[1] * args.cell_arg.positions[face_idx[1]]
296
+ + w_p[0] * w_p[1] * args.cell_arg.positions[face_idx[2]]
297
+ + w_m[0] * w_p[1] * args.cell_arg.positions[face_idx[3]]
298
+ )
299
+
300
+ @wp.func
301
+ def _side_deformation_vecs(args: SideArg, side_index: ElementIndex, coords: Coords):
302
+ face_idx = args.face_vertex_indices[side_index]
303
+
304
+ p0 = args.cell_arg.positions[face_idx[0]]
305
+ p1 = args.cell_arg.positions[face_idx[1]]
306
+ p2 = args.cell_arg.positions[face_idx[2]]
307
+ p3 = args.cell_arg.positions[face_idx[3]]
308
+
309
+ w_p = coords
310
+ w_m = Coords(1.0) - coords
311
+
312
+ v1 = w_m[1] * (p1 - p0) + w_p[1] * (p2 - p3)
313
+ v2 = w_p[0] * (p2 - p1) + w_m[0] * (p3 - p0)
314
+ return v1, v2
315
+
316
+ @wp.func
317
+ def side_deformation_gradient(args: SideArg, s:Sample):
318
+ """Transposed side deformation gradient at `coords`"""
319
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
320
+ return _mat32(v1, v2)
321
+
322
+ @wp.func
323
+ def side_inner_inverse_deformation_gradient(args: SideArg, s:Sample):
324
+ cell_index = Hexmesh.side_inner_cell_index(args, s.element_index)
325
+ cell_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
326
+ return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
327
+
328
+ @wp.func
329
+ def side_outer_inverse_deformation_gradient(args: SideArg, s:Sample):
330
+ cell_index = Hexmesh.side_outer_cell_index(args, s.element_index)
331
+ cell_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
332
+ return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
333
+
334
+ @wp.func
335
+ def side_measure(args: SideArg, s: Sample):
336
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
337
+ return wp.length(wp.cross(v1, v2))
338
+
339
+ @wp.func
340
+ def side_measure_ratio(args: SideArg, s: Sample):
341
+ inner = Hexmesh.side_inner_cell_index(args, s.element_index)
342
+ outer = Hexmesh.side_outer_cell_index(args, s.element_index)
343
+ inner_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
344
+ outer_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
345
+ return Hexmesh.side_measure(args, s) / wp.min(
346
+ Hexmesh.cell_measure(args.cell_arg, make_free_sample(inner, inner_coords)),
347
+ Hexmesh.cell_measure(args.cell_arg, make_free_sample(outer, outer_coords)),
348
+ )
349
+
350
+ @wp.func
351
+ def side_normal(args: SideArg, s: Sample):
352
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
353
+ return wp.normalize(wp.cross(v1, v2))
354
+
355
+ @wp.func
356
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
357
+ return arg.face_hex_indices[side_index][0]
358
+
359
+ @wp.func
360
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
361
+ return arg.face_hex_indices[side_index][1]
362
+
363
+ @wp.func
364
+ def _hex_local_face_coords(hex_coords: Coords, face_index: int):
365
+ # Coordinatex in local face coordinates system
366
+ # Sign of last coordinate (out of face)
367
+
368
+ face_coords = wp.vec2(
369
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]], hex_coords[_FACE_COORD_INDICES[face_index, 1]]
370
+ )
371
+
372
+ normal_coord = hex_coords[_FACE_COORD_INDICES[face_index, 2]]
373
+ normal_coord = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, normal_coord - 1.0, -normal_coord)
374
+
375
+ return face_coords, normal_coord
376
+
377
+ @wp.func
378
+ def _local_face_hex_coords(face_coords: wp.vec2, face_index: int):
379
+ # Coordinates in hex from local face coordinates system
380
+
381
+ hex_coords = Coords()
382
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]] = face_coords[0]
383
+ hex_coords[_FACE_COORD_INDICES[face_index, 1]] = face_coords[1]
384
+ hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, 1.0, 0.0)
385
+
386
+ return hex_coords
387
+
388
+ @wp.func
389
+ def _local_from_oriented_face_coords(ori: int, oriented_coords: Coords):
390
+ fv = ori // 2
391
+ return (oriented_coords[0] - _FACE_TRANSLATION_F[fv, 0]) * _FACE_ORIENTATION_F[2 * ori] + (
392
+ oriented_coords[1] - _FACE_TRANSLATION_F[fv, 1]
393
+ ) * _FACE_ORIENTATION_F[2 * ori + 1]
394
+
395
+ @wp.func
396
+ def _local_to_oriented_face_coords(ori: int, coords: wp.vec2):
397
+ fv = ori // 2
398
+ return Coords(
399
+ wp.dot(_FACE_ORIENTATION_F[2 * ori], coords) + _FACE_TRANSLATION_F[fv, 0],
400
+ wp.dot(_FACE_ORIENTATION_F[2 * ori + 1], coords) + _FACE_TRANSLATION_F[fv, 1],
401
+ 0.0,
402
+ )
403
+
404
+ @wp.func
405
+ def face_to_hex_coords(local_face_index: int, face_orientation: int, side_coords: Coords):
406
+ local_coords = Hexmesh._local_from_oriented_face_coords(face_orientation, side_coords)
407
+ return Hexmesh._local_face_hex_coords(local_coords, local_face_index)
408
+
409
+ @wp.func
410
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
411
+ local_face_index = args.face_hex_face_orientation[side_index][0]
412
+ face_orientation = args.face_hex_face_orientation[side_index][1]
413
+
414
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
415
+
416
+ @wp.func
417
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
418
+ local_face_index = args.face_hex_face_orientation[side_index][2]
419
+ face_orientation = args.face_hex_face_orientation[side_index][3]
420
+
421
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
422
+
423
+ @wp.func
424
+ def side_from_cell_coords(args: SideArg, side_index: ElementIndex, hex_index: ElementIndex, hex_coords: Coords):
425
+ if Hexmesh.side_inner_cell_index(args, side_index) == hex_index:
426
+ local_face_index = args.face_hex_face_orientation[side_index][0]
427
+ face_orientation = args.face_hex_face_orientation[side_index][1]
428
+ else:
429
+ local_face_index = args.face_hex_face_orientation[side_index][2]
430
+ face_orientation = args.face_hex_face_orientation[side_index][3]
431
+
432
+ face_coords, normal_coord = Hexmesh._hex_local_face_coords(hex_coords, local_face_index)
433
+ return wp.select(
434
+ normal_coord == 0.0, Coords(OUTSIDE), Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords)
435
+ )
436
+
437
+ @wp.func
438
+ def side_to_cell_arg(side_arg: SideArg):
439
+ return side_arg.cell_arg
440
+
441
+ def _build_topology(self, temporary_store: TemporaryStore):
442
+ from warp.fem.utils import compress_node_indices, masked_indices
443
+ from warp.utils import array_scan
444
+
445
+ device = self.hex_vertex_indices.device
446
+
447
+ vertex_hex_offsets, vertex_hex_indices, _, __ = compress_node_indices(
448
+ self.vertex_count(), self.hex_vertex_indices, temporary_store=temporary_store
449
+ )
450
+ self._vertex_hex_offsets = vertex_hex_offsets.detach()
451
+ self._vertex_hex_indices = vertex_hex_indices.detach()
452
+
453
+ vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
454
+ vertex_start_face_count.array.zero_()
455
+ vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
456
+
457
+ vertex_face_other_vs = borrow_temporary(
458
+ temporary_store, dtype=wp.vec3i, device=device, shape=(8 * self.cell_count())
459
+ )
460
+ vertex_face_hexes = borrow_temporary(
461
+ temporary_store, dtype=int, device=device, shape=(8 * self.cell_count(), 2)
462
+ )
463
+
464
+ # Count face edges starting at each vertex
465
+ wp.launch(
466
+ kernel=Hexmesh._count_starting_faces_kernel,
467
+ device=device,
468
+ dim=self.cell_count(),
469
+ inputs=[self.hex_vertex_indices, vertex_start_face_count.array],
470
+ )
471
+
472
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_start_face_offsets.array, inclusive=False)
473
+
474
+ # Count number of unique edges (deduplicate across faces)
475
+ vertex_unique_face_count = vertex_start_face_count
476
+ wp.launch(
477
+ kernel=Hexmesh._count_unique_starting_faces_kernel,
478
+ device=device,
479
+ dim=self.vertex_count(),
480
+ inputs=[
481
+ self._vertex_hex_offsets,
482
+ self._vertex_hex_indices,
483
+ self.hex_vertex_indices,
484
+ vertex_start_face_offsets.array,
485
+ vertex_unique_face_count.array,
486
+ vertex_face_other_vs.array,
487
+ vertex_face_hexes.array,
488
+ ],
489
+ )
490
+
491
+ vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
492
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
493
+
494
+ # Get back edge count to host
495
+ if device.is_cuda:
496
+ face_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
497
+ # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
498
+ wp.copy(
499
+ dest=face_count.array, src=vertex_unique_face_offsets.array, src_offset=self.vertex_count() - 1, count=1
500
+ )
501
+ wp.synchronize_stream(wp.get_stream(device))
502
+ face_count = int(face_count.array.numpy()[0])
503
+ else:
504
+ face_count = int(vertex_unique_face_offsets.array.numpy()[self.vertex_count() - 1])
505
+
506
+ self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
507
+ self._face_hex_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
508
+ self._face_hex_face_orientation = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
509
+
510
+ boundary_mask = borrow_temporary(temporary_store, shape=(face_count,), dtype=int, device=device)
511
+
512
+ # Compress edge data
513
+ wp.launch(
514
+ kernel=Hexmesh._compress_faces_kernel,
515
+ device=device,
516
+ dim=self.vertex_count(),
517
+ inputs=[
518
+ vertex_start_face_offsets.array,
519
+ vertex_unique_face_offsets.array,
520
+ vertex_unique_face_count.array,
521
+ vertex_face_other_vs.array,
522
+ vertex_face_hexes.array,
523
+ self._face_vertex_indices,
524
+ self._face_hex_indices,
525
+ boundary_mask.array,
526
+ ],
527
+ )
528
+
529
+ vertex_start_face_offsets.release()
530
+ vertex_unique_face_offsets.release()
531
+ vertex_unique_face_count.release()
532
+ vertex_face_other_vs.release()
533
+ vertex_face_hexes.release()
534
+
535
+ # Flip normals if necessary
536
+ wp.launch(
537
+ kernel=Hexmesh._flip_face_normals,
538
+ device=device,
539
+ dim=self.side_count(),
540
+ inputs=[self._face_vertex_indices, self._face_hex_indices, self.hex_vertex_indices, self.positions],
541
+ )
542
+
543
+ # Compute and store face orientation
544
+ wp.launch(
545
+ kernel=Hexmesh._compute_face_orientation,
546
+ device=device,
547
+ dim=self.side_count(),
548
+ inputs=[
549
+ self._face_vertex_indices,
550
+ self._face_hex_indices,
551
+ self.hex_vertex_indices,
552
+ self._face_hex_face_orientation,
553
+ ],
554
+ )
555
+
556
+ boundary_face_indices, _ = masked_indices(boundary_mask.array)
557
+ self._boundary_face_indices = boundary_face_indices.detach()
558
+
559
+ def _compute_hex_edges(self, temporary_store: Optional[TemporaryStore] = None):
560
+ from warp.utils import array_scan
561
+
562
+ device = self.hex_vertex_indices.device
563
+
564
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
565
+ vertex_start_edge_count.array.zero_()
566
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
567
+
568
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(12 * self.cell_count()))
569
+
570
+ # Count face edges starting at each vertex
571
+ wp.launch(
572
+ kernel=Hexmesh._count_starting_edges_kernel,
573
+ device=device,
574
+ dim=self.cell_count(),
575
+ inputs=[self.hex_vertex_indices, vertex_start_edge_count.array],
576
+ )
577
+
578
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
579
+
580
+ # Count number of unique edges (deduplicate across faces)
581
+ vertex_unique_edge_count = vertex_start_edge_count
582
+ wp.launch(
583
+ kernel=Hexmesh._count_unique_starting_edges_kernel,
584
+ device=device,
585
+ dim=self.vertex_count(),
586
+ inputs=[
587
+ self._vertex_hex_offsets,
588
+ self._vertex_hex_indices,
589
+ self.hex_vertex_indices,
590
+ vertex_start_edge_offsets.array,
591
+ vertex_unique_edge_count.array,
592
+ vertex_edge_ends.array,
593
+ ],
594
+ )
595
+
596
+ vertex_unique_edge_offsets = borrow_temporary_like(
597
+ vertex_start_edge_offsets.array, temporary_store=temporary_store
598
+ )
599
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
600
+
601
+ # Get back edge count to host
602
+ if device.is_cuda:
603
+ edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
604
+ # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
605
+ wp.copy(
606
+ dest=edge_count.array,
607
+ src=vertex_unique_edge_offsets.array,
608
+ src_offset=self.vertex_count() - 1,
609
+ count=1,
610
+ )
611
+ wp.synchronize_stream(wp.get_stream(device))
612
+ self._edge_count = int(edge_count.array.numpy()[0])
613
+ else:
614
+ self._edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
615
+
616
+ self._hex_edge_indices = wp.empty(
617
+ dtype=int, device=self.hex_vertex_indices.device, shape=(self.cell_count(), 12)
618
+ )
619
+
620
+ # Compress edge data
621
+ wp.launch(
622
+ kernel=Hexmesh._compress_edges_kernel,
623
+ device=device,
624
+ dim=self.vertex_count(),
625
+ inputs=[
626
+ self._vertex_hex_offsets,
627
+ self._vertex_hex_indices,
628
+ self.hex_vertex_indices,
629
+ vertex_start_edge_offsets.array,
630
+ vertex_unique_edge_offsets.array,
631
+ vertex_unique_edge_count.array,
632
+ vertex_edge_ends.array,
633
+ self._hex_edge_indices,
634
+ ],
635
+ )
636
+
637
+ vertex_start_edge_offsets.release()
638
+ vertex_unique_edge_offsets.release()
639
+ vertex_unique_edge_count.release()
640
+ vertex_edge_ends.release()
641
+
642
+ @wp.kernel
643
+ def _count_starting_faces_kernel(
644
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_face_count: wp.array(dtype=int)
645
+ ):
646
+ t = wp.tid()
647
+ for k in range(6):
648
+ vi = wp.vec4i(
649
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 0]],
650
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 1]],
651
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 2]],
652
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 3]],
653
+ )
654
+ vm = wp.min(vi)
655
+
656
+ for i in range(4):
657
+ if vm == vi[i]:
658
+ wp.atomic_add(vertex_start_face_count, vm, 1)
659
+
660
+ @wp.func
661
+ def _face_sort(vidx: wp.vec4i, min_k: int):
662
+ v1 = vidx[(min_k + 1) % 4]
663
+ v2 = vidx[(min_k + 2) % 4]
664
+ v3 = vidx[(min_k + 3) % 4]
665
+
666
+ if v1 < v3:
667
+ return wp.vec3i(v1, v2, v3)
668
+ return wp.vec3i(v3, v2, v1)
669
+
670
+ @wp.func
671
+ def _find_face(
672
+ needle: wp.vec3i,
673
+ values: wp.array(dtype=wp.vec3i),
674
+ beg: int,
675
+ end: int,
676
+ ):
677
+ for i in range(beg, end):
678
+ if values[i] == needle:
679
+ return i
680
+
681
+ return -1
682
+
683
+ @wp.kernel
684
+ def _count_unique_starting_faces_kernel(
685
+ vertex_hex_offsets: wp.array(dtype=int),
686
+ vertex_hex_indices: wp.array(dtype=int),
687
+ hex_vertex_indices: wp.array2d(dtype=int),
688
+ vertex_start_face_offsets: wp.array(dtype=int),
689
+ vertex_start_face_count: wp.array(dtype=int),
690
+ face_other_vs: wp.array(dtype=wp.vec3i),
691
+ face_hexes: wp.array2d(dtype=int),
692
+ ):
693
+ v = wp.tid()
694
+
695
+ face_beg = vertex_start_face_offsets[v]
696
+
697
+ hex_beg = vertex_hex_offsets[v]
698
+ hex_end = vertex_hex_offsets[v + 1]
699
+
700
+ face_cur = face_beg
701
+
702
+ for hexa in range(hex_beg, hex_end):
703
+ hx = vertex_hex_indices[hexa]
704
+
705
+ for k in range(6):
706
+ vi = wp.vec4i(
707
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 0]],
708
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 1]],
709
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 2]],
710
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 3]],
711
+ )
712
+ min_i = int(wp.argmin(vi))
713
+
714
+ if v == vi[min_i]:
715
+ other_v = Hexmesh._face_sort(vi, min_i)
716
+
717
+ # Check if other_v has been seen
718
+ seen_idx = Hexmesh._find_face(other_v, face_other_vs, face_beg, face_cur)
719
+
720
+ if seen_idx == -1:
721
+ face_other_vs[face_cur] = other_v
722
+ face_hexes[face_cur, 0] = hx
723
+ face_hexes[face_cur, 1] = hx
724
+ face_cur += 1
725
+ else:
726
+ face_hexes[seen_idx, 1] = hx
727
+
728
+ vertex_start_face_count[v] = face_cur - face_beg
729
+
730
+ @wp.kernel
731
+ def _compress_faces_kernel(
732
+ vertex_start_face_offsets: wp.array(dtype=int),
733
+ vertex_unique_face_offsets: wp.array(dtype=int),
734
+ vertex_unique_face_count: wp.array(dtype=int),
735
+ uncompressed_face_other_vs: wp.array(dtype=wp.vec3i),
736
+ uncompressed_face_hexes: wp.array2d(dtype=int),
737
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
738
+ face_hex_indices: wp.array(dtype=wp.vec2i),
739
+ boundary_mask: wp.array(dtype=int),
740
+ ):
741
+ v = wp.tid()
742
+
743
+ start_beg = vertex_start_face_offsets[v]
744
+ unique_beg = vertex_unique_face_offsets[v]
745
+ unique_count = vertex_unique_face_count[v]
746
+
747
+ for f in range(unique_count):
748
+ src_index = start_beg + f
749
+ face_index = unique_beg + f
750
+
751
+ face_vertex_indices[face_index] = wp.vec4i(
752
+ v,
753
+ uncompressed_face_other_vs[src_index][0],
754
+ uncompressed_face_other_vs[src_index][1],
755
+ uncompressed_face_other_vs[src_index][2],
756
+ )
757
+
758
+ hx0 = uncompressed_face_hexes[src_index, 0]
759
+ hx1 = uncompressed_face_hexes[src_index, 1]
760
+ face_hex_indices[face_index] = wp.vec2i(hx0, hx1)
761
+ if hx0 == hx1:
762
+ boundary_mask[face_index] = 1
763
+ else:
764
+ boundary_mask[face_index] = 0
765
+
766
+ @wp.kernel
767
+ def _flip_face_normals(
768
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
769
+ face_hex_indices: wp.array(dtype=wp.vec2i),
770
+ hex_vertex_indices: wp.array2d(dtype=int),
771
+ positions: wp.array(dtype=wp.vec3),
772
+ ):
773
+ f = wp.tid()
774
+
775
+ hexa = face_hex_indices[f][0]
776
+
777
+ hex_vidx = hex_vertex_indices[hexa]
778
+ face_vidx = face_vertex_indices[f]
779
+
780
+ hex_centroid = (
781
+ positions[hex_vidx[0]]
782
+ + positions[hex_vidx[1]]
783
+ + positions[hex_vidx[2]]
784
+ + positions[hex_vidx[3]]
785
+ + positions[hex_vidx[4]]
786
+ + positions[hex_vidx[5]]
787
+ + positions[hex_vidx[6]]
788
+ + positions[hex_vidx[7]]
789
+ ) / 8.0
790
+
791
+ v0 = positions[face_vidx[0]]
792
+ v1 = positions[face_vidx[1]]
793
+ v2 = positions[face_vidx[2]]
794
+ v3 = positions[face_vidx[3]]
795
+
796
+ face_center = (v1 + v0 + v2 + v3) / 4.0
797
+ face_normal = wp.cross(v2 - v0, v3 - v1)
798
+
799
+ # if face normal points toward first tet centroid, flip indices
800
+ if wp.dot(hex_centroid - face_center, face_normal) > 0.0:
801
+ face_vertex_indices[f] = wp.vec4i(face_vidx[0], face_vidx[3], face_vidx[2], face_vidx[1])
802
+
803
+ @wp.func
804
+ def _find_face_orientation(face_vidx: wp.vec4i, hex_index: int, hex_vertex_indices: wp.array2d(dtype=int)):
805
+ hex_vidx = hex_vertex_indices[hex_index]
806
+
807
+ # Find local index in hex corresponding to face
808
+
809
+ face_min_i = int(wp.argmin(face_vidx))
810
+ face_other_v = Hexmesh._face_sort(face_vidx, face_min_i)
811
+
812
+ for k in range(6):
813
+ hex_face_vi = wp.vec4i(
814
+ hex_vidx[FACE_VERTEX_INDICES[k, 0]],
815
+ hex_vidx[FACE_VERTEX_INDICES[k, 1]],
816
+ hex_vidx[FACE_VERTEX_INDICES[k, 2]],
817
+ hex_vidx[FACE_VERTEX_INDICES[k, 3]],
818
+ )
819
+ hex_min_i = int(wp.argmin(hex_face_vi))
820
+ hex_other_v = Hexmesh._face_sort(hex_face_vi, hex_min_i)
821
+
822
+ if hex_other_v == face_other_v:
823
+ local_face_index = k
824
+ break
825
+
826
+ # Find starting vertex index
827
+ for k in range(4):
828
+ if face_vidx[k] == hex_face_vi[0]:
829
+ face_orientation = 2 * k
830
+ if face_vidx[(k + 1) % 4] != hex_face_vi[1]:
831
+ face_orientation += 1
832
+
833
+ return local_face_index, face_orientation
834
+
835
+ @wp.kernel
836
+ def _compute_face_orientation(
837
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
838
+ face_hex_indices: wp.array(dtype=wp.vec2i),
839
+ hex_vertex_indices: wp.array2d(dtype=int),
840
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
841
+ ):
842
+ f = wp.tid()
843
+
844
+ face_vidx = face_vertex_indices[f]
845
+
846
+ hx0 = face_hex_indices[f][0]
847
+ local_face_0, ori_0 = Hexmesh._find_face_orientation(face_vidx, hx0, hex_vertex_indices)
848
+
849
+ hx1 = face_hex_indices[f][1]
850
+ if hx0 == hx1:
851
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_0, ori_0)
852
+ else:
853
+ local_face_1, ori_1 = Hexmesh._find_face_orientation(face_vidx, hx1, hex_vertex_indices)
854
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_1, ori_1)
855
+
856
+ @wp.kernel
857
+ def _count_starting_edges_kernel(
858
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
859
+ ):
860
+ t = wp.tid()
861
+ for k in range(12):
862
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
863
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
864
+
865
+ if v0 < v1:
866
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
867
+ else:
868
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
869
+
870
+ @wp.func
871
+ def _find_edge(
872
+ needle: int,
873
+ values: wp.array(dtype=int),
874
+ beg: int,
875
+ end: int,
876
+ ):
877
+ for i in range(beg, end):
878
+ if values[i] == needle:
879
+ return i
880
+
881
+ return -1
882
+
883
+ @wp.kernel
884
+ def _count_unique_starting_edges_kernel(
885
+ vertex_hex_offsets: wp.array(dtype=int),
886
+ vertex_hex_indices: wp.array(dtype=int),
887
+ hex_vertex_indices: wp.array2d(dtype=int),
888
+ vertex_start_edge_offsets: wp.array(dtype=int),
889
+ vertex_start_edge_count: wp.array(dtype=int),
890
+ edge_ends: wp.array(dtype=int),
891
+ ):
892
+ v = wp.tid()
893
+
894
+ edge_beg = vertex_start_edge_offsets[v]
895
+
896
+ hex_beg = vertex_hex_offsets[v]
897
+ hex_end = vertex_hex_offsets[v + 1]
898
+
899
+ edge_cur = edge_beg
900
+
901
+ for tet in range(hex_beg, hex_end):
902
+ t = vertex_hex_indices[tet]
903
+
904
+ for k in range(12):
905
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
906
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
907
+
908
+ if v == wp.min(v0, v1):
909
+ other_v = wp.max(v0, v1)
910
+ if Hexmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
911
+ edge_ends[edge_cur] = other_v
912
+ edge_cur += 1
913
+
914
+ vertex_start_edge_count[v] = edge_cur - edge_beg
915
+
916
+ @wp.kernel
917
+ def _compress_edges_kernel(
918
+ vertex_hex_offsets: wp.array(dtype=int),
919
+ vertex_hex_indices: wp.array(dtype=int),
920
+ hex_vertex_indices: wp.array2d(dtype=int),
921
+ vertex_start_edge_offsets: wp.array(dtype=int),
922
+ vertex_unique_edge_offsets: wp.array(dtype=int),
923
+ vertex_unique_edge_count: wp.array(dtype=int),
924
+ uncompressed_edge_ends: wp.array(dtype=int),
925
+ hex_edge_indices: wp.array2d(dtype=int),
926
+ ):
927
+ v = wp.tid()
928
+
929
+ uncompressed_beg = vertex_start_edge_offsets[v]
930
+
931
+ unique_beg = vertex_unique_edge_offsets[v]
932
+ unique_count = vertex_unique_edge_count[v]
933
+
934
+ hex_beg = vertex_hex_offsets[v]
935
+ hex_end = vertex_hex_offsets[v + 1]
936
+
937
+ for tet in range(hex_beg, hex_end):
938
+ t = vertex_hex_indices[tet]
939
+
940
+ for k in range(12):
941
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
942
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
943
+
944
+ if v == wp.min(v0, v1):
945
+ other_v = wp.max(v0, v1)
946
+ edge_id = (
947
+ Hexmesh._find_edge(
948
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
949
+ )
950
+ - uncompressed_beg
951
+ + unique_beg
952
+ )
953
+ hex_edge_indices[t][k] = edge_id