warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,693 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import cached_property
17
+ from typing import Any, ClassVar
18
+
19
+ import warp as wp
20
+ from warp._src.codegen import Struct
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, ElementKind, Sample, make_free_sample
23
+
24
+ from .element import Element
25
+
26
+ _wp_module_name_ = "warp.fem.geometry.geometry"
27
+
28
+ _mat32 = wp.mat(shape=(3, 2), dtype=float)
29
+
30
+ _NULL_BVH_ID = wp.uint64(0)
31
+ _COORD_LOOKUP_ITERATIONS = 24
32
+ _COORD_LOOKUP_STEP = 1.0
33
+ _COORD_LOOKUP_EPS = float(2**-20)
34
+ _BVH_MIN_PADDING = float(2**-16)
35
+ _BVH_MAX_PADDING = float(2**16)
36
+
37
+
38
+ class Geometry:
39
+ """
40
+ Interface class for discrete geometries
41
+
42
+ A geometry is composed of cells and sides. Sides may be boundary or interior (between cells).
43
+ """
44
+
45
+ dimension: int = 0
46
+
47
+ _bvhs = None
48
+
49
+ def cell_count(self):
50
+ """Number of cells in the geometry"""
51
+ raise NotImplementedError
52
+
53
+ def side_count(self):
54
+ """Number of sides in the geometry"""
55
+ raise NotImplementedError
56
+
57
+ def boundary_side_count(self):
58
+ """Number of boundary sides (sides with a single neighbour cell) in the geometry"""
59
+ raise NotImplementedError
60
+
61
+ def reference_cell(self) -> Element:
62
+ """Prototypical element for a cell"""
63
+ raise NotImplementedError
64
+
65
+ def reference_side(self) -> Element:
66
+ """Prototypical element for a side"""
67
+ raise NotImplementedError
68
+
69
+ @property
70
+ def cell_dimension(self) -> int:
71
+ """Manifold dimension of the geometry cells"""
72
+ return self.reference_cell().prototype.dimension
73
+
74
+ @property
75
+ def base(self) -> "Geometry":
76
+ """Returns the base geometry from which this geometry derives its topology. Usually `self`"""
77
+ return self
78
+
79
+ @property
80
+ def name(self) -> str:
81
+ return self.__class__.__name__
82
+
83
+ def __str__(self) -> str:
84
+ return self.name
85
+
86
+ CellArg: Struct
87
+ """Structure containing arguments to be passed to device functions evaluating cell-related quantities"""
88
+
89
+ SideArg: Struct
90
+ """Structure containing arguments to be passed to device functions evaluating side-related quantities"""
91
+
92
+ SideIndexArg: Struct
93
+ """Structure containing arguments to be passed to device functions for indexing sides"""
94
+
95
+ @cache.cached_arg_value
96
+ def cell_arg_value(self, device) -> "Geometry.CellArg":
97
+ """Value of the arguments to be passed to cell-related device functions"""
98
+ args = self.CellArg()
99
+ self.fill_cell_arg(args, device)
100
+ return args
101
+
102
+ def fill_cell_arg(self, args: "Geometry.CellArg", device):
103
+ """Fill the arguments to be passed to cell-related device functions"""
104
+ if self.cell_arg_value is __class__.cell_arg_value:
105
+ raise NotImplementedError()
106
+ args.assign(self.cell_arg_value(device))
107
+
108
+ @staticmethod
109
+ def cell_position(args: "Geometry.CellArg", s: "Sample"):
110
+ """Device function returning the world position of a cell sample point"""
111
+ raise NotImplementedError
112
+
113
+ @staticmethod
114
+ def cell_deformation_gradient(args: "Geometry.CellArg", s: "Sample"):
115
+ """Device function returning the transpose of the gradient of world position with respect to reference cell"""
116
+ raise NotImplementedError
117
+
118
+ @staticmethod
119
+ def cell_inverse_deformation_gradient(args: "Geometry.CellArg", s: "Sample"):
120
+ """Device function returning the matrix right-transforming a gradient w.r.t. cell space to a gradient w.r.t. world space
121
+ (i.e. the inverse deformation gradient)
122
+ """
123
+ raise NotImplementedError
124
+
125
+ @staticmethod
126
+ def cell_measure(args: "Geometry.CellArg", s: "Sample"):
127
+ """Device function returning the measure determinant (e.g. volume, area) at a given point"""
128
+ raise NotImplementedError
129
+
130
+ @wp.func
131
+ def cell_measure_ratio(args: Any, s: Sample):
132
+ return 1.0
133
+
134
+ @staticmethod
135
+ def cell_normal(args: "Geometry.CellArg", s: "Sample"):
136
+ """Device function returning the element normal at a sample point.
137
+
138
+ For elements with the same dimension as the embedding space, this will be zero."""
139
+ raise NotImplementedError
140
+
141
+ @cache.cached_arg_value
142
+ def side_arg_value(self, device) -> "Geometry.SideArg":
143
+ """Value of the arguments to be passed to side-related device functions"""
144
+ args = self.SideArg()
145
+ self.fill_side_arg(args, device)
146
+ return args
147
+
148
+ def fill_side_arg(self, args: "Geometry.SideArg", device):
149
+ """Fill the arguments to be passed to side-related device functions"""
150
+ if self.side_arg_value is __class__.side_arg_value:
151
+ raise NotImplementedError()
152
+ args.assign(self.side_arg_value(device))
153
+
154
+ @cache.cached_arg_value
155
+ def side_index_arg_value(self, device) -> "Geometry.SideIndexArg":
156
+ """Value of the arguments to be passed to side-related device functions"""
157
+ args = self.SideIndexArg()
158
+ self.fill_side_index_arg(args, device)
159
+ return args
160
+
161
+ def fill_side_index_arg(self, args: "Geometry.SideIndexArg", device):
162
+ """Fill the arguments to be passed to side-related device functions"""
163
+ if self.side_index_arg_value is __class__.side_index_arg_value:
164
+ raise NotImplementedError()
165
+ args.assign(self.side_index_arg_value(device))
166
+
167
+ @staticmethod
168
+ def boundary_side_index(args: "Geometry.SideIndexArg", boundary_side_index: int):
169
+ """Device function returning the side index corresponding to a boundary side"""
170
+ raise NotImplementedError
171
+
172
+ @staticmethod
173
+ def side_position(args: "Geometry.SideArg", s: "Sample"):
174
+ """Device function returning the side position at a sample point"""
175
+ raise NotImplementedError
176
+
177
+ @staticmethod
178
+ def side_deformation_gradient(args: "Geometry.SideArg", s: "Sample"):
179
+ """Device function returning the gradient of world position with respect to reference side"""
180
+ raise NotImplementedError
181
+
182
+ @staticmethod
183
+ def side_inner_inverse_deformation_gradient(args: "Geometry.Siderg", side_index: ElementIndex, coords: Coords):
184
+ """Device function returning the matrix right-transforming a gradient w.r.t. inner cell space to a gradient w.r.t. world space
185
+ (i.e. the inverse deformation gradient)
186
+ """
187
+ raise NotImplementedError
188
+
189
+ @staticmethod
190
+ def side_outer_inverse_deformation_gradient(args: "Geometry.CellArg", side_index: ElementIndex, coords: Coords):
191
+ """Device function returning the matrix right-transforming a gradient w.r.t. outer cell space to a gradient w.r.t. world space
192
+ (i.e. the inverse deformation gradient)
193
+ """
194
+ raise NotImplementedError
195
+
196
+ @staticmethod
197
+ def side_measure(args: "Geometry.SideArg", s: "Sample"):
198
+ """Device function returning the measure determinant (e.g. volume, area) at a given point"""
199
+ raise NotImplementedError
200
+
201
+ @staticmethod
202
+ def side_measure_ratio(args: "Geometry.SideArg", s: "Sample"):
203
+ """Device function returning the ratio of the measure of a side to that of its neighbour cells"""
204
+ raise NotImplementedError
205
+
206
+ @staticmethod
207
+ def side_normal(args: "Geometry.SideArg", s: "Sample"):
208
+ """Device function returning the element normal at a sample point"""
209
+ raise NotImplementedError
210
+
211
+ @staticmethod
212
+ def side_inner_cell_index(args: "Geometry.SideArg", side_index: ElementIndex):
213
+ """Device function returning the inner cell index for a given side"""
214
+ raise NotImplementedError
215
+
216
+ @staticmethod
217
+ def side_outer_cell_index(args: "Geometry.SideArg", side_index: ElementIndex):
218
+ """Device function returning the outer cell index for a given side"""
219
+ raise NotImplementedError
220
+
221
+ @staticmethod
222
+ def side_inner_cell_coords(args: "Geometry.SideArg", side_index: ElementIndex, side_coords: Coords):
223
+ """Device function returning the coordinates of a point on a side in the inner cell"""
224
+ raise NotImplementedError
225
+
226
+ @staticmethod
227
+ def side_outer_cell_coords(args: "Geometry.SideArg", side_index: ElementIndex, side_coords: Coords):
228
+ """Device function returning the coordinates of a point on a side in the outer cell"""
229
+ raise NotImplementedError
230
+
231
+ @staticmethod
232
+ def side_from_cell_coords(
233
+ args: "Geometry.SideArg",
234
+ side_index: ElementIndex,
235
+ element_index: ElementIndex,
236
+ element_coords: Coords,
237
+ ):
238
+ """Device function converting coordinates on a cell to coordinates on a side, or ``OUTSIDE``"""
239
+ raise NotImplementedError
240
+
241
+ @staticmethod
242
+ def side_to_cell_arg(side_arg: "Geometry.SideArg"):
243
+ """Device function converting a side-related argument value to a cell-related argument value, for promoting trace samples to the full space"""
244
+ raise NotImplementedError
245
+
246
+ # Default implementations for dependent quantities
247
+ # Can be overridden in derived classes if more efficient implementations exist
248
+
249
+ _dynamic_attribute_constructors: ClassVar = {
250
+ "cell_inverse_deformation_gradient": lambda obj: obj._make_cell_inverse_deformation_gradient(),
251
+ "cell_measure": lambda obj: obj._make_cell_measure(),
252
+ "cell_normal": lambda obj: obj._make_cell_normal(),
253
+ "side_inverse_deformation_gradient": lambda obj: obj._make_side_inverse_deformation_gradient(),
254
+ "side_inner_inverse_deformation_gradient": lambda obj: obj._make_side_inner_inverse_deformation_gradient(),
255
+ "side_outer_inverse_deformation_gradient": lambda obj: obj._make_side_outer_inverse_deformation_gradient(),
256
+ "side_measure": lambda obj: obj._make_side_measure(),
257
+ "side_measure_ratio": lambda obj: obj._make_side_measure_ratio(),
258
+ "side_normal": lambda obj: obj._make_side_normal(),
259
+ "compute_cell_bounds": lambda obj: obj._make_compute_cell_bounds(),
260
+ }
261
+
262
+ def _make_default_dependent_implementations(self):
263
+ cache.setup_dynamic_attributes(self, cls=__class__)
264
+
265
+ @wp.func
266
+ def _element_measure(F: wp.vec2):
267
+ return wp.length(F)
268
+
269
+ @wp.func
270
+ def _element_measure(F: wp.vec3):
271
+ return wp.length(F)
272
+
273
+ @wp.func
274
+ def _element_measure(F: _mat32):
275
+ Ft = wp.transpose(F)
276
+ Fcross = wp.cross(Ft[0], Ft[1])
277
+ return wp.length(Fcross)
278
+
279
+ @wp.func
280
+ def _element_measure(F: wp.mat33):
281
+ return wp.abs(wp.determinant(F))
282
+
283
+ @wp.func
284
+ def _element_measure(F: wp.mat22):
285
+ return wp.abs(wp.determinant(F))
286
+
287
+ @wp.func
288
+ def _element_normal(F: wp.vec2):
289
+ return wp.normalize(wp.vec2(F[1], -F[0]))
290
+
291
+ @wp.func
292
+ def _element_normal(F: _mat32):
293
+ Ft = wp.transpose(F)
294
+ Fcross = wp.cross(Ft[0], Ft[1])
295
+ return wp.normalize(Fcross)
296
+
297
+ def _make_cell_measure(self):
298
+ REF_MEASURE = wp.constant(self.reference_cell().prototype.measure())
299
+
300
+ @cache.dynamic_func(suffix=self.name)
301
+ def cell_measure(args: self.CellArg, s: Sample):
302
+ F = self.cell_deformation_gradient(args, s)
303
+ return Geometry._element_measure(F) * REF_MEASURE
304
+
305
+ return cell_measure
306
+
307
+ def _make_cell_normal(self):
308
+ cell_dim = self.reference_cell().prototype.dimension
309
+ geo_dim = self.dimension
310
+ normal_vec = wp.vec(length=geo_dim, dtype=float)
311
+
312
+ @cache.dynamic_func(suffix=self.name)
313
+ def zero_normal(args: self.CellArg, s: Sample):
314
+ return normal_vec(0.0)
315
+
316
+ @cache.dynamic_func(suffix=self.name)
317
+ def cell_hyperplane_normal(args: self.CellArg, s: Sample):
318
+ F = self.cell_deformation_gradient(args, s)
319
+ return Geometry._element_normal(F)
320
+
321
+ if cell_dim == geo_dim:
322
+ return zero_normal
323
+ if cell_dim == geo_dim - 1:
324
+ return cell_hyperplane_normal
325
+
326
+ return None
327
+
328
+ def _make_cell_inverse_deformation_gradient(self):
329
+ cell_dim = self.reference_cell().prototype.dimension
330
+ geo_dim = self.dimension
331
+
332
+ @cache.dynamic_func(suffix=self.name)
333
+ def cell_inverse_deformation_gradient(cell_arg: self.CellArg, s: Sample):
334
+ return wp.inverse(self.cell_deformation_gradient(cell_arg, s))
335
+
336
+ @cache.dynamic_func(suffix=self.name)
337
+ def cell_pseudoinverse_deformation_gradient(cell_arg: self.CellArg, s: Sample):
338
+ F = self.cell_deformation_gradient(cell_arg, s)
339
+ Ft = wp.transpose(F)
340
+ return wp.inverse(Ft * F) * Ft
341
+
342
+ return cell_inverse_deformation_gradient if cell_dim == geo_dim else cell_pseudoinverse_deformation_gradient
343
+
344
+ def _make_side_inverse_deformation_gradient(self):
345
+ side_dim = self.reference_side().prototype.dimension
346
+ geo_dim = self.dimension
347
+
348
+ if side_dim == geo_dim:
349
+
350
+ @cache.dynamic_func(suffix=self.name)
351
+ def side_inverse_deformation_gradient(side_arg: self.SideArg, s: Sample):
352
+ return wp.inverse(self.side_deformation_gradient(side_arg, s))
353
+
354
+ return side_inverse_deformation_gradient
355
+
356
+ if side_dim == 1:
357
+
358
+ @cache.dynamic_func(suffix=self.name)
359
+ def edge_pseudoinverse_deformation_gradient(side_arg: self.SideArg, s: Sample):
360
+ F = self.side_deformation_gradient(side_arg, s)
361
+ return wp.matrix_from_rows(F / wp.dot(F, F))
362
+
363
+ return edge_pseudoinverse_deformation_gradient
364
+
365
+ @cache.dynamic_func(suffix=self.name)
366
+ def side_pseudoinverse_deformation_gradient(side_arg: self.SideArg, s: Sample):
367
+ F = self.side_deformation_gradient(side_arg, s)
368
+ Ft = wp.transpose(F)
369
+ return wp.inverse(Ft * F) * Ft
370
+
371
+ return side_pseudoinverse_deformation_gradient
372
+
373
+ def _make_side_measure(self):
374
+ REF_MEASURE = wp.constant(self.reference_side().prototype.measure())
375
+
376
+ @cache.dynamic_func(suffix=self.name)
377
+ def side_measure(args: self.SideArg, s: Sample):
378
+ F = self.side_deformation_gradient(args, s)
379
+ return Geometry._element_measure(F) * REF_MEASURE
380
+
381
+ return side_measure
382
+
383
+ def _make_side_measure_ratio(self):
384
+ @cache.dynamic_func(suffix=self.name)
385
+ def side_measure_ratio(args: self.SideArg, s: Sample):
386
+ inner = self.side_inner_cell_index(args, s.element_index)
387
+ outer = self.side_outer_cell_index(args, s.element_index)
388
+ inner_coords = self.side_inner_cell_coords(args, s.element_index, s.element_coords)
389
+ outer_coords = self.side_outer_cell_coords(args, s.element_index, s.element_coords)
390
+ cell_arg = self.side_to_cell_arg(args)
391
+ return self.side_measure(args, s) / wp.min(
392
+ self.cell_measure(cell_arg, make_free_sample(inner, inner_coords)),
393
+ self.cell_measure(cell_arg, make_free_sample(outer, outer_coords)),
394
+ )
395
+
396
+ return side_measure_ratio
397
+
398
+ def _make_side_normal(self):
399
+ side_dim = self.reference_side().prototype.dimension
400
+ geo_dim = self.dimension
401
+
402
+ @cache.dynamic_func(suffix=self.name)
403
+ def hyperplane_normal(args: self.SideArg, s: Sample):
404
+ F = self.side_deformation_gradient(args, s)
405
+ return Geometry._element_normal(F)
406
+
407
+ if side_dim == geo_dim - 1:
408
+ return hyperplane_normal
409
+
410
+ return None
411
+
412
+ def _make_side_inner_inverse_deformation_gradient(self):
413
+ @cache.dynamic_func(suffix=self.name)
414
+ def side_inner_inverse_deformation_gradient(args: self.SideArg, s: Sample):
415
+ cell_index = self.side_inner_cell_index(args, s.element_index)
416
+ cell_coords = self.side_inner_cell_coords(args, s.element_index, s.element_coords)
417
+ cell_arg = self.side_to_cell_arg(args)
418
+ return self.cell_inverse_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
419
+
420
+ return side_inner_inverse_deformation_gradient
421
+
422
+ def _make_side_outer_inverse_deformation_gradient(self):
423
+ @cache.dynamic_func(suffix=self.name)
424
+ def side_outer_inverse_deformation_gradient(args: self.SideArg, s: Sample):
425
+ cell_index = self.side_outer_cell_index(args, s.element_index)
426
+ cell_coords = self.side_outer_cell_coords(args, s.element_index, s.element_coords)
427
+ cell_arg = self.side_to_cell_arg(args)
428
+ return self.cell_inverse_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
429
+
430
+ return side_outer_inverse_deformation_gradient
431
+
432
+ def _make_element_coordinates(self, element_kind: ElementKind, assume_linear: bool = False):
433
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
434
+
435
+ if element_kind == ElementKind.CELL:
436
+ ref_elt = self.reference_cell().prototype
437
+ arg_type = self.CellArg
438
+ elt_pos = self.cell_position
439
+ elt_inv_grad = self.cell_inverse_deformation_gradient
440
+ else:
441
+ ref_elt = self.reference_side().prototype
442
+ arg_type = self.SideArg
443
+ elt_pos = self.side_position
444
+ elt_inv_grad = self.side_inverse_deformation_gradient
445
+
446
+ elt_center = Coords(ref_elt.center())
447
+
448
+ ITERATIONS = 1 if assume_linear else _COORD_LOOKUP_ITERATIONS
449
+ STEP = 1.0 if assume_linear else _COORD_LOOKUP_STEP
450
+
451
+ @cache.dynamic_func(suffix=f"{self.name}{element_kind}{assume_linear}")
452
+ def element_coordinates(args: arg_type, element_index: ElementIndex, pos: pos_type):
453
+ coords = elt_center
454
+
455
+ # Newton loop (single iteration in linear case)
456
+ for _k in range(ITERATIONS):
457
+ s = make_free_sample(element_index, coords)
458
+ x = elt_pos(args, s)
459
+ dc = elt_inv_grad(args, s) * (pos - x)
460
+ if wp.static(not assume_linear):
461
+ if wp.length_sq(dc) < _COORD_LOOKUP_EPS:
462
+ break
463
+ coords = coords + ref_elt.coord_delta(STEP * dc)
464
+
465
+ return coords
466
+
467
+ return element_coordinates
468
+
469
+ def _make_cell_coordinates(self, assume_linear: bool = False):
470
+ return self._make_element_coordinates(element_kind=ElementKind.CELL, assume_linear=assume_linear)
471
+
472
+ def _make_side_coordinates(self, assume_linear: bool = False):
473
+ return self._make_element_coordinates(element_kind=ElementKind.SIDE, assume_linear=assume_linear)
474
+
475
+ def _make_element_closest_point(self, element_kind: ElementKind, assume_linear: bool = False):
476
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
477
+
478
+ element_coordinates = self._make_element_coordinates(element_kind=element_kind, assume_linear=assume_linear)
479
+
480
+ if element_kind == ElementKind.CELL:
481
+ ref_elt = self.reference_cell().prototype
482
+ arg_type = self.CellArg
483
+ elt_pos = self.cell_position
484
+ elt_def_grad = self.cell_deformation_gradient
485
+ else:
486
+ ref_elt = self.reference_side().prototype
487
+ arg_type = self.SideArg
488
+ elt_pos = self.side_position
489
+ elt_def_grad = self.side_deformation_gradient
490
+
491
+ @cache.dynamic_func(suffix=f"{self.name}{element_kind}{assume_linear}")
492
+ def cell_closest_point(args: arg_type, cell_index: ElementIndex, pos: pos_type):
493
+ # First get unconstrained coordinates, may use newton for this
494
+ coords = element_coordinates(args, cell_index, pos)
495
+
496
+ # Now do projected gradient
497
+ # For interior points should exit at first iteration
498
+ for _k in range(_COORD_LOOKUP_ITERATIONS):
499
+ cur_coords = coords
500
+ s = make_free_sample(cell_index, cur_coords)
501
+ x = elt_pos(args, s)
502
+
503
+ F = elt_def_grad(args, s)
504
+ F_scale = wp.ddot(F, F)
505
+
506
+ dc = (pos - x) @ F # gradient step
507
+ coords = ref_elt.project(cur_coords + ref_elt.coord_delta(dc / F_scale))
508
+
509
+ if wp.length_sq(coords - cur_coords) < _COORD_LOOKUP_EPS:
510
+ break
511
+
512
+ return cur_coords, wp.length_sq(pos - x)
513
+
514
+ return cell_closest_point
515
+
516
+ def _make_cell_closest_point(self, assume_linear: bool = False):
517
+ return self._make_element_closest_point(element_kind=ElementKind.CELL, assume_linear=assume_linear)
518
+
519
+ def _make_side_closest_point(self, assume_linear: bool = False):
520
+ return self._make_element_closest_point(element_kind=ElementKind.SIDE, assume_linear=assume_linear)
521
+
522
+ def make_filtered_cell_lookup(self, filter_func: wp.Function = None):
523
+ suffix = f"{self.name}{filter_func.key if filter_func is not None else ''}"
524
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
525
+
526
+ @cache.dynamic_func(suffix=suffix)
527
+ def cell_lookup(args: self.CellArg, pos: pos_type, max_dist: float, filter_data: Any, filter_target: Any):
528
+ closest_cell = int(NULL_ELEMENT_INDEX)
529
+ closest_coords = Coords(OUTSIDE)
530
+
531
+ bvh_id = self.cell_bvh_id(args)
532
+ if bvh_id != _NULL_BVH_ID:
533
+ pad = wp.max(max_dist, 1.0) * _BVH_MIN_PADDING
534
+
535
+ # query with increasing bbox size until we find an element
536
+ # or reach the max distance bound
537
+ while closest_cell == NULL_ELEMENT_INDEX:
538
+ query = wp.bvh_query_aabb(bvh_id, _bvh_vec(pos) - wp.vec3(pad), _bvh_vec(pos) + wp.vec3(pad))
539
+ cell_index = int(0)
540
+ closest_dist = float(pad * pad)
541
+
542
+ while wp.bvh_query_next(query, cell_index):
543
+ if wp.static(filter_func is not None):
544
+ if filter_func(filter_data, cell_index) != filter_target:
545
+ continue
546
+
547
+ coords, dist = self.cell_closest_point(args, cell_index, pos)
548
+ if dist <= closest_dist:
549
+ closest_dist = dist
550
+ closest_cell = cell_index
551
+ closest_coords = coords
552
+
553
+ if pad >= _BVH_MAX_PADDING:
554
+ break
555
+ pad = wp.min(4.0 * pad, _BVH_MAX_PADDING)
556
+
557
+ return make_free_sample(closest_cell, closest_coords)
558
+
559
+ return cell_lookup
560
+
561
+ @cached_property
562
+ def cell_lookup(self) -> wp.Function:
563
+ unfiltered_cell_lookup = self.make_filtered_cell_lookup(filter_func=None)
564
+
565
+ # overloads
566
+ null_filter_data = 0
567
+ null_filter_target = 0
568
+
569
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
570
+
571
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
572
+ def cell_lookup(args: self.CellArg, pos: pos_type, max_dist: float):
573
+ return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
574
+
575
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
576
+ def cell_lookup(args: self.CellArg, pos: pos_type, guess: Sample):
577
+ guess_pos = self.cell_position(args, guess)
578
+ max_dist = wp.length(guess_pos - pos)
579
+ return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
580
+
581
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
582
+ def cell_lookup(args: self.CellArg, pos: pos_type):
583
+ max_dist = 0.0
584
+ return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
585
+
586
+ # array filtering variants
587
+ filtered_cell_lookup = self.make_filtered_cell_lookup(filter_func=_array_load)
588
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
589
+
590
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
591
+ def cell_lookup(
592
+ args: self.CellArg, pos: pos_type, max_dist: float, filter_array: wp.array(dtype=Any), filter_target: Any
593
+ ):
594
+ return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)
595
+
596
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
597
+ def cell_lookup(args: self.CellArg, pos: pos_type, filter_array: wp.array(dtype=Any), filter_target: Any):
598
+ max_dist = 0.0
599
+ return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)
600
+
601
+ return cell_lookup
602
+
603
+ def _make_compute_cell_bounds(self):
604
+ @cache.dynamic_kernel(suffix=self.name)
605
+ def compute_cell_bounds(
606
+ args: self.CellArg,
607
+ lowers: wp.array(dtype=wp.vec3),
608
+ uppers: wp.array(dtype=wp.vec3),
609
+ ):
610
+ i = wp.tid()
611
+ lo, up = self.cell_bounds(args, i)
612
+ lowers[i] = _bvh_vec(lo)
613
+ uppers[i] = _bvh_vec(up)
614
+
615
+ return compute_cell_bounds
616
+
617
+ def supports_cell_lookup(self, device) -> bool:
618
+ return self.bvh_id(device) != _NULL_BVH_ID
619
+
620
+ def update_bvh(self, device=None):
621
+ """
622
+ Refits the BVH, or rebuilds it from scratch if `force_rebuild` is ``True``.
623
+ """
624
+
625
+ if self._bvhs is None:
626
+ return self.build_bvh(device)
627
+
628
+ device = wp.get_device(device)
629
+ bvh = self._bvhs.get(device.ordinal)
630
+ if bvh is None:
631
+ return self.build_bvh(device)
632
+
633
+ wp.launch(
634
+ self.compute_cell_bounds,
635
+ dim=self.cell_count(),
636
+ device=device,
637
+ inputs=[self.cell_arg_value(device=device)],
638
+ outputs=[
639
+ bvh.lowers,
640
+ bvh.uppers,
641
+ ],
642
+ )
643
+
644
+ bvh.refit()
645
+
646
+ def build_bvh(self, device=None):
647
+ device = wp.get_device(device)
648
+
649
+ lowers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=device)
650
+ uppers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=device)
651
+
652
+ wp.launch(
653
+ self.compute_cell_bounds,
654
+ dim=self.cell_count(),
655
+ device=device,
656
+ inputs=[self.cell_arg_value(device=device)],
657
+ outputs=[
658
+ lowers,
659
+ uppers,
660
+ ],
661
+ )
662
+
663
+ if self._bvhs is None:
664
+ self._bvhs = {}
665
+
666
+ self._bvhs[device.ordinal] = wp.Bvh(lowers, uppers)
667
+
668
+ Geometry.cell_arg_value.invalidate(self, device)
669
+ Geometry.side_arg_value.invalidate(self, device)
670
+
671
+ def bvh_id(self, device):
672
+ if self._bvhs is None:
673
+ return _NULL_BVH_ID
674
+
675
+ bvh = self._bvhs.get(wp.get_device(device).ordinal)
676
+ if bvh is None:
677
+ return _NULL_BVH_ID
678
+ return bvh.id
679
+
680
+
681
+ @wp.func
682
+ def _bvh_vec(v: wp.vec3):
683
+ return v
684
+
685
+
686
+ @wp.func
687
+ def _bvh_vec(v: wp.vec2):
688
+ return wp.vec3(v[0], v[1], 0.0)
689
+
690
+
691
+ @wp.func
692
+ def _array_load(arr: wp.array(dtype=Any), idx: int):
693
+ return arr[idx]