warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,201 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Union
17
+
18
+ import warp as wp
19
+ from warp.fem import cache
20
+ from warp.fem.geometry import AdaptiveNanogrid, Nanogrid
21
+ from warp.fem.types import ElementIndex
22
+
23
+ from .shape import CubeShapeFunction
24
+ from .topology import SpaceTopology, forward_base_topology
25
+
26
+
27
+ @wp.struct
28
+ class NanogridTopologyArg:
29
+ vertex_grid: wp.uint64
30
+ face_grid: wp.uint64
31
+ edge_grid: wp.uint64
32
+
33
+ vertex_count: int
34
+ edge_count: int
35
+ face_count: int
36
+
37
+
38
+ class NanogridSpaceTopology(SpaceTopology):
39
+ TopologyArg = NanogridTopologyArg
40
+
41
+ def __init__(
42
+ self,
43
+ grid: Union[Nanogrid, AdaptiveNanogrid],
44
+ shape: CubeShapeFunction,
45
+ ):
46
+ self._shape = shape
47
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
48
+ self._grid = grid
49
+
50
+ self._vertex_grid = grid.vertex_grid.id
51
+
52
+ need_edge_indices = shape.EDGE_NODE_COUNT > 0
53
+ need_face_indices = shape.FACE_NODE_COUNT > 0
54
+
55
+ if isinstance(grid, Nanogrid):
56
+ self._edge_grid = grid.edge_grid.id if need_edge_indices else -1
57
+ self._face_grid = grid.face_grid.id if need_face_indices else -1
58
+ self._edge_count = grid.edge_count() if need_edge_indices else 0
59
+ self._face_count = grid.side_count() if need_face_indices else 0
60
+ else:
61
+ self._edge_grid = grid.stacked_edge_grid.id if need_edge_indices else -1
62
+ self._face_grid = grid.stacked_face_grid.id if need_face_indices else -1
63
+ self._edge_count = grid.stacked_edge_count() if need_edge_indices else 0
64
+ self._face_count = grid.stacked_face_count() if need_face_indices else 0
65
+
66
+ self.element_node_index = self._make_element_node_index()
67
+
68
+ @property
69
+ def name(self):
70
+ return f"{self.geometry.name}_{self._shape.name}"
71
+
72
+ @cache.cached_arg_value
73
+ def topo_arg_value(self, device):
74
+ arg = NanogridTopologyArg()
75
+
76
+ arg.vertex_grid = self._vertex_grid
77
+ arg.face_grid = self._face_grid
78
+ arg.edge_grid = self._edge_grid
79
+
80
+ arg.vertex_count = self._grid.vertex_count()
81
+ arg.face_count = self._face_count
82
+ arg.edge_count = self._edge_count
83
+ return arg
84
+
85
+ def _make_element_node_index(self):
86
+ element_node_index_generic = self._make_element_node_index_generic()
87
+
88
+ @cache.dynamic_func(suffix=self.name)
89
+ def element_node_index(
90
+ geo_arg: Nanogrid.CellArg,
91
+ topo_arg: NanogridTopologyArg,
92
+ element_index: ElementIndex,
93
+ node_index_in_elt: int,
94
+ ):
95
+ ijk = geo_arg.cell_ijk[element_index]
96
+ return element_node_index_generic(topo_arg, element_index, node_index_in_elt, ijk, 0)
97
+
98
+ if isinstance(self._grid, Nanogrid):
99
+ return element_node_index
100
+
101
+ @cache.dynamic_func(suffix=self.name)
102
+ def element_node_index_adaptive(
103
+ geo_arg: AdaptiveNanogrid.CellArg,
104
+ topo_arg: NanogridTopologyArg,
105
+ element_index: ElementIndex,
106
+ node_index_in_elt: int,
107
+ ):
108
+ ijk = geo_arg.cell_ijk[element_index]
109
+ level = int(geo_arg.cell_level[element_index])
110
+ return element_node_index_generic(topo_arg, element_index, node_index_in_elt, ijk, level)
111
+
112
+ return element_node_index_adaptive
113
+
114
+ def node_count(self) -> int:
115
+ return (
116
+ self._grid.vertex_count() * self._shape.VERTEX_NODE_COUNT
117
+ + self._edge_count * self._shape.EDGE_NODE_COUNT
118
+ + self._face_count * self._shape.FACE_NODE_COUNT
119
+ + self._grid.cell_count() * self._shape.INTERIOR_NODE_COUNT
120
+ )
121
+
122
+ def _make_element_node_index_generic(self):
123
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
124
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
125
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
126
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
127
+
128
+ @cache.dynamic_func(suffix=self.name)
129
+ def element_node_index_generic(
130
+ topo_arg: NanogridTopologyArg,
131
+ element_index: ElementIndex,
132
+ node_index_in_elt: int,
133
+ ijk: wp.vec3i,
134
+ level: int,
135
+ ):
136
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
137
+
138
+ if wp.static(VERTEX_NODE_COUNT > 0):
139
+ if node_type == CubeShapeFunction.VERTEX:
140
+ n_ijk = _cell_vertex_coord(ijk, level, type_instance)
141
+ return (
142
+ wp.volume_lookup_index(topo_arg.vertex_grid, n_ijk[0], n_ijk[1], n_ijk[2]) * VERTEX_NODE_COUNT
143
+ + type_index
144
+ )
145
+
146
+ offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
147
+
148
+ if wp.static(EDGE_NODE_COUNT > 0):
149
+ if node_type == CubeShapeFunction.EDGE:
150
+ axis = type_instance >> 2
151
+ node_offset = type_instance & 3
152
+
153
+ n_ijk = _cell_edge_coord(ijk, level, axis, node_offset)
154
+
155
+ edge_index = wp.volume_lookup_index(topo_arg.edge_grid, n_ijk[0], n_ijk[1], n_ijk[2])
156
+ return offset + EDGE_NODE_COUNT * edge_index + type_index
157
+
158
+ offset += EDGE_NODE_COUNT * topo_arg.edge_count
159
+
160
+ if wp.static(FACE_NODE_COUNT > 0):
161
+ if node_type == CubeShapeFunction.FACE:
162
+ axis = type_instance >> 1
163
+ node_offset = type_instance & 1
164
+
165
+ n_ijk = _cell_face_coord(ijk, level, axis, node_offset)
166
+
167
+ face_index = wp.volume_lookup_index(topo_arg.face_grid, n_ijk[0], n_ijk[1], n_ijk[2])
168
+ return offset + FACE_NODE_COUNT * face_index + type_index
169
+
170
+ offset += FACE_NODE_COUNT * topo_arg.face_count
171
+
172
+ return offset + INTERIOR_NODE_COUNT * element_index + type_index
173
+
174
+ return element_node_index_generic
175
+
176
+
177
+ @wp.func
178
+ def _cell_vertex_coord(cell_ijk: wp.vec3i, cell_level: int, n: int):
179
+ return cell_ijk + AdaptiveNanogrid.fine_ijk(wp.vec3i((n & 4) >> 2, (n & 2) >> 1, n & 1), cell_level)
180
+
181
+
182
+ @wp.func
183
+ def _cell_edge_coord(cell_ijk: wp.vec3i, cell_level: int, axis: int, offset: int):
184
+ e_ijk = AdaptiveNanogrid.coarse_ijk(cell_ijk, cell_level)
185
+ e_ijk[(axis + 1) % 3] += offset >> 1
186
+ e_ijk[(axis + 2) % 3] += offset & 1
187
+ return AdaptiveNanogrid.encode_axis_and_level(e_ijk, axis, cell_level)
188
+
189
+
190
+ @wp.func
191
+ def _cell_face_coord(cell_ijk: wp.vec3i, cell_level: int, axis: int, offset: int):
192
+ f_ijk = AdaptiveNanogrid.coarse_ijk(cell_ijk, cell_level)
193
+ f_ijk[axis] += offset
194
+ return AdaptiveNanogrid.encode_axis_and_level(f_ijk, axis, cell_level)
195
+
196
+
197
+ def make_nanogrid_space_topology(grid: Union[Nanogrid, AdaptiveNanogrid], shape: CubeShapeFunction):
198
+ if isinstance(shape, CubeShapeFunction):
199
+ return forward_base_topology(NanogridSpaceTopology, grid, shape)
200
+
201
+ raise ValueError(f"Unsupported shape function {shape.name}")
@@ -0,0 +1,367 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional
17
+
18
+ import warp as wp
19
+ import warp.fem.cache as cache
20
+ from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
21
+ from warp.fem.types import NULL_NODE_INDEX
22
+ from warp.fem.utils import compress_node_indices
23
+
24
+ from .function_space import FunctionSpace
25
+ from .topology import SpaceTopology
26
+
27
+ wp.set_module_options({"enable_backward": False})
28
+
29
+
30
+ class SpacePartition:
31
+ class PartitionArg:
32
+ pass
33
+
34
+ def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition):
35
+ self.space_topology = space_topology
36
+ self.geo_partition = geo_partition
37
+
38
+ def node_count(self):
39
+ """Returns number of nodes in this partition"""
40
+
41
+ def owned_node_count(self) -> int:
42
+ """Returns number of nodes in this partition, excluding exterior halo"""
43
+
44
+ def interior_node_count(self) -> int:
45
+ """Returns number of interior nodes in this partition"""
46
+
47
+ def space_node_indices(self) -> wp.array:
48
+ """Return the global function space indices for nodes in this partition"""
49
+
50
+ def partition_arg_value(self, device):
51
+ pass
52
+
53
+ @staticmethod
54
+ def partition_node_index(args: "PartitionArg", space_node_index: int):
55
+ """Returns the index in the partition of a function space node, or ``NULL_NODE_INDEX`` if it does not exist"""
56
+
57
+ def __str__(self) -> str:
58
+ return self.name
59
+
60
+ @property
61
+ def name(self) -> str:
62
+ return f"{self.__class__.__name__}"
63
+
64
+
65
+ class WholeSpacePartition(SpacePartition):
66
+ @wp.struct
67
+ class PartitionArg:
68
+ pass
69
+
70
+ def __init__(self, space_topology: SpaceTopology):
71
+ super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry))
72
+ self._node_indices = None
73
+
74
+ def node_count(self):
75
+ """Returns number of nodes in this partition"""
76
+ return self.space_topology.node_count()
77
+
78
+ def owned_node_count(self) -> int:
79
+ """Returns number of nodes in this partition, excluding exterior halo"""
80
+ return self.space_topology.node_count()
81
+
82
+ def interior_node_count(self) -> int:
83
+ """Returns number of interior nodes in this partition"""
84
+ return self.space_topology.node_count()
85
+
86
+ def space_node_indices(self):
87
+ """Return the global function space indices for nodes in this partition"""
88
+ if self._node_indices is None:
89
+ self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
90
+ wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array])
91
+ return self._node_indices.array
92
+
93
+ def partition_arg_value(self, device):
94
+ return WholeSpacePartition.PartitionArg()
95
+
96
+ @wp.func
97
+ def partition_node_index(args: Any, space_node_index: int):
98
+ return space_node_index
99
+
100
+ def __eq__(self, other: SpacePartition) -> bool:
101
+ return isinstance(other, SpacePartition) and self.space_topology == other.space_topology
102
+
103
+ @property
104
+ def name(self) -> str:
105
+ return "Whole"
106
+
107
+ @wp.kernel
108
+ def _iota_kernel(indices: wp.array(dtype=int)):
109
+ indices[wp.tid()] = wp.tid()
110
+
111
+
112
+ class NodeCategory:
113
+ OWNED_INTERIOR = wp.constant(0)
114
+ """Node is touched exclusively by this partition, not touched by frontier side"""
115
+ OWNED_FRONTIER = wp.constant(1)
116
+ """Node is touched by a frontier side, but belongs to an element of this partition"""
117
+ HALO_LOCAL_SIDE = wp.constant(2)
118
+ """Node belongs to an element of another partition, but is touched by one of our frontier side"""
119
+ HALO_OTHER_SIDE = wp.constant(3)
120
+ """Node belongs to an element of another partition, and is not touched by one of our frontier side"""
121
+ EXTERIOR = wp.constant(4)
122
+ """Node is never referenced by this partition"""
123
+
124
+ COUNT = 5
125
+
126
+
127
+ class NodePartition(SpacePartition):
128
+ @wp.struct
129
+ class PartitionArg:
130
+ space_to_partition: wp.array(dtype=int)
131
+
132
+ def __init__(
133
+ self,
134
+ space_topology: SpaceTopology,
135
+ geo_partition: GeometryPartition,
136
+ with_halo: bool = True,
137
+ device=None,
138
+ temporary_store: cache.TemporaryStore = None,
139
+ ):
140
+ super().__init__(space_topology=space_topology, geo_partition=geo_partition)
141
+
142
+ self._compute_node_indices_from_sides(device, with_halo, temporary_store)
143
+
144
+ def node_count(self) -> int:
145
+ """Returns number of nodes referenced by this partition, including exterior halo"""
146
+ return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
147
+
148
+ def owned_node_count(self) -> int:
149
+ """Returns number of nodes in this partition, excluding exterior halo"""
150
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
151
+
152
+ def interior_node_count(self) -> int:
153
+ """Returns number of interior nodes in this partition"""
154
+ return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
155
+
156
+ def space_node_indices(self):
157
+ """Return the global function space indices for nodes in this partition"""
158
+ return self._node_indices.array
159
+
160
+ @cache.cached_arg_value
161
+ def partition_arg_value(self, device):
162
+ arg = NodePartition.PartitionArg()
163
+ arg.space_to_partition = self._space_to_partition.array.to(device)
164
+ return arg
165
+
166
+ @wp.func
167
+ def partition_node_index(args: PartitionArg, space_node_index: int):
168
+ return args.space_to_partition[space_node_index]
169
+
170
+ def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: cache.TemporaryStore):
171
+ from warp.fem import cache
172
+
173
+ trace_topology = self.space_topology.trace()
174
+
175
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
176
+ def node_category_from_cells_kernel(
177
+ geo_arg: self.geo_partition.geometry.CellArg,
178
+ geo_partition_arg: self.geo_partition.CellArg,
179
+ space_arg: self.space_topology.TopologyArg,
180
+ node_mask: wp.array(dtype=int),
181
+ ):
182
+ partition_cell_index = wp.tid()
183
+
184
+ cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
185
+
186
+ cell_node_count = self.space_topology.element_node_count(geo_arg, space_arg, cell_index)
187
+ for n in range(cell_node_count):
188
+ space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n)
189
+ node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR
190
+
191
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
192
+ def node_category_from_owned_sides_kernel(
193
+ geo_arg: self.geo_partition.geometry.SideArg,
194
+ geo_partition_arg: self.geo_partition.SideArg,
195
+ space_arg: trace_topology.TopologyArg,
196
+ node_mask: wp.array(dtype=int),
197
+ ):
198
+ partition_side_index = wp.tid()
199
+
200
+ side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
201
+
202
+ side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
203
+ for n in range(side_node_count):
204
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
205
+
206
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
207
+ node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE
208
+
209
+ @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
210
+ def node_category_from_frontier_sides_kernel(
211
+ geo_arg: self.geo_partition.geometry.SideArg,
212
+ geo_partition_arg: self.geo_partition.SideArg,
213
+ space_arg: trace_topology.TopologyArg,
214
+ node_mask: wp.array(dtype=int),
215
+ ):
216
+ frontier_side_index = wp.tid()
217
+
218
+ side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
219
+
220
+ side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
221
+ for n in range(side_node_count):
222
+ space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n)
223
+ if node_mask[space_nidx] == NodeCategory.EXTERIOR:
224
+ node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE
225
+ elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR:
226
+ node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER
227
+
228
+ node_category = cache.borrow_temporary(
229
+ temporary_store,
230
+ shape=(self.space_topology.node_count(),),
231
+ dtype=int,
232
+ device=device,
233
+ )
234
+ node_category.array.fill_(value=NodeCategory.EXTERIOR)
235
+
236
+ wp.launch(
237
+ dim=self.geo_partition.cell_count(),
238
+ kernel=node_category_from_cells_kernel,
239
+ inputs=[
240
+ self.geo_partition.geometry.cell_arg_value(device),
241
+ self.geo_partition.cell_arg_value(device),
242
+ self.space_topology.topo_arg_value(device),
243
+ node_category.array,
244
+ ],
245
+ device=device,
246
+ )
247
+
248
+ if with_halo:
249
+ wp.launch(
250
+ dim=self.geo_partition.side_count(),
251
+ kernel=node_category_from_owned_sides_kernel,
252
+ inputs=[
253
+ self.geo_partition.geometry.side_arg_value(device),
254
+ self.geo_partition.side_arg_value(device),
255
+ self.space_topology.topo_arg_value(device),
256
+ node_category.array,
257
+ ],
258
+ device=device,
259
+ )
260
+
261
+ wp.launch(
262
+ dim=self.geo_partition.frontier_side_count(),
263
+ kernel=node_category_from_frontier_sides_kernel,
264
+ inputs=[
265
+ self.geo_partition.geometry.side_arg_value(device),
266
+ self.geo_partition.side_arg_value(device),
267
+ self.space_topology.topo_arg_value(device),
268
+ node_category.array,
269
+ ],
270
+ device=device,
271
+ )
272
+
273
+ self._finalize_node_indices(node_category.array, temporary_store)
274
+
275
+ node_category.release()
276
+
277
+ def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: cache.TemporaryStore):
278
+ category_offsets, node_indices = compress_node_indices(
279
+ NodeCategory.COUNT, node_category, temporary_store=temporary_store
280
+ )
281
+
282
+ # Copy offsets to cpu
283
+ device = node_category.device
284
+ with wp.ScopedDevice(device):
285
+ self._category_offsets = cache.borrow_temporary(
286
+ temporary_store,
287
+ shape=category_offsets.array.shape,
288
+ dtype=category_offsets.array.dtype,
289
+ pinned=device.is_cuda,
290
+ device="cpu",
291
+ )
292
+ wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
293
+ copy_event = cache.capture_event()
294
+
295
+ # Compute global to local indices
296
+ self._space_to_partition = cache.borrow_temporary_like(node_indices, temporary_store)
297
+ wp.launch(
298
+ kernel=NodePartition._scatter_partition_indices,
299
+ dim=self.space_topology.node_count(),
300
+ device=device,
301
+ inputs=[category_offsets.array, node_indices.array, self._space_to_partition.array],
302
+ )
303
+
304
+ # Copy to shrinked-to-fit array
305
+ cache.synchronize_event(copy_event) # Transfer to host must be finished to access node_count()
306
+ self._node_indices = cache.borrow_temporary(
307
+ temporary_store, shape=(self.node_count()), dtype=int, device=device
308
+ )
309
+ wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
310
+
311
+ node_indices.release()
312
+
313
+ @wp.kernel
314
+ def _scatter_partition_indices(
315
+ category_offsets: wp.array(dtype=int),
316
+ node_indices: wp.array(dtype=int),
317
+ space_to_partition_indices: wp.array(dtype=int),
318
+ ):
319
+ local_idx = wp.tid()
320
+ space_idx = node_indices[local_idx]
321
+
322
+ local_node_count = category_offsets[NodeCategory.EXTERIOR] # all but exterior nodes
323
+ if local_idx < local_node_count:
324
+ space_to_partition_indices[space_idx] = local_idx
325
+ else:
326
+ space_to_partition_indices[space_idx] = NULL_NODE_INDEX
327
+
328
+
329
+ def make_space_partition(
330
+ space: Optional[FunctionSpace] = None,
331
+ geometry_partition: Optional[GeometryPartition] = None,
332
+ space_topology: Optional[SpaceTopology] = None,
333
+ with_halo: bool = True,
334
+ device=None,
335
+ temporary_store: cache.TemporaryStore = None,
336
+ ) -> SpacePartition:
337
+ """Computes the subset of nodes from a function space topology that touch a geometry partition
338
+
339
+ Either `space_topology` or `space` must be provided (and will be considered in that order).
340
+
341
+ Args:
342
+ space: (deprecated) the function space defining the topology if `space_topology` is ``None``.
343
+ geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
344
+ space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
345
+ with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
346
+ device: Warp device on which to perform and store computations
347
+
348
+ Returns:
349
+ the resulting space partition
350
+ """
351
+
352
+ if space_topology is None:
353
+ space_topology = space.topology
354
+
355
+ space_topology = space_topology.full_space_topology()
356
+
357
+ if geometry_partition is not None:
358
+ if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
359
+ return NodePartition(
360
+ space_topology=space_topology,
361
+ geo_partition=geometry_partition,
362
+ with_halo=with_halo,
363
+ device=device,
364
+ temporary_store=temporary_store,
365
+ )
366
+
367
+ return WholeSpacePartition(space_topology)