warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -14,15 +14,17 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from functools import cached_property
17
- from typing import Any
17
+ from typing import Any, Optional
18
18
 
19
19
  import warp as wp
20
- from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_struct
21
- from warp.fem.types import NULL_ELEMENT_INDEX, ElementIndex
22
- from warp.fem.utils import masked_indices
20
+ from warp._src.fem import cache
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, ElementIndex
22
+ from warp._src.fem.utils import masked_indices
23
23
 
24
24
  from .geometry import Geometry
25
25
 
26
+ _wp_module_name_ = "warp.fem.geometry.partition"
27
+
26
28
  wp.set_module_options({"enable_backward": False})
27
29
 
28
30
 
@@ -61,17 +63,27 @@ class GeometryPartition:
61
63
  def __str__(self) -> str:
62
64
  return self.name
63
65
 
66
+ @cache.cached_arg_value
64
67
  def cell_arg_value(self, device):
65
- raise NotImplementedError()
68
+ args = self.CellArg()
69
+ self.fill_cell_arg(args, device)
70
+ return args
66
71
 
67
72
  def fill_cell_arg(self, args: CellArg, device):
68
- raise NotImplementedError()
73
+ if self.cell_arg_value is __class__.cell_arg_value:
74
+ raise NotImplementedError()
75
+ args.assign(self.cell_arg_value(device))
69
76
 
77
+ @cache.cached_arg_value
70
78
  def side_arg_value(self, device):
71
- raise NotImplementedError()
79
+ args = self.SideArg()
80
+ self.fill_side_arg(args, device)
81
+ return args
72
82
 
73
83
  def fill_side_arg(self, args: SideArg, device):
74
- raise NotImplementedError()
84
+ if self.side_arg_value is __class__.side_arg_value:
85
+ raise NotImplementedError()
86
+ args.assign(self.side_arg_value(device))
75
87
 
76
88
  @staticmethod
77
89
  def cell_index(args: CellArg, partition_cell_index: int):
@@ -139,10 +151,6 @@ class WholeGeometryPartition(GeometryPartition):
139
151
  class CellArg:
140
152
  pass
141
153
 
142
- def cell_arg_value(self, device):
143
- arg = WholeGeometryPartition.CellArg()
144
- return arg
145
-
146
154
  def fill_cell_arg(self, args: CellArg, device):
147
155
  pass
148
156
 
@@ -169,12 +177,16 @@ class CellBasedGeometryPartition(GeometryPartition):
169
177
  ):
170
178
  super().__init__(geometry)
171
179
 
180
+ self._partition_side_indices: wp.array = None
181
+ self._boundary_side_indices: wp.array = None
182
+ self._frontier_side_indices: wp.array = None
183
+
172
184
  @cached_property
173
185
  def SideArg(self):
174
186
  return self._make_side_arg()
175
187
 
176
188
  def _make_side_arg(self):
177
- @dynamic_struct(suffix=self.name)
189
+ @cache.dynamic_struct(suffix=self.name)
178
190
  class SideArg:
179
191
  cell_arg: self.CellArg
180
192
  partition_side_indices: wp.array(dtype=int)
@@ -184,25 +196,19 @@ class CellBasedGeometryPartition(GeometryPartition):
184
196
  return SideArg
185
197
 
186
198
  def side_count(self) -> int:
187
- return self._partition_side_indices.array.shape[0]
199
+ return self._partition_side_indices.shape[0]
188
200
 
189
201
  def boundary_side_count(self) -> int:
190
- return self._boundary_side_indices.array.shape[0]
202
+ return self._boundary_side_indices.shape[0]
191
203
 
192
204
  def frontier_side_count(self) -> int:
193
- return self._frontier_side_indices.array.shape[0]
194
-
195
- @cached_arg_value
196
- def side_arg_value(self, device):
197
- arg = self.SideArg()
198
- self.fill_side_arg(arg, device)
199
- return arg
205
+ return self._frontier_side_indices.shape[0]
200
206
 
201
207
  def fill_side_arg(self, args: SideArg, device):
202
208
  self.fill_cell_arg(args.cell_arg, device)
203
- args.partition_side_indices = self._partition_side_indices.array.to(device)
204
- args.boundary_side_indices = self._boundary_side_indices.array.to(device)
205
- args.frontier_side_indices = self._frontier_side_indices.array.to(device)
209
+ args.partition_side_indices = self._partition_side_indices.to(device)
210
+ args.boundary_side_indices = self._boundary_side_indices.to(device)
211
+ args.frontier_side_indices = self._frontier_side_indices.to(device)
206
212
 
207
213
  @wp.func
208
214
  def side_index(args: Any, partition_side_index: int):
@@ -220,9 +226,20 @@ class CellBasedGeometryPartition(GeometryPartition):
220
226
  return args.frontier_side_indices[frontier_side_index]
221
227
 
222
228
  def compute_side_indices_from_cells(
223
- self, cell_arg_value: Any, cell_inclusion_test_func: wp.Function, device, temporary_store: TemporaryStore = None
229
+ self,
230
+ cell_arg_value: Any,
231
+ cell_inclusion_test_func: wp.Function,
232
+ device,
233
+ max_side_count: int = -1,
234
+ temporary_store: cache.TemporaryStore = None,
224
235
  ):
225
- from warp.fem import cache
236
+ self.side_arg_value.invalidate(self)
237
+
238
+ if max_side_count == 0:
239
+ self._partition_side_indices = cache.borrow_temporary(temporary_store, dtype=int, shape=(0,), device=device)
240
+ self._boundary_side_indices = self._partition_side_indices
241
+ self._frontier_side_indices = self._partition_side_indices
242
+ return
226
243
 
227
244
  cell_arg_type = next(iter(cell_inclusion_test_func.input_types.values()))
228
245
 
@@ -253,46 +270,61 @@ class CellBasedGeometryPartition(GeometryPartition):
253
270
  # Exactly one neighbor in partition; count as frontier side
254
271
  frontier_side_mask[side_index] = 1
255
272
 
256
- partition_side_mask = borrow_temporary(
273
+ partition_side_mask = cache.borrow_temporary(
257
274
  temporary_store,
258
275
  shape=(self.geometry.side_count(),),
259
276
  dtype=int,
260
277
  device=device,
261
278
  )
262
- boundary_side_mask = borrow_temporary(
279
+ boundary_side_mask = cache.borrow_temporary(
263
280
  temporary_store,
264
281
  shape=(self.geometry.side_count(),),
265
282
  dtype=int,
266
283
  device=device,
267
284
  )
268
- frontier_side_mask = borrow_temporary(
285
+ frontier_side_mask = cache.borrow_temporary(
269
286
  temporary_store,
270
287
  shape=(self.geometry.side_count(),),
271
288
  dtype=int,
272
289
  device=device,
273
290
  )
274
291
 
275
- partition_side_mask.array.zero_()
276
- boundary_side_mask.array.zero_()
277
- frontier_side_mask.array.zero_()
292
+ partition_side_mask.zero_()
293
+ boundary_side_mask.zero_()
294
+ frontier_side_mask.zero_()
278
295
 
279
296
  wp.launch(
280
- dim=partition_side_mask.array.shape[0],
297
+ dim=partition_side_mask.shape[0],
281
298
  kernel=count_sides,
282
299
  inputs=[
283
300
  self.geometry.side_arg_value(device),
284
301
  cell_arg_value,
285
- partition_side_mask.array,
286
- boundary_side_mask.array,
287
- frontier_side_mask.array,
302
+ partition_side_mask,
303
+ boundary_side_mask,
304
+ frontier_side_mask,
288
305
  ],
289
306
  device=device,
290
307
  )
291
308
 
292
309
  # Convert counts to indices
293
- self._partition_side_indices, _ = masked_indices(partition_side_mask.array, temporary_store=temporary_store)
294
- self._boundary_side_indices, _ = masked_indices(boundary_side_mask.array, temporary_store=temporary_store)
295
- self._frontier_side_indices, _ = masked_indices(frontier_side_mask.array, temporary_store=temporary_store)
310
+ self._partition_side_indices, _ = masked_indices(
311
+ partition_side_mask,
312
+ max_index_count=max_side_count,
313
+ local_to_global=self._partition_side_indices,
314
+ temporary_store=temporary_store,
315
+ )
316
+ self._boundary_side_indices, _ = masked_indices(
317
+ boundary_side_mask,
318
+ max_index_count=max_side_count,
319
+ local_to_global=self._boundary_side_indices,
320
+ temporary_store=temporary_store,
321
+ )
322
+ self._frontier_side_indices, _ = masked_indices(
323
+ frontier_side_mask,
324
+ max_index_count=max_side_count,
325
+ local_to_global=self._frontier_side_indices,
326
+ temporary_store=temporary_store,
327
+ )
296
328
 
297
329
  partition_side_mask.release()
298
330
  boundary_side_mask.release()
@@ -310,7 +342,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
310
342
  partition_rank: int,
311
343
  partition_count: int,
312
344
  device=None,
313
- temporary_store: TemporaryStore = None,
345
+ temporary_store: cache.TemporaryStore = None,
314
346
  ):
315
347
  """Creates a geometry partition by uniformly partionning cell indices
316
348
 
@@ -343,11 +375,6 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
343
375
  cell_begin: int
344
376
  cell_end: int
345
377
 
346
- def cell_arg_value(self, device):
347
- arg = LinearGeometryPartition.CellArg()
348
- self.fill_cell_arg(arg, device)
349
- return arg
350
-
351
378
  def fill_cell_arg(self, args: CellArg, device):
352
379
  args.cell_begin = self.cell_begin
353
380
  args.cell_end = self.cell_end
@@ -372,43 +399,76 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
372
399
 
373
400
 
374
401
  class ExplicitGeometryPartition(CellBasedGeometryPartition):
375
- def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)", temporary_store: TemporaryStore = None):
376
- """Creates a geometry partition by uniformly partionning cell indices
402
+ def __init__(
403
+ self,
404
+ geometry: Geometry,
405
+ cell_mask: "wp.array(dtype=int)",
406
+ max_cell_count: int = -1,
407
+ max_side_count: int = -1,
408
+ temporary_store: Optional[cache.TemporaryStore] = None,
409
+ ):
410
+ """Creates a geometry partition from an active cell mask
377
411
 
378
412
  Args:
379
413
  geometry: the geometry to partition
380
414
  cell_mask: warp array of length ``geometry.cell_count()`` indicating which cells are selected. Array values must be either ``1`` (selected) or ``0`` (not selected).
415
+ max_cell_count: if positive, will be used to limit the number of cells to avoid device/host synchronization
416
+ max_side_count: if positive, will be used to limit the number of sides to avoid device/host synchronization
381
417
  """
382
418
 
383
419
  super().__init__(geometry)
384
420
 
385
- self._cell_mask = cell_mask
386
- self._cells, self._partition_cells = masked_indices(self._cell_mask, temporary_store=temporary_store)
421
+ self._cells: wp.array = None
422
+ self._partition_cells: wp.array = None
423
+
424
+ self._max_cell_count = max_cell_count
425
+ self._max_side_count = max_side_count
426
+
427
+ self.rebuild(cell_mask, temporary_store)
428
+
429
+ def rebuild(
430
+ self,
431
+ cell_mask: "wp.array(dtype=int)",
432
+ temporary_store: Optional[cache.TemporaryStore] = None,
433
+ ):
434
+ """
435
+ Rebuilds the geometry partition from a new active cell mask
436
+
437
+ Args:
438
+ geometry: the geometry to partition
439
+ cell_mask: warp array of length ``geometry.cell_count()`` indicating which cells are selected. Array values must be either ``1`` (selected) or ``0`` (not selected).
440
+ max_cell_count: if positive, will be used to limit the number of cells to avoid device/host synchronization
441
+ max_side_count: if positive, will be used to limit the number of sides to avoid device/host synchronization
442
+ """
443
+ self.cell_arg_value.invalidate(self)
444
+
445
+ self._cells, self._partition_cells = masked_indices(
446
+ cell_mask,
447
+ local_to_global=self._cells,
448
+ global_to_local=self._partition_cells,
449
+ max_index_count=self._max_cell_count,
450
+ temporary_store=temporary_store,
451
+ )
387
452
 
388
453
  super().compute_side_indices_from_cells(
389
- self._cell_mask,
454
+ self.cell_arg_value(cell_mask.device),
390
455
  ExplicitGeometryPartition._cell_inclusion_test,
391
- self._cell_mask.device,
456
+ max_side_count=self._max_side_count,
457
+ device=cell_mask.device,
392
458
  temporary_store=temporary_store,
393
459
  )
394
460
 
395
461
  def cell_count(self) -> int:
396
- return self._cells.array.shape[0]
462
+ return self._cells.shape[0]
397
463
 
398
464
  @wp.struct
399
465
  class CellArg:
400
466
  cell_index: wp.array(dtype=int)
401
467
  partition_cell_index: wp.array(dtype=int)
402
468
 
403
- @cached_arg_value
404
- def cell_arg_value(self, device):
405
- arg = ExplicitGeometryPartition.CellArg()
406
- self.fill_cell_arg(arg, device)
407
- return arg
408
-
409
469
  def fill_cell_arg(self, args: CellArg, device):
410
- args.cell_index = self._cells.array.to(device)
411
- args.partition_cell_index = self._partition_cells.array.to(device)
470
+ args.cell_index = self._cells.to(device)
471
+ args.partition_cell_index = self._partition_cells.to(device)
412
472
 
413
473
  @wp.func
414
474
  def cell_index(args: CellArg, partition_cell_index: int):
@@ -419,5 +479,5 @@ class ExplicitGeometryPartition(CellBasedGeometryPartition):
419
479
  return args.partition_cell_index[cell_index]
420
480
 
421
481
  @wp.func
422
- def _cell_inclusion_test(mask: wp.array(dtype=int), cell_index: int):
423
- return mask[cell_index] > 0
482
+ def _cell_inclusion_test(arg: CellArg, cell_index: int):
483
+ return arg.partition_cell_index[cell_index] != NULL_ELEMENT_INDEX
@@ -16,18 +16,19 @@
16
16
  from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import (
19
+ from warp._src.fem.cache import (
20
20
  TemporaryStore,
21
21
  borrow_temporary,
22
22
  borrow_temporary_like,
23
- cached_arg_value,
24
23
  )
25
- from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample
24
+ from warp._src.fem.types import OUTSIDE, Coords, ElementIndex, Sample
26
25
 
27
26
  from .closest_point import project_on_seg_at_origin
28
- from .element import LinearEdge, Square
27
+ from .element import Element
29
28
  from .geometry import Geometry
30
29
 
30
+ _wp_module_name_ = "warp.fem.geometry.quadmesh"
31
+
31
32
 
32
33
  @wp.struct
33
34
  class QuadmeshCellArg:
@@ -99,11 +100,11 @@ class Quadmesh(Geometry):
99
100
  def boundary_side_count(self):
100
101
  return self._boundary_edge_indices.shape[0]
101
102
 
102
- def reference_cell(self) -> Square:
103
- return Square()
103
+ def reference_cell(self) -> Element:
104
+ return Element.SQUARE
104
105
 
105
- def reference_side(self) -> LinearEdge:
106
- return LinearEdge()
106
+ def reference_side(self) -> Element:
107
+ return Element.LINE_SEGMENT
107
108
 
108
109
  @property
109
110
  def edge_quad_indices(self) -> wp.array:
@@ -126,30 +127,14 @@ class Quadmesh(Geometry):
126
127
  args.edge_vertex_indices = self._edge_vertex_indices.to(device)
127
128
  args.edge_quad_indices = self._edge_quad_indices.to(device)
128
129
 
129
- def cell_arg_value(self, device):
130
- args = self.CellArg()
131
- self.fill_cell_arg(args, device)
132
- return args
133
-
134
130
  def fill_cell_arg(self, args: "Quadmesh.CellArg", device):
135
131
  self.fill_cell_topo_arg(args.topology, device)
136
132
  args.positions = self.positions.to(device)
137
133
 
138
- def side_arg_value(self, device):
139
- args = self.SideArg()
140
- self.fill_side_arg(args, device)
141
- return args
142
-
143
134
  def fill_side_arg(self, args: "Quadmesh.SideArg", device):
144
135
  self.fill_side_topo_arg(args.topology, device)
145
136
  args.positions = self.positions.to(device)
146
137
 
147
- @cached_arg_value
148
- def side_index_arg_value(self, device) -> SideIndexArg:
149
- args = self.SideIndexArg()
150
- self.fill_side_index_arg(args, device)
151
- return args
152
-
153
138
  def fill_side_index_arg(self, args: "SideIndexArg", device):
154
139
  args.boundary_edge_indices = self._boundary_edge_indices.to(device)
155
140
 
@@ -211,8 +196,8 @@ class Quadmesh(Geometry):
211
196
  return args.boundary_edge_indices[boundary_side_index]
212
197
 
213
198
  def _build_topology(self, temporary_store: TemporaryStore):
214
- from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
215
- from warp.utils import array_scan
199
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index, masked_indices
200
+ from warp._src.utils import array_scan
216
201
 
217
202
  device = self.quad_vertex_indices.device
218
203
 
@@ -223,7 +208,7 @@ class Quadmesh(Geometry):
223
208
  self._vertex_quad_indices = vertex_quad_indices.detach()
224
209
 
225
210
  vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
226
- vertex_start_edge_count.array.zero_()
211
+ vertex_start_edge_count.zero_()
227
212
  vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
228
213
 
229
214
  vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(4 * self.cell_count()))
@@ -236,10 +221,10 @@ class Quadmesh(Geometry):
236
221
  kernel=Quadmesh2D._count_starting_edges_kernel,
237
222
  device=device,
238
223
  dim=self.cell_count(),
239
- inputs=[self.quad_vertex_indices, vertex_start_edge_count.array],
224
+ inputs=[self.quad_vertex_indices, vertex_start_edge_count],
240
225
  )
241
226
 
242
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
227
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
243
228
 
244
229
  # Count number of unique edges (deduplicate across faces)
245
230
  vertex_unique_edge_count = vertex_start_edge_count
@@ -251,21 +236,19 @@ class Quadmesh(Geometry):
251
236
  self._vertex_quad_offsets,
252
237
  self._vertex_quad_indices,
253
238
  self.quad_vertex_indices,
254
- vertex_start_edge_offsets.array,
255
- vertex_unique_edge_count.array,
256
- vertex_edge_ends.array,
257
- vertex_edge_quads.array,
239
+ vertex_start_edge_offsets,
240
+ vertex_unique_edge_count,
241
+ vertex_edge_ends,
242
+ vertex_edge_quads,
258
243
  ],
259
244
  )
260
245
 
261
246
  vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
262
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
247
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
263
248
 
264
249
  # Get back edge count to host
265
250
  edge_count = int(
266
- host_read_at_index(
267
- vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
268
- )
251
+ host_read_at_index(vertex_unique_edge_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
269
252
  )
270
253
 
271
254
  self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
@@ -279,14 +262,14 @@ class Quadmesh(Geometry):
279
262
  device=device,
280
263
  dim=self.vertex_count(),
281
264
  inputs=[
282
- vertex_start_edge_offsets.array,
283
- vertex_unique_edge_offsets.array,
284
- vertex_unique_edge_count.array,
285
- vertex_edge_ends.array,
286
- vertex_edge_quads.array,
265
+ vertex_start_edge_offsets,
266
+ vertex_unique_edge_offsets,
267
+ vertex_unique_edge_count,
268
+ vertex_edge_ends,
269
+ vertex_edge_quads,
287
270
  self._edge_vertex_indices,
288
271
  self._edge_quad_indices,
289
- boundary_mask.array,
272
+ boundary_mask,
290
273
  ],
291
274
  )
292
275
 
@@ -296,7 +279,7 @@ class Quadmesh(Geometry):
296
279
  vertex_edge_ends.release()
297
280
  vertex_edge_quads.release()
298
281
 
299
- boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
282
+ boundary_edge_indices, _ = masked_indices(boundary_mask, temporary_store=temporary_store)
300
283
  self._boundary_edge_indices = boundary_edge_indices.detach()
301
284
 
302
285
  boundary_mask.release()
@@ -457,7 +440,7 @@ class Quadmesh(Geometry):
457
440
  q = pos - p0
458
441
  e = args.positions[edge_idx[1]] - p0
459
442
 
460
- dist, t = project_on_seg_at_origin(q, e, wp.lengh_sq(e))
443
+ dist, t = project_on_seg_at_origin(q, e, wp.length_sq(e))
461
444
  return Coords(t, 0.0, 0.0), dist
462
445
 
463
446
  @wp.func