warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/fem/integrate.py ADDED
@@ -0,0 +1,2335 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ast
17
+ import inspect
18
+ import textwrap
19
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
20
+
21
+ import warp as wp
22
+ from warp.codegen import get_annotations
23
+ from warp.fem import cache
24
+ from warp.fem.domain import GeometryDomain
25
+ from warp.fem.field import (
26
+ DiscreteField,
27
+ FieldLike,
28
+ FieldRestriction,
29
+ GeometryField,
30
+ LocalTestField,
31
+ LocalTrialField,
32
+ TestField,
33
+ TrialField,
34
+ make_restriction,
35
+ )
36
+ from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
37
+ from warp.fem.linalg import array_axpy, basis_coefficient
38
+ from warp.fem.operator import Integrand, Operator, at_node, integrand
39
+ from warp.fem.quadrature import Quadrature, RegularQuadrature
40
+ from warp.fem.types import (
41
+ NULL_DOF_INDEX,
42
+ NULL_ELEMENT_INDEX,
43
+ NULL_NODE_INDEX,
44
+ OUTSIDE,
45
+ Coords,
46
+ DofIndex,
47
+ Domain,
48
+ Field,
49
+ Sample,
50
+ make_free_sample,
51
+ )
52
+ from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
53
+ from warp.types import type_length
54
+ from warp.utils import array_cast
55
+
56
+
57
+ def _resolve_path(func, node):
58
+ """
59
+ Resolves variable and path from ast node/attribute (adapted from warp.codegen)
60
+ """
61
+
62
+ modules = []
63
+
64
+ while isinstance(node, ast.Attribute):
65
+ modules.append(node.attr)
66
+ node = node.value
67
+
68
+ if isinstance(node, ast.Name):
69
+ modules.append(node.id)
70
+
71
+ # reverse list since ast presents it backward order
72
+ path = [*reversed(modules)]
73
+
74
+ if len(path) == 0:
75
+ return None, path
76
+
77
+ # try and evaluate object path
78
+ try:
79
+ # Look up the closure info and append it to adj.func.__globals__
80
+ # in case you want to define a kernel inside a function and refer
81
+ # to variables you've declared inside that function:
82
+ capturedvars = dict(zip(func.__code__.co_freevars, [c.cell_contents for c in (func.__closure__ or [])]))
83
+
84
+ vars_dict = {**func.__globals__, **capturedvars}
85
+ func = eval(".".join(path), vars_dict)
86
+ return func, path
87
+ except (NameError, AttributeError):
88
+ pass
89
+
90
+ return None, path
91
+
92
+
93
+ class IntegrandVisitor(ast.NodeTransformer):
94
+ class FieldInfo(NamedTuple):
95
+ field: FieldLike
96
+ abstract_type: type
97
+ concrete_type: type
98
+ root_arg_name: type
99
+
100
+ def __init__(
101
+ self,
102
+ integrand: Integrand,
103
+ field_info: Dict[str, FieldInfo],
104
+ ):
105
+ self._integrand = integrand
106
+ self._field_symbols = field_info.copy()
107
+ self._field_nodes = {}
108
+
109
+ @staticmethod
110
+ def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
111
+ def get_concrete_type(field: Union[FieldLike, Domain]):
112
+ if isinstance(field, FieldLike):
113
+ return field.ElementEvalArg
114
+ return field.ElementArg
115
+
116
+ return {
117
+ name: IntegrandVisitor.FieldInfo(
118
+ field=field,
119
+ abstract_type=integrand.argspec.annotations[name],
120
+ concrete_type=get_concrete_type(field),
121
+ root_arg_name=name,
122
+ )
123
+ for name, field in field_args.items()
124
+ }
125
+
126
+ def _get_field_info(self, node: ast.expr):
127
+ field_info = self._field_nodes.get(node)
128
+ if field_info is None and isinstance(node, ast.Name):
129
+ field_info = self._field_symbols.get(node.id)
130
+
131
+ return field_info
132
+
133
+ def visit_Call(self, call: ast.Call):
134
+ call = self.generic_visit(call)
135
+
136
+ callee = getattr(call.func, "id", None)
137
+ if callee in self._field_symbols:
138
+ # Shortcut for evaluating fields as f(x...)
139
+ field_info = self._field_symbols[callee]
140
+
141
+ # Replace with default call operator
142
+ default_operator = field_info.abstract_type.call_operator
143
+
144
+ self._process_operator_call(call, callee, default_operator, field_info)
145
+
146
+ return call
147
+
148
+ func, _ = _resolve_path(self._integrand.func, call.func)
149
+
150
+ if isinstance(func, Operator) and len(call.args) > 0:
151
+ # Evaluating operators as op(field, x, ...)
152
+ field_info = self._get_field_info(call.args[0])
153
+ if field_info is not None:
154
+ self._process_operator_call(call, func, func, field_info)
155
+
156
+ if func.field_result:
157
+ res = func.field_result(field_info.field)
158
+ self._field_nodes[call] = IntegrandVisitor.FieldInfo(
159
+ field=res[0],
160
+ abstract_type=res[1],
161
+ concrete_type=res[2],
162
+ root_arg_name=f"{field_info.root_arg_name}.{func.name}",
163
+ )
164
+
165
+ if isinstance(func, Integrand):
166
+ callee_field_args = self._get_callee_field_args(func, call.args)
167
+ self._process_integrand_call(call, func, callee_field_args)
168
+
169
+ # print(ast.dump(call, indent=4))
170
+
171
+ return call
172
+
173
+ def visit_Assign(self, node: ast.Assign):
174
+ node = self.generic_visit(node)
175
+
176
+ # Check if we're assigning a field
177
+ src_field_info = self._get_field_info(node.value)
178
+ if src_field_info is not None:
179
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
180
+ raise NotImplementedError("warp.fem Fields and Domains may only be assigned to simple variables")
181
+
182
+ self._field_symbols[node.targets[0].id] = src_field_info
183
+
184
+ return node
185
+
186
+ def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
187
+ # Get field types for call site arguments
188
+ call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
189
+ for arg in args:
190
+ field_info = self._get_field_info(arg)
191
+ if field_info is not None:
192
+ call_site_field_args.append(field_info)
193
+
194
+ call_site_field_args.reverse()
195
+
196
+ # Pass to callee in same order
197
+ callee_field_args = {}
198
+ for arg in callee.argspec.args:
199
+ arg_type = callee.argspec.annotations[arg]
200
+ if arg_type in (Field, Domain):
201
+ passed_field_info = call_site_field_args.pop()
202
+ if passed_field_info.abstract_type != arg_type:
203
+ raise TypeError(
204
+ f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
205
+ )
206
+ callee_field_args[arg] = passed_field_info
207
+
208
+ return callee_field_args
209
+
210
+
211
+ class IntegrandOperatorParser(IntegrandVisitor):
212
+ def __init__(self, integrand: Integrand, field_info: Dict[str, IntegrandVisitor.FieldInfo], callback: Callable):
213
+ super().__init__(integrand, field_info)
214
+ self._operator_callback = callback
215
+
216
+ def _process_operator_call(
217
+ self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
218
+ ):
219
+ self._operator_callback(field_info, operator)
220
+
221
+ def _process_integrand_call(
222
+ self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
223
+ ):
224
+ callee_field_args = self._get_callee_field_args(callee, call.args)
225
+ callee_parser = IntegrandOperatorParser(callee, callee_field_args, callback=self._operator_callback)
226
+ callee_parser._apply()
227
+
228
+ def _apply(self):
229
+ source = textwrap.dedent(inspect.getsource(self._integrand.func))
230
+ tree = ast.parse(source)
231
+ self.visit(tree)
232
+
233
+ @staticmethod
234
+ def apply(
235
+ integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
236
+ ) -> wp.Function:
237
+ field_info = IntegrandVisitor._build_field_info(integrand, field_args)
238
+ IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
239
+
240
+
241
+ class IntegrandTransformer(IntegrandVisitor):
242
+ def _process_operator_call(
243
+ self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
244
+ ):
245
+ field = field_info.field
246
+
247
+ try:
248
+ # Retrieve the function pointer corresponding to the operator implementation for the field type
249
+ pointer = operator.resolver(field)
250
+ if not isinstance(pointer, wp.context.Function):
251
+ raise NotImplementedError(operator.resolver.__name__)
252
+
253
+ except (AttributeError, NotImplementedError) as e:
254
+ raise TypeError(
255
+ f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
256
+ ) from e
257
+
258
+ # Update the ast Call node to use the new function pointer
259
+ call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
260
+
261
+ # Save the pointer as an attribute than can be accessed from the calling scope
262
+ # For usual operator call syntax, we can use the operator itself, but for the
263
+ # shortcut default operator syntax, we store it on the callee's concrete type
264
+ if isinstance(callee, Operator):
265
+ setattr(callee, pointer.key, pointer)
266
+ else:
267
+ setattr(field_info.concrete_type, pointer.key, pointer)
268
+
269
+ # also insert callee as first argument
270
+ call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
271
+
272
+ def _process_integrand_call(
273
+ self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
274
+ ):
275
+ callee_field_args = self._get_callee_field_args(callee, call.args)
276
+ transformer = IntegrandTransformer(callee, callee_field_args)
277
+ key = transformer._apply().key
278
+ call.func = ast.Attribute(
279
+ value=call.func,
280
+ attr=key,
281
+ ctx=ast.Load(),
282
+ )
283
+
284
+ def _apply(self) -> wp.Function:
285
+ # Transform field evaluation calls
286
+ field_info = self._field_symbols
287
+
288
+ # Specialize field argument types
289
+ argspec = self._integrand.argspec
290
+ annotations = argspec.annotations.copy()
291
+ annotations.update({name: f.concrete_type for name, f in field_info.items()})
292
+
293
+ suffix = "_".join([f.field.name for f in field_info.values()])
294
+ func = cache.get_integrand_function(
295
+ integrand=self._integrand,
296
+ suffix=suffix,
297
+ annotations=annotations,
298
+ code_transformers=[self],
299
+ )
300
+
301
+ # func = self._integrand.module.functions[func.key] #no longer needed?
302
+ setattr(self._integrand, func.key, func)
303
+
304
+ return func
305
+
306
+ @staticmethod
307
+ def apply(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
308
+ field_info = IntegrandVisitor._build_field_info(integrand, field_args)
309
+ return IntegrandTransformer(integrand, field_info)._apply()
310
+
311
+
312
+ class IntegrandArguments(NamedTuple):
313
+ field_args: Dict[str, Union[FieldLike, GeometryDomain]]
314
+ value_args: Dict[str, Any]
315
+ domain_name: str
316
+ sample_name: str
317
+ test_name: str
318
+ trial_name: str
319
+
320
+
321
+ def _parse_integrand_arguments(
322
+ integrand: Integrand,
323
+ fields: Dict[str, FieldLike],
324
+ ):
325
+ # parse argument types
326
+ field_args = {}
327
+ value_args = {}
328
+
329
+ domain_name = None
330
+ sample_name = None
331
+ test_name = None
332
+ trial_name = None
333
+
334
+ argspec = integrand.argspec
335
+ for arg in argspec.args:
336
+ arg_type = argspec.annotations[arg]
337
+ if arg_type == Field:
338
+ try:
339
+ field = fields[arg]
340
+ except KeyError as err:
341
+ raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
342
+ if not isinstance(field, FieldLike):
343
+ raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
344
+ if isinstance(field, TestField):
345
+ if test_name is not None:
346
+ raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
347
+ test_name = arg
348
+ elif isinstance(field, TrialField):
349
+ if trial_name is not None:
350
+ raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
351
+ trial_name = arg
352
+ field_args[arg] = field
353
+ elif arg_type == Domain:
354
+ if domain_name is not None:
355
+ raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
356
+ if arg in fields:
357
+ raise ValueError(
358
+ f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
359
+ )
360
+ domain_name = arg
361
+ elif arg_type == Sample:
362
+ if sample_name is not None:
363
+ raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
364
+ if arg in fields:
365
+ raise ValueError(
366
+ f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
367
+ )
368
+ sample_name = arg
369
+ else:
370
+ if arg in fields:
371
+ raise ValueError(
372
+ f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
373
+ )
374
+ value_args[arg] = arg_type
375
+
376
+ return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
377
+
378
+
379
+ def _check_field_compat(integrand: Integrand, arguments: IntegrandArguments, domain: GeometryDomain):
380
+ # Check field compatibility
381
+ for name, field in arguments.field_args.items():
382
+ if isinstance(field, GeometryField) and domain is not None:
383
+ if field.geometry != domain.geometry:
384
+ raise ValueError(f"Field '{name}' must be defined on the same geometry as the integration domain")
385
+ if field.element_kind != domain.element_kind:
386
+ raise ValueError(
387
+ f"Field '{name}' is not defined on the same kind of elements (cells or sides) as the integration domain. Maybe a forgotten `.trace()`?"
388
+ )
389
+
390
+
391
+ def _find_integrand_operators(integrand: Integrand, field_args: Dict[str, FieldLike]):
392
+ if integrand.operators is None:
393
+ # Integrands operator dictionary does not depend on concrete field type,
394
+ # so only needs to be built once per integrand
395
+
396
+ operators = {}
397
+
398
+ def operator_callback(field: IntegrandVisitor.FieldInfo, op: Operator):
399
+ if field.root_arg_name in operators:
400
+ operators[field.root_arg_name].add(op)
401
+ else:
402
+ operators[field.root_arg_name] = {op}
403
+
404
+ IntegrandOperatorParser.apply(integrand, field_args, operator_callback=operator_callback)
405
+
406
+ integrand.operators = operators
407
+
408
+
409
+ def _notify_operator_usage(
410
+ integrand: Integrand,
411
+ field_args: Dict[str, FieldLike],
412
+ ):
413
+ for arg, field_ops in integrand.operators.items():
414
+ if arg in field_args:
415
+ # print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
416
+ field_args[arg].notify_operator_usage(field_ops)
417
+
418
+
419
+ def _gen_field_struct(field_args: Dict[str, FieldLike]):
420
+ class Fields:
421
+ pass
422
+
423
+ annotations = get_annotations(Fields)
424
+
425
+ for name, arg in field_args.items():
426
+ if isinstance(arg, GeometryDomain):
427
+ continue
428
+ setattr(Fields, name, arg.EvalArg())
429
+ annotations[name] = arg.EvalArg
430
+
431
+ try:
432
+ Fields.__annotations__ = annotations
433
+ except AttributeError:
434
+ Fields.__dict__.__annotations__ = annotations
435
+
436
+ suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
437
+
438
+ return cache.get_struct(Fields, suffix=suffix)
439
+
440
+
441
+ def _get_trial_arg():
442
+ pass
443
+
444
+
445
+ def _get_test_arg():
446
+ pass
447
+
448
+
449
+ class PassFieldArgsToIntegrand(ast.NodeTransformer):
450
+ def __init__(
451
+ self,
452
+ arg_names: List[str],
453
+ parsed_args: IntegrandArguments,
454
+ integrand_func: wp.Function,
455
+ func_name: str = "integrand_func",
456
+ fields_var_name: str = "fields",
457
+ values_var_name: str = "values",
458
+ domain_var_name: str = "domain_arg",
459
+ sample_var_name: str = "sample",
460
+ field_wrappers_attr: str = "_field_wrappers",
461
+ ):
462
+ self._arg_names = arg_names
463
+ self._field_args = parsed_args.field_args
464
+ self._value_args = parsed_args.value_args
465
+ self._domain_name = parsed_args.domain_name
466
+ self._sample_name = parsed_args.sample_name
467
+ self._test_name = parsed_args.test_name
468
+ self._trial_name = parsed_args.trial_name
469
+ self._func_name = func_name
470
+ self._fields_var_name = fields_var_name
471
+ self._values_var_name = values_var_name
472
+ self._domain_var_name = domain_var_name
473
+ self._sample_var_name = sample_var_name
474
+
475
+ self._field_wrappers_attr = field_wrappers_attr
476
+ self._register_integrand_field_wrappers(integrand_func, parsed_args.field_args)
477
+
478
+ class _FieldWrappers:
479
+ pass
480
+
481
+ def _register_integrand_field_wrappers(self, integrand_func: wp.Function, fields: Dict[str, FieldLike]):
482
+ # Mechanism to pass the geometry argument only once to the root kernel
483
+ # Field wrappers are used to forward it to all fields in nested integrand calls
484
+ field_wrappers = PassFieldArgsToIntegrand._FieldWrappers()
485
+ for name, field in fields.items():
486
+ if isinstance(field, FieldLike):
487
+ setattr(field_wrappers, name, field.ElementEvalArg)
488
+ setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
489
+
490
+ def visit_Call(self, call: ast.Call):
491
+ call = self.generic_visit(call)
492
+
493
+ callee = getattr(call.func, "id", None)
494
+
495
+ if callee == self._func_name:
496
+ # Replace function arguments with our generated structs
497
+ call.args.clear()
498
+ for arg in self._arg_names:
499
+ if arg == self._domain_name:
500
+ call.args.append(
501
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
502
+ )
503
+ elif arg == self._sample_name:
504
+ call.args.append(
505
+ ast.Name(id=self._sample_var_name, ctx=ast.Load()),
506
+ )
507
+ elif arg in self._field_args:
508
+ call.args.append(
509
+ ast.Call(
510
+ func=ast.Attribute(
511
+ value=ast.Attribute(
512
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
513
+ attr=self._field_wrappers_attr,
514
+ ctx=ast.Load(),
515
+ ),
516
+ attr=arg,
517
+ ctx=ast.Load(),
518
+ ),
519
+ args=[
520
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
521
+ ast.Attribute(
522
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
523
+ attr=arg,
524
+ ctx=ast.Load(),
525
+ ),
526
+ ],
527
+ keywords=[],
528
+ )
529
+ )
530
+ elif arg in self._value_args:
531
+ call.args.append(
532
+ ast.Attribute(
533
+ value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
534
+ attr=arg,
535
+ ctx=ast.Load(),
536
+ )
537
+ )
538
+ else:
539
+ raise RuntimeError(f"Unhandled argument {arg}")
540
+ # print(ast.dump(call, indent=4))
541
+ elif callee == _get_test_arg.__name__:
542
+ # print(ast.dump(call, indent=4))
543
+ call = ast.Attribute(
544
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
545
+ attr=self._test_name,
546
+ ctx=ast.Load(),
547
+ )
548
+ elif callee == _get_trial_arg.__name__:
549
+ # print(ast.dump(call, indent=4))
550
+ call = ast.Attribute(
551
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
552
+ attr=self._trial_name,
553
+ ctx=ast.Load(),
554
+ )
555
+
556
+ return call
557
+
558
+
559
+ def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_site_options: Optional[Dict[str, Any]]):
560
+ if integrand_options is None:
561
+ return {} if call_site_options is None else call_site_options
562
+
563
+ options = integrand_options.copy()
564
+ if call_site_options is not None:
565
+ options.update(call_site_options)
566
+ return options
567
+
568
+
569
+ def get_integrate_constant_kernel(
570
+ integrand_func: wp.Function,
571
+ domain: GeometryDomain,
572
+ quadrature: Quadrature,
573
+ FieldStruct: wp.codegen.Struct,
574
+ ValueStruct: wp.codegen.Struct,
575
+ accumulate_dtype,
576
+ ):
577
+ def integrate_kernel_fn(
578
+ qp_arg: quadrature.Arg,
579
+ qp_element_index_arg: quadrature.ElementIndexArg,
580
+ domain_arg: domain.ElementArg,
581
+ domain_index_arg: domain.ElementIndexArg,
582
+ fields: FieldStruct,
583
+ values: ValueStruct,
584
+ result: wp.array(dtype=accumulate_dtype),
585
+ ):
586
+ qp_eval_index = wp.tid()
587
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
588
+ if domain_element_index == NULL_ELEMENT_INDEX:
589
+ return
590
+
591
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
592
+
593
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
594
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
595
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
596
+
597
+ test_dof_index = NULL_DOF_INDEX
598
+ trial_dof_index = NULL_DOF_INDEX
599
+
600
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
601
+ vol = domain.element_measure(domain_arg, sample)
602
+
603
+ val = integrand_func(sample, fields, values)
604
+
605
+ wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
606
+
607
+ return integrate_kernel_fn
608
+
609
+
610
+ def get_integrate_linear_kernel(
611
+ integrand_func: wp.Function,
612
+ domain: GeometryDomain,
613
+ quadrature: Quadrature,
614
+ FieldStruct: wp.codegen.Struct,
615
+ ValueStruct: wp.codegen.Struct,
616
+ test: TestField,
617
+ output_dtype,
618
+ accumulate_dtype,
619
+ ):
620
+ def integrate_kernel_fn(
621
+ qp_arg: quadrature.Arg,
622
+ domain_arg: domain.ElementArg,
623
+ domain_index_arg: domain.ElementIndexArg,
624
+ test_arg: test.space_restriction.NodeArg,
625
+ fields: FieldStruct,
626
+ values: ValueStruct,
627
+ result: wp.array2d(dtype=output_dtype),
628
+ ):
629
+ local_node_index, test_dof = wp.tid()
630
+ node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
631
+ element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
632
+
633
+ trial_dof_index = NULL_DOF_INDEX
634
+
635
+ val_sum = accumulate_dtype(0.0)
636
+
637
+ for n in range(element_beg, element_end):
638
+ node_element_index = test.space_restriction.node_element_index(test_arg, n)
639
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
640
+
641
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
642
+
643
+ qp_point_count = quadrature.point_count(
644
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index
645
+ )
646
+ for k in range(qp_point_count):
647
+ qp_index = quadrature.point_index(
648
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
649
+ )
650
+ qp_coords = quadrature.point_coords(
651
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
652
+ )
653
+ qp_weight = quadrature.point_weight(
654
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
655
+ )
656
+
657
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
658
+
659
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
660
+ val = integrand_func(sample, fields, values)
661
+
662
+ val_sum += accumulate_dtype(qp_weight * vol * val)
663
+
664
+ result[node_index, test_dof] += output_dtype(val_sum)
665
+
666
+ return integrate_kernel_fn
667
+
668
+
669
+ def get_integrate_linear_nodal_kernel(
670
+ integrand_func: wp.Function,
671
+ domain: GeometryDomain,
672
+ FieldStruct: wp.codegen.Struct,
673
+ ValueStruct: wp.codegen.Struct,
674
+ test: TestField,
675
+ output_dtype,
676
+ accumulate_dtype,
677
+ ):
678
+ def integrate_kernel_fn(
679
+ domain_arg: domain.ElementArg,
680
+ domain_index_arg: domain.ElementIndexArg,
681
+ test_restriction_arg: test.space_restriction.NodeArg,
682
+ test_topo_arg: test.space.topology.TopologyArg,
683
+ fields: FieldStruct,
684
+ values: ValueStruct,
685
+ result: wp.array2d(dtype=output_dtype),
686
+ ):
687
+ local_node_index, dof = wp.tid()
688
+
689
+ partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
690
+ element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
691
+
692
+ trial_dof_index = NULL_DOF_INDEX
693
+
694
+ val_sum = accumulate_dtype(0.0)
695
+
696
+ for n in range(element_beg, element_end):
697
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
698
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
699
+
700
+ if n == element_beg:
701
+ node_index = test.space.topology.element_node_index(
702
+ domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
703
+ )
704
+
705
+ coords = test.space.node_coords_in_element(
706
+ domain_arg,
707
+ _get_test_arg(),
708
+ element_index,
709
+ node_element_index.node_index_in_element,
710
+ )
711
+
712
+ if coords[0] != OUTSIDE:
713
+ node_weight = test.space.node_quadrature_weight(
714
+ domain_arg,
715
+ _get_test_arg(),
716
+ element_index,
717
+ node_element_index.node_index_in_element,
718
+ )
719
+
720
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
721
+
722
+ sample = Sample(
723
+ element_index,
724
+ coords,
725
+ node_index,
726
+ node_weight,
727
+ test_dof_index,
728
+ trial_dof_index,
729
+ )
730
+ vol = domain.element_measure(domain_arg, sample)
731
+ val = integrand_func(sample, fields, values)
732
+
733
+ val_sum += accumulate_dtype(node_weight * vol * val)
734
+
735
+ result[partition_node_index, dof] += output_dtype(val_sum)
736
+
737
+ return integrate_kernel_fn
738
+
739
+
740
+ def get_integrate_linear_local_kernel(
741
+ integrand_func: wp.Function,
742
+ domain: GeometryDomain,
743
+ quadrature: Quadrature,
744
+ FieldStruct: wp.codegen.Struct,
745
+ ValueStruct: wp.codegen.Struct,
746
+ test: LocalTestField,
747
+ ):
748
+ def integrate_kernel_fn(
749
+ qp_arg: quadrature.Arg,
750
+ qp_element_index_arg: quadrature.ElementIndexArg,
751
+ domain_arg: domain.ElementArg,
752
+ domain_index_arg: domain.ElementIndexArg,
753
+ fields: FieldStruct,
754
+ values: ValueStruct,
755
+ result: wp.array3d(dtype=float),
756
+ ):
757
+ qp_eval_index, taylor_dof, test_dof = wp.tid()
758
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
759
+
760
+ if domain_element_index == NULL_ELEMENT_INDEX:
761
+ return
762
+
763
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
764
+
765
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
766
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
767
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
768
+
769
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
770
+
771
+ trial_dof_index = NULL_DOF_INDEX
772
+ test_dof_index = DofIndex(taylor_dof, test_dof)
773
+
774
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
775
+ val = integrand_func(sample, fields, values)
776
+ result[qp_eval_index, taylor_dof, test_dof] = qp_weight * vol * val
777
+
778
+ return integrate_kernel_fn
779
+
780
+
781
+ def get_integrate_bilinear_kernel(
782
+ integrand_func: wp.Function,
783
+ domain: GeometryDomain,
784
+ quadrature: Quadrature,
785
+ FieldStruct: wp.codegen.Struct,
786
+ ValueStruct: wp.codegen.Struct,
787
+ test: TestField,
788
+ trial: TrialField,
789
+ output_dtype,
790
+ accumulate_dtype,
791
+ ):
792
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
793
+
794
+ def integrate_kernel_fn(
795
+ qp_arg: quadrature.Arg,
796
+ domain_arg: domain.ElementArg,
797
+ domain_index_arg: domain.ElementIndexArg,
798
+ test_arg: test.space_restriction.NodeArg,
799
+ trial_partition_arg: trial.space_partition.PartitionArg,
800
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
801
+ fields: FieldStruct,
802
+ values: ValueStruct,
803
+ triplet_rows: wp.array(dtype=int),
804
+ triplet_cols: wp.array(dtype=int),
805
+ triplet_values: wp.array3d(dtype=output_dtype),
806
+ ):
807
+ test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
808
+
809
+ test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
810
+ element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
811
+
812
+ trial_dof_index = DofIndex(trial_node, trial_dof)
813
+
814
+ for element in range(element_beg, element_end):
815
+ test_element_index = test.space_restriction.node_element_index(test_arg, element)
816
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
817
+
818
+ element_trial_node_count = trial.space.topology.element_node_count(
819
+ domain_arg, trial_topology_arg, element_index
820
+ )
821
+ qp_point_count = wp.where(
822
+ trial_node < element_trial_node_count,
823
+ quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
824
+ 0,
825
+ )
826
+
827
+ test_dof_index = DofIndex(
828
+ test_element_index.node_index_in_element,
829
+ test_dof,
830
+ )
831
+
832
+ val_sum = accumulate_dtype(0.0)
833
+
834
+ for k in range(qp_point_count):
835
+ qp_index = quadrature.point_index(
836
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
837
+ )
838
+ coords = quadrature.point_coords(
839
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
840
+ )
841
+
842
+ qp_weight = quadrature.point_weight(
843
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
844
+ )
845
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
846
+
847
+ sample = Sample(
848
+ element_index,
849
+ coords,
850
+ qp_index,
851
+ qp_weight,
852
+ test_dof_index,
853
+ trial_dof_index,
854
+ )
855
+ val = integrand_func(sample, fields, values)
856
+ val_sum += accumulate_dtype(qp_weight * vol * val)
857
+
858
+ block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
859
+ triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
860
+
861
+ # Set row and column indices
862
+ if test_dof == 0 and trial_dof == 0:
863
+ if trial_node < element_trial_node_count:
864
+ trial_node_index = trial.space_partition.partition_node_index(
865
+ trial_partition_arg,
866
+ trial.space.topology.element_node_index(
867
+ domain_arg, trial_topology_arg, element_index, trial_node
868
+ ),
869
+ )
870
+ else:
871
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
872
+ triplet_rows[block_offset] = test_node_index
873
+ triplet_cols[block_offset] = trial_node_index
874
+
875
+ return integrate_kernel_fn
876
+
877
+
878
+ def get_integrate_bilinear_nodal_kernel(
879
+ integrand_func: wp.Function,
880
+ domain: GeometryDomain,
881
+ FieldStruct: wp.codegen.Struct,
882
+ ValueStruct: wp.codegen.Struct,
883
+ test: TestField,
884
+ output_dtype,
885
+ accumulate_dtype,
886
+ ):
887
+ def integrate_kernel_fn(
888
+ domain_arg: domain.ElementArg,
889
+ domain_index_arg: domain.ElementIndexArg,
890
+ test_restriction_arg: test.space_restriction.NodeArg,
891
+ test_topo_arg: test.space.topology.TopologyArg,
892
+ fields: FieldStruct,
893
+ values: ValueStruct,
894
+ triplet_rows: wp.array(dtype=int),
895
+ triplet_cols: wp.array(dtype=int),
896
+ triplet_values: wp.array3d(dtype=output_dtype),
897
+ ):
898
+ local_node_index, test_dof, trial_dof = wp.tid()
899
+
900
+ partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
901
+ element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
902
+
903
+ val_sum = accumulate_dtype(0.0)
904
+
905
+ for n in range(element_beg, element_end):
906
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
907
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
908
+
909
+ if n == element_beg:
910
+ node_index = test.space.topology.element_node_index(
911
+ domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
912
+ )
913
+
914
+ coords = test.space.node_coords_in_element(
915
+ domain_arg,
916
+ _get_test_arg(),
917
+ element_index,
918
+ node_element_index.node_index_in_element,
919
+ )
920
+
921
+ if coords[0] != OUTSIDE:
922
+ node_weight = test.space.node_quadrature_weight(
923
+ domain_arg,
924
+ _get_test_arg(),
925
+ element_index,
926
+ node_element_index.node_index_in_element,
927
+ )
928
+
929
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
930
+ trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
931
+
932
+ sample = Sample(
933
+ element_index,
934
+ coords,
935
+ node_index,
936
+ node_weight,
937
+ test_dof_index,
938
+ trial_dof_index,
939
+ )
940
+ vol = domain.element_measure(domain_arg, sample)
941
+ val = integrand_func(sample, fields, values)
942
+
943
+ val_sum += accumulate_dtype(node_weight * vol * val)
944
+
945
+ triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
946
+ triplet_rows[local_node_index] = partition_node_index
947
+ triplet_cols[local_node_index] = partition_node_index
948
+
949
+ return integrate_kernel_fn
950
+
951
+
952
+ def get_integrate_bilinear_local_kernel(
953
+ integrand_func: wp.Function,
954
+ domain: GeometryDomain,
955
+ quadrature: Quadrature,
956
+ FieldStruct: wp.codegen.Struct,
957
+ ValueStruct: wp.codegen.Struct,
958
+ test: LocalTestField,
959
+ trial: LocalTrialField,
960
+ ):
961
+ TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
962
+ TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
963
+
964
+ def integrate_kernel_fn(
965
+ qp_arg: quadrature.Arg,
966
+ qp_element_index_arg: quadrature.ElementIndexArg,
967
+ domain_arg: domain.ElementArg,
968
+ domain_index_arg: domain.ElementIndexArg,
969
+ fields: FieldStruct,
970
+ values: ValueStruct,
971
+ result: wp.array4d(dtype=float),
972
+ ):
973
+ qp_eval_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
974
+
975
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
976
+ if domain_element_index == NULL_ELEMENT_INDEX:
977
+ return
978
+
979
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
980
+
981
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
982
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
983
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
984
+
985
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
986
+ qp_vol = vol * qp_weight
987
+
988
+ trial_dof_index = DofIndex(trial_taylor_dof, trial_dof)
989
+
990
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
991
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
992
+
993
+ test_dof_index = DofIndex(test_taylor_dof, test_dof)
994
+
995
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
996
+ val = integrand_func(sample, fields, values)
997
+ result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
998
+
999
+ return integrate_kernel_fn
1000
+
1001
+
1002
+ def _generate_integrate_kernel(
1003
+ integrand: Integrand,
1004
+ domain: GeometryDomain,
1005
+ quadrature: Quadrature,
1006
+ arguments: IntegrandArguments,
1007
+ test: Optional[TestField],
1008
+ trial: Optional[TrialField],
1009
+ output_dtype: type,
1010
+ accumulate_dtype: type,
1011
+ kernel_options: Optional[Dict[str, Any]] = None,
1012
+ ) -> wp.Kernel:
1013
+ output_dtype = wp.types.type_scalar_type(output_dtype)
1014
+
1015
+ FieldStruct = _gen_field_struct(arguments.field_args)
1016
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
1017
+
1018
+ _notify_operator_usage(integrand, arguments.field_args)
1019
+
1020
+ # Check if kernel exist in cache
1021
+ field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
1022
+ kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{field_names}"
1023
+
1024
+ if quadrature is not None:
1025
+ kernel_suffix += quadrature.name
1026
+
1027
+ kernel = cache.get_integrand_kernel(integrand=integrand, suffix=kernel_suffix, kernel_options=kernel_options)
1028
+ if kernel is not None:
1029
+ return kernel, FieldStruct, ValueStruct
1030
+
1031
+ # Not found in cache, transform integrand and generate kernel
1032
+ _check_field_compat(integrand, arguments, domain)
1033
+
1034
+ integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
1035
+
1036
+ nodal = quadrature is None
1037
+
1038
+ if test is None and trial is None:
1039
+ integrate_kernel_fn = get_integrate_constant_kernel(
1040
+ integrand_func,
1041
+ domain,
1042
+ quadrature,
1043
+ FieldStruct,
1044
+ ValueStruct,
1045
+ accumulate_dtype=accumulate_dtype,
1046
+ )
1047
+ elif trial is None:
1048
+ if nodal:
1049
+ integrate_kernel_fn = get_integrate_linear_nodal_kernel(
1050
+ integrand_func,
1051
+ domain,
1052
+ FieldStruct,
1053
+ ValueStruct,
1054
+ test=test,
1055
+ output_dtype=output_dtype,
1056
+ accumulate_dtype=accumulate_dtype,
1057
+ )
1058
+ elif isinstance(test, LocalTestField):
1059
+ integrate_kernel_fn = get_integrate_linear_local_kernel(
1060
+ integrand_func,
1061
+ domain,
1062
+ quadrature,
1063
+ FieldStruct,
1064
+ ValueStruct,
1065
+ test=test,
1066
+ )
1067
+ else:
1068
+ integrate_kernel_fn = get_integrate_linear_kernel(
1069
+ integrand_func,
1070
+ domain,
1071
+ quadrature,
1072
+ FieldStruct,
1073
+ ValueStruct,
1074
+ test=test,
1075
+ output_dtype=output_dtype,
1076
+ accumulate_dtype=accumulate_dtype,
1077
+ )
1078
+ else:
1079
+ if nodal:
1080
+ integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
1081
+ integrand_func,
1082
+ domain,
1083
+ FieldStruct,
1084
+ ValueStruct,
1085
+ test=test,
1086
+ output_dtype=output_dtype,
1087
+ accumulate_dtype=accumulate_dtype,
1088
+ )
1089
+ elif isinstance(test, LocalTestField):
1090
+ integrate_kernel_fn = get_integrate_bilinear_local_kernel(
1091
+ integrand_func,
1092
+ domain,
1093
+ quadrature,
1094
+ FieldStruct,
1095
+ ValueStruct,
1096
+ test=test,
1097
+ trial=trial,
1098
+ )
1099
+ else:
1100
+ integrate_kernel_fn = get_integrate_bilinear_kernel(
1101
+ integrand_func,
1102
+ domain,
1103
+ quadrature,
1104
+ FieldStruct,
1105
+ ValueStruct,
1106
+ test=test,
1107
+ trial=trial,
1108
+ output_dtype=output_dtype,
1109
+ accumulate_dtype=accumulate_dtype,
1110
+ )
1111
+
1112
+ kernel = cache.get_integrand_kernel(
1113
+ integrand=integrand,
1114
+ kernel_fn=integrate_kernel_fn,
1115
+ suffix=kernel_suffix,
1116
+ kernel_options=kernel_options,
1117
+ code_transformers=[
1118
+ PassFieldArgsToIntegrand(
1119
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
1120
+ )
1121
+ ],
1122
+ )
1123
+
1124
+ return kernel, FieldStruct, ValueStruct
1125
+
1126
+
1127
+ def _launch_integrate_kernel(
1128
+ integrand: Integrand,
1129
+ kernel: wp.Kernel,
1130
+ FieldStruct: wp.codegen.Struct,
1131
+ ValueStruct: wp.codegen.Struct,
1132
+ domain: GeometryDomain,
1133
+ quadrature: Quadrature,
1134
+ test: Optional[TestField],
1135
+ trial: Optional[TrialField],
1136
+ fields: Dict[str, FieldLike],
1137
+ values: Dict[str, Any],
1138
+ accumulate_dtype: type,
1139
+ temporary_store: Optional[cache.TemporaryStore],
1140
+ output_dtype: type,
1141
+ output: Optional[Union[wp.array, BsrMatrix]],
1142
+ add_to_output: bool,
1143
+ bsr_options: Optional[Dict[str, Any]],
1144
+ device,
1145
+ ):
1146
+ # Set-up launch arguments
1147
+ domain_elt_arg = domain.element_arg_value(device=device)
1148
+ domain_elt_index_arg = domain.element_index_arg_value(device=device)
1149
+
1150
+ if quadrature is not None:
1151
+ qp_arg = quadrature.arg_value(device=device)
1152
+
1153
+ field_arg_values = FieldStruct()
1154
+ for k, v in fields.items():
1155
+ if not isinstance(v, GeometryDomain):
1156
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
1157
+
1158
+ value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
1159
+
1160
+ # Constant form
1161
+ if test is None and trial is None:
1162
+ if output is not None and output.dtype == accumulate_dtype:
1163
+ if output.size < 1:
1164
+ raise RuntimeError("Output array must be of size at least 1")
1165
+ accumulate_array = output
1166
+ else:
1167
+ accumulate_temporary = cache.borrow_temporary(
1168
+ shape=(1),
1169
+ device=device,
1170
+ dtype=accumulate_dtype,
1171
+ temporary_store=temporary_store,
1172
+ requires_grad=output is not None and output.requires_grad,
1173
+ )
1174
+ accumulate_array = accumulate_temporary.array
1175
+
1176
+ if output != accumulate_array or not add_to_output:
1177
+ accumulate_array.zero_()
1178
+
1179
+ wp.launch(
1180
+ kernel=kernel,
1181
+ dim=quadrature.evaluation_point_count(),
1182
+ inputs=[
1183
+ qp_arg,
1184
+ quadrature.element_index_arg_value(device),
1185
+ domain_elt_arg,
1186
+ domain_elt_index_arg,
1187
+ field_arg_values,
1188
+ value_struct_values,
1189
+ accumulate_array,
1190
+ ],
1191
+ device=device,
1192
+ )
1193
+
1194
+ if output == accumulate_array:
1195
+ return output
1196
+ if output is None:
1197
+ return accumulate_array.numpy()[0]
1198
+
1199
+ if add_to_output:
1200
+ # accumulate dtype is distinct from output dtype
1201
+ array_axpy(x=accumulate_array, y=output)
1202
+ else:
1203
+ array_cast(in_array=accumulate_array, out_array=output)
1204
+ return output
1205
+
1206
+ test_arg = test.space_restriction.node_arg(device=device)
1207
+ nodal = quadrature is None
1208
+
1209
+ # Linear form
1210
+ if trial is None:
1211
+ # If an output array is provided with the correct type, accumulate directly into it
1212
+ # Otherwise, grab a temporary array
1213
+ if output is None:
1214
+ if type_length(output_dtype) == test.node_dof_count:
1215
+ output_shape = (test.space_partition.node_count(),)
1216
+ elif type_length(output_dtype) == 1:
1217
+ output_shape = (test.space_partition.node_count(), test.node_dof_count)
1218
+ else:
1219
+ raise RuntimeError(
1220
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1221
+ )
1222
+
1223
+ output_temporary = cache.borrow_temporary(
1224
+ temporary_store=temporary_store,
1225
+ shape=output_shape,
1226
+ dtype=output_dtype,
1227
+ device=device,
1228
+ )
1229
+
1230
+ output = output_temporary.array
1231
+
1232
+ else:
1233
+ output_temporary = None
1234
+
1235
+ if output.shape[0] < test.space_partition.node_count():
1236
+ raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
1237
+
1238
+ output_dtype = output.dtype
1239
+ if type_length(output_dtype) != test.node_dof_count:
1240
+ if type_length(output_dtype) != 1:
1241
+ raise RuntimeError(
1242
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1243
+ )
1244
+ if output.ndim != 2 and output.shape[1] != test.node_dof_count:
1245
+ raise RuntimeError(
1246
+ f"Incompatible output array shape, last dimension must be of size {test.node_dof_count}"
1247
+ )
1248
+
1249
+ # Launch the integration on the kernel on a 2d scalar view of the actual array
1250
+ if not add_to_output:
1251
+ output.zero_()
1252
+
1253
+ def as_2d_array(array):
1254
+ return wp.array(
1255
+ data=None,
1256
+ ptr=array.ptr,
1257
+ capacity=array.capacity,
1258
+ device=array.device,
1259
+ shape=(test.space_partition.node_count(), test.node_dof_count),
1260
+ dtype=wp.types.type_scalar_type(output_dtype),
1261
+ grad=None if array.grad is None else as_2d_array(array.grad),
1262
+ )
1263
+
1264
+ output_view = output if output.ndim == 2 else as_2d_array(output)
1265
+
1266
+ if nodal:
1267
+ wp.launch(
1268
+ kernel=kernel,
1269
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1270
+ inputs=[
1271
+ domain_elt_arg,
1272
+ domain_elt_index_arg,
1273
+ test_arg,
1274
+ test.space.topology.topo_arg_value(device),
1275
+ field_arg_values,
1276
+ value_struct_values,
1277
+ output_view,
1278
+ ],
1279
+ device=device,
1280
+ )
1281
+ elif isinstance(test, LocalTestField):
1282
+ local_result = cache.borrow_temporary(
1283
+ temporary_store=temporary_store,
1284
+ device=device,
1285
+ requires_grad=output.requires_grad,
1286
+ shape=(quadrature.evaluation_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1287
+ dtype=float,
1288
+ )
1289
+
1290
+ wp.launch(
1291
+ kernel=kernel,
1292
+ dim=local_result.array.shape,
1293
+ inputs=[
1294
+ qp_arg,
1295
+ quadrature.element_index_arg_value(device),
1296
+ domain_elt_arg,
1297
+ domain_elt_index_arg,
1298
+ field_arg_values,
1299
+ value_struct_values,
1300
+ local_result.array,
1301
+ ],
1302
+ device=device,
1303
+ )
1304
+
1305
+ dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
1306
+ wp.launch(
1307
+ kernel=dispatch_kernel,
1308
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1309
+ inputs=[
1310
+ qp_arg,
1311
+ domain_elt_arg,
1312
+ domain_elt_index_arg,
1313
+ test_arg,
1314
+ test.global_field.eval_arg_value(device),
1315
+ local_result.array,
1316
+ output_view,
1317
+ ],
1318
+ device=device,
1319
+ )
1320
+
1321
+ local_result.release()
1322
+
1323
+ else:
1324
+ wp.launch(
1325
+ kernel=kernel,
1326
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1327
+ inputs=[
1328
+ qp_arg,
1329
+ domain_elt_arg,
1330
+ domain_elt_index_arg,
1331
+ test_arg,
1332
+ field_arg_values,
1333
+ value_struct_values,
1334
+ output_view,
1335
+ ],
1336
+ device=device,
1337
+ )
1338
+
1339
+ if output_temporary is not None:
1340
+ return output_temporary.detach()
1341
+
1342
+ return output
1343
+
1344
+ # Bilinear form
1345
+
1346
+ if test.node_dof_count == 1 and trial.node_dof_count == 1:
1347
+ block_type = output_dtype
1348
+ else:
1349
+ block_type = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=output_dtype)
1350
+
1351
+ if nodal:
1352
+ nnz = test.space_restriction.node_count()
1353
+ else:
1354
+ nnz = test.space_restriction.total_node_element_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
1355
+
1356
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1357
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1358
+ triplet_values_temp = cache.borrow_temporary(
1359
+ temporary_store,
1360
+ shape=(
1361
+ nnz,
1362
+ test.node_dof_count,
1363
+ trial.node_dof_count,
1364
+ ),
1365
+ dtype=output_dtype,
1366
+ device=device,
1367
+ )
1368
+ triplet_cols = triplet_cols_temp.array
1369
+ triplet_rows = triplet_rows_temp.array
1370
+ triplet_values = triplet_values_temp.array
1371
+
1372
+ triplet_values.zero_()
1373
+
1374
+ if nodal:
1375
+ wp.launch(
1376
+ kernel=kernel,
1377
+ dim=triplet_values.shape,
1378
+ inputs=[
1379
+ domain_elt_arg,
1380
+ domain_elt_index_arg,
1381
+ test_arg,
1382
+ test.space.topology.topo_arg_value(device),
1383
+ field_arg_values,
1384
+ value_struct_values,
1385
+ triplet_rows,
1386
+ triplet_cols,
1387
+ triplet_values,
1388
+ ],
1389
+ device=device,
1390
+ )
1391
+ elif isinstance(test, LocalTestField):
1392
+ local_result = cache.borrow_temporary(
1393
+ temporary_store=temporary_store,
1394
+ device=device,
1395
+ requires_grad=False,
1396
+ shape=(
1397
+ quadrature.evaluation_point_count(),
1398
+ test.value_dof_count,
1399
+ trial.value_dof_count,
1400
+ test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
1401
+ ),
1402
+ dtype=float,
1403
+ )
1404
+
1405
+ wp.launch(
1406
+ kernel=kernel,
1407
+ dim=(
1408
+ quadrature.evaluation_point_count(),
1409
+ test.value_dof_count,
1410
+ trial.value_dof_count,
1411
+ trial.TAYLOR_DOF_COUNT,
1412
+ ),
1413
+ inputs=[
1414
+ qp_arg,
1415
+ quadrature.element_index_arg_value(device),
1416
+ domain_elt_arg,
1417
+ domain_elt_index_arg,
1418
+ field_arg_values,
1419
+ value_struct_values,
1420
+ local_result.array,
1421
+ ],
1422
+ device=device,
1423
+ )
1424
+
1425
+ vec_array_shape = (*local_result.array.shape[:-1], test.TAYLOR_DOF_COUNT)
1426
+ vec_array_dtype = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
1427
+ local_result_as_vec = wp.array(
1428
+ data=None,
1429
+ ptr=local_result.array.ptr,
1430
+ capacity=local_result.array.capacity,
1431
+ device=local_result.array.device,
1432
+ shape=vec_array_shape,
1433
+ dtype=vec_array_dtype,
1434
+ )
1435
+
1436
+ dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
1437
+
1438
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1439
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1440
+ wp.launch(
1441
+ kernel=dispatch_kernel,
1442
+ dim=(
1443
+ test.space_restriction.node_count(),
1444
+ test.node_dof_count,
1445
+ trial.node_dof_count,
1446
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1447
+ ),
1448
+ inputs=[
1449
+ qp_arg,
1450
+ domain_elt_arg,
1451
+ domain_elt_index_arg,
1452
+ test_arg,
1453
+ test.global_field.eval_arg_value(device),
1454
+ trial_partition_arg,
1455
+ trial_topology_arg,
1456
+ trial.global_field.eval_arg_value(device),
1457
+ local_result_as_vec,
1458
+ triplet_rows,
1459
+ triplet_cols,
1460
+ triplet_values,
1461
+ ],
1462
+ device=device,
1463
+ )
1464
+
1465
+ local_result.release()
1466
+
1467
+ else:
1468
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1469
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1470
+ wp.launch(
1471
+ kernel=kernel,
1472
+ dim=(
1473
+ test.space_restriction.node_count(),
1474
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1475
+ test.node_dof_count,
1476
+ trial.node_dof_count,
1477
+ ),
1478
+ inputs=[
1479
+ qp_arg,
1480
+ domain_elt_arg,
1481
+ domain_elt_index_arg,
1482
+ test_arg,
1483
+ trial_partition_arg,
1484
+ trial_topology_arg,
1485
+ field_arg_values,
1486
+ value_struct_values,
1487
+ triplet_rows,
1488
+ triplet_cols,
1489
+ triplet_values,
1490
+ ],
1491
+ device=device,
1492
+ )
1493
+
1494
+ if output is not None:
1495
+ if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
1496
+ raise RuntimeError(
1497
+ f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1498
+ )
1499
+
1500
+ if output is None or add_to_output:
1501
+ bsr_result = bsr_zeros(
1502
+ rows_of_blocks=test.space_partition.node_count(),
1503
+ cols_of_blocks=trial.space_partition.node_count(),
1504
+ block_type=block_type,
1505
+ device=device,
1506
+ )
1507
+ else:
1508
+ bsr_result = output
1509
+
1510
+ bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
1511
+
1512
+ # Do not wait for garbage collection
1513
+ triplet_values_temp.release()
1514
+ triplet_rows_temp.release()
1515
+ triplet_cols_temp.release()
1516
+
1517
+ if add_to_output:
1518
+ output += bsr_result
1519
+ else:
1520
+ output = bsr_result
1521
+
1522
+ return output
1523
+
1524
+
1525
+ def _pick_assembly_strategy(
1526
+ assembly: Optional[str], nodal: bool, operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
1527
+ ):
1528
+ if assembly is not None:
1529
+ if assembly not in ("generic", "nodal", "dispatch"):
1530
+ raise ValueError(f"Invalid assembly strategy'{assembly}'")
1531
+ return assembly
1532
+ elif nodal:
1533
+ return "nodal"
1534
+
1535
+ test_operators = operators.get(arguments.test_name, {})
1536
+ trial_operators = operators.get(arguments.trial_name, {})
1537
+ uses_at_node = at_node in test_operators or at_node in trial_operators
1538
+
1539
+ return "generic" if uses_at_node else "dispatch"
1540
+
1541
+
1542
+ def integrate(
1543
+ integrand: Integrand,
1544
+ domain: Optional[GeometryDomain] = None,
1545
+ quadrature: Optional[Quadrature] = None,
1546
+ nodal: bool = False,
1547
+ fields: Optional[Dict[str, FieldLike]] = None,
1548
+ values: Optional[Dict[str, Any]] = None,
1549
+ accumulate_dtype: type = wp.float64,
1550
+ output_dtype: Optional[type] = None,
1551
+ output: Optional[Union[BsrMatrix, wp.array]] = None,
1552
+ device=None,
1553
+ temporary_store: Optional[cache.TemporaryStore] = None,
1554
+ kernel_options: Optional[Dict[str, Any]] = None,
1555
+ assembly: Optional[str] = None,
1556
+ add: bool = False,
1557
+ bsr_options: Optional[Dict[str, Any]] = None,
1558
+ ):
1559
+ """
1560
+ Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
1561
+
1562
+ Args:
1563
+ integrand: Form to be integrated, must have :func:`integrand` decorator
1564
+ domain: Integration domain. If None, deduced from fields
1565
+ quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1566
+ nodal: Deprecated. Use the equivalent assembly="nodal" instead.
1567
+ fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1568
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1569
+ temporary_store: shared pool from which to allocate temporary arrays
1570
+ accumulate_dtype: Scalar type to be used for accumulating integration samples
1571
+ output: Sparse matrix or warp array into which to store the result of the integration
1572
+ output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
1573
+ device: Device on which to perform the integration
1574
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1575
+ assembly: Specifies the strategy for assembling the integrated vector or matrix:
1576
+ - "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1577
+ - "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
1578
+ - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
1579
+ - `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
1580
+ add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
1581
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
1582
+ """
1583
+ if fields is None:
1584
+ fields = {}
1585
+
1586
+ if values is None:
1587
+ values = {}
1588
+
1589
+ if not isinstance(integrand, Integrand):
1590
+ raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1591
+
1592
+ # test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1593
+ arguments = _parse_integrand_arguments(integrand, fields)
1594
+
1595
+ test = None
1596
+ if arguments.test_name:
1597
+ test = arguments.field_args[arguments.test_name]
1598
+ trial = None
1599
+ if arguments.trial_name:
1600
+ if test is None:
1601
+ raise ValueError("A trial field cannot be provided without a test field")
1602
+ trial = arguments.field_args[arguments.trial_name]
1603
+ if test.domain != trial.domain:
1604
+ raise ValueError("Incompatible test and trial domains")
1605
+
1606
+ if domain is None:
1607
+ if quadrature is not None:
1608
+ domain = quadrature.domain
1609
+ elif test is not None:
1610
+ domain = test.domain
1611
+
1612
+ if domain is None:
1613
+ raise ValueError("Must provide at least one of domain, quadrature, or test field")
1614
+ if test is not None and domain != test.domain:
1615
+ raise NotImplementedError("Mixing integration and test domain is not supported yet")
1616
+
1617
+ if add and output is None:
1618
+ raise ValueError("An 'output' array or matrix needs to be provided for add=True")
1619
+
1620
+ if arguments.domain_name is not None:
1621
+ arguments.field_args[arguments.domain_name] = domain
1622
+
1623
+ _find_integrand_operators(integrand, arguments.field_args)
1624
+
1625
+ assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
1626
+ # print("assembly for ", integrand.name, ":", strategy)
1627
+
1628
+ if assembly == "dispatch":
1629
+ if test is not None:
1630
+ test = LocalTestField(test)
1631
+ arguments.field_args[arguments.test_name] = test
1632
+ if trial is not None:
1633
+ trial = LocalTrialField(trial)
1634
+ arguments.field_args[arguments.trial_name] = trial
1635
+
1636
+ if assembly == "nodal":
1637
+ if quadrature is not None:
1638
+ raise ValueError("Cannot specify quadrature for nodal integration")
1639
+
1640
+ if test is None:
1641
+ raise ValueError("Nodal integration requires specifying a test function")
1642
+
1643
+ if trial is not None and test.space_partition != trial.space_partition:
1644
+ raise ValueError(
1645
+ "Bilinear nodal integration requires test and trial to be defined on the same function space"
1646
+ )
1647
+ else:
1648
+ if quadrature is None:
1649
+ order = sum(field.degree for field in fields.values())
1650
+ quadrature = RegularQuadrature(domain=domain, order=order)
1651
+ elif domain != quadrature.domain:
1652
+ raise ValueError("Incompatible integration and quadrature domain")
1653
+
1654
+ # Canonicalize types
1655
+ accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
1656
+ if output is not None:
1657
+ if isinstance(output, BsrMatrix):
1658
+ output_dtype = output.scalar_type
1659
+ else:
1660
+ output_dtype = output.dtype
1661
+ elif output_dtype is None:
1662
+ output_dtype = accumulate_dtype
1663
+ else:
1664
+ output_dtype = wp.types.type_to_warp(output_dtype)
1665
+
1666
+ kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1667
+ integrand=integrand,
1668
+ domain=domain,
1669
+ quadrature=quadrature,
1670
+ arguments=arguments,
1671
+ test=test,
1672
+ trial=trial,
1673
+ accumulate_dtype=accumulate_dtype,
1674
+ output_dtype=output_dtype,
1675
+ kernel_options=kernel_options,
1676
+ )
1677
+
1678
+ return _launch_integrate_kernel(
1679
+ integrand=integrand,
1680
+ kernel=kernel,
1681
+ FieldStruct=FieldStruct,
1682
+ ValueStruct=ValueStruct,
1683
+ domain=domain,
1684
+ quadrature=quadrature,
1685
+ test=test,
1686
+ trial=trial,
1687
+ fields=arguments.field_args,
1688
+ values=values,
1689
+ accumulate_dtype=accumulate_dtype,
1690
+ temporary_store=temporary_store,
1691
+ output_dtype=output_dtype,
1692
+ output=output,
1693
+ add_to_output=add,
1694
+ bsr_options=bsr_options,
1695
+ device=device,
1696
+ )
1697
+
1698
+
1699
+ def get_interpolate_to_field_function(
1700
+ integrand_func: wp.Function,
1701
+ domain: GeometryDomain,
1702
+ FieldStruct: wp.codegen.Struct,
1703
+ ValueStruct: wp.codegen.Struct,
1704
+ dest: FieldRestriction,
1705
+ ):
1706
+ value_type = dest.space.dtype
1707
+
1708
+ def interpolate_to_field_fn(
1709
+ local_node_index: int,
1710
+ domain_arg: domain.ElementArg,
1711
+ domain_index_arg: domain.ElementIndexArg,
1712
+ dest_node_arg: dest.space_restriction.NodeArg,
1713
+ dest_eval_arg: dest.field.EvalArg,
1714
+ fields: FieldStruct,
1715
+ values: ValueStruct,
1716
+ ):
1717
+ partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1718
+ element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1719
+
1720
+ test_dof_index = NULL_DOF_INDEX
1721
+ trial_dof_index = NULL_DOF_INDEX
1722
+ node_weight = 1.0
1723
+
1724
+ # Volume-weighted average across elements
1725
+ # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1726
+
1727
+ val_sum = value_type(0.0)
1728
+ vol_sum = float(0.0)
1729
+
1730
+ for n in range(element_beg, element_end):
1731
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
1732
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1733
+
1734
+ if n == element_beg:
1735
+ node_index = dest.space.topology.element_node_index(
1736
+ domain_arg, dest_eval_arg.topology_arg, element_index, node_element_index.node_index_in_element
1737
+ )
1738
+
1739
+ coords = dest.space.node_coords_in_element(
1740
+ domain_arg,
1741
+ dest_eval_arg.space_arg,
1742
+ element_index,
1743
+ node_element_index.node_index_in_element,
1744
+ )
1745
+
1746
+ if coords[0] != OUTSIDE:
1747
+ sample = Sample(
1748
+ element_index,
1749
+ coords,
1750
+ node_index,
1751
+ node_weight,
1752
+ test_dof_index,
1753
+ trial_dof_index,
1754
+ )
1755
+ vol = domain.element_measure(domain_arg, sample)
1756
+ val = integrand_func(sample, fields, values)
1757
+
1758
+ vol_sum += vol
1759
+ val_sum += vol * val
1760
+
1761
+ return val_sum, vol_sum
1762
+
1763
+ return interpolate_to_field_fn
1764
+
1765
+
1766
+ def get_interpolate_to_field_kernel(
1767
+ interpolate_to_field_fn: wp.Function,
1768
+ domain: GeometryDomain,
1769
+ FieldStruct: wp.codegen.Struct,
1770
+ ValueStruct: wp.codegen.Struct,
1771
+ dest: FieldRestriction,
1772
+ ):
1773
+ @wp.func
1774
+ def _find_node_in_element(
1775
+ domain_arg: domain.ElementArg,
1776
+ domain_index_arg: domain.ElementIndexArg,
1777
+ dest_node_arg: dest.space_restriction.NodeArg,
1778
+ dest_eval_arg: dest.field.EvalArg,
1779
+ partition_node_index: int,
1780
+ ):
1781
+ element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1782
+
1783
+ for n in range(element_beg, element_end):
1784
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
1785
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1786
+ coords = dest.space.node_coords_in_element(
1787
+ domain_arg,
1788
+ dest_eval_arg.space_arg,
1789
+ element_index,
1790
+ node_element_index.node_index_in_element,
1791
+ )
1792
+ if coords[0] != OUTSIDE:
1793
+ return element_index, node_element_index.node_index_in_element
1794
+
1795
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
1796
+
1797
+ def interpolate_to_field_kernel_fn(
1798
+ domain_arg: domain.ElementArg,
1799
+ domain_index_arg: domain.ElementIndexArg,
1800
+ dest_node_arg: dest.space_restriction.NodeArg,
1801
+ dest_eval_arg: dest.field.EvalArg,
1802
+ fields: FieldStruct,
1803
+ values: ValueStruct,
1804
+ ):
1805
+ local_node_index = wp.tid()
1806
+
1807
+ val_sum, vol_sum = interpolate_to_field_fn(
1808
+ local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1809
+ )
1810
+
1811
+ if vol_sum > 0.0:
1812
+ partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1813
+
1814
+ # Grab first element containing node; there must be at least one since vol_sum != 0
1815
+ element_index, node_index_in_element = _find_node_in_element(
1816
+ domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
1817
+ )
1818
+ dest.field.set_node_value(
1819
+ domain_arg,
1820
+ dest_eval_arg,
1821
+ element_index,
1822
+ node_index_in_element,
1823
+ partition_node_index,
1824
+ val_sum / vol_sum,
1825
+ )
1826
+
1827
+ return interpolate_to_field_kernel_fn
1828
+
1829
+
1830
+ def get_interpolate_at_quadrature_kernel(
1831
+ integrand_func: wp.Function,
1832
+ domain: GeometryDomain,
1833
+ quadrature: Quadrature,
1834
+ FieldStruct: wp.codegen.Struct,
1835
+ ValueStruct: wp.codegen.Struct,
1836
+ value_type: type,
1837
+ ):
1838
+ def interpolate_at_quadrature_nonvalued_kernel_fn(
1839
+ qp_arg: quadrature.Arg,
1840
+ qp_element_index_arg: quadrature.ElementIndexArg,
1841
+ domain_arg: quadrature.domain.ElementArg,
1842
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1843
+ fields: FieldStruct,
1844
+ values: ValueStruct,
1845
+ result: wp.array(dtype=float),
1846
+ ):
1847
+ qp_eval_index = wp.tid()
1848
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1849
+ if domain_element_index == NULL_ELEMENT_INDEX:
1850
+ return
1851
+
1852
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1853
+
1854
+ test_dof_index = NULL_DOF_INDEX
1855
+ trial_dof_index = NULL_DOF_INDEX
1856
+
1857
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1858
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1859
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1860
+
1861
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1862
+ integrand_func(sample, fields, values)
1863
+
1864
+ def interpolate_at_quadrature_kernel_fn(
1865
+ qp_arg: quadrature.Arg,
1866
+ qp_element_index_arg: quadrature.ElementIndexArg,
1867
+ domain_arg: quadrature.domain.ElementArg,
1868
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1869
+ fields: FieldStruct,
1870
+ values: ValueStruct,
1871
+ result: wp.array(dtype=value_type),
1872
+ ):
1873
+ qp_eval_index = wp.tid()
1874
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1875
+ if domain_element_index == NULL_ELEMENT_INDEX:
1876
+ return
1877
+
1878
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1879
+
1880
+ test_dof_index = NULL_DOF_INDEX
1881
+ trial_dof_index = NULL_DOF_INDEX
1882
+
1883
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1884
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1885
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1886
+
1887
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1888
+ result[qp_index] = integrand_func(sample, fields, values)
1889
+
1890
+ return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
1891
+
1892
+
1893
+ def get_interpolate_jacobian_at_quadrature_kernel(
1894
+ integrand_func: wp.Function,
1895
+ domain: GeometryDomain,
1896
+ quadrature: Quadrature,
1897
+ FieldStruct: wp.codegen.Struct,
1898
+ ValueStruct: wp.codegen.Struct,
1899
+ trial: TrialField,
1900
+ value_size: int,
1901
+ value_type: type,
1902
+ ):
1903
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
1904
+ VALUE_SIZE = wp.constant(value_size)
1905
+
1906
+ def interpolate_jacobian_kernel_fn(
1907
+ qp_arg: quadrature.Arg,
1908
+ qp_element_index_arg: quadrature.ElementIndexArg,
1909
+ domain_arg: domain.ElementArg,
1910
+ domain_index_arg: domain.ElementIndexArg,
1911
+ trial_partition_arg: trial.space_partition.PartitionArg,
1912
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
1913
+ fields: FieldStruct,
1914
+ values: ValueStruct,
1915
+ triplet_rows: wp.array(dtype=int),
1916
+ triplet_cols: wp.array(dtype=int),
1917
+ triplet_values: wp.array3d(dtype=value_type),
1918
+ ):
1919
+ qp_eval_index, trial_node, trial_dof = wp.tid()
1920
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1921
+
1922
+ if domain_element_index == NULL_ELEMENT_INDEX:
1923
+ return
1924
+
1925
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1926
+ if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
1927
+ return
1928
+
1929
+ element_trial_node_count = trial.space.topology.element_node_count(
1930
+ domain_arg, trial_topology_arg, element_index
1931
+ )
1932
+
1933
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1934
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1935
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1936
+
1937
+ block_offset = qp_index * MAX_NODES_PER_ELEMENT + trial_node
1938
+
1939
+ test_dof_index = NULL_DOF_INDEX
1940
+ trial_dof_index = DofIndex(trial_node, trial_dof)
1941
+
1942
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1943
+ val = integrand_func(sample, fields, values)
1944
+
1945
+ for k in range(VALUE_SIZE):
1946
+ triplet_values[block_offset, k, trial_dof] = basis_coefficient(val, k)
1947
+
1948
+ if trial_dof == 0:
1949
+ if trial_node < element_trial_node_count:
1950
+ trial_node_index = trial.space_partition.partition_node_index(
1951
+ trial_partition_arg,
1952
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
1953
+ )
1954
+ else:
1955
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
1956
+ triplet_rows[block_offset] = qp_index
1957
+ triplet_cols[block_offset] = trial_node_index
1958
+
1959
+ return interpolate_jacobian_kernel_fn
1960
+
1961
+
1962
+ def get_interpolate_free_kernel(
1963
+ integrand_func: wp.Function,
1964
+ domain: GeometryDomain,
1965
+ FieldStruct: wp.codegen.Struct,
1966
+ ValueStruct: wp.codegen.Struct,
1967
+ value_type: type,
1968
+ ):
1969
+ def interpolate_free_nonvalued_kernel_fn(
1970
+ dim: int,
1971
+ domain_arg: domain.ElementArg,
1972
+ fields: FieldStruct,
1973
+ values: ValueStruct,
1974
+ result: wp.array(dtype=float),
1975
+ ):
1976
+ qp_index = wp.tid()
1977
+ qp_weight = 1.0 / float(dim)
1978
+ element_index = NULL_ELEMENT_INDEX
1979
+ coords = Coords(OUTSIDE)
1980
+
1981
+ test_dof_index = NULL_DOF_INDEX
1982
+ trial_dof_index = NULL_DOF_INDEX
1983
+
1984
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1985
+ integrand_func(sample, fields, values)
1986
+
1987
+ def interpolate_free_kernel_fn(
1988
+ dim: int,
1989
+ domain_arg: domain.ElementArg,
1990
+ fields: FieldStruct,
1991
+ values: ValueStruct,
1992
+ result: wp.array(dtype=value_type),
1993
+ ):
1994
+ qp_index = wp.tid()
1995
+ qp_weight = 1.0 / float(dim)
1996
+ element_index = NULL_ELEMENT_INDEX
1997
+ coords = Coords(OUTSIDE)
1998
+
1999
+ test_dof_index = NULL_DOF_INDEX
2000
+ trial_dof_index = NULL_DOF_INDEX
2001
+
2002
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
2003
+
2004
+ result[qp_index] = integrand_func(sample, fields, values)
2005
+
2006
+ return interpolate_free_nonvalued_kernel_fn if value_type is None else interpolate_free_kernel_fn
2007
+
2008
+
2009
+ def _generate_interpolate_kernel(
2010
+ integrand: Integrand,
2011
+ domain: GeometryDomain,
2012
+ dest: Optional[Union[FieldLike, wp.array]],
2013
+ quadrature: Optional[Quadrature],
2014
+ arguments: IntegrandArguments,
2015
+ kernel_options: Optional[Dict[str, Any]] = None,
2016
+ ) -> wp.Kernel:
2017
+ # Generate field struct
2018
+ FieldStruct = _gen_field_struct(arguments.field_args)
2019
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
2020
+
2021
+ _notify_operator_usage(integrand, arguments.field_args)
2022
+
2023
+ # Check if kernel exist in cache
2024
+ field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
2025
+ if isinstance(dest, FieldRestriction):
2026
+ kernel_suffix = f"_itp_{field_names}_{dest.domain.name}_{dest.space_restriction.space_partition.name}"
2027
+ else:
2028
+ dest_dtype = dest.dtype if dest else None
2029
+ type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
2030
+ if quadrature is None:
2031
+ kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
2032
+ else:
2033
+ kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
2034
+
2035
+ kernel = cache.get_integrand_kernel(
2036
+ integrand=integrand,
2037
+ suffix=kernel_suffix,
2038
+ kernel_options=kernel_options,
2039
+ )
2040
+ if kernel is not None:
2041
+ return kernel, FieldStruct, ValueStruct
2042
+
2043
+ # Not found in cache, transform integrand and generate kernel
2044
+ _check_field_compat(integrand, arguments, domain)
2045
+
2046
+ integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
2047
+
2048
+ # Generate interpolation kernel
2049
+ if isinstance(dest, FieldRestriction):
2050
+ # need to split into kernel + function for differentiability
2051
+ interpolate_fn = get_interpolate_to_field_function(
2052
+ integrand_func,
2053
+ domain,
2054
+ dest=dest,
2055
+ FieldStruct=FieldStruct,
2056
+ ValueStruct=ValueStruct,
2057
+ )
2058
+
2059
+ interpolate_fn = cache.get_integrand_function(
2060
+ integrand=integrand,
2061
+ func=interpolate_fn,
2062
+ suffix=kernel_suffix,
2063
+ code_transformers=[
2064
+ PassFieldArgsToIntegrand(
2065
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
2066
+ )
2067
+ ],
2068
+ )
2069
+
2070
+ interpolate_kernel_fn = get_interpolate_to_field_kernel(
2071
+ interpolate_fn,
2072
+ domain,
2073
+ dest=dest,
2074
+ FieldStruct=FieldStruct,
2075
+ ValueStruct=ValueStruct,
2076
+ )
2077
+ elif quadrature is not None:
2078
+ if arguments.trial_name:
2079
+ trial = arguments.field_args[arguments.trial_name]
2080
+ interpolate_kernel_fn = get_interpolate_jacobian_at_quadrature_kernel(
2081
+ integrand_func,
2082
+ domain=domain,
2083
+ quadrature=quadrature,
2084
+ FieldStruct=FieldStruct,
2085
+ ValueStruct=ValueStruct,
2086
+ trial=trial,
2087
+ value_size=dest.block_shape[0],
2088
+ value_type=dest.scalar_type,
2089
+ )
2090
+ else:
2091
+ interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
2092
+ integrand_func,
2093
+ domain=domain,
2094
+ quadrature=quadrature,
2095
+ value_type=dest_dtype,
2096
+ FieldStruct=FieldStruct,
2097
+ ValueStruct=ValueStruct,
2098
+ )
2099
+ else:
2100
+ interpolate_kernel_fn = get_interpolate_free_kernel(
2101
+ integrand_func,
2102
+ domain=domain,
2103
+ value_type=dest_dtype,
2104
+ FieldStruct=FieldStruct,
2105
+ ValueStruct=ValueStruct,
2106
+ )
2107
+
2108
+ kernel = cache.get_integrand_kernel(
2109
+ integrand=integrand,
2110
+ kernel_fn=interpolate_kernel_fn,
2111
+ suffix=kernel_suffix,
2112
+ kernel_options=kernel_options,
2113
+ code_transformers=[
2114
+ PassFieldArgsToIntegrand(
2115
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
2116
+ )
2117
+ ],
2118
+ )
2119
+
2120
+ return kernel, FieldStruct, ValueStruct
2121
+
2122
+
2123
+ def _launch_interpolate_kernel(
2124
+ integrand: Integrand,
2125
+ kernel: wp.kernel,
2126
+ FieldStruct: wp.codegen.Struct,
2127
+ ValueStruct: wp.codegen.Struct,
2128
+ domain: GeometryDomain,
2129
+ dest: Optional[Union[FieldRestriction, wp.array]],
2130
+ quadrature: Optional[Quadrature],
2131
+ dim: int,
2132
+ trial: Optional[TrialField],
2133
+ fields: Dict[str, FieldLike],
2134
+ values: Dict[str, Any],
2135
+ temporary_store: Optional[cache.TemporaryStore],
2136
+ bsr_options: Optional[Dict[str, Any]],
2137
+ device,
2138
+ ) -> wp.Kernel:
2139
+ # Set-up launch arguments
2140
+ elt_arg = domain.element_arg_value(device=device)
2141
+ elt_index_arg = domain.element_index_arg_value(device=device)
2142
+
2143
+ field_arg_values = FieldStruct()
2144
+ for k, v in fields.items():
2145
+ if not isinstance(v, GeometryDomain):
2146
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
2147
+
2148
+ value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
2149
+
2150
+ if isinstance(dest, FieldRestriction):
2151
+ dest_node_arg = dest.space_restriction.node_arg(device=device)
2152
+ dest_eval_arg = dest.field.eval_arg_value(device=device)
2153
+
2154
+ wp.launch(
2155
+ kernel=kernel,
2156
+ dim=dest.space_restriction.node_count(),
2157
+ inputs=[
2158
+ elt_arg,
2159
+ elt_index_arg,
2160
+ dest_node_arg,
2161
+ dest_eval_arg,
2162
+ field_arg_values,
2163
+ value_struct_values,
2164
+ ],
2165
+ device=device,
2166
+ )
2167
+ return
2168
+
2169
+ if quadrature is None:
2170
+ wp.launch(
2171
+ kernel=kernel,
2172
+ dim=dim,
2173
+ inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
2174
+ device=device,
2175
+ )
2176
+ return
2177
+
2178
+ qp_arg = quadrature.arg_value(device)
2179
+ qp_element_index_arg = quadrature.element_index_arg_value(device)
2180
+ if trial is None:
2181
+ wp.launch(
2182
+ kernel=kernel,
2183
+ dim=quadrature.evaluation_point_count(),
2184
+ inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2185
+ device=device,
2186
+ )
2187
+ return
2188
+
2189
+ nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
2190
+
2191
+ if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
2192
+ raise RuntimeError(
2193
+ f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
2194
+ )
2195
+ if dest.block_shape[1] != trial.node_dof_count:
2196
+ raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
2197
+
2198
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2199
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2200
+ triplet_values_temp = cache.borrow_temporary(
2201
+ temporary_store,
2202
+ dtype=dest.scalar_type,
2203
+ shape=(nnz, *dest.block_shape),
2204
+ device=device,
2205
+ )
2206
+ triplet_cols = triplet_cols_temp.array
2207
+ triplet_rows = triplet_rows_temp.array
2208
+ triplet_values = triplet_values_temp.array
2209
+ triplet_rows.fill_(-1)
2210
+ triplet_values.zero_()
2211
+
2212
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
2213
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
2214
+
2215
+ wp.launch(
2216
+ kernel=kernel,
2217
+ dim=(quadrature.evaluation_point_count(), trial.space.topology.MAX_NODES_PER_ELEMENT, trial.node_dof_count),
2218
+ inputs=[
2219
+ qp_arg,
2220
+ qp_element_index_arg,
2221
+ elt_arg,
2222
+ elt_index_arg,
2223
+ trial_partition_arg,
2224
+ trial_topology_arg,
2225
+ field_arg_values,
2226
+ value_struct_values,
2227
+ triplet_rows,
2228
+ triplet_cols,
2229
+ triplet_values,
2230
+ ],
2231
+ device=device,
2232
+ )
2233
+
2234
+ bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
2235
+
2236
+
2237
+ @integrand
2238
+ def _identity_field(field: Field, s: Sample):
2239
+ return field(s)
2240
+
2241
+
2242
+ def interpolate(
2243
+ integrand: Union[Integrand, FieldLike],
2244
+ dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
2245
+ quadrature: Optional[Quadrature] = None,
2246
+ dim: int = 0,
2247
+ domain: Optional[Domain] = None,
2248
+ fields: Optional[Dict[str, FieldLike]] = None,
2249
+ values: Optional[Dict[str, Any]] = None,
2250
+ device=None,
2251
+ kernel_options: Optional[Dict[str, Any]] = None,
2252
+ temporary_store: Optional[cache.TemporaryStore] = None,
2253
+ bsr_options: Optional[Dict[str, Any]] = None,
2254
+ ):
2255
+ """
2256
+ Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
2257
+
2258
+ Args:
2259
+ integrand: Function to be interpolated: either a function with :func:`warp.fem.integrand` decorator or a field
2260
+ dest: Where to store the interpolation result. Can be either
2261
+
2262
+ - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
2263
+ - a normal warp ``array``, or ``None``. In this case, the interpolation samples will determined by the `quadrature` or `dim` arguments, in that order.
2264
+ quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
2265
+ dim: Number of interpolation samples if `dest` is not a discrete field or restriction and `quadrature` is ``None``.
2266
+ In this case, the ``Sample`` passed to the `integrand` will be invalid, but the sample point index ``s.qp_index`` can be used to define custom interpolation logic.
2267
+ domain: Interpolation domain, only used if `dest` is not a field restriction and `quadrature` is ``None``
2268
+ fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
2269
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
2270
+ device: Device on which to perform the interpolation
2271
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
2272
+ temporary_store: shared pool from which to allocate temporary arrays
2273
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
2274
+ """
2275
+
2276
+ if isinstance(integrand, FieldLike):
2277
+ fields = {"field": integrand}
2278
+ values = {}
2279
+ integrand = _identity_field
2280
+
2281
+ if fields is None:
2282
+ fields = {}
2283
+
2284
+ if values is None:
2285
+ values = {}
2286
+
2287
+ if not isinstance(integrand, Integrand):
2288
+ raise ValueError("integrand must be tagged with @integrand decorator")
2289
+
2290
+ arguments = _parse_integrand_arguments(integrand, fields)
2291
+ if arguments.test_name:
2292
+ raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
2293
+ if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
2294
+ raise ValueError(
2295
+ f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
2296
+ )
2297
+
2298
+ if isinstance(dest, DiscreteField):
2299
+ dest = make_restriction(dest, domain=domain)
2300
+
2301
+ if isinstance(dest, FieldRestriction):
2302
+ domain = dest.domain
2303
+ elif quadrature is not None:
2304
+ domain = quadrature.domain
2305
+
2306
+ if arguments.domain_name:
2307
+ arguments.field_args[arguments.domain_name] = domain
2308
+
2309
+ _find_integrand_operators(integrand, arguments.field_args)
2310
+
2311
+ kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
2312
+ integrand=integrand,
2313
+ domain=domain,
2314
+ dest=dest,
2315
+ quadrature=quadrature,
2316
+ arguments=arguments,
2317
+ kernel_options=kernel_options,
2318
+ )
2319
+
2320
+ return _launch_interpolate_kernel(
2321
+ integrand=integrand,
2322
+ kernel=kernel,
2323
+ FieldStruct=FieldStruct,
2324
+ ValueStruct=ValueStruct,
2325
+ domain=domain,
2326
+ dest=dest,
2327
+ quadrature=quadrature,
2328
+ dim=dim,
2329
+ trial=fields.get(arguments.trial_name),
2330
+ fields=arguments.field_args,
2331
+ values=values,
2332
+ temporary_store=temporary_store,
2333
+ bsr_options=bsr_options,
2334
+ device=device,
2335
+ )