warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,855 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import warp as wp
19
+ from warp.fem.cache import (
20
+ TemporaryStore,
21
+ borrow_temporary,
22
+ borrow_temporary_like,
23
+ cached_arg_value,
24
+ )
25
+ from warp.fem.types import (
26
+ NULL_ELEMENT_INDEX,
27
+ OUTSIDE,
28
+ Coords,
29
+ ElementIndex,
30
+ Sample,
31
+ make_free_sample,
32
+ )
33
+
34
+ from .closest_point import project_on_tet_at_origin
35
+ from .element import Tetrahedron, Triangle
36
+ from .geometry import Geometry
37
+
38
+
39
+ @wp.struct
40
+ class TetmeshCellArg:
41
+ tet_vertex_indices: wp.array2d(dtype=int)
42
+ positions: wp.array(dtype=wp.vec3)
43
+
44
+ # for neighbor cell lookup
45
+ vertex_tet_offsets: wp.array(dtype=int)
46
+ vertex_tet_indices: wp.array(dtype=int)
47
+
48
+ # for global cell lookup
49
+ tet_bvh: wp.uint64
50
+
51
+
52
+ @wp.struct
53
+ class TetmeshSideArg:
54
+ cell_arg: TetmeshCellArg
55
+ face_vertex_indices: wp.array(dtype=wp.vec3i)
56
+ face_tet_indices: wp.array(dtype=wp.vec2i)
57
+
58
+
59
+ _NULL_BVH = wp.constant(wp.uint64(-1))
60
+
61
+
62
+ class Tetmesh(Geometry):
63
+ """Tetrahedral mesh geometry"""
64
+
65
+ dimension = 3
66
+
67
+ def __init__(
68
+ self,
69
+ tet_vertex_indices: wp.array,
70
+ positions: wp.array,
71
+ build_bvh: bool = False,
72
+ temporary_store: Optional[TemporaryStore] = None,
73
+ ):
74
+ """
75
+ Constructs a tetrahedral mesh.
76
+
77
+ Args:
78
+ tet_vertex_indices: warp array of shape (num_tets, 4) containing vertex indices for each tet
79
+ positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
80
+ temporary_store: shared pool from which to allocate temporary arrays
81
+ build_bvh: Whether to also build the tet BVH, which is necessary for the global `fem.lookup` operator to function without initial guess
82
+ """
83
+
84
+ self.tet_vertex_indices = tet_vertex_indices
85
+ self.positions = positions
86
+
87
+ self._face_vertex_indices: wp.array = None
88
+ self._face_tet_indices: wp.array = None
89
+ self._vertex_tet_offsets: wp.array = None
90
+ self._vertex_tet_indices: wp.array = None
91
+ self._tet_edge_indices: wp.array = None
92
+ self._edge_count = 0
93
+ self._build_topology(temporary_store)
94
+
95
+ self._make_default_dependent_implementations()
96
+
97
+ self._tet_bvh: wp.Bvh = None
98
+ if build_bvh:
99
+ self._build_bvh()
100
+
101
+ def update_bvh(self, force_rebuild: bool = False):
102
+ """
103
+ Refits the BVH, or rebuilds it from scratch if `force_rebuild` is ``True``.
104
+ """
105
+
106
+ if self._tet_bvh is None or force_rebuild:
107
+ return self.build_bvh()
108
+
109
+ wp.launch(
110
+ Tetmesh._compute_tet_bounds,
111
+ self.tet_vertex_indices,
112
+ self.positions,
113
+ self._tet_bvh.lowers,
114
+ self._tet_bvh.uppers,
115
+ )
116
+ self._tet_bvh.refit()
117
+
118
+ def _build_bvh(self, temporary_store: Optional[TemporaryStore] = None):
119
+ lowers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
120
+ uppers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
121
+ wp.launch(
122
+ Tetmesh._compute_tet_bounds,
123
+ device=self.positions.device,
124
+ dim=self.cell_count(),
125
+ inputs=[self.tet_vertex_indices, self.positions, lowers, uppers],
126
+ )
127
+ self._tet_bvh = wp.Bvh(lowers, uppers)
128
+
129
+ def cell_count(self):
130
+ return self.tet_vertex_indices.shape[0]
131
+
132
+ def vertex_count(self):
133
+ return self.positions.shape[0]
134
+
135
+ def side_count(self):
136
+ return self._face_vertex_indices.shape[0]
137
+
138
+ def edge_count(self):
139
+ if self._tet_edge_indices is None:
140
+ self._compute_tet_edges()
141
+ return self._edge_count
142
+
143
+ def boundary_side_count(self):
144
+ return self._boundary_face_indices.shape[0]
145
+
146
+ def reference_cell(self) -> Tetrahedron:
147
+ return Tetrahedron()
148
+
149
+ def reference_side(self) -> Triangle:
150
+ return Triangle()
151
+
152
+ @property
153
+ def tet_edge_indices(self) -> wp.array:
154
+ if self._tet_edge_indices is None:
155
+ self._compute_tet_edges()
156
+ return self._tet_edge_indices
157
+
158
+ @property
159
+ def face_tet_indices(self) -> wp.array:
160
+ return self._face_tet_indices
161
+
162
+ @property
163
+ def face_vertex_indices(self) -> wp.array:
164
+ return self._face_vertex_indices
165
+
166
+ CellArg = TetmeshCellArg
167
+ SideArg = TetmeshSideArg
168
+
169
+ @wp.struct
170
+ class SideIndexArg:
171
+ boundary_face_indices: wp.array(dtype=int)
172
+
173
+ # Geometry device interface
174
+
175
+ @cached_arg_value
176
+ def cell_arg_value(self, device) -> CellArg:
177
+ args = self.CellArg()
178
+
179
+ args.tet_vertex_indices = self.tet_vertex_indices.to(device)
180
+ args.positions = self.positions.to(device)
181
+ args.vertex_tet_offsets = self._vertex_tet_offsets.to(device)
182
+ args.vertex_tet_indices = self._vertex_tet_indices.to(device)
183
+
184
+ args.tet_bvh = _NULL_BVH if self._tet_bvh is None else self._tet_bvh.id
185
+
186
+ return args
187
+
188
+ @wp.func
189
+ def cell_position(args: CellArg, s: Sample):
190
+ tet_idx = args.tet_vertex_indices[s.element_index]
191
+ w0 = 1.0 - s.element_coords[0] - s.element_coords[1] - s.element_coords[2]
192
+ return (
193
+ w0 * args.positions[tet_idx[0]]
194
+ + s.element_coords[0] * args.positions[tet_idx[1]]
195
+ + s.element_coords[1] * args.positions[tet_idx[2]]
196
+ + s.element_coords[2] * args.positions[tet_idx[3]]
197
+ )
198
+
199
+ @wp.func
200
+ def cell_deformation_gradient(args: CellArg, s: Sample):
201
+ p0 = args.positions[args.tet_vertex_indices[s.element_index, 0]]
202
+ p1 = args.positions[args.tet_vertex_indices[s.element_index, 1]]
203
+ p2 = args.positions[args.tet_vertex_indices[s.element_index, 2]]
204
+ p3 = args.positions[args.tet_vertex_indices[s.element_index, 3]]
205
+ return wp.matrix_from_cols(p1 - p0, p2 - p0, p3 - p0)
206
+
207
+ @wp.func
208
+ def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
209
+ return wp.inverse(Tetmesh.cell_deformation_gradient(args, s))
210
+
211
+ @wp.func
212
+ def _project_on_tet(args: CellArg, pos: wp.vec3, tet_index: int):
213
+ p0 = args.positions[args.tet_vertex_indices[tet_index, 0]]
214
+
215
+ q = pos - p0
216
+ e1 = args.positions[args.tet_vertex_indices[tet_index, 1]] - p0
217
+ e2 = args.positions[args.tet_vertex_indices[tet_index, 2]] - p0
218
+ e3 = args.positions[args.tet_vertex_indices[tet_index, 3]] - p0
219
+
220
+ dist, coords = project_on_tet_at_origin(q, e1, e2, e3)
221
+ return dist, coords
222
+
223
+ @wp.func
224
+ def _bvh_lookup(args: CellArg, pos: wp.vec3):
225
+ closest_tet = int(NULL_ELEMENT_INDEX)
226
+ closest_coords = Coords(OUTSIDE)
227
+ closest_dist = float(1.0e8)
228
+
229
+ if args.tet_bvh != _NULL_BVH:
230
+ query = wp.bvh_query_aabb(args.tet_bvh, pos, pos)
231
+ tet = int(0)
232
+ while wp.bvh_query_next(query, tet):
233
+ dist, coords = Tetmesh._project_on_tet(args, pos, tet)
234
+ if dist <= closest_dist:
235
+ closest_dist = dist
236
+ closest_tet = tet
237
+ closest_coords = coords
238
+
239
+ return closest_dist, closest_tet, closest_coords
240
+
241
+ @wp.func
242
+ def cell_lookup(args: CellArg, pos: wp.vec3):
243
+ closest_dist, closest_tet, closest_coords = Tetmesh._bvh_lookup(args, pos)
244
+
245
+ return make_free_sample(closest_tet, closest_coords)
246
+
247
+ @wp.func
248
+ def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
249
+ closest_dist, closest_tet, closest_coords = Tetmesh._bvh_lookup(args, pos)
250
+ return make_free_sample(closest_tet, closest_coords)
251
+
252
+ if closest_tet == NULL_ELEMENT_INDEX:
253
+ # nothing found yet, bvh may not be available or outside mesh
254
+ for v in range(4):
255
+ vtx = args.tet_vertex_indices[guess.element_index, v]
256
+ tet_beg = args.vertex_tet_offsets[vtx]
257
+ tet_end = args.vertex_tet_offsets[vtx + 1]
258
+
259
+ for t in range(tet_beg, tet_end):
260
+ tet = args.vertex_tet_indices[t]
261
+ dist, coords = Tetmesh._project_on_tet(args, pos, tet)
262
+ if dist <= closest_dist:
263
+ closest_dist = dist
264
+ closest_tet = tet
265
+ closest_coords = coords
266
+
267
+ return make_free_sample(closest_tet, closest_coords)
268
+
269
+ @cached_arg_value
270
+ def side_index_arg_value(self, device) -> SideIndexArg:
271
+ args = self.SideIndexArg()
272
+
273
+ args.boundary_face_indices = self._boundary_face_indices.to(device)
274
+
275
+ return args
276
+
277
+ @wp.func
278
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
279
+ """Boundary side to side index"""
280
+
281
+ return args.boundary_face_indices[boundary_side_index]
282
+
283
+ @cached_arg_value
284
+ def side_arg_value(self, device) -> CellArg:
285
+ args = self.SideArg()
286
+
287
+ args.cell_arg = self.cell_arg_value(device)
288
+ args.face_vertex_indices = self._face_vertex_indices.to(device)
289
+ args.face_tet_indices = self._face_tet_indices.to(device)
290
+
291
+ return args
292
+
293
+ @wp.func
294
+ def side_position(args: SideArg, s: Sample):
295
+ face_idx = args.face_vertex_indices[s.element_index]
296
+ return (
297
+ s.element_coords[0] * args.cell_arg.positions[face_idx[0]]
298
+ + s.element_coords[1] * args.cell_arg.positions[face_idx[1]]
299
+ + s.element_coords[2] * args.cell_arg.positions[face_idx[2]]
300
+ )
301
+
302
+ @wp.func
303
+ def _side_vecs(args: SideArg, side_index: ElementIndex):
304
+ face_idx = args.face_vertex_indices[side_index]
305
+ v0 = args.cell_arg.positions[face_idx[0]]
306
+ v1 = args.cell_arg.positions[face_idx[1]]
307
+ v2 = args.cell_arg.positions[face_idx[2]]
308
+
309
+ return v1 - v0, v2 - v0
310
+
311
+ @wp.func
312
+ def side_deformation_gradient(args: SideArg, s: Sample):
313
+ e1, e2 = Tetmesh._side_vecs(args, s.element_index)
314
+ return wp.matrix_from_cols(e1, e2)
315
+
316
+ @wp.func
317
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
318
+ return arg.face_tet_indices[side_index][0]
319
+
320
+ @wp.func
321
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
322
+ return arg.face_tet_indices[side_index][1]
323
+
324
+ @wp.func
325
+ def face_to_tet_coords(args: SideArg, side_index: ElementIndex, tet_index: ElementIndex, side_coords: Coords):
326
+ fvi = args.face_vertex_indices[side_index]
327
+
328
+ tv1 = args.cell_arg.tet_vertex_indices[tet_index, 1]
329
+ tv2 = args.cell_arg.tet_vertex_indices[tet_index, 2]
330
+ tv3 = args.cell_arg.tet_vertex_indices[tet_index, 3]
331
+
332
+ c1 = float(0.0)
333
+ c2 = float(0.0)
334
+ c3 = float(0.0)
335
+
336
+ for k in range(3):
337
+ if tv1 == fvi[k]:
338
+ c1 = side_coords[k]
339
+ elif tv2 == fvi[k]:
340
+ c2 = side_coords[k]
341
+ elif tv3 == fvi[k]:
342
+ c3 = side_coords[k]
343
+
344
+ return Coords(c1, c2, c3)
345
+
346
+ @wp.func
347
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
348
+ inner_cell_index = Tetmesh.side_inner_cell_index(args, side_index)
349
+ return Tetmesh.face_to_tet_coords(args, side_index, inner_cell_index, side_coords)
350
+
351
+ @wp.func
352
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
353
+ outer_cell_index = Tetmesh.side_outer_cell_index(args, side_index)
354
+ return Tetmesh.face_to_tet_coords(args, side_index, outer_cell_index, side_coords)
355
+
356
+ @wp.func
357
+ def side_from_cell_coords(args: SideArg, side_index: ElementIndex, tet_index: ElementIndex, tet_coords: Coords):
358
+ fvi = args.face_vertex_indices[side_index]
359
+
360
+ tv1 = args.cell_arg.tet_vertex_indices[tet_index, 1]
361
+ tv2 = args.cell_arg.tet_vertex_indices[tet_index, 2]
362
+ tv3 = args.cell_arg.tet_vertex_indices[tet_index, 3]
363
+
364
+ if tv1 == fvi[0]:
365
+ c0 = tet_coords[0]
366
+ elif tv2 == fvi[0]:
367
+ c0 = tet_coords[1]
368
+ elif tv3 == fvi[0]:
369
+ c0 = tet_coords[2]
370
+ else:
371
+ c0 = 1.0 - tet_coords[0] - tet_coords[1] - tet_coords[2]
372
+
373
+ if tv1 == fvi[1]:
374
+ c1 = tet_coords[0]
375
+ elif tv2 == fvi[1]:
376
+ c1 = tet_coords[1]
377
+ elif tv3 == fvi[1]:
378
+ c1 = tet_coords[2]
379
+ else:
380
+ c1 = 1.0 - tet_coords[0] - tet_coords[1] - tet_coords[2]
381
+
382
+ if tv1 == fvi[2]:
383
+ c2 = tet_coords[0]
384
+ elif tv2 == fvi[2]:
385
+ c2 = tet_coords[1]
386
+ elif tv3 == fvi[2]:
387
+ c2 = tet_coords[2]
388
+ else:
389
+ c2 = 1.0 - tet_coords[0] - tet_coords[1] - tet_coords[2]
390
+
391
+ return wp.where(c0 + c1 + c2 > 0.999, Coords(c0, c1, c2), Coords(OUTSIDE))
392
+
393
+ @wp.func
394
+ def side_to_cell_arg(side_arg: SideArg):
395
+ return side_arg.cell_arg
396
+
397
+ def _build_topology(self, temporary_store: TemporaryStore):
398
+ from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
399
+ from warp.utils import array_scan
400
+
401
+ device = self.tet_vertex_indices.device
402
+
403
+ vertex_tet_offsets, vertex_tet_indices = compress_node_indices(
404
+ self.vertex_count(), self.tet_vertex_indices, temporary_store=temporary_store
405
+ )
406
+ self._vertex_tet_offsets = vertex_tet_offsets.detach()
407
+ self._vertex_tet_indices = vertex_tet_indices.detach()
408
+
409
+ vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
410
+ vertex_start_face_count.array.zero_()
411
+ vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
412
+
413
+ vertex_face_other_vs = borrow_temporary(
414
+ temporary_store, dtype=wp.vec2i, device=device, shape=(4 * self.cell_count())
415
+ )
416
+ vertex_face_tets = borrow_temporary(temporary_store, dtype=int, device=device, shape=(4 * self.cell_count(), 2))
417
+
418
+ # Count face edges starting at each vertex
419
+ wp.launch(
420
+ kernel=Tetmesh._count_starting_faces_kernel,
421
+ device=device,
422
+ dim=self.cell_count(),
423
+ inputs=[self.tet_vertex_indices, vertex_start_face_count.array],
424
+ )
425
+
426
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_start_face_offsets.array, inclusive=False)
427
+
428
+ # Count number of unique edges (deduplicate across faces)
429
+ vertex_unique_face_count = vertex_start_face_count
430
+ wp.launch(
431
+ kernel=Tetmesh._count_unique_starting_faces_kernel,
432
+ device=device,
433
+ dim=self.vertex_count(),
434
+ inputs=[
435
+ self._vertex_tet_offsets,
436
+ self._vertex_tet_indices,
437
+ self.tet_vertex_indices,
438
+ vertex_start_face_offsets.array,
439
+ vertex_unique_face_count.array,
440
+ vertex_face_other_vs.array,
441
+ vertex_face_tets.array,
442
+ ],
443
+ )
444
+
445
+ vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
446
+ array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
447
+
448
+ # Get back edge count to host
449
+ face_count = int(
450
+ host_read_at_index(
451
+ vertex_unique_face_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
452
+ )
453
+ )
454
+
455
+ self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec3i, device=device)
456
+ self._face_tet_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
457
+
458
+ boundary_mask = borrow_temporary(temporary_store, shape=(face_count,), dtype=int, device=device)
459
+
460
+ # Compress edge data
461
+ wp.launch(
462
+ kernel=Tetmesh._compress_faces_kernel,
463
+ device=device,
464
+ dim=self.vertex_count(),
465
+ inputs=[
466
+ vertex_start_face_offsets.array,
467
+ vertex_unique_face_offsets.array,
468
+ vertex_unique_face_count.array,
469
+ vertex_face_other_vs.array,
470
+ vertex_face_tets.array,
471
+ self._face_vertex_indices,
472
+ self._face_tet_indices,
473
+ boundary_mask.array,
474
+ ],
475
+ )
476
+
477
+ vertex_start_face_offsets.release()
478
+ vertex_unique_face_offsets.release()
479
+ vertex_unique_face_count.release()
480
+ vertex_face_other_vs.release()
481
+ vertex_face_tets.release()
482
+
483
+ # Flip normals if necessary
484
+ wp.launch(
485
+ kernel=Tetmesh._flip_face_normals,
486
+ device=device,
487
+ dim=self.side_count(),
488
+ inputs=[self._face_vertex_indices, self._face_tet_indices, self.tet_vertex_indices, self.positions],
489
+ )
490
+
491
+ boundary_face_indices, _ = masked_indices(boundary_mask.array)
492
+ self._boundary_face_indices = boundary_face_indices.detach()
493
+
494
+ def _compute_tet_edges(self, temporary_store: Optional[TemporaryStore] = None):
495
+ from warp.fem.utils import host_read_at_index
496
+ from warp.utils import array_scan
497
+
498
+ device = self.tet_vertex_indices.device
499
+
500
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
501
+ vertex_start_edge_count.array.zero_()
502
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
503
+
504
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(6 * self.cell_count()))
505
+
506
+ # Count face edges starting at each vertex
507
+ wp.launch(
508
+ kernel=Tetmesh._count_starting_edges_kernel,
509
+ device=device,
510
+ dim=self.cell_count(),
511
+ inputs=[self.tet_vertex_indices, vertex_start_edge_count.array],
512
+ )
513
+
514
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
515
+
516
+ # Count number of unique edges (deduplicate across faces)
517
+ vertex_unique_edge_count = vertex_start_edge_count
518
+ wp.launch(
519
+ kernel=Tetmesh._count_unique_starting_edges_kernel,
520
+ device=device,
521
+ dim=self.vertex_count(),
522
+ inputs=[
523
+ self._vertex_tet_offsets,
524
+ self._vertex_tet_indices,
525
+ self.tet_vertex_indices,
526
+ vertex_start_edge_offsets.array,
527
+ vertex_unique_edge_count.array,
528
+ vertex_edge_ends.array,
529
+ ],
530
+ )
531
+
532
+ vertex_unique_edge_offsets = borrow_temporary_like(
533
+ vertex_start_edge_offsets.array, temporary_store=temporary_store
534
+ )
535
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
536
+
537
+ # Get back edge count to host
538
+ self._edge_count = int(
539
+ host_read_at_index(
540
+ vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
541
+ )
542
+ )
543
+
544
+ self._tet_edge_indices = wp.empty(
545
+ dtype=int, device=self.tet_vertex_indices.device, shape=(self.cell_count(), 6)
546
+ )
547
+
548
+ # Compress edge data
549
+ wp.launch(
550
+ kernel=Tetmesh._compress_edges_kernel,
551
+ device=device,
552
+ dim=self.vertex_count(),
553
+ inputs=[
554
+ self._vertex_tet_offsets,
555
+ self._vertex_tet_indices,
556
+ self.tet_vertex_indices,
557
+ vertex_start_edge_offsets.array,
558
+ vertex_unique_edge_offsets.array,
559
+ vertex_unique_edge_count.array,
560
+ vertex_edge_ends.array,
561
+ self._tet_edge_indices,
562
+ ],
563
+ )
564
+
565
+ vertex_start_edge_offsets.release()
566
+ vertex_unique_edge_offsets.release()
567
+ vertex_unique_edge_count.release()
568
+ vertex_edge_ends.release()
569
+
570
+ @wp.kernel
571
+ def _count_starting_faces_kernel(
572
+ tet_vertex_indices: wp.array2d(dtype=int), vertex_start_face_count: wp.array(dtype=int)
573
+ ):
574
+ t = wp.tid()
575
+ for k in range(4):
576
+ vi = wp.vec3i(
577
+ tet_vertex_indices[t, k], tet_vertex_indices[t, (k + 1) % 4], tet_vertex_indices[t, (k + 2) % 4]
578
+ )
579
+ vm = wp.min(vi)
580
+
581
+ for i in range(3):
582
+ if vm == vi[i]:
583
+ wp.atomic_add(vertex_start_face_count, vm, 1)
584
+
585
+ @wp.func
586
+ def _find_face(
587
+ needle: wp.vec2i,
588
+ values: wp.array(dtype=wp.vec2i),
589
+ beg: int,
590
+ end: int,
591
+ ):
592
+ for i in range(beg, end):
593
+ if values[i] == needle:
594
+ return i
595
+
596
+ return -1
597
+
598
+ @wp.kernel
599
+ def _count_unique_starting_faces_kernel(
600
+ vertex_tet_offsets: wp.array(dtype=int),
601
+ vertex_tet_indices: wp.array(dtype=int),
602
+ tet_vertex_indices: wp.array2d(dtype=int),
603
+ vertex_start_face_offsets: wp.array(dtype=int),
604
+ vertex_start_face_count: wp.array(dtype=int),
605
+ face_other_vs: wp.array(dtype=wp.vec2i),
606
+ face_tets: wp.array2d(dtype=int),
607
+ ):
608
+ v = wp.tid()
609
+
610
+ face_beg = vertex_start_face_offsets[v]
611
+
612
+ tet_beg = vertex_tet_offsets[v]
613
+ tet_end = vertex_tet_offsets[v + 1]
614
+
615
+ face_cur = face_beg
616
+
617
+ for tet in range(tet_beg, tet_end):
618
+ t = vertex_tet_indices[tet]
619
+
620
+ for k in range(4):
621
+ vi = wp.vec3i(
622
+ tet_vertex_indices[t, k], tet_vertex_indices[t, (k + 1) % 4], tet_vertex_indices[t, (k + 2) % 4]
623
+ )
624
+ min_v = wp.min(vi)
625
+
626
+ if v == min_v:
627
+ max_v = wp.max(vi)
628
+ mid_v = vi[0] + vi[1] + vi[2] - min_v - max_v
629
+ other_v = wp.vec2i(mid_v, max_v)
630
+
631
+ # Check if other_v has been seen
632
+ seen_idx = Tetmesh._find_face(other_v, face_other_vs, face_beg, face_cur)
633
+
634
+ if seen_idx == -1:
635
+ face_other_vs[face_cur] = other_v
636
+ face_tets[face_cur, 0] = t
637
+ face_tets[face_cur, 1] = t
638
+ face_cur += 1
639
+ else:
640
+ face_tets[seen_idx, 1] = t
641
+
642
+ vertex_start_face_count[v] = face_cur - face_beg
643
+
644
+ @wp.kernel
645
+ def _compress_faces_kernel(
646
+ vertex_start_face_offsets: wp.array(dtype=int),
647
+ vertex_unique_face_offsets: wp.array(dtype=int),
648
+ vertex_unique_face_count: wp.array(dtype=int),
649
+ uncompressed_face_other_vs: wp.array(dtype=wp.vec2i),
650
+ uncompressed_face_tets: wp.array2d(dtype=int),
651
+ face_vertex_indices: wp.array(dtype=wp.vec3i),
652
+ face_tet_indices: wp.array(dtype=wp.vec2i),
653
+ boundary_mask: wp.array(dtype=int),
654
+ ):
655
+ v = wp.tid()
656
+
657
+ start_beg = vertex_start_face_offsets[v]
658
+ unique_beg = vertex_unique_face_offsets[v]
659
+ unique_count = vertex_unique_face_count[v]
660
+
661
+ for f in range(unique_count):
662
+ src_index = start_beg + f
663
+ face_index = unique_beg + f
664
+
665
+ face_vertex_indices[face_index] = wp.vec3i(
666
+ v,
667
+ uncompressed_face_other_vs[src_index][0],
668
+ uncompressed_face_other_vs[src_index][1],
669
+ )
670
+
671
+ t0 = uncompressed_face_tets[src_index, 0]
672
+ t1 = uncompressed_face_tets[src_index, 1]
673
+ face_tet_indices[face_index] = wp.vec2i(t0, t1)
674
+ if t0 == t1:
675
+ boundary_mask[face_index] = 1
676
+ else:
677
+ boundary_mask[face_index] = 0
678
+
679
+ @wp.kernel
680
+ def _flip_face_normals(
681
+ face_vertex_indices: wp.array(dtype=wp.vec3i),
682
+ face_tet_indices: wp.array(dtype=wp.vec2i),
683
+ tet_vertex_indices: wp.array2d(dtype=int),
684
+ positions: wp.array(dtype=wp.vec3),
685
+ ):
686
+ e = wp.tid()
687
+
688
+ tet = face_tet_indices[e][0]
689
+
690
+ tet_vidx = tet_vertex_indices[tet]
691
+ face_vidx = face_vertex_indices[e]
692
+
693
+ tet_centroid = (
694
+ positions[tet_vidx[0]] + positions[tet_vidx[1]] + positions[tet_vidx[2]] + positions[tet_vidx[3]]
695
+ ) / 4.0
696
+
697
+ v0 = positions[face_vidx[0]]
698
+ v1 = positions[face_vidx[1]]
699
+ v2 = positions[face_vidx[2]]
700
+
701
+ face_center = (v1 + v0 + v2) / 3.0
702
+ face_normal = wp.cross(v1 - v0, v2 - v0)
703
+
704
+ # if face normal points toward first tet centroid, flip indices
705
+ if wp.dot(tet_centroid - face_center, face_normal) > 0.0:
706
+ face_vertex_indices[e] = wp.vec3i(face_vidx[0], face_vidx[2], face_vidx[1])
707
+
708
+ @wp.kernel
709
+ def _count_starting_edges_kernel(
710
+ tri_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
711
+ ):
712
+ t = wp.tid()
713
+ for k in range(3):
714
+ v0 = tri_vertex_indices[t, k]
715
+ v1 = tri_vertex_indices[t, (k + 1) % 3]
716
+
717
+ if v0 < v1:
718
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
719
+ else:
720
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
721
+
722
+ for k in range(3):
723
+ v0 = tri_vertex_indices[t, k]
724
+ v1 = tri_vertex_indices[t, 3]
725
+
726
+ if v0 < v1:
727
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
728
+ else:
729
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
730
+
731
+ @wp.func
732
+ def _find_edge(
733
+ needle: int,
734
+ values: wp.array(dtype=int),
735
+ beg: int,
736
+ end: int,
737
+ ):
738
+ for i in range(beg, end):
739
+ if values[i] == needle:
740
+ return i
741
+
742
+ return -1
743
+
744
+ @wp.kernel
745
+ def _count_unique_starting_edges_kernel(
746
+ vertex_tet_offsets: wp.array(dtype=int),
747
+ vertex_tet_indices: wp.array(dtype=int),
748
+ tet_vertex_indices: wp.array2d(dtype=int),
749
+ vertex_start_edge_offsets: wp.array(dtype=int),
750
+ vertex_start_edge_count: wp.array(dtype=int),
751
+ edge_ends: wp.array(dtype=int),
752
+ ):
753
+ v = wp.tid()
754
+
755
+ edge_beg = vertex_start_edge_offsets[v]
756
+
757
+ tet_beg = vertex_tet_offsets[v]
758
+ tet_end = vertex_tet_offsets[v + 1]
759
+
760
+ edge_cur = edge_beg
761
+
762
+ for tet in range(tet_beg, tet_end):
763
+ t = vertex_tet_indices[tet]
764
+
765
+ for k in range(3):
766
+ v0 = tet_vertex_indices[t, k]
767
+ v1 = tet_vertex_indices[t, (k + 1) % 3]
768
+
769
+ if v == wp.min(v0, v1):
770
+ other_v = wp.max(v0, v1)
771
+ if Tetmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
772
+ edge_ends[edge_cur] = other_v
773
+ edge_cur += 1
774
+
775
+ for k in range(3):
776
+ v0 = tet_vertex_indices[t, k]
777
+ v1 = tet_vertex_indices[t, 3]
778
+
779
+ if v == wp.min(v0, v1):
780
+ other_v = wp.max(v0, v1)
781
+ if Tetmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
782
+ edge_ends[edge_cur] = other_v
783
+ edge_cur += 1
784
+
785
+ vertex_start_edge_count[v] = edge_cur - edge_beg
786
+
787
+ @wp.kernel
788
+ def _compress_edges_kernel(
789
+ vertex_tet_offsets: wp.array(dtype=int),
790
+ vertex_tet_indices: wp.array(dtype=int),
791
+ tet_vertex_indices: wp.array2d(dtype=int),
792
+ vertex_start_edge_offsets: wp.array(dtype=int),
793
+ vertex_unique_edge_offsets: wp.array(dtype=int),
794
+ vertex_unique_edge_count: wp.array(dtype=int),
795
+ uncompressed_edge_ends: wp.array(dtype=int),
796
+ tet_edge_indices: wp.array2d(dtype=int),
797
+ ):
798
+ v = wp.tid()
799
+
800
+ uncompressed_beg = vertex_start_edge_offsets[v]
801
+
802
+ unique_beg = vertex_unique_edge_offsets[v]
803
+ unique_count = vertex_unique_edge_count[v]
804
+
805
+ tet_beg = vertex_tet_offsets[v]
806
+ tet_end = vertex_tet_offsets[v + 1]
807
+
808
+ for tet in range(tet_beg, tet_end):
809
+ t = vertex_tet_indices[tet]
810
+
811
+ for k in range(3):
812
+ v0 = tet_vertex_indices[t, k]
813
+ v1 = tet_vertex_indices[t, (k + 1) % 3]
814
+
815
+ if v == wp.min(v0, v1):
816
+ other_v = wp.max(v0, v1)
817
+ edge_id = (
818
+ Tetmesh._find_edge(
819
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
820
+ )
821
+ - uncompressed_beg
822
+ + unique_beg
823
+ )
824
+ tet_edge_indices[t][k] = edge_id
825
+
826
+ for k in range(3):
827
+ v0 = tet_vertex_indices[t, k]
828
+ v1 = tet_vertex_indices[t, 3]
829
+
830
+ if v == wp.min(v0, v1):
831
+ other_v = wp.max(v0, v1)
832
+ edge_id = (
833
+ Tetmesh._find_edge(
834
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
835
+ )
836
+ - uncompressed_beg
837
+ + unique_beg
838
+ )
839
+ tet_edge_indices[t][k + 3] = edge_id
840
+
841
+ @wp.kernel
842
+ def _compute_tet_bounds(
843
+ tet_vertex_indices: wp.array2d(dtype=int),
844
+ positions: wp.array(dtype=wp.vec3),
845
+ lowers: wp.array(dtype=wp.vec3),
846
+ uppers: wp.array(dtype=wp.vec3),
847
+ ):
848
+ t = wp.tid()
849
+ p0 = positions[tet_vertex_indices[t, 0]]
850
+ p1 = positions[tet_vertex_indices[t, 1]]
851
+ p2 = positions[tet_vertex_indices[t, 2]]
852
+ p3 = positions[tet_vertex_indices[t, 3]]
853
+
854
+ lowers[t] = wp.min(wp.min(p0, p1), wp.min(p2, p3))
855
+ uppers[t] = wp.max(wp.max(p0, p1), wp.max(p2, p3))