warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -15,10 +15,11 @@
15
15
 
16
16
  from typing import Any
17
17
 
18
- import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import Geometry
21
- from warp.fem.types import Coords, ElementIndex, ElementKind, Sample, make_free_sample
18
+ from warp._src.codegen import Struct, StructInstance
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Geometry
21
+ from warp._src.fem.types import Coords, ElementIndex, ElementKind, Sample, make_free_sample
22
+ from warp._src.types import type_is_matrix, type_is_vector, type_size
22
23
 
23
24
  from .topology import SpaceTopology
24
25
 
@@ -41,7 +42,7 @@ class FunctionSpace:
41
42
  dof_dtype: type
42
43
  """Data type of the degrees of freedom of each node"""
43
44
 
44
- SpaceArg: wp.codegen.Struct
45
+ SpaceArg: Struct
45
46
  """Structure containing arguments to be passed to device function"""
46
47
 
47
48
  LocalValueMap: type
@@ -71,7 +72,7 @@ class FunctionSpace:
71
72
  """Number of nodes in the interpolation basis"""
72
73
  return self.topology.node_count()
73
74
 
74
- def space_arg_value(self, device) -> wp.codegen.StructInstance:
75
+ def space_arg_value(self, device) -> StructInstance:
75
76
  """Value of the arguments to be passed to device functions"""
76
77
  raise NotImplementedError
77
78
 
@@ -123,13 +124,13 @@ class FunctionSpace:
123
124
 
124
125
  def gradient_valid(self) -> bool:
125
126
  """Whether gradient operator can be computed. Only for scalar and vector fields as higher-order tensors are not supported yet"""
126
- return not wp.types.type_is_matrix(self.dtype)
127
+ return not type_is_matrix(self.dtype)
127
128
 
128
129
  def divergence_valid(self) -> bool:
129
130
  """Whether divergence of this field can be computed. Only for vector and tensor fields with same dimension as embedding geometry"""
130
- if wp.types.type_is_vector(self.dtype):
131
- return wp.types.type_size(self.dtype) == self.geometry.dimension
132
- if wp.types.type_is_matrix(self.dtype):
131
+ if type_is_vector(self.dtype):
132
+ return type_size(self.dtype) == self.geometry.dimension
133
+ if type_is_matrix(self.dtype):
133
134
  return self.dtype._shape_[0] == self.geometry.dimension
134
135
  return False
135
136
 
@@ -243,7 +244,7 @@ class FunctionSpace:
243
244
  - node_weight: weight associated to the node, as given per `element_(inn|out)er_weight`
244
245
  - local_value_map: local transformation from node space to world space, as given per `local_map_value_(inn|out)er`
245
246
  """
246
- raise NotADirectoryError
247
+ raise NotImplementedError
247
248
 
248
249
  def space_gradient(
249
250
  dof_value: "FunctionSpace.dof_dtype",
@@ -257,7 +258,7 @@ class FunctionSpace:
257
258
  - dof_value: node value in the degrees-of-freedom basis
258
259
  - node_weight_gradient: gradient of the weight associated to the node, as given per `element_(inn|out)er_weight_gradient`
259
260
  - local_value_map: local transformation from node space to world space, as given per `local_map_value_(inn|out)er`
260
- - grad_transform: transform mapping the reference space gradient to worls-space gradient (inverse deformation gradient)
261
+ - grad_transform: transform mapping the reference-space gradient to world-space gradient (inverse deformation gradient)
261
262
  """
262
263
  raise NotImplementedError
263
264
 
@@ -273,7 +274,7 @@ class FunctionSpace:
273
274
  - dof_value: node value in the degrees-of-freedom basis
274
275
  - node_weight_gradient: gradient of the weight associated to the node, as given per `element_(inn|out)er_weight_gradient`
275
276
  - local_value_map: local transformation from node space to world space, as given per `local_map_value_(inn|out)er`
276
- - grad_transform: transform mapping the reference space gradient to worls-space gradient (inverse deformation gradient)
277
+ - grad_transform: transform mapping the reference-space gradient to world-space gradient (inverse deformation gradient)
277
278
  """
278
279
  raise NotImplementedError
279
280
 
@@ -16,10 +16,10 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import Grid2D
21
- from warp.fem.polynomial import is_closed
22
- from warp.fem.types import NULL_NODE_INDEX, ElementIndex
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Grid2D
21
+ from warp._src.fem.polynomial import is_closed
22
+ from warp._src.fem.types import NULL_NODE_INDEX, ElementIndex
23
23
 
24
24
  from .shape import SquareBipolynomialShapeFunctions, SquareShapeFunction
25
25
  from .topology import SpaceTopology, forward_base_topology
@@ -38,9 +38,6 @@ class Grid2DSpaceTopology(SpaceTopology):
38
38
  def name(self):
39
39
  return f"{self.geometry.name}_{self._shape.name}"
40
40
 
41
- def topo_arg_value(self, device):
42
- return self.geometry.side_arg_value(device)
43
-
44
41
  def fill_topo_arg(self, arg: Grid2D.SideArg, device):
45
42
  self.geometry.fill_side_arg(arg, device)
46
43
 
@@ -16,10 +16,10 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import Grid3D
21
- from warp.fem.polynomial import is_closed
22
- from warp.fem.types import ElementIndex
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Grid3D
21
+ from warp._src.fem.polynomial import is_closed
22
+ from warp._src.fem.types import ElementIndex
23
23
 
24
24
  from .shape import (
25
25
  CubeShapeFunction,
@@ -14,14 +14,14 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import warp as wp
17
- from warp.fem import cache
18
- from warp.fem.geometry import Hexmesh
19
- from warp.fem.geometry.hexmesh import (
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Hexmesh
19
+ from warp._src.fem.geometry.hexmesh import (
20
20
  EDGE_VERTEX_INDICES,
21
21
  FACE_ORIENTATION,
22
22
  FACE_TRANSLATION,
23
23
  )
24
- from warp.fem.types import ElementIndex
24
+ from warp._src.fem.types import ElementIndex
25
25
 
26
26
  from .shape import CubeShapeFunction
27
27
  from .topology import SpaceTopology, forward_base_topology
@@ -82,12 +82,6 @@ class HexmeshSpaceTopology(SpaceTopology):
82
82
  def name(self):
83
83
  return f"{self.geometry.name}_{self._shape.name}"
84
84
 
85
- @cache.cached_arg_value
86
- def topo_arg_value(self, device):
87
- arg = HexmeshTopologyArg()
88
- self.fill_topo_arg(arg, device)
89
- return arg
90
-
91
85
  def fill_topo_arg(self, arg: HexmeshTopologyArg, device):
92
86
  arg.hex_edge_indices = self._hex_edge_indices.to(device)
93
87
  arg.hex_face_indices = self._hex_face_indices.to(device)
@@ -16,9 +16,9 @@
16
16
  from typing import Union
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import AdaptiveNanogrid, Nanogrid
21
- from warp.fem.types import ElementIndex
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import AdaptiveNanogrid, Nanogrid
21
+ from warp._src.fem.types import ElementIndex
22
22
 
23
23
  from .shape import CubeShapeFunction
24
24
  from .topology import SpaceTopology, forward_base_topology
@@ -69,12 +69,6 @@ class NanogridSpaceTopology(SpaceTopology):
69
69
  def name(self):
70
70
  return f"{self.geometry.name}_{self._shape.name}"
71
71
 
72
- @cache.cached_arg_value
73
- def topo_arg_value(self, device):
74
- arg = NanogridTopologyArg()
75
- self.fill_topo_arg(arg, device)
76
- return arg
77
-
78
72
  def fill_topo_arg(self, arg, device):
79
73
  arg.vertex_grid = self._vertex_grid
80
74
  arg.face_grid = self._face_grid
@@ -16,10 +16,10 @@
16
16
  from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- import warp.fem.cache as cache
20
- from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
21
- from warp.fem.types import NULL_NODE_INDEX
22
- from warp.fem.utils import compress_node_indices
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import GeometryPartition, WholeGeometryPartition
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX
22
+ from warp._src.fem.utils import compress_node_indices
23
23
 
24
24
  from .function_space import FunctionSpace
25
25
  from .topology import SpaceTopology
@@ -47,9 +47,16 @@ class SpacePartition:
47
47
  def space_node_indices(self) -> wp.array:
48
48
  """Return the global function space indices for nodes in this partition"""
49
49
 
50
- def partition_arg_value(self, device):
50
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
51
+ """Rebuild the space partition indices"""
51
52
  pass
52
53
 
54
+ @cache.cached_arg_value
55
+ def partition_arg_value(self, device):
56
+ arg = self.PartitionArg()
57
+ self.fill_partition_arg(arg, device)
58
+ return arg
59
+
53
60
  def fill_partition_arg(self, arg, device):
54
61
  pass
55
62
 
@@ -90,8 +97,8 @@ class WholeSpacePartition(SpacePartition):
90
97
  """Return the global function space indices for nodes in this partition"""
91
98
  if self._node_indices is None:
92
99
  self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
93
- wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array])
94
- return self._node_indices.array
100
+ wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices])
101
+ return self._node_indices
95
102
 
96
103
  def partition_arg_value(self, device):
97
104
  return WholeSpacePartition.PartitionArg()
@@ -140,45 +147,56 @@ class NodePartition(SpacePartition):
140
147
  space_topology: SpaceTopology,
141
148
  geo_partition: GeometryPartition,
142
149
  with_halo: bool = True,
150
+ max_node_count: int = -1,
143
151
  device=None,
144
- temporary_store: cache.TemporaryStore = None,
152
+ temporary_store: Optional[cache.TemporaryStore] = None,
145
153
  ):
146
154
  super().__init__(space_topology=space_topology, geo_partition=geo_partition)
147
155
 
148
- self._compute_node_indices_from_sides(device, with_halo, temporary_store)
156
+ if max_node_count >= 0:
157
+ max_node_count = min(max_node_count, space_topology.node_count())
158
+
159
+ self._max_node_count = max_node_count
160
+ self._with_halo = with_halo
161
+
162
+ self._category_offsets: wp.array = None
163
+ """Offsets for each node category"""
164
+ self._node_indices: wp.array = None
165
+ """Mapping from local partition node indices to global space node indices"""
166
+ self._space_to_partition: wp.array = None
167
+ """Mapping from global space node indices to local partition node indices"""
168
+
169
+ self.rebuild(device, temporary_store)
170
+
171
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
172
+ self._compute_node_indices_from_sides(device, self._with_halo, self._max_node_count, temporary_store)
149
173
 
150
174
  def node_count(self) -> int:
151
175
  """Returns number of nodes referenced by this partition, including exterior halo"""
152
- return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
176
+ return int(self._category_offsets.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
153
177
 
154
178
  def owned_node_count(self) -> int:
155
179
  """Returns number of nodes in this partition, excluding exterior halo"""
156
- return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
180
+ return int(self._category_offsets.numpy()[NodeCategory.OWNED_FRONTIER + 1])
157
181
 
158
182
  def interior_node_count(self) -> int:
159
183
  """Returns number of interior nodes in this partition"""
160
- return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
184
+ return int(self._category_offsets.numpy()[NodeCategory.OWNED_INTERIOR + 1])
161
185
 
162
186
  def space_node_indices(self):
163
187
  """Return the global function space indices for nodes in this partition"""
164
- return self._node_indices.array
165
-
166
- @cache.cached_arg_value
167
- def partition_arg_value(self, device):
168
- arg = NodePartition.PartitionArg()
169
- self.fill_partition_arg(arg, device)
170
- return arg
188
+ return self._node_indices
171
189
 
172
190
  def fill_partition_arg(self, arg, device):
173
- arg.space_to_partition = self._space_to_partition.array.to(device)
191
+ arg.space_to_partition = self._space_to_partition.to(device)
174
192
 
175
193
  @wp.func
176
194
  def partition_node_index(args: PartitionArg, space_node_index: int):
177
195
  return args.space_to_partition[space_node_index]
178
196
 
179
- def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: cache.TemporaryStore):
180
- from warp.fem import cache
181
-
197
+ def _compute_node_indices_from_sides(
198
+ self, device, with_halo: bool, max_node_count: int, temporary_store: cache.TemporaryStore
199
+ ):
182
200
  trace_topology = self.space_topology.trace()
183
201
 
184
202
  @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
@@ -191,6 +209,8 @@ class NodePartition(SpacePartition):
191
209
  partition_cell_index = wp.tid()
192
210
 
193
211
  cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
212
+ if cell_index == NULL_ELEMENT_INDEX:
213
+ return
194
214
 
195
215
  cell_node_count = self.space_topology.element_node_count(geo_arg, space_arg, cell_index)
196
216
  for n in range(cell_node_count):
@@ -207,6 +227,8 @@ class NodePartition(SpacePartition):
207
227
  partition_side_index = wp.tid()
208
228
 
209
229
  side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
230
+ if side_index == NULL_ELEMENT_INDEX:
231
+ return
210
232
 
211
233
  side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
212
234
  for n in range(side_node_count):
@@ -225,6 +247,8 @@ class NodePartition(SpacePartition):
225
247
  frontier_side_index = wp.tid()
226
248
 
227
249
  side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
250
+ if side_index == NULL_ELEMENT_INDEX:
251
+ return
228
252
 
229
253
  side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
230
254
  for n in range(side_node_count):
@@ -240,7 +264,7 @@ class NodePartition(SpacePartition):
240
264
  dtype=int,
241
265
  device=device,
242
266
  )
243
- node_category.array.fill_(value=NodeCategory.EXTERIOR)
267
+ node_category.fill_(value=NodeCategory.EXTERIOR)
244
268
 
245
269
  wp.launch(
246
270
  dim=self.geo_partition.cell_count(),
@@ -249,7 +273,7 @@ class NodePartition(SpacePartition):
249
273
  self.geo_partition.geometry.cell_arg_value(device),
250
274
  self.geo_partition.cell_arg_value(device),
251
275
  self.space_topology.topo_arg_value(device),
252
- node_category.array,
276
+ node_category,
253
277
  ],
254
278
  device=device,
255
279
  )
@@ -262,7 +286,7 @@ class NodePartition(SpacePartition):
262
286
  self.geo_partition.geometry.side_arg_value(device),
263
287
  self.geo_partition.side_arg_value(device),
264
288
  self.space_topology.topo_arg_value(device),
265
- node_category.array,
289
+ node_category,
266
290
  ],
267
291
  device=device,
268
292
  )
@@ -274,53 +298,73 @@ class NodePartition(SpacePartition):
274
298
  self.geo_partition.geometry.side_arg_value(device),
275
299
  self.geo_partition.side_arg_value(device),
276
300
  self.space_topology.topo_arg_value(device),
277
- node_category.array,
301
+ node_category,
278
302
  ],
279
303
  device=device,
280
304
  )
281
305
 
282
- self._finalize_node_indices(node_category.array, temporary_store)
306
+ with wp.ScopedDevice(device):
307
+ self._finalize_node_indices(node_category, max_node_count, temporary_store)
283
308
 
284
309
  node_category.release()
285
310
 
286
- def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: cache.TemporaryStore):
311
+ def _finalize_node_indices(
312
+ self, node_category: wp.array(dtype=int), max_node_count: int, temporary_store: cache.TemporaryStore
313
+ ):
287
314
  category_offsets, node_indices = compress_node_indices(
288
315
  NodeCategory.COUNT, node_category, temporary_store=temporary_store
289
316
  )
290
-
291
- # Copy offsets to cpu
292
317
  device = node_category.device
293
- with wp.ScopedDevice(device):
294
- self._category_offsets = cache.borrow_temporary(
295
- temporary_store,
296
- shape=category_offsets.array.shape,
297
- dtype=category_offsets.array.dtype,
298
- pinned=device.is_cuda,
299
- device="cpu",
300
- )
301
- wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
318
+
319
+ if max_node_count >= 0:
320
+ if self._category_offsets is None:
321
+ self._category_offsets = cache.borrow_temporary(
322
+ temporary_store,
323
+ shape=(NodeCategory.COUNT + 1,),
324
+ dtype=category_offsets.dtype,
325
+ device="cpu",
326
+ )
327
+ self._category_offsets.fill_(max_node_count)
328
+ copy_event = None
329
+ else:
330
+ # Copy offsets to cpu
331
+ if self._category_offsets is None:
332
+ self._category_offsets = cache.borrow_temporary(
333
+ temporary_store,
334
+ shape=(NodeCategory.COUNT + 1,),
335
+ dtype=category_offsets.dtype,
336
+ pinned=device.is_cuda,
337
+ device="cpu",
338
+ )
339
+ wp.copy(src=category_offsets, dest=self._category_offsets, count=NodeCategory.COUNT + 1)
302
340
  copy_event = cache.capture_event()
303
341
 
304
- # Compute global to local indices
342
+ # Compute global to local indices
343
+ if self._space_to_partition is None or self._space_to_partition.shape != node_indices.shape:
305
344
  self._space_to_partition = cache.borrow_temporary_like(node_indices, temporary_store)
306
- wp.launch(
307
- kernel=NodePartition._scatter_partition_indices,
308
- dim=self.space_topology.node_count(),
309
- device=device,
310
- inputs=[category_offsets.array, node_indices.array, self._space_to_partition.array],
311
- )
312
345
 
313
- # Copy to shrinked-to-fit array
346
+ wp.launch(
347
+ kernel=NodePartition._scatter_partition_indices,
348
+ dim=self.space_topology.node_count(),
349
+ device=device,
350
+ inputs=[max_node_count, category_offsets, node_indices, self._space_to_partition],
351
+ )
352
+
353
+ if copy_event is not None:
314
354
  cache.synchronize_event(copy_event) # Transfer to host must be finished to access node_count()
355
+
356
+ # Copy to shrunk-to-fit array
357
+ if self._node_indices is None or self._node_indices.shape[0] != self.node_count():
315
358
  self._node_indices = cache.borrow_temporary(
316
- temporary_store, shape=(self.node_count()), dtype=int, device=device
359
+ temporary_store, shape=(self.node_count(),), dtype=int, device=device
317
360
  )
318
- wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
319
361
 
320
- node_indices.release()
362
+ wp.copy(dest=self._node_indices, src=node_indices, count=self.node_count())
363
+ node_indices.release()
321
364
 
322
365
  @wp.kernel
323
366
  def _scatter_partition_indices(
367
+ max_node_count: int,
324
368
  category_offsets: wp.array(dtype=int),
325
369
  node_indices: wp.array(dtype=int),
326
370
  space_to_partition_indices: wp.array(dtype=int),
@@ -329,6 +373,17 @@ class NodePartition(SpacePartition):
329
373
  space_idx = node_indices[local_idx]
330
374
 
331
375
  local_node_count = category_offsets[NodeCategory.EXTERIOR] # all but exterior nodes
376
+ if max_node_count >= 0:
377
+ if local_node_count > max_node_count:
378
+ if local_idx == 0:
379
+ wp.printf(
380
+ "Number of space partition nodes exceeded the %d limit; increase `max_node_count` to %d.\n",
381
+ max_node_count,
382
+ local_node_count,
383
+ )
384
+
385
+ local_node_count = max_node_count
386
+
332
387
  if local_idx < local_node_count:
333
388
  space_to_partition_indices[space_idx] = local_idx
334
389
  else:
@@ -340,6 +395,7 @@ def make_space_partition(
340
395
  geometry_partition: Optional[GeometryPartition] = None,
341
396
  space_topology: Optional[SpaceTopology] = None,
342
397
  with_halo: bool = True,
398
+ max_node_count: int = -1,
343
399
  device=None,
344
400
  temporary_store: cache.TemporaryStore = None,
345
401
  ) -> SpacePartition:
@@ -352,6 +408,7 @@ def make_space_partition(
352
408
  geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
353
409
  space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
354
410
  with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
411
+ max_node_count: if positive, will be used to limit the number of nodes to avoid device/host synchronization.
355
412
  device: Warp device on which to perform and store computations
356
413
 
357
414
  Returns:
@@ -363,14 +420,14 @@ def make_space_partition(
363
420
 
364
421
  space_topology = space_topology.full_space_topology()
365
422
 
366
- if geometry_partition is not None:
367
- if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
368
- return NodePartition(
369
- space_topology=space_topology,
370
- geo_partition=geometry_partition,
371
- with_halo=with_halo,
372
- device=device,
373
- temporary_store=temporary_store,
374
- )
423
+ if geometry_partition is not None and not isinstance(geometry_partition, WholeGeometryPartition):
424
+ return NodePartition(
425
+ space_topology=space_topology,
426
+ geo_partition=geometry_partition,
427
+ with_halo=with_halo,
428
+ max_node_count=max_node_count,
429
+ device=device,
430
+ temporary_store=temporary_store,
431
+ )
375
432
 
376
433
  return WholeSpacePartition(space_topology)
@@ -14,10 +14,10 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import warp as wp
17
- from warp.fem import cache
18
- from warp.fem.geometry import Quadmesh2D
19
- from warp.fem.polynomial import is_closed
20
- from warp.fem.types import NULL_NODE_INDEX, ElementIndex
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Quadmesh2D
19
+ from warp._src.fem.polynomial import is_closed
20
+ from warp._src.fem.types import NULL_NODE_INDEX, ElementIndex
21
21
 
22
22
  from .shape import SquareShapeFunction
23
23
  from .topology import SpaceTopology, forward_base_topology
@@ -52,12 +52,6 @@ class QuadmeshSpaceTopology(SpaceTopology):
52
52
  def name(self):
53
53
  return f"{self.geometry.name}_{self._shape.name}"
54
54
 
55
- @cache.cached_arg_value
56
- def topo_arg_value(self, device):
57
- arg = Quadmesh2DTopologyArg()
58
- self.fill_topo_arg(arg, device)
59
- return arg
60
-
61
55
  def fill_topo_arg(self, arg: Quadmesh2DTopologyArg, device):
62
56
  arg.quad_edge_indices = self._quad_edge_indices.to(device)
63
57
  arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)