warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,11 @@ import textwrap
19
19
  from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
20
20
 
21
21
  import warp as wp
22
- import warp.fem.operator as operator
23
- from warp.codegen import get_annotations
24
- from warp.fem import cache
25
- from warp.fem.domain import GeometryDomain
26
- from warp.fem.field import (
22
+ import warp._src.fem.operator as operator
23
+ from warp._src.codegen import Struct, StructInstance, get_annotations
24
+ from warp._src.fem import cache
25
+ from warp._src.fem.domain import GeometryDomain
26
+ from warp._src.fem.field import (
27
27
  DiscreteField,
28
28
  FieldLike,
29
29
  FieldRestriction,
@@ -34,18 +34,18 @@ from warp.fem.field import (
34
34
  TrialField,
35
35
  make_restriction,
36
36
  )
37
- from warp.fem.field.virtual import (
37
+ from warp._src.fem.field.virtual import (
38
38
  make_bilinear_dispatch_kernel,
39
39
  make_linear_dispatch_kernel,
40
40
  )
41
- from warp.fem.linalg import array_axpy, basis_coefficient
42
- from warp.fem.operator import (
41
+ from warp._src.fem.linalg import array_axpy, basis_coefficient
42
+ from warp._src.fem.operator import (
43
43
  Integrand,
44
44
  Operator,
45
45
  integrand,
46
46
  )
47
- from warp.fem.quadrature import Quadrature, RegularQuadrature
48
- from warp.fem.types import (
47
+ from warp._src.fem.quadrature import Quadrature, RegularQuadrature
48
+ from warp._src.fem.types import (
49
49
  NULL_DOF_INDEX,
50
50
  NULL_ELEMENT_INDEX,
51
51
  NULL_NODE_INDEX,
@@ -57,15 +57,17 @@ from warp.fem.types import (
57
57
  Sample,
58
58
  make_free_sample,
59
59
  )
60
- from warp.fem.utils import type_zero_element
61
- from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
62
- from warp.types import is_array, type_size
63
- from warp.utils import array_cast
60
+ from warp._src.fem.utils import type_zero_element
61
+ from warp._src.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
62
+ from warp._src.types import is_array, type_repr, type_scalar_type, type_size, type_to_warp
63
+ from warp._src.utils import array_cast, warn
64
+
65
+ _wp_module_name_ = "warp.fem.integrate"
64
66
 
65
67
 
66
68
  def _resolve_path(func, node):
67
69
  """
68
- Resolves variable and path from ast node/attribute (adapted from warp.codegen)
70
+ Resolves variable and path from ast node/attribute (adapted from warp._src.codegen)
69
71
  """
70
72
 
71
73
  modules = []
@@ -83,20 +85,20 @@ def _resolve_path(func, node):
83
85
  if len(path) == 0:
84
86
  return None, path
85
87
 
86
- # try and evaluate object path
88
+ name = path[0]
87
89
  try:
88
- # Look up the closure info and append it to adj.func.__globals__
89
- # in case you want to define a kernel inside a function and refer
90
- # to variables you've declared inside that function:
91
- capturedvars = dict(zip(func.__code__.co_freevars, [c.cell_contents for c in (func.__closure__ or [])]))
92
-
93
- vars_dict = {**func.__globals__, **capturedvars}
94
- func = eval(".".join(path), vars_dict)
95
- return func, path
96
- except (NameError, AttributeError):
97
- pass
90
+ # look up in closure variables
91
+ idx = func.__code__.co_freevars.index(name)
92
+ expr = func.__closure__[idx].cell_contents
93
+ except ValueError:
94
+ # look up in global variables
95
+ expr = func.__globals__.get(name)
96
+
97
+ for name in path[1:]:
98
+ if expr is not None:
99
+ expr = getattr(expr, name, None)
98
100
 
99
- return None, path
101
+ return expr, path
100
102
 
101
103
 
102
104
  class IntegrandVisitor(ast.NodeTransformer):
@@ -275,7 +277,7 @@ class IntegrandTransformer(IntegrandVisitor):
275
277
  try:
276
278
  # Retrieve the function pointer corresponding to the operator implementation for the field type
277
279
  pointer = operator.resolver(field)
278
- if not isinstance(pointer, wp.context.Function):
280
+ if not isinstance(pointer, wp.Function):
279
281
  raise NotImplementedError(operator.resolver.__name__)
280
282
 
281
283
  except (AttributeError, NotImplementedError) as e:
@@ -360,15 +362,13 @@ def _parse_integrand_arguments(
360
362
  trial_name = None
361
363
 
362
364
  argspec = integrand.argspec
363
- for arg in argspec.args:
364
- arg_type = argspec.annotations[arg]
365
+ for arg, arg_type in argspec.annotations.items():
365
366
  if arg_type == Field:
366
367
  try:
367
368
  field = fields[arg]
368
369
  except KeyError as err:
369
370
  raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
370
- if not isinstance(field, FieldLike):
371
- raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
371
+
372
372
  if isinstance(field, TestField):
373
373
  if test_name is not None:
374
374
  raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
@@ -377,28 +377,26 @@ def _parse_integrand_arguments(
377
377
  if trial_name is not None:
378
378
  raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
379
379
  trial_name = arg
380
+ elif not isinstance(field, FieldLike):
381
+ raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
382
+
380
383
  field_args[arg] = field
381
- elif arg_type == Domain:
384
+ continue
385
+
386
+ if arg in fields:
387
+ raise ValueError(
388
+ f"Cannot pass a field argument to '{arg}' of '{integrand.name}' which is not of type 'Field'"
389
+ )
390
+
391
+ if arg_type == Domain:
382
392
  if domain_name is not None:
383
393
  raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
384
- if arg in fields:
385
- raise ValueError(
386
- f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
387
- )
388
394
  domain_name = arg
389
395
  elif arg_type == Sample:
390
396
  if sample_name is not None:
391
397
  raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
392
- if arg in fields:
393
- raise ValueError(
394
- f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
395
- )
396
398
  sample_name = arg
397
399
  else:
398
- if arg in fields:
399
- raise ValueError(
400
- f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
401
- )
402
400
  value_args[arg] = arg_type
403
401
 
404
402
  return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
@@ -438,10 +436,8 @@ def _notify_operator_usage(
438
436
  integrand: Integrand,
439
437
  field_args: Dict[str, FieldLike],
440
438
  ):
441
- for arg, field_ops in integrand.operators.items():
442
- if arg in field_args:
443
- # print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
444
- field_args[arg].notify_operator_usage(field_ops)
439
+ for arg, field in field_args.items():
440
+ field.notify_operator_usage(integrand.operators.get(arg, set()))
445
441
 
446
442
 
447
443
  def _gen_field_struct(field_args: Dict[str, FieldLike]):
@@ -615,8 +611,8 @@ def get_integrate_constant_kernel(
615
611
  integrand_func: wp.Function,
616
612
  domain: GeometryDomain,
617
613
  quadrature: Quadrature,
618
- FieldStruct: wp.codegen.Struct,
619
- ValueStruct: wp.codegen.Struct,
614
+ FieldStruct: Struct,
615
+ ValueStruct: Struct,
620
616
  accumulate_dtype,
621
617
  tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
622
618
  ):
@@ -641,10 +637,13 @@ def get_integrate_constant_kernel(
641
637
  domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
642
638
 
643
639
  if domain_element_index == NULL_ELEMENT_INDEX:
644
- val = zero_element()
640
+ element_index = NULL_ELEMENT_INDEX
645
641
  else:
646
642
  element_index = domain.element_index(domain_index_arg, domain_element_index)
647
643
 
644
+ if element_index == NULL_ELEMENT_INDEX:
645
+ val = zero_element()
646
+ else:
648
647
  qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
649
648
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
650
649
  qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
@@ -667,8 +666,8 @@ def get_integrate_linear_kernel(
667
666
  integrand_func: wp.Function,
668
667
  domain: GeometryDomain,
669
668
  quadrature: Quadrature,
670
- FieldStruct: wp.codegen.Struct,
671
- ValueStruct: wp.codegen.Struct,
669
+ FieldStruct: Struct,
670
+ ValueStruct: Struct,
672
671
  test: TestField,
673
672
  output_dtype,
674
673
  accumulate_dtype,
@@ -684,6 +683,9 @@ def get_integrate_linear_kernel(
684
683
  ):
685
684
  local_node_index, test_dof = wp.tid()
686
685
  node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
686
+ if node_index == NULL_NODE_INDEX:
687
+ return
688
+
687
689
  element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
688
690
 
689
691
  trial_dof_index = NULL_DOF_INDEX
@@ -725,8 +727,8 @@ def get_integrate_linear_kernel(
725
727
  def get_integrate_linear_nodal_kernel(
726
728
  integrand_func: wp.Function,
727
729
  domain: GeometryDomain,
728
- FieldStruct: wp.codegen.Struct,
729
- ValueStruct: wp.codegen.Struct,
730
+ FieldStruct: Struct,
731
+ ValueStruct: Struct,
730
732
  test: TestField,
731
733
  output_dtype,
732
734
  accumulate_dtype,
@@ -743,6 +745,9 @@ def get_integrate_linear_nodal_kernel(
743
745
  local_node_index, dof = wp.tid()
744
746
 
745
747
  partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
748
+ if partition_node_index == NULL_NODE_INDEX:
749
+ return
750
+
746
751
  element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
747
752
 
748
753
  trial_dof_index = NULL_DOF_INDEX
@@ -797,8 +802,8 @@ def get_integrate_linear_local_kernel(
797
802
  integrand_func: wp.Function,
798
803
  domain: GeometryDomain,
799
804
  quadrature: Quadrature,
800
- FieldStruct: wp.codegen.Struct,
801
- ValueStruct: wp.codegen.Struct,
805
+ FieldStruct: Struct,
806
+ ValueStruct: Struct,
802
807
  test: LocalTestField,
803
808
  ):
804
809
  def integrate_kernel_fn(
@@ -817,6 +822,8 @@ def get_integrate_linear_local_kernel(
817
822
  return
818
823
 
819
824
  element_index = domain.element_index(domain_index_arg, domain_element_index)
825
+ if element_index == NULL_ELEMENT_INDEX:
826
+ return
820
827
 
821
828
  qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
822
829
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
@@ -838,8 +845,8 @@ def get_integrate_bilinear_kernel(
838
845
  integrand_func: wp.Function,
839
846
  domain: GeometryDomain,
840
847
  quadrature: Quadrature,
841
- FieldStruct: wp.codegen.Struct,
842
- ValueStruct: wp.codegen.Struct,
848
+ FieldStruct: Struct,
849
+ ValueStruct: Struct,
843
850
  test: TestField,
844
851
  trial: TrialField,
845
852
  output_dtype,
@@ -863,6 +870,9 @@ def get_integrate_bilinear_kernel(
863
870
  test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
864
871
 
865
872
  test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
873
+ if test_node_index == NULL_NODE_INDEX:
874
+ return
875
+
866
876
  element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
867
877
 
868
878
  trial_dof_index = DofIndex(trial_node, trial_dof)
@@ -934,8 +944,8 @@ def get_integrate_bilinear_kernel(
934
944
  def get_integrate_bilinear_nodal_kernel(
935
945
  integrand_func: wp.Function,
936
946
  domain: GeometryDomain,
937
- FieldStruct: wp.codegen.Struct,
938
- ValueStruct: wp.codegen.Struct,
947
+ FieldStruct: Struct,
948
+ ValueStruct: Struct,
939
949
  test: TestField,
940
950
  output_dtype,
941
951
  accumulate_dtype,
@@ -954,6 +964,11 @@ def get_integrate_bilinear_nodal_kernel(
954
964
  local_node_index, test_dof, trial_dof = wp.tid()
955
965
 
956
966
  partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
967
+ if partition_node_index == NULL_NODE_INDEX:
968
+ triplet_rows[local_node_index] = -1
969
+ triplet_cols[local_node_index] = -1
970
+ return
971
+
957
972
  element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
958
973
 
959
974
  val_sum = accumulate_dtype(0.0)
@@ -1009,8 +1024,8 @@ def get_integrate_bilinear_local_kernel(
1009
1024
  integrand_func: wp.Function,
1010
1025
  domain: GeometryDomain,
1011
1026
  quadrature: Quadrature,
1012
- FieldStruct: wp.codegen.Struct,
1013
- ValueStruct: wp.codegen.Struct,
1027
+ FieldStruct: Struct,
1028
+ ValueStruct: Struct,
1014
1029
  test: LocalTestField,
1015
1030
  trial: LocalTrialField,
1016
1031
  ):
@@ -1033,6 +1048,8 @@ def get_integrate_bilinear_local_kernel(
1033
1048
  return
1034
1049
 
1035
1050
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1051
+ if element_index == NULL_ELEMENT_INDEX:
1052
+ return
1036
1053
 
1037
1054
  qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1038
1055
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
@@ -1066,23 +1083,27 @@ def _generate_integrate_kernel(
1066
1083
  accumulate_dtype: type,
1067
1084
  kernel_options: Optional[Dict[str, Any]] = None,
1068
1085
  ) -> wp.Kernel:
1069
- output_dtype = wp.types.type_scalar_type(output_dtype)
1070
-
1071
- FieldStruct = _gen_field_struct(arguments.field_args)
1072
- ValueStruct = cache.get_argument_struct(arguments.value_args)
1086
+ output_dtype = type_scalar_type(output_dtype)
1073
1087
 
1074
1088
  _notify_operator_usage(integrand, arguments.field_args)
1075
1089
 
1076
1090
  # Check if kernel exist in cache
1077
- field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
1078
- kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{field_names}"
1091
+ field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
1092
+ kernel_suffix = ("itg", field_names, cache.pod_type_key(output_dtype), cache.pod_type_key(accumulate_dtype))
1079
1093
 
1080
1094
  if quadrature is not None:
1081
- kernel_suffix += quadrature.name
1095
+ kernel_suffix = (quadrature.name, *kernel_suffix)
1082
1096
 
1083
- kernel = cache.get_integrand_kernel(integrand=integrand, suffix=kernel_suffix, kernel_options=kernel_options)
1097
+ kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
1098
+ integrand=integrand,
1099
+ suffix=kernel_suffix,
1100
+ kernel_options=kernel_options,
1101
+ )
1084
1102
  if kernel is not None:
1085
- return kernel, FieldStruct, ValueStruct
1103
+ return kernel, field_arg_values, value_struct_values
1104
+
1105
+ FieldStruct = _gen_field_struct(arguments.field_args)
1106
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
1086
1107
 
1087
1108
  # Not found in cache, transform integrand and generate kernel
1088
1109
  _check_field_compat(integrand, arguments, domain)
@@ -1165,7 +1186,7 @@ def _generate_integrate_kernel(
1165
1186
  accumulate_dtype=accumulate_dtype,
1166
1187
  )
1167
1188
 
1168
- kernel = cache.get_integrand_kernel(
1189
+ kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
1169
1190
  integrand=integrand,
1170
1191
  kernel_fn=integrate_kernel_fn,
1171
1192
  suffix=kernel_suffix,
@@ -1175,9 +1196,11 @@ def _generate_integrate_kernel(
1175
1196
  arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
1176
1197
  )
1177
1198
  ],
1199
+ FieldStruct=FieldStruct,
1200
+ ValueStruct=ValueStruct,
1178
1201
  )
1179
1202
 
1180
- return kernel, FieldStruct, ValueStruct
1203
+ return kernel, FieldStruct(), ValueStruct()
1181
1204
 
1182
1205
 
1183
1206
  def _generate_auxiliary_kernels(
@@ -1220,8 +1243,8 @@ def _launch_integrate_kernel(
1220
1243
  integrand: Integrand,
1221
1244
  kernel: wp.Kernel,
1222
1245
  auxiliary_kernels: List[Tuple[wp.Kernel, int]],
1223
- FieldStruct: wp.codegen.Struct,
1224
- ValueStruct: wp.codegen.Struct,
1246
+ field_arg_values: StructInstance,
1247
+ value_struct_values: StructInstance,
1225
1248
  domain: GeometryDomain,
1226
1249
  quadrature: Quadrature,
1227
1250
  test: Optional[TestField],
@@ -1243,12 +1266,11 @@ def _launch_integrate_kernel(
1243
1266
  if quadrature is not None:
1244
1267
  qp_arg = quadrature.arg_value(device=device)
1245
1268
 
1246
- field_arg_values = FieldStruct()
1247
1269
  for k, v in fields.items():
1248
1270
  if not isinstance(v, GeometryDomain):
1249
1271
  v.fill_eval_arg(getattr(field_arg_values, k), device=device)
1250
1272
 
1251
- value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
1273
+ cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
1252
1274
 
1253
1275
  # Constant form
1254
1276
  if test is None and trial is None:
@@ -1257,14 +1279,13 @@ def _launch_integrate_kernel(
1257
1279
  raise RuntimeError("Output array must be of size at least 1")
1258
1280
  accumulate_array = output
1259
1281
  else:
1260
- accumulate_temporary = cache.borrow_temporary(
1282
+ accumulate_array = cache.borrow_temporary(
1261
1283
  shape=(1),
1262
1284
  device=device,
1263
1285
  dtype=accumulate_dtype,
1264
1286
  temporary_store=temporary_store,
1265
1287
  requires_grad=output is not None and output.requires_grad,
1266
1288
  )
1267
- accumulate_array = accumulate_temporary.array
1268
1289
 
1269
1290
  if output != accumulate_array or not add_to_output:
1270
1291
  accumulate_array.zero_()
@@ -1315,21 +1336,17 @@ def _launch_integrate_kernel(
1315
1336
  output_shape = (test.space_partition.node_count(), test.node_dof_count)
1316
1337
  else:
1317
1338
  raise RuntimeError(
1318
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1339
+ f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1319
1340
  )
1320
1341
 
1321
- output_temporary = cache.borrow_temporary(
1342
+ output = cache.borrow_temporary(
1322
1343
  temporary_store=temporary_store,
1323
1344
  shape=output_shape,
1324
1345
  dtype=output_dtype,
1325
1346
  device=device,
1326
1347
  )
1327
1348
 
1328
- output = output_temporary.array
1329
-
1330
1349
  else:
1331
- output_temporary = None
1332
-
1333
1350
  if output.shape[0] < test.space_partition.node_count():
1334
1351
  raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
1335
1352
 
@@ -1337,7 +1354,7 @@ def _launch_integrate_kernel(
1337
1354
  if type_size(output_dtype) != test.node_dof_count:
1338
1355
  if type_size(output_dtype) != 1:
1339
1356
  raise RuntimeError(
1340
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1357
+ f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1341
1358
  )
1342
1359
  if output.ndim != 2 and output.shape[1] != test.node_dof_count:
1343
1360
  raise RuntimeError(
@@ -1355,7 +1372,7 @@ def _launch_integrate_kernel(
1355
1372
  capacity=array.capacity,
1356
1373
  device=array.device,
1357
1374
  shape=(test.space_partition.node_count(), test.node_dof_count),
1358
- dtype=wp.types.type_scalar_type(output_dtype),
1375
+ dtype=type_scalar_type(output_dtype),
1359
1376
  grad=None if array.grad is None else as_2d_array(array.grad),
1360
1377
  )
1361
1378
 
@@ -1387,7 +1404,7 @@ def _launch_integrate_kernel(
1387
1404
 
1388
1405
  wp.launch(
1389
1406
  kernel=kernel,
1390
- dim=local_result.array.shape,
1407
+ dim=local_result.shape,
1391
1408
  inputs=[
1392
1409
  qp_arg,
1393
1410
  quadrature.element_index_arg_value(device),
@@ -1395,13 +1412,13 @@ def _launch_integrate_kernel(
1395
1412
  domain_elt_index_arg,
1396
1413
  field_arg_values,
1397
1414
  value_struct_values,
1398
- local_result.array,
1415
+ local_result,
1399
1416
  ],
1400
1417
  device=device,
1401
1418
  )
1402
1419
 
1403
1420
  if test.TAYLOR_DOF_COUNT == 0:
1404
- wp.utils.warn(
1421
+ warn(
1405
1422
  f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
1406
1423
  category=UserWarning,
1407
1424
  stacklevel=2,
@@ -1418,7 +1435,7 @@ def _launch_integrate_kernel(
1418
1435
  domain_elt_index_arg,
1419
1436
  test_arg,
1420
1437
  test.space.space_arg_value(device),
1421
- local_result.array,
1438
+ local_result,
1422
1439
  output_view,
1423
1440
  ],
1424
1441
  device=device,
@@ -1442,9 +1459,6 @@ def _launch_integrate_kernel(
1442
1459
  device=device,
1443
1460
  )
1444
1461
 
1445
- if output_temporary is not None:
1446
- return output_temporary.detach()
1447
-
1448
1462
  return output
1449
1463
 
1450
1464
  # Bilinear form
@@ -1475,8 +1489,6 @@ def _launch_integrate_kernel(
1475
1489
  triplet_rows = triplet_rows_temp.array
1476
1490
  triplet_values = triplet_values_temp.array
1477
1491
 
1478
- triplet_values.zero_()
1479
-
1480
1492
  if nodal:
1481
1493
  wp.launch(
1482
1494
  kernel=kernel,
@@ -1524,13 +1536,13 @@ def _launch_integrate_kernel(
1524
1536
  domain_elt_index_arg,
1525
1537
  field_arg_values,
1526
1538
  value_struct_values,
1527
- local_result.array,
1539
+ local_result,
1528
1540
  ],
1529
1541
  device=device,
1530
1542
  )
1531
1543
 
1532
1544
  if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
1533
- wp.utils.warn(
1545
+ warn(
1534
1546
  f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
1535
1547
  category=UserWarning,
1536
1548
  stacklevel=2,
@@ -1557,7 +1569,7 @@ def _launch_integrate_kernel(
1557
1569
  trial_partition_arg,
1558
1570
  trial_topology_arg,
1559
1571
  trial.space.space_arg_value(device),
1560
- local_result.array,
1572
+ local_result,
1561
1573
  triplet_rows,
1562
1574
  triplet_cols,
1563
1575
  triplet_values,
@@ -1626,20 +1638,12 @@ def _launch_integrate_kernel(
1626
1638
 
1627
1639
 
1628
1640
  def _pick_assembly_strategy(
1629
- assembly: Optional[str], nodal: bool, operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
1641
+ assembly: Optional[str], operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
1630
1642
  ):
1631
1643
  if assembly is not None:
1632
1644
  if assembly not in ("generic", "nodal", "dispatch"):
1633
1645
  raise ValueError(f"Invalid assembly strategy'{assembly}'")
1634
1646
  return assembly
1635
- elif nodal is not None:
1636
- wp.utils.warn(
1637
- "'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
1638
- category=DeprecationWarning,
1639
- stacklevel=2,
1640
- )
1641
- if nodal:
1642
- return "nodal"
1643
1647
 
1644
1648
  test_operators = operators.get(arguments.test_name, set())
1645
1649
  trial_operators = operators.get(arguments.trial_name, set())
@@ -1655,7 +1659,6 @@ def integrate(
1655
1659
  integrand: Integrand,
1656
1660
  domain: Optional[GeometryDomain] = None,
1657
1661
  quadrature: Optional[Quadrature] = None,
1658
- nodal: Optional[bool] = None,
1659
1662
  fields: Optional[Dict[str, FieldLike]] = None,
1660
1663
  values: Optional[Dict[str, Any]] = None,
1661
1664
  accumulate_dtype: type = wp.float64,
@@ -1675,7 +1678,6 @@ def integrate(
1675
1678
  integrand: Form to be integrated, must have :func:`integrand` decorator
1676
1679
  domain: Integration domain. If None, deduced from fields
1677
1680
  quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1678
- nodal: Deprecated. Use the equivalent assembly="nodal" instead.
1679
1681
  fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1680
1682
  values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1681
1683
  temporary_store: shared pool from which to allocate temporary arrays
@@ -1738,9 +1740,9 @@ def integrate(
1738
1740
  _find_integrand_operators(integrand, arguments.field_args)
1739
1741
 
1740
1742
  if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
1741
- wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
1743
+ warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
1742
1744
 
1743
- assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
1745
+ assembly = _pick_assembly_strategy(assembly, arguments=arguments, operators=integrand.operators)
1744
1746
  # print("assembly for ", integrand.name, ":", strategy)
1745
1747
 
1746
1748
  if assembly == "dispatch":
@@ -1770,7 +1772,7 @@ def integrate(
1770
1772
  raise ValueError("Incompatible integration and quadrature domain")
1771
1773
 
1772
1774
  # Canonicalize types
1773
- accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
1775
+ accumulate_dtype = type_to_warp(accumulate_dtype)
1774
1776
  if output is not None:
1775
1777
  if isinstance(output, BsrMatrix):
1776
1778
  output_dtype = output.scalar_type
@@ -1779,9 +1781,9 @@ def integrate(
1779
1781
  elif output_dtype is None:
1780
1782
  output_dtype = accumulate_dtype
1781
1783
  else:
1782
- output_dtype = wp.types.type_to_warp(output_dtype)
1784
+ output_dtype = type_to_warp(output_dtype)
1783
1785
 
1784
- kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1786
+ kernel, field_arg_values, value_struct_values = _generate_integrate_kernel(
1785
1787
  integrand=integrand,
1786
1788
  domain=domain,
1787
1789
  quadrature=quadrature,
@@ -1806,8 +1808,8 @@ def integrate(
1806
1808
  integrand=integrand,
1807
1809
  kernel=kernel,
1808
1810
  auxiliary_kernels=auxiliary_kernels,
1809
- FieldStruct=FieldStruct,
1810
- ValueStruct=ValueStruct,
1811
+ field_arg_values=field_arg_values,
1812
+ value_struct_values=value_struct_values,
1811
1813
  domain=domain,
1812
1814
  quadrature=quadrature,
1813
1815
  test=test,
@@ -1827,14 +1829,14 @@ def integrate(
1827
1829
  def get_interpolate_to_field_function(
1828
1830
  integrand_func: wp.Function,
1829
1831
  domain: GeometryDomain,
1830
- FieldStruct: wp.codegen.Struct,
1831
- ValueStruct: wp.codegen.Struct,
1832
+ FieldStruct: Struct,
1833
+ ValueStruct: Struct,
1832
1834
  dest: FieldRestriction,
1833
1835
  ):
1834
1836
  zero_value = type_zero_element(dest.space.dtype)
1835
1837
 
1836
1838
  def interpolate_to_field_fn(
1837
- local_node_index: int,
1839
+ partition_node_index: int,
1838
1840
  domain_arg: domain.ElementArg,
1839
1841
  domain_index_arg: domain.ElementIndexArg,
1840
1842
  dest_node_arg: dest.space_restriction.NodeArg,
@@ -1842,7 +1844,6 @@ def get_interpolate_to_field_function(
1842
1844
  fields: FieldStruct,
1843
1845
  values: ValueStruct,
1844
1846
  ):
1845
- partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1846
1847
  element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1847
1848
 
1848
1849
  test_dof_index = NULL_DOF_INDEX
@@ -1894,8 +1895,8 @@ def get_interpolate_to_field_function(
1894
1895
  def get_interpolate_to_field_kernel(
1895
1896
  interpolate_to_field_fn: wp.Function,
1896
1897
  domain: GeometryDomain,
1897
- FieldStruct: wp.codegen.Struct,
1898
- ValueStruct: wp.codegen.Struct,
1898
+ FieldStruct: Struct,
1899
+ ValueStruct: Struct,
1899
1900
  dest: FieldRestriction,
1900
1901
  ):
1901
1902
  @wp.func
@@ -1932,13 +1933,15 @@ def get_interpolate_to_field_kernel(
1932
1933
  ):
1933
1934
  local_node_index = wp.tid()
1934
1935
 
1936
+ partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1937
+ if partition_node_index == NULL_NODE_INDEX:
1938
+ return
1939
+
1935
1940
  val_sum, vol_sum = interpolate_to_field_fn(
1936
1941
  local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1937
1942
  )
1938
1943
 
1939
1944
  if vol_sum > 0.0:
1940
- partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1941
-
1942
1945
  # Grab first element containing node; there must be at least one since vol_sum != 0
1943
1946
  element_index, node_index_in_element = _find_node_in_element(
1944
1947
  domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
@@ -1959,8 +1962,8 @@ def get_interpolate_at_quadrature_kernel(
1959
1962
  integrand_func: wp.Function,
1960
1963
  domain: GeometryDomain,
1961
1964
  quadrature: Quadrature,
1962
- FieldStruct: wp.codegen.Struct,
1963
- ValueStruct: wp.codegen.Struct,
1965
+ FieldStruct: Struct,
1966
+ ValueStruct: Struct,
1964
1967
  value_type: type,
1965
1968
  ):
1966
1969
  def interpolate_at_quadrature_nonvalued_kernel_fn(
@@ -1978,6 +1981,8 @@ def get_interpolate_at_quadrature_kernel(
1978
1981
  return
1979
1982
 
1980
1983
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1984
+ if element_index == NULL_ELEMENT_INDEX:
1985
+ return
1981
1986
 
1982
1987
  test_dof_index = NULL_DOF_INDEX
1983
1988
  trial_dof_index = NULL_DOF_INDEX
@@ -2004,6 +2009,8 @@ def get_interpolate_at_quadrature_kernel(
2004
2009
  return
2005
2010
 
2006
2011
  element_index = domain.element_index(domain_index_arg, domain_element_index)
2012
+ if element_index == NULL_ELEMENT_INDEX:
2013
+ return
2007
2014
 
2008
2015
  test_dof_index = NULL_DOF_INDEX
2009
2016
  trial_dof_index = NULL_DOF_INDEX
@@ -2022,8 +2029,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
2022
2029
  integrand_func: wp.Function,
2023
2030
  domain: GeometryDomain,
2024
2031
  quadrature: Quadrature,
2025
- FieldStruct: wp.codegen.Struct,
2026
- ValueStruct: wp.codegen.Struct,
2032
+ FieldStruct: Struct,
2033
+ ValueStruct: Struct,
2027
2034
  trial: TrialField,
2028
2035
  value_size: int,
2029
2036
  value_type: type,
@@ -2046,11 +2053,13 @@ def get_interpolate_jacobian_at_quadrature_kernel(
2046
2053
  ):
2047
2054
  qp_eval_index, trial_node, trial_dof = wp.tid()
2048
2055
  domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
2049
-
2050
2056
  if domain_element_index == NULL_ELEMENT_INDEX:
2051
2057
  return
2052
2058
 
2053
2059
  element_index = domain.element_index(domain_index_arg, domain_element_index)
2060
+ if element_index == NULL_ELEMENT_INDEX:
2061
+ return
2062
+
2054
2063
  if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
2055
2064
  return
2056
2065
 
@@ -2090,8 +2099,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
2090
2099
  def get_interpolate_free_kernel(
2091
2100
  integrand_func: wp.Function,
2092
2101
  domain: GeometryDomain,
2093
- FieldStruct: wp.codegen.Struct,
2094
- ValueStruct: wp.codegen.Struct,
2102
+ FieldStruct: Struct,
2103
+ ValueStruct: Struct,
2095
2104
  value_type: type,
2096
2105
  ):
2097
2106
  def interpolate_free_nonvalued_kernel_fn(
@@ -2144,31 +2153,31 @@ def _generate_interpolate_kernel(
2144
2153
  arguments: IntegrandArguments,
2145
2154
  kernel_options: Optional[Dict[str, Any]] = None,
2146
2155
  ) -> wp.Kernel:
2147
- # Generate field struct
2148
- FieldStruct = _gen_field_struct(arguments.field_args)
2149
- ValueStruct = cache.get_argument_struct(arguments.value_args)
2150
-
2151
2156
  _notify_operator_usage(integrand, arguments.field_args)
2152
2157
 
2153
2158
  # Check if kernel exist in cache
2154
- field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
2159
+ field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
2155
2160
  if isinstance(dest, FieldRestriction):
2156
- kernel_suffix = f"_itp_{field_names}_{dest.domain.name}_{dest.space_restriction.space_partition.name}"
2161
+ kernel_suffix = ("itp", *field_names, dest.domain.name, dest.space_restriction.space_partition.name)
2157
2162
  else:
2158
2163
  dest_dtype = dest.dtype if dest else None
2159
- type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
2164
+ type_str = cache.pod_type_key(dest_dtype) if dest_dtype else ""
2160
2165
  if quadrature is None:
2161
- kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
2166
+ kernel_suffix = ("itp", *field_names, domain.name, type_str)
2162
2167
  else:
2163
- kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
2168
+ kernel_suffix = ("itp", *field_names, domain.name, quadrature.name, type_str)
2164
2169
 
2165
- kernel = cache.get_integrand_kernel(
2170
+ kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
2166
2171
  integrand=integrand,
2167
2172
  suffix=kernel_suffix,
2168
2173
  kernel_options=kernel_options,
2169
2174
  )
2170
2175
  if kernel is not None:
2171
- return kernel, FieldStruct, ValueStruct
2176
+ return kernel, field_arg_values, value_struct_values
2177
+
2178
+ # Generate field struct
2179
+ FieldStruct = _gen_field_struct(arguments.field_args)
2180
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
2172
2181
 
2173
2182
  # Not found in cache, transform integrand and generate kernel
2174
2183
  _check_field_compat(integrand, arguments, domain)
@@ -2235,7 +2244,7 @@ def _generate_interpolate_kernel(
2235
2244
  ValueStruct=ValueStruct,
2236
2245
  )
2237
2246
 
2238
- kernel = cache.get_integrand_kernel(
2247
+ kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
2239
2248
  integrand=integrand,
2240
2249
  kernel_fn=interpolate_kernel_fn,
2241
2250
  suffix=kernel_suffix,
@@ -2245,16 +2254,18 @@ def _generate_interpolate_kernel(
2245
2254
  arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
2246
2255
  )
2247
2256
  ],
2257
+ FieldStruct=FieldStruct,
2258
+ ValueStruct=ValueStruct,
2248
2259
  )
2249
2260
 
2250
- return kernel, FieldStruct, ValueStruct
2261
+ return kernel, FieldStruct(), ValueStruct()
2251
2262
 
2252
2263
 
2253
2264
  def _launch_interpolate_kernel(
2254
2265
  integrand: Integrand,
2255
2266
  kernel: wp.kernel,
2256
- FieldStruct: wp.codegen.Struct,
2257
- ValueStruct: wp.codegen.Struct,
2267
+ field_arg_values: StructInstance,
2268
+ value_struct_values: StructInstance,
2258
2269
  domain: GeometryDomain,
2259
2270
  dest: Optional[Union[FieldRestriction, wp.array]],
2260
2271
  quadrature: Optional[Quadrature],
@@ -2270,12 +2281,10 @@ def _launch_interpolate_kernel(
2270
2281
  elt_arg = domain.element_arg_value(device=device)
2271
2282
  elt_index_arg = domain.element_index_arg_value(device=device)
2272
2283
 
2273
- field_arg_values = FieldStruct()
2274
2284
  for k, v in fields.items():
2275
2285
  if not isinstance(v, GeometryDomain):
2276
2286
  v.fill_eval_arg(getattr(field_arg_values, k), device=device)
2277
-
2278
- value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
2287
+ cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
2279
2288
 
2280
2289
  if isinstance(dest, FieldRestriction):
2281
2290
  dest_node_arg = dest.space_restriction.node_arg_value(device=device)
@@ -2313,7 +2322,7 @@ def _launch_interpolate_kernel(
2313
2322
  qp_index_count = quadrature.total_point_count()
2314
2323
 
2315
2324
  if qp_eval_count != qp_index_count:
2316
- wp.utils.warn(
2325
+ warn(
2317
2326
  f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
2318
2327
  category=UserWarning,
2319
2328
  stacklevel=2,
@@ -2353,7 +2362,6 @@ def _launch_interpolate_kernel(
2353
2362
  triplet_rows = triplet_rows_temp.array
2354
2363
  triplet_values = triplet_values_temp.array
2355
2364
  triplet_rows.fill_(-1)
2356
- triplet_values.zero_()
2357
2365
 
2358
2366
  trial_partition_arg = trial.space_partition.partition_arg_value(device)
2359
2367
  trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
@@ -2470,9 +2478,9 @@ def interpolate(
2470
2478
  _find_integrand_operators(integrand, arguments.field_args)
2471
2479
 
2472
2480
  if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
2473
- wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
2481
+ warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
2474
2482
 
2475
- kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
2483
+ kernel, field_struct, value_struct = _generate_interpolate_kernel(
2476
2484
  integrand=integrand,
2477
2485
  domain=domain,
2478
2486
  dest=dest,
@@ -2484,8 +2492,8 @@ def interpolate(
2484
2492
  return _launch_interpolate_kernel(
2485
2493
  integrand=integrand,
2486
2494
  kernel=kernel,
2487
- FieldStruct=FieldStruct,
2488
- ValueStruct=ValueStruct,
2495
+ field_arg_values=field_struct,
2496
+ value_struct_values=value_struct,
2489
2497
  domain=domain,
2490
2498
  dest=dest,
2491
2499
  quadrature=quadrature,