warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2507 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ast
17
+ import inspect
18
+ import textwrap
19
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
20
+
21
+ import warp as wp
22
+ import warp._src.fem.operator as operator
23
+ from warp._src.codegen import Struct, StructInstance, get_annotations
24
+ from warp._src.fem import cache
25
+ from warp._src.fem.domain import GeometryDomain
26
+ from warp._src.fem.field import (
27
+ DiscreteField,
28
+ FieldLike,
29
+ FieldRestriction,
30
+ GeometryField,
31
+ LocalTestField,
32
+ LocalTrialField,
33
+ TestField,
34
+ TrialField,
35
+ make_restriction,
36
+ )
37
+ from warp._src.fem.field.virtual import (
38
+ make_bilinear_dispatch_kernel,
39
+ make_linear_dispatch_kernel,
40
+ )
41
+ from warp._src.fem.linalg import array_axpy, basis_coefficient
42
+ from warp._src.fem.operator import (
43
+ Integrand,
44
+ Operator,
45
+ integrand,
46
+ )
47
+ from warp._src.fem.quadrature import Quadrature, RegularQuadrature
48
+ from warp._src.fem.types import (
49
+ NULL_DOF_INDEX,
50
+ NULL_ELEMENT_INDEX,
51
+ NULL_NODE_INDEX,
52
+ OUTSIDE,
53
+ Coords,
54
+ DofIndex,
55
+ Domain,
56
+ Field,
57
+ Sample,
58
+ make_free_sample,
59
+ )
60
+ from warp._src.fem.utils import type_zero_element
61
+ from warp._src.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
62
+ from warp._src.types import is_array, type_repr, type_scalar_type, type_size, type_to_warp
63
+ from warp._src.utils import array_cast, warn
64
+
65
+ _wp_module_name_ = "warp.fem.integrate"
66
+
67
+
68
+ def _resolve_path(func, node):
69
+ """
70
+ Resolves variable and path from ast node/attribute (adapted from warp._src.codegen)
71
+ """
72
+
73
+ modules = []
74
+
75
+ while isinstance(node, ast.Attribute):
76
+ modules.append(node.attr)
77
+ node = node.value
78
+
79
+ if isinstance(node, ast.Name):
80
+ modules.append(node.id)
81
+
82
+ # reverse list since ast presents it backward order
83
+ path = [*reversed(modules)]
84
+
85
+ if len(path) == 0:
86
+ return None, path
87
+
88
+ name = path[0]
89
+ try:
90
+ # look up in closure variables
91
+ idx = func.__code__.co_freevars.index(name)
92
+ expr = func.__closure__[idx].cell_contents
93
+ except ValueError:
94
+ # look up in global variables
95
+ expr = func.__globals__.get(name)
96
+
97
+ for name in path[1:]:
98
+ if expr is not None:
99
+ expr = getattr(expr, name, None)
100
+
101
+ return expr, path
102
+
103
+
104
+ class IntegrandVisitor(ast.NodeTransformer):
105
+ class FieldInfo(NamedTuple):
106
+ field: FieldLike
107
+ abstract_type: type
108
+ concrete_type: type
109
+ root_arg_name: str
110
+ local_arg_name: str
111
+
112
+ def __init__(
113
+ self,
114
+ integrand: Integrand,
115
+ field_info: Dict[str, FieldInfo],
116
+ ):
117
+ self._integrand = integrand
118
+ self._field_symbols = field_info.copy()
119
+ self._field_nodes = {}
120
+ self._field_arg_annotation_nodes = {}
121
+
122
+ @staticmethod
123
+ def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
124
+ def get_concrete_type(field: Union[FieldLike, Domain]):
125
+ if isinstance(field, FieldLike):
126
+ return field.ElementEvalArg
127
+ elif isinstance(field, GeometryDomain):
128
+ return field.DomainArg
129
+ return field.ElementArg
130
+
131
+ return {
132
+ name: IntegrandVisitor.FieldInfo(
133
+ field=field,
134
+ abstract_type=integrand.argspec.annotations[name],
135
+ concrete_type=get_concrete_type(field),
136
+ root_arg_name=name,
137
+ local_arg_name=name,
138
+ )
139
+ for name, field in field_args.items()
140
+ }
141
+
142
+ def _get_field_info(self, node: ast.expr):
143
+ field_info = self._field_nodes.get(node)
144
+ if field_info is None and isinstance(node, ast.Name):
145
+ field_info = self._field_symbols.get(node.id)
146
+
147
+ return field_info
148
+
149
+ def visit_Call(self, call: ast.Call):
150
+ call = self.generic_visit(call)
151
+
152
+ callee = getattr(call.func, "id", None)
153
+ if callee in self._field_symbols:
154
+ # Shortcut for evaluating fields as f(x...)
155
+ field_info = self._field_symbols[callee]
156
+
157
+ # Replace with default call operator
158
+ default_operator = field_info.abstract_type.call_operator
159
+
160
+ self._process_operator_call(call, callee, default_operator, field_info)
161
+
162
+ return call
163
+
164
+ func, _ = _resolve_path(self._integrand.func, call.func)
165
+
166
+ if isinstance(func, Operator) and len(call.args) > 0:
167
+ # Evaluating operators as op(field, x, ...)
168
+ field_info = self._get_field_info(call.args[0])
169
+ if field_info is not None:
170
+ self._process_operator_call(call, func, func, field_info)
171
+
172
+ if func.field_result:
173
+ res = func.field_result(field_info.field)
174
+ self._field_nodes[call] = IntegrandVisitor.FieldInfo(
175
+ field=res[0],
176
+ abstract_type=res[1],
177
+ concrete_type=res[2],
178
+ local_arg_name=field_info.local_arg_name,
179
+ root_arg_name=f"{field_info.root_arg_name}.{func.name}",
180
+ )
181
+
182
+ if isinstance(func, Integrand):
183
+ callee_field_args = self._get_callee_field_args(func, call.args)
184
+ self._process_integrand_call(call, func, callee_field_args)
185
+
186
+ # print(ast.dump(call, indent=4))
187
+
188
+ return call
189
+
190
+ def visit_Assign(self, node: ast.Assign):
191
+ node = self.generic_visit(node)
192
+
193
+ # Check if we're assigning a field
194
+ src_field_info = self._get_field_info(node.value)
195
+ if src_field_info is not None:
196
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
197
+ raise NotImplementedError("warp.fem Fields and Domains may only be assigned to simple variables")
198
+
199
+ self._field_symbols[node.targets[0].id] = src_field_info
200
+
201
+ return node
202
+
203
+ def visit_FunctionDef(self, node: ast.FunctionDef):
204
+ # record field arg annotation nodes
205
+ for arg in node.args.args:
206
+ self._field_arg_annotation_nodes[arg.arg] = arg.annotation
207
+
208
+ return self.generic_visit(node)
209
+
210
+ def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
211
+ # Get field types for call site arguments
212
+ call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
213
+ for arg in args:
214
+ field_info = self._get_field_info(arg)
215
+ if field_info is not None:
216
+ call_site_field_args.append(field_info)
217
+
218
+ call_site_field_args.reverse()
219
+
220
+ # Pass to callee in same order
221
+ callee_field_args = {}
222
+ for arg in callee.argspec.args:
223
+ arg_type = callee.argspec.annotations[arg]
224
+ if arg_type in (Field, Domain):
225
+ passed_field_info = call_site_field_args.pop()
226
+ if passed_field_info.abstract_type != arg_type:
227
+ raise TypeError(
228
+ f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
229
+ )
230
+ callee_field_args[arg] = IntegrandVisitor.FieldInfo(
231
+ field=passed_field_info.field,
232
+ abstract_type=passed_field_info.abstract_type,
233
+ concrete_type=passed_field_info.concrete_type,
234
+ local_arg_name=arg,
235
+ root_arg_name=passed_field_info.root_arg_name,
236
+ )
237
+
238
+ return callee_field_args
239
+
240
+
241
+ class IntegrandOperatorParser(IntegrandVisitor):
242
+ def __init__(self, integrand: Integrand, field_info: Dict[str, IntegrandVisitor.FieldInfo], callback: Callable):
243
+ super().__init__(integrand, field_info)
244
+ self._operator_callback = callback
245
+
246
+ def _process_operator_call(
247
+ self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
248
+ ):
249
+ self._operator_callback(field_info, operator)
250
+
251
+ def _process_integrand_call(
252
+ self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
253
+ ):
254
+ callee_field_args = self._get_callee_field_args(callee, call.args)
255
+ callee_parser = IntegrandOperatorParser(callee, callee_field_args, callback=self._operator_callback)
256
+ callee_parser._apply()
257
+
258
+ def _apply(self):
259
+ source = textwrap.dedent(inspect.getsource(self._integrand.func))
260
+ tree = ast.parse(source)
261
+ self.visit(tree)
262
+
263
+ @staticmethod
264
+ def apply(
265
+ integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Optional[Callable] = None
266
+ ) -> wp.Function:
267
+ field_info = IntegrandVisitor._build_field_info(integrand, field_args)
268
+ IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
269
+
270
+
271
+ class IntegrandTransformer(IntegrandVisitor):
272
+ def _process_operator_call(
273
+ self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
274
+ ):
275
+ field = field_info.field
276
+
277
+ try:
278
+ # Retrieve the function pointer corresponding to the operator implementation for the field type
279
+ pointer = operator.resolver(field)
280
+ if not isinstance(pointer, wp.Function):
281
+ raise NotImplementedError(operator.resolver.__name__)
282
+
283
+ except (AttributeError, NotImplementedError) as e:
284
+ raise TypeError(
285
+ f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
286
+ ) from e
287
+
288
+ # Save the pointer as an attribute than can be accessed from the calling scope
289
+ # (use the annotation node of the argument this field is constructed from)
290
+ callee_node = self._field_arg_annotation_nodes[field_info.local_arg_name]
291
+ setattr(self._field_symbols[field_info.local_arg_name].abstract_type, pointer.key, pointer)
292
+ call.func = ast.Attribute(value=callee_node, attr=pointer.key, ctx=ast.Load())
293
+
294
+ # For shortcut default operator syntax, insert callee as first argument
295
+ if not isinstance(callee, Operator):
296
+ call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
297
+
298
+ # replace first argument with selected attribute
299
+ if operator.attr:
300
+ call.args[0] = ast.Attribute(value=call.args[0], attr=operator.attr)
301
+
302
+ def _process_integrand_call(
303
+ self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
304
+ ):
305
+ callee_field_args = self._get_callee_field_args(callee, call.args)
306
+ transformer = IntegrandTransformer(callee, callee_field_args)
307
+ key = transformer._apply().key
308
+ call.func = ast.Attribute(
309
+ value=call.func,
310
+ attr=key,
311
+ ctx=ast.Load(),
312
+ )
313
+
314
+ def _apply(self) -> wp.Function:
315
+ # Transform field evaluation calls
316
+ field_info = self._field_symbols
317
+
318
+ # Specialize field argument types
319
+ argspec = self._integrand.argspec
320
+ annotations = argspec.annotations.copy()
321
+ annotations.update({name: f.concrete_type for name, f in field_info.items()})
322
+
323
+ suffix = "_".join([f.field.name for f in field_info.values()])
324
+ func = cache.get_integrand_function(
325
+ integrand=self._integrand,
326
+ suffix=suffix,
327
+ annotations=annotations,
328
+ code_transformers=[self],
329
+ )
330
+
331
+ # func = self._integrand.module.functions[func.key] #no longer needed?
332
+ setattr(self._integrand, func.key, func)
333
+
334
+ return func
335
+
336
+ @staticmethod
337
+ def apply(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
338
+ field_info = IntegrandVisitor._build_field_info(integrand, field_args)
339
+ return IntegrandTransformer(integrand, field_info)._apply()
340
+
341
+
342
+ class IntegrandArguments(NamedTuple):
343
+ field_args: Dict[str, Union[FieldLike, GeometryDomain]]
344
+ value_args: Dict[str, Any]
345
+ domain_name: str
346
+ sample_name: str
347
+ test_name: str
348
+ trial_name: str
349
+
350
+
351
+ def _parse_integrand_arguments(
352
+ integrand: Integrand,
353
+ fields: Dict[str, FieldLike],
354
+ ):
355
+ # parse argument types
356
+ field_args = {}
357
+ value_args = {}
358
+
359
+ domain_name = None
360
+ sample_name = None
361
+ test_name = None
362
+ trial_name = None
363
+
364
+ argspec = integrand.argspec
365
+ for arg, arg_type in argspec.annotations.items():
366
+ if arg_type == Field:
367
+ try:
368
+ field = fields[arg]
369
+ except KeyError as err:
370
+ raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
371
+
372
+ if isinstance(field, TestField):
373
+ if test_name is not None:
374
+ raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
375
+ test_name = arg
376
+ elif isinstance(field, TrialField):
377
+ if trial_name is not None:
378
+ raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
379
+ trial_name = arg
380
+ elif not isinstance(field, FieldLike):
381
+ raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
382
+
383
+ field_args[arg] = field
384
+ continue
385
+
386
+ if arg in fields:
387
+ raise ValueError(
388
+ f"Cannot pass a field argument to '{arg}' of '{integrand.name}' which is not of type 'Field'"
389
+ )
390
+
391
+ if arg_type == Domain:
392
+ if domain_name is not None:
393
+ raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
394
+ domain_name = arg
395
+ elif arg_type == Sample:
396
+ if sample_name is not None:
397
+ raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
398
+ sample_name = arg
399
+ else:
400
+ value_args[arg] = arg_type
401
+
402
+ return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
403
+
404
+
405
+ def _check_field_compat(integrand: Integrand, arguments: IntegrandArguments, domain: GeometryDomain):
406
+ # Check field compatibility
407
+ for name, field in arguments.field_args.items():
408
+ if isinstance(field, GeometryField) and domain is not None:
409
+ if field.geometry != domain.geometry:
410
+ raise ValueError(f"Field '{name}' must be defined on the same geometry as the integration domain")
411
+ if field.element_kind != domain.element_kind:
412
+ raise ValueError(
413
+ f"Field '{name}' is not defined on the same kind of elements (cells or sides) as the integration domain. Maybe a forgotten `.trace()`?"
414
+ )
415
+
416
+
417
+ def _find_integrand_operators(integrand: Integrand, field_args: Dict[str, FieldLike]):
418
+ if integrand.operators is None:
419
+ # Integrands operator dictionary does not depend on concrete field type,
420
+ # so only needs to be built once per integrand
421
+
422
+ operators = {}
423
+
424
+ def operator_callback(field: IntegrandVisitor.FieldInfo, op: Operator):
425
+ if field.root_arg_name in operators:
426
+ operators[field.root_arg_name].add(op)
427
+ else:
428
+ operators[field.root_arg_name] = {op}
429
+
430
+ IntegrandOperatorParser.apply(integrand, field_args, operator_callback=operator_callback)
431
+
432
+ integrand.operators = operators
433
+
434
+
435
+ def _notify_operator_usage(
436
+ integrand: Integrand,
437
+ field_args: Dict[str, FieldLike],
438
+ ):
439
+ for arg, field in field_args.items():
440
+ field.notify_operator_usage(integrand.operators.get(arg, set()))
441
+
442
+
443
+ def _gen_field_struct(field_args: Dict[str, FieldLike]):
444
+ class Fields:
445
+ pass
446
+
447
+ annotations = get_annotations(Fields)
448
+
449
+ for name, arg in field_args.items():
450
+ if isinstance(arg, GeometryDomain):
451
+ continue
452
+ setattr(Fields, name, arg.EvalArg())
453
+ annotations[name] = arg.EvalArg
454
+
455
+ try:
456
+ Fields.__annotations__ = annotations
457
+ except AttributeError:
458
+ Fields.__dict__.__annotations__ = annotations
459
+
460
+ suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
461
+
462
+ return cache.get_struct(Fields, suffix=suffix)
463
+
464
+
465
+ def _get_trial_arg():
466
+ pass
467
+
468
+
469
+ def _get_test_arg():
470
+ pass
471
+
472
+
473
+ class PassFieldArgsToIntegrand(ast.NodeTransformer):
474
+ def __init__(
475
+ self,
476
+ arg_names: List[str],
477
+ parsed_args: IntegrandArguments,
478
+ integrand_func: wp.Function,
479
+ func_name: str = "integrand_func",
480
+ fields_var_name: str = "fields",
481
+ values_var_name: str = "values",
482
+ domain_var_name: str = "domain_arg",
483
+ domain_index_var_name: str = "domain_index_arg",
484
+ sample_var_name: str = "sample",
485
+ field_wrappers_attr: str = "_field_wrappers",
486
+ ):
487
+ self._arg_names = arg_names
488
+ self._field_args = parsed_args.field_args
489
+ self._value_args = parsed_args.value_args
490
+ self._domain_name = parsed_args.domain_name
491
+ self._sample_name = parsed_args.sample_name
492
+ self._test_name = parsed_args.test_name
493
+ self._trial_name = parsed_args.trial_name
494
+ self._func_name = func_name
495
+ self._fields_var_name = fields_var_name
496
+ self._values_var_name = values_var_name
497
+ self._domain_var_name = domain_var_name
498
+ self._domain_index_var_name = domain_index_var_name
499
+ self._sample_var_name = sample_var_name
500
+
501
+ self._field_wrappers_attr = field_wrappers_attr
502
+ self._register_integrand_field_wrappers(integrand_func, parsed_args.field_args)
503
+
504
+ class _FieldWrappers:
505
+ pass
506
+
507
+ def _register_integrand_field_wrappers(self, integrand_func: wp.Function, fields: Dict[str, FieldLike]):
508
+ # Mechanism to pass the geometry argument only once to the root kernel
509
+ # Field wrappers are used to forward it to all fields in nested integrand calls
510
+ field_wrappers = PassFieldArgsToIntegrand._FieldWrappers()
511
+ for name, field in fields.items():
512
+ if isinstance(field, FieldLike):
513
+ setattr(field_wrappers, name, field.ElementEvalArg)
514
+ elif isinstance(field, GeometryDomain):
515
+ setattr(field_wrappers, name, field.DomainArg)
516
+ setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
517
+
518
+ def _emit_field_wrapper_call(self, field_name, *data_arguments):
519
+ return ast.Call(
520
+ func=ast.Attribute(
521
+ value=ast.Attribute(
522
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
523
+ attr=self._field_wrappers_attr,
524
+ ctx=ast.Load(),
525
+ ),
526
+ attr=field_name,
527
+ ctx=ast.Load(),
528
+ ),
529
+ args=[
530
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
531
+ *data_arguments,
532
+ ],
533
+ keywords=[],
534
+ )
535
+
536
+ def visit_Call(self, call: ast.Call):
537
+ call = self.generic_visit(call)
538
+
539
+ callee = getattr(call.func, "id", None)
540
+
541
+ if callee == self._func_name:
542
+ # Replace function arguments with our generated structs
543
+ call.args.clear()
544
+ for arg in self._arg_names:
545
+ if arg == self._domain_name:
546
+ call.args.append(
547
+ self._emit_field_wrapper_call(
548
+ arg,
549
+ ast.Name(id=self._domain_index_var_name, ctx=ast.Load()),
550
+ )
551
+ )
552
+
553
+ elif arg == self._sample_name:
554
+ call.args.append(
555
+ ast.Name(id=self._sample_var_name, ctx=ast.Load()),
556
+ )
557
+ elif arg in self._field_args:
558
+ call.args.append(
559
+ self._emit_field_wrapper_call(
560
+ arg,
561
+ ast.Attribute(
562
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
563
+ attr=arg,
564
+ ctx=ast.Load(),
565
+ ),
566
+ )
567
+ )
568
+ elif arg in self._value_args:
569
+ call.args.append(
570
+ ast.Attribute(
571
+ value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
572
+ attr=arg,
573
+ ctx=ast.Load(),
574
+ )
575
+ )
576
+ else:
577
+ raise RuntimeError(f"Unhandled argument {arg}")
578
+ # print(ast.dump(call, indent=4))
579
+ elif callee == _get_test_arg.__name__:
580
+ # print(ast.dump(call, indent=4))
581
+ call = ast.Attribute(
582
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
583
+ attr=self._test_name,
584
+ ctx=ast.Load(),
585
+ )
586
+ elif callee == _get_trial_arg.__name__:
587
+ # print(ast.dump(call, indent=4))
588
+ call = ast.Attribute(
589
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
590
+ attr=self._trial_name,
591
+ ctx=ast.Load(),
592
+ )
593
+
594
+ return call
595
+
596
+
597
+ def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_site_options: Optional[Dict[str, Any]]):
598
+ if integrand_options is None:
599
+ return {} if call_site_options is None else call_site_options
600
+
601
+ options = integrand_options.copy()
602
+ if call_site_options is not None:
603
+ options.update(call_site_options)
604
+ return options
605
+
606
+
607
+ _INTEGRATE_CONSTANT_TILE_SIZE = 256
608
+
609
+
610
+ def get_integrate_constant_kernel(
611
+ integrand_func: wp.Function,
612
+ domain: GeometryDomain,
613
+ quadrature: Quadrature,
614
+ FieldStruct: Struct,
615
+ ValueStruct: Struct,
616
+ accumulate_dtype,
617
+ tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
618
+ ):
619
+ zero_element = type_zero_element(accumulate_dtype)
620
+
621
+ def integrate_kernel_fn(
622
+ qp_count: int,
623
+ qp_arg: quadrature.Arg,
624
+ qp_element_index_arg: quadrature.ElementIndexArg,
625
+ domain_arg: domain.ElementArg,
626
+ domain_index_arg: domain.ElementIndexArg,
627
+ fields: FieldStruct,
628
+ values: ValueStruct,
629
+ result: wp.array(dtype=accumulate_dtype),
630
+ ):
631
+ block_index, lane = wp.tid()
632
+ qp_eval_index = block_index * tile_size + lane
633
+
634
+ if qp_eval_index >= qp_count:
635
+ domain_element_index, qp = NULL_ELEMENT_INDEX, 0
636
+ else:
637
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
638
+
639
+ if domain_element_index == NULL_ELEMENT_INDEX:
640
+ element_index = NULL_ELEMENT_INDEX
641
+ else:
642
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
643
+
644
+ if element_index == NULL_ELEMENT_INDEX:
645
+ val = zero_element()
646
+ else:
647
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
648
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
649
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
650
+
651
+ test_dof_index = NULL_DOF_INDEX
652
+ trial_dof_index = NULL_DOF_INDEX
653
+
654
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
655
+ vol = domain.element_measure(domain_arg, sample)
656
+
657
+ val = accumulate_dtype(qp_weight * vol * integrand_func(sample, fields, values))
658
+
659
+ tile_integral = wp.tile_sum(wp.tile(val))
660
+ wp.tile_atomic_add(result, tile_integral, offset=0)
661
+
662
+ return integrate_kernel_fn
663
+
664
+
665
+ def get_integrate_linear_kernel(
666
+ integrand_func: wp.Function,
667
+ domain: GeometryDomain,
668
+ quadrature: Quadrature,
669
+ FieldStruct: Struct,
670
+ ValueStruct: Struct,
671
+ test: TestField,
672
+ output_dtype,
673
+ accumulate_dtype,
674
+ ):
675
+ def integrate_kernel_fn(
676
+ qp_arg: quadrature.Arg,
677
+ domain_arg: domain.ElementArg,
678
+ domain_index_arg: domain.ElementIndexArg,
679
+ test_arg: test.space_restriction.NodeArg,
680
+ fields: FieldStruct,
681
+ values: ValueStruct,
682
+ result: wp.array2d(dtype=output_dtype),
683
+ ):
684
+ local_node_index, test_dof = wp.tid()
685
+ node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
686
+ if node_index == NULL_NODE_INDEX:
687
+ return
688
+
689
+ element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
690
+
691
+ trial_dof_index = NULL_DOF_INDEX
692
+
693
+ val_sum = accumulate_dtype(0.0)
694
+
695
+ for n in range(element_beg, element_end):
696
+ node_element_index = test.space_restriction.node_element_index(test_arg, n)
697
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
698
+
699
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
700
+
701
+ qp_point_count = quadrature.point_count(
702
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index
703
+ )
704
+ for k in range(qp_point_count):
705
+ qp_index = quadrature.point_index(
706
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
707
+ )
708
+ qp_coords = quadrature.point_coords(
709
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
710
+ )
711
+ qp_weight = quadrature.point_weight(
712
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
713
+ )
714
+
715
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
716
+
717
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
718
+ val = integrand_func(sample, fields, values)
719
+
720
+ val_sum += accumulate_dtype(qp_weight * vol * val)
721
+
722
+ result[node_index, test_dof] += output_dtype(val_sum)
723
+
724
+ return integrate_kernel_fn
725
+
726
+
727
+ def get_integrate_linear_nodal_kernel(
728
+ integrand_func: wp.Function,
729
+ domain: GeometryDomain,
730
+ FieldStruct: Struct,
731
+ ValueStruct: Struct,
732
+ test: TestField,
733
+ output_dtype,
734
+ accumulate_dtype,
735
+ ):
736
+ def integrate_kernel_fn(
737
+ domain_arg: domain.ElementArg,
738
+ domain_index_arg: domain.ElementIndexArg,
739
+ test_restriction_arg: test.space_restriction.NodeArg,
740
+ test_topo_arg: test.space.topology.TopologyArg,
741
+ fields: FieldStruct,
742
+ values: ValueStruct,
743
+ result: wp.array2d(dtype=output_dtype),
744
+ ):
745
+ local_node_index, dof = wp.tid()
746
+
747
+ partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
748
+ if partition_node_index == NULL_NODE_INDEX:
749
+ return
750
+
751
+ element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
752
+
753
+ trial_dof_index = NULL_DOF_INDEX
754
+
755
+ val_sum = accumulate_dtype(0.0)
756
+
757
+ for n in range(element_beg, element_end):
758
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
759
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
760
+
761
+ if n == element_beg:
762
+ node_index = test.space.topology.element_node_index(
763
+ domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
764
+ )
765
+
766
+ coords = test.space.node_coords_in_element(
767
+ domain_arg,
768
+ _get_test_arg().space_arg,
769
+ element_index,
770
+ node_element_index.node_index_in_element,
771
+ )
772
+
773
+ if coords[0] != OUTSIDE:
774
+ node_weight = test.space.node_quadrature_weight(
775
+ domain_arg,
776
+ _get_test_arg().space_arg,
777
+ element_index,
778
+ node_element_index.node_index_in_element,
779
+ )
780
+
781
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
782
+
783
+ sample = Sample(
784
+ element_index,
785
+ coords,
786
+ node_index,
787
+ node_weight,
788
+ test_dof_index,
789
+ trial_dof_index,
790
+ )
791
+ vol = domain.element_measure(domain_arg, sample)
792
+ val = integrand_func(sample, fields, values)
793
+
794
+ val_sum += accumulate_dtype(node_weight * vol * val)
795
+
796
+ result[partition_node_index, dof] += output_dtype(val_sum)
797
+
798
+ return integrate_kernel_fn
799
+
800
+
801
+ def get_integrate_linear_local_kernel(
802
+ integrand_func: wp.Function,
803
+ domain: GeometryDomain,
804
+ quadrature: Quadrature,
805
+ FieldStruct: Struct,
806
+ ValueStruct: Struct,
807
+ test: LocalTestField,
808
+ ):
809
+ def integrate_kernel_fn(
810
+ qp_arg: quadrature.Arg,
811
+ qp_element_index_arg: quadrature.ElementIndexArg,
812
+ domain_arg: domain.ElementArg,
813
+ domain_index_arg: domain.ElementIndexArg,
814
+ fields: FieldStruct,
815
+ values: ValueStruct,
816
+ result: wp.array3d(dtype=float),
817
+ ):
818
+ qp_eval_index, taylor_dof, test_dof = wp.tid()
819
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
820
+
821
+ if domain_element_index == NULL_ELEMENT_INDEX:
822
+ return
823
+
824
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
825
+ if element_index == NULL_ELEMENT_INDEX:
826
+ return
827
+
828
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
829
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
830
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
831
+
832
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
833
+
834
+ trial_dof_index = NULL_DOF_INDEX
835
+ test_dof_index = DofIndex(taylor_dof, test_dof)
836
+
837
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
838
+ val = integrand_func(sample, fields, values)
839
+ result[qp_eval_index, taylor_dof, test_dof] = qp_weight * vol * val
840
+
841
+ return integrate_kernel_fn
842
+
843
+
844
+ def get_integrate_bilinear_kernel(
845
+ integrand_func: wp.Function,
846
+ domain: GeometryDomain,
847
+ quadrature: Quadrature,
848
+ FieldStruct: Struct,
849
+ ValueStruct: Struct,
850
+ test: TestField,
851
+ trial: TrialField,
852
+ output_dtype,
853
+ accumulate_dtype,
854
+ ):
855
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
856
+
857
+ def integrate_kernel_fn(
858
+ qp_arg: quadrature.Arg,
859
+ domain_arg: domain.ElementArg,
860
+ domain_index_arg: domain.ElementIndexArg,
861
+ test_arg: test.space_restriction.NodeArg,
862
+ trial_partition_arg: trial.space_partition.PartitionArg,
863
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
864
+ fields: FieldStruct,
865
+ values: ValueStruct,
866
+ triplet_rows: wp.array(dtype=int),
867
+ triplet_cols: wp.array(dtype=int),
868
+ triplet_values: wp.array3d(dtype=output_dtype),
869
+ ):
870
+ test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
871
+
872
+ test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
873
+ if test_node_index == NULL_NODE_INDEX:
874
+ return
875
+
876
+ element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
877
+
878
+ trial_dof_index = DofIndex(trial_node, trial_dof)
879
+
880
+ for element in range(element_beg, element_end):
881
+ test_element_index = test.space_restriction.node_element_index(test_arg, element)
882
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
883
+
884
+ element_trial_node_count = trial.space.topology.element_node_count(
885
+ domain_arg, trial_topology_arg, element_index
886
+ )
887
+ qp_point_count = wp.where(
888
+ trial_node < element_trial_node_count,
889
+ quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
890
+ 0,
891
+ )
892
+
893
+ test_dof_index = DofIndex(
894
+ test_element_index.node_index_in_element,
895
+ test_dof,
896
+ )
897
+
898
+ val_sum = accumulate_dtype(0.0)
899
+
900
+ for k in range(qp_point_count):
901
+ qp_index = quadrature.point_index(
902
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
903
+ )
904
+ coords = quadrature.point_coords(
905
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
906
+ )
907
+
908
+ qp_weight = quadrature.point_weight(
909
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
910
+ )
911
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
912
+
913
+ sample = Sample(
914
+ element_index,
915
+ coords,
916
+ qp_index,
917
+ qp_weight,
918
+ test_dof_index,
919
+ trial_dof_index,
920
+ )
921
+ val = integrand_func(sample, fields, values)
922
+ val_sum += accumulate_dtype(qp_weight * vol * val)
923
+
924
+ block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
925
+ triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
926
+
927
+ # Set row and column indices
928
+ if test_dof == 0 and trial_dof == 0:
929
+ if trial_node < element_trial_node_count:
930
+ trial_node_index = trial.space_partition.partition_node_index(
931
+ trial_partition_arg,
932
+ trial.space.topology.element_node_index(
933
+ domain_arg, trial_topology_arg, element_index, trial_node
934
+ ),
935
+ )
936
+ else:
937
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
938
+ triplet_rows[block_offset] = test_node_index
939
+ triplet_cols[block_offset] = trial_node_index
940
+
941
+ return integrate_kernel_fn
942
+
943
+
944
+ def get_integrate_bilinear_nodal_kernel(
945
+ integrand_func: wp.Function,
946
+ domain: GeometryDomain,
947
+ FieldStruct: Struct,
948
+ ValueStruct: Struct,
949
+ test: TestField,
950
+ output_dtype,
951
+ accumulate_dtype,
952
+ ):
953
+ def integrate_kernel_fn(
954
+ domain_arg: domain.ElementArg,
955
+ domain_index_arg: domain.ElementIndexArg,
956
+ test_restriction_arg: test.space_restriction.NodeArg,
957
+ test_topo_arg: test.space.topology.TopologyArg,
958
+ fields: FieldStruct,
959
+ values: ValueStruct,
960
+ triplet_rows: wp.array(dtype=int),
961
+ triplet_cols: wp.array(dtype=int),
962
+ triplet_values: wp.array3d(dtype=output_dtype),
963
+ ):
964
+ local_node_index, test_dof, trial_dof = wp.tid()
965
+
966
+ partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
967
+ if partition_node_index == NULL_NODE_INDEX:
968
+ triplet_rows[local_node_index] = -1
969
+ triplet_cols[local_node_index] = -1
970
+ return
971
+
972
+ element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
973
+
974
+ val_sum = accumulate_dtype(0.0)
975
+
976
+ for n in range(element_beg, element_end):
977
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
978
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
979
+
980
+ if n == element_beg:
981
+ node_index = test.space.topology.element_node_index(
982
+ domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
983
+ )
984
+
985
+ coords = test.space.node_coords_in_element(
986
+ domain_arg,
987
+ _get_test_arg().space_arg,
988
+ element_index,
989
+ node_element_index.node_index_in_element,
990
+ )
991
+
992
+ if coords[0] != OUTSIDE:
993
+ node_weight = test.space.node_quadrature_weight(
994
+ domain_arg,
995
+ _get_test_arg().space_arg,
996
+ element_index,
997
+ node_element_index.node_index_in_element,
998
+ )
999
+
1000
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
1001
+ trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
1002
+
1003
+ sample = Sample(
1004
+ element_index,
1005
+ coords,
1006
+ node_index,
1007
+ node_weight,
1008
+ test_dof_index,
1009
+ trial_dof_index,
1010
+ )
1011
+ vol = domain.element_measure(domain_arg, sample)
1012
+ val = integrand_func(sample, fields, values)
1013
+
1014
+ val_sum += accumulate_dtype(node_weight * vol * val)
1015
+
1016
+ triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
1017
+ triplet_rows[local_node_index] = partition_node_index
1018
+ triplet_cols[local_node_index] = partition_node_index
1019
+
1020
+ return integrate_kernel_fn
1021
+
1022
+
1023
+ def get_integrate_bilinear_local_kernel(
1024
+ integrand_func: wp.Function,
1025
+ domain: GeometryDomain,
1026
+ quadrature: Quadrature,
1027
+ FieldStruct: Struct,
1028
+ ValueStruct: Struct,
1029
+ test: LocalTestField,
1030
+ trial: LocalTrialField,
1031
+ ):
1032
+ TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
1033
+ TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
1034
+
1035
+ def integrate_kernel_fn(
1036
+ qp_arg: quadrature.Arg,
1037
+ qp_element_index_arg: quadrature.ElementIndexArg,
1038
+ domain_arg: domain.ElementArg,
1039
+ domain_index_arg: domain.ElementIndexArg,
1040
+ fields: FieldStruct,
1041
+ values: ValueStruct,
1042
+ result: wp.array4d(dtype=float),
1043
+ ):
1044
+ qp_eval_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
1045
+
1046
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1047
+ if domain_element_index == NULL_ELEMENT_INDEX:
1048
+ return
1049
+
1050
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1051
+ if element_index == NULL_ELEMENT_INDEX:
1052
+ return
1053
+
1054
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1055
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1056
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1057
+
1058
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
1059
+ qp_vol = vol * qp_weight
1060
+
1061
+ trial_dof_index = DofIndex(trial_taylor_dof, trial_dof)
1062
+
1063
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
1064
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
1065
+
1066
+ test_dof_index = DofIndex(test_taylor_dof, test_dof)
1067
+
1068
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1069
+ val = integrand_func(sample, fields, values)
1070
+ result[test_dof, trial_dof, qp_eval_index, taylor_dof] = qp_vol * val
1071
+
1072
+ return integrate_kernel_fn
1073
+
1074
+
1075
+ def _generate_integrate_kernel(
1076
+ integrand: Integrand,
1077
+ domain: GeometryDomain,
1078
+ quadrature: Quadrature,
1079
+ arguments: IntegrandArguments,
1080
+ test: Optional[TestField],
1081
+ trial: Optional[TrialField],
1082
+ output_dtype: type,
1083
+ accumulate_dtype: type,
1084
+ kernel_options: Optional[Dict[str, Any]] = None,
1085
+ ) -> wp.Kernel:
1086
+ output_dtype = type_scalar_type(output_dtype)
1087
+
1088
+ _notify_operator_usage(integrand, arguments.field_args)
1089
+
1090
+ # Check if kernel exist in cache
1091
+ field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
1092
+ kernel_suffix = ("itg", field_names, cache.pod_type_key(output_dtype), cache.pod_type_key(accumulate_dtype))
1093
+
1094
+ if quadrature is not None:
1095
+ kernel_suffix = (quadrature.name, *kernel_suffix)
1096
+
1097
+ kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
1098
+ integrand=integrand,
1099
+ suffix=kernel_suffix,
1100
+ kernel_options=kernel_options,
1101
+ )
1102
+ if kernel is not None:
1103
+ return kernel, field_arg_values, value_struct_values
1104
+
1105
+ FieldStruct = _gen_field_struct(arguments.field_args)
1106
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
1107
+
1108
+ # Not found in cache, transform integrand and generate kernel
1109
+ _check_field_compat(integrand, arguments, domain)
1110
+
1111
+ integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
1112
+
1113
+ nodal = quadrature is None
1114
+
1115
+ if test is None and trial is None:
1116
+ integrate_kernel_fn = get_integrate_constant_kernel(
1117
+ integrand_func,
1118
+ domain,
1119
+ quadrature,
1120
+ FieldStruct,
1121
+ ValueStruct,
1122
+ accumulate_dtype=accumulate_dtype,
1123
+ )
1124
+ elif trial is None:
1125
+ if nodal:
1126
+ integrate_kernel_fn = get_integrate_linear_nodal_kernel(
1127
+ integrand_func,
1128
+ domain,
1129
+ FieldStruct,
1130
+ ValueStruct,
1131
+ test=test,
1132
+ output_dtype=output_dtype,
1133
+ accumulate_dtype=accumulate_dtype,
1134
+ )
1135
+ elif isinstance(test, LocalTestField):
1136
+ integrate_kernel_fn = get_integrate_linear_local_kernel(
1137
+ integrand_func,
1138
+ domain,
1139
+ quadrature,
1140
+ FieldStruct,
1141
+ ValueStruct,
1142
+ test=test,
1143
+ )
1144
+ else:
1145
+ integrate_kernel_fn = get_integrate_linear_kernel(
1146
+ integrand_func,
1147
+ domain,
1148
+ quadrature,
1149
+ FieldStruct,
1150
+ ValueStruct,
1151
+ test=test,
1152
+ output_dtype=output_dtype,
1153
+ accumulate_dtype=accumulate_dtype,
1154
+ )
1155
+ else:
1156
+ if nodal:
1157
+ integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
1158
+ integrand_func,
1159
+ domain,
1160
+ FieldStruct,
1161
+ ValueStruct,
1162
+ test=test,
1163
+ output_dtype=output_dtype,
1164
+ accumulate_dtype=accumulate_dtype,
1165
+ )
1166
+ elif isinstance(test, LocalTestField):
1167
+ integrate_kernel_fn = get_integrate_bilinear_local_kernel(
1168
+ integrand_func,
1169
+ domain,
1170
+ quadrature,
1171
+ FieldStruct,
1172
+ ValueStruct,
1173
+ test=test,
1174
+ trial=trial,
1175
+ )
1176
+ else:
1177
+ integrate_kernel_fn = get_integrate_bilinear_kernel(
1178
+ integrand_func,
1179
+ domain,
1180
+ quadrature,
1181
+ FieldStruct,
1182
+ ValueStruct,
1183
+ test=test,
1184
+ trial=trial,
1185
+ output_dtype=output_dtype,
1186
+ accumulate_dtype=accumulate_dtype,
1187
+ )
1188
+
1189
+ kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
1190
+ integrand=integrand,
1191
+ kernel_fn=integrate_kernel_fn,
1192
+ suffix=kernel_suffix,
1193
+ kernel_options=kernel_options,
1194
+ code_transformers=[
1195
+ PassFieldArgsToIntegrand(
1196
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
1197
+ )
1198
+ ],
1199
+ FieldStruct=FieldStruct,
1200
+ ValueStruct=ValueStruct,
1201
+ )
1202
+
1203
+ return kernel, FieldStruct(), ValueStruct()
1204
+
1205
+
1206
+ def _generate_auxiliary_kernels(
1207
+ quadrature: Quadrature,
1208
+ test: Optional[TestField],
1209
+ trial: Optional[TrialField],
1210
+ accumulate_dtype: type,
1211
+ device,
1212
+ kernel_options: Optional[Dict[str, Any]] = None,
1213
+ ) -> List[Tuple[wp.Kernel, int]]:
1214
+ if test is None or not isinstance(test, LocalTestField):
1215
+ return ()
1216
+
1217
+ # For dispatched assembly, generate additional kernels
1218
+ # heuristic to use tiles for "long" quadratures
1219
+ dispatch_tile_size = 32
1220
+ qp_eval_count = quadrature.evaluation_point_count()
1221
+
1222
+ if trial is None:
1223
+ if (
1224
+ not device.is_cuda
1225
+ or qp_eval_count * test.space_restriction.total_node_element_count()
1226
+ < 3 * dispatch_tile_size * test.space_restriction.node_count() * test.domain.element_count()
1227
+ ):
1228
+ dispatch_tile_size = 1
1229
+ dispatch_kernel = make_linear_dispatch_kernel(
1230
+ test, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
1231
+ )
1232
+ else:
1233
+ if not device.is_cuda or qp_eval_count < 3 * dispatch_tile_size * test.domain.element_count():
1234
+ dispatch_tile_size = 1
1235
+ dispatch_kernel = make_bilinear_dispatch_kernel(
1236
+ test, trial, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
1237
+ )
1238
+
1239
+ return ((dispatch_kernel, dispatch_tile_size),)
1240
+
1241
+
1242
+ def _launch_integrate_kernel(
1243
+ integrand: Integrand,
1244
+ kernel: wp.Kernel,
1245
+ auxiliary_kernels: List[Tuple[wp.Kernel, int]],
1246
+ field_arg_values: StructInstance,
1247
+ value_struct_values: StructInstance,
1248
+ domain: GeometryDomain,
1249
+ quadrature: Quadrature,
1250
+ test: Optional[TestField],
1251
+ trial: Optional[TrialField],
1252
+ fields: Dict[str, FieldLike],
1253
+ values: Dict[str, Any],
1254
+ accumulate_dtype: type,
1255
+ temporary_store: Optional[cache.TemporaryStore],
1256
+ output_dtype: type,
1257
+ output: Optional[Union[wp.array, BsrMatrix]],
1258
+ add_to_output: bool,
1259
+ bsr_options: Optional[Dict[str, Any]],
1260
+ device,
1261
+ ):
1262
+ # Set-up launch arguments
1263
+ domain_elt_arg = domain.element_arg_value(device=device)
1264
+ domain_elt_index_arg = domain.element_index_arg_value(device=device)
1265
+
1266
+ if quadrature is not None:
1267
+ qp_arg = quadrature.arg_value(device=device)
1268
+
1269
+ for k, v in fields.items():
1270
+ if not isinstance(v, GeometryDomain):
1271
+ v.fill_eval_arg(getattr(field_arg_values, k), device=device)
1272
+
1273
+ cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
1274
+
1275
+ # Constant form
1276
+ if test is None and trial is None:
1277
+ if output is not None and output.dtype == accumulate_dtype:
1278
+ if output.size < 1:
1279
+ raise RuntimeError("Output array must be of size at least 1")
1280
+ accumulate_array = output
1281
+ else:
1282
+ accumulate_array = cache.borrow_temporary(
1283
+ shape=(1),
1284
+ device=device,
1285
+ dtype=accumulate_dtype,
1286
+ temporary_store=temporary_store,
1287
+ requires_grad=output is not None and output.requires_grad,
1288
+ )
1289
+
1290
+ if output != accumulate_array or not add_to_output:
1291
+ accumulate_array.zero_()
1292
+
1293
+ qp_count = quadrature.evaluation_point_count()
1294
+ tile_size = _INTEGRATE_CONSTANT_TILE_SIZE
1295
+ block_count = (qp_count + tile_size - 1) // tile_size
1296
+ wp.launch(
1297
+ kernel=kernel,
1298
+ dim=(block_count, tile_size),
1299
+ block_dim=tile_size,
1300
+ inputs=[
1301
+ qp_count,
1302
+ qp_arg,
1303
+ quadrature.element_index_arg_value(device),
1304
+ domain_elt_arg,
1305
+ domain_elt_index_arg,
1306
+ field_arg_values,
1307
+ value_struct_values,
1308
+ accumulate_array,
1309
+ ],
1310
+ device=device,
1311
+ )
1312
+
1313
+ if output == accumulate_array:
1314
+ return output
1315
+ if output is None:
1316
+ return accumulate_array.numpy()[0]
1317
+
1318
+ if add_to_output:
1319
+ # accumulate dtype is distinct from output dtype
1320
+ array_axpy(x=accumulate_array, y=output)
1321
+ else:
1322
+ array_cast(in_array=accumulate_array, out_array=output)
1323
+ return output
1324
+
1325
+ test_arg = test.space_restriction.node_arg_value(device=device)
1326
+ nodal = quadrature is None
1327
+
1328
+ # Linear form
1329
+ if trial is None:
1330
+ # If an output array is provided with the correct type, accumulate directly into it
1331
+ # Otherwise, grab a temporary array
1332
+ if output is None:
1333
+ if type_size(output_dtype) == test.node_dof_count:
1334
+ output_shape = (test.space_partition.node_count(),)
1335
+ elif type_size(output_dtype) == 1:
1336
+ output_shape = (test.space_partition.node_count(), test.node_dof_count)
1337
+ else:
1338
+ raise RuntimeError(
1339
+ f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1340
+ )
1341
+
1342
+ output = cache.borrow_temporary(
1343
+ temporary_store=temporary_store,
1344
+ shape=output_shape,
1345
+ dtype=output_dtype,
1346
+ device=device,
1347
+ )
1348
+
1349
+ else:
1350
+ if output.shape[0] < test.space_partition.node_count():
1351
+ raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
1352
+
1353
+ output_dtype = output.dtype
1354
+ if type_size(output_dtype) != test.node_dof_count:
1355
+ if type_size(output_dtype) != 1:
1356
+ raise RuntimeError(
1357
+ f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1358
+ )
1359
+ if output.ndim != 2 and output.shape[1] != test.node_dof_count:
1360
+ raise RuntimeError(
1361
+ f"Incompatible output array shape, last dimension must be of size {test.node_dof_count}"
1362
+ )
1363
+
1364
+ # Launch the integration on the kernel on a 2d scalar view of the actual array
1365
+ if not add_to_output:
1366
+ output.zero_()
1367
+
1368
+ def as_2d_array(array):
1369
+ return wp.array(
1370
+ data=None,
1371
+ ptr=array.ptr,
1372
+ capacity=array.capacity,
1373
+ device=array.device,
1374
+ shape=(test.space_partition.node_count(), test.node_dof_count),
1375
+ dtype=type_scalar_type(output_dtype),
1376
+ grad=None if array.grad is None else as_2d_array(array.grad),
1377
+ )
1378
+
1379
+ output_view = output if output.ndim == 2 else as_2d_array(output)
1380
+
1381
+ if nodal:
1382
+ wp.launch(
1383
+ kernel=kernel,
1384
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1385
+ inputs=[
1386
+ domain_elt_arg,
1387
+ domain_elt_index_arg,
1388
+ test_arg,
1389
+ test.space.topology.topo_arg_value(device),
1390
+ field_arg_values,
1391
+ value_struct_values,
1392
+ output_view,
1393
+ ],
1394
+ device=device,
1395
+ )
1396
+ elif isinstance(test, LocalTestField):
1397
+ local_result = cache.borrow_temporary(
1398
+ temporary_store=temporary_store,
1399
+ device=device,
1400
+ requires_grad=output.requires_grad,
1401
+ shape=(quadrature.evaluation_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1402
+ dtype=float,
1403
+ )
1404
+
1405
+ wp.launch(
1406
+ kernel=kernel,
1407
+ dim=local_result.shape,
1408
+ inputs=[
1409
+ qp_arg,
1410
+ quadrature.element_index_arg_value(device),
1411
+ domain_elt_arg,
1412
+ domain_elt_index_arg,
1413
+ field_arg_values,
1414
+ value_struct_values,
1415
+ local_result,
1416
+ ],
1417
+ device=device,
1418
+ )
1419
+
1420
+ if test.TAYLOR_DOF_COUNT == 0:
1421
+ warn(
1422
+ f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
1423
+ category=UserWarning,
1424
+ stacklevel=2,
1425
+ )
1426
+ else:
1427
+ dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
1428
+ wp.launch(
1429
+ kernel=dispatch_kernel,
1430
+ dim=(test.space_restriction.node_count(), dispatch_tile_size),
1431
+ block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
1432
+ inputs=[
1433
+ qp_arg,
1434
+ domain_elt_arg,
1435
+ domain_elt_index_arg,
1436
+ test_arg,
1437
+ test.space.space_arg_value(device),
1438
+ local_result,
1439
+ output_view,
1440
+ ],
1441
+ device=device,
1442
+ )
1443
+
1444
+ local_result.release()
1445
+
1446
+ else:
1447
+ wp.launch(
1448
+ kernel=kernel,
1449
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1450
+ inputs=[
1451
+ qp_arg,
1452
+ domain_elt_arg,
1453
+ domain_elt_index_arg,
1454
+ test_arg,
1455
+ field_arg_values,
1456
+ value_struct_values,
1457
+ output_view,
1458
+ ],
1459
+ device=device,
1460
+ )
1461
+
1462
+ return output
1463
+
1464
+ # Bilinear form
1465
+
1466
+ if test.node_dof_count == 1 and trial.node_dof_count == 1:
1467
+ block_type = output_dtype
1468
+ else:
1469
+ block_type = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=output_dtype)
1470
+
1471
+ if nodal:
1472
+ nnz = test.space_restriction.node_count()
1473
+ else:
1474
+ nnz = test.space_restriction.total_node_element_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
1475
+
1476
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1477
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1478
+ triplet_values_temp = cache.borrow_temporary(
1479
+ temporary_store,
1480
+ shape=(
1481
+ nnz,
1482
+ test.node_dof_count,
1483
+ trial.node_dof_count,
1484
+ ),
1485
+ dtype=output_dtype,
1486
+ device=device,
1487
+ )
1488
+ triplet_cols = triplet_cols_temp.array
1489
+ triplet_rows = triplet_rows_temp.array
1490
+ triplet_values = triplet_values_temp.array
1491
+
1492
+ if nodal:
1493
+ wp.launch(
1494
+ kernel=kernel,
1495
+ dim=triplet_values.shape,
1496
+ inputs=[
1497
+ domain_elt_arg,
1498
+ domain_elt_index_arg,
1499
+ test_arg,
1500
+ test.space.topology.topo_arg_value(device),
1501
+ field_arg_values,
1502
+ value_struct_values,
1503
+ triplet_rows,
1504
+ triplet_cols,
1505
+ triplet_values,
1506
+ ],
1507
+ device=device,
1508
+ )
1509
+ elif isinstance(test, LocalTestField):
1510
+ qp_eval_count = quadrature.evaluation_point_count()
1511
+ local_result = cache.borrow_temporary(
1512
+ temporary_store=temporary_store,
1513
+ device=device,
1514
+ requires_grad=False,
1515
+ shape=(
1516
+ test.value_dof_count,
1517
+ trial.value_dof_count,
1518
+ qp_eval_count,
1519
+ test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
1520
+ ),
1521
+ dtype=float,
1522
+ )
1523
+
1524
+ wp.launch(
1525
+ kernel=kernel,
1526
+ dim=(
1527
+ qp_eval_count,
1528
+ test.value_dof_count,
1529
+ trial.value_dof_count,
1530
+ trial.TAYLOR_DOF_COUNT,
1531
+ ),
1532
+ inputs=[
1533
+ qp_arg,
1534
+ quadrature.element_index_arg_value(device),
1535
+ domain_elt_arg,
1536
+ domain_elt_index_arg,
1537
+ field_arg_values,
1538
+ value_struct_values,
1539
+ local_result,
1540
+ ],
1541
+ device=device,
1542
+ )
1543
+
1544
+ if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
1545
+ warn(
1546
+ f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
1547
+ category=UserWarning,
1548
+ stacklevel=2,
1549
+ )
1550
+ triplet_rows.fill_(-1)
1551
+ else:
1552
+ dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
1553
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1554
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1555
+ wp.launch(
1556
+ kernel=dispatch_kernel,
1557
+ dim=(
1558
+ test.space_restriction.total_node_element_count(),
1559
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1560
+ dispatch_tile_size,
1561
+ ),
1562
+ block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
1563
+ inputs=[
1564
+ qp_arg,
1565
+ domain_elt_arg,
1566
+ domain_elt_index_arg,
1567
+ test_arg,
1568
+ test.space.space_arg_value(device),
1569
+ trial_partition_arg,
1570
+ trial_topology_arg,
1571
+ trial.space.space_arg_value(device),
1572
+ local_result,
1573
+ triplet_rows,
1574
+ triplet_cols,
1575
+ triplet_values,
1576
+ ],
1577
+ device=device,
1578
+ )
1579
+
1580
+ local_result.release()
1581
+
1582
+ else:
1583
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1584
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1585
+ wp.launch(
1586
+ kernel=kernel,
1587
+ dim=(
1588
+ test.space_restriction.node_count(),
1589
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1590
+ test.node_dof_count,
1591
+ trial.node_dof_count,
1592
+ ),
1593
+ inputs=[
1594
+ qp_arg,
1595
+ domain_elt_arg,
1596
+ domain_elt_index_arg,
1597
+ test_arg,
1598
+ trial_partition_arg,
1599
+ trial_topology_arg,
1600
+ field_arg_values,
1601
+ value_struct_values,
1602
+ triplet_rows,
1603
+ triplet_cols,
1604
+ triplet_values,
1605
+ ],
1606
+ device=device,
1607
+ )
1608
+
1609
+ if output is not None:
1610
+ if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
1611
+ raise RuntimeError(
1612
+ f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1613
+ )
1614
+
1615
+ if output is None or add_to_output:
1616
+ bsr_result = bsr_zeros(
1617
+ rows_of_blocks=test.space_partition.node_count(),
1618
+ cols_of_blocks=trial.space_partition.node_count(),
1619
+ block_type=block_type,
1620
+ device=device,
1621
+ )
1622
+ else:
1623
+ bsr_result = output
1624
+
1625
+ bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
1626
+
1627
+ # Do not wait for garbage collection
1628
+ triplet_values_temp.release()
1629
+ triplet_rows_temp.release()
1630
+ triplet_cols_temp.release()
1631
+
1632
+ if add_to_output:
1633
+ output += bsr_result
1634
+ else:
1635
+ output = bsr_result
1636
+
1637
+ return output
1638
+
1639
+
1640
+ def _pick_assembly_strategy(
1641
+ assembly: Optional[str], operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
1642
+ ):
1643
+ if assembly is not None:
1644
+ if assembly not in ("generic", "nodal", "dispatch"):
1645
+ raise ValueError(f"Invalid assembly strategy'{assembly}'")
1646
+ return assembly
1647
+
1648
+ test_operators = operators.get(arguments.test_name, set())
1649
+ trial_operators = operators.get(arguments.trial_name, set())
1650
+
1651
+ uses_virtual_node_operator = {operator.at_node, operator.node_count, operator.node_index} & (
1652
+ test_operators | trial_operators
1653
+ )
1654
+
1655
+ return "generic" if uses_virtual_node_operator else "dispatch"
1656
+
1657
+
1658
+ def integrate(
1659
+ integrand: Integrand,
1660
+ domain: Optional[GeometryDomain] = None,
1661
+ quadrature: Optional[Quadrature] = None,
1662
+ fields: Optional[Dict[str, FieldLike]] = None,
1663
+ values: Optional[Dict[str, Any]] = None,
1664
+ accumulate_dtype: type = wp.float64,
1665
+ output_dtype: Optional[type] = None,
1666
+ output: Optional[Union[BsrMatrix, wp.array]] = None,
1667
+ device=None,
1668
+ temporary_store: Optional[cache.TemporaryStore] = None,
1669
+ kernel_options: Optional[Dict[str, Any]] = None,
1670
+ assembly: Optional[str] = None,
1671
+ add: bool = False,
1672
+ bsr_options: Optional[Dict[str, Any]] = None,
1673
+ ):
1674
+ """
1675
+ Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
1676
+
1677
+ Args:
1678
+ integrand: Form to be integrated, must have :func:`integrand` decorator
1679
+ domain: Integration domain. If None, deduced from fields
1680
+ quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1681
+ fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1682
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1683
+ temporary_store: shared pool from which to allocate temporary arrays
1684
+ accumulate_dtype: Scalar type to be used for accumulating integration samples
1685
+ output: Sparse matrix or warp array into which to store the result of the integration
1686
+ output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
1687
+ device: Device on which to perform the integration
1688
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1689
+ assembly: Specifies the strategy for assembling the integrated vector or matrix:
1690
+ - "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1691
+ - "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
1692
+ - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` and `node_index` operators on test or trial functions.
1693
+ - `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
1694
+ add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
1695
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
1696
+ """
1697
+ if fields is None:
1698
+ fields = {}
1699
+
1700
+ if values is None:
1701
+ values = {}
1702
+
1703
+ if device is None:
1704
+ device = wp.get_device()
1705
+
1706
+ if not isinstance(integrand, Integrand):
1707
+ raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1708
+
1709
+ # test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1710
+ arguments = _parse_integrand_arguments(integrand, fields)
1711
+
1712
+ test = None
1713
+ if arguments.test_name:
1714
+ test = arguments.field_args[arguments.test_name]
1715
+ trial = None
1716
+ if arguments.trial_name:
1717
+ if test is None:
1718
+ raise ValueError("A trial field cannot be provided without a test field")
1719
+ trial = arguments.field_args[arguments.trial_name]
1720
+ if test.domain != trial.domain:
1721
+ raise ValueError("Incompatible test and trial domains")
1722
+
1723
+ if domain is None:
1724
+ if quadrature is not None:
1725
+ domain = quadrature.domain
1726
+ elif test is not None:
1727
+ domain = test.domain
1728
+
1729
+ if domain is None:
1730
+ raise ValueError("Must provide at least one of domain, quadrature, or test field")
1731
+ if test is not None and domain != test.domain:
1732
+ raise NotImplementedError("Mixing integration and test domain is not supported yet")
1733
+
1734
+ if add and output is None:
1735
+ raise ValueError("An 'output' array or matrix needs to be provided for add=True")
1736
+
1737
+ if arguments.domain_name is not None:
1738
+ arguments.field_args[arguments.domain_name] = domain
1739
+
1740
+ _find_integrand_operators(integrand, arguments.field_args)
1741
+
1742
+ if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
1743
+ warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
1744
+
1745
+ assembly = _pick_assembly_strategy(assembly, arguments=arguments, operators=integrand.operators)
1746
+ # print("assembly for ", integrand.name, ":", strategy)
1747
+
1748
+ if assembly == "dispatch":
1749
+ if test is not None:
1750
+ test = LocalTestField(test)
1751
+ arguments.field_args[arguments.test_name] = test
1752
+ if trial is not None:
1753
+ trial = LocalTrialField(trial)
1754
+ arguments.field_args[arguments.trial_name] = trial
1755
+
1756
+ if assembly == "nodal":
1757
+ if quadrature is not None:
1758
+ raise ValueError("Cannot specify quadrature for nodal integration")
1759
+
1760
+ if test is None:
1761
+ raise ValueError("Nodal integration requires specifying a test function")
1762
+
1763
+ if trial is not None and test.space_partition != trial.space_partition:
1764
+ raise ValueError(
1765
+ "Bilinear nodal integration requires test and trial to be defined on the same function space"
1766
+ )
1767
+ else:
1768
+ if quadrature is None:
1769
+ order = sum(field.degree for field in fields.values())
1770
+ quadrature = RegularQuadrature(domain=domain, order=order)
1771
+ elif domain != quadrature.domain:
1772
+ raise ValueError("Incompatible integration and quadrature domain")
1773
+
1774
+ # Canonicalize types
1775
+ accumulate_dtype = type_to_warp(accumulate_dtype)
1776
+ if output is not None:
1777
+ if isinstance(output, BsrMatrix):
1778
+ output_dtype = output.scalar_type
1779
+ else:
1780
+ output_dtype = output.dtype
1781
+ elif output_dtype is None:
1782
+ output_dtype = accumulate_dtype
1783
+ else:
1784
+ output_dtype = type_to_warp(output_dtype)
1785
+
1786
+ kernel, field_arg_values, value_struct_values = _generate_integrate_kernel(
1787
+ integrand=integrand,
1788
+ domain=domain,
1789
+ quadrature=quadrature,
1790
+ arguments=arguments,
1791
+ test=test,
1792
+ trial=trial,
1793
+ accumulate_dtype=accumulate_dtype,
1794
+ output_dtype=output_dtype,
1795
+ kernel_options=kernel_options,
1796
+ )
1797
+
1798
+ auxiliary_kernels = _generate_auxiliary_kernels(
1799
+ quadrature=quadrature,
1800
+ test=test,
1801
+ trial=trial,
1802
+ accumulate_dtype=accumulate_dtype,
1803
+ device=device,
1804
+ kernel_options=kernel_options,
1805
+ )
1806
+
1807
+ return _launch_integrate_kernel(
1808
+ integrand=integrand,
1809
+ kernel=kernel,
1810
+ auxiliary_kernels=auxiliary_kernels,
1811
+ field_arg_values=field_arg_values,
1812
+ value_struct_values=value_struct_values,
1813
+ domain=domain,
1814
+ quadrature=quadrature,
1815
+ test=test,
1816
+ trial=trial,
1817
+ fields=arguments.field_args,
1818
+ values=values,
1819
+ accumulate_dtype=accumulate_dtype,
1820
+ temporary_store=temporary_store,
1821
+ output_dtype=output_dtype,
1822
+ output=output,
1823
+ add_to_output=add,
1824
+ bsr_options=bsr_options,
1825
+ device=device,
1826
+ )
1827
+
1828
+
1829
+ def get_interpolate_to_field_function(
1830
+ integrand_func: wp.Function,
1831
+ domain: GeometryDomain,
1832
+ FieldStruct: Struct,
1833
+ ValueStruct: Struct,
1834
+ dest: FieldRestriction,
1835
+ ):
1836
+ zero_value = type_zero_element(dest.space.dtype)
1837
+
1838
+ def interpolate_to_field_fn(
1839
+ partition_node_index: int,
1840
+ domain_arg: domain.ElementArg,
1841
+ domain_index_arg: domain.ElementIndexArg,
1842
+ dest_node_arg: dest.space_restriction.NodeArg,
1843
+ dest_eval_arg: dest.field.EvalArg,
1844
+ fields: FieldStruct,
1845
+ values: ValueStruct,
1846
+ ):
1847
+ element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1848
+
1849
+ test_dof_index = NULL_DOF_INDEX
1850
+ trial_dof_index = NULL_DOF_INDEX
1851
+ node_weight = 1.0
1852
+
1853
+ # Volume-weighted average across elements
1854
+ # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1855
+
1856
+ val_sum = zero_value()
1857
+ vol_sum = float(0.0)
1858
+
1859
+ for n in range(element_beg, element_end):
1860
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
1861
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1862
+
1863
+ if n == element_beg:
1864
+ node_index = dest.space.topology.element_node_index(
1865
+ domain_arg, dest_eval_arg.topology_arg, element_index, node_element_index.node_index_in_element
1866
+ )
1867
+
1868
+ coords = dest.space.node_coords_in_element(
1869
+ domain_arg,
1870
+ dest_eval_arg.space_arg,
1871
+ element_index,
1872
+ node_element_index.node_index_in_element,
1873
+ )
1874
+
1875
+ if coords[0] != OUTSIDE:
1876
+ sample = Sample(
1877
+ element_index,
1878
+ coords,
1879
+ node_index,
1880
+ node_weight,
1881
+ test_dof_index,
1882
+ trial_dof_index,
1883
+ )
1884
+ vol = domain.element_measure(domain_arg, sample)
1885
+ val = integrand_func(sample, fields, values)
1886
+
1887
+ vol_sum += vol
1888
+ val_sum += vol * val
1889
+
1890
+ return val_sum, vol_sum
1891
+
1892
+ return interpolate_to_field_fn
1893
+
1894
+
1895
+ def get_interpolate_to_field_kernel(
1896
+ interpolate_to_field_fn: wp.Function,
1897
+ domain: GeometryDomain,
1898
+ FieldStruct: Struct,
1899
+ ValueStruct: Struct,
1900
+ dest: FieldRestriction,
1901
+ ):
1902
+ @wp.func
1903
+ def _find_node_in_element(
1904
+ domain_arg: domain.ElementArg,
1905
+ domain_index_arg: domain.ElementIndexArg,
1906
+ dest_node_arg: dest.space_restriction.NodeArg,
1907
+ dest_eval_arg: dest.field.EvalArg,
1908
+ partition_node_index: int,
1909
+ ):
1910
+ element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1911
+
1912
+ for n in range(element_beg, element_end):
1913
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
1914
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1915
+ coords = dest.space.node_coords_in_element(
1916
+ domain_arg,
1917
+ dest_eval_arg.space_arg,
1918
+ element_index,
1919
+ node_element_index.node_index_in_element,
1920
+ )
1921
+ if coords[0] != OUTSIDE:
1922
+ return element_index, node_element_index.node_index_in_element
1923
+
1924
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
1925
+
1926
+ def interpolate_to_field_kernel_fn(
1927
+ domain_arg: domain.ElementArg,
1928
+ domain_index_arg: domain.ElementIndexArg,
1929
+ dest_node_arg: dest.space_restriction.NodeArg,
1930
+ dest_eval_arg: dest.field.EvalArg,
1931
+ fields: FieldStruct,
1932
+ values: ValueStruct,
1933
+ ):
1934
+ local_node_index = wp.tid()
1935
+
1936
+ partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1937
+ if partition_node_index == NULL_NODE_INDEX:
1938
+ return
1939
+
1940
+ val_sum, vol_sum = interpolate_to_field_fn(
1941
+ local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1942
+ )
1943
+
1944
+ if vol_sum > 0.0:
1945
+ # Grab first element containing node; there must be at least one since vol_sum != 0
1946
+ element_index, node_index_in_element = _find_node_in_element(
1947
+ domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
1948
+ )
1949
+ dest.field.set_node_value(
1950
+ domain_arg,
1951
+ dest_eval_arg,
1952
+ element_index,
1953
+ node_index_in_element,
1954
+ partition_node_index,
1955
+ val_sum / vol_sum,
1956
+ )
1957
+
1958
+ return interpolate_to_field_kernel_fn
1959
+
1960
+
1961
+ def get_interpolate_at_quadrature_kernel(
1962
+ integrand_func: wp.Function,
1963
+ domain: GeometryDomain,
1964
+ quadrature: Quadrature,
1965
+ FieldStruct: Struct,
1966
+ ValueStruct: Struct,
1967
+ value_type: type,
1968
+ ):
1969
+ def interpolate_at_quadrature_nonvalued_kernel_fn(
1970
+ qp_arg: quadrature.Arg,
1971
+ qp_element_index_arg: quadrature.ElementIndexArg,
1972
+ domain_arg: quadrature.domain.ElementArg,
1973
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1974
+ fields: FieldStruct,
1975
+ values: ValueStruct,
1976
+ result: wp.array(dtype=float),
1977
+ ):
1978
+ qp_eval_index = wp.tid()
1979
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1980
+ if domain_element_index == NULL_ELEMENT_INDEX:
1981
+ return
1982
+
1983
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1984
+ if element_index == NULL_ELEMENT_INDEX:
1985
+ return
1986
+
1987
+ test_dof_index = NULL_DOF_INDEX
1988
+ trial_dof_index = NULL_DOF_INDEX
1989
+
1990
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1991
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1992
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1993
+
1994
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1995
+ integrand_func(sample, fields, values)
1996
+
1997
+ def interpolate_at_quadrature_kernel_fn(
1998
+ qp_arg: quadrature.Arg,
1999
+ qp_element_index_arg: quadrature.ElementIndexArg,
2000
+ domain_arg: quadrature.domain.ElementArg,
2001
+ domain_index_arg: quadrature.domain.ElementIndexArg,
2002
+ fields: FieldStruct,
2003
+ values: ValueStruct,
2004
+ result: wp.array(dtype=value_type),
2005
+ ):
2006
+ qp_eval_index = wp.tid()
2007
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
2008
+ if domain_element_index == NULL_ELEMENT_INDEX:
2009
+ return
2010
+
2011
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
2012
+ if element_index == NULL_ELEMENT_INDEX:
2013
+ return
2014
+
2015
+ test_dof_index = NULL_DOF_INDEX
2016
+ trial_dof_index = NULL_DOF_INDEX
2017
+
2018
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
2019
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
2020
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
2021
+
2022
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
2023
+ result[qp_index] = integrand_func(sample, fields, values)
2024
+
2025
+ return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
2026
+
2027
+
2028
+ def get_interpolate_jacobian_at_quadrature_kernel(
2029
+ integrand_func: wp.Function,
2030
+ domain: GeometryDomain,
2031
+ quadrature: Quadrature,
2032
+ FieldStruct: Struct,
2033
+ ValueStruct: Struct,
2034
+ trial: TrialField,
2035
+ value_size: int,
2036
+ value_type: type,
2037
+ ):
2038
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
2039
+ VALUE_SIZE = wp.constant(value_size)
2040
+
2041
+ def interpolate_jacobian_kernel_fn(
2042
+ qp_arg: quadrature.Arg,
2043
+ qp_element_index_arg: quadrature.ElementIndexArg,
2044
+ domain_arg: domain.ElementArg,
2045
+ domain_index_arg: domain.ElementIndexArg,
2046
+ trial_partition_arg: trial.space_partition.PartitionArg,
2047
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
2048
+ fields: FieldStruct,
2049
+ values: ValueStruct,
2050
+ triplet_rows: wp.array(dtype=int),
2051
+ triplet_cols: wp.array(dtype=int),
2052
+ triplet_values: wp.array3d(dtype=value_type),
2053
+ ):
2054
+ qp_eval_index, trial_node, trial_dof = wp.tid()
2055
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
2056
+ if domain_element_index == NULL_ELEMENT_INDEX:
2057
+ return
2058
+
2059
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
2060
+ if element_index == NULL_ELEMENT_INDEX:
2061
+ return
2062
+
2063
+ if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
2064
+ return
2065
+
2066
+ element_trial_node_count = trial.space.topology.element_node_count(
2067
+ domain_arg, trial_topology_arg, element_index
2068
+ )
2069
+
2070
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
2071
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
2072
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
2073
+
2074
+ block_offset = qp_index * MAX_NODES_PER_ELEMENT + trial_node
2075
+
2076
+ test_dof_index = NULL_DOF_INDEX
2077
+ trial_dof_index = DofIndex(trial_node, trial_dof)
2078
+
2079
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
2080
+ val = integrand_func(sample, fields, values)
2081
+
2082
+ for k in range(VALUE_SIZE):
2083
+ triplet_values[block_offset, k, trial_dof] = basis_coefficient(val, k)
2084
+
2085
+ if trial_dof == 0:
2086
+ if trial_node < element_trial_node_count:
2087
+ trial_node_index = trial.space_partition.partition_node_index(
2088
+ trial_partition_arg,
2089
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
2090
+ )
2091
+ else:
2092
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
2093
+ triplet_rows[block_offset] = qp_index
2094
+ triplet_cols[block_offset] = trial_node_index
2095
+
2096
+ return interpolate_jacobian_kernel_fn
2097
+
2098
+
2099
+ def get_interpolate_free_kernel(
2100
+ integrand_func: wp.Function,
2101
+ domain: GeometryDomain,
2102
+ FieldStruct: Struct,
2103
+ ValueStruct: Struct,
2104
+ value_type: type,
2105
+ ):
2106
+ def interpolate_free_nonvalued_kernel_fn(
2107
+ dim: int,
2108
+ domain_arg: domain.ElementArg,
2109
+ domain_index_arg: domain.ElementIndexArg,
2110
+ fields: FieldStruct,
2111
+ values: ValueStruct,
2112
+ result: wp.array(dtype=float),
2113
+ ):
2114
+ qp_index = wp.tid()
2115
+ qp_weight = 1.0 / float(dim)
2116
+ element_index = NULL_ELEMENT_INDEX
2117
+ coords = Coords(OUTSIDE)
2118
+
2119
+ test_dof_index = NULL_DOF_INDEX
2120
+ trial_dof_index = NULL_DOF_INDEX
2121
+
2122
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
2123
+ integrand_func(sample, fields, values)
2124
+
2125
+ def interpolate_free_kernel_fn(
2126
+ dim: int,
2127
+ domain_arg: domain.ElementArg,
2128
+ domain_index_arg: domain.ElementIndexArg,
2129
+ fields: FieldStruct,
2130
+ values: ValueStruct,
2131
+ result: wp.array(dtype=value_type),
2132
+ ):
2133
+ qp_index = wp.tid()
2134
+ qp_weight = 1.0 / float(dim)
2135
+ element_index = NULL_ELEMENT_INDEX
2136
+ coords = Coords(OUTSIDE)
2137
+
2138
+ test_dof_index = NULL_DOF_INDEX
2139
+ trial_dof_index = NULL_DOF_INDEX
2140
+
2141
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
2142
+
2143
+ result[qp_index] = integrand_func(sample, fields, values)
2144
+
2145
+ return interpolate_free_nonvalued_kernel_fn if value_type is None else interpolate_free_kernel_fn
2146
+
2147
+
2148
+ def _generate_interpolate_kernel(
2149
+ integrand: Integrand,
2150
+ domain: GeometryDomain,
2151
+ dest: Optional[Union[FieldLike, wp.array]],
2152
+ quadrature: Optional[Quadrature],
2153
+ arguments: IntegrandArguments,
2154
+ kernel_options: Optional[Dict[str, Any]] = None,
2155
+ ) -> wp.Kernel:
2156
+ _notify_operator_usage(integrand, arguments.field_args)
2157
+
2158
+ # Check if kernel exist in cache
2159
+ field_names = tuple((k, f.name) for k, f in arguments.field_args.items())
2160
+ if isinstance(dest, FieldRestriction):
2161
+ kernel_suffix = ("itp", *field_names, dest.domain.name, dest.space_restriction.space_partition.name)
2162
+ else:
2163
+ dest_dtype = dest.dtype if dest else None
2164
+ type_str = cache.pod_type_key(dest_dtype) if dest_dtype else ""
2165
+ if quadrature is None:
2166
+ kernel_suffix = ("itp", *field_names, domain.name, type_str)
2167
+ else:
2168
+ kernel_suffix = ("itp", *field_names, domain.name, quadrature.name, type_str)
2169
+
2170
+ kernel, field_arg_values, value_struct_values = cache.get_integrand_kernel(
2171
+ integrand=integrand,
2172
+ suffix=kernel_suffix,
2173
+ kernel_options=kernel_options,
2174
+ )
2175
+ if kernel is not None:
2176
+ return kernel, field_arg_values, value_struct_values
2177
+
2178
+ # Generate field struct
2179
+ FieldStruct = _gen_field_struct(arguments.field_args)
2180
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
2181
+
2182
+ # Not found in cache, transform integrand and generate kernel
2183
+ _check_field_compat(integrand, arguments, domain)
2184
+
2185
+ integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
2186
+
2187
+ # Generate interpolation kernel
2188
+ if isinstance(dest, FieldRestriction):
2189
+ # need to split into kernel + function for differentiability
2190
+ interpolate_fn = get_interpolate_to_field_function(
2191
+ integrand_func,
2192
+ domain,
2193
+ dest=dest,
2194
+ FieldStruct=FieldStruct,
2195
+ ValueStruct=ValueStruct,
2196
+ )
2197
+
2198
+ interpolate_fn = cache.get_integrand_function(
2199
+ integrand=integrand,
2200
+ func=interpolate_fn,
2201
+ suffix=kernel_suffix,
2202
+ code_transformers=[
2203
+ PassFieldArgsToIntegrand(
2204
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
2205
+ )
2206
+ ],
2207
+ )
2208
+
2209
+ interpolate_kernel_fn = get_interpolate_to_field_kernel(
2210
+ interpolate_fn,
2211
+ domain,
2212
+ dest=dest,
2213
+ FieldStruct=FieldStruct,
2214
+ ValueStruct=ValueStruct,
2215
+ )
2216
+ elif quadrature is not None:
2217
+ if arguments.trial_name:
2218
+ trial = arguments.field_args[arguments.trial_name]
2219
+ interpolate_kernel_fn = get_interpolate_jacobian_at_quadrature_kernel(
2220
+ integrand_func,
2221
+ domain=domain,
2222
+ quadrature=quadrature,
2223
+ FieldStruct=FieldStruct,
2224
+ ValueStruct=ValueStruct,
2225
+ trial=trial,
2226
+ value_size=dest.block_shape[0],
2227
+ value_type=dest.scalar_type,
2228
+ )
2229
+ else:
2230
+ interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
2231
+ integrand_func,
2232
+ domain=domain,
2233
+ quadrature=quadrature,
2234
+ value_type=dest_dtype,
2235
+ FieldStruct=FieldStruct,
2236
+ ValueStruct=ValueStruct,
2237
+ )
2238
+ else:
2239
+ interpolate_kernel_fn = get_interpolate_free_kernel(
2240
+ integrand_func,
2241
+ domain=domain,
2242
+ value_type=dest_dtype,
2243
+ FieldStruct=FieldStruct,
2244
+ ValueStruct=ValueStruct,
2245
+ )
2246
+
2247
+ kernel, _FieldStruct, _ValueStruct = cache.get_integrand_kernel(
2248
+ integrand=integrand,
2249
+ kernel_fn=interpolate_kernel_fn,
2250
+ suffix=kernel_suffix,
2251
+ kernel_options=kernel_options,
2252
+ code_transformers=[
2253
+ PassFieldArgsToIntegrand(
2254
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
2255
+ )
2256
+ ],
2257
+ FieldStruct=FieldStruct,
2258
+ ValueStruct=ValueStruct,
2259
+ )
2260
+
2261
+ return kernel, FieldStruct(), ValueStruct()
2262
+
2263
+
2264
+ def _launch_interpolate_kernel(
2265
+ integrand: Integrand,
2266
+ kernel: wp.kernel,
2267
+ field_arg_values: StructInstance,
2268
+ value_struct_values: StructInstance,
2269
+ domain: GeometryDomain,
2270
+ dest: Optional[Union[FieldRestriction, wp.array]],
2271
+ quadrature: Optional[Quadrature],
2272
+ dim: int,
2273
+ trial: Optional[TrialField],
2274
+ fields: Dict[str, FieldLike],
2275
+ values: Dict[str, Any],
2276
+ temporary_store: Optional[cache.TemporaryStore],
2277
+ bsr_options: Optional[Dict[str, Any]],
2278
+ device,
2279
+ ) -> wp.Kernel:
2280
+ # Set-up launch arguments
2281
+ elt_arg = domain.element_arg_value(device=device)
2282
+ elt_index_arg = domain.element_index_arg_value(device=device)
2283
+
2284
+ for k, v in fields.items():
2285
+ if not isinstance(v, GeometryDomain):
2286
+ v.fill_eval_arg(getattr(field_arg_values, k), device=device)
2287
+ cache.populate_argument_struct(value_struct_values, values, func_name=integrand.name)
2288
+
2289
+ if isinstance(dest, FieldRestriction):
2290
+ dest_node_arg = dest.space_restriction.node_arg_value(device=device)
2291
+ dest_eval_arg = dest.field.eval_arg_value(device=device)
2292
+
2293
+ wp.launch(
2294
+ kernel=kernel,
2295
+ dim=dest.space_restriction.node_count(),
2296
+ inputs=[
2297
+ elt_arg,
2298
+ elt_index_arg,
2299
+ dest_node_arg,
2300
+ dest_eval_arg,
2301
+ field_arg_values,
2302
+ value_struct_values,
2303
+ ],
2304
+ device=device,
2305
+ )
2306
+ return
2307
+
2308
+ if quadrature is None:
2309
+ if dest is not None and (not is_array(dest) or dest.shape[0] != dim):
2310
+ raise ValueError(f"dest must be a warp array with {dim} rows")
2311
+
2312
+ wp.launch(
2313
+ kernel=kernel,
2314
+ dim=dim,
2315
+ inputs=[dim, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2316
+ device=device,
2317
+ )
2318
+ return
2319
+
2320
+ qp_arg = quadrature.arg_value(device)
2321
+ qp_eval_count = quadrature.evaluation_point_count()
2322
+ qp_index_count = quadrature.total_point_count()
2323
+
2324
+ if qp_eval_count != qp_index_count:
2325
+ warn(
2326
+ f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
2327
+ category=UserWarning,
2328
+ stacklevel=2,
2329
+ )
2330
+
2331
+ qp_element_index_arg = quadrature.element_index_arg_value(device)
2332
+ if trial is None:
2333
+ if dest is not None and (not is_array(dest) or dest.shape[0] != qp_index_count):
2334
+ raise ValueError(f"dest must be a warp array with {qp_index_count} rows")
2335
+
2336
+ wp.launch(
2337
+ kernel=kernel,
2338
+ dim=qp_eval_count,
2339
+ inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2340
+ device=device,
2341
+ )
2342
+ return
2343
+
2344
+ nnz = qp_eval_count * trial.space.topology.MAX_NODES_PER_ELEMENT
2345
+
2346
+ if dest.nrow != qp_index_count or dest.ncol != trial.space_partition.node_count():
2347
+ raise RuntimeError(
2348
+ f"'dest' matrix must have {qp_index_count} rows and {trial.space_partition.node_count()} columns of blocks"
2349
+ )
2350
+ if dest.block_shape[1] != trial.node_dof_count:
2351
+ raise RuntimeError(f"'dest' matrix blocks must have {trial.node_dof_count} columns")
2352
+
2353
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2354
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2355
+ triplet_values_temp = cache.borrow_temporary(
2356
+ temporary_store,
2357
+ dtype=dest.scalar_type,
2358
+ shape=(nnz, *dest.block_shape),
2359
+ device=device,
2360
+ )
2361
+ triplet_cols = triplet_cols_temp.array
2362
+ triplet_rows = triplet_rows_temp.array
2363
+ triplet_values = triplet_values_temp.array
2364
+ triplet_rows.fill_(-1)
2365
+
2366
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
2367
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
2368
+
2369
+ wp.launch(
2370
+ kernel=kernel,
2371
+ dim=(quadrature.evaluation_point_count(), trial.space.topology.MAX_NODES_PER_ELEMENT, trial.node_dof_count),
2372
+ inputs=[
2373
+ qp_arg,
2374
+ qp_element_index_arg,
2375
+ elt_arg,
2376
+ elt_index_arg,
2377
+ trial_partition_arg,
2378
+ trial_topology_arg,
2379
+ field_arg_values,
2380
+ value_struct_values,
2381
+ triplet_rows,
2382
+ triplet_cols,
2383
+ triplet_values,
2384
+ ],
2385
+ device=device,
2386
+ )
2387
+
2388
+ bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
2389
+
2390
+
2391
+ @integrand
2392
+ def _identity_field(field: Field, s: Sample):
2393
+ return field(s)
2394
+
2395
+
2396
+ def interpolate(
2397
+ integrand: Union[Integrand, FieldLike],
2398
+ dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
2399
+ quadrature: Optional[Quadrature] = None,
2400
+ dim: Optional[int] = None,
2401
+ domain: Optional[Domain] = None,
2402
+ fields: Optional[Dict[str, FieldLike]] = None,
2403
+ values: Optional[Dict[str, Any]] = None,
2404
+ device=None,
2405
+ kernel_options: Optional[Dict[str, Any]] = None,
2406
+ temporary_store: Optional[cache.TemporaryStore] = None,
2407
+ bsr_options: Optional[Dict[str, Any]] = None,
2408
+ ):
2409
+ """
2410
+ Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
2411
+
2412
+ Args:
2413
+ integrand: Function to be interpolated: either a function with :func:`warp.fem.integrand` decorator or a field
2414
+ dest: Where to store the interpolation result. Can be either
2415
+
2416
+ - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
2417
+ - a normal warp ``array``, or ``None``. In this case, the interpolation samples will determined by the `quadrature` or `dim` arguments, in that order.
2418
+ quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
2419
+ dim: Number of interpolation samples if `dest` is not a discrete field or restriction and `quadrature` is ``None``.
2420
+ In this case, the ``Sample`` passed to the `integrand` will be invalid, but the sample point index ``s.qp_index`` can be used to define custom interpolation logic.
2421
+ domain: Interpolation domain, only used if `dest` is not a field restriction and `quadrature` is ``None``
2422
+ fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
2423
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
2424
+ device: Device on which to perform the interpolation
2425
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
2426
+ temporary_store: shared pool from which to allocate temporary arrays
2427
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
2428
+ """
2429
+
2430
+ if isinstance(integrand, FieldLike):
2431
+ fields = {"field": integrand}
2432
+ values = {}
2433
+ integrand = _identity_field
2434
+
2435
+ if fields is None:
2436
+ fields = {}
2437
+
2438
+ if values is None:
2439
+ values = {}
2440
+
2441
+ if device is None:
2442
+ device = wp.get_device()
2443
+
2444
+ if not isinstance(integrand, Integrand):
2445
+ raise ValueError("integrand must be tagged with @integrand decorator")
2446
+
2447
+ arguments = _parse_integrand_arguments(integrand, fields)
2448
+ if arguments.test_name:
2449
+ raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
2450
+ if arguments.trial_name and not isinstance(dest, BsrMatrix):
2451
+ raise ValueError(
2452
+ f"Interpolation using trial field '{arguments.trial_name}' requires 'dest' to be a `warp.sparse.BsrMatrix`"
2453
+ )
2454
+
2455
+ trial = arguments.field_args.get(arguments.trial_name, None)
2456
+
2457
+ if isinstance(dest, DiscreteField):
2458
+ dest = make_restriction(dest, domain=domain)
2459
+
2460
+ if isinstance(dest, FieldRestriction):
2461
+ domain = dest.domain
2462
+ elif quadrature is not None:
2463
+ domain = quadrature.domain
2464
+ elif dim is None:
2465
+ if trial is not None:
2466
+ domain = trial.domain
2467
+ elif domain is None:
2468
+ raise ValueError(
2469
+ "Unable to determine interpolation domain, provide an explicit field restriction or quadrature"
2470
+ )
2471
+
2472
+ # Default to one sample per domain element
2473
+ quadrature = RegularQuadrature(domain, order=0)
2474
+
2475
+ if arguments.domain_name:
2476
+ arguments.field_args[arguments.domain_name] = domain
2477
+
2478
+ _find_integrand_operators(integrand, arguments.field_args)
2479
+
2480
+ if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
2481
+ warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
2482
+
2483
+ kernel, field_struct, value_struct = _generate_interpolate_kernel(
2484
+ integrand=integrand,
2485
+ domain=domain,
2486
+ dest=dest,
2487
+ quadrature=quadrature,
2488
+ arguments=arguments,
2489
+ kernel_options=kernel_options,
2490
+ )
2491
+
2492
+ return _launch_interpolate_kernel(
2493
+ integrand=integrand,
2494
+ kernel=kernel,
2495
+ field_arg_values=field_struct,
2496
+ value_struct_values=value_struct,
2497
+ domain=domain,
2498
+ dest=dest,
2499
+ quadrature=quadrature,
2500
+ dim=dim,
2501
+ trial=trial,
2502
+ fields=arguments.field_args,
2503
+ values=values,
2504
+ temporary_store=temporary_store,
2505
+ bsr_options=bsr_options,
2506
+ device=device,
2507
+ )