warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,11 @@ import textwrap
19
19
  from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
20
20
 
21
21
  import warp as wp
22
- import warp.fem.operator as operator
23
- from warp.codegen import get_annotations
24
- from warp.fem import cache
25
- from warp.fem.domain import GeometryDomain
26
- from warp.fem.field import (
22
+ import warp._src.fem.operator as operator
23
+ from warp._src.codegen import Struct, StructInstance, get_annotations
24
+ from warp._src.fem import cache
25
+ from warp._src.fem.domain import GeometryDomain
26
+ from warp._src.fem.field import (
27
27
  DiscreteField,
28
28
  FieldLike,
29
29
  FieldRestriction,
@@ -34,18 +34,18 @@ from warp.fem.field import (
34
34
  TrialField,
35
35
  make_restriction,
36
36
  )
37
- from warp.fem.field.virtual import (
37
+ from warp._src.fem.field.virtual import (
38
38
  make_bilinear_dispatch_kernel,
39
39
  make_linear_dispatch_kernel,
40
40
  )
41
- from warp.fem.linalg import array_axpy, basis_coefficient
42
- from warp.fem.operator import (
41
+ from warp._src.fem.linalg import array_axpy, basis_coefficient
42
+ from warp._src.fem.operator import (
43
43
  Integrand,
44
44
  Operator,
45
45
  integrand,
46
46
  )
47
- from warp.fem.quadrature import Quadrature, RegularQuadrature
48
- from warp.fem.types import (
47
+ from warp._src.fem.quadrature import Quadrature, RegularQuadrature
48
+ from warp._src.fem.types import (
49
49
  NULL_DOF_INDEX,
50
50
  NULL_ELEMENT_INDEX,
51
51
  NULL_NODE_INDEX,
@@ -57,15 +57,15 @@ from warp.fem.types import (
57
57
  Sample,
58
58
  make_free_sample,
59
59
  )
60
- from warp.fem.utils import type_zero_element
61
- from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
62
- from warp.types import is_array, type_size
63
- from warp.utils import array_cast
60
+ from warp._src.fem.utils import type_zero_element
61
+ from warp._src.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
62
+ from warp._src.types import is_array, type_repr, type_scalar_type, type_size, type_to_warp
63
+ from warp._src.utils import array_cast, warn
64
64
 
65
65
 
66
66
  def _resolve_path(func, node):
67
67
  """
68
- Resolves variable and path from ast node/attribute (adapted from warp.codegen)
68
+ Resolves variable and path from ast node/attribute (adapted from warp._src.codegen)
69
69
  """
70
70
 
71
71
  modules = []
@@ -83,20 +83,20 @@ def _resolve_path(func, node):
83
83
  if len(path) == 0:
84
84
  return None, path
85
85
 
86
- # try and evaluate object path
86
+ name = path[0]
87
87
  try:
88
- # Look up the closure info and append it to adj.func.__globals__
89
- # in case you want to define a kernel inside a function and refer
90
- # to variables you've declared inside that function:
91
- capturedvars = dict(zip(func.__code__.co_freevars, [c.cell_contents for c in (func.__closure__ or [])]))
92
-
93
- vars_dict = {**func.__globals__, **capturedvars}
94
- func = eval(".".join(path), vars_dict)
95
- return func, path
96
- except (NameError, AttributeError):
97
- pass
88
+ # look up in closure variables
89
+ idx = func.__code__.co_freevars.index(name)
90
+ expr = func.__closure__[idx].cell_contents
91
+ except ValueError:
92
+ # look up in global variables
93
+ expr = func.__globals__.get(name)
94
+
95
+ for name in path[1:]:
96
+ if expr is not None:
97
+ expr = getattr(expr, name, None)
98
98
 
99
- return None, path
99
+ return expr, path
100
100
 
101
101
 
102
102
  class IntegrandVisitor(ast.NodeTransformer):
@@ -275,7 +275,7 @@ class IntegrandTransformer(IntegrandVisitor):
275
275
  try:
276
276
  # Retrieve the function pointer corresponding to the operator implementation for the field type
277
277
  pointer = operator.resolver(field)
278
- if not isinstance(pointer, wp.context.Function):
278
+ if not isinstance(pointer, wp.Function):
279
279
  raise NotImplementedError(operator.resolver.__name__)
280
280
 
281
281
  except (AttributeError, NotImplementedError) as e:
@@ -360,15 +360,13 @@ def _parse_integrand_arguments(
360
360
  trial_name = None
361
361
 
362
362
  argspec = integrand.argspec
363
- for arg in argspec.args:
364
- arg_type = argspec.annotations[arg]
363
+ for arg, arg_type in argspec.annotations.items():
365
364
  if arg_type == Field:
366
365
  try:
367
366
  field = fields[arg]
368
367
  except KeyError as err:
369
368
  raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
370
- if not isinstance(field, FieldLike):
371
- raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
369
+
372
370
  if isinstance(field, TestField):
373
371
  if test_name is not None:
374
372
  raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
@@ -377,28 +375,26 @@ def _parse_integrand_arguments(
377
375
  if trial_name is not None:
378
376
  raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
379
377
  trial_name = arg
378
+ elif not isinstance(field, FieldLike):
379
+ raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
380
+
380
381
  field_args[arg] = field
381
- elif arg_type == Domain:
382
+ continue
383
+
384
+ if arg in fields:
385
+ raise ValueError(
386
+ f"Cannot pass a field argument to '{arg}' of '{integrand.name}' which is not of type 'Field'"
387
+ )
388
+
389
+ if arg_type == Domain:
382
390
  if domain_name is not None:
383
391
  raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
384
- if arg in fields:
385
- raise ValueError(
386
- f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
387
- )
388
392
  domain_name = arg
389
393
  elif arg_type == Sample:
390
394
  if sample_name is not None:
391
395
  raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
392
- if arg in fields:
393
- raise ValueError(
394
- f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
395
- )
396
396
  sample_name = arg
397
397
  else:
398
- if arg in fields:
399
- raise ValueError(
400
- f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
401
- )
402
398
  value_args[arg] = arg_type
403
399
 
404
400
  return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
@@ -438,10 +434,8 @@ def _notify_operator_usage(
438
434
  integrand: Integrand,
439
435
  field_args: Dict[str, FieldLike],
440
436
  ):
441
- for arg, field_ops in integrand.operators.items():
442
- if arg in field_args:
443
- # print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
444
- field_args[arg].notify_operator_usage(field_ops)
437
+ for arg, field in field_args.items():
438
+ field.notify_operator_usage(integrand.operators.get(arg, set()))
445
439
 
446
440
 
447
441
  def _gen_field_struct(field_args: Dict[str, FieldLike]):
@@ -615,8 +609,8 @@ def get_integrate_constant_kernel(
615
609
  integrand_func: wp.Function,
616
610
  domain: GeometryDomain,
617
611
  quadrature: Quadrature,
618
- FieldStruct: wp.codegen.Struct,
619
- ValueStruct: wp.codegen.Struct,
612
+ FieldStruct: Struct,
613
+ ValueStruct: Struct,
620
614
  accumulate_dtype,
621
615
  tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
622
616
  ):
@@ -641,10 +635,13 @@ def get_integrate_constant_kernel(
641
635
  domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
642
636
 
643
637
  if domain_element_index == NULL_ELEMENT_INDEX:
644
- val = zero_element()
638
+ element_index = NULL_ELEMENT_INDEX
645
639
  else:
646
640
  element_index = domain.element_index(domain_index_arg, domain_element_index)
647
641
 
642
+ if element_index == NULL_ELEMENT_INDEX:
643
+ val = zero_element()
644
+ else:
648
645
  qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
649
646
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
650
647
  qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
@@ -667,8 +664,8 @@ def get_integrate_linear_kernel(
667
664
  integrand_func: wp.Function,
668
665
  domain: GeometryDomain,
669
666
  quadrature: Quadrature,
670
- FieldStruct: wp.codegen.Struct,
671
- ValueStruct: wp.codegen.Struct,
667
+ FieldStruct: Struct,
668
+ ValueStruct: Struct,
672
669
  test: TestField,
673
670
  output_dtype,
674
671
  accumulate_dtype,
@@ -684,6 +681,9 @@ def get_integrate_linear_kernel(
684
681
  ):
685
682
  local_node_index, test_dof = wp.tid()
686
683
  node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
684
+ if node_index == NULL_NODE_INDEX:
685
+ return
686
+
687
687
  element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
688
688
 
689
689
  trial_dof_index = NULL_DOF_INDEX
@@ -725,8 +725,8 @@ def get_integrate_linear_kernel(
725
725
  def get_integrate_linear_nodal_kernel(
726
726
  integrand_func: wp.Function,
727
727
  domain: GeometryDomain,
728
- FieldStruct: wp.codegen.Struct,
729
- ValueStruct: wp.codegen.Struct,
728
+ FieldStruct: Struct,
729
+ ValueStruct: Struct,
730
730
  test: TestField,
731
731
  output_dtype,
732
732
  accumulate_dtype,
@@ -743,6 +743,9 @@ def get_integrate_linear_nodal_kernel(
743
743
  local_node_index, dof = wp.tid()
744
744
 
745
745
  partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
746
+ if partition_node_index == NULL_NODE_INDEX:
747
+ return
748
+
746
749
  element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
747
750
 
748
751
  trial_dof_index = NULL_DOF_INDEX
@@ -797,8 +800,8 @@ def get_integrate_linear_local_kernel(
797
800
  integrand_func: wp.Function,
798
801
  domain: GeometryDomain,
799
802
  quadrature: Quadrature,
800
- FieldStruct: wp.codegen.Struct,
801
- ValueStruct: wp.codegen.Struct,
803
+ FieldStruct: Struct,
804
+ ValueStruct: Struct,
802
805
  test: LocalTestField,
803
806
  ):
804
807
  def integrate_kernel_fn(
@@ -817,6 +820,8 @@ def get_integrate_linear_local_kernel(
817
820
  return
818
821
 
819
822
  element_index = domain.element_index(domain_index_arg, domain_element_index)
823
+ if element_index == NULL_ELEMENT_INDEX:
824
+ return
820
825
 
821
826
  qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
822
827
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
@@ -838,8 +843,8 @@ def get_integrate_bilinear_kernel(
838
843
  integrand_func: wp.Function,
839
844
  domain: GeometryDomain,
840
845
  quadrature: Quadrature,
841
- FieldStruct: wp.codegen.Struct,
842
- ValueStruct: wp.codegen.Struct,
846
+ FieldStruct: Struct,
847
+ ValueStruct: Struct,
843
848
  test: TestField,
844
849
  trial: TrialField,
845
850
  output_dtype,
@@ -863,6 +868,9 @@ def get_integrate_bilinear_kernel(
863
868
  test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
864
869
 
865
870
  test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
871
+ if test_node_index == NULL_NODE_INDEX:
872
+ return
873
+
866
874
  element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
867
875
 
868
876
  trial_dof_index = DofIndex(trial_node, trial_dof)
@@ -934,8 +942,8 @@ def get_integrate_bilinear_kernel(
934
942
  def get_integrate_bilinear_nodal_kernel(
935
943
  integrand_func: wp.Function,
936
944
  domain: GeometryDomain,
937
- FieldStruct: wp.codegen.Struct,
938
- ValueStruct: wp.codegen.Struct,
945
+ FieldStruct: Struct,
946
+ ValueStruct: Struct,
939
947
  test: TestField,
940
948
  output_dtype,
941
949
  accumulate_dtype,
@@ -954,6 +962,11 @@ def get_integrate_bilinear_nodal_kernel(
954
962
  local_node_index, test_dof, trial_dof = wp.tid()
955
963
 
956
964
  partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
965
+ if partition_node_index == NULL_NODE_INDEX:
966
+ triplet_rows[local_node_index] = -1
967
+ triplet_cols[local_node_index] = -1
968
+ return
969
+
957
970
  element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
958
971
 
959
972
  val_sum = accumulate_dtype(0.0)
@@ -1009,8 +1022,8 @@ def get_integrate_bilinear_local_kernel(
1009
1022
  integrand_func: wp.Function,
1010
1023
  domain: GeometryDomain,
1011
1024
  quadrature: Quadrature,
1012
- FieldStruct: wp.codegen.Struct,
1013
- ValueStruct: wp.codegen.Struct,
1025
+ FieldStruct: Struct,
1026
+ ValueStruct: Struct,
1014
1027
  test: LocalTestField,
1015
1028
  trial: LocalTrialField,
1016
1029
  ):
@@ -1033,6 +1046,8 @@ def get_integrate_bilinear_local_kernel(
1033
1046
  return
1034
1047
 
1035
1048
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1049
+ if element_index == NULL_ELEMENT_INDEX:
1050
+ return
1036
1051
 
1037
1052
  qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1038
1053
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
@@ -1066,23 +1081,27 @@ def _generate_integrate_kernel(
1066
1081
  accumulate_dtype: type,
1067
1082
  kernel_options: Optional[Dict[str, Any]] = None,
1068
1083
  ) -> wp.Kernel:
1069
- output_dtype = wp.types.type_scalar_type(output_dtype)
1070
-
1071
- FieldStruct = _gen_field_struct(arguments.field_args)
1072
- ValueStruct = cache.get_argument_struct(arguments.value_args)
1084
+ output_dtype = type_scalar_type(output_dtype)
1073
1085
 
1074
1086
  _notify_operator_usage(integrand, arguments.field_args)
1075
1087
 
1076
1088
  # Check if kernel exist in cache
1077
- field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
1078
- kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{field_names}"
1089
+ field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
1090
+ kernel_suffix = ("itg", field_names, cache.pod_type_key(output_dtype), cache.pod_type_key(accumulate_dtype))
1079
1091
 
1080
1092
  if quadrature is not None:
1081
- kernel_suffix += quadrature.name
1093
+ kernel_suffix = (quadrature.name, *kernel_suffix)
1082
1094
 
1083
- kernel = cache.get_integrand_kernel(integrand=integrand, suffix=kernel_suffix, kernel_options=kernel_options)
1095
+ kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
1096
+ integrand=integrand,
1097
+ suffix=kernel_suffix,
1098
+ kernel_options=kernel_options,
1099
+ )
1084
1100
  if kernel is not None:
1085
- return kernel, FieldStruct, ValueStruct
1101
+ return kernel, field_arg_values, value_struct_values
1102
+
1103
+ FieldStruct = _gen_field_struct(arguments.field_args)
1104
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
1086
1105
 
1087
1106
  # Not found in cache, transform integrand and generate kernel
1088
1107
  _check_field_compat(integrand, arguments, domain)
@@ -1165,7 +1184,7 @@ def _generate_integrate_kernel(
1165
1184
  accumulate_dtype=accumulate_dtype,
1166
1185
  )
1167
1186
 
1168
- kernel = cache.get_integrand_kernel(
1187
+ kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
1169
1188
  integrand=integrand,
1170
1189
  kernel_fn=integrate_kernel_fn,
1171
1190
  suffix=kernel_suffix,
@@ -1175,9 +1194,11 @@ def _generate_integrate_kernel(
1175
1194
  arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
1176
1195
  )
1177
1196
  ],
1197
+ FieldStruct=FieldStruct,
1198
+ ValueStruct=ValueStruct,
1178
1199
  )
1179
1200
 
1180
- return kernel, FieldStruct, ValueStruct
1201
+ return kernel, FieldStruct(), ValueStruct()
1181
1202
 
1182
1203
 
1183
1204
  def _generate_auxiliary_kernels(
@@ -1220,8 +1241,8 @@ def _launch_integrate_kernel(
1220
1241
  integrand: Integrand,
1221
1242
  kernel: wp.Kernel,
1222
1243
  auxiliary_kernels: List[Tuple[wp.Kernel, int]],
1223
- FieldStruct: wp.codegen.Struct,
1224
- ValueStruct: wp.codegen.Struct,
1244
+ field_arg_values: StructInstance,
1245
+ value_struct_values: StructInstance,
1225
1246
  domain: GeometryDomain,
1226
1247
  quadrature: Quadrature,
1227
1248
  test: Optional[TestField],
@@ -1243,12 +1264,11 @@ def _launch_integrate_kernel(
1243
1264
  if quadrature is not None:
1244
1265
  qp_arg = quadrature.arg_value(device=device)
1245
1266
 
1246
- field_arg_values = FieldStruct()
1247
1267
  for k, v in fields.items():
1248
1268
  if not isinstance(v, GeometryDomain):
1249
1269
  v.fill_eval_arg(getattr(field_arg_values, k), device=device)
1250
1270
 
1251
- value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
1271
+ cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
1252
1272
 
1253
1273
  # Constant form
1254
1274
  if test is None and trial is None:
@@ -1257,14 +1277,13 @@ def _launch_integrate_kernel(
1257
1277
  raise RuntimeError("Output array must be of size at least 1")
1258
1278
  accumulate_array = output
1259
1279
  else:
1260
- accumulate_temporary = cache.borrow_temporary(
1280
+ accumulate_array = cache.borrow_temporary(
1261
1281
  shape=(1),
1262
1282
  device=device,
1263
1283
  dtype=accumulate_dtype,
1264
1284
  temporary_store=temporary_store,
1265
1285
  requires_grad=output is not None and output.requires_grad,
1266
1286
  )
1267
- accumulate_array = accumulate_temporary.array
1268
1287
 
1269
1288
  if output != accumulate_array or not add_to_output:
1270
1289
  accumulate_array.zero_()
@@ -1315,21 +1334,17 @@ def _launch_integrate_kernel(
1315
1334
  output_shape = (test.space_partition.node_count(), test.node_dof_count)
1316
1335
  else:
1317
1336
  raise RuntimeError(
1318
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1337
+ f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1319
1338
  )
1320
1339
 
1321
- output_temporary = cache.borrow_temporary(
1340
+ output = cache.borrow_temporary(
1322
1341
  temporary_store=temporary_store,
1323
1342
  shape=output_shape,
1324
1343
  dtype=output_dtype,
1325
1344
  device=device,
1326
1345
  )
1327
1346
 
1328
- output = output_temporary.array
1329
-
1330
1347
  else:
1331
- output_temporary = None
1332
-
1333
1348
  if output.shape[0] < test.space_partition.node_count():
1334
1349
  raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
1335
1350
 
@@ -1337,7 +1352,7 @@ def _launch_integrate_kernel(
1337
1352
  if type_size(output_dtype) != test.node_dof_count:
1338
1353
  if type_size(output_dtype) != 1:
1339
1354
  raise RuntimeError(
1340
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1355
+ f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1341
1356
  )
1342
1357
  if output.ndim != 2 and output.shape[1] != test.node_dof_count:
1343
1358
  raise RuntimeError(
@@ -1355,7 +1370,7 @@ def _launch_integrate_kernel(
1355
1370
  capacity=array.capacity,
1356
1371
  device=array.device,
1357
1372
  shape=(test.space_partition.node_count(), test.node_dof_count),
1358
- dtype=wp.types.type_scalar_type(output_dtype),
1373
+ dtype=type_scalar_type(output_dtype),
1359
1374
  grad=None if array.grad is None else as_2d_array(array.grad),
1360
1375
  )
1361
1376
 
@@ -1387,7 +1402,7 @@ def _launch_integrate_kernel(
1387
1402
 
1388
1403
  wp.launch(
1389
1404
  kernel=kernel,
1390
- dim=local_result.array.shape,
1405
+ dim=local_result.shape,
1391
1406
  inputs=[
1392
1407
  qp_arg,
1393
1408
  quadrature.element_index_arg_value(device),
@@ -1395,13 +1410,13 @@ def _launch_integrate_kernel(
1395
1410
  domain_elt_index_arg,
1396
1411
  field_arg_values,
1397
1412
  value_struct_values,
1398
- local_result.array,
1413
+ local_result,
1399
1414
  ],
1400
1415
  device=device,
1401
1416
  )
1402
1417
 
1403
1418
  if test.TAYLOR_DOF_COUNT == 0:
1404
- wp.utils.warn(
1419
+ warn(
1405
1420
  f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
1406
1421
  category=UserWarning,
1407
1422
  stacklevel=2,
@@ -1418,7 +1433,7 @@ def _launch_integrate_kernel(
1418
1433
  domain_elt_index_arg,
1419
1434
  test_arg,
1420
1435
  test.space.space_arg_value(device),
1421
- local_result.array,
1436
+ local_result,
1422
1437
  output_view,
1423
1438
  ],
1424
1439
  device=device,
@@ -1442,9 +1457,6 @@ def _launch_integrate_kernel(
1442
1457
  device=device,
1443
1458
  )
1444
1459
 
1445
- if output_temporary is not None:
1446
- return output_temporary.detach()
1447
-
1448
1460
  return output
1449
1461
 
1450
1462
  # Bilinear form
@@ -1475,8 +1487,6 @@ def _launch_integrate_kernel(
1475
1487
  triplet_rows = triplet_rows_temp.array
1476
1488
  triplet_values = triplet_values_temp.array
1477
1489
 
1478
- triplet_values.zero_()
1479
-
1480
1490
  if nodal:
1481
1491
  wp.launch(
1482
1492
  kernel=kernel,
@@ -1524,13 +1534,13 @@ def _launch_integrate_kernel(
1524
1534
  domain_elt_index_arg,
1525
1535
  field_arg_values,
1526
1536
  value_struct_values,
1527
- local_result.array,
1537
+ local_result,
1528
1538
  ],
1529
1539
  device=device,
1530
1540
  )
1531
1541
 
1532
1542
  if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
1533
- wp.utils.warn(
1543
+ warn(
1534
1544
  f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
1535
1545
  category=UserWarning,
1536
1546
  stacklevel=2,
@@ -1557,7 +1567,7 @@ def _launch_integrate_kernel(
1557
1567
  trial_partition_arg,
1558
1568
  trial_topology_arg,
1559
1569
  trial.space.space_arg_value(device),
1560
- local_result.array,
1570
+ local_result,
1561
1571
  triplet_rows,
1562
1572
  triplet_cols,
1563
1573
  triplet_values,
@@ -1626,20 +1636,12 @@ def _launch_integrate_kernel(
1626
1636
 
1627
1637
 
1628
1638
  def _pick_assembly_strategy(
1629
- assembly: Optional[str], nodal: bool, operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
1639
+ assembly: Optional[str], operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
1630
1640
  ):
1631
1641
  if assembly is not None:
1632
1642
  if assembly not in ("generic", "nodal", "dispatch"):
1633
1643
  raise ValueError(f"Invalid assembly strategy'{assembly}'")
1634
1644
  return assembly
1635
- elif nodal is not None:
1636
- wp.utils.warn(
1637
- "'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
1638
- category=DeprecationWarning,
1639
- stacklevel=2,
1640
- )
1641
- if nodal:
1642
- return "nodal"
1643
1645
 
1644
1646
  test_operators = operators.get(arguments.test_name, set())
1645
1647
  trial_operators = operators.get(arguments.trial_name, set())
@@ -1655,7 +1657,6 @@ def integrate(
1655
1657
  integrand: Integrand,
1656
1658
  domain: Optional[GeometryDomain] = None,
1657
1659
  quadrature: Optional[Quadrature] = None,
1658
- nodal: Optional[bool] = None,
1659
1660
  fields: Optional[Dict[str, FieldLike]] = None,
1660
1661
  values: Optional[Dict[str, Any]] = None,
1661
1662
  accumulate_dtype: type = wp.float64,
@@ -1675,7 +1676,6 @@ def integrate(
1675
1676
  integrand: Form to be integrated, must have :func:`integrand` decorator
1676
1677
  domain: Integration domain. If None, deduced from fields
1677
1678
  quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1678
- nodal: Deprecated. Use the equivalent assembly="nodal" instead.
1679
1679
  fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1680
1680
  values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1681
1681
  temporary_store: shared pool from which to allocate temporary arrays
@@ -1738,9 +1738,9 @@ def integrate(
1738
1738
  _find_integrand_operators(integrand, arguments.field_args)
1739
1739
 
1740
1740
  if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
1741
- wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
1741
+ warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
1742
1742
 
1743
- assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
1743
+ assembly = _pick_assembly_strategy(assembly, arguments=arguments, operators=integrand.operators)
1744
1744
  # print("assembly for ", integrand.name, ":", strategy)
1745
1745
 
1746
1746
  if assembly == "dispatch":
@@ -1770,7 +1770,7 @@ def integrate(
1770
1770
  raise ValueError("Incompatible integration and quadrature domain")
1771
1771
 
1772
1772
  # Canonicalize types
1773
- accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
1773
+ accumulate_dtype = type_to_warp(accumulate_dtype)
1774
1774
  if output is not None:
1775
1775
  if isinstance(output, BsrMatrix):
1776
1776
  output_dtype = output.scalar_type
@@ -1779,9 +1779,9 @@ def integrate(
1779
1779
  elif output_dtype is None:
1780
1780
  output_dtype = accumulate_dtype
1781
1781
  else:
1782
- output_dtype = wp.types.type_to_warp(output_dtype)
1782
+ output_dtype = type_to_warp(output_dtype)
1783
1783
 
1784
- kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1784
+ kernel, field_arg_values, value_struct_values = _generate_integrate_kernel(
1785
1785
  integrand=integrand,
1786
1786
  domain=domain,
1787
1787
  quadrature=quadrature,
@@ -1806,8 +1806,8 @@ def integrate(
1806
1806
  integrand=integrand,
1807
1807
  kernel=kernel,
1808
1808
  auxiliary_kernels=auxiliary_kernels,
1809
- FieldStruct=FieldStruct,
1810
- ValueStruct=ValueStruct,
1809
+ field_arg_values=field_arg_values,
1810
+ value_struct_values=value_struct_values,
1811
1811
  domain=domain,
1812
1812
  quadrature=quadrature,
1813
1813
  test=test,
@@ -1827,14 +1827,14 @@ def integrate(
1827
1827
  def get_interpolate_to_field_function(
1828
1828
  integrand_func: wp.Function,
1829
1829
  domain: GeometryDomain,
1830
- FieldStruct: wp.codegen.Struct,
1831
- ValueStruct: wp.codegen.Struct,
1830
+ FieldStruct: Struct,
1831
+ ValueStruct: Struct,
1832
1832
  dest: FieldRestriction,
1833
1833
  ):
1834
1834
  zero_value = type_zero_element(dest.space.dtype)
1835
1835
 
1836
1836
  def interpolate_to_field_fn(
1837
- local_node_index: int,
1837
+ partition_node_index: int,
1838
1838
  domain_arg: domain.ElementArg,
1839
1839
  domain_index_arg: domain.ElementIndexArg,
1840
1840
  dest_node_arg: dest.space_restriction.NodeArg,
@@ -1842,7 +1842,6 @@ def get_interpolate_to_field_function(
1842
1842
  fields: FieldStruct,
1843
1843
  values: ValueStruct,
1844
1844
  ):
1845
- partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1846
1845
  element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1847
1846
 
1848
1847
  test_dof_index = NULL_DOF_INDEX
@@ -1894,8 +1893,8 @@ def get_interpolate_to_field_function(
1894
1893
  def get_interpolate_to_field_kernel(
1895
1894
  interpolate_to_field_fn: wp.Function,
1896
1895
  domain: GeometryDomain,
1897
- FieldStruct: wp.codegen.Struct,
1898
- ValueStruct: wp.codegen.Struct,
1896
+ FieldStruct: Struct,
1897
+ ValueStruct: Struct,
1899
1898
  dest: FieldRestriction,
1900
1899
  ):
1901
1900
  @wp.func
@@ -1932,13 +1931,15 @@ def get_interpolate_to_field_kernel(
1932
1931
  ):
1933
1932
  local_node_index = wp.tid()
1934
1933
 
1934
+ partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1935
+ if partition_node_index == NULL_NODE_INDEX:
1936
+ return
1937
+
1935
1938
  val_sum, vol_sum = interpolate_to_field_fn(
1936
1939
  local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1937
1940
  )
1938
1941
 
1939
1942
  if vol_sum > 0.0:
1940
- partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1941
-
1942
1943
  # Grab first element containing node; there must be at least one since vol_sum != 0
1943
1944
  element_index, node_index_in_element = _find_node_in_element(
1944
1945
  domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
@@ -1959,8 +1960,8 @@ def get_interpolate_at_quadrature_kernel(
1959
1960
  integrand_func: wp.Function,
1960
1961
  domain: GeometryDomain,
1961
1962
  quadrature: Quadrature,
1962
- FieldStruct: wp.codegen.Struct,
1963
- ValueStruct: wp.codegen.Struct,
1963
+ FieldStruct: Struct,
1964
+ ValueStruct: Struct,
1964
1965
  value_type: type,
1965
1966
  ):
1966
1967
  def interpolate_at_quadrature_nonvalued_kernel_fn(
@@ -1978,6 +1979,8 @@ def get_interpolate_at_quadrature_kernel(
1978
1979
  return
1979
1980
 
1980
1981
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1982
+ if element_index == NULL_ELEMENT_INDEX:
1983
+ return
1981
1984
 
1982
1985
  test_dof_index = NULL_DOF_INDEX
1983
1986
  trial_dof_index = NULL_DOF_INDEX
@@ -2004,6 +2007,8 @@ def get_interpolate_at_quadrature_kernel(
2004
2007
  return
2005
2008
 
2006
2009
  element_index = domain.element_index(domain_index_arg, domain_element_index)
2010
+ if element_index == NULL_ELEMENT_INDEX:
2011
+ return
2007
2012
 
2008
2013
  test_dof_index = NULL_DOF_INDEX
2009
2014
  trial_dof_index = NULL_DOF_INDEX
@@ -2022,8 +2027,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
2022
2027
  integrand_func: wp.Function,
2023
2028
  domain: GeometryDomain,
2024
2029
  quadrature: Quadrature,
2025
- FieldStruct: wp.codegen.Struct,
2026
- ValueStruct: wp.codegen.Struct,
2030
+ FieldStruct: Struct,
2031
+ ValueStruct: Struct,
2027
2032
  trial: TrialField,
2028
2033
  value_size: int,
2029
2034
  value_type: type,
@@ -2046,11 +2051,13 @@ def get_interpolate_jacobian_at_quadrature_kernel(
2046
2051
  ):
2047
2052
  qp_eval_index, trial_node, trial_dof = wp.tid()
2048
2053
  domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
2049
-
2050
2054
  if domain_element_index == NULL_ELEMENT_INDEX:
2051
2055
  return
2052
2056
 
2053
2057
  element_index = domain.element_index(domain_index_arg, domain_element_index)
2058
+ if element_index == NULL_ELEMENT_INDEX:
2059
+ return
2060
+
2054
2061
  if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
2055
2062
  return
2056
2063
 
@@ -2090,8 +2097,8 @@ def get_interpolate_jacobian_at_quadrature_kernel(
2090
2097
  def get_interpolate_free_kernel(
2091
2098
  integrand_func: wp.Function,
2092
2099
  domain: GeometryDomain,
2093
- FieldStruct: wp.codegen.Struct,
2094
- ValueStruct: wp.codegen.Struct,
2100
+ FieldStruct: Struct,
2101
+ ValueStruct: Struct,
2095
2102
  value_type: type,
2096
2103
  ):
2097
2104
  def interpolate_free_nonvalued_kernel_fn(
@@ -2144,31 +2151,31 @@ def _generate_interpolate_kernel(
2144
2151
  arguments: IntegrandArguments,
2145
2152
  kernel_options: Optional[Dict[str, Any]] = None,
2146
2153
  ) -> wp.Kernel:
2147
- # Generate field struct
2148
- FieldStruct = _gen_field_struct(arguments.field_args)
2149
- ValueStruct = cache.get_argument_struct(arguments.value_args)
2150
-
2151
2154
  _notify_operator_usage(integrand, arguments.field_args)
2152
2155
 
2153
2156
  # Check if kernel exist in cache
2154
- field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
2157
+ field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
2155
2158
  if isinstance(dest, FieldRestriction):
2156
- kernel_suffix = f"_itp_{field_names}_{dest.domain.name}_{dest.space_restriction.space_partition.name}"
2159
+ kernel_suffix = ("itp", *field_names, dest.domain.name, dest.space_restriction.space_partition.name)
2157
2160
  else:
2158
2161
  dest_dtype = dest.dtype if dest else None
2159
- type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
2162
+ type_str = cache.pod_type_key(dest_dtype) if dest_dtype else ""
2160
2163
  if quadrature is None:
2161
- kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
2164
+ kernel_suffix = ("itp", *field_names, domain.name, type_str)
2162
2165
  else:
2163
- kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
2166
+ kernel_suffix = ("itp", *field_names, domain.name, quadrature.name, type_str)
2164
2167
 
2165
- kernel = cache.get_integrand_kernel(
2168
+ kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
2166
2169
  integrand=integrand,
2167
2170
  suffix=kernel_suffix,
2168
2171
  kernel_options=kernel_options,
2169
2172
  )
2170
2173
  if kernel is not None:
2171
- return kernel, FieldStruct, ValueStruct
2174
+ return kernel, field_arg_values, value_struct_values
2175
+
2176
+ # Generate field struct
2177
+ FieldStruct = _gen_field_struct(arguments.field_args)
2178
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
2172
2179
 
2173
2180
  # Not found in cache, transform integrand and generate kernel
2174
2181
  _check_field_compat(integrand, arguments, domain)
@@ -2235,7 +2242,7 @@ def _generate_interpolate_kernel(
2235
2242
  ValueStruct=ValueStruct,
2236
2243
  )
2237
2244
 
2238
- kernel = cache.get_integrand_kernel(
2245
+ kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
2239
2246
  integrand=integrand,
2240
2247
  kernel_fn=interpolate_kernel_fn,
2241
2248
  suffix=kernel_suffix,
@@ -2245,16 +2252,18 @@ def _generate_interpolate_kernel(
2245
2252
  arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
2246
2253
  )
2247
2254
  ],
2255
+ FieldStruct=FieldStruct,
2256
+ ValueStruct=ValueStruct,
2248
2257
  )
2249
2258
 
2250
- return kernel, FieldStruct, ValueStruct
2259
+ return kernel, FieldStruct(), ValueStruct()
2251
2260
 
2252
2261
 
2253
2262
  def _launch_interpolate_kernel(
2254
2263
  integrand: Integrand,
2255
2264
  kernel: wp.kernel,
2256
- FieldStruct: wp.codegen.Struct,
2257
- ValueStruct: wp.codegen.Struct,
2265
+ field_arg_values: StructInstance,
2266
+ value_struct_values: StructInstance,
2258
2267
  domain: GeometryDomain,
2259
2268
  dest: Optional[Union[FieldRestriction, wp.array]],
2260
2269
  quadrature: Optional[Quadrature],
@@ -2270,12 +2279,10 @@ def _launch_interpolate_kernel(
2270
2279
  elt_arg = domain.element_arg_value(device=device)
2271
2280
  elt_index_arg = domain.element_index_arg_value(device=device)
2272
2281
 
2273
- field_arg_values = FieldStruct()
2274
2282
  for k, v in fields.items():
2275
2283
  if not isinstance(v, GeometryDomain):
2276
2284
  v.fill_eval_arg(getattr(field_arg_values, k), device=device)
2277
-
2278
- value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
2285
+ cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
2279
2286
 
2280
2287
  if isinstance(dest, FieldRestriction):
2281
2288
  dest_node_arg = dest.space_restriction.node_arg_value(device=device)
@@ -2313,7 +2320,7 @@ def _launch_interpolate_kernel(
2313
2320
  qp_index_count = quadrature.total_point_count()
2314
2321
 
2315
2322
  if qp_eval_count != qp_index_count:
2316
- wp.utils.warn(
2323
+ warn(
2317
2324
  f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
2318
2325
  category=UserWarning,
2319
2326
  stacklevel=2,
@@ -2353,7 +2360,6 @@ def _launch_interpolate_kernel(
2353
2360
  triplet_rows = triplet_rows_temp.array
2354
2361
  triplet_values = triplet_values_temp.array
2355
2362
  triplet_rows.fill_(-1)
2356
- triplet_values.zero_()
2357
2363
 
2358
2364
  trial_partition_arg = trial.space_partition.partition_arg_value(device)
2359
2365
  trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
@@ -2470,9 +2476,9 @@ def interpolate(
2470
2476
  _find_integrand_operators(integrand, arguments.field_args)
2471
2477
 
2472
2478
  if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
2473
- wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
2479
+ warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
2474
2480
 
2475
- kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
2481
+ kernel, field_struct, value_struct = _generate_interpolate_kernel(
2476
2482
  integrand=integrand,
2477
2483
  domain=domain,
2478
2484
  dest=dest,
@@ -2484,8 +2490,8 @@ def interpolate(
2484
2490
  return _launch_interpolate_kernel(
2485
2491
  integrand=integrand,
2486
2492
  kernel=kernel,
2487
- FieldStruct=FieldStruct,
2488
- ValueStruct=ValueStruct,
2493
+ field_arg_values=field_struct,
2494
+ value_struct_values=value_struct,
2489
2495
  domain=domain,
2490
2496
  dest=dest,
2491
2497
  quadrature=quadrature,