warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,16 @@
15
15
 
16
16
  from typing import Any
17
17
 
18
- import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import Geometry
21
- from warp.fem.types import Coords, ElementIndex, ElementKind, Sample, make_free_sample
18
+ from warp._src.codegen import Struct, StructInstance
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Geometry
21
+ from warp._src.fem.types import Coords, ElementIndex, ElementKind, Sample, make_free_sample
22
+ from warp._src.types import type_is_matrix, type_is_vector, type_size
22
23
 
23
24
  from .topology import SpaceTopology
24
25
 
26
+ _wp_module_name_ = "warp.fem.space.function_space"
27
+
25
28
 
26
29
  class FunctionSpace:
27
30
  """
@@ -41,7 +44,7 @@ class FunctionSpace:
41
44
  dof_dtype: type
42
45
  """Data type of the degrees of freedom of each node"""
43
46
 
44
- SpaceArg: wp.codegen.Struct
47
+ SpaceArg: Struct
45
48
  """Structure containing arguments to be passed to device function"""
46
49
 
47
50
  LocalValueMap: type
@@ -71,7 +74,7 @@ class FunctionSpace:
71
74
  """Number of nodes in the interpolation basis"""
72
75
  return self.topology.node_count()
73
76
 
74
- def space_arg_value(self, device) -> wp.codegen.StructInstance:
77
+ def space_arg_value(self, device) -> StructInstance:
75
78
  """Value of the arguments to be passed to device functions"""
76
79
  raise NotImplementedError
77
80
 
@@ -123,13 +126,13 @@ class FunctionSpace:
123
126
 
124
127
  def gradient_valid(self) -> bool:
125
128
  """Whether gradient operator can be computed. Only for scalar and vector fields as higher-order tensors are not supported yet"""
126
- return not wp.types.type_is_matrix(self.dtype)
129
+ return not type_is_matrix(self.dtype)
127
130
 
128
131
  def divergence_valid(self) -> bool:
129
132
  """Whether divergence of this field can be computed. Only for vector and tensor fields with same dimension as embedding geometry"""
130
- if wp.types.type_is_vector(self.dtype):
131
- return wp.types.type_size(self.dtype) == self.geometry.dimension
132
- if wp.types.type_is_matrix(self.dtype):
133
+ if type_is_vector(self.dtype):
134
+ return type_size(self.dtype) == self.geometry.dimension
135
+ if type_is_matrix(self.dtype):
133
136
  return self.dtype._shape_[0] == self.geometry.dimension
134
137
  return False
135
138
 
@@ -243,7 +246,7 @@ class FunctionSpace:
243
246
  - node_weight: weight associated to the node, as given per `element_(inn|out)er_weight`
244
247
  - local_value_map: local transformation from node space to world space, as given per `local_map_value_(inn|out)er`
245
248
  """
246
- raise NotADirectoryError
249
+ raise NotImplementedError
247
250
 
248
251
  def space_gradient(
249
252
  dof_value: "FunctionSpace.dof_dtype",
@@ -257,7 +260,7 @@ class FunctionSpace:
257
260
  - dof_value: node value in the degrees-of-freedom basis
258
261
  - node_weight_gradient: gradient of the weight associated to the node, as given per `element_(inn|out)er_weight_gradient`
259
262
  - local_value_map: local transformation from node space to world space, as given per `local_map_value_(inn|out)er`
260
- - grad_transform: transform mapping the reference space gradient to worls-space gradient (inverse deformation gradient)
263
+ - grad_transform: transform mapping the reference-space gradient to world-space gradient (inverse deformation gradient)
261
264
  """
262
265
  raise NotImplementedError
263
266
 
@@ -273,7 +276,7 @@ class FunctionSpace:
273
276
  - dof_value: node value in the degrees-of-freedom basis
274
277
  - node_weight_gradient: gradient of the weight associated to the node, as given per `element_(inn|out)er_weight_gradient`
275
278
  - local_value_map: local transformation from node space to world space, as given per `local_map_value_(inn|out)er`
276
- - grad_transform: transform mapping the reference space gradient to worls-space gradient (inverse deformation gradient)
279
+ - grad_transform: transform mapping the reference-space gradient to world-space gradient (inverse deformation gradient)
277
280
  """
278
281
  raise NotImplementedError
279
282
 
@@ -16,14 +16,16 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import Grid2D
21
- from warp.fem.polynomial import is_closed
22
- from warp.fem.types import NULL_NODE_INDEX, ElementIndex
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Grid2D
21
+ from warp._src.fem.polynomial import is_closed
22
+ from warp._src.fem.types import NULL_NODE_INDEX, ElementIndex
23
23
 
24
24
  from .shape import SquareBipolynomialShapeFunctions, SquareShapeFunction
25
25
  from .topology import SpaceTopology, forward_base_topology
26
26
 
27
+ _wp_module_name_ = "warp.fem.space.grid_2d_function_space"
28
+
27
29
 
28
30
  class Grid2DSpaceTopology(SpaceTopology):
29
31
  def __init__(self, grid: Grid2D, shape: SquareShapeFunction):
@@ -38,9 +40,6 @@ class Grid2DSpaceTopology(SpaceTopology):
38
40
  def name(self):
39
41
  return f"{self.geometry.name}_{self._shape.name}"
40
42
 
41
- def topo_arg_value(self, device):
42
- return self.geometry.side_arg_value(device)
43
-
44
43
  def fill_topo_arg(self, arg: Grid2D.SideArg, device):
45
44
  self.geometry.fill_side_arg(arg, device)
46
45
 
@@ -16,10 +16,10 @@
16
16
  import numpy as np
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import Grid3D
21
- from warp.fem.polynomial import is_closed
22
- from warp.fem.types import ElementIndex
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import Grid3D
21
+ from warp._src.fem.polynomial import is_closed
22
+ from warp._src.fem.types import ElementIndex
23
23
 
24
24
  from .shape import (
25
25
  CubeShapeFunction,
@@ -27,6 +27,8 @@ from .shape import (
27
27
  )
28
28
  from .topology import SpaceTopology, forward_base_topology
29
29
 
30
+ _wp_module_name_ = "warp.fem.space.grid_3d_function_space"
31
+
30
32
 
31
33
  class Grid3DSpaceTopology(SpaceTopology):
32
34
  def __init__(self, grid: Grid3D, shape: CubeShapeFunction):
@@ -14,18 +14,20 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import warp as wp
17
- from warp.fem import cache
18
- from warp.fem.geometry import Hexmesh
19
- from warp.fem.geometry.hexmesh import (
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Hexmesh
19
+ from warp._src.fem.geometry.hexmesh import (
20
20
  EDGE_VERTEX_INDICES,
21
21
  FACE_ORIENTATION,
22
22
  FACE_TRANSLATION,
23
23
  )
24
- from warp.fem.types import ElementIndex
24
+ from warp._src.fem.types import ElementIndex
25
25
 
26
26
  from .shape import CubeShapeFunction
27
27
  from .topology import SpaceTopology, forward_base_topology
28
28
 
29
+ _wp_module_name_ = "warp.fem.space.hexmesh_function_space"
30
+
29
31
  _FACE_ORIENTATION_I = wp.constant(wp.mat(shape=(16, 2), dtype=int)(FACE_ORIENTATION))
30
32
  _FACE_TRANSLATION_I = wp.constant(wp.mat(shape=(4, 2), dtype=int)(FACE_TRANSLATION))
31
33
 
@@ -82,12 +84,6 @@ class HexmeshSpaceTopology(SpaceTopology):
82
84
  def name(self):
83
85
  return f"{self.geometry.name}_{self._shape.name}"
84
86
 
85
- @cache.cached_arg_value
86
- def topo_arg_value(self, device):
87
- arg = HexmeshTopologyArg()
88
- self.fill_topo_arg(arg, device)
89
- return arg
90
-
91
87
  def fill_topo_arg(self, arg: HexmeshTopologyArg, device):
92
88
  arg.hex_edge_indices = self._hex_edge_indices.to(device)
93
89
  arg.hex_face_indices = self._hex_face_indices.to(device)
@@ -16,13 +16,15 @@
16
16
  from typing import Union
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.geometry import AdaptiveNanogrid, Nanogrid
21
- from warp.fem.types import ElementIndex
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import AdaptiveNanogrid, Nanogrid
21
+ from warp._src.fem.types import ElementIndex
22
22
 
23
23
  from .shape import CubeShapeFunction
24
24
  from .topology import SpaceTopology, forward_base_topology
25
25
 
26
+ _wp_module_name_ = "warp.fem.space.nanogrid_function_space"
27
+
26
28
 
27
29
  @wp.struct
28
30
  class NanogridTopologyArg:
@@ -69,12 +71,6 @@ class NanogridSpaceTopology(SpaceTopology):
69
71
  def name(self):
70
72
  return f"{self.geometry.name}_{self._shape.name}"
71
73
 
72
- @cache.cached_arg_value
73
- def topo_arg_value(self, device):
74
- arg = NanogridTopologyArg()
75
- self.fill_topo_arg(arg, device)
76
- return arg
77
-
78
74
  def fill_topo_arg(self, arg, device):
79
75
  arg.vertex_grid = self._vertex_grid
80
76
  arg.face_grid = self._face_grid
@@ -16,14 +16,16 @@
16
16
  from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- import warp.fem.cache as cache
20
- from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
21
- from warp.fem.types import NULL_NODE_INDEX
22
- from warp.fem.utils import compress_node_indices
19
+ from warp._src.fem import cache
20
+ from warp._src.fem.geometry import GeometryPartition, WholeGeometryPartition
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, NULL_NODE_INDEX
22
+ from warp._src.fem.utils import compress_node_indices
23
23
 
24
24
  from .function_space import FunctionSpace
25
25
  from .topology import SpaceTopology
26
26
 
27
+ _wp_module_name_ = "warp.fem.space.partition"
28
+
27
29
  wp.set_module_options({"enable_backward": False})
28
30
 
29
31
 
@@ -47,9 +49,16 @@ class SpacePartition:
47
49
  def space_node_indices(self) -> wp.array:
48
50
  """Return the global function space indices for nodes in this partition"""
49
51
 
50
- def partition_arg_value(self, device):
52
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
53
+ """Rebuild the space partition indices"""
51
54
  pass
52
55
 
56
+ @cache.cached_arg_value
57
+ def partition_arg_value(self, device):
58
+ arg = self.PartitionArg()
59
+ self.fill_partition_arg(arg, device)
60
+ return arg
61
+
53
62
  def fill_partition_arg(self, arg, device):
54
63
  pass
55
64
 
@@ -90,8 +99,8 @@ class WholeSpacePartition(SpacePartition):
90
99
  """Return the global function space indices for nodes in this partition"""
91
100
  if self._node_indices is None:
92
101
  self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
93
- wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array])
94
- return self._node_indices.array
102
+ wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices])
103
+ return self._node_indices
95
104
 
96
105
  def partition_arg_value(self, device):
97
106
  return WholeSpacePartition.PartitionArg()
@@ -140,45 +149,56 @@ class NodePartition(SpacePartition):
140
149
  space_topology: SpaceTopology,
141
150
  geo_partition: GeometryPartition,
142
151
  with_halo: bool = True,
152
+ max_node_count: int = -1,
143
153
  device=None,
144
- temporary_store: cache.TemporaryStore = None,
154
+ temporary_store: Optional[cache.TemporaryStore] = None,
145
155
  ):
146
156
  super().__init__(space_topology=space_topology, geo_partition=geo_partition)
147
157
 
148
- self._compute_node_indices_from_sides(device, with_halo, temporary_store)
158
+ if max_node_count >= 0:
159
+ max_node_count = min(max_node_count, space_topology.node_count())
160
+
161
+ self._max_node_count = max_node_count
162
+ self._with_halo = with_halo
163
+
164
+ self._category_offsets: wp.array = None
165
+ """Offsets for each node category"""
166
+ self._node_indices: wp.array = None
167
+ """Mapping from local partition node indices to global space node indices"""
168
+ self._space_to_partition: wp.array = None
169
+ """Mapping from global space node indices to local partition node indices"""
170
+
171
+ self.rebuild(device, temporary_store)
172
+
173
+ def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None):
174
+ self._compute_node_indices_from_sides(device, self._with_halo, self._max_node_count, temporary_store)
149
175
 
150
176
  def node_count(self) -> int:
151
177
  """Returns number of nodes referenced by this partition, including exterior halo"""
152
- return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
178
+ return int(self._category_offsets.numpy()[NodeCategory.HALO_OTHER_SIDE + 1])
153
179
 
154
180
  def owned_node_count(self) -> int:
155
181
  """Returns number of nodes in this partition, excluding exterior halo"""
156
- return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1])
182
+ return int(self._category_offsets.numpy()[NodeCategory.OWNED_FRONTIER + 1])
157
183
 
158
184
  def interior_node_count(self) -> int:
159
185
  """Returns number of interior nodes in this partition"""
160
- return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1])
186
+ return int(self._category_offsets.numpy()[NodeCategory.OWNED_INTERIOR + 1])
161
187
 
162
188
  def space_node_indices(self):
163
189
  """Return the global function space indices for nodes in this partition"""
164
- return self._node_indices.array
165
-
166
- @cache.cached_arg_value
167
- def partition_arg_value(self, device):
168
- arg = NodePartition.PartitionArg()
169
- self.fill_partition_arg(arg, device)
170
- return arg
190
+ return self._node_indices
171
191
 
172
192
  def fill_partition_arg(self, arg, device):
173
- arg.space_to_partition = self._space_to_partition.array.to(device)
193
+ arg.space_to_partition = self._space_to_partition.to(device)
174
194
 
175
195
  @wp.func
176
196
  def partition_node_index(args: PartitionArg, space_node_index: int):
177
197
  return args.space_to_partition[space_node_index]
178
198
 
179
- def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: cache.TemporaryStore):
180
- from warp.fem import cache
181
-
199
+ def _compute_node_indices_from_sides(
200
+ self, device, with_halo: bool, max_node_count: int, temporary_store: cache.TemporaryStore
201
+ ):
182
202
  trace_topology = self.space_topology.trace()
183
203
 
184
204
  @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}")
@@ -191,6 +211,8 @@ class NodePartition(SpacePartition):
191
211
  partition_cell_index = wp.tid()
192
212
 
193
213
  cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index)
214
+ if cell_index == NULL_ELEMENT_INDEX:
215
+ return
194
216
 
195
217
  cell_node_count = self.space_topology.element_node_count(geo_arg, space_arg, cell_index)
196
218
  for n in range(cell_node_count):
@@ -207,6 +229,8 @@ class NodePartition(SpacePartition):
207
229
  partition_side_index = wp.tid()
208
230
 
209
231
  side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index)
232
+ if side_index == NULL_ELEMENT_INDEX:
233
+ return
210
234
 
211
235
  side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
212
236
  for n in range(side_node_count):
@@ -225,6 +249,8 @@ class NodePartition(SpacePartition):
225
249
  frontier_side_index = wp.tid()
226
250
 
227
251
  side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index)
252
+ if side_index == NULL_ELEMENT_INDEX:
253
+ return
228
254
 
229
255
  side_node_count = trace_topology.element_node_count(geo_arg, space_arg, side_index)
230
256
  for n in range(side_node_count):
@@ -240,7 +266,7 @@ class NodePartition(SpacePartition):
240
266
  dtype=int,
241
267
  device=device,
242
268
  )
243
- node_category.array.fill_(value=NodeCategory.EXTERIOR)
269
+ node_category.fill_(value=NodeCategory.EXTERIOR)
244
270
 
245
271
  wp.launch(
246
272
  dim=self.geo_partition.cell_count(),
@@ -249,7 +275,7 @@ class NodePartition(SpacePartition):
249
275
  self.geo_partition.geometry.cell_arg_value(device),
250
276
  self.geo_partition.cell_arg_value(device),
251
277
  self.space_topology.topo_arg_value(device),
252
- node_category.array,
278
+ node_category,
253
279
  ],
254
280
  device=device,
255
281
  )
@@ -262,7 +288,7 @@ class NodePartition(SpacePartition):
262
288
  self.geo_partition.geometry.side_arg_value(device),
263
289
  self.geo_partition.side_arg_value(device),
264
290
  self.space_topology.topo_arg_value(device),
265
- node_category.array,
291
+ node_category,
266
292
  ],
267
293
  device=device,
268
294
  )
@@ -274,53 +300,73 @@ class NodePartition(SpacePartition):
274
300
  self.geo_partition.geometry.side_arg_value(device),
275
301
  self.geo_partition.side_arg_value(device),
276
302
  self.space_topology.topo_arg_value(device),
277
- node_category.array,
303
+ node_category,
278
304
  ],
279
305
  device=device,
280
306
  )
281
307
 
282
- self._finalize_node_indices(node_category.array, temporary_store)
308
+ with wp.ScopedDevice(device):
309
+ self._finalize_node_indices(node_category, max_node_count, temporary_store)
283
310
 
284
311
  node_category.release()
285
312
 
286
- def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: cache.TemporaryStore):
313
+ def _finalize_node_indices(
314
+ self, node_category: wp.array(dtype=int), max_node_count: int, temporary_store: cache.TemporaryStore
315
+ ):
287
316
  category_offsets, node_indices = compress_node_indices(
288
317
  NodeCategory.COUNT, node_category, temporary_store=temporary_store
289
318
  )
290
-
291
- # Copy offsets to cpu
292
319
  device = node_category.device
293
- with wp.ScopedDevice(device):
294
- self._category_offsets = cache.borrow_temporary(
295
- temporary_store,
296
- shape=category_offsets.array.shape,
297
- dtype=category_offsets.array.dtype,
298
- pinned=device.is_cuda,
299
- device="cpu",
300
- )
301
- wp.copy(src=category_offsets.array, dest=self._category_offsets.array)
320
+
321
+ if max_node_count >= 0:
322
+ if self._category_offsets is None:
323
+ self._category_offsets = cache.borrow_temporary(
324
+ temporary_store,
325
+ shape=(NodeCategory.COUNT + 1,),
326
+ dtype=category_offsets.dtype,
327
+ device="cpu",
328
+ )
329
+ self._category_offsets.fill_(max_node_count)
330
+ copy_event = None
331
+ else:
332
+ # Copy offsets to cpu
333
+ if self._category_offsets is None:
334
+ self._category_offsets = cache.borrow_temporary(
335
+ temporary_store,
336
+ shape=(NodeCategory.COUNT + 1,),
337
+ dtype=category_offsets.dtype,
338
+ pinned=device.is_cuda,
339
+ device="cpu",
340
+ )
341
+ wp.copy(src=category_offsets, dest=self._category_offsets, count=NodeCategory.COUNT + 1)
302
342
  copy_event = cache.capture_event()
303
343
 
304
- # Compute global to local indices
344
+ # Compute global to local indices
345
+ if self._space_to_partition is None or self._space_to_partition.shape != node_indices.shape:
305
346
  self._space_to_partition = cache.borrow_temporary_like(node_indices, temporary_store)
306
- wp.launch(
307
- kernel=NodePartition._scatter_partition_indices,
308
- dim=self.space_topology.node_count(),
309
- device=device,
310
- inputs=[category_offsets.array, node_indices.array, self._space_to_partition.array],
311
- )
312
347
 
313
- # Copy to shrinked-to-fit array
348
+ wp.launch(
349
+ kernel=NodePartition._scatter_partition_indices,
350
+ dim=self.space_topology.node_count(),
351
+ device=device,
352
+ inputs=[max_node_count, category_offsets, node_indices, self._space_to_partition],
353
+ )
354
+
355
+ if copy_event is not None:
314
356
  cache.synchronize_event(copy_event) # Transfer to host must be finished to access node_count()
357
+
358
+ # Copy to shrunk-to-fit array
359
+ if self._node_indices is None or self._node_indices.shape[0] != self.node_count():
315
360
  self._node_indices = cache.borrow_temporary(
316
- temporary_store, shape=(self.node_count()), dtype=int, device=device
361
+ temporary_store, shape=(self.node_count(),), dtype=int, device=device
317
362
  )
318
- wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count())
319
363
 
320
- node_indices.release()
364
+ wp.copy(dest=self._node_indices, src=node_indices, count=self.node_count())
365
+ node_indices.release()
321
366
 
322
367
  @wp.kernel
323
368
  def _scatter_partition_indices(
369
+ max_node_count: int,
324
370
  category_offsets: wp.array(dtype=int),
325
371
  node_indices: wp.array(dtype=int),
326
372
  space_to_partition_indices: wp.array(dtype=int),
@@ -329,6 +375,17 @@ class NodePartition(SpacePartition):
329
375
  space_idx = node_indices[local_idx]
330
376
 
331
377
  local_node_count = category_offsets[NodeCategory.EXTERIOR] # all but exterior nodes
378
+ if max_node_count >= 0:
379
+ if local_node_count > max_node_count:
380
+ if local_idx == 0:
381
+ wp.printf(
382
+ "Number of space partition nodes exceeded the %d limit; increase `max_node_count` to %d.\n",
383
+ max_node_count,
384
+ local_node_count,
385
+ )
386
+
387
+ local_node_count = max_node_count
388
+
332
389
  if local_idx < local_node_count:
333
390
  space_to_partition_indices[space_idx] = local_idx
334
391
  else:
@@ -340,6 +397,7 @@ def make_space_partition(
340
397
  geometry_partition: Optional[GeometryPartition] = None,
341
398
  space_topology: Optional[SpaceTopology] = None,
342
399
  with_halo: bool = True,
400
+ max_node_count: int = -1,
343
401
  device=None,
344
402
  temporary_store: cache.TemporaryStore = None,
345
403
  ) -> SpacePartition:
@@ -352,6 +410,7 @@ def make_space_partition(
352
410
  geometry_partition: The subset of the space geometry. If not provided, use the whole geometry.
353
411
  space_topology: the topology of the function space to consider. If ``None``, deduced from `space`.
354
412
  with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition)
413
+ max_node_count: if positive, will be used to limit the number of nodes to avoid device/host synchronization.
355
414
  device: Warp device on which to perform and store computations
356
415
 
357
416
  Returns:
@@ -363,14 +422,14 @@ def make_space_partition(
363
422
 
364
423
  space_topology = space_topology.full_space_topology()
365
424
 
366
- if geometry_partition is not None:
367
- if geometry_partition.cell_count() < geometry_partition.geometry.cell_count():
368
- return NodePartition(
369
- space_topology=space_topology,
370
- geo_partition=geometry_partition,
371
- with_halo=with_halo,
372
- device=device,
373
- temporary_store=temporary_store,
374
- )
425
+ if geometry_partition is not None and not isinstance(geometry_partition, WholeGeometryPartition):
426
+ return NodePartition(
427
+ space_topology=space_topology,
428
+ geo_partition=geometry_partition,
429
+ with_halo=with_halo,
430
+ max_node_count=max_node_count,
431
+ device=device,
432
+ temporary_store=temporary_store,
433
+ )
375
434
 
376
435
  return WholeSpacePartition(space_topology)
@@ -14,14 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import warp as wp
17
- from warp.fem import cache
18
- from warp.fem.geometry import Quadmesh2D
19
- from warp.fem.polynomial import is_closed
20
- from warp.fem.types import NULL_NODE_INDEX, ElementIndex
17
+ from warp._src.fem import cache
18
+ from warp._src.fem.geometry import Quadmesh2D
19
+ from warp._src.fem.polynomial import is_closed
20
+ from warp._src.fem.types import NULL_NODE_INDEX, ElementIndex
21
21
 
22
22
  from .shape import SquareShapeFunction
23
23
  from .topology import SpaceTopology, forward_base_topology
24
24
 
25
+ _wp_module_name_ = "warp.fem.space.quadmesh_function_space"
26
+
25
27
 
26
28
  @wp.struct
27
29
  class Quadmesh2DTopologyArg:
@@ -52,12 +54,6 @@ class QuadmeshSpaceTopology(SpaceTopology):
52
54
  def name(self):
53
55
  return f"{self.geometry.name}_{self._shape.name}"
54
56
 
55
- @cache.cached_arg_value
56
- def topo_arg_value(self, device):
57
- arg = Quadmesh2DTopologyArg()
58
- self.fill_topo_arg(arg, device)
59
- return arg
60
-
61
57
  def fill_topo_arg(self, arg: Quadmesh2DTopologyArg, device):
62
58
  arg.quad_edge_indices = self._quad_edge_indices.to(device)
63
59
  arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)