warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,806 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional
17
+
18
+ import warp as wp
19
+ from warp.fem.cache import (
20
+ TemporaryStore,
21
+ borrow_temporary,
22
+ borrow_temporary_like,
23
+ cached_arg_value,
24
+ )
25
+ from warp.fem.types import (
26
+ NULL_ELEMENT_INDEX,
27
+ OUTSIDE,
28
+ Coords,
29
+ ElementIndex,
30
+ Sample,
31
+ make_free_sample,
32
+ )
33
+
34
+ from .closest_point import project_on_tri_at_origin
35
+ from .element import LinearEdge, Triangle
36
+ from .geometry import Geometry
37
+
38
+
39
+ @wp.struct
40
+ class TrimeshCellArg:
41
+ tri_vertex_indices: wp.array2d(dtype=int)
42
+
43
+ # for neighbor cell lookup
44
+ vertex_tri_offsets: wp.array(dtype=int)
45
+ vertex_tri_indices: wp.array(dtype=int)
46
+
47
+ # for global cell lookup
48
+ tri_bvh: wp.uint64
49
+
50
+
51
+ @wp.struct
52
+ class TrimeshSideArg:
53
+ cell_arg: TrimeshCellArg
54
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
55
+ edge_tri_indices: wp.array(dtype=wp.vec2i)
56
+
57
+
58
+ _NULL_BVH = wp.constant(wp.uint64(0))
59
+
60
+
61
+ @wp.func
62
+ def _bvh_vec(v: wp.vec3):
63
+ return v
64
+
65
+
66
+ @wp.func
67
+ def _bvh_vec(v: wp.vec2):
68
+ return wp.vec3(v[0], v[1], 0.0)
69
+
70
+
71
+ class Trimesh(Geometry):
72
+ """Triangular mesh geometry"""
73
+
74
+ def __init__(
75
+ self,
76
+ tri_vertex_indices: wp.array,
77
+ positions: wp.array,
78
+ build_bvh: bool = False,
79
+ temporary_store: Optional[TemporaryStore] = None,
80
+ ):
81
+ """
82
+ Constructs a D-dimensional triangular mesh.
83
+
84
+ Args:
85
+ tri_vertex_indices: warp array of shape (num_tris, 3) containing vertex indices for each tri
86
+ positions: warp array of shape (num_vertices, D) containing the position of each vertex
87
+ temporary_store: shared pool from which to allocate temporary arrays
88
+ build_bvh: Whether to also build the triangle BVH, which is necessary for the global `fem.lookup` operator to function without initial guess
89
+ """
90
+
91
+ self.tri_vertex_indices = tri_vertex_indices
92
+ self.positions = positions
93
+
94
+ self._edge_vertex_indices: wp.array = None
95
+ self._edge_tri_indices: wp.array = None
96
+ self._vertex_tri_offsets: wp.array = None
97
+ self._vertex_tri_indices: wp.array = None
98
+ self._build_topology(temporary_store)
99
+
100
+ self._tri_bvh: wp.Bvh = None
101
+ if build_bvh:
102
+ self._build_bvh()
103
+
104
+ # Flip edges so that normals point away from inner cell
105
+ wp.launch(
106
+ kernel=self._orient_edges,
107
+ device=positions.device,
108
+ dim=self.side_count(),
109
+ inputs=[self._edge_vertex_indices, self._edge_tri_indices, self.tri_vertex_indices, self.positions],
110
+ )
111
+
112
+ self._make_default_dependent_implementations()
113
+
114
+ def update_bvh(self, force_rebuild: bool = False):
115
+ """
116
+ Refits the BVH, or rebuilds it from scratch if `force_rebuild` is ``True``.
117
+ """
118
+
119
+ if self._tri_bvh is None or force_rebuild:
120
+ return self.build_bvh()
121
+
122
+ wp.launch(
123
+ Trimesh._compute_tri_bounds,
124
+ self.tri_vertex_indices,
125
+ self.positions,
126
+ self._tri_bvh.lowers,
127
+ self._tri_bvh.uppers,
128
+ )
129
+ self._tri_bvh.refit()
130
+
131
+ def _build_bvh(self, temporary_store: Optional[TemporaryStore] = None):
132
+ lowers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
133
+ uppers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
134
+ wp.launch(
135
+ Trimesh._compute_tri_bounds,
136
+ device=self.positions.device,
137
+ dim=self.cell_count(),
138
+ inputs=[self.tri_vertex_indices, self.positions, lowers, uppers],
139
+ )
140
+ self._tri_bvh = wp.Bvh(lowers, uppers)
141
+
142
+ def cell_count(self):
143
+ return self.tri_vertex_indices.shape[0]
144
+
145
+ def vertex_count(self):
146
+ return self.positions.shape[0]
147
+
148
+ def side_count(self):
149
+ return self._edge_vertex_indices.shape[0]
150
+
151
+ def boundary_side_count(self):
152
+ return self._boundary_edge_indices.shape[0]
153
+
154
+ def reference_cell(self) -> Triangle:
155
+ return Triangle()
156
+
157
+ def reference_side(self) -> LinearEdge:
158
+ return LinearEdge()
159
+
160
+ @property
161
+ def edge_tri_indices(self) -> wp.array:
162
+ return self._edge_tri_indices
163
+
164
+ @property
165
+ def edge_vertex_indices(self) -> wp.array:
166
+ return self._edge_vertex_indices
167
+
168
+ @wp.struct
169
+ class SideIndexArg:
170
+ boundary_edge_indices: wp.array(dtype=int)
171
+
172
+ @cached_arg_value
173
+ def _cell_topo_arg_value(self, device):
174
+ args = TrimeshCellArg()
175
+
176
+ args.tri_vertex_indices = self.tri_vertex_indices.to(device)
177
+ args.vertex_tri_offsets = self._vertex_tri_offsets.to(device)
178
+ args.vertex_tri_indices = self._vertex_tri_indices.to(device)
179
+
180
+ return args
181
+
182
+ @cached_arg_value
183
+ def _side_topo_arg_value(self, device):
184
+ args = TrimeshSideArg()
185
+
186
+ args.cell_arg = self._cell_topo_arg_value(device)
187
+ args.edge_vertex_indices = self._edge_vertex_indices.to(device)
188
+ args.edge_tri_indices = self._edge_tri_indices.to(device)
189
+
190
+ return args
191
+
192
+ def _bvh_id(self, device):
193
+ if self._tri_bvh is None or self._tri_bvh.device != device:
194
+ return _NULL_BVH
195
+ return self._tri_bvh.id
196
+
197
+ def cell_arg_value(self, device):
198
+ args = self.CellArg()
199
+
200
+ args.topology = self._cell_topo_arg_value(device)
201
+ args.positions = self.positions.to(device)
202
+ args.topology.tri_bvh = self._bvh_id(device)
203
+
204
+ return args
205
+
206
+ def side_arg_value(self, device):
207
+ args = self.SideArg()
208
+
209
+ args.topology = self._side_topo_arg_value(device)
210
+ args.positions = self.positions.to(device)
211
+ args.topology.cell_arg.tri_bvh = self._bvh_id(device)
212
+
213
+ return args
214
+
215
+ @wp.func
216
+ def _project_on_tri(args: TrimeshCellArg, positions: wp.array(dtype=Any), pos: Any, tri_index: int):
217
+ p0 = positions[args.tri_vertex_indices[tri_index, 0]]
218
+
219
+ q = pos - p0
220
+ e1 = positions[args.tri_vertex_indices[tri_index, 1]] - p0
221
+ e2 = positions[args.tri_vertex_indices[tri_index, 2]] - p0
222
+
223
+ dist, coords = project_on_tri_at_origin(q, e1, e2)
224
+ return dist, coords
225
+
226
+ @wp.func
227
+ def _bvh_lookup(args: TrimeshCellArg, positions: wp.array(dtype=Any), pos: Any):
228
+ closest_tri = int(NULL_ELEMENT_INDEX)
229
+ closest_coords = Coords(OUTSIDE)
230
+ closest_dist = float(1.0e8)
231
+
232
+ if args.tri_bvh != _NULL_BVH:
233
+ bvh_pos = _bvh_vec(pos)
234
+ query = wp.bvh_query_aabb(args.tri_bvh, bvh_pos, bvh_pos)
235
+ tri = int(0)
236
+ while wp.bvh_query_next(query, tri):
237
+ dist, coords = Trimesh._project_on_tri(args, positions, pos, tri)
238
+ if dist <= closest_dist:
239
+ closest_dist = dist
240
+ closest_tri = tri
241
+ closest_coords = coords
242
+
243
+ return closest_dist, closest_tri, closest_coords
244
+
245
+ @wp.func
246
+ def _cell_neighbor_lookup(args: TrimeshCellArg, positions: wp.array(dtype=Any), pos: Any, cell_index: int):
247
+ closest_dist = float(1.0e8)
248
+
249
+ for v in range(3):
250
+ vtx = args.tri_vertex_indices[cell_index, v]
251
+ tri_beg = args.vertex_tri_offsets[vtx]
252
+ tri_end = args.vertex_tri_offsets[vtx + 1]
253
+
254
+ for t in range(tri_beg, tri_end):
255
+ tri = args.vertex_tri_indices[t]
256
+ dist, coords = Trimesh._project_on_tri(args, positions, pos, tri)
257
+ if dist <= closest_dist:
258
+ closest_dist = dist
259
+ closest_tri = tri
260
+ closest_coords = coords
261
+
262
+ return closest_dist, closest_tri, closest_coords
263
+
264
+ @cached_arg_value
265
+ def side_index_arg_value(self, device) -> SideIndexArg:
266
+ args = self.SideIndexArg()
267
+
268
+ args.boundary_edge_indices = self._boundary_edge_indices.to(device)
269
+
270
+ return args
271
+
272
+ @wp.func
273
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
274
+ """Boundary side to side index"""
275
+
276
+ return args.boundary_edge_indices[boundary_side_index]
277
+
278
+ @wp.func
279
+ def _edge_to_tri_coords(
280
+ args: TrimeshSideArg, side_index: ElementIndex, tri_index: ElementIndex, side_coords: Coords
281
+ ):
282
+ edge_vidx = args.edge_vertex_indices[side_index]
283
+ tri_vidx = args.cell_arg.tri_vertex_indices[tri_index]
284
+
285
+ v0 = tri_vidx[0]
286
+ v1 = tri_vidx[1]
287
+
288
+ cx = float(0.0)
289
+ cy = float(0.0)
290
+ cz = float(0.0)
291
+
292
+ if edge_vidx[0] == v0:
293
+ cx = 1.0 - side_coords[0]
294
+ elif edge_vidx[0] == v1:
295
+ cy = 1.0 - side_coords[0]
296
+ else:
297
+ cz = 1.0 - side_coords[0]
298
+
299
+ if edge_vidx[1] == v0:
300
+ cx = side_coords[0]
301
+ elif edge_vidx[1] == v1:
302
+ cy = side_coords[0]
303
+ else:
304
+ cz = side_coords[0]
305
+
306
+ return Coords(cx, cy, cz)
307
+
308
+ @wp.func
309
+ def _tri_to_edge_coords(
310
+ args: TrimeshSideArg,
311
+ side_index: ElementIndex,
312
+ tri_index: ElementIndex,
313
+ tri_coords: Coords,
314
+ ):
315
+ edge_vidx = args.edge_vertex_indices[side_index]
316
+ tri_vidx = args.cell_arg.tri_vertex_indices[tri_index]
317
+
318
+ start = int(2)
319
+ end = int(2)
320
+
321
+ for k in range(2):
322
+ v = tri_vidx[k]
323
+ if edge_vidx[1] == v:
324
+ end = k
325
+ elif edge_vidx[0] == v:
326
+ start = k
327
+
328
+ return wp.where(tri_coords[start] + tri_coords[end] > 0.999, Coords(tri_coords[end], 0.0, 0.0), Coords(OUTSIDE))
329
+
330
+ def _build_topology(self, temporary_store: TemporaryStore):
331
+ from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
332
+ from warp.utils import array_scan
333
+
334
+ device = self.tri_vertex_indices.device
335
+
336
+ vertex_tri_offsets, vertex_tri_indices = compress_node_indices(
337
+ self.vertex_count(), self.tri_vertex_indices, temporary_store=temporary_store
338
+ )
339
+ self._vertex_tri_offsets = vertex_tri_offsets.detach()
340
+ self._vertex_tri_indices = vertex_tri_indices.detach()
341
+
342
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
343
+ vertex_start_edge_count.array.zero_()
344
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
345
+
346
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(3 * self.cell_count()))
347
+ vertex_edge_tris = borrow_temporary(temporary_store, dtype=int, device=device, shape=(3 * self.cell_count(), 2))
348
+
349
+ # Count face edges starting at each vertex
350
+ wp.launch(
351
+ kernel=Trimesh._count_starting_edges_kernel,
352
+ device=device,
353
+ dim=self.cell_count(),
354
+ inputs=[self.tri_vertex_indices, vertex_start_edge_count.array],
355
+ )
356
+
357
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
358
+
359
+ # Count number of unique edges (deduplicate across faces)
360
+ vertex_unique_edge_count = vertex_start_edge_count
361
+ wp.launch(
362
+ kernel=Trimesh._count_unique_starting_edges_kernel,
363
+ device=device,
364
+ dim=self.vertex_count(),
365
+ inputs=[
366
+ self._vertex_tri_offsets,
367
+ self._vertex_tri_indices,
368
+ self.tri_vertex_indices,
369
+ vertex_start_edge_offsets.array,
370
+ vertex_unique_edge_count.array,
371
+ vertex_edge_ends.array,
372
+ vertex_edge_tris.array,
373
+ ],
374
+ )
375
+
376
+ vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
377
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
378
+
379
+ # Get back edge count to host
380
+ edge_count = int(
381
+ host_read_at_index(
382
+ vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
383
+ )
384
+ )
385
+
386
+ self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
387
+ self._edge_tri_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
388
+
389
+ boundary_mask = borrow_temporary(temporary_store=temporary_store, shape=(edge_count,), dtype=int, device=device)
390
+
391
+ # Compress edge data
392
+ wp.launch(
393
+ kernel=Trimesh._compress_edges_kernel,
394
+ device=device,
395
+ dim=self.vertex_count(),
396
+ inputs=[
397
+ vertex_start_edge_offsets.array,
398
+ vertex_unique_edge_offsets.array,
399
+ vertex_unique_edge_count.array,
400
+ vertex_edge_ends.array,
401
+ vertex_edge_tris.array,
402
+ self._edge_vertex_indices,
403
+ self._edge_tri_indices,
404
+ boundary_mask.array,
405
+ ],
406
+ )
407
+
408
+ vertex_start_edge_offsets.release()
409
+ vertex_unique_edge_offsets.release()
410
+ vertex_unique_edge_count.release()
411
+ vertex_edge_ends.release()
412
+ vertex_edge_tris.release()
413
+
414
+ boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
415
+ self._boundary_edge_indices = boundary_edge_indices.detach()
416
+
417
+ boundary_mask.release()
418
+
419
+ @wp.kernel
420
+ def _count_starting_edges_kernel(
421
+ tri_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
422
+ ):
423
+ t = wp.tid()
424
+ for k in range(3):
425
+ v0 = tri_vertex_indices[t, k]
426
+ v1 = tri_vertex_indices[t, (k + 1) % 3]
427
+
428
+ if v0 < v1:
429
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
430
+ else:
431
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
432
+
433
+ @wp.func
434
+ def _find(
435
+ needle: int,
436
+ values: wp.array(dtype=int),
437
+ beg: int,
438
+ end: int,
439
+ ):
440
+ for i in range(beg, end):
441
+ if values[i] == needle:
442
+ return i
443
+
444
+ return -1
445
+
446
+ @wp.kernel
447
+ def _count_unique_starting_edges_kernel(
448
+ vertex_tri_offsets: wp.array(dtype=int),
449
+ vertex_tri_indices: wp.array(dtype=int),
450
+ tri_vertex_indices: wp.array2d(dtype=int),
451
+ vertex_start_edge_offsets: wp.array(dtype=int),
452
+ vertex_start_edge_count: wp.array(dtype=int),
453
+ edge_ends: wp.array(dtype=int),
454
+ edge_tris: wp.array2d(dtype=int),
455
+ ):
456
+ v = wp.tid()
457
+
458
+ edge_beg = vertex_start_edge_offsets[v]
459
+
460
+ tri_beg = vertex_tri_offsets[v]
461
+ tri_end = vertex_tri_offsets[v + 1]
462
+
463
+ edge_cur = edge_beg
464
+
465
+ for tri in range(tri_beg, tri_end):
466
+ t = vertex_tri_indices[tri]
467
+
468
+ for k in range(3):
469
+ v0 = tri_vertex_indices[t, k]
470
+ v1 = tri_vertex_indices[t, (k + 1) % 3]
471
+
472
+ if v == wp.min(v0, v1):
473
+ other_v = wp.max(v0, v1)
474
+
475
+ # Check if other_v has been seen
476
+ seen_idx = Trimesh._find(other_v, edge_ends, edge_beg, edge_cur)
477
+
478
+ if seen_idx == -1:
479
+ edge_ends[edge_cur] = other_v
480
+ edge_tris[edge_cur, 0] = t
481
+ edge_tris[edge_cur, 1] = t
482
+ edge_cur += 1
483
+ else:
484
+ edge_tris[seen_idx, 1] = t
485
+
486
+ vertex_start_edge_count[v] = edge_cur - edge_beg
487
+
488
+ @wp.kernel
489
+ def _compress_edges_kernel(
490
+ vertex_start_edge_offsets: wp.array(dtype=int),
491
+ vertex_unique_edge_offsets: wp.array(dtype=int),
492
+ vertex_unique_edge_count: wp.array(dtype=int),
493
+ uncompressed_edge_ends: wp.array(dtype=int),
494
+ uncompressed_edge_tris: wp.array2d(dtype=int),
495
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
496
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
497
+ boundary_mask: wp.array(dtype=int),
498
+ ):
499
+ v = wp.tid()
500
+
501
+ start_beg = vertex_start_edge_offsets[v]
502
+ unique_beg = vertex_unique_edge_offsets[v]
503
+ unique_count = vertex_unique_edge_count[v]
504
+
505
+ for e in range(unique_count):
506
+ src_index = start_beg + e
507
+ edge_index = unique_beg + e
508
+
509
+ edge_vertex_indices[edge_index] = wp.vec2i(v, uncompressed_edge_ends[src_index])
510
+
511
+ t0 = uncompressed_edge_tris[src_index, 0]
512
+ t1 = uncompressed_edge_tris[src_index, 1]
513
+ edge_tri_indices[edge_index] = wp.vec2i(t0, t1)
514
+ if t0 == t1:
515
+ boundary_mask[edge_index] = 1
516
+ else:
517
+ boundary_mask[edge_index] = 0
518
+
519
+ @wp.kernel
520
+ def _compute_tri_bounds(
521
+ tri_vertex_indices: wp.array2d(dtype=int),
522
+ positions: wp.array(dtype=wp.vec2),
523
+ lowers: wp.array(dtype=wp.vec3),
524
+ uppers: wp.array(dtype=wp.vec3),
525
+ ):
526
+ t = wp.tid()
527
+ p0 = _bvh_vec(positions[tri_vertex_indices[t, 0]])
528
+ p1 = _bvh_vec(positions[tri_vertex_indices[t, 1]])
529
+ p2 = _bvh_vec(positions[tri_vertex_indices[t, 2]])
530
+
531
+ lowers[t] = wp.vec3(
532
+ wp.min(wp.min(p0[0], p1[0]), p2[0]),
533
+ wp.min(wp.min(p0[1], p1[1]), p2[1]),
534
+ wp.min(wp.min(p0[2], p1[2]), p2[2]),
535
+ )
536
+ uppers[t] = wp.vec3(
537
+ wp.max(wp.max(p0[0], p1[0]), p2[0]),
538
+ wp.max(wp.max(p0[1], p1[1]), p2[1]),
539
+ wp.max(wp.max(p0[2], p1[2]), p2[2]),
540
+ )
541
+
542
+
543
+ @wp.struct
544
+ class Trimesh2DCellArg:
545
+ topology: TrimeshCellArg
546
+ positions: wp.array(dtype=wp.vec2)
547
+
548
+
549
+ @wp.struct
550
+ class Trimesh2DSideArg:
551
+ topology: TrimeshSideArg
552
+ positions: wp.array(dtype=wp.vec2)
553
+
554
+
555
+ class Trimesh2D(Trimesh):
556
+ """2D Triangular mesh geometry"""
557
+
558
+ dimension = 2
559
+ CellArg = Trimesh2DCellArg
560
+ SideArg = Trimesh2DSideArg
561
+
562
+ @wp.func
563
+ def cell_position(args: CellArg, s: Sample):
564
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
565
+ return (
566
+ s.element_coords[0] * args.positions[tri_idx[0]]
567
+ + s.element_coords[1] * args.positions[tri_idx[1]]
568
+ + s.element_coords[2] * args.positions[tri_idx[2]]
569
+ )
570
+
571
+ @wp.func
572
+ def cell_deformation_gradient(args: CellArg, s: Sample):
573
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
574
+ p0 = args.positions[tri_idx[0]]
575
+ p1 = args.positions[tri_idx[1]]
576
+ p2 = args.positions[tri_idx[2]]
577
+ return wp.matrix_from_cols(p1 - p0, p2 - p0)
578
+
579
+ @wp.func
580
+ def cell_lookup(args: CellArg, pos: wp.vec2):
581
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
582
+
583
+ return make_free_sample(closest_tri, closest_coords)
584
+
585
+ @wp.func
586
+ def cell_lookup(args: CellArg, pos: wp.vec2, guess: Sample):
587
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
588
+
589
+ if closest_tri == NULL_ELEMENT_INDEX:
590
+ closest_dist, closest_tri, closest_coords = Trimesh._cell_neighbor_lookup(
591
+ args.topology, args.positions, pos, guess.element_index
592
+ )
593
+
594
+ return make_free_sample(closest_tri, closest_coords)
595
+
596
+ @wp.func
597
+ def side_position(args: SideArg, s: Sample):
598
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
599
+ return (1.0 - s.element_coords[0]) * args.positions[edge_idx[0]] + s.element_coords[0] * args.positions[
600
+ edge_idx[1]
601
+ ]
602
+
603
+ @wp.func
604
+ def side_deformation_gradient(args: SideArg, s: Sample):
605
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
606
+ v0 = args.positions[edge_idx[0]]
607
+ v1 = args.positions[edge_idx[1]]
608
+ return v1 - v0
609
+
610
+ @wp.func
611
+ def side_normal(args: SideArg, s: Sample):
612
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
613
+ v0 = args.positions[edge_idx[0]]
614
+ v1 = args.positions[edge_idx[1]]
615
+ e = v1 - v0
616
+
617
+ return wp.normalize(wp.vec2(-e[1], e[0]))
618
+
619
+ @wp.func
620
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
621
+ return arg.topology.edge_tri_indices[side_index][0]
622
+
623
+ @wp.func
624
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
625
+ return arg.topology.edge_tri_indices[side_index][1]
626
+
627
+ @wp.func
628
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
629
+ inner_cell_index = Trimesh2D.side_inner_cell_index(args, side_index)
630
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, inner_cell_index, side_coords)
631
+
632
+ @wp.func
633
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
634
+ outer_cell_index = Trimesh2D.side_outer_cell_index(args, side_index)
635
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, outer_cell_index, side_coords)
636
+
637
+ @wp.func
638
+ def side_from_cell_coords(
639
+ args: SideArg,
640
+ side_index: ElementIndex,
641
+ tri_index: ElementIndex,
642
+ tri_coords: Coords,
643
+ ):
644
+ return Trimesh._tri_to_edge_coords(args.topology, side_index, tri_index, tri_coords)
645
+
646
+ @wp.func
647
+ def side_to_cell_arg(side_arg: SideArg):
648
+ return Trimesh2DCellArg(side_arg.topology.cell_arg, side_arg.positions)
649
+
650
+ @wp.kernel
651
+ def _orient_edges(
652
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
653
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
654
+ tri_vertex_indices: wp.array2d(dtype=int),
655
+ positions: wp.array(dtype=wp.vec2),
656
+ ):
657
+ e = wp.tid()
658
+
659
+ tri = edge_tri_indices[e][0]
660
+
661
+ tri_vidx = tri_vertex_indices[tri]
662
+ edge_vidx = edge_vertex_indices[e]
663
+
664
+ tri_centroid = (positions[tri_vidx[0]] + positions[tri_vidx[1]] + positions[tri_vidx[2]]) / 3.0
665
+
666
+ v0 = positions[edge_vidx[0]]
667
+ v1 = positions[edge_vidx[1]]
668
+
669
+ edge_center = 0.5 * (v1 + v0)
670
+ edge_vec = v1 - v0
671
+ edge_normal = wp.vec2(-edge_vec[1], edge_vec[0])
672
+
673
+ # if edge normal points toward first triangle centroid, flip indices
674
+ if wp.dot(tri_centroid - edge_center, edge_normal) > 0.0:
675
+ edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])
676
+
677
+
678
+ @wp.struct
679
+ class Trimesh3DCellArg:
680
+ topology: TrimeshCellArg
681
+ positions: wp.array(dtype=wp.vec3)
682
+
683
+
684
+ @wp.struct
685
+ class Trimesh3DSideArg:
686
+ topology: TrimeshSideArg
687
+ positions: wp.array(dtype=wp.vec3)
688
+
689
+
690
+ class Trimesh3D(Trimesh):
691
+ """3D Triangular mesh geometry"""
692
+
693
+ dimension = 3
694
+ CellArg = Trimesh3DCellArg
695
+ SideArg = Trimesh3DSideArg
696
+
697
+ @wp.func
698
+ def cell_position(args: CellArg, s: Sample):
699
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
700
+ return (
701
+ s.element_coords[0] * args.positions[tri_idx[0]]
702
+ + s.element_coords[1] * args.positions[tri_idx[1]]
703
+ + s.element_coords[2] * args.positions[tri_idx[2]]
704
+ )
705
+
706
+ @wp.func
707
+ def cell_deformation_gradient(args: CellArg, s: Sample):
708
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
709
+ p0 = args.positions[tri_idx[0]]
710
+ p1 = args.positions[tri_idx[1]]
711
+ p2 = args.positions[tri_idx[2]]
712
+ return wp.matrix_from_cols(p1 - p0, p2 - p0)
713
+
714
+ @wp.func
715
+ def cell_lookup(args: CellArg, pos: wp.vec3):
716
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
717
+
718
+ return make_free_sample(closest_tri, closest_coords)
719
+
720
+ @wp.func
721
+ def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
722
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
723
+
724
+ if closest_tri == NULL_ELEMENT_INDEX:
725
+ closest_dist, closest_tri, closest_coords = Trimesh._cell_neighbor_lookup(
726
+ args.topology, args.positions, pos, guess.element_index
727
+ )
728
+
729
+ return make_free_sample(closest_tri, closest_coords)
730
+
731
+ @wp.func
732
+ def side_position(args: SideArg, s: Sample):
733
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
734
+ return (1.0 - s.element_coords[0]) * args.positions[edge_idx[0]] + s.element_coords[0] * args.positions[
735
+ edge_idx[1]
736
+ ]
737
+
738
+ @wp.func
739
+ def side_deformation_gradient(args: SideArg, s: Sample):
740
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
741
+ v0 = args.positions[edge_idx[0]]
742
+ v1 = args.positions[edge_idx[1]]
743
+ return v1 - v0
744
+
745
+ @wp.func
746
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
747
+ return arg.topology.edge_tri_indices[side_index][0]
748
+
749
+ @wp.func
750
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
751
+ return arg.topology.edge_tri_indices[side_index][1]
752
+
753
+ @wp.func
754
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
755
+ inner_cell_index = Trimesh3D.side_inner_cell_index(args, side_index)
756
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, inner_cell_index, side_coords)
757
+
758
+ @wp.func
759
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
760
+ outer_cell_index = Trimesh3D.side_outer_cell_index(args, side_index)
761
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, outer_cell_index, side_coords)
762
+
763
+ @wp.func
764
+ def side_from_cell_coords(
765
+ args: SideArg,
766
+ side_index: ElementIndex,
767
+ tri_index: ElementIndex,
768
+ tri_coords: Coords,
769
+ ):
770
+ return Trimesh._tri_to_edge_coords(args.topology, side_index, tri_index, tri_coords)
771
+
772
+ @wp.func
773
+ def side_to_cell_arg(side_arg: SideArg):
774
+ return Trimesh3DCellArg(side_arg.topology.cell_arg, side_arg.positions)
775
+
776
+ @wp.kernel
777
+ def _orient_edges(
778
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
779
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
780
+ tri_vertex_indices: wp.array2d(dtype=int),
781
+ positions: wp.array(dtype=wp.vec3),
782
+ ):
783
+ e = wp.tid()
784
+
785
+ tri = edge_tri_indices[e][0]
786
+
787
+ tri_vidx = tri_vertex_indices[tri]
788
+ edge_vidx = edge_vertex_indices[e]
789
+
790
+ t0 = positions[tri_vidx[0]]
791
+ t1 = positions[tri_vidx[1]]
792
+ t2 = positions[tri_vidx[2]]
793
+
794
+ tri_centroid = (t0 + t1 + t2) / 3.0
795
+ tri_normal = wp.cross(t1 - t0, t2 - t0)
796
+
797
+ v0 = positions[edge_vidx[0]]
798
+ v1 = positions[edge_vidx[1]]
799
+
800
+ edge_center = 0.5 * (v1 + v0)
801
+ edge_vec = v1 - v0
802
+ edge_normal = wp.cross(edge_vec, tri_normal)
803
+
804
+ # if edge normal points toward first triangle centroid, flip indices
805
+ if wp.dot(tri_centroid - edge_center, edge_normal) > 0.0:
806
+ edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])