warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,199 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Union
17
+
18
+ import warp as wp
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import AdaptiveNanogrid, Nanogrid
21
+ from warp._src.fem.types import ElementIndex
22
+
23
+ from .shape import CubeShapeFunction
24
+ from .topology import SpaceTopology, forward_base_topology
25
+
26
+ _wp_module_name_ = "warp.fem.space.nanogrid_function_space"
27
+
28
+
29
+ @wp.struct
30
+ class NanogridTopologyArg:
31
+ vertex_grid: wp.uint64
32
+ face_grid: wp.uint64
33
+ edge_grid: wp.uint64
34
+
35
+ vertex_count: int
36
+ edge_count: int
37
+ face_count: int
38
+
39
+
40
+ class NanogridSpaceTopology(SpaceTopology):
41
+ TopologyArg = NanogridTopologyArg
42
+
43
+ def __init__(
44
+ self,
45
+ grid: Union[Nanogrid, AdaptiveNanogrid],
46
+ shape: CubeShapeFunction,
47
+ ):
48
+ self._shape = shape
49
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
50
+ self._grid = grid
51
+
52
+ self._vertex_grid = grid.vertex_grid.id
53
+
54
+ need_edge_indices = shape.EDGE_NODE_COUNT > 0
55
+ need_face_indices = shape.FACE_NODE_COUNT > 0
56
+
57
+ if isinstance(grid, Nanogrid):
58
+ self._edge_grid = grid.edge_grid.id if need_edge_indices else -1
59
+ self._face_grid = grid.face_grid.id if need_face_indices else -1
60
+ self._edge_count = grid.edge_count() if need_edge_indices else 0
61
+ self._face_count = grid.side_count() if need_face_indices else 0
62
+ else:
63
+ self._edge_grid = grid.stacked_edge_grid.id if need_edge_indices else -1
64
+ self._face_grid = grid.stacked_face_grid.id if need_face_indices else -1
65
+ self._edge_count = grid.stacked_edge_count() if need_edge_indices else 0
66
+ self._face_count = grid.stacked_face_count() if need_face_indices else 0
67
+
68
+ self.element_node_index = self._make_element_node_index()
69
+
70
+ @property
71
+ def name(self):
72
+ return f"{self.geometry.name}_{self._shape.name}"
73
+
74
+ def fill_topo_arg(self, arg, device):
75
+ arg.vertex_grid = self._vertex_grid
76
+ arg.face_grid = self._face_grid
77
+ arg.edge_grid = self._edge_grid
78
+
79
+ arg.vertex_count = self._grid.vertex_count()
80
+ arg.face_count = self._face_count
81
+ arg.edge_count = self._edge_count
82
+
83
+ def _make_element_node_index(self):
84
+ element_node_index_generic = self._make_element_node_index_generic()
85
+
86
+ @cache.dynamic_func(suffix=self.name)
87
+ def element_node_index(
88
+ geo_arg: Nanogrid.CellArg,
89
+ topo_arg: NanogridTopologyArg,
90
+ element_index: ElementIndex,
91
+ node_index_in_elt: int,
92
+ ):
93
+ ijk = geo_arg.cell_ijk[element_index]
94
+ return element_node_index_generic(topo_arg, element_index, node_index_in_elt, ijk, 0)
95
+
96
+ if isinstance(self._grid, Nanogrid):
97
+ return element_node_index
98
+
99
+ @cache.dynamic_func(suffix=self.name)
100
+ def element_node_index_adaptive(
101
+ geo_arg: AdaptiveNanogrid.CellArg,
102
+ topo_arg: NanogridTopologyArg,
103
+ element_index: ElementIndex,
104
+ node_index_in_elt: int,
105
+ ):
106
+ ijk = geo_arg.cell_ijk[element_index]
107
+ level = int(geo_arg.cell_level[element_index])
108
+ return element_node_index_generic(topo_arg, element_index, node_index_in_elt, ijk, level)
109
+
110
+ return element_node_index_adaptive
111
+
112
+ def node_count(self) -> int:
113
+ return (
114
+ self._grid.vertex_count() * self._shape.VERTEX_NODE_COUNT
115
+ + self._edge_count * self._shape.EDGE_NODE_COUNT
116
+ + self._face_count * self._shape.FACE_NODE_COUNT
117
+ + self._grid.cell_count() * self._shape.INTERIOR_NODE_COUNT
118
+ )
119
+
120
+ def _make_element_node_index_generic(self):
121
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
122
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
123
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
124
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
125
+
126
+ @cache.dynamic_func(suffix=self.name)
127
+ def element_node_index_generic(
128
+ topo_arg: NanogridTopologyArg,
129
+ element_index: ElementIndex,
130
+ node_index_in_elt: int,
131
+ ijk: wp.vec3i,
132
+ level: int,
133
+ ):
134
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
135
+
136
+ if wp.static(VERTEX_NODE_COUNT > 0):
137
+ if node_type == CubeShapeFunction.VERTEX:
138
+ n_ijk = _cell_vertex_coord(ijk, level, type_instance)
139
+ return (
140
+ wp.volume_lookup_index(topo_arg.vertex_grid, n_ijk[0], n_ijk[1], n_ijk[2]) * VERTEX_NODE_COUNT
141
+ + type_index
142
+ )
143
+
144
+ offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
145
+
146
+ if wp.static(EDGE_NODE_COUNT > 0):
147
+ if node_type == CubeShapeFunction.EDGE:
148
+ axis = type_instance >> 2
149
+ node_offset = type_instance & 3
150
+
151
+ n_ijk = _cell_edge_coord(ijk, level, axis, node_offset)
152
+
153
+ edge_index = wp.volume_lookup_index(topo_arg.edge_grid, n_ijk[0], n_ijk[1], n_ijk[2])
154
+ return offset + EDGE_NODE_COUNT * edge_index + type_index
155
+
156
+ offset += EDGE_NODE_COUNT * topo_arg.edge_count
157
+
158
+ if wp.static(FACE_NODE_COUNT > 0):
159
+ if node_type == CubeShapeFunction.FACE:
160
+ axis = type_instance >> 1
161
+ node_offset = type_instance & 1
162
+
163
+ n_ijk = _cell_face_coord(ijk, level, axis, node_offset)
164
+
165
+ face_index = wp.volume_lookup_index(topo_arg.face_grid, n_ijk[0], n_ijk[1], n_ijk[2])
166
+ return offset + FACE_NODE_COUNT * face_index + type_index
167
+
168
+ offset += FACE_NODE_COUNT * topo_arg.face_count
169
+
170
+ return offset + INTERIOR_NODE_COUNT * element_index + type_index
171
+
172
+ return element_node_index_generic
173
+
174
+
175
+ @wp.func
176
+ def _cell_vertex_coord(cell_ijk: wp.vec3i, cell_level: int, n: int):
177
+ return cell_ijk + AdaptiveNanogrid.fine_ijk(wp.vec3i((n & 4) >> 2, (n & 2) >> 1, n & 1), cell_level)
178
+
179
+
180
+ @wp.func
181
+ def _cell_edge_coord(cell_ijk: wp.vec3i, cell_level: int, axis: int, offset: int):
182
+ e_ijk = AdaptiveNanogrid.coarse_ijk(cell_ijk, cell_level)
183
+ e_ijk[(axis + 1) % 3] += offset >> 1
184
+ e_ijk[(axis + 2) % 3] += offset & 1
185
+ return AdaptiveNanogrid.encode_axis_and_level(e_ijk, axis, cell_level)
186
+
187
+
188
+ @wp.func
189
+ def _cell_face_coord(cell_ijk: wp.vec3i, cell_level: int, axis: int, offset: int):
190
+ f_ijk = AdaptiveNanogrid.coarse_ijk(cell_ijk, cell_level)
191
+ f_ijk[axis] += offset
192
+ return AdaptiveNanogrid.encode_axis_and_level(f_ijk, axis, cell_level)
193
+
194
+
195
+ def make_nanogrid_space_topology(grid: Union[Nanogrid, AdaptiveNanogrid], shape: CubeShapeFunction):
196
+ if isinstance(shape, CubeShapeFunction):
197
+ return forward_base_topology(NanogridSpaceTopology, grid, shape)
198
+
199
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,435 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional
17
+
18
+ import warp as wp
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import GeometryPartition, WholeGeometryPartition
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX
22
+ from warp._src.fem.utils import compress_node_indices
23
+
24
+ from .function_space import FunctionSpace
25
+ from .topology import SpaceTopology
26
+
27
+ _wp_module_name_ = "warp.fem.space.partition"
28
+
29
+ wp.set_module_options({"enable_backward": False})
30
+
31
+
32
+ class SpacePartition:
33
+ class PartitionArg:
34
+ pass
35
+
36
+ def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
37
+ self.space_topology = space_topology
38
+ self.geo_partition = geo_partition
39
+
40
+ def node_count(self):
41
+ """Returns number of nodes in this partition"""
42
+
43
+ def owned_node_count(self) -> int:
44
+ """Returns number of nodes in this partition, excluding exterior halo"""
45
+
46
+ def interior_node_count(self) -> int:
47
+ """Returns number of interior nodes in this partition"""
48
+
49
+ def space_node_indices(self) -> wp.array:
50
+ """Return the global function space indices for nodes in this partition"""
51
+
52
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
53
+ """Rebuild the space partition indices"""
54
+ pass
55
+
56
+ @cache.cached_arg_value
57
+ def partition_arg_value(self, device):
58
+ arg = self.PartitionArg()
59
+ self.fill_partition_arg(arg, device)
60
+ return arg
61
+
62
+ def fill_partition_arg(self, arg, device):
63
+ pass
64
+
65
+ @staticmethod
66
+ def partition_node_index(args: "PartitionArg", space_node_index: int):
67
+ """Returns the index in the partition of a function space node, or ``NULL_NODE_INDEX`` if it does not exist"""
68
+
69
+ def __str__(self) -> str:
70
+ return self.name
71
+
72
+ @property
73
+ def name(self) -> str:
74
+ return f"{self.__class__.__name__}"
75
+
76
+
77
+ class WholeSpacePartition(SpacePartition):
78
+ @wp.struct
79
+ class PartitionArg:
80
+ pass
81
+
82
+ def __init__(self, space_topology: SpaceTopology):
83
+ super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
84
+ self._node_indices = None
85
+
86
+ def node_count(self):
87
+ """Returns number of nodes in this partition"""
88
+ return self.space_topology.node_count()
89
+
90
+ def owned_node_count(self) -> int:
91
+ """Returns number of nodes in this partition, excluding exterior halo"""
92
+ return self.space_topology.node_count()
93
+
94
+ def interior_node_count(self) -> int:
95
+ """Returns number of interior nodes in this partition"""
96
+ return self.space_topology.node_count()
97
+
98
+ def space_node_indices(self):
99
+ """Return the global function space indices for nodes in this partition"""
100
+ if self._node_indices is None:
101
+ self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
102
+ wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices])
103
+ return self._node_indices
104
+
105
+ def partition_arg_value(self, device):
106
+ return WholeSpacePartition.PartitionArg()
107
+
108
+ def fill_partition_arg(self, arg, device):
109
+ pass
110
+
111
+ @wp.func
112
+ def partition_node_index(args: Any, space_node_index: int):
113
+ return space_node_index
114
+
115
+ def __eq__(self, other: SpacePartition) -> bool:
116
+ return isinstance(other, WholeSpacePartition) and self.space_topology == other.space_topology
117
+
118
+ @property
119
+ def name(self) -> str:
120
+ return "Whole"
121
+
122
+ @wp.kernel
123
+ def _iota_kernel(indices: wp.array(dtype=int)):
124
+ indices[wp.tid()] = wp.tid()
125
+
126
+
127
+ class NodeCategory:
128
+ OWNED_INTERIOR = wp.constant(0)
129
+ """Node is touched exclusively by this partition, not touched by frontier side"""
130
+ OWNED_FRONTIER = wp.constant(1)
131
+ """Node is touched by a frontier side, but belongs to an element of this partition"""
132
+ HALO_LOCAL_SIDE = wp.constant(2)
133
+ """Node belongs to an element of another partition, but is touched by one of our frontier side"""
134
+ HALO_OTHER_SIDE = wp.constant(3)
135
+ """Node belongs to an element of another partition, and is not touched by one of our frontier side"""
136
+ EXTERIOR = wp.constant(4)
137
+ """Node is never referenced by this partition"""
138
+
139
+ COUNT = 5
140
+
141
+
142
+ class NodePartition(SpacePartition):
143
+ @wp.struct
144
+ class PartitionArg:
145
+ space_to_partition: wp.array(dtype=int)
146
+
147
+ def __init__(
148
+ self,
149
+ space_topology: SpaceTopology,
150
+ geo_partition: GeometryPartition,
151
+ with_halo: bool = True,
152
+ max_node_count: int = -1,
153
+ device=None,
154
+ temporary_store: Optional[cache.TemporaryStore] = None,
155
+ ):
156
+ super().__init__(space_topology=space_topology, geo_partition=geo_partition)
157
+
158
+ if max_node_count >= 0:
159
+ max_node_count = min(max_node_count, space_topology.node_count())
160
+
161
+ self._max_node_count = max_node_count
162
+ self._with_halo = with_halo
163
+
164
+ self._category_offsets: wp.array = None
165
+ """Offsets for each node category"""
166
+ self._node_indices: wp.array = None
167
+ """Mapping from local partition node indices to global space node indices"""
168
+ self._space_to_partition: wp.array = None
169
+ """Mapping from global space node indices to local partition node indices"""
170
+
171
+ self.rebuild(device, temporary_store)
172
+
173
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
174
+ self._compute_node_indices_from_sides(device, self._with_halo, self._max_node_count, temporary_store)
175
+
176
+ def node_count(self) -> int:
177
+ """Returns number of nodes referenced by this partition, including exterior halo"""
178
+ return int(self._category_offsets.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
179
+
180
+ def owned_node_count(self) -> int:
181
+ """Returns number of nodes in this partition, excluding exterior halo"""
182
+ return int(self._category_offsets.numpy()[NodeCategory.OWNED_FRONTIER + 1])
183
+
184
+ def interior_node_count(self) -> int:
185
+ """Returns number of interior nodes in this partition"""
186
+ return int(self._category_offsets.numpy()[NodeCategory.OWNED_INTERIOR + 1])
187
+
188
+ def space_node_indices(self):
189
+ """Return the global function space indices for nodes in this partition"""
190
+ return self._node_indices
191
+
192
+ def fill_partition_arg(self, arg, device):
193
+ arg.space_to_partition = self._space_to_partition.to(device)
194
+
195
+ @wp.func
196
+ def partition_node_index(args: PartitionArg, space_node_index: int):
197
+ return args.space_to_partition[space_node_index]
198
+
199
+ def _compute_node_indices_from_sides(
200
+ self, device, with_halo: bool, max_node_count: int, temporary_store: cache.TemporaryStore
201
+ ):
202
+ trace_topology = self.space_topology.trace()
203
+
204
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
205
+ def node_category_from_cells_kernel(
206
+ geo_arg: self.geo_partition.geometry.CellArg,
207
+ geo_partition_arg: self.geo_partition.CellArg,
208
+ space_arg: self.space_topology.TopologyArg,
209
+ node_mask: wp.array(dtype=int),
210
+ ):
211
+ partition_cell_index = wp.tid()
212
+
213
+ cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
214
+ if cell_index == NULL_ELEMENT_INDEX:
215
+ return
216
+
217
+ cell_node_count = self.space_topology.element_node_count(geo_arg, space_arg, cell_index)
218
+ for n in range(cell_node_count):
219
+ space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
220
+ node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
221
+
222
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
223
+ def node_category_from_owned_sides_kernel(
224
+ geo_arg: self.geo_partition.geometry.SideArg,
225
+ geo_partition_arg: self.geo_partition.SideArg,
226
+ space_arg: trace_topology.TopologyArg,
227
+ node_mask: wp.array(dtype=int),
228
+ ):
229
+ partition_side_index = wp.tid()
230
+
231
+ side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
232
+ if side_index == NULL_ELEMENT_INDEX:
233
+ return
234
+
235
+ side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
236
+ for n in range(side_node_count):
237
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
238
+
239
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
240
+ node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
241
+
242
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
243
+ def node_category_from_frontier_sides_kernel(
244
+ geo_arg: self.geo_partition.geometry.SideArg,
245
+ geo_partition_arg: self.geo_partition.SideArg,
246
+ space_arg: trace_topology.TopologyArg,
247
+ node_mask: wp.array(dtype=int),
248
+ ):
249
+ frontier_side_index = wp.tid()
250
+
251
+ side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
252
+ if side_index == NULL_ELEMENT_INDEX:
253
+ return
254
+
255
+ side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
256
+ for n in range(side_node_count):
257
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
258
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
259
+ node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
260
+ elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
261
+ node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
262
+
263
+ node_category = cache.borrow_temporary(
264
+ temporary_store,
265
+ shape=(self.space_topology.node_count(),),
266
+ dtype=int,
267
+ device=device,
268
+ )
269
+ node_category.fill_(value=NodeCategory.EXTERIOR)
270
+
271
+ wp.launch(
272
+ dim=self.geo_partition.cell_count(),
273
+ kernel=node_category_from_cells_kernel,
274
+ inputs=[
275
+ self.geo_partition.geometry.cell_arg_value(device),
276
+ self.geo_partition.cell_arg_value(device),
277
+ self.space_topology.topo_arg_value(device),
278
+ node_category,
279
+ ],
280
+ device=device,
281
+ )
282
+
283
+ if with_halo:
284
+ wp.launch(
285
+ dim=self.geo_partition.side_count(),
286
+ kernel=node_category_from_owned_sides_kernel,
287
+ inputs=[
288
+ self.geo_partition.geometry.side_arg_value(device),
289
+ self.geo_partition.side_arg_value(device),
290
+ self.space_topology.topo_arg_value(device),
291
+ node_category,
292
+ ],
293
+ device=device,
294
+ )
295
+
296
+ wp.launch(
297
+ dim=self.geo_partition.frontier_side_count(),
298
+ kernel=node_category_from_frontier_sides_kernel,
299
+ inputs=[
300
+ self.geo_partition.geometry.side_arg_value(device),
301
+ self.geo_partition.side_arg_value(device),
302
+ self.space_topology.topo_arg_value(device),
303
+ node_category,
304
+ ],
305
+ device=device,
306
+ )
307
+
308
+ with wp.ScopedDevice(device):
309
+ self._finalize_node_indices(node_category, max_node_count, temporary_store)
310
+
311
+ node_category.release()
312
+
313
+ def _finalize_node_indices(
314
+ self, node_category: wp.array(dtype=int), max_node_count: int, temporary_store: cache.TemporaryStore
315
+ ):
316
+ category_offsets, node_indices = compress_node_indices(
317
+ NodeCategory.COUNT, node_category, temporary_store=temporary_store
318
+ )
319
+ device = node_category.device
320
+
321
+ if max_node_count >= 0:
322
+ if self._category_offsets is None:
323
+ self._category_offsets = cache.borrow_temporary(
324
+ temporary_store,
325
+ shape=(NodeCategory.COUNT + 1,),
326
+ dtype=category_offsets.dtype,
327
+ device="cpu",
328
+ )
329
+ self._category_offsets.fill_(max_node_count)
330
+ copy_event = None
331
+ else:
332
+ # Copy offsets to cpu
333
+ if self._category_offsets is None:
334
+ self._category_offsets = cache.borrow_temporary(
335
+ temporary_store,
336
+ shape=(NodeCategory.COUNT + 1,),
337
+ dtype=category_offsets.dtype,
338
+ pinned=device.is_cuda,
339
+ device="cpu",
340
+ )
341
+ wp.copy(src=category_offsets, dest=self._category_offsets, count=NodeCategory.COUNT + 1)
342
+ copy_event = cache.capture_event()
343
+
344
+ # Compute global to local indices
345
+ if self._space_to_partition is None or self._space_to_partition.shape != node_indices.shape:
346
+ self._space_to_partition = cache.borrow_temporary_like(node_indices, temporary_store)
347
+
348
+ wp.launch(
349
+ kernel=NodePartition._scatter_partition_indices,
350
+ dim=self.space_topology.node_count(),
351
+ device=device,
352
+ inputs=[max_node_count, category_offsets, node_indices, self._space_to_partition],
353
+ )
354
+
355
+ if copy_event is not None:
356
+ cache.synchronize_event(copy_event) # Transfer to host must be finished to access node_count()
357
+
358
+ # Copy to shrunk-to-fit array
359
+ if self._node_indices is None or self._node_indices.shape[0] != self.node_count():
360
+ self._node_indices = cache.borrow_temporary(
361
+ temporary_store, shape=(self.node_count(),), dtype=int, device=device
362
+ )
363
+
364
+ wp.copy(dest=self._node_indices, src=node_indices, count=self.node_count())
365
+ node_indices.release()
366
+
367
+ @wp.kernel
368
+ def _scatter_partition_indices(
369
+ max_node_count: int,
370
+ category_offsets: wp.array(dtype=int),
371
+ node_indices: wp.array(dtype=int),
372
+ space_to_partition_indices: wp.array(dtype=int),
373
+ ):
374
+ local_idx = wp.tid()
375
+ space_idx = node_indices[local_idx]
376
+
377
+ local_node_count = category_offsets[NodeCategory.EXTERIOR] # all but exterior nodes
378
+ if max_node_count >= 0:
379
+ if local_node_count > max_node_count:
380
+ if local_idx == 0:
381
+ wp.printf(
382
+ "Number of space partition nodes exceeded the %d limit; increase `max_node_count` to %d.\n",
383
+ max_node_count,
384
+ local_node_count,
385
+ )
386
+
387
+ local_node_count = max_node_count
388
+
389
+ if local_idx < local_node_count:
390
+ space_to_partition_indices[space_idx] = local_idx
391
+ else:
392
+ space_to_partition_indices[space_idx] = NULL_NODE_INDEX
393
+
394
+
395
+ def make_space_partition(
396
+ space: Optional[FunctionSpace] = None,
397
+ geometry_partition: Optional[GeometryPartition] = None,
398
+ space_topology: Optional[SpaceTopology] = None,
399
+ with_halo: bool = True,
400
+ max_node_count: int = -1,
401
+ device=None,
402
+ temporary_store: cache.TemporaryStore = None,
403
+ ) -> SpacePartition:
404
+ """Computes the subset of nodes from a function space topology that touch a geometry partition
405
+
406
+ Either `space_topology` or `space` must be provided (and will be considered in that order).
407
+
408
+ Args:
409
+ space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
410
+ geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
411
+ space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
412
+ with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
413
+ max_node_count: if positive, will be used to limit the number of nodes to avoid device/host synchronization.
414
+ device: Warp device on which to perform and store computations
415
+
416
+ Returns:
417
+ the resulting space partition
418
+ """
419
+
420
+ if space_topology is None:
421
+ space_topology = space.topology
422
+
423
+ space_topology = space_topology.full_space_topology()
424
+
425
+ if geometry_partition is not None and not isinstance(geometry_partition, WholeGeometryPartition):
426
+ return NodePartition(
427
+ space_topology=space_topology,
428
+ geo_partition=geometry_partition,
429
+ with_halo=with_halo,
430
+ max_node_count=max_node_count,
431
+ device=device,
432
+ temporary_store=temporary_store,
433
+ )
434
+
435
+ return WholeSpacePartition(space_topology)