warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py ADDED
@@ -0,0 +1,1616 @@
1
+ from typing import List, Dict, Set, Optional, Any, Union
2
+
3
+ import warp as wp
4
+
5
+ import re
6
+ import ast
7
+
8
+ from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy, bsr_assign
9
+ from warp.types import type_length
10
+ from warp.utils import array_cast
11
+ from warp.codegen import get_annotations
12
+
13
+ from warp.fem.domain import GeometryDomain
14
+ from warp.fem.field import (
15
+ TestField,
16
+ TrialField,
17
+ FieldLike,
18
+ DiscreteField,
19
+ FieldRestriction,
20
+ make_restriction,
21
+ )
22
+ from warp.fem.quadrature import Quadrature, RegularQuadrature
23
+ from warp.fem.operator import Operator, Integrand
24
+ from warp.fem import cache
25
+ from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE, make_free_sample
26
+
27
+
28
+ def _resolve_path(func, node):
29
+ """
30
+ Resolves variable and path from ast node/attribute (adapted from warp.codegen)
31
+ """
32
+
33
+ modules = []
34
+
35
+ while isinstance(node, ast.Attribute):
36
+ modules.append(node.attr)
37
+ node = node.value
38
+
39
+ if isinstance(node, ast.Name):
40
+ modules.append(node.id)
41
+
42
+ # reverse list since ast presents it backward order
43
+ path = [*reversed(modules)]
44
+
45
+ if len(path) == 0:
46
+ return None, path
47
+
48
+ # try and evaluate object path
49
+ try:
50
+ # Look up the closure info and append it to adj.func.__globals__
51
+ # in case you want to define a kernel inside a function and refer
52
+ # to variables you've declared inside that function:
53
+ capturedvars = dict(
54
+ zip(
55
+ func.__code__.co_freevars,
56
+ [c.cell_contents for c in (func.__closure__ or [])],
57
+ )
58
+ )
59
+
60
+ vars_dict = {**func.__globals__, **capturedvars}
61
+ func = eval(".".join(path), vars_dict)
62
+ return func, path
63
+ except (NameError, AttributeError):
64
+ pass
65
+
66
+ return None, path
67
+
68
+
69
+ def _path_to_ast_attribute(name: str) -> ast.Attribute:
70
+ path = name.split(".")
71
+ path.reverse()
72
+
73
+ node = ast.Name(id=path.pop(), ctx=ast.Load())
74
+ while len(path):
75
+ node = ast.Attribute(
76
+ value=node,
77
+ attr=path.pop(),
78
+ ctx=ast.Load(),
79
+ )
80
+ return node
81
+
82
+
83
+ class IntegrandTransformer(ast.NodeTransformer):
84
+ def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
85
+ self._integrand = integrand
86
+ self._field_args = field_args
87
+
88
+ def visit_Call(self, call: ast.Call):
89
+ call = self.generic_visit(call)
90
+
91
+ callee = getattr(call.func, "id", None)
92
+ if callee in self._field_args:
93
+ # Shortcut for evaluating fields as f(x...)
94
+ field = self._field_args[callee]
95
+
96
+ arg_type = self._integrand.argspec.annotations[callee]
97
+ operator = arg_type.call_operator
98
+
99
+ call.func = ast.Attribute(
100
+ value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
101
+ attr="call_operator",
102
+ ctx=ast.Load(),
103
+ )
104
+ call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
105
+
106
+ self._replace_call_func(call, operator, field)
107
+
108
+ return call
109
+
110
+ func, _ = _resolve_path(self._integrand.func, call.func)
111
+
112
+ if isinstance(func, Operator) and len(call.args) > 0:
113
+ # Evaluating operators as op(field, x, ...)
114
+ callee = getattr(call.args[0], "id", None)
115
+ if callee in self._field_args:
116
+ field = self._field_args[callee]
117
+ self._replace_call_func(call, func, field)
118
+
119
+ if isinstance(func, Integrand):
120
+ key = self._translate_callee(func, call.args)
121
+ call.func = ast.Attribute(
122
+ value=call.func,
123
+ attr=key,
124
+ ctx=ast.Load(),
125
+ )
126
+
127
+ # print(ast.dump(call, indent=4))
128
+
129
+ return call
130
+
131
+ def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
132
+ try:
133
+ pointer = operator.resolver(field)
134
+ setattr(operator, pointer.key, pointer)
135
+ except AttributeError:
136
+ raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}")
137
+ call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
138
+
139
+ def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
140
+ # Get field types for call site arguments
141
+ call_site_field_args = []
142
+ for arg in args:
143
+ name = getattr(arg, "id", None)
144
+ if name in self._field_args:
145
+ call_site_field_args.append(self._field_args[name])
146
+
147
+ call_site_field_args.reverse()
148
+
149
+ # Pass to callee in same order
150
+ callee_field_args = {}
151
+ for arg in callee.argspec.args:
152
+ arg_type = callee.argspec.annotations[arg]
153
+ if arg_type in (Field, Domain):
154
+ callee_field_args[arg] = call_site_field_args.pop()
155
+
156
+ return _translate_integrand(callee, callee_field_args).key
157
+
158
+
159
+ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
160
+ # Specialize field argument types
161
+ argspec = integrand.argspec
162
+ annotations = {}
163
+ for arg in argspec.args:
164
+ arg_type = argspec.annotations[arg]
165
+ if arg_type == Field:
166
+ annotations[arg] = field_args[arg].ElementEvalArg
167
+ elif arg_type == Domain:
168
+ annotations[arg] = field_args[arg].ElementArg
169
+ else:
170
+ annotations[arg] = arg_type
171
+
172
+ # Transform field evaluation calls
173
+ transformer = IntegrandTransformer(integrand, field_args)
174
+
175
+ suffix = "_".join([f.name for f in field_args.values()])
176
+
177
+ func = cache.get_integrand_function(
178
+ integrand=integrand,
179
+ suffix=suffix,
180
+ annotations=annotations,
181
+ code_transformers=[transformer],
182
+ )
183
+
184
+ key = func.key
185
+ setattr(integrand, key, integrand.module.functions[key])
186
+
187
+ return getattr(integrand, key)
188
+
189
+
190
+ def _get_integrand_field_arguments(
191
+ integrand: Integrand,
192
+ fields: Dict[str, FieldLike],
193
+ domain: GeometryDomain = None,
194
+ ):
195
+ # parse argument types
196
+ field_args = {}
197
+ value_args = {}
198
+
199
+ domain_name = None
200
+ sample_name = None
201
+
202
+ argspec = integrand.argspec
203
+ for arg in argspec.args:
204
+ arg_type = argspec.annotations[arg]
205
+ if arg_type == Field:
206
+ if arg not in fields:
207
+ raise ValueError(f"Missing field for argument '{arg}'")
208
+ field_args[arg] = fields[arg]
209
+ elif arg_type == Domain:
210
+ domain_name = arg
211
+ field_args[arg] = domain
212
+ elif arg_type == Sample:
213
+ sample_name = arg
214
+ else:
215
+ value_args[arg] = arg_type
216
+
217
+ return field_args, value_args, domain_name, sample_name
218
+
219
+
220
+ def _get_test_and_trial_fields(
221
+ fields: Dict[str, FieldLike],
222
+ ):
223
+ test = None
224
+ trial = None
225
+ test_name = None
226
+ trial_name = None
227
+
228
+ for name, field in fields.items():
229
+ if isinstance(field, TestField):
230
+ if test is not None:
231
+ raise ValueError("Duplicate test field argument")
232
+ test = field
233
+ test_name = name
234
+ elif isinstance(field, TrialField):
235
+ if trial is not None:
236
+ raise ValueError("Duplicate test field argument")
237
+ trial = field
238
+ trial_name = name
239
+
240
+ if trial is not None:
241
+ if test is None:
242
+ raise ValueError("A trial field cannot be provided without a test field")
243
+
244
+ if test.domain != trial.domain:
245
+ raise ValueError("Incompatible test and trial domains")
246
+
247
+ return test, test_name, trial, trial_name
248
+
249
+
250
+ def _gen_field_struct(field_args: Dict[str, FieldLike]):
251
+ class Fields:
252
+ pass
253
+
254
+ annotations = get_annotations(Fields)
255
+
256
+ for name, arg in field_args.items():
257
+ if isinstance(arg, GeometryDomain):
258
+ continue
259
+ setattr(Fields, name, arg.EvalArg())
260
+ annotations[name] = arg.EvalArg
261
+
262
+ try:
263
+ Fields.__annotations__ = annotations
264
+ except AttributeError:
265
+ setattr(Fields.__dict__, "__annotations__", annotations)
266
+
267
+ suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
268
+
269
+ return cache.get_struct(Fields, suffix=suffix)
270
+
271
+
272
+ def _gen_value_struct(value_args: Dict[str, type]):
273
+ class Values:
274
+ pass
275
+
276
+ annotations = get_annotations(Values)
277
+
278
+ for name, arg_type in value_args.items():
279
+ setattr(Values, name, None)
280
+ annotations[name] = arg_type
281
+
282
+ def arg_type_name(arg_type):
283
+ if isinstance(arg_type, wp.codegen.Struct):
284
+ return arg_type_name(arg_type.cls)
285
+ return getattr(arg_type, "__name__", str(arg_type))
286
+
287
+ def arg_type_name(arg_type):
288
+ if isinstance(arg_type, wp.codegen.Struct):
289
+ return arg_type_name(arg_type.cls)
290
+ return getattr(arg_type, "__name__", str(arg_type))
291
+
292
+ try:
293
+ Values.__annotations__ = annotations
294
+ except AttributeError:
295
+ setattr(Values.__dict__, "__annotations__", annotations)
296
+
297
+ suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
298
+
299
+ return cache.get_struct(Values, suffix=suffix)
300
+
301
+
302
+ def _get_trial_arg():
303
+ pass
304
+
305
+
306
+ def _get_test_arg():
307
+ pass
308
+
309
+
310
+ class _FieldWrappers:
311
+ pass
312
+
313
+
314
+ def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
315
+ integrand_func._field_wrappers = _FieldWrappers()
316
+ for name, field in fields.items():
317
+ setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
318
+
319
+
320
+ class PassFieldArgsToIntegrand(ast.NodeTransformer):
321
+ def __init__(
322
+ self,
323
+ arg_names: List[str],
324
+ field_args: Set[str],
325
+ value_args: Set[str],
326
+ sample_name: str,
327
+ domain_name: str,
328
+ test_name: str = None,
329
+ trial_name: str = None,
330
+ func_name: str = "integrand_func",
331
+ fields_var_name: str = "fields",
332
+ values_var_name: str = "values",
333
+ domain_var_name: str = "domain_arg",
334
+ sample_var_name: str = "sample",
335
+ field_wrappers_attr: str = "_field_wrappers",
336
+ ):
337
+ self._arg_names = arg_names
338
+ self._field_args = field_args
339
+ self._value_args = value_args
340
+ self._domain_name = domain_name
341
+ self._sample_name = sample_name
342
+ self._func_name = func_name
343
+ self._test_name = test_name
344
+ self._trial_name = trial_name
345
+ self._fields_var_name = fields_var_name
346
+ self._values_var_name = values_var_name
347
+ self._domain_var_name = domain_var_name
348
+ self._sample_var_name = sample_var_name
349
+ self._field_wrappers_attr = field_wrappers_attr
350
+
351
+ def visit_Call(self, call: ast.Call):
352
+ call = self.generic_visit(call)
353
+
354
+ callee = getattr(call.func, "id", None)
355
+
356
+ if callee == self._func_name:
357
+ # Replace function arguments with ours generated structs
358
+ call.args.clear()
359
+ for arg in self._arg_names:
360
+ if arg == self._domain_name:
361
+ call.args.append(
362
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
363
+ )
364
+ elif arg == self._sample_name:
365
+ call.args.append(
366
+ ast.Name(id=self._sample_var_name, ctx=ast.Load()),
367
+ )
368
+ elif arg in self._field_args:
369
+ call.args.append(
370
+ ast.Call(
371
+ func=ast.Attribute(
372
+ value=ast.Attribute(
373
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
374
+ attr=self._field_wrappers_attr,
375
+ ctx=ast.Load(),
376
+ ),
377
+ attr=arg,
378
+ ctx=ast.Load(),
379
+ ),
380
+ args=[
381
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
382
+ ast.Attribute(
383
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
384
+ attr=arg,
385
+ ctx=ast.Load(),
386
+ ),
387
+ ],
388
+ keywords=[],
389
+ )
390
+ )
391
+ elif arg in self._value_args:
392
+ call.args.append(
393
+ ast.Attribute(
394
+ value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
395
+ attr=arg,
396
+ ctx=ast.Load(),
397
+ )
398
+ )
399
+ else:
400
+ raise RuntimeError(f"Unhandled argument {arg}")
401
+ # print(ast.dump(call, indent=4))
402
+ elif callee == _get_test_arg.__name__:
403
+ # print(ast.dump(call, indent=4))
404
+ call = ast.Attribute(
405
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
406
+ attr=self._test_name,
407
+ ctx=ast.Load(),
408
+ )
409
+ elif callee == _get_trial_arg.__name__:
410
+ # print(ast.dump(call, indent=4))
411
+ call = ast.Attribute(
412
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
413
+ attr=self._trial_name,
414
+ ctx=ast.Load(),
415
+ )
416
+
417
+ return call
418
+
419
+
420
+ def get_integrate_constant_kernel(
421
+ integrand_func: wp.Function,
422
+ domain: GeometryDomain,
423
+ quadrature: Quadrature,
424
+ FieldStruct: wp.codegen.Struct,
425
+ ValueStruct: wp.codegen.Struct,
426
+ accumulate_dtype,
427
+ ):
428
+ def integrate_kernel_fn(
429
+ qp_arg: quadrature.Arg,
430
+ domain_arg: domain.ElementArg,
431
+ domain_index_arg: domain.ElementIndexArg,
432
+ fields: FieldStruct,
433
+ values: ValueStruct,
434
+ result: wp.array(dtype=accumulate_dtype),
435
+ ):
436
+ element_index = domain.element_index(domain_index_arg, wp.tid())
437
+ elem_sum = accumulate_dtype(0.0)
438
+
439
+ test_dof_index = NULL_DOF_INDEX
440
+ trial_dof_index = NULL_DOF_INDEX
441
+
442
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
443
+ for k in range(qp_point_count):
444
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
445
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
446
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
447
+
448
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
449
+ vol = domain.element_measure(domain_arg, sample)
450
+
451
+ val = integrand_func(sample, fields, values)
452
+
453
+ elem_sum += accumulate_dtype(qp_weight * vol * val)
454
+
455
+ wp.atomic_add(result, 0, elem_sum)
456
+
457
+ return integrate_kernel_fn
458
+
459
+
460
+ def get_integrate_linear_kernel(
461
+ integrand_func: wp.Function,
462
+ domain: GeometryDomain,
463
+ quadrature: Quadrature,
464
+ FieldStruct: wp.codegen.Struct,
465
+ ValueStruct: wp.codegen.Struct,
466
+ test: TestField,
467
+ output_dtype,
468
+ accumulate_dtype,
469
+ ):
470
+ def integrate_kernel_fn(
471
+ qp_arg: quadrature.Arg,
472
+ domain_arg: domain.ElementArg,
473
+ domain_index_arg: domain.ElementIndexArg,
474
+ test_arg: test.space_restriction.NodeArg,
475
+ fields: FieldStruct,
476
+ values: ValueStruct,
477
+ result: wp.array2d(dtype=output_dtype),
478
+ ):
479
+ local_node_index, test_dof = wp.tid()
480
+ node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
481
+ element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
482
+
483
+ trial_dof_index = NULL_DOF_INDEX
484
+
485
+ val_sum = accumulate_dtype(0.0)
486
+
487
+ for n in range(element_count):
488
+ node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
489
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
490
+
491
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
492
+
493
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
494
+ for k in range(qp_point_count):
495
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
496
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
497
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
498
+
499
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
500
+
501
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
502
+ val = integrand_func(sample, fields, values)
503
+
504
+ val_sum += accumulate_dtype(qp_weight * vol * val)
505
+
506
+ result[node_index, test_dof] = output_dtype(val_sum)
507
+
508
+ return integrate_kernel_fn
509
+
510
+
511
+ def get_integrate_linear_nodal_kernel(
512
+ integrand_func: wp.Function,
513
+ domain: GeometryDomain,
514
+ FieldStruct: wp.codegen.Struct,
515
+ ValueStruct: wp.codegen.Struct,
516
+ test: TestField,
517
+ output_dtype,
518
+ accumulate_dtype,
519
+ ):
520
+ def integrate_kernel_fn(
521
+ domain_arg: domain.ElementArg,
522
+ domain_index_arg: domain.ElementIndexArg,
523
+ test_restriction_arg: test.space_restriction.NodeArg,
524
+ fields: FieldStruct,
525
+ values: ValueStruct,
526
+ result: wp.array2d(dtype=output_dtype),
527
+ ):
528
+ local_node_index, dof = wp.tid()
529
+
530
+ node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
531
+ element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
532
+
533
+ trial_dof_index = NULL_DOF_INDEX
534
+
535
+ val_sum = accumulate_dtype(0.0)
536
+
537
+ for n in range(element_count):
538
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
539
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
540
+
541
+ coords = test.space.node_coords_in_element(
542
+ domain_arg,
543
+ _get_test_arg(),
544
+ element_index,
545
+ node_element_index.node_index_in_element,
546
+ )
547
+
548
+ if coords[0] != OUTSIDE:
549
+ node_weight = test.space.node_quadrature_weight(
550
+ domain_arg,
551
+ _get_test_arg(),
552
+ element_index,
553
+ node_element_index.node_index_in_element,
554
+ )
555
+
556
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
557
+
558
+ sample = Sample(
559
+ element_index,
560
+ coords,
561
+ node_index,
562
+ node_weight,
563
+ test_dof_index,
564
+ trial_dof_index,
565
+ )
566
+ vol = domain.element_measure(domain_arg, sample)
567
+ val = integrand_func(sample, fields, values)
568
+
569
+ val_sum += accumulate_dtype(node_weight * vol * val)
570
+
571
+ result[node_index, dof] = output_dtype(val_sum)
572
+
573
+ return integrate_kernel_fn
574
+
575
+
576
+ def get_integrate_bilinear_kernel(
577
+ integrand_func: wp.Function,
578
+ domain: GeometryDomain,
579
+ quadrature: Quadrature,
580
+ FieldStruct: wp.codegen.Struct,
581
+ ValueStruct: wp.codegen.Struct,
582
+ test: TestField,
583
+ trial: TrialField,
584
+ output_dtype,
585
+ accumulate_dtype,
586
+ ):
587
+ NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
588
+
589
+ def integrate_kernel_fn(
590
+ qp_arg: quadrature.Arg,
591
+ domain_arg: domain.ElementArg,
592
+ domain_index_arg: domain.ElementIndexArg,
593
+ test_arg: test.space_restriction.NodeArg,
594
+ trial_partition_arg: trial.space_partition.PartitionArg,
595
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
596
+ fields: FieldStruct,
597
+ values: ValueStruct,
598
+ row_offsets: wp.array(dtype=int),
599
+ triplet_rows: wp.array(dtype=int),
600
+ triplet_cols: wp.array(dtype=int),
601
+ triplet_values: wp.array3d(dtype=output_dtype),
602
+ ):
603
+ test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
604
+
605
+ element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
606
+ test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
607
+
608
+ trial_dof_index = DofIndex(trial_node, trial_dof)
609
+
610
+ for element in range(element_count):
611
+ test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
612
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
613
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
614
+
615
+ test_dof_index = DofIndex(
616
+ test_element_index.node_index_in_element,
617
+ test_dof,
618
+ )
619
+
620
+ val_sum = accumulate_dtype(0.0)
621
+
622
+ for k in range(qp_point_count):
623
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
624
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
625
+
626
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
627
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
628
+
629
+ sample = Sample(
630
+ element_index,
631
+ coords,
632
+ qp_index,
633
+ qp_weight,
634
+ test_dof_index,
635
+ trial_dof_index,
636
+ )
637
+ val = integrand_func(sample, fields, values)
638
+ val_sum += accumulate_dtype(qp_weight * vol * val)
639
+
640
+ block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
641
+ triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
642
+
643
+ # Set row and column indices
644
+ if test_dof == 0 and trial_dof == 0:
645
+ trial_node_index = trial.space_partition.partition_node_index(
646
+ trial_partition_arg,
647
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
648
+ )
649
+ triplet_rows[block_offset] = test_node_index
650
+ triplet_cols[block_offset] = trial_node_index
651
+
652
+ return integrate_kernel_fn
653
+
654
+
655
+ def get_integrate_bilinear_nodal_kernel(
656
+ integrand_func: wp.Function,
657
+ domain: GeometryDomain,
658
+ FieldStruct: wp.codegen.Struct,
659
+ ValueStruct: wp.codegen.Struct,
660
+ test: TestField,
661
+ output_dtype,
662
+ accumulate_dtype,
663
+ ):
664
+ def integrate_kernel_fn(
665
+ domain_arg: domain.ElementArg,
666
+ domain_index_arg: domain.ElementIndexArg,
667
+ test_restriction_arg: test.space_restriction.NodeArg,
668
+ fields: FieldStruct,
669
+ values: ValueStruct,
670
+ triplet_rows: wp.array(dtype=int),
671
+ triplet_cols: wp.array(dtype=int),
672
+ triplet_values: wp.array3d(dtype=output_dtype),
673
+ ):
674
+ local_node_index, test_dof, trial_dof = wp.tid()
675
+
676
+ element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
677
+ node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
678
+
679
+ val_sum = accumulate_dtype(0.0)
680
+
681
+ for n in range(element_count):
682
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
683
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
684
+
685
+ coords = test.space.node_coords_in_element(
686
+ domain_arg,
687
+ _get_test_arg(),
688
+ element_index,
689
+ node_element_index.node_index_in_element,
690
+ )
691
+
692
+ if coords[0] != OUTSIDE:
693
+ node_weight = test.space.node_quadrature_weight(
694
+ domain_arg,
695
+ _get_test_arg(),
696
+ element_index,
697
+ node_element_index.node_index_in_element,
698
+ )
699
+
700
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
701
+ trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
702
+
703
+ sample = Sample(
704
+ element_index,
705
+ coords,
706
+ node_index,
707
+ node_weight,
708
+ test_dof_index,
709
+ trial_dof_index,
710
+ )
711
+ vol = domain.element_measure(domain_arg, sample)
712
+ val = integrand_func(sample, fields, values)
713
+
714
+ val_sum += accumulate_dtype(node_weight * vol * val)
715
+
716
+ triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
717
+ triplet_rows[local_node_index] = node_index
718
+ triplet_cols[local_node_index] = node_index
719
+
720
+ return integrate_kernel_fn
721
+
722
+
723
+ def _generate_integrate_kernel(
724
+ integrand: Integrand,
725
+ domain: GeometryDomain,
726
+ nodal: bool,
727
+ quadrature: Quadrature,
728
+ test: Optional[TestField],
729
+ test_name: str,
730
+ trial: Optional[TrialField],
731
+ trial_name: str,
732
+ fields: Dict[str, FieldLike],
733
+ output_dtype: type,
734
+ accumulate_dtype: type,
735
+ kernel_options: Dict[str, Any] = {},
736
+ ) -> wp.Kernel:
737
+ output_dtype = wp.types.type_scalar_type(output_dtype)
738
+
739
+ # Extract field arguments from integrand
740
+ field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
741
+ integrand, fields=fields, domain=domain
742
+ )
743
+
744
+ FieldStruct = _gen_field_struct(field_args)
745
+ ValueStruct = _gen_value_struct(value_args)
746
+
747
+ # Check if kernel exist in cache
748
+ kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
749
+ if nodal:
750
+ kernel_suffix += "_nodal"
751
+ else:
752
+ kernel_suffix += quadrature.name
753
+
754
+ if test:
755
+ kernel_suffix += f"_test_{test.space_partition.name}_{test.space.name}"
756
+ if trial:
757
+ kernel_suffix += f"_trial_{trial.space_partition.name}_{trial.space.name}"
758
+
759
+ kernel = cache.get_integrand_kernel(
760
+ integrand=integrand,
761
+ suffix=kernel_suffix,
762
+ )
763
+ if kernel is not None:
764
+ return kernel, FieldStruct, ValueStruct
765
+
766
+ # Not found in cache, transform integrand and generate kernel
767
+
768
+ integrand_func = _translate_integrand(
769
+ integrand,
770
+ field_args,
771
+ )
772
+
773
+ _register_integrand_field_wrappers(integrand_func, fields)
774
+
775
+ if test is None and trial is None:
776
+ integrate_kernel_fn = get_integrate_constant_kernel(
777
+ integrand_func,
778
+ domain,
779
+ quadrature,
780
+ FieldStruct,
781
+ ValueStruct,
782
+ accumulate_dtype=accumulate_dtype,
783
+ )
784
+ elif trial is None:
785
+ if nodal:
786
+ integrate_kernel_fn = get_integrate_linear_nodal_kernel(
787
+ integrand_func,
788
+ domain,
789
+ FieldStruct,
790
+ ValueStruct,
791
+ test=test,
792
+ output_dtype=output_dtype,
793
+ accumulate_dtype=accumulate_dtype,
794
+ )
795
+ else:
796
+ integrate_kernel_fn = get_integrate_linear_kernel(
797
+ integrand_func,
798
+ domain,
799
+ quadrature,
800
+ FieldStruct,
801
+ ValueStruct,
802
+ test=test,
803
+ output_dtype=output_dtype,
804
+ accumulate_dtype=accumulate_dtype,
805
+ )
806
+ else:
807
+ if nodal:
808
+ integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
809
+ integrand_func,
810
+ domain,
811
+ FieldStruct,
812
+ ValueStruct,
813
+ test=test,
814
+ output_dtype=output_dtype,
815
+ accumulate_dtype=accumulate_dtype,
816
+ )
817
+ else:
818
+ integrate_kernel_fn = get_integrate_bilinear_kernel(
819
+ integrand_func,
820
+ domain,
821
+ quadrature,
822
+ FieldStruct,
823
+ ValueStruct,
824
+ test=test,
825
+ trial=trial,
826
+ output_dtype=output_dtype,
827
+ accumulate_dtype=accumulate_dtype,
828
+ )
829
+
830
+ kernel = cache.get_integrand_kernel(
831
+ integrand=integrand,
832
+ kernel_fn=integrate_kernel_fn,
833
+ suffix=kernel_suffix,
834
+ kernel_options=kernel_options,
835
+ code_transformers=[
836
+ PassFieldArgsToIntegrand(
837
+ arg_names=integrand.argspec.args,
838
+ field_args=field_args.keys(),
839
+ value_args=value_args.keys(),
840
+ sample_name=sample_name,
841
+ domain_name=domain_name,
842
+ test_name=test_name,
843
+ trial_name=trial_name,
844
+ )
845
+ ],
846
+ )
847
+
848
+ return kernel, FieldStruct, ValueStruct
849
+
850
+
851
+ def _launch_integrate_kernel(
852
+ kernel: wp.Kernel,
853
+ FieldStruct: wp.codegen.Struct,
854
+ ValueStruct: wp.codegen.Struct,
855
+ domain: GeometryDomain,
856
+ nodal: bool,
857
+ quadrature: Quadrature,
858
+ test: Optional[TestField],
859
+ trial: Optional[TrialField],
860
+ fields: Dict[str, FieldLike],
861
+ values: Dict[str, Any],
862
+ accumulate_dtype: type,
863
+ temporary_store: Optional[cache.TemporaryStore],
864
+ output_dtype: type,
865
+ output: Optional[Union[wp.array, BsrMatrix]],
866
+ device,
867
+ ):
868
+ # Set-up launch arguments
869
+ domain_elt_arg = domain.element_arg_value(device=device)
870
+ domain_elt_index_arg = domain.element_index_arg_value(device=device)
871
+
872
+ if quadrature is not None:
873
+ qp_arg = quadrature.arg_value(device=device)
874
+
875
+ field_arg_values = FieldStruct()
876
+ for k, v in fields.items():
877
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
878
+
879
+ value_struct_values = ValueStruct()
880
+ for k, v in values.items():
881
+ setattr(value_struct_values, k, v)
882
+
883
+ # Constant form
884
+ if test is None and trial is None:
885
+ if output is not None and output.dtype == accumulate_dtype:
886
+ if output.size < 1:
887
+ raise RuntimeError("Output array must be of size at least 1")
888
+ accumulate_array = output
889
+ else:
890
+ accumulate_temporary = cache.borrow_temporary(
891
+ shape=(1),
892
+ device=device,
893
+ dtype=accumulate_dtype,
894
+ temporary_store=temporary_store,
895
+ requires_grad=output is not None and output.requires_grad,
896
+ )
897
+ accumulate_array = accumulate_temporary.array
898
+
899
+ accumulate_array.zero_()
900
+ wp.launch(
901
+ kernel=kernel,
902
+ dim=domain.element_count(),
903
+ inputs=[
904
+ qp_arg,
905
+ domain_elt_arg,
906
+ domain_elt_index_arg,
907
+ field_arg_values,
908
+ value_struct_values,
909
+ accumulate_array,
910
+ ],
911
+ device=device,
912
+ )
913
+
914
+ if output == accumulate_array:
915
+ return output
916
+ elif output is None:
917
+ return accumulate_array.numpy()[0]
918
+ else:
919
+ array_cast(in_array=accumulate_array, out_array=output)
920
+ return output
921
+
922
+ test_arg = test.space_restriction.node_arg(device=device)
923
+
924
+ # Linear form
925
+ if trial is None:
926
+ # If an output array is provided with the correct type, accumulate directly into it
927
+ # Otherwise, grab a temporary array
928
+ if output is None:
929
+ if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
930
+ output_shape = (test.space_partition.node_count(),)
931
+ elif type_length(output_dtype) == 1:
932
+ output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
933
+ else:
934
+ raise RuntimeError(
935
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
936
+ )
937
+
938
+ output_temporary = cache.borrow_temporary(
939
+ temporary_store=temporary_store,
940
+ shape=output_shape,
941
+ dtype=output_dtype,
942
+ device=device,
943
+ )
944
+
945
+ output = output_temporary.array
946
+
947
+ else:
948
+ output_temporary = None
949
+
950
+ if output.shape[0] < test.space_partition.node_count():
951
+ raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
952
+
953
+ output_dtype = output.dtype
954
+ if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
955
+ if type_length(output_dtype) != 1:
956
+ raise RuntimeError(
957
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
958
+ )
959
+ if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
960
+ raise RuntimeError(
961
+ f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
962
+ )
963
+
964
+ # Launch the integration on the kernel on a 2d scalar view of the actual array
965
+ output.zero_()
966
+
967
+ def as_2d_array(array):
968
+ return wp.array(
969
+ data=None,
970
+ ptr=array.ptr,
971
+ capacity=array.capacity,
972
+ owner=False,
973
+ device=array.device,
974
+ shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
975
+ dtype=wp.types.type_scalar_type(output_dtype),
976
+ grad=None if array.grad is None else as_2d_array(array.grad),
977
+ )
978
+
979
+ output_view = output if output.ndim == 2 else as_2d_array(output)
980
+
981
+ if nodal:
982
+ wp.launch(
983
+ kernel=kernel,
984
+ dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
985
+ inputs=[
986
+ domain_elt_arg,
987
+ domain_elt_index_arg,
988
+ test_arg,
989
+ field_arg_values,
990
+ value_struct_values,
991
+ output_view,
992
+ ],
993
+ device=device,
994
+ )
995
+ else:
996
+ wp.launch(
997
+ kernel=kernel,
998
+ dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
999
+ inputs=[
1000
+ qp_arg,
1001
+ domain_elt_arg,
1002
+ domain_elt_index_arg,
1003
+ test_arg,
1004
+ field_arg_values,
1005
+ value_struct_values,
1006
+ output_view,
1007
+ ],
1008
+ device=device,
1009
+ )
1010
+
1011
+ if output_temporary is not None:
1012
+ return output_temporary.detach()
1013
+
1014
+ return output
1015
+
1016
+ # Bilinear form
1017
+
1018
+ if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
1019
+ block_type = output_dtype
1020
+ else:
1021
+ block_type = cache.cached_mat_type(
1022
+ shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
1023
+ )
1024
+
1025
+ if nodal:
1026
+ nnz = test.space_restriction.node_count()
1027
+ else:
1028
+ nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
1029
+
1030
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1031
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1032
+ triplet_values_temp = cache.borrow_temporary(
1033
+ temporary_store,
1034
+ shape=(
1035
+ nnz,
1036
+ test.space.VALUE_DOF_COUNT,
1037
+ trial.space.VALUE_DOF_COUNT,
1038
+ ),
1039
+ dtype=output_dtype,
1040
+ device=device,
1041
+ )
1042
+ triplet_cols = triplet_cols_temp.array
1043
+ triplet_rows = triplet_rows_temp.array
1044
+ triplet_values = triplet_values_temp.array
1045
+
1046
+ triplet_values.zero_()
1047
+
1048
+ if nodal:
1049
+ wp.launch(
1050
+ kernel=kernel,
1051
+ dim=triplet_values.shape,
1052
+ inputs=[
1053
+ domain_elt_arg,
1054
+ domain_elt_index_arg,
1055
+ test_arg,
1056
+ field_arg_values,
1057
+ value_struct_values,
1058
+ triplet_rows,
1059
+ triplet_cols,
1060
+ triplet_values,
1061
+ ],
1062
+ device=device,
1063
+ )
1064
+
1065
+ else:
1066
+ offsets = test.space_restriction.partition_element_offsets()
1067
+
1068
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1069
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1070
+ wp.launch(
1071
+ kernel=kernel,
1072
+ dim=(
1073
+ test.space_restriction.node_count(),
1074
+ trial.space.topology.NODES_PER_ELEMENT,
1075
+ test.space.VALUE_DOF_COUNT,
1076
+ trial.space.VALUE_DOF_COUNT,
1077
+ ),
1078
+ inputs=[
1079
+ qp_arg,
1080
+ domain_elt_arg,
1081
+ domain_elt_index_arg,
1082
+ test_arg,
1083
+ trial_partition_arg,
1084
+ trial_topology_arg,
1085
+ field_arg_values,
1086
+ value_struct_values,
1087
+ offsets,
1088
+ triplet_rows,
1089
+ triplet_cols,
1090
+ triplet_values,
1091
+ ],
1092
+ device=device,
1093
+ )
1094
+
1095
+ if output is not None:
1096
+ if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
1097
+ raise RuntimeError(
1098
+ f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1099
+ )
1100
+
1101
+ else:
1102
+ output = bsr_zeros(
1103
+ rows_of_blocks=test.space_partition.node_count(),
1104
+ cols_of_blocks=trial.space_partition.node_count(),
1105
+ block_type=block_type,
1106
+ device=device,
1107
+ )
1108
+
1109
+ bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
1110
+
1111
+ # Do not wait for garbage collection
1112
+ triplet_values_temp.release()
1113
+ triplet_rows_temp.release()
1114
+ triplet_cols_temp.release()
1115
+
1116
+ return output
1117
+
1118
+
1119
+ def integrate(
1120
+ integrand: Integrand,
1121
+ domain: Optional[GeometryDomain] = None,
1122
+ quadrature: Optional[Quadrature] = None,
1123
+ nodal: bool = False,
1124
+ fields: Dict[str, FieldLike] = {},
1125
+ values: Dict[str, Any] = {},
1126
+ accumulate_dtype: type = wp.float64,
1127
+ output_dtype: Optional[type] = None,
1128
+ output: Optional[Union[BsrMatrix, wp.array]] = None,
1129
+ device=None,
1130
+ temporary_store: Optional[cache.TemporaryStore] = None,
1131
+ kernel_options: Dict[str, Any] = {},
1132
+ ):
1133
+ """
1134
+ Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
1135
+
1136
+ Args:
1137
+ integrand: Form to be integrated, must have :func:`integrand` decorator
1138
+ domain: Integration domain. If None, deduced from fields
1139
+ quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1140
+ nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1141
+ fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1142
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1143
+ temporary_store: shared pool from which to allocate temporary arrays
1144
+ accumulate_dtype: Scalar type to be used for accumulating integration samples
1145
+ output: Sparse matrix or warp array into which to store the result of the integration
1146
+ output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
1147
+ device: Device on which to perform the integration
1148
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1149
+ """
1150
+ if not isinstance(integrand, Integrand):
1151
+ raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1152
+
1153
+ test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1154
+
1155
+ if domain is None:
1156
+ if quadrature is not None:
1157
+ domain = quadrature.domain
1158
+ elif test is not None:
1159
+ domain = test.domain
1160
+
1161
+ if domain is None:
1162
+ raise ValueError("Must provide at least one of domain, quadrature, or test field")
1163
+ if test is not None and domain != test.domain:
1164
+ raise NotImplementedError("Mixing integration and test domain is not supported yet")
1165
+
1166
+ if nodal:
1167
+ if quadrature is not None:
1168
+ raise ValueError("Cannot specify quadrature for nodal integration")
1169
+
1170
+ if test is None:
1171
+ raise ValueError("Nodal integration requires specifying a test function")
1172
+
1173
+ if trial is not None and test.space_partition != trial.space_partition:
1174
+ raise ValueError(
1175
+ "Bilinear nodal integration requires test and trial to be defined on the same function space"
1176
+ )
1177
+ else:
1178
+ if quadrature is None:
1179
+ order = sum(field.degree for field in fields.values())
1180
+ quadrature = RegularQuadrature(domain=domain, order=order)
1181
+ elif domain != quadrature.domain:
1182
+ raise ValueError("Incompatible integration and quadrature domain")
1183
+
1184
+ # Canonicalize types
1185
+ accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
1186
+ if output is not None:
1187
+ if isinstance(output, BsrMatrix):
1188
+ output_dtype = output.scalar_type
1189
+ else:
1190
+ output_dtype = output.dtype
1191
+ elif output_dtype is None:
1192
+ output_dtype = accumulate_dtype
1193
+ else:
1194
+ output_dtype = wp.types.type_to_warp(output_dtype)
1195
+
1196
+ kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1197
+ integrand=integrand,
1198
+ domain=domain,
1199
+ nodal=nodal,
1200
+ quadrature=quadrature,
1201
+ test=test,
1202
+ test_name=test_name,
1203
+ trial=trial,
1204
+ trial_name=trial_name,
1205
+ fields=fields,
1206
+ accumulate_dtype=accumulate_dtype,
1207
+ output_dtype=output_dtype,
1208
+ kernel_options=kernel_options,
1209
+ )
1210
+
1211
+ return _launch_integrate_kernel(
1212
+ kernel=kernel,
1213
+ FieldStruct=FieldStruct,
1214
+ ValueStruct=ValueStruct,
1215
+ domain=domain,
1216
+ nodal=nodal,
1217
+ quadrature=quadrature,
1218
+ test=test,
1219
+ trial=trial,
1220
+ fields=fields,
1221
+ values=values,
1222
+ accumulate_dtype=accumulate_dtype,
1223
+ temporary_store=temporary_store,
1224
+ output_dtype=output_dtype,
1225
+ output=output,
1226
+ device=device,
1227
+ )
1228
+
1229
+
1230
+ def get_interpolate_to_field_function(
1231
+ integrand_func: wp.Function,
1232
+ domain: GeometryDomain,
1233
+ FieldStruct: wp.codegen.Struct,
1234
+ ValueStruct: wp.codegen.Struct,
1235
+ dest: FieldRestriction,
1236
+ ):
1237
+ value_type = dest.space.dtype
1238
+
1239
+ def interpolate_to_field_fn(
1240
+ local_node_index: int,
1241
+ domain_arg: domain.ElementArg,
1242
+ domain_index_arg: domain.ElementIndexArg,
1243
+ dest_node_arg: dest.space_restriction.NodeArg,
1244
+ dest_eval_arg: dest.field.EvalArg,
1245
+ fields: FieldStruct,
1246
+ values: ValueStruct,
1247
+ ):
1248
+ node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1249
+ element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
1250
+
1251
+ test_dof_index = NULL_DOF_INDEX
1252
+ trial_dof_index = NULL_DOF_INDEX
1253
+ node_weight = 1.0
1254
+
1255
+ # Volume-weighted average across elements
1256
+ # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1257
+
1258
+ val_sum = value_type(0.0)
1259
+ vol_sum = float(0.0)
1260
+
1261
+ for n in range(element_count):
1262
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, local_node_index, n)
1263
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1264
+
1265
+ coords = dest.space.node_coords_in_element(
1266
+ domain_arg,
1267
+ dest_eval_arg.space_arg,
1268
+ element_index,
1269
+ node_element_index.node_index_in_element,
1270
+ )
1271
+
1272
+ if coords[0] != OUTSIDE:
1273
+ sample = Sample(
1274
+ element_index,
1275
+ coords,
1276
+ node_index,
1277
+ node_weight,
1278
+ test_dof_index,
1279
+ trial_dof_index,
1280
+ )
1281
+ vol = domain.element_measure(domain_arg, sample)
1282
+ val = integrand_func(sample, fields, values)
1283
+
1284
+ vol_sum += vol
1285
+ val_sum += vol * val
1286
+
1287
+ return val_sum, vol_sum
1288
+
1289
+ return interpolate_to_field_fn
1290
+
1291
+
1292
+ def get_interpolate_to_field_kernel(
1293
+ interpolate_to_field_fn: wp.Function,
1294
+ domain: GeometryDomain,
1295
+ FieldStruct: wp.codegen.Struct,
1296
+ ValueStruct: wp.codegen.Struct,
1297
+ dest: FieldRestriction,
1298
+ ):
1299
+ def interpolate_to_field_kernel_fn(
1300
+ domain_arg: domain.ElementArg,
1301
+ domain_index_arg: domain.ElementIndexArg,
1302
+ dest_node_arg: dest.space_restriction.NodeArg,
1303
+ dest_eval_arg: dest.field.EvalArg,
1304
+ fields: FieldStruct,
1305
+ values: ValueStruct,
1306
+ ):
1307
+ local_node_index = wp.tid()
1308
+
1309
+ val_sum, vol_sum = interpolate_to_field_fn(
1310
+ local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1311
+ )
1312
+
1313
+ if vol_sum > 0.0:
1314
+ node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1315
+ dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
1316
+
1317
+ return interpolate_to_field_kernel_fn
1318
+
1319
+
1320
+ def get_interpolate_to_array_kernel(
1321
+ integrand_func: wp.Function,
1322
+ domain: GeometryDomain,
1323
+ quadrature: Quadrature,
1324
+ FieldStruct: wp.codegen.Struct,
1325
+ ValueStruct: wp.codegen.Struct,
1326
+ value_type: type,
1327
+ ):
1328
+ def interpolate_to_array_kernel_fn(
1329
+ qp_arg: quadrature.Arg,
1330
+ domain_arg: quadrature.domain.ElementArg,
1331
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1332
+ fields: FieldStruct,
1333
+ values: ValueStruct,
1334
+ result: wp.array(dtype=value_type),
1335
+ ):
1336
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1337
+
1338
+ test_dof_index = NULL_DOF_INDEX
1339
+ trial_dof_index = NULL_DOF_INDEX
1340
+
1341
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1342
+ for k in range(qp_point_count):
1343
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1344
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1345
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1346
+
1347
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1348
+
1349
+ result[qp_index] = integrand_func(sample, fields, values)
1350
+
1351
+ return interpolate_to_array_kernel_fn
1352
+
1353
+
1354
+ def get_interpolate_nonvalued_kernel(
1355
+ integrand_func: wp.Function,
1356
+ domain: GeometryDomain,
1357
+ quadrature: Quadrature,
1358
+ FieldStruct: wp.codegen.Struct,
1359
+ ValueStruct: wp.codegen.Struct,
1360
+ ):
1361
+ def interpolate_nonvalued_kernel_fn(
1362
+ qp_arg: quadrature.Arg,
1363
+ domain_arg: quadrature.domain.ElementArg,
1364
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1365
+ fields: FieldStruct,
1366
+ values: ValueStruct,
1367
+ ):
1368
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1369
+
1370
+ test_dof_index = NULL_DOF_INDEX
1371
+ trial_dof_index = NULL_DOF_INDEX
1372
+
1373
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1374
+ for k in range(qp_point_count):
1375
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1376
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1377
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1378
+
1379
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1380
+ integrand_func(sample, fields, values)
1381
+
1382
+ return interpolate_nonvalued_kernel_fn
1383
+
1384
+
1385
+ def _generate_interpolate_kernel(
1386
+ integrand: Integrand,
1387
+ domain: GeometryDomain,
1388
+ dest: Optional[Union[FieldLike, wp.array]],
1389
+ quadrature: Optional[Quadrature],
1390
+ fields: Dict[str, FieldLike],
1391
+ kernel_options: Dict[str, Any] = {},
1392
+ ) -> wp.Kernel:
1393
+ # Extract field arguments from integrand
1394
+ field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
1395
+ integrand, fields=fields, domain=domain
1396
+ )
1397
+
1398
+ # Generate field struct
1399
+ integrand_func = _translate_integrand(
1400
+ integrand,
1401
+ field_args,
1402
+ )
1403
+
1404
+ _register_integrand_field_wrappers(integrand_func, fields)
1405
+
1406
+ FieldStruct = _gen_field_struct(field_args)
1407
+ ValueStruct = _gen_value_struct(value_args)
1408
+
1409
+ # Check if kernel exist in cache
1410
+ if isinstance(dest, FieldRestriction):
1411
+ kernel_suffix = (
1412
+ f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1413
+ )
1414
+ elif wp.types.is_array(dest):
1415
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
1416
+ else:
1417
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
1418
+
1419
+ kernel = cache.get_integrand_kernel(
1420
+ integrand=integrand,
1421
+ suffix=kernel_suffix,
1422
+ )
1423
+ if kernel is not None:
1424
+ return kernel, FieldStruct, ValueStruct
1425
+
1426
+ # Generate interpolation kernel
1427
+ if isinstance(dest, FieldRestriction):
1428
+ # need to split into kernel + function for diffferentiability
1429
+ interpolate_fn = get_interpolate_to_field_function(
1430
+ integrand_func,
1431
+ domain,
1432
+ dest=dest,
1433
+ FieldStruct=FieldStruct,
1434
+ ValueStruct=ValueStruct,
1435
+ )
1436
+
1437
+ interpolate_fn = cache.get_integrand_function(
1438
+ integrand=integrand,
1439
+ func=interpolate_fn,
1440
+ suffix=kernel_suffix,
1441
+ code_transformers=[
1442
+ PassFieldArgsToIntegrand(
1443
+ arg_names=integrand.argspec.args,
1444
+ field_args=field_args.keys(),
1445
+ value_args=value_args.keys(),
1446
+ sample_name=sample_name,
1447
+ domain_name=domain_name,
1448
+ )
1449
+ ],
1450
+ )
1451
+
1452
+ interpolate_kernel_fn = get_interpolate_to_field_kernel(
1453
+ interpolate_fn,
1454
+ domain,
1455
+ dest=dest,
1456
+ FieldStruct=FieldStruct,
1457
+ ValueStruct=ValueStruct,
1458
+ )
1459
+ elif wp.types.is_array(dest):
1460
+ interpolate_kernel_fn = get_interpolate_to_array_kernel(
1461
+ integrand_func,
1462
+ domain=domain,
1463
+ quadrature=quadrature,
1464
+ value_type=dest.dtype,
1465
+ FieldStruct=FieldStruct,
1466
+ ValueStruct=ValueStruct,
1467
+ )
1468
+ else:
1469
+ interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
1470
+ integrand_func,
1471
+ domain=domain,
1472
+ quadrature=quadrature,
1473
+ FieldStruct=FieldStruct,
1474
+ ValueStruct=ValueStruct,
1475
+ )
1476
+
1477
+ kernel = cache.get_integrand_kernel(
1478
+ integrand=integrand,
1479
+ kernel_fn=interpolate_kernel_fn,
1480
+ suffix=kernel_suffix,
1481
+ kernel_options=kernel_options,
1482
+ code_transformers=[
1483
+ PassFieldArgsToIntegrand(
1484
+ arg_names=integrand.argspec.args,
1485
+ field_args=field_args.keys(),
1486
+ value_args=value_args.keys(),
1487
+ sample_name=sample_name,
1488
+ domain_name=domain_name,
1489
+ )
1490
+ ],
1491
+ )
1492
+
1493
+ return kernel, FieldStruct, ValueStruct
1494
+
1495
+
1496
+ def _launch_interpolate_kernel(
1497
+ kernel: wp.kernel,
1498
+ FieldStruct: wp.codegen.Struct,
1499
+ ValueStruct: wp.codegen.Struct,
1500
+ domain: GeometryDomain,
1501
+ dest: Optional[Union[FieldRestriction, wp.array]],
1502
+ quadrature: Optional[Quadrature],
1503
+ fields: Dict[str, FieldLike],
1504
+ values: Dict[str, Any],
1505
+ device,
1506
+ ) -> wp.Kernel:
1507
+ # Set-up launch arguments
1508
+ elt_arg = domain.element_arg_value(device=device)
1509
+ elt_index_arg = domain.element_index_arg_value(device=device)
1510
+
1511
+ field_arg_values = FieldStruct()
1512
+ for k, v in fields.items():
1513
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
1514
+
1515
+ value_struct_values = ValueStruct()
1516
+ for k, v in values.items():
1517
+ setattr(value_struct_values, k, v)
1518
+
1519
+ if isinstance(dest, FieldRestriction):
1520
+ dest_node_arg = dest.space_restriction.node_arg(device=device)
1521
+ dest_eval_arg = dest.field.eval_arg_value(device=device)
1522
+
1523
+ wp.launch(
1524
+ kernel=kernel,
1525
+ dim=dest.space_restriction.node_count(),
1526
+ inputs=[
1527
+ elt_arg,
1528
+ elt_index_arg,
1529
+ dest_node_arg,
1530
+ dest_eval_arg,
1531
+ field_arg_values,
1532
+ value_struct_values,
1533
+ ],
1534
+ device=device,
1535
+ )
1536
+ elif wp.types.is_array(dest):
1537
+ qp_arg = quadrature.arg_value(device)
1538
+ wp.launch(
1539
+ kernel=kernel,
1540
+ dim=domain.element_count(),
1541
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
1542
+ device=device,
1543
+ )
1544
+ else:
1545
+ qp_arg = quadrature.arg_value(device)
1546
+ wp.launch(
1547
+ kernel=kernel,
1548
+ dim=domain.element_count(),
1549
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
1550
+ device=device,
1551
+ )
1552
+
1553
+
1554
+ def interpolate(
1555
+ integrand: Integrand,
1556
+ dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
1557
+ quadrature: Optional[Quadrature] = None,
1558
+ fields: Dict[str, FieldLike] = {},
1559
+ values: Dict[str, Any] = {},
1560
+ device=None,
1561
+ kernel_options: Dict[str, Any] = {},
1562
+ ):
1563
+ """
1564
+ Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
1565
+
1566
+ Args:
1567
+ integrand: Function to be interpolated, must have :func:`integrand` decorator
1568
+ dest: Where to store the interpolation result. Can be either
1569
+
1570
+ - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
1571
+ - a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
1572
+ - ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is responsible for dealing with the interpolation result.
1573
+ quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
1574
+ fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
1575
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1576
+ device: Device on which to perform the interpolation
1577
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1578
+ """
1579
+ if not isinstance(integrand, Integrand):
1580
+ raise ValueError("integrand must be tagged with @integrand decorator")
1581
+
1582
+ test, _, trial, __ = _get_test_and_trial_fields(fields)
1583
+ if test is not None or trial is not None:
1584
+ raise ValueError("Test or Trial fields should not be used for interpolation")
1585
+
1586
+ if isinstance(dest, DiscreteField):
1587
+ dest = make_restriction(dest)
1588
+
1589
+ if isinstance(dest, FieldRestriction):
1590
+ domain = dest.domain
1591
+ else:
1592
+ if quadrature is None:
1593
+ raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
1594
+
1595
+ domain = quadrature.domain
1596
+
1597
+ kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
1598
+ integrand=integrand,
1599
+ domain=domain,
1600
+ dest=dest,
1601
+ quadrature=quadrature,
1602
+ fields=fields,
1603
+ kernel_options=kernel_options,
1604
+ )
1605
+
1606
+ return _launch_interpolate_kernel(
1607
+ kernel=kernel,
1608
+ FieldStruct=FieldStruct,
1609
+ ValueStruct=ValueStruct,
1610
+ domain=domain,
1611
+ dest=dest,
1612
+ quadrature=quadrature,
1613
+ fields=fields,
1614
+ values=values,
1615
+ device=device,
1616
+ )