warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,12 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from functools import cached_property
17
- from typing import Any
17
+ from typing import Any, Optional
18
18
 
19
19
  import warp as wp
20
- from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_struct
21
- from warp.fem.types import NULL_ELEMENT_INDEX, ElementIndex
22
- from warp.fem.utils import masked_indices
20
+ from warp._src.fem import cache
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, ElementIndex
22
+ from warp._src.fem.utils import masked_indices
23
23
 
24
24
  from .geometry import Geometry
25
25
 
@@ -61,17 +61,27 @@ class GeometryPartition:
61
61
  def __str__(self) -> str:
62
62
  return self.name
63
63
 
64
+ @cache.cached_arg_value
64
65
  def cell_arg_value(self, device):
65
- raise NotImplementedError()
66
+ args = self.CellArg()
67
+ self.fill_cell_arg(args, device)
68
+ return args
66
69
 
67
70
  def fill_cell_arg(self, args: CellArg, device):
68
- raise NotImplementedError()
71
+ if self.cell_arg_value is __class__.cell_arg_value:
72
+ raise NotImplementedError()
73
+ args.assign(self.cell_arg_value(device))
69
74
 
75
+ @cache.cached_arg_value
70
76
  def side_arg_value(self, device):
71
- raise NotImplementedError()
77
+ args = self.SideArg()
78
+ self.fill_side_arg(args, device)
79
+ return args
72
80
 
73
81
  def fill_side_arg(self, args: SideArg, device):
74
- raise NotImplementedError()
82
+ if self.side_arg_value is __class__.side_arg_value:
83
+ raise NotImplementedError()
84
+ args.assign(self.side_arg_value(device))
75
85
 
76
86
  @staticmethod
77
87
  def cell_index(args: CellArg, partition_cell_index: int):
@@ -139,10 +149,6 @@ class WholeGeometryPartition(GeometryPartition):
139
149
  class CellArg:
140
150
  pass
141
151
 
142
- def cell_arg_value(self, device):
143
- arg = WholeGeometryPartition.CellArg()
144
- return arg
145
-
146
152
  def fill_cell_arg(self, args: CellArg, device):
147
153
  pass
148
154
 
@@ -169,12 +175,16 @@ class CellBasedGeometryPartition(GeometryPartition):
169
175
  ):
170
176
  super().__init__(geometry)
171
177
 
178
+ self._partition_side_indices: wp.array = None
179
+ self._boundary_side_indices: wp.array = None
180
+ self._frontier_side_indices: wp.array = None
181
+
172
182
  @cached_property
173
183
  def SideArg(self):
174
184
  return self._make_side_arg()
175
185
 
176
186
  def _make_side_arg(self):
177
- @dynamic_struct(suffix=self.name)
187
+ @cache.dynamic_struct(suffix=self.name)
178
188
  class SideArg:
179
189
  cell_arg: self.CellArg
180
190
  partition_side_indices: wp.array(dtype=int)
@@ -184,25 +194,19 @@ class CellBasedGeometryPartition(GeometryPartition):
184
194
  return SideArg
185
195
 
186
196
  def side_count(self) -> int:
187
- return self._partition_side_indices.array.shape[0]
197
+ return self._partition_side_indices.shape[0]
188
198
 
189
199
  def boundary_side_count(self) -> int:
190
- return self._boundary_side_indices.array.shape[0]
200
+ return self._boundary_side_indices.shape[0]
191
201
 
192
202
  def frontier_side_count(self) -> int:
193
- return self._frontier_side_indices.array.shape[0]
194
-
195
- @cached_arg_value
196
- def side_arg_value(self, device):
197
- arg = self.SideArg()
198
- self.fill_side_arg(arg, device)
199
- return arg
203
+ return self._frontier_side_indices.shape[0]
200
204
 
201
205
  def fill_side_arg(self, args: SideArg, device):
202
206
  self.fill_cell_arg(args.cell_arg, device)
203
- args.partition_side_indices = self._partition_side_indices.array.to(device)
204
- args.boundary_side_indices = self._boundary_side_indices.array.to(device)
205
- args.frontier_side_indices = self._frontier_side_indices.array.to(device)
207
+ args.partition_side_indices = self._partition_side_indices.to(device)
208
+ args.boundary_side_indices = self._boundary_side_indices.to(device)
209
+ args.frontier_side_indices = self._frontier_side_indices.to(device)
206
210
 
207
211
  @wp.func
208
212
  def side_index(args: Any, partition_side_index: int):
@@ -220,9 +224,20 @@ class CellBasedGeometryPartition(GeometryPartition):
220
224
  return args.frontier_side_indices[frontier_side_index]
221
225
 
222
226
  def compute_side_indices_from_cells(
223
- self, cell_arg_value: Any, cell_inclusion_test_func: wp.Function, device, temporary_store: TemporaryStore = None
227
+ self,
228
+ cell_arg_value: Any,
229
+ cell_inclusion_test_func: wp.Function,
230
+ device,
231
+ max_side_count: int = -1,
232
+ temporary_store: cache.TemporaryStore = None,
224
233
  ):
225
- from warp.fem import cache
234
+ self.side_arg_value.invalidate(self)
235
+
236
+ if max_side_count == 0:
237
+ self._partition_side_indices = cache.borrow_temporary(temporary_store, dtype=int, shape=(0,), device=device)
238
+ self._boundary_side_indices = self._partition_side_indices
239
+ self._frontier_side_indices = self._partition_side_indices
240
+ return
226
241
 
227
242
  cell_arg_type = next(iter(cell_inclusion_test_func.input_types.values()))
228
243
 
@@ -253,46 +268,61 @@ class CellBasedGeometryPartition(GeometryPartition):
253
268
  # Exactly one neighbor in partition; count as frontier side
254
269
  frontier_side_mask[side_index] = 1
255
270
 
256
- partition_side_mask = borrow_temporary(
271
+ partition_side_mask = cache.borrow_temporary(
257
272
  temporary_store,
258
273
  shape=(self.geometry.side_count(),),
259
274
  dtype=int,
260
275
  device=device,
261
276
  )
262
- boundary_side_mask = borrow_temporary(
277
+ boundary_side_mask = cache.borrow_temporary(
263
278
  temporary_store,
264
279
  shape=(self.geometry.side_count(),),
265
280
  dtype=int,
266
281
  device=device,
267
282
  )
268
- frontier_side_mask = borrow_temporary(
283
+ frontier_side_mask = cache.borrow_temporary(
269
284
  temporary_store,
270
285
  shape=(self.geometry.side_count(),),
271
286
  dtype=int,
272
287
  device=device,
273
288
  )
274
289
 
275
- partition_side_mask.array.zero_()
276
- boundary_side_mask.array.zero_()
277
- frontier_side_mask.array.zero_()
290
+ partition_side_mask.zero_()
291
+ boundary_side_mask.zero_()
292
+ frontier_side_mask.zero_()
278
293
 
279
294
  wp.launch(
280
- dim=partition_side_mask.array.shape[0],
295
+ dim=partition_side_mask.shape[0],
281
296
  kernel=count_sides,
282
297
  inputs=[
283
298
  self.geometry.side_arg_value(device),
284
299
  cell_arg_value,
285
- partition_side_mask.array,
286
- boundary_side_mask.array,
287
- frontier_side_mask.array,
300
+ partition_side_mask,
301
+ boundary_side_mask,
302
+ frontier_side_mask,
288
303
  ],
289
304
  device=device,
290
305
  )
291
306
 
292
307
  # Convert counts to indices
293
- self._partition_side_indices, _ = masked_indices(partition_side_mask.array, temporary_store=temporary_store)
294
- self._boundary_side_indices, _ = masked_indices(boundary_side_mask.array, temporary_store=temporary_store)
295
- self._frontier_side_indices, _ = masked_indices(frontier_side_mask.array, temporary_store=temporary_store)
308
+ self._partition_side_indices, _ = masked_indices(
309
+ partition_side_mask,
310
+ max_index_count=max_side_count,
311
+ local_to_global=self._partition_side_indices,
312
+ temporary_store=temporary_store,
313
+ )
314
+ self._boundary_side_indices, _ = masked_indices(
315
+ boundary_side_mask,
316
+ max_index_count=max_side_count,
317
+ local_to_global=self._boundary_side_indices,
318
+ temporary_store=temporary_store,
319
+ )
320
+ self._frontier_side_indices, _ = masked_indices(
321
+ frontier_side_mask,
322
+ max_index_count=max_side_count,
323
+ local_to_global=self._frontier_side_indices,
324
+ temporary_store=temporary_store,
325
+ )
296
326
 
297
327
  partition_side_mask.release()
298
328
  boundary_side_mask.release()
@@ -310,7 +340,7 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
310
340
  partition_rank: int,
311
341
  partition_count: int,
312
342
  device=None,
313
- temporary_store: TemporaryStore = None,
343
+ temporary_store: cache.TemporaryStore = None,
314
344
  ):
315
345
  """Creates a geometry partition by uniformly partionning cell indices
316
346
 
@@ -343,11 +373,6 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
343
373
  cell_begin: int
344
374
  cell_end: int
345
375
 
346
- def cell_arg_value(self, device):
347
- arg = LinearGeometryPartition.CellArg()
348
- self.fill_cell_arg(arg, device)
349
- return arg
350
-
351
376
  def fill_cell_arg(self, args: CellArg, device):
352
377
  args.cell_begin = self.cell_begin
353
378
  args.cell_end = self.cell_end
@@ -372,43 +397,76 @@ class LinearGeometryPartition(CellBasedGeometryPartition):
372
397
 
373
398
 
374
399
  class ExplicitGeometryPartition(CellBasedGeometryPartition):
375
- def __init__(self, geometry: Geometry, cell_mask: "wp.array(dtype=int)", temporary_store: TemporaryStore = None):
376
- """Creates a geometry partition by uniformly partionning cell indices
400
+ def __init__(
401
+ self,
402
+ geometry: Geometry,
403
+ cell_mask: "wp.array(dtype=int)",
404
+ max_cell_count: int = -1,
405
+ max_side_count: int = -1,
406
+ temporary_store: Optional[cache.TemporaryStore] = None,
407
+ ):
408
+ """Creates a geometry partition from an active cell mask
377
409
 
378
410
  Args:
379
411
  geometry: the geometry to partition
380
412
  cell_mask: warp array of length ``geometry.cell_count()`` indicating which cells are selected. Array values must be either ``1`` (selected) or ``0`` (not selected).
413
+ max_cell_count: if positive, will be used to limit the number of cells to avoid device/host synchronization
414
+ max_side_count: if positive, will be used to limit the number of sides to avoid device/host synchronization
381
415
  """
382
416
 
383
417
  super().__init__(geometry)
384
418
 
385
- self._cell_mask = cell_mask
386
- self._cells, self._partition_cells = masked_indices(self._cell_mask, temporary_store=temporary_store)
419
+ self._cells: wp.array = None
420
+ self._partition_cells: wp.array = None
421
+
422
+ self._max_cell_count = max_cell_count
423
+ self._max_side_count = max_side_count
424
+
425
+ self.rebuild(cell_mask, temporary_store)
426
+
427
+ def rebuild(
428
+ self,
429
+ cell_mask: "wp.array(dtype=int)",
430
+ temporary_store: Optional[cache.TemporaryStore] = None,
431
+ ):
432
+ """
433
+ Rebuilds the geometry partition from a new active cell mask
434
+
435
+ Args:
436
+ geometry: the geometry to partition
437
+ cell_mask: warp array of length ``geometry.cell_count()`` indicating which cells are selected. Array values must be either ``1`` (selected) or ``0`` (not selected).
438
+ max_cell_count: if positive, will be used to limit the number of cells to avoid device/host synchronization
439
+ max_side_count: if positive, will be used to limit the number of sides to avoid device/host synchronization
440
+ """
441
+ self.cell_arg_value.invalidate(self)
442
+
443
+ self._cells, self._partition_cells = masked_indices(
444
+ cell_mask,
445
+ local_to_global=self._cells,
446
+ global_to_local=self._partition_cells,
447
+ max_index_count=self._max_cell_count,
448
+ temporary_store=temporary_store,
449
+ )
387
450
 
388
451
  super().compute_side_indices_from_cells(
389
- self._cell_mask,
452
+ self.cell_arg_value(cell_mask.device),
390
453
  ExplicitGeometryPartition._cell_inclusion_test,
391
- self._cell_mask.device,
454
+ max_side_count=self._max_side_count,
455
+ device=cell_mask.device,
392
456
  temporary_store=temporary_store,
393
457
  )
394
458
 
395
459
  def cell_count(self) -> int:
396
- return self._cells.array.shape[0]
460
+ return self._cells.shape[0]
397
461
 
398
462
  @wp.struct
399
463
  class CellArg:
400
464
  cell_index: wp.array(dtype=int)
401
465
  partition_cell_index: wp.array(dtype=int)
402
466
 
403
- @cached_arg_value
404
- def cell_arg_value(self, device):
405
- arg = ExplicitGeometryPartition.CellArg()
406
- self.fill_cell_arg(arg, device)
407
- return arg
408
-
409
467
  def fill_cell_arg(self, args: CellArg, device):
410
- args.cell_index = self._cells.array.to(device)
411
- args.partition_cell_index = self._partition_cells.array.to(device)
468
+ args.cell_index = self._cells.to(device)
469
+ args.partition_cell_index = self._partition_cells.to(device)
412
470
 
413
471
  @wp.func
414
472
  def cell_index(args: CellArg, partition_cell_index: int):
@@ -419,5 +477,5 @@ class ExplicitGeometryPartition(CellBasedGeometryPartition):
419
477
  return args.partition_cell_index[cell_index]
420
478
 
421
479
  @wp.func
422
- def _cell_inclusion_test(mask: wp.array(dtype=int), cell_index: int):
423
- return mask[cell_index] > 0
480
+ def _cell_inclusion_test(arg: CellArg, cell_index: int):
481
+ return arg.partition_cell_index[cell_index] != NULL_ELEMENT_INDEX
@@ -16,16 +16,15 @@
16
16
  from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import (
19
+ from warp._src.fem.cache import (
20
20
  TemporaryStore,
21
21
  borrow_temporary,
22
22
  borrow_temporary_like,
23
- cached_arg_value,
24
23
  )
25
- from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample
24
+ from warp._src.fem.types import OUTSIDE, Coords, ElementIndex, Sample
26
25
 
27
26
  from .closest_point import project_on_seg_at_origin
28
- from .element import LinearEdge, Square
27
+ from .element import Element
29
28
  from .geometry import Geometry
30
29
 
31
30
 
@@ -99,11 +98,11 @@ class Quadmesh(Geometry):
99
98
  def boundary_side_count(self):
100
99
  return self._boundary_edge_indices.shape[0]
101
100
 
102
- def reference_cell(self) -> Square:
103
- return Square()
101
+ def reference_cell(self) -> Element:
102
+ return Element.SQUARE
104
103
 
105
- def reference_side(self) -> LinearEdge:
106
- return LinearEdge()
104
+ def reference_side(self) -> Element:
105
+ return Element.LINE_SEGMENT
107
106
 
108
107
  @property
109
108
  def edge_quad_indices(self) -> wp.array:
@@ -126,30 +125,14 @@ class Quadmesh(Geometry):
126
125
  args.edge_vertex_indices = self._edge_vertex_indices.to(device)
127
126
  args.edge_quad_indices = self._edge_quad_indices.to(device)
128
127
 
129
- def cell_arg_value(self, device):
130
- args = self.CellArg()
131
- self.fill_cell_arg(args, device)
132
- return args
133
-
134
128
  def fill_cell_arg(self, args: "Quadmesh.CellArg", device):
135
129
  self.fill_cell_topo_arg(args.topology, device)
136
130
  args.positions = self.positions.to(device)
137
131
 
138
- def side_arg_value(self, device):
139
- args = self.SideArg()
140
- self.fill_side_arg(args, device)
141
- return args
142
-
143
132
  def fill_side_arg(self, args: "Quadmesh.SideArg", device):
144
133
  self.fill_side_topo_arg(args.topology, device)
145
134
  args.positions = self.positions.to(device)
146
135
 
147
- @cached_arg_value
148
- def side_index_arg_value(self, device) -> SideIndexArg:
149
- args = self.SideIndexArg()
150
- self.fill_side_index_arg(args, device)
151
- return args
152
-
153
136
  def fill_side_index_arg(self, args: "SideIndexArg", device):
154
137
  args.boundary_edge_indices = self._boundary_edge_indices.to(device)
155
138
 
@@ -211,8 +194,8 @@ class Quadmesh(Geometry):
211
194
  return args.boundary_edge_indices[boundary_side_index]
212
195
 
213
196
  def _build_topology(self, temporary_store: TemporaryStore):
214
- from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
215
- from warp.utils import array_scan
197
+ from warp._src.fem.utils import compress_node_indices, host_read_at_index, masked_indices
198
+ from warp._src.utils import array_scan
216
199
 
217
200
  device = self.quad_vertex_indices.device
218
201
 
@@ -223,7 +206,7 @@ class Quadmesh(Geometry):
223
206
  self._vertex_quad_indices = vertex_quad_indices.detach()
224
207
 
225
208
  vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
226
- vertex_start_edge_count.array.zero_()
209
+ vertex_start_edge_count.zero_()
227
210
  vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
228
211
 
229
212
  vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(4 * self.cell_count()))
@@ -236,10 +219,10 @@ class Quadmesh(Geometry):
236
219
  kernel=Quadmesh2D._count_starting_edges_kernel,
237
220
  device=device,
238
221
  dim=self.cell_count(),
239
- inputs=[self.quad_vertex_indices, vertex_start_edge_count.array],
222
+ inputs=[self.quad_vertex_indices, vertex_start_edge_count],
240
223
  )
241
224
 
242
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
225
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
243
226
 
244
227
  # Count number of unique edges (deduplicate across faces)
245
228
  vertex_unique_edge_count = vertex_start_edge_count
@@ -251,21 +234,19 @@ class Quadmesh(Geometry):
251
234
  self._vertex_quad_offsets,
252
235
  self._vertex_quad_indices,
253
236
  self.quad_vertex_indices,
254
- vertex_start_edge_offsets.array,
255
- vertex_unique_edge_count.array,
256
- vertex_edge_ends.array,
257
- vertex_edge_quads.array,
237
+ vertex_start_edge_offsets,
238
+ vertex_unique_edge_count,
239
+ vertex_edge_ends,
240
+ vertex_edge_quads,
258
241
  ],
259
242
  )
260
243
 
261
244
  vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
262
- array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
245
+ array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
263
246
 
264
247
  # Get back edge count to host
265
248
  edge_count = int(
266
- host_read_at_index(
267
- vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
268
- )
249
+ host_read_at_index(vertex_unique_edge_offsets, self.vertex_count() - 1, temporary_store=temporary_store)
269
250
  )
270
251
 
271
252
  self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
@@ -279,14 +260,14 @@ class Quadmesh(Geometry):
279
260
  device=device,
280
261
  dim=self.vertex_count(),
281
262
  inputs=[
282
- vertex_start_edge_offsets.array,
283
- vertex_unique_edge_offsets.array,
284
- vertex_unique_edge_count.array,
285
- vertex_edge_ends.array,
286
- vertex_edge_quads.array,
263
+ vertex_start_edge_offsets,
264
+ vertex_unique_edge_offsets,
265
+ vertex_unique_edge_count,
266
+ vertex_edge_ends,
267
+ vertex_edge_quads,
287
268
  self._edge_vertex_indices,
288
269
  self._edge_quad_indices,
289
- boundary_mask.array,
270
+ boundary_mask,
290
271
  ],
291
272
  )
292
273
 
@@ -296,7 +277,7 @@ class Quadmesh(Geometry):
296
277
  vertex_edge_ends.release()
297
278
  vertex_edge_quads.release()
298
279
 
299
- boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
280
+ boundary_edge_indices, _ = masked_indices(boundary_mask, temporary_store=temporary_store)
300
281
  self._boundary_edge_indices = boundary_edge_indices.detach()
301
282
 
302
283
  boundary_mask.release()
@@ -457,7 +438,7 @@ class Quadmesh(Geometry):
457
438
  q = pos - p0
458
439
  e = args.positions[edge_idx[1]] - p0
459
440
 
460
- dist, t = project_on_seg_at_origin(q, e, wp.lengh_sq(e))
441
+ dist, t = project_on_seg_at_origin(q, e, wp.length_sq(e))
461
442
  return Coords(t, 0.0, 0.0), dist
462
443
 
463
444
  @wp.func