warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,911 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import warp as wp
19
+ from warp.fem.cache import (
20
+ TemporaryStore,
21
+ borrow_temporary,
22
+ borrow_temporary_like,
23
+ cached_arg_value,
24
+ )
25
+ from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample
26
+
27
+ from .element import Cube, Square
28
+ from .geometry import Geometry
29
+
30
+
31
+ @wp.struct
32
+ class HexmeshCellArg:
33
+ hex_vertex_indices: wp.array2d(dtype=int)
34
+ positions: wp.array(dtype=wp.vec3)
35
+
36
+ # for neighbor cell lookup
37
+ vertex_hex_offsets: wp.array(dtype=int)
38
+ vertex_hex_indices: wp.array(dtype=int)
39
+
40
+
41
+ @wp.struct
42
+ class HexmeshSideArg:
43
+ cell_arg: HexmeshCellArg
44
+ face_vertex_indices: wp.array(dtype=wp.vec4i)
45
+ face_hex_indices: wp.array(dtype=wp.vec2i)
46
+ face_hex_face_orientation: wp.array(dtype=wp.vec4i)
47
+
48
+
49
+ FACE_VERTEX_INDICES = wp.constant(
50
+ wp.mat(shape=(6, 4), dtype=int)(
51
+ [
52
+ [0, 4, 7, 3], # x = 0
53
+ [1, 2, 6, 5], # x = 1
54
+ [0, 1, 5, 4], # y = 0
55
+ [3, 7, 6, 2], # y = 1
56
+ [0, 3, 2, 1], # z = 0
57
+ [4, 5, 6, 7], # z = 1
58
+ ]
59
+ )
60
+ )
61
+
62
+ EDGE_VERTEX_INDICES = wp.constant(
63
+ wp.mat(shape=(12, 2), dtype=int)(
64
+ [
65
+ [0, 1],
66
+ [1, 2],
67
+ [3, 2],
68
+ [0, 3],
69
+ [4, 5],
70
+ [5, 6],
71
+ [7, 6],
72
+ [4, 7],
73
+ [0, 4],
74
+ [1, 5],
75
+ [2, 6],
76
+ [3, 7],
77
+ ]
78
+ )
79
+ )
80
+
81
+ # orthogonal transform for face coordinates given first vertex + winding
82
+ # (two rows per entry)
83
+
84
+ FACE_ORIENTATION = [
85
+ [1, 0], # FV: 0, det: +
86
+ [0, 1],
87
+ [0, 1], # FV: 0, det: -
88
+ [1, 0],
89
+ [0, -1], # FV: 1, det: +
90
+ [1, 0],
91
+ [-1, 0], # FV: 1, det: -
92
+ [0, 1],
93
+ [-1, 0], # FV: 2, det: +
94
+ [0, -1],
95
+ [0, -1], # FV: 2, det: -
96
+ [-1, 0],
97
+ [0, 1], # FV: 3, det: +
98
+ [-1, 0],
99
+ [1, 0], # FV: 3, det: -
100
+ [0, -1],
101
+ ]
102
+
103
+ FACE_TRANSLATION = [
104
+ [0, 0],
105
+ [1, 0],
106
+ [1, 1],
107
+ [0, 1],
108
+ ]
109
+
110
+ # local face coordinate system
111
+ _FACE_COORD_INDICES = wp.constant(
112
+ wp.mat(shape=(6, 4), dtype=int)(
113
+ [
114
+ [2, 1, 0, 0], # 0: z y -x
115
+ [1, 2, 0, 1], # 1: y z x-1
116
+ [0, 2, 1, 0], # 2: x z -y
117
+ [2, 0, 1, 1], # 3: z x y-1
118
+ [1, 0, 2, 0], # 4: y x -z
119
+ [0, 1, 2, 1], # 5: x y z-1
120
+ ]
121
+ )
122
+ )
123
+
124
+ _FACE_ORIENTATION_F = wp.constant(wp.mat(shape=(16, 2), dtype=float)(FACE_ORIENTATION))
125
+ _FACE_TRANSLATION_F = wp.constant(wp.mat(shape=(4, 2), dtype=float)(FACE_TRANSLATION))
126
+
127
+
128
+ class Hexmesh(Geometry):
129
+ """Hexahedral mesh geometry"""
130
+
131
+ dimension = 3
132
+
133
+ def __init__(
134
+ self, hex_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
135
+ ):
136
+ """
137
+ Constructs a tetrahedral mesh.
138
+
139
+ Args:
140
+ hex_vertex_indices: warp array of shape (num_hexes, 8) containing vertex indices for each hex
141
+ following standard ordering (bottom face vertices in counter-clockwise order, then similarly for upper face)
142
+ positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
143
+ temporary_store: shared pool from which to allocate temporary arrays
144
+ """
145
+
146
+ self.hex_vertex_indices = hex_vertex_indices
147
+ self.positions = positions
148
+
149
+ self._face_vertex_indices: wp.array = None
150
+ self._face_hex_indices: wp.array = None
151
+ self._face_hex_face_orientation: wp.array = None
152
+ self._vertex_hex_offsets: wp.array = None
153
+ self._vertex_hex_indices: wp.array = None
154
+ self._hex_edge_indices: wp.array = None
155
+ self._edge_count = 0
156
+ self._build_topology(temporary_store)
157
+
158
+ self._make_default_dependent_implementations()
159
+
160
+ def cell_count(self):
161
+ return self.hex_vertex_indices.shape[0]
162
+
163
+ def vertex_count(self):
164
+ return self.positions.shape[0]
165
+
166
+ def side_count(self):
167
+ return self._face_vertex_indices.shape[0]
168
+
169
+ def edge_count(self):
170
+ if self._hex_edge_indices is None:
171
+ self._compute_hex_edges()
172
+ return self._edge_count
173
+
174
+ def boundary_side_count(self):
175
+ return self._boundary_face_indices.shape[0]
176
+
177
+ def reference_cell(self) -> Cube:
178
+ return Cube()
179
+
180
+ def reference_side(self) -> Square:
181
+ return Square()
182
+
183
+ @property
184
+ def hex_edge_indices(self) -> wp.array:
185
+ if self._hex_edge_indices is None:
186
+ self._compute_hex_edges()
187
+ return self._hex_edge_indices
188
+
189
+ @property
190
+ def face_hex_indices(self) -> wp.array:
191
+ return self._face_hex_indices
192
+
193
+ @property
194
+ def face_vertex_indices(self) -> wp.array:
195
+ return self._face_vertex_indices
196
+
197
+ CellArg = HexmeshCellArg
198
+ SideArg = HexmeshSideArg
199
+
200
+ @wp.struct
201
+ class SideIndexArg:
202
+ boundary_face_indices: wp.array(dtype=int)
203
+
204
+ # Geometry device interface
205
+
206
+ @cached_arg_value
207
+ def cell_arg_value(self, device) -> CellArg:
208
+ args = self.CellArg()
209
+
210
+ args.hex_vertex_indices = self.hex_vertex_indices.to(device)
211
+ args.positions = self.positions.to(device)
212
+ args.vertex_hex_offsets = self._vertex_hex_offsets.to(device)
213
+ args.vertex_hex_indices = self._vertex_hex_indices.to(device)
214
+
215
+ return args
216
+
217
+ @wp.func
218
+ def cell_position(args: CellArg, s: Sample):
219
+ hex_idx = args.hex_vertex_indices[s.element_index]
220
+
221
+ w_p = s.element_coords
222
+ w_m = Coords(1.0) - s.element_coords
223
+
224
+ # 0 : m m m
225
+ # 1 : p m m
226
+ # 2 : p p m
227
+ # 3 : m p m
228
+ # 4 : m m p
229
+ # 5 : p m p
230
+ # 6 : p p p
231
+ # 7 : m p p
232
+
233
+ return (
234
+ w_m[0] * w_m[1] * w_m[2] * args.positions[hex_idx[0]]
235
+ + w_p[0] * w_m[1] * w_m[2] * args.positions[hex_idx[1]]
236
+ + w_p[0] * w_p[1] * w_m[2] * args.positions[hex_idx[2]]
237
+ + w_m[0] * w_p[1] * w_m[2] * args.positions[hex_idx[3]]
238
+ + w_m[0] * w_m[1] * w_p[2] * args.positions[hex_idx[4]]
239
+ + w_p[0] * w_m[1] * w_p[2] * args.positions[hex_idx[5]]
240
+ + w_p[0] * w_p[1] * w_p[2] * args.positions[hex_idx[6]]
241
+ + w_m[0] * w_p[1] * w_p[2] * args.positions[hex_idx[7]]
242
+ )
243
+
244
+ @wp.func
245
+ def cell_deformation_gradient(cell_arg: CellArg, s: Sample):
246
+ """Deformation gradient at `coords`"""
247
+ """Transposed deformation gradient at `coords`"""
248
+ hex_idx = cell_arg.hex_vertex_indices[s.element_index]
249
+
250
+ w_p = s.element_coords
251
+ w_m = Coords(1.0) - s.element_coords
252
+
253
+ return (
254
+ wp.outer(cell_arg.positions[hex_idx[0]], wp.vec3(-w_m[1] * w_m[2], -w_m[0] * w_m[2], -w_m[0] * w_m[1]))
255
+ + wp.outer(cell_arg.positions[hex_idx[1]], wp.vec3(w_m[1] * w_m[2], -w_p[0] * w_m[2], -w_p[0] * w_m[1]))
256
+ + wp.outer(cell_arg.positions[hex_idx[2]], wp.vec3(w_p[1] * w_m[2], w_p[0] * w_m[2], -w_p[0] * w_p[1]))
257
+ + wp.outer(cell_arg.positions[hex_idx[3]], wp.vec3(-w_p[1] * w_m[2], w_m[0] * w_m[2], -w_m[0] * w_p[1]))
258
+ + wp.outer(cell_arg.positions[hex_idx[4]], wp.vec3(-w_m[1] * w_p[2], -w_m[0] * w_p[2], w_m[0] * w_m[1]))
259
+ + wp.outer(cell_arg.positions[hex_idx[5]], wp.vec3(w_m[1] * w_p[2], -w_p[0] * w_p[2], w_p[0] * w_m[1]))
260
+ + wp.outer(cell_arg.positions[hex_idx[6]], wp.vec3(w_p[1] * w_p[2], w_p[0] * w_p[2], w_p[0] * w_p[1]))
261
+ + wp.outer(cell_arg.positions[hex_idx[7]], wp.vec3(-w_p[1] * w_p[2], w_m[0] * w_p[2], w_m[0] * w_p[1]))
262
+ )
263
+
264
+ @cached_arg_value
265
+ def side_index_arg_value(self, device) -> SideIndexArg:
266
+ args = self.SideIndexArg()
267
+
268
+ args.boundary_face_indices = self._boundary_face_indices.to(device)
269
+
270
+ return args
271
+
272
+ @wp.func
273
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
274
+ """Boundary side to side index"""
275
+
276
+ return args.boundary_face_indices[boundary_side_index]
277
+
278
+ @cached_arg_value
279
+ def side_arg_value(self, device) -> CellArg:
280
+ args = self.SideArg()
281
+
282
+ args.cell_arg = self.cell_arg_value(device)
283
+ args.face_vertex_indices = self._face_vertex_indices.to(device)
284
+ args.face_hex_indices = self._face_hex_indices.to(device)
285
+ args.face_hex_face_orientation = self._face_hex_face_orientation.to(device)
286
+
287
+ return args
288
+
289
+ @wp.func
290
+ def side_position(args: SideArg, s: Sample):
291
+ face_idx = args.face_vertex_indices[s.element_index]
292
+
293
+ w_p = s.element_coords
294
+ w_m = Coords(1.0) - s.element_coords
295
+
296
+ return (
297
+ w_m[0] * w_m[1] * args.cell_arg.positions[face_idx[0]]
298
+ + w_p[0] * w_m[1] * args.cell_arg.positions[face_idx[1]]
299
+ + w_p[0] * w_p[1] * args.cell_arg.positions[face_idx[2]]
300
+ + w_m[0] * w_p[1] * args.cell_arg.positions[face_idx[3]]
301
+ )
302
+
303
+ @wp.func
304
+ def _side_deformation_vecs(args: SideArg, side_index: ElementIndex, coords: Coords):
305
+ face_idx = args.face_vertex_indices[side_index]
306
+
307
+ p0 = args.cell_arg.positions[face_idx[0]]
308
+ p1 = args.cell_arg.positions[face_idx[1]]
309
+ p2 = args.cell_arg.positions[face_idx[2]]
310
+ p3 = args.cell_arg.positions[face_idx[3]]
311
+
312
+ w_p = coords
313
+ w_m = Coords(1.0) - coords
314
+
315
+ v1 = w_m[1] * (p1 - p0) + w_p[1] * (p2 - p3)
316
+ v2 = w_p[0] * (p2 - p1) + w_m[0] * (p3 - p0)
317
+ return v1, v2
318
+
319
+ @wp.func
320
+ def side_deformation_gradient(args: SideArg, s: Sample):
321
+ """Transposed side deformation gradient at `coords`"""
322
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
323
+ return wp.matrix_from_cols(v1, v2)
324
+
325
+ @wp.func
326
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
327
+ return arg.face_hex_indices[side_index][0]
328
+
329
+ @wp.func
330
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
331
+ return arg.face_hex_indices[side_index][1]
332
+
333
+ @wp.func
334
+ def _hex_local_face_coords(hex_coords: Coords, face_index: int):
335
+ # Coordinatex in local face coordinates system
336
+ # Sign of last coordinate (out of face)
337
+
338
+ face_coords = wp.vec2(
339
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]], hex_coords[_FACE_COORD_INDICES[face_index, 1]]
340
+ )
341
+
342
+ normal_coord = hex_coords[_FACE_COORD_INDICES[face_index, 2]]
343
+ normal_coord = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, -normal_coord, normal_coord - 1.0)
344
+
345
+ return face_coords, normal_coord
346
+
347
+ @wp.func
348
+ def _local_face_hex_coords(face_coords: wp.vec2, face_index: int):
349
+ # Coordinates in hex from local face coordinates system
350
+
351
+ hex_coords = Coords()
352
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]] = face_coords[0]
353
+ hex_coords[_FACE_COORD_INDICES[face_index, 1]] = face_coords[1]
354
+ hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, 0.0, 1.0)
355
+
356
+ return hex_coords
357
+
358
+ @wp.func
359
+ def _local_from_oriented_face_coords(ori: int, oriented_coords: Coords):
360
+ fv = ori // 2
361
+ return (oriented_coords[0] - _FACE_TRANSLATION_F[fv, 0]) * _FACE_ORIENTATION_F[2 * ori] + (
362
+ oriented_coords[1] - _FACE_TRANSLATION_F[fv, 1]
363
+ ) * _FACE_ORIENTATION_F[2 * ori + 1]
364
+
365
+ @wp.func
366
+ def _local_to_oriented_face_coords(ori: int, coords: wp.vec2):
367
+ fv = ori // 2
368
+ return Coords(
369
+ wp.dot(_FACE_ORIENTATION_F[2 * ori], coords) + _FACE_TRANSLATION_F[fv, 0],
370
+ wp.dot(_FACE_ORIENTATION_F[2 * ori + 1], coords) + _FACE_TRANSLATION_F[fv, 1],
371
+ 0.0,
372
+ )
373
+
374
+ @wp.func
375
+ def face_to_hex_coords(local_face_index: int, face_orientation: int, side_coords: Coords):
376
+ local_coords = Hexmesh._local_from_oriented_face_coords(face_orientation, side_coords)
377
+ return Hexmesh._local_face_hex_coords(local_coords, local_face_index)
378
+
379
+ @wp.func
380
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
381
+ local_face_index = args.face_hex_face_orientation[side_index][0]
382
+ face_orientation = args.face_hex_face_orientation[side_index][1]
383
+
384
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
385
+
386
+ @wp.func
387
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
388
+ local_face_index = args.face_hex_face_orientation[side_index][2]
389
+ face_orientation = args.face_hex_face_orientation[side_index][3]
390
+
391
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
392
+
393
+ @wp.func
394
+ def side_from_cell_coords(args: SideArg, side_index: ElementIndex, hex_index: ElementIndex, hex_coords: Coords):
395
+ if Hexmesh.side_inner_cell_index(args, side_index) == hex_index:
396
+ local_face_index = args.face_hex_face_orientation[side_index][0]
397
+ face_orientation = args.face_hex_face_orientation[side_index][1]
398
+ else:
399
+ local_face_index = args.face_hex_face_orientation[side_index][2]
400
+ face_orientation = args.face_hex_face_orientation[side_index][3]
401
+
402
+ face_coords, normal_coord = Hexmesh._hex_local_face_coords(hex_coords, local_face_index)
403
+ return wp.where(
404
+ normal_coord == 0.0, Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords), Coords(OUTSIDE)
405
+ )
406
+
407
+ @wp.func
408
+ def side_to_cell_arg(side_arg: SideArg):
409
+ return side_arg.cell_arg
410
+
411
+ def _build_topology(self, temporary_store: TemporaryStore):
412
+ from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
413
+ from warp.utils import array_scan
414
+
415
+ device = self.hex_vertex_indices.device
416
+
417
+ vertex_hex_offsets, vertex_hex_indices = compress_node_indices(
418
+ self.vertex_count(), self.hex_vertex_indices, temporary_store=temporary_store
419
+ )
420
+ self._vertex_hex_offsets = vertex_hex_offsets.detach()
421
+ self._vertex_hex_indices = vertex_hex_indices.detach()
422
+
423
+ vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
424
+ vertex_start_face_count.array.zero_()
425
+ vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
426
+
427
+ vertex_face_other_vs = borrow_temporary(
428
+ temporary_store, dtype=wp.vec3i, device=device, shape=(8 * self.cell_count())
429
+ )
430
+ vertex_face_hexes = borrow_temporary(
431
+ temporary_store, dtype=int, device=device, shape=(8 * self.cell_count(), 2)
432
+ )
433
+
434
+ # Count face edges starting at each vertex
435
+ wp.launch(
436
+ kernel=Hexmesh._count_starting_faces_kernel,
437
+ device=device,
438
+ dim=self.cell_count(),
439
+ inputs=[self.hex_vertex_indices, vertex_start_face_count.array],
440
+ )
441
+
442
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_start_face_offsets.array, inclusive=False)
443
+
444
+ # Count number of unique edges (deduplicate across faces)
445
+ vertex_unique_face_count = vertex_start_face_count
446
+ wp.launch(
447
+ kernel=Hexmesh._count_unique_starting_faces_kernel,
448
+ device=device,
449
+ dim=self.vertex_count(),
450
+ inputs=[
451
+ self._vertex_hex_offsets,
452
+ self._vertex_hex_indices,
453
+ self.hex_vertex_indices,
454
+ vertex_start_face_offsets.array,
455
+ vertex_unique_face_count.array,
456
+ vertex_face_other_vs.array,
457
+ vertex_face_hexes.array,
458
+ ],
459
+ )
460
+
461
+ vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
462
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
463
+
464
+ # Get back edge count to host
465
+ face_count = int(
466
+ host_read_at_index(
467
+ vertex_unique_face_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
468
+ )
469
+ )
470
+
471
+ self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
472
+ self._face_hex_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
473
+ self._face_hex_face_orientation = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
474
+
475
+ boundary_mask = borrow_temporary(temporary_store, shape=(face_count,), dtype=int, device=device)
476
+
477
+ # Compress edge data
478
+ wp.launch(
479
+ kernel=Hexmesh._compress_faces_kernel,
480
+ device=device,
481
+ dim=self.vertex_count(),
482
+ inputs=[
483
+ vertex_start_face_offsets.array,
484
+ vertex_unique_face_offsets.array,
485
+ vertex_unique_face_count.array,
486
+ vertex_face_other_vs.array,
487
+ vertex_face_hexes.array,
488
+ self._face_vertex_indices,
489
+ self._face_hex_indices,
490
+ boundary_mask.array,
491
+ ],
492
+ )
493
+
494
+ vertex_start_face_offsets.release()
495
+ vertex_unique_face_offsets.release()
496
+ vertex_unique_face_count.release()
497
+ vertex_face_other_vs.release()
498
+ vertex_face_hexes.release()
499
+
500
+ # Flip normals if necessary
501
+ wp.launch(
502
+ kernel=Hexmesh._flip_face_normals,
503
+ device=device,
504
+ dim=self.side_count(),
505
+ inputs=[self._face_vertex_indices, self._face_hex_indices, self.hex_vertex_indices, self.positions],
506
+ )
507
+
508
+ # Compute and store face orientation
509
+ wp.launch(
510
+ kernel=Hexmesh._compute_face_orientation,
511
+ device=device,
512
+ dim=self.side_count(),
513
+ inputs=[
514
+ self._face_vertex_indices,
515
+ self._face_hex_indices,
516
+ self.hex_vertex_indices,
517
+ self._face_hex_face_orientation,
518
+ ],
519
+ )
520
+
521
+ boundary_face_indices, _ = masked_indices(boundary_mask.array)
522
+ self._boundary_face_indices = boundary_face_indices.detach()
523
+
524
+ def _compute_hex_edges(self, temporary_store: Optional[TemporaryStore] = None):
525
+ from warp.fem.utils import host_read_at_index
526
+ from warp.utils import array_scan
527
+
528
+ device = self.hex_vertex_indices.device
529
+
530
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
531
+ vertex_start_edge_count.array.zero_()
532
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
533
+
534
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(12 * self.cell_count()))
535
+
536
+ # Count face edges starting at each vertex
537
+ wp.launch(
538
+ kernel=Hexmesh._count_starting_edges_kernel,
539
+ device=device,
540
+ dim=self.cell_count(),
541
+ inputs=[self.hex_vertex_indices, vertex_start_edge_count.array],
542
+ )
543
+
544
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
545
+
546
+ # Count number of unique edges (deduplicate across faces)
547
+ vertex_unique_edge_count = vertex_start_edge_count
548
+ wp.launch(
549
+ kernel=Hexmesh._count_unique_starting_edges_kernel,
550
+ device=device,
551
+ dim=self.vertex_count(),
552
+ inputs=[
553
+ self._vertex_hex_offsets,
554
+ self._vertex_hex_indices,
555
+ self.hex_vertex_indices,
556
+ vertex_start_edge_offsets.array,
557
+ vertex_unique_edge_count.array,
558
+ vertex_edge_ends.array,
559
+ ],
560
+ )
561
+
562
+ vertex_unique_edge_offsets = borrow_temporary_like(
563
+ vertex_start_edge_offsets.array, temporary_store=temporary_store
564
+ )
565
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
566
+
567
+ # Get back edge count to host
568
+ self._edge_count = int(
569
+ host_read_at_index(
570
+ vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
571
+ )
572
+ )
573
+
574
+ self._hex_edge_indices = wp.empty(
575
+ dtype=int, device=self.hex_vertex_indices.device, shape=(self.cell_count(), 12)
576
+ )
577
+
578
+ # Compress edge data
579
+ wp.launch(
580
+ kernel=Hexmesh._compress_edges_kernel,
581
+ device=device,
582
+ dim=self.vertex_count(),
583
+ inputs=[
584
+ self._vertex_hex_offsets,
585
+ self._vertex_hex_indices,
586
+ self.hex_vertex_indices,
587
+ vertex_start_edge_offsets.array,
588
+ vertex_unique_edge_offsets.array,
589
+ vertex_unique_edge_count.array,
590
+ vertex_edge_ends.array,
591
+ self._hex_edge_indices,
592
+ ],
593
+ )
594
+
595
+ vertex_start_edge_offsets.release()
596
+ vertex_unique_edge_offsets.release()
597
+ vertex_unique_edge_count.release()
598
+ vertex_edge_ends.release()
599
+
600
+ @wp.kernel
601
+ def _count_starting_faces_kernel(
602
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_face_count: wp.array(dtype=int)
603
+ ):
604
+ t = wp.tid()
605
+ for k in range(6):
606
+ vi = wp.vec4i(
607
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 0]],
608
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 1]],
609
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 2]],
610
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 3]],
611
+ )
612
+ vm = wp.min(vi)
613
+
614
+ for i in range(4):
615
+ if vm == vi[i]:
616
+ wp.atomic_add(vertex_start_face_count, vm, 1)
617
+
618
+ @wp.func
619
+ def _face_sort(vidx: wp.vec4i, min_k: int):
620
+ v1 = vidx[(min_k + 1) % 4]
621
+ v2 = vidx[(min_k + 2) % 4]
622
+ v3 = vidx[(min_k + 3) % 4]
623
+
624
+ if v1 < v3:
625
+ return wp.vec3i(v1, v2, v3)
626
+ return wp.vec3i(v3, v2, v1)
627
+
628
+ @wp.func
629
+ def _find_face(
630
+ needle: wp.vec3i,
631
+ values: wp.array(dtype=wp.vec3i),
632
+ beg: int,
633
+ end: int,
634
+ ):
635
+ for i in range(beg, end):
636
+ if values[i] == needle:
637
+ return i
638
+
639
+ return -1
640
+
641
+ @wp.kernel
642
+ def _count_unique_starting_faces_kernel(
643
+ vertex_hex_offsets: wp.array(dtype=int),
644
+ vertex_hex_indices: wp.array(dtype=int),
645
+ hex_vertex_indices: wp.array2d(dtype=int),
646
+ vertex_start_face_offsets: wp.array(dtype=int),
647
+ vertex_start_face_count: wp.array(dtype=int),
648
+ face_other_vs: wp.array(dtype=wp.vec3i),
649
+ face_hexes: wp.array2d(dtype=int),
650
+ ):
651
+ v = wp.tid()
652
+
653
+ face_beg = vertex_start_face_offsets[v]
654
+
655
+ hex_beg = vertex_hex_offsets[v]
656
+ hex_end = vertex_hex_offsets[v + 1]
657
+
658
+ face_cur = face_beg
659
+
660
+ for hexa in range(hex_beg, hex_end):
661
+ hx = vertex_hex_indices[hexa]
662
+
663
+ for k in range(6):
664
+ vi = wp.vec4i(
665
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 0]],
666
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 1]],
667
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 2]],
668
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 3]],
669
+ )
670
+ min_i = int(wp.argmin(vi))
671
+
672
+ if v == vi[min_i]:
673
+ other_v = Hexmesh._face_sort(vi, min_i)
674
+
675
+ # Check if other_v has been seen
676
+ seen_idx = Hexmesh._find_face(other_v, face_other_vs, face_beg, face_cur)
677
+
678
+ if seen_idx == -1:
679
+ face_other_vs[face_cur] = other_v
680
+ face_hexes[face_cur, 0] = hx
681
+ face_hexes[face_cur, 1] = hx
682
+ face_cur += 1
683
+ else:
684
+ face_hexes[seen_idx, 1] = hx
685
+
686
+ vertex_start_face_count[v] = face_cur - face_beg
687
+
688
+ @wp.kernel
689
+ def _compress_faces_kernel(
690
+ vertex_start_face_offsets: wp.array(dtype=int),
691
+ vertex_unique_face_offsets: wp.array(dtype=int),
692
+ vertex_unique_face_count: wp.array(dtype=int),
693
+ uncompressed_face_other_vs: wp.array(dtype=wp.vec3i),
694
+ uncompressed_face_hexes: wp.array2d(dtype=int),
695
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
696
+ face_hex_indices: wp.array(dtype=wp.vec2i),
697
+ boundary_mask: wp.array(dtype=int),
698
+ ):
699
+ v = wp.tid()
700
+
701
+ start_beg = vertex_start_face_offsets[v]
702
+ unique_beg = vertex_unique_face_offsets[v]
703
+ unique_count = vertex_unique_face_count[v]
704
+
705
+ for f in range(unique_count):
706
+ src_index = start_beg + f
707
+ face_index = unique_beg + f
708
+
709
+ face_vertex_indices[face_index] = wp.vec4i(
710
+ v,
711
+ uncompressed_face_other_vs[src_index][0],
712
+ uncompressed_face_other_vs[src_index][1],
713
+ uncompressed_face_other_vs[src_index][2],
714
+ )
715
+
716
+ hx0 = uncompressed_face_hexes[src_index, 0]
717
+ hx1 = uncompressed_face_hexes[src_index, 1]
718
+ face_hex_indices[face_index] = wp.vec2i(hx0, hx1)
719
+ if hx0 == hx1:
720
+ boundary_mask[face_index] = 1
721
+ else:
722
+ boundary_mask[face_index] = 0
723
+
724
+ @wp.kernel
725
+ def _flip_face_normals(
726
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
727
+ face_hex_indices: wp.array(dtype=wp.vec2i),
728
+ hex_vertex_indices: wp.array2d(dtype=int),
729
+ positions: wp.array(dtype=wp.vec3),
730
+ ):
731
+ f = wp.tid()
732
+
733
+ hexa = face_hex_indices[f][0]
734
+
735
+ hex_vidx = hex_vertex_indices[hexa]
736
+ face_vidx = face_vertex_indices[f]
737
+
738
+ hex_centroid = (
739
+ positions[hex_vidx[0]]
740
+ + positions[hex_vidx[1]]
741
+ + positions[hex_vidx[2]]
742
+ + positions[hex_vidx[3]]
743
+ + positions[hex_vidx[4]]
744
+ + positions[hex_vidx[5]]
745
+ + positions[hex_vidx[6]]
746
+ + positions[hex_vidx[7]]
747
+ ) / 8.0
748
+
749
+ v0 = positions[face_vidx[0]]
750
+ v1 = positions[face_vidx[1]]
751
+ v2 = positions[face_vidx[2]]
752
+ v3 = positions[face_vidx[3]]
753
+
754
+ face_center = (v1 + v0 + v2 + v3) / 4.0
755
+ face_normal = wp.cross(v2 - v0, v3 - v1)
756
+
757
+ # if face normal points toward first tet centroid, flip indices
758
+ if wp.dot(hex_centroid - face_center, face_normal) > 0.0:
759
+ face_vertex_indices[f] = wp.vec4i(face_vidx[0], face_vidx[3], face_vidx[2], face_vidx[1])
760
+
761
+ @wp.func
762
+ def _find_face_orientation(face_vidx: wp.vec4i, hex_index: int, hex_vertex_indices: wp.array2d(dtype=int)):
763
+ hex_vidx = hex_vertex_indices[hex_index]
764
+
765
+ # Find local index in hex corresponding to face
766
+
767
+ face_min_i = int(wp.argmin(face_vidx))
768
+ face_other_v = Hexmesh._face_sort(face_vidx, face_min_i)
769
+
770
+ for k in range(6):
771
+ hex_face_vi = wp.vec4i(
772
+ hex_vidx[FACE_VERTEX_INDICES[k, 0]],
773
+ hex_vidx[FACE_VERTEX_INDICES[k, 1]],
774
+ hex_vidx[FACE_VERTEX_INDICES[k, 2]],
775
+ hex_vidx[FACE_VERTEX_INDICES[k, 3]],
776
+ )
777
+ hex_min_i = int(wp.argmin(hex_face_vi))
778
+ hex_other_v = Hexmesh._face_sort(hex_face_vi, hex_min_i)
779
+
780
+ if hex_other_v == face_other_v:
781
+ local_face_index = k
782
+ break
783
+
784
+ # Find starting vertex index
785
+ for k in range(4):
786
+ if face_vidx[k] == hex_face_vi[0]:
787
+ face_orientation = 2 * k
788
+ if face_vidx[(k + 1) % 4] != hex_face_vi[1]:
789
+ face_orientation += 1
790
+
791
+ return local_face_index, face_orientation
792
+
793
+ @wp.kernel
794
+ def _compute_face_orientation(
795
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
796
+ face_hex_indices: wp.array(dtype=wp.vec2i),
797
+ hex_vertex_indices: wp.array2d(dtype=int),
798
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
799
+ ):
800
+ f = wp.tid()
801
+
802
+ face_vidx = face_vertex_indices[f]
803
+
804
+ hx0 = face_hex_indices[f][0]
805
+ local_face_0, ori_0 = Hexmesh._find_face_orientation(face_vidx, hx0, hex_vertex_indices)
806
+
807
+ hx1 = face_hex_indices[f][1]
808
+ if hx0 == hx1:
809
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_0, ori_0)
810
+ else:
811
+ local_face_1, ori_1 = Hexmesh._find_face_orientation(face_vidx, hx1, hex_vertex_indices)
812
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_1, ori_1)
813
+
814
+ @wp.kernel
815
+ def _count_starting_edges_kernel(
816
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
817
+ ):
818
+ t = wp.tid()
819
+ for k in range(12):
820
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
821
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
822
+
823
+ if v0 < v1:
824
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
825
+ else:
826
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
827
+
828
+ @wp.func
829
+ def _find_edge(
830
+ needle: int,
831
+ values: wp.array(dtype=int),
832
+ beg: int,
833
+ end: int,
834
+ ):
835
+ for i in range(beg, end):
836
+ if values[i] == needle:
837
+ return i
838
+
839
+ return -1
840
+
841
+ @wp.kernel
842
+ def _count_unique_starting_edges_kernel(
843
+ vertex_hex_offsets: wp.array(dtype=int),
844
+ vertex_hex_indices: wp.array(dtype=int),
845
+ hex_vertex_indices: wp.array2d(dtype=int),
846
+ vertex_start_edge_offsets: wp.array(dtype=int),
847
+ vertex_start_edge_count: wp.array(dtype=int),
848
+ edge_ends: wp.array(dtype=int),
849
+ ):
850
+ v = wp.tid()
851
+
852
+ edge_beg = vertex_start_edge_offsets[v]
853
+
854
+ hex_beg = vertex_hex_offsets[v]
855
+ hex_end = vertex_hex_offsets[v + 1]
856
+
857
+ edge_cur = edge_beg
858
+
859
+ for tet in range(hex_beg, hex_end):
860
+ t = vertex_hex_indices[tet]
861
+
862
+ for k in range(12):
863
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
864
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
865
+
866
+ if v == wp.min(v0, v1):
867
+ other_v = wp.max(v0, v1)
868
+ if Hexmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
869
+ edge_ends[edge_cur] = other_v
870
+ edge_cur += 1
871
+
872
+ vertex_start_edge_count[v] = edge_cur - edge_beg
873
+
874
+ @wp.kernel
875
+ def _compress_edges_kernel(
876
+ vertex_hex_offsets: wp.array(dtype=int),
877
+ vertex_hex_indices: wp.array(dtype=int),
878
+ hex_vertex_indices: wp.array2d(dtype=int),
879
+ vertex_start_edge_offsets: wp.array(dtype=int),
880
+ vertex_unique_edge_offsets: wp.array(dtype=int),
881
+ vertex_unique_edge_count: wp.array(dtype=int),
882
+ uncompressed_edge_ends: wp.array(dtype=int),
883
+ hex_edge_indices: wp.array2d(dtype=int),
884
+ ):
885
+ v = wp.tid()
886
+
887
+ uncompressed_beg = vertex_start_edge_offsets[v]
888
+
889
+ unique_beg = vertex_unique_edge_offsets[v]
890
+ unique_count = vertex_unique_edge_count[v]
891
+
892
+ hex_beg = vertex_hex_offsets[v]
893
+ hex_end = vertex_hex_offsets[v + 1]
894
+
895
+ for tet in range(hex_beg, hex_end):
896
+ t = vertex_hex_indices[tet]
897
+
898
+ for k in range(12):
899
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
900
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
901
+
902
+ if v == wp.min(v0, v1):
903
+ other_v = wp.max(v0, v1)
904
+ edge_id = (
905
+ Hexmesh._find_edge(
906
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
907
+ )
908
+ - uncompressed_beg
909
+ + unique_beg
910
+ )
911
+ hex_edge_indices[t][k] = edge_id