warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,956 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import warp as wp
19
+ from warp._src.fem.cache import (
20
+ TemporaryStore,
21
+ borrow_temporary,
22
+ borrow_temporary_like,
23
+ )
24
+ from warp._src.fem.types import OUTSIDE, Coords, ElementIndex, Sample
25
+
26
+ from .element import Element
27
+ from .geometry import Geometry
28
+
29
+ _wp_module_name_ = "warp.fem.geometry.hexmesh"
30
+
31
+
32
+ @wp.struct
33
+ class HexmeshCellArg:
34
+ hex_vertex_indices: wp.array2d(dtype=int)
35
+ positions: wp.array(dtype=wp.vec3)
36
+
37
+ # for global cell lookup
38
+ hex_bvh: wp.uint64
39
+
40
+
41
+ @wp.struct
42
+ class HexmeshSideArg:
43
+ cell_arg: HexmeshCellArg
44
+ face_vertex_indices: wp.array(dtype=wp.vec4i)
45
+ face_hex_indices: wp.array(dtype=wp.vec2i)
46
+ face_hex_face_orientation: wp.array(dtype=wp.vec4i)
47
+
48
+
49
+ FACE_VERTEX_INDICES = wp.constant(
50
+ wp.mat(shape=(6, 4), dtype=int)(
51
+ [
52
+ [0, 4, 7, 3], # x = 0
53
+ [1, 2, 6, 5], # x = 1
54
+ [0, 1, 5, 4], # y = 0
55
+ [3, 7, 6, 2], # y = 1
56
+ [0, 3, 2, 1], # z = 0
57
+ [4, 5, 6, 7], # z = 1
58
+ ]
59
+ )
60
+ )
61
+
62
+ EDGE_VERTEX_INDICES = wp.constant(
63
+ wp.mat(shape=(12, 2), dtype=int)(
64
+ [
65
+ [0, 1],
66
+ [1, 2],
67
+ [3, 2],
68
+ [0, 3],
69
+ [4, 5],
70
+ [5, 6],
71
+ [7, 6],
72
+ [4, 7],
73
+ [0, 4],
74
+ [1, 5],
75
+ [2, 6],
76
+ [3, 7],
77
+ ]
78
+ )
79
+ )
80
+
81
+ # orthogonal transform for face coordinates given first vertex + winding
82
+ # (two rows per entry)
83
+
84
+ FACE_ORIENTATION = [
85
+ [1, 0], # FV: 0, det: +
86
+ [0, 1],
87
+ [0, 1], # FV: 0, det: -
88
+ [1, 0],
89
+ [0, -1], # FV: 1, det: +
90
+ [1, 0],
91
+ [-1, 0], # FV: 1, det: -
92
+ [0, 1],
93
+ [-1, 0], # FV: 2, det: +
94
+ [0, -1],
95
+ [0, -1], # FV: 2, det: -
96
+ [-1, 0],
97
+ [0, 1], # FV: 3, det: +
98
+ [-1, 0],
99
+ [1, 0], # FV: 3, det: -
100
+ [0, -1],
101
+ ]
102
+
103
+ FACE_TRANSLATION = [
104
+ [0, 0],
105
+ [1, 0],
106
+ [1, 1],
107
+ [0, 1],
108
+ ]
109
+
110
+ # local face coordinate system
111
+ _FACE_COORD_INDICES = wp.constant(
112
+ wp.mat(shape=(6, 4), dtype=int)(
113
+ [
114
+ [2, 1, 0, 0], # 0: z y -x
115
+ [1, 2, 0, 1], # 1: y z x-1
116
+ [0, 2, 1, 0], # 2: x z -y
117
+ [2, 0, 1, 1], # 3: z x y-1
118
+ [1, 0, 2, 0], # 4: y x -z
119
+ [0, 1, 2, 1], # 5: x y z-1
120
+ ]
121
+ )
122
+ )
123
+
124
+ _FACE_ORIENTATION_F = wp.constant(wp.mat(shape=(16, 2), dtype=float)(FACE_ORIENTATION))
125
+ _FACE_TRANSLATION_F = wp.constant(wp.mat(shape=(4, 2), dtype=float)(FACE_TRANSLATION))
126
+
127
+
128
+ class Hexmesh(Geometry):
129
+ """Hexahedral mesh geometry"""
130
+
131
+ dimension = 3
132
+
133
+ def __init__(
134
+ self,
135
+ hex_vertex_indices: wp.array,
136
+ positions: wp.array,
137
+ assume_parallelepiped_cells=False,
138
+ build_bvh: bool = False,
139
+ temporary_store: Optional[TemporaryStore] = None,
140
+ ):
141
+ """
142
+ Constructs a hexahedral mesh.
143
+
144
+ Args:
145
+ hex_vertex_indices: warp array of shape (num_hexes, 8) containing vertex indices for each hex
146
+ following standard ordering (bottom face vertices in counter-clockwise order, then similarly for upper face)
147
+ positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
148
+ assume_parallelepiped: If true, assume that all cells are parallelepipeds (cheaper position/gradient evaluations)
149
+ build_bvh: Whether to also build the hex BVH, which is necessary for the global `fem.lookup` operator
150
+ temporary_store: shared pool from which to allocate temporary arrays
151
+ """
152
+
153
+ self.hex_vertex_indices = hex_vertex_indices
154
+ self.positions = positions
155
+ self.parallelepiped_cells = assume_parallelepiped_cells
156
+
157
+ self._face_vertex_indices: wp.array = None
158
+ self._face_hex_indices: wp.array = None
159
+ self._face_hex_face_orientation: wp.array = None
160
+ self._vertex_hex_offsets: wp.array = None
161
+ self._vertex_hex_indices: wp.array = None
162
+ self._hex_edge_indices: wp.array = None
163
+ self._edge_count = 0
164
+ self._build_topology(temporary_store)
165
+
166
+ # Use cheaper variants if we know that cells are parallelepipeds (i.e. linearly transformed)
167
+ # (Cells only, not as much difference for sides)
168
+ self.cell_position = (
169
+ self._cell_position_parallelepiped if assume_parallelepiped_cells else self._cell_position_generic
170
+ )
171
+ self.cell_deformation_gradient = (
172
+ self._cell_deformation_gradient_parallelepiped
173
+ if assume_parallelepiped_cells
174
+ else self._cell_deformation_gradient_generic
175
+ )
176
+
177
+ self._make_default_dependent_implementations()
178
+ self.cell_coordinates = self._make_cell_coordinates(assume_linear=assume_parallelepiped_cells)
179
+ self.side_coordinates = self._make_side_coordinates(assume_linear=assume_parallelepiped_cells)
180
+ self.cell_closest_point = self._make_cell_closest_point(assume_linear=assume_parallelepiped_cells)
181
+ self.side_closest_point = self._make_side_closest_point(assume_linear=assume_parallelepiped_cells)
182
+
183
+ if build_bvh:
184
+ self.build_bvh(self.positions.device)
185
+
186
+ def cell_count(self):
187
+ return self.hex_vertex_indices.shape[0]
188
+
189
+ def vertex_count(self):
190
+ return self.positions.shape[0]
191
+
192
+ def side_count(self):
193
+ return self._face_vertex_indices.shape[0]
194
+
195
+ def edge_count(self):
196
+ if self._hex_edge_indices is None:
197
+ self._compute_hex_edges()
198
+ return self._edge_count
199
+
200
+ def boundary_side_count(self):
201
+ return self._boundary_face_indices.shape[0]
202
+
203
+ def reference_cell(self) -> Element:
204
+ return Element.CUBE
205
+
206
+ def reference_side(self) -> Element:
207
+ return Element.SQUARE
208
+
209
+ @property
210
+ def hex_edge_indices(self) -> wp.array:
211
+ if self._hex_edge_indices is None:
212
+ self._compute_hex_edges()
213
+ return self._hex_edge_indices
214
+
215
+ @property
216
+ def face_hex_indices(self) -> wp.array:
217
+ return self._face_hex_indices
218
+
219
+ @property
220
+ def face_vertex_indices(self) -> wp.array:
221
+ return self._face_vertex_indices
222
+
223
+ CellArg = HexmeshCellArg
224
+ SideArg = HexmeshSideArg
225
+
226
+ @wp.struct
227
+ class SideIndexArg:
228
+ boundary_face_indices: wp.array(dtype=int)
229
+
230
+ # Geometry device interface
231
+
232
+ def fill_cell_arg(self, args: CellArg, device):
233
+ args.hex_vertex_indices = self.hex_vertex_indices.to(device)
234
+ args.positions = self.positions.to(device)
235
+ args.hex_bvh = self.bvh_id(device)
236
+
237
+ @wp.func
238
+ def _cell_position_generic(args: CellArg, s: Sample):
239
+ hex_idx = args.hex_vertex_indices[s.element_index]
240
+
241
+ w_p = s.element_coords
242
+ w_m = Coords(1.0) - s.element_coords
243
+
244
+ # 0 : m m m
245
+ # 1 : p m m
246
+ # 2 : p p m
247
+ # 3 : m p m
248
+ # 4 : m m p
249
+ # 5 : p m p
250
+ # 6 : p p p
251
+ # 7 : m p p
252
+
253
+ return (
254
+ w_m[0] * w_m[1] * w_m[2] * args.positions[hex_idx[0]]
255
+ + w_p[0] * w_m[1] * w_m[2] * args.positions[hex_idx[1]]
256
+ + w_p[0] * w_p[1] * w_m[2] * args.positions[hex_idx[2]]
257
+ + w_m[0] * w_p[1] * w_m[2] * args.positions[hex_idx[3]]
258
+ + w_m[0] * w_m[1] * w_p[2] * args.positions[hex_idx[4]]
259
+ + w_p[0] * w_m[1] * w_p[2] * args.positions[hex_idx[5]]
260
+ + w_p[0] * w_p[1] * w_p[2] * args.positions[hex_idx[6]]
261
+ + w_m[0] * w_p[1] * w_p[2] * args.positions[hex_idx[7]]
262
+ )
263
+
264
+ @wp.func
265
+ def _cell_position_parallelepiped(args: CellArg, s: Sample):
266
+ hex_idx = args.hex_vertex_indices[s.element_index]
267
+ w = s.element_coords
268
+ p0 = args.positions[hex_idx[0]]
269
+ p1 = args.positions[hex_idx[1]]
270
+ p2 = args.positions[hex_idx[3]]
271
+ p3 = args.positions[hex_idx[4]]
272
+ return w[0] * p1 + w[1] * p2 + w[2] * p3 + (1.0 - w[0] - w[1] - w[2]) * p0
273
+
274
+ @wp.func
275
+ def _cell_deformation_gradient_generic(cell_arg: CellArg, s: Sample):
276
+ """Deformation gradient at `coords`"""
277
+ hex_idx = cell_arg.hex_vertex_indices[s.element_index]
278
+
279
+ w_p = s.element_coords
280
+ w_m = Coords(1.0) - s.element_coords
281
+
282
+ return (
283
+ wp.outer(cell_arg.positions[hex_idx[0]], wp.vec3(-w_m[1] * w_m[2], -w_m[0] * w_m[2], -w_m[0] * w_m[1]))
284
+ + wp.outer(cell_arg.positions[hex_idx[1]], wp.vec3(w_m[1] * w_m[2], -w_p[0] * w_m[2], -w_p[0] * w_m[1]))
285
+ + wp.outer(cell_arg.positions[hex_idx[2]], wp.vec3(w_p[1] * w_m[2], w_p[0] * w_m[2], -w_p[0] * w_p[1]))
286
+ + wp.outer(cell_arg.positions[hex_idx[3]], wp.vec3(-w_p[1] * w_m[2], w_m[0] * w_m[2], -w_m[0] * w_p[1]))
287
+ + wp.outer(cell_arg.positions[hex_idx[4]], wp.vec3(-w_m[1] * w_p[2], -w_m[0] * w_p[2], w_m[0] * w_m[1]))
288
+ + wp.outer(cell_arg.positions[hex_idx[5]], wp.vec3(w_m[1] * w_p[2], -w_p[0] * w_p[2], w_p[0] * w_m[1]))
289
+ + wp.outer(cell_arg.positions[hex_idx[6]], wp.vec3(w_p[1] * w_p[2], w_p[0] * w_p[2], w_p[0] * w_p[1]))
290
+ + wp.outer(cell_arg.positions[hex_idx[7]], wp.vec3(-w_p[1] * w_p[2], w_m[0] * w_p[2], w_m[0] * w_p[1]))
291
+ )
292
+
293
+ @wp.func
294
+ def _cell_deformation_gradient_parallelepiped(cell_arg: CellArg, s: Sample):
295
+ """Deformation gradient at `coords`"""
296
+ hex_idx = cell_arg.hex_vertex_indices[s.element_index]
297
+
298
+ p0 = cell_arg.positions[hex_idx[0]]
299
+ p1 = cell_arg.positions[hex_idx[1]]
300
+ p2 = cell_arg.positions[hex_idx[3]]
301
+ p3 = cell_arg.positions[hex_idx[4]]
302
+ return wp.matrix_from_cols(p1 - p0, p2 - p0, p3 - p0)
303
+
304
+ def fill_side_index_arg(self, args: SideIndexArg, device):
305
+ args.boundary_face_indices = self._boundary_face_indices.to(device)
306
+
307
+ @wp.func
308
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
309
+ """Boundary side to side index"""
310
+
311
+ return args.boundary_face_indices[boundary_side_index]
312
+
313
+ def fill_side_arg(self, args: SideArg, device):
314
+ self.fill_cell_arg(args.cell_arg, device)
315
+ args.face_vertex_indices = self._face_vertex_indices.to(device)
316
+ args.face_hex_indices = self._face_hex_indices.to(device)
317
+ args.face_hex_face_orientation = self._face_hex_face_orientation.to(device)
318
+
319
+ @wp.func
320
+ def side_position(args: SideArg, s: Sample):
321
+ face_idx = args.face_vertex_indices[s.element_index]
322
+
323
+ w_p = s.element_coords
324
+ w_m = Coords(1.0) - s.element_coords
325
+
326
+ return (
327
+ w_m[0] * w_m[1] * args.cell_arg.positions[face_idx[0]]
328
+ + w_p[0] * w_m[1] * args.cell_arg.positions[face_idx[1]]
329
+ + w_p[0] * w_p[1] * args.cell_arg.positions[face_idx[2]]
330
+ + w_m[0] * w_p[1] * args.cell_arg.positions[face_idx[3]]
331
+ )
332
+
333
+ @wp.func
334
+ def _side_deformation_vecs(args: SideArg, side_index: ElementIndex, coords: Coords):
335
+ face_idx = args.face_vertex_indices[side_index]
336
+
337
+ p0 = args.cell_arg.positions[face_idx[0]]
338
+ p1 = args.cell_arg.positions[face_idx[1]]
339
+ p2 = args.cell_arg.positions[face_idx[2]]
340
+ p3 = args.cell_arg.positions[face_idx[3]]
341
+
342
+ w_p = coords
343
+ w_m = Coords(1.0) - coords
344
+
345
+ v1 = w_m[1] * (p1 - p0) + w_p[1] * (p2 - p3)
346
+ v2 = w_p[0] * (p2 - p1) + w_m[0] * (p3 - p0)
347
+ return v1, v2
348
+
349
+ @wp.func
350
+ def side_deformation_gradient(args: SideArg, s: Sample):
351
+ """Transposed side deformation gradient at `coords`"""
352
+ v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
353
+ return wp.matrix_from_cols(v1, v2)
354
+
355
+ @wp.func
356
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
357
+ return arg.face_hex_indices[side_index][0]
358
+
359
+ @wp.func
360
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
361
+ return arg.face_hex_indices[side_index][1]
362
+
363
+ @wp.func
364
+ def _hex_local_face_coords(hex_coords: Coords, face_index: int):
365
+ # Coordinates in local face coordinates system
366
+ # Sign of last coordinate (out of face)
367
+
368
+ face_coords = wp.vec2(
369
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]], hex_coords[_FACE_COORD_INDICES[face_index, 1]]
370
+ )
371
+
372
+ normal_coord = hex_coords[_FACE_COORD_INDICES[face_index, 2]]
373
+ normal_coord = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, -normal_coord, normal_coord - 1.0)
374
+
375
+ return face_coords, normal_coord
376
+
377
+ @wp.func
378
+ def _local_face_hex_coords(face_coords: wp.vec2, face_index: int):
379
+ # Coordinates in hex from local face coordinates system
380
+
381
+ hex_coords = Coords()
382
+ hex_coords[_FACE_COORD_INDICES[face_index, 0]] = face_coords[0]
383
+ hex_coords[_FACE_COORD_INDICES[face_index, 1]] = face_coords[1]
384
+ hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, 0.0, 1.0)
385
+
386
+ return hex_coords
387
+
388
+ @wp.func
389
+ def _local_from_oriented_face_coords(ori: int, oriented_coords: Coords):
390
+ fv = ori // 2
391
+ return (oriented_coords[0] - _FACE_TRANSLATION_F[fv, 0]) * _FACE_ORIENTATION_F[2 * ori] + (
392
+ oriented_coords[1] - _FACE_TRANSLATION_F[fv, 1]
393
+ ) * _FACE_ORIENTATION_F[2 * ori + 1]
394
+
395
+ @wp.func
396
+ def _local_to_oriented_face_coords(ori: int, coords: wp.vec2):
397
+ fv = ori // 2
398
+ return Coords(
399
+ wp.dot(_FACE_ORIENTATION_F[2 * ori], coords) + _FACE_TRANSLATION_F[fv, 0],
400
+ wp.dot(_FACE_ORIENTATION_F[2 * ori + 1], coords) + _FACE_TRANSLATION_F[fv, 1],
401
+ 0.0,
402
+ )
403
+
404
+ @wp.func
405
+ def face_to_hex_coords(local_face_index: int, face_orientation: int, side_coords: Coords):
406
+ local_coords = Hexmesh._local_from_oriented_face_coords(face_orientation, side_coords)
407
+ return Hexmesh._local_face_hex_coords(local_coords, local_face_index)
408
+
409
+ @wp.func
410
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
411
+ local_face_index = args.face_hex_face_orientation[side_index][0]
412
+ face_orientation = args.face_hex_face_orientation[side_index][1]
413
+
414
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
415
+
416
+ @wp.func
417
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
418
+ local_face_index = args.face_hex_face_orientation[side_index][2]
419
+ face_orientation = args.face_hex_face_orientation[side_index][3]
420
+
421
+ return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
422
+
423
+ @wp.func
424
+ def side_from_cell_coords(args: SideArg, side_index: ElementIndex, hex_index: ElementIndex, hex_coords: Coords):
425
+ if Hexmesh.side_inner_cell_index(args, side_index) == hex_index:
426
+ local_face_index = args.face_hex_face_orientation[side_index][0]
427
+ face_orientation = args.face_hex_face_orientation[side_index][1]
428
+ else:
429
+ local_face_index = args.face_hex_face_orientation[side_index][2]
430
+ face_orientation = args.face_hex_face_orientation[side_index][3]
431
+
432
+ face_coords, normal_coord = Hexmesh._hex_local_face_coords(hex_coords, local_face_index)
433
+ return wp.where(
434
+ normal_coord == 0.0, Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords), Coords(OUTSIDE)
435
+ )
436
+
437
+ @wp.func
438
+ def side_to_cell_arg(side_arg: SideArg):
439
+ return side_arg.cell_arg
440
+
441
+ def _build_topology(self, temporary_store: TemporaryStore):
442
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index, masked_indices
443
+ from warp._src.utils import array_scan
444
+
445
+ device = self.hex_vertex_indices.device
446
+
447
+ vertex_hex_offsets, vertex_hex_indices = compress_node_indices(
448
+ self.vertex_count(), self.hex_vertex_indices, temporary_store=temporary_store
449
+ )
450
+ self._vertex_hex_offsets = vertex_hex_offsets.detach()
451
+ self._vertex_hex_indices = vertex_hex_indices.detach()
452
+
453
+ vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
454
+ vertex_start_face_count.zero_()
455
+ vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
456
+
457
+ vertex_face_other_vs = borrow_temporary(
458
+ temporary_store, dtype=wp.vec3i, device=device, shape=(8 * self.cell_count())
459
+ )
460
+ vertex_face_hexes = borrow_temporary(
461
+ temporary_store, dtype=int, device=device, shape=(8 * self.cell_count(), 2)
462
+ )
463
+
464
+ # Count face edges starting at each vertex
465
+ wp.launch(
466
+ kernel=Hexmesh._count_starting_faces_kernel,
467
+ device=device,
468
+ dim=self.cell_count(),
469
+ inputs=[self.hex_vertex_indices, vertex_start_face_count],
470
+ )
471
+
472
+ array_scan(in_array=vertex_start_face_count, out_array=vertex_start_face_offsets, inclusive=False)
473
+
474
+ # Count number of unique edges (deduplicate across faces)
475
+ vertex_unique_face_count = vertex_start_face_count
476
+ wp.launch(
477
+ kernel=Hexmesh._count_unique_starting_faces_kernel,
478
+ device=device,
479
+ dim=self.vertex_count(),
480
+ inputs=[
481
+ self._vertex_hex_offsets,
482
+ self._vertex_hex_indices,
483
+ self.hex_vertex_indices,
484
+ vertex_start_face_offsets,
485
+ vertex_unique_face_count,
486
+ vertex_face_other_vs,
487
+ vertex_face_hexes,
488
+ ],
489
+ )
490
+
491
+ vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
492
+ array_scan(in_array=vertex_start_face_count, out_array=vertex_unique_face_offsets, inclusive=False)
493
+
494
+ # Get back edge count to host
495
+ face_count = int(
496
+ host_read_at_index(vertex_unique_face_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
497
+ )
498
+
499
+ self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
500
+ self._face_hex_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
501
+ self._face_hex_face_orientation = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
502
+
503
+ boundary_mask = borrow_temporary(temporary_store, shape=(face_count,), dtype=int, device=device)
504
+
505
+ # Compress edge data
506
+ wp.launch(
507
+ kernel=Hexmesh._compress_faces_kernel,
508
+ device=device,
509
+ dim=self.vertex_count(),
510
+ inputs=[
511
+ vertex_start_face_offsets,
512
+ vertex_unique_face_offsets,
513
+ vertex_unique_face_count,
514
+ vertex_face_other_vs,
515
+ vertex_face_hexes,
516
+ self._face_vertex_indices,
517
+ self._face_hex_indices,
518
+ boundary_mask,
519
+ ],
520
+ )
521
+
522
+ vertex_start_face_offsets.release()
523
+ vertex_unique_face_offsets.release()
524
+ vertex_unique_face_count.release()
525
+ vertex_face_other_vs.release()
526
+ vertex_face_hexes.release()
527
+
528
+ # Flip normals if necessary
529
+ wp.launch(
530
+ kernel=Hexmesh._flip_face_normals,
531
+ device=device,
532
+ dim=self.side_count(),
533
+ inputs=[self._face_vertex_indices, self._face_hex_indices, self.hex_vertex_indices, self.positions],
534
+ )
535
+
536
+ # Compute and store face orientation
537
+ wp.launch(
538
+ kernel=Hexmesh._compute_face_orientation,
539
+ device=device,
540
+ dim=self.side_count(),
541
+ inputs=[
542
+ self._face_vertex_indices,
543
+ self._face_hex_indices,
544
+ self.hex_vertex_indices,
545
+ self._face_hex_face_orientation,
546
+ ],
547
+ )
548
+
549
+ boundary_face_indices, _ = masked_indices(boundary_mask)
550
+ self._boundary_face_indices = boundary_face_indices.detach()
551
+
552
+ def _compute_hex_edges(self, temporary_store: Optional[TemporaryStore] = None):
553
+ from warp._src.fem.utils import host_read_at_index
554
+ from warp._src.utils import array_scan
555
+
556
+ device = self.hex_vertex_indices.device
557
+
558
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
559
+ vertex_start_edge_count.zero_()
560
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
561
+
562
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(12 * self.cell_count()))
563
+
564
+ # Count face edges starting at each vertex
565
+ wp.launch(
566
+ kernel=Hexmesh._count_starting_edges_kernel,
567
+ device=device,
568
+ dim=self.cell_count(),
569
+ inputs=[self.hex_vertex_indices, vertex_start_edge_count],
570
+ )
571
+
572
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
573
+
574
+ # Count number of unique edges (deduplicate across faces)
575
+ vertex_unique_edge_count = vertex_start_edge_count
576
+ wp.launch(
577
+ kernel=Hexmesh._count_unique_starting_edges_kernel,
578
+ device=device,
579
+ dim=self.vertex_count(),
580
+ inputs=[
581
+ self._vertex_hex_offsets,
582
+ self._vertex_hex_indices,
583
+ self.hex_vertex_indices,
584
+ vertex_start_edge_offsets,
585
+ vertex_unique_edge_count,
586
+ vertex_edge_ends,
587
+ ],
588
+ )
589
+
590
+ vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
591
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
592
+
593
+ # Get back edge count to host
594
+ self._edge_count = int(
595
+ host_read_at_index(vertex_unique_edge_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
596
+ )
597
+
598
+ self._hex_edge_indices = wp.empty(
599
+ dtype=int, device=self.hex_vertex_indices.device, shape=(self.cell_count(), 12)
600
+ )
601
+
602
+ # Compress edge data
603
+ wp.launch(
604
+ kernel=Hexmesh._compress_edges_kernel,
605
+ device=device,
606
+ dim=self.vertex_count(),
607
+ inputs=[
608
+ self._vertex_hex_offsets,
609
+ self._vertex_hex_indices,
610
+ self.hex_vertex_indices,
611
+ vertex_start_edge_offsets,
612
+ vertex_unique_edge_offsets,
613
+ vertex_unique_edge_count,
614
+ vertex_edge_ends,
615
+ self._hex_edge_indices,
616
+ ],
617
+ )
618
+
619
+ vertex_start_edge_offsets.release()
620
+ vertex_unique_edge_offsets.release()
621
+ vertex_unique_edge_count.release()
622
+ vertex_edge_ends.release()
623
+
624
+ @wp.kernel
625
+ def _count_starting_faces_kernel(
626
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_face_count: wp.array(dtype=int)
627
+ ):
628
+ t = wp.tid()
629
+ for k in range(6):
630
+ vi = wp.vec4i(
631
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 0]],
632
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 1]],
633
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 2]],
634
+ hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 3]],
635
+ )
636
+ vm = wp.min(vi)
637
+
638
+ for i in range(4):
639
+ if vm == vi[i]:
640
+ wp.atomic_add(vertex_start_face_count, vm, 1)
641
+
642
+ @wp.func
643
+ def _face_sort(vidx: wp.vec4i, min_k: int):
644
+ v1 = vidx[(min_k + 1) % 4]
645
+ v2 = vidx[(min_k + 2) % 4]
646
+ v3 = vidx[(min_k + 3) % 4]
647
+
648
+ if v1 < v3:
649
+ return wp.vec3i(v1, v2, v3)
650
+ return wp.vec3i(v3, v2, v1)
651
+
652
+ @wp.func
653
+ def _find_face(
654
+ needle: wp.vec3i,
655
+ values: wp.array(dtype=wp.vec3i),
656
+ beg: int,
657
+ end: int,
658
+ ):
659
+ for i in range(beg, end):
660
+ if values[i] == needle:
661
+ return i
662
+
663
+ return -1
664
+
665
+ @wp.kernel
666
+ def _count_unique_starting_faces_kernel(
667
+ vertex_hex_offsets: wp.array(dtype=int),
668
+ vertex_hex_indices: wp.array(dtype=int),
669
+ hex_vertex_indices: wp.array2d(dtype=int),
670
+ vertex_start_face_offsets: wp.array(dtype=int),
671
+ vertex_start_face_count: wp.array(dtype=int),
672
+ face_other_vs: wp.array(dtype=wp.vec3i),
673
+ face_hexes: wp.array2d(dtype=int),
674
+ ):
675
+ v = wp.tid()
676
+
677
+ face_beg = vertex_start_face_offsets[v]
678
+
679
+ hex_beg = vertex_hex_offsets[v]
680
+ hex_end = vertex_hex_offsets[v + 1]
681
+
682
+ face_cur = face_beg
683
+
684
+ for hexa in range(hex_beg, hex_end):
685
+ hx = vertex_hex_indices[hexa]
686
+
687
+ for k in range(6):
688
+ vi = wp.vec4i(
689
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 0]],
690
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 1]],
691
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 2]],
692
+ hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 3]],
693
+ )
694
+ min_i = int(wp.argmin(vi))
695
+
696
+ if v == vi[min_i]:
697
+ other_v = Hexmesh._face_sort(vi, min_i)
698
+
699
+ # Check if other_v has been seen
700
+ seen_idx = Hexmesh._find_face(other_v, face_other_vs, face_beg, face_cur)
701
+
702
+ if seen_idx == -1:
703
+ face_other_vs[face_cur] = other_v
704
+ face_hexes[face_cur, 0] = hx
705
+ face_hexes[face_cur, 1] = hx
706
+ face_cur += 1
707
+ else:
708
+ face_hexes[seen_idx, 1] = hx
709
+
710
+ vertex_start_face_count[v] = face_cur - face_beg
711
+
712
+ @wp.kernel
713
+ def _compress_faces_kernel(
714
+ vertex_start_face_offsets: wp.array(dtype=int),
715
+ vertex_unique_face_offsets: wp.array(dtype=int),
716
+ vertex_unique_face_count: wp.array(dtype=int),
717
+ uncompressed_face_other_vs: wp.array(dtype=wp.vec3i),
718
+ uncompressed_face_hexes: wp.array2d(dtype=int),
719
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
720
+ face_hex_indices: wp.array(dtype=wp.vec2i),
721
+ boundary_mask: wp.array(dtype=int),
722
+ ):
723
+ v = wp.tid()
724
+
725
+ start_beg = vertex_start_face_offsets[v]
726
+ unique_beg = vertex_unique_face_offsets[v]
727
+ unique_count = vertex_unique_face_count[v]
728
+
729
+ for f in range(unique_count):
730
+ src_index = start_beg + f
731
+ face_index = unique_beg + f
732
+
733
+ face_vertex_indices[face_index] = wp.vec4i(
734
+ v,
735
+ uncompressed_face_other_vs[src_index][0],
736
+ uncompressed_face_other_vs[src_index][1],
737
+ uncompressed_face_other_vs[src_index][2],
738
+ )
739
+
740
+ hx0 = uncompressed_face_hexes[src_index, 0]
741
+ hx1 = uncompressed_face_hexes[src_index, 1]
742
+ face_hex_indices[face_index] = wp.vec2i(hx0, hx1)
743
+ if hx0 == hx1:
744
+ boundary_mask[face_index] = 1
745
+ else:
746
+ boundary_mask[face_index] = 0
747
+
748
+ @wp.kernel
749
+ def _flip_face_normals(
750
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
751
+ face_hex_indices: wp.array(dtype=wp.vec2i),
752
+ hex_vertex_indices: wp.array2d(dtype=int),
753
+ positions: wp.array(dtype=wp.vec3),
754
+ ):
755
+ f = wp.tid()
756
+
757
+ hexa = face_hex_indices[f][0]
758
+
759
+ hex_vidx = hex_vertex_indices[hexa]
760
+ face_vidx = face_vertex_indices[f]
761
+
762
+ hex_centroid = (
763
+ positions[hex_vidx[0]]
764
+ + positions[hex_vidx[1]]
765
+ + positions[hex_vidx[2]]
766
+ + positions[hex_vidx[3]]
767
+ + positions[hex_vidx[4]]
768
+ + positions[hex_vidx[5]]
769
+ + positions[hex_vidx[6]]
770
+ + positions[hex_vidx[7]]
771
+ ) / 8.0
772
+
773
+ v0 = positions[face_vidx[0]]
774
+ v1 = positions[face_vidx[1]]
775
+ v2 = positions[face_vidx[2]]
776
+ v3 = positions[face_vidx[3]]
777
+
778
+ face_center = (v1 + v0 + v2 + v3) / 4.0
779
+ face_normal = wp.cross(v2 - v0, v3 - v1)
780
+
781
+ # if face normal points toward first tet centroid, flip indices
782
+ if wp.dot(hex_centroid - face_center, face_normal) > 0.0:
783
+ face_vertex_indices[f] = wp.vec4i(face_vidx[0], face_vidx[3], face_vidx[2], face_vidx[1])
784
+
785
+ @wp.func
786
+ def _find_face_orientation(face_vidx: wp.vec4i, hex_index: int, hex_vertex_indices: wp.array2d(dtype=int)):
787
+ hex_vidx = hex_vertex_indices[hex_index]
788
+
789
+ # Find local index in hex corresponding to face
790
+
791
+ face_min_i = int(wp.argmin(face_vidx))
792
+ face_other_v = Hexmesh._face_sort(face_vidx, face_min_i)
793
+
794
+ for k in range(6):
795
+ hex_face_vi = wp.vec4i(
796
+ hex_vidx[FACE_VERTEX_INDICES[k, 0]],
797
+ hex_vidx[FACE_VERTEX_INDICES[k, 1]],
798
+ hex_vidx[FACE_VERTEX_INDICES[k, 2]],
799
+ hex_vidx[FACE_VERTEX_INDICES[k, 3]],
800
+ )
801
+ hex_min_i = int(wp.argmin(hex_face_vi))
802
+ hex_other_v = Hexmesh._face_sort(hex_face_vi, hex_min_i)
803
+
804
+ if hex_other_v == face_other_v:
805
+ local_face_index = k
806
+ break
807
+
808
+ # Find starting vertex index
809
+ for k in range(4):
810
+ if face_vidx[k] == hex_face_vi[0]:
811
+ face_orientation = 2 * k
812
+ if face_vidx[(k + 1) % 4] != hex_face_vi[1]:
813
+ face_orientation += 1
814
+
815
+ return local_face_index, face_orientation
816
+
817
+ @wp.kernel
818
+ def _compute_face_orientation(
819
+ face_vertex_indices: wp.array(dtype=wp.vec4i),
820
+ face_hex_indices: wp.array(dtype=wp.vec2i),
821
+ hex_vertex_indices: wp.array2d(dtype=int),
822
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
823
+ ):
824
+ f = wp.tid()
825
+
826
+ face_vidx = face_vertex_indices[f]
827
+
828
+ hx0 = face_hex_indices[f][0]
829
+ local_face_0, ori_0 = Hexmesh._find_face_orientation(face_vidx, hx0, hex_vertex_indices)
830
+
831
+ hx1 = face_hex_indices[f][1]
832
+ if hx0 == hx1:
833
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_0, ori_0)
834
+ else:
835
+ local_face_1, ori_1 = Hexmesh._find_face_orientation(face_vidx, hx1, hex_vertex_indices)
836
+ face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_1, ori_1)
837
+
838
+ @wp.kernel
839
+ def _count_starting_edges_kernel(
840
+ hex_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
841
+ ):
842
+ t = wp.tid()
843
+ for k in range(12):
844
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
845
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
846
+
847
+ if v0 < v1:
848
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
849
+ else:
850
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
851
+
852
+ @wp.func
853
+ def _find_edge(
854
+ needle: int,
855
+ values: wp.array(dtype=int),
856
+ beg: int,
857
+ end: int,
858
+ ):
859
+ for i in range(beg, end):
860
+ if values[i] == needle:
861
+ return i
862
+
863
+ return -1
864
+
865
+ @wp.kernel
866
+ def _count_unique_starting_edges_kernel(
867
+ vertex_hex_offsets: wp.array(dtype=int),
868
+ vertex_hex_indices: wp.array(dtype=int),
869
+ hex_vertex_indices: wp.array2d(dtype=int),
870
+ vertex_start_edge_offsets: wp.array(dtype=int),
871
+ vertex_start_edge_count: wp.array(dtype=int),
872
+ edge_ends: wp.array(dtype=int),
873
+ ):
874
+ v = wp.tid()
875
+
876
+ edge_beg = vertex_start_edge_offsets[v]
877
+
878
+ hex_beg = vertex_hex_offsets[v]
879
+ hex_end = vertex_hex_offsets[v + 1]
880
+
881
+ edge_cur = edge_beg
882
+
883
+ for tet in range(hex_beg, hex_end):
884
+ t = vertex_hex_indices[tet]
885
+
886
+ for k in range(12):
887
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
888
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
889
+
890
+ if v == wp.min(v0, v1):
891
+ other_v = wp.max(v0, v1)
892
+ if Hexmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
893
+ edge_ends[edge_cur] = other_v
894
+ edge_cur += 1
895
+
896
+ vertex_start_edge_count[v] = edge_cur - edge_beg
897
+
898
+ @wp.kernel
899
+ def _compress_edges_kernel(
900
+ vertex_hex_offsets: wp.array(dtype=int),
901
+ vertex_hex_indices: wp.array(dtype=int),
902
+ hex_vertex_indices: wp.array2d(dtype=int),
903
+ vertex_start_edge_offsets: wp.array(dtype=int),
904
+ vertex_unique_edge_offsets: wp.array(dtype=int),
905
+ vertex_unique_edge_count: wp.array(dtype=int),
906
+ uncompressed_edge_ends: wp.array(dtype=int),
907
+ hex_edge_indices: wp.array2d(dtype=int),
908
+ ):
909
+ v = wp.tid()
910
+
911
+ uncompressed_beg = vertex_start_edge_offsets[v]
912
+
913
+ unique_beg = vertex_unique_edge_offsets[v]
914
+ unique_count = vertex_unique_edge_count[v]
915
+
916
+ hex_beg = vertex_hex_offsets[v]
917
+ hex_end = vertex_hex_offsets[v + 1]
918
+
919
+ for tet in range(hex_beg, hex_end):
920
+ t = vertex_hex_indices[tet]
921
+
922
+ for k in range(12):
923
+ v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
924
+ v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
925
+
926
+ if v == wp.min(v0, v1):
927
+ other_v = wp.max(v0, v1)
928
+ edge_id = (
929
+ Hexmesh._find_edge(
930
+ other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
931
+ )
932
+ - uncompressed_beg
933
+ + unique_beg
934
+ )
935
+ hex_edge_indices[t][k] = edge_id
936
+
937
+ @wp.func
938
+ def cell_bvh_id(cell_arg: HexmeshCellArg):
939
+ return cell_arg.hex_bvh
940
+
941
+ @wp.func
942
+ def cell_bounds(cell_arg: HexmeshCellArg, cell_index: ElementIndex):
943
+ vidx = cell_arg.hex_vertex_indices[cell_index]
944
+ p0 = cell_arg.positions[vidx[0]]
945
+ p1 = cell_arg.positions[vidx[1]]
946
+ p2 = cell_arg.positions[vidx[2]]
947
+ p3 = cell_arg.positions[vidx[3]]
948
+ lo0, up0 = wp.min(wp.min(p0, p1), wp.min(p2, p3)), wp.max(wp.max(p0, p1), wp.max(p2, p3))
949
+
950
+ p4 = cell_arg.positions[vidx[4]]
951
+ p5 = cell_arg.positions[vidx[5]]
952
+ p6 = cell_arg.positions[vidx[6]]
953
+ p7 = cell_arg.positions[vidx[7]]
954
+ lo1, up1 = wp.min(wp.min(p4, p5), wp.min(p6, p7)), wp.max(wp.max(p4, p5), wp.max(p6, p7))
955
+
956
+ return wp.min(lo0, lo1), wp.max(up0, up1)