warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,12 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from functools import cached_property
17
- from typing import Any
17
+ from typing import Any, Optional
18
18
 
19
19
  import warp as wp
20
- from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_struct
21
- from warp.fem.types import NULL_ELEMENT_INDEX, ElementIndex
22
- from warp.fem.utils import masked_indices
20
+ from warp._src.fem import cache
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, ElementIndex
22
+ from warp._src.fem.utils import masked_indices
23
23
 
24
24
  from .geometry import Geometry
25
25
 
@@ -61,17 +61,27 @@ class GeometryPartition:
61
61
  def __str__(self) -> str:
62
62
  return self.name
63
63
 
64
+ @cache.cached_arg_value
64
65
  def cell_arg_value(self, device):
65
- raise NotImplementedError()
66
+ args = self.CellArg()
67
+ self.fill_cell_arg(args, device)
68
+ return args
66
69
 
67
70
  def fill_cell_arg(self, args: CellArg, device):
68
- raise NotImplementedError()
71
+ if self.cell_arg_value is __class__.cell_arg_value:
72
+ raise NotImplementedError()
73
+ args.assign(self.cell_arg_value(device))
69
74
 
75
+ @cache.cached_arg_value
70
76
  def side_arg_value(self, device):
71
- raise NotImplementedError()
77
+ args = self.SideArg()
78
+ self.fill_side_arg(args, device)
79
+ return args
72
80
 
73
81
  def fill_side_arg(self, args: SideArg, device):
74
- raise NotImplementedError()
82
+ if self.side_arg_value is __class__.side_arg_value:
83
+ raise NotImplementedError()
84
+ args.assign(self.side_arg_value(device))
75
85
 
76
86
  @staticmethod
77
87
  def cell_index(args: CellArg, partition_cell_index: int):
@@ -139,10 +149,6 @@ class WholeGeometryPartition(GeometryPartition):
139
149
  class CellArg:
140
150
  pass
141
151
 
142
- def cell_arg_value(self, device):
143
- arg = WholeGeometryPartition.CellArg()
144
- return arg
145
-
146
152
  def fill_cell_arg(self, args: CellArg, device):
147
153
  pass
148
154
 
@@ -169,12 +175,16 @@ class CellBasedGeometryPartition(GeometryPartition):
169
175
  ):
170
176
  super().__init__(geometry)
171
177
 
178
+ self._partition_side_indices: wp.array = None
179
+ self._boundary_side_indices: wp.array = None
180
+ self._frontier_side_indices: wp.array = None
181
+
172
182
  @cached_property
173
183
  def SideArg(self):
174
184
  return self._make_side_arg()
175
185
 
176
186
  def _make_side_arg(self):
177
- @dynamic_struct(suffix=self.name)
187
+ @cache.dynamic_struct(suffix=self.name)
178
188
  class SideArg:
179
189
  cell_arg: self.CellArg
180
190
  partition_side_indices: wp.array(dtype=int)
@@ -184,25 +194,19 @@ class CellBasedGeometryPartition(GeometryPartition):
184
194
  return SideArg
185
195
 
186
196
  def side_count(self) -> int:
187
- return self._partition_side_indices.array.shape[0]
197
+ return self._partition_side_indices.shape[0]
188
198
 
189
199
  def boundary_side_count(self) -> int:
190
- return self._boundary_side_indices.array.shape[0]
200
+ return self._boundary_side_indices.shape[0]
191
201
 
192
202
  def frontier_side_count(self) -> int:
193
- return self._frontier_side_indices.array.shape[0]
194
-
195
- @cached_arg_value
196
- def side_arg_value(self, device):
197
- arg = self.SideArg()
198
- self.fill_side_arg(arg, device)
199
- return arg
203
+ return self._frontier_side_indices.shape[0]
200
204
 
201
205
  def fill_side_arg(self, args: SideArg, device):
202
206
  self.fill_cell_arg(args.cell_arg, device)
203
- args.partition_side_indices = self._partition_side_indices.array.to(device)
204
- args.boundary_side_indices = self._boundary_side_indices.array.to(device)
205
- args.frontier_side_indices = self._frontier_side_indices.array.to(device)
207
+ args.partition_side_indices = self._partition_side_indices.to(device)
208
+ args.boundary_side_indices = self._boundary_side_indices.to(device)
209
+ args.frontier_side_indices = self._frontier_side_indices.to(device)
206
210
 
207
211
  @wp.func
208
212
  def side_index(args: Any, partition_side_index: int):
@@ -220,9 +224,20 @@ class CellBasedGeometryPartition(GeometryPartition):
220
224
  return args.frontier_side_indices[frontier_side_index]
221
225
 
222
226
  def compute_side_indices_from_cells(
223
- self, cell_arg_value: Any, cell_inclusion_test_func: wp.Function, device, temporary_store: TemporaryStore = None
227
+ self,
228
+ cell_arg_value: Any,
229
+ cell_inclusion_test_func: wp.Function,
230
+ device,
231
+ max_side_count: int = -1,
232
+ temporary_store: cache.TemporaryStore = None,
224
233
  ):
225
- from warp.fem import cache
234
+ self.side_arg_value.invalidate(self)
235
+
236
+ if max_side_count == 0:
237
+ self._partition_side_indices = cache.borrow_temporary(temporary_store, dtype=int, shape=(0,), device=device)
238
+ self._boundary_side_indices = self._partition_side_indices
239
+ self._frontier_side_indices = self._partition_side_indices
240
+ return
226
241
 
227
242
  cell_arg_type = next(iter(cell_inclusion_test_func.input_types.values()))
228
243
 
@@ -253,46 +268,61 @@ class CellBasedGeometryPartition(GeometryPartition):
253
268
  # Exactly one neighbor in partition; count as frontier side
254
269
  frontier_side_mask[side_index] = 1
255
270
 
256
- partition_side_mask = borrow_temporary(
271
+ partition_side_mask = cache.borrow_temporary(
257
272
  temporary_store,
258
273
  shape=(self.geometry.side_count(),),
259
274
  dtype=int,
260
275
  device=device,
261
276
  )
262
- boundary_side_mask = borrow_temporary(
277
+ boundary_side_mask = cache.borrow_temporary(
263
278
  temporary_store,
264
279
  shape=(self.geometry.side_count(),),
265
280
  dtype=int,
266
281
  device=device,
267
282
  )
268
- frontier_side_mask = borrow_temporary(
283
+ frontier_side_mask = cache.borrow_temporary(
269
284
  temporary_store,
270
285
  shape=(self.geometry.side_count(),),
271
286
  dtype=int,
272
287
  device=device,
273
288
  )
274
289
 
275
- partition_side_mask.array.zero_()
276
- boundary_side_mask.array.zero_()
277
- frontier_side_mask.array.zero_()
290
+ partition_side_mask.zero_()
291
+ boundary_side_mask.zero_()
292
+ frontier_side_mask.zero_()
278
293
 
279
294
  wp.launch(
280
- dim=partition_side_mask.array.shape[0],
295
+ dim=partition_side_mask.shape[0],
281
296
  kernel=count_sides,
282
297
  inputs=[
283
298
  self.geometry.side_arg_value(device),
284
299
  cell_arg_value,
285
- partition_side_mask.array,
286
- boundary_side_mask.array,
287
- frontier_side_mask.array,
300
+ partition_side_mask,
301
+ boundary_side_mask,
302
+ frontier_side_mask,
288
303
  ],
289
304
  device=device,
290
305
  )
291
306
 
292
307
  # Convert counts to indices
293
- self._partition_side_indices, _ = masked_indices(partition_side_mask.array, temporary_store=temporary_store)
294
- self._boundary_side_indices, _ = masked_indices(boundary_side_mask.array, temporary_store=temporary_store)
295
- self._frontier_side_indices, _ = masked_indices(frontier_side_mask.array, temporary_store=temporary_store)
308
+ self._partition_side_indices, _ = masked_indices(
309
+ partition_side_mask,
310
+ max_index_count=max_side_count,
311
+ local_to_global=self._partition_side_indices,
312
+ temporary_store=temporary_store,
313
+ )
314
+ self._boundary_side_indices, _ = masked_indices(
315
+ boundary_side_mask,
316
+ max_index_count=max_side_count,
317
+ local_to_global=self._boundary_side_indices,
318
+ temporary_store=temporary_store,
319
+ )
320
+ self._frontier_side_indices, _ = masked_indices(
321
+ frontier_side_mask,
322
+ max_index_count=max_side_count,
323
+ local_to_global=self._frontier_side_indices,
324
+ temporary_store=temporary_store,
325
+ )
296
326
 
297
327
  partition_side_mask.release()
298
328
  boundary_side_mask.release()
@@ -310,7 +340,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
310
340
  partition_rank: int,
311
341
  partition_count: int,
312
342
  device=None,
313
- temporary_store: TemporaryStore = None,
343
+ temporary_store: cache.TemporaryStore = None,
314
344
  ):
315
345
  """Creates a geometry partition by uniformly partionning cell indices
316
346
 
@@ -343,11 +373,6 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
343
373
  cell_begin: int
344
374
  cell_end: int
345
375
 
346
- def cell_arg_value(self, device):
347
- arg = LinearGeometryPartition.CellArg()
348
- self.fill_cell_arg(arg, device)
349
- return arg
350
-
351
376
  def fill_cell_arg(self, args: CellArg, device):
352
377
  args.cell_begin = self.cell_begin
353
378
  args.cell_end = self.cell_end
@@ -372,43 +397,76 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
372
397
 
373
398
 
374
399
  class ExplicitGeometryPartition(CellBasedGeometryPartition):
375
- def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)", temporary_store: TemporaryStore = None):
376
- """Creates a geometry partition by uniformly partionning cell indices
400
+ def __init__(
401
+ self,
402
+ geometry: Geometry,
403
+ cell_mask: "wp.array(dtype=int)",
404
+ max_cell_count: int = -1,
405
+ max_side_count: int = -1,
406
+ temporary_store: Optional[cache.TemporaryStore] = None,
407
+ ):
408
+ """Creates a geometry partition from an active cell mask
377
409
 
378
410
  Args:
379
411
  geometry: the geometry to partition
380
412
  cell_mask: warp array of length ``geometry.cell_count()`` indicating which cells are selected. Array values must be either ``1`` (selected) or ``0`` (not selected).
413
+ max_cell_count: if positive, will be used to limit the number of cells to avoid device/host synchronization
414
+ max_side_count: if positive, will be used to limit the number of sides to avoid device/host synchronization
381
415
  """
382
416
 
383
417
  super().__init__(geometry)
384
418
 
385
- self._cell_mask = cell_mask
386
- self._cells, self._partition_cells = masked_indices(self._cell_mask, temporary_store=temporary_store)
419
+ self._cells: wp.array = None
420
+ self._partition_cells: wp.array = None
421
+
422
+ self._max_cell_count = max_cell_count
423
+ self._max_side_count = max_side_count
424
+
425
+ self.rebuild(cell_mask, temporary_store)
426
+
427
+ def rebuild(
428
+ self,
429
+ cell_mask: "wp.array(dtype=int)",
430
+ temporary_store: Optional[cache.TemporaryStore] = None,
431
+ ):
432
+ """
433
+ Rebuilds the geometry partition from a new active cell mask
434
+
435
+ Args:
436
+ geometry: the geometry to partition
437
+ cell_mask: warp array of length ``geometry.cell_count()`` indicating which cells are selected. Array values must be either ``1`` (selected) or ``0`` (not selected).
438
+ max_cell_count: if positive, will be used to limit the number of cells to avoid device/host synchronization
439
+ max_side_count: if positive, will be used to limit the number of sides to avoid device/host synchronization
440
+ """
441
+ self.cell_arg_value.invalidate(self)
442
+
443
+ self._cells, self._partition_cells = masked_indices(
444
+ cell_mask,
445
+ local_to_global=self._cells,
446
+ global_to_local=self._partition_cells,
447
+ max_index_count=self._max_cell_count,
448
+ temporary_store=temporary_store,
449
+ )
387
450
 
388
451
  super().compute_side_indices_from_cells(
389
- self._cell_mask,
452
+ self.cell_arg_value(cell_mask.device),
390
453
  ExplicitGeometryPartition._cell_inclusion_test,
391
- self._cell_mask.device,
454
+ max_side_count=self._max_side_count,
455
+ device=cell_mask.device,
392
456
  temporary_store=temporary_store,
393
457
  )
394
458
 
395
459
  def cell_count(self) -> int:
396
- return self._cells.array.shape[0]
460
+ return self._cells.shape[0]
397
461
 
398
462
  @wp.struct
399
463
  class CellArg:
400
464
  cell_index: wp.array(dtype=int)
401
465
  partition_cell_index: wp.array(dtype=int)
402
466
 
403
- @cached_arg_value
404
- def cell_arg_value(self, device):
405
- arg = ExplicitGeometryPartition.CellArg()
406
- self.fill_cell_arg(arg, device)
407
- return arg
408
-
409
467
  def fill_cell_arg(self, args: CellArg, device):
410
- args.cell_index = self._cells.array.to(device)
411
- args.partition_cell_index = self._partition_cells.array.to(device)
468
+ args.cell_index = self._cells.to(device)
469
+ args.partition_cell_index = self._partition_cells.to(device)
412
470
 
413
471
  @wp.func
414
472
  def cell_index(args: CellArg, partition_cell_index: int):
@@ -419,5 +477,5 @@ class ExplicitGeometryPartition(CellBasedGeometryPartition):
419
477
  return args.partition_cell_index[cell_index]
420
478
 
421
479
  @wp.func
422
- def _cell_inclusion_test(mask: wp.array(dtype=int), cell_index: int):
423
- return mask[cell_index] > 0
480
+ def _cell_inclusion_test(arg: CellArg, cell_index: int):
481
+ return arg.partition_cell_index[cell_index] != NULL_ELEMENT_INDEX
@@ -16,16 +16,15 @@
16
16
  from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import (
19
+ from warp._src.fem.cache import (
20
20
  TemporaryStore,
21
21
  borrow_temporary,
22
22
  borrow_temporary_like,
23
- cached_arg_value,
24
23
  )
25
- from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample
24
+ from warp._src.fem.types import OUTSIDE, Coords, ElementIndex, Sample
26
25
 
27
26
  from .closest_point import project_on_seg_at_origin
28
- from .element import LinearEdge, Square
27
+ from .element import Element
29
28
  from .geometry import Geometry
30
29
 
31
30
 
@@ -99,11 +98,11 @@ class Quadmesh(Geometry):
99
98
  def boundary_side_count(self):
100
99
  return self._boundary_edge_indices.shape[0]
101
100
 
102
- def reference_cell(self) -> Square:
103
- return Square()
101
+ def reference_cell(self) -> Element:
102
+ return Element.SQUARE
104
103
 
105
- def reference_side(self) -> LinearEdge:
106
- return LinearEdge()
104
+ def reference_side(self) -> Element:
105
+ return Element.LINE_SEGMENT
107
106
 
108
107
  @property
109
108
  def edge_quad_indices(self) -> wp.array:
@@ -126,30 +125,14 @@ class Quadmesh(Geometry):
126
125
  args.edge_vertex_indices = self._edge_vertex_indices.to(device)
127
126
  args.edge_quad_indices = self._edge_quad_indices.to(device)
128
127
 
129
- def cell_arg_value(self, device):
130
- args = self.CellArg()
131
- self.fill_cell_arg(args, device)
132
- return args
133
-
134
128
  def fill_cell_arg(self, args: "Quadmesh.CellArg", device):
135
129
  self.fill_cell_topo_arg(args.topology, device)
136
130
  args.positions = self.positions.to(device)
137
131
 
138
- def side_arg_value(self, device):
139
- args = self.SideArg()
140
- self.fill_side_arg(args, device)
141
- return args
142
-
143
132
  def fill_side_arg(self, args: "Quadmesh.SideArg", device):
144
133
  self.fill_side_topo_arg(args.topology, device)
145
134
  args.positions = self.positions.to(device)
146
135
 
147
- @cached_arg_value
148
- def side_index_arg_value(self, device) -> SideIndexArg:
149
- args = self.SideIndexArg()
150
- self.fill_side_index_arg(args, device)
151
- return args
152
-
153
136
  def fill_side_index_arg(self, args: "SideIndexArg", device):
154
137
  args.boundary_edge_indices = self._boundary_edge_indices.to(device)
155
138
 
@@ -211,8 +194,8 @@ class Quadmesh(Geometry):
211
194
  return args.boundary_edge_indices[boundary_side_index]
212
195
 
213
196
  def _build_topology(self, temporary_store: TemporaryStore):
214
- from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
215
- from warp.utils import array_scan
197
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index, masked_indices
198
+ from warp._src.utils import array_scan
216
199
 
217
200
  device = self.quad_vertex_indices.device
218
201
 
@@ -223,7 +206,7 @@ class Quadmesh(Geometry):
223
206
  self._vertex_quad_indices = vertex_quad_indices.detach()
224
207
 
225
208
  vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
226
- vertex_start_edge_count.array.zero_()
209
+ vertex_start_edge_count.zero_()
227
210
  vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
228
211
 
229
212
  vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(4 * self.cell_count()))
@@ -236,10 +219,10 @@ class Quadmesh(Geometry):
236
219
  kernel=Quadmesh2D._count_starting_edges_kernel,
237
220
  device=device,
238
221
  dim=self.cell_count(),
239
- inputs=[self.quad_vertex_indices, vertex_start_edge_count.array],
222
+ inputs=[self.quad_vertex_indices, vertex_start_edge_count],
240
223
  )
241
224
 
242
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
225
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
243
226
 
244
227
  # Count number of unique edges (deduplicate across faces)
245
228
  vertex_unique_edge_count = vertex_start_edge_count
@@ -251,21 +234,19 @@ class Quadmesh(Geometry):
251
234
  self._vertex_quad_offsets,
252
235
  self._vertex_quad_indices,
253
236
  self.quad_vertex_indices,
254
- vertex_start_edge_offsets.array,
255
- vertex_unique_edge_count.array,
256
- vertex_edge_ends.array,
257
- vertex_edge_quads.array,
237
+ vertex_start_edge_offsets,
238
+ vertex_unique_edge_count,
239
+ vertex_edge_ends,
240
+ vertex_edge_quads,
258
241
  ],
259
242
  )
260
243
 
261
244
  vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
262
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
245
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
263
246
 
264
247
  # Get back edge count to host
265
248
  edge_count = int(
266
- host_read_at_index(
267
- vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
268
- )
249
+ host_read_at_index(vertex_unique_edge_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
269
250
  )
270
251
 
271
252
  self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
@@ -279,14 +260,14 @@ class Quadmesh(Geometry):
279
260
  device=device,
280
261
  dim=self.vertex_count(),
281
262
  inputs=[
282
- vertex_start_edge_offsets.array,
283
- vertex_unique_edge_offsets.array,
284
- vertex_unique_edge_count.array,
285
- vertex_edge_ends.array,
286
- vertex_edge_quads.array,
263
+ vertex_start_edge_offsets,
264
+ vertex_unique_edge_offsets,
265
+ vertex_unique_edge_count,
266
+ vertex_edge_ends,
267
+ vertex_edge_quads,
287
268
  self._edge_vertex_indices,
288
269
  self._edge_quad_indices,
289
- boundary_mask.array,
270
+ boundary_mask,
290
271
  ],
291
272
  )
292
273
 
@@ -296,7 +277,7 @@ class Quadmesh(Geometry):
296
277
  vertex_edge_ends.release()
297
278
  vertex_edge_quads.release()
298
279
 
299
- boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
280
+ boundary_edge_indices, _ = masked_indices(boundary_mask, temporary_store=temporary_store)
300
281
  self._boundary_edge_indices = boundary_edge_indices.detach()
301
282
 
302
283
  boundary_mask.release()
@@ -457,7 +438,7 @@ class Quadmesh(Geometry):
457
438
  q = pos - p0
458
439
  e = args.positions[edge_idx[1]] - p0
459
440
 
460
- dist, t = project_on_seg_at_origin(q, e, wp.lengh_sq(e))
441
+ dist, t = project_on_seg_at_origin(q, e, wp.length_sq(e))
461
442
  return Coords(t, 0.0, 0.0), dist
462
443
 
463
444
  @wp.func