warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,461 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import cached_property
17
+ from typing import ClassVar, Optional, Tuple, Type
18
+
19
+ import warp as wp
20
+ from warp._src.fem import cache
21
+ from warp._src.fem.geometry import DeformedGeometry, Geometry
22
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX, ElementIndex
23
+
24
+ _wp_module_name_ = "warp.fem.space.topology"
25
+
26
+
27
+ class SpaceTopology:
28
+ """
29
+ Interface class for defining the topology of a function space.
30
+
31
+ The topology only considers the indices of the nodes in each element, and as such,
32
+ the connectivity pattern of the function space.
33
+ It does not specify the actual location of the nodes within the elements, or the valuation function.
34
+ """
35
+
36
+ dimension: int
37
+ """Embedding dimension of the function space"""
38
+
39
+ MAX_NODES_PER_ELEMENT: int
40
+ """maximum number of interpolation nodes per element of the geometry.
41
+
42
+ .. note:: This will change to be defined per-element in future versions
43
+ """
44
+
45
+ _dynamic_attribute_constructors: ClassVar = {
46
+ "element_node_count": lambda obj: obj._make_constant_element_node_count(),
47
+ "element_node_sign": lambda obj: obj._make_constant_element_node_sign(),
48
+ "side_neighbor_node_counts": lambda obj: obj._make_constant_side_neighbor_node_counts(),
49
+ }
50
+
51
+ @wp.struct
52
+ class TopologyArg:
53
+ """Structure containing arguments to be passed to device functions"""
54
+
55
+ pass
56
+
57
+ def __init__(self, geometry: Geometry, max_nodes_per_element: int):
58
+ self._geometry = geometry
59
+ self.dimension = geometry.dimension
60
+ self.MAX_NODES_PER_ELEMENT = wp.constant(max_nodes_per_element)
61
+ self.ElementArg = geometry.CellArg
62
+
63
+ cache.setup_dynamic_attributes(self, cls=__class__)
64
+
65
+ @property
66
+ def geometry(self) -> Geometry:
67
+ """Underlying geometry"""
68
+ return self._geometry
69
+
70
+ def node_count(self) -> int:
71
+ """Number of nodes in the interpolation basis"""
72
+ raise NotImplementedError
73
+
74
+ @cache.cached_arg_value
75
+ def topo_arg_value(self, device) -> "TopologyArg":
76
+ """Value of the topology argument structure to be passed to device functions"""
77
+ arg = self.TopologyArg()
78
+ self.fill_topo_arg(arg, device)
79
+ return arg
80
+
81
+ def fill_topo_arg(self, arg, device):
82
+ pass
83
+
84
+ @cached_property
85
+ def name(self):
86
+ return f"{self.__class__.__name__}_{self.MAX_NODES_PER_ELEMENT}"
87
+
88
+ def __str__(self):
89
+ return self.name
90
+
91
+ @staticmethod
92
+ def element_node_count(
93
+ geo_arg: "ElementArg", # noqa: F821
94
+ topo_arg: "TopologyArg",
95
+ element_index: ElementIndex,
96
+ ) -> int:
97
+ """Returns the actual number of nodes in a given element"""
98
+ raise NotImplementedError
99
+
100
+ @staticmethod
101
+ def element_node_index(
102
+ geo_arg: "ElementArg", # noqa: F821
103
+ topo_arg: "TopologyArg",
104
+ element_index: ElementIndex,
105
+ node_index_in_elt: int,
106
+ ) -> int:
107
+ """Global node index for a given node in a given element"""
108
+ raise NotImplementedError
109
+
110
+ @staticmethod
111
+ def side_neighbor_node_counts(
112
+ side_arg: "ElementArg", # noqa: F821
113
+ side_index: ElementIndex,
114
+ ) -> Tuple[int, int]:
115
+ """Returns the number of nodes for both the inner and outer cells of a given sides"""
116
+ raise NotImplementedError
117
+
118
+ def element_node_indices(self, out: Optional[wp.array] = None) -> wp.array:
119
+ """Returns a temporary array containing the global index for each node of each element"""
120
+
121
+ MAX_NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
122
+
123
+ @cache.dynamic_kernel(suffix=self.name)
124
+ def fill_element_node_indices(
125
+ geo_cell_arg: self.geometry.CellArg,
126
+ topo_arg: self.TopologyArg,
127
+ element_node_indices: wp.array2d(dtype=int),
128
+ ):
129
+ element_index = wp.tid()
130
+ element_node_count = self.element_node_count(geo_cell_arg, topo_arg, element_index)
131
+ for n in range(element_node_count):
132
+ element_node_indices[element_index, n] = self.element_node_index(
133
+ geo_cell_arg, topo_arg, element_index, n
134
+ )
135
+
136
+ shape = (self.geometry.cell_count(), MAX_NODES_PER_ELEMENT)
137
+ if out is None:
138
+ element_node_indices = wp.empty(
139
+ shape=shape,
140
+ dtype=int,
141
+ )
142
+ else:
143
+ if out.shape != shape or out.dtype != wp.int32:
144
+ raise ValueError(f"Out element node indices array must have shape {shape} and data type 'int32'")
145
+ element_node_indices = out
146
+
147
+ wp.launch(
148
+ dim=element_node_indices.shape[0],
149
+ kernel=fill_element_node_indices,
150
+ inputs=[
151
+ self.geometry.cell_arg_value(device=element_node_indices.device),
152
+ self.topo_arg_value(device=element_node_indices.device),
153
+ element_node_indices,
154
+ ],
155
+ device=element_node_indices.device,
156
+ )
157
+
158
+ return element_node_indices
159
+
160
+ # Interface generating trace space topology
161
+
162
+ def trace(self) -> "TraceSpaceTopology":
163
+ """Trace of the function space over lower-dimensional elements of the geometry"""
164
+
165
+ return TraceSpaceTopology(self)
166
+
167
+ @property
168
+ def is_trace(self) -> bool:
169
+ """Whether this topology is defined on the trace of the geometry"""
170
+ return self.dimension == self.geometry.dimension - 1
171
+
172
+ def full_space_topology(self) -> "SpaceTopology":
173
+ """Returns the full space topology from which this topology is derived"""
174
+ return self
175
+
176
+ def __eq__(self, other: "SpaceTopology") -> bool:
177
+ """Checks whether two topologies are compatible"""
178
+ return self.geometry == other.geometry and self.name == other.name
179
+
180
+ def is_derived_from(self, other: "SpaceTopology") -> bool:
181
+ """Checks whether two topologies are equal, or `self` is the trace of `other`"""
182
+ if self.dimension == other.dimension:
183
+ return self == other
184
+ if self.dimension + 1 == other.dimension:
185
+ return self.full_space_topology() == other
186
+ return False
187
+
188
+ def _make_constant_element_node_count(self):
189
+ NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
190
+
191
+ @cache.dynamic_func(suffix=self.name)
192
+ def constant_element_node_count(
193
+ geo_arg: self.geometry.CellArg,
194
+ topo_arg: self.TopologyArg,
195
+ element_index: ElementIndex,
196
+ ):
197
+ return NODES_PER_ELEMENT
198
+
199
+ return constant_element_node_count
200
+
201
+ def _make_constant_side_neighbor_node_counts(self):
202
+ NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
203
+
204
+ @cache.dynamic_func(suffix=self.name)
205
+ def constant_side_neighbor_node_counts(
206
+ side_arg: self.geometry.SideArg,
207
+ element_index: ElementIndex,
208
+ ):
209
+ return NODES_PER_ELEMENT, NODES_PER_ELEMENT
210
+
211
+ return constant_side_neighbor_node_counts
212
+
213
+ def _make_constant_element_node_sign(self):
214
+ @cache.dynamic_func(suffix=self.name)
215
+ def constant_element_node_sign(
216
+ geo_arg: self.geometry.CellArg,
217
+ topo_arg: self.TopologyArg,
218
+ element_index: ElementIndex,
219
+ node_index_in_element: int,
220
+ ):
221
+ return 1.0
222
+
223
+ return constant_element_node_sign
224
+
225
+
226
+ class TraceSpaceTopology(SpaceTopology):
227
+ """Auto-generated trace topology defining the node indices associated to the geometry sides"""
228
+
229
+ _dynamic_attribute_constructors: ClassVar = {
230
+ "inner_cell_index": lambda obj: obj._make_inner_cell_index(),
231
+ "outer_cell_index": lambda obj: obj._make_outer_cell_index(),
232
+ "neighbor_cell_index": lambda obj: obj._make_neighbor_cell_index(),
233
+ "element_node_index": lambda obj: obj._make_element_node_index(),
234
+ "element_node_count": lambda obj: obj._make_element_node_count(),
235
+ "element_node_sign": lambda obj: obj._make_element_node_sign(),
236
+ }
237
+
238
+ def __init__(self, topo: SpaceTopology):
239
+ self._topo = topo
240
+
241
+ super().__init__(topo.geometry, 2 * topo.MAX_NODES_PER_ELEMENT)
242
+
243
+ self.dimension = topo.dimension - 1
244
+ self.ElementArg = topo.geometry.SideArg
245
+
246
+ self.TopologyArg = topo.TopologyArg
247
+ self.topo_arg_value = topo.topo_arg_value
248
+ self.fill_topo_arg = topo.fill_topo_arg
249
+
250
+ self.side_neighbor_node_counts = None
251
+ cache.setup_dynamic_attributes(self, cls=__class__)
252
+
253
+ def node_count(self) -> int:
254
+ return self._topo.node_count()
255
+
256
+ @cached_property
257
+ def name(self):
258
+ return f"{self._topo.name}_Trace"
259
+
260
+ def _make_inner_cell_index(self):
261
+ @cache.dynamic_func(suffix=self.name)
262
+ def inner_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
263
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
264
+ if node_index_in_elt >= inner_count:
265
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
266
+ return self.geometry.side_inner_cell_index(side_arg, element_index), node_index_in_elt
267
+
268
+ return inner_cell_index
269
+
270
+ def _make_outer_cell_index(self):
271
+ @cache.dynamic_func(suffix=self.name)
272
+ def outer_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
273
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
274
+ if node_index_in_elt < inner_count:
275
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
276
+ return self.geometry.side_outer_cell_index(side_arg, element_index), node_index_in_elt - inner_count
277
+
278
+ return outer_cell_index
279
+
280
+ def _make_neighbor_cell_index(self):
281
+ @cache.dynamic_func(suffix=self.name)
282
+ def neighbor_cell_index(side_arg: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
283
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(side_arg, element_index)
284
+ if node_index_in_elt < inner_count:
285
+ return self.geometry.side_inner_cell_index(side_arg, element_index), node_index_in_elt
286
+
287
+ return (
288
+ self.geometry.side_outer_cell_index(side_arg, element_index),
289
+ node_index_in_elt - inner_count,
290
+ )
291
+
292
+ return neighbor_cell_index
293
+
294
+ def _make_element_node_count(self):
295
+ @cache.dynamic_func(suffix=self.name)
296
+ def trace_element_node_count(
297
+ geo_side_arg: self.geometry.SideArg,
298
+ topo_arg: self._topo.TopologyArg,
299
+ element_index: ElementIndex,
300
+ ):
301
+ inner_count, outer_count = self._topo.side_neighbor_node_counts(geo_side_arg, element_index)
302
+ return inner_count + outer_count
303
+
304
+ return trace_element_node_count
305
+
306
+ def _make_element_node_index(self):
307
+ @cache.dynamic_func(suffix=self.name)
308
+ def trace_element_node_index(
309
+ geo_side_arg: self.geometry.SideArg,
310
+ topo_arg: self._topo.TopologyArg,
311
+ element_index: ElementIndex,
312
+ node_index_in_elt: int,
313
+ ):
314
+ cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)
315
+
316
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
317
+ return self._topo.element_node_index(geo_cell_arg, topo_arg, cell_index, index_in_cell)
318
+
319
+ return trace_element_node_index
320
+
321
+ def _make_element_node_sign(self):
322
+ @cache.dynamic_func(suffix=self.name)
323
+ def trace_element_node_sign(
324
+ geo_side_arg: self.geometry.SideArg,
325
+ topo_arg: self._topo.TopologyArg,
326
+ element_index: ElementIndex,
327
+ node_index_in_elt: int,
328
+ ):
329
+ cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)
330
+
331
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
332
+ return self._topo.element_node_sign(geo_cell_arg, topo_arg, cell_index, index_in_cell)
333
+
334
+ return trace_element_node_sign
335
+
336
+ def full_space_topology(self) -> SpaceTopology:
337
+ """Returns the full space topology from which this topology is derived"""
338
+ return self._topo
339
+
340
+ def __eq__(self, other: "TraceSpaceTopology") -> bool:
341
+ return self._topo == other._topo
342
+
343
+
344
+ class RegularDiscontinuousSpaceTopologyMixin:
345
+ """Helper for defining discontinuous topologies (per-element nodes)"""
346
+
347
+ def __init__(self, *args, **kwargs):
348
+ super().__init__(*args, **kwargs)
349
+ self.element_node_index = self._make_element_node_index()
350
+
351
+ def node_count(self):
352
+ return self.geometry.cell_count() * self.MAX_NODES_PER_ELEMENT
353
+
354
+ @cached_property
355
+ def name(self):
356
+ return f"{self.geometry.name}_D{self.MAX_NODES_PER_ELEMENT}"
357
+
358
+ def _make_element_node_index(self):
359
+ NODES_PER_ELEMENT = self.MAX_NODES_PER_ELEMENT
360
+
361
+ @cache.dynamic_func(suffix=self.name)
362
+ def element_node_index(
363
+ elt_arg: self.geometry.CellArg,
364
+ topo_arg: self.TopologyArg,
365
+ element_index: ElementIndex,
366
+ node_index_in_elt: int,
367
+ ):
368
+ return NODES_PER_ELEMENT * element_index + node_index_in_elt
369
+
370
+ return element_node_index
371
+
372
+
373
+ class RegularDiscontinuousSpaceTopology(RegularDiscontinuousSpaceTopologyMixin, SpaceTopology):
374
+ """Topology for generic discontinuous spaces"""
375
+
376
+ pass
377
+
378
+
379
+ class DeformedGeometrySpaceTopology(SpaceTopology):
380
+ _dynamic_attribute_constructors: ClassVar = {
381
+ "element_node_index": lambda obj: obj._make_element_node_index(),
382
+ "element_node_count": lambda obj: obj._make_element_node_count(),
383
+ "element_node_sign": lambda obj: obj._make_element_node_sign(),
384
+ "side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
385
+ }
386
+
387
+ def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
388
+ self.base = base_topology
389
+ super().__init__(geometry, base_topology.MAX_NODES_PER_ELEMENT)
390
+
391
+ self.node_count = self.base.node_count
392
+ self.topo_arg_value = self.base.topo_arg_value
393
+ self.fill_topo_arg = self.base.fill_topo_arg
394
+ self.TopologyArg = self.base.TopologyArg
395
+
396
+ cache.setup_dynamic_attributes(self, cls=__class__)
397
+
398
+ @cached_property
399
+ def name(self):
400
+ return f"{self.base.name}_{self.geometry.field.name}"
401
+
402
+ def _make_element_node_index(self):
403
+ @cache.dynamic_func(suffix=self.name)
404
+ def element_node_index(
405
+ elt_arg: self.geometry.CellArg,
406
+ topo_arg: self.TopologyArg,
407
+ element_index: ElementIndex,
408
+ node_index_in_elt: int,
409
+ ):
410
+ return self.base.element_node_index(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
411
+
412
+ return element_node_index
413
+
414
+ def _make_element_node_count(self):
415
+ @cache.dynamic_func(suffix=self.name)
416
+ def element_node_count(
417
+ elt_arg: self.geometry.CellArg,
418
+ topo_arg: self.TopologyArg,
419
+ element_count: ElementIndex,
420
+ ):
421
+ return self.base.element_node_count(elt_arg.base_arg, topo_arg, element_count)
422
+
423
+ return element_node_count
424
+
425
+ def _make_side_neighbor_node_counts(self):
426
+ @cache.dynamic_func(suffix=self.name)
427
+ def side_neighbor_node_counts(
428
+ side_arg: self.geometry.SideArg,
429
+ element_index: ElementIndex,
430
+ ):
431
+ inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
432
+ return inner_count, outer_count
433
+
434
+ return side_neighbor_node_counts
435
+
436
+ def _make_element_node_sign(self):
437
+ @cache.dynamic_func(suffix=self.name)
438
+ def element_node_sign(
439
+ elt_arg: self.geometry.CellArg,
440
+ topo_arg: self.TopologyArg,
441
+ element_index: ElementIndex,
442
+ node_index_in_elt: int,
443
+ ):
444
+ return self.base.element_node_sign(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
445
+
446
+ return element_node_sign
447
+
448
+
449
+ def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
450
+ """
451
+ If `geometry` is *not* a :class:`DeformedGeometry`, constructs a normal instance of `topology_class` over `geometry`, forwarding additional arguments.
452
+
453
+ If `geometry` *is* a :class:`DeformedGeometry`, constructs an instance of `topology_class` over the base (undeformed) geometry of `geometry`, then warp it
454
+ in a :class:`DeformedGeometrySpaceTopology` forwarding the calls to the underlying topology.
455
+ """
456
+
457
+ if isinstance(geometry, DeformedGeometry):
458
+ base_topo = topology_class(geometry.base, *args, **kwargs)
459
+ return DeformedGeometrySpaceTopology(geometry, base_topo)
460
+
461
+ return topology_class(geometry, *args, **kwargs)
@@ -0,0 +1,193 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Trimesh
19
+ from warp._src.fem.types import ElementIndex
20
+
21
+ from .shape import TriangleShapeFunction
22
+ from .topology import SpaceTopology, forward_base_topology
23
+
24
+ _wp_module_name_ = "warp.fem.space.trimesh_function_space"
25
+
26
+
27
+ @wp.struct
28
+ class TrimeshTopologyArg:
29
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
30
+ tri_edge_indices: wp.array2d(dtype=int)
31
+
32
+ vertex_count: int
33
+ edge_count: int
34
+
35
+
36
+ class TrimeshSpaceTopology(SpaceTopology):
37
+ TopologyArg = TrimeshTopologyArg
38
+
39
+ def __init__(self, mesh: Trimesh, shape: TriangleShapeFunction):
40
+ self._shape = shape
41
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
42
+ self._mesh = mesh
43
+
44
+ self._compute_tri_edge_indices()
45
+ self.element_node_index = self._make_element_node_index()
46
+ self.element_node_sign = self._make_element_node_sign()
47
+
48
+ @property
49
+ def name(self):
50
+ return f"{self.geometry.name}_{self._shape.name}"
51
+
52
+ def fill_topo_arg(self, arg: TrimeshTopologyArg, device):
53
+ arg.tri_edge_indices = self._tri_edge_indices.to(device)
54
+ arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
55
+
56
+ arg.vertex_count = self._mesh.vertex_count()
57
+ arg.edge_count = self._mesh.side_count()
58
+
59
+ def _compute_tri_edge_indices(self):
60
+ self._tri_edge_indices = wp.empty(
61
+ dtype=int, device=self._mesh.tri_vertex_indices.device, shape=(self._mesh.cell_count(), 3)
62
+ )
63
+
64
+ wp.launch(
65
+ kernel=TrimeshSpaceTopology._compute_tri_edge_indices_kernel,
66
+ dim=self._mesh.edge_tri_indices.shape,
67
+ device=self._mesh.tri_vertex_indices.device,
68
+ inputs=[
69
+ self._mesh.edge_tri_indices,
70
+ self._mesh.edge_vertex_indices,
71
+ self._mesh.tri_vertex_indices,
72
+ self._tri_edge_indices,
73
+ ],
74
+ )
75
+
76
+ @wp.func
77
+ def _find_edge_index_in_tri(
78
+ edge_vtx: wp.vec2i,
79
+ tri_vtx: wp.vec3i,
80
+ ):
81
+ for k in range(2):
82
+ if (edge_vtx[0] == tri_vtx[k] and edge_vtx[1] == tri_vtx[k + 1]) or (
83
+ edge_vtx[1] == tri_vtx[k] and edge_vtx[0] == tri_vtx[k + 1]
84
+ ):
85
+ return k
86
+ return 2
87
+
88
+ @wp.kernel
89
+ def _compute_tri_edge_indices_kernel(
90
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
91
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
92
+ tri_vertex_indices: wp.array2d(dtype=int),
93
+ tri_edge_indices: wp.array2d(dtype=int),
94
+ ):
95
+ e = wp.tid()
96
+
97
+ edge_vtx = edge_vertex_indices[e]
98
+ edge_tris = edge_tri_indices[e]
99
+
100
+ t0 = edge_tris[0]
101
+ t0_vtx = wp.vec3i(tri_vertex_indices[t0, 0], tri_vertex_indices[t0, 1], tri_vertex_indices[t0, 2])
102
+ t0_edge = TrimeshSpaceTopology._find_edge_index_in_tri(edge_vtx, t0_vtx)
103
+ tri_edge_indices[t0, t0_edge] = e
104
+
105
+ t1 = edge_tris[1]
106
+ if t1 != t0:
107
+ t1_vtx = wp.vec3i(tri_vertex_indices[t1, 0], tri_vertex_indices[t1, 1], tri_vertex_indices[t1, 2])
108
+ t1_edge = TrimeshSpaceTopology._find_edge_index_in_tri(edge_vtx, t1_vtx)
109
+ tri_edge_indices[t1, t1_edge] = e
110
+
111
+ def node_count(self) -> int:
112
+ return (
113
+ self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
114
+ + self._mesh.side_count() * self._shape.EDGE_NODE_COUNT
115
+ + self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
116
+ )
117
+
118
+ def _make_element_node_index(self):
119
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
120
+ INTERIOR_NODES_PER_SIDE = self._shape.EDGE_NODE_COUNT
121
+ INTERIOR_NODES_PER_CELL = self._shape.INTERIOR_NODE_COUNT
122
+
123
+ @cache.dynamic_func(suffix=self.name)
124
+ def element_node_index(
125
+ geo_arg: self.geometry.CellArg,
126
+ topo_arg: TrimeshTopologyArg,
127
+ element_index: ElementIndex,
128
+ node_index_in_elt: int,
129
+ ):
130
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
131
+
132
+ if wp.static(VERTEX_NODE_COUNT > 0):
133
+ if node_type == TriangleShapeFunction.VERTEX:
134
+ vertex = type_index // VERTEX_NODE_COUNT
135
+ vertex_node = type_index - VERTEX_NODE_COUNT * vertex
136
+ return geo_arg.topology.tri_vertex_indices[element_index][vertex] * VERTEX_NODE_COUNT + vertex_node
137
+
138
+ global_offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
139
+
140
+ if wp.static(INTERIOR_NODES_PER_SIDE > 0):
141
+ if node_type == TriangleShapeFunction.EDGE:
142
+ edge = type_index // INTERIOR_NODES_PER_SIDE
143
+ edge_node = type_index - INTERIOR_NODES_PER_SIDE * edge
144
+
145
+ global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
146
+
147
+ if (
148
+ topo_arg.edge_vertex_indices[global_edge_index][0]
149
+ != geo_arg.topology.tri_vertex_indices[element_index][edge]
150
+ ):
151
+ edge_node = INTERIOR_NODES_PER_SIDE - 1 - edge_node
152
+
153
+ return global_offset + INTERIOR_NODES_PER_SIDE * global_edge_index + edge_node
154
+
155
+ global_offset += INTERIOR_NODES_PER_SIDE * topo_arg.edge_count
156
+
157
+ return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
158
+
159
+ return element_node_index
160
+
161
+ def _make_element_node_sign(self):
162
+ INTERIOR_NODES_PER_SIDE = self._shape.EDGE_NODE_COUNT
163
+
164
+ @cache.dynamic_func(suffix=self.name)
165
+ def element_node_sign(
166
+ geo_arg: self.geometry.CellArg,
167
+ topo_arg: TrimeshTopologyArg,
168
+ element_index: ElementIndex,
169
+ node_index_in_elt: int,
170
+ ):
171
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
172
+
173
+ if node_type == TriangleShapeFunction.EDGE:
174
+ edge = type_index // INTERIOR_NODES_PER_SIDE
175
+
176
+ global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
177
+ return wp.where(
178
+ topo_arg.edge_vertex_indices[global_edge_index][0]
179
+ == geo_arg.topology.tri_vertex_indices[element_index][edge],
180
+ 1.0,
181
+ -1.0,
182
+ )
183
+
184
+ return 1.0
185
+
186
+ return element_node_sign
187
+
188
+
189
+ def make_trimesh_space_topology(mesh: Trimesh, shape: TriangleShapeFunction):
190
+ if isinstance(shape, TriangleShapeFunction):
191
+ return forward_base_topology(TrimeshSpaceTopology, mesh, shape)
192
+
193
+ raise ValueError(f"Unsupported shape function {shape.name}")